Skip to content

Commit 7789567

Browse files
authored
BasicConnection (#10)
* Introducing BasicConnection * BasicConnection supports authentication and compression * BasicConnection supports authentication and compression * Fixing linter issues
1 parent 41a9bda commit 7789567

17 files changed

+788
-42
lines changed

.pre-commit-config.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -38,4 +38,4 @@ repos:
3838
hooks:
3939
- id: mypy
4040
files: ^arangoasync/
41-
additional_dependencies: ['types-requests', "types-setuptools"]
41+
additional_dependencies: ["types-requests", "types-setuptools"]

arangoasync/__init__.py

+4
Original file line numberDiff line numberDiff line change
@@ -1 +1,5 @@
1+
import logging
2+
13
from .version import __version__
4+
5+
logger = logging.getLogger(__name__)

arangoasync/auth.py

+74
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
__all__ = [
2+
"Auth",
3+
"JwtToken",
4+
]
5+
6+
from dataclasses import dataclass
7+
8+
import jwt
9+
10+
11+
@dataclass
12+
class Auth:
13+
"""Authentication details for the ArangoDB instance.
14+
15+
Attributes:
16+
username (str): Username.
17+
password (str): Password.
18+
encoding (str): Encoding for the password (default: utf-8)
19+
"""
20+
21+
username: str
22+
password: str
23+
encoding: str = "utf-8"
24+
25+
26+
class JwtToken:
27+
"""JWT token.
28+
29+
Args:
30+
token (str | bytes): JWT token.
31+
32+
Raises:
33+
TypeError: If the token type is not str or bytes.
34+
JWTExpiredError: If the token expired.
35+
"""
36+
37+
def __init__(self, token: str | bytes) -> None:
38+
self._token = token
39+
self._validate()
40+
41+
@property
42+
def token(self) -> str | bytes:
43+
"""Get token."""
44+
return self._token
45+
46+
@token.setter
47+
def token(self, token: str | bytes) -> None:
48+
"""Set token.
49+
50+
Raises:
51+
jwt.ExpiredSignatureError: If the token expired.
52+
"""
53+
self._token = token
54+
self._validate()
55+
56+
def _validate(self) -> None:
57+
"""Validate the token."""
58+
if type(self._token) not in (str, bytes):
59+
raise TypeError("Token must be str or bytes")
60+
61+
jwt_payload = jwt.decode(
62+
self._token,
63+
issuer="arangodb",
64+
algorithms=["HS256"],
65+
options={
66+
"require_exp": True,
67+
"require_iat": True,
68+
"verify_iat": True,
69+
"verify_exp": True,
70+
"verify_signature": False,
71+
},
72+
)
73+
74+
self._token_exp = jwt_payload["exp"]

arangoasync/compression.py

+118
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
__all__ = [
2+
"AcceptEncoding",
3+
"ContentEncoding",
4+
"CompressionManager",
5+
"DefaultCompressionManager",
6+
]
7+
8+
import zlib
9+
from abc import ABC, abstractmethod
10+
from enum import Enum, auto
11+
from typing import Optional
12+
13+
14+
class AcceptEncoding(Enum):
15+
"""Valid accepted encodings for the Accept-Encoding header."""
16+
17+
DEFLATE = auto()
18+
GZIP = auto()
19+
IDENTITY = auto()
20+
21+
22+
class ContentEncoding(Enum):
23+
"""Valid content encodings for the Content-Encoding header."""
24+
25+
DEFLATE = auto()
26+
GZIP = auto()
27+
28+
29+
class CompressionManager(ABC): # pragma: no cover
30+
"""Abstract base class for handling request/response compression."""
31+
32+
@abstractmethod
33+
def needs_compression(self, data: str | bytes) -> bool:
34+
"""Determine if the data needs to be compressed
35+
36+
Args:
37+
data (str | bytes): Data to check
38+
39+
Returns:
40+
bool: True if the data needs to be compressed
41+
"""
42+
raise NotImplementedError
43+
44+
@abstractmethod
45+
def compress(self, data: str | bytes) -> bytes:
46+
"""Compress the data
47+
48+
Args:
49+
data (str | bytes): Data to compress
50+
51+
Returns:
52+
bytes: Compressed data
53+
"""
54+
raise NotImplementedError
55+
56+
@abstractmethod
57+
def content_encoding(self) -> str:
58+
"""Return the content encoding.
59+
60+
This is the value of the Content-Encoding header in the HTTP request.
61+
Must match the encoding used in the compress method.
62+
63+
Returns:
64+
str: Content encoding
65+
"""
66+
raise NotImplementedError
67+
68+
@abstractmethod
69+
def accept_encoding(self) -> str | None:
70+
"""Return the accept encoding.
71+
72+
This is the value of the Accept-Encoding header in the HTTP request.
73+
Currently, only deflate and "gzip" are supported.
74+
75+
Returns:
76+
str: Accept encoding
77+
"""
78+
raise NotImplementedError
79+
80+
81+
class DefaultCompressionManager(CompressionManager):
82+
"""Compress requests using the deflate algorithm.
83+
84+
Args:
85+
threshold (int): Will compress requests to the server if
86+
the size of the request body (in bytes) is at least the value of this option.
87+
Setting it to -1 will disable request compression (default).
88+
level (int): Compression level. Defaults to 6.
89+
accept (str | None): Accepted encoding. By default, there is
90+
no compression of responses.
91+
"""
92+
93+
def __init__(
94+
self,
95+
threshold: int = -1,
96+
level: int = 6,
97+
accept: Optional[AcceptEncoding] = None,
98+
) -> None:
99+
self._threshold = threshold
100+
self._level = level
101+
self._content_encoding = ContentEncoding.DEFLATE.name.lower()
102+
self._accept_encoding = accept.name.lower() if accept else None
103+
104+
def needs_compression(self, data: str | bytes) -> bool:
105+
return self._threshold != -1 and len(data) >= self._threshold
106+
107+
def compress(self, data: str | bytes) -> bytes:
108+
if data is not None:
109+
if isinstance(data, bytes):
110+
return zlib.compress(data, self._level)
111+
return zlib.compress(data.encode("utf-8"), self._level)
112+
return b""
113+
114+
def content_encoding(self) -> str:
115+
return self._content_encoding
116+
117+
def accept_encoding(self) -> str | None:
118+
return self._accept_encoding

arangoasync/connection.py

+174
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
__all__ = [
2+
"BaseConnection",
3+
"BasicConnection",
4+
]
5+
6+
from abc import ABC, abstractmethod
7+
from typing import Any, List, Optional
8+
9+
from arangoasync.auth import Auth
10+
from arangoasync.compression import CompressionManager, DefaultCompressionManager
11+
from arangoasync.exceptions import (
12+
ClientConnectionError,
13+
ConnectionAbortedError,
14+
ServerConnectionError,
15+
)
16+
from arangoasync.http import HTTPClient
17+
from arangoasync.request import Method, Request
18+
from arangoasync.resolver import HostResolver
19+
from arangoasync.response import Response
20+
21+
22+
class BaseConnection(ABC):
23+
"""Blueprint for connection to a specific ArangoDB database.
24+
25+
Args:
26+
sessions (list): List of client sessions.
27+
host_resolver (HostResolver): Host resolver.
28+
http_client (HTTPClient): HTTP client.
29+
db_name (str): Database name.
30+
compression (CompressionManager | None): Compression manager.
31+
"""
32+
33+
def __init__(
34+
self,
35+
sessions: List[Any],
36+
host_resolver: HostResolver,
37+
http_client: HTTPClient,
38+
db_name: str,
39+
compression: Optional[CompressionManager] = None,
40+
) -> None:
41+
self._sessions = sessions
42+
self._db_endpoint = f"/_db/{db_name}"
43+
self._host_resolver = host_resolver
44+
self._http_client = http_client
45+
self._db_name = db_name
46+
self._compression = compression or DefaultCompressionManager()
47+
48+
@property
49+
def db_name(self) -> str:
50+
"""Return the database name."""
51+
return self._db_name
52+
53+
def prep_response(self, request: Request, resp: Response) -> Response:
54+
"""Prepare response for return.
55+
56+
Args:
57+
request (Request): Request object.
58+
resp (Response): Response object.
59+
60+
Returns:
61+
Response: Response object
62+
63+
Raises:
64+
ServerConnectionError: If the response status code is not successful.
65+
"""
66+
resp.is_success = 200 <= resp.status_code < 300
67+
if not resp.is_success:
68+
raise ServerConnectionError(resp, request)
69+
return resp
70+
71+
async def process_request(self, request: Request) -> Response:
72+
"""Process request, potentially trying multiple hosts.
73+
74+
Args:
75+
request (Request): Request object.
76+
77+
Returns:
78+
Response: Response object.
79+
80+
Raises:
81+
ConnectionAbortedError: If can't connect to host(s) within limit.
82+
"""
83+
84+
ex_host_index = -1
85+
host_index = self._host_resolver.get_host_index()
86+
for tries in range(self._host_resolver.max_tries):
87+
try:
88+
resp = await self._http_client.send_request(
89+
self._sessions[host_index], request
90+
)
91+
return self.prep_response(request, resp)
92+
except ClientConnectionError:
93+
ex_host_index = host_index
94+
host_index = self._host_resolver.get_host_index()
95+
if ex_host_index == host_index:
96+
self._host_resolver.change_host()
97+
host_index = self._host_resolver.get_host_index()
98+
99+
raise ConnectionAbortedError(
100+
f"Can't connect to host(s) within limit ({self._host_resolver.max_tries})"
101+
)
102+
103+
async def ping(self) -> int:
104+
"""Ping host to check if connection is established.
105+
106+
Returns:
107+
int: Response status code.
108+
109+
Raises:
110+
ServerConnectionError: If the response status code is not successful.
111+
"""
112+
request = Request(method=Method.GET, endpoint="/_api/collection")
113+
resp = await self.send_request(request)
114+
if resp.status_code in {401, 403}:
115+
raise ServerConnectionError(resp, request, "Authentication failed.")
116+
if not resp.is_success:
117+
raise ServerConnectionError(resp, request, "Bad server response.")
118+
return resp.status_code
119+
120+
@abstractmethod
121+
async def send_request(self, request: Request) -> Response: # pragma: no cover
122+
"""Send an HTTP request to the ArangoDB server.
123+
124+
Args:
125+
request (Request): HTTP request.
126+
127+
Returns:
128+
Response: HTTP response.
129+
"""
130+
raise NotImplementedError
131+
132+
133+
class BasicConnection(BaseConnection):
134+
"""Connection to a specific ArangoDB database.
135+
136+
Allows for basic authentication to be used (username and password).
137+
138+
Args:
139+
sessions (list): List of client sessions.
140+
host_resolver (HostResolver): Host resolver.
141+
http_client (HTTPClient): HTTP client.
142+
db_name (str): Database name.
143+
compression (CompressionManager | None): Compression manager.
144+
auth (Auth | None): Authentication information.
145+
"""
146+
147+
def __init__(
148+
self,
149+
sessions: List[Any],
150+
host_resolver: HostResolver,
151+
http_client: HTTPClient,
152+
db_name: str,
153+
compression: Optional[CompressionManager] = None,
154+
auth: Optional[Auth] = None,
155+
) -> None:
156+
super().__init__(sessions, host_resolver, http_client, db_name, compression)
157+
self._auth = auth
158+
159+
async def send_request(self, request: Request) -> Response:
160+
"""Send an HTTP request to the ArangoDB server."""
161+
if request.data is not None and self._compression.needs_compression(
162+
request.data
163+
):
164+
request.data = self._compression.compress(request.data)
165+
request.headers["content-encoding"] = self._compression.content_encoding()
166+
167+
accept_encoding: str | None = self._compression.accept_encoding()
168+
if accept_encoding is not None:
169+
request.headers["accept-encoding"] = accept_encoding
170+
171+
if self._auth:
172+
request.auth = self._auth
173+
174+
return await self.process_request(request)

0 commit comments

Comments
 (0)