Skip to content

Commit

Permalink
fix(agents-api): Remove sync s3 client
Browse files Browse the repository at this point in the history
Signed-off-by: Diwank Singh Tomer <[email protected]>
  • Loading branch information
creatorrr committed Nov 29, 2024
1 parent 95def7c commit f6bf839
Show file tree
Hide file tree
Showing 9 changed files with 34 additions and 388 deletions.
2 changes: 0 additions & 2 deletions agents-api/agents_api/activities/execute_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@
TextOnlyDocSearchRequest,
VectorDocSearchRequest,
)
from ..autogen.Sessions import CreateSessionRequest
from ..autogen.Tools import SystemDef
from ..common.protocol.tasks import StepContext
from ..common.storage_handler import auto_blob_store, load_from_blob_store_if_remote
from ..env import testing
Expand Down
3 changes: 1 addition & 2 deletions agents-api/agents_api/autogen/openapi_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
model_validator,
)

from ..common.storage_handler import RemoteObject
from ..common.utils.datetime import utcnow
from .Agents import *
from .Chat import *
Expand Down Expand Up @@ -355,8 +356,6 @@ def validate_subworkflows(self):
# Create models
# -------------

from ..common.storage_handler import RemoteObject


class SystemDef(SystemDef):
arguments: dict[str, Any] | None | RemoteObject = None
Expand Down
156 changes: 1 addition & 155 deletions agents-api/agents_api/common/protocol/remote.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import Any, Iterator
from typing import Any

from temporalio import activity, workflow

Expand Down Expand Up @@ -89,157 +89,3 @@ async def unload_all(self) -> "BaseRemoteModel":
for name in list(self._remote_cache.keys()):
await self.unload_attribute(name)
return self


class RemoteList(list):
_remote_cache: dict[int, Any]

def __init__(self, iterable: list[Any] | None = None):
super().__init__()
self._remote_cache: dict[int, Any] = {}
if iterable:
for item in iterable:
self.append(item)

def __load_item(self, item: Any | RemoteObject) -> Any:
if not activity.in_activity():
return item

from ..storage_handler import load_from_blob_store_if_remote

return load_from_blob_store_if_remote(item)

def __save_item(self, item: Any) -> Any:
if not activity.in_activity():
return item

from ..storage_handler import store_in_blob_store_if_large

return store_in_blob_store_if_large(item)

def __getitem__(
self, index: int | slice
) -> Any: # pytype: disable=signature-mismatch
if isinstance(index, slice):
# Obtain the slice without triggering __getitem__ recursively
sliced_items = super().__getitem__(
index
) # This returns a list of items as is
return RemoteList._from_existing_items(sliced_items)
else:
value = super().__getitem__(index)

if isinstance(value, RemoteObject):
if index in self._remote_cache:
return self._remote_cache[index]
loaded_data = self.__load_item(value)
self._remote_cache[index] = loaded_data
return loaded_data
return value

@classmethod
def _from_existing_items(cls, items: list[Any]) -> "RemoteList":
"""
Create a RemoteList from existing items without processing them again.
This method ensures that slicing does not trigger loading of items.
"""
new_remote_list = cls.__new__(
cls
) # Create a new instance without calling __init__
list.__init__(new_remote_list) # Initialize as an empty list
new_remote_list._remote_cache = {}
new_remote_list._extend_without_processing(items)
return new_remote_list

def _extend_without_processing(self, items: list[Any]) -> None:
"""
Extend the list without processing the items (i.e., without storing them again).
"""
super().extend(items)

def __setitem__(
self, index: int | slice, value: Any
) -> None: # pytype: disable=signature-mismatch
if isinstance(index, slice):
# Handle slice assignment without processing existing RemoteObjects
processed_values = [self.__save_item(v) for v in value]
super().__setitem__(index, processed_values)
# Clear cache for affected indices
for i in range(*index.indices(len(self))):
self._remote_cache.pop(i, None)
else:
stored_value = self.__save_item(value)
super().__setitem__(index, stored_value)
self._remote_cache.pop(index, None)

def append(self, value: Any) -> None:
stored_value = self.__save_item(value)
super().append(stored_value)
# No need to cache immediately

def insert(self, index: int, value: Any) -> None:
stored_value = self.__save_item(value)
super().insert(index, stored_value)
# Adjust cache indices
self._shift_cache_on_insert(index)

def _shift_cache_on_insert(self, index: int) -> None:
new_cache = {}
for i, v in self._remote_cache.items():
if i >= index:
new_cache[i + 1] = v
else:
new_cache[i] = v
self._remote_cache = new_cache

def remove(self, value: Any) -> None:
# Find the index of the value to remove
index = self.index(value)
super().remove(value)
self._remote_cache.pop(index, None)
# Adjust cache indices
self._shift_cache_on_remove(index)

