Skip to content

Commit 79a8a2d

Browse files
committed
use app state to store global vars
there is only one engine, sessionmaker, azure_credentials, context, chat_client, and embed_client created during the lifespan of the fastapi app
1 parent da5220c commit 79a8a2d

File tree

3 files changed

+83
-28
lines changed

3 files changed

+83
-28
lines changed

src/fastapi_app/__init__.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,46 @@
11
import logging
22
import os
3+
from collections.abc import AsyncIterator
4+
from contextlib import asynccontextmanager
5+
from typing import TypedDict
36

47
from dotenv import load_dotenv
58
from fastapi import FastAPI
9+
from openai import AsyncAzureOpenAI, AsyncOpenAI
10+
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
11+
12+
from fastapi_app.dependencies import (
13+
FastAPIAppContext,
14+
common_parameters,
15+
create_async_sessionmaker,
16+
get_azure_credentials,
17+
)
18+
from fastapi_app.openai_clients import create_openai_chat_client, create_openai_embed_client
19+
from fastapi_app.postgres_engine import create_postgres_engine_from_env
620

721
logger = logging.getLogger("ragapp")
822

923

24+
class State(TypedDict):
25+
sessionmaker: async_sessionmaker[AsyncSession]
26+
context: FastAPIAppContext
27+
chat_client: AsyncOpenAI | AsyncAzureOpenAI
28+
embed_client: AsyncOpenAI | AsyncAzureOpenAI
29+
30+
31+
@asynccontextmanager
32+
async def lifespan(app: FastAPI) -> AsyncIterator[State]:
33+
context = await common_parameters()
34+
azure_credential = await get_azure_credentials()
35+
engine = await create_postgres_engine_from_env(azure_credential)
36+
sessionmaker = await create_async_sessionmaker(engine)
37+
chat_client = await create_openai_chat_client(azure_credential)
38+
embed_client = await create_openai_embed_client(azure_credential)
39+
40+
yield {"sessionmaker": sessionmaker, "context": context, "chat_client": chat_client, "embed_client": embed_client}
41+
await engine.dispose()
42+
43+
1044
def create_app(testing: bool = False):
1145
if os.getenv("RUNNING_IN_PRODUCTION"):
1246
logging.basicConfig(level=logging.WARNING)
@@ -15,7 +49,7 @@ def create_app(testing: bool = False):
1549
load_dotenv(override=True)
1650
logging.basicConfig(level=logging.INFO)
1751

18-
app = FastAPI(docs_url="/docs")
52+
app = FastAPI(docs_url="/docs", lifespan=lifespan)
1953

2054
from fastapi_app.routes import api_routes, frontend_routes
2155

src/fastapi_app/dependencies.py

Lines changed: 41 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,14 @@
11
import logging
22
import os
3+
from collections.abc import AsyncGenerator
34
from typing import Annotated
45

56
import azure.identity
6-
from fastapi import Depends
7+
from fastapi import Depends, Request
78
from openai import AsyncAzureOpenAI, AsyncOpenAI
89
from pydantic import BaseModel
910
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker
1011

11-
from fastapi_app.openai_clients import create_openai_chat_client, create_openai_embed_client
12-
from fastapi_app.postgres_engine import create_postgres_engine_from_env
13-
1412
logger = logging.getLogger("ragapp")
1513

1614

@@ -67,7 +65,7 @@ async def common_parameters():
6765
)
6866

6967

70-
def get_azure_credentials() -> azure.identity.DefaultAzureCredential | azure.identity.ManagedIdentityCredential:
68+
async def get_azure_credentials() -> azure.identity.DefaultAzureCredential | azure.identity.ManagedIdentityCredential:
7169
azure_credential: azure.identity.DefaultAzureCredential | azure.identity.ManagedIdentityCredential
7270
try:
7371
if client_id := os.getenv("APP_IDENTITY_ID"):
@@ -86,35 +84,55 @@ def get_azure_credentials() -> azure.identity.DefaultAzureCredential | azure.ide
8684
raise e
8785

8886

