From f8ae8d640c393dcaf162c5d34291442893664f8d Mon Sep 17 00:00:00 2001 From: Christoph Strobl Date: Fri, 25 Apr 2025 15:52:30 +0200 Subject: [PATCH 1/6] Prepare issue branch. --- pom.xml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pom.xml b/pom.xml index a6dc167a03..2607252e7c 100644 --- a/pom.xml +++ b/pom.xml @@ -5,7 +5,7 @@ org.springframework.data spring-data-commons - 4.0.0-SNAPSHOT + 4.0.x-GH-3279-SNAPSHOT Spring Data Core Core Spring concepts underpinning every Spring Data module. From 8e1b87aed771ec9501fcfc32816618096a1a5bcd Mon Sep 17 00:00:00 2001 From: Christoph Strobl Date: Thu, 24 Apr 2025 07:36:55 +0200 Subject: [PATCH 2/6] Refine Repository Composition retrieval during AOT Add module identifier and base repository implementation properties. Fix fragment function previously overriding already set property due to name clash. Extend tests for bean definition resolution and code block creation. --- .../aot/generate/AotRepositoryBuilder.java | 7 +- .../aot/generate/MethodContributor.java | 2 +- ...toryBeanDefinitionPropertiesDecorator.java | 2 +- .../config/AotRepositoryInformation.java | 36 +++- .../RepositoryBeanDefinitionReader.java | 133 ++++++++++---- ...RepositoryRegistrationAotContribution.java | 21 +-- .../core/RepositoryInformation.java | 5 + .../core/RepositoryInformationSupport.java | 2 +- .../support/RepositoryFactoryBeanSupport.java | 16 +- .../core/support/RepositoryFragment.java | 2 +- src/test/java/example/UserRepository.java | 2 +- .../java/example/UserRepositoryExtension.java | 25 +++ .../example/UserRepositoryExtensionImpl.java | 29 +++ .../AotRepositoryBuilderUnitTests.java | 157 ++++++++++++++++ .../AotRepositoryMethodBuilderUnitTests.java | 88 +++++++++ .../MethodCapturingRepositoryContributor.java | 57 ++++++ .../RepositoryContributorUnitTests.java | 167 +++++++++++++++++- .../RepositoryBeanDefinitionReaderTests.java | 122 +++++++++++++ 18 files changed, 800 insertions(+), 73 deletions(-) create mode 100644 src/test/java/example/UserRepositoryExtension.java create mode 100644 src/test/java/example/UserRepositoryExtensionImpl.java create mode 100644 src/test/java/org/springframework/data/repository/aot/generate/AotRepositoryBuilderUnitTests.java create mode 100644 src/test/java/org/springframework/data/repository/aot/generate/AotRepositoryMethodBuilderUnitTests.java create mode 100644 src/test/java/org/springframework/data/repository/aot/generate/MethodCapturingRepositoryContributor.java create mode 100644 src/test/java/org/springframework/data/repository/config/RepositoryBeanDefinitionReaderTests.java diff --git a/src/main/java/org/springframework/data/repository/aot/generate/AotRepositoryBuilder.java b/src/main/java/org/springframework/data/repository/aot/generate/AotRepositoryBuilder.java index c1ea88e7b1..199ca89f63 100644 --- a/src/main/java/org/springframework/data/repository/aot/generate/AotRepositoryBuilder.java +++ b/src/main/java/org/springframework/data/repository/aot/generate/AotRepositoryBuilder.java @@ -138,9 +138,8 @@ public AotBundle build() { this.customizer.customize(repositoryInformation, generationMetadata, builder); JavaFile javaFile = JavaFile.builder(packageName(), builder.build()).build(); - // TODO: module identifier AotRepositoryMetadata metadata = new AotRepositoryMetadata(repositoryInformation.getRepositoryInterface().getName(), - "", repositoryType, methodMetadata); + repositoryInformation.moduleName() != null ? repositoryInformation.moduleName() : "", repositoryType, methodMetadata); return new AotBundle(javaFile, metadata.toJson()); } @@ -148,14 +147,14 @@ public AotBundle build() { private void contributeMethod(Method method, RepositoryComposition repositoryComposition, List methodMetadata, TypeSpec.Builder builder) { - if (repositoryInformation.isCustomMethod(method) || repositoryInformation.isBaseClassMethod(method)) { + if (repositoryInformation.isCustomMethod(method) || (repositoryInformation.isBaseClassMethod(method) && !repositoryInformation.isQueryMethod(method))) { RepositoryFragment fragment = repositoryComposition.findFragment(method); if (fragment != null) { methodMetadata.add(getFragmentMetadata(method, fragment)); + return; } - return; } if (method.isBridge() || method.isDefault() || java.lang.reflect.Modifier.isStatic(method.getModifiers())) { diff --git a/src/main/java/org/springframework/data/repository/aot/generate/MethodContributor.java b/src/main/java/org/springframework/data/repository/aot/generate/MethodContributor.java index cfd29faf02..b30b2fa5ab 100644 --- a/src/main/java/org/springframework/data/repository/aot/generate/MethodContributor.java +++ b/src/main/java/org/springframework/data/repository/aot/generate/MethodContributor.java @@ -36,7 +36,7 @@ public abstract class MethodContributor { private final M queryMethod; private final QueryMetadata metadata; - private MethodContributor(M queryMethod, QueryMetadata metadata) { + MethodContributor(M queryMethod, QueryMetadata metadata) { this.queryMethod = queryMethod; this.metadata = metadata; } diff --git a/src/main/java/org/springframework/data/repository/config/AotRepositoryBeanDefinitionPropertiesDecorator.java b/src/main/java/org/springframework/data/repository/config/AotRepositoryBeanDefinitionPropertiesDecorator.java index 1326ac4370..d25e0f1cb3 100644 --- a/src/main/java/org/springframework/data/repository/config/AotRepositoryBeanDefinitionPropertiesDecorator.java +++ b/src/main/java/org/springframework/data/repository/config/AotRepositoryBeanDefinitionPropertiesDecorator.java @@ -55,7 +55,7 @@ public CodeBlock decorate() { // bring in properties as usual builder.add(inheritedProperties.get()); - builder.add("beanDefinition.getPropertyValues().addPropertyValue(\"repositoryFragments\", new $T() {\n", + builder.add("beanDefinition.getPropertyValues().addPropertyValue(\"repositoryFragmentsFunction\", new $T() {\n", RepositoryFactoryBeanSupport.RepositoryFragmentsFunction.class); builder.indent(); builder.add("public $T getRepositoryFragments($T beanFactory, $T context) {\n", diff --git a/src/main/java/org/springframework/data/repository/config/AotRepositoryInformation.java b/src/main/java/org/springframework/data/repository/config/AotRepositoryInformation.java index c4ea580ab8..1ddcbde9a4 100644 --- a/src/main/java/org/springframework/data/repository/config/AotRepositoryInformation.java +++ b/src/main/java/org/springframework/data/repository/config/AotRepositoryInformation.java @@ -21,10 +21,12 @@ import java.util.Set; import java.util.function.Supplier; +import org.jspecify.annotations.Nullable; import org.springframework.data.repository.core.RepositoryInformation; import org.springframework.data.repository.core.RepositoryInformationSupport; import org.springframework.data.repository.core.RepositoryMetadata; import org.springframework.data.repository.core.support.RepositoryComposition; +import org.springframework.data.repository.core.support.RepositoryComposition.RepositoryFragments; import org.springframework.data.repository.core.support.RepositoryFragment; import org.springframework.data.util.Lazy; @@ -36,16 +38,31 @@ */ class AotRepositoryInformation extends RepositoryInformationSupport implements RepositoryInformation { + private final @Nullable String moduleName; private final Supplier>> fragments; - private Lazy baseComposition = Lazy.of(() -> { - return RepositoryComposition.of(RepositoryFragment.structural(getRepositoryBaseClass())); - }); - AotRepositoryInformation(Supplier repositoryMetadata, Supplier> repositoryBaseClass, - Supplier>> fragments) { + private final Lazy repositoryComposition; + private final Lazy baseComposition; + + AotRepositoryInformation(@Nullable String moduleName, Supplier repositoryMetadata, + Supplier> repositoryBaseClass, Supplier>> fragments) { super(repositoryMetadata, repositoryBaseClass); + + this.moduleName = moduleName; this.fragments = fragments; + + this.repositoryComposition = Lazy + .of(() -> RepositoryComposition.fromMetadata(getMetadata()).append(RepositoryFragments.from(getFragments()))); + + this.baseComposition = Lazy.of(() -> { + + RepositoryComposition targetRepoComposition = repositoryComposition.get(); + + return RepositoryComposition.of(RepositoryFragment.structural(getRepositoryBaseClass())) // + .withArgumentConverter(targetRepoComposition.getArgumentConverter()) // + .withMethodLookup(targetRepoComposition.getMethodLookup()); + }); } /** @@ -57,10 +74,9 @@ public Set> getFragments() { return new LinkedHashSet<>(fragments.get()); } - // Not required during AOT processing. @Override public boolean isCustomMethod(Method method) { - return false; + return repositoryComposition.get().findMethod(method).isPresent(); } @Override @@ -75,7 +91,11 @@ public Method getTargetClassMethod(Method method) { @Override public RepositoryComposition getRepositoryComposition() { - return baseComposition.get().append(RepositoryComposition.RepositoryFragments.from(fragments.get())); + return repositoryComposition.get(); } + @Override + public @Nullable String moduleName() { + return moduleName; + } } diff --git a/src/main/java/org/springframework/data/repository/config/RepositoryBeanDefinitionReader.java b/src/main/java/org/springframework/data/repository/config/RepositoryBeanDefinitionReader.java index 0ac1ae991a..1209903d3b 100644 --- a/src/main/java/org/springframework/data/repository/config/RepositoryBeanDefinitionReader.java +++ b/src/main/java/org/springframework/data/repository/config/RepositoryBeanDefinitionReader.java @@ -16,17 +16,21 @@ package org.springframework.data.repository.config; import java.util.ArrayList; -import java.util.Collection; -import java.util.Collections; import java.util.List; -import java.util.function.Supplier; -import java.util.stream.Collectors; +import org.springframework.beans.factory.config.BeanDefinition; import org.springframework.beans.factory.config.ConfigurableListableBeanFactory; +import org.springframework.beans.factory.config.ConstructorArgumentValues.ValueHolder; +import org.springframework.beans.factory.config.RuntimeBeanReference; +import org.springframework.beans.factory.support.RegisteredBean; +import org.springframework.core.ResolvableType; +import org.springframework.data.repository.CrudRepository; +import org.springframework.data.repository.PagingAndSortingRepository; import org.springframework.data.repository.core.RepositoryInformation; -import org.springframework.data.repository.core.support.DefaultRepositoryMetadata; +import org.springframework.data.repository.core.RepositoryMetadata; +import org.springframework.data.repository.core.support.AbstractRepositoryMetadata; import org.springframework.data.repository.core.support.RepositoryFragment; -import org.springframework.data.util.Lazy; +import org.springframework.data.repository.core.support.RepositoryFragment.ImplementedRepositoryFragment; import org.springframework.util.ClassUtils; /** @@ -38,49 +42,108 @@ */ class RepositoryBeanDefinitionReader { - static RepositoryInformation readRepositoryInformation(RepositoryConfiguration metadata, - ConfigurableListableBeanFactory beanFactory) { - - return new AotRepositoryInformation(metadataSupplier(metadata, beanFactory), - repositoryBaseClass(metadata, beanFactory), fragments(metadata, beanFactory)); + /** + * @return + */ + static RepositoryInformation repositoryInformation(RepositoryConfiguration repoConfig, RegisteredBean repoBean) { + return repositoryInformation(repoConfig, repoBean.getMergedBeanDefinition(), repoBean.getBeanFactory()); } - private static Supplier>> fragments(RepositoryConfiguration metadata, + /** + * @param source the RepositoryFactoryBeanSupport bean definition. + * @param beanFactory + * @return + */ + @SuppressWarnings("NullAway") + static RepositoryInformation repositoryInformation(RepositoryConfiguration repoConfig, BeanDefinition source, ConfigurableListableBeanFactory beanFactory) { - if (metadata instanceof RepositoryFragmentConfigurationProvider provider) { - - return Lazy.of(() -> { - return provider.getFragmentConfiguration().stream().flatMap(it -> { - - List> fragments = new ArrayList<>(1); + RepositoryMetadata metadata = AbstractRepositoryMetadata + .getMetadata(forName(repoConfig.getRepositoryInterface(), beanFactory)); + Class repositoryBaseClass = readRepositoryBaseClass(source, beanFactory); + List> fragmentList = readRepositoryFragments(source, beanFactory); + if (source.getPropertyValues().contains("customImplementation")) { + + Object o = source.getPropertyValues().get("customImplementation"); + if (o instanceof RuntimeBeanReference rbr) { + BeanDefinition customImplBeanDefintion = beanFactory.getBeanDefinition(rbr.getBeanName()); + Class beanType = forName(customImplBeanDefintion.getBeanClassName(), beanFactory); + ResolvableType[] interfaces = ResolvableType.forClass(beanType).getInterfaces(); + if (interfaces.length == 1) { + fragmentList.add(new ImplementedRepositoryFragment(interfaces[0].toClass(), beanType)); + } else { + boolean found = false; + for (ResolvableType i : interfaces) { + if (beanType.getSimpleName().contains(i.resolve().getSimpleName())) { + fragmentList.add(new ImplementedRepositoryFragment(interfaces[0].toClass(), beanType)); + found = true; + break; + } + } + if (!found) { + fragmentList.add(RepositoryFragment.implemented(beanType)); + } + } + } + } - fragments.add(RepositoryFragment.implemented(forName(it.getClassName(), beanFactory))); + String moduleName = (String) source.getPropertyValues().get("moduleName"); + AotRepositoryInformation repositoryInformation = new AotRepositoryInformation(moduleName, () -> metadata, + () -> repositoryBaseClass, () -> fragmentList); + return repositoryInformation; + } - if (it.getInterfaceName() != null) { - fragments.add(RepositoryFragment.structural(forName(it.getInterfaceName(), beanFactory))); - } + @SuppressWarnings("NullAway") + private static Class readRepositoryBaseClass(BeanDefinition source, ConfigurableListableBeanFactory beanFactory) { - return fragments.stream(); - }).collect(Collectors.toList()); - }); + Object repoBaseClassName = source.getPropertyValues().get("repositoryBaseClass"); + if (repoBaseClassName != null) { + return forName(repoBaseClassName.toString(), beanFactory); } - - return Lazy.of(Collections::emptyList); + if (source.getPropertyValues().contains("moduleBaseClass")) { + return forName((String) source.getPropertyValues().get("moduleBaseClass"), beanFactory); + } + return Dummy.class; } - @SuppressWarnings({ "rawtypes", "unchecked" }) - private static Supplier> repositoryBaseClass(RepositoryConfiguration metadata, + @SuppressWarnings("NullAway") + private static List> readRepositoryFragments(BeanDefinition source, ConfigurableListableBeanFactory beanFactory) { - return Lazy.of(() -> (Class) metadata.getRepositoryBaseClassName().map(it -> forName(it.toString(), beanFactory)) - .orElse(Object.class)); + RuntimeBeanReference beanReference = (RuntimeBeanReference) source.getPropertyValues().get("repositoryFragments"); + BeanDefinition fragments = beanFactory.getBeanDefinition(beanReference.getBeanName()); + + ValueHolder fragmentBeanNameList = fragments.getConstructorArgumentValues().getArgumentValue(0, List.class); + List fragmentBeanNames = (List) fragmentBeanNameList.getValue(); + + List> fragmentList = new ArrayList<>(); + for (String beanName : fragmentBeanNames) { + + BeanDefinition fragmentBeanDefinition = beanFactory.getBeanDefinition(beanName); + ValueHolder argumentValue = fragmentBeanDefinition.getConstructorArgumentValues().getArgumentValue(0, + String.class); + ValueHolder argumentValue1 = fragmentBeanDefinition.getConstructorArgumentValues().getArgumentValue(1, null, null, + null); + Object fragmentClassName = argumentValue.getValue(); + + try { + Class type = ClassUtils.forName(fragmentClassName.toString(), beanFactory.getBeanClassLoader()); + + if (argumentValue1 != null && argumentValue1.getValue() instanceof RuntimeBeanReference rbf) { + BeanDefinition implBeanDef = beanFactory.getBeanDefinition(rbf.getBeanName()); + Class implClass = ClassUtils.forName(implBeanDef.getBeanClassName(), beanFactory.getBeanClassLoader()); + fragmentList.add(new RepositoryFragment.ImplementedRepositoryFragment(type, implClass)); + } else { + fragmentList.add(RepositoryFragment.structural(type)); + } + } catch (ClassNotFoundException e) { + throw new RuntimeException(e); + } + } + return fragmentList; } - private static Supplier metadataSupplier( - RepositoryConfiguration metadata, ConfigurableListableBeanFactory beanFactory) { - return Lazy.of(() -> new DefaultRepositoryMetadata(forName(metadata.getRepositoryInterface(), beanFactory))); - } + static abstract class Dummy implements CrudRepository, PagingAndSortingRepository {} static Class forName(String name, ConfigurableListableBeanFactory beanFactory) { try { diff --git a/src/main/java/org/springframework/data/repository/config/RepositoryRegistrationAotContribution.java b/src/main/java/org/springframework/data/repository/config/RepositoryRegistrationAotContribution.java index 92405a0aeb..40b2cc43a7 100644 --- a/src/main/java/org/springframework/data/repository/config/RepositoryRegistrationAotContribution.java +++ b/src/main/java/org/springframework/data/repository/config/RepositoryRegistrationAotContribution.java @@ -28,7 +28,6 @@ import java.util.function.Predicate; import org.jspecify.annotations.Nullable; - import org.springframework.aop.SpringProxy; import org.springframework.aop.framework.Advised; import org.springframework.aot.generate.GenerationContext; @@ -49,7 +48,6 @@ import org.springframework.data.repository.Repository; import org.springframework.data.repository.aot.generate.RepositoryContributor; import org.springframework.data.repository.core.RepositoryInformation; -import org.springframework.data.repository.core.support.RepositoryFactoryBeanSupport; import org.springframework.data.repository.core.support.RepositoryFragment; import org.springframework.data.util.Predicates; import org.springframework.data.util.QTypeContributor; @@ -90,8 +88,7 @@ public class RepositoryRegistrationAotContribution implements BeanRegistrationAo * @throws IllegalArgumentException if the {@link RepositoryRegistrationAotProcessor} is {@literal null}. * @see RepositoryRegistrationAotProcessor */ - protected RepositoryRegistrationAotContribution( - RepositoryRegistrationAotProcessor processor) { + protected RepositoryRegistrationAotContribution(RepositoryRegistrationAotProcessor processor) { Assert.notNull(processor, "RepositoryRegistrationAotProcessor must not be null"); @@ -108,8 +105,7 @@ protected RepositoryRegistrationAotContribution( * @throws IllegalArgumentException if the {@link RepositoryRegistrationAotProcessor} is {@literal null}. * @see RepositoryRegistrationAotProcessor */ - public static RepositoryRegistrationAotContribution fromProcessor( - RepositoryRegistrationAotProcessor processor) { + public static RepositoryRegistrationAotContribution fromProcessor(RepositoryRegistrationAotProcessor processor) { return new RepositoryRegistrationAotContribution(processor); } @@ -255,7 +251,8 @@ private void contributeRepositoryInfo(AotRepositoryContext repositoryContext, Ge }); implementation.ifPresent(impl -> { - contribution.getRuntimeHints().reflection().registerType(impl.getClass(), hint -> { + Class typeToRegister = impl instanceof Class c ? c : impl.getClass(); + contribution.getRuntimeHints().reflection().registerType(typeToRegister, hint -> { hint.withMembers(MemberCategory.INVOKE_PUBLIC_METHODS); @@ -365,18 +362,16 @@ public Predicate> typeFilter() { // like only document ones. // TODO: A @SuppressWarnings("rawtypes") private DefaultAotRepositoryContext buildAotRepositoryContext(RegisteredBean bean, - RepositoryConfiguration repositoryMetadata) { + RepositoryConfiguration repositoryConfiguration) { DefaultAotRepositoryContext repositoryContext = new DefaultAotRepositoryContext( AotContext.from(getBeanFactory(), getRepositoryRegistrationAotProcessor().getEnvironment())); - RepositoryFactoryBeanSupport rfbs = bean.getBeanFactory().getBean("&" + bean.getBeanName(), - RepositoryFactoryBeanSupport.class); - repositoryContext.setBeanName(bean.getBeanName()); - repositoryContext.setBasePackages(repositoryMetadata.getBasePackages().toSet()); + repositoryContext.setBasePackages(repositoryConfiguration.getBasePackages().toSet()); repositoryContext.setIdentifyingAnnotations(resolveIdentifyingAnnotations()); - repositoryContext.setRepositoryInformation(rfbs.getRepositoryInformation()); + repositoryContext + .setRepositoryInformation(RepositoryBeanDefinitionReader.repositoryInformation(repositoryConfiguration, bean)); return repositoryContext; } diff --git a/src/main/java/org/springframework/data/repository/core/RepositoryInformation.java b/src/main/java/org/springframework/data/repository/core/RepositoryInformation.java index e3f77cc339..3ebee41f24 100644 --- a/src/main/java/org/springframework/data/repository/core/RepositoryInformation.java +++ b/src/main/java/org/springframework/data/repository/core/RepositoryInformation.java @@ -18,6 +18,7 @@ import java.lang.reflect.Method; import java.util.List; +import org.jspecify.annotations.Nullable; import org.springframework.data.repository.core.support.RepositoryComposition; /** @@ -105,4 +106,8 @@ default boolean hasQueryMethods() { */ RepositoryComposition getRepositoryComposition(); + default @Nullable String moduleName() { + return null; + } + } diff --git a/src/main/java/org/springframework/data/repository/core/RepositoryInformationSupport.java b/src/main/java/org/springframework/data/repository/core/RepositoryInformationSupport.java index 269563dc1f..94660dec28 100644 --- a/src/main/java/org/springframework/data/repository/core/RepositoryInformationSupport.java +++ b/src/main/java/org/springframework/data/repository/core/RepositoryInformationSupport.java @@ -184,7 +184,7 @@ protected boolean isQueryMethodCandidate(Method method) { return true; } - private RepositoryMetadata getMetadata() { + protected RepositoryMetadata getMetadata() { return metadata.get(); } diff --git a/src/main/java/org/springframework/data/repository/core/support/RepositoryFactoryBeanSupport.java b/src/main/java/org/springframework/data/repository/core/support/RepositoryFactoryBeanSupport.java index d2f449c6e5..a0f19c5fc2 100644 --- a/src/main/java/org/springframework/data/repository/core/support/RepositoryFactoryBeanSupport.java +++ b/src/main/java/org/springframework/data/repository/core/support/RepositoryFactoryBeanSupport.java @@ -95,6 +95,10 @@ public abstract class RepositoryFactoryBeanSupport, private @Nullable Lazy repository; private @Nullable RepositoryMetadata repositoryMetadata; + // AOT bean factory hint? + private @Nullable String moduleBaseClass; + private @Nullable String moduleName; + /** * Creates a new {@link RepositoryFactoryBeanSupport} for the given repository interface. * @@ -155,7 +159,7 @@ public void setCustomImplementation(Object customImplementation) { * @param repositoryFragments */ public void setRepositoryFragments(RepositoryFragments repositoryFragments) { - setRepositoryFragments(RepositoryFragmentsFunction.just(repositoryFragments)); + setRepositoryFragmentsFunction(RepositoryFragmentsFunction.just(repositoryFragments)); } /** @@ -165,7 +169,7 @@ public void setRepositoryFragments(RepositoryFragments repositoryFragments) { * @param fragmentsFunction * @since 4.0 */ - public void setRepositoryFragments(RepositoryFragmentsFunction fragmentsFunction) { + public void setRepositoryFragmentsFunction(RepositoryFragmentsFunction fragmentsFunction) { this.fragments.add(fragmentsFunction); } @@ -257,6 +261,14 @@ public void setApplicationEventPublisher(ApplicationEventPublisher publisher) { this.publisher = publisher; } + public void setModuleBaseClass(String moduleBaseClass) { + this.moduleBaseClass = moduleBaseClass; + } + + public void setModuleName(String moduleName) { + this.moduleName = moduleName; + } + @Override @SuppressWarnings("unchecked") public EntityInformation getEntityInformation() { diff --git a/src/main/java/org/springframework/data/repository/core/support/RepositoryFragment.java b/src/main/java/org/springframework/data/repository/core/support/RepositoryFragment.java index f89b80b847..7b34326a8a 100644 --- a/src/main/java/org/springframework/data/repository/core/support/RepositoryFragment.java +++ b/src/main/java/org/springframework/data/repository/core/support/RepositoryFragment.java @@ -265,7 +265,7 @@ public ImplementedRepositoryFragment(@Nullable Class interfaceClass, T implem Assert.notNull(implementation, "Implementation object must not be null"); - if (interfaceClass != null) { + if (interfaceClass != null && !(implementation instanceof Class)) { Assert .isTrue(ClassUtils.isAssignableValue(interfaceClass, implementation), diff --git a/src/test/java/example/UserRepository.java b/src/test/java/example/UserRepository.java index d87b9237ad..d9b35863ef 100644 --- a/src/test/java/example/UserRepository.java +++ b/src/test/java/example/UserRepository.java @@ -24,7 +24,7 @@ /** * @author Christoph Strobl */ -public interface UserRepository extends CrudRepository { +public interface UserRepository extends CrudRepository, UserRepositoryExtension { User findByFirstname(String firstname); diff --git a/src/test/java/example/UserRepositoryExtension.java b/src/test/java/example/UserRepositoryExtension.java new file mode 100644 index 0000000000..6123aed839 --- /dev/null +++ b/src/test/java/example/UserRepositoryExtension.java @@ -0,0 +1,25 @@ +/* + * Copyright 2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package example; + +import example.UserRepository.User; + +/** + * @author Christoph Strobl + */ +public interface UserRepositoryExtension { + User findUserByExtensionMethod(); +} diff --git a/src/test/java/example/UserRepositoryExtensionImpl.java b/src/test/java/example/UserRepositoryExtensionImpl.java new file mode 100644 index 0000000000..8e6ccb2419 --- /dev/null +++ b/src/test/java/example/UserRepositoryExtensionImpl.java @@ -0,0 +1,29 @@ +/* + * Copyright 2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package example; + +import example.UserRepository.User; + +/** + * @author Christoph Strobl + */ +public class UserRepositoryExtensionImpl implements UserRepositoryExtension { + + @Override + public User findUserByExtensionMethod() { + return null; + } +} diff --git a/src/test/java/org/springframework/data/repository/aot/generate/AotRepositoryBuilderUnitTests.java b/src/test/java/org/springframework/data/repository/aot/generate/AotRepositoryBuilderUnitTests.java new file mode 100644 index 0000000000..f57dc41c13 --- /dev/null +++ b/src/test/java/org/springframework/data/repository/aot/generate/AotRepositoryBuilderUnitTests.java @@ -0,0 +1,157 @@ +/* + * Copyright 2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.repository.aot.generate; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import example.UserRepository; +import example.UserRepository.User; + +import java.util.TimeZone; + +import javax.lang.model.element.Modifier; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.Mockito; +import org.springframework.data.geo.Metric; +import org.springframework.data.projection.SpelAwareProxyProjectionFactory; +import org.springframework.data.repository.core.RepositoryInformation; +import org.springframework.data.repository.query.QueryMethod; +import org.springframework.data.util.TypeInformation; +import org.springframework.javapoet.MethodSpec; +import org.springframework.javapoet.TypeName; +import org.springframework.stereotype.Repository; + +/** + * @author Christoph Strobl + */ +class AotRepositoryBuilderUnitTests { + + RepositoryInformation repositoryInformation; + + @BeforeEach + void beforeEach() { + + repositoryInformation = mock(RepositoryInformation.class); + doReturn(UserRepository.class).when(repositoryInformation).getRepositoryInterface(); + } + + @Test // GH-3279 + void writesClassSkeleton() { + + AotRepositoryBuilder repoBuilder = AotRepositoryBuilder.forRepository(repositoryInformation, + new SpelAwareProxyProjectionFactory()); + assertThat(repoBuilder.build().javaFile().toString()) + .contains("package %s;".formatted(UserRepository.class.getPackageName())) // same package as source repo + .contains("@Generated") // marked as generated source + .contains("public class %sImpl__Aot".formatted(UserRepository.class.getSimpleName())) // target name + .contains("public UserRepositoryImpl__Aot()"); // default constructor if not arguments to wire + } + + @Test // GH-3279 + void appliesCtorArguments() { + + AotRepositoryBuilder repoBuilder = AotRepositoryBuilder.forRepository(repositoryInformation, + new SpelAwareProxyProjectionFactory()); + repoBuilder.withConstructorCustomizer(ctor -> { + ctor.addParameter("param1", Metric.class); + ctor.addParameter("param2", String.class); + ctor.addParameter("ctorScoped", TypeName.OBJECT, false); + }); + assertThat(repoBuilder.build().javaFile().toString()) // + .contains("private final Metric param1;") // + .contains("private final String param2;") // + .doesNotContain("private final Object ctorScoped;") // + .contains("public UserRepositoryImpl__Aot(Metric param1, String param2, Object ctorScoped)") // + .contains("this.param1 = param1") // + .contains("this.param2 = param2") // + .doesNotContain("this.ctorScoped = ctorScoped"); + } + + @Test // GH-3279 + void appliesCtorCodeBlock() { + + AotRepositoryBuilder repoBuilder = AotRepositoryBuilder.forRepository(repositoryInformation, + new SpelAwareProxyProjectionFactory()); + repoBuilder.withConstructorCustomizer(ctor -> { + ctor.customize((info, code) -> { + code.addStatement("throw new $T($S)", IllegalStateException.class, "initialization error"); + }); + }); + assertThat(repoBuilder.build().javaFile().toString()).containsIgnoringWhitespaces( + "UserRepositoryImpl__Aot() { throw new IllegalStateException(\"initialization error\"); }"); + } + + @Test // GH-3279 + void appliesClassCustomizations() { + + AotRepositoryBuilder repoBuilder = AotRepositoryBuilder.forRepository(repositoryInformation, + new SpelAwareProxyProjectionFactory()); + + repoBuilder.withClassCustomizer((info, metadata, clazz) -> { + + clazz.addField(Float.class, "f", Modifier.PRIVATE, Modifier.STATIC); + clazz.addField(Double.class, "d", Modifier.PUBLIC); + clazz.addField(TimeZone.class, "t", Modifier.FINAL); + + clazz.addAnnotation(Repository.class); + + clazz.addMethod(MethodSpec.methodBuilder("oops").build()); + }); + + assertThat(repoBuilder.build().javaFile().toString()) // + .contains("@Repository") // + .contains("private static Float f;") // + .contains("public Double d;") // + .contains("final TimeZone t;") // + .containsIgnoringWhitespaces("void oops() { }"); + } + + @Test // GH-3279 + void appliesQueryMethodContributor() { + + AotRepositoryBuilder repoBuilder = AotRepositoryBuilder.forRepository(repositoryInformation, + new SpelAwareProxyProjectionFactory()); + + when(repositoryInformation.isQueryMethod(Mockito.argThat(arg -> arg.getName().equals("findByFirstname")))) + .thenReturn(true); + doReturn(TypeInformation.of(User.class)).when(repositoryInformation).getReturnType(any()); + + repoBuilder.withQueryMethodContributor((method, info) -> { + + return new MethodContributor<>(mock(QueryMethod.class), null) { + + @Override + public MethodSpec contribute(AotQueryMethodGenerationContext context) { + return MethodSpec.methodBuilder("oops").build(); + } + + @Override + public boolean contributesMethodSpec() { + return true; + } + }; + }); + + assertThat(repoBuilder.build().javaFile().toString()) // + .containsIgnoringWhitespaces("void oops() { }"); + } +} diff --git a/src/test/java/org/springframework/data/repository/aot/generate/AotRepositoryMethodBuilderUnitTests.java b/src/test/java/org/springframework/data/repository/aot/generate/AotRepositoryMethodBuilderUnitTests.java new file mode 100644 index 0000000000..b0f19b807e --- /dev/null +++ b/src/test/java/org/springframework/data/repository/aot/generate/AotRepositoryMethodBuilderUnitTests.java @@ -0,0 +1,88 @@ +/* + * Copyright 2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.repository.aot.generate; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.when; + +import example.UserRepository; +import example.UserRepository.User; + +import java.lang.reflect.Method; +import java.util.List; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.Mockito; +import org.springframework.core.ResolvableType; +import org.springframework.data.repository.core.RepositoryInformation; +import org.springframework.data.util.TypeInformation; +import org.springframework.javapoet.ParameterSpec; +import org.springframework.javapoet.ParameterizedTypeName; + +/** + * @author Christoph Strobl + */ +class AotRepositoryMethodBuilderUnitTests { + + RepositoryInformation repositoryInformation; + AotQueryMethodGenerationContext methodGenerationContext; + + @BeforeEach + void beforeEach() { + repositoryInformation = Mockito.mock(RepositoryInformation.class); + methodGenerationContext = Mockito.mock(AotQueryMethodGenerationContext.class); + + when(methodGenerationContext.getRepositoryInformation()).thenReturn(repositoryInformation); + } + + @Test // GH-3279 + void generatesMethodSkeletonBasedOnGenerationMetadata() throws NoSuchMethodException { + + Method method = UserRepository.class.getMethod("findByFirstname", String.class); + when(methodGenerationContext.getMethod()).thenReturn(method); + when(methodGenerationContext.getReturnType()).thenReturn(ResolvableType.forClass(User.class)); + doReturn(TypeInformation.of(User.class)).when(repositoryInformation).getReturnType(any()); + MethodMetadata methodMetadata = new MethodMetadata(repositoryInformation, method); + methodMetadata.addParameter(ParameterSpec.builder(String.class, "firstname").build()); + when(methodGenerationContext.getTargetMethodMetadata()).thenReturn(methodMetadata); + + AotRepositoryMethodBuilder builder = new AotRepositoryMethodBuilder(methodGenerationContext); + assertThat(builder.buildMethod().toString()) // + .containsPattern("public .*User findByFirstname\\(.*String firstname\\)"); + } + + @Test // GH-3279 + void generatesMethodWithGenerics() throws NoSuchMethodException { + + Method method = UserRepository.class.getMethod("findByFirstnameIn", List.class); + when(methodGenerationContext.getMethod()).thenReturn(method); + when(methodGenerationContext.getReturnType()) + .thenReturn(ResolvableType.forClassWithGenerics(List.class, User.class)); + doReturn(TypeInformation.of(User.class)).when(repositoryInformation).getReturnType(any()); + MethodMetadata methodMetadata = new MethodMetadata(repositoryInformation, method); + methodMetadata + .addParameter(ParameterSpec.builder(ParameterizedTypeName.get(List.class, String.class), "firstnames").build()); + when(methodGenerationContext.getTargetMethodMetadata()).thenReturn(methodMetadata); + + AotRepositoryMethodBuilder builder = new AotRepositoryMethodBuilder(methodGenerationContext); + assertThat(builder.buildMethod().toString()) // + .containsPattern("public .*List<.*User> findByFirstnameIn\\(") // + .containsPattern(".*List<.*String> firstnames\\)"); + } +} diff --git a/src/test/java/org/springframework/data/repository/aot/generate/MethodCapturingRepositoryContributor.java b/src/test/java/org/springframework/data/repository/aot/generate/MethodCapturingRepositoryContributor.java new file mode 100644 index 0000000000..033c7fbe18 --- /dev/null +++ b/src/test/java/org/springframework/data/repository/aot/generate/MethodCapturingRepositoryContributor.java @@ -0,0 +1,57 @@ +/* + * Copyright 2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.repository.aot.generate; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.lang.reflect.Method; +import java.util.List; + +import org.assertj.core.api.MapAssert; +import org.jspecify.annotations.Nullable; +import org.springframework.data.repository.config.AotRepositoryContext; +import org.springframework.data.repository.core.RepositoryInformation; +import org.springframework.data.repository.query.QueryMethod; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; + +/** + * @author Christoph Strobl + */ +public class MethodCapturingRepositoryContributor extends RepositoryContributor { + + MultiValueMap capturedInvocations; + + public MethodCapturingRepositoryContributor(AotRepositoryContext repositoryContext) { + super(repositoryContext); + this.capturedInvocations = new LinkedMultiValueMap<>(3); + } + + @Override + protected @Nullable MethodContributor contributeQueryMethod(Method method, + RepositoryInformation repositoryInformation) { + capturedInvocations.add(method.getName(), method); + return null; + } + + void verifyContributionFor(String methodName) { + assertThat(capturedInvocations).containsKey(methodName); + } + + MapAssert> verifyContributedMethods() { + return assertThat(capturedInvocations); + } +} diff --git a/src/test/java/org/springframework/data/repository/aot/generate/RepositoryContributorUnitTests.java b/src/test/java/org/springframework/data/repository/aot/generate/RepositoryContributorUnitTests.java index b77ac6346e..133281fe0c 100644 --- a/src/test/java/org/springframework/data/repository/aot/generate/RepositoryContributorUnitTests.java +++ b/src/test/java/org/springframework/data/repository/aot/generate/RepositoryContributorUnitTests.java @@ -15,31 +15,45 @@ */ package org.springframework.data.repository.aot.generate; -import static org.assertj.core.api.Assertions.*; +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.argThat; +import static org.mockito.Mockito.when; import example.UserRepository; +import example.UserRepositoryExtension; +import example.UserRepositoryExtensionImpl; import java.lang.reflect.Method; import java.util.Map; +import java.util.Optional; +import java.util.Set; import org.jspecify.annotations.Nullable; import org.junit.jupiter.api.Test; - +import org.mockito.Mockito; import org.springframework.aot.test.generate.TestGenerationContext; import org.springframework.core.test.tools.TestCompiler; import org.springframework.data.aot.CodeContributionAssert; +import org.springframework.data.repository.CrudRepository; +import org.springframework.data.repository.config.AotRepositoryContext; import org.springframework.data.repository.core.RepositoryInformation; +import org.springframework.data.repository.core.support.RepositoryComposition; +import org.springframework.data.repository.core.support.RepositoryComposition.RepositoryFragments; +import org.springframework.data.repository.core.support.RepositoryFragment; import org.springframework.data.repository.query.QueryMethod; import org.springframework.javapoet.CodeBlock; import org.springframework.util.ClassUtils; /** + * Unit tests targeting {@link RepositoryContributor}. + * * @author Christoph Strobl */ class RepositoryContributorUnitTests { - @Test - void testCompile() { + @Test // GH-3279 + void createsCompilableClassStub() { DummyModuleAotRepositoryContext aotContext = new DummyModuleAotRepositoryContext(UserRepository.class, null); RepositoryContributor repositoryContributor = new RepositoryContributor(aotContext) { @@ -55,8 +69,7 @@ void testCompile() { public Map serialize() { return Map.of(); } - }) - .contribute(context -> { + }).contribute(context -> { CodeBlock.Builder builder = CodeBlock.builder(); if (!ClassUtils.isVoidType(method.getReturnType())) { @@ -81,4 +94,146 @@ public Map serialize() { new CodeContributionAssert(generationContext).contributesReflectionFor(expectedTypeName); } + @Test // GH-3279 + void callsMethodContributionForQueryMethod() { + + AotRepositoryContext repositoryContext = Mockito.mock(AotRepositoryContext.class); + RepositoryInformation repositoryInformation = Mockito.mock(RepositoryInformation.class); + + when(repositoryContext.getRepositoryInformation()).thenReturn(repositoryInformation); + when(repositoryInformation.getRepositoryInterface()).thenReturn((Class) UserRepository.class); + when(repositoryInformation.isQueryMethod(argThat(it -> it.getName().equals("findByFirstname")))).thenReturn(true); + + MethodCapturingRepositoryContributor contributor = new MethodCapturingRepositoryContributor(repositoryContext); + contributor.contribute(new TestGenerationContext(UserRepository.class)); + + contributor.verifyContributionFor("findByFirstname"); + } + + @Test // GH-3279 + void doesNotContributeBaseClassMethods() { + + AotRepositoryContext repositoryContext = Mockito.mock(AotRepositoryContext.class); + RepositoryInformation repositoryInformation = Mockito.mock(RepositoryInformation.class); + + when(repositoryContext.getRepositoryInformation()).thenReturn(repositoryInformation); + when(repositoryInformation.getRepositoryInterface()).thenReturn((Class) UserRepository.class); + when(repositoryInformation.getRepositoryComposition()) + .thenReturn(RepositoryComposition.of(RepositoryFragment.structural(RepoBaseClass.class))); + when(repositoryInformation.isBaseClassMethod(argThat(it -> it.getName().equals("findByFirstname")))) + .thenReturn(true); + when(repositoryInformation.isQueryMethod(argThat(it -> !it.getName().equals("findByFirstname")))).thenReturn(true); + + MethodCapturingRepositoryContributor contributor = new MethodCapturingRepositoryContributor(repositoryContext); + contributor.contribute(new TestGenerationContext(UserRepository.class)); + + contributor.verifyContributedMethods().isNotEmpty().doesNotContainKey("findByFirstname"); + } + + @Test // GH-3279 + void doesNotContributeFragmentMethod() { + + AotRepositoryContext repositoryContext = Mockito.mock(AotRepositoryContext.class); + RepositoryInformation repositoryInformation = Mockito.mock(RepositoryInformation.class); + + when(repositoryContext.getRepositoryInformation()).thenReturn(repositoryInformation); + when(repositoryInformation.getRepositoryInterface()).thenReturn((Class) UserRepository.class); + when(repositoryInformation.getRepositoryComposition()) + .thenReturn(RepositoryComposition.of(RepositoryFragment.structural(UserRepository.class)) + .append(RepositoryFragments + .from(Set.of(new RepositoryFragment.ImplementedRepositoryFragment(UserRepositoryExtension.class, + UserRepositoryExtensionImpl.class))))); + + when(repositoryInformation.isCustomMethod(argThat(it -> it.getName().equals("findUserByExtensionMethod")))) + .thenReturn(true); + when(repositoryInformation.isQueryMethod(argThat(it -> it.getName().equals("findByFirstname")))).thenReturn(true); + + MethodCapturingRepositoryContributor contributor = new MethodCapturingRepositoryContributor(repositoryContext); + contributor.contribute(new TestGenerationContext(UserRepository.class)); + + contributor.verifyContributedMethods().isNotEmpty().doesNotContainKey("findUserByExtensionMethod"); + } + + @Test // GH-3279 + void contributesBaseClassMethodIfQueryMethod() { + + AotRepositoryContext repositoryContext = Mockito.mock(AotRepositoryContext.class); + RepositoryInformation repositoryInformation = Mockito.mock(RepositoryInformation.class); + + when(repositoryContext.getRepositoryInformation()).thenReturn(repositoryInformation); + when(repositoryInformation.getRepositoryInterface()).thenReturn((Class) UserRepository.class); + when(repositoryInformation.getRepositoryComposition()) + .thenReturn(RepositoryComposition.of(RepositoryFragment.structural(RepoBaseClass.class))); + when(repositoryInformation.isBaseClassMethod(argThat(it -> it.getName().equals("findByFirstname")))) + .thenReturn(true); + when(repositoryInformation.isQueryMethod(any())).thenReturn(true); + + MethodCapturingRepositoryContributor contributor = new MethodCapturingRepositoryContributor(repositoryContext); + contributor.contribute(new TestGenerationContext(UserRepository.class)); + + contributor.verifyContributedMethods().containsKey("findByFirstname").hasSizeGreaterThan(1); + } + + static class RepoBaseClass implements CrudRepository { + + private CrudRepository delegate; + + public S save(S entity) { + return this.delegate.save(entity); + } + + @Override + public Iterable saveAll(Iterable entities) { + return this.delegate.saveAll(entities); + } + + public Optional findById(ID id) { + return this.delegate.findById(id); + } + + @Override + public boolean existsById(ID id) { + return this.delegate.existsById(id); + } + + @Override + public Iterable findAll() { + return this.delegate.findAll(); + } + + @Override + public Iterable findAllById(Iterable ids) { + return this.delegate.findAllById(ids); + } + + @Override + public long count() { + return this.delegate.count(); + } + + @Override + public void deleteById(ID id) { + this.delegate.deleteById(id); + } + + @Override + public void delete(T entity) { + this.delegate.delete(entity); + } + + @Override + public void deleteAllById(Iterable ids) { + this.delegate.deleteAllById(ids); + } + + @Override + public void deleteAll(Iterable entities) { + this.delegate.deleteAll(entities); + } + + @Override + public void deleteAll() { + this.delegate.deleteAll(); + } + } } diff --git a/src/test/java/org/springframework/data/repository/config/RepositoryBeanDefinitionReaderTests.java b/src/test/java/org/springframework/data/repository/config/RepositoryBeanDefinitionReaderTests.java new file mode 100644 index 0000000000..54379865c1 --- /dev/null +++ b/src/test/java/org/springframework/data/repository/config/RepositoryBeanDefinitionReaderTests.java @@ -0,0 +1,122 @@ +/* + * Copyright 2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.repository.config; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Mockito.mock; + +import org.junit.jupiter.api.Test; +import org.mockito.Mockito; +import org.springframework.aot.hint.RuntimeHints; +import org.springframework.beans.factory.support.RegisteredBean; +import org.springframework.beans.factory.support.RootBeanDefinition; +import org.springframework.context.annotation.AnnotationConfigApplicationContext; +import org.springframework.data.aot.sample.ConfigWithCustomImplementation; +import org.springframework.data.aot.sample.ConfigWithCustomRepositoryBaseClass; +import org.springframework.data.aot.sample.ConfigWithCustomRepositoryBaseClass.CustomerRepositoryWithCustomBaseRepo; +import org.springframework.data.aot.sample.ConfigWithSimpleCrudRepository; +import org.springframework.data.repository.core.RepositoryInformation; +import org.springframework.data.repository.core.support.RepositoryFactoryBeanSupport; + +/** + * @author Christoph Strobl + */ +class RepositoryBeanDefinitionReaderTests { + + @Test // GH-3279 + void readsSimpleConfigFromBeanFactory() { + + RegisteredBean repoFactoryBean = repositoryFactory(ConfigWithSimpleCrudRepository.class); + + RepositoryConfiguration repoConfig = mock(RepositoryConfiguration.class); + Mockito.when(repoConfig.getRepositoryInterface()).thenReturn(ConfigWithSimpleCrudRepository.MyRepo.class.getName()); + + RepositoryInformation repositoryInformation = RepositoryBeanDefinitionReader.repositoryInformation(repoConfig, + repoFactoryBean.getMergedBeanDefinition(), repoFactoryBean.getBeanFactory()); + + assertThat(repositoryInformation.getRepositoryInterface()).isEqualTo(ConfigWithSimpleCrudRepository.MyRepo.class); + assertThat(repositoryInformation.getDomainType()).isEqualTo(ConfigWithSimpleCrudRepository.Person.class); + assertThat(repositoryInformation.getFragments()).isEmpty(); + } + + @Test // GH-3279 + void readsCustomRepoBaseClassFromBeanFactory() { + + RegisteredBean repoFactoryBean = repositoryFactory(ConfigWithCustomRepositoryBaseClass.class); + + RepositoryConfiguration repoConfig = mock(RepositoryConfiguration.class); + Class repositoryInterfaceType = CustomerRepositoryWithCustomBaseRepo.class; + Mockito.when(repoConfig.getRepositoryInterface()).thenReturn(repositoryInterfaceType.getName()); + + RepositoryInformation repositoryInformation = RepositoryBeanDefinitionReader.repositoryInformation(repoConfig, + repoFactoryBean.getMergedBeanDefinition(), repoFactoryBean.getBeanFactory()); + + assertThat(repositoryInformation.getRepositoryBaseClass()) + .isEqualTo(ConfigWithCustomRepositoryBaseClass.RepoBaseClass.class); + } + + @Test // GH-3279 + void readsFragmentsFromBeanFactory() { + + RegisteredBean repoFactoryBean = repositoryFactory(ConfigWithCustomImplementation.class); + + RepositoryConfiguration repoConfig = mock(RepositoryConfiguration.class); + Class repositoryInterfaceType = ConfigWithCustomImplementation.RepositoryWithCustomImplementation.class; + Mockito.when(repoConfig.getRepositoryInterface()).thenReturn(repositoryInterfaceType.getName()); + + RepositoryInformation repositoryInformation = RepositoryBeanDefinitionReader.repositoryInformation(repoConfig, + repoFactoryBean.getMergedBeanDefinition(), repoFactoryBean.getBeanFactory()); + + assertThat(repositoryInformation.getFragments()).satisfiesExactly(fragment -> { + assertThat(fragment.getSignatureContributor()) + .isEqualTo(ConfigWithCustomImplementation.CustomImplInterface.class); + }); + } + + @Test // GH-3279 + void fallsBackToModuleBaseClassIfSetAndNoRepoBaseDefined() { + + RegisteredBean repoFactoryBean = repositoryFactory(ConfigWithSimpleCrudRepository.class); + RootBeanDefinition rootBeanDefinition = repoFactoryBean.getMergedBeanDefinition().cloneBeanDefinition(); + // need to unset because its defined as non default + rootBeanDefinition.getPropertyValues().removePropertyValue("repositoryBaseClass"); + rootBeanDefinition.getPropertyValues().add("moduleBaseClass", ModuleBase.class.getName()); + + RepositoryConfiguration repoConfig = mock(RepositoryConfiguration.class); + Mockito.when(repoConfig.getRepositoryInterface()).thenReturn(ConfigWithSimpleCrudRepository.MyRepo.class.getName()); + + RepositoryInformation repositoryInformation = RepositoryBeanDefinitionReader.repositoryInformation(repoConfig, + rootBeanDefinition, repoFactoryBean.getBeanFactory()); + + assertThat(repositoryInformation.getRepositoryBaseClass()).isEqualTo(ModuleBase.class); + } + + static RegisteredBean repositoryFactory(Class configClass) { + + AnnotationConfigApplicationContext applicationContext = new AnnotationConfigApplicationContext(); + applicationContext.register(configClass); + applicationContext.refreshForAotProcessing(new RuntimeHints()); + + String[] beanNamesForType = applicationContext.getBeanNamesForType(RepositoryFactoryBeanSupport.class); + if (beanNamesForType.length != 1) { + throw new IllegalStateException("Unable to find repository FactoryBean"); + } + + return RegisteredBean.of(applicationContext.getBeanFactory(), beanNamesForType[0]); + } + + static class ModuleBase {} +} From 9f76af389a187849ca0cad84fc2512814e4fcbbc Mon Sep 17 00:00:00 2001 From: Christoph Strobl Date: Tue, 29 Apr 2025 10:30:23 +0200 Subject: [PATCH 3/6] Make NullAway go away. Ignore warnings for already checked constructs null away does not understand. --- src/main/java/org/springframework/data/mapping/Parameter.java | 1 + .../org/springframework/data/repository/query/Parameter.java | 1 + 2 files changed, 2 insertions(+) diff --git a/src/main/java/org/springframework/data/mapping/Parameter.java b/src/main/java/org/springframework/data/mapping/Parameter.java index bf6221faad..ca4afea4e6 100644 --- a/src/main/java/org/springframework/data/mapping/Parameter.java +++ b/src/main/java/org/springframework/data/mapping/Parameter.java @@ -118,6 +118,7 @@ public boolean hasName() { * @since 3.5 * @see org.springframework.core.ParameterNameDiscoverer */ + @SuppressWarnings("NullAway") public String getRequiredName() { if (!hasName()) { diff --git a/src/main/java/org/springframework/data/repository/query/Parameter.java b/src/main/java/org/springframework/data/repository/query/Parameter.java index 0907d0f035..b52cbb3df1 100644 --- a/src/main/java/org/springframework/data/repository/query/Parameter.java +++ b/src/main/java/org/springframework/data/repository/query/Parameter.java @@ -125,6 +125,7 @@ public boolean isDynamicProjectionParameter() { * * @return */ + @SuppressWarnings("NullAway") public String getPlaceholder() { if (isNamedParameter()) { From 22af17f432bd787e266c8971b524ca44ab3decbf Mon Sep 17 00:00:00 2001 From: Mark Paluch Date: Wed, 7 May 2025 12:21:15 +0200 Subject: [PATCH 4/6] Refine AOT composition detection. Associate Repository Bean Definition with RepositoryConfiguration and RepositoryConfigurationExtension attributes to capture configuration details such as the module name or the configuration source. Introduce RepositoryFragmentsContributor to provide an abstraction for structural fragment implementation allowing to describe the implementation type instead of requiring the implementation object. Obtain repository fragments from a RepositoryFragmentsContributor (either the configured one or one from a RepositoryFactoryBean). --- .../aot/generate/AotRepositoryBuilder.java | 103 ++++++--- .../aot/generate/RepositoryContributor.java | 31 ++- ...notationRepositoryConfigurationSource.java | 11 + .../config/AotRepositoryContext.java | 14 +- .../config/AotRepositoryInformation.java | 54 ++--- .../config/DefaultAotRepositoryContext.java | 58 ++--- .../DefaultRepositoryConfiguration.java | 9 +- .../RepositoryBeanDefinitionBuilder.java | 9 + .../RepositoryBeanDefinitionReader.java | 200 +++++++++++------- .../config/RepositoryConfiguration.java | 12 +- .../RepositoryConfigurationAdapter.java | 5 + .../RepositoryConfigurationExtension.java | 13 +- .../config/RepositoryConfigurationSource.java | 9 + ...RepositoryRegistrationAotContribution.java | 161 +++++++------- .../RepositoryRegistrationAotProcessor.java | 12 +- .../XmlRepositoryConfigurationSource.java | 5 + .../core/RepositoryInformation.java | 5 - .../support/DefaultRepositoryInformation.java | 7 +- .../support/RepositoryFactoryBeanSupport.java | 17 +- .../support/RepositoryFactoryInformation.java | 9 + .../core/support/RepositoryFragment.java | 57 +++-- .../RepositoryFragmentsContributor.java | 56 +++++ .../data/repository/support/Repositories.java | 6 + .../ConfigWithCustomRepositoryBaseClass.java | 4 +- .../ConfigWithSimpleCrudRepository.java | 4 +- ...istrationAotProcessorIntegrationTests.java | 4 +- .../AotRepositoryBuilderUnitTests.java | 68 ++++-- .../DummyModuleAotRepositoryContext.java | 5 + .../RepositoryContributorUnitTests.java | 28 +-- ...epositoryConfigurationSourceUnitTests.java | 38 +++- .../config/DummyRegistrarWithContributor.java | 40 ++++ .../EnableRepositoriesWithContributor.java | 61 ++++++ .../RepositoryBeanDefinitionReaderTests.java | 101 ++++++--- .../SampleRepositoryFragmentsContributor.java | 33 +++ .../support/DummyRepositoryFactoryBean.java | 14 ++ .../support/RepositoriesUnitTests.java | 6 + 36 files changed, 913 insertions(+), 356 deletions(-) create mode 100644 src/main/java/org/springframework/data/repository/core/support/RepositoryFragmentsContributor.java create mode 100644 src/test/java/org/springframework/data/repository/config/DummyRegistrarWithContributor.java create mode 100644 src/test/java/org/springframework/data/repository/config/EnableRepositoriesWithContributor.java create mode 100644 src/test/java/org/springframework/data/repository/config/SampleRepositoryFragmentsContributor.java diff --git a/src/main/java/org/springframework/data/repository/aot/generate/AotRepositoryBuilder.java b/src/main/java/org/springframework/data/repository/aot/generate/AotRepositoryBuilder.java index 199ca89f63..7ca6b536e4 100644 --- a/src/main/java/org/springframework/data/repository/aot/generate/AotRepositoryBuilder.java +++ b/src/main/java/org/springframework/data/repository/aot/generate/AotRepositoryBuilder.java @@ -41,6 +41,7 @@ import org.springframework.javapoet.ClassName; import org.springframework.javapoet.FieldSpec; import org.springframework.javapoet.JavaFile; +import org.springframework.javapoet.MethodSpec; import org.springframework.javapoet.TypeName; import org.springframework.javapoet.TypeSpec; @@ -53,6 +54,7 @@ class AotRepositoryBuilder { private final RepositoryInformation repositoryInformation; + private final String moduleName; private final ProjectionFactory projectionFactory; private final AotRepositoryFragmentMetadata generationMetadata; @@ -60,9 +62,11 @@ class AotRepositoryBuilder { private @Nullable BiFunction> methodContributorFunction; private ClassCustomizer customizer; - private AotRepositoryBuilder(RepositoryInformation repositoryInformation, ProjectionFactory projectionFactory) { + private AotRepositoryBuilder(RepositoryInformation repositoryInformation, String moduleName, + ProjectionFactory projectionFactory) { this.repositoryInformation = repositoryInformation; + this.moduleName = moduleName; this.projectionFactory = projectionFactory; this.generationMetadata = new AotRepositoryFragmentMetadata(className()); @@ -74,11 +78,37 @@ private AotRepositoryBuilder(RepositoryInformation repositoryInformation, Projec this.customizer = (info, metadata, builder) -> {}; } - public static AotRepositoryBuilder forRepository(RepositoryInformation repositoryInformation, + /** + * Create a new {@code AotRepositoryBuilder} for the given {@link RepositoryInformation}. + * + * @param information must not be {@literal null}. + * @param moduleName must not be {@literal null}. + * @param projectionFactory must not be {@literal null}. + * @return + */ + public static AotRepositoryBuilder forRepository(RepositoryInformation information, String moduleName, ProjectionFactory projectionFactory) { - return new AotRepositoryBuilder(repositoryInformation, projectionFactory); + return new AotRepositoryBuilder(information, moduleName, projectionFactory); } + /** + * Configure a {@link ClassCustomizer} customizer. + * + * @param classCustomizer must not be {@literal null}. + * @return {@code this}. + */ + public AotRepositoryBuilder withClassCustomizer(ClassCustomizer classCustomizer) { + + this.customizer = classCustomizer; + return this; + } + + /** + * Configure a {@link AotRepositoryConstructorBuilder} customizer. + * + * @param constructorCustomizer must not be {@literal null}. + * @return {@code this}. + */ public AotRepositoryBuilder withConstructorCustomizer( Consumer constructorCustomizer) { @@ -86,42 +116,33 @@ public AotRepositoryBuilder withConstructorCustomizer( return this; } + /** + * Configure a {@link MethodContributor}. + * + * @param methodContributorFunction must not be {@literal null}. + * @return {@code this}. + */ public AotRepositoryBuilder withQueryMethodContributor( BiFunction> methodContributorFunction) { - this.methodContributorFunction = methodContributorFunction; - return this; - } - public AotRepositoryBuilder withClassCustomizer(ClassCustomizer classCustomizer) { - - this.customizer = classCustomizer; + this.methodContributorFunction = methodContributorFunction; return this; } public AotBundle build() { + List methodMetadata = new ArrayList<>(); + RepositoryComposition repositoryComposition = repositoryInformation.getRepositoryComposition(); + // start creating the type TypeSpec.Builder builder = TypeSpec.classBuilder(this.generationMetadata.getTargetTypeName()) // .addModifiers(Modifier.PUBLIC) // .addAnnotation(Generated.class) // - .addJavadoc("AOT generated repository implementation for {@link $T}.\n", + .addJavadoc("AOT generated $L repository implementation for {@link $T}.\n", moduleName, repositoryInformation.getRepositoryInterface()); // create the constructor - AotRepositoryConstructorBuilder constructorBuilder = new AotRepositoryConstructorBuilder(repositoryInformation, - generationMetadata); - if (constructorCustomizer != null) { - constructorCustomizer.accept(constructorBuilder); - } - - builder.addMethod(constructorBuilder.buildConstructor()); - - List methodMetadata = new ArrayList<>(); - AotRepositoryMetadata.RepositoryType repositoryType = repositoryInformation.isReactiveRepository() - ? AotRepositoryMetadata.RepositoryType.REACTIVE - : AotRepositoryMetadata.RepositoryType.IMPERATIVE; - - RepositoryComposition repositoryComposition = repositoryInformation.getRepositoryComposition(); + builder.addMethod(buildConstructor()); Arrays.stream(repositoryInformation.getRepositoryInterface().getMethods()) .sorted(Comparator. comparing(it -> { @@ -136,12 +157,35 @@ public AotBundle build() { // finally customize the file itself this.customizer.customize(repositoryInformation, generationMetadata, builder); + JavaFile javaFile = JavaFile.builder(packageName(), builder.build()).build(); + AotRepositoryMetadata metadata = getAotRepositoryMetadata(methodMetadata); + + return new AotBundle(javaFile, metadata); + } + + private MethodSpec buildConstructor() { + + AotRepositoryConstructorBuilder constructorBuilder = new AotRepositoryConstructorBuilder(repositoryInformation, + generationMetadata); + + if (constructorCustomizer != null) { + constructorCustomizer.accept(constructorBuilder); + } + + return constructorBuilder.buildConstructor(); + } - AotRepositoryMetadata metadata = new AotRepositoryMetadata(repositoryInformation.getRepositoryInterface().getName(), - repositoryInformation.moduleName() != null ? repositoryInformation.moduleName() : "", repositoryType, methodMetadata); + private AotRepositoryMetadata getAotRepositoryMetadata(List methodMetadata) { - return new AotBundle(javaFile, metadata.toJson()); + AotRepositoryMetadata.RepositoryType repositoryType = repositoryInformation.isReactiveRepository() + ? AotRepositoryMetadata.RepositoryType.REACTIVE + : AotRepositoryMetadata.RepositoryType.IMPERATIVE; + + String jsonModuleName = moduleName.replaceAll("Reactive", "").trim(); + + return new AotRepositoryMetadata(repositoryInformation.getRepositoryInterface().getName(), jsonModuleName, + repositoryType, methodMetadata); } private void contributeMethod(Method method, RepositoryComposition repositoryComposition, @@ -185,8 +229,7 @@ private void contributeMethod(Method method, RepositoryComposition repositoryCom private AotRepositoryMethod getFragmentMetadata(Method method, RepositoryFragment fragment) { String signature = fragment.getSignatureContributor().getName(); - String implementation = fragment.getImplementation().map(it -> it.getClass().getName()).orElse(null); - + String implementation = fragment.getImplementationClass().map(Class::getName).orElse(null); AotFragmentTarget fragmentTarget = new AotFragmentTarget(signature, implementation); return new AotRepositoryMethod(method.getName(), method.toGenericString(), null, fragmentTarget); @@ -240,7 +283,7 @@ public interface ClassCustomizer { } - record AotBundle(JavaFile javaFile, JSONObject metadata) { + record AotBundle(JavaFile javaFile, AotRepositoryMetadata metadata) { } } diff --git a/src/main/java/org/springframework/data/repository/aot/generate/RepositoryContributor.java b/src/main/java/org/springframework/data/repository/aot/generate/RepositoryContributor.java index 9fffacb9c8..bcfc9d7a16 100644 --- a/src/main/java/org/springframework/data/repository/aot/generate/RepositoryContributor.java +++ b/src/main/java/org/springframework/data/repository/aot/generate/RepositoryContributor.java @@ -39,6 +39,7 @@ * * @author Christoph Strobl * @author Mark Paluch + * @since 4.0 */ public class RepositoryContributor { @@ -46,19 +47,34 @@ public class RepositoryContributor { private final AotRepositoryBuilder builder; + /** + * Create a new {@code RepositoryContributor} for the given {@link AotRepositoryContext}. + * + * @param repositoryContext + */ public RepositoryContributor(AotRepositoryContext repositoryContext) { this.builder = AotRepositoryBuilder.forRepository(repositoryContext.getRepositoryInformation(), - createProjectionFactory()); + repositoryContext.getModuleName(), createProjectionFactory()); } + /** + * @return a new {@link ProjectionFactory} to be used with the AOT repository builder. The actual instance should be + * accessed through {@link #getProjectionFactory()}. + */ protected ProjectionFactory createProjectionFactory() { return new SpelAwareProxyProjectionFactory(); } + /** + * @return the used {@link ProjectionFactory}. + */ protected ProjectionFactory getProjectionFactory() { return builder.getProjectionFactory(); } + /** + * @return the used {@link RepositoryInformation}. + */ protected RepositoryInformation getRepositoryInformation() { return builder.getRepositoryInformation(); } @@ -73,13 +89,10 @@ public java.util.Map requiredArgs() { public void contribute(GenerationContext generationContext) { - // TODO: do we need - generationContext.withName("spring-data"); - - builder.withClassCustomizer(this::customizeClass); - builder.withConstructorCustomizer(this::customizeConstructor); - builder.withQueryMethodContributor(this::contributeQueryMethod); - - AotRepositoryBuilder.AotBundle aotBundle = builder.build(); + AotRepositoryBuilder.AotBundle aotBundle = builder.withClassCustomizer(this::customizeClass) // + .withConstructorCustomizer(this::customizeConstructor) // + .withQueryMethodContributor(this::contributeQueryMethod) // + .build(); Class repositoryInterface = getRepositoryInformation().getRepositoryInterface(); String repositoryJsonFileName = getRepositoryJsonFileName(repositoryInterface); @@ -89,7 +102,7 @@ public void contribute(GenerationContext generationContext) { String repositoryJson; try { - repositoryJson = aotBundle.metadata().toString(2); + repositoryJson = aotBundle.metadata().toJson().toString(2); } catch (JSONException e) { throw new RuntimeException(e); } diff --git a/src/main/java/org/springframework/data/repository/config/AnnotationRepositoryConfigurationSource.java b/src/main/java/org/springframework/data/repository/config/AnnotationRepositoryConfigurationSource.java index f143cbb2a1..48f2d42f9d 100644 --- a/src/main/java/org/springframework/data/repository/config/AnnotationRepositoryConfigurationSource.java +++ b/src/main/java/org/springframework/data/repository/config/AnnotationRepositoryConfigurationSource.java @@ -65,6 +65,7 @@ public class AnnotationRepositoryConfigurationSource extends RepositoryConfigura private static final String QUERY_LOOKUP_STRATEGY = "queryLookupStrategy"; private static final String REPOSITORY_FACTORY_BEAN_CLASS = "repositoryFactoryBeanClass"; private static final String REPOSITORY_BASE_CLASS = "repositoryBaseClass"; + private static final String REPOSITORY_FRAGMENTS_CONTRIBUTOR_CLASS = "fragmentsContributor"; private static final String CONSIDER_NESTED_REPOSITORIES = "considerNestedRepositories"; private static final String BOOTSTRAP_MODE = "bootstrapMode"; private static final String BEAN_NAME_GENERATOR = "nameGenerator"; @@ -187,6 +188,16 @@ public Optional getRepositoryBaseClassName() { : Optional.of(repositoryBaseClass.getName()); } + @Override + public Optional getRepositoryFragmentsContributorClassName() { + + if (!attributes.containsKey(REPOSITORY_FRAGMENTS_CONTRIBUTOR_CLASS)) { + return Optional.empty(); + } + + return Optional.of(attributes.getClass(REPOSITORY_FRAGMENTS_CONTRIBUTOR_CLASS).getName()); + } + /** * Returns the {@link AnnotationAttributes} of the annotation configured. * diff --git a/src/main/java/org/springframework/data/repository/config/AotRepositoryContext.java b/src/main/java/org/springframework/data/repository/config/AotRepositoryContext.java index 995aa04084..231e7bba18 100644 --- a/src/main/java/org/springframework/data/repository/config/AotRepositoryContext.java +++ b/src/main/java/org/springframework/data/repository/config/AotRepositoryContext.java @@ -16,9 +16,9 @@ package org.springframework.data.repository.config; import java.lang.annotation.Annotation; +import java.util.Collection; import java.util.Set; -import org.springframework.core.SpringProperties; import org.springframework.core.annotation.MergedAnnotation; import org.springframework.data.aot.AotContext; import org.springframework.data.repository.core.RepositoryInformation; @@ -28,8 +28,9 @@ * * @author Christoph Strobl * @author John Blum - * @see AotContext + * @author Mark Paluch * @since 3.0 + * @see AotContext */ public interface AotRepositoryContext extends AotContext { @@ -38,6 +39,12 @@ public interface AotRepositoryContext extends AotContext { */ String getBeanName(); + /** + * @return the Spring Data module name, see {@link RepositoryConfigurationExtension#getModuleName()}. + * @since 4.0 + */ + String getModuleName(); + /** * @return a {@link Set} of {@link String base packages} to search for repositories. */ @@ -46,7 +53,7 @@ public interface AotRepositoryContext extends AotContext { /** * @return the {@link Annotation} types used to identify domain types. */ - Set> getIdentifyingAnnotations(); + Collection> getIdentifyingAnnotations(); /** * @return {@link RepositoryInformation metadata} about the repository itself. @@ -64,4 +71,5 @@ public interface AotRepositoryContext extends AotContext { * @return all {@link Class types} reachable from the repository. */ Set> getResolvedTypes(); + } diff --git a/src/main/java/org/springframework/data/repository/config/AotRepositoryInformation.java b/src/main/java/org/springframework/data/repository/config/AotRepositoryInformation.java index 1ddcbde9a4..0237d51361 100644 --- a/src/main/java/org/springframework/data/repository/config/AotRepositoryInformation.java +++ b/src/main/java/org/springframework/data/repository/config/AotRepositoryInformation.java @@ -17,52 +17,40 @@ import java.lang.reflect.Method; import java.util.Collection; -import java.util.LinkedHashSet; import java.util.Set; -import java.util.function.Supplier; -import org.jspecify.annotations.Nullable; import org.springframework.data.repository.core.RepositoryInformation; import org.springframework.data.repository.core.RepositoryInformationSupport; import org.springframework.data.repository.core.RepositoryMetadata; import org.springframework.data.repository.core.support.RepositoryComposition; import org.springframework.data.repository.core.support.RepositoryComposition.RepositoryFragments; import org.springframework.data.repository.core.support.RepositoryFragment; -import org.springframework.data.util.Lazy; /** * {@link RepositoryInformation} based on {@link RepositoryMetadata} collected at build time. * * @author Christoph Strobl + * @author Mark Paluch * @since 3.0 */ -class AotRepositoryInformation extends RepositoryInformationSupport implements RepositoryInformation { +public class AotRepositoryInformation extends RepositoryInformationSupport implements RepositoryInformation { - private final @Nullable String moduleName; - private final Supplier>> fragments; + private final RepositoryComposition fragmentsComposition; + private final RepositoryComposition baseComposition; + private final RepositoryComposition composition; - private final Lazy repositoryComposition; - private final Lazy baseComposition; + public AotRepositoryInformation(RepositoryMetadata repositoryMetadata, Class repositoryBaseClass, + Collection> fragments) { - AotRepositoryInformation(@Nullable String moduleName, Supplier repositoryMetadata, - Supplier> repositoryBaseClass, Supplier>> fragments) { + super(() -> repositoryMetadata, () -> repositoryBaseClass); - super(repositoryMetadata, repositoryBaseClass); + this.fragmentsComposition = RepositoryComposition.fromMetadata(getMetadata()) + .append(RepositoryFragments.from(fragments)); + this.baseComposition = RepositoryComposition.of(RepositoryFragment.structural(getRepositoryBaseClass())) // + .withArgumentConverter(this.fragmentsComposition.getArgumentConverter()) // + .withMethodLookup(this.fragmentsComposition.getMethodLookup()); - this.moduleName = moduleName; - this.fragments = fragments; - - this.repositoryComposition = Lazy - .of(() -> RepositoryComposition.fromMetadata(getMetadata()).append(RepositoryFragments.from(getFragments()))); - - this.baseComposition = Lazy.of(() -> { - - RepositoryComposition targetRepoComposition = repositoryComposition.get(); - - return RepositoryComposition.of(RepositoryFragment.structural(getRepositoryBaseClass())) // - .withArgumentConverter(targetRepoComposition.getArgumentConverter()) // - .withMethodLookup(targetRepoComposition.getMethodLookup()); - }); + this.composition = this.fragmentsComposition.append(this.baseComposition.getFragments()); } /** @@ -71,31 +59,27 @@ class AotRepositoryInformation extends RepositoryInformationSupport implements R */ @Override public Set> getFragments() { - return new LinkedHashSet<>(fragments.get()); + return fragmentsComposition.getFragments().toSet(); } @Override public boolean isCustomMethod(Method method) { - return repositoryComposition.get().findMethod(method).isPresent(); + return fragmentsComposition.findMethod(method).isPresent(); } @Override public boolean isBaseClassMethod(Method method) { - return baseComposition.get().findMethod(method).isPresent(); + return baseComposition.findMethod(method).isPresent(); } @Override public Method getTargetClassMethod(Method method) { - return baseComposition.get().findMethod(method).orElse(method); + return baseComposition.findMethod(method).orElse(method); } @Override public RepositoryComposition getRepositoryComposition() { - return repositoryComposition.get(); + return composition; } - @Override - public @Nullable String moduleName() { - return moduleName; - } } diff --git a/src/main/java/org/springframework/data/repository/config/DefaultAotRepositoryContext.java b/src/main/java/org/springframework/data/repository/config/DefaultAotRepositoryContext.java index f40985b272..5f695f6276 100644 --- a/src/main/java/org/springframework/data/repository/config/DefaultAotRepositoryContext.java +++ b/src/main/java/org/springframework/data/repository/config/DefaultAotRepositoryContext.java @@ -1,5 +1,5 @@ /* - * Copyright 2022. the original author or authors. + * Copyright 2022-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,14 +16,14 @@ package org.springframework.data.repository.config; import java.lang.annotation.Annotation; +import java.util.Collection; import java.util.Collections; import java.util.LinkedHashSet; import java.util.Set; import java.util.stream.Collectors; -import org.jspecify.annotations.Nullable; - import org.springframework.beans.factory.config.ConfigurableListableBeanFactory; +import org.springframework.beans.factory.support.RegisteredBean; import org.springframework.core.annotation.MergedAnnotation; import org.springframework.core.env.Environment; import org.springframework.data.aot.AotContext; @@ -37,29 +37,42 @@ * * @author Christoph Strobl * @author John Blum + * @author Mark Paluch * @see AotRepositoryContext * @since 3.0 */ @SuppressWarnings("NullAway") // TODO class DefaultAotRepositoryContext implements AotRepositoryContext { + private final RegisteredBean bean; + private final String moduleName; private final AotContext aotContext; + private final RepositoryInformation repositoryInformation; private final Lazy>> resolvedAnnotations = Lazy.of(this::discoverAnnotations); private final Lazy>> managedTypes = Lazy.of(this::discoverTypes); - private @Nullable RepositoryInformation repositoryInformation; - private @Nullable Set basePackages; - private @Nullable Set> identifyingAnnotations; - private @Nullable String beanName; + private Set basePackages = Collections.emptySet(); + private Collection> identifyingAnnotations = Collections.emptySet(); + private String beanName; - public DefaultAotRepositoryContext(AotContext aotContext) { + public DefaultAotRepositoryContext(RegisteredBean bean, RepositoryInformation repositoryInformation, + String moduleName, AotContext aotContext) { + this.bean = bean; + this.repositoryInformation = repositoryInformation; + this.moduleName = moduleName; this.aotContext = aotContext; + this.beanName = bean.getBeanName(); } public AotContext getAotContext() { return aotContext; } + @Override + public String getModuleName() { + return moduleName; + } + @Override public ConfigurableListableBeanFactory getBeanFactory() { return getAotContext().getBeanFactory(); @@ -72,7 +85,7 @@ public Environment getEnvironment() { @Override public Set getBasePackages() { - return basePackages == null ? Collections.emptySet() : basePackages; + return basePackages; } public void setBasePackages(Set basePackages) { @@ -89,11 +102,11 @@ public void setBeanName(String beanName) { } @Override - public Set> getIdentifyingAnnotations() { - return identifyingAnnotations == null ? Collections.emptySet() : identifyingAnnotations; + public Collection> getIdentifyingAnnotations() { + return identifyingAnnotations; } - public void setIdentifyingAnnotations(Set> identifyingAnnotations) { + public void setIdentifyingAnnotations(Collection> identifyingAnnotations) { this.identifyingAnnotations = identifyingAnnotations; } @@ -102,10 +115,6 @@ public RepositoryInformation getRepositoryInformation() { return repositoryInformation; } - public void setRepositoryInformation(RepositoryInformation repositoryInformation) { - this.repositoryInformation = repositoryInformation; - } - @Override public Set> getResolvedAnnotations() { return resolvedAnnotations.get(); @@ -132,24 +141,18 @@ protected Set> discoverAnnotations() { .flatMap(type -> TypeUtils.resolveUsedAnnotations(type).stream()) .collect(Collectors.toCollection(LinkedHashSet::new)); - if (repositoryInformation != null) { - annotations.addAll(TypeUtils.resolveUsedAnnotations(repositoryInformation.getRepositoryInterface())); - } + annotations.addAll(TypeUtils.resolveUsedAnnotations(repositoryInformation.getRepositoryInterface())); return annotations; } protected Set> discoverTypes() { - Set> types = new LinkedHashSet<>(); + Set> types = new LinkedHashSet<>(TypeCollector.inspect(repositoryInformation.getDomainType()).list()); - if (repositoryInformation != null) { - types.addAll(TypeCollector.inspect(repositoryInformation.getDomainType()).list()); - - repositoryInformation.getQueryMethods().stream() - .flatMap(it -> TypeUtils.resolveTypesInSignature(repositoryInformation.getRepositoryInterface(), it).stream()) - .flatMap(it -> TypeCollector.inspect(it).list().stream()).forEach(types::add); - } + repositoryInformation.getQueryMethods().stream() + .flatMap(it -> TypeUtils.resolveTypesInSignature(repositoryInformation.getRepositoryInterface(), it).stream()) + .flatMap(it -> TypeCollector.inspect(it).list().stream()).forEach(types::add); if (!getIdentifyingAnnotations().isEmpty()) { @@ -160,4 +163,5 @@ protected Set> discoverTypes() { return types; } + } diff --git a/src/main/java/org/springframework/data/repository/config/DefaultRepositoryConfiguration.java b/src/main/java/org/springframework/data/repository/config/DefaultRepositoryConfiguration.java index 1ebdb1c907..eefcdf4043 100644 --- a/src/main/java/org/springframework/data/repository/config/DefaultRepositoryConfiguration.java +++ b/src/main/java/org/springframework/data/repository/config/DefaultRepositoryConfiguration.java @@ -114,12 +114,17 @@ public T getConfigurationSource() { @Override public Optional getRepositoryBaseClassName() { - return configurationSource.getRepositoryBaseClassName(); + return configurationSource.getRepositoryBaseClassName() + .or(() -> Optional.ofNullable(extension.getRepositoryBaseClassName())); } @Override - public String getRepositoryFactoryBeanClassName() { + public Optional getRepositoryFragmentsContributorClassName() { + return configurationSource.getRepositoryFragmentsContributorClassName(); + } + @Override + public String getRepositoryFactoryBeanClassName() { return configurationSource.getRepositoryFactoryBeanClassName() .orElseGet(extension::getRepositoryFactoryBeanClassName); } diff --git a/src/main/java/org/springframework/data/repository/config/RepositoryBeanDefinitionBuilder.java b/src/main/java/org/springframework/data/repository/config/RepositoryBeanDefinitionBuilder.java index ada478eb9b..994ae1e0c5 100644 --- a/src/main/java/org/springframework/data/repository/config/RepositoryBeanDefinitionBuilder.java +++ b/src/main/java/org/springframework/data/repository/config/RepositoryBeanDefinitionBuilder.java @@ -116,6 +116,11 @@ public BeanDefinitionBuilder build(RepositoryConfiguration configuration) { .rootBeanDefinition(configuration.getRepositoryFactoryBeanClassName()); builder.getRawBeanDefinition().setSource(configuration.getSource()); + + // AOT Repository hints + builder.getRawBeanDefinition().setAttribute(RepositoryConfiguration.class.getName(), configuration); + builder.getRawBeanDefinition().setAttribute(RepositoryConfigurationExtension.class.getName(), extension); + builder.addConstructorArgValue(configuration.getRepositoryInterface()); builder.addPropertyValue("queryLookupStrategyKey", configuration.getQueryLookupStrategyKey()); builder.addPropertyValue("lazyInit", configuration.isLazyInit()); @@ -125,6 +130,10 @@ public BeanDefinitionBuilder build(RepositoryConfiguration configuration) { configuration.getRepositoryBaseClassName()// .ifPresent(it -> builder.addPropertyValue("repositoryBaseClass", it)); + configuration.getRepositoryFragmentsContributorClassName()// + .ifPresent(it -> builder.addPropertyValue("repositoryFragmentsContributor", + BeanDefinitionBuilder.genericBeanDefinition(it).getRawBeanDefinition())); + NamedQueriesBeanDefinitionBuilder definitionBuilder = new NamedQueriesBeanDefinitionBuilder( extension.getDefaultNamedQueryLocation()); configuration.getNamedQueriesLocation().ifPresent(definitionBuilder::setLocations); diff --git a/src/main/java/org/springframework/data/repository/config/RepositoryBeanDefinitionReader.java b/src/main/java/org/springframework/data/repository/config/RepositoryBeanDefinitionReader.java index 1209903d3b..6d064a294e 100644 --- a/src/main/java/org/springframework/data/repository/config/RepositoryBeanDefinitionReader.java +++ b/src/main/java/org/springframework/data/repository/config/RepositoryBeanDefinitionReader.java @@ -15,139 +15,197 @@ */ package org.springframework.data.repository.config; +import java.lang.reflect.Constructor; import java.util.ArrayList; import java.util.List; +import org.jspecify.annotations.Nullable; + +import org.springframework.beans.BeanUtils; +import org.springframework.beans.PropertyValue; +import org.springframework.beans.PropertyValues; import org.springframework.beans.factory.config.BeanDefinition; import org.springframework.beans.factory.config.ConfigurableListableBeanFactory; +import org.springframework.beans.factory.config.ConstructorArgumentValues; import org.springframework.beans.factory.config.ConstructorArgumentValues.ValueHolder; import org.springframework.beans.factory.config.RuntimeBeanReference; import org.springframework.beans.factory.support.RegisteredBean; -import org.springframework.core.ResolvableType; +import org.springframework.beans.factory.support.RootBeanDefinition; import org.springframework.data.repository.CrudRepository; import org.springframework.data.repository.PagingAndSortingRepository; import org.springframework.data.repository.core.RepositoryInformation; import org.springframework.data.repository.core.RepositoryMetadata; import org.springframework.data.repository.core.support.AbstractRepositoryMetadata; +import org.springframework.data.repository.core.support.RepositoryFactoryBeanSupport; import org.springframework.data.repository.core.support.RepositoryFragment; -import org.springframework.data.repository.core.support.RepositoryFragment.ImplementedRepositoryFragment; +import org.springframework.data.repository.core.support.RepositoryFragmentsContributor; import org.springframework.util.ClassUtils; +import org.springframework.util.ObjectUtils; /** * Reader used to extract {@link RepositoryInformation} from {@link RepositoryConfiguration}. * * @author Christoph Strobl * @author John Blum + * @author Mark Paluch * @since 3.0 */ class RepositoryBeanDefinitionReader { - /** - * @return - */ - static RepositoryInformation repositoryInformation(RepositoryConfiguration repoConfig, RegisteredBean repoBean) { - return repositoryInformation(repoConfig, repoBean.getMergedBeanDefinition(), repoBean.getBeanFactory()); + private final RootBeanDefinition beanDefinition; + private final ConfigurableListableBeanFactory beanFactory; + private final ClassLoader beanClassLoader; + private final @Nullable RepositoryConfiguration configuration; + private final @Nullable RepositoryConfigurationExtensionSupport extension; + + public RepositoryBeanDefinitionReader(RegisteredBean bean) { + + this.beanDefinition = bean.getMergedBeanDefinition(); + this.beanFactory = bean.getBeanFactory(); + this.beanClassLoader = bean.getBeanClass().getClassLoader(); + this.configuration = (RepositoryConfiguration) beanDefinition + .getAttribute(RepositoryConfiguration.class.getName()); + this.extension = (RepositoryConfigurationExtensionSupport) beanDefinition + .getAttribute(RepositoryConfigurationExtension.class.getName()); + } + + public @Nullable RepositoryConfiguration getConfiguration() { + return this.configuration; + } + + public @Nullable RepositoryConfigurationExtensionSupport getConfigurationExtension() { + return this.extension; } /** - * @param source the RepositoryFactoryBeanSupport bean definition. - * @param beanFactory - * @return + * @return the {@link RepositoryInformation} derived from the repository bean. */ - @SuppressWarnings("NullAway") - static RepositoryInformation repositoryInformation(RepositoryConfiguration repoConfig, BeanDefinition source, - ConfigurableListableBeanFactory beanFactory) { + public RepositoryInformation getRepositoryInformation() { RepositoryMetadata metadata = AbstractRepositoryMetadata - .getMetadata(forName(repoConfig.getRepositoryInterface(), beanFactory)); - Class repositoryBaseClass = readRepositoryBaseClass(source, beanFactory); - List> fragmentList = readRepositoryFragments(source, beanFactory); - if (source.getPropertyValues().contains("customImplementation")) { - - Object o = source.getPropertyValues().get("customImplementation"); - if (o instanceof RuntimeBeanReference rbr) { - BeanDefinition customImplBeanDefintion = beanFactory.getBeanDefinition(rbr.getBeanName()); - Class beanType = forName(customImplBeanDefintion.getBeanClassName(), beanFactory); - ResolvableType[] interfaces = ResolvableType.forClass(beanType).getInterfaces(); - if (interfaces.length == 1) { - fragmentList.add(new ImplementedRepositoryFragment(interfaces[0].toClass(), beanType)); - } else { - boolean found = false; - for (ResolvableType i : interfaces) { - if (beanType.getSimpleName().contains(i.resolve().getSimpleName())) { - fragmentList.add(new ImplementedRepositoryFragment(interfaces[0].toClass(), beanType)); - found = true; - break; - } - } - if (!found) { - fragmentList.add(RepositoryFragment.implemented(beanType)); - } - } + .getMetadata(forName(configuration.getRepositoryInterface())); + Class repositoryBaseClass = getRepositoryBaseClass(); + + List> fragments = new ArrayList<>(); + fragments.addAll(readRepositoryFragments()); + fragments.addAll(readContributedRepositoryFragments(metadata)); + + RepositoryFragment customImplementation = getCustomImplementation(); + if (customImplementation != null) { + fragments.add(0, customImplementation); + } + + return new AotRepositoryInformation(metadata, repositoryBaseClass, fragments); + } + + private @Nullable RepositoryFragment getCustomImplementation() { + + PropertyValues mpv = beanDefinition.getPropertyValues(); + PropertyValue customImplementation = mpv.getPropertyValue("customImplementation"); + + if (customImplementation != null) { + + if (customImplementation.getValue() instanceof RuntimeBeanReference rbr) { + BeanDefinition customImplementationBean = beanFactory.getBeanDefinition(rbr.getBeanName()); + Class beanType = getClass(customImplementationBean); + return RepositoryFragment.structural(beanType); + } else if (customImplementation.getValue() instanceof BeanDefinition bd) { + Class beanType = getClass(bd); + return RepositoryFragment.structural(beanType); } } - String moduleName = (String) source.getPropertyValues().get("moduleName"); - AotRepositoryInformation repositoryInformation = new AotRepositoryInformation(moduleName, () -> metadata, - () -> repositoryBaseClass, () -> fragmentList); - return repositoryInformation; + return null; } @SuppressWarnings("NullAway") - private static Class readRepositoryBaseClass(BeanDefinition source, ConfigurableListableBeanFactory beanFactory) { + private Class getRepositoryBaseClass() { + + Object repoBaseClassName = beanDefinition.getPropertyValues().get("repositoryBaseClass"); - Object repoBaseClassName = source.getPropertyValues().get("repositoryBaseClass"); if (repoBaseClassName != null) { - return forName(repoBaseClassName.toString(), beanFactory); - } - if (source.getPropertyValues().contains("moduleBaseClass")) { - return forName((String) source.getPropertyValues().get("moduleBaseClass"), beanFactory); + return forName(repoBaseClassName.toString()); } + return Dummy.class; } @SuppressWarnings("NullAway") - private static List> readRepositoryFragments(BeanDefinition source, - ConfigurableListableBeanFactory beanFactory) { + private List> readRepositoryFragments() { - RuntimeBeanReference beanReference = (RuntimeBeanReference) source.getPropertyValues().get("repositoryFragments"); + RuntimeBeanReference beanReference = (RuntimeBeanReference) beanDefinition.getPropertyValues() + .get("repositoryFragments"); BeanDefinition fragments = beanFactory.getBeanDefinition(beanReference.getBeanName()); ValueHolder fragmentBeanNameList = fragments.getConstructorArgumentValues().getArgumentValue(0, List.class); List fragmentBeanNames = (List) fragmentBeanNameList.getValue(); List> fragmentList = new ArrayList<>(); + for (String beanName : fragmentBeanNames) { BeanDefinition fragmentBeanDefinition = beanFactory.getBeanDefinition(beanName); - ValueHolder argumentValue = fragmentBeanDefinition.getConstructorArgumentValues().getArgumentValue(0, - String.class); - ValueHolder argumentValue1 = fragmentBeanDefinition.getConstructorArgumentValues().getArgumentValue(1, null, null, - null); - Object fragmentClassName = argumentValue.getValue(); - - try { - Class type = ClassUtils.forName(fragmentClassName.toString(), beanFactory.getBeanClassLoader()); - - if (argumentValue1 != null && argumentValue1.getValue() instanceof RuntimeBeanReference rbf) { - BeanDefinition implBeanDef = beanFactory.getBeanDefinition(rbf.getBeanName()); - Class implClass = ClassUtils.forName(implBeanDef.getBeanClassName(), beanFactory.getBeanClassLoader()); - fragmentList.add(new RepositoryFragment.ImplementedRepositoryFragment(type, implClass)); - } else { - fragmentList.add(RepositoryFragment.structural(type)); - } - } catch (ClassNotFoundException e) { - throw new RuntimeException(e); + ConstructorArgumentValues cv = fragmentBeanDefinition.getConstructorArgumentValues(); + ValueHolder interfaceClassVh = cv.getArgumentValue(0, String.class); + ValueHolder implementationVh = cv.getArgumentValue(1, null, null, null); + + Object fragmentClassName = interfaceClassVh.getValue(); + Class interfaceClass = forName(fragmentClassName.toString()); + + if (implementationVh != null && implementationVh.getValue() instanceof RuntimeBeanReference rbf) { + BeanDefinition implBeanDef = beanFactory.getBeanDefinition(rbf.getBeanName()); + Class implClass = getClass(implBeanDef); + fragmentList.add(RepositoryFragment.structural(interfaceClass, implClass)); + } else { + fragmentList.add(RepositoryFragment.structural(interfaceClass)); } } + return fragmentList; } + private List> readContributedRepositoryFragments(RepositoryMetadata metadata) { + + RepositoryFragmentsContributor contributor = getFragmentsContributor(metadata.getRepositoryInterface()); + return contributor.describe(metadata).stream().toList(); + } + + private RepositoryFragmentsContributor getFragmentsContributor(Class repositoryInterface) { + + Object repositoryFragmentsContributor = beanDefinition.getPropertyValues().get("repositoryFragmentsContributor"); + + if (repositoryFragmentsContributor instanceof BeanDefinition bd) { + return (RepositoryFragmentsContributor) BeanUtils.instantiateClass(getClass(bd)); + } + + Class repositoryFactoryBean = forName(beanDefinition.getBeanClassName()); + Constructor constructor = ClassUtils.getConstructorIfAvailable(repositoryFactoryBean, Class.class); + + if (constructor == null) { + throw new IllegalStateException("No constructor accepting Class in " + repositoryFactoryBean.getName()); + } + RepositoryFactoryBeanSupport factoryBean = (RepositoryFactoryBeanSupport) BeanUtils + .instantiateClass(constructor, repositoryInterface); + + return factoryBean.getRepositoryFragmentsContributor(); + } + + private Class getClass(BeanDefinition definition) { + + String beanClassName = definition.getBeanClassName(); + + if (ObjectUtils.isEmpty(beanClassName)) { + throw new IllegalStateException("No bean class name specified for %s".formatted(definition)); + } + + return forName(beanClassName); + } + static abstract class Dummy implements CrudRepository, PagingAndSortingRepository {} - static Class forName(String name, ConfigurableListableBeanFactory beanFactory) { + private Class forName(String name) { try { - return ClassUtils.forName(name, beanFactory.getBeanClassLoader()); + return ClassUtils.forName(name, beanClassLoader); } catch (ClassNotFoundException cause) { throw new TypeNotPresentException(name, cause); } diff --git a/src/main/java/org/springframework/data/repository/config/RepositoryConfiguration.java b/src/main/java/org/springframework/data/repository/config/RepositoryConfiguration.java index 0a42dcd108..4ee3ceb4a3 100644 --- a/src/main/java/org/springframework/data/repository/config/RepositoryConfiguration.java +++ b/src/main/java/org/springframework/data/repository/config/RepositoryConfiguration.java @@ -78,6 +78,15 @@ public interface RepositoryConfiguration getRepositoryBaseClassName(); + /** + * Returns the name of the repository fragments contributor class to be used or {@link Optional#empty()} if the store + * specific defaults shall be applied. + * + * @return + * @since 4.0 + */ + Optional getRepositoryFragmentsContributorClassName(); + /** * Returns the name of the repository factory bean class to be used. * @@ -157,11 +166,12 @@ public interface RepositoryConfiguration getRepositoryBaseClassName() { return repositoryConfiguration.getRepositoryBaseClassName(); } + @Override + public Optional getRepositoryFragmentsContributorClassName() { + return repositoryConfiguration.getRepositoryFragmentsContributorClassName(); + } + @Override public String getRepositoryFactoryBeanClassName() { return repositoryConfiguration.getRepositoryFactoryBeanClassName(); diff --git a/src/main/java/org/springframework/data/repository/config/RepositoryConfigurationExtension.java b/src/main/java/org/springframework/data/repository/config/RepositoryConfigurationExtension.java index 1b9531da35..1c5a5530f9 100644 --- a/src/main/java/org/springframework/data/repository/config/RepositoryConfigurationExtension.java +++ b/src/main/java/org/springframework/data/repository/config/RepositoryConfigurationExtension.java @@ -18,7 +18,7 @@ import java.util.Collection; import java.util.Locale; -import org.jspecify.annotations.NonNull; +import org.jspecify.annotations.Nullable; import org.springframework.beans.factory.aot.BeanRegistrationAotProcessor; import org.springframework.beans.factory.config.BeanDefinition; @@ -63,7 +63,6 @@ default String getModuleIdentifier() { * @see org.springframework.beans.factory.aot.BeanRegistrationAotProcessor * @since 3.0 */ - @NonNull default Class getRepositoryAotProcessor() { return RepositoryRegistrationAotProcessor.class; } @@ -90,6 +89,16 @@ Collection> */ String getDefaultNamedQueryLocation(); + /** + * Returns the {@link String name} of the repository base class to be used. + * + * @return can be {@literal null} if the base class cannot be provided. + * @since 4.0 + */ + default @Nullable String getRepositoryBaseClassName() { + return null; + } + /** * Returns the {@link String name} of the repository factory class to be used. * diff --git a/src/main/java/org/springframework/data/repository/config/RepositoryConfigurationSource.java b/src/main/java/org/springframework/data/repository/config/RepositoryConfigurationSource.java index af1dec7a65..7d750f6cff 100644 --- a/src/main/java/org/springframework/data/repository/config/RepositoryConfigurationSource.java +++ b/src/main/java/org/springframework/data/repository/config/RepositoryConfigurationSource.java @@ -81,6 +81,15 @@ public interface RepositoryConfigurationSource { */ Optional getRepositoryBaseClassName(); + /** + * Returns the name of the repository fragments contributor class to be used or {@link Optional#empty()} if the store + * specific defaults shall be applied. + * + * @return + * @since 4.0 + */ + Optional getRepositoryFragmentsContributorClassName(); + /** * Returns the name of the repository factory bean class or {@link Optional#empty()} if not defined in the source. * diff --git a/src/main/java/org/springframework/data/repository/config/RepositoryRegistrationAotContribution.java b/src/main/java/org/springframework/data/repository/config/RepositoryRegistrationAotContribution.java index 40b2cc43a7..feddf13e5c 100644 --- a/src/main/java/org/springframework/data/repository/config/RepositoryRegistrationAotContribution.java +++ b/src/main/java/org/springframework/data/repository/config/RepositoryRegistrationAotContribution.java @@ -16,18 +16,18 @@ package org.springframework.data.repository.config; import java.io.Serializable; -import java.lang.annotation.Annotation; import java.util.ArrayList; import java.util.Arrays; -import java.util.Collections; import java.util.List; import java.util.Optional; -import java.util.Set; import java.util.function.BiConsumer; import java.util.function.BiFunction; import java.util.function.Predicate; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; import org.jspecify.annotations.Nullable; + import org.springframework.aop.SpringProxy; import org.springframework.aop.framework.Advised; import org.springframework.aot.generate.GenerationContext; @@ -37,11 +37,11 @@ import org.springframework.beans.factory.aot.BeanRegistrationCode; import org.springframework.beans.factory.aot.BeanRegistrationCodeFragments; import org.springframework.beans.factory.aot.BeanRegistrationCodeFragmentsDecorator; -import org.springframework.beans.factory.config.ConfigurableListableBeanFactory; import org.springframework.beans.factory.support.RegisteredBean; import org.springframework.beans.factory.support.RootBeanDefinition; import org.springframework.core.DecoratingProxy; import org.springframework.core.annotation.AnnotationUtils; +import org.springframework.core.env.Environment; import org.springframework.data.aot.AotContext; import org.springframework.data.projection.EntityProjectionIntrospector; import org.springframework.data.projection.TargetAware; @@ -66,18 +66,19 @@ * @author Mark Paluch * @since 3.0 */ -// TODO: Consider moving to data.repository.aot public class RepositoryRegistrationAotContribution implements BeanRegistrationAotContribution { + private static final Log logger = LogFactory.getLog(RepositoryRegistrationAotContribution.class); + private static final String KOTLIN_COROUTINE_REPOSITORY_TYPE_NAME = "org.springframework.data.repository.kotlin.CoroutineCrudRepository"; - private @Nullable RepositoryContributor repositoryContributor; + private final RepositoryRegistrationAotProcessor aotProcessor; - private @Nullable AotRepositoryContext repositoryContext; + private final AotRepositoryContext repositoryContext; - private @Nullable BiFunction moduleContribution; + private @Nullable RepositoryContributor repositoryContributor; - private final RepositoryRegistrationAotProcessor aotProcessor; + private @Nullable BiFunction moduleContribution; /** * Constructs a new instance of the {@link RepositoryRegistrationAotContribution} initialized with the given, required @@ -85,14 +86,18 @@ public class RepositoryRegistrationAotContribution implements BeanRegistrationAo * * @param processor reference back to the {@link RepositoryRegistrationAotProcessor} from which this contribution was * created. + * @param context reference back to the {@link AotRepositoryContext} from which this contribution was created. * @throws IllegalArgumentException if the {@link RepositoryRegistrationAotProcessor} is {@literal null}. * @see RepositoryRegistrationAotProcessor */ - protected RepositoryRegistrationAotContribution(RepositoryRegistrationAotProcessor processor) { + protected RepositoryRegistrationAotContribution(RepositoryRegistrationAotProcessor processor, + AotRepositoryContext context) { Assert.notNull(processor, "RepositoryRegistrationAotProcessor must not be null"); + Assert.notNull(context, "AotRepositoryContext must not be null"); this.aotProcessor = processor; + this.repositoryContext = context; } /** @@ -101,16 +106,57 @@ protected RepositoryRegistrationAotContribution(RepositoryRegistrationAotProcess * * @param processor reference back to the {@link RepositoryRegistrationAotProcessor} from which this contribution was * created. - * @return a new instance of {@link RepositoryRegistrationAotContribution}. - * @throws IllegalArgumentException if the {@link RepositoryRegistrationAotProcessor} is {@literal null}. + * @return a new instance of {@link RepositoryRegistrationAotContribution} if a contribution can be made; + * {@literal null} if no contribution can be made. * @see RepositoryRegistrationAotProcessor */ - public static RepositoryRegistrationAotContribution fromProcessor(RepositoryRegistrationAotProcessor processor) { - return new RepositoryRegistrationAotContribution(processor); + public static @Nullable RepositoryRegistrationAotContribution load(RepositoryRegistrationAotProcessor processor, + RegisteredBean repositoryBean) { + + RepositoryConfiguration repositoryMetadata = processor.getRepositoryMetadata(repositoryBean); + + if (repositoryMetadata == null) { + return null; + } + + AotRepositoryContext repositoryContext = buildAotRepositoryContext(processor.getEnvironment(), repositoryBean, + repositoryMetadata); + + if (repositoryContext == null) { + return null; + } + + return new RepositoryRegistrationAotContribution(processor, repositoryContext); } - protected ConfigurableListableBeanFactory getBeanFactory() { - return getRepositoryRegistrationAotProcessor().getBeanFactory(); + /** + * Builds a {@link RepositoryRegistrationAotContribution} for given, required {@link RegisteredBean} representing the + * {@link Repository} registered in the bean registry. + * + * @param repositoryBean {@link RegisteredBean} for the {@link Repository}; must not be {@literal null}. + * @return a {@link RepositoryRegistrationAotContribution} to contribute AOT metadata and code for the + * {@link Repository} {@link RegisteredBean}. + * @throws IllegalArgumentException if the {@link RegisteredBean} is {@literal null}. + * @deprecated since 4.0. + */ + @Deprecated(since = "4.0", forRemoval = true) + public @Nullable RepositoryRegistrationAotContribution forBean(RegisteredBean repositoryBean) { + + RepositoryConfiguration repositoryMetadata = getRepositoryRegistrationAotProcessor() + .getRepositoryMetadata(repositoryBean); + + if (repositoryMetadata == null) { + return null; + } + + AotRepositoryContext repositoryContext = buildAotRepositoryContext(aotProcessor.getEnvironment(), repositoryBean, + repositoryMetadata); + + if (repositoryContext == null) { + return null; + } + + return new RepositoryRegistrationAotContribution(getRepositoryRegistrationAotProcessor(), repositoryContext); } protected @Nullable BiFunction getModuleContribution() { @@ -118,10 +164,6 @@ protected ConfigurableListableBeanFactory getBeanFactory() { } protected AotRepositoryContext getRepositoryContext() { - - Assert.state(this.repositoryContext != null, - "The AOT RepositoryContext was not properly initialized; did you call the forBean(:RegisteredBean) method"); - return this.repositoryContext; } @@ -137,28 +179,27 @@ private void logTrace(String message, Object... arguments) { getRepositoryRegistrationAotProcessor().logTrace(message, arguments); } - /** - * Builds a {@link RepositoryRegistrationAotContribution} for given, required {@link RegisteredBean} representing the - * {@link Repository} registered in the bean registry. - * - * @param repositoryBean {@link RegisteredBean} for the {@link Repository}; must not be {@literal null}. - * @return a {@link RepositoryRegistrationAotContribution} to contribute AOT metadata and code for the - * {@link Repository} {@link RegisteredBean}. - * @throws IllegalArgumentException if the {@link RegisteredBean} is {@literal null}. - * @see org.springframework.beans.factory.support.RegisteredBean - */ - public RepositoryRegistrationAotContribution forBean(RegisteredBean repositoryBean) { - - Assert.notNull(repositoryBean, "The RegisteredBean for the repository must not be null"); + private static @Nullable AotRepositoryContext buildAotRepositoryContext(Environment environment, RegisteredBean bean, + RepositoryConfiguration repositoryConfiguration) { - RepositoryConfiguration repositoryMetadata = getRepositoryRegistrationAotProcessor() - .getRepositoryMetadata(repositoryBean); + RepositoryBeanDefinitionReader reader = new RepositoryBeanDefinitionReader(bean); + RepositoryConfiguration configuration = reader.getConfiguration(); + RepositoryConfigurationExtensionSupport extension = reader.getConfigurationExtension(); - Assert.state(repositoryMetadata != null, "The RepositoryConfiguration for the repository must not be null"); + if (configuration == null || extension == null) { + logger.warn( + "Cannot create AotRepositoryContext for bean [%s]. No RepositoryConfiguration/RepositoryConfigurationExtension. Please make sure to register the repository bean through @Enable…Repositories." + .formatted(bean.getBeanName())); + return null; + } + RepositoryInformation repositoryInformation = reader.getRepositoryInformation(); + DefaultAotRepositoryContext repositoryContext = new DefaultAotRepositoryContext(bean, repositoryInformation, + extension.getModuleName(), AotContext.from(bean.getBeanFactory(), environment)); - this.repositoryContext = buildAotRepositoryContext(repositoryBean, repositoryMetadata); + repositoryContext.setBasePackages(repositoryConfiguration.getBasePackages().toSet()); + repositoryContext.setIdentifyingAnnotations(extension.getIdentifyingAnnotations()); - return this; + return repositoryContext; } /** @@ -176,9 +217,6 @@ public RepositoryRegistrationAotContribution withModuleContribution( @Override public void applyTo(GenerationContext generationContext, BeanRegistrationCode beanRegistrationCode) { - Assert.state(this.repositoryContext != null, - "RepositoryContext cannot be null. Make sure to initialize this class with forBean(…)."); - contributeRepositoryInfo(this.repositoryContext, generationContext); var moduleContribution = getModuleContribution(); @@ -219,6 +257,10 @@ public CodeBlock generateSetBeanDefinitionPropertiesCode(GenerationContext gener }; } + public Predicate> typeFilter() { // like only document ones. // TODO: As in MongoDB? + return Predicates.isTrue(); + } + private void contributeRepositoryInfo(AotRepositoryContext repositoryContext, GenerationContext contribution) { RepositoryInformation repositoryInformation = getRepositoryInformation(); @@ -239,7 +281,7 @@ private void contributeRepositoryInfo(AotRepositoryContext repositoryContext, Ge for (RepositoryFragment fragment : getRepositoryInformation().getFragments()) { Class repositoryFragmentType = fragment.getSignatureContributor(); - Optional implementation = fragment.getImplementation(); + Optional> implementation = fragment.getImplementationClass(); contribution.getRuntimeHints().reflection().registerType(repositoryFragmentType, hint -> { @@ -250,13 +292,12 @@ private void contributeRepositoryInfo(AotRepositoryContext repositoryContext, Ge } }); - implementation.ifPresent(impl -> { - Class typeToRegister = impl instanceof Class c ? c : impl.getClass(); + implementation.ifPresent(typeToRegister -> { contribution.getRuntimeHints().reflection().registerType(typeToRegister, hint -> { hint.withMembers(MemberCategory.INVOKE_PUBLIC_METHODS); - if (!impl.getClass().isInterface()) { + if (!typeToRegister.isInterface()) { hint.withMembers(MemberCategory.INVOKE_DECLARED_CONSTRUCTORS); } }); @@ -289,12 +330,6 @@ private void contributeRepositoryInfo(AotRepositoryContext repositoryContext, Ge } // }); - // Reactive Repositories - if (repositoryInformation.isReactiveRepository()) { - // TODO: do we still need this and how to configure it? - // registry.initialization().add(NativeInitializationEntry.ofBuildTimeType(configuration.getRepositoryInterface())); - } - // Kotlin if (isKotlinCoroutineRepository(repositoryContext, repositoryInformation)) { contribution.getRuntimeHints().reflection().registerTypes(kotlinRepositoryReflectionTypeReferences(), @@ -356,29 +391,5 @@ static boolean isJavaOrPrimitiveType(Class type) { || ClassUtils.isPrimitiveArray(type); // } - public Predicate> typeFilter() { // like only document ones. // TODO: As in MongoDB? - return Predicates.isTrue(); - } - - @SuppressWarnings("rawtypes") - private DefaultAotRepositoryContext buildAotRepositoryContext(RegisteredBean bean, - RepositoryConfiguration repositoryConfiguration) { - - DefaultAotRepositoryContext repositoryContext = new DefaultAotRepositoryContext( - AotContext.from(getBeanFactory(), getRepositoryRegistrationAotProcessor().getEnvironment())); - - repositoryContext.setBeanName(bean.getBeanName()); - repositoryContext.setBasePackages(repositoryConfiguration.getBasePackages().toSet()); - repositoryContext.setIdentifyingAnnotations(resolveIdentifyingAnnotations()); - repositoryContext - .setRepositoryInformation(RepositoryBeanDefinitionReader.repositoryInformation(repositoryConfiguration, bean)); - - return repositoryContext; - } - - // TODO: Capture Repository Config - private Set> resolveIdentifyingAnnotations() { - return Collections.emptySet(); - } } diff --git a/src/main/java/org/springframework/data/repository/config/RepositoryRegistrationAotProcessor.java b/src/main/java/org/springframework/data/repository/config/RepositoryRegistrationAotProcessor.java index 7bed43d305..4fbb086106 100644 --- a/src/main/java/org/springframework/data/repository/config/RepositoryRegistrationAotProcessor.java +++ b/src/main/java/org/springframework/data/repository/config/RepositoryRegistrationAotProcessor.java @@ -69,6 +69,7 @@ * * @author Christoph Strobl * @author John Blum + * @author Mark Paluch * @since 3.0 */ public class RepositoryRegistrationAotProcessor @@ -123,11 +124,16 @@ private boolean isRepositoryBean(RegisteredBean bean) { return getConfigMap().containsKey(bean.getBeanName()); } - protected RepositoryRegistrationAotContribution newRepositoryRegistrationAotContribution( + protected @Nullable RepositoryRegistrationAotContribution newRepositoryRegistrationAotContribution( RegisteredBean repositoryBean) { - RepositoryRegistrationAotContribution contribution = RepositoryRegistrationAotContribution.fromProcessor(this) - .forBean(repositoryBean); + RepositoryRegistrationAotContribution contribution = RepositoryRegistrationAotContribution.load(this, + repositoryBean); + + // cannot contribute a repository bean. + if (contribution == null) { + return null; + } //TODO: add the hook for customizing bean initialization code here! diff --git a/src/main/java/org/springframework/data/repository/config/XmlRepositoryConfigurationSource.java b/src/main/java/org/springframework/data/repository/config/XmlRepositoryConfigurationSource.java index 5573613c7d..61dd29a4a0 100644 --- a/src/main/java/org/springframework/data/repository/config/XmlRepositoryConfigurationSource.java +++ b/src/main/java/org/springframework/data/repository/config/XmlRepositoryConfigurationSource.java @@ -144,6 +144,11 @@ public Optional getRepositoryBaseClassName() { return getNullDefaultedAttribute(element, REPOSITORY_BASE_CLASS_NAME); } + @Override + public Optional getRepositoryFragmentsContributorClassName() { + return Optional.empty(); + } + @Override public Optional getRepositoryFactoryBeanClassName() { return getNullDefaultedAttribute(element, REPOSITORY_FACTORY_BEAN_CLASS_NAME); diff --git a/src/main/java/org/springframework/data/repository/core/RepositoryInformation.java b/src/main/java/org/springframework/data/repository/core/RepositoryInformation.java index 3ebee41f24..e3f77cc339 100644 --- a/src/main/java/org/springframework/data/repository/core/RepositoryInformation.java +++ b/src/main/java/org/springframework/data/repository/core/RepositoryInformation.java @@ -18,7 +18,6 @@ import java.lang.reflect.Method; import java.util.List; -import org.jspecify.annotations.Nullable; import org.springframework.data.repository.core.support.RepositoryComposition; /** @@ -106,8 +105,4 @@ default boolean hasQueryMethods() { */ RepositoryComposition getRepositoryComposition(); - default @Nullable String moduleName() { - return null; - } - } diff --git a/src/main/java/org/springframework/data/repository/core/support/DefaultRepositoryInformation.java b/src/main/java/org/springframework/data/repository/core/support/DefaultRepositoryInformation.java index 79e2078b42..42074d3ea1 100644 --- a/src/main/java/org/springframework/data/repository/core/support/DefaultRepositoryInformation.java +++ b/src/main/java/org/springframework/data/repository/core/support/DefaultRepositoryInformation.java @@ -25,6 +25,7 @@ import org.springframework.data.repository.core.RepositoryInformation; import org.springframework.data.repository.core.RepositoryInformationSupport; import org.springframework.data.repository.core.RepositoryMetadata; +import org.springframework.data.util.Lazy; import org.springframework.lang.Contract; import org.springframework.util.Assert; import org.springframework.util.ReflectionUtils; @@ -44,6 +45,7 @@ class DefaultRepositoryInformation extends RepositoryInformationSupport implemen private final RepositoryComposition composition; private final RepositoryComposition baseComposition; + private final Lazy fullComposition; /** * Creates a new {@link DefaultRepositoryMetadata} for the given repository interface and repository base class. @@ -62,6 +64,8 @@ public DefaultRepositoryInformation(RepositoryMetadata metadata, Class reposi this.baseComposition = RepositoryComposition.of(RepositoryFragment.structural(repositoryBaseClass)) // .withArgumentConverter(composition.getArgumentConverter()) // .withMethodLookup(composition.getMethodLookup()); + + this.fullComposition = Lazy.of(() -> composition.append(baseComposition.getFragments())); } @Override @@ -106,7 +110,6 @@ public boolean isBaseClassMethod(Method method) { @Override protected boolean isQueryMethodCandidate(Method method) { - // FIXME - that should be simplified boolean queryMethodCandidate = super.isQueryMethodCandidate(method); if(!isQueryAnnotationPresentOn(method)) { return queryMethodCandidate; @@ -133,7 +136,7 @@ public Set> getFragments() { @Override public RepositoryComposition getRepositoryComposition() { - return composition.append(baseComposition.getFragments()); + return fullComposition.get(); } } diff --git a/src/main/java/org/springframework/data/repository/core/support/RepositoryFactoryBeanSupport.java b/src/main/java/org/springframework/data/repository/core/support/RepositoryFactoryBeanSupport.java index a0f19c5fc2..4ccaba6c53 100644 --- a/src/main/java/org/springframework/data/repository/core/support/RepositoryFactoryBeanSupport.java +++ b/src/main/java/org/springframework/data/repository/core/support/RepositoryFactoryBeanSupport.java @@ -95,10 +95,6 @@ public abstract class RepositoryFactoryBeanSupport, private @Nullable Lazy repository; private @Nullable RepositoryMetadata repositoryMetadata; - // AOT bean factory hint? - private @Nullable String moduleBaseClass; - private @Nullable String moduleName; - /** * Creates a new {@link RepositoryFactoryBeanSupport} for the given repository interface. * @@ -261,14 +257,6 @@ public void setApplicationEventPublisher(ApplicationEventPublisher publisher) { this.publisher = publisher; } - public void setModuleBaseClass(String moduleBaseClass) { - this.moduleBaseClass = moduleBaseClass; - } - - public void setModuleName(String moduleName) { - this.moduleName = moduleName; - } - @Override @SuppressWarnings("unchecked") public EntityInformation getEntityInformation() { @@ -281,6 +269,11 @@ public RepositoryInformation getRepositoryInformation() { return getRequiredFactory().getRepositoryInformation(getRequiredRepositoryMetadata(), cachedFragments); } + @Override + public RepositoryFragmentsContributor getRepositoryFragmentsContributor() { + return RepositoryFragmentsContributor.empty(); + } + @Override public PersistentEntity getPersistentEntity() { diff --git a/src/main/java/org/springframework/data/repository/core/support/RepositoryFactoryInformation.java b/src/main/java/org/springframework/data/repository/core/support/RepositoryFactoryInformation.java index 6ee3adbbf9..75ee8ac65d 100644 --- a/src/main/java/org/springframework/data/repository/core/support/RepositoryFactoryInformation.java +++ b/src/main/java/org/springframework/data/repository/core/support/RepositoryFactoryInformation.java @@ -46,6 +46,15 @@ public interface RepositoryFactoryInformation { */ RepositoryInformation getRepositoryInformation(); + /** + * Returns the {@link RepositoryFragmentsContributor} that is used to contribute additional fragments based on the + * repository declaration. + * + * @return + * @since 4.0 + */ + RepositoryFragmentsContributor getRepositoryFragmentsContributor(); + /** * Returns the {@link PersistentEntity} managed by the underlying repository. Can be {@literal null} in case the * underlying persistence mechanism does not expose a {@link MappingContext}. diff --git a/src/main/java/org/springframework/data/repository/core/support/RepositoryFragment.java b/src/main/java/org/springframework/data/repository/core/support/RepositoryFragment.java index 7b34326a8a..98a6b60735 100644 --- a/src/main/java/org/springframework/data/repository/core/support/RepositoryFragment.java +++ b/src/main/java/org/springframework/data/repository/core/support/RepositoryFragment.java @@ -82,6 +82,18 @@ static RepositoryFragment structural(Class interfaceOrImplementation) return new StructuralRepositoryFragment<>(interfaceOrImplementation); } + /** + * Create a structural {@link RepositoryFragment} given {@code interfaceClass} and {@code implementationClass}. + * + * @param interfaceClass must not be {@literal null}. + * @param implementationClass must not be {@literal null}. + * @return + * @since 4.0 + */ + static RepositoryFragment structural(Class interfaceClass, Class implementationClass) { + return new StructuralRepositoryFragment<>(interfaceClass, implementationClass); + } + /** * Attempt to find the {@link Method} by name and exact parameters. Returns {@literal true} if the method was found or * {@literal false} otherwise. @@ -103,6 +115,15 @@ default Optional getImplementation() { return Optional.empty(); } + /** + * @return the optional implementation class. Only available for fragments that ship an implementation descriptor. + * Structural (interface-only) fragments return always {@link Optional#empty()}. + * @since 4.0 + */ + default Optional> getImplementationClass() { + return getImplementation().map(it -> it.getClass()); + } + /** * @return a {@link Stream} of methods exposed by this {@link RepositoryFragment}. */ @@ -186,17 +207,30 @@ private static boolean hasMethod(Method method, Method[] candidates) { class StructuralRepositoryFragment implements RepositoryFragment { - private final Class interfaceOrImplementation; + private final Class interfaceClass; + private final Class implementationClass; private final Method[] methods; public StructuralRepositoryFragment(Class interfaceOrImplementation) { - this.interfaceOrImplementation = interfaceOrImplementation; - this.methods = getSignatureContributor().getMethods(); + this.interfaceClass = interfaceOrImplementation; + this.implementationClass = interfaceOrImplementation; + this.methods = interfaceOrImplementation.getMethods(); + } + + public StructuralRepositoryFragment(Class interfaceClass, Class implementationClass) { + this.interfaceClass = interfaceClass; + this.implementationClass = implementationClass; + this.methods = interfaceClass.getMethods(); } @Override public Class getSignatureContributor() { - return interfaceOrImplementation; + return interfaceClass; + } + + @Override + public Optional> getImplementationClass() { + return Optional.of(implementationClass); } @Override @@ -221,31 +255,30 @@ public boolean hasMethod(Method method) { @Override public RepositoryFragment withImplementation(T implementation) { - return new ImplementedRepositoryFragment<>(interfaceOrImplementation, implementation); + return new ImplementedRepositoryFragment<>(interfaceClass, implementation); } @Override public String toString() { - return String.format("StructuralRepositoryFragment %s", ClassUtils.getShortName(interfaceOrImplementation)); + return String.format("StructuralRepositoryFragment %s", ClassUtils.getShortName(interfaceClass)); } @Override public boolean equals(Object o) { - - if (this == o) { - return true; + if (!(o instanceof StructuralRepositoryFragment that)) { + return false; } - if (!(o instanceof StructuralRepositoryFragment that)) { + if (!ObjectUtils.nullSafeEquals(interfaceClass, that.interfaceClass)) { return false; } - return ObjectUtils.nullSafeEquals(interfaceOrImplementation, that.interfaceOrImplementation); + return ObjectUtils.nullSafeEquals(implementationClass, that.implementationClass); } @Override public int hashCode() { - return ObjectUtils.nullSafeHashCode(interfaceOrImplementation); + return ObjectUtils.nullSafeHash(interfaceClass, implementationClass); } } diff --git a/src/main/java/org/springframework/data/repository/core/support/RepositoryFragmentsContributor.java b/src/main/java/org/springframework/data/repository/core/support/RepositoryFragmentsContributor.java new file mode 100644 index 0000000000..782b8356c5 --- /dev/null +++ b/src/main/java/org/springframework/data/repository/core/support/RepositoryFragmentsContributor.java @@ -0,0 +1,56 @@ +/* + * Copyright 2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.repository.core.support; + +import org.springframework.data.repository.core.RepositoryMetadata; +import org.springframework.data.repository.core.support.RepositoryComposition.RepositoryFragments; + +/** + * Strategy interface support allowing to contribute a {@link RepositoryFragments} based on {@link RepositoryMetadata}. + *

+ * Fragments contributors enhance repository functionality based on a repository declaration and activate additional + * fragments if a repository defines them, such as extending a built-in fragment interface (e.g. + * {@code QuerydslPredicateExecutor}, {@code QueryByExampleExecutor}). + *

+ * This interface is a base-interface serving as a contract for repository fragment introspection. The actual + * implementation and methods to contribute fragments to be used within the repository instance are store-specific and + * require typically access to infrastructure such as a database connection hence those methods must be defined within + * the particular store module. + * + * @author Mark Paluch + * @since 4.0 + */ +public interface RepositoryFragmentsContributor { + + /** + * Empty {@code RepositoryFragmentsContributor} that does not contribute any fragments. + * + * @return empty {@code RepositoryFragmentsContributor} that does not contribute any fragments. + */ + public static RepositoryFragmentsContributor empty() { + return metadata -> RepositoryFragments.empty(); + } + + /** + * Describe fragments that are contributed by {@link RepositoryMetadata}. Fragment description reports typically + * structural fragments that are not suitable for invocation but can be used to introspect the repository structure. + * + * @param metadata the repository metadata describing the repository interface. + * @return fragments to be (structurally) contributed to the repository. + */ + RepositoryFragments describe(RepositoryMetadata metadata); + +} diff --git a/src/main/java/org/springframework/data/repository/support/Repositories.java b/src/main/java/org/springframework/data/repository/support/Repositories.java index 430139305f..4b4b4ca38f 100644 --- a/src/main/java/org/springframework/data/repository/support/Repositories.java +++ b/src/main/java/org/springframework/data/repository/support/Repositories.java @@ -35,6 +35,7 @@ import org.springframework.data.repository.core.EntityInformation; import org.springframework.data.repository.core.RepositoryInformation; import org.springframework.data.repository.core.support.RepositoryFactoryInformation; +import org.springframework.data.repository.core.support.RepositoryFragmentsContributor; import org.springframework.data.repository.query.QueryMethod; import org.springframework.data.util.ProxyUtils; import org.springframework.util.Assert; @@ -365,6 +366,11 @@ public RepositoryInformation getRepositoryInformation() { throw new UnsupportedOperationException(); } + @Override + public RepositoryFragmentsContributor getRepositoryFragmentsContributor() { + throw new UnsupportedOperationException(); + } + @Override public PersistentEntity getPersistentEntity() { throw new UnsupportedOperationException(); diff --git a/src/test/java/org/springframework/data/aot/sample/ConfigWithCustomRepositoryBaseClass.java b/src/test/java/org/springframework/data/aot/sample/ConfigWithCustomRepositoryBaseClass.java index 29d7471593..790f660b9b 100644 --- a/src/test/java/org/springframework/data/aot/sample/ConfigWithCustomRepositoryBaseClass.java +++ b/src/test/java/org/springframework/data/aot/sample/ConfigWithCustomRepositoryBaseClass.java @@ -22,13 +22,13 @@ import org.springframework.context.annotation.FilterType; import org.springframework.data.aot.sample.ConfigWithCustomRepositoryBaseClass.RepoBaseClass; import org.springframework.data.repository.CrudRepository; -import org.springframework.data.repository.config.EnableRepositories; +import org.springframework.data.repository.config.EnableRepositoriesWithContributor; /** * @author Christoph Strobl */ @Configuration -@EnableRepositories(repositoryBaseClass = RepoBaseClass.class, considerNestedRepositories = true, +@EnableRepositoriesWithContributor(repositoryBaseClass = RepoBaseClass.class, considerNestedRepositories = true, includeFilters = { @Filter(type = FilterType.REGEX, pattern = ".*CustomerRepositoryWithCustomBaseRepo$") }) public class ConfigWithCustomRepositoryBaseClass { diff --git a/src/test/java/org/springframework/data/aot/sample/ConfigWithSimpleCrudRepository.java b/src/test/java/org/springframework/data/aot/sample/ConfigWithSimpleCrudRepository.java index 09236c4418..25e35d9248 100644 --- a/src/test/java/org/springframework/data/aot/sample/ConfigWithSimpleCrudRepository.java +++ b/src/test/java/org/springframework/data/aot/sample/ConfigWithSimpleCrudRepository.java @@ -15,6 +15,8 @@ */ package org.springframework.data.aot.sample; +import org.jspecify.annotations.Nullable; + import org.springframework.context.annotation.ComponentScan.Filter; import org.springframework.context.annotation.FilterType; import org.springframework.data.aot.sample.ConfigWithSimpleCrudRepository.MyRepo; @@ -34,7 +36,7 @@ public interface MyRepo extends CrudRepository { public static class Person { - @javax.annotation.Nullable + @Nullable Address address; } diff --git a/src/test/java/org/springframework/data/repository/aot/RepositoryRegistrationAotProcessorIntegrationTests.java b/src/test/java/org/springframework/data/repository/aot/RepositoryRegistrationAotProcessorIntegrationTests.java index 39bc545541..bb71245359 100644 --- a/src/test/java/org/springframework/data/repository/aot/RepositoryRegistrationAotProcessorIntegrationTests.java +++ b/src/test/java/org/springframework/data/repository/aot/RepositoryRegistrationAotProcessorIntegrationTests.java @@ -55,6 +55,7 @@ import org.springframework.data.repository.config.EnableRepositories; import org.springframework.data.repository.config.RepositoryRegistrationAotContribution; import org.springframework.data.repository.config.RepositoryRegistrationAotProcessor; +import org.springframework.data.repository.config.SampleRepositoryFragmentsContributor; import org.springframework.data.repository.reactive.ReactiveSortingRepository; import org.springframework.transaction.interceptor.TransactionalProxy; @@ -237,10 +238,11 @@ void contributesRepositoryBaseClassCorrectly() { assertThatContribution(repositoryBeanContribution) // .targetRepositoryTypeIs(ConfigWithCustomRepositoryBaseClass.CustomerRepositoryWithCustomBaseRepo.class) // - .hasNoFragments() // + .hasFragments() // .codeContributionSatisfies(contribution -> { // // interface contribution + .contributesReflectionFor(SampleRepositoryFragmentsContributor.class) // repository structural fragment .contributesReflectionFor(ConfigWithCustomRepositoryBaseClass.CustomerRepositoryWithCustomBaseRepo.class) // repository .contributesReflectionFor(ConfigWithCustomRepositoryBaseClass.RepoBaseClass.class) // base repo class .contributesReflectionFor(ConfigWithCustomRepositoryBaseClass.Person.class); // repository domain type diff --git a/src/test/java/org/springframework/data/repository/aot/generate/AotRepositoryBuilderUnitTests.java b/src/test/java/org/springframework/data/repository/aot/generate/AotRepositoryBuilderUnitTests.java index f57dc41c13..1ac8d043b3 100644 --- a/src/test/java/org/springframework/data/repository/aot/generate/AotRepositoryBuilderUnitTests.java +++ b/src/test/java/org/springframework/data/repository/aot/generate/AotRepositoryBuilderUnitTests.java @@ -15,33 +15,37 @@ */ package org.springframework.data.repository.aot.generate; -import static org.assertj.core.api.Assertions.assertThat; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.doReturn; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; +import static org.assertj.core.api.Assertions.*; +import static org.mockito.Mockito.*; -import example.UserRepository; import example.UserRepository.User; +import java.util.List; import java.util.TimeZone; import javax.lang.model.element.Modifier; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; -import org.mockito.Mockito; + import org.springframework.data.geo.Metric; import org.springframework.data.projection.SpelAwareProxyProjectionFactory; +import org.springframework.data.querydsl.QuerydslPredicateExecutor; +import org.springframework.data.repository.CrudRepository; +import org.springframework.data.repository.config.AotRepositoryInformation; import org.springframework.data.repository.core.RepositoryInformation; +import org.springframework.data.repository.core.support.AnnotationRepositoryMetadata; +import org.springframework.data.repository.core.support.RepositoryFragment; import org.springframework.data.repository.query.QueryMethod; -import org.springframework.data.util.TypeInformation; import org.springframework.javapoet.MethodSpec; import org.springframework.javapoet.TypeName; import org.springframework.stereotype.Repository; /** + * Unit tests for {@link AotRepositoryBuilder}. + * * @author Christoph Strobl + * @author Mark Paluch */ class AotRepositoryBuilderUnitTests { @@ -57,7 +61,7 @@ void beforeEach() { @Test // GH-3279 void writesClassSkeleton() { - AotRepositoryBuilder repoBuilder = AotRepositoryBuilder.forRepository(repositoryInformation, + AotRepositoryBuilder repoBuilder = AotRepositoryBuilder.forRepository(repositoryInformation, "Commons", new SpelAwareProxyProjectionFactory()); assertThat(repoBuilder.build().javaFile().toString()) .contains("package %s;".formatted(UserRepository.class.getPackageName())) // same package as source repo @@ -69,7 +73,7 @@ void writesClassSkeleton() { @Test // GH-3279 void appliesCtorArguments() { - AotRepositoryBuilder repoBuilder = AotRepositoryBuilder.forRepository(repositoryInformation, + AotRepositoryBuilder repoBuilder = AotRepositoryBuilder.forRepository(repositoryInformation, "Commons", new SpelAwareProxyProjectionFactory()); repoBuilder.withConstructorCustomizer(ctor -> { ctor.addParameter("param1", Metric.class); @@ -89,7 +93,7 @@ void appliesCtorArguments() { @Test // GH-3279 void appliesCtorCodeBlock() { - AotRepositoryBuilder repoBuilder = AotRepositoryBuilder.forRepository(repositoryInformation, + AotRepositoryBuilder repoBuilder = AotRepositoryBuilder.forRepository(repositoryInformation, "Commons", new SpelAwareProxyProjectionFactory()); repoBuilder.withConstructorCustomizer(ctor -> { ctor.customize((info, code) -> { @@ -103,7 +107,7 @@ void appliesCtorCodeBlock() { @Test // GH-3279 void appliesClassCustomizations() { - AotRepositoryBuilder repoBuilder = AotRepositoryBuilder.forRepository(repositoryInformation, + AotRepositoryBuilder repoBuilder = AotRepositoryBuilder.forRepository(repositoryInformation, "Commons", new SpelAwareProxyProjectionFactory()); repoBuilder.withClassCustomizer((info, metadata, clazz) -> { @@ -128,12 +132,11 @@ void appliesClassCustomizations() { @Test // GH-3279 void appliesQueryMethodContributor() { - AotRepositoryBuilder repoBuilder = AotRepositoryBuilder.forRepository(repositoryInformation, - new SpelAwareProxyProjectionFactory()); + AotRepositoryInformation repositoryInformation = new AotRepositoryInformation( + AnnotationRepositoryMetadata.getMetadata(UserRepository.class), CrudRepository.class, List.of()); - when(repositoryInformation.isQueryMethod(Mockito.argThat(arg -> arg.getName().equals("findByFirstname")))) - .thenReturn(true); - doReturn(TypeInformation.of(User.class)).when(repositoryInformation).getReturnType(any()); + AotRepositoryBuilder repoBuilder = AotRepositoryBuilder.forRepository(repositoryInformation, "Commons", + new SpelAwareProxyProjectionFactory()); repoBuilder.withQueryMethodContributor((method, info) -> { @@ -154,4 +157,35 @@ public boolean contributesMethodSpec() { assertThat(repoBuilder.build().javaFile().toString()) // .containsIgnoringWhitespaces("void oops() { }"); } + + @Test // GH-3279 + void shouldContributeFragmentImplementationMetadata() { + + AotRepositoryInformation repositoryInformation = new AotRepositoryInformation( + AnnotationRepositoryMetadata.getMetadata(QuerydslUserRepository.class), CrudRepository.class, + List.of(RepositoryFragment.structural(QuerydslPredicateExecutor.class, DummyQuerydslPredicateExecutor.class))); + + AotRepositoryBuilder builder = AotRepositoryBuilder.forRepository(repositoryInformation, "Commons", + new SpelAwareProxyProjectionFactory()); + AotRepositoryBuilder.AotBundle bundle = builder.build(); + + AotRepositoryMethod method = bundle.metadata().methods().stream().filter(it -> it.name().equals("findBy")) + .findFirst().get(); + + assertThat(method.fragment()).isNotNull(); + assertThat(method.fragment().signature()).isEqualTo(QuerydslPredicateExecutor.class.getName()); + assertThat(method.fragment().implementation()).isEqualTo(DummyQuerydslPredicateExecutor.class.getName()); + } + + interface UserRepository extends org.springframework.data.repository.Repository { + + String someMethod(); + } + + interface QuerydslUserRepository + extends org.springframework.data.repository.Repository, QuerydslPredicateExecutor { + + } + + interface DummyQuerydslPredicateExecutor extends QuerydslPredicateExecutor {} } diff --git a/src/test/java/org/springframework/data/repository/aot/generate/DummyModuleAotRepositoryContext.java b/src/test/java/org/springframework/data/repository/aot/generate/DummyModuleAotRepositoryContext.java index 05b058f8e5..8c05276a9a 100644 --- a/src/test/java/org/springframework/data/repository/aot/generate/DummyModuleAotRepositoryContext.java +++ b/src/test/java/org/springframework/data/repository/aot/generate/DummyModuleAotRepositoryContext.java @@ -43,6 +43,11 @@ public DummyModuleAotRepositoryContext(Class repositoryInterface, @Nullable R this.repositoryInformation = new StubRepositoryInformation(repositoryInterface, composition); } + @Override + public String getModuleName() { + return "Commons"; + } + @Override public ConfigurableListableBeanFactory getBeanFactory() { return null; diff --git a/src/test/java/org/springframework/data/repository/aot/generate/RepositoryContributorUnitTests.java b/src/test/java/org/springframework/data/repository/aot/generate/RepositoryContributorUnitTests.java index 133281fe0c..9156704008 100644 --- a/src/test/java/org/springframework/data/repository/aot/generate/RepositoryContributorUnitTests.java +++ b/src/test/java/org/springframework/data/repository/aot/generate/RepositoryContributorUnitTests.java @@ -15,10 +15,9 @@ */ package org.springframework.data.repository.aot.generate; -import static org.assertj.core.api.Assertions.assertThat; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.argThat; -import static org.mockito.Mockito.when; +import static org.assertj.core.api.Assertions.*; +import static org.mockito.ArgumentMatchers.*; +import static org.mockito.Mockito.*; import example.UserRepository; import example.UserRepositoryExtension; @@ -31,7 +30,7 @@ import org.jspecify.annotations.Nullable; import org.junit.jupiter.api.Test; -import org.mockito.Mockito; + import org.springframework.aot.test.generate.TestGenerationContext; import org.springframework.core.test.tools.TestCompiler; import org.springframework.data.aot.CodeContributionAssert; @@ -97,8 +96,8 @@ public Map serialize() { @Test // GH-3279 void callsMethodContributionForQueryMethod() { - AotRepositoryContext repositoryContext = Mockito.mock(AotRepositoryContext.class); - RepositoryInformation repositoryInformation = Mockito.mock(RepositoryInformation.class); + AotRepositoryContext repositoryContext = mock(AotRepositoryContext.class); + RepositoryInformation repositoryInformation = mock(RepositoryInformation.class); when(repositoryContext.getRepositoryInformation()).thenReturn(repositoryInformation); when(repositoryInformation.getRepositoryInterface()).thenReturn((Class) UserRepository.class); @@ -113,8 +112,9 @@ void callsMethodContributionForQueryMethod() { @Test // GH-3279 void doesNotContributeBaseClassMethods() { - AotRepositoryContext repositoryContext = Mockito.mock(AotRepositoryContext.class); - RepositoryInformation repositoryInformation = Mockito.mock(RepositoryInformation.class); + AotRepositoryContext repositoryContext = mock(AotRepositoryContext.class); + when(repositoryContext.getModuleName()).thenReturn("Commons"); + RepositoryInformation repositoryInformation = mock(RepositoryInformation.class); when(repositoryContext.getRepositoryInformation()).thenReturn(repositoryInformation); when(repositoryInformation.getRepositoryInterface()).thenReturn((Class) UserRepository.class); @@ -133,8 +133,9 @@ void doesNotContributeBaseClassMethods() { @Test // GH-3279 void doesNotContributeFragmentMethod() { - AotRepositoryContext repositoryContext = Mockito.mock(AotRepositoryContext.class); - RepositoryInformation repositoryInformation = Mockito.mock(RepositoryInformation.class); + AotRepositoryContext repositoryContext = mock(AotRepositoryContext.class); + when(repositoryContext.getModuleName()).thenReturn("Commons"); + RepositoryInformation repositoryInformation = mock(RepositoryInformation.class); when(repositoryContext.getRepositoryInformation()).thenReturn(repositoryInformation); when(repositoryInformation.getRepositoryInterface()).thenReturn((Class) UserRepository.class); @@ -157,8 +158,9 @@ void doesNotContributeFragmentMethod() { @Test // GH-3279 void contributesBaseClassMethodIfQueryMethod() { - AotRepositoryContext repositoryContext = Mockito.mock(AotRepositoryContext.class); - RepositoryInformation repositoryInformation = Mockito.mock(RepositoryInformation.class); + AotRepositoryContext repositoryContext = mock(AotRepositoryContext.class); + when(repositoryContext.getModuleName()).thenReturn("Commons"); + RepositoryInformation repositoryInformation = mock(RepositoryInformation.class); when(repositoryContext.getRepositoryInformation()).thenReturn(repositoryInformation); when(repositoryInformation.getRepositoryInterface()).thenReturn((Class) UserRepository.class); diff --git a/src/test/java/org/springframework/data/repository/config/AnnotationRepositoryConfigurationSourceUnitTests.java b/src/test/java/org/springframework/data/repository/config/AnnotationRepositoryConfigurationSourceUnitTests.java index 8917668d0c..1b1c657cb7 100755 --- a/src/test/java/org/springframework/data/repository/config/AnnotationRepositoryConfigurationSourceUnitTests.java +++ b/src/test/java/org/springframework/data/repository/config/AnnotationRepositoryConfigurationSourceUnitTests.java @@ -15,15 +15,15 @@ */ package org.springframework.data.repository.config; -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; -import static org.mockito.Mockito.mock; +import static org.assertj.core.api.Assertions.*; +import static org.mockito.Mockito.*; import java.lang.annotation.Retention; import java.lang.annotation.RetentionPolicy; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; + import org.springframework.beans.factory.support.BeanDefinitionRegistry; import org.springframework.beans.factory.support.RootBeanDefinition; import org.springframework.context.annotation.ComponentScan.Filter; @@ -184,6 +184,34 @@ void considerBeanNameGenerator() { assertThat(getConfigSource(DefaultConfiguration.class).generateBeanName(bd)).isEqualTo("personRepository"); } + @Test // GH-3279 + void considersDefaultFragmentsContributor() { + + RootBeanDefinition bd = new RootBeanDefinition(DummyRepositoryFactory.class); + bd.getConstructorArgumentValues().addGenericArgumentValue(PersonRepository.class); + + AnnotationMetadata metadata = new StandardAnnotationMetadata(ConfigurationWithFragmentsContributor.class, true); + AnnotationRepositoryConfigurationSource configurationSource = new AnnotationRepositoryConfigurationSource(metadata, + EnableRepositoriesWithContributor.class, resourceLoader, environment, registry, null); + + assertThat(configurationSource.getRepositoryFragmentsContributorClassName()) + .contains(SampleRepositoryFragmentsContributor.class.getName()); + } + + @Test // GH-3279 + void omitsUnspecifiedFragmentsContributor() { + + RootBeanDefinition bd = new RootBeanDefinition(DummyRepositoryFactory.class); + bd.getConstructorArgumentValues().addGenericArgumentValue(PersonRepository.class); + + AnnotationMetadata metadata = new StandardAnnotationMetadata(ReactiveConfigurationWithBeanNameGenerator.class, + true); + AnnotationRepositoryConfigurationSource configurationSource = new AnnotationRepositoryConfigurationSource(metadata, + EnableReactiveRepositories.class, resourceLoader, environment, registry, null); + + assertThat(configurationSource.getRepositoryFragmentsContributorClassName()).isEmpty(); + } + @Test // GH-3082 void considerBeanNameGeneratorForReactiveRepos() { @@ -219,6 +247,9 @@ static class ConfigurationWithExplicitFilter {} @EnableRepositories(nameGenerator = FullyQualifiedAnnotationBeanNameGenerator.class) static class ConfigurationWithBeanNameGenerator {} + @EnableRepositoriesWithContributor() + static class ConfigurationWithFragmentsContributor {} + @EnableReactiveRepositories(nameGenerator = FullyQualifiedAnnotationBeanNameGenerator.class) static class ReactiveConfigurationWithBeanNameGenerator {} @@ -234,4 +265,5 @@ static class ReactiveConfigurationWithBeanNameGenerator {} static class ConfigWithSampleAnnotation {} interface ReactivePersonRepository extends ReactiveCrudRepository {} + } diff --git a/src/test/java/org/springframework/data/repository/config/DummyRegistrarWithContributor.java b/src/test/java/org/springframework/data/repository/config/DummyRegistrarWithContributor.java new file mode 100644 index 0000000000..85708eb8fa --- /dev/null +++ b/src/test/java/org/springframework/data/repository/config/DummyRegistrarWithContributor.java @@ -0,0 +1,40 @@ +/* + * Copyright 2022-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.repository.config; + +import java.lang.annotation.Annotation; + +import org.springframework.core.io.DefaultResourceLoader; + +/** + * @author Mark Paluch + */ +class DummyRegistrarWithContributor extends RepositoryBeanDefinitionRegistrarSupport { + + DummyRegistrarWithContributor() { + setResourceLoader(new DefaultResourceLoader()); + } + + @Override + protected Class getAnnotation() { + return EnableRepositoriesWithContributor.class; + } + + @Override + protected RepositoryConfigurationExtension getExtension() { + return new DummyConfigurationExtension(); + } +} diff --git a/src/test/java/org/springframework/data/repository/config/EnableRepositoriesWithContributor.java b/src/test/java/org/springframework/data/repository/config/EnableRepositoriesWithContributor.java new file mode 100644 index 0000000000..2c38047e00 --- /dev/null +++ b/src/test/java/org/springframework/data/repository/config/EnableRepositoriesWithContributor.java @@ -0,0 +1,61 @@ +/* + * Copyright 2012-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.repository.config; + +import java.lang.annotation.Inherited; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; + +import org.springframework.beans.factory.support.BeanNameGenerator; +import org.springframework.context.annotation.ComponentScan.Filter; +import org.springframework.context.annotation.Import; +import org.springframework.data.repository.PagingAndSortingRepository; +import org.springframework.data.repository.core.support.DummyRepositoryFactoryBean; +import org.springframework.data.repository.core.support.RepositoryFragmentsContributor; + +@Retention(RetentionPolicy.RUNTIME) +@Import(DummyRegistrarWithContributor.class) +@Inherited +public @interface EnableRepositoriesWithContributor { + + String[] value() default {}; + + String[] basePackages() default {}; + + Class[] basePackageClasses() default {}; + + Filter[] includeFilters() default {}; + + Filter[] excludeFilters() default {}; + + Class repositoryFactoryBeanClass() default DummyRepositoryFactoryBean.class; + + Class fragmentsContributor() default SampleRepositoryFragmentsContributor.class; + + Class repositoryBaseClass() default PagingAndSortingRepository.class; + + Class nameGenerator() default BeanNameGenerator.class; + + String namedQueriesLocation() default ""; + + String repositoryImplementationPostfix() default "Impl"; + + boolean considerNestedRepositories() default false; + + boolean limitImplementationBasePackages() default true; + + BootstrapMode bootstrapMode() default BootstrapMode.DEFAULT; +} diff --git a/src/test/java/org/springframework/data/repository/config/RepositoryBeanDefinitionReaderTests.java b/src/test/java/org/springframework/data/repository/config/RepositoryBeanDefinitionReaderTests.java index 54379865c1..13482da3f8 100644 --- a/src/test/java/org/springframework/data/repository/config/RepositoryBeanDefinitionReaderTests.java +++ b/src/test/java/org/springframework/data/repository/config/RepositoryBeanDefinitionReaderTests.java @@ -15,24 +15,29 @@ */ package org.springframework.data.repository.config; -import static org.assertj.core.api.Assertions.assertThat; -import static org.mockito.Mockito.mock; +import static org.assertj.core.api.Assertions.*; +import static org.mockito.Mockito.*; import org.junit.jupiter.api.Test; -import org.mockito.Mockito; + import org.springframework.aot.hint.RuntimeHints; import org.springframework.beans.factory.support.RegisteredBean; -import org.springframework.beans.factory.support.RootBeanDefinition; import org.springframework.context.annotation.AnnotationConfigApplicationContext; import org.springframework.data.aot.sample.ConfigWithCustomImplementation; import org.springframework.data.aot.sample.ConfigWithCustomRepositoryBaseClass; import org.springframework.data.aot.sample.ConfigWithCustomRepositoryBaseClass.CustomerRepositoryWithCustomBaseRepo; +import org.springframework.data.aot.sample.ConfigWithFragments; import org.springframework.data.aot.sample.ConfigWithSimpleCrudRepository; +import org.springframework.data.aot.sample.ReactiveConfig; import org.springframework.data.repository.core.RepositoryInformation; import org.springframework.data.repository.core.support.RepositoryFactoryBeanSupport; +import org.springframework.data.repository.core.support.RepositoryFragment; /** + * Unit tests for {@link RepositoryBeanDefinitionReader}. + * * @author Christoph Strobl + * @author Mark Paluch */ class RepositoryBeanDefinitionReaderTests { @@ -42,10 +47,10 @@ void readsSimpleConfigFromBeanFactory() { RegisteredBean repoFactoryBean = repositoryFactory(ConfigWithSimpleCrudRepository.class); RepositoryConfiguration repoConfig = mock(RepositoryConfiguration.class); - Mockito.when(repoConfig.getRepositoryInterface()).thenReturn(ConfigWithSimpleCrudRepository.MyRepo.class.getName()); + when(repoConfig.getRepositoryInterface()).thenReturn(ConfigWithSimpleCrudRepository.MyRepo.class.getName()); - RepositoryInformation repositoryInformation = RepositoryBeanDefinitionReader.repositoryInformation(repoConfig, - repoFactoryBean.getMergedBeanDefinition(), repoFactoryBean.getBeanFactory()); + RepositoryBeanDefinitionReader reader = new RepositoryBeanDefinitionReader(repoFactoryBean); + RepositoryInformation repositoryInformation = reader.getRepositoryInformation(); assertThat(repositoryInformation.getRepositoryInterface()).isEqualTo(ConfigWithSimpleCrudRepository.MyRepo.class); assertThat(repositoryInformation.getDomainType()).isEqualTo(ConfigWithSimpleCrudRepository.Person.class); @@ -59,49 +64,86 @@ void readsCustomRepoBaseClassFromBeanFactory() { RepositoryConfiguration repoConfig = mock(RepositoryConfiguration.class); Class repositoryInterfaceType = CustomerRepositoryWithCustomBaseRepo.class; - Mockito.when(repoConfig.getRepositoryInterface()).thenReturn(repositoryInterfaceType.getName()); + when(repoConfig.getRepositoryInterface()).thenReturn(repositoryInterfaceType.getName()); - RepositoryInformation repositoryInformation = RepositoryBeanDefinitionReader.repositoryInformation(repoConfig, - repoFactoryBean.getMergedBeanDefinition(), repoFactoryBean.getBeanFactory()); + RepositoryBeanDefinitionReader reader = new RepositoryBeanDefinitionReader(repoFactoryBean); + RepositoryInformation repositoryInformation = reader.getRepositoryInformation(); assertThat(repositoryInformation.getRepositoryBaseClass()) .isEqualTo(ConfigWithCustomRepositoryBaseClass.RepoBaseClass.class); } @Test // GH-3279 - void readsFragmentsFromBeanFactory() { + void readsFragmentsContributorFromBeanDefinition() { - RegisteredBean repoFactoryBean = repositoryFactory(ConfigWithCustomImplementation.class); + RegisteredBean repoFactoryBean = repositoryFactory(ConfigWithCustomRepositoryBaseClass.class); + + RepositoryConfiguration repoConfig = mock(RepositoryConfiguration.class); + Class repositoryInterfaceType = CustomerRepositoryWithCustomBaseRepo.class; + when(repoConfig.getRepositoryInterface()).thenReturn(repositoryInterfaceType.getName()); + + RepositoryBeanDefinitionReader reader = new RepositoryBeanDefinitionReader(repoFactoryBean); + RepositoryInformation repositoryInformation = reader.getRepositoryInformation(); + + assertThat(repositoryInformation.getFragments()) + .contains(RepositoryFragment.structural(SampleRepositoryFragmentsContributor.class)); + } + + @Test // GH-3279 + void readsFragmentsContributorFromBeanFactory() { + + RegisteredBean repoFactoryBean = repositoryFactory(ReactiveConfig.class); + + RepositoryConfiguration repoConfig = mock(RepositoryConfiguration.class); + Class repositoryInterfaceType = ReactiveConfig.CustomerRepositoryReactive.class; + when(repoConfig.getRepositoryInterface()).thenReturn(repositoryInterfaceType.getName()); + RepositoryBeanDefinitionReader reader = new RepositoryBeanDefinitionReader(repoFactoryBean); + RepositoryInformation repositoryInformation = reader.getRepositoryInformation(); + + assertThat(repositoryInformation.getFragments()).isEmpty(); + } + + @Test // GH-3279, GH-3282 + void readsCustomImplementationFromBeanFactory() { + + RegisteredBean repoFactoryBean = repositoryFactory(ConfigWithCustomImplementation.class); RepositoryConfiguration repoConfig = mock(RepositoryConfiguration.class); + Class repositoryInterfaceType = ConfigWithCustomImplementation.RepositoryWithCustomImplementation.class; - Mockito.when(repoConfig.getRepositoryInterface()).thenReturn(repositoryInterfaceType.getName()); + when(repoConfig.getRepositoryInterface()).thenReturn(repositoryInterfaceType.getName()); - RepositoryInformation repositoryInformation = RepositoryBeanDefinitionReader.repositoryInformation(repoConfig, - repoFactoryBean.getMergedBeanDefinition(), repoFactoryBean.getBeanFactory()); + RepositoryBeanDefinitionReader reader = new RepositoryBeanDefinitionReader(repoFactoryBean); + RepositoryInformation repositoryInformation = reader.getRepositoryInformation(); assertThat(repositoryInformation.getFragments()).satisfiesExactly(fragment -> { - assertThat(fragment.getSignatureContributor()) - .isEqualTo(ConfigWithCustomImplementation.CustomImplInterface.class); + assertThat(fragment.getImplementationClass()) + .contains(ConfigWithCustomImplementation.RepositoryWithCustomImplementationImpl.class); }); } - @Test // GH-3279 - void fallsBackToModuleBaseClassIfSetAndNoRepoBaseDefined() { - - RegisteredBean repoFactoryBean = repositoryFactory(ConfigWithSimpleCrudRepository.class); - RootBeanDefinition rootBeanDefinition = repoFactoryBean.getMergedBeanDefinition().cloneBeanDefinition(); - // need to unset because its defined as non default - rootBeanDefinition.getPropertyValues().removePropertyValue("repositoryBaseClass"); - rootBeanDefinition.getPropertyValues().add("moduleBaseClass", ModuleBase.class.getName()); + @Test // GH-3279, GH-3282 + void readsFragmentsFromBeanFactory() { + RegisteredBean repoFactoryBean = repositoryFactory(ConfigWithFragments.class); RepositoryConfiguration repoConfig = mock(RepositoryConfiguration.class); - Mockito.when(repoConfig.getRepositoryInterface()).thenReturn(ConfigWithSimpleCrudRepository.MyRepo.class.getName()); - RepositoryInformation repositoryInformation = RepositoryBeanDefinitionReader.repositoryInformation(repoConfig, - rootBeanDefinition, repoFactoryBean.getBeanFactory()); + Class repositoryInterfaceType = ConfigWithFragments.RepositoryWithFragments.class; + when(repoConfig.getRepositoryInterface()).thenReturn(repositoryInterfaceType.getName()); + + RepositoryBeanDefinitionReader reader = new RepositoryBeanDefinitionReader(repoFactoryBean); + RepositoryInformation repositoryInformation = reader.getRepositoryInformation(); - assertThat(repositoryInformation.getRepositoryBaseClass()).isEqualTo(ModuleBase.class); + assertThat(repositoryInformation.getFragments()).hasSize(2); + + for (RepositoryFragment fragment : repositoryInformation.getFragments()) { + + assertThat(fragment.getSignatureContributor()).isIn(ConfigWithFragments.CustomImplInterface1.class, + ConfigWithFragments.CustomImplInterface2.class); + + assertThat(fragment.getImplementationClass().get()).isIn(ConfigWithFragments.CustomImplInterface1Impl.class, + ConfigWithFragments.CustomImplInterface2Impl.class); + } } static RegisteredBean repositoryFactory(Class configClass) { @@ -118,5 +160,4 @@ static RegisteredBean repositoryFactory(Class configClass) { return RegisteredBean.of(applicationContext.getBeanFactory(), beanNamesForType[0]); } - static class ModuleBase {} } diff --git a/src/test/java/org/springframework/data/repository/config/SampleRepositoryFragmentsContributor.java b/src/test/java/org/springframework/data/repository/config/SampleRepositoryFragmentsContributor.java new file mode 100644 index 0000000000..a22db03b6a --- /dev/null +++ b/src/test/java/org/springframework/data/repository/config/SampleRepositoryFragmentsContributor.java @@ -0,0 +1,33 @@ +/* + * Copyright 2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.repository.config; + +import org.springframework.data.repository.core.RepositoryMetadata; +import org.springframework.data.repository.core.support.RepositoryComposition; +import org.springframework.data.repository.core.support.RepositoryFragment; +import org.springframework.data.repository.core.support.RepositoryFragmentsContributor; + +/** + * @author Mark Paluch + */ +public class SampleRepositoryFragmentsContributor implements RepositoryFragmentsContributor { + + @Override + public RepositoryComposition.RepositoryFragments describe(RepositoryMetadata metadata) { + return RepositoryComposition.RepositoryFragments + .of(RepositoryFragment.structural(SampleRepositoryFragmentsContributor.class)); + } +} diff --git a/src/test/java/org/springframework/data/repository/core/support/DummyRepositoryFactoryBean.java b/src/test/java/org/springframework/data/repository/core/support/DummyRepositoryFactoryBean.java index 8a5b9b6d6d..4121d6f0c3 100644 --- a/src/test/java/org/springframework/data/repository/core/support/DummyRepositoryFactoryBean.java +++ b/src/test/java/org/springframework/data/repository/core/support/DummyRepositoryFactoryBean.java @@ -29,6 +29,7 @@ public class DummyRepositoryFactoryBean, S, ID exten extends RepositoryFactoryBeanSupport { private final T repository; + private RepositoryFragmentsContributor repositoryFragmentsContributor = RepositoryFragmentsContributor.empty(); public DummyRepositoryFactoryBean(Class repositoryInterface) { @@ -38,6 +39,19 @@ public DummyRepositoryFactoryBean(Class repositoryInterface) { setMappingContext(new SampleMappingContext()); } + public T getRepository() { + return repository; + } + + @Override + public RepositoryFragmentsContributor getRepositoryFragmentsContributor() { + return repositoryFragmentsContributor; + } + + public void setRepositoryFragmentsContributor(RepositoryFragmentsContributor repositoryFragmentsContributor) { + this.repositoryFragmentsContributor = repositoryFragmentsContributor; + } + @Override protected RepositoryFactorySupport createRepositoryFactory() { return new DummyRepositoryFactory(repository); diff --git a/src/test/java/org/springframework/data/repository/support/RepositoriesUnitTests.java b/src/test/java/org/springframework/data/repository/support/RepositoriesUnitTests.java index 22c0959811..cbb08cd940 100755 --- a/src/test/java/org/springframework/data/repository/support/RepositoriesUnitTests.java +++ b/src/test/java/org/springframework/data/repository/support/RepositoriesUnitTests.java @@ -46,6 +46,7 @@ import org.springframework.data.repository.core.support.DummyRepositoryFactoryBean; import org.springframework.data.repository.core.support.DummyRepositoryInformation; import org.springframework.data.repository.core.support.RepositoryFactoryInformation; +import org.springframework.data.repository.core.support.RepositoryFragmentsContributor; import org.springframework.data.repository.query.QueryMethod; import org.springframework.data.util.TypeInformation; import org.springframework.util.ClassUtils; @@ -290,6 +291,11 @@ public RepositoryInformation getRepositoryInformation() { return new DummyRepositoryInformation(repositoryMetadata); } + @Override + public RepositoryFragmentsContributor getRepositoryFragmentsContributor() { + return RepositoryFragmentsContributor.empty(); + } + @Override public PersistentEntity getPersistentEntity() { return mappingContext.getRequiredPersistentEntity(repositoryMetadata.getDomainType()); From 01c3d212c619662631972a161dc6b46edb14dabe Mon Sep 17 00:00:00 2001 From: Mark Paluch Date: Wed, 7 May 2025 12:29:14 +0200 Subject: [PATCH 5/6] Polishing. Introduce dedicated contributor interfaces to AotRepositoryBuilder. --- .../aot/generate/AotRepositoryBuilder.java | 70 ++++++++++++++----- 1 file changed, 51 insertions(+), 19 deletions(-) diff --git a/src/main/java/org/springframework/data/repository/aot/generate/AotRepositoryBuilder.java b/src/main/java/org/springframework/data/repository/aot/generate/AotRepositoryBuilder.java index 7ca6b536e4..ebf1760cf1 100644 --- a/src/main/java/org/springframework/data/repository/aot/generate/AotRepositoryBuilder.java +++ b/src/main/java/org/springframework/data/repository/aot/generate/AotRepositoryBuilder.java @@ -22,14 +22,13 @@ import java.util.LinkedHashMap; import java.util.List; import java.util.Map; -import java.util.function.BiFunction; -import java.util.function.Consumer; import javax.lang.model.element.Modifier; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.jspecify.annotations.Nullable; + import org.springframework.aot.generate.ClassNameGenerator; import org.springframework.aot.generate.Generated; import org.springframework.data.projection.ProjectionFactory; @@ -58,8 +57,8 @@ class AotRepositoryBuilder { private final ProjectionFactory projectionFactory; private final AotRepositoryFragmentMetadata generationMetadata; - private @Nullable Consumer constructorCustomizer; - private @Nullable BiFunction> methodContributorFunction; + private @Nullable ConstructorCustomizer constructorCustomizer; + private @Nullable MethodContributorFactory methodContributorFactory; private ClassCustomizer customizer; private AotRepositoryBuilder(RepositoryInformation repositoryInformation, String moduleName, @@ -109,23 +108,21 @@ public AotRepositoryBuilder withClassCustomizer(ClassCustomizer classCustomizer) * @param constructorCustomizer must not be {@literal null}. * @return {@code this}. */ - public AotRepositoryBuilder withConstructorCustomizer( - Consumer constructorCustomizer) { + public AotRepositoryBuilder withConstructorCustomizer(ConstructorCustomizer constructorCustomizer) { this.constructorCustomizer = constructorCustomizer; return this; } /** - * Configure a {@link MethodContributor}. + * Configure a {@link MethodContributor} factory. * - * @param methodContributorFunction must not be {@literal null}. + * @param methodContributorFactory must not be {@literal null}. * @return {@code this}. */ - public AotRepositoryBuilder withQueryMethodContributor( - BiFunction> methodContributorFunction) { + public AotRepositoryBuilder withQueryMethodContributor(MethodContributorFactory methodContributorFactory) { - this.methodContributorFunction = methodContributorFunction; + this.methodContributorFactory = methodContributorFactory; return this; } @@ -170,7 +167,7 @@ private MethodSpec buildConstructor() { generationMetadata); if (constructorCustomizer != null) { - constructorCustomizer.accept(constructorBuilder); + constructorCustomizer.customize(constructorBuilder); } return constructorBuilder.buildConstructor(); @@ -191,7 +188,8 @@ private AotRepositoryMetadata getAotRepositoryMetadata(List private void contributeMethod(Method method, RepositoryComposition repositoryComposition, List methodMetadata, TypeSpec.Builder builder) { - if (repositoryInformation.isCustomMethod(method) || (repositoryInformation.isBaseClassMethod(method) && !repositoryInformation.isQueryMethod(method))) { + if (repositoryInformation.isCustomMethod(method) + || (repositoryInformation.isBaseClassMethod(method) && !repositoryInformation.isQueryMethod(method))) { RepositoryFragment fragment = repositoryComposition.findFragment(method); @@ -205,9 +203,9 @@ private void contributeMethod(Method method, RepositoryComposition repositoryCom return; } - if (repositoryInformation.isQueryMethod(method) && methodContributorFunction != null) { + if (repositoryInformation.isQueryMethod(method) && methodContributorFactory != null) { - MethodContributor contributor = methodContributorFunction.apply(method, + MethodContributor contributor = methodContributorFactory.create(method, repositoryInformation); if (contributor != null) { @@ -273,16 +271,50 @@ public ProjectionFactory getProjectionFactory() { public interface ClassCustomizer { /** - * Apply customization ot the AOT repository fragment class after it has been defined.. + * Apply customization ot the AOT repository fragment class after it has been defined. * - * @param information - * @param metadata - * @param builder + * @param information repository information. + * @param metadata metadata of the AOT repository fragment. + * @param builder the actual builder. */ void customize(RepositoryInformation information, AotRepositoryFragmentMetadata metadata, TypeSpec.Builder builder); } + /** + * Customizer interface to customize the AOT repository fragment constructor through + * {@link AotRepositoryConstructorBuilder}. + */ + public interface ConstructorCustomizer { + + /** + * Apply customization ot the AOT repository fragment constructor. + * + * @param constructorBuilder the builder to be customized. + */ + void customize(AotRepositoryConstructorBuilder constructorBuilder); + + } + + /** + * Factory interface to conditionally create {@link MethodContributor} instances. An implementation may decide whether + * to return a {@link MethodContributor} or {@literal null}, if no method (code or metadata) should be contributed. + */ + public interface MethodContributorFactory { + + /** + * Apply customization ot the AOT repository fragment constructor. + * + * @param method the method to be contributed. + * @param information repository information. + * @return the {@link MethodContributor} to be used. Can be {@literal null} if the method and method metadata should + * not be contributed. + */ + @Nullable + MethodContributor create(Method method, RepositoryInformation information); + + } + record AotBundle(JavaFile javaFile, AotRepositoryMetadata metadata) { } From cbed5c3ee00ddbec27f40ac40172fa8dbf7a9139 Mon Sep 17 00:00:00 2001 From: Mark Paluch Date: Wed, 7 May 2025 15:52:23 +0200 Subject: [PATCH 6/6] Polishing. --- .../data/repository/aot/generate/AotRepositoryBuilder.java | 2 +- .../repository/aot/generate/RepositoryContributor.java | 7 +------ 2 files changed, 2 insertions(+), 7 deletions(-) diff --git a/src/main/java/org/springframework/data/repository/aot/generate/AotRepositoryBuilder.java b/src/main/java/org/springframework/data/repository/aot/generate/AotRepositoryBuilder.java index ebf1760cf1..d26fd21f37 100644 --- a/src/main/java/org/springframework/data/repository/aot/generate/AotRepositoryBuilder.java +++ b/src/main/java/org/springframework/data/repository/aot/generate/AotRepositoryBuilder.java @@ -179,7 +179,7 @@ private AotRepositoryMetadata getAotRepositoryMetadata(List ? AotRepositoryMetadata.RepositoryType.REACTIVE : AotRepositoryMetadata.RepositoryType.IMPERATIVE; - String jsonModuleName = moduleName.replaceAll("Reactive", "").trim(); + String jsonModuleName = moduleName != null ? moduleName.replaceAll("Reactive", "").trim() : null; return new AotRepositoryMetadata(repositoryInformation.getRepositoryInterface().getName(), jsonModuleName, repositoryType, methodMetadata); diff --git a/src/main/java/org/springframework/data/repository/aot/generate/RepositoryContributor.java b/src/main/java/org/springframework/data/repository/aot/generate/RepositoryContributor.java index bcfc9d7a16..6c66c7ffee 100644 --- a/src/main/java/org/springframework/data/repository/aot/generate/RepositoryContributor.java +++ b/src/main/java/org/springframework/data/repository/aot/generate/RepositoryContributor.java @@ -32,7 +32,6 @@ import org.springframework.javapoet.JavaFile; import org.springframework.javapoet.TypeName; import org.springframework.javapoet.TypeSpec; -import org.springframework.util.StringUtils; /** * Contributor for AOT repository fragments. @@ -131,11 +130,7 @@ public void contribute(GenerationContext generationContext) { } private static String getRepositoryJsonFileName(Class repositoryInterface) { - - String repositoryJsonName = repositoryInterface.getSimpleName() + ".json"; - String repositoryJsonPath = repositoryInterface.getPackageName().replace('.', '/'); - - return StringUtils.hasText(repositoryJsonPath) ? repositoryJsonPath + "/" + repositoryJsonName : repositoryJsonName; + return repositoryInterface.getName().replace('.', '/') + ".json"; } /**