6
6
"JwtSuperuserConnection" ,
7
7
]
8
8
9
- import json
10
9
from abc import ABC , abstractmethod
10
+ from json import JSONDecodeError
11
11
from typing import Any , List , Optional
12
12
13
- import jwt
13
+ from jwt import ExpiredSignatureError
14
14
15
15
from arangoasync import errno , logger
16
16
from arangoasync .auth import Auth , JwtToken
26
26
from arangoasync .request import Method , Request
27
27
from arangoasync .resolver import HostResolver
28
28
from arangoasync .response import Response
29
+ from arangoasync .serialization import (
30
+ DefaultDeserializer ,
31
+ DefaultSerializer ,
32
+ Deserializer ,
33
+ Serializer ,
34
+ )
29
35
30
36
31
37
class BaseConnection (ABC ):
@@ -37,6 +43,10 @@ class BaseConnection(ABC):
37
43
http_client (HTTPClient): HTTP client.
38
44
db_name (str): Database name.
39
45
compression (CompressionManager | None): Compression manager.
46
+ serializer (Serializer | None): For custom serialization.
47
+ Leave `None` for default.
48
+ deserializer (Deserializer | None): For custom deserialization.
49
+ Leave `None` for default.
40
50
"""
41
51
42
52
def __init__ (
@@ -46,19 +56,33 @@ def __init__(
46
56
http_client : HTTPClient ,
47
57
db_name : str ,
48
58
compression : Optional [CompressionManager ] = None ,
59
+ serializer : Optional [Serializer ] = None ,
60
+ deserializer : Optional [Deserializer ] = None ,
49
61
) -> None :
50
62
self ._sessions = sessions
51
63
self ._db_endpoint = f"/_db/{ db_name } "
52
64
self ._host_resolver = host_resolver
53
65
self ._http_client = http_client
54
66
self ._db_name = db_name
55
67
self ._compression = compression
68
+ self ._serializer = serializer or DefaultSerializer ()
69
+ self ._deserializer = deserializer or DefaultDeserializer ()
56
70
57
71
@property
58
72
def db_name (self ) -> str :
59
73
"""Return the database name."""
60
74
return self ._db_name
61
75
76
+ @property
77
+ def serializer (self ) -> Serializer :
78
+ """Return the serializer."""
79
+ return self ._serializer
80
+
81
+ @property
82
+ def deserializer (self ) -> Deserializer :
83
+ """Return the deserializer."""
84
+ return self ._deserializer
85
+
62
86
@staticmethod
63
87
def raise_for_status (request : Request , resp : Response ) -> None :
64
88
"""Raise an exception based on the response.
@@ -75,8 +99,7 @@ def raise_for_status(request: Request, resp: Response) -> None:
75
99
if not resp .is_success :
76
100
raise ServerConnectionError (resp , request , "Bad server response." )
77
101
78
- @staticmethod
79
- def prep_response (request : Request , resp : Response ) -> Response :
102
+ def prep_response (self , request : Request , resp : Response ) -> Response :
80
103
"""Prepare response for return.
81
104
82
105
Args:
@@ -89,8 +112,8 @@ def prep_response(request: Request, resp: Response) -> Response:
89
112
resp .is_success = 200 <= resp .status_code < 300
90
113
if not resp .is_success :
91
114
try :
92
- body = json . loads (resp .raw_body )
93
- except json . JSONDecodeError as e :
115
+ body = self . _deserializer . from_bytes (resp .raw_body )
116
+ except JSONDecodeError as e :
94
117
logger .debug (
95
118
f"Failed to decode response body: { e } (from request { request } )"
96
119
)
@@ -202,6 +225,8 @@ class BasicConnection(BaseConnection):
202
225
http_client (HTTPClient): HTTP client.
203
226
db_name (str): Database name.
204
227
compression (CompressionManager | None): Compression manager.
228
+ serializer (Serializer | None): For custom serialization.
229
+ deserializer (Deserializer | None): For custom deserialization.
205
230
auth (Auth | None): Authentication information.
206
231
"""
207
232
@@ -212,9 +237,19 @@ def __init__(
212
237
http_client : HTTPClient ,
213
238
db_name : str ,
214
239
compression : Optional [CompressionManager ] = None ,
240
+ serializer : Optional [Serializer ] = None ,
241
+ deserializer : Optional [Deserializer ] = None ,
215
242
auth : Optional [Auth ] = None ,
216
243
) -> None :
217
- super ().__init__ (sessions , host_resolver , http_client , db_name , compression )
244
+ super ().__init__ (
245
+ sessions ,
246
+ host_resolver ,
247
+ http_client ,
248
+ db_name ,
249
+ compression ,
250
+ serializer ,
251
+ deserializer ,
252
+ )
218
253
self ._auth = auth
219
254
220
255
async def send_request (self , request : Request ) -> Response :
@@ -249,6 +284,8 @@ class JwtConnection(BaseConnection):
249
284
http_client (HTTPClient): HTTP client.
250
285
db_name (str): Database name.
251
286
compression (CompressionManager | None): Compression manager.
287
+ serializer (Serializer | None): For custom serialization.
288
+ deserializer (Deserializer | None): For custom deserialization.
252
289
auth (Auth | None): Authentication information.
253
290
token (JwtToken | None): JWT token.
254
291
@@ -263,10 +300,20 @@ def __init__(
263
300
http_client : HTTPClient ,
264
301
db_name : str ,
265
302
compression : Optional [CompressionManager ] = None ,
303
+ serializer : Optional [Serializer ] = None ,
304
+ deserializer : Optional [Deserializer ] = None ,
266
305
auth : Optional [Auth ] = None ,
267
306
token : Optional [JwtToken ] = None ,
268
307
) -> None :
269
- super ().__init__ (sessions , host_resolver , http_client , db_name , compression )
308
+ super ().__init__ (
309
+ sessions ,
310
+ host_resolver ,
311
+ http_client ,
312
+ db_name ,
313
+ compression ,
314
+ serializer ,
315
+ deserializer ,
316
+ )
270
317
self ._auth = auth
271
318
self ._expire_leeway : int = 0
272
319
self ._token : Optional [JwtToken ] = token
@@ -306,10 +353,8 @@ async def refresh_token(self) -> None:
306
353
if self ._auth is None :
307
354
raise JWTRefreshError ("Auth must be provided to refresh the token." )
308
355
309
- auth_data = json . dumps (
356
+ auth_data = self . _serializer . to_str (
310
357
dict (username = self ._auth .username , password = self ._auth .password ),
311
- separators = ("," , ":" ),
312
- ensure_ascii = False ,
313
358
)
314
359
request = Request (
315
360
method = Method .POST ,
@@ -330,10 +375,10 @@ async def refresh_token(self) -> None:
330
375
f"{ resp .status_code } { resp .status_text } "
331
376
)
332
377
333
- token = json . loads (resp .raw_body )
378
+ token = self . _deserializer . from_bytes (resp .raw_body )
334
379
try :
335
380
self .token = JwtToken (token ["jwt" ])
336
- except jwt . ExpiredSignatureError as e :
381
+ except ExpiredSignatureError as e :
337
382
raise JWTRefreshError (
338
383
"Failed to refresh the JWT token: got an expired token"
339
384
) from e
@@ -385,6 +430,8 @@ class JwtSuperuserConnection(BaseConnection):
385
430
http_client (HTTPClient): HTTP client.
386
431
db_name (str): Database name.
387
432
compression (CompressionManager | None): Compression manager.
433
+ serializer (Serializer | None): For custom serialization.
434
+ deserializer (Deserializer | None): For custom deserialization.
388
435
token (JwtToken | None): JWT token.
389
436
"""
390
437
@@ -395,10 +442,19 @@ def __init__(
395
442
http_client : HTTPClient ,
396
443
db_name : str ,
397
444
compression : Optional [CompressionManager ] = None ,
445
+ serializer : Optional [Serializer ] = None ,
446
+ deserializer : Optional [Deserializer ] = None ,
398
447
token : Optional [JwtToken ] = None ,
399
448
) -> None :
400
- super ().__init__ (sessions , host_resolver , http_client , db_name , compression )
401
- self ._expire_leeway : int = 0
449
+ super ().__init__ (
450
+ sessions ,
451
+ host_resolver ,
452
+ http_client ,
453
+ db_name ,
454
+ compression ,
455
+ serializer ,
456
+ deserializer ,
457
+ )
402
458
self ._token : Optional [JwtToken ] = token
403
459
self ._auth_header : Optional [str ] = None
404
460
self .token = self ._token
0 commit comments