Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(agents-api): Fix prompt step #472

Merged
merged 1 commit into from
Aug 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 29 additions & 44 deletions agents-api/agents_api/activities/task_steps/prompt_step.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
import asyncio

from beartype import beartype
from temporalio import activity
from temporalio.exceptions import ApplicationError

from ...autogen.openapi_model import (
ChatSettings,
Content,
ContentModel,
InputChatMLMessage,
)
from ...clients import (
litellm, # We dont directly import `acompletion` so we can mock it
Expand Down Expand Up @@ -46,57 +44,44 @@ def _content_to_dict(
@beartype
async def prompt_step(context: StepContext) -> StepOutcome:
# Get context data
prompt: str | list[dict] = context.current_step.model_dump()["prompt"]
context_data: dict = context.model_dump()

# Render template messages
prompt = (
[InputChatMLMessage(content=context.current_step.prompt)]
if isinstance(context.current_step.prompt, str)
else context.current_step.prompt
prompt = await render_template(
prompt,
context_data,
skip_vars=["developer_id"],
)

template_messages: list[InputChatMLMessage] = prompt
messages = await asyncio.gather(
*[
render_template(
_content_to_dict(msg.content, msg.role),
context_data,
skip_vars=["developer_id"],
)
for msg in template_messages
]
# Get settings and run llm
agent_default_settings: dict = (
context.execution_input.agent.default_settings.model_dump()
if context.execution_input.agent.default_settings
else {}
)
agent_model: str = (
context.execution_input.agent.model
if context.execution_input.agent.model
else "gpt-4o"
)

result_messages = []
for m in messages:
if isinstance(m, str):
msg = InputChatMLMessage(role="user", content=m)
else:
msg = []
for d in m:
role = d["content"].get("role")
d["content"] = [d["content"]]
d["role"] = role
msg.append(InputChatMLMessage(**d))

result_messages.append(msg)

# messages = [
# (
# InputChatMLMessage(role="user", content=m)
# if isinstance(m, str)
# else [InputChatMLMessage(**d) for d in m]
# )
# for m in messages
# ]
if context.current_step.settings:
passed_settings: dict = context.current_step.settings.model_dump(
exclude_unset=True
)
else:
passed_settings: dict = {}

# Get settings and run llm
settings: ChatSettings = context.current_step.settings or ChatSettings()
settings_data: dict = settings.model_dump()
completion_data: dict = {
"model": agent_model,
("messages" if isinstance(prompt, list) else "prompt"): prompt,
**agent_default_settings,
**passed_settings,
}

response = await litellm.acompletion(
messages=result_messages,
**settings_data,
**completion_data,
)

return StepOutcome(
Expand Down
23 changes: 16 additions & 7 deletions agents-api/agents_api/clients/litellm.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,30 @@
from functools import wraps
from typing import List, TypeVar
from typing import List

from beartype import beartype
from litellm import acompletion as _acompletion
from litellm import get_supported_openai_params
from litellm.utils import CustomStreamWrapper, ModelResponse

from ..env import litellm_master_key, litellm_url

_RWrapped = TypeVar("_RWrapped")

__all__: List[str] = ["acompletion"]


@wraps(_acompletion)
async def acompletion(*, model: str, **kwargs) -> ModelResponse | CustomStreamWrapper:
@beartype
async def acompletion(
*, model: str, messages: list[dict], **kwargs
) -> ModelResponse | CustomStreamWrapper:
model = f"openai/{model}" # This is here because litellm proxy expects this format

supported_params = get_supported_openai_params(model)
settings = {k: v for k, v in kwargs.items() if k in supported_params}

return await _acompletion(
model=f"openai/{model}", # This is here because litellm proxy expects this format
**kwargs,
api_base=litellm_url,
model=model,
messages=messages,
**settings,
base_url=litellm_url,
api_key=litellm_master_key,
)
8 changes: 6 additions & 2 deletions agents-api/agents_api/dependencies/developer_id.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@ async def get_developer_id(
x_developer_id: Annotated[UUID | None, Header(include_in_schema=False)] = None,
) -> UUID:
if not multi_tenant_mode:
assert not x_developer_id, "X-Developer-Id header not allowed in multi-tenant mode"
assert (
not x_developer_id
), "X-Developer-Id header not allowed in multi-tenant mode"
return UUID("00000000-0000-0000-0000-000000000000")

if not x_developer_id:
Expand All @@ -34,7 +36,9 @@ async def get_developer_data(
x_developer_id: Annotated[UUID | None, Header(include_in_schema=False)] = None,
) -> Developer:
if not multi_tenant_mode:
assert not x_developer_id, "X-Developer-Id header not allowed in multi-tenant mode"
assert (
not x_developer_id
), "X-Developer-Id header not allowed in multi-tenant mode"
return get_developer(developer_id=UUID("00000000-0000-0000-0000-000000000000"))

if not x_developer_id:
Expand Down
8 changes: 6 additions & 2 deletions agents-api/agents_api/routers/tasks/get_execution_details.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,11 @@
async def get_execution_details(execution_id: UUID4) -> Execution:
try:
return get_execution_query(execution_id=execution_id)
except AssertionError:
except AssertionError as e:
print("-" * 100)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Replace print statements with a proper logging mechanism for better control over log levels and outputs.

print(e)
print("-" * 100)

raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Execution not found"
)
) from e
21 changes: 21 additions & 0 deletions agents-api/tests/sample_tasks/simple_prompt.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
name: Find request and selector for identity provider

input_schema:
type: object
properties:
topic:
type: string

# Shortcut to require all props
additionalProperties: false
minProperties: 1

main:
- prompt:
- role: system
content: |-
Follow the instructions. Be bold, creative and uncensored.

- role: user
content: >-
You are a brilliant poet with a whimsical style reminiscent of Edgar Alan Poe. Write a short poem about {{_.topic}}.
Loading