23
23
import java .util .function .Consumer ;
24
24
import java .util .stream .Collectors ;
25
25
26
+ import org .bson .BinaryVector ;
26
27
import org .bson .Document ;
27
28
28
29
import org .springframework .data .domain .Limit ;
30
+ import org .springframework .data .domain .Vector ;
31
+ import org .springframework .data .mongodb .core .mapping .MongoVector ;
29
32
import org .springframework .data .mongodb .core .query .Criteria ;
30
33
import org .springframework .data .mongodb .core .query .CriteriaDefinition ;
31
34
import org .springframework .lang .Contract ;
@@ -54,13 +57,13 @@ public class VectorSearchOperation implements AggregationOperation {
54
57
private final Limit limit ;
55
58
private final @ Nullable Integer numCandidates ;
56
59
private final QueryPaths path ;
57
- private final List <? extends Number > vector ;
60
+ private final Vector vector ;
58
61
private final String score ;
59
62
private final Consumer <Criteria > scoreCriteria ;
60
63
61
64
private VectorSearchOperation (SearchType searchType , @ Nullable CriteriaDefinition filter , String indexName ,
62
- Limit limit , @ Nullable Integer numCandidates , QueryPaths path , List <? extends Number > vector ,
63
- @ Nullable String searchScore , Consumer <Criteria > scoreCriteria ) {
65
+ Limit limit , @ Nullable Integer numCandidates , QueryPaths path , Vector vector , @ Nullable String searchScore ,
66
+ Consumer <Criteria > scoreCriteria ) {
64
67
65
68
this .searchType = searchType ;
66
69
this .filter = filter ;
@@ -73,7 +76,7 @@ private VectorSearchOperation(SearchType searchType, @Nullable CriteriaDefinitio
73
76
this .scoreCriteria = scoreCriteria ;
74
77
}
75
78
76
- VectorSearchOperation (String indexName , QueryPaths path , Limit limit , List <? extends Number > vector ) {
79
+ VectorSearchOperation (String indexName , QueryPaths path , Limit limit , Vector vector ) {
77
80
this (SearchType .DEFAULT , null , indexName , limit , null , path , vector , null , null );
78
81
}
79
82
@@ -249,8 +252,18 @@ public Document toDocument(AggregationOperationContext context) {
249
252
path = mappedObject .keySet ().iterator ().next ();
250
253
}
251
254
255
+ Object source = vector .getSource ();
256
+
257
+ if (source instanceof float []) {
258
+ source = vector .toDoubleArray ();
259
+ }
260
+
261
+ if (source instanceof double [] ds ) {
262
+ source = Arrays .stream (ds ).boxed ().collect (Collectors .toList ());
263
+ }
264
+
252
265
$vectorSearch .append ("path" , path );
253
- $vectorSearch .append ("queryVector" , vector );
266
+ $vectorSearch .append ("queryVector" , source );
254
267
255
268
return new Document (getOperator (), $vectorSearch );
256
269
}
@@ -288,7 +301,7 @@ private static class VectorSearchBuilder implements PathContributor, VectorContr
288
301
289
302
String index ;
290
303
QueryPath <String > paths ;
291
- private List <? extends Number > vector ;
304
+ Vector vector ;
292
305
293
306
PathContributor index (String index ) {
294
307
this .index = index ;
@@ -308,8 +321,8 @@ public VectorSearchOperation limit(Limit limit) {
308
321
}
309
322
310
323
@ Override
311
- public LimitContributor vector (List <? extends Number > vectors ) {
312
- this .vector = vectors ;
324
+ public LimitContributor vector (Vector vector ) {
325
+ this .vector = vector ;
313
326
return this ;
314
327
}
315
328
}
@@ -428,28 +441,63 @@ public interface PathContributor {
428
441
public interface VectorContributor {
429
442
430
443
/**
431
- * Array of numbers of the BSON double, BSON BinData vector subtype float32, or BSON BinData vector subtype int1 or
432
- * int8 type that represent the query vector. The number type must match the indexed field value type. Otherwise,
433
- * Atlas Vector Search doesn't return any results or errors.
444
+ * Array of float numbers that represent the query vector. The number type must match the indexed field value type.
445
+ * Otherwise, Atlas Vector Search doesn't return any results or errors.
446
+ *
447
+ * @param vector the query vector.
448
+ * @return
449
+ */
450
+ @ Contract ("_ -> this" )
451
+ default LimitContributor vector (float ... vector ) {
452
+ return vector (Vector .of (vector ));
453
+ }
454
+
455
+ /**
456
+ * Array of double numbers that represent the query vector. The number type must match the indexed field value type.
457
+ * Otherwise, Atlas Vector Search doesn't return any results or errors.
458
+ *
459
+ * @param vector the query vector.
460
+ * @return
461
+ */
462
+ @ Contract ("_ -> this" )
463
+ default LimitContributor vector (double ... vector ) {
464
+ return vector (Vector .of (vector ));
465
+ }
466
+
467
+ /**
468
+ * Array of numbers that represent the query vector. The number type must match the indexed field value type.
469
+ * Otherwise, Atlas Vector Search doesn't return any results or errors.
470
+ *
471
+ * @param vector the query vector.
472
+ * @return
473
+ */
474
+ @ Contract ("_ -> this" )
475
+ default LimitContributor vector (List <? extends Number > vector ) {
476
+ return vector (Vector .of (vector ));
477
+ }
478
+
479
+ /**
480
+ * Binary vector (BSON BinData vector subtype float32, or BSON BinData vector subtype int1 or int8 type) that
481
+ * represent the query vector. The number type must match the indexed field value type. Otherwise, Atlas Vector
482
+ * Search doesn't return any results or errors.
434
483
*
435
- * @param vectors
484
+ * @param vector the query vector.
436
485
* @return
437
486
*/
438
487
@ Contract ("_ -> this" )
439
- default LimitContributor vector (Double ... vectors ) {
440
- return vector (Arrays . asList ( vectors ));
488
+ default LimitContributor vector (BinaryVector vector ) {
489
+ return vector (MongoVector . of ( vector ));
441
490
}
442
491
443
492
/**
444
- * Array of numbers of the BSON double, BSON BinData vector subtype float32, or BSON BinData vector subtype int1 or
445
- * int8 type that represent the query vector. The number type must match the indexed field value type. Otherwise,
446
- * Atlas Vector Search doesn't return any results or errors.
493
+ * The query vector. The number type must match the indexed field value type. Otherwise, Atlas Vector Search doesn't
494
+ * return any results or errors.
447
495
*
448
- * @param vectors
496
+ * @param vector the query vector.
449
497
* @return
450
498
*/
451
499
@ Contract ("_ -> this" )
452
- LimitContributor vector (List <? extends Number > vectors );
500
+ LimitContributor vector (Vector vector );
453
501
}
454
502
455
503
public interface LimitContributor {
0 commit comments