|
| 1 | +__all__ = ["ArangoClient"] |
| 2 | + |
| 3 | +import asyncio |
| 4 | +from typing import Any, Optional, Sequence |
| 5 | + |
| 6 | +from arangoasync.auth import Auth, JwtToken |
| 7 | +from arangoasync.compression import CompressionManager |
| 8 | +from arangoasync.connection import ( |
| 9 | + BasicConnection, |
| 10 | + Connection, |
| 11 | + JwtConnection, |
| 12 | + JwtSuperuserConnection, |
| 13 | +) |
| 14 | +from arangoasync.database import Database |
| 15 | +from arangoasync.http import DefaultHTTPClient, HTTPClient |
| 16 | +from arangoasync.resolver import HostResolver, get_resolver |
| 17 | +from arangoasync.version import __version__ |
| 18 | + |
| 19 | + |
| 20 | +class ArangoClient: |
| 21 | + """ArangoDB client. |
| 22 | +
|
| 23 | + Args: |
| 24 | + hosts (str | Sequence[str]): Host URL or list of URL's. |
| 25 | + In case of a cluster, this would be the list of coordinators. |
| 26 | + Which coordinator to use is determined by the `host_resolver`. |
| 27 | + host_resolver (str | HostResolver): Host resolver strategy. |
| 28 | + This determines how the client will choose which server to use. |
| 29 | + Passing a string would configure a resolver with the default settings. |
| 30 | + See :class:`DefaultHostResolver <arangoasync.resolver.DefaultHostResolver>` |
| 31 | + and :func:`get_resolver <arangoasync.resolver.get_resolver>` |
| 32 | + for more information. |
| 33 | + If you need more customization, pass a subclass of |
| 34 | + :class:`HostResolver <arangoasync.resolver.HostResolver>`. |
| 35 | + http_client (HTTPClient | None): HTTP client implementation. |
| 36 | + This is the core component that sends requests to the ArangoDB server. |
| 37 | + Defaults to :class:`DefaultHttpClient <arangoasync.http.DefaultHTTPClient>`, |
| 38 | + but you can fully customize its parameters or even use a different |
| 39 | + implementation by subclassing |
| 40 | + :class:`HTTPClient <arangoasync.http.HTTPClient>`. |
| 41 | + compression (CompressionManager | None): Disabled by default. |
| 42 | + Used to compress requests to the server or instruct the server to compress |
| 43 | + responses. Enable it by passing an instance of |
| 44 | + :class:`DefaultCompressionManager |
| 45 | + <arangoasync.compression.DefaultCompressionManager>` |
| 46 | + or a subclass of :class:`CompressionManager |
| 47 | + <arangoasync.compression.CompressionManager>`. |
| 48 | +
|
| 49 | + Raises: |
| 50 | + ValueError: If the `host_resolver` is not supported. |
| 51 | + """ |
| 52 | + |
| 53 | + def __init__( |
| 54 | + self, |
| 55 | + hosts: str | Sequence[str] = "http://127.0.0.1:8529", |
| 56 | + host_resolver: str | HostResolver = "default", |
| 57 | + http_client: Optional[HTTPClient] = None, |
| 58 | + compression: Optional[CompressionManager] = None, |
| 59 | + ) -> None: |
| 60 | + self._hosts = [hosts] if isinstance(hosts, str) else hosts |
| 61 | + self._host_resolver = ( |
| 62 | + get_resolver(host_resolver, len(self._hosts)) |
| 63 | + if isinstance(host_resolver, str) |
| 64 | + else host_resolver |
| 65 | + ) |
| 66 | + self._http_client = http_client or DefaultHTTPClient() |
| 67 | + self._sessions = [ |
| 68 | + self._http_client.create_session(host) for host in self._hosts |
| 69 | + ] |
| 70 | + self._compression = compression |
| 71 | + |
| 72 | + def __repr__(self) -> str: |
| 73 | + return f"<ArangoClient {','.join(self._hosts)}>" |
| 74 | + |
| 75 | + async def __aenter__(self) -> "ArangoClient": |
| 76 | + return self |
| 77 | + |
| 78 | + async def __aexit__(self, *exc: Any) -> None: |
| 79 | + await self.close() |
| 80 | + |
| 81 | + @property |
| 82 | + def hosts(self) -> Sequence[str]: |
| 83 | + """Return the list of hosts.""" |
| 84 | + return self._hosts |
| 85 | + |
| 86 | + @property |
| 87 | + def host_resolver(self) -> HostResolver: |
| 88 | + """Return the host resolver.""" |
| 89 | + return self._host_resolver |
| 90 | + |
| 91 | + @property |
| 92 | + def compression(self) -> Optional[CompressionManager]: |
| 93 | + """Return the compression manager.""" |
| 94 | + return self._compression |
| 95 | + |
| 96 | + @property |
| 97 | + def sessions(self) -> Sequence[Any]: |
| 98 | + """Return the list of sessions. |
| 99 | +
|
| 100 | + You may use this to customize sessions on the fly (for example, |
| 101 | + adjust the timeout). Not recommended unless you know what you are doing. |
| 102 | +
|
| 103 | + Warning: |
| 104 | + Modifying only a subset of sessions may lead to unexpected behavior. |
| 105 | + In order to keep the client in a consistent state, you should make sure |
| 106 | + all sessions are configured in the same way. |
| 107 | + """ |
| 108 | + return self._sessions |
| 109 | + |
| 110 | + @property |
| 111 | + def version(self) -> str: |
| 112 | + """Return the version of the client.""" |
| 113 | + return __version__ |
| 114 | + |
| 115 | + async def close(self) -> None: |
| 116 | + """Close HTTP sessions.""" |
| 117 | + await asyncio.gather(*(session.close() for session in self._sessions)) |
| 118 | + |
| 119 | + async def db( |
| 120 | + self, |
| 121 | + name: str, |
| 122 | + auth_method: str = "basic", |
| 123 | + auth: Optional[Auth] = None, |
| 124 | + token: Optional[JwtToken] = None, |
| 125 | + verify: bool = False, |
| 126 | + compression: Optional[CompressionManager] = None, |
| 127 | + ) -> Database: |
| 128 | + """Connects to a database and returns and API wrapper. |
| 129 | +
|
| 130 | + Args: |
| 131 | + name (str): Database name. |
| 132 | + auth_method (str): The following methods are supported: |
| 133 | +
|
| 134 | + - "basic": HTTP authentication. |
| 135 | + Requires the `auth` parameter. The `token` parameter is ignored. |
| 136 | + - "jwt": User JWT authentication. |
| 137 | + At least one of the `auth` or `token` parameters are required. |
| 138 | + If `auth` is provided, but the `token` is not, the token will be |
| 139 | + refreshed automatically. This assumes that the clocks of the server |
| 140 | + and client are synchronized. |
| 141 | + - "superuser": Superuser JWT authentication. |
| 142 | + The `token` parameter is required. The `auth` parameter is ignored. |
| 143 | + auth (Auth | None): Login information. |
| 144 | + token (JwtToken | None): JWT token. |
| 145 | + verify (bool): Verify the connection by sending a test request. |
| 146 | + compression (CompressionManager | None): Supersedes the client-level |
| 147 | + compression settings. |
| 148 | +
|
| 149 | + Returns: |
| 150 | + Database: Database API wrapper. |
| 151 | +
|
| 152 | + Raises: |
| 153 | + ValueError: If the authentication is invalid. |
| 154 | + ServerConnectionError: If `verify` is `True` and the connection fails. |
| 155 | + """ |
| 156 | + connection: Connection |
| 157 | + if auth_method == "basic": |
| 158 | + if auth is None: |
| 159 | + raise ValueError("Basic authentication requires the `auth` parameter") |
| 160 | + connection = BasicConnection( |
| 161 | + sessions=self._sessions, |
| 162 | + host_resolver=self._host_resolver, |
| 163 | + http_client=self._http_client, |
| 164 | + db_name=name, |
| 165 | + compression=compression or self._compression, |
| 166 | + auth=auth, |
| 167 | + ) |
| 168 | + elif auth_method == "jwt": |
| 169 | + if auth is None and token is None: |
| 170 | + raise ValueError( |
| 171 | + "JWT authentication requires the `auth` or `token` parameter" |
| 172 | + ) |
| 173 | + connection = JwtConnection( |
| 174 | + sessions=self._sessions, |
| 175 | + host_resolver=self._host_resolver, |
| 176 | + http_client=self._http_client, |
| 177 | + db_name=name, |
| 178 | + compression=compression or self._compression, |
| 179 | + auth=auth, |
| 180 | + token=token, |
| 181 | + ) |
| 182 | + elif auth_method == "superuser": |
| 183 | + if token is None: |
| 184 | + raise ValueError( |
| 185 | + "Superuser JWT authentication requires the `token` parameter" |
| 186 | + ) |
| 187 | + connection = JwtSuperuserConnection( |
| 188 | + sessions=self._sessions, |
| 189 | + host_resolver=self._host_resolver, |
| 190 | + http_client=self._http_client, |
| 191 | + db_name=name, |
| 192 | + compression=compression or self._compression, |
| 193 | + token=token, |
| 194 | + ) |
| 195 | + else: |
| 196 | + raise ValueError(f"Invalid authentication method: {auth_method}") |
| 197 | + |
| 198 | + if verify: |
| 199 | + await connection.ping() |
| 200 | + |
| 201 | + return Database(connection) |
0 commit comments