diff --git a/arango/client.py b/arango/client.py index 1666982e..2b6e4993 100644 --- a/arango/client.py +++ b/arango/client.py @@ -13,7 +13,12 @@ ) from arango.database import StandardDatabase from arango.exceptions import ServerConnectionError -from arango.http import DEFAULT_REQUEST_TIMEOUT, DefaultHTTPClient, HTTPClient +from arango.http import ( + DEFAULT_REQUEST_TIMEOUT, + DefaultHTTPClient, + HTTPClient, + RequestCompression, +) from arango.resolver import ( FallbackHostResolver, HostResolver, @@ -33,7 +38,7 @@ def default_serializer(x: Any) -> str: :return: The object serialized as a JSON string :rtype: str """ - return dumps(x) + return dumps(x, separators=(",", ":")) def default_deserializer(x: str) -> Any: @@ -85,6 +90,12 @@ class ArangoClient: None: No timeout. int: Timeout value in seconds. :type request_timeout: int | float + :param request_compression: Will compress requests to the server according to + the given algorithm. No compression happens by default. + :type request_compression: arango.http.RequestCompression | None + :param response_compression: Tells the server what compression algorithm is + acceptable for the response. No compression happens by default. + :type response_compression: str | None """ def __init__( @@ -97,6 +108,8 @@ def __init__( deserializer: Callable[[str], Any] = default_deserializer, verify_override: Union[bool, str, None] = None, request_timeout: Union[int, float, None] = DEFAULT_REQUEST_TIMEOUT, + request_compression: Optional[RequestCompression] = None, + response_compression: Optional[str] = None, ) -> None: if isinstance(hosts, str): self._hosts = [host.strip("/") for host in hosts.split(",")] @@ -133,6 +146,9 @@ def __init__( for session in self._sessions: session.verify = verify_override + self._request_compression = request_compression + self._response_compression = response_compression + def __repr__(self) -> str: return f"" @@ -231,6 +247,8 @@ def db( serializer=self._serializer, deserializer=self._deserializer, superuser_token=superuser_token, + request_compression=self._request_compression, + response_compression=self._response_compression, ) elif user_token is not None: connection = JwtConnection( @@ -242,6 +260,8 @@ def db( serializer=self._serializer, deserializer=self._deserializer, user_token=user_token, + request_compression=self._request_compression, + response_compression=self._response_compression, ) elif auth_method.lower() == "basic": connection = BasicConnection( @@ -254,6 +274,8 @@ def db( http_client=self._http, serializer=self._serializer, deserializer=self._deserializer, + request_compression=self._request_compression, + response_compression=self._response_compression, ) elif auth_method.lower() == "jwt": connection = JwtConnection( @@ -266,6 +288,8 @@ def db( http_client=self._http, serializer=self._serializer, deserializer=self._deserializer, + request_compression=self._request_compression, + response_compression=self._response_compression, ) else: raise ValueError(f"invalid auth_method: {auth_method}") diff --git a/arango/collection.py b/arango/collection.py index 01f11896..181ac001 100644 --- a/arango/collection.py +++ b/arango/collection.py @@ -1813,7 +1813,7 @@ def insert_many( index caches if document insertions affect the edge index or cache-enabled persistent indexes. :type refill_index_caches: bool | None - param version_attribute: support for simple external versioning to + :param version_attribute: support for simple external versioning to document operations. :type version_attribute: str :return: List of document metadata (e.g. document keys, revisions) and @@ -1939,7 +1939,7 @@ def update_many( as opposed to returning the error as an object in the result list. Defaults to False. :type raise_on_document_error: bool - param version_attribute: support for simple external versioning to + :param version_attribute: support for simple external versioning to document operations. :type version_attribute: str :return: List of document metadata (e.g. document keys, revisions) and @@ -2138,7 +2138,7 @@ def replace_many( index caches if document operations affect the edge index or cache-enabled persistent indexes. :type refill_index_caches: bool | None - param version_attribute: support for simple external versioning to + :param version_attribute: support for simple external versioning to document operations. :type version_attribute: str :return: List of document metadata (e.g. document keys, revisions) and @@ -2670,7 +2670,7 @@ def insert( index caches if document insertions affect the edge index or cache-enabled persistent indexes. :type refill_index_caches: bool | None - param version_attribute: support for simple external versioning to + :param version_attribute: support for simple external versioning to document operations. :type version_attribute: str :return: Document metadata (e.g. document key, revision) or True if @@ -2765,7 +2765,7 @@ def update( index caches if document insertions affect the edge index or cache-enabled persistent indexes. :type refill_index_caches: bool | None - param version_attribute: support for simple external versioning + :param version_attribute: support for simple external versioning to document operations. :type version_attribute: str :return: Document metadata (e.g. document key, revision) or True if @@ -2850,7 +2850,7 @@ def replace( index caches if document insertions affect the edge index or cache-enabled persistent indexes. :type refill_index_caches: bool | None - param version_attribute: support for simple external versioning to + :param version_attribute: support for simple external versioning to document operations. :type version_attribute: str :return: Document metadata (e.g. document key, revision) or True if diff --git a/arango/connection.py b/arango/connection.py index 3daa4585..d25d2f78 100644 --- a/arango/connection.py +++ b/arango/connection.py @@ -23,7 +23,7 @@ JWTRefreshError, ServerConnectionError, ) -from arango.http import HTTPClient +from arango.http import HTTPClient, RequestCompression from arango.request import Request from arango.resolver import HostResolver from arango.response import Response @@ -44,6 +44,8 @@ def __init__( http_client: HTTPClient, serializer: Callable[..., str], deserializer: Callable[[str], Any], + request_compression: Optional[RequestCompression] = None, + response_compression: Optional[str] = None, ) -> None: self._url_prefixes = [f"{host}/_db/{db_name}" for host in hosts] self._host_resolver = host_resolver @@ -53,6 +55,8 @@ def __init__( self._serializer = serializer self._deserializer = deserializer self._username: Optional[str] = None + self._request_compression = request_compression + self._response_compression = response_compression @property def db_name(self) -> str: @@ -133,6 +137,19 @@ def process_request( """ tries = 0 indexes_to_filter: Set[int] = set() + + data = self.normalize_data(request.data) + if ( + self._request_compression is not None + and isinstance(data, str) + and self._request_compression.needs_compression(data) + ): + request.headers["content-encoding"] = self._request_compression.encoding() + data = self._request_compression.compress(data) + + if self._response_compression is not None: + request.headers["accept-encoding"] = self._response_compression + while tries < self._host_resolver.max_tries: try: resp = self._http.send_request( @@ -140,7 +157,7 @@ def process_request( method=request.method, url=self._url_prefixes[host_index] + request.endpoint, params=request.params, - data=self.normalize_data(request.data), + data=data, headers=request.headers, auth=auth, ) @@ -243,6 +260,10 @@ class BasicConnection(BaseConnection): :type password: str :param http_client: User-defined HTTP client. :type http_client: arango.http.HTTPClient + :param: request_compression: The request compression algorithm. + :type request_compression: arango.http.RequestCompression | None + :param: response_compression: The response compression algorithm. + :type response_compression: str | None """ def __init__( @@ -256,6 +277,8 @@ def __init__( http_client: HTTPClient, serializer: Callable[..., str], deserializer: Callable[[str], Any], + request_compression: Optional[RequestCompression] = None, + response_compression: Optional[str] = None, ) -> None: super().__init__( hosts, @@ -265,6 +288,8 @@ def __init__( http_client, serializer, deserializer, + request_compression, + response_compression, ) self._username = username self._auth = (username, password) @@ -298,6 +323,10 @@ class JwtConnection(BaseConnection): :type password: str :param http_client: User-defined HTTP client. :type http_client: arango.http.HTTPClient + :param request_compression: The request compression algorithm. + :type request_compression: arango.http.RequestCompression | None + :param response_compression: The response compression algorithm. + :type response_compression: str | None """ def __init__( @@ -312,6 +341,8 @@ def __init__( username: Optional[str] = None, password: Optional[str] = None, user_token: Optional[str] = None, + request_compression: Optional[RequestCompression] = None, + response_compression: Optional[str] = None, ) -> None: super().__init__( hosts, @@ -321,6 +352,8 @@ def __init__( http_client, serializer, deserializer, + request_compression, + response_compression, ) self._username = username self._password = password @@ -439,6 +472,10 @@ class JwtSuperuserConnection(BaseConnection): :type http_client: arango.http.HTTPClient :param superuser_token: User generated token for superuser access. :type superuser_token: str + :param request_compression: The request compression algorithm. + :type request_compression: arango.http.RequestCompression | None + :param response_compression: The response compression algorithm. + :type response_compression: str | None """ def __init__( @@ -451,6 +488,8 @@ def __init__( serializer: Callable[..., str], deserializer: Callable[[str], Any], superuser_token: str, + request_compression: Optional[RequestCompression] = None, + response_compression: Optional[str] = None, ) -> None: super().__init__( hosts, @@ -460,6 +499,8 @@ def __init__( http_client, serializer, deserializer, + request_compression, + response_compression, ) self._auth_header = f"bearer {superuser_token}" diff --git a/arango/http.py b/arango/http.py index c5eb0acd..d0b17939 100644 --- a/arango/http.py +++ b/arango/http.py @@ -1,6 +1,13 @@ -__all__ = ["HTTPClient", "DefaultHTTPClient", "DEFAULT_REQUEST_TIMEOUT"] +__all__ = [ + "HTTPClient", + "DefaultHTTPClient", + "DeflateRequestCompression", + "RequestCompression", + "DEFAULT_REQUEST_TIMEOUT", +] import typing +import zlib from abc import ABC, abstractmethod from typing import Any, MutableMapping, Optional, Tuple, Union @@ -40,7 +47,7 @@ def send_request( url: str, headers: Optional[Headers] = None, params: Optional[MutableMapping[str, str]] = None, - data: Union[str, MultipartEncoder, None] = None, + data: Union[str, bytes, MultipartEncoder, None] = None, auth: Optional[Tuple[str, str]] = None, ) -> Response: """Send an HTTP request. @@ -58,7 +65,7 @@ def send_request( :param params: URL (query) parameters. :type params: dict :param data: Request payload. - :type data: str | MultipartEncoder | None + :type data: str | bytes | MultipartEncoder | None :param auth: Username and password. :type auth: tuple :returns: HTTP response. @@ -198,7 +205,7 @@ def send_request( url: str, headers: Optional[Headers] = None, params: Optional[MutableMapping[str, str]] = None, - data: Union[str, MultipartEncoder, None] = None, + data: Union[str, bytes, MultipartEncoder, None] = None, auth: Optional[Tuple[str, str]] = None, ) -> Response: """Send an HTTP request. @@ -214,7 +221,7 @@ def send_request( :param params: URL (query) parameters. :type params: dict :param data: Request payload. - :type data: str | MultipartEncoder | None + :type data: str | bytes | MultipartEncoder | None :param auth: Username and password. :type auth: tuple :returns: HTTP response. @@ -237,3 +244,75 @@ def send_request( status_text=response.reason, raw_body=response.text, ) + + +class RequestCompression(ABC): # pragma: no cover + """Abstract base class for request compression.""" + + @abstractmethod + def needs_compression(self, data: str) -> bool: + """ + :param data: Data to be compressed. + :type data: str + :returns: True if the data needs to be compressed. + :rtype: bool + """ + raise NotImplementedError + + @abstractmethod + def compress(self, data: str) -> bytes: + """Compress the data. + + :param data: Data to be compressed. + :type data: str + :returns: Compressed data. + :rtype: bytes + """ + raise NotImplementedError + + @abstractmethod + def encoding(self) -> str: + """Return the content encoding exactly as it should + appear in the headers. + + :returns: Content encoding. + :rtype: str + """ + raise NotImplementedError + + +class DeflateRequestCompression(RequestCompression): + """Compress requests using the 'deflate' algorithm.""" + + def __init__(self, threshold: int = 1024, level: int = 6): + """ + :param threshold: Will compress requests to the server if + the size of the request body (in bytes) is at least the value of this + option. + :type threshold: int + :param level: Compression level, in 0-9 or -1. + :type level: int + """ + self._threshold = threshold + self._level = level + + def needs_compression(self, data: str) -> bool: + """ + :param data: Data to be compressed. + :type data: str + :returns: True if the data needs to be compressed. + :rtype: bool + """ + return len(data) >= self._threshold + + def compress(self, data: str) -> bytes: + """ + :param data: Data to be compressed. + :type data: str + :returns: Compressed data. + :rtype: bytes + """ + return zlib.compress(data.encode("utf-8"), level=self._level) + + def encoding(self) -> str: + return "deflate" diff --git a/docs/compression.rst b/docs/compression.rst new file mode 100644 index 00000000..526e20f1 --- /dev/null +++ b/docs/compression.rst @@ -0,0 +1,40 @@ +Compression +------------ + +The :ref:`ArangoClient` lets you define the preferred compression policy for request and responses. By default +compression is disabled. You can change this by setting the `request_compression` and `response_compression` parameters +when creating the client. Currently, only the "deflate" compression algorithm is supported. + +.. testcode:: + + from arango import ArangoClient + + from arango.http import DeflateRequestCompression + + client = ArangoClient( + hosts='http://localhost:8529', + request_compression=DeflateRequestCompression(), + response_compression="deflate" + ) + +Furthermore, you can customize the request compression policy by defining the minimum size of the request body that +should be compressed and the desired compression level. For example, the following code sets the minimum size to 2 KB +and the compression level to 8: + +.. code-block:: python + + client = ArangoClient( + hosts='http://localhost:8529', + request_compression=DeflateRequestCompression( + threshold=2048, + level=8), + ) + +If you want to implement your own compression policy, you can do so by implementing the +:class:`arango.http.RequestCompression` interface. + +.. note:: + The `response_compression` parameter is only used to inform the server that the client prefers compressed responses + (in the form of an *Accept-Encoding* header). Note that the server may or may not honor this preference, depending + on how it is configured. This can be controlled by setting the `--http.compress-response-threshold` option to + a value greater than 0 when starting the ArangoDB server. diff --git a/docs/index.rst b/docs/index.rst index cd073440..22a9b8a3 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -80,6 +80,7 @@ Miscellaneous logging auth http + compression serializer schema cursor diff --git a/docs/specs.rst b/docs/specs.rst index b4f61854..87e1d184 100644 --- a/docs/specs.rst +++ b/docs/specs.rst @@ -103,6 +103,12 @@ DefaultHTTPClient .. autoclass:: arango.http.DefaultHTTPClient :members: +DeflateRequestCompression +========================= + +.. autoclass:: arango.http.DeflateRequestCompression + :members: + .. _EdgeCollection: EdgeCollection diff --git a/tests/test_client.py b/tests/test_client.py index 5faa84db..e43180f7 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -9,9 +9,14 @@ from arango.client import ArangoClient from arango.database import StandardDatabase from arango.exceptions import ServerConnectionError -from arango.http import DefaultHTTPClient +from arango.http import DefaultHTTPClient, DeflateRequestCompression from arango.resolver import FallbackHostResolver, RandomHostResolver, SingleHostResolver -from tests.helpers import generate_db_name, generate_string, generate_username +from tests.helpers import ( + generate_col_name, + generate_db_name, + generate_string, + generate_username, +) def test_client_attributes(): @@ -184,3 +189,56 @@ def test_can_serialize_deserialize_client() -> None: client_pstr = pickle.dumps(client) client2 = pickle.loads(client_pstr) assert len(client2._sessions) > 0 + + +def test_client_compression(db, username, password): + class CheckCompression: + def __init__(self, should_compress: bool): + self.should_compress = should_compress + + def check(self, headers): + if self.should_compress: + assert headers["content-encoding"] == "deflate" + else: + assert "content-encoding" not in headers + + class MyHTTPClient(DefaultHTTPClient): + def __init__(self, compression_checker: CheckCompression) -> None: + super().__init__() + self.checker = compression_checker + + def send_request( + self, session, method, url, headers=None, params=None, data=None, auth=None + ): + self.checker.check(headers) + return super().send_request( + session, method, url, headers, params, data, auth + ) + + checker = CheckCompression(should_compress=False) + + # should not compress, as threshold is 0 + client = ArangoClient( + hosts="http://127.0.0.1:8529", + http_client=MyHTTPClient(compression_checker=checker), + response_compression="gzip", + ) + db = client.db(db.name, username, password) + col = db.create_collection(generate_col_name()) + col.insert({"_key": "1"}) + + # should not compress, as size of payload is less than threshold + checker = CheckCompression(should_compress=False) + client = ArangoClient( + hosts="http://127.0.0.1:8529", + http_client=MyHTTPClient(compression_checker=checker), + request_compression=DeflateRequestCompression(250, level=7), + response_compression="deflate", + ) + db = client.db(db.name, username, password) + col = db.create_collection(generate_col_name()) + col.insert({"_key": "2"}) + + # should compress + checker.should_compress = True + col.insert({"_key": "3" * 250})