Skip to content

Commit

Permalink
Merge pull request #910 from julep-ai/f/async-transition
Browse files Browse the repository at this point in the history
Make transitions queries async
  • Loading branch information
creatorrr authored Nov 29, 2024
2 parents 6fd42a6 + 0487f98 commit d3b37db
Show file tree
Hide file tree
Showing 13 changed files with 404 additions and 73 deletions.
1 change: 0 additions & 1 deletion agents-api/agents_api/activities/execute_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
)
from ..autogen.Sessions import CreateSessionRequest
from ..autogen.Tools import SystemDef
from ..common.protocol.remote import RemoteObject
from ..common.protocol.tasks import StepContext
from ..common.storage_handler import auto_blob_store, load_from_blob_store_if_remote
from ..env import testing
Expand Down
6 changes: 0 additions & 6 deletions agents-api/agents_api/activities/task_steps/prompt_step.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,4 @@
from typing import Callable

from anthropic.types.beta.beta_message import BetaMessage
from beartype import beartype
from langchain_core.tools import BaseTool
from langchain_core.tools.convert import tool as tool_decorator
from litellm.types.utils import ModelResponse
from temporalio import activity
from temporalio.exceptions import ApplicationError
Expand All @@ -16,7 +11,6 @@
from ...common.storage_handler import auto_blob_store
from ...common.utils.template import render_template
from ...env import debug
from ..utils import get_handler_with_filtered_params
from .base_evaluate import base_evaluate

COMPUTER_USE_BETA_FLAG = "computer-use-2024-10-22"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
from ...common.protocol.tasks import StepContext
from ...common.storage_handler import load_from_blob_store_if_remote
from ...env import testing
from ...models.execution.create_execution_transition import create_execution_transition
from ...models.execution.create_execution_transition import (
create_execution_transition_async,
)


