Skip to content

Commit 2f1c4ef

Browse files
authored
Merge pull request #70 from Azure-Samples/embedcolumn
Add embedding with ollama
2 parents d1e990c + c124a84 commit 2f1c4ef

15 files changed

+234140
-954
lines changed

.env.sample

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,17 @@ AZURE_OPENAI_CHAT_MODEL=gpt-35-turbo
1818
AZURE_OPENAI_EMBED_DEPLOYMENT=text-embedding-ada-002
1919
AZURE_OPENAI_EMBED_MODEL=text-embedding-ada-002
2020
AZURE_OPENAI_EMBED_MODEL_DIMENSIONS=1536
21+
AZURE_OPENAI_EMBEDDING_COLUMN=embedding_ada002
2122
# Only needed when using key-based Azure authentication:
2223
AZURE_OPENAI_KEY=
2324
# Needed for OpenAI.com:
2425
OPENAICOM_KEY=YOUR-OPENAI-API-KEY
2526
OPENAICOM_CHAT_MODEL=gpt-3.5-turbo
2627
OPENAICOM_EMBED_MODEL=text-embedding-ada-002
2728
OPENAICOM_EMBED_MODEL_DIMENSIONS=1536
29+
OPENAICOM_EMBEDDING_COLUMN=embedding_ada002
2830
# Needed for Ollama:
2931
OLLAMA_ENDPOINT=http://host.docker.internal:11434/v1
30-
OLLAMA_CHAT_MODEL=phi3:3.8b
32+
OLLAMA_CHAT_MODEL=llama3.1
33+
OLLAMA_EMBED_MODEL=nomic-embed-text
34+
OLLAMA_EMBEDDING_COLUMN=embedding_nomic

.vscode/settings.json

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,5 +22,10 @@
2222
"ssl": true
2323
}
2424
}
25-
]
25+
],
26+
"python.testing.pytestArgs": [
27+
"tests"
28+
],
29+
"python.testing.unittestEnabled": false,
30+
"python.testing.pytestEnabled": true
2631
}

src/backend/fastapi_app/dependencies.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,10 @@ class FastAPIAppContext(BaseModel):
2828

2929
openai_chat_model: str
3030
openai_embed_model: str
31-
openai_embed_dimensions: int
31+
openai_embed_dimensions: int | None
3232
openai_chat_deployment: str | None
3333
openai_embed_deployment: str | None
34+
embedding_column: str
3435

3536

3637
async def common_parameters():
@@ -43,16 +44,24 @@ async def common_parameters():
4344
openai_embed_deployment = os.getenv("AZURE_OPENAI_EMBED_DEPLOYMENT", "text-embedding-ada-002")
4445
openai_embed_model = os.getenv("AZURE_OPENAI_EMBED_MODEL", "text-embedding-ada-002")
4546
openai_embed_dimensions = int(os.getenv("AZURE_OPENAI_EMBED_DIMENSIONS", 1536))
47+
embedding_column = os.getenv("AZURE_OPENAI_EMBEDDING_COLUMN", "embedding_ada002")
48+
elif OPENAI_EMBED_HOST == "ollama":
49+
openai_embed_deployment = None
50+
openai_embed_model = os.getenv("OLLAMA_EMBED_MODEL", "nomic-embed-text")
51+
openai_embed_dimensions = None
52+
embedding_column = os.getenv("OLLAMA_EMBEDDING_COLUMN", "embedding_nomic")
4653
else:
47-
openai_embed_deployment = "text-embedding-ada-002"
54+
openai_embed_deployment = None
4855
openai_embed_model = os.getenv("OPENAICOM_EMBED_MODEL", "text-embedding-ada-002")
4956
openai_embed_dimensions = int(os.getenv("OPENAICOM_EMBED_DIMENSIONS", 1536))
57+
embedding_column = os.getenv("OPENAICOM_EMBEDDING_COLUMN", "embedding_ada002")
5058
if OPENAI_CHAT_HOST == "azure":
5159
openai_chat_deployment = os.getenv("AZURE_OPENAI_CHAT_DEPLOYMENT", "gpt-35-turbo")
5260
openai_chat_model = os.getenv("AZURE_OPENAI_CHAT_MODEL", "gpt-35-turbo")
5361
elif OPENAI_CHAT_HOST == "ollama":
5462
openai_chat_deployment = None
5563
openai_chat_model = os.getenv("OLLAMA_CHAT_MODEL", "phi3:3.8b")
64+
openai_embed_model = os.getenv("OLLAMA_EMBED_MODEL", "nomic-embed-text")
5665
else:
5766
openai_chat_deployment = None
5867
openai_chat_model = os.getenv("OPENAICOM_CHAT_MODEL", "gpt-3.5-turbo")
@@ -62,6 +71,7 @@ async def common_parameters():
6271
openai_embed_dimensions=openai_embed_dimensions,
6372
openai_chat_deployment=openai_chat_deployment,
6473
openai_embed_deployment=openai_embed_deployment,
74+
embedding_column=embedding_column,
6575
)
6676

