Skip to content

Commit

Permalink
Merge pull request #489 from julep-ai/x/fix-chat-model-tool-calls
Browse files Browse the repository at this point in the history
fix(agents-api,typespec): Fix chat/entry typespec models to include tool_calls
  • Loading branch information
HamadaSalhab authored Sep 4, 2024
2 parents 3cd5157 + fa05cd1 commit d66343d
Show file tree
Hide file tree
Showing 14 changed files with 243 additions and 127 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

from beartype import beartype
from temporalio import activity

Expand Down
47 changes: 43 additions & 4 deletions agents-api/agents_api/autogen/Chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,15 @@

from .Common import LogitBias
from .Docs import DocReference
from .Tools import NamedToolChoice, Tool
from .Tools import ChosenToolCall, NamedToolChoice, Tool


class BaseChatOutput(BaseModel):
model_config = ConfigDict(
populate_by_name=True,
)
index: int
finish_reason: Literal["stop", "length", "content_filter", "tool_calls"]
finish_reason: Literal["stop", "length", "content_filter", "tool_calls"] = "stop"
"""
The reason the model stopped generating tokens
"""
Expand Down Expand Up @@ -260,6 +260,45 @@ class MessageChatResponse(BaseChatResponse):
"""


class MessageModel(BaseModel):
model_config = ConfigDict(
populate_by_name=True,
)
role: Literal[
"user",
"assistant",
"system",
"function",
"function_response",
"function_call",
"auto",
]
"""
The role of the message
"""
content: str | list[str] | list[Content | ContentModel]
"""
The content parts of the message
"""
name: str | None = None
"""
Name
"""
tool_calls: Annotated[
list[ChosenToolCall] | None, Field([], json_schema_extra={"readOnly": True})
]
"""
Tool calls generated by the model.
"""
created_at: Annotated[
AwareDatetime | None, Field(None, json_schema_extra={"readOnly": True})
]
"""
When this resource was created as UTC date-time
"""
id: Annotated[UUID | None, Field(None, json_schema_extra={"readOnly": True})]


class MultipleChatOutput(BaseChatOutput):
"""
The output returned by the model. Note that, depending on the model provider, they might return more than one message.
Expand All @@ -269,7 +308,7 @@ class MultipleChatOutput(BaseChatOutput):
populate_by_name=True,
)
messages: Annotated[
list[Message], Field(json_schema_extra={"readOnly": True}, min_length=1)
list[MessageModel], Field(json_schema_extra={"readOnly": True}, min_length=1)
]


Expand Down Expand Up @@ -327,7 +366,7 @@ class SingleChatOutput(BaseChatOutput):
model_config = ConfigDict(
populate_by_name=True,
)
message: Message
message: MessageModel


class TokenLogProb(BaseTokenLogProb):
Expand Down
7 changes: 7 additions & 0 deletions agents-api/agents_api/autogen/Common.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,3 +98,10 @@ class ResourceUpdatedResponse(BaseModel):
"""
IDs (if any) of jobs created as part of this request
"""


class Uuid(RootModel[UUID]):
model_config = ConfigDict(
populate_by_name=True,
)
root: UUID
198 changes: 99 additions & 99 deletions agents-api/poetry.lock

Large diffs are not rendered by default.

11 changes: 9 additions & 2 deletions agents-api/tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from unittest.mock import patch

from fastapi.testclient import TestClient
from litellm.types.utils import Choices, ModelResponse
from litellm.types.utils import ModelResponse
from temporalio.testing import WorkflowEnvironment

from agents_api.worker.codec import pydantic_data_converter
Expand Down Expand Up @@ -72,7 +72,14 @@ def make_request(method, url, **kwargs):
def patch_embed_acompletion(output={"role": "assistant", "content": "Hello, world!"}):
mock_model_response = ModelResponse(
id="fake_id",
choices=[Choices(message=output)],
choices=[
dict(
message=output,
tool_calls=[],
created_at=1,
# finish_reason="stop",
)
],
created=0,
object="text_completion",
)
Expand Down
4 changes: 2 additions & 2 deletions scripts/generate_openapi_code.sh
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@ cd -
# fern generate

cd sdks/python && \
poetry update && \
# poetry update && \
poetry run poe format
cd -

cd agents-api && \
poetry update && \
# poetry update && \
poetry run poe codegen && \
poetry run poe format
cd -
Expand Down
6 changes: 3 additions & 3 deletions sdks/python/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

