Skip to content

Commit

Permalink
feat(agents-api): Transitions stream SSE endpoint
Browse files Browse the repository at this point in the history
Signed-off-by: Diwank Singh Tomer <[email protected]>
  • Loading branch information
creatorrr committed Sep 3, 2024
1 parent 3b933b5 commit ab427f0
Show file tree
Hide file tree
Showing 11 changed files with 201 additions and 38 deletions.
14 changes: 14 additions & 0 deletions agents-api/agents_api/clients/temporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,17 @@ async def run_task_execution_workflow(
run_timeout=timedelta(days=31),
# TODO: Should add search_attributes for queryability
)


async def get_workflow_handle(
*,
handle_id: str,
client: Client | None = None,
):
client = client or (await get_client())

handle = client.get_workflow_handle(
handle_id,
)

return handle
1 change: 1 addition & 0 deletions agents-api/agents_api/models/execution/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,6 @@
from .get_execution_transition import get_execution_transition
from .list_execution_transitions import list_execution_transitions
from .list_executions import list_executions
from .lookup_temporal_data import lookup_temporal_data
from .prepare_execution_input import prepare_execution_input
from .update_execution import update_execution
15 changes: 6 additions & 9 deletions agents-api/agents_api/models/execution/create_temporal_lookup.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from typing import TypeVar
from uuid import UUID, uuid4
from uuid import UUID

from beartype import beartype
from fastapi import HTTPException
Expand All @@ -21,6 +21,7 @@

