Skip to content

Add support for ConfigurationSource and Dynamic Projections #3289

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -293,4 +293,22 @@ public String getLimitParameterName() {
return getParameterName(queryMethod.getParameters().getLimitIndex());
}

/**
* @return the parameter name for the {@link org.springframework.data.domain.ScrollPosition scroll position parameter}
* or {@code null} if the method does not declare a scroll position parameter.
*/
@Nullable
public String getScrollPositionParameterName() {
return getParameterName(queryMethod.getParameters().getScrollPositionIndex());
}

/**
* @return the parameter name for the {@link Class dynamic projection parameter} or {@code null} if the method does
* not declare a dynamic projection parameter.
*/
@Nullable
public String getDynamicProjectionParameterName() {
return getParameterName(queryMethod.getParameters().getDynamicProjectionIndex());
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -48,18 +48,25 @@ class MethodMetadata {
MethodMetadata(RepositoryInformation repositoryInformation, Method method) {

this.returnType = repositoryInformation.getReturnType(method).toResolvableType();
this.actualReturnType = ResolvableType.forType(repositoryInformation.getReturnedDomainClass(method));
this.actualReturnType = repositoryInformation.getReturnedDomainTypeInformation(method).toResolvableType();
this.initParameters(repositoryInformation, method, new DefaultParameterNameDiscoverer());
}

@Nullable
public String getParameterNameOf(Class<?> type) {
for (Entry<String, ParameterSpec> entry : methodArguments.entrySet()) {
if (entry.getValue().type.equals(TypeName.get(type))) {
return entry.getKey();
}
private void initParameters(RepositoryInformation repositoryInformation, Method method,
ParameterNameDiscoverer nameDiscoverer) {

ResolvableType repositoryInterface = ResolvableType.forClass(repositoryInformation.getRepositoryInterface());

for (java.lang.reflect.Parameter parameter : method.getParameters()) {

MethodParameter methodParameter = MethodParameter.forParameter(parameter);
methodParameter.initParameterNameDiscovery(nameDiscoverer);
ResolvableType resolvableParameterType = ResolvableType.forMethodParameter(methodParameter, repositoryInterface);

TypeName parameterType = TypeName.get(resolvableParameterType.getType());

addParameter(ParameterSpec.builder(parameterType, methodParameter.getParameterName()).build());
}
return null;
}

ResolvableType getReturnType() {
Expand Down Expand Up @@ -96,20 +103,4 @@ Map<String, String> getLocalVariables() {
return localVariables;
}

private void initParameters(RepositoryInformation repositoryInformation, Method method,
ParameterNameDiscoverer nameDiscoverer) {

ResolvableType repositoryInterface = ResolvableType.forClass(repositoryInformation.getRepositoryInterface());

for (java.lang.reflect.Parameter parameter : method.getParameters()) {

MethodParameter methodParameter = MethodParameter.forParameter(parameter);
methodParameter.initParameterNameDiscovery(nameDiscoverer);
ResolvableType resolvableParameterType = ResolvableType.forMethodParameter(methodParameter, repositoryInterface);

TypeName parameterType = TypeName.get(resolvableParameterType.getType());

addParameter(ParameterSpec.builder(parameterType, methodParameter.getParameterName()).build());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,17 @@ public interface AotRepositoryContext extends AotContext {
*/
String getModuleName();

/**
* @return the repository configuration source.
*/
RepositoryConfigurationSource getConfigurationSource();

/**
* @return a {@link Set} of {@link String base packages} to search for repositories.
*/
Set<String> getBasePackages();
default Set<String> getBasePackages() {
return getConfigurationSource().getBasePackages().toSet();
}

/**
* @return the {@link Annotation} types used to identify domain types.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class DefaultAotRepositoryContext implements AotRepositoryContext {

private final RegisteredBean bean;
private final String moduleName;
private final RepositoryConfigurationSource configurationSource;
private final AotContext aotContext;
private final RepositoryInformation repositoryInformation;
private final Lazy<Set<MergedAnnotation<Annotation>>> resolvedAnnotations = Lazy.of(this::discoverAnnotations);
Expand All @@ -56,12 +57,14 @@ class DefaultAotRepositoryContext implements AotRepositoryContext {
private String beanName;

public DefaultAotRepositoryContext(RegisteredBean bean, RepositoryInformation repositoryInformation,
String moduleName, AotContext aotContext) {
String moduleName, AotContext aotContext, RepositoryConfigurationSource configurationSource) {
this.bean = bean;
this.repositoryInformation = repositoryInformation;
this.moduleName = moduleName;
this.configurationSource = configurationSource;
this.aotContext = aotContext;
this.beanName = bean.getBeanName();
this.basePackages = configurationSource.getBasePackages().toSet();
}

public AotContext getAotContext() {
Expand All @@ -73,6 +76,11 @@ public String getModuleName() {
return moduleName;
}

@Override
public RepositoryConfigurationSource getConfigurationSource() {
return configurationSource;
}

@Override
public ConfigurableListableBeanFactory getBeanFactory() {
return getAotContext().getBeanFactory();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -194,9 +194,9 @@ private void logTrace(String message, Object... arguments) {
}
RepositoryInformation repositoryInformation = reader.getRepositoryInformation();
DefaultAotRepositoryContext repositoryContext = new DefaultAotRepositoryContext(bean, repositoryInformation,
extension.getModuleName(), AotContext.from(bean.getBeanFactory(), environment));
extension.getModuleName(), AotContext.from(bean.getBeanFactory(), environment),
configuration.getConfigurationSource());

repositoryContext.setBasePackages(repositoryConfiguration.getBasePackages().toSet());
repositoryContext.setIdentifyingAnnotations(extension.getIdentifyingAnnotations());

return repositoryContext;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,11 @@ public Class<?> getReturnedDomainClass(Method method) {
return getMetadata().getReturnedDomainClass(method);
}

@Override
public TypeInformation<?> getReturnedDomainTypeInformation(Method method) {
return getMetadata().getReturnedDomainTypeInformation(method);
}

@Override
public CrudMethods getCrudMethods() {
return getMetadata().getCrudMethods();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,26 @@ default Class<?> getDomainType() {
*
* @param method
* @return
* @see #getReturnedDomainTypeInformation(Method)
* @see #getReturnType(Method)
*/
Class<?> getReturnedDomainClass(Method method);

/**
* Returns the domain type information returned by the given {@link Method}. In contrast to
* {@link #getReturnType(Method)}, this method extracts the type from {@link Collection}s and
* {@link org.springframework.data.domain.Page} as well.
*
* @param method
* @return
* @see #getReturnedDomainClass(Method)
* @see #getReturnType(Method)
* @since 4.0
*/
default TypeInformation<?> getReturnedDomainTypeInformation(Method method) {
return TypeInformation.of(getReturnedDomainClass(method));
}

/**
* Returns {@link CrudMethods} meta information for the repository.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,11 +100,16 @@ public TypeInformation<?> getReturnType(Method method) {

@Override
public Class<?> getReturnedDomainClass(Method method) {
return getReturnedDomainTypeInformation(method).getType();
}

@Override
public TypeInformation<?> getReturnedDomainTypeInformation(Method method) {

TypeInformation<?> returnType = getReturnType(method);
returnType = ReactiveWrapperConverters.unwrapWrapperTypes(returnType);

return QueryExecutionConverters.unwrapWrapperTypes(returnType, getDomainTypeInformation()).getType();
return QueryExecutionConverters.unwrapWrapperTypes(returnType, getDomainTypeInformation());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,8 @@ private static void assertReturnTypeAssignable(Method method, Set<Class<?>> type
}
}

throw new IllegalStateException("Method has to have one of the following return types " + types);
throw new IllegalStateException(
"Method '%s' has to have one of the following return types: %s".formatted(method, types));
}

static class QueryMethodValidator {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,16 @@
import org.junit.jupiter.api.Test;
import org.mockito.Mockito;

import org.springframework.data.domain.Limit;
import org.springframework.data.domain.Page;
import org.springframework.data.domain.Pageable;
import org.springframework.data.domain.ScrollPosition;
import org.springframework.data.domain.Window;
import org.springframework.data.projection.SpelAwareProxyProjectionFactory;
import org.springframework.data.repository.Repository;
import org.springframework.data.repository.core.RepositoryInformation;
import org.springframework.data.repository.core.support.AbstractRepositoryMetadata;
import org.springframework.data.repository.query.DefaultParameters;
import org.springframework.data.repository.query.QueryMethod;
import org.springframework.data.util.TypeInformation;

Expand All @@ -45,6 +53,28 @@ void suggestLocalVariableNameConsidersMethodArguments() throws NoSuchMethodExcep
assertThat(ctx.localVariable("arg0")).isNotIn("arg0", "arg1", "arg2");
}

@Test // GH-3270
void returnsCorrectParameterNames() throws NoSuchMethodException {

AotQueryMethodGenerationContext ctx = ctxFor("limitScrollPositionDynamicProjection");

assertThat(ctx.getLimitParameterName()).isEqualTo("l");
assertThat(ctx.getPageableParameterName()).isNull();
assertThat(ctx.getScrollPositionParameterName()).isEqualTo("sp");
assertThat(ctx.getDynamicProjectionParameterName()).isEqualTo("projection");
}

@Test // GH-3270
void returnsCorrectParameterNameForPageable() throws NoSuchMethodException {

AotQueryMethodGenerationContext ctx = ctxFor("pageable");

assertThat(ctx.getLimitParameterName()).isNull();
assertThat(ctx.getPageableParameterName()).isEqualTo("p");
assertThat(ctx.getScrollPositionParameterName()).isNull();
assertThat(ctx.getDynamicProjectionParameterName()).isNull();
}

AotQueryMethodGenerationContext ctxFor(String methodName) throws NoSuchMethodException {

Method target = null;
Expand All @@ -60,13 +90,21 @@ AotQueryMethodGenerationContext ctxFor(String methodName) throws NoSuchMethodExc
}

RepositoryInformation ri = Mockito.mock(RepositoryInformation.class);
Mockito.doReturn(TypeInformation.of(target.getReturnType())).when(ri).getReturnType(eq(target));
Mockito.doReturn(TypeInformation.of(String.class)).when(ri).getReturnType(eq(target));
Mockito.doReturn(TypeInformation.of(String.class)).when(ri).getReturnedDomainTypeInformation(eq(target));

return new AotQueryMethodGenerationContext(ri, target, Mockito.mock(QueryMethod.class),
return new AotQueryMethodGenerationContext(ri, target,
new QueryMethod(target, AbstractRepositoryMetadata.getMetadata(DummyRepo.class),
new SpelAwareProxyProjectionFactory(), DefaultParameters::new),
Mockito.mock(AotRepositoryFragmentMetadata.class));
}

private interface DummyRepo {
String reservedParameterMethod(Object arg0, Pageable arg1, Object arg2);
private interface DummyRepo extends Repository<String, Long> {

Page<String> reservedParameterMethod(Object arg0, Pageable arg1, Object arg2);

<T> Window<T> limitScrollPositionDynamicProjection(Limit l, ScrollPosition sp, Class<T> projection);

Page<String> pageable(Pageable p);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,9 @@
*/
package org.springframework.data.repository.aot.generate;

import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.when;
import static org.assertj.core.api.Assertions.*;
import static org.mockito.ArgumentMatchers.*;
import static org.mockito.Mockito.*;

import example.UserRepository;
import example.UserRepository.User;
Expand All @@ -29,6 +28,7 @@
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.mockito.Mockito;

import org.springframework.core.ResolvableType;
import org.springframework.data.repository.core.RepositoryInformation;
import org.springframework.data.util.TypeInformation;
Expand Down Expand Up @@ -58,6 +58,7 @@ void generatesMethodSkeletonBasedOnGenerationMetadata() throws NoSuchMethodExcep
when(methodGenerationContext.getMethod()).thenReturn(method);
when(methodGenerationContext.getReturnType()).thenReturn(ResolvableType.forClass(User.class));
doReturn(TypeInformation.of(User.class)).when(repositoryInformation).getReturnType(any());
doReturn(TypeInformation.of(User.class)).when(repositoryInformation).getReturnedDomainTypeInformation(any());
MethodMetadata methodMetadata = new MethodMetadata(repositoryInformation, method);
methodMetadata.addParameter(ParameterSpec.builder(String.class, "firstname").build());
when(methodGenerationContext.getTargetMethodMetadata()).thenReturn(methodMetadata);
Expand All @@ -75,6 +76,7 @@ void generatesMethodWithGenerics() throws NoSuchMethodException {
when(methodGenerationContext.getReturnType())
.thenReturn(ResolvableType.forClassWithGenerics(List.class, User.class));
doReturn(TypeInformation.of(User.class)).when(repositoryInformation).getReturnType(any());
doReturn(TypeInformation.of(User.class)).when(repositoryInformation).getReturnedDomainTypeInformation(any());
MethodMetadata methodMetadata = new MethodMetadata(repositoryInformation, method);
methodMetadata
.addParameter(ParameterSpec.builder(ParameterizedTypeName.get(List.class, String.class), "firstnames").build());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import org.springframework.core.io.ClassPathResource;
import org.springframework.core.test.tools.ClassFile;
import org.springframework.data.repository.config.AotRepositoryContext;
import org.springframework.data.repository.config.RepositoryConfigurationSource;
import org.springframework.data.repository.core.RepositoryInformation;
import org.springframework.data.repository.core.support.RepositoryComposition;
import org.springframework.lang.Nullable;
Expand All @@ -48,6 +49,11 @@ public String getModuleName() {
return "Commons";
}

@Override
public RepositoryConfigurationSource getConfigurationSource() {
return null;
}

@Override
public ConfigurableListableBeanFactory getBeanFactory() {
return null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,14 @@
*/
package org.springframework.data.repository.aot.generate;

import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.eq;
import static org.assertj.core.api.Assertions.*;
import static org.mockito.ArgumentMatchers.*;

import java.lang.reflect.Method;

import org.junit.jupiter.api.Test;
import org.mockito.Mockito;

import org.springframework.data.domain.Pageable;
import org.springframework.data.repository.core.RepositoryInformation;
import org.springframework.data.util.TypeInformation;
Expand Down Expand Up @@ -73,6 +74,7 @@ static MethodMetadata methodMetadataFor(String methodName) throws NoSuchMethodEx

RepositoryInformation ri = Mockito.mock(RepositoryInformation.class);
Mockito.doReturn(TypeInformation.of(target.getReturnType())).when(ri).getReturnType(eq(target));
Mockito.doReturn(TypeInformation.of(target.getReturnType())).when(ri).getReturnedDomainTypeInformation(eq(target));
return new MethodMetadata(ri, target);
}

Expand Down
Loading