Skip to content

Commit

Permalink
feat(agents-api): Add token count support using litellm
Browse files Browse the repository at this point in the history
Signed-off-by: Diwank Tomer <[email protected]>
  • Loading branch information
Diwank Tomer committed Aug 14, 2024
1 parent 3ab7758 commit 313a312
Show file tree
Hide file tree
Showing 16 changed files with 127 additions and 143 deletions.
10 changes: 0 additions & 10 deletions agents-api/agents_api/activities/summarization.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,13 @@
#!/usr/bin/env python3

import asyncio
from textwrap import dedent
from typing import Callable
from uuid import UUID

import pandas as pd
from temporalio import activity

# from agents_api.common.protocol.entries import Entry
# from agents_api.models.entry.entries_summarization import (
# entries_summarization_query,
# get_toplevel_entries_query,
# )
from agents_api.rec_sum.entities import get_entities
from agents_api.rec_sum.summarize import summarize_messages
from agents_api.rec_sum.trim import trim_messages

from ..env import summarization_model_name


# TODO: remove stubs
Expand Down
7 changes: 2 additions & 5 deletions agents-api/agents_api/activities/truncation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,7 @@

from temporalio import activity

# from agents_api.autogen.openapi_model import Role
from agents_api.common.protocol.entries import Entry
from agents_api.models.entry.delete_entries import delete_entries
from agents_api.autogen.openapi_model import Entry

# from agents_api.models.entry.entries_summarization import get_toplevel_entries_query

Expand All @@ -13,8 +11,7 @@ def get_extra_entries(messages: list[Entry], token_count_threshold: int) -> list
if not len(messages):
return messages

result: list[UUID] = []
token_cnt, offset = 0, 0
_token_cnt, _offset = 0, 0
# if messages[0].role == Role.system:
# token_cnt, offset = messages[0].token_count, 1

Expand Down
4 changes: 2 additions & 2 deletions agents-api/agents_api/autogen/Entries.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ class BaseEntry(BaseModel):
source: Literal[
"api_request", "api_response", "tool_response", "internal", "summarizer", "meta"
]
tokenizer: str | None = None
token_count: int | None = None
tokenizer: str
token_count: int
timestamp: Annotated[float, Field(ge=0.0)]
"""
This is the time that this event refers to.
Expand Down
59 changes: 57 additions & 2 deletions agents-api/agents_api/autogen/openapi_model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
# ruff: noqa: F401, F403, F405
from typing import Annotated, Generic, TypeVar
from typing import Annotated, Generic, Self, Type, TypeVar
from uuid import UUID

from litellm.utils import _select_tokenizer as select_tokenizer
from litellm.utils import token_counter
from pydantic import AwareDatetime, Field
from pydantic_partial import create_partial_model

Expand Down Expand Up @@ -34,14 +36,67 @@
"metadata",
)

ChatMLRole = BaseEntry.model_fields["role"].annotation
ChatMLRole = Literal[
"user",
"assistant",
"system",
"function",
"function_response",
"function_call",
"auto",
]

ChatMLContent = (
list[ChatMLTextContentPart | ChatMLImageContentPart]
| Tool
| ChosenToolCall
| str
| ToolResponse
| list[
list[ChatMLTextContentPart | ChatMLImageContentPart]
| Tool
| ChosenToolCall
| str
| ToolResponse
]
)

ChatMLSource = Literal[
"api_request", "api_response", "tool_response", "internal", "summarizer", "meta"
]


class CreateEntryRequest(BaseEntry):
timestamp: Annotated[
float, Field(ge=0.0, default_factory=lambda: utcnow().timestamp())
]

@classmethod
def from_model_input(
cls: Type[Self],
model: str,
*,
role: ChatMLRole,
content: ChatMLContent,
name: str | None = None,
source: ChatMLSource,
**kwargs: dict,
) -> Self:
tokenizer: dict = select_tokenizer(model=model)
token_count = token_counter(
model=model, messages=[{"role": role, "content": content, "name": name}]
)

return cls(
role=role,
content=content,
name=name,
source=source,
tokenizer=tokenizer["type"],
token_count=token_count,
**kwargs,
)


def make_session(
*,
Expand Down
56 changes: 0 additions & 56 deletions agents-api/agents_api/common/protocol/entries.py

This file was deleted.

3 changes: 2 additions & 1 deletion agents-api/agents_api/routers/tasks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
# ruff: noqa: F401, F403, F405
from .create_task import create_task
from .create_task_execution import create_task_execution
from .get_execution_details import get_execution_details
from .get_task_details import get_task_details
from .list_task_executions import list_task_executions
from .list_tasks import list_tasks
from .patch_execution import patch_execution
from .router import router # noqa: F401
from .router import router
from .update_execution import update_execution
2 changes: 0 additions & 2 deletions agents-api/agents_api/routers/tasks/get_execution_details.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from uuid import uuid4

from fastapi import HTTPException, status
from pydantic import UUID4

Expand Down
44 changes: 31 additions & 13 deletions agents-api/tests/test_chat_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@

from ward import test

from agents_api.autogen.Sessions import CreateSessionRequest
from agents_api.autogen.openapi_model import ChatInput, CreateSessionRequest
from agents_api.clients import embed, litellm
from agents_api.common.protocol.sessions import ChatContext
from agents_api.models.chat.gather_messages import gather_messages
from agents_api.models.chat.prepare_chat_context import prepare_chat_context
from agents_api.models.session.create_session import create_session
from agents_api.models.session.prepare_chat_context import prepare_chat_context
from agents_api.routers.sessions.chat import get_messages
from tests.fixtures import (
cozo_client,
make_request,
Expand All @@ -28,7 +29,7 @@ async def _(
assert (await embed.embed())[0][0] == 1.0


@test("chat: check that non-recall get_messages works")
@test("chat: check that non-recall gather_messages works")
async def _(
developer=test_developer,
client=cozo_client,
Expand All @@ -49,14 +50,13 @@ async def _(

session_id = session.id

new_raw_messages = [{"role": "user", "content": "hello"}]
messages = [{"role": "user", "content": "hello"}]

past_messages, doc_references = await get_messages(
past_messages, doc_references = await gather_messages(
developer=developer,
session_id=session_id,
new_raw_messages=new_raw_messages,
chat_context=chat_context,
recall=False,
chat_input=ChatInput(messages=messages, recall=False),
)

assert isinstance(past_messages, list)
Expand All @@ -68,7 +68,7 @@ async def _(
embed.assert_not_called()


@test("chat: check that get_messages works")
@test("chat: check that gather_messages works")
async def _(
developer=test_developer,
client=cozo_client,
Expand All @@ -89,14 +89,13 @@ async def _(

session_id = session.id

new_raw_messages = [{"role": "user", "content": "hello"}]
messages = [{"role": "user", "content": "hello"}]

past_messages, doc_references = await get_messages(
past_messages, doc_references = await gather_messages(
developer=developer,
session_id=session_id,
new_raw_messages=new_raw_messages,
chat_context=chat_context,
recall=True,
chat_input=ChatInput(messages=messages, recall=True),
)

assert isinstance(past_messages, list)
Expand Down Expand Up @@ -136,3 +135,22 @@ async def _(
# Check that both mocks were called at least once
embed.assert_called()
acompletion.assert_called()


@test("model: prepare chat context")
def _(
client=cozo_client,
developer_id=test_developer_id,
agent=test_agent,
session=test_session,
tool=test_tool,
user=test_user,
):
context = prepare_chat_context(
developer_id=developer_id,
session_id=session.id,
client=client,
)

assert isinstance(context, ChatContext)
assert len(context.toolsets) > 0
32 changes: 16 additions & 16 deletions agents-api/tests/test_entry_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ def _(client=cozo_client, developer_id=test_developer_id, session=test_session):
Verifies that the entry can be successfully added using the create_entries function.
"""

