Skip to content

Commit b2bb121

Browse files
committed
add postgres searcher tests
1 parent 17fc97f commit b2bb121

File tree

2 files changed

+58
-1
lines changed

2 files changed

+58
-1
lines changed

tests/conftest.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from sqlalchemy.ext.asyncio import async_sessionmaker
1818

1919
from fastapi_app import create_app
20+
from fastapi_app.openai_clients import create_openai_embed_client
2021
from fastapi_app.postgres_engine import create_postgres_engine_from_env
2122
from fastapi_app.setup_postgres_database import create_db_schema
2223
from fastapi_app.setup_postgres_seeddata import seed_data
@@ -235,7 +236,7 @@ async def test_client(app, mock_default_azure_credential, mock_openai_embedding,
235236

236237

237238
@pytest_asyncio.fixture(scope="function")
238-
async def db_session():
239+
async def db_session(mock_session_env, mock_default_azure_credential):
239240
"""Create a new database session with a rollback at the end of the test."""
240241
engine = await create_postgres_engine_from_env()
241242
async_sesion = async_sessionmaker(autocommit=False, autoflush=False, bind=engine)
@@ -245,3 +246,18 @@ async def db_session():
245246
await session.rollback()
246247
await session.close()
247248
await engine.dispose()
249+
250+
251+
@pytest_asyncio.fixture(scope="function")
252+
async def postgres_searcher(mock_session_env, mock_default_azure_credential, db_session, mock_openai_embedding):
253+
from fastapi_app.postgres_searcher import PostgresSearcher
254+
255+
openai_embed_client = await create_openai_embed_client(mock_default_azure_credential)
256+
257+
yield PostgresSearcher(
258+
db_session=db_session,
259+
openai_embed_client=openai_embed_client,
260+
embed_deployment="text-embedding-ada-002",
261+
embed_model="text-embedding-ada-002",
262+
embed_dimensions=1536,
263+
)

tests/test_postgres_searcher.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import pytest
2+
3+
from fastapi_app.api_models import ItemPublic
4+
from tests.data import test_data
5+
6+
7+
def test_postgres_build_filter_clause_without_filters(postgres_searcher):
8+
assert postgres_searcher.build_filter_clause(None) == ("", "")
9+
assert postgres_searcher.build_filter_clause([]) == ("", "")
10+
11+
12+
def test_postgres_build_filter_clause_with_filters(postgres_searcher):
13+
assert postgres_searcher.build_filter_clause([{"column": "id", "comparison_operator": "=", "value": 1}]) == (
14+
"WHERE id = 1",
15+
"AND id = 1",
16+
)
17+
18+
19+
@pytest.mark.asyncio
20+
async def test_postgres_searcher_search_empty_text_search(postgres_searcher):
21+
assert await postgres_searcher.search("", [], 5, None) == []
22+
23+
24+
@pytest.mark.asyncio
25+
async def test_postgres_searcher_search(postgres_searcher):
26+
assert (await postgres_searcher.search(test_data.name, test_data.embeddings, 5, None))[0].to_dict() == ItemPublic(
27+
**test_data.model_dump()
28+
).model_dump()
29+
30+
31+
@pytest.mark.asyncio
32+
async def test_postgres_searcher_search_and_embed_empty_text_search(postgres_searcher):
33+
assert await postgres_searcher.search_and_embed("", 5, False, True) == []
34+
35+
36+
@pytest.mark.asyncio
37+
async def test_postgres_searcher_search_and_embed(postgres_searcher):
38+
assert await postgres_searcher.search_and_embed("", 5, False, True) == []
39+
assert (await postgres_searcher.search_and_embed(test_data.name, 5, True))[0].to_dict() == ItemPublic(
40+
**test_data.model_dump()
41+
).model_dump()

0 commit comments

Comments
 (0)