@beartype
Expand All @@ -17,7 +19,7 @@ async def transition_step(
transition_info.output = load_from_blob_store_if_remote(transition_info.output)

# Create transition
transition = create_execution_transition(
transition = await create_execution_transition_async(
developer_id=context.execution_input.developer_id,
execution_id=context.execution_input.execution.id,
task_id=context.execution_input.task.id,
Expand Down
1 change: 0 additions & 1 deletion agents-api/agents_api/autogen/Tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from pydantic import AwareDatetime, BaseModel, ConfigDict, Field, StrictBool

from .Chat import ChatSettings
from .Common import JinjaTemplate
from .Tools import (
ChosenBash20241022,
ChosenComputer20241022,
Expand Down
11 changes: 11 additions & 0 deletions agents-api/agents_api/clients/cozo.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Dict

from pycozo.client import Client
from pycozo_async import Client as AsyncClient

from ..env import cozo_auth, cozo_host
from ..web import app
Expand All @@ -16,3 +17,13 @@ def get_cozo_client() -> Client:
app.state.cozo_client = client

return client


def get_async_cozo_client() -> AsyncClient:
client = getattr(
app.state, "async_cozo_client", AsyncClient("http", options=options)
)
if not hasattr(app.state, "async_cozo_client"):
app.state.async_cozo_client = client

return client
12 changes: 9 additions & 3 deletions agents-api/agents_api/metrics/counters.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import inspect
from functools import wraps
from typing import Callable, ParamSpec, TypeVar
from typing import Awaitable, Callable, ParamSpec, TypeVar

from prometheus_client import Counter

Expand All @@ -8,7 +9,7 @@


def increase_counter(metric_label: str, id_field_name: str = "developer_id"):
def decor(func: Callable[P, T]):
def decor(func: Callable[P, T | Awaitable[T]]):
metric = Counter(
metric_label,
f"Number of {metric_label} calls",
Expand All @@ -20,6 +21,11 @@ def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
metric.labels(kwargs.get(id_field_name, "not_set")).inc()
return func(*args, **kwargs)

return wrapper
@wraps(func)
async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
metric.labels(kwargs.get(id_field_name, "not_set")).inc()
return await func(*args, **kwargs)

return async_wrapper if inspect.iscoroutinefunction(func) else wrapper

return decor
5 changes: 4 additions & 1 deletion agents-api/agents_api/models/execution/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@

from .count_executions import count_executions
from .create_execution import create_execution
from .create_execution_transition import create_execution_transition
from .create_execution_transition import (
create_execution_transition,
create_execution_transition_async,
)
from .get_execution import get_execution
from .get_execution_transition import get_execution_transition
from .list_execution_transitions import list_execution_transitions
Expand Down
167 changes: 167 additions & 0 deletions agents-api/agents_api/models/execution/create_execution_transition.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from ...metrics.counters import increase_counter
from ..utils import (
cozo_query,
cozo_query_async,
partialclass,
rewrap_exceptions,
verify_developer_id_query,
Expand Down Expand Up @@ -222,3 +223,169 @@ def create_execution_transition(
**update_execution_params,
},
)


@rewrap_exceptions(
{
QueryException: partialclass(HTTPException, status_code=400),
ValidationError: partialclass(HTTPException, status_code=400),
TypeError: partialclass(HTTPException, status_code=400),
}
)
@wrap_in_class(
Transition,
transform=lambda d: {
**d,
"id": d["transition_id"],
"current": {"workflow": d["current"][0], "step": d["current"][1]},
"next": d["next"] and {"workflow": d["next"][0], "step": d["next"][1]},
},
one=True,
_kind="inserted",
)
@cozo_query_async
@increase_counter("create_execution_transition_async")
@beartype
async def create_execution_transition_async(
*,
developer_id: UUID,
execution_id: UUID,
data: CreateTransitionRequest,
# Only one of these needed
transition_id: UUID | None = None,
task_token: str | None = None,
# Only required for updating the execution status as well
update_execution_status: bool = False,
task_id: UUID | None = None,
) -> tuple[list[str | None], dict]:
transition_id = transition_id or uuid4()
data.metadata = data.metadata or {}
data.execution_id = execution_id

# Dump to json
if isinstance(data.output, list):
data.output = [
item.model_dump(mode="json") if hasattr(item, "model_dump") else item
for item in data.output
]

elif hasattr(data.output, "model_dump"):
data.output = data.output.model_dump(mode="json")

# TODO: This is a hack to make sure the transition is valid
# (parallel transitions are whack, we should do something better)
is_parallel = data.current.workflow.startswith("PAR:")

# Prepare the transition data
transition_data = data.model_dump(exclude_unset=True, exclude={"id"})

# Parse the current and next targets
validate_transition_targets(data)
current_target = transition_data.pop("current")
next_target = transition_data.pop("next")

transition_data["current"] = (current_target["workflow"], current_target["step"])
transition_data["next"] = next_target and (
next_target["workflow"],
next_target["step"],
)

columns, transition_values = cozo_process_mutate_data(
{
**transition_data,
"task_token": str(task_token), # Converting to str for JSON serialisation
"transition_id": str(transition_id),
"execution_id": str(execution_id),
}
)

# Make sure the transition is valid
check_last_transition_query = f"""
valid_transition[start, end] <- [
{", ".join(f'["{start}", "{end}"]' for start, ends in valid_transitions.items() for end in ends)}
]
last_transition_type[min_cost(type_created_at)] :=
*transitions {{
execution_id: to_uuid("{str(execution_id)}"),
type,
created_at,
}},
type_created_at = [type, -created_at]
matched[collect(last_type)] :=
last_transition_type[data],
last_type_data = first(data),
last_type = if(is_null(last_type_data), "init", last_type_data),
valid_transition[last_type, $next_type]
?[valid] :=
matched[prev_transitions],
found = length(prev_transitions),
valid = if($next_type == "init", found == 0, found > 0),
assert(valid, "Invalid transition"),
:limit 1
"""

# Prepare the insert query
insert_query = f"""
?[{columns}] <- $transition_values
:insert transitions {{
{columns}
}}
:returning
"""

validate_status_query, update_execution_query, update_execution_params = (
"",
"",
{},
)

if update_execution_status:
assert (
task_id is not None
), "task_id is required for updating the execution status"

# Prepare the execution update query
[*_, validate_status_query, update_execution_query], update_execution_params = (
update_execution.__wrapped__(
developer_id=developer_id,
task_id=task_id,
execution_id=execution_id,
data=UpdateExecutionRequest(
status=transition_to_execution_status[data.type]
),
output=data.output if data.type != "error" else None,
error=str(data.output)
if data.type == "error" and data.output
else None,
)
)

queries = [
verify_developer_id_query(developer_id),
verify_developer_owns_resource_query(
developer_id,
"executions",
execution_id=execution_id,
parents=[("agents", "agent_id"), ("tasks", "task_id")],
),
validate_status_query if not is_parallel else None,
update_execution_query if not is_parallel else None,
check_last_transition_query if not is_parallel else None,
insert_query,
]

return (
queries,
{
"transition_values": transition_values,
"next_type": data.type,
"valid_transitions": valid_transitions,
**update_execution_params,
},
)
Loading

0 comments on commit d3b37db

Please sign in to comment.