@rewrap_exceptions(
{
AssertionError: partialclass(HTTPException, status_code=404),
QueryException: partialclass(HTTPException, status_code=400),
ValidationError: partialclass(HTTPException, status_code=400),
TypeError: partialclass(HTTPException, status_code=400),
Expand All @@ -31,14 +32,10 @@
def create_temporal_lookup(
*,
developer_id: UUID,
task_id: UUID,
execution_id: UUID | None = None,
execution_id: UUID,
workflow_handle: WorkflowHandle,
) -> tuple[list[str], dict]:
execution_id = execution_id or uuid4()

developer_id = str(developer_id)
task_id = str(task_id)
execution_id = str(execution_id)

temporal_columns, temporal_values = cozo_process_mutate_data(
Expand All @@ -63,9 +60,9 @@ def create_temporal_lookup(
verify_developer_id_query(developer_id),
verify_developer_owns_resource_query(
developer_id,
"tasks",
task_id=task_id,
parents=[("agents", "agent_id")],
"executions",
execution_id=execution_id,
parents=[("agents", "agent_id"), ("tasks", "task_id")],
),
temporal_executions_lookup_query,
]
Expand Down
64 changes: 64 additions & 0 deletions agents-api/agents_api/models/execution/lookup_temporal_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from typing import Any, TypeVar
from uuid import UUID

from beartype import beartype
from fastapi import HTTPException
from pycozo.client import QueryException
from pydantic import ValidationError

from ..utils import (
cozo_query,
partialclass,
rewrap_exceptions,
verify_developer_id_query,
verify_developer_owns_resource_query,
wrap_in_class,
)

ModelT = TypeVar("ModelT", bound=Any)
T = TypeVar("T")


@rewrap_exceptions(
{
QueryException: partialclass(HTTPException, status_code=400),
ValidationError: partialclass(HTTPException, status_code=400),
TypeError: partialclass(HTTPException, status_code=400),
}
)
@wrap_in_class(dict, one=True)
@cozo_query
@beartype
def lookup_temporal_data(
*,
developer_id: UUID,
execution_id: UUID,
) -> tuple[list[str], dict]:
developer_id = str(developer_id)
execution_id = str(execution_id)

temporal_query = """
?[id] :=
execution_id = to_uuid($execution_id),
*temporal_executions_lookup {
id, execution_id, run_id, first_execution_run_id, result_run_id
}
"""

queries = [
verify_developer_id_query(developer_id),
verify_developer_owns_resource_query(
developer_id,
"executions",
execution_id=execution_id,
parents=[("agents", "agent_id"), ("tasks", "task_id")],
),
temporal_query,
]

return (
queries,
{
"execution_id": str(execution_id),
},
)
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,6 @@ async def create_task_execution(
create_temporal_lookup,
#
developer_id=x_developer_id,
task_id=task_id,
execution_id=execution.id,
workflow_handle=handle,
)
Expand Down
112 changes: 94 additions & 18 deletions agents-api/agents_api/routers/tasks/stream_transitions_events.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,117 @@
import logging
from base64 import b64decode, b64encode
from functools import partial
from typing import Annotated

import anyio
from anyio.streams.memory import MemoryObjectSendStream
from fastapi import Depends
from fastapi import Depends, Query
from pydantic import UUID4
from sse_starlette.sse import EventSourceResponse
from starlette.requests import Request
from temporalio.api.enums.v1 import EventType
from temporalio.client import (
WorkflowHistoryEventFilterType,
WorkflowHistoryEventAsyncIterator,
)

from ...autogen.openapi_model import TransitionEvent
from ...clients.temporal import get_workflow_handle
from ...dependencies.developer_id import get_developer_id
from ...models.execution.lookup_temporal_data import lookup_temporal_data
from ...worker.codec import from_payload_data
from .router import router

STREAM_TIMEOUT = 10 * 60 # 10 minutes


# Create a function to publish events to the client
# TODO: Unnest and simplify this function
async def event_publisher(
inner_send_chan: MemoryObjectSendStream,
history_events: WorkflowHistoryEventAsyncIterator,
):
async with inner_send_chan:
try:
async for event in history_events:
if event.event_type == EventType.EVENT_TYPE_ACTIVITY_TASK_COMPLETED:
payloads = (
event.activity_task_completed_event_attributes.result.payloads
)

for payload in payloads:
try:
data_item = from_payload_data(payload.data)

except Exception as e:
logging.warning(f"Could not decode payload: {e}")
continue

if not isinstance(data_item, TransitionEvent):
continue

# FIXME: This does NOT return the last event (and maybe other events)
# Need to fix this. I think we need to grab events from child workflows too
transition_event_dict = dict(
type=data_item.type,
output=data_item.output,
created_at=data_item.created_at.isoformat(),
)

next_page_token = (
b64encode(history_events.next_page_token).decode("ascii")
if history_events.next_page_token
else None
)

await inner_send_chan.send(
dict(
data=dict(
transition=transition_event_dict,
next_page_token=next_page_token,
),
)
)

except anyio.get_cancelled_exc_class() as e:
with anyio.move_on_after(STREAM_TIMEOUT, shield=True):
await inner_send_chan.send(dict(closing=True))
raise e


@router.get("/executions/{execution_id}/transitions.stream", tags=["executions"])
async def stream_transitions_events(
x_developer_id: Annotated[UUID4, Depends(get_developer_id)],
execution_id: UUID4,
req: Request,
# FIXME: add support for page token
next_page_token: Annotated[str | None, Query()] = None,
):
# Get temporal id
temporal_data = lookup_temporal_data(
developer_id=x_developer_id,
execution_id=execution_id,
)

workflow_handle = await get_workflow_handle(
handle_id=temporal_data["id"],
)

next_page_token: bytes | None = (
b64decode(next_page_token) if next_page_token else None
)

history_events = workflow_handle.fetch_history_events(
page_size=1,
next_page_token=next_page_token,
wait_new_event=True,
event_filter_type=WorkflowHistoryEventFilterType.ALL_EVENT,
skip_archival=True,
)

# Create a channel to send events to the client
send_chan, recv_chan = anyio.create_memory_object_stream(10)

# Create a function to publish events to the client
async def event_publisher(inner_send_chan: MemoryObjectSendStream):
async with inner_send_chan:
try:
i = 0
while True:
i += 1
await inner_send_chan.send(dict(data=i))
await anyio.sleep(1.0)
except anyio.get_cancelled_exc_class() as e:
with anyio.move_on_after(1, shield=True):
await inner_send_chan.send(dict(closing=True))
raise e
send_chan, recv_chan = anyio.create_memory_object_stream(max_buffer_size=100)

return EventSourceResponse(
recv_chan, data_sender_callable=partial(event_publisher, send_chan)
recv_chan,
data_sender_callable=partial(event_publisher, send_chan, history_events),
)
2 changes: 0 additions & 2 deletions agents-api/agents_api/workflows/task_execution/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from temporalio import workflow
from temporalio.exceptions import ApplicationError


with workflow.unsafe.imports_passed_through():
from ...activities import task_steps
from ...autogen.openapi_model import (
Expand Down Expand Up @@ -102,7 +101,6 @@
# TODO: find a way to transition to error if workflow or activity times out.



async def continue_as_child(
execution_input: ExecutionInput,
start: TransitionTarget,
Expand Down
7 changes: 4 additions & 3 deletions agents-api/agents_api/workflows/task_execution/transition.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
from datetime import timedelta

from temporalio import workflow
from temporalio.exceptions import ApplicationError

from ...activities import task_steps
from ...autogen.openapi_model import (
CreateTransitionRequest,
TransitionTarget,
Transition,
TransitionTarget,
)
from ...common.protocol.tasks import StepContext, PartialTransition
from ...activities import task_steps
from ...common.protocol.tasks import PartialTransition, StepContext


async def transition(
Expand Down
4 changes: 2 additions & 2 deletions agents-api/tests/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ def test_execution(
)
create_temporal_lookup(
developer_id=developer_id,
task_id=task.id,
execution_id=execution.id,
workflow_handle=workflow_handle,
client=client,
)
Expand Down Expand Up @@ -294,7 +294,7 @@ def test_execution_started(
)
create_temporal_lookup(
developer_id=developer_id,
task_id=task.id,
execution_id=execution.id,
workflow_handle=workflow_handle,
client=client,
)
Expand Down
18 changes: 16 additions & 2 deletions agents-api/tests/test_execution_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from agents_api.models.execution.create_temporal_lookup import create_temporal_lookup
from agents_api.models.execution.get_execution import get_execution
from agents_api.models.execution.list_executions import list_executions
from agents_api.models.execution.lookup_temporal_data import lookup_temporal_data

from .fixtures import (
cozo_client,
Expand All @@ -33,15 +34,16 @@ def _(client=cozo_client, developer_id=test_developer_id, task=test_task):
id="blah",
)

create_execution(
execution = create_execution(
developer_id=developer_id,
task_id=task.id,
data=CreateExecutionRequest(input={"test": "test"}),
client=client,
)

create_temporal_lookup(
developer_id=developer_id,
task_id=task.id,
execution_id=execution.id,
workflow_handle=workflow_handle,
client=client,
)
Expand All @@ -59,6 +61,18 @@ def _(client=cozo_client, developer_id=test_developer_id, execution=test_executi
assert result.status == "queued"


@test("model: lookup temporal id")
def _(client=cozo_client, developer_id=test_developer_id, execution=test_execution):
result = lookup_temporal_data(
execution_id=execution.id,
developer_id=developer_id,
client=client,
)

assert result is not None
assert result["id"]


@test("model: list executions")
def _(
client=cozo_client,
Expand Down
1 change: 0 additions & 1 deletion agents-api/tests/test_set_get_workflow.py
Original file line number Diff line number Diff line change
@@ -1 +0,0 @@

0 comments on commit ab427f0

Please sign in to comment.