Skip to content

Commit 4fdc4d9

Browse files
mp911dechristophstrobl
authored andcommitted
Add support for Vector abstraction.
1 parent 10fbf4b commit 4fdc4d9

File tree

5 files changed

+346
-7
lines changed

5 files changed

+346
-7
lines changed

spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/convert/MongoConverters.java

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@
3131
import java.util.concurrent.atomic.AtomicInteger;
3232
import java.util.concurrent.atomic.AtomicLong;
3333

34+
import org.bson.BinaryVector;
35+
import org.bson.BsonArray;
36+
import org.bson.BsonDouble;
3437
import org.bson.BsonReader;
3538
import org.bson.BsonTimestamp;
3639
import org.bson.BsonUndefined;
@@ -44,14 +47,17 @@
4447
import org.bson.types.Code;
4548
import org.bson.types.Decimal128;
4649
import org.bson.types.ObjectId;
50+
4751
import org.springframework.core.convert.ConversionFailedException;
4852
import org.springframework.core.convert.TypeDescriptor;
4953
import org.springframework.core.convert.converter.ConditionalConverter;
5054
import org.springframework.core.convert.converter.Converter;
5155
import org.springframework.core.convert.converter.ConverterFactory;
5256
import org.springframework.data.convert.ReadingConverter;
5357
import org.springframework.data.convert.WritingConverter;
58+
import org.springframework.data.domain.Vector;
5459
import org.springframework.data.mongodb.core.mapping.FieldName;
60+
import org.springframework.data.mongodb.core.mapping.MongoVector;
5561
import org.springframework.data.mongodb.core.query.Term;
5662
import org.springframework.data.mongodb.core.script.NamedMongoScript;
5763
import org.springframework.util.Assert;
@@ -106,6 +112,10 @@ static Collection<Object> getConvertersToRegister() {
106112
converters.add(BinaryToByteArrayConverter.INSTANCE);
107113
converters.add(BsonTimestampToInstantConverter.INSTANCE);
108114

115+
converters.add(VectorToBsonArrayConverter.INSTANCE);
116+
converters.add(ListToVectorConverter.INSTANCE);
117+
converters.add(BinaryVectorToMongoVectorConverter.INSTANCE);
118+
109119
converters.add(reading(BsonUndefined.class, Object.class, it -> null));
110120
converters.add(reading(String.class, URI.class, URI::create).andWriting(URI::toString));
111121

@@ -417,6 +427,52 @@ public T convert(Number source) {
417427
}
418428
}
419429

