diff --git a/src/main/java/graphql/servlet/GraphQLQueryInvoker.java b/src/main/java/graphql/servlet/GraphQLQueryInvoker.java index 400ee499..63c589c0 100644 --- a/src/main/java/graphql/servlet/GraphQLQueryInvoker.java +++ b/src/main/java/graphql/servlet/GraphQLQueryInvoker.java @@ -63,17 +63,27 @@ protected Instrumentation getInstrumentation(Object context) { if (context instanceof GraphQLContext) { return ((GraphQLContext) context).getDataLoaderRegistry() .map(registry -> { - List instrumentations = new ArrayList<>(); - instrumentations.add(getInstrumentation.get()); - instrumentations.add(new DataLoaderDispatcherInstrumentation(dataLoaderDispatcherInstrumentationOptionsSupplier.get())); - return new ChainedInstrumentation(instrumentations); + Instrumentation instrumentation = getInstrumentation.get(); + if (!containsDispatchInstrumentation(instrumentation)) { + List instrumentations = new ArrayList<>(); + instrumentations.add(instrumentation); + instrumentations.add(new DataLoaderDispatcherInstrumentation(dataLoaderDispatcherInstrumentationOptionsSupplier.get())); + instrumentation = new ChainedInstrumentation(instrumentations); + } + return instrumentation; }) - .map(Instrumentation.class::cast) .orElse(getInstrumentation.get()); } return getInstrumentation.get(); } + private boolean containsDispatchInstrumentation(Instrumentation instrumentation) { + if (instrumentation instanceof ChainedInstrumentation) { + return ((ChainedInstrumentation)instrumentation).getInstrumentations().stream().anyMatch(this::containsDispatchInstrumentation); + } + return instrumentation instanceof DataLoaderDispatcherInstrumentation; + } + private ExecutionResult query(GraphQLInvocationInput invocationInput, ExecutionInput executionInput) { if (Subject.getSubject(AccessController.getContext()) == null && invocationInput.getSubject().isPresent()) { return Subject.doAs(invocationInput.getSubject().get(), (PrivilegedAction) () -> { diff --git a/src/test/groovy/graphql/servlet/AbstractGraphQLHttpServletSpec.groovy b/src/test/groovy/graphql/servlet/AbstractGraphQLHttpServletSpec.groovy index 88a4238a..7b1baf80 100644 --- a/src/test/groovy/graphql/servlet/AbstractGraphQLHttpServletSpec.groovy +++ b/src/test/groovy/graphql/servlet/AbstractGraphQLHttpServletSpec.groovy @@ -5,6 +5,7 @@ import graphql.Scalars import graphql.execution.ExecutionStepInfo import graphql.execution.instrumentation.ChainedInstrumentation import graphql.execution.instrumentation.Instrumentation +import graphql.execution.instrumentation.dataloader.DataLoaderDispatcherInstrumentation import graphql.schema.DataFetcher import graphql.execution.reactive.SingleSubscriberPublisher import graphql.schema.GraphQLNonNull @@ -1214,4 +1215,23 @@ class AbstractGraphQLHttpServletSpec extends Specification { actualInstrumentation instanceof ChainedInstrumentation actualInstrumentation != servletInstrumentation } + + def "getInstrumentation does not add dataloader dispatch instrumentation if one is provided"() { + setup: + Instrumentation servletInstrumentation = Mock() + DataLoaderDispatcherInstrumentation mockDispatchInstrumentation = Mock() + ChainedInstrumentation chainedInstrumentation = new ChainedInstrumentation(Arrays.asList(servletInstrumentation, + mockDispatchInstrumentation)) + GraphQLContext context = new GraphQLContext(request, response, null, null, null) + DataLoaderRegistry dlr = Mock() + context.setDataLoaderRegistry(dlr) + SimpleGraphQLHttpServlet simpleGraphQLServlet = SimpleGraphQLHttpServlet + .newBuilder(TestUtils.createGraphQlSchema()) + .withQueryInvoker(GraphQLQueryInvoker.newBuilder().withInstrumentation(chainedInstrumentation).build()) + .build(); + when: + Instrumentation actualInstrumentation = simpleGraphQLServlet.getQueryInvoker().getInstrumentation(context) + then: + actualInstrumentation == chainedInstrumentation + } }