Skip to content

JwtSuperuserConnection #15

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Sep 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions arangoasync/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import time
from dataclasses import dataclass
from typing import Optional

import jwt

Expand Down Expand Up @@ -39,6 +40,38 @@ def __init__(self, token: str) -> None:
self._token = token
self._validate()

@staticmethod
def generate_token(
secret: str | bytes,
iat: Optional[int] = None,
exp: int = 3600,
iss: str = "arangodb",
server_id: str = "client",
) -> "JwtToken":
"""Generate and return a JWT token.

Args:
secret (str | bytes): JWT secret.
iat (int): Time the token was issued in seconds. Defaults to current time.
exp (int): Time to expire in seconds.
iss (str): Issuer.
server_id (str): Server ID.

Returns:
str: JWT token.
"""
iat = iat or int(time.time())
token = jwt.encode(
payload={
"iat": iat,
"exp": iat + exp,
"iss": iss,
"server_id": server_id,
},
key=secret,
)
return JwtToken(token)

@property
def token(self) -> str:
"""Get token."""
Expand Down
140 changes: 117 additions & 23 deletions arangoasync/connection.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
__all__ = [
"BaseConnection",
"BasicConnection",
"JwtConnection",
"JwtSuperuserConnection",
]

import json
Expand All @@ -9,6 +11,7 @@

import jwt

