1
1
import logging
2
2
import os
3
+ from collections .abc import AsyncGenerator
3
4
from typing import Annotated
4
5
5
6
import azure .identity
6
- from fastapi import Depends
7
+ from fastapi import Depends , Request
7
8
from openai import AsyncAzureOpenAI , AsyncOpenAI
8
9
from pydantic import BaseModel
9
10
from sqlalchemy .ext .asyncio import AsyncEngine , AsyncSession , async_sessionmaker
10
11
11
- from fastapi_app .openai_clients import create_openai_chat_client , create_openai_embed_client
12
- from fastapi_app .postgres_engine import create_postgres_engine_from_env
13
-
14
12
logger = logging .getLogger ("ragapp" )
15
13
16
14
@@ -67,7 +65,7 @@ async def common_parameters():
67
65
)
68
66
69
67
70
- def get_azure_credentials () -> azure .identity .DefaultAzureCredential | azure .identity .ManagedIdentityCredential :
68
+ async def get_azure_credentials () -> azure .identity .DefaultAzureCredential | azure .identity .ManagedIdentityCredential :
71
69
azure_credential : azure .identity .DefaultAzureCredential | azure .identity .ManagedIdentityCredential
72
70
try :
73
71
if client_id := os .getenv ("APP_IDENTITY_ID" ):
@@ -86,35 +84,55 @@ def get_azure_credentials() -> azure.identity.DefaultAzureCredential | azure.ide
86
84
raise e
87
85
88
86
89
- azure_credentials = get_azure_credentials ()
87
+ async def create_async_sessionmaker (engine : AsyncEngine ) -> async_sessionmaker [AsyncSession ]:
88
+ """Get the agent database"""
89
+ return async_sessionmaker (
90
+ engine ,
91
+ expire_on_commit = False ,
92
+ autoflush = False ,
93
+ )
90
94
91
95
92
- async def get_engine ():
93
- """Get the agent database engine"""
94
- engine = await create_postgres_engine_from_env ( azure_credentials )
95
- return engine
96
+ async def get_async_sessionmaker (
97
+ request : Request ,
98
+ ) -> AsyncGenerator [ async_sessionmaker [ AsyncSession ], None ]:
99
+ yield request . state . sessionmaker
96
100
97
101
98
- async def get_async_session (engine : Annotated [AsyncEngine , Depends (get_engine )]):
99
- """Get the agent database"""
100
- async_session_maker = async_sessionmaker (engine , expire_on_commit = False )
101
- async with async_session_maker () as async_session :
102
- yield async_session
102
+ async def get_context (
103
+ request : Request ,
104
+ ) -> FastAPIAppContext :
105
+ return request .state .context
106
+
107
+
108
+ async def get_async_db_session (
109
+ sessionmaker : Annotated [async_sessionmaker [AsyncSession ], Depends (get_async_sessionmaker )],
110
+ ) -> AsyncGenerator [AsyncSession , None ]:
111
+ async with sessionmaker () as session :
112
+ try :
113
+ yield session
114
+ except :
115
+ await session .rollback ()
116
+ raise
117
+ else :
118
+ await session .commit ()
103
119
104
120
105
- async def get_openai_chat_client ():
121
+ async def get_openai_chat_client (
122
+ request : Request ,
123
+ ) -> OpenAIClient :
106
124
"""Get the OpenAI chat client"""
107
- chat_client = await create_openai_chat_client (azure_credentials )
108
- return OpenAIClient (client = chat_client )
125
+ return OpenAIClient (client = request .state .chat_client )
109
126
110
127
111
- async def get_openai_embed_client ():
128
+ async def get_openai_embed_client (
129
+ request : Request ,
130
+ ) -> OpenAIClient :
112
131
"""Get the OpenAI embed client"""
113
- embed_client = await create_openai_embed_client (azure_credentials )
114
- return OpenAIClient (client = embed_client )
132
+ return OpenAIClient (client = request .state .embed_client )
115
133
116
134
117
- CommonDeps = Annotated [FastAPIAppContext , Depends (common_parameters )]
118
- DBSession = Annotated [AsyncSession , Depends (get_async_session )]
135
+ CommonDeps = Annotated [FastAPIAppContext , Depends (get_context )]
136
+ DBSession = Annotated [AsyncSession , Depends (get_async_db_session )]
119
137
ChatClient = Annotated [OpenAIClient , Depends (get_openai_chat_client )]
120
138
EmbeddingsClient = Annotated [OpenAIClient , Depends (get_openai_embed_client )]
0 commit comments