diff --git a/arangoasync/auth.py b/arangoasync/auth.py index 9b5da9c..5a1ab04 100644 --- a/arangoasync/auth.py +++ b/arangoasync/auth.py @@ -33,7 +33,7 @@ class JwtToken: Raises: TypeError: If the token type is not str or bytes. - jwt.ExpiredSignatureError: If the token expired. + jwt.exceptions.ExpiredSignatureError: If the token expired. """ def __init__(self, token: str) -> None: @@ -82,7 +82,7 @@ def token(self, token: str) -> None: """Set token. Raises: - jwt.ExpiredSignatureError: If the token expired. + jwt.exceptions.ExpiredSignatureError: If the token expired. """ self._token = token self._validate() diff --git a/arangoasync/client.py b/arangoasync/client.py new file mode 100644 index 0000000..ab54426 --- /dev/null +++ b/arangoasync/client.py @@ -0,0 +1,201 @@ +__all__ = ["ArangoClient"] + +import asyncio +from typing import Any, Optional, Sequence + +from arangoasync.auth import Auth, JwtToken +from arangoasync.compression import CompressionManager +from arangoasync.connection import ( + BasicConnection, + Connection, + JwtConnection, + JwtSuperuserConnection, +) +from arangoasync.database import Database +from arangoasync.http import DefaultHTTPClient, HTTPClient +from arangoasync.resolver import HostResolver, get_resolver +from arangoasync.version import __version__ + + +class ArangoClient: + """ArangoDB client. + + Args: + hosts (str | Sequence[str]): Host URL or list of URL's. + In case of a cluster, this would be the list of coordinators. + Which coordinator to use is determined by the `host_resolver`. + host_resolver (str | HostResolver): Host resolver strategy. + This determines how the client will choose which server to use. + Passing a string would configure a resolver with the default settings. + See :class:`DefaultHostResolver ` + and :func:`get_resolver ` + for more information. + If you need more customization, pass a subclass of + :class:`HostResolver `. + http_client (HTTPClient | None): HTTP client implementation. + This is the core component that sends requests to the ArangoDB server. + Defaults to :class:`DefaultHttpClient `, + but you can fully customize its parameters or even use a different + implementation by subclassing + :class:`HTTPClient `. + compression (CompressionManager | None): Disabled by default. + Used to compress requests to the server or instruct the server to compress + responses. Enable it by passing an instance of + :class:`DefaultCompressionManager + ` + or a subclass of :class:`CompressionManager + `. + + Raises: + ValueError: If the `host_resolver` is not supported. + """ + + def __init__( + self, + hosts: str | Sequence[str] = "http://127.0.0.1:8529", + host_resolver: str | HostResolver = "default", + http_client: Optional[HTTPClient] = None, + compression: Optional[CompressionManager] = None, + ) -> None: + self._hosts = [hosts] if isinstance(hosts, str) else hosts + self._host_resolver = ( + get_resolver(host_resolver, len(self._hosts)) + if isinstance(host_resolver, str) + else host_resolver + ) + self._http_client = http_client or DefaultHTTPClient() + self._sessions = [ + self._http_client.create_session(host) for host in self._hosts + ] + self._compression = compression + + def __repr__(self) -> str: + return f"" + + async def __aenter__(self) -> "ArangoClient": + return self + + async def __aexit__(self, *exc: Any) -> None: + await self.close() + + @property + def hosts(self) -> Sequence[str]: + """Return the list of hosts.""" + return self._hosts + + @property + def host_resolver(self) -> HostResolver: + """Return the host resolver.""" + return self._host_resolver + + @property + def compression(self) -> Optional[CompressionManager]: + """Return the compression manager.""" + return self._compression + + @property + def sessions(self) -> Sequence[Any]: + """Return the list of sessions. + + You may use this to customize sessions on the fly (for example, + adjust the timeout). Not recommended unless you know what you are doing. + + Warning: + Modifying only a subset of sessions may lead to unexpected behavior. + In order to keep the client in a consistent state, you should make sure + all sessions are configured in the same way. + """ + return self._sessions + + @property + def version(self) -> str: + """Return the version of the client.""" + return __version__ + + async def close(self) -> None: + """Close HTTP sessions.""" + await asyncio.gather(*(session.close() for session in self._sessions)) + + async def db( + self, + name: str, + auth_method: str = "basic", + auth: Optional[Auth] = None, + token: Optional[JwtToken] = None, + verify: bool = False, + compression: Optional[CompressionManager] = None, + ) -> Database: + """Connects to a database and returns and API wrapper. + + Args: + name (str): Database name. + auth_method (str): The following methods are supported: + + - "basic": HTTP authentication. + Requires the `auth` parameter. The `token` parameter is ignored. + - "jwt": User JWT authentication. + At least one of the `auth` or `token` parameters are required. + If `auth` is provided, but the `token` is not, the token will be + refreshed automatically. This assumes that the clocks of the server + and client are synchronized. + - "superuser": Superuser JWT authentication. + The `token` parameter is required. The `auth` parameter is ignored. + auth (Auth | None): Login information. + token (JwtToken | None): JWT token. + verify (bool): Verify the connection by sending a test request. + compression (CompressionManager | None): Supersedes the client-level + compression settings. + + Returns: + Database: Database API wrapper. + + Raises: + ValueError: If the authentication is invalid. + ServerConnectionError: If `verify` is `True` and the connection fails. + """ + connection: Connection + if auth_method == "basic": + if auth is None: + raise ValueError("Basic authentication requires the `auth` parameter") + connection = BasicConnection( + sessions=self._sessions, + host_resolver=self._host_resolver, + http_client=self._http_client, + db_name=name, + compression=compression or self._compression, + auth=auth, + ) + elif auth_method == "jwt": + if auth is None and token is None: + raise ValueError( + "JWT authentication requires the `auth` or `token` parameter" + ) + connection = JwtConnection( + sessions=self._sessions, + host_resolver=self._host_resolver, + http_client=self._http_client, + db_name=name, + compression=compression or self._compression, + auth=auth, + token=token, + ) + elif auth_method == "superuser": + if token is None: + raise ValueError( + "Superuser JWT authentication requires the `token` parameter" + ) + connection = JwtSuperuserConnection( + sessions=self._sessions, + host_resolver=self._host_resolver, + http_client=self._http_client, + db_name=name, + compression=compression or self._compression, + token=token, + ) + else: + raise ValueError(f"Invalid authentication method: {auth_method}") + + if verify: + await connection.ping() + + return Database(connection) diff --git a/arangoasync/compression.py b/arangoasync/compression.py index d7b260a..adc3957 100644 --- a/arangoasync/compression.py +++ b/arangoasync/compression.py @@ -72,7 +72,7 @@ def accept_encoding(self) -> str | None: """Return the accept encoding. This is the value of the Accept-Encoding header in the HTTP request. - Currently, only deflate and "gzip" are supported. + Currently, only "deflate" and "gzip" are supported. Returns: str: Accept encoding diff --git a/arangoasync/connection.py b/arangoasync/connection.py index cb52a4c..0d342de 100644 --- a/arangoasync/connection.py +++ b/arangoasync/connection.py @@ -1,6 +1,7 @@ __all__ = [ "BaseConnection", "BasicConnection", + "Connection", "JwtConnection", "JwtSuperuserConnection", ] @@ -244,9 +245,9 @@ def __init__( super().__init__(sessions, host_resolver, http_client, db_name, compression) self._auth = auth self._expire_leeway: int = 0 - self._token: Optional[JwtToken] = None + self._token: Optional[JwtToken] = token self._auth_header: Optional[str] = None - self.token = token + self.token = self._token if self._token is None and self._auth is None: raise ValueError("Either token or auth must be provided.") @@ -323,6 +324,7 @@ async def send_request(self, request: Request) -> Response: Response: HTTP response Raises: + AuthHeaderError: If the authentication header could not be generated. ArangoClientError: If an error occurred from the client side. ArangoServerError: If an error occurred from the server side. """ @@ -372,9 +374,9 @@ def __init__( ) -> None: super().__init__(sessions, host_resolver, http_client, db_name, compression) self._expire_leeway: int = 0 - self._token: Optional[JwtToken] = None + self._token: Optional[JwtToken] = token self._auth_header: Optional[str] = None - self.token = token + self.token = self._token @property def token(self) -> Optional[JwtToken]: @@ -407,6 +409,7 @@ async def send_request(self, request: Request) -> Response: Response: HTTP response Raises: + AuthHeaderError: If the authentication header could not be generated. ArangoClientError: If an error occurred from the client side. ArangoServerError: If an error occurred from the server side. """ @@ -417,3 +420,6 @@ async def send_request(self, request: Request) -> Response: resp = await self.process_request(request) self.raise_for_status(request, resp) return resp + + +Connection = BasicConnection | JwtConnection | JwtSuperuserConnection diff --git a/arangoasync/database.py b/arangoasync/database.py new file mode 100644 index 0000000..8a6e52a --- /dev/null +++ b/arangoasync/database.py @@ -0,0 +1,17 @@ +__all__ = [ + "Database", +] + +from arangoasync.connection import BaseConnection + + +class Database: + """Database API.""" + + def __init__(self, connection: BaseConnection) -> None: + self._conn = connection + + @property + def conn(self) -> BaseConnection: + """Return the HTTP connection.""" + return self._conn diff --git a/arangoasync/exceptions.py b/arangoasync/exceptions.py index e816a1b..b0cd62c 100644 --- a/arangoasync/exceptions.py +++ b/arangoasync/exceptions.py @@ -80,10 +80,6 @@ class AuthHeaderError(ArangoClientError): """The authentication header could not be determined.""" -class JWTExpiredError(ArangoClientError): - """JWT token has expired.""" - - class JWTRefreshError(ArangoClientError): """Failed to refresh the JWT token.""" diff --git a/arangoasync/http.py b/arangoasync/http.py index 7fba5c2..02b88da 100644 --- a/arangoasync/http.py +++ b/arangoasync/http.py @@ -5,6 +5,7 @@ ] from abc import ABC, abstractmethod +from ssl import SSLContext, create_default_context from typing import Any, Optional from aiohttp import ( @@ -82,6 +83,10 @@ class AioHTTPClient(HTTPClient): timeout (aiohttp.ClientTimeout | None): Client timeout settings. 300s total timeout by default for a complete request/response operation. read_bufsize (int): Size of read buffer (64KB default). + ssl_context (ssl.SSLContext | bool): SSL validation mode. + `True` for default SSL checks (see :func:`ssl.create_default_context`). + `False` disables SSL checks. + Additionally, you can pass a custom :class:`ssl.SSLContext`. .. _aiohttp: https://docs.aiohttp.org/en/stable/ @@ -92,6 +97,7 @@ def __init__( connector: Optional[BaseConnector] = None, timeout: Optional[ClientTimeout] = None, read_bufsize: int = 2**16, + ssl_context: bool | SSLContext = True, ) -> None: self._connector = connector or TCPConnector( keepalive_timeout=60, # timeout for connection reusing after releasing @@ -102,6 +108,9 @@ def __init__( connect=60, # max number of seconds for acquiring a pool connection ) self._read_bufsize = read_bufsize + self._ssl_context = ( + ssl_context if ssl_context is not True else create_default_context() + ) def create_session(self, host: str) -> ClientSession: """Return a new session given the base host URL. @@ -155,6 +164,7 @@ async def send_request( params=request.normalized_params(), data=request.data, auth=auth, + ssl=self._ssl_context, ) as response: raw_body = await response.read() return Response( diff --git a/arangoasync/resolver.py b/arangoasync/resolver.py index 1aa2bd8..ab1e2c2 100644 --- a/arangoasync/resolver.py +++ b/arangoasync/resolver.py @@ -15,7 +15,8 @@ class HostResolver(ABC): Args: host_count (int): Number of hosts. - max_tries (int): Maximum number of attempts to try a host. + max_tries (int | None): Maximum number of attempts to try a host. + Will default to 3 times the number of hosts if not provided. Raises: ValueError: If max_tries is less than host_count. @@ -42,7 +43,7 @@ def get_host_index(self) -> int: # pragma: no cover raise NotImplementedError def change_host(self) -> None: - """If there aer multiple hosts available, switch to the next one.""" + """If there are multiple hosts available, switch to the next one.""" self._index = (self._index + 1) % self.host_count @property @@ -57,7 +58,10 @@ def max_tries(self) -> int: class SingleHostResolver(HostResolver): - """Single host resolver. Always returns the same host index.""" + """Single host resolver. + + Always returns the same host index, unless prompted to change. + """ def __init__(self, host_count: int, max_tries: Optional[int] = None) -> None: super().__init__(host_count, max_tries) diff --git a/docs/conf.py b/docs/conf.py index 6dae081..163bc1d 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -23,6 +23,7 @@ autodoc_typehints = "none" intersphinx_mapping = { + "python": ("https://docs.python.org/3", None), "aiohttp": ("https://docs.aiohttp.org/en/stable/", None), "jwt": ("https://pyjwt.readthedocs.io/en/stable/", None), } diff --git a/docs/specs.rst b/docs/specs.rst index 29ba812..d9f6ad7 100644 --- a/docs/specs.rst +++ b/docs/specs.rst @@ -4,14 +4,20 @@ API Specification This page contains the specification for all classes and methods available in python-arango-async. +.. automodule:: arangoasync.client + :members: + .. automodule:: arangoasync.auth :members: +.. automodule:: arangoasync.compression + :members: + .. automodule:: arangoasync.connection :members: .. automodule:: arangoasync.exceptions - :members: ArangoError, ArangoClientError + :members: .. automodule:: arangoasync.http :members: diff --git a/tests/test_client.py b/tests/test_client.py new file mode 100644 index 0000000..f8fc7a7 --- /dev/null +++ b/tests/test_client.py @@ -0,0 +1,123 @@ +import pytest + +from arangoasync.auth import Auth, JwtToken +from arangoasync.client import ArangoClient +from arangoasync.compression import DefaultCompressionManager +from arangoasync.http import DefaultHTTPClient +from arangoasync.resolver import DefaultHostResolver, RoundRobinHostResolver +from arangoasync.version import __version__ + + +@pytest.mark.asyncio +async def test_client_attributes(monkeypatch): + hosts = ["http://127.0.0.1:8529", "http://localhost:8529"] + + async with ArangoClient(hosts=hosts[0]) as client: + assert client.version == __version__ + assert client.hosts == [hosts[0]] + assert repr(client) == f"" + assert isinstance(client.host_resolver, DefaultHostResolver) + assert client.compression is None + assert len(client.sessions) == 1 + + with pytest.raises(ValueError): + async with ArangoClient(hosts=hosts, host_resolver="invalid") as _: + pass + + http_client = DefaultHTTPClient() + create_session = 0 + close_session = 0 + + class MockSession: + async def close(self): + nonlocal close_session + close_session += 1 + + def mock_method(*args, **kwargs): + nonlocal create_session + create_session += 1 + return MockSession() + + monkeypatch.setattr(http_client, "create_session", mock_method) + async with ArangoClient( + hosts=hosts, + host_resolver="roundrobin", + http_client=http_client, + compression=DefaultCompressionManager(threshold=5000), + ) as client: + assert repr(client) == f"" + assert isinstance(client.host_resolver, RoundRobinHostResolver) + assert isinstance(client.compression, DefaultCompressionManager) + assert client.compression.threshold == 5000 + assert len(client.sessions) == len(hosts) + assert create_session == 2 + assert close_session == 2 + + +@pytest.mark.asyncio +async def test_client_bad_auth_method(url, sys_db_name, root, password): + async with ArangoClient(hosts=url) as client: + with pytest.raises(ValueError): + await client.db(sys_db_name, auth_method="invalid") + + +@pytest.mark.asyncio +async def test_client_basic_auth(url, sys_db_name, root, password): + auth = Auth(username=root, password=password) + + # successful authentication + async with ArangoClient(hosts=url) as client: + await client.db(sys_db_name, auth_method="basic", auth=auth, verify=True) + + # auth missing + async with ArangoClient(hosts=url) as client: + with pytest.raises(ValueError): + await client.db( + sys_db_name, + auth_method="basic", + auth=None, + token=JwtToken.generate_token("test"), + verify=True, + ) + + +@pytest.mark.asyncio +async def test_client_jwt_auth(url, sys_db_name, root, password): + auth = Auth(username=root, password=password) + token: JwtToken + + # successful authentication with auth only + async with ArangoClient(hosts=url) as client: + db = await client.db(sys_db_name, auth_method="jwt", auth=auth, verify=True) + token = db.conn.token + + # successful authentication with token only + async with ArangoClient(hosts=url) as client: + await client.db(sys_db_name, auth_method="jwt", token=token, verify=True) + + # successful authentication with both + async with ArangoClient(hosts=url) as client: + await client.db( + sys_db_name, auth_method="jwt", auth=auth, token=token, verify=True + ) + + # auth and token missing + async with ArangoClient(hosts=url) as client: + with pytest.raises(ValueError): + await client.db(sys_db_name, auth_method="jwt", verify=True) + + +@pytest.mark.asyncio +async def test_client_jwt_superuser_auth(url, sys_db_name, root, password, token): + auth = Auth(username=root, password=password) + + # successful authentication + async with ArangoClient(hosts=url) as client: + await client.db(sys_db_name, auth_method="superuser", token=token, verify=True) + + # token missing + async with ArangoClient(hosts=url) as client: + with pytest.raises(ValueError): + await client.db( + sys_db_name, auth_method="superuser", auth=auth, verify=True + )