-
Notifications
You must be signed in to change notification settings - Fork 904
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(agents-api): Transitions stream SSE endpoint
Signed-off-by: Diwank Singh Tomer <[email protected]>
- Loading branch information
Showing
11 changed files
with
201 additions
and
38 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
64 changes: 64 additions & 0 deletions
64
agents-api/agents_api/models/execution/lookup_temporal_data.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
from typing import Any, TypeVar | ||
from uuid import UUID | ||
|
||
from beartype import beartype | ||
from fastapi import HTTPException | ||
from pycozo.client import QueryException | ||
from pydantic import ValidationError | ||
|
||
from ..utils import ( | ||
cozo_query, | ||
partialclass, | ||
rewrap_exceptions, | ||
verify_developer_id_query, | ||
verify_developer_owns_resource_query, | ||
wrap_in_class, | ||
) | ||
|
||
ModelT = TypeVar("ModelT", bound=Any) | ||
T = TypeVar("T") | ||
|
||
|
||
@rewrap_exceptions( | ||
{ | ||
QueryException: partialclass(HTTPException, status_code=400), | ||
ValidationError: partialclass(HTTPException, status_code=400), | ||
TypeError: partialclass(HTTPException, status_code=400), | ||
} | ||
) | ||
@wrap_in_class(dict, one=True) | ||
@cozo_query | ||
@beartype | ||
def lookup_temporal_data( | ||
*, | ||
developer_id: UUID, | ||
execution_id: UUID, | ||
) -> tuple[list[str], dict]: | ||
developer_id = str(developer_id) | ||
execution_id = str(execution_id) | ||
|
||
temporal_query = """ | ||
?[id] := | ||
execution_id = to_uuid($execution_id), | ||
*temporal_executions_lookup { | ||
id, execution_id, run_id, first_execution_run_id, result_run_id | ||
} | ||
""" | ||
|
||
queries = [ | ||
verify_developer_id_query(developer_id), | ||
verify_developer_owns_resource_query( | ||
developer_id, | ||
"executions", | ||
execution_id=execution_id, | ||
parents=[("agents", "agent_id"), ("tasks", "task_id")], | ||
), | ||
temporal_query, | ||
] | ||
|
||
return ( | ||
queries, | ||
{ | ||
"execution_id": str(execution_id), | ||
}, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
112 changes: 94 additions & 18 deletions
112
agents-api/agents_api/routers/tasks/stream_transitions_events.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,41 +1,117 @@ | ||
import logging | ||
from base64 import b64decode, b64encode | ||
from functools import partial | ||
from typing import Annotated | ||
|
||
import anyio | ||
from anyio.streams.memory import MemoryObjectSendStream | ||
from fastapi import Depends | ||
from fastapi import Depends, Query | ||
from pydantic import UUID4 | ||
from sse_starlette.sse import EventSourceResponse | ||
from starlette.requests import Request | ||
from temporalio.api.enums.v1 import EventType | ||
from temporalio.client import ( | ||
WorkflowHistoryEventFilterType, | ||
WorkflowHistoryEventAsyncIterator, | ||
) | ||
|
||
from ...autogen.openapi_model import TransitionEvent | ||
from ...clients.temporal import get_workflow_handle | ||
from ...dependencies.developer_id import get_developer_id | ||
from ...models.execution.lookup_temporal_data import lookup_temporal_data | ||
from ...worker.codec import from_payload_data | ||
from .router import router | ||
|
||
STREAM_TIMEOUT = 10 * 60 # 10 minutes | ||
|
||
|
||
# Create a function to publish events to the client | ||
# TODO: Unnest and simplify this function | ||
async def event_publisher( | ||
inner_send_chan: MemoryObjectSendStream, | ||
history_events: WorkflowHistoryEventAsyncIterator, | ||
): | ||
async with inner_send_chan: | ||
try: | ||
async for event in history_events: | ||
if event.event_type == EventType.EVENT_TYPE_ACTIVITY_TASK_COMPLETED: | ||
payloads = ( | ||
event.activity_task_completed_event_attributes.result.payloads | ||
) | ||
|
||
for payload in payloads: | ||
try: | ||
data_item = from_payload_data(payload.data) | ||
|
||
except Exception as e: | ||
logging.warning(f"Could not decode payload: {e}") | ||
continue | ||
|
||
if not isinstance(data_item, TransitionEvent): | ||
continue | ||
|
||
# FIXME: This does NOT return the last event (and maybe other events) | ||
# Need to fix this. I think we need to grab events from child workflows too | ||
transition_event_dict = dict( | ||
type=data_item.type, | ||
output=data_item.output, | ||
created_at=data_item.created_at.isoformat(), | ||
) | ||
|
||
next_page_token = ( | ||
b64encode(history_events.next_page_token).decode("ascii") | ||
if history_events.next_page_token | ||
else None | ||
) | ||
|
||
await inner_send_chan.send( | ||
dict( | ||
data=dict( | ||
transition=transition_event_dict, | ||
next_page_token=next_page_token, | ||
), | ||
) | ||
) | ||
|
||
except anyio.get_cancelled_exc_class() as e: | ||
with anyio.move_on_after(STREAM_TIMEOUT, shield=True): | ||
await inner_send_chan.send(dict(closing=True)) | ||
raise e | ||
|
||
|
||
@router.get("/executions/{execution_id}/transitions.stream", tags=["executions"]) | ||
async def stream_transitions_events( | ||
x_developer_id: Annotated[UUID4, Depends(get_developer_id)], | ||
execution_id: UUID4, | ||
req: Request, | ||
# FIXME: add support for page token | ||
next_page_token: Annotated[str | None, Query()] = None, | ||
): | ||
# Get temporal id | ||
temporal_data = lookup_temporal_data( | ||
developer_id=x_developer_id, | ||
execution_id=execution_id, | ||
) | ||
|
||
workflow_handle = await get_workflow_handle( | ||
handle_id=temporal_data["id"], | ||
) | ||
|
||
next_page_token: bytes | None = ( | ||
b64decode(next_page_token) if next_page_token else None | ||
) | ||
|
||
history_events = workflow_handle.fetch_history_events( | ||
page_size=1, | ||
next_page_token=next_page_token, | ||
wait_new_event=True, | ||
event_filter_type=WorkflowHistoryEventFilterType.ALL_EVENT, | ||
skip_archival=True, | ||
) | ||
|
||
# Create a channel to send events to the client | ||
send_chan, recv_chan = anyio.create_memory_object_stream(10) | ||
|
||
# Create a function to publish events to the client | ||
async def event_publisher(inner_send_chan: MemoryObjectSendStream): | ||
async with inner_send_chan: | ||
try: | ||
i = 0 | ||
while True: | ||
i += 1 | ||
await inner_send_chan.send(dict(data=i)) | ||
await anyio.sleep(1.0) | ||
except anyio.get_cancelled_exc_class() as e: | ||
with anyio.move_on_after(1, shield=True): | ||
await inner_send_chan.send(dict(closing=True)) | ||
raise e | ||
send_chan, recv_chan = anyio.create_memory_object_stream(max_buffer_size=100) | ||
|
||
return EventSourceResponse( | ||
recv_chan, data_sender_callable=partial(event_publisher, send_chan) | ||
recv_chan, | ||
data_sender_callable=partial(event_publisher, send_chan, history_events), | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +0,0 @@ | ||
|
||