Skip to content

Commit 5a3deac

Browse files
authored
Port from pydantic-ai to openai-agents SDK (#211)
* Port to OpenAI-agents SDK * Port to OpenAI-agents SDK * Fix tests, mypy * Update package requirements * More dep/mypy updates * Update snapshot * Add system message to thoughts * Make mypy happy
1 parent b000a71 commit 5a3deac

File tree

16 files changed

+212
-461
lines changed

16 files changed

+212
-461
lines changed

.github/workflows/app-tests.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ jobs:
123123
key: mypy${{ matrix.os }}-${{ matrix.python_version }}-${{ hashFiles('requirements-dev.txt', 'src/backend/requirements.txt', 'src/backend/pyproject.toml') }}
124124

125125
- name: Run MyPy
126-
run: python3 -m mypy .
126+
run: python3 -m mypy . --python-version ${{ matrix.python_version }}
127127

128128
- name: Run Pytest
129129
run: python3 -m pytest -s -vv --cov --cov-fail-under=85

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ lint.isort.known-first-party = ["fastapi_app"]
77

88
[tool.mypy]
99
check_untyped_defs = true
10-
python_version = 3.9
1110
exclude = [".venv/*"]
1211

1312
[tool.pytest.ini_options]

requirements-dev.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,3 @@ pytest-snapshot
1414
locust
1515
psycopg2
1616
dotenv-azd
17-
freezegun

src/backend/fastapi_app/api_models.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
from enum import Enum
2-
from typing import Any, Optional, Union
2+
from typing import Any, Optional
33

4-
from openai.types.chat import ChatCompletionMessageParam
4+
from openai.types.responses import ResponseInputItemParam
55
from pydantic import BaseModel, Field
6-
from pydantic_ai.messages import ModelRequest, ModelResponse
76

87

98
class AIChatRoles(str, Enum):
@@ -37,7 +36,7 @@ class ChatRequestContext(BaseModel):
3736

3837

3938
class ChatRequest(BaseModel):
40-
messages: list[ChatCompletionMessageParam]
39+
messages: list[ResponseInputItemParam]
4140
context: ChatRequestContext
4241
sessionState: Optional[Any] = None
4342

@@ -96,7 +95,7 @@ class ChatParams(ChatRequestOverrides):
9695
enable_text_search: bool
9796
enable_vector_search: bool
9897
original_user_query: str
99-
past_messages: list[Union[ModelRequest, ModelResponse]]
98+
past_messages: list[ResponseInputItemParam]
10099

101100

102101
class Filter(BaseModel):
Lines changed: 22 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -1,76 +1,36 @@
11
[
22
{
3-
"parts": [
4-
{
5-
"content": "good options for climbing gear that can be used outside?",
6-
"timestamp": "2025-05-07T19:02:46.977501Z",
7-
"part_kind": "user-prompt"
8-
}
9-
],
10-
"instructions": null,
11-
"kind": "request"
3+
"role": "user",
4+
"content": "good options for climbing gear that can be used outside?"
125
},
136
{
14-
"parts": [
15-
{
16-
"tool_name": "search_database",
17-
"args": "{\"search_query\":\"climbing gear outside\"}",
18-
"tool_call_id": "call_4HeBCmo2uioV6CyoePEGyZPc",
19-
"part_kind": "tool-call"
20-
}
21-
],
22-
"model_name": "gpt-4o-mini-2024-07-18",
23-
"timestamp": "2025-05-07T19:02:47Z",
24-
"kind": "response"
7+
"id": "madeup",
8+
"call_id": "call_abc123",
9+
"name": "search_database",
10+
"arguments": "{\"search_query\":\"climbing gear outside\"}",
11+
"type": "function_call"
2512
},
2613
{
27-
"parts": [
28-
{
29-
"tool_name": "search_database",
30-
"content": "Search results for climbing gear that can be used outside: ...",
31-
"tool_call_id": "call_4HeBCmo2uioV6CyoePEGyZPc",
32-
"timestamp": "2025-05-07T19:02:48.242408Z",
33-
"part_kind": "tool-return"
34-
}
35-
],
36-
"instructions": null,
37-
"kind": "request"
14+
"id": "madeupoutput",
15+
"call_id": "call_abc123",
16+
"output": "Search results for climbing gear that can be used outside: ...",
17+
"type": "function_call_output"
3818
},
3919
{
40-
"parts": [
41-
{
42-
"content": "are there any shoes less than $50?",
43-
"timestamp": "2025-05-07T19:02:46.977501Z",
44-
"part_kind": "user-prompt"
45-
}
46-
],
47-
"instructions": null,
48-
"kind": "request"
20+
"role": "user",
21+
"content": "are there any shoes less than $50?"
4922
},
5023
{
51-
"parts": [
52-
{
53-
"tool_name": "search_database",
54-
"args": "{\"search_query\":\"shoes\",\"price_filter\":{\"comparison_operator\":\"<\",\"value\":50}}",
55-
"tool_call_id": "call_4HeBCmo2uioV6CyoePEGyZPc",
56-
"part_kind": "tool-call"
57-
}
58-
],
59-
"model_name": "gpt-4o-mini-2024-07-18",
60-
"timestamp": "2025-05-07T19:02:47Z",
61-
"kind": "response"
24+
"id": "madeup",
25+
"call_id": "call_abc456",
26+
"name": "search_database",
27+
"arguments": "{\"search_query\":\"shoes\",\"price_filter\":{\"comparison_operator\":\"<\",\"value\":50}}",
28+
"type": "function_call"
6229
},
6330
{
64-
"parts": [
65-
{
66-
"tool_name": "search_database",
67-
"content": "Search results for shoes cheaper than 50: ...",
68-
"tool_call_id": "call_4HeBCmo2uioV6CyoePEGyZPc",
69-
"timestamp": "2025-05-07T19:02:48.242408Z",
70-
"part_kind": "tool-return"
71-
}
72-
],
73-
"instructions": null,
74-
"kind": "request"
31+
"id": "madeupoutput",
32+
"call_id": "call_abc456",
33+
"output": "Search results for shoes cheaper than 50: ...",
34+
"type": "function_call_output"
7535
}
7636
]
Lines changed: 80 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,19 @@
1+
import json
12
from collections.abc import AsyncGenerator
23
from typing import Optional, Union
34

5+
from agents import (
6+
Agent,
7+
ItemHelpers,
8+
ModelSettings,
9+
OpenAIChatCompletionsModel,
10+
Runner,
11+
ToolCallOutputItem,
12+
function_tool,
13+
set_tracing_disabled,
14+
)
415
from openai import AsyncAzureOpenAI, AsyncOpenAI
5-
from openai.types.chat import ChatCompletionMessageParam
6-
from pydantic_ai import Agent, RunContext
7-
from pydantic_ai.messages import ModelMessagesTypeAdapter
8-
from pydantic_ai.models.openai import OpenAIModel
9-
from pydantic_ai.providers.openai import OpenAIProvider
10-
from pydantic_ai.settings import ModelSettings
16+
from openai.types.responses import EasyInputMessageParam, ResponseInputItemParam, ResponseTextDeltaEvent
1117

1218
from fastapi_app.api_models import (
1319
AIChatRoles,
@@ -24,7 +30,9 @@
2430
ThoughtStep,
2531
)
2632
from fastapi_app.postgres_searcher import PostgresSearcher
27-
from fastapi_app.rag_base import ChatParams, RAGChatBase
33+
from fastapi_app.rag_base import RAGChatBase
34+
35+
set_tracing_disabled(disabled=True)
2836

2937

3038
class AdvancedRAGChat(RAGChatBase):
@@ -34,7 +42,7 @@ class AdvancedRAGChat(RAGChatBase):
3442
def __init__(
3543
self,
3644
*,
37-
messages: list[ChatCompletionMessageParam],
45+
messages: list[ResponseInputItemParam],
3846
overrides: ChatRequestOverrides,
3947
searcher: PostgresSearcher,
4048
openai_chat_client: Union[AsyncOpenAI, AsyncAzureOpenAI],
@@ -46,34 +54,29 @@ def __init__(
4654
self.model_for_thoughts = (
4755
{"model": chat_model, "deployment": chat_deployment} if chat_deployment else {"model": chat_model}
4856
)
49-
pydantic_chat_model = OpenAIModel(
50-
chat_model if chat_deployment is None else chat_deployment,
51-
provider=OpenAIProvider(openai_client=openai_chat_client),
57+
openai_agents_model = OpenAIChatCompletionsModel(
58+
model=chat_model if chat_deployment is None else chat_deployment, openai_client=openai_chat_client
5259
)
53-
self.search_agent = Agent[ChatParams, SearchResults](
54-
pydantic_chat_model,
55-
model_settings=ModelSettings(
56-
temperature=0.0,
57-
max_tokens=500,
58-
**({"seed": self.chat_params.seed} if self.chat_params.seed is not None else {}),
59-
),
60-
system_prompt=self.query_prompt_template,
61-
tools=[self.search_database],
62-
output_type=SearchResults,
60+
self.search_agent = Agent(
61+
name="Searcher",
62+
instructions=self.query_prompt_template,
63+
tools=[function_tool(self.search_database)],
64+
tool_use_behavior="stop_on_first_tool",
65+
model=openai_agents_model,
6366
)
6467
self.answer_agent = Agent(
65-
pydantic_chat_model,
66-
system_prompt=self.answer_prompt_template,
68+
name="Answerer",
69+
instructions=self.answer_prompt_template,
70+
model=openai_agents_model,
6771
model_settings=ModelSettings(
6872
temperature=self.chat_params.temperature,
6973
max_tokens=self.chat_params.response_token_limit,
70-
**({"seed": self.chat_params.seed} if self.chat_params.seed is not None else {}),
74+
extra_body={"seed": self.chat_params.seed} if self.chat_params.seed is not None else {},
7175
),
7276
)
7377

7478
async def search_database(
7579
self,
76-
ctx: RunContext[ChatParams],
7780
search_query: str,
7881
price_filter: Optional[PriceFilter] = None,
7982
brand_filter: Optional[BrandFilter] = None,
@@ -97,66 +100,73 @@ async def search_database(
97100
filters.append(brand_filter)
98101
results = await self.searcher.search_and_embed(
99102
search_query,
100-
top=ctx.deps.top,
101-
enable_vector_search=ctx.deps.enable_vector_search,
102-
enable_text_search=ctx.deps.enable_text_search,
103+
top=self.chat_params.top,
104+
enable_vector_search=self.chat_params.enable_vector_search,
105+
enable_text_search=self.chat_params.enable_text_search,
103106
filters=filters,
104107
)
105108
return SearchResults(
106109
query=search_query, items=[ItemPublic.model_validate(item.to_dict()) for item in results], filters=filters
107110
)
108111

109112
async def prepare_context(self) -> tuple[list[ItemPublic], list[ThoughtStep]]:
110-
few_shots = ModelMessagesTypeAdapter.validate_json(self.query_fewshots)
113+
few_shots: list[ResponseInputItemParam] = json.loads(self.query_fewshots)
111114
user_query = f"Find search results for user query: {self.chat_params.original_user_query}"
112-
results = await self.search_agent.run(
113-
user_query,
114-
message_history=few_shots + self.chat_params.past_messages,
115-
deps=self.chat_params,
116-
)
117-
items = results.output.items
115+
new_user_message = EasyInputMessageParam(role="user", content=user_query)
116+
all_messages = few_shots + self.chat_params.past_messages + [new_user_message]
117+
118+
run_results = await Runner.run(self.search_agent, input=all_messages)
119+
most_recent_response = run_results.new_items[-1]
120+
if isinstance(most_recent_response, ToolCallOutputItem):
121+
search_results = most_recent_response.output
122+
else:
123+
raise ValueError("Error retrieving search results, model did not call tool properly")
124+
118125
thoughts = [
119126
ThoughtStep(
120127
title="Prompt to generate search arguments",
121-
description=results.all_messages(),
128+
description=[{"content": self.query_prompt_template}]
129+
+ ItemHelpers.input_to_new_input_list(run_results.input),
122130
props=self.model_for_thoughts,
123131
),
124132
ThoughtStep(
125133
title="Search using generated search arguments",
126-
description=results.output.query,
134+
description=search_results.query,
127135
props={
128136
"top": self.chat_params.top,
129137
"vector_search": self.chat_params.enable_vector_search,
130138
"text_search": self.chat_params.enable_text_search,
131-
"filters": results.output.filters,
139+
"filters": search_results.filters,
132140
},
133141
),
134142
ThoughtStep(
135143
title="Search results",
136-
description=items,
144+
description=search_results.items,
137145
),
138146
]
139-
return items, thoughts
147+
return search_results.items, thoughts
140148

141149
async def answer(
142150
self,
143151
items: list[ItemPublic],
144152
earlier_thoughts: list[ThoughtStep],
145153
) -> RetrievalResponse:
146-
response = await self.answer_agent.run(
147-
user_prompt=self.prepare_rag_request(self.chat_params.original_user_query, items),
148-
message_history=self.chat_params.past_messages,
154+
run_results = await Runner.run(
155+
self.answer_agent,
156+
input=self.chat_params.past_messages
157+
+ [{"content": self.prepare_rag_request(self.chat_params.original_user_query, items), "role": "user"}],
149158
)
150159

151160
return RetrievalResponse(
152-
message=Message(content=str(response.output), role=AIChatRoles.ASSISTANT),
161+
message=Message(content=str(run_results.final_output), role=AIChatRoles.ASSISTANT),
153162
context=RAGContext(
154163
data_points={item.id: item for item in items},
155164
thoughts=earlier_thoughts
156165
+ [
157166
ThoughtStep(
158167
title="Prompt to generate answer",
159-
description=response.all_messages(),
168+
description=[{"content": self.answer_prompt_template}]
169+
+ ItemHelpers.input_to_new_input_list(run_results.input),
160170
props=self.model_for_thoughts,
161171
),
162172
],
@@ -168,24 +178,28 @@ async def answer_stream(
168178
items: list[ItemPublic],
169179
earlier_thoughts: list[ThoughtStep],
170180
) -> AsyncGenerator[RetrievalResponseDelta, None]:
171-
async with self.answer_agent.run_stream(
172-
self.prepare_rag_request(self.chat_params.original_user_query, items),
173-
message_history=self.chat_params.past_messages,
174-
) as agent_stream_runner:
175-
yield RetrievalResponseDelta(
176-
context=RAGContext(
177-
data_points={item.id: item for item in items},
178-
thoughts=earlier_thoughts
179-
+ [
180-
ThoughtStep(
181-
title="Prompt to generate answer",
182-
description=agent_stream_runner.all_messages(),
183-
props=self.model_for_thoughts,
184-
),
185-
],
186-
),
187-
)
188-
189-
async for message in agent_stream_runner.stream_text(delta=True, debounce_by=None):
190-
yield RetrievalResponseDelta(delta=Message(content=str(message), role=AIChatRoles.ASSISTANT))
191-
return
181+
run_results = Runner.run_streamed(
182+
self.answer_agent,
183+
input=self.chat_params.past_messages
184+
+ [{"content": self.prepare_rag_request(self.chat_params.original_user_query, items), "role": "user"}], # noqa
185+
)
186+
187+
yield RetrievalResponseDelta(
188+
context=RAGContext(
189+
data_points={item.id: item for item in items},
190+
thoughts=earlier_thoughts
191+
+ [
192+
ThoughtStep(
193+
title="Prompt to generate answer",
194+
description=[{"content": self.answer_prompt_template}]
195+
+ ItemHelpers.input_to_new_input_list(run_results.input),
196+
props=self.model_for_thoughts,
197+
),
198+
],
199+
),
200+
)
201+
202+
async for event in run_results.stream_events():
203+
if event.type == "raw_response_event" and isinstance(event.data, ResponseTextDeltaEvent):
204+
yield RetrievalResponseDelta(delta=Message(content=str(event.data.delta), role=AIChatRoles.ASSISTANT))
205+
return

0 commit comments

Comments
 (0)