Skip to content

Commit

Permalink
feat(agents-api): Add retry policies to temporal workflows/activities (
Browse files Browse the repository at this point in the history
…#551)

<!-- ELLIPSIS_HIDDEN -->


> [!IMPORTANT]
> Introduces `DEFAULT_RETRY_POLICY` for consistent retry behavior in
Temporal workflows and activities, updating workflows, activities, and
tests accordingly.
> 
>   - **Retry Policy**:
> - Introduces `DEFAULT_RETRY_POLICY` in `retry_policies.py` with
specific retry configurations.
> - Applies `DEFAULT_RETRY_POLICY` to `run_task_execution_workflow()` in
`temporal.py` and `run_embed_docs_task()` in `create_doc.py`.
>   - **Workflows**:
> - Adds `retry_policy=DEFAULT_RETRY_POLICY` to `DemoWorkflow`,
`EmbedDocsWorkflow`, `MemRatingWorkflow`, `SummarizationWorkflow`,
`TruncationWorkflow`.
> - Updates `TaskExecutionWorkflow` in `task_execution/__init__.py` to
use `DEFAULT_RETRY_POLICY` for activities.
>   - **Activities**:
> - Updates `raise_complete_async()` in `raise_complete_async.py` to use
consistent string formatting.
> - Updates `transition()` in `transition.py` to use
`DEFAULT_RETRY_POLICY`.
>   - **Tests**:
> - Updates `test_activities.py` to use `DEFAULT_RETRY_POLICY` in
workflow execution tests.
> 
> <sup>This description was created by </sup>[<img alt="Ellipsis"
src="https://img.shields.io/badge/Ellipsis-blue?color=175173">](https://www.ellipsis.dev?ref=julep-ai%2Fjulep&utm_source=github&utm_medium=referral)<sup>
for 2d945d3. It will automatically
update as commits are pushed.</sup>


<!-- ELLIPSIS_HIDDEN -->

---------

Signed-off-by: Diwank Singh Tomer <[email protected]>
Co-authored-by: Diwank Singh Tomer <[email protected]>
Co-authored-by: creatorrr <[email protected]>
  • Loading branch information
3 people authored Oct 5, 2024
1 parent 8033549 commit 2e0108e
Show file tree
Hide file tree
Showing 12 changed files with 102 additions and 2 deletions.
2 changes: 2 additions & 0 deletions agents-api/agents_api/clients/temporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from ..autogen.openapi_model import TransitionTarget
from ..common.protocol.tasks import ExecutionInput
from ..common.retry_policies import DEFAULT_RETRY_POLICY
from ..env import (
temporal_client_cert,
temporal_namespace,
Expand Down Expand Up @@ -54,6 +55,7 @@ async def run_task_execution_workflow(
task_queue=temporal_task_queue,
id=str(job_id),
run_timeout=timedelta(days=31),
retry_policy=DEFAULT_RETRY_POLICY,
# TODO: Should add search_attributes for queryability
)

Expand Down
63 changes: 63 additions & 0 deletions agents-api/agents_api/common/retry_policies.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from datetime import timedelta

from temporalio.common import RetryPolicy

DEFAULT_RETRY_POLICY = RetryPolicy(
initial_interval=timedelta(seconds=1),
backoff_coefficient=2,
maximum_attempts=25,
maximum_interval=timedelta(seconds=300),
non_retryable_error_types=[
# Temporal-specific errors
"WorkflowExecutionAlreadyStarted",
"temporalio.exceptions.TerminalFailure",
"temporalio.exceptions.CanceledError",
#
# Built-in Python exceptions
"TypeError",
"AssertionError",
"SyntaxError",
"ValueError",
"ZeroDivisionError",
"IndexError",
"AttributeError",
"LookupError",
"BufferError",
"ArithmeticError",
"KeyError",
"NameError",
"NotImplementedError",
"RecursionError",
"RuntimeError",
"StopIteration",
"StopAsyncIteration",
"IndentationError",
"TabError",
#
# Unicode-related errors
"UnicodeError",
"UnicodeEncodeError",
"UnicodeDecodeError",
"UnicodeTranslateError",
#
# HTTP and API-related errors
"HTTPException",
"fastapi.exceptions.HTTPException",
"fastapi.exceptions.RequestValidationError",
"httpx.RequestError",
"httpx.HTTPStatusError",
#
# Asynchronous programming errors
"asyncio.CancelledError",
"asyncio.InvalidStateError",
"GeneratorExit",
#
# Third-party library exceptions
"jinja2.exceptions.TemplateSyntaxError",
"jinja2.exceptions.TemplateNotFound",
"jsonschema.exceptions.ValidationError",
"pydantic.ValidationError",
"requests.exceptions.InvalidURL",
"requests.exceptions.MissingSchema",
],
)
2 changes: 2 additions & 0 deletions agents-api/agents_api/routers/docs/create_doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from ...activities.types import EmbedDocsPayload
from ...autogen.openapi_model import CreateDocRequest, ResourceCreatedResponse
from ...clients import temporal
from ...common.retry_policies import DEFAULT_RETRY_POLICY
from ...dependencies.developer_id import get_developer_id
from ...env import temporal_task_queue, testing
from ...models.docs.create_doc import create_doc as create_doc_query
Expand Down Expand Up @@ -41,6 +42,7 @@ async def run_embed_docs_task(
embed_payload,
task_queue=temporal_task_queue,
id=str(job_id),
retry_policy=DEFAULT_RETRY_POLICY,
)

# TODO: Remove this conditional once we have a way to run workflows in
Expand Down
3 changes: 3 additions & 0 deletions agents-api/agents_api/workflows/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from temporalio import workflow

from ..common.retry_policies import DEFAULT_RETRY_POLICY

with workflow.unsafe.imports_passed_through():
from ..activities.demo import demo_activity

Expand All @@ -14,4 +16,5 @@ async def run(self, a: int, b: int) -> int:
demo_activity,
args=[a, b],
start_to_close_timeout=timedelta(seconds=30),
retry_policy=DEFAULT_RETRY_POLICY,
)
2 changes: 2 additions & 0 deletions agents-api/agents_api/workflows/embed_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
with workflow.unsafe.imports_passed_through():
from ..activities.embed_docs import embed_docs
from ..activities.types import EmbedDocsPayload
from ..common.retry_policies import DEFAULT_RETRY_POLICY


@workflow.defn
Expand All @@ -18,4 +19,5 @@ async def run(self, embed_payload: EmbedDocsPayload) -> None:
embed_docs,
embed_payload,
schedule_to_close_timeout=timedelta(seconds=600),
retry_policy=DEFAULT_RETRY_POLICY,
)
2 changes: 2 additions & 0 deletions agents-api/agents_api/workflows/mem_rating.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

with workflow.unsafe.imports_passed_through():
from ..activities.mem_rating import mem_rating
from ..common.retry_policies import DEFAULT_RETRY_POLICY


@workflow.defn
Expand All @@ -17,4 +18,5 @@ async def run(self, memory: str) -> None:
mem_rating,
memory,
schedule_to_close_timeout=timedelta(seconds=600),
retry_policy=DEFAULT_RETRY_POLICY,
)
2 changes: 2 additions & 0 deletions agents-api/agents_api/workflows/summarization.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