430+
@WritingConverter
431+
enum VectorToBsonArrayConverter implements Converter<Vector, Object> {
432+
433+
INSTANCE;
434+
435+
@Override
436+
public Object convert(Vector source) {
437+
438+
if (source instanceof MongoVector mv) {
439+
return mv.getSource();
440+
}
441+
442+
double[] doubleArray = source.toDoubleArray();
443+
444+
BsonArray array = new BsonArray(doubleArray.length);
445+
446+
for (double v : doubleArray) {
447+
array.add(new BsonDouble(v));
448+
}
449+
450+
return array;
451+
}
452+
}
453+
454+
@ReadingConverter
455+
enum ListToVectorConverter implements Converter<List<Number>, Vector> {
456+
457+
INSTANCE;
458+
459+
@Override
460+
public Vector convert(List<Number> source) {
461+
return Vector.of(source);
462+
}
463+
}
464+
465+
@ReadingConverter
466+
enum BinaryVectorToMongoVectorConverter implements Converter<BinaryVector, Vector> {
467+
468+
INSTANCE;
469+
470+
@Override
471+
public Vector convert(BinaryVector source) {
472+
return MongoVector.of(source);
473+
}
474+
}
475+
420476
/**
421477
* {@link ConverterFactory} implementation converting {@link AtomicLong} into {@link Long}.
422478
*

spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/mapping/MongoSimpleTypes.java

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -53,13 +53,13 @@ public abstract class MongoSimpleTypes {
5353
public static final Set<Class<?>> AUTOGENERATED_ID_TYPES = Set.of(ObjectId.class, String.class, BigInteger.class);
5454
private static final Set<Class<?>> MONGO_SIMPLE_TYPES = Set.of(Binary.class, DBRef.class, Decimal128.class,
5555
org.bson.Document.class, Code.class, CodeWScope.class, CodeWithScope.class, ObjectId.class, Pattern.class,
56-
Symbol.class, UUID.class, Instant.class, BsonValue.class, BsonNumber.class, BsonType.class, BsonArray.class,
57-
BsonSymbol.class, BsonUndefined.class, BsonMinKey.class, BsonMaxKey.class, BsonNull.class, BsonBinary.class,
58-
BsonBoolean.class, BsonDateTime.class, BsonDbPointer.class, BsonDecimal128.class, BsonDocument.class,
59-
BsonDouble.class, BsonInt32.class, BsonInt64.class, BsonJavaScript.class, BsonJavaScriptWithScope.class,
60-
BsonObjectId.class, BsonRegularExpression.class, BsonString.class, BsonTimestamp.class, Geometry.class,
61-
GeometryCollection.class, LineString.class, MultiLineString.class, MultiPoint.class, MultiPolygon.class,
62-
Point.class, Polygon.class);
56+
Symbol.class, UUID.class, Instant.class, BinaryVector.class, BsonValue.class, BsonNumber.class, BsonType.class,
57+
BsonArray.class, BsonSymbol.class, BsonUndefined.class, BsonMinKey.class, BsonMaxKey.class, BsonNull.class,
58+
BsonBinary.class, BsonBoolean.class, BsonDateTime.class, BsonDbPointer.class, BsonDecimal128.class,
59+
BsonDocument.class, BsonDouble.class, BsonInt32.class, BsonInt64.class, BsonJavaScript.class,
60+
BsonJavaScriptWithScope.class, BsonObjectId.class, BsonRegularExpression.class, BsonString.class,
61+
BsonTimestamp.class, Geometry.class, GeometryCollection.class, LineString.class, MultiLineString.class,
62+
MultiPoint.class, MultiPolygon.class, Point.class, Polygon.class);
6363

6464
public static final SimpleTypeHolder HOLDER = new SimpleTypeHolder(MONGO_SIMPLE_TYPES, true) {
6565

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
/*
2+
* Copyright 2025 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package org.springframework.data.mongodb.core.mapping;
17+
18+
import org.bson.BinaryVector;
19+
import org.bson.Float32BinaryVector;
20+
import org.bson.Int8BinaryVector;
21+
import org.bson.PackedBitBinaryVector;
22+
23+
import org.springframework.data.domain.Vector;
24+
import org.springframework.util.ObjectUtils;
25+
26+
/**
27+
* MongoDB-specific extension to {@link Vector} based on Mongo's {@link Binary}. Note that only float32 and int8
28+
* variants can be represented as floating-point numbers. int1 returns an all-zero array for {@link #toFloatArray()} and
29+
* {@link #toDoubleArray()}.
30+
*
31+
* @author Mark Paluch
32+
* @since 4.5
33+
*/
34+
public class MongoVector implements Vector {
35+
36+
private final BinaryVector v;
37+
38+
MongoVector(BinaryVector v) {
39+
this.v = v;
40+
}
41+
42+
/**
43+
* Creates a new {@link MongoVector} from the given {@link BinaryVector}.
44+
*
45+
* @param v binary vector representation.
46+
* @return the {@link MongoVector} for the given vector values.
47+
*/
48+
public static MongoVector of(BinaryVector v) {
49+
return new MongoVector(v);
50+
}
51+
52+
@Override
53+
public Class<? extends Number> getType() {
54+
55+
if (v instanceof Float32BinaryVector) {
56+
return Float.class;
57+
}
58+
59+
if (v instanceof Int8BinaryVector) {
60+
return Byte.class;
61+
}
62+
63+
if (v instanceof PackedBitBinaryVector) {
64+
return Byte.class;
65+
}
66+
67+
return Number.class;
68+
}
69+
70+
@Override
71+
public BinaryVector getSource() {
72+
return v;
73+
}
74+
75+
@Override
76+
public int size() {
77+
78+
if (v instanceof Float32BinaryVector f) {
79+
return f.getData().length;
80+
}
81+
82+
if (v instanceof Int8BinaryVector i) {
83+
return i.getData().length;
84+
}
85+
86+
if (v instanceof PackedBitBinaryVector p) {
87+
return p.getData().length;
88+
}
89+
90+
return 0;
91+
}
92+
93+
@Override
94+
public float[] toFloatArray() {
95+
96+
if (v instanceof Float32BinaryVector f) {
97+
98+
float[] result = new float[f.getData().length];
99+
System.arraycopy(f.getData(), 0, result, 0, result.length);
100+
return result;
101+
}
102+
103+
if (v instanceof Int8BinaryVector i) {
104+
105+
float[] result = new float[i.getData().length];
106+
System.arraycopy(i.getData(), 0, result, 0, result.length);
107+
return result;
108+
}
109+
110+
return new float[size()];
111+
}
112+
113+
@Override
114+
public double[] toDoubleArray() {
115+
116+
if (v instanceof Float32BinaryVector f) {
117+
118+
float[] data = f.getData();
119+
double[] result = new double[data.length];
120+
for (int i = 0; i < data.length; i++) {
121+
result[i] = data[i];
122+
}
123+
124+
return result;
125+
}
126+
127+
if (v instanceof Int8BinaryVector i) {
128+
129+
double[] result = new double[i.getData().length];
130+
System.arraycopy(i.getData(), 0, result, 0, result.length);
131+
return result;
132+
}
133+
134+
return new double[size()];
135+
}
136+
137+
@Override
138+
public boolean equals(Object o) {
139+
if (!(o instanceof MongoVector that)) {
140+
return false;
141+
}
142+
return ObjectUtils.nullSafeEquals(v, that.v);
143+
}
144+
145+
@Override
146+
public int hashCode() {
147+
return ObjectUtils.nullSafeHashCode(v);
148+
}
149+
150+
@Override
151+
public String toString() {
152+
return "MV[" + v + "]";
153+
}
154+
}

spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/convert/MappingMongoConverterUnitTests.java

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
import java.util.stream.Stream;
3434

3535
import org.assertj.core.data.Percentage;
36+
import org.bson.BsonDouble;
3637
import org.bson.BsonUndefined;
3738
import org.bson.types.Binary;
3839
import org.bson.types.Code;
@@ -70,6 +71,7 @@
7071
import org.springframework.data.convert.ReadingConverter;
7172
import org.springframework.data.convert.ValueConverter;
7273
import org.springframework.data.convert.WritingConverter;
74+
import org.springframework.data.domain.Vector;
7375
import org.springframework.data.geo.Box;
7476
import org.springframework.data.geo.Circle;
7577
import org.springframework.data.geo.Distance;
@@ -3328,13 +3330,36 @@ void shouldReadNonIdFieldCalledIdFromSource() {
33283330
assertThat(target.id).isEqualTo(source.id);
33293331
}
33303332

3333+
@Test // GH-4706
3334+
void shouldWriteVectorValues() {
3335+
3336+
WithVector source = new WithVector();
3337+
source.embeddings = Vector.of(1.1d, 2.2d, 3.3d);
3338+
3339+
org.bson.Document document = write(source);
3340+
assertThat(document.getList("embeddings", BsonDouble.class)).hasSize(3);
3341+
}
3342+
3343+
@Test // GH-4706
3344+
void shouldReadVectorValues() {
3345+
3346+
org.bson.Document document = new org.bson.Document("embeddings", List.of(1.1d, 2.2d, 3.3d));
3347+
WithVector withVector = converter.read(WithVector.class, document);
3348+
assertThat(withVector.embeddings.toDoubleArray()).contains(1.1d, 2.2d, 3.3d);
3349+
}
3350+
33313351
org.bson.Document write(Object source) {
33323352

33333353
org.bson.Document target = new org.bson.Document();
33343354
converter.write(source, target);
33353355
return target;
33363356
}
33373357

3358+
static class WithVector {
3359+
3360+
Vector embeddings;
3361+
}
3362+
33383363
static class GenericType<T> {
33393364
T content;
33403365
}

0 commit comments

Comments
 (0)