11 changes: 9 additions & 2 deletions sdks/ts/src/api/models/Chat_MultipleChatOutput.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
/* tslint:disable */
/* eslint-disable */
import type { Chat_BaseChatOutput } from "./Chat_BaseChatOutput";
import type { Common_uuid } from "./Common_uuid";
import type { Entries_ChatMLRole } from "./Entries_ChatMLRole";
import type { Tools_ChosenToolCall } from "./Tools_ChosenToolCall";
/**
* The output returned by the model. Note that, depending on the model provider, they might return more than one message.
*/
Expand All @@ -22,8 +24,13 @@ export type Chat_MultipleChatOutput = Chat_BaseChatOutput & {
*/
name?: string;
/**
* Whether to continue this message or return a new one
* Tool calls generated by the model.
*/
continue?: boolean;
readonly tool_calls?: Array<Tools_ChosenToolCall> | null;
/**
* When this resource was created as UTC date-time
*/
readonly created_at?: string;
readonly id?: Common_uuid;
}>;
};
11 changes: 9 additions & 2 deletions sdks/ts/src/api/models/Chat_SingleChatOutput.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
/* tslint:disable */
/* eslint-disable */
import type { Chat_BaseChatOutput } from "./Chat_BaseChatOutput";
import type { Common_uuid } from "./Common_uuid";
import type { Entries_ChatMLRole } from "./Entries_ChatMLRole";
import type { Tools_ChosenToolCall } from "./Tools_ChosenToolCall";
/**
* The output returned by the model. Note that, depending on the model provider, they might return more than one message.
*/
Expand All @@ -22,8 +24,13 @@ export type Chat_SingleChatOutput = Chat_BaseChatOutput & {
*/
name?: string;
/**
* Whether to continue this message or return a new one
* Tool calls generated by the model.
*/
continue?: boolean;
readonly tool_calls?: Array<Tools_ChosenToolCall> | null;
/**
* When this resource was created as UTC date-time
*/
readonly created_at?: string;
readonly id?: Common_uuid;
};
};
25 changes: 22 additions & 3 deletions sdks/ts/src/api/schemas/$Chat_MultipleChatOutput.ts
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,28 @@ export const $Chat_MultipleChatOutput = {
type: "string",
description: `Name`,
},
continue: {
type: "boolean",
description: `Whether to continue this message or return a new one`,
tool_calls: {
type: "array",
contains: {
type: "Tools_ChosenToolCall",
},
isReadOnly: true,
isNullable: true,
},
created_at: {
type: "string",
description: `When this resource was created as UTC date-time`,
isReadOnly: true,
format: "date-time",
},
id: {
type: "all-of",
contains: [
{
type: "Common_uuid",
},
],
isReadOnly: true,
},
},
},
Expand Down
25 changes: 22 additions & 3 deletions sdks/ts/src/api/schemas/$Chat_SingleChatOutput.ts
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,28 @@ export const $Chat_SingleChatOutput = {
type: "string",
description: `Name`,
},
continue: {
type: "boolean",
description: `Whether to continue this message or return a new one`,
tool_calls: {
type: "array",
contains: {
type: "Tools_ChosenToolCall",
},
isReadOnly: true,
isNullable: true,
},
created_at: {
type: "string",
description: `When this resource was created as UTC date-time`,
isReadOnly: true,
format: "date-time",
},
id: {
type: "all-of",
contains: [
{
type: "Common_uuid",
},
],
isReadOnly: true,
},
},
isRequired: true,
Expand Down
7 changes: 4 additions & 3 deletions typespec/chat/models.tsp
Original file line number Diff line number Diff line change
Expand Up @@ -195,22 +195,23 @@ model BaseChatOutput {
index: uint32;

/** The reason the model stopped generating tokens */
finish_reason: FinishReason;
finish_reason: FinishReason = FinishReason.stop;

/** The log probabilities of tokens */
logprobs?: LogProbResponse;
}

/** The output returned by the model. Note that, depending on the model provider, they might return more than one message. */
// TODO: Need to add support for tool calls
model SingleChatOutput extends BaseChatOutput {
message: InputChatMLMessage;
message: ChatMLMessage;
}

/** The output returned by the model. Note that, depending on the model provider, they might return more than one message. */
model MultipleChatOutput extends BaseChatOutput {
@visibility("read")
@minItems(1)
messages: InputChatMLMessage[];
messages: ChatMLMessage[];
}

alias ChatOutput = SingleChatOutput | MultipleChatOutput;
Expand Down
11 changes: 11 additions & 0 deletions typespec/common/mixins.tsp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,12 @@ namespace Common;
// COMMON MIXINS
//

model HasCreatedAtOptional {
/** When this resource was created as UTC date-time */
@visibility("read")
created_at?: utcDateTime;
}

model HasCreatedAt {
/** When this resource was created as UTC date-time */
@visibility("read")
Expand All @@ -30,6 +36,11 @@ model HasTimestamps {
...HasUpdatedAt;
}

model HasIdOptional {
@visibility("read")
id?: uuid;
}

model HasId {
@visibility("read")
@key
Expand Down
6 changes: 3 additions & 3 deletions typespec/entries/models.tsp
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,10 @@ model ChatMLMessage<T extends string = string> {

/** Tool calls generated by the model. */
@visibility("read")
tool_calls: ChosenToolCall[] = #[];
tool_calls?: ChosenToolCall[] | null = #[];

...HasCreatedAt;
...HasId;
...HasCreatedAtOptional;
...HasIdOptional;
}

@withVisibility("create")
Expand Down

0 comments on commit d66343d

Please sign in to comment.