Skip to content

Commit 62502b1

Browse files
committed
Finish refactoring of rag flows
1 parent 076f367 commit 62502b1

File tree

4 files changed

+164
-209
lines changed

4 files changed

+164
-209
lines changed
Lines changed: 65 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
1-
import os
21
from collections.abc import AsyncGenerator
32
from typing import Optional, TypedDict, Union
43

5-
from openai import AsyncAzureOpenAI, AsyncOpenAI, AsyncStream
6-
from openai.types.chat import ChatCompletionChunk, ChatCompletionMessageParam
7-
from openai_messages_token_helper import get_token_limit
4+
from openai import AsyncAzureOpenAI, AsyncOpenAI
5+
from openai.types.chat import ChatCompletionMessageParam
86
from pydantic_ai import Agent, RunContext
97
from pydantic_ai.messages import ModelMessagesTypeAdapter
108
from pydantic_ai.models.openai import OpenAIModel
@@ -13,22 +11,17 @@
1311

1412
from fastapi_app.api_models import (
1513
AIChatRoles,
14+
ChatRequestOverrides,
1615
ItemPublic,
1716
Message,
1817
RAGContext,
1918
RetrievalResponse,
2019
RetrievalResponseDelta,
2120
ThoughtStep,
2221
)
23-
from fastapi_app.postgres_models import Item
2422
from fastapi_app.postgres_searcher import PostgresSearcher
2523
from fastapi_app.rag_base import ChatParams, RAGChatBase
2624

27-
# Experiment #1: Annotated did not work!
28-
# Experiment #2: Function-level docstring, Inline docstrings next to attributes
29-
# Function -level docstring leads to XML like this: <summary>Search ...
30-
# Experiment #3: Move the docstrings below the attributes in triple-quoted strings - SUCCESS!!!
31-
3225

3326
class PriceFilter(TypedDict):
3427
column: str = "price"
@@ -64,19 +57,44 @@ class SearchResults(TypedDict):
6457

6558

6659
class AdvancedRAGChat(RAGChatBase):
60+
query_prompt_template = open(RAGChatBase.prompts_dir / "query.txt").read()
61+
query_fewshots = open(RAGChatBase.prompts_dir / "query_fewshots.json").read()
62+
6763
def __init__(
6864
self,
6965
*,
66+
messages: list[ChatCompletionMessageParam],
67+
overrides: ChatRequestOverrides,
7068
searcher: PostgresSearcher,
7169
openai_chat_client: Union[AsyncOpenAI, AsyncAzureOpenAI],
7270
chat_model: str,
7371
chat_deployment: Optional[str], # Not needed for non-Azure OpenAI
7472
):
7573
self.searcher = searcher
76-
self.openai_chat_client = openai_chat_client
77-
self.chat_model = chat_model
78-
self.chat_deployment = chat_deployment
79-
self.chat_token_limit = get_token_limit(chat_model, default_to_minimum=True)
74+
self.chat_params = self.get_chat_params(messages, overrides)
75+
self.model_for_thoughts = (
76+
{"model": chat_model, "deployment": chat_deployment} if chat_deployment else {"model": chat_model}
77+
)
78+
pydantic_chat_model = OpenAIModel(
79+
chat_model if chat_deployment is None else chat_deployment,
80+
provider=OpenAIProvider(openai_client=openai_chat_client),
81+
)
82+
self.search_agent = Agent(
83+
pydantic_chat_model,
84+
model_settings=ModelSettings(temperature=0.0, max_tokens=500, seed=self.chat_params.seed),
85+
system_prompt=self.query_prompt_template,
86+
tools=[self.search_database],
87+
output_type=SearchResults,
88+
)
89+
self.answer_agent = Agent(
90+
pydantic_chat_model,
91+
system_prompt=self.answer_prompt_template,
92+
model_settings=ModelSettings(
93+
temperature=self.chat_params.temperature,
94+
max_tokens=self.chat_params.response_token_limit,
95+
seed=self.chat_params.seed,
96+
),
97+
)
8098

