Skip to content

Commit 01a75f5

Browse files
authored
Merge pull request #206 from Azure-Samples/pydanticai
Port to Pydantic-AI
2 parents 6ab9ea5 + 09e317e commit 01a75f5

21 files changed

+824
-406
lines changed

requirements-dev.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,4 @@ pytest-snapshot
1414
locust
1515
psycopg2
1616
dotenv-azd
17+
freezegun

src/backend/fastapi_app/__init__.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,13 @@ class State(TypedDict):
3434
@asynccontextmanager
3535
async def lifespan(app: fastapi.FastAPI) -> AsyncIterator[State]:
3636
context = await common_parameters()
37-
azure_credential = await get_azure_credential()
37+
azure_credential = None
38+
if (
39+
os.getenv("OPENAI_CHAT_HOST") == "azure"
40+
or os.getenv("OPENAI_EMBED_HOST") == "azure"
41+
or os.getenv("POSTGRES_HOST", "").endswith(".database.azure.com")
42+
):
43+
azure_credential = await get_azure_credential()
3844
engine = await create_postgres_engine_from_env(azure_credential)
3945
sessionmaker = await create_async_sessionmaker(engine)
4046
chat_client = await create_openai_chat_client(azure_credential)
@@ -53,6 +59,7 @@ def create_app(testing: bool = False):
5359
if not testing:
5460
load_dotenv(override=True)
5561
logging.basicConfig(level=logging.INFO)
62+
5663
# Turn off particularly noisy INFO level logs from Azure Core SDK:
5764
logging.getLogger("azure.core.pipeline.policies.http_logging_policy").setLevel(logging.WARNING)
5865
logging.getLogger("azure.identity").setLevel(logging.WARNING)

src/backend/fastapi_app/api_models.py

Lines changed: 54 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
from enum import Enum
2-
from typing import Any, Optional
2+
from typing import Any, Optional, Union
33

44
from openai.types.chat import ChatCompletionMessageParam
5-
from pydantic import BaseModel
5+
from pydantic import BaseModel, Field
6+
from pydantic_ai.messages import ModelRequest, ModelResponse
67

78

89
class AIChatRoles(str, Enum):
@@ -41,14 +42,34 @@ class ChatRequest(BaseModel):
4142
sessionState: Optional[Any] = None
4243

4344

45+
class ItemPublic(BaseModel):
46+
id: int
47+
type: str
48+
brand: str
49+
name: str
50+
description: str
51+
price: float
52+
53+
def to_str_for_rag(self):
54+
return f"Name:{self.name} Description:{self.description} Price:{self.price} Brand:{self.brand} Type:{self.type}"
55+
56+
57+
class ItemWithDistance(ItemPublic):
58+
distance: float
59+
60+
def __init__(self, **data):
61+
super().__init__(**data)
62+
self.distance = round(self.distance, 2)
63+
64+
4465
class ThoughtStep(BaseModel):
4566
title: str
4667
description: Any
4768
props: dict = {}
4869

4970

5071
class RAGContext(BaseModel):
51-
data_points: dict[int, dict[str, Any]]
72+
data_points: dict[int, ItemPublic]
5273
thoughts: list[ThoughtStep]
5374
followup_questions: Optional[list[str]] = None
5475

@@ -69,27 +90,39 @@ class RetrievalResponseDelta(BaseModel):
6990
sessionState: Optional[Any] = None
7091

7192

