Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(agents-api): Make chat route tests pass #454

Merged
merged 3 commits into from
Aug 13, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions agents-api/agents_api/autogen/Chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from .Common import LogitBias
from .Docs import DocReference
from .Entries import ChatMLMessage, InputChatMLMessage
from .Entries import InputChatMLMessage
from .Tools import FunctionTool, NamedToolChoice


Expand All @@ -23,7 +23,7 @@ class BaseChatOutput(BaseModel):
"""
The reason the model stopped generating tokens
"""
logprobs: Annotated[LogProbResponse | None, Field(...)]
logprobs: LogProbResponse | None = None
"""
The log probabilities of tokens
"""
Expand All @@ -33,7 +33,7 @@ class BaseChatResponse(BaseModel):
model_config = ConfigDict(
populate_by_name=True,
)
usage: Annotated[CompetionUsage | None, Field(...)]
usage: CompetionUsage | None = None
"""
Usage statistics for the completion request
"""
Expand Down Expand Up @@ -61,7 +61,7 @@ class BaseTokenLogProb(BaseModel):
"""
The log probability of the token
"""
bytes: Annotated[list[int] | None, Field(...)]
bytes: list[int] | None = None


class ChatInputData(BaseModel):
Expand Down Expand Up @@ -90,7 +90,7 @@ class ChatOutputChunk(BaseChatOutput):
model_config = ConfigDict(
populate_by_name=True,
)
delta: ChatMLMessage
delta: InputChatMLMessage
"""
The message generated by the model
"""
Expand Down Expand Up @@ -166,7 +166,7 @@ class MultipleChatOutput(BaseChatOutput):
model_config = ConfigDict(
populate_by_name=True,
)
messages: list[ChatMLMessage]
messages: list[InputChatMLMessage]


class OpenAISettings(BaseModel):
Expand Down Expand Up @@ -199,7 +199,7 @@ class SingleChatOutput(BaseChatOutput):
model_config = ConfigDict(
populate_by_name=True,
)
message: ChatMLMessage
message: InputChatMLMessage


