import asyncio
import json
from collections.abc import AsyncIterator
from datetime import datetime
from typing import cast
from unittest.mock import AsyncMock, Mock
import pytest
from agents import (
Agent,
GuardrailFunctionOutput,
InputGuardrail,
InputGuardrailResult,
InputGuardrailTripwireTriggered,
OutputGuardrail,
OutputGuardrailResult,
OutputGuardrailTripwireTriggered,
RawResponsesStreamEvent,
RunContextWrapper,
RunItemStreamEvent,
Runner,
RunResultStreaming,
StreamEvent,
ToolCallItem,
)
from agents._run_impl import QueueCompleteSentinel
from openai.types.responses import (
EasyInputMessageParam,
ResponseFileSearchToolCall,
ResponseInputContentParam,
ResponseInputTextParam,
ResponseOutputItemAddedEvent,
ResponseOutputItemDoneEvent,
ResponseOutputMessage,
ResponseReasoningItem,
)
from openai.types.responses.response_content_part_added_event import (
ResponseContentPartAddedEvent,
)
from openai.types.responses.response_file_search_tool_call import Result
from openai.types.responses.response_output_text import (
AnnotationFileCitation as ResponsesAnnotationFileCitation,
)
from openai.types.responses.response_output_text import (
AnnotationFilePath as ResponsesAnnotationFilePath,
)
from openai.types.responses.response_output_text import (
AnnotationURLCitation as ResponsesAnnotationURLCitation,
)
from openai.types.responses.response_output_text import (
ResponseOutputText,
)
from openai.types.responses.response_text_delta_event import ResponseTextDeltaEvent
from openai.types.responses.response_text_done_event import ResponseTextDoneEvent
from chatkit.agents import (
AgentContext,
ThreadItemConverter,
accumulate_text,
simple_to_agent_input,
stream_agent_response,
)
from chatkit.types import (
Annotation,
AssistantMessageContent,
AssistantMessageContentPartAdded,
AssistantMessageContentPartDone,
AssistantMessageContentPartTextDelta,
AssistantMessageItem,
Attachment,
ClientToolCallItem,
CustomSummary,
CustomTask,
DurationSummary,
FileSource,
InferenceOptions,
Page,
TaskItem,
ThoughtTask,
Thread,
ThreadItemAddedEvent,
ThreadItemDoneEvent,
ThreadItemUpdated,
ThreadStreamEvent,
URLSource,
UserMessageItem,
UserMessageTagContent,
UserMessageTextContent,
WidgetItem,
Workflow,
WorkflowItem,
WorkflowTaskAdded,
WorkflowTaskUpdated,
)
from chatkit.widgets import Card, Text
thread = Thread(id="123", title="Test", created_at=datetime.now(), items=Page())
mock_store = Mock()
mock_store.generate_item_id = lambda item_type, thread, context: f"{item_type}_id"
mock_store.load_thread_items = AsyncMock(return_value=Page())
mock_store.add_thread_item = AsyncMock()
class RunResult(RunResultStreaming):
def add_event(self, event: StreamEvent):
self._event_queue.put_nowait(event)
def done(self):
self.is_complete = True
self._event_queue.put_nowait(QueueCompleteSentinel())
def throw_input_guardrails(self):
self._stored_exception = InputGuardrailTripwireTriggered(
InputGuardrailResult(
guardrail=Mock(spec=InputGuardrail),
output=GuardrailFunctionOutput(
output_info=None,
tripwire_triggered=True,
),
)
)
self.is_complete = True
self._event_queue.put_nowait(QueueCompleteSentinel())
def throw_output_guardrails(self):
self._stored_exception = OutputGuardrailTripwireTriggered(
OutputGuardrailResult(
guardrail=Mock(spec=OutputGuardrail),
output=GuardrailFunctionOutput(
output_info=None,
tripwire_triggered=True,
),
agent=Mock(spec=Agent),
agent_output=None,
)
)
self.is_complete = True
self._event_queue.put_nowait(QueueCompleteSentinel())
def make_result() -> RunResult:
return RunResult(
context_wrapper=Mock(spec=RunContextWrapper),
input=[],
tool_input_guardrail_results=[],
tool_output_guardrail_results=[],
new_items=[],
raw_responses=[],
final_output=None,
current_agent=Agent(name="test"),
current_turn=0,
max_turns=10,
_current_agent_output_schema=None,
trace=None,
is_complete=False,
_event_queue=asyncio.Queue(),
_input_guardrail_queue=asyncio.Queue(),
_output_guardrails_task=None,
_run_impl_task=None,
_stored_exception=None,
output_guardrail_results=[],
input_guardrail_results=[],
)
async def all_events(
events: AsyncIterator[ThreadStreamEvent],
) -> list[ThreadStreamEvent]:
return [event async for event in events]
async def test_returns_widget_item():
context = AgentContext(
previous_response_id=None, thread=thread, store=mock_store, request_context=None
)
result = make_result()
result.add_event(
RunItemStreamEvent(name="tool_called", item=Mock(spec=ToolCallItem))
)
await context.stream_widget(Card(children=[Text(value="Hello, world!")]))
result.done()
events = await all_events(
stream_agent_response(
context=context,
result=result,
)
)
assert len(events) == 1
assert isinstance(events[0], ThreadItemDoneEvent)
assert isinstance(events[0].item, WidgetItem)
assert events[0].item.widget == Card(children=[Text(value="Hello, world!")])
async def test_returns_widget_item_generator():
context = AgentContext(
previous_response_id=None, thread=thread, store=mock_store, request_context=None
)
result = make_result()
result.add_event(
RunItemStreamEvent(name="tool_called", item=Mock(spec=ToolCallItem))
)
def render_widget(i: int) -> Card:
return Card(children=[Text(id="text", value="Hello, world"[:i])])
async def widget_generator():
yield render_widget(0)
yield render_widget(12)
await context.stream_widget(widget_generator())
result.done()
events = await all_events(
stream_agent_response(
context=context,
result=result,
)
)
assert len(events) == 3
assert isinstance(events[0], ThreadItemAddedEvent)
assert isinstance(events[0].item, WidgetItem)
assert events[0].item.widget == Card(children=[Text(id="text", value="")])
assert isinstance(events[1], ThreadItemUpdated)
assert events[1].update.type == "widget.streaming_text.value_delta"
assert events[1].update.component_id == "text"
assert events[1].update.delta == "Hello, world"
assert isinstance(events[2], ThreadItemDoneEvent)
assert isinstance(events[2].item, WidgetItem)
assert events[2].item.widget == Card(
children=[Text(id="text", value="Hello, world")]
)
async def test_returns_widget_full_replace_generator():
context = AgentContext(
previous_response_id=None, thread=thread, store=mock_store, request_context=None
)
result = make_result()
result.add_event(
RunItemStreamEvent(name="tool_called", item=Mock(spec=ToolCallItem))
)
async def widget_generator():
yield Card(children=[Text(id="text", value="Hello!")])
yield Card(children=[Text(key="other text", value="World!", streaming=False)])
await context.stream_widget(widget_generator())
result.done()
events = await all_events(
stream_agent_response(
context=context,
result=result,
)
)
assert len(events) == 3
assert isinstance(events[0], ThreadItemAddedEvent)
assert isinstance(events[0].item, WidgetItem)
assert events[0].item.widget == Card(children=[Text(id="text", value="Hello!")])
assert isinstance(events[1], ThreadItemUpdated)
assert events[1].update.type == "widget.root.updated"
assert events[1].update.widget == Card(
children=[Text(key="other text", value="World!", streaming=False)]
)
assert isinstance(events[2], ThreadItemDoneEvent)
assert isinstance(events[2].item, WidgetItem)
assert events[2].item.widget == Card(
children=[Text(key="other text", value="World!", streaming=False)]
)
async def test_accumulate_text():
def delta(text: str) -> RawResponsesStreamEvent:
return RawResponsesStreamEvent(
type="raw_response_event",
data=ResponseTextDeltaEvent(
type="response.output_text.delta",
delta=text,
content_index=0,
item_id="123",
logprobs=[],
output_index=0,
sequence_number=0,
),
)
result = Runner.run_streamed(
Agent("Assistant", instructions="You are a helpful assistant."), "Say hello!"
)
result = make_result()
result.add_event(delta("Hello, "))
result.add_event(delta("world!"))
result.done()
events = [
event
async for event in accumulate_text(
result.stream_events(), Text(key="text", value="", streaming=True)
)
]
assert events == [
Text(key="text", value="", streaming=True),
Text(key="text", value="Hello, ", streaming=True),
Text(key="text", value="Hello, world!", streaming=True),
Text(key="text", value="Hello, world!", streaming=False),
]
async def test_input_item_converter_quotes_last_user_message():
items = [
UserMessageItem(
id="123",
content=[UserMessageTextContent(text="Hello!")],
attachments=[],
inference_options=InferenceOptions(),
thread_id=thread.id,
quoted_text="Hi!",
created_at=datetime.now(),
),
UserMessageItem(
id="123",
content=[UserMessageTextContent(text="I'm well, thank you!")],
attachments=[],
inference_options=InferenceOptions(),
thread_id=thread.id,
quoted_text="How are you doing?",
created_at=datetime.now(),
),
]
async def throw_exception(
_: Attachment,
) -> ResponseInputContentParam:
raise Exception("Not implemented")
input_items = await simple_to_agent_input(items)
assert len(input_items) == 3
assert input_items[0] == {
"content": [
{
"text": "Hello!",
"type": "input_text",
},
],
"role": "user",
"type": "message",
}
assert input_items[1] == {
"content": [
{
"text": "I'm well, thank you!",
"type": "input_text",
},
],
"role": "user",
"type": "message",
}
assert input_items[2] == {
"content": [
{
"text": "The user is referring to this in particular: \nHow are you doing?",
"type": "input_text",
},
],
"role": "user",
"type": "message",
}
async def test_input_item_converter_to_input_items_mixed():
items = [
UserMessageItem(
id="123",
content=[UserMessageTextContent(text="Hello!")],
attachments=[],
inference_options=InferenceOptions(),
thread_id=thread.id,
quoted_text="Hi!",
created_at=datetime.now(),
),
UserMessageItem(
id="123",
content=[UserMessageTextContent(text="I'm well, thank you!")],
attachments=[],
inference_options=InferenceOptions(),
thread_id=thread.id,
quoted_text="How are you doing?",
created_at=datetime.now(),
),
AssistantMessageItem(
id="123",
content=[
AssistantMessageContent(text="How are you doing?"),
AssistantMessageContent(text="Can't do that"),
],
thread_id=thread.id,
created_at=datetime.now(),
),
WidgetItem(
id="wd_123",
widget=Card(children=[Text(value="Hello, world!")]),
thread_id=thread.id,
created_at=datetime.now(),
),
]
input_items = await simple_to_agent_input(items)
assert len(input_items) == 4
assert input_items[0] == {
"content": [
{
"text": "Hello!",
"type": "input_text",
},
],
"role": "user",
"type": "message",
}
assert input_items[1] == {
"content": [
{
"text": "I'm well, thank you!",
"type": "input_text",
},
],
"role": "user",
"type": "message",
}
assert input_items[2] == {
"content": [
{
"annotations": [],
"text": "How are you doing?",
"logprobs": None,
"type": "output_text",
},
{
"annotations": [],
"text": "Can't do that",
"logprobs": None,
"type": "output_text",
},
],
"type": "message",
"role": "assistant",
}
assert "type" in input_items[3]
widget_item = cast(EasyInputMessageParam, input_items[3])
assert widget_item.get("type") == "message"
assert widget_item.get("role") == "user"
text = widget_item.get("content")[0]["text"] # type: ignore
assert (
"The following graphical UI widget (id: wd_123) was displayed to the user"
in text
)
assert "Hello, world!" in text
assert "created_at" not in text
async def test_input_item_converter_user_input_with_tags():
class MyThreadItemConverter(ThreadItemConverter):
def tag_to_message_content(self, tag):
return ResponseInputTextParam(
type="input_text", text=tag.text + " " + tag.data["key"]
)
items = [
UserMessageItem(
id="123",
content=[
UserMessageTagContent(
text="Hello!", type="input_tag", id="hello", data={"key": "value"}
)
],
attachments=[],
inference_options=InferenceOptions(),
thread_id=thread.id,
created_at=datetime.now(),
)
]
items = await MyThreadItemConverter().to_agent_input(items)
assert len(items) == 2
assert items[0] == {
"content": [
{
"text": "@Hello!",
"type": "input_text",
},
],
"role": "user",
"type": "message",
}
assert items[1] == {
"content": [
{
"text": "# User-provided context for @-mentions\n- When referencing resolved entities, use their canonical names **without** '@'.\n"
+ "- The '@' form appears only in user text and should not be echoed.",
"type": "input_text",
},
{
"text": "Hello! value",
"type": "input_text",
},
],
"role": "user",
"type": "message",
}
async def test_input_item_converter_user_input_with_tags_throws_by_default():
items = [
UserMessageItem(
id="123",
content=[
UserMessageTagContent(
text="Hello!", type="input_tag", id="hello", data={}
)
],
attachments=[],
inference_options=InferenceOptions(),
thread_id=thread.id,
created_at=datetime.now(),
)
]
with pytest.raises(NotImplementedError):
await simple_to_agent_input(items)
async def test_input_item_converter_with_client_tool_call():
items = [
UserMessageItem(
id="123",
content=[UserMessageTextContent(text="Call a client tool call xyz")],
attachments=[],
inference_options=InferenceOptions(),
thread_id=thread.id,
quoted_text="Hi!",
created_at=datetime.now(),
),
TaskItem(
id="tsk_123",
created_at=datetime.now(),
task=CustomTask(title="Called xyx"),
thread_id=thread.id,
),
ClientToolCallItem(
id="ctc_123",
thread_id=thread.id,
created_at=datetime.now(),
name="xyz",
arguments={"foo": "bar"},
call_id="ctc_123",
),
ClientToolCallItem(
id="ctc_123_done",
thread_id=thread.id,
created_at=datetime.now(),
name="xyz",
arguments={"foo": "bar"},
call_id="ctc_123",
status="completed",
output={"success": True},
),
]
input_items = await simple_to_agent_input(items)
assert len(input_items) == 4
assert input_items[0] == {
"content": [
{
"text": "Call a client tool call xyz",
"type": "input_text",
},
],
"role": "user",
"type": "message",
}
assert input_items[1] == {
"content": [
{
"text": "A message was displayed to the user that the following task was performed:\n