def _shift_cache_on_remove(self, index: int) -> None:
new_cache = {}
for i, v in self._remote_cache.items():
if i > index:
new_cache[i - 1] = v
elif i < index:
new_cache[i] = v
# Else: i == index, already removed
self._remote_cache = new_cache

def pop(self, index: int = -1) -> Any:
value = super().pop(index)
# Adjust negative indices
if index < 0:
index = len(self) + index
self._remote_cache.pop(index, None)
# Adjust cache indices
self._shift_cache_on_remove(index)
return value

def clear(self) -> None:
super().clear()
self._remote_cache.clear()

def extend(self, iterable: list[Any]) -> None:
for item in iterable:
self.append(item)

def __iter__(self) -> Iterator[Any]: # pytype: disable=signature-mismatch
for index in range(len(self)):
yield self.__getitem__(index)

def unload_item(self, index: int) -> None:
"""Unload a specific item and replace it with a RemoteObject."""
if index in self._remote_cache:
data = self._remote_cache.pop(index)
remote_obj = self.__save_item(data)
super().__setitem__(index, remote_obj)

def unload_all(self) -> None:
"""Unload all cached items."""
for index in list(self._remote_cache.keys()):
self.unload_item(index)
62 changes: 24 additions & 38 deletions agents-api/agents_api/common/storage_handler.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import asyncio
import inspect
import sys
from datetime import timedelta
from functools import wraps
Expand All @@ -10,9 +9,8 @@

from ..activities.sync_items_remote import load_inputs_remote
from ..clients import async_s3
from ..common.protocol.remote import BaseRemoteModel, RemoteList, RemoteObject
from ..common.protocol.remote import BaseRemoteModel, RemoteObject
from ..common.retry_policies import DEFAULT_RETRY_POLICY
from ..common.sync_storage_handler import sync_load_args
from ..env import (
blob_store_cutoff_kb,
debug,
Expand Down Expand Up @@ -49,9 +47,6 @@ async def load_from_blob_store_if_remote(x: Any | RemoteObject) -> Any:
fetched = await async_s3.get_object(x.key)
return deserialize(fetched)

elif isinstance(x, RemoteList):
x = list(x)

elif isinstance(x, dict) and set(x.keys()) == {"bucket", "key"}:
fetched = await async_s3.get_object(x["key"])
return deserialize(fetched)
Expand Down Expand Up @@ -109,14 +104,16 @@ async def load_args(
getattr(arg, field)
),
)
elif isinstance(getattr(arg, field), RemoteList):
elif isinstance(getattr(arg, field), list):
setattr(
arg,
field,
[
await load_from_blob_store_if_remote(item)
for item in getattr(arg, field)
],
await asyncio.gather(
*[
await load_from_blob_store_if_remote(item)
for item in getattr(arg, field)
]
),
)
elif isinstance(getattr(arg, field), BaseRemoteModel):
setattr(
Expand Down Expand Up @@ -157,14 +154,16 @@ async def load_args(
getattr(v, field)
),
)
elif isinstance(getattr(v, field), RemoteList):
elif isinstance(getattr(v, field), list):
setattr(
v,
field,
[
await load_from_blob_store_if_remote(item)
for item in getattr(v, field)
],
await asyncio.gather(
*[
await load_from_blob_store_if_remote(item)
for item in getattr(v, field)
]
),
)
elif isinstance(getattr(v, field), BaseRemoteModel):
setattr(
Expand All @@ -179,33 +178,20 @@ async def load_args(

return new_args, new_kwargs

async def unload_return_value(x: Any | BaseRemoteModel | RemoteList) -> Any:
if isinstance(x, (BaseRemoteModel, RemoteList)):
x.unload_all()
async def unload_return_value(x: Any | BaseRemoteModel) -> Any:
if isinstance(x, BaseRemoteModel):
await x.unload_all()

return await store_in_blob_store_if_large(x)

if inspect.iscoroutinefunction(f):

@wraps(f)
async def async_wrapper(*args, **kwargs) -> Any:
new_args, new_kwargs = await load_args(args, kwargs)
output = await f(*new_args, **new_kwargs)

return await unload_return_value(output)

return async_wrapper if use_blob_store_for_temporal else f

else:
# FIXME: Remove sync wrapper
@wraps(f)
def wrapper(*args, **kwargs) -> Any:
new_args, new_kwargs = sync_load_args(deep, args, kwargs)
output = f(*new_args, **new_kwargs)
@wraps(f)
async def async_wrapper(*args, **kwargs) -> Any:
new_args, new_kwargs = await load_args(args, kwargs)
output = await f(*new_args, **new_kwargs)

return unload_return_value(output)
return await unload_return_value(output)

return wrapper if use_blob_store_for_temporal else f
return async_wrapper if use_blob_store_for_temporal else f

return auto_blob_store_decorator(f) if f else auto_blob_store_decorator

Expand Down
Loading

0 comments on commit f6bf839

Please sign in to comment.