89-
azure_credentials = get_azure_credentials()
87+
async def create_async_sessionmaker(engine: AsyncEngine) -> async_sessionmaker[AsyncSession]:
88+
"""Get the agent database"""
89+
return async_sessionmaker(
90+
engine,
91+
expire_on_commit=False,
92+
autoflush=False,
93+
)
9094

9195

92-
async def get_engine():
93-
"""Get the agent database engine"""
94-
engine = await create_postgres_engine_from_env(azure_credentials)
95-
return engine
96+
async def get_async_sessionmaker(
97+
request: Request,
98+
) -> AsyncGenerator[async_sessionmaker[AsyncSession], None]:
99+
yield request.state.sessionmaker
96100

97101

98-
async def get_async_session(engine: Annotated[AsyncEngine, Depends(get_engine)]):
99-
"""Get the agent database"""
100-
async_session_maker = async_sessionmaker(engine, expire_on_commit=False)
101-
async with async_session_maker() as async_session:
102-
yield async_session
102+
async def get_context(
103+
request: Request,
104+
) -> FastAPIAppContext:
105+
return request.state.context
106+
107+
108+
async def get_async_db_session(
109+
sessionmaker: Annotated[async_sessionmaker[AsyncSession], Depends(get_async_sessionmaker)],
110+
) -> AsyncGenerator[AsyncSession, None]:
111+
async with sessionmaker() as session:
112+
try:
113+
yield session
114+
except:
115+
await session.rollback()
116+
raise
117+
else:
118+
await session.commit()
103119

104120

105-
async def get_openai_chat_client():
121+
async def get_openai_chat_client(
122+
request: Request,
123+
) -> OpenAIClient:
106124
"""Get the OpenAI chat client"""
107-
chat_client = await create_openai_chat_client(azure_credentials)
108-
return OpenAIClient(client=chat_client)
125+
return OpenAIClient(client=request.state.chat_client)
109126

110127

111-
async def get_openai_embed_client():
128+
async def get_openai_embed_client(
129+
request: Request,
130+
) -> OpenAIClient:
112131
"""Get the OpenAI embed client"""
113-
embed_client = await create_openai_embed_client(azure_credentials)
114-
return OpenAIClient(client=embed_client)
132+
return OpenAIClient(client=request.state.embed_client)
115133

116134

117-
CommonDeps = Annotated[FastAPIAppContext, Depends(common_parameters)]
118-
DBSession = Annotated[AsyncSession, Depends(get_async_session)]
135+
CommonDeps = Annotated[FastAPIAppContext, Depends(get_context)]
136+
DBSession = Annotated[AsyncSession, Depends(get_async_db_session)]
119137
ChatClient = Annotated[OpenAIClient, Depends(get_openai_chat_client)]
120138
EmbeddingsClient = Annotated[OpenAIClient, Depends(get_openai_embed_client)]

src/fastapi_app/update_embeddings.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,17 @@
44
from sqlalchemy import select
55
from sqlalchemy.ext.asyncio import async_sessionmaker
66

7-
from fastapi_app.dependencies import common_parameters, get_engine, get_openai_embed_client
7+
from fastapi_app.dependencies import common_parameters, get_azure_credentials
88
from fastapi_app.embeddings import compute_text_embedding
9+
from fastapi_app.openai_clients import create_openai_embed_client
10+
from fastapi_app.postgres_engine import create_postgres_engine_from_env
911
from fastapi_app.postgres_models import Item
1012

1113

1214
async def update_embeddings():
13-
engine = await get_engine()
14-
openai_embed = await get_openai_embed_client()
15+
azure_credential = await get_azure_credentials()
16+
engine = await create_postgres_engine_from_env(azure_credential)
17+
openai_embed_client = await create_openai_embed_client(azure_credential)
1518
common_params = await common_parameters()
1619

1720
async with async_sessionmaker(engine, expire_on_commit=False)() as session:
@@ -21,7 +24,7 @@ async def update_embeddings():
2124
for item in items:
2225
item.embedding = await compute_text_embedding(
2326
item.to_str_for_embedding(),
24-
openai_client=openai_embed.client,
27+
openai_client=openai_embed_client,
2528
embed_model=common_params.openai_embed_model,
2629
embedding_dimensions=common_params.openai_embed_dimensions,
2730
)

0 commit comments

Comments
 (0)