Skip to content

Commit ea722d7

Browse files
committed
Typing improvements
1 parent b36a260 commit ea722d7

File tree

4 files changed

+41
-23
lines changed

4 files changed

+41
-23
lines changed

requirements-dev.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,4 @@ pip-compile-cross-platform
66
pytest
77
pytest-cov
88
pytest-asyncio
9+
mypy

src/fastapi_app/openai_clients.py

Lines changed: 33 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -8,24 +8,31 @@
88

99

1010
async def create_openai_chat_client(azure_credential):
11+
openai_chat_client: openai.AsyncAzureOpenAI | openai.AsyncOpenAI
1112
OPENAI_CHAT_HOST = os.getenv("OPENAI_CHAT_HOST")
1213
if OPENAI_CHAT_HOST == "azure":
13-
client_args = {}
14+
api_version = os.environ["AZURE_OPENAI_VERSION"]
15+
azure_endpoint = os.environ["AZURE_OPENAI_ENDPOINT"]
16+
azure_deployment = os.environ["AZURE_OPENAI_EMBED_DEPLOYMENT"]
1417
if api_key := os.getenv("AZURE_OPENAI_KEY"):
1518
logger.info("Authenticating to Azure OpenAI using API key...")
16-
client_args["api_key"] = api_key
19+
openai_chat_client = openai.AsyncAzureOpenAI(
20+
api_version=api_version,
21+
azure_endpoint=azure_endpoint,
22+
azure_deployment=azure_deployment,
23+
api_key=api_key,
24+
)
1725
else:
1826
logger.info("Authenticating to Azure OpenAI using Azure Identity...")
1927
token_provider = azure.identity.get_bearer_token_provider(
2028
azure_credential, "https://cognitiveservices.azure.com/.default"
2129
)
22-
client_args["azure_ad_token_provider"] = token_provider
23-
openai_chat_client = openai.AsyncAzureOpenAI(
24-
api_version=os.getenv("AZURE_OPENAI_VERSION"),
25-
azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
26-
azure_deployment=os.getenv("AZURE_OPENAI_CHAT_DEPLOYMENT"),
27-
**client_args,
28-
)
30+
openai_chat_client = openai.AsyncAzureOpenAI(
31+
api_version=api_version,
32+
azure_endpoint=azure_endpoint,
33+
azure_deployment=azure_deployment,
34+
azure_ad_token_provider=token_provider,
35+
)
2936
openai_chat_model = os.getenv("AZURE_OPENAI_CHAT_MODEL")
3037
elif OPENAI_CHAT_HOST == "ollama":
3138
logger.info("Authenticating to OpenAI using Ollama...")
@@ -43,24 +50,32 @@ async def create_openai_chat_client(azure_credential):
4350

4451

4552
async def create_openai_embed_client(azure_credential):
53+
openai_embed_client: openai.AsyncAzureOpenAI | openai.AsyncOpenAI
4654
OPENAI_EMBED_HOST = os.getenv("OPENAI_EMBED_HOST")
4755
if OPENAI_EMBED_HOST == "azure":
48-
client_args = {}
56+
api_version = os.environ["AZURE_OPENAI_VERSION"]
57+
azure_endpoint = os.environ["AZURE_OPENAI_ENDPOINT"]
58+
azure_deployment = os.environ["AZURE_OPENAI_EMBED_DEPLOYMENT"]
4959
if api_key := os.getenv("AZURE_OPENAI_KEY"):
5060
logger.info("Authenticating to Azure OpenAI using API key...")
51-
client_args["api_key"] = api_key
61+
openai_embed_client = openai.AsyncAzureOpenAI(
62+
api_version=api_version,
63+
azure_endpoint=azure_endpoint,
64+
azure_deployment=azure_deployment,
65+
api_key=api_key,
66+
)
5267
else:
5368
logger.info("Authenticating to Azure OpenAI using Azure Identity...")
5469
token_provider = azure.identity.get_bearer_token_provider(
5570
azure_credential, "https://cognitiveservices.azure.com/.default"
5671
)
57-
client_args["azure_ad_token_provider"] = token_provider
58-
openai_embed_client = openai.AsyncAzureOpenAI(
59-
api_version=os.getenv("AZURE_OPENAI_VERSION"),
60-
azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
61-
azure_deployment=os.getenv("AZURE_OPENAI_EMBED_DEPLOYMENT"),
62-
**client_args,
63-
)
72+
openai_embed_client = openai.AsyncAzureOpenAI(
73+
api_version=api_version,
74+
azure_endpoint=azure_endpoint,
75+
azure_deployment=azure_deployment,
76+
azure_ad_token_provider=token_provider,
77+
)
78+
6479
openai_embed_model = os.getenv("AZURE_OPENAI_EMBED_MODEL")
6580
openai_embed_dimensions = os.getenv("AZURE_OPENAI_EMBED_DIMENSIONS")
6681
else:

src/fastapi_app/rag_advanced.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,15 @@ def __init__(
3232
self.answer_prompt_template = open(current_dir / "prompts/answer.txt").read()
3333

3434
async def run(
35-
self, messages: list[dict], overrides: dict[str, Any] = {}
35+
self, messages: list[ChatCompletionMessageParam], overrides: dict[str, Any] = {}
3636
) -> RetrievalResponse | AsyncGenerator[dict[str, Any], None]:
3737
text_search = overrides.get("retrieval_mode") in ["text", "hybrid", None]
3838
vector_search = overrides.get("retrieval_mode") in ["vectors", "hybrid", None]
3939
top = overrides.get("top", 3)
4040

4141
original_user_query = messages[-1]["content"]
42+
if not isinstance(original_user_query, str):
43+
raise ValueError("The most recent message content must be a string.")
4244
past_messages = messages[:-1]
4345

4446
# Generate an optimized keyword search query based on the chat history and the last question

src/fastapi_app/rag_simple.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
import pathlib
22
from collections.abc import AsyncGenerator
3-
from typing import (
4-
Any,
5-
)
3+
from typing import Any
64

75
from openai import AsyncOpenAI
86
from openai.types.chat import ChatCompletion, ChatCompletionMessageParam
@@ -30,13 +28,15 @@ def __init__(
3028
self.answer_prompt_template = open(current_dir / "prompts/answer.txt").read()
3129

3230
async def run(
33-
self, messages: list[dict], overrides: dict[str, Any] = {}
31+
self, messages: list[ChatCompletionMessageParam], overrides: dict[str, Any] = {}
3432
) -> RetrievalResponse | AsyncGenerator[dict[str, Any], None]:
3533
text_search = overrides.get("retrieval_mode") in ["text", "hybrid", None]
3634
vector_search = overrides.get("retrieval_mode") in ["vectors", "hybrid", None]
3735
top = overrides.get("top", 3)
3836

3937
original_user_query = messages[-1]["content"]
38+
if not isinstance(original_user_query, str):
39+
raise ValueError("The most recent message content must be a string.")
4040
past_messages = messages[:-1]
4141

4242
# Retrieve relevant items from the database

0 commit comments

Comments
 (0)