Skip to content

Commit 3f0286b

Browse files
committed
add pydatnic types for Item table
1 parent 79a8a2d commit 3f0286b

File tree

3 files changed

+27
-21
lines changed

3 files changed

+27
-21
lines changed

src/fastapi_app/api_models.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,16 @@ class RetrievalResponse(BaseModel):
3030
message: Message
3131
context: RAGContext
3232
session_state: Any | None = None
33+
34+
35+
class ItemPublic(BaseModel):
36+
id: int
37+
type: str
38+
brand: str
39+
name: str
40+
description: str
41+
price: float
42+
43+
44+
class ItemWithDistance(ItemPublic):
45+
distance: float

src/fastapi_app/routes/api_routes.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
1-
from typing import Any
2-
31
import fastapi
42
from fastapi import HTTPException
53
from sqlalchemy import select
64

7-
from fastapi_app.api_models import ChatRequest, RetrievalResponse
5+
from fastapi_app.api_models import ChatRequest, ItemPublic, ItemWithDistance, RetrievalResponse
86
from fastapi_app.dependencies import ChatClient, CommonDeps, DBSession, EmbeddingsClient
97
from fastapi_app.postgres_models import Item
108
from fastapi_app.postgres_searcher import PostgresSearcher
@@ -14,17 +12,17 @@
1412
router = fastapi.APIRouter()
1513

1614

17-
@router.get("/items/{id}", response_model=dict[str, Any])
18-
async def item_handler(id: int, database_session: DBSession) -> dict[str, Any]:
15+
@router.get("/items/{id}", response_model=ItemPublic)
16+
async def item_handler(id: int, database_session: DBSession) -> ItemPublic:
1917
"""A simple API to get an item by ID."""
2018
item = (await database_session.scalars(select(Item).where(Item.id == id))).first()
2119
if not item:
2220
raise HTTPException(detail=f"Item with ID {id} not found.", status_code=404)
23-
return item.to_dict()
21+
return ItemPublic.model_validate(item.to_dict())
2422

2523

26-
@router.get("/similar", response_model=list[dict[str, Any]])
27-
async def similar_handler(database_session: DBSession, id: int, n: int = 5) -> list[dict[str, Any]]:
24+
@router.get("/similar", response_model=list[ItemWithDistance])
25+
async def similar_handler(database_session: DBSession, id: int, n: int = 5) -> list[ItemWithDistance]:
2826
"""A similarity API to find items similar to items with given ID."""
2927
item = (await database_session.scalars(select(Item).where(Item.id == id))).first()
3028
if not item:
@@ -35,10 +33,12 @@ async def similar_handler(database_session: DBSession, id: int, n: int = 5) -> l
3533
.order_by(Item.embedding.l2_distance(item.embedding))
3634
.limit(n)
3735
)
38-
return [item.to_dict() | {"distance": round(distance, 2)} for item, distance in closest]
36+
return [
37+
ItemWithDistance.model_validate(item.to_dict() | {"distance": round(distance, 2)}) for item, distance in closest
38+
]
3939

4040

41-
@router.get("/search", response_model=list[dict[str, Any]])
41+
@router.get("/search", response_model=list[ItemPublic])
4242
async def search_handler(
4343
context: CommonDeps,
4444
database_session: DBSession,
@@ -47,7 +47,7 @@ async def search_handler(
4747
top: int = 5,
4848
enable_vector_search: bool = True,
4949
enable_text_search: bool = True,
50-
) -> list[dict[str, Any]]:
50+
) -> list[ItemPublic]:
5151
"""A search API to find items based on a query."""
5252
searcher = PostgresSearcher(
5353
db_session=database_session,
@@ -59,7 +59,7 @@ async def search_handler(
5959
results = await searcher.search_and_embed(
6060
query, top=top, enable_vector_search=enable_vector_search, enable_text_search=enable_text_search
6161
)
62-
return [item.to_dict() for item in results]
62+
return [ItemPublic.model_validate(item.to_dict()) for item in results]
6363

6464

6565
@router.post("/chat", response_model=RetrievalResponse)

tests/data.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,7 @@
1-
from dataclasses import dataclass
1+
from fastapi_app.api_models import ItemPublic
22

33

4-
@dataclass
5-
class TestData:
6-
id: int
7-
type: str
8-
brand: str
9-
name: str
10-
description: str
11-
price: float
4+
class TestData(ItemPublic):
125
embeddings: list[float]
136

147

0 commit comments

Comments
 (0)