Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(agents-api): Add retry policies to temporal workflows/activities #551

Merged
merged 12 commits into from
Oct 5, 2024
Prev Previous commit
Next Next commit
Add retry policies to more activities/workflows
  • Loading branch information
HamadaSalhab committed Oct 2, 2024
commit 93dfee1eb16f3ca35dc90155cf22701e12bdfac0
2 changes: 2 additions & 0 deletions agents-api/agents_api/routers/docs/create_doc.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Annotated
from uuid import UUID, uuid4

from ...common.retry_policies import DEFAULT_RETRY_POLICY
from fastapi import BackgroundTasks, Depends
from starlette.status import HTTP_201_CREATED
from temporalio.client import Client as TemporalClient
Expand Down Expand Up @@ -41,6 +42,7 @@ async def run_embed_docs_task(
embed_payload,
task_queue=temporal_task_queue,
id=str(job_id),
retry_policy=DEFAULT_RETRY_POLICY
)

# TODO: Remove this conditional once we have a way to run workflows in
Expand Down
3 changes: 3 additions & 0 deletions agents-api/agents_api/workflows/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from temporalio import workflow

from ..common.retry_policies import DEFAULT_RETRY_POLICY

with workflow.unsafe.imports_passed_through():
from ..activities.demo import demo_activity

Expand All @@ -14,4 +16,5 @@ async def run(self, a: int, b: int) -> int:
demo_activity,
args=[a, b],
start_to_close_timeout=timedelta(seconds=30),
retry_policy=DEFAULT_RETRY_POLICY
)
2 changes: 2 additions & 0 deletions agents-api/agents_api/workflows/embed_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
with workflow.unsafe.imports_passed_through():
from ..activities.embed_docs import embed_docs
from ..activities.types import EmbedDocsPayload
from ..common.retry_policies import DEFAULT_RETRY_POLICY


@workflow.defn
Expand All @@ -18,4 +19,5 @@ async def run(self, embed_payload: EmbedDocsPayload) -> None:
embed_docs,
embed_payload,
schedule_to_close_timeout=timedelta(seconds=600),
retry_policy=DEFAULT_RETRY_POLICY
)
2 changes: 2 additions & 0 deletions agents-api/agents_api/workflows/mem_rating.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

with workflow.unsafe.imports_passed_through():
from ..activities.mem_rating import mem_rating
from ..common.retry_policies import DEFAULT_RETRY_POLICY


@workflow.defn
Expand All @@ -17,4 +18,5 @@ async def run(self, memory: str) -> None:
mem_rating,
memory,
schedule_to_close_timeout=timedelta(seconds=600),
retry_policy=DEFAULT_RETRY_POLICY
)
2 changes: 2 additions & 0 deletions agents-api/agents_api/workflows/summarization.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

with workflow.unsafe.imports_passed_through():
from ..activities.summarization import summarization
from ..common.retry_policies import DEFAULT_RETRY_POLICY


@workflow.defn
Expand All @@ -17,4 +18,5 @@ async def run(self, session_id: str) -> None:
summarization,
session_id,
schedule_to_close_timeout=timedelta(seconds=600),
retry_policy=DEFAULT_RETRY_POLICY
)
5 changes: 5 additions & 0 deletions agents-api/agents_api/workflows/task_execution/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,7 @@ async def run(
task_steps.raise_complete_async,
args=[context, output],
schedule_to_close_timeout=timedelta(days=31),
retry_policy=DEFAULT_RETRY_POLICY
)

state = PartialTransition(type="resume", output=result)
Expand Down Expand Up @@ -419,6 +420,7 @@ async def run(
task_steps.raise_complete_async,
args=[context, tool_calls_input],
schedule_to_close_timeout=timedelta(days=31),
retry_policy=DEFAULT_RETRY_POLICY
)

# Feed the tool call results back to the model
Expand All @@ -430,6 +432,7 @@ async def run(
schedule_to_close_timeout=timedelta(
seconds=30 if debug or testing else 600
),
retry_policy=DEFAULT_RETRY_POLICY
)
state = PartialTransition(output=new_response.output, type="resume")

Expand Down Expand Up @@ -473,6 +476,7 @@ async def run(
task_steps.raise_complete_async,
args=[context, tool_call],
schedule_to_close_timeout=timedelta(days=31),
retry_policy=DEFAULT_RETRY_POLICY
)

state = PartialTransition(output=tool_call_response, type="resume")
Expand Down Expand Up @@ -503,6 +507,7 @@ async def run(
schedule_to_close_timeout=timedelta(
seconds=30 if debug or testing else 600
),
retry_policy=DEFAULT_RETRY_POLICY
)

state = PartialTransition(output=tool_call_response)
Expand Down
2 changes: 2 additions & 0 deletions agents-api/agents_api/workflows/task_execution/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ async def execute_map_reduce_step(
task_steps.base_evaluate,
args=[reduce, {"results": result, "_": output}],
schedule_to_close_timeout=timedelta(seconds=30),
retry_policy=DEFAULT_RETRY_POLICY
)

return result
Expand Down Expand Up @@ -245,6 +246,7 @@ async def execute_map_reduce_step_parallel(
extra_lambda_strs,
],
schedule_to_close_timeout=timedelta(seconds=30),
retry_policy=DEFAULT_RETRY_POLICY
)

except BaseException as e:
Expand Down
2 changes: 2 additions & 0 deletions agents-api/agents_api/workflows/truncation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

with workflow.unsafe.imports_passed_through():
from ..activities.truncation import truncation
from ..common.retry_policies import DEFAULT_RETRY_POLICY


@workflow.defn
Expand All @@ -17,4 +18,5 @@ async def run(self, session_id: str, token_count_threshold: int) -> None:
truncation,
args=[session_id, token_count_threshold],
schedule_to_close_timeout=timedelta(seconds=600),
retry_policy=DEFAULT_RETRY_POLICY
)
2 changes: 2 additions & 0 deletions agents-api/tests/test_activities.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from agents_api.clients import temporal
from agents_api.env import temporal_task_queue
from agents_api.workflows.demo import DemoWorkflow
from agents_api.workflows.task_execution.helpers import DEFAULT_RETRY_POLICY

from .fixtures import (
cozo_client,
Expand Down Expand Up @@ -49,6 +50,7 @@ async def _():
args=[1, 2],
id=str(uuid4()),
task_queue=temporal_task_queue,
retry_policy=DEFAULT_RETRY_POLICY
)

assert result == 3
Expand Down