diff --git a/src/main/java/org/springframework/data/repository/aot/generate/AotQueryMethodGenerationContext.java b/src/main/java/org/springframework/data/repository/aot/generate/AotQueryMethodGenerationContext.java index 076a66c8c8..4a2fe26999 100644 --- a/src/main/java/org/springframework/data/repository/aot/generate/AotQueryMethodGenerationContext.java +++ b/src/main/java/org/springframework/data/repository/aot/generate/AotQueryMethodGenerationContext.java @@ -293,4 +293,22 @@ public String getLimitParameterName() { return getParameterName(queryMethod.getParameters().getLimitIndex()); } + /** + * @return the parameter name for the {@link org.springframework.data.domain.ScrollPosition scroll position parameter} + * or {@code null} if the method does not declare a scroll position parameter. + */ + @Nullable + public String getScrollPositionParameterName() { + return getParameterName(queryMethod.getParameters().getScrollPositionIndex()); + } + + /** + * @return the parameter name for the {@link Class dynamic projection parameter} or {@code null} if the method does + * not declare a dynamic projection parameter. + */ + @Nullable + public String getDynamicProjectionParameterName() { + return getParameterName(queryMethod.getParameters().getDynamicProjectionIndex()); + } + } diff --git a/src/main/java/org/springframework/data/repository/aot/generate/MethodMetadata.java b/src/main/java/org/springframework/data/repository/aot/generate/MethodMetadata.java index dd9885933c..a169b02629 100644 --- a/src/main/java/org/springframework/data/repository/aot/generate/MethodMetadata.java +++ b/src/main/java/org/springframework/data/repository/aot/generate/MethodMetadata.java @@ -48,18 +48,25 @@ class MethodMetadata { MethodMetadata(RepositoryInformation repositoryInformation, Method method) { this.returnType = repositoryInformation.getReturnType(method).toResolvableType(); - this.actualReturnType = ResolvableType.forType(repositoryInformation.getReturnedDomainClass(method)); + this.actualReturnType = repositoryInformation.getReturnedDomainTypeInformation(method).toResolvableType(); this.initParameters(repositoryInformation, method, new DefaultParameterNameDiscoverer()); } - @Nullable - public String getParameterNameOf(Class type) { - for (Entry entry : methodArguments.entrySet()) { - if (entry.getValue().type.equals(TypeName.get(type))) { - return entry.getKey(); - } + private void initParameters(RepositoryInformation repositoryInformation, Method method, + ParameterNameDiscoverer nameDiscoverer) { + + ResolvableType repositoryInterface = ResolvableType.forClass(repositoryInformation.getRepositoryInterface()); + + for (java.lang.reflect.Parameter parameter : method.getParameters()) { + + MethodParameter methodParameter = MethodParameter.forParameter(parameter); + methodParameter.initParameterNameDiscovery(nameDiscoverer); + ResolvableType resolvableParameterType = ResolvableType.forMethodParameter(methodParameter, repositoryInterface); + + TypeName parameterType = TypeName.get(resolvableParameterType.getType()); + + addParameter(ParameterSpec.builder(parameterType, methodParameter.getParameterName()).build()); } - return null; } ResolvableType getReturnType() { @@ -96,20 +103,4 @@ Map getLocalVariables() { return localVariables; } - private void initParameters(RepositoryInformation repositoryInformation, Method method, - ParameterNameDiscoverer nameDiscoverer) { - - ResolvableType repositoryInterface = ResolvableType.forClass(repositoryInformation.getRepositoryInterface()); - - for (java.lang.reflect.Parameter parameter : method.getParameters()) { - - MethodParameter methodParameter = MethodParameter.forParameter(parameter); - methodParameter.initParameterNameDiscovery(nameDiscoverer); - ResolvableType resolvableParameterType = ResolvableType.forMethodParameter(methodParameter, repositoryInterface); - - TypeName parameterType = TypeName.get(resolvableParameterType.getType()); - - addParameter(ParameterSpec.builder(parameterType, methodParameter.getParameterName()).build()); - } - } } diff --git a/src/main/java/org/springframework/data/repository/config/AotRepositoryContext.java b/src/main/java/org/springframework/data/repository/config/AotRepositoryContext.java index 231e7bba18..0a89486886 100644 --- a/src/main/java/org/springframework/data/repository/config/AotRepositoryContext.java +++ b/src/main/java/org/springframework/data/repository/config/AotRepositoryContext.java @@ -45,10 +45,17 @@ public interface AotRepositoryContext extends AotContext { */ String getModuleName(); + /** + * @return the repository configuration source. + */ + RepositoryConfigurationSource getConfigurationSource(); + /** * @return a {@link Set} of {@link String base packages} to search for repositories. */ - Set getBasePackages(); + default Set getBasePackages() { + return getConfigurationSource().getBasePackages().toSet(); + } /** * @return the {@link Annotation} types used to identify domain types. diff --git a/src/main/java/org/springframework/data/repository/config/DefaultAotRepositoryContext.java b/src/main/java/org/springframework/data/repository/config/DefaultAotRepositoryContext.java index 5f695f6276..ccf60a01e5 100644 --- a/src/main/java/org/springframework/data/repository/config/DefaultAotRepositoryContext.java +++ b/src/main/java/org/springframework/data/repository/config/DefaultAotRepositoryContext.java @@ -46,6 +46,7 @@ class DefaultAotRepositoryContext implements AotRepositoryContext { private final RegisteredBean bean; private final String moduleName; + private final RepositoryConfigurationSource configurationSource; private final AotContext aotContext; private final RepositoryInformation repositoryInformation; private final Lazy>> resolvedAnnotations = Lazy.of(this::discoverAnnotations); @@ -56,12 +57,14 @@ class DefaultAotRepositoryContext implements AotRepositoryContext { private String beanName; public DefaultAotRepositoryContext(RegisteredBean bean, RepositoryInformation repositoryInformation, - String moduleName, AotContext aotContext) { + String moduleName, AotContext aotContext, RepositoryConfigurationSource configurationSource) { this.bean = bean; this.repositoryInformation = repositoryInformation; this.moduleName = moduleName; + this.configurationSource = configurationSource; this.aotContext = aotContext; this.beanName = bean.getBeanName(); + this.basePackages = configurationSource.getBasePackages().toSet(); } public AotContext getAotContext() { @@ -73,6 +76,11 @@ public String getModuleName() { return moduleName; } + @Override + public RepositoryConfigurationSource getConfigurationSource() { + return configurationSource; + } + @Override public ConfigurableListableBeanFactory getBeanFactory() { return getAotContext().getBeanFactory(); diff --git a/src/main/java/org/springframework/data/repository/config/RepositoryRegistrationAotContribution.java b/src/main/java/org/springframework/data/repository/config/RepositoryRegistrationAotContribution.java index feddf13e5c..34b0caff52 100644 --- a/src/main/java/org/springframework/data/repository/config/RepositoryRegistrationAotContribution.java +++ b/src/main/java/org/springframework/data/repository/config/RepositoryRegistrationAotContribution.java @@ -194,9 +194,9 @@ private void logTrace(String message, Object... arguments) { } RepositoryInformation repositoryInformation = reader.getRepositoryInformation(); DefaultAotRepositoryContext repositoryContext = new DefaultAotRepositoryContext(bean, repositoryInformation, - extension.getModuleName(), AotContext.from(bean.getBeanFactory(), environment)); + extension.getModuleName(), AotContext.from(bean.getBeanFactory(), environment), + configuration.getConfigurationSource()); - repositoryContext.setBasePackages(repositoryConfiguration.getBasePackages().toSet()); repositoryContext.setIdentifyingAnnotations(extension.getIdentifyingAnnotations()); return repositoryContext; diff --git a/src/main/java/org/springframework/data/repository/core/RepositoryInformationSupport.java b/src/main/java/org/springframework/data/repository/core/RepositoryInformationSupport.java index 94660dec28..6bd719335e 100644 --- a/src/main/java/org/springframework/data/repository/core/RepositoryInformationSupport.java +++ b/src/main/java/org/springframework/data/repository/core/RepositoryInformationSupport.java @@ -89,6 +89,11 @@ public Class getReturnedDomainClass(Method method) { return getMetadata().getReturnedDomainClass(method); } + @Override + public TypeInformation getReturnedDomainTypeInformation(Method method) { + return getMetadata().getReturnedDomainTypeInformation(method); + } + @Override public CrudMethods getCrudMethods() { return getMetadata().getCrudMethods(); diff --git a/src/main/java/org/springframework/data/repository/core/RepositoryMetadata.java b/src/main/java/org/springframework/data/repository/core/RepositoryMetadata.java index 43c89a6763..bca328fb56 100644 --- a/src/main/java/org/springframework/data/repository/core/RepositoryMetadata.java +++ b/src/main/java/org/springframework/data/repository/core/RepositoryMetadata.java @@ -91,10 +91,26 @@ default Class getDomainType() { * * @param method * @return + * @see #getReturnedDomainTypeInformation(Method) * @see #getReturnType(Method) */ Class getReturnedDomainClass(Method method); + /** + * Returns the domain type information returned by the given {@link Method}. In contrast to + * {@link #getReturnType(Method)}, this method extracts the type from {@link Collection}s and + * {@link org.springframework.data.domain.Page} as well. + * + * @param method + * @return + * @see #getReturnedDomainClass(Method) + * @see #getReturnType(Method) + * @since 4.0 + */ + default TypeInformation getReturnedDomainTypeInformation(Method method) { + return TypeInformation.of(getReturnedDomainClass(method)); + } + /** * Returns {@link CrudMethods} meta information for the repository. * diff --git a/src/main/java/org/springframework/data/repository/core/support/AbstractRepositoryMetadata.java b/src/main/java/org/springframework/data/repository/core/support/AbstractRepositoryMetadata.java index cb91cc41d1..a843053844 100644 --- a/src/main/java/org/springframework/data/repository/core/support/AbstractRepositoryMetadata.java +++ b/src/main/java/org/springframework/data/repository/core/support/AbstractRepositoryMetadata.java @@ -100,11 +100,16 @@ public TypeInformation getReturnType(Method method) { @Override public Class getReturnedDomainClass(Method method) { + return getReturnedDomainTypeInformation(method).getType(); + } + + @Override + public TypeInformation getReturnedDomainTypeInformation(Method method) { TypeInformation returnType = getReturnType(method); returnType = ReactiveWrapperConverters.unwrapWrapperTypes(returnType); - return QueryExecutionConverters.unwrapWrapperTypes(returnType, getDomainTypeInformation()).getType(); + return QueryExecutionConverters.unwrapWrapperTypes(returnType, getDomainTypeInformation()); } @Override diff --git a/src/main/java/org/springframework/data/repository/query/QueryMethod.java b/src/main/java/org/springframework/data/repository/query/QueryMethod.java index 25fd18f12a..b09c706345 100644 --- a/src/main/java/org/springframework/data/repository/query/QueryMethod.java +++ b/src/main/java/org/springframework/data/repository/query/QueryMethod.java @@ -403,7 +403,8 @@ private static void assertReturnTypeAssignable(Method method, Set> type } } - throw new IllegalStateException("Method has to have one of the following return types " + types); + throw new IllegalStateException( + "Method '%s' has to have one of the following return types: %s".formatted(method, types)); } static class QueryMethodValidator { diff --git a/src/test/java/org/springframework/data/repository/aot/generate/AotQueryMethodGenerationContextUnitTests.java b/src/test/java/org/springframework/data/repository/aot/generate/AotQueryMethodGenerationContextUnitTests.java index 8f63192006..e44f5c40fa 100644 --- a/src/test/java/org/springframework/data/repository/aot/generate/AotQueryMethodGenerationContextUnitTests.java +++ b/src/test/java/org/springframework/data/repository/aot/generate/AotQueryMethodGenerationContextUnitTests.java @@ -23,8 +23,16 @@ import org.junit.jupiter.api.Test; import org.mockito.Mockito; +import org.springframework.data.domain.Limit; +import org.springframework.data.domain.Page; import org.springframework.data.domain.Pageable; +import org.springframework.data.domain.ScrollPosition; +import org.springframework.data.domain.Window; +import org.springframework.data.projection.SpelAwareProxyProjectionFactory; +import org.springframework.data.repository.Repository; import org.springframework.data.repository.core.RepositoryInformation; +import org.springframework.data.repository.core.support.AbstractRepositoryMetadata; +import org.springframework.data.repository.query.DefaultParameters; import org.springframework.data.repository.query.QueryMethod; import org.springframework.data.util.TypeInformation; @@ -45,6 +53,28 @@ void suggestLocalVariableNameConsidersMethodArguments() throws NoSuchMethodExcep assertThat(ctx.localVariable("arg0")).isNotIn("arg0", "arg1", "arg2"); } + @Test // GH-3270 + void returnsCorrectParameterNames() throws NoSuchMethodException { + + AotQueryMethodGenerationContext ctx = ctxFor("limitScrollPositionDynamicProjection"); + + assertThat(ctx.getLimitParameterName()).isEqualTo("l"); + assertThat(ctx.getPageableParameterName()).isNull(); + assertThat(ctx.getScrollPositionParameterName()).isEqualTo("sp"); + assertThat(ctx.getDynamicProjectionParameterName()).isEqualTo("projection"); + } + + @Test // GH-3270 + void returnsCorrectParameterNameForPageable() throws NoSuchMethodException { + + AotQueryMethodGenerationContext ctx = ctxFor("pageable"); + + assertThat(ctx.getLimitParameterName()).isNull(); + assertThat(ctx.getPageableParameterName()).isEqualTo("p"); + assertThat(ctx.getScrollPositionParameterName()).isNull(); + assertThat(ctx.getDynamicProjectionParameterName()).isNull(); + } + AotQueryMethodGenerationContext ctxFor(String methodName) throws NoSuchMethodException { Method target = null; @@ -60,13 +90,21 @@ AotQueryMethodGenerationContext ctxFor(String methodName) throws NoSuchMethodExc } RepositoryInformation ri = Mockito.mock(RepositoryInformation.class); - Mockito.doReturn(TypeInformation.of(target.getReturnType())).when(ri).getReturnType(eq(target)); + Mockito.doReturn(TypeInformation.of(String.class)).when(ri).getReturnType(eq(target)); + Mockito.doReturn(TypeInformation.of(String.class)).when(ri).getReturnedDomainTypeInformation(eq(target)); - return new AotQueryMethodGenerationContext(ri, target, Mockito.mock(QueryMethod.class), + return new AotQueryMethodGenerationContext(ri, target, + new QueryMethod(target, AbstractRepositoryMetadata.getMetadata(DummyRepo.class), + new SpelAwareProxyProjectionFactory(), DefaultParameters::new), Mockito.mock(AotRepositoryFragmentMetadata.class)); } - private interface DummyRepo { - String reservedParameterMethod(Object arg0, Pageable arg1, Object arg2); + private interface DummyRepo extends Repository { + + Page reservedParameterMethod(Object arg0, Pageable arg1, Object arg2); + + Window limitScrollPositionDynamicProjection(Limit l, ScrollPosition sp, Class projection); + + Page pageable(Pageable p); } } diff --git a/src/test/java/org/springframework/data/repository/aot/generate/AotRepositoryMethodBuilderUnitTests.java b/src/test/java/org/springframework/data/repository/aot/generate/AotRepositoryMethodBuilderUnitTests.java index b0f19b807e..a80559ebe8 100644 --- a/src/test/java/org/springframework/data/repository/aot/generate/AotRepositoryMethodBuilderUnitTests.java +++ b/src/test/java/org/springframework/data/repository/aot/generate/AotRepositoryMethodBuilderUnitTests.java @@ -15,10 +15,9 @@ */ package org.springframework.data.repository.aot.generate; -import static org.assertj.core.api.Assertions.assertThat; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.doReturn; -import static org.mockito.Mockito.when; +import static org.assertj.core.api.Assertions.*; +import static org.mockito.ArgumentMatchers.*; +import static org.mockito.Mockito.*; import example.UserRepository; import example.UserRepository.User; @@ -29,6 +28,7 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.mockito.Mockito; + import org.springframework.core.ResolvableType; import org.springframework.data.repository.core.RepositoryInformation; import org.springframework.data.util.TypeInformation; @@ -58,6 +58,7 @@ void generatesMethodSkeletonBasedOnGenerationMetadata() throws NoSuchMethodExcep when(methodGenerationContext.getMethod()).thenReturn(method); when(methodGenerationContext.getReturnType()).thenReturn(ResolvableType.forClass(User.class)); doReturn(TypeInformation.of(User.class)).when(repositoryInformation).getReturnType(any()); + doReturn(TypeInformation.of(User.class)).when(repositoryInformation).getReturnedDomainTypeInformation(any()); MethodMetadata methodMetadata = new MethodMetadata(repositoryInformation, method); methodMetadata.addParameter(ParameterSpec.builder(String.class, "firstname").build()); when(methodGenerationContext.getTargetMethodMetadata()).thenReturn(methodMetadata); @@ -75,6 +76,7 @@ void generatesMethodWithGenerics() throws NoSuchMethodException { when(methodGenerationContext.getReturnType()) .thenReturn(ResolvableType.forClassWithGenerics(List.class, User.class)); doReturn(TypeInformation.of(User.class)).when(repositoryInformation).getReturnType(any()); + doReturn(TypeInformation.of(User.class)).when(repositoryInformation).getReturnedDomainTypeInformation(any()); MethodMetadata methodMetadata = new MethodMetadata(repositoryInformation, method); methodMetadata .addParameter(ParameterSpec.builder(ParameterizedTypeName.get(List.class, String.class), "firstnames").build()); diff --git a/src/test/java/org/springframework/data/repository/aot/generate/DummyModuleAotRepositoryContext.java b/src/test/java/org/springframework/data/repository/aot/generate/DummyModuleAotRepositoryContext.java index 8c05276a9a..cebffa4a94 100644 --- a/src/test/java/org/springframework/data/repository/aot/generate/DummyModuleAotRepositoryContext.java +++ b/src/test/java/org/springframework/data/repository/aot/generate/DummyModuleAotRepositoryContext.java @@ -26,6 +26,7 @@ import org.springframework.core.io.ClassPathResource; import org.springframework.core.test.tools.ClassFile; import org.springframework.data.repository.config.AotRepositoryContext; +import org.springframework.data.repository.config.RepositoryConfigurationSource; import org.springframework.data.repository.core.RepositoryInformation; import org.springframework.data.repository.core.support.RepositoryComposition; import org.springframework.lang.Nullable; @@ -48,6 +49,11 @@ public String getModuleName() { return "Commons"; } + @Override + public RepositoryConfigurationSource getConfigurationSource() { + return null; + } + @Override public ConfigurableListableBeanFactory getBeanFactory() { return null; diff --git a/src/test/java/org/springframework/data/repository/aot/generate/MethodMetadataUnitTests.java b/src/test/java/org/springframework/data/repository/aot/generate/MethodMetadataUnitTests.java index 8cc981251a..1e7161d99c 100644 --- a/src/test/java/org/springframework/data/repository/aot/generate/MethodMetadataUnitTests.java +++ b/src/test/java/org/springframework/data/repository/aot/generate/MethodMetadataUnitTests.java @@ -15,13 +15,14 @@ */ package org.springframework.data.repository.aot.generate; -import static org.assertj.core.api.Assertions.assertThat; -import static org.mockito.ArgumentMatchers.eq; +import static org.assertj.core.api.Assertions.*; +import static org.mockito.ArgumentMatchers.*; import java.lang.reflect.Method; import org.junit.jupiter.api.Test; import org.mockito.Mockito; + import org.springframework.data.domain.Pageable; import org.springframework.data.repository.core.RepositoryInformation; import org.springframework.data.util.TypeInformation; @@ -73,6 +74,7 @@ static MethodMetadata methodMetadataFor(String methodName) throws NoSuchMethodEx RepositoryInformation ri = Mockito.mock(RepositoryInformation.class); Mockito.doReturn(TypeInformation.of(target.getReturnType())).when(ri).getReturnType(eq(target)); + Mockito.doReturn(TypeInformation.of(target.getReturnType())).when(ri).getReturnedDomainTypeInformation(eq(target)); return new MethodMetadata(ri, target); } diff --git a/src/test/java/org/springframework/data/repository/core/support/AbstractRepositoryMetadataUnitTests.java b/src/test/java/org/springframework/data/repository/core/support/AbstractRepositoryMetadataUnitTests.java index 71d9419159..733eba5a78 100755 --- a/src/test/java/org/springframework/data/repository/core/support/AbstractRepositoryMetadataUnitTests.java +++ b/src/test/java/org/springframework/data/repository/core/support/AbstractRepositoryMetadataUnitTests.java @@ -27,6 +27,8 @@ import org.junit.jupiter.api.DynamicTest; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.TestFactory; + +import org.springframework.core.ResolvableType; import org.springframework.data.domain.Page; import org.springframework.data.domain.Pageable; import org.springframework.data.querydsl.User; @@ -60,6 +62,16 @@ void resolvesTypeParameterReturnType() throws Exception { assertThat(metadata.getReturnedDomainClass(method)).isEqualTo(User.class); } + @Test // GH-3270 + void detectsProjectionTypeCorrectly() throws Exception { + + RepositoryMetadata metadata = new DefaultRepositoryMetadata(ExtendingRepository.class); + Method method = ExtendingRepository.class.getMethod("findByFirstname", Pageable.class, String.class, Class.class); + + ResolvableType resolvableType = metadata.getReturnedDomainTypeInformation(method).toResolvableType(); + assertThat(resolvableType.getType()).hasToString("T"); + } + @Test // DATACMNS-98 void determinesReturnTypeFromPageable() throws Exception { @@ -153,6 +165,8 @@ interface ExtendingRepository extends Serializable, UserRepository { Page findByFirstname(Pageable pageable, String firstname); + Page findByFirstname(Pageable pageable, String firstname, Class projectionType); + GenericType someMethod(); List> anotherMethod();