class TokenLogProb(BaseTokenLogProb):
Expand Down
41 changes: 2 additions & 39 deletions agents-api/agents_api/autogen/Entries.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class BaseEntry(BaseModel):
)
role: Literal[
"user",
"agent",
"assistant",
"system",
"function",
"function_response",
Expand Down Expand Up @@ -67,43 +67,6 @@ class ChatMLImageContentPart(BaseModel):
"""


class ChatMLMessage(BaseModel):
model_config = ConfigDict(
populate_by_name=True,
)
role: Literal[
"user",
"agent",
"system",
"function",
"function_response",
"function_call",
"auto",
]
"""
The role of the message
"""
content: str | list[str] | list[ChatMLTextContentPart | ChatMLImageContentPart]
"""
The content parts of the message
"""
name: str | None = None
"""
Name
"""
tool_calls: Annotated[
list[ChosenToolCall], Field([], json_schema_extra={"readOnly": True})
]
"""
Tool calls generated by the model.
"""
created_at: Annotated[AwareDatetime, Field(json_schema_extra={"readOnly": True})]
"""
When this resource was created as UTC date-time
"""
id: Annotated[UUID, Field(json_schema_extra={"readOnly": True})]


class ChatMLTextContentPart(BaseModel):
model_config = ConfigDict(
populate_by_name=True,
Expand Down Expand Up @@ -159,7 +122,7 @@ class InputChatMLMessage(BaseModel):
)
role: Literal[
"user",
"agent",
"assistant",
"system",
"function",
"function_response",
Expand Down
6 changes: 3 additions & 3 deletions agents-api/agents_api/common/protocol/sessions.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def get_active_agent(self) -> Agent:
"""
Get the active agent from the session data.
"""
requested_agent: UUID | None = self.settings.agent
requested_agent: UUID | None = self.settings and self.settings.agent

if requested_agent:
assert requested_agent in [agent.id for agent in self.agents], (
Expand All @@ -67,15 +67,15 @@ def get_active_agent(self) -> Agent:
return self.agents[0]

def merge_settings(self, chat_input: ChatInput) -> ChatSettings:
request_settings = ChatSettings.model_validate(chat_input)
request_settings = chat_input.model_dump(exclude_unset=True)
active_agent = self.get_active_agent()
default_settings = active_agent.default_settings

self.settings = settings = ChatSettings(
**{
"model": active_agent.model,
**default_settings.model_dump(),
**request_settings.model_dump(exclude_unset=True),
**request_settings,
}
)

Expand Down
39 changes: 34 additions & 5 deletions agents-api/agents_api/common/utils/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,28 @@ async def render_template_string(
return rendered


async def render_template_chatml(
messages: list[dict], variables: dict, check: bool = False
) -> list[dict]:
# Parse template
# FIXME: should template_strings contain a list of ChatMLTextContentPart? Should we handle it somehow?
templates = [jinja_env.from_string(msg["content"]) for msg in messages]

# If check is required, get required vars from template and validate variables
if check:
for template in templates:
schema = to_json_schema(infer(template))
validate(instance=variables, schema=schema)

# Render
rendered = [
({**msg, "content": await template.render_async(**variables)})
for template, msg in zip(templates, messages)
]

return rendered


async def render_template_parts(
template_strings: list[dict], variables: dict, check: bool = False
) -> list[dict]:
Expand Down Expand Up @@ -73,7 +95,7 @@ async def render_template_parts(


async def render_template(
template_string: str | list[dict],
input: str | list[dict],
variables: dict,
check: bool = False,
skip_vars: list[str] | None = None,
Expand All @@ -83,8 +105,15 @@ async def render_template(
for name, val in variables.items()
if not (skip_vars is not None and isinstance(name, str) and name in skip_vars)
}
if isinstance(template_string, str):
return await render_template_string(template_string, variables, check)

elif isinstance(template_string, list):
return await render_template_parts(template_string, variables, check)
match input:
case str():
future = render_template_string(input, variables, check)

case [{"content": str()}, *_]:
future = render_template_chatml(input, variables, check)

case _:
future = render_template_parts(input, variables, check)

return await future
3 changes: 3 additions & 0 deletions agents-api/agents_api/models/docs/search_docs_hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ def dbsf_normalize(scores: list[float]) -> list[float]:
Scores scaled using minmax scaler with our custom feature range
(extremes indicated as 3 standard deviations from the mean)
"""
if len(scores) < 2:
return scores

sd = stdev(scores)
if sd == 0:
return scores
Expand Down
105 changes: 63 additions & 42 deletions agents-api/agents_api/routers/sessions/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,16 @@
from ...autogen.openapi_model import (
ChatInput,
ChatResponse,
ChunkChatResponse,
CreateEntryRequest,
DocReference,
History,
MessageChatResponse,
)
from ...clients.embed import embed
from ...clients.litellm import acompletion
from ...clients import embed, litellm
from ...common.protocol.developers import Developer
from ...common.protocol.sessions import ChatContext
from ...common.utils.datetime import utcnow
from ...common.utils.template import render_template
from ...dependencies.developer_id import get_developer_data
from ...models.docs.search_docs_hybrid import search_docs_hybrid
Expand All @@ -24,28 +26,14 @@
from .router import router


@router.post(
"/sessions/{session_id}/chat",
status_code=HTTP_201_CREATED,
tags=["sessions", "chat"],
)
async def chat(
developer: Annotated[Developer, Depends(get_developer_data)],
async def get_messages(
*,
developer: Developer,
session_id: UUID,
data: ChatInput,
background_tasks: BackgroundTasks,
) -> ChatResponse:
# First get the chat context
chat_context: ChatContext = prepare_chat_context(
developer_id=developer.id,
session_id=session_id,
)
assert isinstance(chat_context, ChatContext)

# Merge the settings and prepare environment
chat_context.merge_settings(data)
settings: dict = chat_context.settings.model_dump()
env: dict = chat_context.get_chat_environment()
new_raw_messages: list[dict],
chat_context: ChatContext,
):
assert len(new_raw_messages) > 0

# Get the session history
history: History = get_history(
Expand All @@ -62,10 +50,8 @@ async def chat(
if entry.id not in {r.head for r in relations}
]

new_raw_messages = [msg.model_dump() for msg in data.messages]

# Search matching docs
[query_embedding, *_] = await embed(
[query_embedding, *_] = await embed.embed(
inputs=[
f"{msg.get('name') or msg['role']}: {msg['content']}"
for msg in new_raw_messages
Expand All @@ -82,39 +68,74 @@ async def chat(
query_embedding=query_embedding,
)

return past_messages, doc_references


@router.post(
"/sessions/{session_id}/chat",
status_code=HTTP_201_CREATED,
tags=["sessions", "chat"],
)
async def chat(
developer: Annotated[Developer, Depends(get_developer_data)],
session_id: UUID,
data: ChatInput,
background_tasks: BackgroundTasks,
) -> ChatResponse:
# First get the chat context
chat_context: ChatContext = prepare_chat_context(
developer_id=developer.id,
session_id=session_id,
)

# Merge the settings and prepare environment
chat_context.merge_settings(data)
settings: dict = chat_context.settings.model_dump()
env: dict = chat_context.get_chat_environment()
new_raw_messages = [msg.model_dump() for msg in data.messages]

# Render the messages
past_messages, doc_references = await get_messages(
developer=developer,
session_id=session_id,
new_raw_messages=new_raw_messages,
chat_context=chat_context,
)

env["docs"] = doc_references
new_messages = render_template(new_raw_messages, variables=env)
new_messages = await render_template(new_raw_messages, variables=env)
messages = past_messages + new_messages

# Get the response from the model
model_response = await acompletion(
model_response = await litellm.acompletion(
messages=messages,
**settings,
user=str(developer.id),
tags=developer.tags,
)

# Save the input and the response to the session history
new_entries = [CreateEntryRequest(**msg) for msg in new_messages]
background_tasks.add_task(
create_entries,
developer_id=developer.id,
session_id=session_id,
data=new_entries,
mark_session_as_updated=True,
)
if data.save:
new_entries = [
CreateEntryRequest(**msg, source="api_request") for msg in new_messages
]
background_tasks.add_task(
create_entries,
developer_id=developer.id,
session_id=session_id,
data=new_entries,
mark_session_as_updated=True,
)

# Return the response
response_json = model_response.model_dump()
response_json.pop("id", None)

chat_response: ChatResponse = ChatResponse(
**response_json,
chat_response_class = ChunkChatResponse if data.stream else MessageChatResponse
chat_response: ChatResponse = chat_response_class(
id=uuid4(),
created_at=model_response.created,
created_at=utcnow(),
jobs=[],
docs=doc_references,
usage=model_response.usage.model_dump(),
choices=[choice.model_dump() for choice in model_response.choices],
)

return chat_response
Loading
Loading