From 1869e41ec3d3eab2568cbab9d350ebf70749fd33 Mon Sep 17 00:00:00 2001 From: Mark Paluch Date: Mon, 4 Nov 2024 14:16:54 +0100 Subject: [PATCH 1/5] Prepare issue branch. --- pom.xml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pom.xml b/pom.xml index cdd68fe2d4..ec9dad98ee 100644 --- a/pom.xml +++ b/pom.xml @@ -5,7 +5,7 @@ org.springframework.data spring-data-commons - 3.5.0-SNAPSHOT + 3.5.0-GH-3193-SNAPSHOT Spring Data Core Core Spring concepts underpinning every Spring Data module. From f47d4b01fa5d3a345e4e80010078ba06328be7d2 Mon Sep 17 00:00:00 2001 From: Mark Paluch Date: Mon, 4 Nov 2024 14:20:19 +0100 Subject: [PATCH 2/5] Introduce `Vector` abstraction. --- .../data/domain/DoubleVector.java | 103 +++++++++++ .../data/domain/FloatVector.java | 103 +++++++++++ .../data/domain/NumberVector.java | 132 ++++++++++++++ .../springframework/data/domain/Vector.java | 171 ++++++++++++++++++ .../data/domain/DoubleVectorUnitTests.java | 97 ++++++++++ .../data/domain/FloatVectorUnitTests.java | 97 ++++++++++ .../data/domain/NumberVectorUnitTests.java | 99 ++++++++++ 7 files changed, 802 insertions(+) create mode 100644 src/main/java/org/springframework/data/domain/DoubleVector.java create mode 100644 src/main/java/org/springframework/data/domain/FloatVector.java create mode 100644 src/main/java/org/springframework/data/domain/NumberVector.java create mode 100644 src/main/java/org/springframework/data/domain/Vector.java create mode 100644 src/test/java/org/springframework/data/domain/DoubleVectorUnitTests.java create mode 100644 src/test/java/org/springframework/data/domain/FloatVectorUnitTests.java create mode 100644 src/test/java/org/springframework/data/domain/NumberVectorUnitTests.java diff --git a/src/main/java/org/springframework/data/domain/DoubleVector.java b/src/main/java/org/springframework/data/domain/DoubleVector.java new file mode 100644 index 0000000000..b2dc93c99c --- /dev/null +++ b/src/main/java/org/springframework/data/domain/DoubleVector.java @@ -0,0 +1,103 @@ +/* + * Copyright 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.domain; + +import java.util.Arrays; + +import org.springframework.util.ObjectUtils; + +/** + * {@link Vector} implementation based on {@code double} array. + * + * @author Mark Paluch + * @since 3.5 + */ +class DoubleVector implements Vector { + + private final double[] v; + + public DoubleVector(double[] v) { + this.v = v; + } + + /** + * Copy the given {@code double} array and wrap it within a Vector. + */ + static Vector copy(double[] v) { + + double[] copy = new double[v.length]; + System.arraycopy(v, 0, copy, 0, copy.length); + + return new DoubleVector(copy); + } + + @Override + public Class getType() { + return Double.TYPE; + } + + @Override + public Object getSource() { + return v; + } + + @Override + public int size() { + return v.length; + } + + @Override + public float[] toFloatArray() { + + float[] copy = new float[this.v.length]; + for (int i = 0; i < this.v.length; i++) { + copy[i] = (float) this.v[i]; + } + + return copy; + } + + @Override + public double[] toDoubleArray() { + + double[] copy = new double[this.v.length]; + System.arraycopy(this.v, 0, copy, 0, copy.length); + + return copy; + } + + @Override + public boolean equals(Object o) { + + if (this == o) { + return true; + } + if (!(o instanceof DoubleVector that)) { + return false; + } + return ObjectUtils.nullSafeEquals(v, that.v); + } + + @Override + public int hashCode() { + return Arrays.hashCode(v); + } + + @Override + public String toString() { + return "D" + Arrays.toString(v); + } +} diff --git a/src/main/java/org/springframework/data/domain/FloatVector.java b/src/main/java/org/springframework/data/domain/FloatVector.java new file mode 100644 index 0000000000..4f3c425804 --- /dev/null +++ b/src/main/java/org/springframework/data/domain/FloatVector.java @@ -0,0 +1,103 @@ +/* + * Copyright 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.domain; + +import java.util.Arrays; + +import org.springframework.util.ObjectUtils; + +/** + * {@link Vector} implementation based on {@code float} array. + * + * @author Mark Paluch + * @since 3.5 + */ +class FloatVector implements Vector { + + private final float[] v; + + public FloatVector(float[] v) { + this.v = v; + } + + /** + * Copy the given {@code float} array and wrap it within a Vector. + */ + static Vector copy(float[] v) { + + float[] copy = new float[v.length]; + System.arraycopy(v, 0, copy, 0, copy.length); + + return new FloatVector(copy); + } + + @Override + public Class getType() { + return Float.TYPE; + } + + @Override + public Object getSource() { + return v; + } + + @Override + public int size() { + return v.length; + } + + @Override + public float[] toFloatArray() { + + float[] copy = new float[this.v.length]; + System.arraycopy(this.v, 0, copy, 0, copy.length); + + return copy; + } + + @Override + public double[] toDoubleArray() { + + double[] copy = new double[this.v.length]; + for (int i = 0; i < this.v.length; i++) { + copy[i] = this.v[i]; + } + + return copy; + } + + @Override + public boolean equals(Object o) { + + if (this == o) { + return true; + } + if (!(o instanceof FloatVector that)) { + return false; + } + return ObjectUtils.nullSafeEquals(v, that.v); + } + + @Override + public int hashCode() { + return Arrays.hashCode(v); + } + + @Override + public String toString() { + return "F" + Arrays.toString(v); + } +} diff --git a/src/main/java/org/springframework/data/domain/NumberVector.java b/src/main/java/org/springframework/data/domain/NumberVector.java new file mode 100644 index 0000000000..eda0f9f064 --- /dev/null +++ b/src/main/java/org/springframework/data/domain/NumberVector.java @@ -0,0 +1,132 @@ +/* + * Copyright 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.domain; + +import java.util.Arrays; +import java.util.Collection; + +import org.springframework.util.ObjectUtils; + +/** + * {@link Vector} implementation based on {@link Number} array. + * + * @author Mark Paluch + * @since 3.5 + */ +class NumberVector implements Vector { + + private final Number[] v; + + public NumberVector(Number[] v) { + this.v = v; + } + + /** + * Copy the given {@link Number} array and wrap it within a Vector. + */ + static Vector copy(Number[] v) { + + Number[] copy = new Number[v.length]; + System.arraycopy(v, 0, copy, 0, copy.length); + + return new NumberVector(copy); + } + + /** + * Copy the given {@link Number} and wrap it within a Vector. + */ + static Vector copy(Collection numbers) { + + Number[] copy = new Number[numbers.size()]; + + int i = 0; + for (Number number : numbers) { + copy[i++] = number; + } + + return new NumberVector(copy); + } + + @Override + public Class getType() { + + Class candidate = null; + for (Object val : v) { + if (val != null) { + if (candidate == null) { + candidate = val.getClass(); + } else if (candidate != val.getClass()) { + return Number.class; + } + } + } + return (Class) candidate; + } + + @Override + public Object getSource() { + return v; + } + + @Override + public int size() { + return v.length; + } + + @Override + public float[] toFloatArray() { + + float[] copy = new float[this.v.length]; + for (int i = 0; i < this.v.length; i++) { + copy[i] = this.v[i].floatValue(); + } + + return copy; + } + + @Override + public double[] toDoubleArray() { + + double[] copy = new double[this.v.length]; + for (int i = 0; i < this.v.length; i++) { + copy[i] = this.v[i].doubleValue(); + } + + return copy; + } + + @Override + public boolean equals(Object o) { + + if (this == o) { + return true; + } + if (!(o instanceof NumberVector that)) { + return false; + } + return ObjectUtils.nullSafeEquals(v, that.v); + } + + @Override + public int hashCode() { + return Arrays.hashCode(v); + } + + @Override + public String toString() { + return "N" + Arrays.toString(v); + } +} diff --git a/src/main/java/org/springframework/data/domain/Vector.java b/src/main/java/org/springframework/data/domain/Vector.java new file mode 100644 index 0000000000..2e44d793e7 --- /dev/null +++ b/src/main/java/org/springframework/data/domain/Vector.java @@ -0,0 +1,171 @@ +/* + * Copyright 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.domain; + +import java.util.Collection; + +import org.springframework.util.Assert; + +/** + * A vector is a fixed-length array of non-null numeric values. Vectors are represent a point in a multidimensional + * space that is commonly used in machine learning and statistics. + *

+ * Vector properties do not map cleanly to an existing class in the standard JDK Collections hierarchy. Vectors when + * used with embeddings (machine learning) represent an opaque point in the vector space that does not expose meaningful + * properties nor guarantees computational values to the outside world. + *

+ * Vectors should be treated as opaque values and should not be modified. They can be created from an array of numbers + * (typically {@code double} or {@code float} values) and used by components that need to provide the vector for storage + * or computation. + * + * @author Mark Paluch + * @since 3.5 + */ +public interface Vector { + + /** + * Creates a new {@link Vector} from the given float {@code values}. Vector values are duplicated to avoid capturing a + * mutable array instance and to prevent mutability. + * + * @param values float vector values. + * @return the {@link Vector} for the given vector values. + */ + static Vector of(float... values) { + + Assert.notNull(values, "float vector values must not be null"); + + return FloatVector.copy(values); + } + + /** + * Creates a new {@link Vector} from the given double {@code values}. Vector values are duplicated to avoid capturing + * a mutable array instance and to prevent mutability. + * + * @param values double vector values. + * @return the {@link Vector} for the given vector values. + */ + static Vector of(double... values) { + + Assert.notNull(values, "double vector values must not be null"); + + return DoubleVector.copy(values); + } + + /** + * Creates a new {@link Vector} from the given number {@code values}. Vector values are duplicated to avoid capturing + * a mutable array instance and to prevent mutability. + * + * @param values number vector values. + * @return the {@link Vector} for the given vector values. + */ + static Vector of(Number... values) { + + Assert.notNull(values, "Vector values must not be null"); + + return NumberVector.copy(values); + } + + /** + * Creates a new {@link Vector} from the given number {@code values}. Vector values are duplicated to avoid capturing + * a mutable collection instance and to prevent mutability. + * + * @param values number vector values. + * @return the {@link Vector} for the given vector values. + */ + static Vector of(Collection values) { + + Assert.notNull(values, "Vector values must not be null"); + + return NumberVector.copy(values); + } + + /** + * Creates a new unsafe {@link Vector} wrapper from the given {@code values}. Unsafe wrappers do not duplicate array + * values and are merely a view on the source array. + *

+ * Supported source type + * + * @param values vector values. + * @return the {@link Vector} for the given vector values. + */ + static Vector unsafe(float[] values) { + + Assert.notNull(values, "float vector values must not be null"); + + return new FloatVector(values); + } + + /** + * Creates a new unsafe {@link Vector} wrapper from the given {@code values}. Unsafe wrappers do not duplicate array + * values and are merely a view on the source array. + *

+ * Supported source type + * + * @param values vector values. + * @return the {@link Vector} for the given vector values. + */ + static Vector unsafe(double[] values) { + + Assert.notNull(values, "double vector values must not be null"); + + return new DoubleVector(values); + } + + /** + * Returns the type of the underlying vector source. + * + * @return the type of the underlying vector source. + */ + Class getType(); + + /** + * Returns the source array of the vector. The source array is not copied and should not be modified to avoid + * mutability issues. This method should be used for performance access. + * + * @return the source array of the vector. + */ + Object getSource(); + + /** + * Returns the number of dimensions. + * + * @return the number of dimensions. + */ + int size(); + + /** + * Convert the vector to a {@code float} array. The returned array is a copy of the {@link #getSource() source} array + * and can be modified safely. + *

+ * Conversion to {@code float} can incorporate loss of precision or result in values with a slight offset due to data + * type conversion if the source is not a {@code float} array. + * + * @return a new {@code float} array representing the vector point. + */ + float[] toFloatArray(); + + /** + * Convert the vector to a {@code double} array. The returned array is a copy of the {@link #getSource() source} array + * and can be modified safely. + *

+ * Conversion to {@code double} can incorporate loss of precision or result in values with a slight offset due to data + * type conversion if the source is not a {@code double} array. + * + * @return a new {@code double} array representing the vector point. + */ + double[] toDoubleArray(); + +} diff --git a/src/test/java/org/springframework/data/domain/DoubleVectorUnitTests.java b/src/test/java/org/springframework/data/domain/DoubleVectorUnitTests.java new file mode 100644 index 0000000000..b87c320d72 --- /dev/null +++ b/src/test/java/org/springframework/data/domain/DoubleVectorUnitTests.java @@ -0,0 +1,97 @@ +/* + * Copyright 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.domain; + +import static org.assertj.core.api.Assertions.*; + +import org.junit.jupiter.api.Test; + +/** + * Unit tests for {@link DoubleVector}. + * + * @author Mark Paluch + */ +class DoubleVectorUnitTests { + + double[] values = new double[] { 1.1, 2.2, 3.3, 4.4, 5.5 }; + float[] floats = new float[] { (float) 1.1d, (float) 2.2d, (float) 3.3d, (float) 4.4d, (float) 5.5d }; + + @Test // GH-3193 + void shouldCreateVector() { + + Vector vector = Vector.of(values); + + assertThat(vector.size()).isEqualTo(5); + assertThat(vector.getType()).isEqualTo(Double.TYPE); + } + + @Test // GH-3193 + void shouldCreateUnsafeVector() { + + Vector vector = Vector.unsafe(values); + + assertThat(vector.getSource()).isSameAs(values); + } + + @Test // GH-3193 + void shouldCopyVectorValues() { + + Vector vector = Vector.of(values); + + assertThat(vector.getSource()).isNotSameAs(vector).isEqualTo(values); + } + + @Test // GH-3193 + void shouldRenderToString() { + + Vector vector = Vector.of(values); + + assertThat(vector).hasToString("D[1.1, 2.2, 3.3, 4.4, 5.5]"); + } + + @Test // GH-3193 + void shouldCompareVector() { + + Vector vector = Vector.of(values); + + assertThat(vector).isEqualTo(Vector.of(values)); + assertThat(vector).hasSameHashCodeAs(Vector.of(values)); + } + + @Test // GH-3193 + void sourceShouldReturnSource() { + + Vector vector = new DoubleVector(values); + + assertThat(vector.getSource()).isSameAs(values); + } + + @Test // GH-3193 + void shouldCreateFloatArray() { + + Vector vector = Vector.of(values); + + assertThat(vector.toFloatArray()).isEqualTo(floats).isNotSameAs(floats); + } + + @Test // GH-3193 + void shouldCreateDoubleArray() { + + Vector vector = Vector.of(values); + + assertThat(vector.toDoubleArray()).isEqualTo(values).isNotSameAs(values); + } +} diff --git a/src/test/java/org/springframework/data/domain/FloatVectorUnitTests.java b/src/test/java/org/springframework/data/domain/FloatVectorUnitTests.java new file mode 100644 index 0000000000..bef96f4e81 --- /dev/null +++ b/src/test/java/org/springframework/data/domain/FloatVectorUnitTests.java @@ -0,0 +1,97 @@ +/* + * Copyright 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.domain; + +import static org.assertj.core.api.Assertions.*; + +import org.junit.jupiter.api.Test; + +/** + * Unit tests for {@link FloatVector}. + * + * @author Mark Paluch + */ +class FloatVectorUnitTests { + + float[] values = new float[] { 1.1f, 2.2f, 3.3f, 4.4f, 5.5f }; + double[] doubles = new double[] { 1.1f, 2.2f, 3.3f, 4.4f, 5.5f }; + + @Test // GH-3193 + void shouldCreateVector() { + + Vector vector = Vector.of(values); + + assertThat(vector.size()).isEqualTo(5); + assertThat(vector.getType()).isEqualTo(Float.TYPE); + } + + @Test // GH-3193 + void shouldCreateUnsafeVector() { + + Vector vector = Vector.unsafe(values); + + assertThat(vector.getSource()).isSameAs(values); + } + + @Test // GH-3193 + void shouldCopyVectorValues() { + + Vector vector = Vector.of(values); + + assertThat(vector.getSource()).isNotSameAs(vector).isEqualTo(values); + } + + @Test // GH-3193 + void shouldRenderToString() { + + Vector vector = Vector.of(values); + + assertThat(vector).hasToString("F[1.1, 2.2, 3.3, 4.4, 5.5]"); + } + + @Test // GH-3193 + void shouldCompareVector() { + + Vector vector = Vector.of(values); + + assertThat(vector).isEqualTo(Vector.of(values)); + assertThat(vector).hasSameHashCodeAs(Vector.of(values)); + } + + @Test // GH-3193 + void sourceShouldReturnSource() { + + Vector vector = new FloatVector(values); + + assertThat(vector.getSource()).isSameAs(values); + } + + @Test // GH-3193 + void shouldCreateFloatArray() { + + Vector vector = Vector.of(values); + + assertThat(vector.toFloatArray()).isEqualTo(values).isNotSameAs(values); + } + + @Test // GH-3193 + void shouldCreateDoubleArray() { + + Vector vector = Vector.of(values); + + assertThat(vector.toDoubleArray()).isEqualTo(doubles).isNotSameAs(doubles); + } +} diff --git a/src/test/java/org/springframework/data/domain/NumberVectorUnitTests.java b/src/test/java/org/springframework/data/domain/NumberVectorUnitTests.java new file mode 100644 index 0000000000..a9629a5f44 --- /dev/null +++ b/src/test/java/org/springframework/data/domain/NumberVectorUnitTests.java @@ -0,0 +1,99 @@ +/* + * Copyright 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.domain; + +import static org.assertj.core.api.Assertions.*; + +import org.junit.jupiter.api.Test; + +/** + * Unit tests for {@link NumberVector}. + * + * @author Mark Paluch + */ +class NumberVectorUnitTests { + + Number[] values = new Number[] { 1.1, 2.2, 3.3, 4.4, 5.5 }; + Number[] floats = new Number[] { (float) 1.1d, (float) 2.2d, (float) 3.3d, (float) 4.4d, (float) 5.5d }; + + @Test // GH-3193 + void shouldCreateVector() { + + Vector vector = Vector.of(values); + + assertThat(vector.size()).isEqualTo(5); + assertThat(vector.getType()).isEqualTo(Double.class); + } + + @Test // GH-3193 + void shouldCopyVectorValues() { + + Vector vector = Vector.of(values); + + assertThat(vector.getSource()).isNotSameAs(vector).isEqualTo(values); + } + + @Test // GH-3193 + void shouldRenderToString() { + + Vector vector = Vector.of(values); + + assertThat(vector).hasToString("N[1.1, 2.2, 3.3, 4.4, 5.5]"); + } + + @Test // GH-3193 + void shouldCompareVector() { + + Vector vector = Vector.of(values); + + assertThat(vector).isEqualTo(Vector.of(values)); + assertThat(vector).hasSameHashCodeAs(Vector.of(values)); + } + + @Test // GH-3193 + void sourceShouldReturnSource() { + + Vector vector = new NumberVector(values); + + assertThat(vector.getSource()).isSameAs(values); + } + + @Test // GH-3193 + void shouldCreateFloatArray() { + + Vector vector = Vector.of(values); + + float[] values = new float[this.floats.length]; + for (int i = 0; i < values.length; i++) { + values[i] = this.floats[i].floatValue(); + } + + assertThat(vector.toFloatArray()).isEqualTo(values).isNotSameAs(floats); + } + + @Test // GH-3193 + void shouldCreateDoubleArray() { + + Vector vector = Vector.of(values); + + double[] values = new double[this.values.length]; + for (int i = 0; i < values.length; i++) { + values[i] = this.values[i].doubleValue(); + } + + assertThat(vector.toDoubleArray()).isEqualTo(values).isNotSameAs(values); + } +} From c97a6990e543078945835296e9842c88472be179 Mon Sep 17 00:00:00 2001 From: Mark Paluch Date: Mon, 20 Jan 2025 14:22:18 +0100 Subject: [PATCH 3/5] Polishing. Refine object creation based on common element type to preserve comparison. --- .../data/domain/DoubleVector.java | 15 ++++++++++ .../data/domain/FloatVector.java | 15 ++++++++++ .../springframework/data/domain/Vector.java | 23 +++++++-------- .../data/domain/NumberVectorUnitTests.java | 28 ++++++++++--------- 4 files changed, 55 insertions(+), 26 deletions(-) diff --git a/src/main/java/org/springframework/data/domain/DoubleVector.java b/src/main/java/org/springframework/data/domain/DoubleVector.java index b2dc93c99c..50a6bed6f1 100644 --- a/src/main/java/org/springframework/data/domain/DoubleVector.java +++ b/src/main/java/org/springframework/data/domain/DoubleVector.java @@ -16,6 +16,7 @@ package org.springframework.data.domain; import java.util.Arrays; +import java.util.Collection; import org.springframework.util.ObjectUtils; @@ -44,6 +45,20 @@ static Vector copy(double[] v) { return new DoubleVector(copy); } + /** + * Copy the given numeric values and wrap within a Vector. + */ + static Vector copy(Collection v) { + + double[] copy = new double[v.size()]; + int i = 0; + for (Number number : v) { + copy[i++] = number.doubleValue(); + } + + return new DoubleVector(copy); + } + @Override public Class getType() { return Double.TYPE; diff --git a/src/main/java/org/springframework/data/domain/FloatVector.java b/src/main/java/org/springframework/data/domain/FloatVector.java index 4f3c425804..c1d6f1aa64 100644 --- a/src/main/java/org/springframework/data/domain/FloatVector.java +++ b/src/main/java/org/springframework/data/domain/FloatVector.java @@ -16,6 +16,7 @@ package org.springframework.data.domain; import java.util.Arrays; +import java.util.Collection; import org.springframework.util.ObjectUtils; @@ -44,6 +45,20 @@ static Vector copy(float[] v) { return new FloatVector(copy); } + /** + * Copy the given numeric values and wrap within a Vector. + */ + static Vector copy(Collection v) { + + float[] copy = new float[v.size()]; + int i = 0; + for (Number number : v) { + copy[i++] = number.floatValue(); + } + + return new FloatVector(copy); + } + @Override public Class getType() { return Float.TYPE; diff --git a/src/main/java/org/springframework/data/domain/Vector.java b/src/main/java/org/springframework/data/domain/Vector.java index 2e44d793e7..6d43fdcdd4 100644 --- a/src/main/java/org/springframework/data/domain/Vector.java +++ b/src/main/java/org/springframework/data/domain/Vector.java @@ -18,6 +18,7 @@ import java.util.Collection; import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; /** * A vector is a fixed-length array of non-null numeric values. Vectors are represent a point in a multidimensional @@ -66,28 +67,24 @@ static Vector of(double... values) { /** * Creates a new {@link Vector} from the given number {@code values}. Vector values are duplicated to avoid capturing - * a mutable array instance and to prevent mutability. + * a mutable collection instance and to prevent mutability. * * @param values number vector values. * @return the {@link Vector} for the given vector values. */ - static Vector of(Number... values) { + static Vector of(Collection values) { Assert.notNull(values, "Vector values must not be null"); - return NumberVector.copy(values); - } + Class cet = CollectionUtils.findCommonElementType(values); - /** - * Creates a new {@link Vector} from the given number {@code values}. Vector values are duplicated to avoid capturing - * a mutable collection instance and to prevent mutability. - * - * @param values number vector values. - * @return the {@link Vector} for the given vector values. - */ - static Vector of(Collection values) { + if (cet == Double.class) { + return DoubleVector.copy(values); + } - Assert.notNull(values, "Vector values must not be null"); + if (cet == Float.class) { + return FloatVector.copy(values); + } return NumberVector.copy(values); } diff --git a/src/test/java/org/springframework/data/domain/NumberVectorUnitTests.java b/src/test/java/org/springframework/data/domain/NumberVectorUnitTests.java index a9629a5f44..644113f41d 100644 --- a/src/test/java/org/springframework/data/domain/NumberVectorUnitTests.java +++ b/src/test/java/org/springframework/data/domain/NumberVectorUnitTests.java @@ -17,6 +17,8 @@ import static org.assertj.core.api.Assertions.*; +import java.util.Arrays; + import org.junit.jupiter.api.Test; /** @@ -26,22 +28,22 @@ */ class NumberVectorUnitTests { - Number[] values = new Number[] { 1.1, 2.2, 3.3, 4.4, 5.5 }; - Number[] floats = new Number[] { (float) 1.1d, (float) 2.2d, (float) 3.3d, (float) 4.4d, (float) 5.5d }; + Number[] values = new Number[] { 1.1, 2.2, 3.3, 4.4, 5.5, 6.6f }; + Number[] floats = new Number[] { (float) 1.1d, (float) 2.2d, (float) 3.3d, (float) 4.4d, (float) 5.5, 6.6 }; @Test // GH-3193 void shouldCreateVector() { - Vector vector = Vector.of(values); + Vector vector = Vector.of(Arrays.asList(values)); - assertThat(vector.size()).isEqualTo(5); - assertThat(vector.getType()).isEqualTo(Double.class); + assertThat(vector.size()).isEqualTo(6); + assertThat(vector.getType()).isEqualTo(Number.class); } @Test // GH-3193 void shouldCopyVectorValues() { - Vector vector = Vector.of(values); + Vector vector = Vector.of(Arrays.asList(values)); assertThat(vector.getSource()).isNotSameAs(vector).isEqualTo(values); } @@ -49,18 +51,18 @@ void shouldCopyVectorValues() { @Test // GH-3193 void shouldRenderToString() { - Vector vector = Vector.of(values); + Vector vector = Vector.of(Arrays.asList(values)); - assertThat(vector).hasToString("N[1.1, 2.2, 3.3, 4.4, 5.5]"); + assertThat(vector).hasToString("N[1.1, 2.2, 3.3, 4.4, 5.5, 6.6]"); } @Test // GH-3193 void shouldCompareVector() { - Vector vector = Vector.of(values); + Vector vector = Vector.of(Arrays.asList(values)); - assertThat(vector).isEqualTo(Vector.of(values)); - assertThat(vector).hasSameHashCodeAs(Vector.of(values)); + assertThat(vector).isEqualTo(Vector.of(Arrays.asList(values))); + assertThat(vector).hasSameHashCodeAs(Vector.of(Arrays.asList(values))); } @Test // GH-3193 @@ -74,7 +76,7 @@ void sourceShouldReturnSource() { @Test // GH-3193 void shouldCreateFloatArray() { - Vector vector = Vector.of(values); + Vector vector = Vector.of(Arrays.asList(values)); float[] values = new float[this.floats.length]; for (int i = 0; i < values.length; i++) { @@ -87,7 +89,7 @@ void shouldCreateFloatArray() { @Test // GH-3193 void shouldCreateDoubleArray() { - Vector vector = Vector.of(values); + Vector vector = Vector.of(Arrays.asList(values)); double[] values = new double[this.values.length]; for (int i = 0; i < values.length; i++) { From 85df69d05bd6f7505d7b38ae3796b12205e9a478 Mon Sep 17 00:00:00 2001 From: Mark Paluch Date: Mon, 20 Jan 2025 15:02:20 +0100 Subject: [PATCH 4/5] Extend copy signature for Number subclasses. --- src/main/java/org/springframework/data/domain/DoubleVector.java | 2 +- src/main/java/org/springframework/data/domain/FloatVector.java | 2 +- src/main/java/org/springframework/data/domain/NumberVector.java | 2 +- src/main/java/org/springframework/data/domain/Vector.java | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/main/java/org/springframework/data/domain/DoubleVector.java b/src/main/java/org/springframework/data/domain/DoubleVector.java index 50a6bed6f1..623d6a0118 100644 --- a/src/main/java/org/springframework/data/domain/DoubleVector.java +++ b/src/main/java/org/springframework/data/domain/DoubleVector.java @@ -48,7 +48,7 @@ static Vector copy(double[] v) { /** * Copy the given numeric values and wrap within a Vector. */ - static Vector copy(Collection v) { + static Vector copy(Collection v) { double[] copy = new double[v.size()]; int i = 0; diff --git a/src/main/java/org/springframework/data/domain/FloatVector.java b/src/main/java/org/springframework/data/domain/FloatVector.java index c1d6f1aa64..bb07df19cd 100644 --- a/src/main/java/org/springframework/data/domain/FloatVector.java +++ b/src/main/java/org/springframework/data/domain/FloatVector.java @@ -48,7 +48,7 @@ static Vector copy(float[] v) { /** * Copy the given numeric values and wrap within a Vector. */ - static Vector copy(Collection v) { + static Vector copy(Collection v) { float[] copy = new float[v.size()]; int i = 0; diff --git a/src/main/java/org/springframework/data/domain/NumberVector.java b/src/main/java/org/springframework/data/domain/NumberVector.java index eda0f9f064..b71dd7081e 100644 --- a/src/main/java/org/springframework/data/domain/NumberVector.java +++ b/src/main/java/org/springframework/data/domain/NumberVector.java @@ -48,7 +48,7 @@ static Vector copy(Number[] v) { /** * Copy the given {@link Number} and wrap it within a Vector. */ - static Vector copy(Collection numbers) { + static Vector copy(Collection numbers) { Number[] copy = new Number[numbers.size()]; diff --git a/src/main/java/org/springframework/data/domain/Vector.java b/src/main/java/org/springframework/data/domain/Vector.java index 6d43fdcdd4..d1e0cfa251 100644 --- a/src/main/java/org/springframework/data/domain/Vector.java +++ b/src/main/java/org/springframework/data/domain/Vector.java @@ -72,7 +72,7 @@ static Vector of(double... values) { * @param values number vector values. * @return the {@link Vector} for the given vector values. */ - static Vector of(Collection values) { + static Vector of(Collection values) { Assert.notNull(values, "Vector values must not be null"); From 2ed2f9213af32f4d0c9170c9d733e7f86f7ee272 Mon Sep 17 00:00:00 2001 From: Christoph Strobl Date: Tue, 28 Jan 2025 09:51:46 +0100 Subject: [PATCH 5/5] Enforce non null contract on vector elements Also add shortcuts for empty sources and simplify copy calls that do not transform source values. --- .../data/domain/DoubleVector.java | 21 ++++---- .../data/domain/FloatVector.java | 21 ++++---- .../data/domain/NumberVector.java | 43 +++++++++-------- .../springframework/data/domain/Vector.java | 5 +- .../data/domain/FloatVectorUnitTests.java | 2 +- .../data/domain/NumberVectorUnitTests.java | 48 +++++++++++++++++-- .../data/domain/VectorUnitTests.java | 36 ++++++++++++++ 7 files changed, 130 insertions(+), 46 deletions(-) create mode 100644 src/test/java/org/springframework/data/domain/VectorUnitTests.java diff --git a/src/main/java/org/springframework/data/domain/DoubleVector.java b/src/main/java/org/springframework/data/domain/DoubleVector.java index 623d6a0118..b3669c4f74 100644 --- a/src/main/java/org/springframework/data/domain/DoubleVector.java +++ b/src/main/java/org/springframework/data/domain/DoubleVector.java @@ -1,5 +1,5 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -30,7 +30,7 @@ class DoubleVector implements Vector { private final double[] v; - public DoubleVector(double[] v) { + DoubleVector(double[] v) { this.v = v; } @@ -39,10 +39,11 @@ public DoubleVector(double[] v) { */ static Vector copy(double[] v) { - double[] copy = new double[v.length]; - System.arraycopy(v, 0, copy, 0, copy.length); + if (v.length == 0) { + return new DoubleVector(new double[0]); + } - return new DoubleVector(copy); + return new DoubleVector(Arrays.copyOf(v, v.length)); } /** @@ -50,6 +51,10 @@ static Vector copy(double[] v) { */ static Vector copy(Collection v) { + if (v.isEmpty()) { + return new DoubleVector(new double[0]); + } + double[] copy = new double[v.size()]; int i = 0; for (Number number : v) { @@ -87,11 +92,7 @@ public float[] toFloatArray() { @Override public double[] toDoubleArray() { - - double[] copy = new double[this.v.length]; - System.arraycopy(this.v, 0, copy, 0, copy.length); - - return copy; + return Arrays.copyOf(this.v, this.v.length); } @Override diff --git a/src/main/java/org/springframework/data/domain/FloatVector.java b/src/main/java/org/springframework/data/domain/FloatVector.java index bb07df19cd..d2d85f4388 100644 --- a/src/main/java/org/springframework/data/domain/FloatVector.java +++ b/src/main/java/org/springframework/data/domain/FloatVector.java @@ -1,5 +1,5 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -30,7 +30,7 @@ class FloatVector implements Vector { private final float[] v; - public FloatVector(float[] v) { + FloatVector(float[] v) { this.v = v; } @@ -39,10 +39,11 @@ public FloatVector(float[] v) { */ static Vector copy(float[] v) { - float[] copy = new float[v.length]; - System.arraycopy(v, 0, copy, 0, copy.length); + if (v.length == 0) { + return new FloatVector(new float[0]); + } - return new FloatVector(copy); + return new FloatVector(Arrays.copyOf(v, v.length)); } /** @@ -50,6 +51,10 @@ static Vector copy(float[] v) { */ static Vector copy(Collection v) { + if (v.isEmpty()) { + return new FloatVector(new float[0]); + } + float[] copy = new float[v.size()]; int i = 0; for (Number number : v) { @@ -76,11 +81,7 @@ public int size() { @Override public float[] toFloatArray() { - - float[] copy = new float[this.v.length]; - System.arraycopy(this.v, 0, copy, 0, copy.length); - - return copy; + return Arrays.copyOf(this.v, this.v.length); } @Override diff --git a/src/main/java/org/springframework/data/domain/NumberVector.java b/src/main/java/org/springframework/data/domain/NumberVector.java index b71dd7081e..528916e865 100644 --- a/src/main/java/org/springframework/data/domain/NumberVector.java +++ b/src/main/java/org/springframework/data/domain/NumberVector.java @@ -1,5 +1,5 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,6 +18,7 @@ import java.util.Arrays; import java.util.Collection; +import org.springframework.util.Assert; import org.springframework.util.ObjectUtils; /** @@ -30,7 +31,9 @@ class NumberVector implements Vector { private final Number[] v; - public NumberVector(Number[] v) { + NumberVector(Number[] v) { + + Assert.noNullElements(v, "Vector [v] must not contain null elements"); this.v = v; } @@ -39,41 +42,39 @@ public NumberVector(Number[] v) { */ static Vector copy(Number[] v) { - Number[] copy = new Number[v.length]; - System.arraycopy(v, 0, copy, 0, copy.length); + if (v.length == 0) { + return new NumberVector(new Number[0]); + } - return new NumberVector(copy); + return new NumberVector(Arrays.copyOf(v, v.length)); } /** * Copy the given {@link Number} and wrap it within a Vector. */ - static Vector copy(Collection numbers) { - - Number[] copy = new Number[numbers.size()]; + static Vector copy(Collection v) { - int i = 0; - for (Number number : numbers) { - copy[i++] = number; + if (v.isEmpty()) { + return new NumberVector(new Number[0]); } - return new NumberVector(copy); + return new NumberVector(v.toArray(Number[]::new)); } @Override public Class getType() { - Class candidate = null; - for (Object val : v) { - if (val != null) { - if (candidate == null) { - candidate = val.getClass(); - } else if (candidate != val.getClass()) { - return Number.class; - } + if (this.v.length == 0) { + return Number.class; + } + + Class candidate = this.v[0].getClass(); + for (int i = 1; i < this.v.length; i++) { + if (candidate != this.v[i].getClass()) { + return Number.class; } } - return (Class) candidate; + return candidate; } @Override diff --git a/src/main/java/org/springframework/data/domain/Vector.java b/src/main/java/org/springframework/data/domain/Vector.java index d1e0cfa251..db434e0f75 100644 --- a/src/main/java/org/springframework/data/domain/Vector.java +++ b/src/main/java/org/springframework/data/domain/Vector.java @@ -1,5 +1,5 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -75,6 +75,9 @@ static Vector of(double... values) { static Vector of(Collection values) { Assert.notNull(values, "Vector values must not be null"); + if(values.isEmpty()) { + return NumberVector.copy(new Number[0]); + } Class cet = CollectionUtils.findCommonElementType(values); diff --git a/src/test/java/org/springframework/data/domain/FloatVectorUnitTests.java b/src/test/java/org/springframework/data/domain/FloatVectorUnitTests.java index bef96f4e81..c58d5d047b 100644 --- a/src/test/java/org/springframework/data/domain/FloatVectorUnitTests.java +++ b/src/test/java/org/springframework/data/domain/FloatVectorUnitTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/src/test/java/org/springframework/data/domain/NumberVectorUnitTests.java b/src/test/java/org/springframework/data/domain/NumberVectorUnitTests.java index 644113f41d..ba730de0fc 100644 --- a/src/test/java/org/springframework/data/domain/NumberVectorUnitTests.java +++ b/src/test/java/org/springframework/data/domain/NumberVectorUnitTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,9 +15,12 @@ */ package org.springframework.data.domain; -import static org.assertj.core.api.Assertions.*; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import java.util.ArrayList; import java.util.Arrays; +import java.util.List; import org.junit.jupiter.api.Test; @@ -25,12 +28,40 @@ * Unit tests for {@link NumberVector}. * * @author Mark Paluch + * @author Christoph Strobl */ class NumberVectorUnitTests { Number[] values = new Number[] { 1.1, 2.2, 3.3, 4.4, 5.5, 6.6f }; Number[] floats = new Number[] { (float) 1.1d, (float) 2.2d, (float) 3.3d, (float) 4.4d, (float) 5.5, 6.6 }; + @Test // GH-3193 + void shouldErrorOnNullElements() { + + List source = new ArrayList<>(3); + source.add(1L); + source.add(null); + source.add(3L); + + assertThatExceptionOfType(IllegalArgumentException.class).isThrownBy(() -> NumberVector.copy(source)); + assertThatExceptionOfType(IllegalArgumentException.class) + .isThrownBy(() -> NumberVector.copy(new Number[] { 1L, null, 3L })); + } + + @Test // GH-3193 + void shouldAcceptEmptySource() { + + Vector vector = NumberVector.copy(List.of()); + + assertThat(vector.size()).isEqualTo(0); + assertThat(vector.getType()).isEqualTo(Number.class); + + vector = NumberVector.copy(new Number[] {}); + + assertThat(vector.size()).isEqualTo(0); + assertThat(vector.getType()).isEqualTo(Number.class); + } + @Test // GH-3193 void shouldCreateVector() { @@ -48,6 +79,17 @@ void shouldCopyVectorValues() { assertThat(vector.getSource()).isNotSameAs(vector).isEqualTo(values); } + @Test // GH-3193 + void shouldFigureOutCommonType() { + + assertThat(NumberVector.copy(List.of()).getType()).isEqualTo(Number.class); + assertThat(NumberVector.copy(List.of(1)).getType()).isEqualTo(Integer.class); + assertThat(NumberVector.copy(List.of(1L, 2L)).getType()).isEqualTo(Long.class); + assertThat(NumberVector.copy(List.of(1F, 2F)).getType()).isEqualTo(Float.class); + assertThat(NumberVector.copy(List.of(1D, 2D)).getType()).isEqualTo(Double.class); + assertThat(NumberVector.copy(List.of(1D, 2F, 3F)).getType()).isEqualTo(Number.class); + } + @Test // GH-3193 void shouldRenderToString() { @@ -66,7 +108,7 @@ void shouldCompareVector() { } @Test // GH-3193 - void sourceShouldReturnSource() { + void sourceShouldReturnSource() { // this one is questionable Vector vector = new NumberVector(values); diff --git a/src/test/java/org/springframework/data/domain/VectorUnitTests.java b/src/test/java/org/springframework/data/domain/VectorUnitTests.java new file mode 100644 index 0000000000..4f5f3b0248 --- /dev/null +++ b/src/test/java/org/springframework/data/domain/VectorUnitTests.java @@ -0,0 +1,36 @@ +/* + * Copyright 2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.domain; + +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; + +import java.util.Collection; + +import org.junit.jupiter.api.Test; + +/** + * @author Christoph Strobl + */ +public class VectorUnitTests { + + @Test // GH-3193 + void staticInitializersErrorOnNull() { + + assertThatExceptionOfType(IllegalArgumentException.class).isThrownBy(() -> Vector.of((double[]) null)); + assertThatExceptionOfType(IllegalArgumentException.class).isThrownBy(() -> Vector.of((float[]) null)); + assertThatExceptionOfType(IllegalArgumentException.class).isThrownBy(() -> Vector.of((Collection) null)); + } +}