Skip to content

Adding ArangoClient #16

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions arangoasync/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class JwtToken:

Raises:
TypeError: If the token type is not str or bytes.
jwt.ExpiredSignatureError: If the token expired.
jwt.exceptions.ExpiredSignatureError: If the token expired.
"""

def __init__(self, token: str) -> None:
Expand Down Expand Up @@ -82,7 +82,7 @@ def token(self, token: str) -> None:
"""Set token.

Raises:
jwt.ExpiredSignatureError: If the token expired.
jwt.exceptions.ExpiredSignatureError: If the token expired.
"""
self._token = token
self._validate()
Expand Down
201 changes: 201 additions & 0 deletions arangoasync/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
__all__ = ["ArangoClient"]

import asyncio
from typing import Any, Optional, Sequence

from arangoasync.auth import Auth, JwtToken
from arangoasync.compression import CompressionManager
from arangoasync.connection import (
BasicConnection,
Connection,
JwtConnection,
JwtSuperuserConnection,
)
from arangoasync.database import Database
from arangoasync.http import DefaultHTTPClient, HTTPClient
from arangoasync.resolver import HostResolver, get_resolver
from arangoasync.version import __version__


class ArangoClient:
"""ArangoDB client.

Args:
hosts (str | Sequence[str]): Host URL or list of URL's.
In case of a cluster, this would be the list of coordinators.
Which coordinator to use is determined by the `host_resolver`.
host_resolver (str | HostResolver): Host resolver strategy.
This determines how the client will choose which server to use.
Passing a string would configure a resolver with the default settings.
See :class:`DefaultHostResolver <arangoasync.resolver.DefaultHostResolver>`
and :func:`get_resolver <arangoasync.resolver.get_resolver>`
for more information.
If you need more customization, pass a subclass of
:class:`HostResolver <arangoasync.resolver.HostResolver>`.
http_client (HTTPClient | None): HTTP client implementation.
This is the core component that sends requests to the ArangoDB server.
Defaults to :class:`DefaultHttpClient <arangoasync.http.DefaultHTTPClient>`,
but you can fully customize its parameters or even use a different
implementation by subclassing
:class:`HTTPClient <arangoasync.http.HTTPClient>`.
compression (CompressionManager | None): Disabled by default.
Used to compress requests to the server or instruct the server to compress
responses. Enable it by passing an instance of
:class:`DefaultCompressionManager
<arangoasync.compression.DefaultCompressionManager>`
or a subclass of :class:`CompressionManager
<arangoasync.compression.CompressionManager>`.

Raises:
ValueError: If the `host_resolver` is not supported.
"""

def __init__(
self,
hosts: str | Sequence[str] = "http://127.0.0.1:8529",
host_resolver: str | HostResolver = "default",
http_client: Optional[HTTPClient] = None,
compression: Optional[CompressionManager] = None,
) -> None:
self._hosts = [hosts] if isinstance(hosts, str) else hosts
self._host_resolver = (
get_resolver(host_resolver, len(self._hosts))
if isinstance(host_resolver, str)
else host_resolver
)
self._http_client = http_client or DefaultHTTPClient()
self._sessions = [
self._http_client.create_session(host) for host in self._hosts
]
self._compression = compression

def __repr__(self) -> str:
return f"<ArangoClient {','.join(self._hosts)}>"

async def __aenter__(self) -> "ArangoClient":
return self

async def __aexit__(self, *exc: Any) -> None:
await self.close()

@property
def hosts(self) -> Sequence[str]:
"""Return the list of hosts."""
return self._hosts

@property
def host_resolver(self) -> HostResolver:
"""Return the host resolver."""
return self._host_resolver

@property
def compression(self) -> Optional[CompressionManager]:
"""Return the compression manager."""
return self._compression

@property
def sessions(self) -> Sequence[Any]:
"""Return the list of sessions.

You may use this to customize sessions on the fly (for example,
adjust the timeout). Not recommended unless you know what you are doing.