test_entry = CreateEntryRequest(
session_id=session.id,
test_entry = CreateEntryRequest.from_model_input(
model=MODEL,
role="user",
source="internal",
content="test entry content",
Expand All @@ -50,8 +50,8 @@ def _(client=cozo_client, developer_id=test_developer_id, session=test_session):
Verifies that the entry can be successfully added using the create_entries function.
"""

test_entry = CreateEntryRequest(
session_id=session.id,
test_entry = CreateEntryRequest.from_model_input(
model=MODEL,
role="user",
source="internal",
content="test entry content",
Expand Down Expand Up @@ -84,15 +84,15 @@ def _(client=cozo_client, developer_id=test_developer_id, session=test_session):
Verifies that entries matching specific criteria can be successfully retrieved.
"""

test_entry = CreateEntryRequest(
session_id=session.id,
test_entry = CreateEntryRequest.from_model_input(
model=MODEL,
role="user",
source="api_request",
content="test entry content",
)

internal_entry = CreateEntryRequest(
session_id=session.id,
internal_entry = CreateEntryRequest.from_model_input(
model=MODEL,
role="user",
content="test entry content",
source="internal",
Expand Down Expand Up @@ -122,15 +122,15 @@ def _(client=cozo_client, developer_id=test_developer_id, session=test_session):
Verifies that entries matching specific criteria can be successfully retrieved.
"""

test_entry = CreateEntryRequest(
session_id=session.id,
test_entry = CreateEntryRequest.from_model_input(
model=MODEL,
role="user",
source="api_request",
content="test entry content",
)

internal_entry = CreateEntryRequest(
session_id=session.id,
internal_entry = CreateEntryRequest.from_model_input(
model=MODEL,
role="user",
content="test entry content",
source="internal",
Expand Down Expand Up @@ -161,15 +161,15 @@ def _(client=cozo_client, developer_id=test_developer_id, session=test_session):
Verifies that entries can be successfully deleted using the delete_entries function.
"""

test_entry = CreateEntryRequest(
session_id=session.id,
test_entry = CreateEntryRequest.from_model_input(
model=MODEL,
role="user",
source="api_request",
content="test entry content",
)

internal_entry = CreateEntryRequest(
session_id=session.id,
internal_entry = CreateEntryRequest.from_model_input(
model=MODEL,
role="user",
content="internal entry content",
source="internal",
Expand Down
Loading

0 comments on commit 313a312

Please sign in to comment.