with workflow.unsafe.imports_passed_through():
from ..activities.summarization import summarization
from ..common.retry_policies import DEFAULT_RETRY_POLICY


@workflow.defn
Expand All @@ -17,4 +18,5 @@ async def run(self, session_id: str) -> None:
summarization,
session_id,
schedule_to_close_timeout=timedelta(seconds=600),
retry_policy=DEFAULT_RETRY_POLICY,
)
8 changes: 8 additions & 0 deletions agents-api/agents_api/workflows/task_execution/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
StepContext,
StepOutcome,
)
from ...common.retry_policies import DEFAULT_RETRY_POLICY
from ...env import debug, testing
from .helpers import (
continue_as_child,
Expand All @@ -58,6 +59,7 @@
)
from .transition import transition


# Supported steps
# ---------------

Expand Down Expand Up @@ -204,6 +206,7 @@ async def run(
schedule_to_close_timeout=timedelta(
seconds=30 if debug or testing else 600
),
retry_policy=DEFAULT_RETRY_POLICY,
)
workflow.logger.debug(
f"Step {context.cursor.step} completed successfully"
Expand Down Expand Up @@ -389,6 +392,7 @@ async def run(
task_steps.raise_complete_async,
args=[context, output],
schedule_to_close_timeout=timedelta(days=31),
retry_policy=DEFAULT_RETRY_POLICY,
)

state = PartialTransition(type="resume", output=result)
Expand Down Expand Up @@ -421,6 +425,7 @@ async def run(
task_steps.raise_complete_async,
args=[context, tool_calls_input],
schedule_to_close_timeout=timedelta(days=31),
retry_policy=DEFAULT_RETRY_POLICY,
)

# Feed the tool call results back to the model
Expand All @@ -432,6 +437,7 @@ async def run(
schedule_to_close_timeout=timedelta(
seconds=30 if debug or testing else 600
),
retry_policy=DEFAULT_RETRY_POLICY,
)
state = PartialTransition(output=new_response.output, type="resume")

Expand Down Expand Up @@ -475,6 +481,7 @@ async def run(
task_steps.raise_complete_async,
args=[context, tool_call],
schedule_to_close_timeout=timedelta(days=31),
retry_policy=DEFAULT_RETRY_POLICY,
)

state = PartialTransition(output=tool_call_response, type="resume")
Expand Down Expand Up @@ -505,6 +512,7 @@ async def run(
schedule_to_close_timeout=timedelta(
seconds=30 if debug or testing else 600
),
retry_policy=DEFAULT_RETRY_POLICY,
)

state = PartialTransition(output=tool_call_response)
Expand Down
5 changes: 5 additions & 0 deletions agents-api/agents_api/workflows/task_execution/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from temporalio import workflow
from temporalio.exceptions import ApplicationError

from ...common.retry_policies import DEFAULT_RETRY_POLICY

with workflow.unsafe.imports_passed_through():
from ...activities import task_steps
from ...autogen.openapi_model import (
Expand Down Expand Up @@ -33,6 +35,7 @@ async def continue_as_child(
previous_inputs,
user_state,
],
retry_policy=DEFAULT_RETRY_POLICY,
)


Expand Down Expand Up @@ -169,6 +172,7 @@ async def execute_map_reduce_step(
task_steps.base_evaluate,
args=[reduce, {"results": result, "_": output}],
schedule_to_close_timeout=timedelta(seconds=30),
retry_policy=DEFAULT_RETRY_POLICY,
)

return result
Expand Down Expand Up @@ -244,6 +248,7 @@ async def execute_map_reduce_step_parallel(
extra_lambda_strs,
],
schedule_to_close_timeout=timedelta(seconds=30),
retry_policy=DEFAULT_RETRY_POLICY,
)