from arangoasync import errno, logger
from arangoasync.auth import Auth, JwtToken
from arangoasync.compression import CompressionManager, DefaultCompressionManager
from arangoasync.exceptions import (
Expand Down Expand Up @@ -55,25 +58,45 @@ def db_name(self) -> str:
"""Return the database name."""
return self._db_name

def prep_response(self, request: Request, resp: Response) -> Response:
"""Prepare response for return.
@staticmethod
def raise_for_status(request: Request, resp: Response) -> None:
"""Raise an exception based on the response.

Args:
request (Request): Request object.
resp (Response): Response object.

Returns:
Response: Response object

Raises:
ServerConnectionError: If the response status code is not successful.
"""
# TODO needs refactoring such that it does not throw
resp.is_success = 200 <= resp.status_code < 300
if resp.status_code in {401, 403}:
raise ServerConnectionError(resp, request, "Authentication failed.")
if not resp.is_success:
raise ServerConnectionError(resp, request, "Bad server response.")

@staticmethod
def prep_response(request: Request, resp: Response) -> Response:
"""Prepare response for return.

Args:
request (Request): Request object.
resp (Response): Response object.

Returns:
Response: Response object
"""
resp.is_success = 200 <= resp.status_code < 300
if not resp.is_success:
try:
body = json.loads(resp.raw_body)
except json.JSONDecodeError as e:
logger.debug(
f"Failed to decode response body: {e} (from request {request})"
)
else:
if body.get("error") is True:
resp.error_code = body.get("errorNum")
resp.error_message = body.get("errorMessage")
return resp

async def process_request(self, request: Request) -> Response:
Expand All @@ -86,7 +109,7 @@ async def process_request(self, request: Request) -> Response:
Response: Response object.

Raises:
ConnectionAbortedError: If can't connect to host(s) within limit.
ConnectionAbortedError: If it can't connect to host(s) within limit.
"""

host_index = self._host_resolver.get_host_index()
Expand All @@ -100,6 +123,7 @@ async def process_request(self, request: Request) -> Response:
ex_host_index = host_index
host_index = self._host_resolver.get_host_index()
if ex_host_index == host_index:
# Force change host if the same host is selected
self._host_resolver.change_host()
host_index = self._host_resolver.get_host_index()

Expand All @@ -117,8 +141,8 @@ async def ping(self) -> int:
ServerConnectionError: If the response status code is not successful.
"""
request = Request(method=Method.GET, endpoint="/_api/collection")
request.headers = {"abde": "fghi"}
resp = await self.send_request(request)
self.raise_for_status(request, resp)
return resp.status_code

@abstractmethod
Expand Down Expand Up @@ -257,15 +281,15 @@ async def refresh_token(self) -> None:
if self._auth is None:
raise JWTRefreshError("Auth must be provided to refresh the token.")

data = json.dumps(
auth_data = json.dumps(
dict(username=self._auth.username, password=self._auth.password),
separators=(",", ":"),
ensure_ascii=False,
)
request = Request(
method=Method.POST,
endpoint="/_open/auth",
data=data.encode("utf-8"),
data=auth_data.encode("utf-8"),
)

try:
Expand Down Expand Up @@ -310,16 +334,86 @@ async def send_request(self, request: Request) -> Response:

request.headers["authorization"] = self._auth_header

try:
resp = await self.process_request(request)
if (
resp.status_code == 401 # Unauthorized
and self._token is not None
and self._token.needs_refresh(self._expire_leeway)
):
await self.refresh_token()
return await self.process_request(request) # Retry with new token
except ServerConnectionError:
# TODO modify after refactoring of prep_response, so we can inspect response
resp = await self.process_request(request)
if (
resp.status_code == errno.HTTP_UNAUTHORIZED
and self._token is not None
and self._token.needs_refresh(self._expire_leeway)
):
# If the token has expired, refresh it and retry the request
await self.refresh_token()
return await self.process_request(request) # Retry with new token
resp = await self.process_request(request)
self.raise_for_status(request, resp)
return resp


class JwtSuperuserConnection(BaseConnection):
"""Connection to a specific ArangoDB database, using superuser JWT.

The JWT token is not refreshed and (username and password) are not required.

Args:
sessions (list): List of client sessions.
host_resolver (HostResolver): Host resolver.
http_client (HTTPClient): HTTP client.
db_name (str): Database name.
compression (CompressionManager | None): Compression manager.
token (JwtToken | None): JWT token.
"""

def __init__(
self,
sessions: List[Any],
host_resolver: HostResolver,
http_client: HTTPClient,
db_name: str,
compression: Optional[CompressionManager] = None,
token: Optional[JwtToken] = None,
) -> None:
super().__init__(sessions, host_resolver, http_client, db_name, compression)
self._expire_leeway: int = 0
self._token: Optional[JwtToken] = None
self._auth_header: Optional[str] = None
self.token = token

@property
def token(self) -> Optional[JwtToken]:
"""Get the JWT token.

Returns:
JwtToken | None: JWT token.
"""
return self._token

@token.setter
def token(self, token: Optional[JwtToken]) -> None:
"""Set the JWT token.

Args:
token (JwtToken | None): JWT token.
Setting it to None will cause the token to be automatically
refreshed on the next request, if auth information is provided.
"""
self._token = token
self._auth_header = f"bearer {self._token.token}" if self._token else None

async def send_request(self, request: Request) -> Response:
"""Send an HTTP request to the ArangoDB server.

Args:
request (Request): HTTP request.

Returns:
Response: HTTP response

Raises:
ArangoClientError: If an error occurred from the client side.
ArangoServerError: If an error occurred from the server side.
"""
if self._auth_header is None:
raise AuthHeaderError("Failed to generate authorization header.")
request.headers["authorization"] = self._auth_header

resp = await self.process_request(request)
self.raise_for_status(request, resp)
return resp
3 changes: 3 additions & 0 deletions arangoasync/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,3 +102,6 @@ def normalized_params(self) -> Params:
normalized_params[key] = str(value)

return normalized_params

def __repr__(self) -> str:
return f"<{self.method.name} {self.endpoint}>"
7 changes: 4 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pytest
import pytest_asyncio

from tests.helpers import generate_jwt
from arangoasync.auth import JwtToken


@dataclass
Expand Down Expand Up @@ -45,8 +45,8 @@ def pytest_configure(config):
global_data.url = url
global_data.root = config.getoption("root")
global_data.password = config.getoption("password")
global_data.secret = generate_jwt(config.getoption("secret"))
global_data.token = generate_jwt(global_data.secret)
global_data.secret = config.getoption("secret")
global_data.token = JwtToken.generate_token(global_data.secret)


@pytest.fixture(autouse=False)
Expand Down Expand Up @@ -76,6 +76,7 @@ def sys_db_name():

@pytest_asyncio.fixture
async def client_session():
"""Make sure we close all sessions after the test is done."""
sessions = []

def get_client_session(client, url):
Expand Down
25 changes: 0 additions & 25 deletions tests/helpers.py

This file was deleted.

Loading
Loading