Skip to content

Commit 8c8b237

Browse files
authored
Generic Collection (#20)
1 parent 8dea5f4 commit 8c8b237

11 files changed

+636
-74
lines changed

arangoasync/client.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
Deserializer,
2121
Serializer,
2222
)
23+
from arangoasync.typings import Json, Jsons
2324
from arangoasync.version import __version__
2425

2526

@@ -51,14 +52,18 @@ class ArangoClient:
5152
<arangoasync.compression.DefaultCompressionManager>`
5253
or a custom subclass of :class:`CompressionManager
5354
<arangoasync.compression.CompressionManager>`.
54-
serializer (Serializer | None): Custom serializer implementation.
55+
serializer (Serializer | None): Custom JSON serializer implementation.
5556
Leave as `None` to use the default serializer.
5657
See :class:`DefaultSerializer
5758
<arangoasync.serialization.DefaultSerializer>`.
58-
deserializer (Deserializer | None): Custom deserializer implementation.
59+
For custom serialization of collection documents, see :class:`Collection
60+
<arangoasync.collection.Collection>`.
61+
deserializer (Deserializer | None): Custom JSON deserializer implementation.
5962
Leave as `None` to use the default deserializer.
6063
See :class:`DefaultDeserializer
6164
<arangoasync.serialization.DefaultDeserializer>`.
65+
For custom deserialization of collection documents, see :class:`Collection
66+
<arangoasync.collection.Collection>`.
6267
6368
Raises:
6469
ValueError: If the `host_resolver` is not supported.
@@ -70,8 +75,8 @@ def __init__(
7075
host_resolver: str | HostResolver = "default",
7176
http_client: Optional[HTTPClient] = None,
7277
compression: Optional[CompressionManager] = None,
73-
serializer: Optional[Serializer] = None,
74-
deserializer: Optional[Deserializer] = None,
78+
serializer: Optional[Serializer[Json]] = None,
79+
deserializer: Optional[Deserializer[Json, Jsons]] = None,
7580
) -> None:
7681
self._hosts = [hosts] if isinstance(hosts, str) else hosts
7782
self._host_resolver = (
@@ -84,8 +89,10 @@ def __init__(
8489
self._http_client.create_session(host) for host in self._hosts
8590
]
8691
self._compression = compression
87-
self._serializer = serializer or DefaultSerializer()
88-
self._deserializer = deserializer or DefaultDeserializer()
92+
self._serializer: Serializer[Json] = serializer or DefaultSerializer()
93+
self._deserializer: Deserializer[Json, Jsons] = (
94+
deserializer or DefaultDeserializer()
95+
)
8996

9097
def __repr__(self) -> str:
9198
return f"<ArangoClient {','.join(self._hosts)}>"
@@ -142,8 +149,8 @@ async def db(
142149
token: Optional[JwtToken] = None,
143150
verify: bool = False,
144151
compression: Optional[CompressionManager] = None,
145-
serializer: Optional[Serializer] = None,
146-
deserializer: Optional[Deserializer] = None,
152+
serializer: Optional[Serializer[Json]] = None,
153+
deserializer: Optional[Deserializer[Json, Jsons]] = None,
147154
) -> StandardDatabase:
148155
"""Connects to a database and returns and API wrapper.
149156
@@ -178,6 +185,7 @@ async def db(
178185
ServerConnectionError: If `verify` is `True` and the connection fails.
179186
"""
180187
connection: Connection
188+
181189
if auth_method == "basic":
182190
if auth is None:
183191
raise ValueError("Basic authentication requires the `auth` parameter")

arangoasync/collection.py

Lines changed: 205 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,205 @@
1+
__all__ = ["Collection", "Collection", "StandardCollection"]
2+
3+
4+
from enum import Enum
5+
from typing import Generic, Optional, Tuple, TypeVar
6+
7+
from arangoasync.errno import HTTP_NOT_FOUND, HTTP_PRECONDITION_FAILED
8+
from arangoasync.exceptions import (
9+
DocumentGetError,
10+
DocumentParseError,
11+
DocumentRevisionError,
12+
)
13+
from arangoasync.executor import ApiExecutor
14+
from arangoasync.request import Method, Request
15+
from arangoasync.response import Response
16+
from arangoasync.serialization import Deserializer, Serializer
17+
from arangoasync.typings import Json, Result
18+
19+
T = TypeVar("T")
20+
U = TypeVar("U")
21+
V = TypeVar("V")
22+
23+
24+
class CollectionType(Enum):
25+
"""Collection types."""
26+
27+
DOCUMENT = 2
28+
EDGE = 3
29+
30+
31+
class Collection(Generic[T, U, V]):
32+
"""Base class for collection API wrappers.
33+
34+
Args:
35+
executor (ApiExecutor): API executor.
36+
name (str): Collection name
37+
doc_serializer (Serializer): Document serializer.
38+
doc_deserializer (Deserializer): Document deserializer.
39+
"""
40+
41+
def __init__(
42+
self,
43+
executor: ApiExecutor,
44+
name: str,
45+
doc_serializer: Serializer[T],
46+
doc_deserializer: Deserializer[U, V],
47+
) -> None:
48+
self._executor = executor
49+
self._name = name
50+
self._doc_serializer = doc_serializer
51+
self._doc_deserializer = doc_deserializer
52+
self._id_prefix = f"{self._name}/"
53+
54+
def __repr__(self) -> str:
55+
return f"<StandardCollection {self.name}>"
56+
57+
def _validate_id(self, doc_id: str) -> str:
58+
"""Check the collection name in the document ID.
59+
60+
Args:
61+
doc_id (str): Document ID.
62+
63+
Returns:
64+
str: Verified document ID.
65+
66+
Raises:
67+
DocumentParseError: On bad collection name.
68+
"""
69+
if not doc_id.startswith(self._id_prefix):
70+
raise DocumentParseError(f'Bad collection name in document ID "{doc_id}"')
71+
return doc_id
72+
73+
def _extract_id(self, body: Json) -> str:
74+
"""Extract the document ID from document body.
75+
76+
Args:
77+
body (dict): Document body.
78+
79+
Returns:
80+
str: Document ID.
81+
82+
Raises:
83+
DocumentParseError: On missing ID and key.
84+
"""
85+
try:
86+
if "_id" in body:
87+
return self._validate_id(body["_id"])
88+
else:
89+
key: str = body["_key"]
90+
return self._id_prefix + key
91+
except KeyError:
92+
raise DocumentParseError('Field "_key" or "_id" required')
93+
94+
def _prep_from_doc(
95+
self,
96+
document: str | Json,
97+
rev: Optional[str] = None,
98+
check_rev: bool = False,
99+
) -> Tuple[str, Json]:
100+
"""Prepare document ID, body and request headers before a query.
101+
102+
Args:
103+
document (str | dict): Document ID, key or body.
104+
rev (str | None): Document revision.
105+
check_rev (bool): Whether to check the revision.
106+
107+
Returns:
108+
Document ID and request headers.
109+
110+
Raises:
111+
DocumentParseError: On missing ID and key.
112+
TypeError: On bad document type.
113+
"""
114+
if isinstance(document, dict):
115+
doc_id = self._extract_id(document)
116+
rev = rev or document.get("_rev")
117+
elif isinstance(document, str):
118+
if "/" in document:
119+
doc_id = self._validate_id(document)
120+
else:
121+
doc_id = self._id_prefix + document
122+
else:
123+
raise TypeError("Document must be str or a dict")
124+
125+
if not check_rev or rev is None:
126+
return doc_id, {}
127+
else:
128+
return doc_id, {"If-Match": rev}
129+
130+
@property
131+
def name(self) -> str:
132+
"""Return the name of the collection.
133+
134+
Returns:
135+
str: Collection name.
136+
"""
137+
return self._name
138+
139+
140+
class StandardCollection(Collection[T, U, V]):
141+
"""Standard collection API wrapper.
142+
143+
Args:
144+
executor (ApiExecutor): API executor.
145+
name (str): Collection name
146+
doc_serializer (Serializer): Document serializer.
147+
doc_deserializer (Deserializer): Document deserializer.
148+
"""
149+
150+
def __init__(
151+
self,
152+
executor: ApiExecutor,
153+
name: str,
154+
doc_serializer: Serializer[T],
155+
doc_deserializer: Deserializer[U, V],
156+
) -> None:
157+
super().__init__(executor, name, doc_serializer, doc_deserializer)
158+
159+
async def get(
160+
self,
161+
document: str | Json,
162+
rev: Optional[str] = None,
163+
check_rev: bool = True,
164+
allow_dirty_read: bool = False,
165+
) -> Result[Optional[U]]:
166+
"""Return a document.
167+
168+
Args:
169+
document (str | dict): Document ID, key or body.
170+
Document body must contain the "_id" or "_key" field.
171+
rev (str | None): Expected document revision. Overrides the
172+
value of "_rev" field in **document** if present.
173+
check_rev (bool): If set to True, revision of **document** (if given)
174+
is compared against the revision of target document.
175+
allow_dirty_read (bool): Allow reads from followers in a cluster.
176+
177+
Returns:
178+
Document or None if not found.
179+
180+
Raises:
181+
DocumentRevisionError: If the revision is incorrect.
182+
DocumentGetError: If retrieval fails.
183+
"""
184+
handle, headers = self._prep_from_doc(document, rev, check_rev)
185+
186+
if allow_dirty_read:
187+
headers["x-arango-allow-dirty-read"] = "true"
188+
189+
request = Request(
190+
method=Method.GET,
191+
endpoint=f"/_api/document/{handle}",
192+
headers=headers,
193+
)
194+
195+
def response_handler(resp: Response) -> Optional[U]:
196+
if resp.is_success:
197+
return self._doc_deserializer.loads(resp.raw_body)
198+
elif resp.error_code == HTTP_NOT_FOUND:
199+
return None
200+
elif resp.error_code == HTTP_PRECONDITION_FAILED:
201+
raise DocumentRevisionError(resp, request)
202+
else:
203+
raise DocumentGetError(resp, request)
204+
205+
return await self._executor.execute(request, response_handler)

0 commit comments

Comments
 (0)