Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: protect model repository creation with a lock to avoid a race condition #5095

Merged
merged 1 commit into from
Nov 22, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 17 additions & 10 deletions src/bentoml/_internal/cloud/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import warnings
from concurrent.futures import ThreadPoolExecutor
from tempfile import NamedTemporaryFile
from threading import Lock

import attrs
import fs
Expand Down Expand Up @@ -47,6 +48,7 @@
class ModelAPI:
_client: RestApiClient = attrs.field(repr=False)
spinner: Spinner = attrs.field(repr=False, factory=Spinner)
_lock: Lock = attrs.field(repr=False, init=False, factory=Lock)

def push(
self,
Expand Down Expand Up @@ -88,17 +90,22 @@ def _do_push_model(
if version is None:
raise BentoMLException(f'Model "{model}" version cannot be None')

with self.spinner.spin(text=f'Fetching model repository "{name}"'):
model_repository = rest_client.v1.get_model_repository(
model_repository_name=name
)
if not model_repository:
with self.spinner.spin(
text=f'Model repository "{name}" not found, creating now..'
):
model_repository = rest_client.v1.create_model_repository(
req=CreateModelRepositorySchema(name=name, description="")
with self._lock:
# Models might be pushed by multiple threads at the same time
# when they are under the same model repository, race condition
# might happen when creating the model repository. So we need to
# protect it with a lock.
with self.spinner.spin(text=f'Fetching model repository "{name}"'):
model_repository = rest_client.v1.get_model_repository(
model_repository_name=name
)
if not model_repository:
with self.spinner.spin(
text=f'Model repository "{name}" not found, creating now..'
):
model_repository = rest_client.v1.create_model_repository(
req=CreateModelRepositorySchema(name=name, description="")
)
with self.spinner.spin(
text=f'Try fetching model "{model}" from remote model store..'
):
Expand Down