Skip to content

Commit 9b745c4

Browse files
mp911dechristophstrobl
authored andcommitted
Polishing
1 parent 4fdc4d9 commit 9b745c4

File tree

2 files changed

+91
-46
lines changed

2 files changed

+91
-46
lines changed

spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/VectorSearchOperation.java

Lines changed: 67 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,12 @@
2323
import java.util.function.Consumer;
2424
import java.util.stream.Collectors;
2525

26+
import org.bson.BinaryVector;
2627
import org.bson.Document;
2728

2829
import org.springframework.data.domain.Limit;
30+
import org.springframework.data.domain.Vector;
31+
import org.springframework.data.mongodb.core.mapping.MongoVector;
2932
import org.springframework.data.mongodb.core.query.Criteria;
3033
import org.springframework.data.mongodb.core.query.CriteriaDefinition;
3134
import org.springframework.lang.Contract;
@@ -54,13 +57,13 @@ public class VectorSearchOperation implements AggregationOperation {
5457
private final Limit limit;
5558
private final @Nullable Integer numCandidates;
5659
private final QueryPaths path;
57-
private final List<? extends Number> vector;
60+
private final Vector vector;
5861
private final String score;
5962
private final Consumer<Criteria> scoreCriteria;
6063

6164
private VectorSearchOperation(SearchType searchType, @Nullable CriteriaDefinition filter, String indexName,
62-
Limit limit, @Nullable Integer numCandidates, QueryPaths path, List<? extends Number> vector,
63-
@Nullable String searchScore, Consumer<Criteria> scoreCriteria) {
65+
Limit limit, @Nullable Integer numCandidates, QueryPaths path, Vector vector, @Nullable String searchScore,
66+
Consumer<Criteria> scoreCriteria) {
6467

6568
this.searchType = searchType;
6669
this.filter = filter;
@@ -73,7 +76,7 @@ private VectorSearchOperation(SearchType searchType, @Nullable CriteriaDefinitio
7376
this.scoreCriteria = scoreCriteria;
7477
}
7578

76-
VectorSearchOperation(String indexName, QueryPaths path, Limit limit, List<? extends Number> vector) {
79+
VectorSearchOperation(String indexName, QueryPaths path, Limit limit, Vector vector) {
7780
this(SearchType.DEFAULT, null, indexName, limit, null, path, vector, null, null);
7881
}
7982

@@ -249,8 +252,18 @@ public Document toDocument(AggregationOperationContext context) {
249252
path = mappedObject.keySet().iterator().next();
250253
}
251254

255+
Object source = vector.getSource();
256+
257+
if (source instanceof float[]) {
258+
source = vector.toDoubleArray();
259+
}
260+
261+
if (source instanceof double[] ds) {
262+
source = Arrays.stream(ds).boxed().collect(Collectors.toList());
263+
}
264+
252265
$vectorSearch.append("path", path);
253-
$vectorSearch.append("queryVector", vector);
266+
$vectorSearch.append("queryVector", source);
254267

255268
return new Document(getOperator(), $vectorSearch);
256269
}
@@ -288,7 +301,7 @@ private static class VectorSearchBuilder implements PathContributor, VectorContr
288301

289302
String index;
290303
QueryPath<String> paths;
291-
private List<? extends Number> vector;
304+
Vector vector;
292305

293306
PathContributor index(String index) {
294307
this.index = index;
@@ -308,8 +321,8 @@ public VectorSearchOperation limit(Limit limit) {
308321
}
309322

310323
@Override
311-
public LimitContributor vector(List<? extends Number> vectors) {
312-
this.vector = vectors;
324+
public LimitContributor vector(Vector vector) {
325+
this.vector = vector;
313326
return this;
314327
}
315328
}
@@ -428,28 +441,63 @@ public interface PathContributor {
428441
public interface VectorContributor {
429442

430443
/**
431-
* Array of numbers of the BSON double, BSON BinData vector subtype float32, or BSON BinData vector subtype int1 or
432-
* int8 type that represent the query vector. The number type must match the indexed field value type. Otherwise,
433-
* Atlas Vector Search doesn't return any results or errors.
444+
* Array of float numbers that represent the query vector. The number type must match the indexed field value type.
445+
* Otherwise, Atlas Vector Search doesn't return any results or errors.
446+
*
447+
* @param vector the query vector.
448+
* @return
449+
*/
450+
@Contract("_ -> this")
451+
default LimitContributor vector(float... vector) {
452+
return vector(Vector.of(vector));
453+
}
454+
455+
/**
456+
* Array of double numbers that represent the query vector. The number type must match the indexed field value type.
457+
* Otherwise, Atlas Vector Search doesn't return any results or errors.
458+
*
459+
* @param vector the query vector.
460+
* @return
461+
*/
462+
@Contract("_ -> this")
463+
default LimitContributor vector(double... vector) {
464+
return vector(Vector.of(vector));
465+
}
466+
467+
/**
468+
* Array of numbers that represent the query vector. The number type must match the indexed field value type.
469+
* Otherwise, Atlas Vector Search doesn't return any results or errors.
470+
*
471+
* @param vector the query vector.
472+
* @return
473+
*/
474+
@Contract("_ -> this")
475+
default LimitContributor vector(List<? extends Number> vector) {
476+
return vector(Vector.of(vector));
477+
}
478+
479+
/**
480+
* Binary vector (BSON BinData vector subtype float32, or BSON BinData vector subtype int1 or int8 type) that
481+
* represent the query vector. The number type must match the indexed field value type. Otherwise, Atlas Vector
482+
* Search doesn't return any results or errors.
434483
*
435-
* @param vectors
484+
* @param vector the query vector.
436485
* @return
437486
*/
438487
@Contract("_ -> this")
439-
default LimitContributor vector(Double... vectors) {
440-
return vector(Arrays.asList(vectors));
488+
default LimitContributor vector(BinaryVector vector) {
489+
return vector(MongoVector.of(vector));
441490
}
442491

443492
/**
444-
* Array of numbers of the BSON double, BSON BinData vector subtype float32, or BSON BinData vector subtype int1 or
445-
* int8 type that represent the query vector. The number type must match the indexed field value type. Otherwise,
446-
* Atlas Vector Search doesn't return any results or errors.
493+
* The query vector. The number type must match the indexed field value type. Otherwise, Atlas Vector Search doesn't
494+
* return any results or errors.
447495
*
448-
* @param vectors
496+
* @param vector the query vector.
449497
* @return
450498
*/
451499
@Contract("_ -> this")
452-
LimitContributor vector(List<? extends Number> vectors);
500+
LimitContributor vector(Vector vector);
453501
}
454502

455503
public interface LimitContributor {

spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/VectorSearchTests.java

Lines changed: 24 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -31,47 +31,44 @@
3131
@ExtendWith(MongoTemplateExtension.class)
3232
public class VectorSearchTests {
3333

34-
static final String COLLECTION_NAME = "embedded_movies";
34+
static final String COLLECTION_NAME = "movies";
3535

3636
@Template(database = "mflix") //
3737
static MongoTestTemplate template;
3838

3939
@Test
4040
void xxx() {
4141

42-
// boolean hasIndex = template.indexOps(COLLECTION_NAME).getIndexInfo().stream()
43-
// .anyMatch(it -> it.getName().endsWith("vector_index"));
42+
// boolean hasIndex = template.indexOps(COLLECTION_NAME).getIndexInfo().stream()
43+
// .anyMatch(it -> it.getName().endsWith("movie_vector_index"));
4444

4545
// TODO: index conversion etc. is missing - should we combine the index info listing?
46-
// boolean hasIndex = template.execute(db -> {
47-
//
48-
// Document doc = db.runCommand(new Document("listSearchIndexes", COLLECTION_NAME));
49-
// Object searchIndexes = BsonUtils.resolveValue(BsonUtils.asMap(doc), "cursor.firstBatch");
50-
// if(searchIndexes instanceof Collection<?> indexes) {
51-
// return indexes.stream().anyMatch(it -> it instanceof Document idx && idx.get("name", String.class).equalsIgnoreCase("vector_index"));
52-
// }
53-
// return false;
54-
// });
46+
// boolean hasIndex = template.execute(db -> {
47+
//
48+
// Document doc = db.runCommand(new Document("listSearchIndexes", COLLECTION_NAME));
49+
// Object searchIndexes = BsonUtils.resolveValue(BsonUtils.asMap(doc), "cursor.firstBatch");
50+
// if(searchIndexes instanceof Collection<?> indexes) {
51+
// return indexes.stream().anyMatch(it -> it instanceof Document idx && idx.get("name",
52+
// String.class).equalsIgnoreCase("vector_index"));
53+
// }
54+
// return false;
55+
// });
5556

56-
boolean hasIndex = template.searchIndexOps(COLLECTION_NAME).exists("vector_index");
57-
58-
if(hasIndex) {
59-
System.out.println("found the index: vector_index");
60-
System.out.println(template.searchIndexOps(COLLECTION_NAME).getIndexInfo());
61-
template.searchIndexOps(COLLECTION_NAME).updateIndex(new VectorIndex("vector_index").addVector("plot_embedding",
62-
field -> field.dimensions(1536).similarity("euclidean")));
63-
// template.indexOps(COLLECTION_NAME).vectorIndexOperations().dropIndex("vector_name");
57+
if (!template.collectionExists(COLLECTION_NAME)) {
58+
template.createCollection(COLLECTION_NAME);
6459
}
65-
else {
60+
61+
boolean hasIndex = template.searchIndexOps(COLLECTION_NAME).exists("movie_vector_index");
62+
63+
if (!hasIndex) {
6664

6765
System.out.print("Creating index: ");
68-
String s = template.searchIndexOps(COLLECTION_NAME).ensureIndex(
69-
new VectorIndex("vector_index").addVector("plot_embedding",
70-
field -> field.dimensions(1536).similarity("cosine")));
71-
System.out.println(s);
66+
VectorIndex vectorIndex = new VectorIndex("movie_vector_index").addVector("plot_embedding",
67+
field -> field.dimensions(1536).similarity(VectorIndex.SimilarityFunction.COSINE)).addFilter("language");
68+
String s = template.searchIndexOps(COLLECTION_NAME).ensureIndex(vectorIndex);
7269
}
7370

74-
VectorSearchOperation $vectorSearch = VectorSearchOperation.search("vector_index").path("plot_embedding")
71+
VectorSearchOperation $vectorSearch = VectorSearchOperation.search("movie_vector_index").path("plot_embedding")
7572
.vector(vectors).limit(10).numCandidates(150).withSearchScore();
7673

7774
Aggregation agg = Aggregation.newAggregation($vectorSearch, Aggregation.project("plot", "title"));
@@ -81,7 +78,7 @@ void xxx() {
8178
aggregate.forEach(System.out::println);
8279
}
8380

84-
static Double[] vectors = { -0.0016261312, -0.028070757, -0.011342932, -0.012775794, -0.0027440966, 0.008683807,
81+
static double[] vectors = { -0.0016261312, -0.028070757, -0.011342932, -0.012775794, -0.0027440966, 0.008683807,
8582
-0.02575152, -0.02020668, -0.010283281, -0.0041719596, 0.021392956, 0.028657231, -0.006634482, 0.007490867,
8683
0.018593878, 0.0038187427, 0.029590257, -0.01451522, 0.016061379, 0.00008528442, -0.008943722, 0.01627464,
8784
0.024311995, -0.025911469, 0.00022596726, -0.008863748, 0.008823762, -0.034921836, 0.007910728, -0.01515501,

0 commit comments

Comments
 (0)