1
- from typing import Any
2
-
3
1
import fastapi
4
2
from fastapi import HTTPException
5
3
from sqlalchemy import select
6
4
7
- from fastapi_app .api_models import ChatRequest , RetrievalResponse
5
+ from fastapi_app .api_models import ChatRequest , ItemPublic , ItemWithDistance , RetrievalResponse
8
6
from fastapi_app .dependencies import ChatClient , CommonDeps , DBSession , EmbeddingsClient
9
7
from fastapi_app .postgres_models import Item
10
8
from fastapi_app .postgres_searcher import PostgresSearcher
14
12
router = fastapi .APIRouter ()
15
13
16
14
17
- @router .get ("/items/{id}" , response_model = dict [ str , Any ] )
18
- async def item_handler (id : int , database_session : DBSession ) -> dict [ str , Any ] :
15
+ @router .get ("/items/{id}" , response_model = ItemPublic )
16
+ async def item_handler (id : int , database_session : DBSession ) -> ItemPublic :
19
17
"""A simple API to get an item by ID."""
20
18
item = (await database_session .scalars (select (Item ).where (Item .id == id ))).first ()
21
19
if not item :
22
20
raise HTTPException (detail = f"Item with ID { id } not found." , status_code = 404 )
23
- return item .to_dict ()
21
+ return ItemPublic . model_validate ( item .to_dict () )
24
22
25
23
26
- @router .get ("/similar" , response_model = list [dict [ str , Any ] ])
27
- async def similar_handler (database_session : DBSession , id : int , n : int = 5 ) -> list [dict [ str , Any ] ]:
24
+ @router .get ("/similar" , response_model = list [ItemWithDistance ])
25
+ async def similar_handler (database_session : DBSession , id : int , n : int = 5 ) -> list [ItemWithDistance ]:
28
26
"""A similarity API to find items similar to items with given ID."""
29
27
item = (await database_session .scalars (select (Item ).where (Item .id == id ))).first ()
30
28
if not item :
@@ -35,10 +33,12 @@ async def similar_handler(database_session: DBSession, id: int, n: int = 5) -> l
35
33
.order_by (Item .embedding .l2_distance (item .embedding ))
36
34
.limit (n )
37
35
)
38
- return [item .to_dict () | {"distance" : round (distance , 2 )} for item , distance in closest ]
36
+ return [
37
+ ItemWithDistance .model_validate (item .to_dict () | {"distance" : round (distance , 2 )}) for item , distance in closest
38
+ ]
39
39
40
40
41
- @router .get ("/search" , response_model = list [dict [ str , Any ] ])
41
+ @router .get ("/search" , response_model = list [ItemPublic ])
42
42
async def search_handler (
43
43
context : CommonDeps ,
44
44
database_session : DBSession ,
@@ -47,7 +47,7 @@ async def search_handler(
47
47
top : int = 5 ,
48
48
enable_vector_search : bool = True ,
49
49
enable_text_search : bool = True ,
50
- ) -> list [dict [ str , Any ] ]:
50
+ ) -> list [ItemPublic ]:
51
51
"""A search API to find items based on a query."""
52
52
searcher = PostgresSearcher (
53
53
db_session = database_session ,
@@ -59,7 +59,7 @@ async def search_handler(
59
59
results = await searcher .search_and_embed (
60
60
query , top = top , enable_vector_search = enable_vector_search , enable_text_search = enable_text_search
61
61
)
62
- return [item .to_dict () for item in results ]
62
+ return [ItemPublic . model_validate ( item .to_dict () ) for item in results ]
63
63
64
64
65
65
@router .post ("/chat" , response_model = RetrievalResponse )
0 commit comments