Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(agents-api): Add retry policies to temporal workflows/activities #551

Merged
merged 12 commits into from
Oct 5, 2024
2 changes: 2 additions & 0 deletions agents-api/agents_api/clients/temporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from ..autogen.openapi_model import TransitionTarget
from ..common.protocol.tasks import ExecutionInput
from ..common.retry_policies import DEFAULT_RETRY_POLICY
from ..env import (
temporal_client_cert,
temporal_namespace,
Expand Down Expand Up @@ -54,6 +55,7 @@ async def run_task_execution_workflow(
task_queue=temporal_task_queue,
id=str(job_id),
run_timeout=timedelta(days=31),
retry_policy=DEFAULT_RETRY_POLICY,
# TODO: Should add search_attributes for queryability
)

Expand Down
63 changes: 63 additions & 0 deletions agents-api/agents_api/common/retry_policies.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from datetime import timedelta

from temporalio.common import RetryPolicy

DEFAULT_RETRY_POLICY = RetryPolicy(
initial_interval=timedelta(seconds=1),
backoff_coefficient=2,
maximum_attempts=25,
maximum_interval=timedelta(seconds=300),
non_retryable_error_types=[
# Temporal-specific errors
"WorkflowExecutionAlreadyStarted",
"temporalio.exceptions.TerminalFailure",
"temporalio.exceptions.CanceledError",
#
# Built-in Python exceptions
"TypeError",
"AssertionError",
"SyntaxError",
"ValueError",
HamadaSalhab marked this conversation as resolved.
Show resolved Hide resolved
"ZeroDivisionError",
"IndexError",
"AttributeError",
"LookupError",
"BufferError",
"ArithmeticError",
"KeyError",
"NameError",
"NotImplementedError",
"RecursionError",
"RuntimeError",
"StopIteration",
"StopAsyncIteration",
"IndentationError",
"TabError",
#
# Unicode-related errors
"UnicodeError",
"UnicodeEncodeError",
"UnicodeDecodeError",
"UnicodeTranslateError",
#
# HTTP and API-related errors
"HTTPException",
"fastapi.exceptions.HTTPException",
"fastapi.exceptions.RequestValidationError",
"httpx.RequestError",
"httpx.HTTPStatusError",
#
# Asynchronous programming errors
"asyncio.CancelledError",
"asyncio.InvalidStateError",
"GeneratorExit",
#
# Third-party library exceptions
"jinja2.exceptions.TemplateSyntaxError",
"jinja2.exceptions.TemplateNotFound",
"jsonschema.exceptions.ValidationError",
"pydantic.ValidationError",
"requests.exceptions.InvalidURL",
"requests.exceptions.MissingSchema",
],
)
2 changes: 2 additions & 0 deletions agents-api/agents_api/routers/docs/create_doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from ...activities.types import EmbedDocsPayload
from ...autogen.openapi_model import CreateDocRequest, ResourceCreatedResponse
from ...clients import temporal
from ...common.retry_policies import DEFAULT_RETRY_POLICY
from ...dependencies.developer_id import get_developer_id
from ...env import temporal_task_queue, testing
from ...models.docs.create_doc import create_doc as create_doc_query
Expand Down Expand Up @@ -41,6 +42,7 @@ async def run_embed_docs_task(
embed_payload,
task_queue=temporal_task_queue,
id=str(job_id),
retry_policy=DEFAULT_RETRY_POLICY,
)

# TODO: Remove this conditional once we have a way to run workflows in
Expand Down
3 changes: 3 additions & 0 deletions agents-api/agents_api/workflows/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from temporalio import workflow

from ..common.retry_policies import DEFAULT_RETRY_POLICY

with workflow.unsafe.imports_passed_through():
from ..activities.demo import demo_activity

Expand All @@ -14,4 +16,5 @@ async def run(self, a: int, b: int) -> int:
demo_activity,
args=[a, b],
start_to_close_timeout=timedelta(seconds=30),
retry_policy=DEFAULT_RETRY_POLICY,
)
2 changes: 2 additions & 0 deletions agents-api/agents_api/workflows/embed_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
with workflow.unsafe.imports_passed_through():
from ..activities.embed_docs import embed_docs
from ..activities.types import EmbedDocsPayload
from ..common.retry_policies import DEFAULT_RETRY_POLICY


@workflow.defn
Expand All @@ -18,4 +19,5 @@ async def run(self, embed_payload: EmbedDocsPayload) -> None:
embed_docs,
embed_payload,
schedule_to_close_timeout=timedelta(seconds=600),
retry_policy=DEFAULT_RETRY_POLICY,
)
2 changes: 2 additions & 0 deletions agents-api/agents_api/workflows/mem_rating.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

with workflow.unsafe.imports_passed_through():
from ..activities.mem_rating import mem_rating
from ..common.retry_policies import DEFAULT_RETRY_POLICY


@workflow.defn
Expand All @@ -17,4 +18,5 @@ async def run(self, memory: str) -> None:
mem_rating,
memory,
schedule_to_close_timeout=timedelta(seconds=600),
retry_policy=DEFAULT_RETRY_POLICY,
)
2 changes: 2 additions & 0 deletions agents-api/agents_api/workflows/summarization.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

with workflow.unsafe.imports_passed_through():
from ..activities.summarization import summarization
from ..common.retry_policies import DEFAULT_RETRY_POLICY


