Skip to content

Commit f4b79fd

Browse files
committed
Fix credential
1 parent 9cc444c commit f4b79fd

7 files changed

+27
-27
lines changed

evals/generate_ground_truth.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from collections.abc import Generator
44
from pathlib import Path
55

6-
from azure.identity import DefaultAzureCredential
6+
from azure.identity import AzureDeveloperCliCredential
77
from dotenv import load_dotenv
88
from evaltools.gen.generate import generate_test_qa_data
99
from sqlalchemy import create_engine, select
@@ -56,7 +56,7 @@ def get_openai_config_dict() -> dict:
5656
api_key = os.environ["AZURE_OPENAI_KEY"]
5757
else:
5858
logger.info("Using Azure OpenAI Service with Azure Developer CLI Credential")
59-
azure_credential = DefaultAzureCredential(process_timeout=60)
59+
azure_credential = AzureDeveloperCliCredential(process_timeout=60)
6060
api_key = azure_credential.get_token("https://cognitiveservices.azure.com/.default").token
6161
openai_config = {
6262
"api_type": "azure",

src/backend/fastapi_app/postgres_engine.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import logging
22
import os
33

4-
from azure.identity import DefaultAzureCredential
4+
from azure.identity import AzureDeveloperCliCredential
55
from sqlalchemy import event
66
from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine
77

@@ -60,7 +60,7 @@ async def create_postgres_engine_from_env(azure_credential=None) -> AsyncEngine:
6060

6161
async def create_postgres_engine_from_args(args, azure_credential=None) -> AsyncEngine:
6262
if azure_credential is None and args.host.endswith(".database.azure.com"):
63-
azure_credential = DefaultAzureCredential(process_timeout=60)
63+
azure_credential = AzureDeveloperCliCredential(process_timeout=60)
6464

6565
return await create_postgres_engine(
6666
host=args.host,

tests/conftest.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -274,22 +274,22 @@ async def mock_acreate(*args, **kwargs):
274274

275275

276276
@pytest.fixture(scope="function")
277-
def mock_default_azure_credential(mock_session_env):
277+
def mock_azure_credential(mock_session_env):
278278
"""Mock the Azure credential for testing."""
279-
with mock.patch("azure.identity.DefaultAzureCredential") as mock_default_azure_credential:
280-
mock_default_azure_credential.return_value = MockAzureCredential()
281-
yield mock_default_azure_credential
279+
with mock.patch("azure.identity.AzureDeveloperCliCredential") as mock_azure_credential:
280+
mock_azure_credential.return_value = MockAzureCredential()
281+
yield mock_azure_credential
282282

283283

284284
@pytest_asyncio.fixture(scope="function")
285-
async def test_client(app, mock_default_azure_credential, mock_openai_embedding, mock_openai_chatcompletion):
285+
async def test_client(app, mock_azure_credential, mock_openai_embedding, mock_openai_chatcompletion):
286286
"""Create a test client."""
287287
with TestClient(app) as test_client:
288288
yield test_client
289289

290290

291291
@pytest_asyncio.fixture(scope="function")
292-
async def db_session(mock_session_env, mock_default_azure_credential):
292+
async def db_session(mock_session_env, mock_azure_credential):
293293
"""Create a new database session with a rollback at the end of the test."""
294294
engine = await create_postgres_engine_from_env()
295295
async_sesion = async_sessionmaker(autocommit=False, autoflush=False, bind=engine)
@@ -302,10 +302,10 @@ async def db_session(mock_session_env, mock_default_azure_credential):
302302

303303

304304
@pytest_asyncio.fixture(scope="function")
305-
async def postgres_searcher(mock_session_env, mock_default_azure_credential, db_session, mock_openai_embedding):
305+
async def postgres_searcher(mock_session_env, mock_azure_credential, db_session, mock_openai_embedding):
306306
from fastapi_app.postgres_searcher import PostgresSearcher
307307

308-
openai_embed_client = await create_openai_embed_client(mock_default_azure_credential)
308+
openai_embed_client = await create_openai_embed_client(mock_azure_credential)
309309

310310
yield PostgresSearcher(
311311
db_session=db_session,

tests/test_dependencies.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ async def test_get_common_parameters_openai(mock_session_env_openai):
3434

3535

3636
@pytest.mark.asyncio
37-
async def test_get_azure_credential(mock_session_env, mock_default_azure_credential):
37+
async def test_get_azure_credential(mock_session_env, mock_azure_credential):
3838
result = await get_azure_credential()
3939
token = result.get_token("https://vault.azure.net")
4040
assert token.expires_on == 9999999999

tests/test_embeddings.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66

77

88
@pytest.mark.asyncio
9-
async def test_compute_text_embedding(mock_default_azure_credential, mock_openai_embedding):
10-
openai_embed_client = await create_openai_embed_client(mock_default_azure_credential)
9+
async def test_compute_text_embedding(mock_azure_credential, mock_openai_embedding):
10+
openai_embed_client = await create_openai_embed_client(mock_azure_credential)
1111
result = await compute_text_embedding(
1212
q="test",
1313
openai_client=openai_embed_client,
@@ -18,8 +18,8 @@ async def test_compute_text_embedding(mock_default_azure_credential, mock_openai
1818

1919

2020
@pytest.mark.asyncio
21-
async def test_compute_text_embedding_dimensions(mock_default_azure_credential, mock_openai_embedding):
22-
openai_embed_client = await create_openai_embed_client(mock_default_azure_credential)
21+
async def test_compute_text_embedding_dimensions(mock_azure_credential, mock_openai_embedding):
22+
openai_embed_client = await create_openai_embed_client(mock_azure_credential)
2323
result = await compute_text_embedding(
2424
q="test",
2525
openai_client=openai_embed_client,

tests/test_openai_clients.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55

66

77
@pytest.mark.asyncio
8-
async def test_create_openai_embed_client(mock_default_azure_credential, mock_openai_embedding):
9-
openai_embed_client = await create_openai_embed_client(mock_default_azure_credential)
8+
async def test_create_openai_embed_client(mock_azure_credential, mock_openai_embedding):
9+
openai_embed_client = await create_openai_embed_client(mock_azure_credential)
1010
assert openai_embed_client.embeddings.create is not None
1111
embeddings = await openai_embed_client.embeddings.create(
1212
model="text-embedding-ada-002", input="test", dimensions=1536
@@ -15,8 +15,8 @@ async def test_create_openai_embed_client(mock_default_azure_credential, mock_op
1515

1616

1717
@pytest.mark.asyncio
18-
async def test_create_openai_chat_client(mock_default_azure_credential, mock_openai_chatcompletion):
19-
openai_chat_client = await create_openai_chat_client(mock_default_azure_credential)
18+
async def test_create_openai_chat_client(mock_azure_credential, mock_openai_chatcompletion):
19+
openai_chat_client = await create_openai_chat_client(mock_azure_credential)
2020
assert openai_chat_client.chat.completions.create is not None
2121
response = await openai_chat_client.chat.completions.create(
2222
model="gpt-4o-mini", messages=[{"content": "test", "role": "user"}]

tests/test_postgres_engine.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,14 @@
1111

1212

1313
@pytest.mark.asyncio
14-
async def test_create_postgres_engine(mock_session_env, mock_default_azure_credential):
14+
async def test_create_postgres_engine(mock_session_env, mock_azure_credential):
1515
engine = await create_postgres_engine(
1616
host=os.environ["POSTGRES_HOST"],
1717
username=os.environ["POSTGRES_USERNAME"],
1818
database=os.environ["POSTGRES_DATABASE"],
1919
password=os.environ.get("POSTGRES_PASSWORD"),
2020
sslmode=os.environ.get("POSTGRES_SSL"),
21-
azure_credential=mock_default_azure_credential,
21+
azure_credential=mock_azure_credential,
2222
)
2323
assert engine.url.host == "localhost"
2424
assert engine.url.username == "admin"
@@ -28,9 +28,9 @@ async def test_create_postgres_engine(mock_session_env, mock_default_azure_crede
2828

2929

3030
@pytest.mark.asyncio
31-
async def test_create_postgres_engine_from_env(mock_session_env, mock_default_azure_credential):
31+
async def test_create_postgres_engine_from_env(mock_session_env, mock_azure_credential):
3232
engine = await create_postgres_engine_from_env(
33-
azure_credential=mock_default_azure_credential,
33+
azure_credential=mock_azure_credential,
3434
)
3535
assert engine.url.host == "localhost"
3636
assert engine.url.username == "admin"
@@ -40,7 +40,7 @@ async def test_create_postgres_engine_from_env(mock_session_env, mock_default_az
4040

4141

4242
@pytest.mark.asyncio
43-
async def test_create_postgres_engine_from_args(mock_default_azure_credential):
43+
async def test_create_postgres_engine_from_args(mock_azure_credential):
4444
args = type(
4545
"Args",
4646
(),
@@ -54,7 +54,7 @@ async def test_create_postgres_engine_from_args(mock_default_azure_credential):
5454
)
5555
engine = await create_postgres_engine_from_args(
5656
args=args,
57-
azure_credential=mock_default_azure_credential,
57+
azure_credential=mock_azure_credential,
5858
)
5959
assert engine.url.host == "localhost"
6060
assert engine.url.username == "admin"

0 commit comments

Comments
 (0)