diff --git a/graphql-java-servlet/src/main/java/graphql/kickstart/servlet/GraphQLConfiguration.java b/graphql-java-servlet/src/main/java/graphql/kickstart/servlet/GraphQLConfiguration.java index f009889e..9877ce5a 100644 --- a/graphql-java-servlet/src/main/java/graphql/kickstart/servlet/GraphQLConfiguration.java +++ b/graphql-java-servlet/src/main/java/graphql/kickstart/servlet/GraphQLConfiguration.java @@ -36,6 +36,7 @@ public class GraphQLConfiguration { private final ContextSetting contextSetting; private final GraphQLResponseCacheManager responseCacheManager; @Getter private final Executor asyncExecutor; + @Getter private final List allowedOrigins; private HttpRequestHandler requestHandler; private GraphQLConfiguration( @@ -49,9 +50,11 @@ private GraphQLConfiguration( ContextSetting contextSetting, Supplier batchInputPreProcessor, GraphQLResponseCacheManager responseCacheManager, - Executor asyncExecutor) { + Executor asyncExecutor, + List allowedOrigins) { this.invocationInputFactory = invocationInputFactory; this.asyncExecutor = asyncExecutor; + this.allowedOrigins = allowedOrigins; this.graphQLInvoker = graphQLInvoker != null ? graphQLInvoker : queryInvoker.toGraphQLInvoker(); this.objectMapper = objectMapper; this.listeners = listeners; @@ -148,6 +151,7 @@ public static class Builder { private int asyncMaxPoolSize = 200; private Executor asyncExecutor; private AsyncTaskDecorator asyncTaskDecorator; + private List allowedOrigins = new ArrayList<>(); private Builder(GraphQLInvocationInputFactory.Builder invocationInputFactoryBuilder) { this.invocationInputFactoryBuilder = invocationInputFactoryBuilder; @@ -249,6 +253,13 @@ public Builder with(AsyncTaskDecorator asyncTaskDecorator) { return this; } + public Builder allowedOrigins(List allowedOrigins) { + if (allowedOrigins != null) { + this.allowedOrigins.addAll(allowedOrigins); + } + return this; + } + private Executor getAsyncExecutor() { if (asyncExecutor != null) { return asyncExecutor; @@ -279,7 +290,8 @@ public GraphQLConfiguration build() { contextSetting, batchInputPreProcessorSupplier, responseCacheManager, - getAsyncTaskExecutor()); + getAsyncTaskExecutor(), + allowedOrigins); } } } diff --git a/graphql-java-servlet/src/main/java/graphql/kickstart/servlet/GraphQLWebsocketServlet.java b/graphql-java-servlet/src/main/java/graphql/kickstart/servlet/GraphQLWebsocketServlet.java index cd22aede..ab88ac16 100644 --- a/graphql-java-servlet/src/main/java/graphql/kickstart/servlet/GraphQLWebsocketServlet.java +++ b/graphql-java-servlet/src/main/java/graphql/kickstart/servlet/GraphQLWebsocketServlet.java @@ -1,6 +1,7 @@ package graphql.kickstart.servlet; import static java.util.Arrays.asList; +import static java.util.Collections.emptyList; import static java.util.Collections.singletonList; import static java.util.stream.Collectors.toList; @@ -65,6 +66,7 @@ public class GraphQLWebsocketServlet extends Endpoint { private final AtomicBoolean isShuttingDown = new AtomicBoolean(false); private final AtomicBoolean isShutDown = new AtomicBoolean(false); private final Object cacheLock = new Object(); + private final List allowedOrigins; public GraphQLWebsocketServlet(GraphQLConfiguration configuration) { this(configuration, null); @@ -77,21 +79,23 @@ public GraphQLWebsocketServlet( configuration.getGraphQLInvoker(), configuration.getInvocationInputFactory(), configuration.getObjectMapper(), - connectionListeners); + connectionListeners, + configuration.getAllowedOrigins()); } public GraphQLWebsocketServlet( GraphQLInvoker graphQLInvoker, GraphQLSubscriptionInvocationInputFactory invocationInputFactory, GraphQLObjectMapper graphQLObjectMapper) { - this(graphQLInvoker, invocationInputFactory, graphQLObjectMapper, null); + this(graphQLInvoker, invocationInputFactory, graphQLObjectMapper, null, emptyList()); } public GraphQLWebsocketServlet( GraphQLInvoker graphQLInvoker, GraphQLSubscriptionInvocationInputFactory invocationInputFactory, GraphQLObjectMapper graphQLObjectMapper, - Collection connectionListeners) { + Collection connectionListeners, + List allowedOrigins) { List listeners = new ArrayList<>(); if (connectionListeners != null) { connectionListeners.stream() @@ -114,12 +118,10 @@ public GraphQLWebsocketServlet( Stream.of(fallbackSubscriptionProtocolFactory)) .map(SubscriptionProtocolFactory::getProtocol) .collect(toList()); + this.allowedOrigins = allowedOrigins; } public GraphQLWebsocketServlet( - GraphQLInvoker graphQLInvoker, - GraphQLSubscriptionInvocationInputFactory invocationInputFactory, - GraphQLObjectMapper graphQLObjectMapper, List subscriptionProtocolFactory, SubscriptionProtocolFactory fallbackSubscriptionProtocolFactory) { @@ -132,6 +134,8 @@ public GraphQLWebsocketServlet( Stream.of(fallbackSubscriptionProtocolFactory)) .map(SubscriptionProtocolFactory::getProtocol) .collect(toList()); + + this.allowedOrigins = emptyList(); } @Override @@ -202,6 +206,26 @@ private void closeUnexpectedly(Session session, Throwable t) { } } + public boolean checkOrigin(String originHeaderValue) { + if (originHeaderValue == null || originHeaderValue.isBlank()) { + return allowedOrigins.isEmpty(); + } + String originToCheck = trimTrailingSlash(originHeaderValue); + if (!allowedOrigins.isEmpty()) { + if (allowedOrigins.contains("*")) { + return true; + } + return allowedOrigins.stream() + .map(this::trimTrailingSlash) + .anyMatch(originToCheck::equalsIgnoreCase); + } + return true; + } + + private String trimTrailingSlash(String origin) { + return (origin.endsWith("/") ? origin.substring(0, origin.length() - 1) : origin); + } + public void modifyHandshake( ServerEndpointConfig sec, HandshakeRequest request, HandshakeResponse response) { sec.getUserProperties().put(HANDSHAKE_REQUEST_KEY, request); diff --git a/graphql-java-servlet/src/test/groovy/graphql/kickstart/servlet/GraphQLWebsocketServletSpec.groovy b/graphql-java-servlet/src/test/groovy/graphql/kickstart/servlet/GraphQLWebsocketServletSpec.groovy new file mode 100644 index 00000000..e4074455 --- /dev/null +++ b/graphql-java-servlet/src/test/groovy/graphql/kickstart/servlet/GraphQLWebsocketServletSpec.groovy @@ -0,0 +1,83 @@ +package graphql.kickstart.servlet + +import spock.lang.Specification + +class GraphQLWebsocketServletSpec extends Specification { + + def "checkOrigin without any allowed origins allows given origin"() { + given: "a websocket servlet with no allowed origins" + def servlet = new GraphQLWebsocketServlet(GraphQLConfiguration.with(TestUtils.createGraphQlSchema()).build()) + + when: "we check origin http://localhost:8080" + def allowed = servlet.checkOrigin("http://localhost:8080") + + then: + allowed + } + + def "checkOrigin without any allowed origins allows when no origin given"() { + given: "a websocket servlet with no allowed origins" + def servlet = new GraphQLWebsocketServlet(GraphQLConfiguration.with(TestUtils.createGraphQlSchema()).build()) + + when: "we check origin null" + def allowed = servlet.checkOrigin(null) + + then: + allowed + } + + def "checkOrigin without any allowed origins allows when origin is empty"() { + given: "a websocket servlet with no allowed origins" + def servlet = new GraphQLWebsocketServlet(GraphQLConfiguration.with(TestUtils.createGraphQlSchema()).build()) + + when: "we check origin null" + def allowed = servlet.checkOrigin(" ") + + then: + allowed + } + + def "checkOrigin with allow all origins allows given origin"() { + given: "a websocket servlet with allow all origins" + def servlet = new GraphQLWebsocketServlet(GraphQLConfiguration.with(TestUtils.createGraphQlSchema()).allowedOrigins(List.of("*")).build()) + + when: "we check origin http://localhost:8080" + def allowed = servlet.checkOrigin("http://localhost:8080") + + then: + allowed + } + + def "checkOrigin with specific allowed origins allows given origin"() { + given: "a websocket servlet with allow all origins" + def servlet = new GraphQLWebsocketServlet(GraphQLConfiguration.with(TestUtils.createGraphQlSchema()).allowedOrigins(List.of("http://localhost:8080")).build()) + + when: "we check origin http://localhost:8080" + def allowed = servlet.checkOrigin("http://localhost:8080") + + then: + allowed + } + + def "checkOrigin with specific allowed origins allows given origin with trailing slash"() { + given: "a websocket servlet with allow all origins" + def servlet = new GraphQLWebsocketServlet(GraphQLConfiguration.with(TestUtils.createGraphQlSchema()).allowedOrigins(List.of("http://localhost:8080")).build()) + + when: "we check origin http://localhost:8080/" + def allowed = servlet.checkOrigin("http://localhost:8080/") + + then: + allowed + } + + def "checkOrigin with specific allowed origins with trailing slash allows given origin without trailing slash"() { + given: "a websocket servlet with allow all origins" + def servlet = new GraphQLWebsocketServlet(GraphQLConfiguration.with(TestUtils.createGraphQlSchema()).allowedOrigins(List.of("http://localhost:8080/")).build()) + + when: "we check origin http://localhost:8080" + def allowed = servlet.checkOrigin("http://localhost:8080") + + then: + allowed + } +}