Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(agents-api,typespec): Fix chat/entry typespec models to include tool_calls #489

Merged
merged 1 commit into from
Sep 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
fix(agents-api,typespec): Fix chat/entry typespec models to include t…
…ool_calls

Signed-off-by: Diwank Singh Tomer <[email protected]>
  • Loading branch information
creatorrr committed Sep 4, 2024
commit fa05cd1e1d1330a4bacf2f45f706d5d9fed61b65
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

from beartype import beartype
from temporalio import activity

Expand Down
47 changes: 43 additions & 4 deletions agents-api/agents_api/autogen/Chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,15 @@

from .Common import LogitBias
from .Docs import DocReference
from .Tools import NamedToolChoice, Tool
from .Tools import ChosenToolCall, NamedToolChoice, Tool


class BaseChatOutput(BaseModel):
model_config = ConfigDict(
populate_by_name=True,
)
index: int
finish_reason: Literal["stop", "length", "content_filter", "tool_calls"]
finish_reason: Literal["stop", "length", "content_filter", "tool_calls"] = "stop"
"""
The reason the model stopped generating tokens
"""
Expand Down Expand Up @@ -260,6 +260,45 @@ class MessageChatResponse(BaseChatResponse):
"""


class MessageModel(BaseModel):
model_config = ConfigDict(
populate_by_name=True,
)
role: Literal[
"user",
"assistant",
"system",
"function",
"function_response",
"function_call",
"auto",
]
"""
The role of the message
"""
content: str | list[str] | list[Content | ContentModel]
"""
The content parts of the message
"""
name: str | None = None
"""
Name
"""
tool_calls: Annotated[
HamadaSalhab marked this conversation as resolved.
Show resolved Hide resolved
list[ChosenToolCall] | None, Field([], json_schema_extra={"readOnly": True})
]
"""
Tool calls generated by the model.
"""
created_at: Annotated[
AwareDatetime | None, Field(None, json_schema_extra={"readOnly": True})
]
"""
When this resource was created as UTC date-time
"""
id: Annotated[UUID | None, Field(None, json_schema_extra={"readOnly": True})]


class MultipleChatOutput(BaseChatOutput):
"""
The output returned by the model. Note that, depending on the model provider, they might return more than one message.
Expand All @@ -269,7 +308,7 @@ class MultipleChatOutput(BaseChatOutput):
populate_by_name=True,
)
messages: Annotated[
list[Message], Field(json_schema_extra={"readOnly": True}, min_length=1)
list[MessageModel], Field(json_schema_extra={"readOnly": True}, min_length=1)
]


Expand Down Expand Up @@ -327,7 +366,7 @@ class SingleChatOutput(BaseChatOutput):
model_config = ConfigDict(
populate_by_name=True,
)
message: Message
message: MessageModel


class TokenLogProb(BaseTokenLogProb):
Expand Down
7 changes: 7 additions & 0 deletions agents-api/agents_api/autogen/Common.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,3 +98,10 @@ class ResourceUpdatedResponse(BaseModel):
"""
IDs (if any) of jobs created as part of this request
"""


class Uuid(RootModel[UUID]):
model_config = ConfigDict(
populate_by_name=True,
)
root: UUID
198 changes: 99 additions & 99 deletions agents-api/poetry.lock

Large diffs are not rendered by default.

11 changes: 9 additions & 2 deletions agents-api/tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from unittest.mock import patch

from fastapi.testclient import TestClient
from litellm.types.utils import Choices, ModelResponse
from litellm.types.utils import ModelResponse
from temporalio.testing import WorkflowEnvironment

from agents_api.worker.codec import pydantic_data_converter
Expand Down Expand Up @@ -72,7 +72,14 @@ def make_request(method, url, **kwargs):
def patch_embed_acompletion(output={"role": "assistant", "content": "Hello, world!"}):
mock_model_response = ModelResponse(
id="fake_id",
choices=[Choices(message=output)],
choices=[
dict(
message=output,
tool_calls=[],
created_at=1,
# finish_reason="stop",
)
],
created=0,
object="text_completion",
)
Expand Down
4 changes: 2 additions & 2 deletions scripts/generate_openapi_code.sh
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@ cd -
# fern generate

cd sdks/python && \
poetry update && \
# poetry update && \
poetry run poe format
cd -

cd agents-api && \
poetry update && \
# poetry update && \
poetry run poe codegen && \
poetry run poe format
cd -
Expand Down
6 changes: 3 additions & 3 deletions sdks/python/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