72-
class ItemPublic(BaseModel):
73-
id: int
74-
type: str
75-
brand: str
76-
name: str
77-
description: str
78-
price: float
79-
80-
81-
class ItemWithDistance(ItemPublic):
82-
distance: float
83-
84-
def __init__(self, **data):
85-
super().__init__(**data)
86-
self.distance = round(self.distance, 2)
87-
88-
8993
class ChatParams(ChatRequestOverrides):
9094
prompt_template: str
9195
response_token_limit: int = 1024
9296
enable_text_search: bool
9397
enable_vector_search: bool
9498
original_user_query: str
95-
past_messages: list[ChatCompletionMessageParam]
99+
past_messages: list[Union[ModelRequest, ModelResponse]]
100+
101+
102+
class Filter(BaseModel):
103+
column: str
104+
comparison_operator: str
105+
value: Any
106+
107+
108+
class PriceFilter(Filter):
109+
column: str = Field(default="price", description="The column to filter on (always 'price' for this filter)")
110+
comparison_operator: str = Field(description="The operator for price comparison ('>', '<', '>=', '<=', '=')")
111+
value: float = Field(description="The price value to compare against (e.g., 30.00)")
112+
113+
114+
class BrandFilter(Filter):
115+
column: str = Field(default="brand", description="The column to filter on (always 'brand' for this filter)")
116+
comparison_operator: str = Field(description="The operator for brand comparison ('=' or '!=')")
117+
value: str = Field(description="The brand name to compare against (e.g., 'AirStrider')")
118+
119+
120+
class SearchResults(BaseModel):
121+
query: str
122+
"""The original search query"""
123+
124+
items: list[ItemPublic]
125+
"""List of items that match the search query and filters"""
126+
127+
filters: list[Filter]
128+
"""List of filters applied to the search results"""

src/backend/fastapi_app/openai_clients.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,12 @@
99

1010

