Skip to content

Commit

Permalink
feat(agents-api): Add asynchronous boto3
Browse files Browse the repository at this point in the history
  • Loading branch information
Ahmad-mtos committed Nov 29, 2024
1 parent 70cb496 commit be87865
Show file tree
Hide file tree
Showing 27 changed files with 399 additions and 89 deletions.
2 changes: 1 addition & 1 deletion agents-api/agents_api/activities/execute_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ async def execute_system(
arguments: dict[str, Any] = system.arguments or {}

if set(arguments.keys()) == {"bucket", "key"}:
arguments = load_from_blob_store_if_remote(arguments)
arguments = await load_from_blob_store_if_remote(arguments)

arguments["developer_id"] = context.execution_input.developer_id

Expand Down
6 changes: 3 additions & 3 deletions agents-api/agents_api/activities/sync_items_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,20 @@
from temporalio import activity

from ..common.protocol.remote import RemoteObject

import asyncio

@beartype
async def save_inputs_remote_fn(inputs: list[Any]) -> list[Any | RemoteObject]:
from ..common.storage_handler import store_in_blob_store_if_large

return [store_in_blob_store_if_large(input) for input in inputs]
return await asyncio.gather(*[store_in_blob_store_if_large(input) for input in inputs])


@beartype
async def load_inputs_remote_fn(inputs: list[Any | RemoteObject]) -> list[Any]:
from ..common.storage_handler import load_from_blob_store_if_remote

return [load_from_blob_store_if_remote(input) for input in inputs]
return await asyncio.gather(*[load_from_blob_store_if_remote(input) for input in inputs])


save_inputs_remote = activity.defn(name="save_inputs_remote")(save_inputs_remote_fn)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ async def evaluate_step(
else context.current_step.evaluate
)

values = context.prepare_for_step(include_remote=True) | additional_values
values = await context.prepare_for_step(include_remote=True) | additional_values

output = simple_eval_dict(expr, values)
result = StepOutcome(output=output)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ async def for_each_step(context: StepContext) -> StepOutcome:
assert isinstance(context.current_step, ForeachStep)

output = await base_evaluate(
context.current_step.foreach.in_, context.prepare_for_step()
context.current_step.foreach.in_, await context.prepare_for_step()
)
return StepOutcome(output=output)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ async def if_else_step(context: StepContext) -> StepOutcome:
assert isinstance(context.current_step, IfElseWorkflowStep)

expr: str = context.current_step.if_
output = await base_evaluate(expr, context.prepare_for_step())
output = await base_evaluate(expr, await context.prepare_for_step())
output: bool = bool(output)

result = StepOutcome(output=output)
Expand Down
2 changes: 1 addition & 1 deletion agents-api/agents_api/activities/task_steps/log_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ async def log_step(context: StepContext) -> StepOutcome:
template: str = context.current_step.log
output = await render_template(
template,
context.prepare_for_step(include_remote=True),
await context.prepare_for_step(include_remote=True),
skip_vars=["developer_id"],
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ async def map_reduce_step(context: StepContext) -> StepOutcome:
assert isinstance(context.current_step, MapReduceStep)

output = await base_evaluate(
context.current_step.over, context.prepare_for_step()
context.current_step.over, await context.prepare_for_step()
)

return StepOutcome(output=output)
Expand Down
2 changes: 1 addition & 1 deletion agents-api/agents_api/activities/task_steps/prompt_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def format_tool(tool: Tool) -> dict:
async def prompt_step(context: StepContext) -> StepOutcome:
# Get context data
prompt: str | list[dict] = context.current_step.model_dump()["prompt"]
context_data: dict = context.prepare_for_step(include_remote=True)
context_data: dict = await context.prepare_for_step(include_remote=True)

# If the prompt is a string and starts with $_ then we need to evaluate it
should_evaluate_prompt = isinstance(prompt, str) and prompt.startswith(
Expand Down
2 changes: 1 addition & 1 deletion agents-api/agents_api/activities/task_steps/return_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ async def return_step(context: StepContext) -> StepOutcome:
assert isinstance(context.current_step, ReturnStep)

exprs: dict[str, str] = context.current_step.return_
output = await base_evaluate(exprs, context.prepare_for_step())
output = await base_evaluate(exprs, await context.prepare_for_step())

result = StepOutcome(output=output)
return result
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ async def set_value_step(
try:
expr = override_expr if override_expr is not None else context.current_step.set

values = context.prepare_for_step() | additional_values
values = await context.prepare_for_step() | additional_values
output = simple_eval_dict(expr, values)
result = StepOutcome(output=output)

Expand Down
2 changes: 1 addition & 1 deletion agents-api/agents_api/activities/task_steps/switch_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ async def switch_step(context: StepContext) -> StepOutcome:
output: int = -1
cases: list[str] = [c.case for c in context.current_step.switch]

evaluator = get_evaluator(names=context.prepare_for_step())
evaluator = get_evaluator(names=await context.prepare_for_step())

for i, case in enumerate(cases):
result = evaluator.eval(case)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ async def tool_call_step(context: StepContext) -> StepOutcome:
raise ApplicationError(f"Tool {tool_name} not found in the toolset")

arguments = await base_evaluate(
context.current_step.arguments, context.prepare_for_step()
context.current_step.arguments, await context.prepare_for_step()
)

call_id = generate_call_id()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ async def transition_step(
transition_info: CreateTransitionRequest,
) -> Transition:
# Load output from blob store if it is a remote object
transition_info.output = load_from_blob_store_if_remote(transition_info.output)
transition_info.output = await load_from_blob_store_if_remote(transition_info.output)

# Create transition
transition = create_execution_transition(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ async def wait_for_input_step(context: StepContext) -> StepOutcome:
assert isinstance(context.current_step, WaitForInputStep)

exprs = context.current_step.wait_for_input.info
output = await base_evaluate(exprs, context.prepare_for_step())
output = await base_evaluate(exprs, await context.prepare_for_step())

result = StepOutcome(output=output)
return result
Expand Down
2 changes: 1 addition & 1 deletion agents-api/agents_api/activities/task_steps/yield_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ async def yield_step(context: StepContext) -> StepOutcome:
], f"Workflow {workflow} not found in task"

# Evaluate the expressions in the arguments
arguments = await base_evaluate(exprs, context.prepare_for_step())
arguments = await base_evaluate(exprs, await context.prepare_for_step())

# Transition to the first step of that workflow
transition_target = TransitionTarget(
Expand Down
92 changes: 92 additions & 0 deletions agents-api/agents_api/clients/async_s3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
from functools import cache, lru_cache

from beartype import beartype
from temporalio import workflow

with workflow.unsafe.imports_passed_through():
import aioboto3
import botocore
from xxhash import xxh3_64_hexdigest as xxhash_key

from ..env import (
blob_store_bucket,
blob_store_cutoff_kb,
s3_access_key,
s3_endpoint,
s3_secret_key,
)


@cache
async def get_s3_client():
return await aioboto3.session.client(
"s3",
endpoint_url=s3_endpoint,
aws_access_key_id=s3_access_key,
aws_secret_access_key=s3_secret_key,
)


async def list_buckets() -> list[str]:
client = await get_s3_client()
data = await client.list_buckets()
buckets = [bucket["Name"] for bucket in data["Buckets"]]

return buckets


@cache
async def setup():
client = await get_s3_client()
if blob_store_bucket not in await list_buckets():
await client.create_bucket(Bucket=blob_store_bucket)


@lru_cache(maxsize=10_000)
async def exists(key: str) -> bool:
client = await get_s3_client()

try:
client.head_object(Bucket=blob_store_bucket, Key=key)
return True

except botocore.exceptions.ClientError as e:
if e.response["Error"]["Code"] == "404":
return False
else:
raise e


@beartype
async def add_object(key: str, body: bytes, replace: bool = False) -> None:
client = await get_s3_client()

if replace:
client.put_object(Bucket=blob_store_bucket, Key=key, Body=body)
return

if exists(key):
return

client.put_object(Bucket=blob_store_bucket, Key=key, Body=body)


@lru_cache(maxsize=256 * 1024 // max(1, blob_store_cutoff_kb)) # 256mb in cache
@beartype
async def get_object(key: str) -> bytes:
client = await get_s3_client()
return (await client.get_object(Bucket=blob_store_bucket, Key=key))["Body"].read()


@beartype
async def delete_object(key: str) -> None:
client = await get_s3_client()
await client.delete_object(Bucket=blob_store_bucket, Key=key)


@beartype
async def add_object_with_hash(body: bytes, replace: bool = False) -> str:
key = xxhash_key(body)
await add_object(key, body, replace=replace)

return key
4 changes: 3 additions & 1 deletion agents-api/agents_api/clients/temporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,9 @@ async def run_task_execution_workflow(
client = client or (await get_client())
execution_id = execution_input.execution.id
execution_id_key = SearchAttributeKey.for_keyword("CustomStringField")
execution_input.arguments = store_in_blob_store_if_large(execution_input.arguments)
execution_input.arguments = await store_in_blob_store_if_large(
execution_input.arguments
)

return await client.start_workflow(
TaskExecutionWorkflow.run,
Expand Down
36 changes: 14 additions & 22 deletions agents-api/agents_api/common/protocol/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,23 +25,23 @@ def __init__(self, **data: Any):
super().__init__(**data)
self._remote_cache = {}

def __load_item(self, item: Any | RemoteObject) -> Any:
async def load_item(self, item: Any | RemoteObject) -> Any:
if not activity.in_activity():
return item

from ..storage_handler import load_from_blob_store_if_remote

return load_from_blob_store_if_remote(item)
return await load_from_blob_store_if_remote(item)

def __save_item(self, item: Any) -> Any:
async def save_item(self, item: Any) -> Any:
if not activity.in_activity():
return item

from ..storage_handler import store_in_blob_store_if_large

return store_in_blob_store_if_large(item)
return await store_in_blob_store_if_large(item)

def __getattribute__(self, name: str) -> Any:
async def get_attribute(self, name: str) -> Any:
if name.startswith("_"):
return super().__getattribute__(name)

Expand All @@ -57,45 +57,37 @@ def __getattribute__(self, name: str) -> Any:
if name in cache:
return cache[name]

loaded_data = self.__load_item(value)
loaded_data = await self.load_item(value)
cache[name] = loaded_data
return loaded_data

return value

def __setattr__(self, name: str, value: Any) -> None:
async def set_attribute(self, name: str, value: Any) -> None:
if name.startswith("_"):
super().__setattr__(name, value)
return

stored_value = self.__save_item(value)
stored_value = await self.save_item(value)
super().__setattr__(name, stored_value)

if isinstance(stored_value, RemoteObject):
cache = self.__dict__.get("_remote_cache", {})
cache.pop(name, None)

def load_all(self) -> None:
async def load_all(self) -> None:
for name in self.model_fields_set:
self.__getattribute__(name)

def model_dump(
self, *args, include_remote: bool = False, **kwargs
) -> dict[str, Any]:
if include_remote:
self.load_all()
await self.get_attribute(name)

return super().model_dump(*args, **kwargs)

def unload_attribute(self, name: str) -> None:
async def unload_attribute(self, name: str) -> None:
if name in self._remote_cache:
data = self._remote_cache.pop(name)
remote_obj = self.__save_item(data)
remote_obj = await self.save_item(data)
super().__setattr__(name, remote_obj)

def unload_all(self) -> "BaseRemoteModel":
async def unload_all(self) -> "BaseRemoteModel":
for name in list(self._remote_cache.keys()):
self.unload_attribute(name)
await self.unload_attribute(name)
return self


Expand Down
14 changes: 10 additions & 4 deletions agents-api/agents_api/common/protocol/tasks.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
from typing import Annotated, Any, Literal
from uuid import UUID

Expand Down Expand Up @@ -241,12 +242,17 @@ def model_dump(self, *args, **kwargs) -> dict[str, Any]:

return dump | execution_input

def prepare_for_step(self, *args, **kwargs) -> dict[str, Any]:
async def prepare_for_step(
self, *args, include_remote: bool = False, **kwargs
) -> dict[str, Any]:
current_input = self.current_input
inputs = self.inputs
if activity.in_activity():
inputs = [load_from_blob_store_if_remote(input) for input in inputs]
current_input = load_from_blob_store_if_remote(current_input)
if activity.in_activity() and include_remote:
await self.load_all()
inputs = await asyncio.gather(
*[load_from_blob_store_if_remote(input) for input in inputs]
)
current_input = await load_from_blob_store_if_remote(current_input)

# Merge execution inputs into the dump dict
dump = self.model_dump(*args, **kwargs)
Expand Down
Loading

0 comments on commit be87865

Please sign in to comment.