Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(agents-api): Add retry policies to temporal workflows/activities #551

Merged
merged 12 commits into from
Oct 5, 2024
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import base64

from temporalio import activity

from ...autogen.openapi_model import CreateTransitionRequest
Expand All @@ -11,10 +12,9 @@

@activity.defn
async def raise_complete_async(context: StepContext, output: StepOutcome) -> None:

activity_info = activity.info()

captured_token = base64.b64encode(activity_info.task_token).decode('ascii')
captured_token = base64.b64encode(activity_info.task_token).decode("ascii")
activity_id = activity_info.activity_id
workflow_run_id = activity_info.workflow_run_id
workflow_id = activity_info.workflow_id
Expand Down
2 changes: 2 additions & 0 deletions agents-api/agents_api/clients/temporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from ..autogen.openapi_model import TransitionTarget
from ..common.protocol.tasks import ExecutionInput
from ..common.retry_policies import DEFAULT_RETRY_POLICY
from ..env import (
temporal_client_cert,
temporal_namespace,
Expand Down Expand Up @@ -54,6 +55,7 @@ async def run_task_execution_workflow(
task_queue=temporal_task_queue,
id=str(job_id),
run_timeout=timedelta(days=31),
retry_policy=DEFAULT_RETRY_POLICY,
# TODO: Should add search_attributes for queryability
)

Expand Down
18 changes: 18 additions & 0 deletions agents-api/agents_api/common/retry_policies.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from datetime import timedelta

from temporalio.common import RetryPolicy

DEFAULT_RETRY_POLICY = RetryPolicy(
initial_interval=timedelta(seconds=1),
backoff_coefficient=2,
maximum_attempts=2,
HamadaSalhab marked this conversation as resolved.
Show resolved Hide resolved
maximum_interval=timedelta(seconds=10),
HamadaSalhab marked this conversation as resolved.
Show resolved Hide resolved
non_retryable_error_types=[
"WorkflowExecutionAlreadyStarted",
"TypeError",
"AssertionError",
"HTTPException",
"SyntaxError",
"ValueError",
HamadaSalhab marked this conversation as resolved.
Show resolved Hide resolved
],
)
2 changes: 2 additions & 0 deletions agents-api/agents_api/routers/docs/create_doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from ...activities.types import EmbedDocsPayload
from ...autogen.openapi_model import CreateDocRequest, ResourceCreatedResponse
from ...clients import temporal
from ...common.retry_policies import DEFAULT_RETRY_POLICY
from ...dependencies.developer_id import get_developer_id
from ...env import temporal_task_queue, testing
from ...models.docs.create_doc import create_doc as create_doc_query
Expand Down Expand Up @@ -41,6 +42,7 @@ async def run_embed_docs_task(
embed_payload,
task_queue=temporal_task_queue,
id=str(job_id),
retry_policy=DEFAULT_RETRY_POLICY,
)

# TODO: Remove this conditional once we have a way to run workflows in
Expand Down
8 changes: 6 additions & 2 deletions agents-api/agents_api/routers/tasks/update_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,9 @@ async def update_execution(
workflow_id = token_data["metadata"].get("x-workflow-id", None)
if activity_id is None or run_id is None or workflow_id is None:
act_handle = temporal_client.get_async_activity_handle(
task_token=base64.b64decode(token_data["task_token"].encode('ascii')),
task_token=base64.b64decode(
token_data["task_token"].encode("ascii")
),
)

else:
Expand All @@ -59,6 +61,8 @@ async def update_execution(
try:
await act_handle.complete(data.input)
except Exception as e:
raise HTTPException(status_code=500, detail="Failed to resume execution")
raise HTTPException(
status_code=500, detail="Failed to resume execution"
)
case _:
raise HTTPException(status_code=400, detail="Invalid request data")
3 changes: 3 additions & 0 deletions agents-api/agents_api/workflows/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from temporalio import workflow

from ..common.retry_policies import DEFAULT_RETRY_POLICY

with workflow.unsafe.imports_passed_through():
from ..activities.demo import demo_activity

Expand All @@ -14,4 +16,5 @@ async def run(self, a: int, b: int) -> int:
demo_activity,
args=[a, b],
start_to_close_timeout=timedelta(seconds=30),
retry_policy=DEFAULT_RETRY_POLICY,
)
2 changes: 2 additions & 0 deletions agents-api/agents_api/workflows/embed_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
with workflow.unsafe.imports_passed_through():
from ..activities.embed_docs import embed_docs
from ..activities.types import EmbedDocsPayload
from ..common.retry_policies import DEFAULT_RETRY_POLICY


@workflow.defn
Expand All @@ -18,4 +19,5 @@ async def run(self, embed_payload: EmbedDocsPayload) -> None:
embed_docs,
embed_payload,
schedule_to_close_timeout=timedelta(seconds=600),
retry_policy=DEFAULT_RETRY_POLICY,
)
2 changes: 2 additions & 0 deletions agents-api/agents_api/workflows/mem_rating.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

with workflow.unsafe.imports_passed_through():
from ..activities.mem_rating import mem_rating
from ..common.retry_policies import DEFAULT_RETRY_POLICY


@workflow.defn
Expand All @@ -17,4 +18,5 @@ async def run(self, memory: str) -> None:
mem_rating,
memory,
schedule_to_close_timeout=timedelta(seconds=600),
retry_policy=DEFAULT_RETRY_POLICY,
)
2 changes: 2 additions & 0 deletions agents-api/agents_api/workflows/summarization.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