Warning:
Modifying only a subset of sessions may lead to unexpected behavior.
In order to keep the client in a consistent state, you should make sure
all sessions are configured in the same way.
"""
return self._sessions

@property
def version(self) -> str:
"""Return the version of the client."""
return __version__

async def close(self) -> None:
"""Close HTTP sessions."""
await asyncio.gather(*(session.close() for session in self._sessions))

async def db(
self,
name: str,
auth_method: str = "basic",
auth: Optional[Auth] = None,
token: Optional[JwtToken] = None,
verify: bool = False,
compression: Optional[CompressionManager] = None,
) -> Database:
"""Connects to a database and returns and API wrapper.

Args:
name (str): Database name.
auth_method (str): The following methods are supported:

- "basic": HTTP authentication.
Requires the `auth` parameter. The `token` parameter is ignored.
- "jwt": User JWT authentication.
At least one of the `auth` or `token` parameters are required.
If `auth` is provided, but the `token` is not, the token will be
refreshed automatically. This assumes that the clocks of the server
and client are synchronized.
- "superuser": Superuser JWT authentication.
The `token` parameter is required. The `auth` parameter is ignored.
auth (Auth | None): Login information.
token (JwtToken | None): JWT token.
verify (bool): Verify the connection by sending a test request.
compression (CompressionManager | None): Supersedes the client-level
compression settings.

Returns:
Database: Database API wrapper.

Raises:
ValueError: If the authentication is invalid.
ServerConnectionError: If `verify` is `True` and the connection fails.
"""
connection: Connection
if auth_method == "basic":
if auth is None:
raise ValueError("Basic authentication requires the `auth` parameter")
connection = BasicConnection(
sessions=self._sessions,
host_resolver=self._host_resolver,
http_client=self._http_client,
db_name=name,
compression=compression or self._compression,
auth=auth,
)
elif auth_method == "jwt":
if auth is None and token is None:
raise ValueError(
"JWT authentication requires the `auth` or `token` parameter"
)
connection = JwtConnection(
sessions=self._sessions,
host_resolver=self._host_resolver,
http_client=self._http_client,
db_name=name,
compression=compression or self._compression,
auth=auth,
token=token,
)
elif auth_method == "superuser":
if token is None:
raise ValueError(
"Superuser JWT authentication requires the `token` parameter"
)
connection = JwtSuperuserConnection(
sessions=self._sessions,
host_resolver=self._host_resolver,
http_client=self._http_client,
db_name=name,
compression=compression or self._compression,
token=token,
)
else:
raise ValueError(f"Invalid authentication method: {auth_method}")

if verify:
await connection.ping()

return Database(connection)
2 changes: 1 addition & 1 deletion arangoasync/compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def accept_encoding(self) -> str | None:
"""Return the accept encoding.

This is the value of the Accept-Encoding header in the HTTP request.
Currently, only deflate and "gzip" are supported.
Currently, only "deflate" and "gzip" are supported.

Returns:
str: Accept encoding
Expand Down
14 changes: 10 additions & 4 deletions arangoasync/connection.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
__all__ = [
"BaseConnection",
"BasicConnection",
"Connection",
"JwtConnection",
"JwtSuperuserConnection",
]
Expand Down Expand Up @@ -244,9 +245,9 @@ def __init__(
super().__init__(sessions, host_resolver, http_client, db_name, compression)
self._auth = auth
self._expire_leeway: int = 0
self._token: Optional[JwtToken] = None
self._token: Optional[JwtToken] = token
self._auth_header: Optional[str] = None
self.token = token
self.token = self._token

if self._token is None and self._auth is None:
raise ValueError("Either token or auth must be provided.")
Expand Down Expand Up @@ -323,6 +324,7 @@ async def send_request(self, request: Request) -> Response:
Response: HTTP response

Raises:
AuthHeaderError: If the authentication header could not be generated.
ArangoClientError: If an error occurred from the client side.
ArangoServerError: If an error occurred from the server side.
"""
Expand Down Expand Up @@ -372,9 +374,9 @@ def __init__(
) -> None:
super().__init__(sessions, host_resolver, http_client, db_name, compression)
self._expire_leeway: int = 0
self._token: Optional[JwtToken] = None
self._token: Optional[JwtToken] = token
self._auth_header: Optional[str] = None
self.token = token
self.token = self._token

@property
def token(self) -> Optional[JwtToken]:
Expand Down Expand Up @@ -407,6 +409,7 @@ async def send_request(self, request: Request) -> Response:
Response: HTTP response

Raises:
AuthHeaderError: If the authentication header could not be generated.
ArangoClientError: If an error occurred from the client side.
ArangoServerError: If an error occurred from the server side.
"""
Expand All @@ -417,3 +420,6 @@ async def send_request(self, request: Request) -> Response:
resp = await self.process_request(request)
self.raise_for_status(request, resp)
return resp


Connection = BasicConnection | JwtConnection | JwtSuperuserConnection
17 changes: 17 additions & 0 deletions arangoasync/database.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
__all__ = [
"Database",
]

from arangoasync.connection import BaseConnection


class Database:
"""Database API."""

def __init__(self, connection: BaseConnection) -> None:
self._conn = connection

@property
def conn(self) -> BaseConnection:
"""Return the HTTP connection."""
return self._conn
4 changes: 0 additions & 4 deletions arangoasync/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,6 @@ class AuthHeaderError(ArangoClientError):
"""The authentication header could not be determined."""


class JWTExpiredError(ArangoClientError):
"""JWT token has expired."""


class JWTRefreshError(ArangoClientError):
"""Failed to refresh the JWT token."""

Expand Down
10 changes: 10 additions & 0 deletions arangoasync/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
]