@workflow.defn
Expand All @@ -17,4 +18,5 @@ async def run(self, session_id: str) -> None:
summarization,
session_id,
schedule_to_close_timeout=timedelta(seconds=600),
retry_policy=DEFAULT_RETRY_POLICY,
)
8 changes: 8 additions & 0 deletions agents-api/agents_api/workflows/task_execution/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
StepContext,
StepOutcome,
)
from ...common.retry_policies import DEFAULT_RETRY_POLICY
from ...env import debug, testing
from .helpers import (
continue_as_child,
Expand All @@ -58,6 +59,7 @@
)
from .transition import transition


# Supported steps
# ---------------

Expand Down Expand Up @@ -204,6 +206,7 @@ async def run(
schedule_to_close_timeout=timedelta(
seconds=30 if debug or testing else 600
),
retry_policy=DEFAULT_RETRY_POLICY,
)
workflow.logger.debug(
f"Step {context.cursor.step} completed successfully"
Expand Down Expand Up @@ -389,6 +392,7 @@ async def run(
task_steps.raise_complete_async,
args=[context, output],
schedule_to_close_timeout=timedelta(days=31),
retry_policy=DEFAULT_RETRY_POLICY,
)

state = PartialTransition(type="resume", output=result)
Expand Down Expand Up @@ -421,6 +425,7 @@ async def run(
task_steps.raise_complete_async,
args=[context, tool_calls_input],
schedule_to_close_timeout=timedelta(days=31),
retry_policy=DEFAULT_RETRY_POLICY,
)

# Feed the tool call results back to the model
Expand All @@ -432,6 +437,7 @@ async def run(
schedule_to_close_timeout=timedelta(
seconds=30 if debug or testing else 600
),
retry_policy=DEFAULT_RETRY_POLICY,
)
state = PartialTransition(output=new_response.output, type="resume")

Expand Down Expand Up @@ -475,6 +481,7 @@ async def run(
task_steps.raise_complete_async,
args=[context, tool_call],
schedule_to_close_timeout=timedelta(days=31),
retry_policy=DEFAULT_RETRY_POLICY,
)

state = PartialTransition(output=tool_call_response, type="resume")
Expand Down Expand Up @@ -505,6 +512,7 @@ async def run(
schedule_to_close_timeout=timedelta(
seconds=30 if debug or testing else 600
),
retry_policy=DEFAULT_RETRY_POLICY,
)

state = PartialTransition(output=tool_call_response)
Expand Down
5 changes: 5 additions & 0 deletions agents-api/agents_api/workflows/task_execution/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from temporalio import workflow
from temporalio.exceptions import ApplicationError

from ...common.retry_policies import DEFAULT_RETRY_POLICY

with workflow.unsafe.imports_passed_through():
from ...activities import task_steps
from ...autogen.openapi_model import (
Expand Down Expand Up @@ -33,6 +35,7 @@ async def continue_as_child(
previous_inputs,
user_state,
],
retry_policy=DEFAULT_RETRY_POLICY,
)


Expand Down Expand Up @@ -169,6 +172,7 @@ async def execute_map_reduce_step(
task_steps.base_evaluate,
args=[reduce, {"results": result, "_": output}],
schedule_to_close_timeout=timedelta(seconds=30),
retry_policy=DEFAULT_RETRY_POLICY,
)

return result
Expand Down Expand Up @@ -244,6 +248,7 @@ async def execute_map_reduce_step_parallel(
extra_lambda_strs,
],
schedule_to_close_timeout=timedelta(seconds=30),
retry_policy=DEFAULT_RETRY_POLICY,
)

except BaseException as e:
Expand Down
2 changes: 2 additions & 0 deletions agents-api/agents_api/workflows/task_execution/transition.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
TransitionTarget,
)
from ...common.protocol.tasks import PartialTransition, StepContext
from ...common.retry_policies import DEFAULT_RETRY_POLICY


async def transition(
Expand Down Expand Up @@ -44,6 +45,7 @@ async def transition(
task_steps.transition_step,
args=[context, transition_request],
schedule_to_close_timeout=timedelta(seconds=30),
retry_policy=DEFAULT_RETRY_POLICY,
)

except Exception as e:
Expand Down
2 changes: 2 additions & 0 deletions agents-api/agents_api/workflows/truncation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

with workflow.unsafe.imports_passed_through():
from ..activities.truncation import truncation
from ..common.retry_policies import DEFAULT_RETRY_POLICY


@workflow.defn
Expand All @@ -17,4 +18,5 @@ async def run(self, session_id: str, token_count_threshold: int) -> None:
truncation,
args=[session_id, token_count_threshold],
schedule_to_close_timeout=timedelta(seconds=600),
retry_policy=DEFAULT_RETRY_POLICY,
)
11 changes: 9 additions & 2 deletions agents-api/tests/test_activities.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,14 @@
from agents_api.clients import temporal
from agents_api.env import temporal_task_queue
from agents_api.workflows.demo import DemoWorkflow
from tests.fixtures import cozo_client, test_developer_id, test_doc
from tests.utils import patch_testing_temporal
from agents_api.workflows.task_execution.helpers import DEFAULT_RETRY_POLICY

from .fixtures import (
cozo_client,
test_developer_id,
test_doc,
)
from .utils import patch_testing_temporal


@test("activity: call direct embed_docs")
Expand Down Expand Up @@ -44,6 +50,7 @@ async def _():
args=[1, 2],
id=str(uuid4()),
task_queue=temporal_task_queue,
retry_policy=DEFAULT_RETRY_POLICY,
)

assert result == 3
Expand Down