1
+ import asyncio
1
2
import copy
2
3
from collections .abc import MutableMapping
3
4
from functools import partial
6
7
from flask import Response , render_template_string , request
7
8
from flask .views import View
8
9
from graphql .error import GraphQLError
10
+ from graphql .pyutils import is_awaitable
9
11
from graphql .type .schema import GraphQLSchema
10
12
11
13
from graphql_server import (
@@ -41,6 +43,7 @@ class GraphQLView(View):
41
43
default_query = None
42
44
header_editor_enabled = None
43
45
should_persist_headers = None
46
+ enable_async = False
44
47
45
48
methods = ["GET" , "POST" , "PUT" , "DELETE" ]
46
49
@@ -53,26 +56,51 @@ def __init__(self, **kwargs):
53
56
if hasattr (self , key ):
54
57
setattr (self , key , value )
55
58
56
- assert isinstance (
57
- self .schema , GraphQLSchema
58
- ), "A Schema is required to be provided to GraphQLView."
59
+ assert isinstance (self .schema , GraphQLSchema ), "A Schema is required to be provided to GraphQLView."
59
60
60
61
def get_root_value (self ):
61
62
return self .root_value
62
63
63
64
def get_context (self ):
64
- context = (
65
- copy .copy (self .context )
66
- if self .context and isinstance (self .context , MutableMapping )
67
- else {}
68
- )
65
+ context = copy .copy (self .context ) if self .context and isinstance (self .context , MutableMapping ) else {}
69
66
if isinstance (context , MutableMapping ) and "request" not in context :
70
67
context .update ({"request" : request })
71
68
return context
72
69
73
70
def get_middleware (self ):
74
71
return self .middleware
75
72
73
+ def result_results (self , request_method , data , catch ):
74
+ return run_http_query (
75
+ self .schema ,
76
+ request_method ,
77
+ data ,
78
+ query_data = request .args ,
79
+ batch_enabled = self .batch ,
80
+ catch = catch ,
81
+ # Execute options
82
+ root_value = self .get_root_value (),
83
+ context_value = self .get_context (),
84
+ middleware = self .get_middleware (),
85
+ run_sync = not self .enable_async ,
86
+ )
87
+
88
+ async def resolve_results_async (self , request_method , data , catch ):
89
+ execution_results , all_params = run_http_query (
90
+ self .schema ,
91
+ request_method ,
92
+ data ,
93
+ query_data = request .args ,
94
+ batch_enabled = self .batch ,
95
+ catch = catch ,
96
+ # Execute options
97
+ root_value = self .get_root_value (),
98
+ context_value = self .get_context (),
99
+ middleware = self .get_middleware (),
100
+ run_sync = not self .enable_async ,
101
+ )
102
+ return [await ex if is_awaitable (ex ) else ex for ex in execution_results ], all_params
103
+
76
104
def dispatch_request (self ):
77
105
try :
78
106
request_method = request .method .lower ()
@@ -84,18 +112,11 @@ def dispatch_request(self):
84
112
pretty = self .pretty or show_graphiql or request .args .get ("pretty" )
85
113
86
114
all_params : List [GraphQLParams ]
87
- execution_results , all_params = run_http_query (
88
- self .schema ,
89
- request_method ,
90
- data ,
91
- query_data = request .args ,
92
- batch_enabled = self .batch ,
93
- catch = catch ,
94
- # Execute options
95
- root_value = self .get_root_value (),
96
- context_value = self .get_context (),
97
- middleware = self .get_middleware (),
98
- )
115
+ if self .enable_async :
116
+ execution_results , all_params = asyncio .run (self .resolve_results_async (request_method , data , catch ))
117
+ else :
118
+ execution_results , all_params = self .result_results (request_method , data , catch )
119
+
99
120
result , status_code = encode_execution_results (
100
121
execution_results ,
101
122
is_batch = isinstance (data , list ),
@@ -123,9 +144,7 @@ def dispatch_request(self):
123
144
header_editor_enabled = self .header_editor_enabled ,
124
145
should_persist_headers = self .should_persist_headers ,
125
146
)
126
- source = render_graphiql_sync (
127
- data = graphiql_data , config = graphiql_config , options = graphiql_options
128
- )
147
+ source = render_graphiql_sync (data = graphiql_data , config = graphiql_config , options = graphiql_options )
129
148
return render_template_string (source )
130
149
131
150
return Response (result , status = status_code , content_type = "application/json" )
@@ -150,10 +169,7 @@ def parse_body():
150
169
elif content_type == "application/json" :
151
170
return load_json_body (request .data .decode ("utf8" ))
152
171
153
- elif content_type in (
154
- "application/x-www-form-urlencoded" ,
155
- "multipart/form-data" ,
156
- ):
172
+ elif content_type in ("application/x-www-form-urlencoded" , "multipart/form-data" ,):
157
173
return request .form
158
174
159
175
return {}
@@ -167,8 +183,4 @@ def should_display_graphiql(self):
167
183
@staticmethod
168
184
def request_wants_html ():
169
185
best = request .accept_mimetypes .best_match (["application/json" , "text/html" ])
170
- return (
171
- best == "text/html"
172
- and request .accept_mimetypes [best ]
173
- > request .accept_mimetypes ["application/json" ]
174
- )
186
+ return best == "text/html" and request .accept_mimetypes [best ] > request .accept_mimetypes ["application/json" ]
0 commit comments