11 changes: 9 additions & 2 deletions sdks/ts/src/api/models/Chat_MultipleChatOutput.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
/* tslint:disable */
/* eslint-disable */
import type { Chat_BaseChatOutput } from "./Chat_BaseChatOutput";
import type { Common_uuid } from "./Common_uuid";
import type { Entries_ChatMLRole } from "./Entries_ChatMLRole";
import type { Tools_ChosenToolCall } from "./Tools_ChosenToolCall";
/**
* The output returned by the model. Note that, depending on the model provider, they might return more than one message.
*/
Expand All @@ -22,8 +24,13 @@ export type Chat_MultipleChatOutput = Chat_BaseChatOutput & {
*/
name?: string;
/**
* Whether to continue this message or return a new one
* Tool calls generated by the model.
*/
continue?: boolean;
readonly tool_calls?: Array<Tools_ChosenToolCall> | null;
/**
* When this resource was created as UTC date-time
*/
readonly created_at?: string;
readonly id?: Common_uuid;
}>;
};
11 changes: 9 additions & 2 deletions sdks/ts/src/api/models/Chat_SingleChatOutput.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
/* tslint:disable */
/* eslint-disable */
import type { Chat_BaseChatOutput } from "./Chat_BaseChatOutput";
import type { Common_uuid } from "./Common_uuid";
import type { Entries_ChatMLRole } from "./Entries_ChatMLRole";
import type { Tools_ChosenToolCall } from "./Tools_ChosenToolCall";
/**
* The output returned by the model. Note that, depending on the model provider, they might return more than one message.
*/
Expand All @@ -22,8 +24,13 @@ export type Chat_SingleChatOutput = Chat_BaseChatOutput & {
*/
name?: string;
/**
* Whether to continue this message or return a new one
* Tool calls generated by the model.
*/
continue?: boolean;
readonly tool_calls?: Array<Tools_ChosenToolCall> | null;
/**
* When this resource was created as UTC date-time
*/
readonly created_at?: string;
readonly id?: Common_uuid;
};
};
25 changes: 22 additions & 3 deletions sdks/ts/src/api/schemas/$Chat_MultipleChatOutput.ts
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,28 @@ export const $Chat_MultipleChatOutput = {
type: "string",
description: `Name`,
},
continue: {
type: "boolean",
description: `Whether to continue this message or return a new one`,
tool_calls: {
type: "array",
contains: {
type: "Tools_ChosenToolCall",
},
isReadOnly: true,
isNullable: true,
},
created_at: {
type: "string",
description: `When this resource was created as UTC date-time`,
isReadOnly: true,
format: "date-time",
},
id: {
type: "all-of",
contains: [
{
type: "Common_uuid",
},
],
isReadOnly: true,
},
},
},
Expand Down
25 changes: 22 additions & 3 deletions sdks/ts/src/api/schemas/$Chat_SingleChatOutput.ts
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,28 @@ export const $Chat_SingleChatOutput = {
type: "string",
description: `Name`,
},
continue: {
type: "boolean",
description: `Whether to continue this message or return a new one`,
tool_calls: {
type: "array",
contains: {
type: "Tools_ChosenToolCall",
},
isReadOnly: true,
isNullable: true,
},
created_at: {
type: "string",
description: `When this resource was created as UTC date-time`,
isReadOnly: true,
format: "date-time",
},
id: {
type: "all-of",
contains: [
{
type: "Common_uuid",
},
],
isReadOnly: true,
},
},
isRequired: true,
Expand Down
7 changes: 4 additions & 3 deletions typespec/chat/models.tsp
Original file line number Diff line number Diff line change
Expand Up @@ -195,22 +195,23 @@ model BaseChatOutput {
index: uint32;

/** The reason the model stopped generating tokens */
finish_reason: FinishReason;
finish_reason: FinishReason = FinishReason.stop;

/** The log probabilities of tokens */
logprobs?: LogProbResponse;
}

/** The output returned by the model. Note that, depending on the model provider, they might return more than one message. */
// TODO: Need to add support for tool calls
model SingleChatOutput extends BaseChatOutput {
message: InputChatMLMessage;
message: ChatMLMessage;
}

/** The output returned by the model. Note that, depending on the model provider, they might return more than one message. */
model MultipleChatOutput extends BaseChatOutput {
@visibility("read")
@minItems(1)
messages: InputChatMLMessage[];
messages: ChatMLMessage[];
}

alias ChatOutput = SingleChatOutput | MultipleChatOutput;
Expand Down
11 changes: 11 additions & 0 deletions typespec/common/mixins.tsp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,12 @@ namespace Common;
// COMMON MIXINS
//

model HasCreatedAtOptional {
/** When this resource was created as UTC date-time */
@visibility("read")
created_at?: utcDateTime;
}

model HasCreatedAt {
/** When this resource was created as UTC date-time */
@visibility("read")
Expand All @@ -30,6 +36,11 @@ model HasTimestamps {
...HasUpdatedAt;
}

model HasIdOptional {
@visibility("read")
id?: uuid;
}

model HasId {
@visibility("read")
@key
Expand Down
6 changes: 3 additions & 3 deletions typespec/entries/models.tsp
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,10 @@ model ChatMLMessage<T extends string = string> {

/** Tool calls generated by the model. */
@visibility("read")
tool_calls: ChosenToolCall[] = #[];
tool_calls?: ChosenToolCall[] | null = #[];

...HasCreatedAt;
...HasId;
...HasCreatedAtOptional;
...HasIdOptional;
}

@withVisibility("create")
Expand Down
Loading