diff --git a/graphql_server/flask/graphqlview.py b/graphql_server/flask/graphqlview.py index 6132b61..2278049 100644 --- a/graphql_server/flask/graphqlview.py +++ b/graphql_server/flask/graphqlview.py @@ -1,12 +1,14 @@ +import asyncio import copy from collections.abc import MutableMapping from functools import partial from typing import List -from flask import Response, render_template_string, request +from flask import Response, current_app, render_template_string, request from flask.views import View from graphql import specified_rules from graphql.error import GraphQLError +from graphql.pyutils import is_awaitable from graphql.type.schema import GraphQLSchema from graphql_server import ( @@ -25,6 +27,7 @@ GraphiQLOptions, render_graphiql_sync, ) +from graphql_server.utils import wrap_in_async class GraphQLView(View): @@ -41,6 +44,7 @@ class GraphQLView(View): execution_context_class = None batch = False jinja_env = None + enable_async = False subscriptions = None headers = None default_query = None @@ -110,12 +114,29 @@ def dispatch_request(self): batch_enabled=self.batch, catch=catch, # Execute options + run_sync=not self.enable_async, root_value=self.get_root_value(), context_value=self.get_context(), middleware=self.get_middleware(), validation_rules=self.get_validation_rules(), execution_context_class=self.get_execution_context_class(), ) + + async def get_async_execution_results(): + return await asyncio.gather( + *( + ex + if ex is not None and is_awaitable(ex) + else wrap_in_async(lambda: ex)() + for ex in execution_results + ) + ) + + if self.enable_async: + execution_results = current_app.ensure_sync( + get_async_execution_results + )() + result, status_code = encode_execution_results( execution_results, is_batch=isinstance(data, list), diff --git a/setup.py b/setup.py index bf3f24f..cba07e7 100644 --- a/setup.py +++ b/setup.py @@ -27,6 +27,10 @@ "flask>=1,<3", ] +install_flask_async_requires = [ + "flask[async]>=2,<3", +] + install_sanic_requires = [ "sanic>=21.12,<23", ] @@ -43,7 +47,7 @@ install_all_requires = ( install_requires - + install_flask_requires + + install_flask_async_requires + install_sanic_requires + install_webob_requires + install_aiohttp_requires diff --git a/tests/flask/app.py b/tests/flask/app.py index ec9e9d0..1297446 100644 --- a/tests/flask/app.py +++ b/tests/flask/app.py @@ -4,11 +4,13 @@ from tests.flask.schema import Schema -def create_app(path="/graphql", **kwargs): +def create_app(path="/graphql", schema=Schema, **kwargs): server = Flask(__name__) server.debug = True + view_cls = GraphQLView server.add_url_rule( - path, view_func=GraphQLView.as_view("graphql", schema=Schema, **kwargs) + path, + view_func=view_cls.as_view("graphql", schema=schema, **kwargs), ) return server diff --git a/tests/flask/schema.py b/tests/flask/schema.py index eb51e26..23cf2d2 100644 --- a/tests/flask/schema.py +++ b/tests/flask/schema.py @@ -1,3 +1,5 @@ +import asyncio + from graphql.type.definition import ( GraphQLArgument, GraphQLField, @@ -49,3 +51,29 @@ def resolve_raises(*_): ) Schema = GraphQLSchema(QueryRootType, MutationRootType) + + +async def resolver_field_async_1(_obj, info): + await asyncio.sleep(0.001) + return "hey" + + +async def resolver_field_async_2(_obj, info): + await asyncio.sleep(0.003) + return "hey2" + + +def resolver_field_sync(_obj, info): + return "hey3" + + +AsyncQueryType = GraphQLObjectType( + name="AsyncQueryType", + fields={ + "a": GraphQLField(GraphQLString, resolve=resolver_field_async_1), + "b": GraphQLField(GraphQLString, resolve=resolver_field_async_2), + "c": GraphQLField(GraphQLString, resolve=resolver_field_sync), + }, +) + +AsyncSchema = GraphQLSchema(AsyncQueryType) diff --git a/tests/flask/test_graphqlview.py b/tests/flask/test_graphqlview.py index 358ec3b..9aab3df 100644 --- a/tests/flask/test_graphqlview.py +++ b/tests/flask/test_graphqlview.py @@ -7,6 +7,7 @@ from ..utils import RepeatExecutionContext from .app import create_app +from .schema import AsyncSchema def url_string(app, **url_params): @@ -574,3 +575,11 @@ def test_custom_execution_context_class(app, client): assert response.status_code == 200 assert response_json(response) == {"data": {"test": "Hello WorldHello World"}} + + +@pytest.mark.parametrize("app", [create_app(schema=AsyncSchema, enable_async=True)]) +def test_async_schema(app, client): + response = client.get(url_string(app, query="{a,b,c}")) + + assert response.status_code == 200 + assert response_json(response) == {"data": {"a": "hey", "b": "hey2", "c": "hey3"}}