Skip to content

Commit 8dea5f4

Browse files
authored
Introducing custom serialization (#19)
1 parent 5b84719 commit 8dea5f4

11 files changed

+428
-27
lines changed

arangoasync/client.py

+30
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,12 @@
1414
from arangoasync.database import StandardDatabase
1515
from arangoasync.http import DefaultHTTPClient, HTTPClient
1616
from arangoasync.resolver import HostResolver, get_resolver
17+
from arangoasync.serialization import (
18+
DefaultDeserializer,
19+
DefaultSerializer,
20+
Deserializer,
21+
Serializer,
22+
)
1723
from arangoasync.version import __version__
1824

1925

@@ -45,6 +51,14 @@ class ArangoClient:
4551
<arangoasync.compression.DefaultCompressionManager>`
4652
or a custom subclass of :class:`CompressionManager
4753
<arangoasync.compression.CompressionManager>`.
54+
serializer (Serializer | None): Custom serializer implementation.
55+
Leave as `None` to use the default serializer.
56+
See :class:`DefaultSerializer
57+
<arangoasync.serialization.DefaultSerializer>`.
58+
deserializer (Deserializer | None): Custom deserializer implementation.
59+
Leave as `None` to use the default deserializer.
60+
See :class:`DefaultDeserializer
61+
<arangoasync.serialization.DefaultDeserializer>`.
4862
4963
Raises:
5064
ValueError: If the `host_resolver` is not supported.
@@ -56,6 +70,8 @@ def __init__(
5670
host_resolver: str | HostResolver = "default",
5771
http_client: Optional[HTTPClient] = None,
5872
compression: Optional[CompressionManager] = None,
73+
serializer: Optional[Serializer] = None,
74+
deserializer: Optional[Deserializer] = None,
5975
) -> None:
6076
self._hosts = [hosts] if isinstance(hosts, str) else hosts
6177
self._host_resolver = (
@@ -68,6 +84,8 @@ def __init__(
6884
self._http_client.create_session(host) for host in self._hosts
6985
]
7086
self._compression = compression
87+
self._serializer = serializer or DefaultSerializer()
88+
self._deserializer = deserializer or DefaultDeserializer()
7189

7290
def __repr__(self) -> str:
7391
return f"<ArangoClient {','.join(self._hosts)}>"
@@ -124,6 +142,8 @@ async def db(
124142
token: Optional[JwtToken] = None,
125143
verify: bool = False,
126144
compression: Optional[CompressionManager] = None,
145+
serializer: Optional[Serializer] = None,
146+
deserializer: Optional[Deserializer] = None,
127147
) -> StandardDatabase:
128148
"""Connects to a database and returns and API wrapper.
129149
@@ -145,6 +165,10 @@ async def db(
145165
verify (bool): Verify the connection by sending a test request.
146166
compression (CompressionManager | None): If set, supersedes the
147167
client-level compression settings.
168+
serializer (Serializer | None): If set, supersedes the client-level
169+
serializer.
170+
deserializer (Deserializer | None): If set, supersedes the client-level
171+
deserializer.
148172
149173
Returns:
150174
StandardDatabase: Database API wrapper.
@@ -163,6 +187,8 @@ async def db(
163187
http_client=self._http_client,
164188
db_name=name,
165189
compression=compression or self._compression,
190+
serializer=serializer or self._serializer,
191+
deserializer=deserializer or self._deserializer,
166192
auth=auth,
167193
)
168194
elif auth_method == "jwt":
@@ -176,6 +202,8 @@ async def db(
176202
http_client=self._http_client,
177203
db_name=name,
178204
compression=compression or self._compression,
205+
serializer=serializer or self._serializer,
206+
deserializer=deserializer or self._deserializer,
179207
auth=auth,
180208
token=token,
181209
)
@@ -190,6 +218,8 @@ async def db(
190218
http_client=self._http_client,
191219
db_name=name,
192220
compression=compression or self._compression,
221+
serializer=serializer or self._serializer,
222+
deserializer=deserializer or self._deserializer,
193223
token=token,
194224
)
195225
else:

arangoasync/connection.py

+71-15
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,11 @@
66
"JwtSuperuserConnection",
77
]
88

9-
import json
109
from abc import ABC, abstractmethod
10+
from json import JSONDecodeError
1111
from typing import Any, List, Optional
1212

13-
import jwt
13+
from jwt import ExpiredSignatureError
1414

1515
from arangoasync import errno, logger
1616
from arangoasync.auth import Auth, JwtToken
@@ -26,6 +26,12 @@
2626
from arangoasync.request import Method, Request
2727
from arangoasync.resolver import HostResolver
2828
from arangoasync.response import Response
29+
from arangoasync.serialization import (
30+
DefaultDeserializer,
31+
DefaultSerializer,
32+
Deserializer,
33+
Serializer,
34+
)
2935

3036

3137
class BaseConnection(ABC):
@@ -37,6 +43,10 @@ class BaseConnection(ABC):
3743
http_client (HTTPClient): HTTP client.
3844
db_name (str): Database name.
3945
compression (CompressionManager | None): Compression manager.
46+
serializer (Serializer | None): For custom serialization.
47+
Leave `None` for default.
48+
deserializer (Deserializer | None): For custom deserialization.
49+
Leave `None` for default.
4050
"""
4151

4252
def __init__(
@@ -46,19 +56,33 @@ def __init__(
4656
http_client: HTTPClient,
4757
db_name: str,
4858
compression: Optional[CompressionManager] = None,
59+
serializer: Optional[Serializer] = None,
60+
deserializer: Optional[Deserializer] = None,
4961
) -> None:
5062
self._sessions = sessions
5163
self._db_endpoint = f"/_db/{db_name}"
5264
self._host_resolver = host_resolver
5365
self._http_client = http_client
5466
self._db_name = db_name
5567
self._compression = compression
68+
self._serializer = serializer or DefaultSerializer()
69+
self._deserializer = deserializer or DefaultDeserializer()
5670

5771
@property
5872
def db_name(self) -> str:
5973
"""Return the database name."""
6074
return self._db_name
6175

76+
@property
77+
def serializer(self) -> Serializer:
78+
"""Return the serializer."""
79+
return self._serializer
80+
81+
@property
82+
def deserializer(self) -> Deserializer:
83+
"""Return the deserializer."""
84+
return self._deserializer
85+
6286
@staticmethod
6387
def raise_for_status(request: Request, resp: Response) -> None:
6488
"""Raise an exception based on the response.
@@ -75,8 +99,7 @@ def raise_for_status(request: Request, resp: Response) -> None:
7599
if not resp.is_success:
76100
raise ServerConnectionError(resp, request, "Bad server response.")
77101

78-
@staticmethod
79-
def prep_response(request: Request, resp: Response) -> Response:
102+
def prep_response(self, request: Request, resp: Response) -> Response:
80103
"""Prepare response for return.
81104
82105
Args:
@@ -89,8 +112,8 @@ def prep_response(request: Request, resp: Response) -> Response:
89112
resp.is_success = 200 <= resp.status_code < 300
90113
if not resp.is_success:
91114
try:
92-
body = json.loads(resp.raw_body)
93-
except json.JSONDecodeError as e:
115+
body = self._deserializer.from_bytes(resp.raw_body)
116+
except JSONDecodeError as e:
94117
logger.debug(
95118
f"Failed to decode response body: {e} (from request {request})"
96119
)
@@ -202,6 +225,8 @@ class BasicConnection(BaseConnection):
202225
http_client (HTTPClient): HTTP client.
203226
db_name (str): Database name.
204227
compression (CompressionManager | None): Compression manager.
228+
serializer (Serializer | None): For custom serialization.
229+
deserializer (Deserializer | None): For custom deserialization.
205230
auth (Auth | None): Authentication information.
206231
"""
207232

@@ -212,9 +237,19 @@ def __init__(
212237
http_client: HTTPClient,
213238
db_name: str,
214239
compression: Optional[CompressionManager] = None,
240+
serializer: Optional[Serializer] = None,
241+
deserializer: Optional[Deserializer] = None,
215242
auth: Optional[Auth] = None,
216243
) -> None:
217-
super().__init__(sessions, host_resolver, http_client, db_name, compression)
244+
super().__init__(
245+
sessions,
246+
host_resolver,
247+
http_client,
248+
db_name,
249+
compression,
250+
serializer,
251+
deserializer,
252+
)
218253
self._auth = auth
219254

220255
async def send_request(self, request: Request) -> Response:
@@ -249,6 +284,8 @@ class JwtConnection(BaseConnection):
249284
http_client (HTTPClient): HTTP client.
250285
db_name (str): Database name.
251286
compression (CompressionManager | None): Compression manager.
287+
serializer (Serializer | None): For custom serialization.
288+
deserializer (Deserializer | None): For custom deserialization.
252289
auth (Auth | None): Authentication information.
253290
token (JwtToken | None): JWT token.
254291
@@ -263,10 +300,20 @@ def __init__(
263300
http_client: HTTPClient,
264301
db_name: str,
265302
compression: Optional[CompressionManager] = None,
303+
serializer: Optional[Serializer] = None,
304+
deserializer: Optional[Deserializer] = None,
266305
auth: Optional[Auth] = None,
267306
token: Optional[JwtToken] = None,
268307
) -> None:
269-
super().__init__(sessions, host_resolver, http_client, db_name, compression)
308+
super().__init__(
309+
sessions,
310+
host_resolver,
311+
http_client,
312+
db_name,
313+
compression,
314+
serializer,
315+
deserializer,
316+
)
270317
self._auth = auth
271318
self._expire_leeway: int = 0
272319
self._token: Optional[JwtToken] = token
@@ -306,10 +353,8 @@ async def refresh_token(self) -> None:
306353
if self._auth is None:
307354
raise JWTRefreshError("Auth must be provided to refresh the token.")
308355

309-
auth_data = json.dumps(
356+
auth_data = self._serializer.to_str(
310357
dict(username=self._auth.username, password=self._auth.password),
311-
separators=(",", ":"),
312-
ensure_ascii=False,
313358
)
314359
request = Request(
315360
method=Method.POST,
@@ -330,10 +375,10 @@ async def refresh_token(self) -> None:
330375
f"{resp.status_code} {resp.status_text}"
331376
)
332377

333-
token = json.loads(resp.raw_body)
378+
token = self._deserializer.from_bytes(resp.raw_body)
334379
try:
335380
self.token = JwtToken(token["jwt"])
336-
except jwt.ExpiredSignatureError as e:
381+
except ExpiredSignatureError as e:
337382
raise JWTRefreshError(
338383
"Failed to refresh the JWT token: got an expired token"
339384
) from e
@@ -385,6 +430,8 @@ class JwtSuperuserConnection(BaseConnection):
385430
http_client (HTTPClient): HTTP client.
386431
db_name (str): Database name.
387432
compression (CompressionManager | None): Compression manager.
433+
serializer (Serializer | None): For custom serialization.
434+
deserializer (Deserializer | None): For custom deserialization.
388435
token (JwtToken | None): JWT token.
389436
"""
390437

@@ -395,10 +442,19 @@ def __init__(
395442
http_client: HTTPClient,
396443
db_name: str,
397444
compression: Optional[CompressionManager] = None,
445+
serializer: Optional[Serializer] = None,
446+
deserializer: Optional[Deserializer] = None,
398447
token: Optional[JwtToken] = None,
399448
) -> None:
400-
super().__init__(sessions, host_resolver, http_client, db_name, compression)
401-
self._expire_leeway: int = 0
449+
super().__init__(
450+
sessions,
451+
host_resolver,
452+
http_client,
453+
db_name,
454+
compression,
455+
serializer,
456+
deserializer,
457+
)
402458
self._token: Optional[JwtToken] = token
403459
self._auth_header: Optional[str] = None
404460
self.token = self._token

arangoasync/database.py

+17-10
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,15 @@
33
"StandardDatabase",
44
]
55

6-
import json
7-
from typing import Any
86

97
from arangoasync.connection import Connection
108
from arangoasync.exceptions import ServerStatusError
119
from arangoasync.executor import ApiExecutor, DefaultApiExecutor
1210
from arangoasync.request import Method, Request
1311
from arangoasync.response import Response
12+
from arangoasync.serialization import Deserializer, Serializer
13+
from arangoasync.typings import Result
14+
from arangoasync.wrapper import ServerStatusInformation
1415

1516

1617
class Database:
@@ -29,25 +30,31 @@ def name(self) -> str:
2930
"""Return the name of the current database."""
3031
return self.connection.db_name
3132

32-
# TODO - user real return type
33-
async def status(self) -> Any:
33+
@property
34+
def serializer(self) -> Serializer:
35+
"""Return the serializer."""
36+
return self._executor.serializer
37+
38+
@property
39+
def deserializer(self) -> Deserializer:
40+
"""Return the deserializer."""
41+
return self._executor.deserializer
42+
43+
async def status(self) -> Result[ServerStatusInformation]:
3444
"""Query the server status.
3545
3646
Returns:
37-
Json: Server status.
47+
ServerStatusInformation: Server status.
3848
3949
Raises:
4050
ServerSatusError: If retrieval fails.
4151
"""
4252
request = Request(method=Method.GET, endpoint="/_admin/status")
4353

44-
# TODO
45-
# - introduce specific return type for response_handler
46-
# - introduce specific serializer and deserializer
47-
def response_handler(resp: Response) -> Any:
54+
def response_handler(resp: Response) -> ServerStatusInformation:
4855
if not resp.is_success:
4956
raise ServerStatusError(resp, request)
50-
return json.loads(resp.raw_body)
57+
return ServerStatusInformation(self.deserializer.from_bytes(resp.raw_body))
5158

5259
return await self._executor.execute(request, response_handler)
5360

arangoasync/executor.py

+9
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from arangoasync.connection import Connection
44
from arangoasync.request import Request
55
from arangoasync.response import Response
6+
from arangoasync.serialization import Deserializer, Serializer
67

78
T = TypeVar("T")
89

@@ -27,6 +28,14 @@ def connection(self) -> Connection:
2728
def context(self) -> str:
2829
return "default"
2930

31+
@property
32+
def serializer(self) -> Serializer:
33+
return self._conn.serializer
34+
35+
@property
36+
def deserializer(self) -> Deserializer:
37+
return self._conn.deserializer
38+
3039
async def execute(
3140
self, request: Request, response_handler: Callable[[Response], T]
3241
) -> T:

arangoasync/job.py

+12
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
__all__ = ["AsyncJob"]
2+
3+
4+
from typing import Generic, TypeVar
5+
6+
T = TypeVar("T")
7+
8+
9+
class AsyncJob(Generic[T]):
10+
"""Job for tracking and retrieving result of an async API execution."""
11+
12+
pass

0 commit comments

Comments
 (0)