Skip to content

Generic Collection #20

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 16 additions & 8 deletions arangoasync/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
Deserializer,
Serializer,
)
from arangoasync.typings import Json, Jsons
from arangoasync.version import __version__


Expand Down Expand Up @@ -51,14 +52,18 @@ class ArangoClient:
<arangoasync.compression.DefaultCompressionManager>`
or a custom subclass of :class:`CompressionManager
<arangoasync.compression.CompressionManager>`.
serializer (Serializer | None): Custom serializer implementation.
serializer (Serializer | None): Custom JSON serializer implementation.
Leave as `None` to use the default serializer.
See :class:`DefaultSerializer
<arangoasync.serialization.DefaultSerializer>`.
deserializer (Deserializer | None): Custom deserializer implementation.
For custom serialization of collection documents, see :class:`Collection
<arangoasync.collection.Collection>`.
deserializer (Deserializer | None): Custom JSON deserializer implementation.
Leave as `None` to use the default deserializer.
See :class:`DefaultDeserializer
<arangoasync.serialization.DefaultDeserializer>`.
For custom deserialization of collection documents, see :class:`Collection
<arangoasync.collection.Collection>`.

Raises:
ValueError: If the `host_resolver` is not supported.
Expand All @@ -70,8 +75,8 @@ def __init__(
host_resolver: str | HostResolver = "default",
http_client: Optional[HTTPClient] = None,
compression: Optional[CompressionManager] = None,
serializer: Optional[Serializer] = None,
deserializer: Optional[Deserializer] = None,
serializer: Optional[Serializer[Json]] = None,
deserializer: Optional[Deserializer[Json, Jsons]] = None,
) -> None:
self._hosts = [hosts] if isinstance(hosts, str) else hosts
self._host_resolver = (
Expand All @@ -84,8 +89,10 @@ def __init__(
self._http_client.create_session(host) for host in self._hosts
]
self._compression = compression
self._serializer = serializer or DefaultSerializer()
self._deserializer = deserializer or DefaultDeserializer()
self._serializer: Serializer[Json] = serializer or DefaultSerializer()
self._deserializer: Deserializer[Json, Jsons] = (
deserializer or DefaultDeserializer()
)

def __repr__(self) -> str:
return f"<ArangoClient {','.join(self._hosts)}>"
Expand Down Expand Up @@ -142,8 +149,8 @@ async def db(
token: Optional[JwtToken] = None,
verify: bool = False,
compression: Optional[CompressionManager] = None,
serializer: Optional[Serializer] = None,
deserializer: Optional[Deserializer] = None,
serializer: Optional[Serializer[Json]] = None,
deserializer: Optional[Deserializer[Json, Jsons]] = None,
) -> StandardDatabase:
"""Connects to a database and returns and API wrapper.

Expand Down Expand Up @@ -178,6 +185,7 @@ async def db(
ServerConnectionError: If `verify` is `True` and the connection fails.
"""
connection: Connection

if auth_method == "basic":
if auth is None:
raise ValueError("Basic authentication requires the `auth` parameter")
Expand Down
205 changes: 205 additions & 0 deletions arangoasync/collection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
__all__ = ["Collection", "Collection", "StandardCollection"]


from enum import Enum
from typing import Generic, Optional, Tuple, TypeVar

from arangoasync.errno import HTTP_NOT_FOUND, HTTP_PRECONDITION_FAILED
from arangoasync.exceptions import (
DocumentGetError,
DocumentParseError,
DocumentRevisionError,
)
from arangoasync.executor import ApiExecutor
from arangoasync.request import Method, Request
from arangoasync.response import Response
from arangoasync.serialization import Deserializer, Serializer
from arangoasync.typings import Json, Result

T = TypeVar("T")
U = TypeVar("U")
V = TypeVar("V")


class CollectionType(Enum):
"""Collection types."""

DOCUMENT = 2
EDGE = 3


class Collection(Generic[T, U, V]):
"""Base class for collection API wrappers.

Args:
executor (ApiExecutor): API executor.
name (str): Collection name
doc_serializer (Serializer): Document serializer.
doc_deserializer (Deserializer): Document deserializer.
"""

def __init__(
self,
executor: ApiExecutor,
name: str,
doc_serializer: Serializer[T],
doc_deserializer: Deserializer[U, V],
) -> None:
self._executor = executor
self._name = name
self._doc_serializer = doc_serializer
self._doc_deserializer = doc_deserializer
self._id_prefix = f"{self._name}/"

def __repr__(self) -> str:
return f"<StandardCollection {self.name}>"

def _validate_id(self, doc_id: str) -> str:
"""Check the collection name in the document ID.

Args:
doc_id (str): Document ID.

Returns:
str: Verified document ID.

Raises:
DocumentParseError: On bad collection name.
"""
if not doc_id.startswith(self._id_prefix):
raise DocumentParseError(f'Bad collection name in document ID "{doc_id}"')
return doc_id

def _extract_id(self, body: Json) -> str:
"""Extract the document ID from document body.

Args:
body (dict): Document body.

Returns:
str: Document ID.

Raises:
DocumentParseError: On missing ID and key.
"""
try:
if "_id" in body:
return self._validate_id(body["_id"])
else:
key: str = body["_key"]
return self._id_prefix + key
except KeyError:
raise DocumentParseError('Field "_key" or "_id" required')

def _prep_from_doc(
self,
document: str | Json,
rev: Optional[str] = None,
check_rev: bool = False,
) -> Tuple[str, Json]:
"""Prepare document ID, body and request headers before a query.

Args:
document (str | dict): Document ID, key or body.
rev (str | None): Document revision.
check_rev (bool): Whether to check the revision.

Returns:
Document ID and request headers.

Raises:
DocumentParseError: On missing ID and key.
TypeError: On bad document type.
"""
if isinstance(document, dict):
doc_id = self._extract_id(document)
rev = rev or document.get("_rev")
elif isinstance(document, str):
if "/" in document:
doc_id = self._validate_id(document)
else:
doc_id = self._id_prefix + document
else:
raise TypeError("Document must be str or a dict")

if not check_rev or rev is None:
return doc_id, {}
else:
return doc_id, {"If-Match": rev}

@property
def name(self) -> str:
"""Return the name of the collection.

Returns:
str: Collection name.
"""
return self._name


class StandardCollection(Collection[T, U, V]):
"""Standard collection API wrapper.

Args:
executor (ApiExecutor): API executor.
name (str): Collection name
doc_serializer (Serializer): Document serializer.
doc_deserializer (Deserializer): Document deserializer.
"""

def __init__(
self,
executor: ApiExecutor,
name: str,
doc_serializer: Serializer[T],
doc_deserializer: Deserializer[U, V],
) -> None:
super().__init__(executor, name, doc_serializer, doc_deserializer)

async def get(
self,
document: str | Json,
rev: Optional[str] = None,
check_rev: bool = True,
allow_dirty_read: bool = False,
) -> Result[Optional[U]]:
"""Return a document.

Args:
document (str | dict): Document ID, key or body.
Document body must contain the "_id" or "_key" field.
rev (str | None): Expected document revision. Overrides the
value of "_rev" field in **document** if present.
check_rev (bool): If set to True, revision of **document** (if given)
is compared against the revision of target document.
allow_dirty_read (bool): Allow reads from followers in a cluster.

Returns:
Document or None if not found.

Raises:
DocumentRevisionError: If the revision is incorrect.
DocumentGetError: If retrieval fails.
"""
handle, headers = self._prep_from_doc(document, rev, check_rev)

if allow_dirty_read:
headers["x-arango-allow-dirty-read"] = "true"

request = Request(
method=Method.GET,
endpoint=f"/_api/document/{handle}",
headers=headers,
)

def response_handler(resp: Response) -> Optional[U]:
if resp.is_success:
return self._doc_deserializer.loads(resp.raw_body)
elif resp.error_code == HTTP_NOT_FOUND:
return None
elif resp.error_code == HTTP_PRECONDITION_FAILED:
raise DocumentRevisionError(resp, request)
else:
raise DocumentGetError(resp, request)

return await self._executor.execute(request, response_handler)
Loading
Loading