from abc import ABC, abstractmethod
from ssl import SSLContext, create_default_context
from typing import Any, Optional

from aiohttp import (
Expand Down Expand Up @@ -82,6 +83,10 @@ class AioHTTPClient(HTTPClient):
timeout (aiohttp.ClientTimeout | None): Client timeout settings.
300s total timeout by default for a complete request/response operation.
read_bufsize (int): Size of read buffer (64KB default).
ssl_context (ssl.SSLContext | bool): SSL validation mode.
`True` for default SSL checks (see :func:`ssl.create_default_context`).
`False` disables SSL checks.
Additionally, you can pass a custom :class:`ssl.SSLContext`.

.. _aiohttp:
https://docs.aiohttp.org/en/stable/
Expand All @@ -92,6 +97,7 @@ def __init__(
connector: Optional[BaseConnector] = None,
timeout: Optional[ClientTimeout] = None,
read_bufsize: int = 2**16,
ssl_context: bool | SSLContext = True,
) -> None:
self._connector = connector or TCPConnector(
keepalive_timeout=60, # timeout for connection reusing after releasing
Expand All @@ -102,6 +108,9 @@ def __init__(
connect=60, # max number of seconds for acquiring a pool connection
)
self._read_bufsize = read_bufsize
self._ssl_context = (
ssl_context if ssl_context is not True else create_default_context()
)

def create_session(self, host: str) -> ClientSession:
"""Return a new session given the base host URL.
Expand Down Expand Up @@ -155,6 +164,7 @@ async def send_request(
params=request.normalized_params(),
data=request.data,
auth=auth,
ssl=self._ssl_context,
) as response:
raw_body = await response.read()
return Response(
Expand Down
10 changes: 7 additions & 3 deletions arangoasync/resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ class HostResolver(ABC):

Args:
host_count (int): Number of hosts.
max_tries (int): Maximum number of attempts to try a host.
max_tries (int | None): Maximum number of attempts to try a host.
Will default to 3 times the number of hosts if not provided.

Raises:
ValueError: If max_tries is less than host_count.
Expand All @@ -42,7 +43,7 @@ def get_host_index(self) -> int: # pragma: no cover
raise NotImplementedError

def change_host(self) -> None:
"""If there aer multiple hosts available, switch to the next one."""
"""If there are multiple hosts available, switch to the next one."""
self._index = (self._index + 1) % self.host_count

@property
Expand All @@ -57,7 +58,10 @@ def max_tries(self) -> int:


class SingleHostResolver(HostResolver):
"""Single host resolver. Always returns the same host index."""
"""Single host resolver.

Always returns the same host index, unless prompted to change.
"""

def __init__(self, host_count: int, max_tries: Optional[int] = None) -> None:
super().__init__(host_count, max_tries)
Expand Down
1 change: 1 addition & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
autodoc_typehints = "none"

intersphinx_mapping = {
"python": ("https://docs.python.org/3", None),
"aiohttp": ("https://docs.aiohttp.org/en/stable/", None),
"jwt": ("https://pyjwt.readthedocs.io/en/stable/", None),
}
Expand Down
Loading
Loading