Skip to content

Commit b65bcdb

Browse files
author
Cameron Hurst
committed
feat: flask asyncio support for dataloaders
1 parent c03e1a4 commit b65bcdb

File tree

1 file changed

+44
-32
lines changed

1 file changed

+44
-32
lines changed

graphql_server/flask/graphqlview.py

+44-32
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import asyncio
12
import copy
23
from collections.abc import MutableMapping
34
from functools import partial
@@ -6,6 +7,7 @@
67
from flask import Response, render_template_string, request
78
from flask.views import View
89
from graphql.error import GraphQLError
10+
from graphql.pyutils import is_awaitable
911
from graphql.type.schema import GraphQLSchema
1012

1113
from graphql_server import (
@@ -41,6 +43,7 @@ class GraphQLView(View):
4143
default_query = None
4244
header_editor_enabled = None
4345
should_persist_headers = None
46+
enable_async = False
4447

4548
methods = ["GET", "POST", "PUT", "DELETE"]
4649

@@ -53,26 +56,51 @@ def __init__(self, **kwargs):
5356
if hasattr(self, key):
5457
setattr(self, key, value)
5558

56-
assert isinstance(
57-
self.schema, GraphQLSchema
58-
), "A Schema is required to be provided to GraphQLView."
59+
assert isinstance(self.schema, GraphQLSchema), "A Schema is required to be provided to GraphQLView."
5960

6061
def get_root_value(self):
6162
return self.root_value
6263

6364
def get_context(self):
64-
context = (
65-
copy.copy(self.context)
66-
if self.context and isinstance(self.context, MutableMapping)
67-
else {}
68-
)
65+
context = copy.copy(self.context) if self.context and isinstance(self.context, MutableMapping) else {}
6966
if isinstance(context, MutableMapping) and "request" not in context:
7067
context.update({"request": request})
7168
return context
7269

7370
def get_middleware(self):
7471
return self.middleware
7572

73+
def result_results(self, request_method, data, catch):
74+
return run_http_query(
75+
self.schema,
76+
request_method,
77+
data,
78+
query_data=request.args,
79+
batch_enabled=self.batch,
80+
catch=catch,
81+
# Execute options
82+
root_value=self.get_root_value(),
83+
context_value=self.get_context(),
84+
middleware=self.get_middleware(),
85+
run_sync=not self.enable_async,
86+
)
87+
88+
async def resolve_results_async(self, request_method, data, catch):
89+
execution_results, all_params = run_http_query(
90+
self.schema,
91+
request_method,
92+
data,
93+
query_data=request.args,
94+
batch_enabled=self.batch,
95+
catch=catch,
96+
# Execute options
97+
root_value=self.get_root_value(),
98+
context_value=self.get_context(),
99+
middleware=self.get_middleware(),
100+
run_sync=not self.enable_async,
101+
)
102+
return [await ex if is_awaitable(ex) else ex for ex in execution_results], all_params
103+
76104
def dispatch_request(self):
77105
try:
78106
request_method = request.method.lower()
@@ -84,18 +112,11 @@ def dispatch_request(self):
84112
pretty = self.pretty or show_graphiql or request.args.get("pretty")
85113

86114
all_params: List[GraphQLParams]
87-
execution_results, all_params = run_http_query(
88-
self.schema,
89-
request_method,
90-
data,
91-
query_data=request.args,
92-
batch_enabled=self.batch,
93-
catch=catch,
94-
# Execute options
95-
root_value=self.get_root_value(),
96-
context_value=self.get_context(),
97-
middleware=self.get_middleware(),
98-
)
115+
if self.enable_async:
116+
execution_results, all_params = asyncio.run(self.resolve_results_async(request_method, data, catch))
117+
else:
118+
execution_results, all_params = self.result_results(request_method, data, catch)
119+
99120
result, status_code = encode_execution_results(
100121
execution_results,
101122
is_batch=isinstance(data, list),
@@ -123,9 +144,7 @@ def dispatch_request(self):
123144
header_editor_enabled=self.header_editor_enabled,
124145
should_persist_headers=self.should_persist_headers,
125146
)
126-
source = render_graphiql_sync(
127-
data=graphiql_data, config=graphiql_config, options=graphiql_options
128-
)
147+
source = render_graphiql_sync(data=graphiql_data, config=graphiql_config, options=graphiql_options)
129148
return render_template_string(source)
130149

131150
return Response(result, status=status_code, content_type="application/json")
@@ -150,10 +169,7 @@ def parse_body():
150169
elif content_type == "application/json":
151170
return load_json_body(request.data.decode("utf8"))
152171

153-
elif content_type in (
154-
"application/x-www-form-urlencoded",
155-
"multipart/form-data",
156-
):
172+
elif content_type in ("application/x-www-form-urlencoded", "multipart/form-data",):
157173
return request.form
158174

159175
return {}
@@ -167,8 +183,4 @@ def should_display_graphiql(self):
167183
@staticmethod
168184
def request_wants_html():
169185
best = request.accept_mimetypes.best_match(["application/json", "text/html"])
170-
return (
171-
best == "text/html"
172-
and request.accept_mimetypes[best]
173-
> request.accept_mimetypes["application/json"]
174-
)
186+
return best == "text/html" and request.accept_mimetypes[best] > request.accept_mimetypes["application/json"]

0 commit comments

Comments
 (0)