1111
async def create_openai_chat_client(
12-
azure_credential: Union[azure.identity.AzureDeveloperCliCredential, azure.identity.ManagedIdentityCredential],
12+
azure_credential: Union[azure.identity.AzureDeveloperCliCredential, azure.identity.ManagedIdentityCredential, None],
1313
) -> Union[openai.AsyncAzureOpenAI, openai.AsyncOpenAI]:
1414
openai_chat_client: Union[openai.AsyncAzureOpenAI, openai.AsyncOpenAI]
1515
OPENAI_CHAT_HOST = os.getenv("OPENAI_CHAT_HOST")
1616
if OPENAI_CHAT_HOST == "azure":
17-
api_version = os.environ["AZURE_OPENAI_VERSION"] or "2024-03-01-preview"
17+
api_version = os.environ["AZURE_OPENAI_VERSION"] or "2024-10-21"
1818
azure_endpoint = os.environ["AZURE_OPENAI_ENDPOINT"]
1919
azure_deployment = os.environ["AZURE_OPENAI_CHAT_DEPLOYMENT"]
2020
if api_key := os.getenv("AZURE_OPENAI_KEY"):
@@ -29,7 +29,7 @@ async def create_openai_chat_client(
2929
azure_deployment=azure_deployment,
3030
api_key=api_key,
3131
)
32-
else:
32+
elif azure_credential:
3333
logger.info(
3434
"Setting up Azure OpenAI client for chat completions using Azure Identity, endpoint %s, deployment %s",
3535
azure_endpoint,
@@ -44,6 +44,8 @@ async def create_openai_chat_client(
4444
azure_deployment=azure_deployment,
4545
azure_ad_token_provider=token_provider,
4646
)
47+
else:
48+
raise ValueError("Azure OpenAI client requires either an API key or Azure Identity credential.")
4749
elif OPENAI_CHAT_HOST == "ollama":
4850
logger.info("Setting up OpenAI client for chat completions using Ollama")
4951
openai_chat_client = openai.AsyncOpenAI(
@@ -67,7 +69,7 @@ async def create_openai_chat_client(
6769

6870

6971
async def create_openai_embed_client(
70-
azure_credential: Union[azure.identity.AzureDeveloperCliCredential, azure.identity.ManagedIdentityCredential],
72+
azure_credential: Union[azure.identity.AzureDeveloperCliCredential, azure.identity.ManagedIdentityCredential, None],
7173
) -> Union[openai.AsyncAzureOpenAI, openai.AsyncOpenAI]:
7274
openai_embed_client: Union[openai.AsyncAzureOpenAI, openai.AsyncOpenAI]
7375
OPENAI_EMBED_HOST = os.getenv("OPENAI_EMBED_HOST")
@@ -87,7 +89,7 @@ async def create_openai_embed_client(
8789
azure_deployment=azure_deployment,
8890
api_key=api_key,
8991
)
90-
else:
92+
elif azure_credential:
9193
logger.info(
9294
"Setting up Azure OpenAI client for embeddings using Azure Identity, endpoint %s, deployment %s",
9395
azure_endpoint,
@@ -102,6 +104,8 @@ async def create_openai_embed_client(
102104
azure_deployment=azure_deployment,
103105
azure_ad_token_provider=token_provider,
104106
)
107+
else:
108+
raise ValueError("Azure OpenAI client requires either an API key or Azure Identity credential.")
105109
elif OPENAI_EMBED_HOST == "ollama":
106110
logger.info("Setting up OpenAI client for embeddings using Ollama")
107111
openai_embed_client = openai.AsyncOpenAI(

src/backend/fastapi_app/postgres_searcher.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from sqlalchemy import Float, Integer, column, select, text
66
from sqlalchemy.ext.asyncio import AsyncSession
77

8+
from fastapi_app.api_models import Filter
89
from fastapi_app.embeddings import compute_text_embedding
910
from fastapi_app.postgres_models import Item
1011

@@ -26,21 +27,24 @@ def __init__(
2627
self.embed_dimensions = embed_dimensions
2728
self.embedding_column = embedding_column
2829

29-
def build_filter_clause(self, filters) -> tuple[str, str]:
30+
def build_filter_clause(self, filters: Optional[list[Filter]]) -> tuple[str, str]:
3031
if filters is None:
3132
return "", ""
3233
filter_clauses = []
3334
for filter in filters:
34-
if isinstance(filter["value"], str):
35-
filter["value"] = f"'{filter['value']}'"
36-
filter_clauses.append(f"{filter['column']} {filter['comparison_operator']} {filter['value']}")
35+
filter_value = f"'{filter.value}'" if isinstance(filter.value, str) else filter.value
36+
filter_clauses.append(f"{filter.column} {filter.comparison_operator} {filter_value}")
3737
filter_clause = " AND ".join(filter_clauses)
3838
if len(filter_clause) > 0:
3939
return f"WHERE {filter_clause}", f"AND {filter_clause}"
4040
return "", ""
4141

4242
async def search(
43-
self, query_text: Optional[str], query_vector: list[float], top: int = 5, filters: Optional[list[dict]] = None
43+
self,
44+
query_text: Optional[str],
45+
query_vector: list[float],
46+
top: int = 5,
47+
filters: Optional[list[Filter]] = None,
4448
):
4549
filter_clause_where, filter_clause_and = self.build_filter_clause(filters)
4650
table_name = Item.__tablename__
@@ -106,7 +110,7 @@ async def search_and_embed(
106110
top: int = 5,
107111
enable_vector_search: bool = False,
108112
enable_text_search: bool = False,
109-
filters: Optional[list[dict]] = None,
113+
filters: Optional[list[Filter]] = None,
110114
) -> list[Item]:
111115
"""
112116
Search rows by query text. Optionally converts the query text to a vector if enable_vector_search is True.
Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
1-
Below is a history of the conversation so far, and a new question asked by the user that needs to be answered by searching database rows.
2-
You have access to an Azure PostgreSQL database with an items table that has columns for title, description, brand, price, and type.
3-
Generate a search query based on the conversation and the new question.
4-
If the question is not in English, translate the question to English before generating the search query.
5-
If you cannot generate a search query, return the original user question.
6-
DO NOT return anything besides the query.
1+
Your job is to find search results based off the user's question and past messages.
2+
You have access to only these tools:
3+
1. **search_database**: This tool allows you to search a table for items based on a query.
4+
You can pass in a search query and optional filters.
5+
Once you get the search results, you're done.
Lines changed: 74 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,76 @@
11
[
2-
{"role": "user", "content": "good options for climbing gear that can be used outside?"},
3-
{"role": "assistant", "tool_calls": [
4-
{
5-
"id": "call_abc123",
6-
"type": "function",
7-
"function": {
8-
"arguments": "{\"search_query\":\"climbing gear outside\"}",
9-
"name": "search_database"
10-
}
11-
}
12-
]},
13-
{
14-
"role": "tool",
15-
"tool_call_id": "call_abc123",
16-
"content": "Search results for climbing gear that can be used outside: ..."
17-
},
18-
{"role": "user", "content": "are there any shoes less than $50?"},
19-
{"role": "assistant", "tool_calls": [
20-
{
21-
"id": "call_abc456",
22-
"type": "function",
23-
"function": {
24-
"arguments": "{\"search_query\":\"shoes\",\"price_filter\":{\"comparison_operator\":\"<\",\"value\":50}}",
25-
"name": "search_database"
26-
}
27-
}
28-
]},
29-
{
30-
"role": "tool",
31-
"tool_call_id": "call_abc456",
32-
"content": "Search results for shoes cheaper than 50: ..."
33-
}
2+
{
3+
"parts": [
4+
{
5+
"content": "good options for climbing gear that can be used outside?",
6+
"timestamp": "2025-05-07T19:02:46.977501Z",
7+
"part_kind": "user-prompt"
8+
}
9+
],
10+
"instructions": null,
11+
"kind": "request"
12+
},
13+
{
14+
"parts": [
15+
{
16+
"tool_name": "search_database",
17+
"args": "{\"search_query\":\"climbing gear outside\"}",
18+
"tool_call_id": "call_4HeBCmo2uioV6CyoePEGyZPc",
19+
"part_kind": "tool-call"
20+
}
21+
],
22+
"model_name": "gpt-4o-mini-2024-07-18",
23+
"timestamp": "2025-05-07T19:02:47Z",
24+
"kind": "response"
25+
},
26+
{
27+
"parts": [
28+
{
29+
"tool_name": "search_database",
30+
"content": "Search results for climbing gear that can be used outside: ...",
31+
"tool_call_id": "call_4HeBCmo2uioV6CyoePEGyZPc",
32+
"timestamp": "2025-05-07T19:02:48.242408Z",
33+
"part_kind": "tool-return"
34+
}
35+
],
36+
"instructions": null,
37+
"kind": "request"
38+
},
39+
{
40+
"parts": [
41+
{
42+
"content": "are there any shoes less than $50?",
43+
"timestamp": "2025-05-07T19:02:46.977501Z",
44+
"part_kind": "user-prompt"
45+
}
46+
],
47+
"instructions": null,
48+
"kind": "request"
49+
},
50+
{
51+
"parts": [
52+
{
53+
"tool_name": "search_database",
54+
"args": "{\"search_query\":\"shoes\",\"price_filter\":{\"comparison_operator\":\"<\",\"value\":50}}",
55+
"tool_call_id": "call_4HeBCmo2uioV6CyoePEGyZPc",
56+
"part_kind": "tool-call"
57+
}
58+
],
59+
"model_name": "gpt-4o-mini-2024-07-18",
60+
"timestamp": "2025-05-07T19:02:47Z",
61+
"kind": "response"
62+
},
63+
{
64+
"parts": [
65+
{
66+
"tool_name": "search_database",
67+
"content": "Search results for shoes cheaper than 50: ...",
68+
"tool_call_id": "call_4HeBCmo2uioV6CyoePEGyZPc",
69+
"timestamp": "2025-05-07T19:02:48.242408Z",
70+
"part_kind": "tool-return"
71+
}
72+
],
73+
"instructions": null,
74+
"kind": "request"
75+
}
3476
]

0 commit comments

Comments
 (0)