6777

src/backend/fastapi_app/embeddings.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ async def compute_text_embedding(
1010
openai_client: AsyncOpenAI | AsyncAzureOpenAI,
1111
embed_model: str,
1212
embed_deployment: str | None = None,
13-
embedding_dimensions: int = 1536,
13+
embedding_dimensions: int | None = None,
1414
) -> list[float]:
1515
SUPPORTED_DIMENSIONS_MODEL = {
1616
"text-embedding-ada-002": False,
@@ -21,7 +21,12 @@ async def compute_text_embedding(
2121
class ExtraArgs(TypedDict, total=False):
2222
dimensions: int
2323

24-
dimensions_args: ExtraArgs = {"dimensions": embedding_dimensions} if SUPPORTED_DIMENSIONS_MODEL[embed_model] else {}
24+
dimensions_args: ExtraArgs = {}
25+
if SUPPORTED_DIMENSIONS_MODEL.get(embed_model):
26+
if embedding_dimensions is None:
27+
raise ValueError(f"Model {embed_model} requires embedding dimensions")
28+
else:
29+
dimensions_args = {"dimensions": embedding_dimensions}
2530

2631
embedding = await openai_client.embeddings.create(
2732
# Azure OpenAI takes the deployment name as the model name

src/backend/fastapi_app/openai_clients.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,13 @@ async def create_openai_embed_client(
7676
azure_deployment=azure_deployment,
7777
azure_ad_token_provider=token_provider,
7878
)
79-
79+
elif OPENAI_EMBED_HOST == "ollama":
80+
logger.info("Authenticating to OpenAI using Ollama...")
81+
openai_embed_client = openai.AsyncOpenAI(
82+
base_url=os.getenv("OLLAMA_ENDPOINT"),
83+
api_key="nokeyneeded",
84+
)
8085
else:
86+
logger.info("Authenticating to OpenAI using OpenAI.com API key...")
8187
openai_embed_client = openai.AsyncOpenAI(api_key=os.getenv("OPENAICOM_KEY"))
8288
return openai_embed_client

src/backend/fastapi_app/postgres_engine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def get_password_from_azure_credential():
3030

3131
engine = create_async_engine(
3232
DATABASE_URI,
33-
echo=False,
33+
echo=True,
3434
)
3535

3636
@event.listens_for(engine.sync_engine, "do_connect")

src/backend/fastapi_app/postgres_models.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,17 @@ class Item(Base):
2020
name: Mapped[str] = mapped_column()
2121
description: Mapped[str] = mapped_column()
2222
price: Mapped[float] = mapped_column()
23-
embedding: Mapped[Vector] = mapped_column(Vector(1536)) # ada-002
23+
embedding_ada002: Mapped[Vector] = mapped_column(Vector(1536)) # ada-002
24+
embedding_nomic: Mapped[Vector] = mapped_column(Vector(768)) # nomic-embed-text
2425

2526
def to_dict(self, include_embedding: bool = False):
2627
model_dict = asdict(self)
2728
if include_embedding:
28-
model_dict["embedding"] = model_dict["embedding"].tolist()
29+
model_dict["embedding_ada002"] = model_dict.get("embedding_ada002", [])
30+
model_dict["embedding_nomic"] = model_dict.get("embedding_nomic", [])
2931
else:
30-
del model_dict["embedding"]
32+
del model_dict["embedding_ada002"]
33+
del model_dict["embedding_nomic"]
3134
return model_dict
3235

3336
def to_str_for_rag(self):
@@ -38,10 +41,18 @@ def to_str_for_embedding(self):
3841

3942

4043
# Define HNSW index to support vector similarity search through the vector_cosine_ops access method (cosine distance).
41-
index = Index(
42-
"hnsw_index_for_innerproduct_item_embedding",
43-
Item.embedding,
44+
index_ada002 = Index(
45+
"hnsw_index_for_innerproduct_item_embedding_ada002",
46+
Item.embedding_ada002,
4447
postgresql_using="hnsw",
4548
postgresql_with={"m": 16, "ef_construction": 64},
46-
postgresql_ops={"embedding": "vector_ip_ops"},
49+
postgresql_ops={"embedding_ada002": "vector_ip_ops"},
50+
)
51+
52+
index_nomic = Index(
53+
"hnsw_index_for_innerproduct_item_embedding_nomic",
54+
Item.embedding_nomic,
55+
postgresql_using="hnsw",
56+
postgresql_with={"m": 16, "ef_construction": 64},
57+
postgresql_ops={"embedding_nomic": "vector_ip_ops"},
4758
)

src/backend/fastapi_app/postgres_searcher.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,15 @@ def __init__(
1414
openai_embed_client: AsyncOpenAI | AsyncAzureOpenAI,
1515
embed_deployment: str | None, # Not needed for non-Azure OpenAI or for retrieval_mode="text"
1616
embed_model: str,
17-
embed_dimensions: int,
17+
embed_dimensions: int | None,
18+
embedding_column: str,
1819
):
1920
self.db_session = db_session
2021
self.openai_embed_client = openai_embed_client
2122
self.embed_model = embed_model
2223
self.embed_deployment = embed_deployment
2324
self.embed_dimensions = embed_dimensions
25+
self.embedding_column = embedding_column
2426

2527
def build_filter_clause(self, filters) -> tuple[str, str]:
2628
if filters is None:
@@ -36,19 +38,15 @@ def build_filter_clause(self, filters) -> tuple[str, str]:
3638
return "", ""
3739

3840
async def search(
39-
self,
40-
query_text: str | None,
41-
query_vector: list[float] | list,
42-
top: int = 5,
43-
filters: list[dict] | None = None,
41+
self, query_text: str | None, query_vector: list[float] | list, top: int = 5, filters: list[dict] | None = None
4442
):
4543
filter_clause_where, filter_clause_and = self.build_filter_clause(filters)
4644

4745
vector_query = f"""
48-
SELECT id, RANK () OVER (ORDER BY embedding <=> :embedding) AS rank
46+
SELECT id, RANK () OVER (ORDER BY {self.embedding_column} <=> :embedding) AS rank
4947
FROM items
5048
{filter_clause_where}
51-
ORDER BY embedding <=> :embedding
49+
ORDER BY {self.embedding_column} <=> :embedding
5250
LIMIT 20
5351
"""
5452

src/backend/fastapi_app/routes/api_routes.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,15 +45,18 @@ async def item_handler(database_session: DBSession, id: int) -> ItemPublic:
4545

4646

4747
@router.get("/similar", response_model=list[ItemWithDistance])
48-
async def similar_handler(database_session: DBSession, id: int, n: int = 5) -> list[ItemWithDistance]:
48+
async def similar_handler(
49+
context: CommonDeps, database_session: DBSession, id: int, n: int = 5
50+
) -> list[ItemWithDistance]:
4951
"""A similarity API to find items similar to items with given ID."""
5052
item = (await database_session.scalars(select(Item).where(Item.id == id))).first()
5153
if not item:
5254
raise HTTPException(detail=f"Item with ID {id} not found.", status_code=404)
55+
5356
closest = await database_session.execute(
54-
select(Item, Item.embedding.l2_distance(item.embedding))
57+
select(Item, Item.embedding_ada002.l2_distance(item.embedding_ada002))
5558
.filter(Item.id != id)
56-
.order_by(Item.embedding.l2_distance(item.embedding))
59+
.order_by(Item.embedding_ada002.l2_distance(item.embedding_ada002))
5760
.limit(n)
5861
)
5962
return [
@@ -78,6 +81,7 @@ async def search_handler(
7881
embed_deployment=context.openai_embed_deployment,
7982
embed_model=context.openai_embed_model,
8083
embed_dimensions=context.openai_embed_dimensions,
84+
embedding_column=context.embedding_column,
8185
)
8286
results = await searcher.search_and_embed(
8387
query, top=top, enable_vector_search=enable_vector_search, enable_text_search=enable_text_search
@@ -99,6 +103,7 @@ async def chat_handler(
99103
embed_deployment=context.openai_embed_deployment,
100104
embed_model=context.openai_embed_model,
101105
embed_dimensions=context.openai_embed_dimensions,
106+
embedding_column=context.embedding_column,
102107
)
103108
rag_flow: SimpleRAGChat | AdvancedRAGChat
104109
if chat_request.context.overrides.use_advanced_flow:
@@ -139,6 +144,7 @@ async def chat_stream_handler(
139144
embed_deployment=context.openai_embed_deployment,
140145
embed_model=context.openai_embed_model,
141146
embed_dimensions=context.openai_embed_dimensions,
147+
embedding_column=context.embedding_column,
142148
)
143149

144150
rag_flow: SimpleRAGChat | AdvancedRAGChat

src/backend/fastapi_app/seed_data.json

Lines changed: 233915 additions & 908 deletions
Large diffs are not rendered by default.

src/backend/fastapi_app/setup_postgres_seeddata.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -36,17 +36,18 @@ async def seed_data(engine):
3636
with open(os.path.join(current_dir, "seed_data.json")) as f:
3737
catalog_items = json.load(f)
3838
for catalog_item in catalog_items:
39-
db_item = await session.execute(select(Item).filter(Item.id == catalog_item["Id"]))
39+
db_item = await session.execute(select(Item).filter(Item.id == catalog_item["id"]))
4040
if db_item.scalars().first():
4141
continue
4242
item = Item(
43-
id=catalog_item["Id"],
44-
type=catalog_item["Type"],
45-
brand=catalog_item["Brand"],
46-
name=catalog_item["Name"],
47-
description=catalog_item["Description"],
48-
price=catalog_item["Price"],
49-
embedding=catalog_item["Embedding"],
43+
id=catalog_item["id"],
44+
type=catalog_item["type"],
45+
brand=catalog_item["brand"],
46+
name=catalog_item["name"],
47+
description=catalog_item["description"],
48+
price=catalog_item["price"],
49+
embedding_ada002=catalog_item["embedding_ada002"],
50+
embedding_nomic=catalog_item.get("embedding_nomic"),
5051
)
5152
session.add(item)
5253
try:
Lines changed: 57 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
11
import asyncio
2+
import json
3+
import logging
4+
import os
25

36
from dotenv import load_dotenv
47
from sqlalchemy import select
@@ -10,28 +13,75 @@
1013
from fastapi_app.postgres_engine import create_postgres_engine_from_env
1114
from fastapi_app.postgres_models import Item
1215

16+
logger = logging.getLogger("ragapp")
1317

14-
async def update_embeddings():
18+
19+
async def update_embeddings(in_seed_data=False):
1520
azure_credential = await get_azure_credentials()
1621
engine = await create_postgres_engine_from_env(azure_credential)
1722
openai_embed_client = await create_openai_embed_client(azure_credential)
1823
common_params = await common_parameters()
1924

20-
async with async_sessionmaker(engine, expire_on_commit=False)() as session:
21-
async with session.begin():
22-
items = (await session.scalars(select(Item))).all()
23-
24-
for item in items:
25-
item.embedding = await compute_text_embedding(
25+
embedding_column = ""
26+
OPENAI_EMBED_HOST = os.getenv("OPENAI_EMBED_HOST")
27+
if OPENAI_EMBED_HOST == "azure":
28+
embedding_column = os.getenv("AZURE_OPENAI_EMBEDDING_COLUMN", "embedding_ada002")
29+
elif OPENAI_EMBED_HOST == "ollama":
30+
embedding_column = os.getenv("OLLAMA_EMBEDDING_COLUMN", "embedding_nomic")
31+
else:
32+
embedding_column = os.getenv("OPENAICOM_EMBEDDING_COLUMN", "embedding_ada002")
33+
logger.info(f"Updating embeddings in column: {embedding_column}")
34+
if in_seed_data:
35+
current_dir = os.path.dirname(os.path.realpath(__file__))
36+
items = []
37+
with open(os.path.join(current_dir, "seed_data.json")) as f:
38+
catalog_items = json.load(f)
39+
for catalog_item in catalog_items:
40+
item = Item(
41+
id=catalog_item["id"],
42+
type=catalog_item["type"],
43+
brand=catalog_item["brand"],
44+
name=catalog_item["name"],
45+
description=catalog_item["description"],
46+
price=catalog_item["price"],
47+
embedding_ada002=catalog_item["embedding_ada002"],
48+
embedding_nomic=catalog_item.get("embedding_nomic"),
49+
)
50+
embedding = await compute_text_embedding(
2651
item.to_str_for_embedding(),
2752
openai_client=openai_embed_client,
2853
embed_model=common_params.openai_embed_model,
54+
embed_deployment=common_params.openai_embed_deployment,
2955
embedding_dimensions=common_params.openai_embed_dimensions,
3056
)
57+
setattr(item, embedding_column, embedding)
58+
items.append(item)
59+
# write to the file
60+
with open(os.path.join(current_dir, "seed_data.json"), "w") as f:
61+
json.dump([item.to_dict(include_embedding=True) for item in items], f, indent=4)
62+
return
3163

64+
async with async_sessionmaker(engine, expire_on_commit=False)() as session:
65+
async with session.begin():
66+
items_to_update = (await session.scalars(select(Item))).all()
67+
68+
for item in items_to_update:
69+
setattr(
70+
item,
71+
embedding_column,
72+
await compute_text_embedding(
73+
item.to_str_for_embedding(),
74+
openai_client=openai_embed_client,
75+
embed_model=common_params.openai_embed_model,
76+
embed_deployment=common_params.openai_embed_deployment,
77+
embedding_dimensions=common_params.openai_embed_dimensions,
78+
),
79+
)
3280
await session.commit()
3381

3482

3583
if __name__ == "__main__":
84+
logging.basicConfig(level=logging.WARNING)
85+
logger.setLevel(logging.INFO)
3686
load_dotenv(override=True)
3787
asyncio.run(update_embeddings())

0 commit comments

Comments
 (0)