Skip to content

Commit 02d0b1b

Browse files
committed
add more tests and fix azure credentials mocking
AsyncTokenCredentials is not the correct class to inherit from as we are not using the async credentials
1 parent 3f0286b commit 02d0b1b

File tree

7 files changed

+136
-22
lines changed

7 files changed

+136
-22
lines changed

src/fastapi_app/embeddings.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ async def compute_text_embedding(
1111
embed_model: str,
1212
embed_deployment: str | None = None,
1313
embedding_dimensions: int = 1536,
14-
):
14+
) -> list[float]:
1515
SUPPORTED_DIMENSIONS_MODEL = {
1616
"text-embedding-ada-002": False,
1717
"text-embedding-3-small": True,

tests/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,7 @@ def mock_default_azure_credential(mock_session_env):
224224
"""Mock the Azure credential for testing."""
225225
with mock.patch("azure.identity.DefaultAzureCredential") as mock_default_azure_credential:
226226
mock_default_azure_credential.return_value = MockAzureCredential()
227-
yield
227+
yield mock_default_azure_credential
228228

229229

230230
@pytest_asyncio.fixture(scope="function")

tests/mocks.py

Lines changed: 8 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,22 @@
11
from collections import namedtuple
2-
from types import TracebackType
32

4-
from azure.core.credentials_async import AsyncTokenCredential
3+
from azure.core.credentials import TokenCredential
54

6-
MockToken = namedtuple("MockToken", ["token", "expires_on", "value"])
5+
MockToken = namedtuple("MockToken", ["token", "expires_on"])
76

87

9-
class MockAzureCredential(AsyncTokenCredential):
10-
async def get_token(self, uri):
11-
return MockToken("", 9999999999, "")
12-
13-
async def close(self) -> None:
14-
pass
15-
16-
async def __aexit__(
17-
self,
18-
exc_type: type[BaseException] | None = None,
19-
exc_value: BaseException | None = None,
20-
traceback: TracebackType | None = None,
21-
) -> None:
22-
pass
8+
class MockAzureCredential(TokenCredential):
9+
def get_token(self, uri):
10+
return MockToken("", 9999999999)
2311

2412

25-
class MockAzureCredentialExpired(AsyncTokenCredential):
13+
class MockAzureCredentialExpired(TokenCredential):
2614
def __init__(self):
2715
self.access_number = 0
2816

2917
async def get_token(self, uri):
3018
self.access_number += 1
3119
if self.access_number == 1:
32-
return MockToken("", 0, "")
20+
return MockToken("", 0)
3321
else:
34-
return MockToken("", 9999999999, "")
22+
return MockToken("", 9999999999)

tests/test_dependencies.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import pytest
2+
3+
from fastapi_app.dependencies import common_parameters, get_azure_credentials
4+
5+
6+
@pytest.mark.asyncio
7+
async def test_get_common_parameters(mock_session_env):
8+
result = await common_parameters()
9+
assert result.openai_chat_model == "gpt-35-turbo"
10+
assert result.openai_embed_model == "text-embedding-ada-002"
11+
assert result.openai_embed_dimensions == 1536
12+
assert result.openai_chat_deployment == "gpt-35-turbo"
13+
assert result.openai_embed_deployment == "text-embedding-ada-002"
14+
15+
16+
@pytest.mark.asyncio
17+
async def test_get_azure_credentials(mock_session_env, mock_default_azure_credential):
18+
result = await get_azure_credentials()
19+
token = result.get_token("https://vault.azure.net")
20+
assert token.expires_on == 9999999999
21+
assert token.token == ""

tests/test_embeddings.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
import pytest
2+
3+
from fastapi_app.embeddings import compute_text_embedding
4+
from fastapi_app.openai_clients import create_openai_embed_client
5+
from tests.data import test_data
6+
7+
8+
@pytest.mark.asyncio
9+
async def test_compute_text_embedding(mock_default_azure_credential, mock_openai_embedding):
10+
openai_embed_client = await create_openai_embed_client(mock_default_azure_credential)
11+
result = await compute_text_embedding(
12+
q="test",
13+
openai_client=openai_embed_client,
14+
embed_model="text-embedding-ada-002",
15+
embed_deployment="text-embedding-ada-002",
16+
embedding_dimensions=1536,
17+
)
18+
assert result == test_data.embeddings

tests/test_openai_clients.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import pytest
2+
3+
from fastapi_app.openai_clients import create_openai_chat_client, create_openai_embed_client
4+
from tests.data import test_data
5+
6+
7+
@pytest.mark.asyncio
8+
async def test_create_openai_embed_client(mock_default_azure_credential, mock_openai_embedding):
9+
openai_embed_client = await create_openai_embed_client(mock_default_azure_credential)
10+
assert openai_embed_client.embeddings.create is not None
11+
embeddings = await openai_embed_client.embeddings.create(
12+
model="text-embedding-ada-002", input="test", dimensions=1536
13+
)
14+
assert embeddings.data[0].embedding == test_data.embeddings
15+
16+
17+
@pytest.mark.asyncio
18+
async def test_create_openai_chat_client(mock_default_azure_credential, mock_openai_chatcompletion):
19+
openai_chat_client = await create_openai_chat_client(mock_default_azure_credential)
20+
assert openai_chat_client.chat.completions.create is not None
21+
response = await openai_chat_client.chat.completions.create(
22+
model="gpt-35-turbo", messages=[{"content": "test", "role": "user"}]
23+
)
24+
assert response.choices[0].message.content == "The capital of France is Paris. [Benefit_Options-2.pdf]."

tests/test_postgres_engine.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
import os
2+
3+
import pytest
4+
5+
from fastapi_app.postgres_engine import (
6+
create_postgres_engine,
7+
create_postgres_engine_from_args,
8+
create_postgres_engine_from_env,
9+
)
10+
from tests.conftest import POSTGRES_DATABASE, POSTGRES_HOST, POSTGRES_PASSWORD, POSTGRES_SSL, POSTGRES_USERNAME
11+
12+
13+
@pytest.mark.asyncio
14+
async def test_create_postgres_engine(mock_session_env, mock_default_azure_credential):
15+
engine = await create_postgres_engine(
16+
host=os.environ["POSTGRES_HOST"],
17+
username=os.environ["POSTGRES_USERNAME"],
18+
database=os.environ["POSTGRES_DATABASE"],
19+
password=os.environ.get("POSTGRES_PASSWORD"),
20+
sslmode=os.environ.get("POSTGRES_SSL"),
21+
azure_credential=mock_default_azure_credential,
22+
)
23+
assert engine.url.host == "localhost"
24+
assert engine.url.username == "admin"
25+
assert engine.url.database == "postgres"
26+
assert engine.url.password == "postgres"
27+
assert engine.url.query["ssl"] == "prefer"
28+
29+
30+
@pytest.mark.asyncio
31+
async def test_create_postgres_engine_from_env(mock_session_env, mock_default_azure_credential):
32+
engine = await create_postgres_engine_from_env(
33+
azure_credential=mock_default_azure_credential,
34+
)
35+
assert engine.url.host == "localhost"
36+
assert engine.url.username == "admin"
37+
assert engine.url.database == "postgres"
38+
assert engine.url.password == "postgres"
39+
assert engine.url.query["ssl"] == "prefer"
40+
41+
42+
@pytest.mark.asyncio
43+
async def test_create_postgres_engine_from_args(mock_default_azure_credential):
44+
args = type(
45+
"Args",
46+
(),
47+
{
48+
"host": POSTGRES_HOST,
49+
"username": POSTGRES_USERNAME,
50+
"database": POSTGRES_DATABASE,
51+
"password": POSTGRES_PASSWORD,
52+
"sslmode": POSTGRES_SSL,
53+
},
54+
)
55+
engine = await create_postgres_engine_from_args(
56+
args=args,
57+
azure_credential=mock_default_azure_credential,
58+
)
59+
assert engine.url.host == "localhost"
60+
assert engine.url.username == "admin"
61+
assert engine.url.database == "postgres"
62+
assert engine.url.password == "postgres"
63+
assert engine.url.query["ssl"] == "prefer"

0 commit comments

Comments
 (0)