except BaseException as e:
Expand Down
2 changes: 2 additions & 0 deletions agents-api/agents_api/workflows/task_execution/transition.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
TransitionTarget,
)
from ...common.protocol.tasks import PartialTransition, StepContext
from ...common.retry_policies import DEFAULT_RETRY_POLICY


async def transition(
Expand Down Expand Up @@ -44,6 +45,7 @@ async def transition(
task_steps.transition_step,
args=[context, transition_request],
schedule_to_close_timeout=timedelta(seconds=30),
retry_policy=DEFAULT_RETRY_POLICY,
)

except Exception as e:
Expand Down
2 changes: 2 additions & 0 deletions agents-api/agents_api/workflows/truncation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

with workflow.unsafe.imports_passed_through():
from ..activities.truncation import truncation
from ..common.retry_policies import DEFAULT_RETRY_POLICY


@workflow.defn
Expand All @@ -17,4 +18,5 @@ async def run(self, session_id: str, token_count_threshold: int) -> None:
truncation,
args=[session_id, token_count_threshold],
schedule_to_close_timeout=timedelta(seconds=600),
retry_policy=DEFAULT_RETRY_POLICY,
)
11 changes: 9 additions & 2 deletions agents-api/tests/test_activities.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,14 @@
from agents_api.clients import temporal
from agents_api.env import temporal_task_queue
from agents_api.workflows.demo import DemoWorkflow
from tests.fixtures import cozo_client, test_developer_id, test_doc
from tests.utils import patch_testing_temporal
from agents_api.workflows.task_execution.helpers import DEFAULT_RETRY_POLICY

from .fixtures import (
cozo_client,
test_developer_id,
test_doc,
)
from .utils import patch_testing_temporal


@test("activity: call direct embed_docs")
Expand Down Expand Up @@ -44,6 +50,7 @@ async def _():
args=[1, 2],
id=str(uuid4()),
task_queue=temporal_task_queue,
retry_policy=DEFAULT_RETRY_POLICY,
)

assert result == 3
Expand Down

0 comments on commit 2e0108e

Please sign in to comment.