Skip to content

Commit 98012ca

Browse files
authored
Adding ArangoClient (#16)
1 parent e642e5b commit 98012ca

File tree

11 files changed

+379
-15
lines changed

11 files changed

+379
-15
lines changed

arangoasync/auth.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ class JwtToken:
3333
3434
Raises:
3535
TypeError: If the token type is not str or bytes.
36-
jwt.ExpiredSignatureError: If the token expired.
36+
jwt.exceptions.ExpiredSignatureError: If the token expired.
3737
"""
3838

3939
def __init__(self, token: str) -> None:
@@ -82,7 +82,7 @@ def token(self, token: str) -> None:
8282
"""Set token.
8383
8484
Raises:
85-
jwt.ExpiredSignatureError: If the token expired.
85+
jwt.exceptions.ExpiredSignatureError: If the token expired.
8686
"""
8787
self._token = token
8888
self._validate()

arangoasync/client.py

+201
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
__all__ = ["ArangoClient"]
2+
3+
import asyncio
4+
from typing import Any, Optional, Sequence
5+
6+
from arangoasync.auth import Auth, JwtToken
7+
from arangoasync.compression import CompressionManager
8+
from arangoasync.connection import (
9+
BasicConnection,
10+
Connection,
11+
JwtConnection,
12+
JwtSuperuserConnection,
13+
)
14+
from arangoasync.database import Database
15+
from arangoasync.http import DefaultHTTPClient, HTTPClient
16+
from arangoasync.resolver import HostResolver, get_resolver
17+
from arangoasync.version import __version__
18+
19+
20+
class ArangoClient:
21+
"""ArangoDB client.
22+
23+
Args:
24+
hosts (str | Sequence[str]): Host URL or list of URL's.
25+
In case of a cluster, this would be the list of coordinators.
26+
Which coordinator to use is determined by the `host_resolver`.
27+
host_resolver (str | HostResolver): Host resolver strategy.
28+
This determines how the client will choose which server to use.
29+
Passing a string would configure a resolver with the default settings.
30+
See :class:`DefaultHostResolver <arangoasync.resolver.DefaultHostResolver>`
31+
and :func:`get_resolver <arangoasync.resolver.get_resolver>`
32+
for more information.
33+
If you need more customization, pass a subclass of
34+
:class:`HostResolver <arangoasync.resolver.HostResolver>`.
35+
http_client (HTTPClient | None): HTTP client implementation.
36+
This is the core component that sends requests to the ArangoDB server.
37+
Defaults to :class:`DefaultHttpClient <arangoasync.http.DefaultHTTPClient>`,
38+
but you can fully customize its parameters or even use a different
39+
implementation by subclassing
40+
:class:`HTTPClient <arangoasync.http.HTTPClient>`.
41+
compression (CompressionManager | None): Disabled by default.
42+
Used to compress requests to the server or instruct the server to compress
43+
responses. Enable it by passing an instance of
44+
:class:`DefaultCompressionManager
45+
<arangoasync.compression.DefaultCompressionManager>`
46+
or a subclass of :class:`CompressionManager
47+
<arangoasync.compression.CompressionManager>`.
48+
49+
Raises:
50+
ValueError: If the `host_resolver` is not supported.
51+
"""
52+
53+
def __init__(
54+
self,
55+
hosts: str | Sequence[str] = "http://127.0.0.1:8529",
56+
host_resolver: str | HostResolver = "default",
57+
http_client: Optional[HTTPClient] = None,
58+
compression: Optional[CompressionManager] = None,
59+
) -> None:
60+
self._hosts = [hosts] if isinstance(hosts, str) else hosts
61+
self._host_resolver = (
62+
get_resolver(host_resolver, len(self._hosts))
63+
if isinstance(host_resolver, str)
64+
else host_resolver
65+
)
66+
self._http_client = http_client or DefaultHTTPClient()
67+
self._sessions = [
68+
self._http_client.create_session(host) for host in self._hosts
69+
]
70+
self._compression = compression
71+
72+
def __repr__(self) -> str:
73+
return f"<ArangoClient {','.join(self._hosts)}>"
74+
75+
async def __aenter__(self) -> "ArangoClient":
76+
return self
77+
78+
async def __aexit__(self, *exc: Any) -> None:
79+
await self.close()
80+
81+
@property
82+
def hosts(self) -> Sequence[str]:
83+
"""Return the list of hosts."""
84+
return self._hosts
85+
86+
@property
87+
def host_resolver(self) -> HostResolver:
88+
"""Return the host resolver."""
89+
return self._host_resolver
90+
91+
@property
92+
def compression(self) -> Optional[CompressionManager]:
93+
"""Return the compression manager."""
94+
return self._compression
95+
96+
@property
97+
def sessions(self) -> Sequence[Any]:
98+
"""Return the list of sessions.
99+
100+
You may use this to customize sessions on the fly (for example,
101+
adjust the timeout). Not recommended unless you know what you are doing.
102+
103+
Warning:
104+
Modifying only a subset of sessions may lead to unexpected behavior.
105+
In order to keep the client in a consistent state, you should make sure
106+
all sessions are configured in the same way.
107+
"""
108+
return self._sessions
109+
110+
@property
111+
def version(self) -> str:
112+
"""Return the version of the client."""
113+
return __version__
114+
115+
async def close(self) -> None:
116+
"""Close HTTP sessions."""
117+
await asyncio.gather(*(session.close() for session in self._sessions))
118+
119+
async def db(
120+
self,
121+
name: str,
122+
auth_method: str = "basic",
123+
auth: Optional[Auth] = None,
124+
token: Optional[JwtToken] = None,
125+
verify: bool = False,
126+
compression: Optional[CompressionManager] = None,
127+
) -> Database:
128+
"""Connects to a database and returns and API wrapper.
129+
130+
Args:
131+
name (str): Database name.
132+
auth_method (str): The following methods are supported:
133+
134+
- "basic": HTTP authentication.
135+
Requires the `auth` parameter. The `token` parameter is ignored.
136+
- "jwt": User JWT authentication.
137+
At least one of the `auth` or `token` parameters are required.
138+
If `auth` is provided, but the `token` is not, the token will be
139+
refreshed automatically. This assumes that the clocks of the server
140+
and client are synchronized.
141+
- "superuser": Superuser JWT authentication.
142+
The `token` parameter is required. The `auth` parameter is ignored.
143+
auth (Auth | None): Login information.
144+
token (JwtToken | None): JWT token.
145+
verify (bool): Verify the connection by sending a test request.
146+
compression (CompressionManager | None): Supersedes the client-level
147+
compression settings.
148+
149+
Returns:
150+
Database: Database API wrapper.
151+
152+
Raises:
153+
ValueError: If the authentication is invalid.
154+
ServerConnectionError: If `verify` is `True` and the connection fails.
155+
"""
156+
connection: Connection
157+
if auth_method == "basic":
158+
if auth is None:
159+
raise ValueError("Basic authentication requires the `auth` parameter")
160+
connection = BasicConnection(
161+
sessions=self._sessions,
162+
host_resolver=self._host_resolver,
163+
http_client=self._http_client,
164+
db_name=name,
165+
compression=compression or self._compression,
166+
auth=auth,
167+
)
168+
elif auth_method == "jwt":
169+
if auth is None and token is None:
170+
raise ValueError(
171+
"JWT authentication requires the `auth` or `token` parameter"
172+
)
173+
connection = JwtConnection(
174+
sessions=self._sessions,
175+
host_resolver=self._host_resolver,
176+
http_client=self._http_client,
177+
db_name=name,
178+
compression=compression or self._compression,
179+
auth=auth,
180+
token=token,
181+
)
182+
elif auth_method == "superuser":
183+
if token is None:
184+
raise ValueError(
185+
"Superuser JWT authentication requires the `token` parameter"
186+
)
187+
connection = JwtSuperuserConnection(
188+
sessions=self._sessions,
189+
host_resolver=self._host_resolver,
190+
http_client=self._http_client,
191+
db_name=name,
192+
compression=compression or self._compression,
193+
token=token,
194+
)
195+
else:
196+
raise ValueError(f"Invalid authentication method: {auth_method}")
197+
198+
if verify:
199+
await connection.ping()
200+
201+
return Database(connection)

arangoasync/compression.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def accept_encoding(self) -> str | None:
7272
"""Return the accept encoding.
7373
7474
This is the value of the Accept-Encoding header in the HTTP request.
75-
Currently, only deflate and "gzip" are supported.
75+
Currently, only "deflate" and "gzip" are supported.
7676
7777
Returns:
7878
str: Accept encoding

arangoasync/connection.py

+10-4
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
__all__ = [
22
"BaseConnection",
33
"BasicConnection",
4+
"Connection",
45
"JwtConnection",
56
"JwtSuperuserConnection",
67
]
@@ -244,9 +245,9 @@ def __init__(
244245
super().__init__(sessions, host_resolver, http_client, db_name, compression)
245246
self._auth = auth
246247
self._expire_leeway: int = 0
247-
self._token: Optional[JwtToken] = None
248+
self._token: Optional[JwtToken] = token
248249
self._auth_header: Optional[str] = None
249-
self.token = token
250+
self.token = self._token
250251

251252
if self._token is None and self._auth is None:
252253
raise ValueError("Either token or auth must be provided.")
@@ -323,6 +324,7 @@ async def send_request(self, request: Request) -> Response:
323324
Response: HTTP response
324325
325326
Raises:
327+
AuthHeaderError: If the authentication header could not be generated.
326328
ArangoClientError: If an error occurred from the client side.
327329
ArangoServerError: If an error occurred from the server side.
328330
"""
@@ -372,9 +374,9 @@ def __init__(
372374
) -> None:
373375
super().__init__(sessions, host_resolver, http_client, db_name, compression)
374376
self._expire_leeway: int = 0
375-
self._token: Optional[JwtToken] = None
377+
self._token: Optional[JwtToken] = token
376378
self._auth_header: Optional[str] = None
377-
self.token = token
379+
self.token = self._token
378380

379381
@property
380382
def token(self) -> Optional[JwtToken]:
@@ -407,6 +409,7 @@ async def send_request(self, request: Request) -> Response:
407409
Response: HTTP response
408410
409411
Raises:
412+
AuthHeaderError: If the authentication header could not be generated.
410413
ArangoClientError: If an error occurred from the client side.
411414
ArangoServerError: If an error occurred from the server side.
412415
"""
@@ -417,3 +420,6 @@ async def send_request(self, request: Request) -> Response:
417420
resp = await self.process_request(request)
418421
self.raise_for_status(request, resp)
419422
return resp
423+
424+
425+
Connection = BasicConnection | JwtConnection | JwtSuperuserConnection

arangoasync/database.py

+17
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
__all__ = [
2+
"Database",
3+
]
4+
5+
from arangoasync.connection import BaseConnection
6+
7+
8+
class Database:
9+
"""Database API."""
10+
11+
def __init__(self, connection: BaseConnection) -> None:
12+
self._conn = connection
13+
14+
@property
15+
def conn(self) -> BaseConnection:
16+
"""Return the HTTP connection."""
17+
return self._conn

arangoasync/exceptions.py

-4
Original file line numberDiff line numberDiff line change
@@ -80,10 +80,6 @@ class AuthHeaderError(ArangoClientError):
8080
"""The authentication header could not be determined."""
8181

8282

83-
class JWTExpiredError(ArangoClientError):
84-
"""JWT token has expired."""
85-
86-
8783
class JWTRefreshError(ArangoClientError):
8884
"""Failed to refresh the JWT token."""
8985

arangoasync/http.py

+10
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
]
66

77
from abc import ABC, abstractmethod
8+
from ssl import SSLContext, create_default_context
89
from typing import Any, Optional
910

1011
from aiohttp import (
@@ -82,6 +83,10 @@ class AioHTTPClient(HTTPClient):
8283
timeout (aiohttp.ClientTimeout | None): Client timeout settings.
8384
300s total timeout by default for a complete request/response operation.
8485
read_bufsize (int): Size of read buffer (64KB default).
86+
ssl_context (ssl.SSLContext | bool): SSL validation mode.
87+
`True` for default SSL checks (see :func:`ssl.create_default_context`).
88+
`False` disables SSL checks.
89+
Additionally, you can pass a custom :class:`ssl.SSLContext`.
8590
8691
.. _aiohttp:
8792
https://docs.aiohttp.org/en/stable/
@@ -92,6 +97,7 @@ def __init__(
9297
connector: Optional[BaseConnector] = None,
9398
timeout: Optional[ClientTimeout] = None,
9499
read_bufsize: int = 2**16,
100+
ssl_context: bool | SSLContext = True,
95101
) -> None:
96102
self._connector = connector or TCPConnector(
97103
keepalive_timeout=60, # timeout for connection reusing after releasing
@@ -102,6 +108,9 @@ def __init__(
102108
connect=60, # max number of seconds for acquiring a pool connection
103109
)
104110
self._read_bufsize = read_bufsize
111+
self._ssl_context = (
112+
ssl_context if ssl_context is not True else create_default_context()
113+
)
105114

106115
def create_session(self, host: str) -> ClientSession:
107116
"""Return a new session given the base host URL.
@@ -155,6 +164,7 @@ async def send_request(
155164
params=request.normalized_params(),
156165
data=request.data,
157166
auth=auth,
167+
ssl=self._ssl_context,
158168
) as response:
159169
raw_body = await response.read()
160170
return Response(

arangoasync/resolver.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@ class HostResolver(ABC):
1515
1616
Args:
1717
host_count (int): Number of hosts.
18-
max_tries (int): Maximum number of attempts to try a host.
18+
max_tries (int | None): Maximum number of attempts to try a host.
19+
Will default to 3 times the number of hosts if not provided.
1920
2021
Raises:
2122
ValueError: If max_tries is less than host_count.
@@ -42,7 +43,7 @@ def get_host_index(self) -> int: # pragma: no cover
4243
raise NotImplementedError
4344

4445
def change_host(self) -> None:
45-
"""If there aer multiple hosts available, switch to the next one."""
46+
"""If there are multiple hosts available, switch to the next one."""
4647
self._index = (self._index + 1) % self.host_count
4748

4849
@property
@@ -57,7 +58,10 @@ def max_tries(self) -> int:
5758

5859

5960
class SingleHostResolver(HostResolver):
60-
"""Single host resolver. Always returns the same host index."""
61+
"""Single host resolver.
62+
63+
Always returns the same host index, unless prompted to change.
64+
"""
6165

6266
def __init__(self, host_count: int, max_tries: Optional[int] = None) -> None:
6367
super().__init__(host_count, max_tries)

docs/conf.py

+1
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
autodoc_typehints = "none"
2424

2525
intersphinx_mapping = {
26+
"python": ("https://docs.python.org/3", None),
2627
"aiohttp": ("https://docs.aiohttp.org/en/stable/", None),
2728
"jwt": ("https://pyjwt.readthedocs.io/en/stable/", None),
2829
}

0 commit comments

Comments
 (0)