Skip to content

Custom serialization #19

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions arangoasync/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,12 @@
from arangoasync.database import StandardDatabase
from arangoasync.http import DefaultHTTPClient, HTTPClient
from arangoasync.resolver import HostResolver, get_resolver
from arangoasync.serialization import (
DefaultDeserializer,
DefaultSerializer,
Deserializer,
Serializer,
)
from arangoasync.version import __version__


Expand Down Expand Up @@ -45,6 +51,14 @@ class ArangoClient:
<arangoasync.compression.DefaultCompressionManager>`
or a custom subclass of :class:`CompressionManager
<arangoasync.compression.CompressionManager>`.
serializer (Serializer | None): Custom serializer implementation.
Leave as `None` to use the default serializer.
See :class:`DefaultSerializer
<arangoasync.serialization.DefaultSerializer>`.
deserializer (Deserializer | None): Custom deserializer implementation.
Leave as `None` to use the default deserializer.
See :class:`DefaultDeserializer
<arangoasync.serialization.DefaultDeserializer>`.

Raises:
ValueError: If the `host_resolver` is not supported.
Expand All @@ -56,6 +70,8 @@ def __init__(
host_resolver: str | HostResolver = "default",
http_client: Optional[HTTPClient] = None,
compression: Optional[CompressionManager] = None,
serializer: Optional[Serializer] = None,
deserializer: Optional[Deserializer] = None,
) -> None:
self._hosts = [hosts] if isinstance(hosts, str) else hosts
self._host_resolver = (
Expand All @@ -68,6 +84,8 @@ def __init__(
self._http_client.create_session(host) for host in self._hosts
]
self._compression = compression
self._serializer = serializer or DefaultSerializer()
self._deserializer = deserializer or DefaultDeserializer()

def __repr__(self) -> str:
return f"<ArangoClient {','.join(self._hosts)}>"
Expand Down Expand Up @@ -124,6 +142,8 @@ async def db(
token: Optional[JwtToken] = None,
verify: bool = False,
compression: Optional[CompressionManager] = None,
serializer: Optional[Serializer] = None,
deserializer: Optional[Deserializer] = None,
) -> StandardDatabase:
"""Connects to a database and returns and API wrapper.

Expand All @@ -145,6 +165,10 @@ async def db(
verify (bool): Verify the connection by sending a test request.
compression (CompressionManager | None): If set, supersedes the
client-level compression settings.
serializer (Serializer | None): If set, supersedes the client-level
serializer.
deserializer (Deserializer | None): If set, supersedes the client-level
deserializer.

Returns:
StandardDatabase: Database API wrapper.
Expand All @@ -163,6 +187,8 @@ async def db(
http_client=self._http_client,
db_name=name,
compression=compression or self._compression,
serializer=serializer or self._serializer,
deserializer=deserializer or self._deserializer,
auth=auth,
)
elif auth_method == "jwt":
Expand All @@ -176,6 +202,8 @@ async def db(
http_client=self._http_client,
db_name=name,
compression=compression or self._compression,
serializer=serializer or self._serializer,
deserializer=deserializer or self._deserializer,
auth=auth,
token=token,
)
Expand All @@ -190,6 +218,8 @@ async def db(
http_client=self._http_client,
db_name=name,
compression=compression or self._compression,
serializer=serializer or self._serializer,
deserializer=deserializer or self._deserializer,
token=token,
)
else:
Expand Down
86 changes: 71 additions & 15 deletions arangoasync/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@
"JwtSuperuserConnection",
]

import json
from abc import ABC, abstractmethod
from json import JSONDecodeError
from typing import Any, List, Optional

import jwt
from jwt import ExpiredSignatureError

from arangoasync import errno, logger
from arangoasync.auth import Auth, JwtToken
Expand All @@ -26,6 +26,12 @@
from arangoasync.request import Method, Request
from arangoasync.resolver import HostResolver
from arangoasync.response import Response
from arangoasync.serialization import (
DefaultDeserializer,
DefaultSerializer,
Deserializer,
Serializer,
)


class BaseConnection(ABC):
Expand All @@ -37,6 +43,10 @@ class BaseConnection(ABC):
http_client (HTTPClient): HTTP client.
db_name (str): Database name.
compression (CompressionManager | None): Compression manager.
serializer (Serializer | None): For custom serialization.
Leave `None` for default.
deserializer (Deserializer | None): For custom deserialization.
Leave `None` for default.
"""

def __init__(
Expand All @@ -46,19 +56,33 @@ def __init__(
http_client: HTTPClient,
db_name: str,
compression: Optional[CompressionManager] = None,
serializer: Optional[Serializer] = None,
deserializer: Optional[Deserializer] = None,
) -> None:
self._sessions = sessions
self._db_endpoint = f"/_db/{db_name}"
self._host_resolver = host_resolver
self._http_client = http_client
self._db_name = db_name
self._compression = compression
self._serializer = serializer or DefaultSerializer()
self._deserializer = deserializer or DefaultDeserializer()

@property
def db_name(self) -> str:
"""Return the database name."""
return self._db_name

@property
def serializer(self) -> Serializer:
"""Return the serializer."""
return self._serializer

@property
def deserializer(self) -> Deserializer:
"""Return the deserializer."""
return self._deserializer

@staticmethod
def raise_for_status(request: Request, resp: Response) -> None:
"""Raise an exception based on the response.
Expand All @@ -75,8 +99,7 @@ def raise_for_status(request: Request, resp: Response) -> None:
if not resp.is_success:
raise ServerConnectionError(resp, request, "Bad server response.")

@staticmethod
def prep_response(request: Request, resp: Response) -> Response:
def prep_response(self, request: Request, resp: Response) -> Response:
"""Prepare response for return.

Args:
Expand All @@ -89,8 +112,8 @@ def prep_response(request: Request, resp: Response) -> Response:
resp.is_success = 200 <= resp.status_code < 300
if not resp.is_success:
try:
body = json.loads(resp.raw_body)
except json.JSONDecodeError as e:
body = self._deserializer.from_bytes(resp.raw_body)
except JSONDecodeError as e:
logger.debug(
f"Failed to decode response body: {e} (from request {request})"
)
Expand Down Expand Up @@ -202,6 +225,8 @@ class BasicConnection(BaseConnection):
http_client (HTTPClient): HTTP client.
db_name (str): Database name.
compression (CompressionManager | None): Compression manager.
serializer (Serializer | None): For custom serialization.
deserializer (Deserializer | None): For custom deserialization.
auth (Auth | None): Authentication information.
"""

Expand All @@ -212,9 +237,19 @@ def __init__(
http_client: HTTPClient,
db_name: str,
compression: Optional[CompressionManager] = None,
serializer: Optional[Serializer] = None,
deserializer: Optional[Deserializer] = None,
auth: Optional[Auth] = None,
) -> None:
super().__init__(sessions, host_resolver, http_client, db_name, compression)
super().__init__(
sessions,
host_resolver,
http_client,
db_name,
compression,
serializer,
deserializer,
)
self._auth = auth

async def send_request(self, request: Request) -> Response:
Expand Down Expand Up @@ -249,6 +284,8 @@ class JwtConnection(BaseConnection):
http_client (HTTPClient): HTTP client.
db_name (str): Database name.
compression (CompressionManager | None): Compression manager.
serializer (Serializer | None): For custom serialization.
deserializer (Deserializer | None): For custom deserialization.
auth (Auth | None): Authentication information.
token (JwtToken | None): JWT token.

Expand All @@ -263,10 +300,20 @@ def __init__(
http_client: HTTPClient,
db_name: str,
compression: Optional[CompressionManager] = None,
serializer: Optional[Serializer] = None,
deserializer: Optional[Deserializer] = None,
auth: Optional[Auth] = None,
token: Optional[JwtToken] = None,
) -> None:
super().__init__(sessions, host_resolver, http_client, db_name, compression)
super().__init__(
sessions,
host_resolver,
http_client,
db_name,
compression,
serializer,
deserializer,
)
self._auth = auth
self._expire_leeway: int = 0
self._token: Optional[JwtToken] = token
Expand Down Expand Up @@ -306,10 +353,8 @@ async def refresh_token(self) -> None:
if self._auth is None:
raise JWTRefreshError("Auth must be provided to refresh the token.")

auth_data = json.dumps(
auth_data = self._serializer.to_str(
dict(username=self._auth.username, password=self._auth.password),
separators=(",", ":"),
ensure_ascii=False,
)
request = Request(
method=Method.POST,
Expand All @@ -330,10 +375,10 @@ async def refresh_token(self) -> None:
f"{resp.status_code} {resp.status_text}"
)

token = json.loads(resp.raw_body)
token = self._deserializer.from_bytes(resp.raw_body)
try:
self.token = JwtToken(token["jwt"])
except jwt.ExpiredSignatureError as e:
except ExpiredSignatureError as e:
raise JWTRefreshError(
"Failed to refresh the JWT token: got an expired token"
) from e
Expand Down Expand Up @@ -385,6 +430,8 @@ class JwtSuperuserConnection(BaseConnection):
http_client (HTTPClient): HTTP client.
db_name (str): Database name.
compression (CompressionManager | None): Compression manager.
serializer (Serializer | None): For custom serialization.
deserializer (Deserializer | None): For custom deserialization.
token (JwtToken | None): JWT token.
"""

Expand All @@ -395,10 +442,19 @@ def __init__(
http_client: HTTPClient,
db_name: str,
compression: Optional[CompressionManager] = None,
serializer: Optional[Serializer] = None,
deserializer: Optional[Deserializer] = None,
token: Optional[JwtToken] = None,
) -> None:
super().__init__(sessions, host_resolver, http_client, db_name, compression)
self._expire_leeway: int = 0
super().__init__(
sessions,
host_resolver,
http_client,
db_name,
compression,
serializer,
deserializer,
)
self._token: Optional[JwtToken] = token
self._auth_header: Optional[str] = None
self.token = self._token
Expand Down
27 changes: 17 additions & 10 deletions arangoasync/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@
"StandardDatabase",
]

import json
from typing import Any

from arangoasync.connection import Connection
from arangoasync.exceptions import ServerStatusError
from arangoasync.executor import ApiExecutor, DefaultApiExecutor
from arangoasync.request import Method, Request
from arangoasync.response import Response
from arangoasync.serialization import Deserializer, Serializer
from arangoasync.typings import Result
from arangoasync.wrapper import ServerStatusInformation


class Database:
Expand All @@ -29,25 +30,31 @@ def name(self) -> str:
"""Return the name of the current database."""
return self.connection.db_name

# TODO - user real return type
async def status(self) -> Any:
@property
def serializer(self) -> Serializer:
"""Return the serializer."""
return self._executor.serializer

@property
def deserializer(self) -> Deserializer:
"""Return the deserializer."""
return self._executor.deserializer

async def status(self) -> Result[ServerStatusInformation]:
"""Query the server status.

Returns:
Json: Server status.
ServerStatusInformation: Server status.

Raises:
ServerSatusError: If retrieval fails.
"""
request = Request(method=Method.GET, endpoint="/_admin/status")

# TODO
# - introduce specific return type for response_handler
# - introduce specific serializer and deserializer
def response_handler(resp: Response) -> Any:
def response_handler(resp: Response) -> ServerStatusInformation:
if not resp.is_success:
raise ServerStatusError(resp, request)
return json.loads(resp.raw_body)
return ServerStatusInformation(self.deserializer.from_bytes(resp.raw_body))

return await self._executor.execute(request, response_handler)

Expand Down
9 changes: 9 additions & 0 deletions arangoasync/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from arangoasync.connection import Connection
from arangoasync.request import Request
from arangoasync.response import Response
from arangoasync.serialization import Deserializer, Serializer

T = TypeVar("T")

Expand All @@ -27,6 +28,14 @@ def connection(self) -> Connection:
def context(self) -> str:
return "default"

@property
def serializer(self) -> Serializer:
return self._conn.serializer

@property
def deserializer(self) -> Deserializer:
return self._conn.deserializer

async def execute(
self, request: Request, response_handler: Callable[[Response], T]
) -> T:
Expand Down
12 changes: 12 additions & 0 deletions arangoasync/job.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
__all__ = ["AsyncJob"]


from typing import Generic, TypeVar

T = TypeVar("T")


class AsyncJob(Generic[T]):
"""Job for tracking and retrieving result of an async API execution."""

pass
Loading
Loading