with workflow.unsafe.imports_passed_through():
from ..activities.summarization import summarization
from ..common.retry_policies import DEFAULT_RETRY_POLICY


@workflow.defn
Expand All @@ -17,4 +18,5 @@ async def run(self, session_id: str) -> None:
summarization,
session_id,
schedule_to_close_timeout=timedelta(seconds=600),
retry_policy=DEFAULT_RETRY_POLICY,
)
8 changes: 8 additions & 0 deletions agents-api/agents_api/workflows/task_execution/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from temporalio import workflow
from temporalio.exceptions import ApplicationError

from ...common.retry_policies import DEFAULT_RETRY_POLICY

# Import necessary modules and types
with workflow.unsafe.imports_passed_through():
from ...activities import task_steps
Expand Down Expand Up @@ -200,6 +202,7 @@ async def run(
schedule_to_close_timeout=timedelta(
seconds=30 if debug or testing else 600
),
retry_policy=DEFAULT_RETRY_POLICY,
)
workflow.logger.debug(
f"Step {context.cursor.step} completed successfully"
Expand Down Expand Up @@ -385,6 +388,7 @@ async def run(
task_steps.raise_complete_async,
args=[context, output],
schedule_to_close_timeout=timedelta(days=31),
retry_policy=DEFAULT_RETRY_POLICY,
)

state = PartialTransition(type="resume", output=result)
Expand Down Expand Up @@ -417,6 +421,7 @@ async def run(
task_steps.raise_complete_async,
args=[context, tool_calls_input],
schedule_to_close_timeout=timedelta(days=31),
retry_policy=DEFAULT_RETRY_POLICY,
)

# Feed the tool call results back to the model
Expand All @@ -428,6 +433,7 @@ async def run(
schedule_to_close_timeout=timedelta(
seconds=30 if debug or testing else 600
),
retry_policy=DEFAULT_RETRY_POLICY,
)
state = PartialTransition(output=new_response.output, type="resume")

Expand Down Expand Up @@ -471,6 +477,7 @@ async def run(
task_steps.raise_complete_async,
args=[context, tool_call],
schedule_to_close_timeout=timedelta(days=31),
retry_policy=DEFAULT_RETRY_POLICY,
)

state = PartialTransition(output=tool_call_response, type="resume")
Expand Down Expand Up @@ -501,6 +508,7 @@ async def run(
schedule_to_close_timeout=timedelta(
seconds=30 if debug or testing else 600
),
retry_policy=DEFAULT_RETRY_POLICY,
)

state = PartialTransition(output=tool_call_response)
Expand Down
5 changes: 5 additions & 0 deletions agents-api/agents_api/workflows/task_execution/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from temporalio import workflow
from temporalio.exceptions import ApplicationError

from ...common.retry_policies import DEFAULT_RETRY_POLICY

with workflow.unsafe.imports_passed_through():
from ...activities import task_steps
from ...autogen.openapi_model import (
Expand Down Expand Up @@ -33,6 +35,7 @@ async def continue_as_child(
previous_inputs,
user_state,
],
retry_policy=DEFAULT_RETRY_POLICY,
)


Expand Down Expand Up @@ -169,6 +172,7 @@ async def execute_map_reduce_step(
task_steps.base_evaluate,
args=[reduce, {"results": result, "_": output}],
schedule_to_close_timeout=timedelta(seconds=30),
retry_policy=DEFAULT_RETRY_POLICY,
)

return result
Expand Down Expand Up @@ -244,6 +248,7 @@ async def execute_map_reduce_step_parallel(
extra_lambda_strs,
],
schedule_to_close_timeout=timedelta(seconds=30),
retry_policy=DEFAULT_RETRY_POLICY,
)

except BaseException as e:
Expand Down
2 changes: 2 additions & 0 deletions agents-api/agents_api/workflows/task_execution/transition.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
TransitionTarget,
)
from ...common.protocol.tasks import PartialTransition, StepContext
from ...common.retry_policies import DEFAULT_RETRY_POLICY


async def transition(
Expand Down Expand Up @@ -44,6 +45,7 @@ async def transition(
task_steps.transition_step,
args=[context, transition_request],
schedule_to_close_timeout=timedelta(seconds=30),
retry_policy=DEFAULT_RETRY_POLICY,
)

except Exception as e:
Expand Down
2 changes: 2 additions & 0 deletions agents-api/agents_api/workflows/truncation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

with workflow.unsafe.imports_passed_through():
from ..activities.truncation import truncation
from ..common.retry_policies import DEFAULT_RETRY_POLICY


@workflow.defn
Expand All @@ -17,4 +18,5 @@ async def run(self, session_id: str, token_count_threshold: int) -> None:
truncation,
args=[session_id, token_count_threshold],
schedule_to_close_timeout=timedelta(seconds=600),
retry_policy=DEFAULT_RETRY_POLICY,
)
4 changes: 3 additions & 1 deletion agents-api/notebooks/03-summarise.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -766,7 +766,9 @@
" messages.append(user(start_message))\n",
"\n",
" print(\"Starting chatml generation\")\n",
" trim_result = generate(messages, model=\"gpt-4-turbo\", temperature=0.1, stop=[\"</ct\"])\n",
" trim_result = generate(\n",
" messages, model=\"gpt-4-turbo\", temperature=0.1, stop=[\"</ct\"]\n",
" )\n",
" print(\"End chatml generation\")\n",
" messages.append(trim_result)\n",
"\n",
Expand Down
Loading
Loading