8199
async def search_database(
82100
self,
@@ -113,42 +131,28 @@ async def search_database(
113131
query=search_query, items=[ItemPublic.model_validate(item.to_dict()) for item in results], filters=filters
114132
)
115133

116-
async def prepare_context(self, chat_params: ChatParams) -> tuple[list[ItemPublic], list[ThoughtStep]]:
117-
model = OpenAIModel(
118-
os.environ["AZURE_OPENAI_CHAT_DEPLOYMENT"], provider=OpenAIProvider(openai_client=self.openai_chat_client)
119-
)
120-
agent = Agent(
121-
model,
122-
model_settings=ModelSettings(temperature=0.0, max_tokens=500, seed=chat_params.seed),
123-
system_prompt=self.query_prompt_template,
124-
tools=[self.search_database],
125-
output_type=SearchResults,
126-
)
134+
async def prepare_context(self) -> tuple[list[ItemPublic], list[ThoughtStep]]:
127135
few_shots = ModelMessagesTypeAdapter.validate_json(self.query_fewshots)
128-
user_query = f"Find search results for user query: {chat_params.original_user_query}"
129-
results = await agent.run(
136+
user_query = f"Find search results for user query: {self.chat_params.original_user_query}"
137+
results = await self.search_agent.run(
130138
user_query,
131-
message_history=few_shots + chat_params.past_messages,
132-
deps=chat_params,
139+
message_history=few_shots + self.chat_params.past_messages,
140+
deps=self.chat_params,
133141
)
134142
items = results.output["items"]
135143
thoughts = [
136144
ThoughtStep(
137145
title="Prompt to generate search arguments",
138146
description=results.all_messages(),
139-
props=(
140-
{"model": self.chat_model, "deployment": self.chat_deployment}
141-
if self.chat_deployment
142-
else {"model": self.chat_model} # TODO
143-
),
147+
props=self.model_for_thoughts,
144148
),
145149
ThoughtStep(
146150
title="Search using generated search arguments",
147151
description=results.output["query"],
148152
props={
149-
"top": chat_params.top,
150-
"vector_search": chat_params.enable_vector_search,
151-
"text_search": chat_params.enable_text_search,
153+
"top": self.chat_params.top,
154+
"vector_search": self.chat_params.enable_vector_search,
155+
"text_search": self.chat_params.enable_text_search,
152156
"filters": results.output["filters"],
153157
},
154158
),
@@ -161,25 +165,12 @@ async def prepare_context(self, chat_params: ChatParams) -> tuple[list[ItemPubli
161165

162166
async def answer(
163167
self,
164-
chat_params: ChatParams,
165168
items: list[ItemPublic],
166169
earlier_thoughts: list[ThoughtStep],
167170
) -> RetrievalResponse:
168-
agent = Agent(
169-
OpenAIModel(
170-
os.environ["AZURE_OPENAI_CHAT_DEPLOYMENT"],
171-
provider=OpenAIProvider(openai_client=self.openai_chat_client),
172-
),
173-
system_prompt=self.answer_prompt_template,
174-
model_settings=ModelSettings(
175-
temperature=chat_params.temperature, max_tokens=chat_params.response_token_limit, seed=chat_params.seed
176-
),
177-
)
178-
179-
sources_content = [f"[{(item.id)}]:{item.to_str_for_rag()}\n\n" for item in items]
180-
response = await agent.run(
181-
user_prompt=chat_params.original_user_query + "Sources:\n" + "\n".join(sources_content),
182-
message_history=chat_params.past_messages,
171+
response = await self.answer_agent.run(
172+
user_prompt=self.prepare_rag_request(self.chat_params.original_user_query, items),
173+
message_history=self.chat_params.past_messages,
183174
)
184175

185176
return RetrievalResponse(
@@ -191,57 +182,35 @@ async def answer(
191182
ThoughtStep(
192183
title="Prompt to generate answer",
193184
description=response.all_messages(),
194-
props=(
195-
{"model": self.chat_model, "deployment": self.chat_deployment}
196-
if self.chat_deployment
197-
else {"model": self.chat_model}
198-
),
185+
props=self.model_for_thoughts,
199186
),
200187
],
201188
),
202189
)
203190

204191
async def answer_stream(
205192
self,
206-
chat_params: ChatParams,
207-
contextual_messages: list[ChatCompletionMessageParam],
208-
results: list[Item],
193+
items: list[ItemPublic],
209194
earlier_thoughts: list[ThoughtStep],
210195
) -> AsyncGenerator[RetrievalResponseDelta, None]:
211-
chat_completion_async_stream: AsyncStream[
212-
ChatCompletionChunk
213-
] = await self.openai_chat_client.chat.completions.create(
214-
# Azure OpenAI takes the deployment name as the model name
215-
model=self.chat_deployment if self.chat_deployment else self.chat_model,
216-
messages=contextual_messages,
217-
temperature=chat_params.temperature,
218-
max_tokens=chat_params.response_token_limit,
219-
n=1,
220-
stream=True,
221-
)
222-
223-
yield RetrievalResponseDelta(
224-
context=RAGContext(
225-
data_points={item.id: item.to_dict() for item in results},
226-
thoughts=earlier_thoughts
227-
+ [
228-
ThoughtStep(
229-
title="Prompt to generate answer",
230-
description=contextual_messages,
231-
props=(
232-
{"model": self.chat_model, "deployment": self.chat_deployment}
233-
if self.chat_deployment
234-
else {"model": self.chat_model}
196+
async with self.answer_agent.run_stream(
197+
self.prepare_rag_request(self.chat_params.original_user_query, items),
198+
message_history=self.chat_params.past_messages,
199+
) as agent_stream_runner:
200+
yield RetrievalResponseDelta(
201+
context=RAGContext(
202+
data_points={item.id: item for item in items},
203+
thoughts=earlier_thoughts
204+
+ [
205+
ThoughtStep(
206+
title="Prompt to generate answer",
207+
description=agent_stream_runner.all_messages(),
208+
props=self.model_for_thoughts,
235209
),
236-
),
237-
],
238-
),
239-
)
210+
],
211+
),
212+
)
240213

241-
async for response_chunk in chat_completion_async_stream:
242-
# first response has empty choices and last response has empty content
243-
if response_chunk.choices and response_chunk.choices[0].delta.content:
244-
yield RetrievalResponseDelta(
245-
delta=Message(content=str(response_chunk.choices[0].delta.content), role=AIChatRoles.ASSISTANT)
246-
)
247-
return
214+
async for message in agent_stream_runner.stream_text(delta=True, debounce_by=None):
215+
yield RetrievalResponseDelta(delta=Message(content=str(message), role=AIChatRoles.ASSISTANT))
216+
return

src/backend/fastapi_app/rag_base.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from fastapi_app.api_models import (
88
ChatParams,
99
ChatRequestOverrides,
10+
ItemPublic,
1011
RetrievalResponse,
1112
RetrievalResponseDelta,
1213
ThoughtStep,
@@ -15,12 +16,12 @@
1516

1617

1718
class RAGChatBase(ABC):
18-
current_dir = pathlib.Path(__file__).parent
19-
query_prompt_template = open(current_dir / "prompts/query.txt").read()
20-
query_fewshots = open(current_dir / "prompts/query_fewshots.json").read()
21-
answer_prompt_template = open(current_dir / "prompts/answer.txt").read()
19+
prompts_dir = pathlib.Path(__file__).parent / "prompts/"
20+
answer_prompt_template = open(prompts_dir / "answer.txt").read()
2221

23-
def get_params(self, messages: list[ChatCompletionMessageParam], overrides: ChatRequestOverrides) -> ChatParams:
22+
def get_chat_params(
23+
self, messages: list[ChatCompletionMessageParam], overrides: ChatRequestOverrides
24+
) -> ChatParams:
2425
response_token_limit = 1024
2526
prompt_template = overrides.prompt_template or self.answer_prompt_template
2627

@@ -52,6 +53,10 @@ async def prepare_context(
5253
) -> tuple[list[ChatCompletionMessageParam], list[Item], list[ThoughtStep]]:
5354
raise NotImplementedError
5455

56+
def prepare_rag_request(self, user_query, items: list[ItemPublic]) -> str:
57+
sources_str = "\n".join([f"[{item.id}]:{item.to_str_for_rag()}" for item in items])
58+
return f"{user_query}Sources:\n{sources_str}"
59+
5560
@abstractmethod
5661
async def answer(
5762
self,

0 commit comments

Comments
 (0)