1
- import os
2
1
from collections .abc import AsyncGenerator
3
2
from typing import Optional , TypedDict , Union
4
3
5
- from openai import AsyncAzureOpenAI , AsyncOpenAI , AsyncStream
6
- from openai .types .chat import ChatCompletionChunk , ChatCompletionMessageParam
7
- from openai_messages_token_helper import get_token_limit
4
+ from openai import AsyncAzureOpenAI , AsyncOpenAI
5
+ from openai .types .chat import ChatCompletionMessageParam
8
6
from pydantic_ai import Agent , RunContext
9
7
from pydantic_ai .messages import ModelMessagesTypeAdapter
10
8
from pydantic_ai .models .openai import OpenAIModel
13
11
14
12
from fastapi_app .api_models import (
15
13
AIChatRoles ,
14
+ ChatRequestOverrides ,
16
15
ItemPublic ,
17
16
Message ,
18
17
RAGContext ,
19
18
RetrievalResponse ,
20
19
RetrievalResponseDelta ,
21
20
ThoughtStep ,
22
21
)
23
- from fastapi_app .postgres_models import Item
24
22
from fastapi_app .postgres_searcher import PostgresSearcher
25
23
from fastapi_app .rag_base import ChatParams , RAGChatBase
26
24
27
- # Experiment #1: Annotated did not work!
28
- # Experiment #2: Function-level docstring, Inline docstrings next to attributes
29
- # Function -level docstring leads to XML like this: <summary>Search ...
30
- # Experiment #3: Move the docstrings below the attributes in triple-quoted strings - SUCCESS!!!
31
-
32
25
33
26
class PriceFilter (TypedDict ):
34
27
column : str = "price"
@@ -64,19 +57,44 @@ class SearchResults(TypedDict):
64
57
65
58
66
59
class AdvancedRAGChat (RAGChatBase ):
60
+ query_prompt_template = open (RAGChatBase .prompts_dir / "query.txt" ).read ()
61
+ query_fewshots = open (RAGChatBase .prompts_dir / "query_fewshots.json" ).read ()
62
+
67
63
def __init__ (
68
64
self ,
69
65
* ,
66
+ messages : list [ChatCompletionMessageParam ],
67
+ overrides : ChatRequestOverrides ,
70
68
searcher : PostgresSearcher ,
71
69
openai_chat_client : Union [AsyncOpenAI , AsyncAzureOpenAI ],
72
70
chat_model : str ,
73
71
chat_deployment : Optional [str ], # Not needed for non-Azure OpenAI
74
72
):
75
73
self .searcher = searcher
76
- self .openai_chat_client = openai_chat_client
77
- self .chat_model = chat_model
78
- self .chat_deployment = chat_deployment
79
- self .chat_token_limit = get_token_limit (chat_model , default_to_minimum = True )
74
+ self .chat_params = self .get_chat_params (messages , overrides )
75
+ self .model_for_thoughts = (
76
+ {"model" : chat_model , "deployment" : chat_deployment } if chat_deployment else {"model" : chat_model }
77
+ )
78
+ pydantic_chat_model = OpenAIModel (
79
+ chat_model if chat_deployment is None else chat_deployment ,
80
+ provider = OpenAIProvider (openai_client = openai_chat_client ),
81
+ )
82
+ self .search_agent = Agent (
83
+ pydantic_chat_model ,
84
+ model_settings = ModelSettings (temperature = 0.0 , max_tokens = 500 , seed = self .chat_params .seed ),
85
+ system_prompt = self .query_prompt_template ,
86
+ tools = [self .search_database ],
87
+ output_type = SearchResults ,
88
+ )
89
+ self .answer_agent = Agent (
90
+ pydantic_chat_model ,
91
+ system_prompt = self .answer_prompt_template ,
92
+ model_settings = ModelSettings (
93
+ temperature = self .chat_params .temperature ,
94
+ max_tokens = self .chat_params .response_token_limit ,
95
+ seed = self .chat_params .seed ,
96
+ ),
97
+ )
80
98
81
99
async def search_database (
82
100
self ,
@@ -113,42 +131,28 @@ async def search_database(
113
131
query = search_query , items = [ItemPublic .model_validate (item .to_dict ()) for item in results ], filters = filters
114
132
)
115
133
116
- async def prepare_context (self , chat_params : ChatParams ) -> tuple [list [ItemPublic ], list [ThoughtStep ]]:
117
- model = OpenAIModel (
118
- os .environ ["AZURE_OPENAI_CHAT_DEPLOYMENT" ], provider = OpenAIProvider (openai_client = self .openai_chat_client )
119
- )
120
- agent = Agent (
121
- model ,
122
- model_settings = ModelSettings (temperature = 0.0 , max_tokens = 500 , seed = chat_params .seed ),
123
- system_prompt = self .query_prompt_template ,
124
- tools = [self .search_database ],
125
- output_type = SearchResults ,
126
- )
134
+ async def prepare_context (self ) -> tuple [list [ItemPublic ], list [ThoughtStep ]]:
127
135
few_shots = ModelMessagesTypeAdapter .validate_json (self .query_fewshots )
128
- user_query = f"Find search results for user query: { chat_params .original_user_query } "
129
- results = await agent .run (
136
+ user_query = f"Find search results for user query: { self . chat_params .original_user_query } "
137
+ results = await self . search_agent .run (
130
138
user_query ,
131
- message_history = few_shots + chat_params .past_messages ,
132
- deps = chat_params ,
139
+ message_history = few_shots + self . chat_params .past_messages ,
140
+ deps = self . chat_params ,
133
141
)
134
142
items = results .output ["items" ]
135
143
thoughts = [
136
144
ThoughtStep (
137
145
title = "Prompt to generate search arguments" ,
138
146
description = results .all_messages (),
139
- props = (
140
- {"model" : self .chat_model , "deployment" : self .chat_deployment }
141
- if self .chat_deployment
142
- else {"model" : self .chat_model } # TODO
143
- ),
147
+ props = self .model_for_thoughts ,
144
148
),
145
149
ThoughtStep (
146
150
title = "Search using generated search arguments" ,
147
151
description = results .output ["query" ],
148
152
props = {
149
- "top" : chat_params .top ,
150
- "vector_search" : chat_params .enable_vector_search ,
151
- "text_search" : chat_params .enable_text_search ,
153
+ "top" : self . chat_params .top ,
154
+ "vector_search" : self . chat_params .enable_vector_search ,
155
+ "text_search" : self . chat_params .enable_text_search ,
152
156
"filters" : results .output ["filters" ],
153
157
},
154
158
),
@@ -161,25 +165,12 @@ async def prepare_context(self, chat_params: ChatParams) -> tuple[list[ItemPubli
161
165
162
166
async def answer (
163
167
self ,
164
- chat_params : ChatParams ,
165
168
items : list [ItemPublic ],
166
169
earlier_thoughts : list [ThoughtStep ],
167
170
) -> RetrievalResponse :
168
- agent = Agent (
169
- OpenAIModel (
170
- os .environ ["AZURE_OPENAI_CHAT_DEPLOYMENT" ],
171
- provider = OpenAIProvider (openai_client = self .openai_chat_client ),
172
- ),
173
- system_prompt = self .answer_prompt_template ,
174
- model_settings = ModelSettings (
175
- temperature = chat_params .temperature , max_tokens = chat_params .response_token_limit , seed = chat_params .seed
176
- ),
177
- )
178
-
179
- sources_content = [f"[{ (item .id )} ]:{ item .to_str_for_rag ()} \n \n " for item in items ]
180
- response = await agent .run (
181
- user_prompt = chat_params .original_user_query + "Sources:\n " + "\n " .join (sources_content ),
182
- message_history = chat_params .past_messages ,
171
+ response = await self .answer_agent .run (
172
+ user_prompt = self .prepare_rag_request (self .chat_params .original_user_query , items ),
173
+ message_history = self .chat_params .past_messages ,
183
174
)
184
175
185
176
return RetrievalResponse (
@@ -191,57 +182,35 @@ async def answer(
191
182
ThoughtStep (
192
183
title = "Prompt to generate answer" ,
193
184
description = response .all_messages (),
194
- props = (
195
- {"model" : self .chat_model , "deployment" : self .chat_deployment }
196
- if self .chat_deployment
197
- else {"model" : self .chat_model }
198
- ),
185
+ props = self .model_for_thoughts ,
199
186
),
200
187
],
201
188
),
202
189
)
203
190
204
191
async def answer_stream (
205
192
self ,
206
- chat_params : ChatParams ,
207
- contextual_messages : list [ChatCompletionMessageParam ],
208
- results : list [Item ],
193
+ items : list [ItemPublic ],
209
194
earlier_thoughts : list [ThoughtStep ],
210
195
) -> AsyncGenerator [RetrievalResponseDelta , None ]:
211
- chat_completion_async_stream : AsyncStream [
212
- ChatCompletionChunk
213
- ] = await self .openai_chat_client .chat .completions .create (
214
- # Azure OpenAI takes the deployment name as the model name
215
- model = self .chat_deployment if self .chat_deployment else self .chat_model ,
216
- messages = contextual_messages ,
217
- temperature = chat_params .temperature ,
218
- max_tokens = chat_params .response_token_limit ,
219
- n = 1 ,
220
- stream = True ,
221
- )
222
-
223
- yield RetrievalResponseDelta (
224
- context = RAGContext (
225
- data_points = {item .id : item .to_dict () for item in results },
226
- thoughts = earlier_thoughts
227
- + [
228
- ThoughtStep (
229
- title = "Prompt to generate answer" ,
230
- description = contextual_messages ,
231
- props = (
232
- {"model" : self .chat_model , "deployment" : self .chat_deployment }
233
- if self .chat_deployment
234
- else {"model" : self .chat_model }
196
+ async with self .answer_agent .run_stream (
197
+ self .prepare_rag_request (self .chat_params .original_user_query , items ),
198
+ message_history = self .chat_params .past_messages ,
199
+ ) as agent_stream_runner :
200
+ yield RetrievalResponseDelta (
201
+ context = RAGContext (
202
+ data_points = {item .id : item for item in items },
203
+ thoughts = earlier_thoughts
204
+ + [
205
+ ThoughtStep (
206
+ title = "Prompt to generate answer" ,
207
+ description = agent_stream_runner .all_messages (),
208
+ props = self .model_for_thoughts ,
235
209
),
236
- ),
237
- ],
238
- ),
239
- )
210
+ ],
211
+ ),
212
+ )
240
213
241
- async for response_chunk in chat_completion_async_stream :
242
- # first response has empty choices and last response has empty content
243
- if response_chunk .choices and response_chunk .choices [0 ].delta .content :
244
- yield RetrievalResponseDelta (
245
- delta = Message (content = str (response_chunk .choices [0 ].delta .content ), role = AIChatRoles .ASSISTANT )
246
- )
247
- return
214
+ async for message in agent_stream_runner .stream_text (delta = True , debounce_by = None ):
215
+ yield RetrievalResponseDelta (delta = Message (content = str (message ), role = AIChatRoles .ASSISTANT ))
216
+ return
0 commit comments