1
+ import json
1
2
from collections .abc import AsyncGenerator
2
3
from typing import Optional , Union
3
4
5
+ from agents import (
6
+ Agent ,
7
+ ItemHelpers ,
8
+ ModelSettings ,
9
+ OpenAIChatCompletionsModel ,
10
+ Runner ,
11
+ ToolCallOutputItem ,
12
+ function_tool ,
13
+ set_tracing_disabled ,
14
+ )
4
15
from openai import AsyncAzureOpenAI , AsyncOpenAI
5
- from openai .types .chat import ChatCompletionMessageParam
6
- from pydantic_ai import Agent , RunContext
7
- from pydantic_ai .messages import ModelMessagesTypeAdapter
8
- from pydantic_ai .models .openai import OpenAIModel
9
- from pydantic_ai .providers .openai import OpenAIProvider
10
- from pydantic_ai .settings import ModelSettings
16
+ from openai .types .responses import EasyInputMessageParam , ResponseInputItemParam , ResponseTextDeltaEvent
11
17
12
18
from fastapi_app .api_models import (
13
19
AIChatRoles ,
24
30
ThoughtStep ,
25
31
)
26
32
from fastapi_app .postgres_searcher import PostgresSearcher
27
- from fastapi_app .rag_base import ChatParams , RAGChatBase
33
+ from fastapi_app .rag_base import RAGChatBase
34
+
35
+ set_tracing_disabled (disabled = True )
28
36
29
37
30
38
class AdvancedRAGChat (RAGChatBase ):
@@ -34,7 +42,7 @@ class AdvancedRAGChat(RAGChatBase):
34
42
def __init__ (
35
43
self ,
36
44
* ,
37
- messages : list [ChatCompletionMessageParam ],
45
+ messages : list [ResponseInputItemParam ],
38
46
overrides : ChatRequestOverrides ,
39
47
searcher : PostgresSearcher ,
40
48
openai_chat_client : Union [AsyncOpenAI , AsyncAzureOpenAI ],
@@ -46,34 +54,29 @@ def __init__(
46
54
self .model_for_thoughts = (
47
55
{"model" : chat_model , "deployment" : chat_deployment } if chat_deployment else {"model" : chat_model }
48
56
)
49
- pydantic_chat_model = OpenAIModel (
50
- chat_model if chat_deployment is None else chat_deployment ,
51
- provider = OpenAIProvider (openai_client = openai_chat_client ),
57
+ openai_agents_model = OpenAIChatCompletionsModel (
58
+ model = chat_model if chat_deployment is None else chat_deployment , openai_client = openai_chat_client
52
59
)
53
- self .search_agent = Agent [ChatParams , SearchResults ](
54
- pydantic_chat_model ,
55
- model_settings = ModelSettings (
56
- temperature = 0.0 ,
57
- max_tokens = 500 ,
58
- ** ({"seed" : self .chat_params .seed } if self .chat_params .seed is not None else {}),
59
- ),
60
- system_prompt = self .query_prompt_template ,
61
- tools = [self .search_database ],
62
- output_type = SearchResults ,
60
+ self .search_agent = Agent (
61
+ name = "Searcher" ,
62
+ instructions = self .query_prompt_template ,
63
+ tools = [function_tool (self .search_database )],
64
+ tool_use_behavior = "stop_on_first_tool" ,
65
+ model = openai_agents_model ,
63
66
)
64
67
self .answer_agent = Agent (
65
- pydantic_chat_model ,
66
- system_prompt = self .answer_prompt_template ,
68
+ name = "Answerer" ,
69
+ instructions = self .answer_prompt_template ,
70
+ model = openai_agents_model ,
67
71
model_settings = ModelSettings (
68
72
temperature = self .chat_params .temperature ,
69
73
max_tokens = self .chat_params .response_token_limit ,
70
- ** ( {"seed" : self .chat_params .seed } if self .chat_params .seed is not None else {}) ,
74
+ extra_body = {"seed" : self .chat_params .seed } if self .chat_params .seed is not None else {},
71
75
),
72
76
)
73
77
74
78
async def search_database (
75
79
self ,
76
- ctx : RunContext [ChatParams ],
77
80
search_query : str ,
78
81
price_filter : Optional [PriceFilter ] = None ,
79
82
brand_filter : Optional [BrandFilter ] = None ,
@@ -97,66 +100,73 @@ async def search_database(
97
100
filters .append (brand_filter )
98
101
results = await self .searcher .search_and_embed (
99
102
search_query ,
100
- top = ctx . deps .top ,
101
- enable_vector_search = ctx . deps .enable_vector_search ,
102
- enable_text_search = ctx . deps .enable_text_search ,
103
+ top = self . chat_params .top ,
104
+ enable_vector_search = self . chat_params .enable_vector_search ,
105
+ enable_text_search = self . chat_params .enable_text_search ,
103
106
filters = filters ,
104
107
)
105
108
return SearchResults (
106
109
query = search_query , items = [ItemPublic .model_validate (item .to_dict ()) for item in results ], filters = filters
107
110
)
108
111
109
112
async def prepare_context (self ) -> tuple [list [ItemPublic ], list [ThoughtStep ]]:
110
- few_shots = ModelMessagesTypeAdapter . validate_json (self .query_fewshots )
113
+ few_shots : list [ ResponseInputItemParam ] = json . loads (self .query_fewshots )
111
114
user_query = f"Find search results for user query: { self .chat_params .original_user_query } "
112
- results = await self .search_agent .run (
113
- user_query ,
114
- message_history = few_shots + self .chat_params .past_messages ,
115
- deps = self .chat_params ,
116
- )
117
- items = results .output .items
115
+ new_user_message = EasyInputMessageParam (role = "user" , content = user_query )
116
+ all_messages = few_shots + self .chat_params .past_messages + [new_user_message ]
117
+
118
+ run_results = await Runner .run (self .search_agent , input = all_messages )
119
+ most_recent_response = run_results .new_items [- 1 ]
120
+ if isinstance (most_recent_response , ToolCallOutputItem ):
121
+ search_results = most_recent_response .output
122
+ else :
123
+ raise ValueError ("Error retrieving search results, model did not call tool properly" )
124
+
118
125
thoughts = [
119
126
ThoughtStep (
120
127
title = "Prompt to generate search arguments" ,
121
- description = results .all_messages (),
128
+ description = [{"content" : self .query_prompt_template }]
129
+ + ItemHelpers .input_to_new_input_list (run_results .input ),
122
130
props = self .model_for_thoughts ,
123
131
),
124
132
ThoughtStep (
125
133
title = "Search using generated search arguments" ,
126
- description = results . output .query ,
134
+ description = search_results .query ,
127
135
props = {
128
136
"top" : self .chat_params .top ,
129
137
"vector_search" : self .chat_params .enable_vector_search ,
130
138
"text_search" : self .chat_params .enable_text_search ,
131
- "filters" : results . output .filters ,
139
+ "filters" : search_results .filters ,
132
140
},
133
141
),
134
142
ThoughtStep (
135
143
title = "Search results" ,
136
- description = items ,
144
+ description = search_results . items ,
137
145
),
138
146
]
139
- return items , thoughts
147
+ return search_results . items , thoughts
140
148
141
149
async def answer (
142
150
self ,
143
151
items : list [ItemPublic ],
144
152
earlier_thoughts : list [ThoughtStep ],
145
153
) -> RetrievalResponse :
146
- response = await self .answer_agent .run (
147
- user_prompt = self .prepare_rag_request (self .chat_params .original_user_query , items ),
148
- message_history = self .chat_params .past_messages ,
154
+ run_results = await Runner .run (
155
+ self .answer_agent ,
156
+ input = self .chat_params .past_messages
157
+ + [{"content" : self .prepare_rag_request (self .chat_params .original_user_query , items ), "role" : "user" }],
149
158
)
150
159
151
160
return RetrievalResponse (
152
- message = Message (content = str (response . output ), role = AIChatRoles .ASSISTANT ),
161
+ message = Message (content = str (run_results . final_output ), role = AIChatRoles .ASSISTANT ),
153
162
context = RAGContext (
154
163
data_points = {item .id : item for item in items },
155
164
thoughts = earlier_thoughts
156
165
+ [
157
166
ThoughtStep (
158
167
title = "Prompt to generate answer" ,
159
- description = response .all_messages (),
168
+ description = [{"content" : self .answer_prompt_template }]
169
+ + ItemHelpers .input_to_new_input_list (run_results .input ),
160
170
props = self .model_for_thoughts ,
161
171
),
162
172
],
@@ -168,24 +178,28 @@ async def answer_stream(
168
178
items : list [ItemPublic ],
169
179
earlier_thoughts : list [ThoughtStep ],
170
180
) -> AsyncGenerator [RetrievalResponseDelta , None ]:
171
- async with self .answer_agent .run_stream (
172
- self .prepare_rag_request (self .chat_params .original_user_query , items ),
173
- message_history = self .chat_params .past_messages ,
174
- ) as agent_stream_runner :
175
- yield RetrievalResponseDelta (
176
- context = RAGContext (
177
- data_points = {item .id : item for item in items },
178
- thoughts = earlier_thoughts
179
- + [
180
- ThoughtStep (
181
- title = "Prompt to generate answer" ,
182
- description = agent_stream_runner .all_messages (),
183
- props = self .model_for_thoughts ,
184
- ),
185
- ],
186
- ),
187
- )
188
-
189
- async for message in agent_stream_runner .stream_text (delta = True , debounce_by = None ):
190
- yield RetrievalResponseDelta (delta = Message (content = str (message ), role = AIChatRoles .ASSISTANT ))
191
- return
181
+ run_results = Runner .run_streamed (
182
+ self .answer_agent ,
183
+ input = self .chat_params .past_messages
184
+ + [{"content" : self .prepare_rag_request (self .chat_params .original_user_query , items ), "role" : "user" }], # noqa
185
+ )
186
+
187
+ yield RetrievalResponseDelta (
188
+ context = RAGContext (
189
+ data_points = {item .id : item for item in items },
190
+ thoughts = earlier_thoughts
191
+ + [
192
+ ThoughtStep (
193
+ title = "Prompt to generate answer" ,
194
+ description = [{"content" : self .answer_prompt_template }]
195
+ + ItemHelpers .input_to_new_input_list (run_results .input ),
196
+ props = self .model_for_thoughts ,
197
+ ),
198
+ ],
199
+ ),
200
+ )
201
+
202
+ async for event in run_results .stream_events ():
203
+ if event .type == "raw_response_event" and isinstance (event .data , ResponseTextDeltaEvent ):
204
+ yield RetrievalResponseDelta (delta = Message (content = str (event .data .delta ), role = AIChatRoles .ASSISTANT ))
205
+ return
0 commit comments