From 86bce915de34982a5232e1468daf9feba669e714 Mon Sep 17 00:00:00 2001 From: Ruben Fonseca Date: Tue, 19 Sep 2023 16:05:48 -0400 Subject: [PATCH 01/75] feat: generate OpenAPI spec from event handler --- .../event_handler/api_gateway.py | 393 +++++------- .../event_handler/openapi/__init__.py | 11 + .../event_handler/openapi/dependant.py | 103 ++++ .../event_handler/openapi/models.py | 557 ++++++++++++++++++ .../event_handler/openapi/params.py | 342 +++++++++++ .../event_handler/openapi/utils.py | 77 +++ aws_lambda_powertools/event_handler/route.py | 366 ++++++++++++ 7 files changed, 1608 insertions(+), 241 deletions(-) create mode 100644 aws_lambda_powertools/event_handler/openapi/__init__.py create mode 100644 aws_lambda_powertools/event_handler/openapi/dependant.py create mode 100644 aws_lambda_powertools/event_handler/openapi/models.py create mode 100644 aws_lambda_powertools/event_handler/openapi/params.py create mode 100644 aws_lambda_powertools/event_handler/openapi/utils.py create mode 100644 aws_lambda_powertools/event_handler/route.py diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index 46cb5587135..d148b5b9ae3 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -4,15 +4,35 @@ import re import traceback import warnings -import zlib from abc import ABC, abstractmethod from enum import Enum from functools import partial from http import HTTPStatus -from typing import Any, Callable, Dict, List, Match, Optional, Pattern, Set, Tuple, Type, Union +from typing import ( + Any, + Callable, + Dict, + List, + Match, + Optional, + Pattern, + Sequence, + Set, + Tuple, + Type, + Union, +) + +import zlib +from pydantic.fields import ModelField +from pydantic.schema import get_flat_models_from_fields, get_model_name_map, model_process_schema from aws_lambda_powertools.event_handler import content_types from aws_lambda_powertools.event_handler.exceptions import NotFoundError, ServiceError +from aws_lambda_powertools.event_handler.openapi.dependant import get_dependant +from aws_lambda_powertools.event_handler.openapi.models import Contact, License, OpenAPI, Server, Tag +from aws_lambda_powertools.event_handler.openapi.utils import get_flat_params +from aws_lambda_powertools.event_handler.route import Route from aws_lambda_powertools.shared.cookies import Cookie from aws_lambda_powertools.shared.functions import powertools_dev_is_set from aws_lambda_powertools.shared.json_encoder import Encoder @@ -197,142 +217,6 @@ def __init__( self.headers.setdefault("Content-Type", content_type) -class Route: - """Internally used Route Configuration""" - - def __init__( - self, - method: str, - rule: Pattern, - func: Callable, - cors: bool, - compress: bool, - cache_control: Optional[str], - middlewares: Optional[List[Callable[..., Response]]], - ): - """ - - Parameters - ---------- - - method: str - The HTTP method, example "GET" - rule: Pattern - The route rule, example "/my/path" - func: Callable - The route handler function - cors: bool - Whether or not to enable CORS for this route - compress: bool - Whether or not to enable gzip compression for this route - cache_control: Optional[str] - The cache control header value, example "max-age=3600" - middlewares: Optional[List[Callable[..., Response]]] - The list of route middlewares to be called in order. - """ - self.method = method.upper() - self.rule = rule - self.func = func - self._middleware_stack = func - self.cors = cors - self.compress = compress - self.cache_control = cache_control - self.middlewares = middlewares or [] - - # _middleware_stack_built is used to ensure the middleware stack is only built once. - self._middleware_stack_built = False - - def __call__( - self, - router_middlewares: List[Callable], - app: "ApiGatewayResolver", - route_arguments: Dict[str, str], - ) -> Union[Dict, Tuple, Response]: - """Calling the Router class instance will trigger the following actions: - 1. If Route Middleware stack has not been built, build it - 2. Call the Route Middleware stack wrapping the original function - handler with the app and route arguments. - - Parameters - ---------- - router_middlewares: List[Callable] - The list of Router Middlewares (assigned to ALL routes) - app: "ApiGatewayResolver" - The ApiGatewayResolver instance to pass into the middleware stack - route_arguments: Dict[str, str] - The route arguments to pass to the app function (extracted from the Api Gateway - Lambda Message structure from AWS) - - Returns - ------- - Union[Dict, Tuple, Response] - API Response object in ALL cases, except when the original API route - handler is called which may also return a Dict, Tuple, or Response. - """ - - # Save CPU cycles by building middleware stack once - if not self._middleware_stack_built: - self._build_middleware_stack(router_middlewares=router_middlewares) - - # If debug is turned on then output the middleware stack to the console - if app._debug: - print(f"\nProcessing Route:::{self.func.__name__} ({app.context['_path']})") - # Collect ALL middleware for debug printing - include internal _registered_api_adapter - all_middlewares = router_middlewares + self.middlewares + [_registered_api_adapter] - print("\nMiddleware Stack:") - print("=================") - print("\n".join(getattr(item, "__name__", "Unknown") for item in all_middlewares)) - print("=================") - - # Add Route Arguments to app context - app.append_context(_route_args=route_arguments) - - # Call the Middleware Wrapped _call_stack function handler with the app - return self._middleware_stack(app) - - def _build_middleware_stack(self, router_middlewares: List[Callable[..., Any]]) -> None: - """ - Builds the middleware stack for the handler by wrapping each - handler in an instance of MiddlewareWrapper which is used to contain the state - of each middleware step. - - Middleware is represented by a standard Python Callable construct. Any Middleware - handler wanting to short-circuit the middlware call chain can raise an exception - to force the Python call stack created by the handler call-chain to naturally un-wind. - - This becomes a simple concept for developers to understand and reason with - no additional - gymanstics other than plain old try ... except. - - Notes - ----- - The Route Middleware stack is processed in reverse order. This is so the stack of - middleware handlers is applied in the order of being added to the handler. - """ - all_middlewares = router_middlewares + self.middlewares - logger.debug(f"Building middleware stack: {all_middlewares}") - - # IMPORTANT: - # this must be the last middleware in the stack (tech debt for backward - # compatibility purposes) - # - # This adapter will: - # 1. Call the registered API passing only the expected route arguments extracted from the path - # and not the middleware. - # 2. Adapt the response type of the route handler (Union[Dict, Tuple, Response]) - # and normalise into a Response object so middleware will always have a constant signature - all_middlewares.append(_registered_api_adapter) - - # Wrap the original route handler function in the middleware handlers - # using the MiddlewareWrapper class callable construct in reverse order to - # ensure middleware is applied in the order the user defined. - # - # Start with the route function and wrap from last to the first Middleware handler. - for handler in reversed(all_middlewares): - self._middleware_stack = MiddlewareFrame(current_middleware=handler, next_middleware=self._middleware_stack) - - self._middleware_stack_built = True - - class ResponseBuilder: """Internally used Response builder""" @@ -674,109 +558,6 @@ def clear_context(self): self.context.clear() -class MiddlewareFrame: - """ - creates a Middle Stack Wrapper instance to be used as a "Frame" in the overall stack of - middleware functions. Each instance contains the current middleware and the next - middleware function to be called in the stack. - - In this way the middleware stack is constructed in a recursive fashion, with each middleware - calling the next as a simple function call. The actual Python call-stack will contain - each MiddlewareStackWrapper "Frame", meaning any Middleware function can cause the - entire Middleware call chain to be exited early (short-circuited) by raising an exception - or by simply returning early with a custom Response. The decision to short-circuit the middleware - chain is at the user's discretion but instantly available due to the Wrapped nature of the - callable constructs in the Middleware stack and each Middleware function having complete control over - whether the "Next" handler in the stack is called or not. - - Parameters - ---------- - current_middleware : Callable - The current middleware function to be called as a request is processed. - next_middleware : Callable - The next middleware in the middleware stack. - """ - - def __init__( - self, - current_middleware: Callable[..., Any], - next_middleware: Callable[..., Any], - ) -> None: - self.current_middleware: Callable[..., Any] = current_middleware - self.next_middleware: Callable[..., Any] = next_middleware - self._next_middleware_name = next_middleware.__name__ - - @property - def __name__(self) -> str: # noqa: A003 - """Current middleware name - - It ensures backward compatibility with view functions being callable. This - improves debugging since we need both current and next middlewares/callable names. - """ - return self.current_middleware.__name__ - - def __str__(self) -> str: - """Identify current middleware identity and call chain for debugging purposes.""" - middleware_name = self.__name__ - return f"[{middleware_name}] next call chain is {middleware_name} -> {self._next_middleware_name}" - - def __call__(self, app: "ApiGatewayResolver") -> Union[Dict, Tuple, Response]: - """ - Call the middleware Frame to process the request. - - Parameters - ---------- - app: BaseRouter - The router instance - - Returns - ------- - Union[Dict, Tuple, Response] - (tech-debt for backward compatibility). The response type should be a - Response object in all cases excepting when the original API route handler - is called which will return one of 3 outputs. - - """ - # Do debug printing and push processed stack frame AFTER calling middleware - # else the stack frame text of `current calling next` is confusing. - logger.debug("MiddlewareFrame: %s", self) - app._push_processed_stack_frame(str(self)) - - return self.current_middleware(app, self.next_middleware) - - -def _registered_api_adapter( - app: "ApiGatewayResolver", - next_middleware: Callable[..., Any], -) -> Union[Dict, Tuple, Response]: - """ - Calls the registered API using the "_route_args" from the Resolver context to ensure the last call - in the chain will match the API route function signature and ensure that Powertools passes the API - route handler the expected arguments. - - **IMPORTANT: This internal middleware ensures the actual API route is called with the correct call signature - and it MUST be the final frame in the middleware stack. This can only be removed when the API Route - function accepts `app: BaseRouter` as the first argument - which is the breaking change. - - Parameters - ---------- - app: ApiGatewayResolver - The API Gateway resolver - next_middleware: Callable[..., Any] - The function to handle the API - - Returns - ------- - Response - The API Response Object - - """ - route_args: Dict = app.context.get("_route_args", {}) - logger.debug(f"Calling API Route Handler: {route_args}") - - return app._to_response(next_middleware(**route_args)) - - class ApiGatewayResolver(BaseRouter): """API Gateway and ALB proxy resolver @@ -847,6 +628,119 @@ def __init__( # Allow for a custom serializer or a concise json serialization self._serializer = serializer or partial(json.dumps, separators=(",", ":"), cls=Encoder) + def get_openapi_schema( + self, + *, + title: str, + version: str, + openapi_version: str = "3.1.0", + summary: Optional[str] = None, + description: Optional[str] = None, + tags: Optional[List[Tag]] = None, + servers: Optional[List[Server]] = None, + terms_of_service: Optional[str] = None, + contact: Optional[Contact] = None, + license_info: Optional[License] = None, + ) -> OpenAPI: + info: Dict[str, Any] = {"title": title, "version": version} + if summary: + info["summary"] = summary + if description: + info["description"] = description + if terms_of_service: + info["termsOfService"] = terms_of_service + if contact: + info["contact"] = contact + if license_info: + info["license"] = license_info + + output: Dict[str, Any] = {"openapi": openapi_version, "info": info} + if servers: + output["servers"] = servers + else: + # If the servers property is not provided, or is an empty array, the default value would be a Server Object + # with a url value of /. + output["servers"] = [Server(url="/")] + + components: Dict[str, Dict[str, Any]] = {} + paths: Dict[str, Dict[str, Any]] = {} + operation_ids: Set[str] = set() + + all_routes = self._dynamic_routes + self._static_routes + all_fields = self._get_fields_from_routes(all_routes) + models = get_flat_models_from_fields(all_fields, known_models=set()) + model_name_map = get_model_name_map(models) + + definitions: Dict[str, Dict[str, Any]] = {} + for model in models: + m_schema, m_definitions, _ = model_process_schema( + model, + model_name_map=model_name_map, + ref_prefix="#/components/schemas/", + ) + definitions.update(m_definitions) + model_name = model_name_map[model] + if "description" in m_schema: + m_schema["description"] = m_schema["description"].split("\f")[0] + definitions[model_name] = m_schema + + for route in all_routes: + dependant = get_dependant( + path=route.func.__name__, + call=route.func, + ) + + result = route._openapi_path( + dependant=dependant, + operation_ids=operation_ids, + model_name_map=model_name_map, + ) + if result: + path, path_definitions = result + if path: + paths.setdefault(route.path, {}).update(path) + if path_definitions: + definitions.update(path_definitions) + + if definitions: + components["schemas"] = {k: definitions[k] for k in sorted(definitions)} + if components: + output["components"] = components + if tags: + output["tags"] = tags + + output["paths"] = paths + + return OpenAPI(**output) # .dict(by_alias=True, exclude_none=True) + + def get_openapi_json_schema( + self, + *, + title: str, + version: str, + openapi_version: str = "3.1.0", + summary: Optional[str] = None, + description: Optional[str] = None, + tags: Optional[List[Tag]] = None, + servers: Optional[List[Server]] = None, + terms_of_service: Optional[str] = None, + contact: Optional[Contact] = None, + license_info: Optional[License] = None, + ) -> str: + """Returns the OpenAPI schema as a JSON serializable dict""" + return self.get_openapi_schema( + title=title, + version=version, + openapi_version=openapi_version, + summary=summary, + description=description, + tags=tags, + servers=servers, + terms_of_service=terms_of_service, + contact=contact, + license_info=license_info, + ).json(by_alias=True, exclude_none=True, indent=2) + def route( self, rule: str, @@ -869,6 +763,7 @@ def register_resolver(func: Callable): for item in methods: _route = Route( item, + rule, self._compile_regex(rule), func, cors_enabled, @@ -1229,6 +1124,22 @@ def include_router(self, router: "Router", prefix: Optional[str] = None) -> None # Still need to ignore for mypy checks or will cause failures (false-positive) self.route(*new_route, middlewares=middlewares)(func) # type: ignore + @staticmethod + def _get_fields_from_routes(routes: Sequence[Route]) -> List[ModelField]: + responses_from_routes: List[ModelField] = [] + request_fields_from_routes: List[ModelField] = [] + + for route in routes: + dependant = get_dependant(path=route.path, call=route.func) + params = get_flat_params(dependant) + request_fields_from_routes.extend(params) + + if dependant.return_param: + responses_from_routes.append(dependant.return_param) + + flat_models = list(responses_from_routes + request_fields_from_routes) + return flat_models + class Router(BaseRouter): """Router helper class to allow splitting ApiGatewayResolver into multiple files""" diff --git a/aws_lambda_powertools/event_handler/openapi/__init__.py b/aws_lambda_powertools/event_handler/openapi/__init__.py new file mode 100644 index 00000000000..91c5b0259f2 --- /dev/null +++ b/aws_lambda_powertools/event_handler/openapi/__init__.py @@ -0,0 +1,11 @@ +from aws_lambda_powertools.event_handler.openapi.models import ( + Example, + Info, + MediaType, + Operation, + Reference, + Response, + Schema, +) + +__all__ = ["Info", "Operation", "Response", "MediaType", "Reference", "Schema", "Example"] diff --git a/aws_lambda_powertools/event_handler/openapi/dependant.py b/aws_lambda_powertools/event_handler/openapi/dependant.py new file mode 100644 index 00000000000..8ebc2f84caf --- /dev/null +++ b/aws_lambda_powertools/event_handler/openapi/dependant.py @@ -0,0 +1,103 @@ +import inspect +import re +from typing import Any, Callable, Dict, ForwardRef, Optional, Set, cast + +from pydantic.fields import ModelField +from pydantic.typing import evaluate_forwardref + +from aws_lambda_powertools.event_handler.openapi.params import Dependant, Param, ParamTypes, analyze_param + + +def add_param_to_fields( + *, + field: ModelField, + dependant: Dependant, +) -> None: + field_info = cast(Param, field.field_info) + if field_info.in_ == ParamTypes.path: + dependant.path_params.append(field) + elif field_info.in_ == ParamTypes.query: + dependant.query_params.append(field) + elif field_info.in_ == ParamTypes.header: + dependant.header_params.append(field) + else: + assert field_info.in_ == ParamTypes.cookie + dependant.cookie_params.append(field) + + +def get_typed_annotation(annotation: Any, globalns: Dict[str, Any]) -> Any: + if isinstance(annotation, str): + annotation = ForwardRef(annotation) + annotation = evaluate_forwardref(annotation, globalns, globalns) + return annotation + + +def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature: + signature = inspect.signature(call) + globalns = getattr(call, "__global__", {}) + typed_params = [ + inspect.Parameter( + name=param.name, + kind=param.kind, + default=param.default, + annotation=get_typed_annotation(param.annotation, globalns), + ) + for param in signature.parameters.values() + ] + + if signature.return_annotation is not inspect.Signature.empty: + return_param = inspect.Parameter( + name="Return", + kind=inspect.Parameter.POSITIONAL_OR_KEYWORD, + default=None, + annotation=get_typed_annotation(signature.return_annotation, globalns), + ) + return inspect.Signature(typed_params, return_annotation=return_param.annotation) + else: + return inspect.Signature(typed_params) + + +def get_path_param_names(path: str) -> Set[str]: + return set(re.findall("{(.*?)}", path)) + + +def get_dependant( + *, + path: str, + call: Callable[..., Any], + name: Optional[str] = None, +) -> Dependant: + path_param_names = get_path_param_names(path) + endpoint_signature = get_typed_signature(call) + signature_params = endpoint_signature.parameters + dependant = Dependant( + call=call, + name=name, + path=path, + ) + + for param_name, param in signature_params.items(): + is_path_param = param_name in path_param_names + type_annotation, param_field = analyze_param( + param_name=param_name, + annotation=param.annotation, + value=param.default, + is_path_param=is_path_param, + ) + assert param_field is not None + + add_param_to_fields(field=param_field, dependant=dependant) + + return_annotation = endpoint_signature.return_annotation + if return_annotation is not inspect.Signature.empty: + type_annotation, param_field = analyze_param( + param_name="Return", + annotation=return_annotation, + value=None, + is_path_param=False, + ) + assert param_field is not None + + dependant.return_param = param_field + + return dependant diff --git a/aws_lambda_powertools/event_handler/openapi/models.py b/aws_lambda_powertools/event_handler/openapi/models.py new file mode 100644 index 00000000000..e492416e30a --- /dev/null +++ b/aws_lambda_powertools/event_handler/openapi/models.py @@ -0,0 +1,557 @@ +from enum import Enum +from typing import Any, Dict, List, Optional, Set, Union + +from pydantic import AnyUrl, BaseModel, Field +from pydantic.version import VERSION as PYDANTIC_VERSION +from typing_extensions import Annotated, Literal + +PYDANTIC_V2 = PYDANTIC_VERSION.startswith("2.") + + +class Contact(BaseModel): + name: Optional[str] = None + url: Optional[AnyUrl] = None + email: Optional[str] = None + + if PYDANTIC_V2: + model_config = {"extra": "allow"} + else: + + class Config: + extra = "allow" + + +class License(BaseModel): + name: str + identifier: Optional[str] = None + url: Optional[AnyUrl] = None + + if PYDANTIC_V2: + model_config = {"extra": "allow"} + + else: + + class Config: + extra = "allow" + + +class Info(BaseModel): + title: str + summary: Optional[str] = None + description: Optional[str] = None + termsOfService: Optional[str] = None + contact: Optional[Contact] = None + license: Optional[License] = None + version: str + + if PYDANTIC_V2: + model_config = {"extra": "allow"} + + else: + + class Config: + extra = "allow" + + +class ServerVariable(BaseModel): + enum: Annotated[Optional[List[str]], Field(min_length=1)] = None + default: str + description: Optional[str] = None + + if PYDANTIC_V2: + model_config = {"extra": "allow"} + + else: + + class Config: + extra = "allow" + + +class Server(BaseModel): + url: Union[AnyUrl, str] + description: Optional[str] = None + variables: Optional[Dict[str, ServerVariable]] = None + + if PYDANTIC_V2: + model_config = {"extra": "allow"} + + else: + + class Config: + extra = "allow" + + +class Reference(BaseModel): + ref: str = Field(alias="$ref") + + +class Discriminator(BaseModel): + propertyName: str + mapping: Optional[Dict[str, str]] = None + + +class XML(BaseModel): + name: Optional[str] = None + namespace: Optional[str] = None + prefix: Optional[str] = None + attribute: Optional[bool] = None + wrapped: Optional[bool] = None + + if PYDANTIC_V2: + model_config = {"extra": "allow"} + + else: + + class Config: + extra = "allow" + + +class ExternalDocumentation(BaseModel): + description: Optional[str] = None + url: AnyUrl + + if PYDANTIC_V2: + model_config = {"extra": "allow"} + + else: + + class Config: + extra = "allow" + + +class Schema(BaseModel): + # Ref: JSON Schema 2020-12: https://json-schema.org/draft/2020-12/json-schema-core.html#name-the-json-schema-core-vocabu + # Core Vocabulary + schema_: Optional[str] = Field(default=None, alias="$schema") + vocabulary: Optional[str] = Field(default=None, alias="$vocabulary") + id: Optional[str] = Field(default=None, alias="$id") + anchor: Optional[str] = Field(default=None, alias="$anchor") + dynamicAnchor: Optional[str] = Field(default=None, alias="$dynamicAnchor") + ref: Optional[str] = Field(default=None, alias="$ref") + dynamicRef: Optional[str] = Field(default=None, alias="$dynamicRef") + defs: Optional[Dict[str, "SchemaOrBool"]] = Field(default=None, alias="$defs") + comment: Optional[str] = Field(default=None, alias="$comment") + # Ref: JSON Schema 2020-12: https://json-schema.org/draft/2020-12/json-schema-core.html#name-a-vocabulary-for-applying-s + # A Vocabulary for Applying Subschemas + allOf: Optional[List["SchemaOrBool"]] = None + anyOf: Optional[List["SchemaOrBool"]] = None + oneOf: Optional[List["SchemaOrBool"]] = None + not_: Optional["SchemaOrBool"] = Field(default=None, alias="not") + if_: Optional["SchemaOrBool"] = Field(default=None, alias="if") + then: Optional["SchemaOrBool"] = None + else_: Optional["SchemaOrBool"] = Field(default=None, alias="else") + dependentSchemas: Optional[Dict[str, "SchemaOrBool"]] = None + prefixItems: Optional[List["SchemaOrBool"]] = None + # TODO: uncomment and remove below when deprecating Pydantic v1 + # It generales a list of schemas for tuples, before prefixItems was available + # items: Optional["SchemaOrBool"] = None + items: Optional[Union["SchemaOrBool", List["SchemaOrBool"]]] = None + contains: Optional["SchemaOrBool"] = None + properties: Optional[Dict[str, "SchemaOrBool"]] = None + patternProperties: Optional[Dict[str, "SchemaOrBool"]] = None + additionalProperties: Optional["SchemaOrBool"] = None + propertyNames: Optional["SchemaOrBool"] = None + unevaluatedItems: Optional["SchemaOrBool"] = None + unevaluatedProperties: Optional["SchemaOrBool"] = None + # Ref: JSON Schema Validation 2020-12: https://json-schema.org/draft/2020-12/json-schema-validation.html#name-a-vocabulary-for-structural + # A Vocabulary for Structural Validation + type: Optional[str] = None + enum: Optional[List[Any]] = None + const: Optional[Any] = None + multipleOf: Optional[float] = Field(default=None, gt=0) + maximum: Optional[float] = None + exclusiveMaximum: Optional[float] = None + minimum: Optional[float] = None + exclusiveMinimum: Optional[float] = None + maxLength: Optional[int] = Field(default=None, ge=0) + minLength: Optional[int] = Field(default=None, ge=0) + pattern: Optional[str] = None + maxItems: Optional[int] = Field(default=None, ge=0) + minItems: Optional[int] = Field(default=None, ge=0) + uniqueItems: Optional[bool] = None + maxContains: Optional[int] = Field(default=None, ge=0) + minContains: Optional[int] = Field(default=None, ge=0) + maxProperties: Optional[int] = Field(default=None, ge=0) + minProperties: Optional[int] = Field(default=None, ge=0) + required: Optional[List[str]] = None + dependentRequired: Optional[Dict[str, Set[str]]] = None + # Ref: JSON Schema Validation 2020-12: https://json-schema.org/draft/2020-12/json-schema-validation.html#name-vocabularies-for-semantic-c + # Vocabularies for Semantic Content With "format" + format: Optional[str] = None + # Ref: JSON Schema Validation 2020-12: https://json-schema.org/draft/2020-12/json-schema-validation.html#name-a-vocabulary-for-the-conten + # A Vocabulary for the Contents of String-Encoded Data + contentEncoding: Optional[str] = None + contentMediaType: Optional[str] = None + contentSchema: Optional["SchemaOrBool"] = None + # Ref: JSON Schema Validation 2020-12: https://json-schema.org/draft/2020-12/json-schema-validation.html#name-a-vocabulary-for-basic-meta + # A Vocabulary for Basic Meta-Data Annotations + title: Optional[str] = None + description: Optional[str] = None + default: Optional[Any] = None + deprecated: Optional[bool] = None + readOnly: Optional[bool] = None + writeOnly: Optional[bool] = None + examples: Optional[List[Any]] = None + # Ref: OpenAPI 3.1.0: https://github.com/OAI/OpenAPI-Specification/blob/main/versions/3.1.0.md#schema-object + # Schema Object + discriminator: Optional[Discriminator] = None + xml: Optional[XML] = None + externalDocs: Optional[ExternalDocumentation] = None + + if PYDANTIC_V2: + model_config = {"extra": "allow"} + + else: + + class Config: + extra = "allow" + + +# Ref: https://json-schema.org/draft/2020-12/json-schema-core.html#name-json-schema-documents +# A JSON Schema MUST be an object or a boolean. +SchemaOrBool = Union[Schema, bool] + + +class Example(BaseModel): + summary: Optional[str] = None + description: Optional[str] = None + value: Optional[Any] = None + externalValue: Optional[AnyUrl] = None + + if PYDANTIC_V2: + model_config = {"extra": "allow"} + + else: + + class Config: + extra = "allow" + + +class ParameterInType(Enum): + query = "query" + header = "header" + path = "path" + cookie = "cookie" + + +class Encoding(BaseModel): + contentType: Optional[str] = None + headers: Optional[Dict[str, Union["Header", Reference]]] = None + style: Optional[str] = None + explode: Optional[bool] = None + allowReserved: Optional[bool] = None + + if PYDANTIC_V2: + model_config = {"extra": "allow"} + + else: + + class Config: + extra = "allow" + + +class MediaType(BaseModel): + schema_: Optional[Union[Schema, Reference]] = Field(default=None, alias="schema") + example: Optional[Any] = None + examples: Optional[Dict[str, Union[Example, Reference]]] = None + encoding: Optional[Dict[str, Encoding]] = None + + if PYDANTIC_V2: + model_config = {"extra": "allow"} + + else: + + class Config: + extra = "allow" + + +class ParameterBase(BaseModel): + description: Optional[str] = None + required: Optional[bool] = None + deprecated: Optional[bool] = None + # Serialization rules for simple scenarios + style: Optional[str] = None + explode: Optional[bool] = None + allowReserved: Optional[bool] = None + schema_: Optional[Union[Schema, Reference]] = Field(default=None, alias="schema") + example: Optional[Any] = None + examples: Optional[Dict[str, Union[Example, Reference]]] = None + # Serialization rules for more complex scenarios + content: Optional[Dict[str, MediaType]] = None + + if PYDANTIC_V2: + model_config = {"extra": "allow"} + + else: + + class Config: + extra = "allow" + + +class Parameter(ParameterBase): + name: str + in_: ParameterInType = Field(alias="in") + + +class Header(ParameterBase): + pass + + +class RequestBody(BaseModel): + description: Optional[str] = None + content: Dict[str, MediaType] + required: Optional[bool] = None + + if PYDANTIC_V2: + model_config = {"extra": "allow"} + + else: + + class Config: + extra = "allow" + + +class Link(BaseModel): + operationRef: Optional[str] = None + operationId: Optional[str] = None + parameters: Optional[Dict[str, Union[Any, str]]] = None + requestBody: Optional[Union[Any, str]] = None + description: Optional[str] = None + server: Optional[Server] = None + + if PYDANTIC_V2: + model_config = {"extra": "allow"} + + else: + + class Config: + extra = "allow" + + +class Response(BaseModel): + description: str + headers: Optional[Dict[str, Union[Header, Reference]]] = None + content: Optional[Dict[str, MediaType]] = None + links: Optional[Dict[str, Union[Link, Reference]]] = None + + if PYDANTIC_V2: + model_config = {"extra": "allow"} + + else: + + class Config: + extra = "allow" + + +class Operation(BaseModel): + tags: Optional[List[str]] = None + summary: Optional[str] = None + description: Optional[str] = None + externalDocs: Optional[ExternalDocumentation] = None + operationId: Optional[str] = None + parameters: Optional[List[Union[Parameter, Reference]]] = None + requestBody: Optional[Union[RequestBody, Reference]] = None + # Using Any for Specification Extensions + responses: Optional[Dict[str, Union[Response, Any]]] = None + callbacks: Optional[Dict[str, Union[Dict[str, "PathItem"], Reference]]] = None + deprecated: Optional[bool] = None + security: Optional[List[Dict[str, List[str]]]] = None + servers: Optional[List[Server]] = None + + if PYDANTIC_V2: + model_config = {"extra": "allow"} + + else: + + class Config: + extra = "allow" + + +class PathItem(BaseModel): + ref: Optional[str] = Field(default=None, alias="$ref") + summary: Optional[str] = None + description: Optional[str] = None + get: Optional[Operation] = None + put: Optional[Operation] = None + post: Optional[Operation] = None + delete: Optional[Operation] = None + options: Optional[Operation] = None + head: Optional[Operation] = None + patch: Optional[Operation] = None + trace: Optional[Operation] = None + servers: Optional[List[Server]] = None + parameters: Optional[List[Union[Parameter, Reference]]] = None + + if PYDANTIC_V2: + model_config = {"extra": "allow"} + + else: + + class Config: + extra = "allow" + + +class SecuritySchemeType(Enum): + apiKey = "apiKey" + http = "http" + oauth2 = "oauth2" + openIdConnect = "openIdConnect" + + +class SecurityBase(BaseModel): + type_: SecuritySchemeType = Field(alias="type") + description: Optional[str] = None + + if PYDANTIC_V2: + model_config = {"extra": "allow"} + + else: + + class Config: + extra = "allow" + + +class APIKeyIn(Enum): + query = "query" + header = "header" + cookie = "cookie" + + +class APIKey(SecurityBase): + type_: SecuritySchemeType = Field(default=SecuritySchemeType.apiKey, alias="type") + in_: APIKeyIn = Field(alias="in") + name: str + + +class HTTPBase(SecurityBase): + type_: SecuritySchemeType = Field(default=SecuritySchemeType.http, alias="type") + scheme: str + + +class HTTPBearer(HTTPBase): + scheme: Literal["bearer"] = "bearer" + bearerFormat: Optional[str] = None + + +class OAuthFlow(BaseModel): + refreshUrl: Optional[str] = None + scopes: Dict[str, str] = {} + + if PYDANTIC_V2: + model_config = {"extra": "allow"} + + else: + + class Config: + extra = "allow" + + +class OAuthFlowImplicit(OAuthFlow): + authorizationUrl: str + + +class OAuthFlowPassword(OAuthFlow): + tokenUrl: str + + +class OAuthFlowClientCredentials(OAuthFlow): + tokenUrl: str + + +class OAuthFlowAuthorizationCode(OAuthFlow): + authorizationUrl: str + tokenUrl: str + + +class OAuthFlows(BaseModel): + implicit: Optional[OAuthFlowImplicit] = None + password: Optional[OAuthFlowPassword] = None + clientCredentials: Optional[OAuthFlowClientCredentials] = None + authorizationCode: Optional[OAuthFlowAuthorizationCode] = None + + if PYDANTIC_V2: + model_config = {"extra": "allow"} + + else: + + class Config: + extra = "allow" + + +class OAuth2(SecurityBase): + type_: SecuritySchemeType = Field(default=SecuritySchemeType.oauth2, alias="type") + flows: OAuthFlows + + +class OpenIdConnect(SecurityBase): + type_: SecuritySchemeType = Field( + default=SecuritySchemeType.openIdConnect, + alias="type", + ) + openIdConnectUrl: str + + +SecurityScheme = Union[APIKey, HTTPBase, OAuth2, OpenIdConnect, HTTPBearer] + + +class Components(BaseModel): + schemas: Optional[Dict[str, Union[Schema, Reference]]] = None + responses: Optional[Dict[str, Union[Response, Reference]]] = None + parameters: Optional[Dict[str, Union[Parameter, Reference]]] = None + examples: Optional[Dict[str, Union[Example, Reference]]] = None + requestBodies: Optional[Dict[str, Union[RequestBody, Reference]]] = None + headers: Optional[Dict[str, Union[Header, Reference]]] = None + securitySchemes: Optional[Dict[str, Union[SecurityScheme, Reference]]] = None + links: Optional[Dict[str, Union[Link, Reference]]] = None + # Using Any for Specification Extensions + callbacks: Optional[Dict[str, Union[Dict[str, PathItem], Reference, Any]]] = None + pathItems: Optional[Dict[str, Union[PathItem, Reference]]] = None + + if PYDANTIC_V2: + model_config = {"extra": "allow"} + + else: + + class Config: + extra = "allow" + + +class Tag(BaseModel): + name: str + description: Optional[str] = None + externalDocs: Optional[ExternalDocumentation] = None + + if PYDANTIC_V2: + model_config = {"extra": "allow"} + + else: + + class Config: + extra = "allow" + + +class OpenAPI(BaseModel): + openapi: str + info: Info + jsonSchemaDialect: Optional[str] = None + servers: Optional[List[Server]] = None + # Using Any for Specification Extensions + paths: Optional[Dict[str, Union[PathItem, Any]]] = None + webhooks: Optional[Dict[str, Union[PathItem, Reference]]] = None + components: Optional[Components] = None + security: Optional[List[Dict[str, List[str]]]] = None + tags: Optional[List[Tag]] = None + externalDocs: Optional[ExternalDocumentation] = None + + if PYDANTIC_V2: + model_config = {"extra": "allow"} + + else: + + class Config: + extra = "allow" + + +Schema.update_forward_refs() +Operation.update_forward_refs() +Encoding.update_forward_refs() diff --git a/aws_lambda_powertools/event_handler/openapi/params.py b/aws_lambda_powertools/event_handler/openapi/params.py new file mode 100644 index 00000000000..63e1cb14c0b --- /dev/null +++ b/aws_lambda_powertools/event_handler/openapi/params.py @@ -0,0 +1,342 @@ +import inspect +from copy import copy +from enum import Enum +from typing import Annotated, Any, Callable, Dict, List, Optional, Tuple, Union, get_args, get_origin + +from pydantic import BaseConfig +from pydantic.fields import FieldInfo, ModelField, Required, Undefined +from pydantic.schema import get_annotation_from_field_info + +from aws_lambda_powertools.event_handler.openapi import Example + + +class Dependant: + def __init__( + self, + *, + path_params: Optional[List[ModelField]] = None, + query_params: Optional[List[ModelField]] = None, + header_params: Optional[List[ModelField]] = None, + cookie_params: Optional[List[ModelField]] = None, + body_params: Optional[List[ModelField]] = None, + return_param: Optional[ModelField] = None, + dependencies: Optional[List["Dependant"]] = None, + name: Optional[str] = None, + call: Optional[Callable[..., Any]] = None, + request_param_name: Optional[str] = None, + websocket_param_name: Optional[str] = None, + http_connection_param_name: Optional[str] = None, + response_param_name: Optional[str] = None, + background_tasks_param_name: Optional[str] = None, + path: Optional[str] = None, + ) -> None: + self.path_params = path_params or [] + self.query_params = query_params or [] + self.header_params = header_params or [] + self.cookie_params = cookie_params or [] + self.body_params = body_params or [] + self.return_param = return_param or None + self.dependencies = dependencies or [] + self.request_param_name = request_param_name + self.websocket_param_name = websocket_param_name + self.http_connection_param_name = http_connection_param_name + self.response_param_name = response_param_name + self.background_tasks_param_name = background_tasks_param_name + self.name = name + self.call = call + # Store the path to be able to re-generate a dependable from it in overrides + self.path = path + # Save the cache key at creation to optimize performance + self.cache_key = self.call + + +class ParamTypes(Enum): + query = "query" + header = "header" + path = "path" + cookie = "cookie" + + +_Unset: Any = Undefined + + +class Param(FieldInfo): + in_: ParamTypes + + def __init__( + self, + default: Any = Undefined, + *, + default_factory: Union[Callable[[], Any], None] = _Unset, + annotation: Optional[Any] = None, + alias: Optional[str] = None, + alias_priority: Union[int, None] = _Unset, + # TODO: update when deprecating Pydantic v1, import these types + # validation_alias: str | AliasPath | AliasChoices | None + validation_alias: Union[str, None] = None, + serialization_alias: Union[str, None] = None, + title: Optional[str] = None, + description: Optional[str] = None, + gt: Optional[float] = None, + ge: Optional[float] = None, + lt: Optional[float] = None, + le: Optional[float] = None, + min_length: Optional[int] = None, + max_length: Optional[int] = None, + pattern: Optional[str] = None, + discriminator: Union[str, None] = None, + strict: Union[bool, None] = _Unset, + multiple_of: Union[float, None] = _Unset, + allow_inf_nan: Union[bool, None] = _Unset, + max_digits: Union[int, None] = _Unset, + decimal_places: Union[int, None] = _Unset, + examples: Optional[List[Any]] = None, + openapi_examples: Optional[Dict[str, Example]] = None, + deprecated: Optional[bool] = None, + include_in_schema: bool = True, + json_schema_extra: Union[Dict[str, Any], None] = None, + **extra: Any, + ): + self.deprecated = deprecated + self.include_in_schema = include_in_schema + self.openapi_examples = openapi_examples + kwargs = dict( + default=default, + default_factory=default_factory, + alias=alias, + title=title, + description=description, + gt=gt, + ge=ge, + lt=lt, + le=le, + min_length=min_length, + max_length=max_length, + discriminator=discriminator, + multiple_of=multiple_of, + allow_nan=allow_inf_nan, + max_digits=max_digits, + decimal_places=decimal_places, + **extra, + ) + if examples is not None: + kwargs["examples"] = examples + + current_json_schema_extra = json_schema_extra or extra + kwargs["regex"] = pattern + kwargs.update(**current_json_schema_extra) + use_kwargs = {k: v for k, v in kwargs.items() if v is not _Unset} + + super().__init__(**use_kwargs) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self.default})" + + +class Path(Param): + in_ = ParamTypes.path + + def __init__( + self, + default: Any = ..., + *, + default_factory: Union[Callable[[], Any], None] = _Unset, + annotation: Optional[Any] = None, + alias: Optional[str] = None, + alias_priority: Union[int, None] = _Unset, + # TODO: update when deprecating Pydantic v1, import these types + # validation_alias: str | AliasPath | AliasChoices | None + validation_alias: Union[str, None] = None, + serialization_alias: Union[str, None] = None, + title: Optional[str] = None, + description: Optional[str] = None, + gt: Optional[float] = None, + ge: Optional[float] = None, + lt: Optional[float] = None, + le: Optional[float] = None, + min_length: Optional[int] = None, + max_length: Optional[int] = None, + pattern: Optional[str] = None, + discriminator: Union[str, None] = None, + strict: Union[bool, None] = _Unset, + multiple_of: Union[float, None] = _Unset, + allow_inf_nan: Union[bool, None] = _Unset, + max_digits: Union[int, None] = _Unset, + decimal_places: Union[int, None] = _Unset, + examples: Optional[List[Any]] = None, + openapi_examples: Optional[Dict[str, Example]] = None, + deprecated: Optional[bool] = None, + include_in_schema: bool = True, + json_schema_extra: Union[Dict[str, Any], None] = None, + **extra: Any, + ): + assert default is ..., "Path parameters cannot have a default value" + self.in_ = self.in_ + super(Path, self).__init__( + default=default, + default_factory=default_factory, + annotation=annotation, + alias=alias, + alias_priority=alias_priority, + validation_alias=validation_alias, + serialization_alias=serialization_alias, + title=title, + description=description, + gt=gt, + ge=ge, + lt=lt, + le=le, + min_length=min_length, + max_length=max_length, + pattern=pattern, + discriminator=discriminator, + strict=strict, + multiple_of=multiple_of, + allow_inf_nan=allow_inf_nan, + max_digits=max_digits, + decimal_places=decimal_places, + deprecated=deprecated, + examples=examples, + openapi_examples=openapi_examples, + include_in_schema=include_in_schema, + json_schema_extra=json_schema_extra, + **extra, + ) + + +class Query(Param): + in_ = ParamTypes.query + + def __init__( + self, + default: Any = Undefined, + *, + default_factory: Union[Callable[[], Any], None] = _Unset, + annotation: Optional[Any] = None, + alias: Optional[str] = None, + alias_priority: Union[int, None] = _Unset, + validation_alias: Union[str, None] = None, + serialization_alias: Union[str, None] = None, + title: Optional[str] = None, + description: Optional[str] = None, + gt: Optional[float] = None, + ge: Optional[float] = None, + lt: Optional[float] = None, + le: Optional[float] = None, + min_length: Optional[int] = None, + max_length: Optional[int] = None, + pattern: Optional[str] = None, + discriminator: Union[str, None] = None, + strict: Union[bool, None] = _Unset, + multiple_of: Union[float, None] = _Unset, + allow_inf_nan: Union[bool, None] = _Unset, + max_digits: Union[int, None] = _Unset, + decimal_places: Union[int, None] = _Unset, + examples: Optional[List[Any]] = None, + openapi_examples: Optional[Dict[str, Example]] = None, + deprecated: Optional[bool] = None, + include_in_schema: bool = True, + json_schema_extra: Union[Dict[str, Any], None] = None, + **extra: Any, + ): + super().__init__( + default=default, + default_factory=default_factory, + annotation=annotation, + alias=alias, + alias_priority=alias_priority, + validation_alias=validation_alias, + serialization_alias=serialization_alias, + title=title, + description=description, + gt=gt, + ge=ge, + lt=lt, + le=le, + min_length=min_length, + max_length=max_length, + pattern=pattern, + discriminator=discriminator, + strict=strict, + multiple_of=multiple_of, + allow_inf_nan=allow_inf_nan, + max_digits=max_digits, + decimal_places=decimal_places, + deprecated=deprecated, + examples=examples, + openapi_examples=openapi_examples, + include_in_schema=include_in_schema, + json_schema_extra=json_schema_extra, + **extra, + ) + + +def analyze_param( + *, + param_name: str, + annotation: Any, + value: Any, + is_path_param: bool, +) -> Tuple[Any, Optional[ModelField]]: + field_info: Optional[FieldInfo] = None + type_annotation: Any = Any + + if annotation is not inspect.Signature.empty and get_origin(annotation) is Annotated: + annotated_args = get_args(annotation) + type_annotation = annotated_args[0] + powertools_annotations = [arg for arg in annotated_args[1:] if isinstance(arg, FieldInfo)] + assert len(powertools_annotations) <= 1 + + powertools_annotation = next(iter(powertools_annotations), None) + + if isinstance(powertools_annotation, FieldInfo): + # Copy `field_info` because we mutate `field_info.default` later + field_info = copy(powertools_annotation) + assert field_info.default is Undefined or field_info.default is Required + if value is not inspect.Signature.empty: + assert not is_path_param + field_info.default = value + else: + field_info.default = Required + elif annotation is not inspect.Signature.empty: + type_annotation = annotation + + if isinstance(value, FieldInfo): + assert field_info is None + field_info = value + + if field_info is None: + default_value = value if value is not inspect.Signature.empty else Required + if is_path_param: + field_info = Path(annotation=type_annotation, default=default_value) + else: + field_info = Query(annotation=type_annotation, default=default_value) + + field = None + if field_info is not None: + if is_path_param: + assert isinstance(field_info, Path) + elif isinstance(field_info, Param) and getattr(field_info, "in_", None) is None: + field_info.in_ = ParamTypes.query + + use_annotation = get_annotation_from_field_info(type_annotation, field_info, param_name) + + if not field_info.alias and getattr(field_info, "convert_underscores", None): + alias = param_name.replace("_", "-") + else: + alias = field_info.alias or param_name + + field_info.alias = alias + + field = ModelField( + name=param_name, + field_info=field_info, + type_=use_annotation, + class_validators={}, + default=field_info.default, + required=field_info.default in (Required, Undefined), + model_config=BaseConfig, + alias=alias, + ) + + return type_annotation, field diff --git a/aws_lambda_powertools/event_handler/openapi/utils.py b/aws_lambda_powertools/event_handler/openapi/utils.py new file mode 100644 index 00000000000..cad6f18975d --- /dev/null +++ b/aws_lambda_powertools/event_handler/openapi/utils.py @@ -0,0 +1,77 @@ +from typing import Any, Callable, List, Optional, Tuple, Type, Union + +from pydantic import BaseConfig +from pydantic.fields import FieldInfo, ModelField, Undefined, UndefinedType + +from aws_lambda_powertools.event_handler.openapi.params import Dependant + +CacheKey = Tuple[Optional[Callable[..., Any]], Tuple[str, ...]] + + +def get_flat_dependant( + dependant: Dependant, + *, + skip_repeats: bool = False, + visited: Optional[List[CacheKey]] = None, +) -> Dependant: + if visited is None: + visited = [] + visited.append(dependant.cache_key) + + flat_dependant = Dependant( + path_params=dependant.path_params.copy(), + query_params=dependant.query_params.copy(), + header_params=dependant.header_params.copy(), + cookie_params=dependant.cookie_params.copy(), + body_params=dependant.body_params.copy(), + path=dependant.path, + ) + for sub_dependant in dependant.dependencies: + if skip_repeats and sub_dependant.cache_key in visited: + continue + + flat_sub = get_flat_dependant(sub_dependant, skip_repeats=skip_repeats, visited=visited) + + flat_dependant.path_params.extend(flat_sub.path_params) + flat_dependant.query_params.extend(flat_sub.query_params) + flat_dependant.header_params.extend(flat_sub.header_params) + flat_dependant.cookie_params.extend(flat_sub.cookie_params) + flat_dependant.body_params.extend(flat_sub.body_params) + + return flat_dependant + + +def get_flat_params(dependant: Dependant) -> List[ModelField]: + flat_dependant = get_flat_dependant(dependant, skip_repeats=True) + return ( + flat_dependant.path_params + + flat_dependant.query_params + + flat_dependant.header_params + + flat_dependant.cookie_params + ) + + +def create_response_field( + name: str, + type_: Type[Any], + default: Optional[Any] = Undefined, + required: Union[bool, UndefinedType] = Undefined, + model_config: Type[BaseConfig] = BaseConfig, + alias: Optional[str] = None, +) -> ModelField: + """ + Create a new response field. + """ + field_info = FieldInfo() + + kwargs = { + "name": name, + "field_info": field_info, + "type_": type_, + "default": default, + "required": required, + "model_config": model_config, + "alias": alias, + "class_validators": {}, + } + return ModelField(**kwargs) diff --git a/aws_lambda_powertools/event_handler/route.py b/aws_lambda_powertools/event_handler/route.py new file mode 100644 index 00000000000..7ca0837a0ae --- /dev/null +++ b/aws_lambda_powertools/event_handler/route.py @@ -0,0 +1,366 @@ +import warnings +from re import Pattern +from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, Union, cast + +from pydantic.fields import ModelField +from pydantic.schema import TypeModelOrEnum, field_schema + +from aws_lambda_powertools.event_handler import Response +from aws_lambda_powertools.event_handler.api_gateway import logger +from aws_lambda_powertools.event_handler.openapi.params import Dependant, Param +from aws_lambda_powertools.event_handler.openapi.utils import get_flat_params + + +class MiddlewareFrame: + """ + creates a Middle Stack Wrapper instance to be used as a "Frame" in the overall stack of + middleware functions. Each instance contains the current middleware and the next + middleware function to be called in the stack. + + In this way the middleware stack is constructed in a recursive fashion, with each middleware + calling the next as a simple function call. The actual Python call-stack will contain + each MiddlewareStackWrapper "Frame", meaning any Middleware function can cause the + entire Middleware call chain to be exited early (short-circuited) by raising an exception + or by simply returning early with a custom Response. The decision to short-circuit the middleware + chain is at the user's discretion but instantly available due to the Wrapped nature of the + callable constructs in the Middleware stack and each Middleware function having complete control over + whether the "Next" handler in the stack is called or not. + + Parameters + ---------- + current_middleware : Callable + The current middleware function to be called as a request is processed. + next_middleware : Callable + The next middleware in the middleware stack. + """ + + def __init__( + self, + current_middleware: Callable[..., Any], + next_middleware: Callable[..., Any], + ) -> None: + self.current_middleware: Callable[..., Any] = current_middleware + self.next_middleware: Callable[..., Any] = next_middleware + self._next_middleware_name = next_middleware.__name__ + + @property + def __name__(self) -> str: # noqa: A003 + """Current middleware name + + It ensures backward compatibility with view functions being callable. This + improves debugging since we need both current and next middlewares/callable names. + """ + return self.current_middleware.__name__ + + def __str__(self) -> str: + """Identify current middleware identity and call chain for debugging purposes.""" + middleware_name = self.__name__ + return f"[{middleware_name}] next call chain is {middleware_name} -> {self._next_middleware_name}" + + def __call__(self, app: "ApiGatewayResolver") -> Union[Dict, Tuple, Response]: + """ + Call the middleware Frame to process the request. + + Parameters + ---------- + app: BaseRouter + The router instance + + Returns + ------- + Union[Dict, Tuple, Response] + (tech-debt for backward compatibility). The response type should be a + Response object in all cases excepting when the original API route handler + is called which will return one of 3 outputs. + + """ + # Do debug printing and push processed stack frame AFTER calling middleware + # else the stack frame text of `current calling next` is confusing. + logger.debug("MiddlewareFrame: %s", self) + app._push_processed_stack_frame(str(self)) + + return self.current_middleware(app, self.next_middleware) + + +class Route: + """Internally used Route Configuration""" + + def __init__( + self, + method: str, + path: str, + rule: Pattern, + func: Callable, + cors: bool, + compress: bool, + cache_control: Optional[str], + middlewares: Optional[List[Callable[..., Response]]], + ): + """ + + Parameters + ---------- + + method: str + The HTTP method, example "GET" + rule: Pattern + The route rule, example "/my/path" + path: str + The path of the route + func: Callable + The route handler function + cors: bool + Whether or not to enable CORS for this route + compress: bool + Whether or not to enable gzip compression for this route + cache_control: Optional[str] + The cache control header value, example "max-age=3600" + middlewares: Optional[List[Callable[..., Response]]] + The list of route middlewares to be called in order. + """ + self.method = method.upper() + self.path = path + self.rule = rule + self.func = func + self._middleware_stack = func + self.cors = cors + self.compress = compress + self.cache_control = cache_control + self.middlewares = middlewares or [] + self.operation_id = self.method.title() + self.func.__name__.title() + + # _middleware_stack_built is used to ensure the middleware stack is only built once. + self._middleware_stack_built = False + + def __call__( + self, + router_middlewares: List[Callable], + app: "ApiGatewayResolver", + route_arguments: Dict[str, str], + ) -> Union[Dict, Tuple, Response]: + """Calling the Router class instance will trigger the following actions: + 1. If Route Middleware stack has not been built, build it + 2. Call the Route Middleware stack wrapping the original function + handler with the app and route arguments. + + Parameters + ---------- + router_middlewares: List[Callable] + The list of Router Middlewares (assigned to ALL routes) + app: "ApiGatewayResolver" + The ApiGatewayResolver instance to pass into the middleware stack + route_arguments: Dict[str, str] + The route arguments to pass to the app function (extracted from the Api Gateway + Lambda Message structure from AWS) + + Returns + ------- + Union[Dict, Tuple, Response] + API Response object in ALL cases, except when the original API route + handler is called which may also return a Dict, Tuple, or Response. + """ + + # Save CPU cycles by building middleware stack once + if not self._middleware_stack_built: + self._build_middleware_stack(router_middlewares=router_middlewares) + + # If debug is turned on then output the middleware stack to the console + if app._debug: + print(f"\nProcessing Route:::{self.func.__name__} ({app.context['_path']})") + # Collect ALL middleware for debug printing - include internal _registered_api_adapter + all_middlewares = router_middlewares + self.middlewares + [_registered_api_adapter] + print("\nMiddleware Stack:") + print("=================") + print("\n".join(getattr(item, "__name__", "Unknown") for item in all_middlewares)) + print("=================") + + # Add Route Arguments to app context + app.append_context(_route_args=route_arguments) + + # Call the Middleware Wrapped _call_stack function handler with the app + return self._middleware_stack(app) + + def _build_middleware_stack(self, router_middlewares: List[Callable[..., Any]]) -> None: + """ + Builds the middleware stack for the handler by wrapping each + handler in an instance of MiddlewareWrapper which is used to contain the state + of each middleware step. + + Middleware is represented by a standard Python Callable construct. Any Middleware + handler wanting to short-circuit the middlware call chain can raise an exception + to force the Python call stack created by the handler call-chain to naturally un-wind. + + This becomes a simple concept for developers to understand and reason with - no additional + gymanstics other than plain old try ... except. + + Notes + ----- + The Route Middleware stack is processed in reverse order. This is so the stack of + middleware handlers is applied in the order of being added to the handler. + """ + all_middlewares = router_middlewares + self.middlewares + logger.debug(f"Building middleware stack: {all_middlewares}") + + # IMPORTANT: + # this must be the last middleware in the stack (tech debt for backward + # compatibility purposes) + # + # This adapter will: + # 1. Call the registered API passing only the expected route arguments extracted from the path + # and not the middleware. + # 2. Adapt the response type of the route handler (Union[Dict, Tuple, Response]) + # and normalise into a Response object so middleware will always have a constant signature + all_middlewares.append(_registered_api_adapter) + + # Wrap the original route handler function in the middleware handlers + # using the MiddlewareWrapper class callable construct in reverse order to + # ensure middleware is applied in the order the user defined. + # + # Start with the route function and wrap from last to the first Middleware handler. + for handler in reversed(all_middlewares): + self._middleware_stack = MiddlewareFrame(current_middleware=handler, next_middleware=self._middleware_stack) + + self._middleware_stack_built = True + + def _openapi_path( + self, + *, + dependant: Dependant, + operation_ids: Set[str], + model_name_map: Dict[TypeModelOrEnum, str], + ) -> Tuple[Dict[str, Any], Dict[str, Any]]: + path = {} + definitions: Dict[str, Any] = {} + + operation = self._openapi_operation_metadata(operation_ids=operation_ids) + parameters: List[Dict[str, Any]] = [] + all_route_params = get_flat_params(dependant) + operation_params = self._openapi_operation_parameters( + all_route_params=all_route_params, + model_name_map=model_name_map, + ) + + parameters.extend(operation_params) + if parameters: + all_parameters = {(param["in"], param["name"]): param for param in parameters} + required_parameters = {(param["in"], param["name"]): param for param in parameters if param.get("required")} + all_parameters.update(required_parameters) + operation["parameters"] = list(all_parameters.values()) + + responses = operation.setdefault("responses", {}) + success_response = responses.setdefault("200", {}) + success_response["description"] = "Success" + success_response["content"] = {"application/json": {"schema": {}}} + json_response = success_response["content"].setdefault("application/json", {}) + + json_response["schema"] = self._openapi_operation_return( + operation_id=self.operation_id, + param=dependant.return_param, + model_name_map=model_name_map, + ) + + path[self.method.lower()] = operation + + # Generate the response schema + return path, definitions + + def _openapi_operation_summary(self): + # TODO: add name, summary to Route, and allow it to be customized during creation + self.rule.__str__().replace("_", " ").title() + + def _openapi_operation_metadata(self, operation_ids: Set[str]) -> Dict[str, Any]: + operation: Dict[str, Any] = {"summary": self._openapi_operation_summary()} + + # TODO: description, tags + operation_id = self.operation_id + if operation_id in operation_ids: + message = f"Duplicate Operation ID {operation_id} for function {self.func.__name__}" + file_name = getattr(self.func, "__globals__", {}).get("__file__") + if file_name: + message += f" in {file_name}" + warnings.warn(message, stacklevel=1) + operation_ids.add(operation_id) + operation["operationId"] = operation_id + return operation + + @staticmethod + def _openapi_operation_parameters( + *, + all_route_params: Sequence[ModelField], + model_name_map: Dict[TypeModelOrEnum, str], + ) -> List[Dict[str, Any]]: + parameters = [] + for param in all_route_params: + field_info = param.field_info + field_info = cast(Param, field_info) + if not field_info.include_in_schema: + continue + + param_schema = field_schema(param, model_name_map=model_name_map, ref_prefix="#/components/schemas/")[0] + + parameter = { + "name": param.alias, + "in": field_info.in_.value, + "required": param.required, + "schema": param_schema, + } + + if field_info.description: + parameter["description"] = field_info.description + + if field_info.deprecated: + parameter["deprecated"] = field_info.deprecated + + parameters.append(parameter) + + return parameters + + @staticmethod + def _openapi_operation_return( + *, + operation_id: str, + param: Optional[ModelField], + model_name_map: Dict[TypeModelOrEnum, str], + ) -> Dict[str, Any]: + if param is None: + return {} + + return_schema = field_schema( + param, + model_name_map=model_name_map, + ref_prefix="#/components/schemas/", + )[0] + + return {"name": f"Return {operation_id}", "schema": return_schema} + + +def _registered_api_adapter( + app: "ApiGatewayResolver", + next_middleware: Callable[..., Any], +) -> Union[Dict, Tuple, Response]: + """ + Calls the registered API using the "_route_args" from the Resolver context to ensure the last call + in the chain will match the API route function signature and ensure that Powertools passes the API + route handler the expected arguments. + + **IMPORTANT: This internal middleware ensures the actual API route is called with the correct call signature + and it MUST be the final frame in the middleware stack. This can only be removed when the API Route + function accepts `app: BaseRouter` as the first argument - which is the breaking change. + + Parameters + ---------- + app: ApiGatewayResolver + The API Gateway resolver + next_middleware: Callable[..., Any] + The function to handle the API + + Returns + ------- + Response + The API Response Object + + """ + route_args: Dict = app.context.get("_route_args", {}) + logger.debug(f"Calling API Route Handler: {route_args}") + + return app._to_response(next_middleware(**route_args)) From e74cee6c700f02ccdd163b1fe8ba43a22056ee59 Mon Sep 17 00:00:00 2001 From: Ruben Fonseca Date: Mon, 9 Oct 2023 15:58:05 +0200 Subject: [PATCH 02/75] fix: resolver circular dependencies --- .../event_handler/__init__.py | 2 +- .../event_handler/api_gateway.py | 1 + .../event_handler/middlewares/base.py | 2 +- .../middlewares/schema_validation.py | 2 +- .../event_handler/response.py | 41 +++++++++++++++++++ aws_lambda_powertools/event_handler/route.py | 6 ++- .../src/binary_responses.py | 2 +- 7 files changed, 50 insertions(+), 6 deletions(-) create mode 100644 aws_lambda_powertools/event_handler/response.py diff --git a/aws_lambda_powertools/event_handler/__init__.py b/aws_lambda_powertools/event_handler/__init__.py index 7bdd9a97f72..14372784adb 100644 --- a/aws_lambda_powertools/event_handler/__init__.py +++ b/aws_lambda_powertools/event_handler/__init__.py @@ -8,12 +8,12 @@ ApiGatewayResolver, APIGatewayRestResolver, CORSConfig, - Response, ) from aws_lambda_powertools.event_handler.appsync import AppSyncResolver from aws_lambda_powertools.event_handler.lambda_function_url import ( LambdaFunctionUrlResolver, ) +from aws_lambda_powertools.event_handler.response import Response from aws_lambda_powertools.event_handler.vpc_lattice import VPCLatticeResolver, VPCLatticeV2Resolver __all__ = [ diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index d148b5b9ae3..bc28c9f541c 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -32,6 +32,7 @@ from aws_lambda_powertools.event_handler.openapi.dependant import get_dependant from aws_lambda_powertools.event_handler.openapi.models import Contact, License, OpenAPI, Server, Tag from aws_lambda_powertools.event_handler.openapi.utils import get_flat_params +from aws_lambda_powertools.event_handler.response import Response from aws_lambda_powertools.event_handler.route import Route from aws_lambda_powertools.shared.cookies import Cookie from aws_lambda_powertools.shared.functions import powertools_dev_is_set diff --git a/aws_lambda_powertools/event_handler/middlewares/base.py b/aws_lambda_powertools/event_handler/middlewares/base.py index fb4bf37cc74..a6b1bff6d4a 100644 --- a/aws_lambda_powertools/event_handler/middlewares/base.py +++ b/aws_lambda_powertools/event_handler/middlewares/base.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod from typing import Generic -from aws_lambda_powertools.event_handler.api_gateway import Response +from aws_lambda_powertools.event_handler import Response from aws_lambda_powertools.event_handler.types import EventHandlerInstance from aws_lambda_powertools.shared.types import Protocol diff --git a/aws_lambda_powertools/event_handler/middlewares/schema_validation.py b/aws_lambda_powertools/event_handler/middlewares/schema_validation.py index 66be47a48f3..a4d3a1c17ab 100644 --- a/aws_lambda_powertools/event_handler/middlewares/schema_validation.py +++ b/aws_lambda_powertools/event_handler/middlewares/schema_validation.py @@ -1,7 +1,7 @@ import logging from typing import Dict, Optional -from aws_lambda_powertools.event_handler.api_gateway import Response +from aws_lambda_powertools.event_handler import Response from aws_lambda_powertools.event_handler.exceptions import BadRequestError, InternalServerError from aws_lambda_powertools.event_handler.middlewares import BaseMiddlewareHandler, NextMiddleware from aws_lambda_powertools.event_handler.types import EventHandlerInstance diff --git a/aws_lambda_powertools/event_handler/response.py b/aws_lambda_powertools/event_handler/response.py new file mode 100644 index 00000000000..3c5ffd0152d --- /dev/null +++ b/aws_lambda_powertools/event_handler/response.py @@ -0,0 +1,41 @@ +from typing import Dict, List, Optional, Union + +from aws_lambda_powertools.shared.cookies import Cookie + + +class Response: + """Response data class that provides greater control over what is returned from the proxy event""" + + def __init__( + self, + status_code: int, + content_type: Optional[str] = None, + body: Union[str, bytes, None] = None, + headers: Optional[Dict[str, Union[str, List[str]]]] = None, + cookies: Optional[List[Cookie]] = None, + compress: Optional[bool] = None, + ): + """ + + Parameters + ---------- + status_code: int + Http status code, example 200 + content_type: str + Optionally set the Content-Type header, example "application/json". Note this will be merged into any + provided http headers + body: Union[str, bytes, None] + Optionally set the response body. Note: bytes body will be automatically base64 encoded + headers: dict[str, Union[str, List[str]]] + Optionally set specific http headers. Setting "Content-Type" here would override the `content_type` value. + cookies: list[Cookie] + Optionally set cookies. + """ + self.status_code = status_code + self.body = body + self.base64_encoded = False + self.headers: Dict[str, Union[str, List[str]]] = headers if headers else {} + self.cookies = cookies or [] + self.compress = compress + if content_type: + self.headers.setdefault("Content-Type", content_type) diff --git a/aws_lambda_powertools/event_handler/route.py b/aws_lambda_powertools/event_handler/route.py index 7ca0837a0ae..ee631f8b98b 100644 --- a/aws_lambda_powertools/event_handler/route.py +++ b/aws_lambda_powertools/event_handler/route.py @@ -1,3 +1,4 @@ +import logging import warnings from re import Pattern from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, Union, cast @@ -5,10 +6,11 @@ from pydantic.fields import ModelField from pydantic.schema import TypeModelOrEnum, field_schema -from aws_lambda_powertools.event_handler import Response -from aws_lambda_powertools.event_handler.api_gateway import logger from aws_lambda_powertools.event_handler.openapi.params import Dependant, Param from aws_lambda_powertools.event_handler.openapi.utils import get_flat_params +from aws_lambda_powertools.event_handler.response import Response + +logger = logging.getLogger(__name__) class MiddlewareFrame: diff --git a/examples/event_handler_rest/src/binary_responses.py b/examples/event_handler_rest/src/binary_responses.py index f91dc879402..0c6d15a0e8c 100644 --- a/examples/event_handler_rest/src/binary_responses.py +++ b/examples/event_handler_rest/src/binary_responses.py @@ -2,9 +2,9 @@ from pathlib import Path from aws_lambda_powertools import Logger, Tracer +from aws_lambda_powertools.event_handler import Response from aws_lambda_powertools.event_handler.api_gateway import ( APIGatewayRestResolver, - Response, ) from aws_lambda_powertools.logging import correlation_paths from aws_lambda_powertools.utilities.typing import LambdaContext From 510ad25e080a5c43cd3c124f94a139f8bca469ca Mon Sep 17 00:00:00 2001 From: Ruben Fonseca Date: Mon, 9 Oct 2023 15:59:16 +0200 Subject: [PATCH 03/75] fix: rebase --- aws_lambda_powertools/event_handler/api_gateway.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index bc28c9f541c..49473434aa5 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -4,6 +4,7 @@ import re import traceback import warnings +import zlib from abc import ABC, abstractmethod from enum import Enum from functools import partial @@ -23,7 +24,6 @@ Union, ) -import zlib from pydantic.fields import ModelField from pydantic.schema import get_flat_models_from_fields, get_model_name_map, model_process_schema From a2c1c9252bb1debe573dd8467fac5f462fceb9f4 Mon Sep 17 00:00:00 2001 From: Ruben Fonseca Date: Mon, 25 Sep 2023 17:39:56 +0200 Subject: [PATCH 04/75] fix: document the new methods --- .../event_handler/api_gateway.py | 2 +- .../event_handler/openapi/dependant.py | 85 +++++++++++++++++- .../event_handler/openapi/models.py | 28 ++++++ .../event_handler/openapi/params.py | 88 +++++++++++++------ .../event_handler/openapi/utils.py | 70 +++++++++------ 5 files changed, 215 insertions(+), 58 deletions(-) diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index 49473434aa5..05d4599eb11 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -687,7 +687,7 @@ def get_openapi_schema( for route in all_routes: dependant = get_dependant( - path=route.func.__name__, + path=route.path, call=route.func, ) diff --git a/aws_lambda_powertools/event_handler/openapi/dependant.py b/aws_lambda_powertools/event_handler/openapi/dependant.py index 8ebc2f84caf..541face1387 100644 --- a/aws_lambda_powertools/event_handler/openapi/dependant.py +++ b/aws_lambda_powertools/event_handler/openapi/dependant.py @@ -7,12 +7,38 @@ from aws_lambda_powertools.event_handler.openapi.params import Dependant, Param, ParamTypes, analyze_param +""" +This turns the opaque function signature into typed, validated models. + +It relies on Pydantic's typing and validation to achieve this in a declarative way. +This enables traits like autocompletion, validation, and declarative structure vs imperative parsing. + +This code parses an OpenAPI operation handler function signature into Pydantic models. It uses inspect to get the +signature and regex to parse path parameters. Each parameter is analyzed to extract its type annotation and generate +a corresponding Pydantic field, which are added to a Dependant model. Return values are handled similarly. + +This modeling allows for type checking, automatic parameter name/location/type extraction, and input validation - +turning the opaque signature into validated models. It relies on Pydantic's typing and validation for a declarative +approach over imperative parsing, enabling autocompletion, validation and structure. +""" + def add_param_to_fields( *, field: ModelField, dependant: Dependant, ) -> None: + """ + Adds a parameter to the list of parameters in the dependant model. + + Parameters + ---------- + field: ModelField + The field to add + dependant: Dependant + The dependant model to add the field to + + """ field_info = cast(Param, field.field_info) if field_info.in_ == ParamTypes.path: dependant.path_params.append(field) @@ -26,6 +52,9 @@ def add_param_to_fields( def get_typed_annotation(annotation: Any, globalns: Dict[str, Any]) -> Any: + """ + Evaluates a type annotation, which can be a string or a ForwardRef. + """ if isinstance(annotation, str): annotation = ForwardRef(annotation) annotation = evaluate_forwardref(annotation, globalns, globalns) @@ -33,8 +62,24 @@ def get_typed_annotation(annotation: Any, globalns: Dict[str, Any]) -> Any: def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature: + """ + Returns a typed signature for a callable, resolving forward references. + + Parameters + ---------- + call: Callable[..., Any] + The callable to get the signature for + + Returns + ------- + inspect.Signature + The typed signature + """ signature = inspect.signature(call) + + # Gets the global namespace for the call. This is used to resolve forward references. globalns = getattr(call, "__global__", {}) + typed_params = [ inspect.Parameter( name=param.name, @@ -45,6 +90,7 @@ def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature: for param in signature.parameters.values() ] + # If the return annotation is not empty, add it to the signature. if signature.return_annotation is not inspect.Signature.empty: return_param = inspect.Parameter( name="Return", @@ -58,7 +104,21 @@ def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature: def get_path_param_names(path: str) -> Set[str]: - return set(re.findall("{(.*?)}", path)) + """ + Returns the path parameter names from a path template. Those are the strings between < and >. + + Parameters + ---------- + path: str + The path template + + Returns + ------- + Set[str] + The path parameter names + + """ + return set(re.findall("<(.*?)>", path)) def get_dependant( @@ -67,9 +127,28 @@ def get_dependant( call: Callable[..., Any], name: Optional[str] = None, ) -> Dependant: + """ + Returns a dependant model for a handler function. A dependant model is a model that contains + the parameters and return value of a handler function. + + Parameters + ---------- + path: str + The path template + call: Callable[..., Any] + The handler function + name: str, optional + The name of the handler function + + Returns + ------- + Dependant + The dependant model for the handler function + """ path_param_names = get_path_param_names(path) endpoint_signature = get_typed_signature(call) signature_params = endpoint_signature.parameters + dependant = Dependant( call=call, name=name, @@ -77,7 +156,10 @@ def get_dependant( ) for param_name, param in signature_params.items(): + # If the parameter is a path parameter, we need to set the in_ field to "path". is_path_param = param_name in path_param_names + + # Analyze the parameter to get the type annotation and the Pydantic field. type_annotation, param_field = analyze_param( param_name=param_name, annotation=param.annotation, @@ -88,6 +170,7 @@ def get_dependant( add_param_to_fields(field=param_field, dependant=dependant) + # If the return annotation is not empty, add it to the dependant model. return_annotation = endpoint_signature.return_annotation if return_annotation is not inspect.Signature.empty: type_annotation, param_field = analyze_param( diff --git a/aws_lambda_powertools/event_handler/openapi/models.py b/aws_lambda_powertools/event_handler/openapi/models.py index e492416e30a..b5ed0d9215b 100644 --- a/aws_lambda_powertools/event_handler/openapi/models.py +++ b/aws_lambda_powertools/event_handler/openapi/models.py @@ -7,7 +7,13 @@ PYDANTIC_V2 = PYDANTIC_VERSION.startswith("2.") +""" +The code defines Pydantic models for the various OpenAPI objects like OpenAPI, PathItem, Operation, Parameter etc. +These models can be used to parse OpenAPI JSON/YAML files into Python objects, or generate OpenAPI from Python data. +""" + +# https://swagger.io/specification/#contact-object class Contact(BaseModel): name: Optional[str] = None url: Optional[AnyUrl] = None @@ -21,6 +27,7 @@ class Config: extra = "allow" +# https://swagger.io/specification/#license-object class License(BaseModel): name: str identifier: Optional[str] = None @@ -35,6 +42,7 @@ class Config: extra = "allow" +# https://swagger.io/specification/#info-object class Info(BaseModel): title: str summary: Optional[str] = None @@ -53,6 +61,7 @@ class Config: extra = "allow" +# https://swagger.io/specification/#server-variable-object class ServerVariable(BaseModel): enum: Annotated[Optional[List[str]], Field(min_length=1)] = None default: str @@ -67,6 +76,7 @@ class Config: extra = "allow" +# https://swagger.io/specification/#server-object class Server(BaseModel): url: Union[AnyUrl, str] description: Optional[str] = None @@ -81,15 +91,18 @@ class Config: extra = "allow" +# https://swagger.io/specification/#reference-object class Reference(BaseModel): ref: str = Field(alias="$ref") +# https://swagger.io/specification/#discriminator-object class Discriminator(BaseModel): propertyName: str mapping: Optional[Dict[str, str]] = None +# https://swagger.io/specification/#xml-object class XML(BaseModel): name: Optional[str] = None namespace: Optional[str] = None @@ -106,6 +119,7 @@ class Config: extra = "allow" +# https://swagger.io/specification/#external-documentation-object class ExternalDocumentation(BaseModel): description: Optional[str] = None url: AnyUrl @@ -119,6 +133,7 @@ class Config: extra = "allow" +# https://swagger.io/specification/#schema-object class Schema(BaseModel): # Ref: JSON Schema 2020-12: https://json-schema.org/draft/2020-12/json-schema-core.html#name-the-json-schema-core-vocabu # Core Vocabulary @@ -212,6 +227,7 @@ class Config: SchemaOrBool = Union[Schema, bool] +# https://swagger.io/specification/#example-object class Example(BaseModel): summary: Optional[str] = None description: Optional[str] = None @@ -234,6 +250,7 @@ class ParameterInType(Enum): cookie = "cookie" +# https://swagger.io/specification/#encoding-object class Encoding(BaseModel): contentType: Optional[str] = None headers: Optional[Dict[str, Union["Header", Reference]]] = None @@ -250,6 +267,7 @@ class Config: extra = "allow" +# https://swagger.io/specification/#media-type-object class MediaType(BaseModel): schema_: Optional[Union[Schema, Reference]] = Field(default=None, alias="schema") example: Optional[Any] = None @@ -265,6 +283,7 @@ class Config: extra = "allow" +# https://swagger.io/specification/#parameter-object class ParameterBase(BaseModel): description: Optional[str] = None required: Optional[bool] = None @@ -297,6 +316,7 @@ class Header(ParameterBase): pass +# https://swagger.io/specification/#request-body-object class RequestBody(BaseModel): description: Optional[str] = None content: Dict[str, MediaType] @@ -311,6 +331,7 @@ class Config: extra = "allow" +# https://swagger.io/specification/#link-object class Link(BaseModel): operationRef: Optional[str] = None operationId: Optional[str] = None @@ -328,6 +349,7 @@ class Config: extra = "allow" +# https://swagger.io/specification/#response-object class Response(BaseModel): description: str headers: Optional[Dict[str, Union[Header, Reference]]] = None @@ -343,6 +365,7 @@ class Config: extra = "allow" +# https://swagger.io/specification/#operation-object class Operation(BaseModel): tags: Optional[List[str]] = None summary: Optional[str] = None @@ -367,6 +390,7 @@ class Config: extra = "allow" +# https://swagger.io/specification/#path-item-object class PathItem(BaseModel): ref: Optional[str] = Field(default=None, alias="$ref") summary: Optional[str] = None @@ -391,6 +415,7 @@ class Config: extra = "allow" +# https://swagger.io/specification/#security-scheme-object class SecuritySchemeType(Enum): apiKey = "apiKey" http = "http" @@ -494,6 +519,7 @@ class OpenIdConnect(SecurityBase): SecurityScheme = Union[APIKey, HTTPBase, OAuth2, OpenIdConnect, HTTPBearer] +# https://swagger.io/specification/#components-object class Components(BaseModel): schemas: Optional[Dict[str, Union[Schema, Reference]]] = None responses: Optional[Dict[str, Union[Response, Reference]]] = None @@ -516,6 +542,7 @@ class Config: extra = "allow" +# https://swagger.io/specification/#tag-object class Tag(BaseModel): name: str description: Optional[str] = None @@ -530,6 +557,7 @@ class Config: extra = "allow" +# https://swagger.io/specification/#openapi-object class OpenAPI(BaseModel): openapi: str info: Info diff --git a/aws_lambda_powertools/event_handler/openapi/params.py b/aws_lambda_powertools/event_handler/openapi/params.py index 63e1cb14c0b..baba6d6d662 100644 --- a/aws_lambda_powertools/event_handler/openapi/params.py +++ b/aws_lambda_powertools/event_handler/openapi/params.py @@ -9,8 +9,16 @@ from aws_lambda_powertools.event_handler.openapi import Example +""" +This turns the low-level function signature into typed, validated Pydantic models for consumption. +""" + class Dependant: + """ + A class used internally to represent a dependency between path operation decorators and the path operation function. + """ + def __init__( self, *, @@ -57,9 +65,6 @@ class ParamTypes(Enum): cookie = "cookie" -_Unset: Any = Undefined - - class Param(FieldInfo): in_: ParamTypes @@ -67,10 +72,10 @@ def __init__( self, default: Any = Undefined, *, - default_factory: Union[Callable[[], Any], None] = _Unset, + default_factory: Union[Callable[[], Any], None] = Undefined, annotation: Optional[Any] = None, alias: Optional[str] = None, - alias_priority: Union[int, None] = _Unset, + alias_priority: Union[int, None] = Undefined, # TODO: update when deprecating Pydantic v1, import these types # validation_alias: str | AliasPath | AliasChoices | None validation_alias: Union[str, None] = None, @@ -85,11 +90,11 @@ def __init__( max_length: Optional[int] = None, pattern: Optional[str] = None, discriminator: Union[str, None] = None, - strict: Union[bool, None] = _Unset, - multiple_of: Union[float, None] = _Unset, - allow_inf_nan: Union[bool, None] = _Unset, - max_digits: Union[int, None] = _Unset, - decimal_places: Union[int, None] = _Unset, + strict: Union[bool, None] = Undefined, + multiple_of: Union[float, None] = Undefined, + allow_inf_nan: Union[bool, None] = Undefined, + max_digits: Union[int, None] = Undefined, + decimal_places: Union[int, None] = Undefined, examples: Optional[List[Any]] = None, openapi_examples: Optional[Dict[str, Example]] = None, deprecated: Optional[bool] = None, @@ -125,7 +130,7 @@ def __init__( current_json_schema_extra = json_schema_extra or extra kwargs["regex"] = pattern kwargs.update(**current_json_schema_extra) - use_kwargs = {k: v for k, v in kwargs.items() if v is not _Unset} + use_kwargs = {k: v for k, v in kwargs.items() if v is not Undefined} super().__init__(**use_kwargs) @@ -140,10 +145,10 @@ def __init__( self, default: Any = ..., *, - default_factory: Union[Callable[[], Any], None] = _Unset, + default_factory: Union[Callable[[], Any], None] = Undefined, annotation: Optional[Any] = None, alias: Optional[str] = None, - alias_priority: Union[int, None] = _Unset, + alias_priority: Union[int, None] = Undefined, # TODO: update when deprecating Pydantic v1, import these types # validation_alias: str | AliasPath | AliasChoices | None validation_alias: Union[str, None] = None, @@ -158,11 +163,11 @@ def __init__( max_length: Optional[int] = None, pattern: Optional[str] = None, discriminator: Union[str, None] = None, - strict: Union[bool, None] = _Unset, - multiple_of: Union[float, None] = _Unset, - allow_inf_nan: Union[bool, None] = _Unset, - max_digits: Union[int, None] = _Unset, - decimal_places: Union[int, None] = _Unset, + strict: Union[bool, None] = Undefined, + multiple_of: Union[float, None] = Undefined, + allow_inf_nan: Union[bool, None] = Undefined, + max_digits: Union[int, None] = Undefined, + decimal_places: Union[int, None] = Undefined, examples: Optional[List[Any]] = None, openapi_examples: Optional[Dict[str, Example]] = None, deprecated: Optional[bool] = None, @@ -211,10 +216,10 @@ def __init__( self, default: Any = Undefined, *, - default_factory: Union[Callable[[], Any], None] = _Unset, + default_factory: Union[Callable[[], Any], None] = Undefined, annotation: Optional[Any] = None, alias: Optional[str] = None, - alias_priority: Union[int, None] = _Unset, + alias_priority: Union[int, None] = Undefined, validation_alias: Union[str, None] = None, serialization_alias: Union[str, None] = None, title: Optional[str] = None, @@ -227,11 +232,11 @@ def __init__( max_length: Optional[int] = None, pattern: Optional[str] = None, discriminator: Union[str, None] = None, - strict: Union[bool, None] = _Unset, - multiple_of: Union[float, None] = _Unset, - allow_inf_nan: Union[bool, None] = _Unset, - max_digits: Union[int, None] = _Unset, - decimal_places: Union[int, None] = _Unset, + strict: Union[bool, None] = Undefined, + multiple_of: Union[float, None] = Undefined, + allow_inf_nan: Union[bool, None] = Undefined, + max_digits: Union[int, None] = Undefined, + decimal_places: Union[int, None] = Undefined, examples: Optional[List[Any]] = None, openapi_examples: Optional[Dict[str, Example]] = None, deprecated: Optional[bool] = None, @@ -278,9 +283,29 @@ def analyze_param( value: Any, is_path_param: bool, ) -> Tuple[Any, Optional[ModelField]]: + """ + Analyze a parameter annotation and value to determine the type and default value of the parameter. + + Parameters + ---------- + param_name: str + The name of the parameter + annotation + The annotation of the parameter + value + The value of the parameter + is_path_param + Whether the parameter is a path parameter + + Returns + ------- + Tuple[Any, Optional[ModelField]] + The type annotation and the Pydantic field representing the parameter + """ field_info: Optional[FieldInfo] = None type_annotation: Any = Any + # If the annotation is an Annotated type, we need to extract the type annotation and the FieldInfo if annotation is not inspect.Signature.empty and get_origin(annotation) is Annotated: annotated_args = get_args(annotation) type_annotation = annotated_args[0] @@ -293,41 +318,50 @@ def analyze_param( # Copy `field_info` because we mutate `field_info.default` later field_info = copy(powertools_annotation) assert field_info.default is Undefined or field_info.default is Required + if value is not inspect.Signature.empty: assert not is_path_param field_info.default = value else: field_info.default = Required + + # If the annotation is not an Annotated type, we use it as the type annotation elif annotation is not inspect.Signature.empty: type_annotation = annotation + # If the value is a FieldInfo, we use it as the FieldInfo for the parameter if isinstance(value, FieldInfo): assert field_info is None field_info = value + # If we didn't determine the FieldInfo yet, we create a default one if field_info is None: default_value = value if value is not inspect.Signature.empty else Required + if is_path_param: field_info = Path(annotation=type_annotation, default=default_value) else: field_info = Query(annotation=type_annotation, default=default_value) + # Now that we have the FieldInfo, we can determine the type annotation field = None if field_info is not None: if is_path_param: - assert isinstance(field_info, Path) + assert isinstance(field_info, Path), "Path parameters must be of type Path" elif isinstance(field_info, Param) and getattr(field_info, "in_", None) is None: field_info.in_ = ParamTypes.query + # If the field_info is a Param, we use the `in_` attribute to determine the type annotation use_annotation = get_annotation_from_field_info(type_annotation, field_info, param_name) + # If the field doesn't have a defined alias, we use the param name if not field_info.alias and getattr(field_info, "convert_underscores", None): alias = param_name.replace("_", "-") else: alias = field_info.alias or param_name - field_info.alias = alias + # Create the Pydantic field field = ModelField( name=param_name, field_info=field_info, diff --git a/aws_lambda_powertools/event_handler/openapi/utils.py b/aws_lambda_powertools/event_handler/openapi/utils.py index cad6f18975d..c01d4f45885 100644 --- a/aws_lambda_powertools/event_handler/openapi/utils.py +++ b/aws_lambda_powertools/event_handler/openapi/utils.py @@ -1,12 +1,15 @@ -from typing import Any, Callable, List, Optional, Tuple, Type, Union +from typing import Any, Callable, List, Optional, Tuple -from pydantic import BaseConfig -from pydantic.fields import FieldInfo, ModelField, Undefined, UndefinedType +from pydantic.fields import ModelField from aws_lambda_powertools.event_handler.openapi.params import Dependant CacheKey = Tuple[Optional[Callable[..., Any]], Tuple[str, ...]] +""" +This file defines utility functions for working with OpenAPI/JSON Schema models. +""" + def get_flat_dependant( dependant: Dependant, @@ -14,6 +17,27 @@ def get_flat_dependant( skip_repeats: bool = False, visited: Optional[List[CacheKey]] = None, ) -> Dependant: + """ + Flatten a recursive Dependant model structure. + + This function recursively concatenates the parameter fields of a Dependant model and its dependencies into a flat + Dependant structure. This is useful for scenarios like parameter validation where the nested structure is not + relevant. + + Parameters + ---------- + dependant: Dependant + The dependant model to flatten + skip_repeats: bool + If True, child Dependents already visited will be skipped to avoid duplicates + visited: List[CacheKey], optional + Keeps track of visited Dependents to avoid infinite recursion. Defaults to empty list. + + Returns + ------- + Dependant + The flattened Dependant model + """ if visited is None: visited = [] visited.append(dependant.cache_key) @@ -42,6 +66,20 @@ def get_flat_dependant( def get_flat_params(dependant: Dependant) -> List[ModelField]: + """ + Get a list of all the parameters from a Dependant object. + + Parameters + ---------- + dependant : Dependant + The Dependant object containing the parameters. + + Returns + ------- + List[ModelField] + A list of ModelField objects containing the flat parameters from the Dependant object. + + """ flat_dependant = get_flat_dependant(dependant, skip_repeats=True) return ( flat_dependant.path_params @@ -49,29 +87,3 @@ def get_flat_params(dependant: Dependant) -> List[ModelField]: + flat_dependant.header_params + flat_dependant.cookie_params ) - - -def create_response_field( - name: str, - type_: Type[Any], - default: Optional[Any] = Undefined, - required: Union[bool, UndefinedType] = Undefined, - model_config: Type[BaseConfig] = BaseConfig, - alias: Optional[str] = None, -) -> ModelField: - """ - Create a new response field. - """ - field_info = FieldInfo() - - kwargs = { - "name": name, - "field_info": field_info, - "type_": type_, - "default": default, - "required": required, - "model_config": model_config, - "alias": alias, - "class_validators": {}, - } - return ModelField(**kwargs) From ba333d38745a294ddc77d4f19177ad8b3f29e779 Mon Sep 17 00:00:00 2001 From: Ruben Fonseca Date: Tue, 26 Sep 2023 11:52:15 +0200 Subject: [PATCH 05/75] fix: linter --- .../event_handler/api_gateway.py | 507 ++++++++++++++++-- .../event_handler/openapi/models.py | 14 +- .../event_handler/openapi/params.py | 90 ++-- .../event_handler/openapi/utils.py | 4 +- aws_lambda_powertools/event_handler/route.py | 368 ------------- 5 files changed, 528 insertions(+), 455 deletions(-) delete mode 100644 aws_lambda_powertools/event_handler/route.py diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index 05d4599eb11..cbc02bd0686 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -22,19 +22,25 @@ Tuple, Type, Union, + cast, ) from pydantic.fields import ModelField -from pydantic.schema import get_flat_models_from_fields, get_model_name_map, model_process_schema +from pydantic.schema import ( + TypeModelOrEnum, + field_schema, + get_flat_models_from_fields, + get_model_name_map, + model_process_schema, +) from aws_lambda_powertools.event_handler import content_types from aws_lambda_powertools.event_handler.exceptions import NotFoundError, ServiceError from aws_lambda_powertools.event_handler.openapi.dependant import get_dependant from aws_lambda_powertools.event_handler.openapi.models import Contact, License, OpenAPI, Server, Tag +from aws_lambda_powertools.event_handler.openapi.params import Dependant, Param from aws_lambda_powertools.event_handler.openapi.utils import get_flat_params from aws_lambda_powertools.event_handler.response import Response -from aws_lambda_powertools.event_handler.route import Route -from aws_lambda_powertools.shared.cookies import Cookie from aws_lambda_powertools.shared.functions import powertools_dev_is_set from aws_lambda_powertools.shared.json_encoder import Encoder from aws_lambda_powertools.utilities.data_classes import ( @@ -180,42 +186,272 @@ def to_dict(self, origin: Optional[str]) -> Dict[str, str]: return headers -class Response: - """Response data class that provides greater control over what is returned from the proxy event""" +class Route: + """Internally used Route Configuration""" def __init__( self, - status_code: int, - content_type: Optional[str] = None, - body: Union[str, bytes, None] = None, - headers: Optional[Dict[str, Union[str, List[str]]]] = None, - cookies: Optional[List[Cookie]] = None, - compress: Optional[bool] = None, + method: str, + path: str, + rule: Pattern, + func: Callable, + cors: bool, + compress: bool, + cache_control: Optional[str], + middlewares: Optional[List[Callable[..., Response]]], + description: Optional[str], + tags: Optional[List[Tag]], ): """ Parameters ---------- - status_code: int - Http status code, example 200 - content_type: str - Optionally set the Content-Type header, example "application/json". Note this will be merged into any - provided http headers - body: Union[str, bytes, None] - Optionally set the response body. Note: bytes body will be automatically base64 encoded - headers: dict[str, Union[str, List[str]]] - Optionally set specific http headers. Setting "Content-Type" here would override the `content_type` value. - cookies: list[Cookie] - Optionally set cookies. + + method: str + The HTTP method, example "GET" + rule: Pattern + The route rule, example "/my/path" + path: str + The path of the route + func: Callable + The route handler function + cors: bool + Whether or not to enable CORS for this route + compress: bool + Whether or not to enable gzip compression for this route + cache_control: Optional[str] + The cache control header value, example "max-age=3600" + middlewares: Optional[List[Callable[..., Response]]] + The list of route middlewares to be called in order. + description: Optional[str] + The OpenAPI description for this route + tags: Optional[List[Tag]] + The list of OpenAPI tags to be used for this route """ - self.status_code = status_code - self.body = body - self.base64_encoded = False - self.headers: Dict[str, Union[str, List[str]]] = headers if headers else {} - self.cookies = cookies or [] + self.method = method.upper() + self.path = path + self.rule = rule + self.func = func + self._middleware_stack = func + self.cors = cors self.compress = compress - if content_type: - self.headers.setdefault("Content-Type", content_type) + self.cache_control = cache_control + self.middlewares = middlewares or [] + self.description = description + self.tags = tags or [] + self.operation_id = self.method.title() + self.func.__name__.title() + + # _middleware_stack_built is used to ensure the middleware stack is only built once. + self._middleware_stack_built = False + + def __call__( + self, + router_middlewares: List[Callable], + app: "ApiGatewayResolver", + route_arguments: Dict[str, str], + ) -> Union[Dict, Tuple, Response]: + """Calling the Router class instance will trigger the following actions: + 1. If Route Middleware stack has not been built, build it + 2. Call the Route Middleware stack wrapping the original function + handler with the app and route arguments. + + Parameters + ---------- + router_middlewares: List[Callable] + The list of Router Middlewares (assigned to ALL routes) + app: "ApiGatewayResolver" + The ApiGatewayResolver instance to pass into the middleware stack + route_arguments: Dict[str, str] + The route arguments to pass to the app function (extracted from the Api Gateway + Lambda Message structure from AWS) + + Returns + ------- + Union[Dict, Tuple, Response] + API Response object in ALL cases, except when the original API route + handler is called which may also return a Dict, Tuple, or Response. + """ + + # Save CPU cycles by building middleware stack once + if not self._middleware_stack_built: + self._build_middleware_stack(router_middlewares=router_middlewares) + + # If debug is turned on then output the middleware stack to the console + if app._debug: + print(f"\nProcessing Route:::{self.func.__name__} ({app.context['_path']})") + # Collect ALL middleware for debug printing - include internal _registered_api_adapter + all_middlewares = router_middlewares + self.middlewares + [_registered_api_adapter] + print("\nMiddleware Stack:") + print("=================") + print("\n".join(getattr(item, "__name__", "Unknown") for item in all_middlewares)) + print("=================") + + # Add Route Arguments to app context + app.append_context(_route_args=route_arguments) + + # Call the Middleware Wrapped _call_stack function handler with the app + return self._middleware_stack(app) + + def _build_middleware_stack(self, router_middlewares: List[Callable[..., Any]]) -> None: + """ + Builds the middleware stack for the handler by wrapping each + handler in an instance of MiddlewareWrapper which is used to contain the state + of each middleware step. + + Middleware is represented by a standard Python Callable construct. Any Middleware + handler wanting to short-circuit the middlware call chain can raise an exception + to force the Python call stack created by the handler call-chain to naturally un-wind. + + This becomes a simple concept for developers to understand and reason with - no additional + gymanstics other than plain old try ... except. + + Notes + ----- + The Route Middleware stack is processed in reverse order. This is so the stack of + middleware handlers is applied in the order of being added to the handler. + """ + all_middlewares = router_middlewares + self.middlewares + logger.debug(f"Building middleware stack: {all_middlewares}") + + # IMPORTANT: + # this must be the last middleware in the stack (tech debt for backward + # compatibility purposes) + # + # This adapter will: + # 1. Call the registered API passing only the expected route arguments extracted from the path + # and not the middleware. + # 2. Adapt the response type of the route handler (Union[Dict, Tuple, Response]) + # and normalise into a Response object so middleware will always have a constant signature + all_middlewares.append(_registered_api_adapter) + + # Wrap the original route handler function in the middleware handlers + # using the MiddlewareWrapper class callable construct in reverse order to + # ensure middleware is applied in the order the user defined. + # + # Start with the route function and wrap from last to the first Middleware handler. + for handler in reversed(all_middlewares): + self._middleware_stack = MiddlewareFrame(current_middleware=handler, next_middleware=self._middleware_stack) + + self._middleware_stack_built = True + + def _get_openapi_path( + self, + *, + dependant: Dependant, + operation_ids: Set[str], + model_name_map: Dict[TypeModelOrEnum, str], + ) -> Tuple[Dict[str, Any], Dict[str, Any]]: + path = {} + definitions: Dict[str, Any] = {} + + operation = self._openapi_operation_metadata(operation_ids=operation_ids) + parameters: List[Dict[str, Any]] = [] + all_route_params = get_flat_params(dependant) + operation_params = self._openapi_operation_parameters( + all_route_params=all_route_params, + model_name_map=model_name_map, + ) + + parameters.extend(operation_params) + if parameters: + all_parameters = {(param["in"], param["name"]): param for param in parameters} + required_parameters = {(param["in"], param["name"]): param for param in parameters if param.get("required")} + all_parameters.update(required_parameters) + operation["parameters"] = list(all_parameters.values()) + + responses = operation.setdefault("responses", {}) + success_response = responses.setdefault("200", {}) + success_response["description"] = "Success" + success_response["content"] = {"application/json": {"schema": {}}} + json_response = success_response["content"].setdefault("application/json", {}) + + json_response["schema"] = self._openapi_operation_return( + operation_id=self.operation_id, + param=dependant.return_param, + model_name_map=model_name_map, + ) + + path[self.method.lower()] = operation + + # Generate the response schema + return path, definitions + + def _openapi_operation_summary(self) -> str: + # Generate a summary from the pattern + return self.rule.__str__().replace("_", " ").title() + + def _openapi_operation_metadata(self, operation_ids: Set[str]) -> Dict[str, Any]: + operation: Dict[str, Any] = {} + + if self.tags: + operation["tags"] = self.tags + + operation["summary"] = self._openapi_operation_summary() + + if self.description: + operation["description"] = self.description + + # Ensure operationId is unique + if self.operation_id in operation_ids: + message = f"Duplicate Operation ID {self.operation_id} for function {self.func.__name__}" + file_name = getattr(self.func, "__globals__", {}).get("__file__") + if file_name: + message += f" in {file_name}" + warnings.warn(message, stacklevel=1) + operation_ids.add(self.operation_id) + operation["operationId"] = self.operation_id + + return operation + + @staticmethod + def _openapi_operation_parameters( + *, + all_route_params: Sequence[ModelField], + model_name_map: Dict[TypeModelOrEnum, str], + ) -> List[Dict[str, Any]]: + parameters = [] + for param in all_route_params: + field_info = param.field_info + field_info = cast(Param, field_info) + if not field_info.include_in_schema: + continue + + param_schema = field_schema(param, model_name_map=model_name_map, ref_prefix="#/components/schemas/")[0] + + parameter = { + "name": param.alias, + "in": field_info.in_.value, + "required": param.required, + "schema": param_schema, + } + + if field_info.description: + parameter["description"] = field_info.description + + if field_info.deprecated: + parameter["deprecated"] = field_info.deprecated + + parameters.append(parameter) + + return parameters + + @staticmethod + def _openapi_operation_return( + *, + operation_id: str, + param: Optional[ModelField], + model_name_map: Dict[TypeModelOrEnum, str], + ) -> Dict[str, Any]: + if param is None: + return {} + + return_schema = field_schema( + param, + model_name_map=model_name_map, + ref_prefix="#/components/schemas/", + )[0] + + return {"name": f"Return {operation_id}", "schema": return_schema} class ResponseBuilder: @@ -329,6 +565,8 @@ def route( compress: bool = False, cache_control: Optional[str] = None, middlewares: Optional[List[Callable[..., Any]]] = None, + description: Optional[str] = None, + tags: Optional[List[Tag]] = None, ): raise NotImplementedError() @@ -380,6 +618,8 @@ def get( compress: bool = False, cache_control: Optional[str] = None, middlewares: Optional[List[Callable[..., Any]]] = None, + description: Optional[str] = None, + tags: Optional[List[Tag]] = None, ): """Get route decorator with GET `method` @@ -403,7 +643,7 @@ def lambda_handler(event, context): return app.resolve(event, context) ``` """ - return self.route(rule, "GET", cors, compress, cache_control, middlewares) + return self.route(rule, "GET", cors, compress, cache_control, middlewares, description, tags) def post( self, @@ -412,6 +652,8 @@ def post( compress: bool = False, cache_control: Optional[str] = None, middlewares: Optional[List[Callable[..., Any]]] = None, + description: Optional[str] = None, + tags: Optional[List[Tag]] = None, ): """Post route decorator with POST `method` @@ -436,7 +678,7 @@ def lambda_handler(event, context): return app.resolve(event, context) ``` """ - return self.route(rule, "POST", cors, compress, cache_control, middlewares) + return self.route(rule, "POST", cors, compress, cache_control, middlewares, description, tags) def put( self, @@ -445,6 +687,8 @@ def put( compress: bool = False, cache_control: Optional[str] = None, middlewares: Optional[List[Callable[..., Any]]] = None, + description: Optional[str] = None, + tags: Optional[List[Tag]] = None, ): """Put route decorator with PUT `method` @@ -469,7 +713,7 @@ def lambda_handler(event, context): return app.resolve(event, context) ``` """ - return self.route(rule, "PUT", cors, compress, cache_control, middlewares) + return self.route(rule, "PUT", cors, compress, cache_control, middlewares, description, tags) def delete( self, @@ -478,6 +722,8 @@ def delete( compress: bool = False, cache_control: Optional[str] = None, middlewares: Optional[List[Callable[..., Any]]] = None, + description: Optional[str] = None, + tags: Optional[List[Tag]] = None, ): """Delete route decorator with DELETE `method` @@ -501,7 +747,7 @@ def lambda_handler(event, context): return app.resolve(event, context) ``` """ - return self.route(rule, "DELETE", cors, compress, cache_control, middlewares) + return self.route(rule, "DELETE", cors, compress, cache_control, middlewares, description, tags) def patch( self, @@ -510,6 +756,8 @@ def patch( compress: bool = False, cache_control: Optional[str] = None, middlewares: Optional[List[Callable]] = None, + description: Optional[str] = None, + tags: Optional[List[Tag]] = None, ): """Patch route decorator with PATCH `method` @@ -536,7 +784,7 @@ def lambda_handler(event, context): return app.resolve(event, context) ``` """ - return self.route(rule, "PATCH", cors, compress, cache_control, middlewares) + return self.route(rule, "PATCH", cors, compress, cache_control, middlewares, description, tags) def _push_processed_stack_frame(self, frame: str): """ @@ -559,6 +807,109 @@ def clear_context(self): self.context.clear() +class MiddlewareFrame: + """ + Creates a Middle Stack Wrapper instance to be used as a "Frame" in the overall stack of + middleware functions. Each instance contains the current middleware and the next + middleware function to be called in the stack. + + In this way the middleware stack is constructed in a recursive fashion, with each middleware + calling the next as a simple function call. The actual Python call-stack will contain + each MiddlewareStackWrapper "Frame", meaning any Middleware function can cause the + entire Middleware call chain to be exited early (short-circuited) by raising an exception + or by simply returning early with a custom Response. The decision to short-circuit the middleware + chain is at the user's discretion but instantly available due to the Wrapped nature of the + callable constructs in the Middleware stack and each Middleware function having complete control over + whether the "Next" handler in the stack is called or not. + + Parameters + ---------- + current_middleware : Callable + The current middleware function to be called as a request is processed. + next_middleware : Callable + The next middleware in the middleware stack. + """ + + def __init__( + self, + current_middleware: Callable[..., Any], + next_middleware: Callable[..., Any], + ) -> None: + self.current_middleware: Callable[..., Any] = current_middleware + self.next_middleware: Callable[..., Any] = next_middleware + self._next_middleware_name = next_middleware.__name__ + + @property + def __name__(self) -> str: # noqa: A003 + """Current middleware name + + It ensures backward compatibility with view functions being callable. This + improves debugging since we need both current and next middlewares/callable names. + """ + return self.current_middleware.__name__ + + def __str__(self) -> str: + """Identify current middleware identity and call chain for debugging purposes.""" + middleware_name = self.__name__ + return f"[{middleware_name}] next call chain is {middleware_name} -> {self._next_middleware_name}" + + def __call__(self, app: "ApiGatewayResolver") -> Union[Dict, Tuple, Response]: + """ + Call the middleware Frame to process the request. + + Parameters + ---------- + app: BaseRouter + The router instance + + Returns + ------- + Union[Dict, Tuple, Response] + (tech-debt for backward compatibility). The response type should be a + Response object in all cases excepting when the original API route handler + is called which will return one of 3 outputs. + + """ + # Do debug printing and push processed stack frame AFTER calling middleware + # else the stack frame text of `current calling next` is confusing. + logger.debug("MiddlewareFrame: %s", self) + app._push_processed_stack_frame(str(self)) + + return self.current_middleware(app, self.next_middleware) + + +def _registered_api_adapter( + app: "ApiGatewayResolver", + next_middleware: Callable[..., Any], +) -> Union[Dict, Tuple, Response]: + """ + Calls the registered API using the "_route_args" from the Resolver context to ensure the last call + in the chain will match the API route function signature and ensure that Powertools passes the API + route handler the expected arguments. + + **IMPORTANT: This internal middleware ensures the actual API route is called with the correct call signature + and it MUST be the final frame in the middleware stack. This can only be removed when the API Route + function accepts `app: BaseRouter` as the first argument - which is the breaking change. + + Parameters + ---------- + app: ApiGatewayResolver + The API Gateway resolver + next_middleware: Callable[..., Any] + The function to handle the API + + Returns + ------- + Response + The API Response Object + + """ + route_args: Dict = app.context.get("_route_args", {}) + logger.debug(f"Calling API Route Handler: {route_args}") + + return app._to_response(next_middleware(**route_args)) + + class ApiGatewayResolver(BaseRouter): """API Gateway and ALB proxy resolver @@ -643,6 +994,39 @@ def get_openapi_schema( contact: Optional[Contact] = None, license_info: Optional[License] = None, ) -> OpenAPI: + """ + Returns the OpenAPI schema as a pydantic model. + + Parameters + ---------- + title: str + The title of the application. + version: str + The version of the OpenAPI document (which is distinct from the OpenAPI Specification version or the API + openapi_version: str, default = "3.1.0" + The version of the OpenAPI Specification (which the document uses). + summary: str, optional + A short summary of what the application does. + description: str, optional + A verbose explanation of the application behavior. + tags: List[Tag], optional + A list of tags used by the specification with additional metadata. + servers: List[Server], optional + An array of Server Objects, which provide connectivity information to a target server. + terms_of_service: str, optional + A URL to the Terms of Service for the API. MUST be in the format of a URL. + contact: Contact, optional + The contact information for the exposed API. + license_info: + The license information for the exposed API. + + Returns + ------- + OpenAPI: pydantic model + The OpenAPI schema as a pydantic model. + """ + + # Start with the bare minimum required for a valid OpenAPI schema info: Dict[str, Any] = {"title": title, "version": version} if summary: info["summary"] = summary @@ -672,6 +1056,7 @@ def get_openapi_schema( models = get_flat_models_from_fields(all_fields, known_models=set()) model_name_map = get_model_name_map(models) + # Collect all models and definitions definitions: Dict[str, Dict[str, Any]] = {} for model in models: m_schema, m_definitions, _ = model_process_schema( @@ -681,17 +1066,16 @@ def get_openapi_schema( ) definitions.update(m_definitions) model_name = model_name_map[model] - if "description" in m_schema: - m_schema["description"] = m_schema["description"].split("\f")[0] definitions[model_name] = m_schema + # Add routes to the OpenAPI schema for route in all_routes: dependant = get_dependant( path=route.path, call=route.func, ) - result = route._openapi_path( + result = route._get_openapi_path( dependant=dependant, operation_ids=operation_ids, model_name_map=model_name_map, @@ -712,7 +1096,7 @@ def get_openapi_schema( output["paths"] = paths - return OpenAPI(**output) # .dict(by_alias=True, exclude_none=True) + return OpenAPI(**output) def get_openapi_json_schema( self, @@ -728,7 +1112,37 @@ def get_openapi_json_schema( contact: Optional[Contact] = None, license_info: Optional[License] = None, ) -> str: - """Returns the OpenAPI schema as a JSON serializable dict""" + """ + Returns the OpenAPI schema as a JSON serializable dict + + Parameters + ---------- + title: str + The title of the application. + version: str + The version of the OpenAPI document (which is distinct from the OpenAPI Specification version or the API + openapi_version: str, default = "3.1.0" + The version of the OpenAPI Specification (which the document uses). + summary: str, optional + A short summary of what the application does. + description: str, optional + A verbose explanation of the application behavior. + tags: List[Tag], optional + A list of tags used by the specification with additional metadata. + servers: List[Server], optional + An array of Server Objects, which provide connectivity information to a target server. + terms_of_service: str, optional + A URL to the Terms of Service for the API. MUST be in the format of a URL. + contact: Contact, optional + The contact information for the exposed API. + license_info: + The license information for the exposed API. + + Returns + ------- + str + The OpenAPI schema as a JSON serializable dict. + """ return self.get_openapi_schema( title=title, version=version, @@ -750,6 +1164,8 @@ def route( compress: bool = False, cache_control: Optional[str] = None, middlewares: Optional[List[Callable[..., Any]]] = None, + description: Optional[str] = None, + tags: Optional[List[Tag]] = None, ): """Route decorator includes parameter `method`""" @@ -771,6 +1187,8 @@ def register_resolver(func: Callable): compress, cache_control, middlewares, + description, + tags, ) # The more specific route wins. @@ -1127,6 +1545,9 @@ def include_router(self, router: "Router", prefix: Optional[str] = None) -> None @staticmethod def _get_fields_from_routes(routes: Sequence[Route]) -> List[ModelField]: + """ + Returns a list of fields from the routes + """ responses_from_routes: List[ModelField] = [] request_fields_from_routes: List[ModelField] = [] @@ -1159,12 +1580,14 @@ def route( compress: bool = False, cache_control: Optional[str] = None, middlewares: Optional[List[Callable[..., Any]]] = None, + description: Optional[str] = None, + tags: Optional[List[Tag]] = None, ): def register_route(func: Callable): # Convert methods to tuple. It needs to be hashable as its part of the self._routes dict key methods = (method,) if isinstance(method, str) else tuple(method) - route_key = (rule, methods, cors, compress, cache_control) + route_key = (rule, methods, cors, compress, cache_control, description, tags) # Collate Middleware for routes if middlewares is not None: @@ -1205,9 +1628,11 @@ def route( compress: bool = False, cache_control: Optional[str] = None, middlewares: Optional[List[Callable[..., Any]]] = None, + description: Optional[str] = None, + tags: Optional[List[Tag]] = None, ): # NOTE: see #1552 for more context. - return super().route(rule.rstrip("/"), method, cors, compress, cache_control, middlewares) + return super().route(rule.rstrip("/"), method, cors, compress, cache_control, middlewares, description, tags) # Override _compile_regex to exclude trailing slashes for route resolution @staticmethod diff --git a/aws_lambda_powertools/event_handler/openapi/models.py b/aws_lambda_powertools/event_handler/openapi/models.py index b5ed0d9215b..2cf06155de5 100644 --- a/aws_lambda_powertools/event_handler/openapi/models.py +++ b/aws_lambda_powertools/event_handler/openapi/models.py @@ -49,7 +49,7 @@ class Info(BaseModel): description: Optional[str] = None termsOfService: Optional[str] = None contact: Optional[Contact] = None - license: Optional[License] = None + license: Optional[License] = None # noqa: A003 version: str if PYDANTIC_V2: @@ -139,7 +139,7 @@ class Schema(BaseModel): # Core Vocabulary schema_: Optional[str] = Field(default=None, alias="$schema") vocabulary: Optional[str] = Field(default=None, alias="$vocabulary") - id: Optional[str] = Field(default=None, alias="$id") + id: Optional[str] = Field(default=None, alias="$id") # noqa: A003 anchor: Optional[str] = Field(default=None, alias="$anchor") dynamicAnchor: Optional[str] = Field(default=None, alias="$dynamicAnchor") ref: Optional[str] = Field(default=None, alias="$ref") @@ -157,9 +157,9 @@ class Schema(BaseModel): else_: Optional["SchemaOrBool"] = Field(default=None, alias="else") dependentSchemas: Optional[Dict[str, "SchemaOrBool"]] = None prefixItems: Optional[List["SchemaOrBool"]] = None - # TODO: uncomment and remove below when deprecating Pydantic v1 - # It generales a list of schemas for tuples, before prefixItems was available - # items: Optional["SchemaOrBool"] = None + # MAINTENANCE: uncomment and remove below when deprecating Pydantic v1 + # MAINTENANCE: It generates a list of schemas for tuples, before prefixItems was available + # MAINTENANCE: items: Optional["SchemaOrBool"] = None items: Optional[Union["SchemaOrBool", List["SchemaOrBool"]]] = None contains: Optional["SchemaOrBool"] = None properties: Optional[Dict[str, "SchemaOrBool"]] = None @@ -170,7 +170,7 @@ class Schema(BaseModel): unevaluatedProperties: Optional["SchemaOrBool"] = None # Ref: JSON Schema Validation 2020-12: https://json-schema.org/draft/2020-12/json-schema-validation.html#name-a-vocabulary-for-structural # A Vocabulary for Structural Validation - type: Optional[str] = None + type: Optional[str] = None # noqa: A003 enum: Optional[List[Any]] = None const: Optional[Any] = None multipleOf: Optional[float] = Field(default=None, gt=0) @@ -192,7 +192,7 @@ class Schema(BaseModel): dependentRequired: Optional[Dict[str, Set[str]]] = None # Ref: JSON Schema Validation 2020-12: https://json-schema.org/draft/2020-12/json-schema-validation.html#name-vocabularies-for-semantic-c # Vocabularies for Semantic Content With "format" - format: Optional[str] = None + format: Optional[str] = None # noqa: A003 # Ref: JSON Schema Validation 2020-12: https://json-schema.org/draft/2020-12/json-schema-validation.html#name-a-vocabulary-for-the-conten # A Vocabulary for the Contents of String-Encoded Data contentEncoding: Optional[str] = None diff --git a/aws_lambda_powertools/event_handler/openapi/params.py b/aws_lambda_powertools/event_handler/openapi/params.py index baba6d6d662..8443d1da6cb 100644 --- a/aws_lambda_powertools/event_handler/openapi/params.py +++ b/aws_lambda_powertools/event_handler/openapi/params.py @@ -6,13 +6,16 @@ from pydantic import BaseConfig from pydantic.fields import FieldInfo, ModelField, Required, Undefined from pydantic.schema import get_annotation_from_field_info +from pydantic.version import VERSION as PYDANTIC_VERSION -from aws_lambda_powertools.event_handler.openapi import Example +from aws_lambda_powertools.event_handler.openapi.utils import CacheKey """ This turns the low-level function signature into typed, validated Pydantic models for consumption. """ +PYDANTIC_V2 = PYDANTIC_VERSION.startswith("2.") + class Dependant: """ @@ -55,7 +58,7 @@ def __init__( # Store the path to be able to re-generate a dependable from it in overrides self.path = path # Save the cache key at creation to optimize performance - self.cache_key = self.call + self.cache_key: CacheKey = self.call class ParamTypes(Enum): @@ -65,6 +68,10 @@ class ParamTypes(Enum): cookie = "cookie" +# MAINTENANCE: update when deprecating Pydantic v1, remove this alias +_Unset: Any = Undefined + + class Param(FieldInfo): in_: ParamTypes @@ -72,12 +79,12 @@ def __init__( self, default: Any = Undefined, *, - default_factory: Union[Callable[[], Any], None] = Undefined, + default_factory: Union[Callable[[], Any], None] = _Unset, annotation: Optional[Any] = None, alias: Optional[str] = None, - alias_priority: Union[int, None] = Undefined, - # TODO: update when deprecating Pydantic v1, import these types - # validation_alias: str | AliasPath | AliasChoices | None + alias_priority: Union[int, None] = _Unset, + # MAINTENANCE: update when deprecating Pydantic v1, import these types + # MAINTENANCE: validation_alias: str | AliasPath | AliasChoices | None validation_alias: Union[str, None] = None, serialization_alias: Union[str, None] = None, title: Optional[str] = None, @@ -90,13 +97,12 @@ def __init__( max_length: Optional[int] = None, pattern: Optional[str] = None, discriminator: Union[str, None] = None, - strict: Union[bool, None] = Undefined, - multiple_of: Union[float, None] = Undefined, - allow_inf_nan: Union[bool, None] = Undefined, - max_digits: Union[int, None] = Undefined, - decimal_places: Union[int, None] = Undefined, + strict: Union[bool, None] = _Unset, + multiple_of: Union[float, None] = _Unset, + allow_inf_nan: Union[bool, None] = _Unset, + max_digits: Union[int, None] = _Unset, + decimal_places: Union[int, None] = _Unset, examples: Optional[List[Any]] = None, - openapi_examples: Optional[Dict[str, Example]] = None, deprecated: Optional[bool] = None, include_in_schema: bool = True, json_schema_extra: Union[Dict[str, Any], None] = None, @@ -104,7 +110,7 @@ def __init__( ): self.deprecated = deprecated self.include_in_schema = include_in_schema - self.openapi_examples = openapi_examples + kwargs = dict( default=default, default_factory=default_factory, @@ -128,9 +134,23 @@ def __init__( kwargs["examples"] = examples current_json_schema_extra = json_schema_extra or extra - kwargs["regex"] = pattern - kwargs.update(**current_json_schema_extra) - use_kwargs = {k: v for k, v in kwargs.items() if v is not Undefined} + if PYDANTIC_V2: + kwargs.update( + { + "annotation": annotation, + "alias_priority": alias_priority, + "validation_alias": validation_alias, + "serialization_alias": serialization_alias, + "strict": strict, + "json_schema_extra": current_json_schema_extra, + }, + ) + kwargs["pattern"] = pattern + else: + kwargs["regex"] = pattern + kwargs.update(**current_json_schema_extra) + + use_kwargs = {k: v for k, v in kwargs.items() if v is not _Unset} super().__init__(**use_kwargs) @@ -145,12 +165,12 @@ def __init__( self, default: Any = ..., *, - default_factory: Union[Callable[[], Any], None] = Undefined, + default_factory: Union[Callable[[], Any], None] = _Unset, annotation: Optional[Any] = None, alias: Optional[str] = None, - alias_priority: Union[int, None] = Undefined, - # TODO: update when deprecating Pydantic v1, import these types - # validation_alias: str | AliasPath | AliasChoices | None + alias_priority: Union[int, None] = _Unset, + # MAINTENANCE: update when deprecating Pydantic v1, import these types + # MAINTENANCE: validation_alias: str | AliasPath | AliasChoices | None validation_alias: Union[str, None] = None, serialization_alias: Union[str, None] = None, title: Optional[str] = None, @@ -163,13 +183,12 @@ def __init__( max_length: Optional[int] = None, pattern: Optional[str] = None, discriminator: Union[str, None] = None, - strict: Union[bool, None] = Undefined, - multiple_of: Union[float, None] = Undefined, - allow_inf_nan: Union[bool, None] = Undefined, - max_digits: Union[int, None] = Undefined, - decimal_places: Union[int, None] = Undefined, + strict: Union[bool, None] = _Unset, + multiple_of: Union[float, None] = _Unset, + allow_inf_nan: Union[bool, None] = _Unset, + max_digits: Union[int, None] = _Unset, + decimal_places: Union[int, None] = _Unset, examples: Optional[List[Any]] = None, - openapi_examples: Optional[Dict[str, Example]] = None, deprecated: Optional[bool] = None, include_in_schema: bool = True, json_schema_extra: Union[Dict[str, Any], None] = None, @@ -202,7 +221,6 @@ def __init__( decimal_places=decimal_places, deprecated=deprecated, examples=examples, - openapi_examples=openapi_examples, include_in_schema=include_in_schema, json_schema_extra=json_schema_extra, **extra, @@ -214,12 +232,12 @@ class Query(Param): def __init__( self, - default: Any = Undefined, + default: Any = _Unset, *, - default_factory: Union[Callable[[], Any], None] = Undefined, + default_factory: Union[Callable[[], Any], None] = _Unset, annotation: Optional[Any] = None, alias: Optional[str] = None, - alias_priority: Union[int, None] = Undefined, + alias_priority: Union[int, None] = _Unset, validation_alias: Union[str, None] = None, serialization_alias: Union[str, None] = None, title: Optional[str] = None, @@ -232,13 +250,12 @@ def __init__( max_length: Optional[int] = None, pattern: Optional[str] = None, discriminator: Union[str, None] = None, - strict: Union[bool, None] = Undefined, - multiple_of: Union[float, None] = Undefined, - allow_inf_nan: Union[bool, None] = Undefined, - max_digits: Union[int, None] = Undefined, - decimal_places: Union[int, None] = Undefined, + strict: Union[bool, None] = _Unset, + multiple_of: Union[float, None] = _Unset, + allow_inf_nan: Union[bool, None] = _Unset, + max_digits: Union[int, None] = _Unset, + decimal_places: Union[int, None] = _Unset, examples: Optional[List[Any]] = None, - openapi_examples: Optional[Dict[str, Example]] = None, deprecated: Optional[bool] = None, include_in_schema: bool = True, json_schema_extra: Union[Dict[str, Any], None] = None, @@ -269,7 +286,6 @@ def __init__( decimal_places=decimal_places, deprecated=deprecated, examples=examples, - openapi_examples=openapi_examples, include_in_schema=include_in_schema, json_schema_extra=json_schema_extra, **extra, diff --git a/aws_lambda_powertools/event_handler/openapi/utils.py b/aws_lambda_powertools/event_handler/openapi/utils.py index c01d4f45885..eeca1d0fd1e 100644 --- a/aws_lambda_powertools/event_handler/openapi/utils.py +++ b/aws_lambda_powertools/event_handler/openapi/utils.py @@ -1,10 +1,10 @@ -from typing import Any, Callable, List, Optional, Tuple +from typing import Any, Callable, List, Optional from pydantic.fields import ModelField from aws_lambda_powertools.event_handler.openapi.params import Dependant -CacheKey = Tuple[Optional[Callable[..., Any]], Tuple[str, ...]] +CacheKey = Optional[Callable[..., Any]] """ This file defines utility functions for working with OpenAPI/JSON Schema models. diff --git a/aws_lambda_powertools/event_handler/route.py b/aws_lambda_powertools/event_handler/route.py deleted file mode 100644 index ee631f8b98b..00000000000 --- a/aws_lambda_powertools/event_handler/route.py +++ /dev/null @@ -1,368 +0,0 @@ -import logging -import warnings -from re import Pattern -from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, Union, cast - -from pydantic.fields import ModelField -from pydantic.schema import TypeModelOrEnum, field_schema - -from aws_lambda_powertools.event_handler.openapi.params import Dependant, Param -from aws_lambda_powertools.event_handler.openapi.utils import get_flat_params -from aws_lambda_powertools.event_handler.response import Response - -logger = logging.getLogger(__name__) - - -class MiddlewareFrame: - """ - creates a Middle Stack Wrapper instance to be used as a "Frame" in the overall stack of - middleware functions. Each instance contains the current middleware and the next - middleware function to be called in the stack. - - In this way the middleware stack is constructed in a recursive fashion, with each middleware - calling the next as a simple function call. The actual Python call-stack will contain - each MiddlewareStackWrapper "Frame", meaning any Middleware function can cause the - entire Middleware call chain to be exited early (short-circuited) by raising an exception - or by simply returning early with a custom Response. The decision to short-circuit the middleware - chain is at the user's discretion but instantly available due to the Wrapped nature of the - callable constructs in the Middleware stack and each Middleware function having complete control over - whether the "Next" handler in the stack is called or not. - - Parameters - ---------- - current_middleware : Callable - The current middleware function to be called as a request is processed. - next_middleware : Callable - The next middleware in the middleware stack. - """ - - def __init__( - self, - current_middleware: Callable[..., Any], - next_middleware: Callable[..., Any], - ) -> None: - self.current_middleware: Callable[..., Any] = current_middleware - self.next_middleware: Callable[..., Any] = next_middleware - self._next_middleware_name = next_middleware.__name__ - - @property - def __name__(self) -> str: # noqa: A003 - """Current middleware name - - It ensures backward compatibility with view functions being callable. This - improves debugging since we need both current and next middlewares/callable names. - """ - return self.current_middleware.__name__ - - def __str__(self) -> str: - """Identify current middleware identity and call chain for debugging purposes.""" - middleware_name = self.__name__ - return f"[{middleware_name}] next call chain is {middleware_name} -> {self._next_middleware_name}" - - def __call__(self, app: "ApiGatewayResolver") -> Union[Dict, Tuple, Response]: - """ - Call the middleware Frame to process the request. - - Parameters - ---------- - app: BaseRouter - The router instance - - Returns - ------- - Union[Dict, Tuple, Response] - (tech-debt for backward compatibility). The response type should be a - Response object in all cases excepting when the original API route handler - is called which will return one of 3 outputs. - - """ - # Do debug printing and push processed stack frame AFTER calling middleware - # else the stack frame text of `current calling next` is confusing. - logger.debug("MiddlewareFrame: %s", self) - app._push_processed_stack_frame(str(self)) - - return self.current_middleware(app, self.next_middleware) - - -class Route: - """Internally used Route Configuration""" - - def __init__( - self, - method: str, - path: str, - rule: Pattern, - func: Callable, - cors: bool, - compress: bool, - cache_control: Optional[str], - middlewares: Optional[List[Callable[..., Response]]], - ): - """ - - Parameters - ---------- - - method: str - The HTTP method, example "GET" - rule: Pattern - The route rule, example "/my/path" - path: str - The path of the route - func: Callable - The route handler function - cors: bool - Whether or not to enable CORS for this route - compress: bool - Whether or not to enable gzip compression for this route - cache_control: Optional[str] - The cache control header value, example "max-age=3600" - middlewares: Optional[List[Callable[..., Response]]] - The list of route middlewares to be called in order. - """ - self.method = method.upper() - self.path = path - self.rule = rule - self.func = func - self._middleware_stack = func - self.cors = cors - self.compress = compress - self.cache_control = cache_control - self.middlewares = middlewares or [] - self.operation_id = self.method.title() + self.func.__name__.title() - - # _middleware_stack_built is used to ensure the middleware stack is only built once. - self._middleware_stack_built = False - - def __call__( - self, - router_middlewares: List[Callable], - app: "ApiGatewayResolver", - route_arguments: Dict[str, str], - ) -> Union[Dict, Tuple, Response]: - """Calling the Router class instance will trigger the following actions: - 1. If Route Middleware stack has not been built, build it - 2. Call the Route Middleware stack wrapping the original function - handler with the app and route arguments. - - Parameters - ---------- - router_middlewares: List[Callable] - The list of Router Middlewares (assigned to ALL routes) - app: "ApiGatewayResolver" - The ApiGatewayResolver instance to pass into the middleware stack - route_arguments: Dict[str, str] - The route arguments to pass to the app function (extracted from the Api Gateway - Lambda Message structure from AWS) - - Returns - ------- - Union[Dict, Tuple, Response] - API Response object in ALL cases, except when the original API route - handler is called which may also return a Dict, Tuple, or Response. - """ - - # Save CPU cycles by building middleware stack once - if not self._middleware_stack_built: - self._build_middleware_stack(router_middlewares=router_middlewares) - - # If debug is turned on then output the middleware stack to the console - if app._debug: - print(f"\nProcessing Route:::{self.func.__name__} ({app.context['_path']})") - # Collect ALL middleware for debug printing - include internal _registered_api_adapter - all_middlewares = router_middlewares + self.middlewares + [_registered_api_adapter] - print("\nMiddleware Stack:") - print("=================") - print("\n".join(getattr(item, "__name__", "Unknown") for item in all_middlewares)) - print("=================") - - # Add Route Arguments to app context - app.append_context(_route_args=route_arguments) - - # Call the Middleware Wrapped _call_stack function handler with the app - return self._middleware_stack(app) - - def _build_middleware_stack(self, router_middlewares: List[Callable[..., Any]]) -> None: - """ - Builds the middleware stack for the handler by wrapping each - handler in an instance of MiddlewareWrapper which is used to contain the state - of each middleware step. - - Middleware is represented by a standard Python Callable construct. Any Middleware - handler wanting to short-circuit the middlware call chain can raise an exception - to force the Python call stack created by the handler call-chain to naturally un-wind. - - This becomes a simple concept for developers to understand and reason with - no additional - gymanstics other than plain old try ... except. - - Notes - ----- - The Route Middleware stack is processed in reverse order. This is so the stack of - middleware handlers is applied in the order of being added to the handler. - """ - all_middlewares = router_middlewares + self.middlewares - logger.debug(f"Building middleware stack: {all_middlewares}") - - # IMPORTANT: - # this must be the last middleware in the stack (tech debt for backward - # compatibility purposes) - # - # This adapter will: - # 1. Call the registered API passing only the expected route arguments extracted from the path - # and not the middleware. - # 2. Adapt the response type of the route handler (Union[Dict, Tuple, Response]) - # and normalise into a Response object so middleware will always have a constant signature - all_middlewares.append(_registered_api_adapter) - - # Wrap the original route handler function in the middleware handlers - # using the MiddlewareWrapper class callable construct in reverse order to - # ensure middleware is applied in the order the user defined. - # - # Start with the route function and wrap from last to the first Middleware handler. - for handler in reversed(all_middlewares): - self._middleware_stack = MiddlewareFrame(current_middleware=handler, next_middleware=self._middleware_stack) - - self._middleware_stack_built = True - - def _openapi_path( - self, - *, - dependant: Dependant, - operation_ids: Set[str], - model_name_map: Dict[TypeModelOrEnum, str], - ) -> Tuple[Dict[str, Any], Dict[str, Any]]: - path = {} - definitions: Dict[str, Any] = {} - - operation = self._openapi_operation_metadata(operation_ids=operation_ids) - parameters: List[Dict[str, Any]] = [] - all_route_params = get_flat_params(dependant) - operation_params = self._openapi_operation_parameters( - all_route_params=all_route_params, - model_name_map=model_name_map, - ) - - parameters.extend(operation_params) - if parameters: - all_parameters = {(param["in"], param["name"]): param for param in parameters} - required_parameters = {(param["in"], param["name"]): param for param in parameters if param.get("required")} - all_parameters.update(required_parameters) - operation["parameters"] = list(all_parameters.values()) - - responses = operation.setdefault("responses", {}) - success_response = responses.setdefault("200", {}) - success_response["description"] = "Success" - success_response["content"] = {"application/json": {"schema": {}}} - json_response = success_response["content"].setdefault("application/json", {}) - - json_response["schema"] = self._openapi_operation_return( - operation_id=self.operation_id, - param=dependant.return_param, - model_name_map=model_name_map, - ) - - path[self.method.lower()] = operation - - # Generate the response schema - return path, definitions - - def _openapi_operation_summary(self): - # TODO: add name, summary to Route, and allow it to be customized during creation - self.rule.__str__().replace("_", " ").title() - - def _openapi_operation_metadata(self, operation_ids: Set[str]) -> Dict[str, Any]: - operation: Dict[str, Any] = {"summary": self._openapi_operation_summary()} - - # TODO: description, tags - operation_id = self.operation_id - if operation_id in operation_ids: - message = f"Duplicate Operation ID {operation_id} for function {self.func.__name__}" - file_name = getattr(self.func, "__globals__", {}).get("__file__") - if file_name: - message += f" in {file_name}" - warnings.warn(message, stacklevel=1) - operation_ids.add(operation_id) - operation["operationId"] = operation_id - return operation - - @staticmethod - def _openapi_operation_parameters( - *, - all_route_params: Sequence[ModelField], - model_name_map: Dict[TypeModelOrEnum, str], - ) -> List[Dict[str, Any]]: - parameters = [] - for param in all_route_params: - field_info = param.field_info - field_info = cast(Param, field_info) - if not field_info.include_in_schema: - continue - - param_schema = field_schema(param, model_name_map=model_name_map, ref_prefix="#/components/schemas/")[0] - - parameter = { - "name": param.alias, - "in": field_info.in_.value, - "required": param.required, - "schema": param_schema, - } - - if field_info.description: - parameter["description"] = field_info.description - - if field_info.deprecated: - parameter["deprecated"] = field_info.deprecated - - parameters.append(parameter) - - return parameters - - @staticmethod - def _openapi_operation_return( - *, - operation_id: str, - param: Optional[ModelField], - model_name_map: Dict[TypeModelOrEnum, str], - ) -> Dict[str, Any]: - if param is None: - return {} - - return_schema = field_schema( - param, - model_name_map=model_name_map, - ref_prefix="#/components/schemas/", - )[0] - - return {"name": f"Return {operation_id}", "schema": return_schema} - - -def _registered_api_adapter( - app: "ApiGatewayResolver", - next_middleware: Callable[..., Any], -) -> Union[Dict, Tuple, Response]: - """ - Calls the registered API using the "_route_args" from the Resolver context to ensure the last call - in the chain will match the API route function signature and ensure that Powertools passes the API - route handler the expected arguments. - - **IMPORTANT: This internal middleware ensures the actual API route is called with the correct call signature - and it MUST be the final frame in the middleware stack. This can only be removed when the API Route - function accepts `app: BaseRouter` as the first argument - which is the breaking change. - - Parameters - ---------- - app: ApiGatewayResolver - The API Gateway resolver - next_middleware: Callable[..., Any] - The function to handle the API - - Returns - ------- - Response - The API Response Object - - """ - route_args: Dict = app.context.get("_route_args", {}) - logger.debug(f"Calling API Route Handler: {route_args}") - - return app._to_response(next_middleware(**route_args)) From 303fb2e5f0a9066667478000c2bc292448464181 Mon Sep 17 00:00:00 2001 From: Ruben Fonseca Date: Tue, 26 Sep 2023 11:54:47 +0200 Subject: [PATCH 06/75] fix: remove unneeded code --- aws_lambda_powertools/event_handler/openapi/params.py | 1 - 1 file changed, 1 deletion(-) diff --git a/aws_lambda_powertools/event_handler/openapi/params.py b/aws_lambda_powertools/event_handler/openapi/params.py index 8443d1da6cb..26c31d939fd 100644 --- a/aws_lambda_powertools/event_handler/openapi/params.py +++ b/aws_lambda_powertools/event_handler/openapi/params.py @@ -195,7 +195,6 @@ def __init__( **extra: Any, ): assert default is ..., "Path parameters cannot have a default value" - self.in_ = self.in_ super(Path, self).__init__( default=default, default_factory=default_factory, From d1be57ba0b6aa5492500766153f99c79e7491975 Mon Sep 17 00:00:00 2001 From: Ruben Fonseca Date: Tue, 26 Sep 2023 11:56:57 +0200 Subject: [PATCH 07/75] fix: reduce duplication --- aws_lambda_powertools/event_handler/api_gateway.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index cbc02bd0686..8a38e5f22be 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -62,6 +62,7 @@ _UNSAFE_URI = r"%<> \[\]{}|^" _NAMED_GROUP_BOUNDARY_PATTERN = rf"(?P\1[{_SAFE_URI}{_UNSAFE_URI}\\w]+)" _ROUTE_REGEX = "^{}$" +_COMPONENT_REF_PREFIX = "#/components/schemas/" class ProxyEventType(Enum): @@ -416,7 +417,7 @@ def _openapi_operation_parameters( if not field_info.include_in_schema: continue - param_schema = field_schema(param, model_name_map=model_name_map, ref_prefix="#/components/schemas/")[0] + param_schema = field_schema(param, model_name_map=model_name_map, ref_prefix=_COMPONENT_REF_PREFIX)[0] parameter = { "name": param.alias, @@ -448,7 +449,7 @@ def _openapi_operation_return( return_schema = field_schema( param, model_name_map=model_name_map, - ref_prefix="#/components/schemas/", + ref_prefix=_COMPONENT_REF_PREFIX, )[0] return {"name": f"Return {operation_id}", "schema": return_schema} @@ -1062,7 +1063,7 @@ def get_openapi_schema( m_schema, m_definitions, _ = model_process_schema( model, model_name_map=model_name_map, - ref_prefix="#/components/schemas/", + ref_prefix=_COMPONENT_REF_PREFIX, ) definitions.update(m_definitions) model_name = model_name_map[model] From 40fcca12205e9ed83098d1508f98fc1f5ba06071 Mon Sep 17 00:00:00 2001 From: Ruben Fonseca Date: Tue, 26 Sep 2023 17:03:49 +0200 Subject: [PATCH 08/75] fix: types and sonarcube --- .../event_handler/api_gateway.py | 44 ++++++++++--------- .../event_handler/openapi/dependant.py | 6 +-- 2 files changed, 26 insertions(+), 24 deletions(-) diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index 8a38e5f22be..5813b674822 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -1029,16 +1029,16 @@ def get_openapi_schema( # Start with the bare minimum required for a valid OpenAPI schema info: Dict[str, Any] = {"title": title, "version": version} - if summary: - info["summary"] = summary - if description: - info["description"] = description - if terms_of_service: - info["termsOfService"] = terms_of_service - if contact: - info["contact"] = contact - if license_info: - info["license"] = license_info + + optional_fields = { + "summary": summary, + "description": description, + "termsOfService": terms_of_service, + "contact": contact, + "license": license_info, + } + + info.update({field: value for field, value in optional_fields.items() if value}) output: Dict[str, Any] = {"openapi": openapi_version, "info": info} if servers: @@ -1173,10 +1173,8 @@ def route( def register_resolver(func: Callable): methods = (method,) if isinstance(method, str) else method logger.debug(f"Adding route using rule {rule} and methods: {','.join((m.upper() for m in methods))}") - if cors is None: - cors_enabled = self._cors_enabled - else: - cors_enabled = cors + + cors_enabled = self._cors_enabled if cors is None else cors for item in methods: _route = Route( @@ -1201,13 +1199,8 @@ def register_resolver(func: Callable): else: self._static_routes.append(_route) - route_key = item + rule - if route_key in self._route_keys: - warnings.warn( - f"A route like this was already registered. method: '{item}' rule: '{rule}'", - stacklevel=2, - ) - self._route_keys.append(route_key) + self._create_route_key(item, rule) + if cors_enabled: logger.debug(f"Registering method {item.upper()} to Allow Methods in CORS") self._cors_methods.add(item.upper()) @@ -1261,6 +1254,15 @@ def resolve(self, event, context) -> Dict[str, Any]: def __call__(self, event, context) -> Any: return self.resolve(event, context) + def _create_route_key(self, item: str, rule: str): + route_key = item + rule + if route_key in self._route_keys: + warnings.warn( + f"A route like this was already registered. method: '{item}' rule: '{rule}'", + stacklevel=2, + ) + self._route_keys.append(route_key) + @staticmethod def _has_debug(debug: Optional[bool] = None) -> bool: # It might have been explicitly switched off (debug=False) diff --git a/aws_lambda_powertools/event_handler/openapi/dependant.py b/aws_lambda_powertools/event_handler/openapi/dependant.py index 541face1387..74271413cd5 100644 --- a/aws_lambda_powertools/event_handler/openapi/dependant.py +++ b/aws_lambda_powertools/event_handler/openapi/dependant.py @@ -159,8 +159,8 @@ def get_dependant( # If the parameter is a path parameter, we need to set the in_ field to "path". is_path_param = param_name in path_param_names - # Analyze the parameter to get the type annotation and the Pydantic field. - type_annotation, param_field = analyze_param( + # Analyze the parameter to get the Pydantic field. + _, param_field = analyze_param( param_name=param_name, annotation=param.annotation, value=param.default, @@ -173,7 +173,7 @@ def get_dependant( # If the return annotation is not empty, add it to the dependant model. return_annotation = endpoint_signature.return_annotation if return_annotation is not inspect.Signature.empty: - type_annotation, param_field = analyze_param( + _, param_field = analyze_param( param_name="Return", annotation=return_annotation, value=None, From 079f3d79a40b378c3795dc201ceedcd9404424ad Mon Sep 17 00:00:00 2001 From: Ruben Fonseca Date: Tue, 26 Sep 2023 17:28:35 +0200 Subject: [PATCH 09/75] chore: refactor complex function --- .../event_handler/openapi/params.py | 101 ++++++++++-------- 1 file changed, 58 insertions(+), 43 deletions(-) diff --git a/aws_lambda_powertools/event_handler/openapi/params.py b/aws_lambda_powertools/event_handler/openapi/params.py index 26c31d939fd..3c93505cba4 100644 --- a/aws_lambda_powertools/event_handler/openapi/params.py +++ b/aws_lambda_powertools/event_handler/openapi/params.py @@ -317,6 +317,28 @@ def analyze_param( Tuple[Any, Optional[ModelField]] The type annotation and the Pydantic field representing the parameter """ + field_info, type_annotation = _get_field_info_and_type_annotation(annotation, value, is_path_param) + + # If the value is a FieldInfo, we use it as the FieldInfo for the parameter + if isinstance(value, FieldInfo): + assert field_info is None + field_info = value + + # If we didn't determine the FieldInfo yet, we create a default one + if field_info is None: + default_value = value if value is not inspect.Signature.empty else Required + + # Check if the parameter is part of the path. Otherwise, defaults to query. + if is_path_param: + field_info = Path(annotation=type_annotation, default=default_value) + else: + field_info = Query(annotation=type_annotation, default=default_value) + + field = _create_model_field(field_info, type_annotation, param_name, is_path_param) + return type_annotation, field + + +def _get_field_info_and_type_annotation(annotation, value, is_path_param: bool) -> Tuple[Optional[FieldInfo], Any]: field_info: Optional[FieldInfo] = None type_annotation: Any = Any @@ -344,48 +366,41 @@ def analyze_param( elif annotation is not inspect.Signature.empty: type_annotation = annotation - # If the value is a FieldInfo, we use it as the FieldInfo for the parameter - if isinstance(value, FieldInfo): - assert field_info is None - field_info = value - - # If we didn't determine the FieldInfo yet, we create a default one - if field_info is None: - default_value = value if value is not inspect.Signature.empty else Required - - if is_path_param: - field_info = Path(annotation=type_annotation, default=default_value) - else: - field_info = Query(annotation=type_annotation, default=default_value) + return field_info, type_annotation - # Now that we have the FieldInfo, we can determine the type annotation - field = None - if field_info is not None: - if is_path_param: - assert isinstance(field_info, Path), "Path parameters must be of type Path" - elif isinstance(field_info, Param) and getattr(field_info, "in_", None) is None: - field_info.in_ = ParamTypes.query - - # If the field_info is a Param, we use the `in_` attribute to determine the type annotation - use_annotation = get_annotation_from_field_info(type_annotation, field_info, param_name) - - # If the field doesn't have a defined alias, we use the param name - if not field_info.alias and getattr(field_info, "convert_underscores", None): - alias = param_name.replace("_", "-") - else: - alias = field_info.alias or param_name - field_info.alias = alias - - # Create the Pydantic field - field = ModelField( - name=param_name, - field_info=field_info, - type_=use_annotation, - class_validators={}, - default=field_info.default, - required=field_info.default in (Required, Undefined), - model_config=BaseConfig, - alias=alias, - ) - return type_annotation, field +def _create_model_field( + field_info: Optional[FieldInfo], + type_annotation: Any, + param_name: str, + is_path_param: bool, +) -> Optional[ModelField]: + if field_info is None: + return None + + if is_path_param: + assert isinstance(field_info, Path), "Path parameters must be of type Path" + elif isinstance(field_info, Param) and getattr(field_info, "in_", None) is None: + field_info.in_ = ParamTypes.query + + # If the field_info is a Param, we use the `in_` attribute to determine the type annotation + use_annotation = get_annotation_from_field_info(type_annotation, field_info, param_name) + + # If the field doesn't have a defined alias, we use the param name + if not field_info.alias and getattr(field_info, "convert_underscores", None): + alias = param_name.replace("_", "-") + else: + alias = field_info.alias or param_name + field_info.alias = alias + + # Create the Pydantic field + return ModelField( + name=param_name, + field_info=field_info, + type_=use_annotation, + class_validators={}, + default=field_info.default, + required=field_info.default in (Required, Undefined), + model_config=BaseConfig, + alias=alias, + ) From 44bc06752d1a382c54df3504e5f69704542992f1 Mon Sep 17 00:00:00 2001 From: Ruben Fonseca Date: Tue, 26 Sep 2023 17:34:45 +0200 Subject: [PATCH 10/75] fix: typing extensions --- aws_lambda_powertools/event_handler/openapi/params.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/aws_lambda_powertools/event_handler/openapi/params.py b/aws_lambda_powertools/event_handler/openapi/params.py index 3c93505cba4..db03123bc20 100644 --- a/aws_lambda_powertools/event_handler/openapi/params.py +++ b/aws_lambda_powertools/event_handler/openapi/params.py @@ -1,12 +1,13 @@ import inspect from copy import copy from enum import Enum -from typing import Annotated, Any, Callable, Dict, List, Optional, Tuple, Union, get_args, get_origin +from typing import Any, Callable, Dict, List, Optional, Tuple, Union, get_args, get_origin from pydantic import BaseConfig from pydantic.fields import FieldInfo, ModelField, Required, Undefined from pydantic.schema import get_annotation_from_field_info from pydantic.version import VERSION as PYDANTIC_VERSION +from typing_extensions import Annotated from aws_lambda_powertools.event_handler.openapi.utils import CacheKey From c11dda4a0d1bbb751b247a7132c22f418567ef69 Mon Sep 17 00:00:00 2001 From: Ruben Fonseca Date: Wed, 27 Sep 2023 10:34:06 +0200 Subject: [PATCH 11/75] fix: tests --- .../event_handler/api_gateway.py | 23 +++-- .../event_handler/openapi/dependant.py | 81 ++++++++++++++++- .../event_handler/openapi/params.py | 2 +- .../event_handler/openapi/types.py | 3 + .../event_handler/openapi/utils.py | 89 ------------------- 5 files changed, 95 insertions(+), 103 deletions(-) create mode 100644 aws_lambda_powertools/event_handler/openapi/types.py delete mode 100644 aws_lambda_powertools/event_handler/openapi/utils.py diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index 5813b674822..ff31f014f0b 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -36,10 +36,9 @@ from aws_lambda_powertools.event_handler import content_types from aws_lambda_powertools.event_handler.exceptions import NotFoundError, ServiceError -from aws_lambda_powertools.event_handler.openapi.dependant import get_dependant +from aws_lambda_powertools.event_handler.openapi.dependant import get_dependant, get_flat_params from aws_lambda_powertools.event_handler.openapi.models import Contact, License, OpenAPI, Server, Tag from aws_lambda_powertools.event_handler.openapi.params import Dependant, Param -from aws_lambda_powertools.event_handler.openapi.utils import get_flat_params from aws_lambda_powertools.event_handler.response import Response from aws_lambda_powertools.shared.functions import powertools_dev_is_set from aws_lambda_powertools.shared.json_encoder import Encoder @@ -565,9 +564,9 @@ def route( cors: Optional[bool] = None, compress: bool = False, cache_control: Optional[str] = None, - middlewares: Optional[List[Callable[..., Any]]] = None, description: Optional[str] = None, tags: Optional[List[Tag]] = None, + middlewares: Optional[List[Callable[..., Any]]] = None, ): raise NotImplementedError() @@ -644,7 +643,7 @@ def lambda_handler(event, context): return app.resolve(event, context) ``` """ - return self.route(rule, "GET", cors, compress, cache_control, middlewares, description, tags) + return self.route(rule, "GET", cors, compress, cache_control, description, tags, middlewares) def post( self, @@ -679,7 +678,7 @@ def lambda_handler(event, context): return app.resolve(event, context) ``` """ - return self.route(rule, "POST", cors, compress, cache_control, middlewares, description, tags) + return self.route(rule, "POST", cors, compress, cache_control, description, tags, middlewares) def put( self, @@ -714,7 +713,7 @@ def lambda_handler(event, context): return app.resolve(event, context) ``` """ - return self.route(rule, "PUT", cors, compress, cache_control, middlewares, description, tags) + return self.route(rule, "PUT", cors, compress, cache_control, description, tags, middlewares) def delete( self, @@ -748,7 +747,7 @@ def lambda_handler(event, context): return app.resolve(event, context) ``` """ - return self.route(rule, "DELETE", cors, compress, cache_control, middlewares, description, tags) + return self.route(rule, "DELETE", cors, compress, cache_control, description, tags, middlewares) def patch( self, @@ -785,7 +784,7 @@ def lambda_handler(event, context): return app.resolve(event, context) ``` """ - return self.route(rule, "PATCH", cors, compress, cache_control, middlewares, description, tags) + return self.route(rule, "PATCH", cors, compress, cache_control, description, tags, middlewares) def _push_processed_stack_frame(self, frame: str): """ @@ -1164,9 +1163,9 @@ def route( cors: Optional[bool] = None, compress: bool = False, cache_control: Optional[str] = None, - middlewares: Optional[List[Callable[..., Any]]] = None, description: Optional[str] = None, tags: Optional[List[Tag]] = None, + middlewares: Optional[List[Callable[..., Any]]] = None, ): """Route decorator includes parameter `method`""" @@ -1582,9 +1581,9 @@ def route( cors: Optional[bool] = None, compress: bool = False, cache_control: Optional[str] = None, - middlewares: Optional[List[Callable[..., Any]]] = None, description: Optional[str] = None, tags: Optional[List[Tag]] = None, + middlewares: Optional[List[Callable[..., Any]]] = None, ): def register_route(func: Callable): # Convert methods to tuple. It needs to be hashable as its part of the self._routes dict key @@ -1630,12 +1629,12 @@ def route( cors: Optional[bool] = None, compress: bool = False, cache_control: Optional[str] = None, - middlewares: Optional[List[Callable[..., Any]]] = None, description: Optional[str] = None, tags: Optional[List[Tag]] = None, + middlewares: Optional[List[Callable[..., Any]]] = None, ): # NOTE: see #1552 for more context. - return super().route(rule.rstrip("/"), method, cors, compress, cache_control, middlewares, description, tags) + return super().route(rule.rstrip("/"), method, cors, compress, cache_control, description, tags, middlewares) # Override _compile_regex to exclude trailing slashes for route resolution @staticmethod diff --git a/aws_lambda_powertools/event_handler/openapi/dependant.py b/aws_lambda_powertools/event_handler/openapi/dependant.py index 74271413cd5..fa915c0a65f 100644 --- a/aws_lambda_powertools/event_handler/openapi/dependant.py +++ b/aws_lambda_powertools/event_handler/openapi/dependant.py @@ -1,11 +1,12 @@ import inspect import re -from typing import Any, Callable, Dict, ForwardRef, Optional, Set, cast +from typing import Any, Callable, Dict, ForwardRef, List, Optional, Set, cast from pydantic.fields import ModelField from pydantic.typing import evaluate_forwardref from aws_lambda_powertools.event_handler.openapi.params import Dependant, Param, ParamTypes, analyze_param +from aws_lambda_powertools.event_handler.openapi.types import CacheKey """ This turns the opaque function signature into typed, validated models. @@ -184,3 +185,81 @@ def get_dependant( dependant.return_param = param_field return dependant + + +def get_flat_dependant( + dependant: Dependant, + *, + skip_repeats: bool = False, + visited: Optional[List[CacheKey]] = None, +) -> Dependant: + """ + Flatten a recursive Dependant model structure. + + This function recursively concatenates the parameter fields of a Dependant model and its dependencies into a flat + Dependant structure. This is useful for scenarios like parameter validation where the nested structure is not + relevant. + + Parameters + ---------- + dependant: Dependant + The dependant model to flatten + skip_repeats: bool + If True, child Dependents already visited will be skipped to avoid duplicates + visited: List[CacheKey], optional + Keeps track of visited Dependents to avoid infinite recursion. Defaults to empty list. + + Returns + ------- + Dependant + The flattened Dependant model + """ + if visited is None: + visited = [] + visited.append(dependant.cache_key) + + flat_dependant = Dependant( + path_params=dependant.path_params.copy(), + query_params=dependant.query_params.copy(), + header_params=dependant.header_params.copy(), + cookie_params=dependant.cookie_params.copy(), + body_params=dependant.body_params.copy(), + path=dependant.path, + ) + for sub_dependant in dependant.dependencies: + if skip_repeats and sub_dependant.cache_key in visited: + continue + + flat_sub = get_flat_dependant(sub_dependant, skip_repeats=skip_repeats, visited=visited) + + flat_dependant.path_params.extend(flat_sub.path_params) + flat_dependant.query_params.extend(flat_sub.query_params) + flat_dependant.header_params.extend(flat_sub.header_params) + flat_dependant.cookie_params.extend(flat_sub.cookie_params) + flat_dependant.body_params.extend(flat_sub.body_params) + + return flat_dependant + + +def get_flat_params(dependant: Dependant) -> List[ModelField]: + """ + Get a list of all the parameters from a Dependant object. + + Parameters + ---------- + dependant : Dependant + The Dependant object containing the parameters. + + Returns + ------- + List[ModelField] + A list of ModelField objects containing the flat parameters from the Dependant object. + + """ + flat_dependant = get_flat_dependant(dependant, skip_repeats=True) + return ( + flat_dependant.path_params + + flat_dependant.query_params + + flat_dependant.header_params + + flat_dependant.cookie_params + ) diff --git a/aws_lambda_powertools/event_handler/openapi/params.py b/aws_lambda_powertools/event_handler/openapi/params.py index db03123bc20..8e60e4df436 100644 --- a/aws_lambda_powertools/event_handler/openapi/params.py +++ b/aws_lambda_powertools/event_handler/openapi/params.py @@ -9,7 +9,7 @@ from pydantic.version import VERSION as PYDANTIC_VERSION from typing_extensions import Annotated -from aws_lambda_powertools.event_handler.openapi.utils import CacheKey +from aws_lambda_powertools.event_handler.openapi.types import CacheKey """ This turns the low-level function signature into typed, validated Pydantic models for consumption. diff --git a/aws_lambda_powertools/event_handler/openapi/types.py b/aws_lambda_powertools/event_handler/openapi/types.py new file mode 100644 index 00000000000..f779caf4c98 --- /dev/null +++ b/aws_lambda_powertools/event_handler/openapi/types.py @@ -0,0 +1,3 @@ +from typing import Any, Callable, Optional + +CacheKey = Optional[Callable[..., Any]] diff --git a/aws_lambda_powertools/event_handler/openapi/utils.py b/aws_lambda_powertools/event_handler/openapi/utils.py deleted file mode 100644 index eeca1d0fd1e..00000000000 --- a/aws_lambda_powertools/event_handler/openapi/utils.py +++ /dev/null @@ -1,89 +0,0 @@ -from typing import Any, Callable, List, Optional - -from pydantic.fields import ModelField - -from aws_lambda_powertools.event_handler.openapi.params import Dependant - -CacheKey = Optional[Callable[..., Any]] - -""" -This file defines utility functions for working with OpenAPI/JSON Schema models. -""" - - -def get_flat_dependant( - dependant: Dependant, - *, - skip_repeats: bool = False, - visited: Optional[List[CacheKey]] = None, -) -> Dependant: - """ - Flatten a recursive Dependant model structure. - - This function recursively concatenates the parameter fields of a Dependant model and its dependencies into a flat - Dependant structure. This is useful for scenarios like parameter validation where the nested structure is not - relevant. - - Parameters - ---------- - dependant: Dependant - The dependant model to flatten - skip_repeats: bool - If True, child Dependents already visited will be skipped to avoid duplicates - visited: List[CacheKey], optional - Keeps track of visited Dependents to avoid infinite recursion. Defaults to empty list. - - Returns - ------- - Dependant - The flattened Dependant model - """ - if visited is None: - visited = [] - visited.append(dependant.cache_key) - - flat_dependant = Dependant( - path_params=dependant.path_params.copy(), - query_params=dependant.query_params.copy(), - header_params=dependant.header_params.copy(), - cookie_params=dependant.cookie_params.copy(), - body_params=dependant.body_params.copy(), - path=dependant.path, - ) - for sub_dependant in dependant.dependencies: - if skip_repeats and sub_dependant.cache_key in visited: - continue - - flat_sub = get_flat_dependant(sub_dependant, skip_repeats=skip_repeats, visited=visited) - - flat_dependant.path_params.extend(flat_sub.path_params) - flat_dependant.query_params.extend(flat_sub.query_params) - flat_dependant.header_params.extend(flat_sub.header_params) - flat_dependant.cookie_params.extend(flat_sub.cookie_params) - flat_dependant.body_params.extend(flat_sub.body_params) - - return flat_dependant - - -def get_flat_params(dependant: Dependant) -> List[ModelField]: - """ - Get a list of all the parameters from a Dependant object. - - Parameters - ---------- - dependant : Dependant - The Dependant object containing the parameters. - - Returns - ------- - List[ModelField] - A list of ModelField objects containing the flat parameters from the Dependant object. - - """ - flat_dependant = get_flat_dependant(dependant, skip_repeats=True) - return ( - flat_dependant.path_params - + flat_dependant.query_params - + flat_dependant.header_params - + flat_dependant.cookie_params - ) From 9c7c37f4c443065e5f54f7002d48193c16437d7a Mon Sep 17 00:00:00 2001 From: Ruben Fonseca Date: Wed, 27 Sep 2023 10:37:05 +0200 Subject: [PATCH 12/75] fix: mypy --- aws_lambda_powertools/event_handler/openapi/params.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/aws_lambda_powertools/event_handler/openapi/params.py b/aws_lambda_powertools/event_handler/openapi/params.py index 8e60e4df436..df22a803a40 100644 --- a/aws_lambda_powertools/event_handler/openapi/params.py +++ b/aws_lambda_powertools/event_handler/openapi/params.py @@ -1,13 +1,13 @@ import inspect from copy import copy from enum import Enum -from typing import Any, Callable, Dict, List, Optional, Tuple, Union, get_args, get_origin +from typing import Any, Callable, Dict, List, Optional, Tuple, Union from pydantic import BaseConfig from pydantic.fields import FieldInfo, ModelField, Required, Undefined from pydantic.schema import get_annotation_from_field_info from pydantic.version import VERSION as PYDANTIC_VERSION -from typing_extensions import Annotated +from typing_extensions import Annotated, get_args, get_origin from aws_lambda_powertools.event_handler.openapi.types import CacheKey From f02f189bac59c01b858d9ad2bcdc6589e8f64f9b Mon Sep 17 00:00:00 2001 From: Ruben Fonseca Date: Wed, 27 Sep 2023 10:58:14 +0200 Subject: [PATCH 13/75] fix: security baseline --- .../event_handler/openapi/dependant.py | 9 ++- .../event_handler/openapi/params.py | 59 ++++++++++++------- 2 files changed, 43 insertions(+), 25 deletions(-) diff --git a/aws_lambda_powertools/event_handler/openapi/dependant.py b/aws_lambda_powertools/event_handler/openapi/dependant.py index fa915c0a65f..a97edd6c91b 100644 --- a/aws_lambda_powertools/event_handler/openapi/dependant.py +++ b/aws_lambda_powertools/event_handler/openapi/dependant.py @@ -48,7 +48,8 @@ def add_param_to_fields( elif field_info.in_ == ParamTypes.header: dependant.header_params.append(field) else: - assert field_info.in_ == ParamTypes.cookie + if field_info.in_ != ParamTypes.cookie: + raise AssertionError(f"Unsupported param type: {field_info.in_}") dependant.cookie_params.append(field) @@ -167,7 +168,8 @@ def get_dependant( value=param.default, is_path_param=is_path_param, ) - assert param_field is not None + if param_field is None: + raise AssertionError(f"Param field is None for param: {param_name}") add_param_to_fields(field=param_field, dependant=dependant) @@ -180,7 +182,8 @@ def get_dependant( value=None, is_path_param=False, ) - assert param_field is not None + if param_field is None: + raise AssertionError("Param field is None for return annotation") dependant.return_param = param_field diff --git a/aws_lambda_powertools/event_handler/openapi/params.py b/aws_lambda_powertools/event_handler/openapi/params.py index df22a803a40..fb55040f10a 100644 --- a/aws_lambda_powertools/event_handler/openapi/params.py +++ b/aws_lambda_powertools/event_handler/openapi/params.py @@ -195,7 +195,9 @@ def __init__( json_schema_extra: Union[Dict[str, Any], None] = None, **extra: Any, ): - assert default is ..., "Path parameters cannot have a default value" + if default is not ...: + raise AssertionError("Path parameters cannot have a default value") + super(Path, self).__init__( default=default, default_factory=default_factory, @@ -322,7 +324,8 @@ def analyze_param( # If the value is a FieldInfo, we use it as the FieldInfo for the parameter if isinstance(value, FieldInfo): - assert field_info is None + if field_info is not None: + raise AssertionError("Cannot use a FieldInfo as a parameter annotation and pass a FieldInfo as a value") field_info = value # If we didn't determine the FieldInfo yet, we create a default one @@ -343,29 +346,40 @@ def _get_field_info_and_type_annotation(annotation, value, is_path_param: bool) field_info: Optional[FieldInfo] = None type_annotation: Any = Any - # If the annotation is an Annotated type, we need to extract the type annotation and the FieldInfo - if annotation is not inspect.Signature.empty and get_origin(annotation) is Annotated: - annotated_args = get_args(annotation) - type_annotation = annotated_args[0] - powertools_annotations = [arg for arg in annotated_args[1:] if isinstance(arg, FieldInfo)] - assert len(powertools_annotations) <= 1 + if annotation is not inspect.Signature.empty: + # If the annotation is an Annotated type, we need to extract the type annotation and the FieldInfo + if get_origin(annotation) is Annotated: + type_annotation = _get_field_info_annotated_type(annotation, value, is_path_param) + # If the annotation is not an Annotated type, we use it as the type annotation + else: + type_annotation = annotation + + return field_info, type_annotation + + +def _get_field_info_annotated_type(annotation, value, is_path_param: bool) -> Tuple[Optional[FieldInfo], Any]: + field_info: Optional[FieldInfo] = None + annotated_args = get_args(annotation) + type_annotation = annotated_args[0] + powertools_annotations = [arg for arg in annotated_args[1:] if isinstance(arg, FieldInfo)] - powertools_annotation = next(iter(powertools_annotations), None) + if len(powertools_annotations) > 1: + raise AssertionError("Only one FieldInfo can be used per parameter") - if isinstance(powertools_annotation, FieldInfo): - # Copy `field_info` because we mutate `field_info.default` later - field_info = copy(powertools_annotation) - assert field_info.default is Undefined or field_info.default is Required + powertools_annotation = next(iter(powertools_annotations), None) - if value is not inspect.Signature.empty: - assert not is_path_param - field_info.default = value - else: - field_info.default = Required + if isinstance(powertools_annotation, FieldInfo): + # Copy `field_info` because we mutate `field_info.default` later + field_info = copy(powertools_annotation) + if field_info.default not in [Undefined, Required]: + raise AssertionError("FieldInfo needs to have a default value of Undefined or Required") - # If the annotation is not an Annotated type, we use it as the type annotation - elif annotation is not inspect.Signature.empty: - type_annotation = annotation + if value is not inspect.Signature.empty: + if is_path_param: + raise AssertionError("Cannot use a FieldInfo as a path parameter and pass a value") + field_info.default = value + else: + field_info.default = Required return field_info, type_annotation @@ -380,7 +394,8 @@ def _create_model_field( return None if is_path_param: - assert isinstance(field_info, Path), "Path parameters must be of type Path" + if not isinstance(field_info, Path): + raise AssertionError("Path parameters must be of type Path") elif isinstance(field_info, Param) and getattr(field_info, "in_", None) is None: field_info.in_ = ParamTypes.query From 2d304431a976f2c93a65d4926026fe510b920ede Mon Sep 17 00:00:00 2001 From: Ruben Fonseca Date: Wed, 27 Sep 2023 14:14:32 +0200 Subject: [PATCH 14/75] feat: add simultaneous support for Pydantic v2 --- .flake8 | 1 + .../event_handler/api_gateway.py | 67 +++--- .../event_handler/openapi/compat.py | 201 ++++++++++++++++++ .../event_handler/openapi/dependant.py | 4 +- .../event_handler/openapi/models.py | 3 +- .../event_handler/openapi/params.py | 41 ++-- .../event_handler/openapi/types.py | 15 +- 7 files changed, 284 insertions(+), 48 deletions(-) create mode 100644 aws_lambda_powertools/event_handler/openapi/compat.py diff --git a/.flake8 b/.flake8 index 1db8406d9e4..0f309f6621a 100644 --- a/.flake8 +++ b/.flake8 @@ -8,6 +8,7 @@ per-file-ignores = tests/e2e/utils/data_builder/__init__.py:F401 tests/e2e/utils/data_fetcher/__init__.py:F401 aws_lambda_powertools/utilities/data_classes/s3_event.py:A003 + aws_lambda_powertools/event_handler/openapi/compat.py:F401 [isort] multi_line_output = 3 diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index ff31f014f0b..290c562e54f 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -25,20 +25,25 @@ cast, ) -from pydantic.fields import ModelField -from pydantic.schema import ( - TypeModelOrEnum, - field_schema, - get_flat_models_from_fields, - get_model_name_map, - model_process_schema, -) +from typing_extensions import Literal from aws_lambda_powertools.event_handler import content_types from aws_lambda_powertools.event_handler.exceptions import NotFoundError, ServiceError +from aws_lambda_powertools.event_handler.openapi.compat import ( + GenerateJsonSchema, + JsonSchemaValue, + ModelField, + get_compat_model_name_map, + get_definitions, + get_schema_from_model_field, +) from aws_lambda_powertools.event_handler.openapi.dependant import get_dependant, get_flat_params from aws_lambda_powertools.event_handler.openapi.models import Contact, License, OpenAPI, Server, Tag from aws_lambda_powertools.event_handler.openapi.params import Dependant, Param +from aws_lambda_powertools.event_handler.openapi.types import ( + COMPONENT_REF_TEMPLATE, + TypeModelOrEnum, +) from aws_lambda_powertools.event_handler.response import Response from aws_lambda_powertools.shared.functions import powertools_dev_is_set from aws_lambda_powertools.shared.json_encoder import Encoder @@ -61,7 +66,6 @@ _UNSAFE_URI = r"%<> \[\]{}|^" _NAMED_GROUP_BOUNDARY_PATTERN = rf"(?P\1[{_SAFE_URI}{_UNSAFE_URI}\\w]+)" _ROUTE_REGEX = "^{}$" -_COMPONENT_REF_PREFIX = "#/components/schemas/" class ProxyEventType(Enum): @@ -340,6 +344,7 @@ def _get_openapi_path( dependant: Dependant, operation_ids: Set[str], model_name_map: Dict[TypeModelOrEnum, str], + field_mapping: Dict[Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue], ) -> Tuple[Dict[str, Any], Dict[str, Any]]: path = {} definitions: Dict[str, Any] = {} @@ -350,6 +355,7 @@ def _get_openapi_path( operation_params = self._openapi_operation_parameters( all_route_params=all_route_params, model_name_map=model_name_map, + field_mapping=field_mapping, ) parameters.extend(operation_params) @@ -369,6 +375,7 @@ def _get_openapi_path( operation_id=self.operation_id, param=dependant.return_param, model_name_map=model_name_map, + field_mapping=field_mapping, ) path[self.method.lower()] = operation @@ -408,6 +415,10 @@ def _openapi_operation_parameters( *, all_route_params: Sequence[ModelField], model_name_map: Dict[TypeModelOrEnum, str], + field_mapping: Dict[ + Tuple[ModelField, Literal["validation", "serialization"]], + JsonSchemaValue, + ], ) -> List[Dict[str, Any]]: parameters = [] for param in all_route_params: @@ -416,7 +427,11 @@ def _openapi_operation_parameters( if not field_info.include_in_schema: continue - param_schema = field_schema(param, model_name_map=model_name_map, ref_prefix=_COMPONENT_REF_PREFIX)[0] + param_schema = get_schema_from_model_field( + field=param, + model_name_map=model_name_map, + field_mapping=field_mapping, + ) parameter = { "name": param.alias, @@ -441,15 +456,19 @@ def _openapi_operation_return( operation_id: str, param: Optional[ModelField], model_name_map: Dict[TypeModelOrEnum, str], + field_mapping: Dict[ + Tuple[ModelField, Literal["validation", "serialization"]], + JsonSchemaValue, + ], ) -> Dict[str, Any]: if param is None: return {} - return_schema = field_schema( - param, + return_schema = get_schema_from_model_field( + field=param, model_name_map=model_name_map, - ref_prefix=_COMPONENT_REF_PREFIX, - )[0] + field_mapping=field_mapping, + ) return {"name": f"Return {operation_id}", "schema": return_schema} @@ -1053,20 +1072,15 @@ def get_openapi_schema( all_routes = self._dynamic_routes + self._static_routes all_fields = self._get_fields_from_routes(all_routes) - models = get_flat_models_from_fields(all_fields, known_models=set()) - model_name_map = get_model_name_map(models) + model_name_map = get_compat_model_name_map(all_fields) # Collect all models and definitions - definitions: Dict[str, Dict[str, Any]] = {} - for model in models: - m_schema, m_definitions, _ = model_process_schema( - model, - model_name_map=model_name_map, - ref_prefix=_COMPONENT_REF_PREFIX, - ) - definitions.update(m_definitions) - model_name = model_name_map[model] - definitions[model_name] = m_schema + schema_generator = GenerateJsonSchema(ref_template=COMPONENT_REF_TEMPLATE) + field_mapping, definitions = get_definitions( + fields=all_fields, + schema_generator=schema_generator, + model_name_map=model_name_map, + ) # Add routes to the OpenAPI schema for route in all_routes: @@ -1079,6 +1093,7 @@ def get_openapi_schema( dependant=dependant, operation_ids=operation_ids, model_name_map=model_name_map, + field_mapping=field_mapping, ) if result: path, path_definitions = result diff --git a/aws_lambda_powertools/event_handler/openapi/compat.py b/aws_lambda_powertools/event_handler/openapi/compat.py new file mode 100644 index 00000000000..cfb8f777436 --- /dev/null +++ b/aws_lambda_powertools/event_handler/openapi/compat.py @@ -0,0 +1,201 @@ +from dataclasses import dataclass +from enum import Enum +from typing import Any, Dict, List, Set, Tuple, Type, Union + +from typing_extensions import Annotated, Literal + +from aws_lambda_powertools.event_handler.openapi.types import COMPONENT_REF_PREFIX, PYDANTIC_V2, ModelNameMap + +if PYDANTIC_V2: + from pydantic import TypeAdapter + from pydantic._internal._typing_extra import eval_type_lenient + from pydantic.fields import FieldInfo + from pydantic.json_schema import GenerateJsonSchema, JsonSchemaValue + from pydantic_core import PydanticUndefined + + from aws_lambda_powertools.event_handler.openapi.types import IncEx + + Undefined = PydanticUndefined + Required = PydanticUndefined + + evaluate_forwardref = eval_type_lenient + + @dataclass + class ModelField: + field_info: FieldInfo + name: str + mode: Literal["validation", "serialization"] = "validation" + + @property + def alias(self) -> str: + value = self.field_info.alias + return value if value is not None else self.name + + @property + def required(self) -> bool: + return self.field_info.is_required() + + @property + def default(self) -> Any: + return self.get_default() + + @property + def type_(self) -> Any: + return self.field_info.annotation + + def __post__init__(self) -> None: + self._type_adapter: TypeAdapter[Any] = TypeAdapter( + Annotated[self.field_info.annotation, self.field_info], + ) + + def get_default(self) -> Any: + if self.field_info.is_required(): + return Undefined + return self.field_info.get_default(call_default_factory=True) + + def serialize( + self, + value: Any, + *, + mode: Literal["json", "python"] = "json", + include: Union[IncEx, None] = None, + exclude: Union[IncEx, None] = None, + by_alias: bool = True, + exclude_unset: bool = False, + exclude_defaults: bool = False, + exclude_none: bool = False, + ) -> Any: + return self._type_adapter.dump_python( + value, + mode=mode, + include=include, + exclude=exclude, + by_alias=by_alias, + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + exclude_none=exclude_none, + ) + + def __hash__(self) -> int: + # Each ModelField is unique for our purposes + return id(self) + + def get_schema_from_model_field( + *, + field: ModelField, + model_name_map: ModelNameMap, + field_mapping: Dict[ + Tuple[ModelField, Literal["validation", "serialization"]], + JsonSchemaValue, + ], + ) -> Dict[str, Any]: + json_schema = field_mapping[(field, field.mode)] + if "$ref" not in json_schema: + # MAINTENANCE: remove when deprecating Pydantic v1 + # Ref: https://github.com/pydantic/pydantic/blob/d61792cc42c80b13b23e3ffa74bc37ec7c77f7d1/pydantic/schema.py#L207 + json_schema["title"] = field.field_info.title or field.alias.title().replace("_", " ") + return json_schema + + def get_definitions( + *, + fields: List[ModelField], + schema_generator: GenerateJsonSchema, + model_name_map: ModelNameMap, + ) -> Tuple[ + Dict[ + Tuple[ModelField, Literal["validation", "serialization"]], + Dict[str, Any], + ], + Dict[str, Dict[str, Any]], + ]: + inputs = [(field, field.mode, field._type_adapter.core_schema) for field in fields] + field_mapping, definitions = schema_generator.generate_definitions(inputs=inputs) + + return field_mapping, definitions # type: ignore[return-value] + + def get_compat_model_name_map(fields: List[ModelField]) -> ModelNameMap: + return {} + + def get_annotation_from_field_info(annotation: Any, field_info: FieldInfo, field_name: str) -> Any: + return annotation + +else: + from pydantic import BaseModel + from pydantic.fields import ( # type: ignore + ModelField as ModelField, # noqa: F401, PLC0414 + ) + from pydantic.fields import ( # type: ignore + Required as Required, # noqa: F401, PLC0414 + ) + from pydantic.fields import ( # type: ignore + Undefined as Undefined, # noqa: PLC0414 + ) + from pydantic.schema import ( # type: ignore[no-redef] + field_schema, + get_annotation_from_field_info, + get_flat_models_from_fields, + get_model_name_map, + model_process_schema, + ) + + # re-export for compatibility + vars()["get_annotation_from_field_info"] = get_annotation_from_field_info + + # type ignore[no-redef] + from pydantic.typing import evaluate_forwardref # type: ignore[no-redef] # noqa: F401 + + JsonSchemaValue = Dict[str, Any] # type: ignore[misc] + + @dataclass + class GenerateJsonSchema: # type: ignore[no-redef] + ref_template: str + + def get_schema_from_model_field( + *, + field: ModelField, + model_name_map: ModelNameMap, + field_mapping: Dict[ + Tuple[ModelField, Literal["validation", "serialization"]], + JsonSchemaValue, + ], + ) -> Dict[str, Any]: + return field_schema( + field, + model_name_map=model_name_map, + ref_prefix=COMPONENT_REF_PREFIX, + )[0] + + def get_definitions( + *, + fields: List[ModelField], + schema_generator: GenerateJsonSchema, + model_name_map: ModelNameMap, + ) -> Tuple[ + Dict[Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue], + Dict[str, Dict[str, Any]], + ]: + models = get_flat_models_from_fields(fields, known_models=set()) + return {}, get_model_definitions(flat_models=models, model_name_map=model_name_map) + + def get_model_definitions( + *, + flat_models: Set[Union[Type[BaseModel], Type[Enum]]], + model_name_map: ModelNameMap, + ) -> Dict[str, Any]: + definitions: Dict[str, Dict[str, Any]] = {} + for model in flat_models: + m_schema, m_definitions, _ = model_process_schema( + model, + model_name_map=model_name_map, + ref_prefix=COMPONENT_REF_PREFIX, + ) + definitions.update(m_definitions) + model_name = model_name_map[model] + if "description" in m_schema: + m_schema["description"] = m_schema["description"].split("\f")[0] + definitions[model_name] = m_schema + return definitions + + def get_compat_model_name_map(fields: List[ModelField]) -> ModelNameMap: + models = get_flat_models_from_fields(fields, known_models=set()) + return get_model_name_map(models) diff --git a/aws_lambda_powertools/event_handler/openapi/dependant.py b/aws_lambda_powertools/event_handler/openapi/dependant.py index a97edd6c91b..8cc57944f80 100644 --- a/aws_lambda_powertools/event_handler/openapi/dependant.py +++ b/aws_lambda_powertools/event_handler/openapi/dependant.py @@ -2,9 +2,7 @@ import re from typing import Any, Callable, Dict, ForwardRef, List, Optional, Set, cast -from pydantic.fields import ModelField -from pydantic.typing import evaluate_forwardref - +from aws_lambda_powertools.event_handler.openapi.compat import ModelField, evaluate_forwardref from aws_lambda_powertools.event_handler.openapi.params import Dependant, Param, ParamTypes, analyze_param from aws_lambda_powertools.event_handler.openapi.types import CacheKey diff --git a/aws_lambda_powertools/event_handler/openapi/models.py b/aws_lambda_powertools/event_handler/openapi/models.py index 2cf06155de5..f9749895d3f 100644 --- a/aws_lambda_powertools/event_handler/openapi/models.py +++ b/aws_lambda_powertools/event_handler/openapi/models.py @@ -2,10 +2,9 @@ from typing import Any, Dict, List, Optional, Set, Union from pydantic import AnyUrl, BaseModel, Field -from pydantic.version import VERSION as PYDANTIC_VERSION from typing_extensions import Annotated, Literal -PYDANTIC_V2 = PYDANTIC_VERSION.startswith("2.") +from aws_lambda_powertools.event_handler.openapi.types import PYDANTIC_V2 """ The code defines Pydantic models for the various OpenAPI objects like OpenAPI, PathItem, Operation, Parameter etc. diff --git a/aws_lambda_powertools/event_handler/openapi/params.py b/aws_lambda_powertools/event_handler/openapi/params.py index fb55040f10a..18909d56b44 100644 --- a/aws_lambda_powertools/event_handler/openapi/params.py +++ b/aws_lambda_powertools/event_handler/openapi/params.py @@ -4,19 +4,21 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union from pydantic import BaseConfig -from pydantic.fields import FieldInfo, ModelField, Required, Undefined -from pydantic.schema import get_annotation_from_field_info -from pydantic.version import VERSION as PYDANTIC_VERSION +from pydantic.fields import FieldInfo from typing_extensions import Annotated, get_args, get_origin -from aws_lambda_powertools.event_handler.openapi.types import CacheKey +from aws_lambda_powertools.event_handler.openapi.compat import ( + ModelField, + Required, + Undefined, + get_annotation_from_field_info, +) +from aws_lambda_powertools.event_handler.openapi.types import PYDANTIC_V2, CacheKey """ This turns the low-level function signature into typed, validated Pydantic models for consumption. """ -PYDANTIC_V2 = PYDANTIC_VERSION.startswith("2.") - class Dependant: """ @@ -410,13 +412,20 @@ def _create_model_field( field_info.alias = alias # Create the Pydantic field - return ModelField( - name=param_name, - field_info=field_info, - type_=use_annotation, - class_validators={}, - default=field_info.default, - required=field_info.default in (Required, Undefined), - model_config=BaseConfig, - alias=alias, - ) + kwargs = {"name": param_name, "field_info": field_info} + + if PYDANTIC_V2: + kwargs.update({"mode": "validation"}) + else: + kwargs.update( + { + "type_": use_annotation, + "class_validators": {}, + "default": field_info.default, + "required": field_info.default in (Required, Undefined), + "model_config": BaseConfig, + "alias": alias, + }, + ) + + return ModelField(**kwargs) # type: ignore[arg-type] diff --git a/aws_lambda_powertools/event_handler/openapi/types.py b/aws_lambda_powertools/event_handler/openapi/types.py index f779caf4c98..bc994e7cfc9 100644 --- a/aws_lambda_powertools/event_handler/openapi/types.py +++ b/aws_lambda_powertools/event_handler/openapi/types.py @@ -1,3 +1,16 @@ -from typing import Any, Callable, Optional +from enum import Enum +from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Set, Type, Union + +from pydantic.version import VERSION as PYDANTIC_VERSION + +if TYPE_CHECKING: + from pydantic import BaseModel # noqa: F401 CacheKey = Optional[Callable[..., Any]] +IncEx = Union[Set[int], Set[str], Dict[int, Any], Dict[str, Any]] +ModelNameMap = Dict[Union[Type["BaseModel"], Type[Enum]], str] +TypeModelOrEnum = Union[Type["BaseModel"], Type[Enum]] + +PYDANTIC_V2 = PYDANTIC_VERSION.startswith("2.") +COMPONENT_REF_PREFIX = "#/components/schemas/" +COMPONENT_REF_TEMPLATE = "#/components/schemas/{model}" From 3b0037fe0e9608beec4623bf0935c84f86bed054 Mon Sep 17 00:00:00 2001 From: Ruben Fonseca Date: Wed, 27 Sep 2023 15:10:59 +0200 Subject: [PATCH 15/75] fix: disable mypy and ruff on openapi compat --- .../event_handler/openapi/compat.py | 28 ++++++------------- pyproject.toml | 1 + 2 files changed, 10 insertions(+), 19 deletions(-) diff --git a/aws_lambda_powertools/event_handler/openapi/compat.py b/aws_lambda_powertools/event_handler/openapi/compat.py index cfb8f777436..a10d20746f9 100644 --- a/aws_lambda_powertools/event_handler/openapi/compat.py +++ b/aws_lambda_powertools/event_handler/openapi/compat.py @@ -1,3 +1,6 @@ +# mypy: ignore-errors +# flake8: noqa + from dataclasses import dataclass from enum import Enum from typing import Any, Dict, List, Set, Tuple, Type, Union @@ -111,7 +114,7 @@ def get_definitions( inputs = [(field, field.mode, field._type_adapter.core_schema) for field in fields] field_mapping, definitions = schema_generator.generate_definitions(inputs=inputs) - return field_mapping, definitions # type: ignore[return-value] + return field_mapping, definitions def get_compat_model_name_map(fields: List[ModelField]) -> ModelNameMap: return {} @@ -121,33 +124,20 @@ def get_annotation_from_field_info(annotation: Any, field_info: FieldInfo, field else: from pydantic import BaseModel - from pydantic.fields import ( # type: ignore - ModelField as ModelField, # noqa: F401, PLC0414 - ) - from pydantic.fields import ( # type: ignore - Required as Required, # noqa: F401, PLC0414 - ) - from pydantic.fields import ( # type: ignore - Undefined as Undefined, # noqa: PLC0414 - ) - from pydantic.schema import ( # type: ignore[no-redef] + from pydantic.fields import ModelField, Required, Undefined + from pydantic.schema import ( field_schema, get_annotation_from_field_info, get_flat_models_from_fields, get_model_name_map, model_process_schema, ) + from pydantic.typing import evaluate_forwardref - # re-export for compatibility - vars()["get_annotation_from_field_info"] = get_annotation_from_field_info - - # type ignore[no-redef] - from pydantic.typing import evaluate_forwardref # type: ignore[no-redef] # noqa: F401 - - JsonSchemaValue = Dict[str, Any] # type: ignore[misc] + JsonSchemaValue = Dict[str, Any] @dataclass - class GenerateJsonSchema: # type: ignore[no-redef] + class GenerateJsonSchema: ref_template: str def get_schema_from_model_field( diff --git a/pyproject.toml b/pyproject.toml index d2e785b4761..2d6aa14a42b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -169,6 +169,7 @@ exclude = ''' | buck-out | build | dist + | aws_lambda_powertools/event_handler/openapi/compat.py )/ | example ) From 633ceb40281fd84135af1c032352cf873a6a55ce Mon Sep 17 00:00:00 2001 From: Ruben Fonseca Date: Wed, 27 Sep 2023 15:17:07 +0200 Subject: [PATCH 16/75] chore: add explanation to imports --- aws_lambda_powertools/event_handler/openapi/compat.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/aws_lambda_powertools/event_handler/openapi/compat.py b/aws_lambda_powertools/event_handler/openapi/compat.py index a10d20746f9..bd5d8c4c445 100644 --- a/aws_lambda_powertools/event_handler/openapi/compat.py +++ b/aws_lambda_powertools/event_handler/openapi/compat.py @@ -1,6 +1,9 @@ # mypy: ignore-errors # flake8: noqa +# MAINTENANCE: remove when deprecating Pydantic v1. Mypy doesn't handle two different code paths that import different +# versions of a module, so we need to ignore errors here. + from dataclasses import dataclass from enum import Enum from typing import Any, Dict, List, Set, Tuple, Type, Union From b4fcde6ce495a195baeb1e1f7f4a2af1007b53c4 Mon Sep 17 00:00:00 2001 From: Ruben Fonseca Date: Thu, 28 Sep 2023 11:30:46 +0200 Subject: [PATCH 17/75] chore: add first test --- .../event_handler/test_openapi_servers.py | 26 +++++++++++++++++++ 1 file changed, 26 insertions(+) create mode 100644 tests/functional/event_handler/test_openapi_servers.py diff --git a/tests/functional/event_handler/test_openapi_servers.py b/tests/functional/event_handler/test_openapi_servers.py new file mode 100644 index 00000000000..20d0dd5550d --- /dev/null +++ b/tests/functional/event_handler/test_openapi_servers.py @@ -0,0 +1,26 @@ +from aws_lambda_powertools.event_handler.api_gateway import ApiGatewayResolver +from aws_lambda_powertools.event_handler.openapi.models import Server + + +def test_openapi_schema_default_server(): + app = ApiGatewayResolver() + + schema = app.get_openapi_schema(title="Hello API", version="1.0.0") + assert schema.servers + assert len(schema.servers) == 1 + assert schema.servers[0].url == "/" + + +def test_openapi_schema_custom_server(): + app = ApiGatewayResolver() + + schema = app.get_openapi_schema( + title="Hello API", + version="1.0.0", + servers=[Server(url="https://example.org", description="Example website")], + ) + + assert schema.servers + assert len(schema.servers) == 1 + assert schema.servers[0].url == "https://example.org" + assert schema.servers[0].description == "Example website" From bca3c715dc1382942a940eae394926f6a71a347a Mon Sep 17 00:00:00 2001 From: Ruben Fonseca Date: Thu, 28 Sep 2023 11:47:43 +0200 Subject: [PATCH 18/75] fix: test --- tests/functional/event_handler/test_openapi_servers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/functional/event_handler/test_openapi_servers.py b/tests/functional/event_handler/test_openapi_servers.py index 20d0dd5550d..ac96bf59597 100644 --- a/tests/functional/event_handler/test_openapi_servers.py +++ b/tests/functional/event_handler/test_openapi_servers.py @@ -22,5 +22,5 @@ def test_openapi_schema_custom_server(): assert schema.servers assert len(schema.servers) == 1 - assert schema.servers[0].url == "https://example.org" + assert str(schema.servers[0].url) == "https://example.org" assert schema.servers[0].description == "Example website" From 88ec111ae76ce5470ff443fb7dac48976877dba8 Mon Sep 17 00:00:00 2001 From: Ruben Fonseca Date: Thu, 28 Sep 2023 11:54:24 +0200 Subject: [PATCH 19/75] fix: test --- tests/functional/event_handler/test_openapi_servers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/functional/event_handler/test_openapi_servers.py b/tests/functional/event_handler/test_openapi_servers.py index ac96bf59597..e348afbd08c 100644 --- a/tests/functional/event_handler/test_openapi_servers.py +++ b/tests/functional/event_handler/test_openapi_servers.py @@ -17,10 +17,10 @@ def test_openapi_schema_custom_server(): schema = app.get_openapi_schema( title="Hello API", version="1.0.0", - servers=[Server(url="https://example.org", description="Example website")], + servers=[Server(url="https://example.org/", description="Example website")], ) assert schema.servers assert len(schema.servers) == 1 - assert str(schema.servers[0].url) == "https://example.org" + assert str(schema.servers[0].url) == "https://example.org/" assert schema.servers[0].description == "Example website" From ba2e8f0f8dc64b49f145bb4ed3141b76cf697914 Mon Sep 17 00:00:00 2001 From: Ruben Fonseca Date: Thu, 28 Sep 2023 14:06:11 +0200 Subject: [PATCH 20/75] fix: don't require pydantic to run normal things --- .../event_handler/api_gateway.py | 127 +++++++++++------- .../event_handler/openapi/__init__.py | 11 -- 2 files changed, 79 insertions(+), 59 deletions(-) diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index 290c562e54f..ce0e6dd23d8 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -10,6 +10,7 @@ from functools import partial from http import HTTPStatus from typing import ( + TYPE_CHECKING, Any, Callable, Dict, @@ -29,21 +30,6 @@ from aws_lambda_powertools.event_handler import content_types from aws_lambda_powertools.event_handler.exceptions import NotFoundError, ServiceError -from aws_lambda_powertools.event_handler.openapi.compat import ( - GenerateJsonSchema, - JsonSchemaValue, - ModelField, - get_compat_model_name_map, - get_definitions, - get_schema_from_model_field, -) -from aws_lambda_powertools.event_handler.openapi.dependant import get_dependant, get_flat_params -from aws_lambda_powertools.event_handler.openapi.models import Contact, License, OpenAPI, Server, Tag -from aws_lambda_powertools.event_handler.openapi.params import Dependant, Param -from aws_lambda_powertools.event_handler.openapi.types import ( - COMPONENT_REF_TEMPLATE, - TypeModelOrEnum, -) from aws_lambda_powertools.event_handler.response import Response from aws_lambda_powertools.shared.functions import powertools_dev_is_set from aws_lambda_powertools.shared.json_encoder import Encoder @@ -67,6 +53,21 @@ _NAMED_GROUP_BOUNDARY_PATTERN = rf"(?P\1[{_SAFE_URI}{_UNSAFE_URI}\\w]+)" _ROUTE_REGEX = "^{}$" +if TYPE_CHECKING: + from aws_lambda_powertools.event_handler.openapi.compat import ( + JsonSchemaValue, + ModelField, + ) + from aws_lambda_powertools.event_handler.openapi.models import ( + Contact, + License, + OpenAPI, + Server, + Tag, + ) + from aws_lambda_powertools.event_handler.openapi.params import Dependant + from aws_lambda_powertools.event_handler.openapi.types import TypeModelOrEnum + class ProxyEventType(Enum): """An enumerations of the supported proxy event types.""" @@ -204,7 +205,7 @@ def __init__( cache_control: Optional[str], middlewares: Optional[List[Callable[..., Response]]], description: Optional[str], - tags: Optional[List[Tag]], + tags: Optional[List["Tag"]], ): """ @@ -341,11 +342,15 @@ def _build_middleware_stack(self, router_middlewares: List[Callable[..., Any]]) def _get_openapi_path( self, *, - dependant: Dependant, + dependant: "Dependant", operation_ids: Set[str], - model_name_map: Dict[TypeModelOrEnum, str], - field_mapping: Dict[Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue], + model_name_map: Dict["TypeModelOrEnum", str], + field_mapping: Dict[Tuple["ModelField", Literal["validation", "serialization"]], "JsonSchemaValue"], ) -> Tuple[Dict[str, Any], Dict[str, Any]]: + from aws_lambda_powertools.event_handler.openapi.dependant import ( + get_flat_params, + ) + path = {} definitions: Dict[str, Any] = {} @@ -413,13 +418,18 @@ def _openapi_operation_metadata(self, operation_ids: Set[str]) -> Dict[str, Any] @staticmethod def _openapi_operation_parameters( *, - all_route_params: Sequence[ModelField], - model_name_map: Dict[TypeModelOrEnum, str], + all_route_params: Sequence["ModelField"], + model_name_map: Dict["TypeModelOrEnum", str], field_mapping: Dict[ - Tuple[ModelField, Literal["validation", "serialization"]], - JsonSchemaValue, + Tuple["ModelField", Literal["validation", "serialization"]], + "JsonSchemaValue", ], ) -> List[Dict[str, Any]]: + from aws_lambda_powertools.event_handler.openapi.compat import ( + get_schema_from_model_field, + ) + from aws_lambda_powertools.event_handler.openapi.params import Param + parameters = [] for param in all_route_params: field_info = param.field_info @@ -454,16 +464,20 @@ def _openapi_operation_parameters( def _openapi_operation_return( *, operation_id: str, - param: Optional[ModelField], - model_name_map: Dict[TypeModelOrEnum, str], + param: Optional["ModelField"], + model_name_map: Dict["TypeModelOrEnum", str], field_mapping: Dict[ - Tuple[ModelField, Literal["validation", "serialization"]], - JsonSchemaValue, + Tuple["ModelField", Literal["validation", "serialization"]], + "JsonSchemaValue", ], ) -> Dict[str, Any]: if param is None: return {} + from aws_lambda_powertools.event_handler.openapi.compat import ( + get_schema_from_model_field, + ) + return_schema = get_schema_from_model_field( field=param, model_name_map=model_name_map, @@ -584,7 +598,7 @@ def route( compress: bool = False, cache_control: Optional[str] = None, description: Optional[str] = None, - tags: Optional[List[Tag]] = None, + tags: Optional[List["Tag"]] = None, middlewares: Optional[List[Callable[..., Any]]] = None, ): raise NotImplementedError() @@ -638,7 +652,7 @@ def get( cache_control: Optional[str] = None, middlewares: Optional[List[Callable[..., Any]]] = None, description: Optional[str] = None, - tags: Optional[List[Tag]] = None, + tags: Optional[List["Tag"]] = None, ): """Get route decorator with GET `method` @@ -672,7 +686,7 @@ def post( cache_control: Optional[str] = None, middlewares: Optional[List[Callable[..., Any]]] = None, description: Optional[str] = None, - tags: Optional[List[Tag]] = None, + tags: Optional[List["Tag"]] = None, ): """Post route decorator with POST `method` @@ -707,7 +721,7 @@ def put( cache_control: Optional[str] = None, middlewares: Optional[List[Callable[..., Any]]] = None, description: Optional[str] = None, - tags: Optional[List[Tag]] = None, + tags: Optional[List["Tag"]] = None, ): """Put route decorator with PUT `method` @@ -742,7 +756,7 @@ def delete( cache_control: Optional[str] = None, middlewares: Optional[List[Callable[..., Any]]] = None, description: Optional[str] = None, - tags: Optional[List[Tag]] = None, + tags: Optional[List["Tag"]] = None, ): """Delete route decorator with DELETE `method` @@ -776,7 +790,7 @@ def patch( cache_control: Optional[str] = None, middlewares: Optional[List[Callable]] = None, description: Optional[str] = None, - tags: Optional[List[Tag]] = None, + tags: Optional[List["Tag"]] = None, ): """Patch route decorator with PATCH `method` @@ -1007,12 +1021,12 @@ def get_openapi_schema( openapi_version: str = "3.1.0", summary: Optional[str] = None, description: Optional[str] = None, - tags: Optional[List[Tag]] = None, - servers: Optional[List[Server]] = None, + tags: Optional[List["Tag"]] = None, + servers: Optional[List["Server"]] = None, terms_of_service: Optional[str] = None, - contact: Optional[Contact] = None, - license_info: Optional[License] = None, - ) -> OpenAPI: + contact: Optional["Contact"] = None, + license_info: Optional["License"] = None, + ) -> "OpenAPI": """ Returns the OpenAPI schema as a pydantic model. @@ -1045,6 +1059,17 @@ def get_openapi_schema( The OpenAPI schema as a pydantic model. """ + from aws_lambda_powertools.event_handler.openapi.compat import ( + GenerateJsonSchema, + get_compat_model_name_map, + get_definitions, + ) + from aws_lambda_powertools.event_handler.openapi.dependant import get_dependant + from aws_lambda_powertools.event_handler.openapi.models import OpenAPI, Server + from aws_lambda_powertools.event_handler.openapi.types import ( + COMPONENT_REF_TEMPLATE, + ) + # Start with the bare minimum required for a valid OpenAPI schema info: Dict[str, Any] = {"title": title, "version": version} @@ -1121,11 +1146,11 @@ def get_openapi_json_schema( openapi_version: str = "3.1.0", summary: Optional[str] = None, description: Optional[str] = None, - tags: Optional[List[Tag]] = None, - servers: Optional[List[Server]] = None, + tags: Optional[List["Tag"]] = None, + servers: Optional[List["Server"]] = None, terms_of_service: Optional[str] = None, - contact: Optional[Contact] = None, - license_info: Optional[License] = None, + contact: Optional["Contact"] = None, + license_info: Optional["License"] = None, ) -> str: """ Returns the OpenAPI schema as a JSON serializable dict @@ -1179,7 +1204,7 @@ def route( compress: bool = False, cache_control: Optional[str] = None, description: Optional[str] = None, - tags: Optional[List[Tag]] = None, + tags: Optional[List["Tag"]] = None, middlewares: Optional[List[Callable[..., Any]]] = None, ): """Route decorator includes parameter `method`""" @@ -1561,12 +1586,18 @@ def include_router(self, router: "Router", prefix: Optional[str] = None) -> None self.route(*new_route, middlewares=middlewares)(func) # type: ignore @staticmethod - def _get_fields_from_routes(routes: Sequence[Route]) -> List[ModelField]: + def _get_fields_from_routes(routes: Sequence[Route]) -> List["ModelField"]: """ Returns a list of fields from the routes """ - responses_from_routes: List[ModelField] = [] - request_fields_from_routes: List[ModelField] = [] + + from aws_lambda_powertools.event_handler.openapi.dependant import ( + get_dependant, + get_flat_params, + ) + + responses_from_routes: List["ModelField"] = [] + request_fields_from_routes: List["ModelField"] = [] for route in routes: dependant = get_dependant(path=route.path, call=route.func) @@ -1597,7 +1628,7 @@ def route( compress: bool = False, cache_control: Optional[str] = None, description: Optional[str] = None, - tags: Optional[List[Tag]] = None, + tags: Optional[List["Tag"]] = None, middlewares: Optional[List[Callable[..., Any]]] = None, ): def register_route(func: Callable): @@ -1645,7 +1676,7 @@ def route( compress: bool = False, cache_control: Optional[str] = None, description: Optional[str] = None, - tags: Optional[List[Tag]] = None, + tags: Optional[List["Tag"]] = None, middlewares: Optional[List[Callable[..., Any]]] = None, ): # NOTE: see #1552 for more context. diff --git a/aws_lambda_powertools/event_handler/openapi/__init__.py b/aws_lambda_powertools/event_handler/openapi/__init__.py index 91c5b0259f2..e69de29bb2d 100644 --- a/aws_lambda_powertools/event_handler/openapi/__init__.py +++ b/aws_lambda_powertools/event_handler/openapi/__init__.py @@ -1,11 +0,0 @@ -from aws_lambda_powertools.event_handler.openapi.models import ( - Example, - Info, - MediaType, - Operation, - Reference, - Response, - Schema, -) - -__all__ = ["Info", "Operation", "Response", "MediaType", "Reference", "Schema", "Example"] From c97d016696047ece887a4c366f2ee891967c7fea Mon Sep 17 00:00:00 2001 From: Ruben Fonseca Date: Thu, 28 Sep 2023 17:16:04 +0200 Subject: [PATCH 21/75] chore: added first tests --- .../event_handler/api_gateway.py | 15 +- .../event_handler/openapi/models.py | 2 - .../event_handler/test_openapi_params.py | 152 ++++++++++++++++++ 3 files changed, 160 insertions(+), 9 deletions(-) create mode 100644 tests/functional/event_handler/test_openapi_params.py diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index ce0e6dd23d8..5c2fec2cabc 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -376,11 +376,13 @@ def _get_openapi_path( success_response["content"] = {"application/json": {"schema": {}}} json_response = success_response["content"].setdefault("application/json", {}) - json_response["schema"] = self._openapi_operation_return( - operation_id=self.operation_id, - param=dependant.return_param, - model_name_map=model_name_map, - field_mapping=field_mapping, + json_response.update( + self._openapi_operation_return( + operation_id=self.operation_id, + param=dependant.return_param, + model_name_map=model_name_map, + field_mapping=field_mapping, + ), ) path[self.method.lower()] = operation @@ -389,8 +391,7 @@ def _get_openapi_path( return path, definitions def _openapi_operation_summary(self) -> str: - # Generate a summary from the pattern - return self.rule.__str__().replace("_", " ").title() + return f"{self.method.upper()} {self.path}" def _openapi_operation_metadata(self, operation_ids: Set[str]) -> Dict[str, Any]: operation: Dict[str, Any] = {} diff --git a/aws_lambda_powertools/event_handler/openapi/models.py b/aws_lambda_powertools/event_handler/openapi/models.py index f9749895d3f..40d80003f1c 100644 --- a/aws_lambda_powertools/event_handler/openapi/models.py +++ b/aws_lambda_powertools/event_handler/openapi/models.py @@ -269,7 +269,6 @@ class Config: # https://swagger.io/specification/#media-type-object class MediaType(BaseModel): schema_: Optional[Union[Schema, Reference]] = Field(default=None, alias="schema") - example: Optional[Any] = None examples: Optional[Dict[str, Union[Example, Reference]]] = None encoding: Optional[Dict[str, Encoding]] = None @@ -292,7 +291,6 @@ class ParameterBase(BaseModel): explode: Optional[bool] = None allowReserved: Optional[bool] = None schema_: Optional[Union[Schema, Reference]] = Field(default=None, alias="schema") - example: Optional[Any] = None examples: Optional[Dict[str, Union[Example, Reference]]] = None # Serialization rules for more complex scenarios content: Optional[Dict[str, MediaType]] = None diff --git a/tests/functional/event_handler/test_openapi_params.py b/tests/functional/event_handler/test_openapi_params.py new file mode 100644 index 00000000000..245dd7e138d --- /dev/null +++ b/tests/functional/event_handler/test_openapi_params.py @@ -0,0 +1,152 @@ +from dataclasses import dataclass + +from pydantic import BaseModel + +from aws_lambda_powertools.event_handler.api_gateway import ApiGatewayResolver +from aws_lambda_powertools.event_handler.openapi.models import ( + Parameter, + ParameterInType, + Schema, +) + + +def test_openapi_no_params(): + app = ApiGatewayResolver() + + @app.get("/") + def handler(): + pass + + schema = app.get_openapi_schema(title="Hello API", version="1.0.0") + + assert len(schema.paths.keys()) == 1 + assert "/" in schema.paths + + path = schema.paths["/"] + assert path.get + + get = path.get + assert get.summary == "GET /" + assert get.operationId == "GetHandler" + + assert "200" in get.responses + response = get.responses["200"] + assert response.description == "Success" + + assert "application/json" in response.content + json_response = response.content["application/json"] + assert json_response.schema_ == Schema() + assert not json_response.examples + assert not json_response.encoding + + +def test_openapi_with_scalar_params(): + app = ApiGatewayResolver() + + @app.get("/users/") + def handler(user_id: str, include_extra: bool = False): + pass + + schema = app.get_openapi_schema(title="Hello API", version="1.0.0") + + assert len(schema.paths.keys()) == 1 + assert "/users/" in schema.paths + + path = schema.paths["/users/"] + assert path.get + + get = path.get + assert get.summary == "GET /users/" + assert get.operationId == "GetHandler" + assert len(get.parameters) == 2 + + parameter = get.parameters[0] + assert isinstance(parameter, Parameter) + assert parameter.in_ == ParameterInType.path + assert parameter.name == "user_id" + assert parameter.required is True + assert parameter.schema_.default is None + assert parameter.schema_.type == "string" + assert parameter.schema_.title == "User Id" + + parameter = get.parameters[1] + assert isinstance(parameter, Parameter) + assert parameter.in_ == ParameterInType.query + assert parameter.name == "include_extra" + assert parameter.required is False + assert parameter.schema_.default is False + assert parameter.schema_.type == "boolean" + assert parameter.schema_.title == "Include Extra" + + +def test_openapi_with_scalar_returns(): + app = ApiGatewayResolver() + + @app.get("/") + def handler() -> str: + return "Hello, world" + + schema = app.get_openapi_schema(title="Hello API", version="1.0.0") + assert len(schema.paths.keys()) == 1 + + get = schema.paths["/"].get + assert get.parameters is None + + response = get.responses["200"].content["application/json"] + assert response.schema_.title == "Return" + assert response.schema_.type == "string" + + +def test_openapi_with_pydantic_returns(): + app = ApiGatewayResolver() + + class User(BaseModel): + name: str + + @app.get("/") + def handler() -> User: + return User(name="Ruben Fonseca") + + schema = app.get_openapi_schema(title="Hello API", version="1.0.0") + assert len(schema.paths.keys()) == 1 + + get = schema.paths["/"].get + assert get.parameters is None + + response = get.responses["200"].content["application/json"] + reference = response.schema_ + assert reference.ref == "#/components/schemas/User" + + assert "User" in schema.components.schemas + user_schema = schema.components.schemas["User"] + assert isinstance(user_schema, Schema) + assert user_schema.title == "User" + assert "name" in user_schema.properties + + +def test_openapi_with_dataclasse_return(): + app = ApiGatewayResolver() + + @dataclass + class User: + surname: str + + @app.get("/") + def handler() -> User: + return User(name="Ruben Fonseca") + + schema = app.get_openapi_schema(title="Hello API", version="1.0.0") + assert len(schema.paths.keys()) == 1 + + get = schema.paths["/"].get + assert get.parameters is None + + response = get.responses["200"].content["application/json"] + reference = response.schema_ + assert reference.ref == "#/components/schemas/User" + + assert "User" in schema.components.schemas + user_schema = schema.components.schemas["User"] + assert isinstance(user_schema, Schema) + assert user_schema.title == "User" + assert "surname" in user_schema.properties From e4de16c14cf1103e751693f8819072a2d351e0a9 Mon Sep 17 00:00:00 2001 From: Ruben Fonseca Date: Thu, 28 Sep 2023 17:26:17 +0200 Subject: [PATCH 22/75] fix: refactored tests to remove code smell --- .../event_handler/api_gateway.py | 8 ++--- .../event_handler/test_openapi_params.py | 30 +++++++++++-------- 2 files changed, 22 insertions(+), 16 deletions(-) diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index 5c2fec2cabc..ad05573a167 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -1017,8 +1017,8 @@ def __init__( def get_openapi_schema( self, *, - title: str, - version: str, + title: str = "Powertools API", + version: str = "1.0.0", openapi_version: str = "3.1.0", summary: Optional[str] = None, description: Optional[str] = None, @@ -1142,8 +1142,8 @@ def get_openapi_schema( def get_openapi_json_schema( self, *, - title: str, - version: str, + title: str = "Powertools API", + version: str = "1.0.0", openapi_version: str = "3.1.0", summary: Optional[str] = None, description: Optional[str] = None, diff --git a/tests/functional/event_handler/test_openapi_params.py b/tests/functional/event_handler/test_openapi_params.py index 245dd7e138d..7c170cfe98f 100644 --- a/tests/functional/event_handler/test_openapi_params.py +++ b/tests/functional/event_handler/test_openapi_params.py @@ -9,15 +9,19 @@ Schema, ) +JSON_CONTENT_TYPE = "application/json" + def test_openapi_no_params(): app = ApiGatewayResolver() @app.get("/") def handler(): - pass + raise NotImplementedError() - schema = app.get_openapi_schema(title="Hello API", version="1.0.0") + schema = app.get_openapi_schema() + assert schema.info.title == "Powertools API" + assert schema.info.version == "1.0.0" assert len(schema.paths.keys()) == 1 assert "/" in schema.paths @@ -33,8 +37,8 @@ def handler(): response = get.responses["200"] assert response.description == "Success" - assert "application/json" in response.content - json_response = response.content["application/json"] + assert JSON_CONTENT_TYPE in response.content + json_response = response.content[JSON_CONTENT_TYPE] assert json_response.schema_ == Schema() assert not json_response.examples assert not json_response.encoding @@ -45,9 +49,11 @@ def test_openapi_with_scalar_params(): @app.get("/users/") def handler(user_id: str, include_extra: bool = False): - pass + raise NotImplementedError() - schema = app.get_openapi_schema(title="Hello API", version="1.0.0") + schema = app.get_openapi_schema(title="My API", version="0.2.2") + assert schema.info.title == "My API" + assert schema.info.version == "0.2.2" assert len(schema.paths.keys()) == 1 assert "/users/" in schema.paths @@ -86,13 +92,13 @@ def test_openapi_with_scalar_returns(): def handler() -> str: return "Hello, world" - schema = app.get_openapi_schema(title="Hello API", version="1.0.0") + schema = app.get_openapi_schema() assert len(schema.paths.keys()) == 1 get = schema.paths["/"].get assert get.parameters is None - response = get.responses["200"].content["application/json"] + response = get.responses["200"].content[JSON_CONTENT_TYPE] assert response.schema_.title == "Return" assert response.schema_.type == "string" @@ -107,13 +113,13 @@ class User(BaseModel): def handler() -> User: return User(name="Ruben Fonseca") - schema = app.get_openapi_schema(title="Hello API", version="1.0.0") + schema = app.get_openapi_schema() assert len(schema.paths.keys()) == 1 get = schema.paths["/"].get assert get.parameters is None - response = get.responses["200"].content["application/json"] + response = get.responses["200"].content[JSON_CONTENT_TYPE] reference = response.schema_ assert reference.ref == "#/components/schemas/User" @@ -135,13 +141,13 @@ class User: def handler() -> User: return User(name="Ruben Fonseca") - schema = app.get_openapi_schema(title="Hello API", version="1.0.0") + schema = app.get_openapi_schema() assert len(schema.paths.keys()) == 1 get = schema.paths["/"].get assert get.parameters is None - response = get.responses["200"].content["application/json"] + response = get.responses["200"].content[JSON_CONTENT_TYPE] reference = response.schema_ assert reference.ref == "#/components/schemas/User" From c92b8c0e30a5dbdd58220a82d6b1e72734938fa0 Mon Sep 17 00:00:00 2001 From: Ruben Fonseca Date: Mon, 2 Oct 2023 11:49:33 +0200 Subject: [PATCH 23/75] fix: customize the handler methods --- .../event_handler/api_gateway.py | 190 ++++++++++++++++-- .../event_handler/openapi/params.py | 2 +- .../event_handler/test_openapi_params.py | 72 ++++++- 3 files changed, 238 insertions(+), 26 deletions(-) diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index ad05573a167..00b00b3eb54 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -51,6 +51,7 @@ # API GW/ALB decode non-safe URI chars; we must support them too _UNSAFE_URI = r"%<> \[\]{}|^" _NAMED_GROUP_BOUNDARY_PATTERN = rf"(?P\1[{_SAFE_URI}{_UNSAFE_URI}\\w]+)" +_DEFAULT_OPENAPI_RESPONSE_DESCRIPTION = "Successful Response" _ROUTE_REGEX = "^{}$" if TYPE_CHECKING: @@ -203,9 +204,13 @@ def __init__( cors: bool, compress: bool, cache_control: Optional[str], - middlewares: Optional[List[Callable[..., Response]]], + summary: Optional[str], description: Optional[str], + responses: Optional[Dict[Union[int, str], Dict[str, Any]]], + response_description: Optional[str], tags: Optional[List["Tag"]], + operation_id: Optional[str], + middlewares: Optional[List[Callable[..., Response]]], ): """ @@ -214,10 +219,10 @@ def __init__( method: str The HTTP method, example "GET" - rule: Pattern - The route rule, example "/my/path" path: str The path of the route + rule: Pattern + The route rule, example "/my/path" func: Callable The route handler function cors: bool @@ -226,12 +231,20 @@ def __init__( Whether or not to enable gzip compression for this route cache_control: Optional[str] The cache control header value, example "max-age=3600" - middlewares: Optional[List[Callable[..., Response]]] - The list of route middlewares to be called in order. + summary: Optional[str] + The OpenAPI summary for this route description: Optional[str] The OpenAPI description for this route + responses: Optional[Dict[Union[int, str], Dict[str, Any]]] + The OpenAPI responses for this route + response_description: Optional[str] + The OpenAPI response description for this route tags: Optional[List[Tag]] The list of OpenAPI tags to be used for this route + operation_id: Optional[str] + The OpenAPI operationId for this route + middlewares: Optional[List[Callable[..., Response]]] + The list of route middlewares to be called in order. """ self.method = method.upper() self.path = path @@ -241,10 +254,13 @@ def __init__( self.cors = cors self.compress = compress self.cache_control = cache_control - self.middlewares = middlewares or [] + self.summary = summary self.description = description + self.responses = responses + self.response_description = response_description self.tags = tags or [] - self.operation_id = self.method.title() + self.func.__name__.title() + self.middlewares = middlewares or [] + self.operation_id = operation_id or (self.method.title() + self.func.__name__.title()) # _middleware_stack_built is used to ensure the middleware stack is only built once. self._middleware_stack_built = False @@ -372,7 +388,7 @@ def _get_openapi_path( responses = operation.setdefault("responses", {}) success_response = responses.setdefault("200", {}) - success_response["description"] = "Success" + success_response["description"] = self.response_description or _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION success_response["content"] = {"application/json": {"schema": {}}} json_response = success_response["content"].setdefault("application/json", {}) @@ -391,7 +407,7 @@ def _get_openapi_path( return path, definitions def _openapi_operation_summary(self) -> str: - return f"{self.method.upper()} {self.path}" + return self.summary or f"{self.method.upper()} {self.path}" def _openapi_operation_metadata(self, operation_ids: Set[str]) -> Dict[str, Any]: operation: Dict[str, Any] = {} @@ -598,8 +614,12 @@ def route( cors: Optional[bool] = None, compress: bool = False, cache_control: Optional[str] = None, + summary: Optional[str] = None, description: Optional[str] = None, + responses: Optional[Dict[Union[int, str], Dict[str, Any]]] = None, + response_description: Optional[str] = _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION, tags: Optional[List["Tag"]] = None, + operation_id: Optional[str] = None, middlewares: Optional[List[Callable[..., Any]]] = None, ): raise NotImplementedError() @@ -651,9 +671,13 @@ def get( cors: Optional[bool] = None, compress: bool = False, cache_control: Optional[str] = None, - middlewares: Optional[List[Callable[..., Any]]] = None, + summary: Optional[str] = None, description: Optional[str] = None, + responses: Optional[Dict[Union[int, str], Dict[str, Any]]] = None, + response_description: Optional[str] = _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION, tags: Optional[List["Tag"]] = None, + operation_id: Optional[str] = None, + middlewares: Optional[List[Callable[..., Any]]] = None, ): """Get route decorator with GET `method` @@ -677,7 +701,20 @@ def lambda_handler(event, context): return app.resolve(event, context) ``` """ - return self.route(rule, "GET", cors, compress, cache_control, description, tags, middlewares) + return self.route( + rule, + "GET", + cors, + compress, + cache_control, + summary, + description, + responses, + response_description, + tags, + operation_id, + middlewares, + ) def post( self, @@ -685,9 +722,13 @@ def post( cors: Optional[bool] = None, compress: bool = False, cache_control: Optional[str] = None, - middlewares: Optional[List[Callable[..., Any]]] = None, + summary: Optional[str] = None, description: Optional[str] = None, + responses: Optional[Dict[Union[int, str], Dict[str, Any]]] = None, + response_description: Optional[str] = _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION, tags: Optional[List["Tag"]] = None, + operation_id: Optional[str] = None, + middlewares: Optional[List[Callable[..., Any]]] = None, ): """Post route decorator with POST `method` @@ -712,7 +753,20 @@ def lambda_handler(event, context): return app.resolve(event, context) ``` """ - return self.route(rule, "POST", cors, compress, cache_control, description, tags, middlewares) + return self.route( + rule, + "POST", + cors, + compress, + cache_control, + summary, + description, + responses, + response_description, + tags, + operation_id, + middlewares, + ) def put( self, @@ -720,9 +774,13 @@ def put( cors: Optional[bool] = None, compress: bool = False, cache_control: Optional[str] = None, - middlewares: Optional[List[Callable[..., Any]]] = None, + summary: Optional[str] = None, description: Optional[str] = None, + responses: Optional[Dict[Union[int, str], Dict[str, Any]]] = None, + response_description: Optional[str] = _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION, tags: Optional[List["Tag"]] = None, + operation_id: Optional[str] = None, + middlewares: Optional[List[Callable[..., Any]]] = None, ): """Put route decorator with PUT `method` @@ -747,7 +805,20 @@ def lambda_handler(event, context): return app.resolve(event, context) ``` """ - return self.route(rule, "PUT", cors, compress, cache_control, description, tags, middlewares) + return self.route( + rule, + "PUT", + cors, + compress, + cache_control, + summary, + description, + responses, + response_description, + tags, + operation_id, + middlewares, + ) def delete( self, @@ -755,9 +826,13 @@ def delete( cors: Optional[bool] = None, compress: bool = False, cache_control: Optional[str] = None, - middlewares: Optional[List[Callable[..., Any]]] = None, + summary: Optional[str] = None, description: Optional[str] = None, + responses: Optional[Dict[Union[int, str], Dict[str, Any]]] = None, + response_description: Optional[str] = _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION, tags: Optional[List["Tag"]] = None, + operation_id: Optional[str] = None, + middlewares: Optional[List[Callable[..., Any]]] = None, ): """Delete route decorator with DELETE `method` @@ -781,7 +856,20 @@ def lambda_handler(event, context): return app.resolve(event, context) ``` """ - return self.route(rule, "DELETE", cors, compress, cache_control, description, tags, middlewares) + return self.route( + rule, + "DELETE", + cors, + compress, + cache_control, + summary, + description, + responses, + response_description, + tags, + operation_id, + middlewares, + ) def patch( self, @@ -789,9 +877,13 @@ def patch( cors: Optional[bool] = None, compress: bool = False, cache_control: Optional[str] = None, - middlewares: Optional[List[Callable]] = None, + summary: Optional[str] = None, description: Optional[str] = None, + responses: Optional[Dict[Union[int, str], Dict[str, Any]]] = None, + response_description: Optional[str] = _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION, tags: Optional[List["Tag"]] = None, + operation_id: Optional[str] = None, + middlewares: Optional[List[Callable]] = None, ): """Patch route decorator with PATCH `method` @@ -818,7 +910,20 @@ def lambda_handler(event, context): return app.resolve(event, context) ``` """ - return self.route(rule, "PATCH", cors, compress, cache_control, description, tags, middlewares) + return self.route( + rule, + "PATCH", + cors, + compress, + cache_control, + summary, + description, + responses, + response_description, + tags, + operation_id, + middlewares, + ) def _push_processed_stack_frame(self, frame: str): """ @@ -1204,8 +1309,12 @@ def route( cors: Optional[bool] = None, compress: bool = False, cache_control: Optional[str] = None, + summary: Optional[str] = None, description: Optional[str] = None, + responses: Optional[Dict[Union[int, str], Dict[str, Any]]] = None, + response_description: Optional[str] = _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION, tags: Optional[List["Tag"]] = None, + operation_id: Optional[str] = None, middlewares: Optional[List[Callable[..., Any]]] = None, ): """Route decorator includes parameter `method`""" @@ -1225,9 +1334,13 @@ def register_resolver(func: Callable): cors_enabled, compress, cache_control, - middlewares, + summary, description, + responses, + response_description, tags, + operation_id, + middlewares, ) # The more specific route wins. @@ -1628,15 +1741,31 @@ def route( cors: Optional[bool] = None, compress: bool = False, cache_control: Optional[str] = None, + summary: Optional[str] = None, description: Optional[str] = None, + responses: Optional[Dict[Union[int, str], Dict[str, Any]]] = None, + response_description: Optional[str] = _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION, tags: Optional[List["Tag"]] = None, + operation_id: Optional[str] = None, middlewares: Optional[List[Callable[..., Any]]] = None, ): def register_route(func: Callable): # Convert methods to tuple. It needs to be hashable as its part of the self._routes dict key methods = (method,) if isinstance(method, str) else tuple(method) - route_key = (rule, methods, cors, compress, cache_control, description, tags) + route_key = ( + rule, + methods, + cors, + compress, + cache_control, + summary, + description, + responses, + response_description, + tags, + operation_id, + ) # Collate Middleware for routes if middlewares is not None: @@ -1676,12 +1805,29 @@ def route( cors: Optional[bool] = None, compress: bool = False, cache_control: Optional[str] = None, + summary: Optional[str] = None, description: Optional[str] = None, + responses: Optional[Dict[Union[int, str], Dict[str, Any]]] = None, + response_description: Optional[str] = _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION, tags: Optional[List["Tag"]] = None, + operation_id: Optional[str] = None, middlewares: Optional[List[Callable[..., Any]]] = None, ): # NOTE: see #1552 for more context. - return super().route(rule.rstrip("/"), method, cors, compress, cache_control, description, tags, middlewares) + return super().route( + rule.rstrip("/"), + method, + cors, + compress, + cache_control, + summary, + description, + responses, + response_description, + tags, + operation_id, + middlewares, + ) # Override _compile_regex to exclude trailing slashes for route resolution @staticmethod diff --git a/aws_lambda_powertools/event_handler/openapi/params.py b/aws_lambda_powertools/event_handler/openapi/params.py index 18909d56b44..9f404525eaf 100644 --- a/aws_lambda_powertools/event_handler/openapi/params.py +++ b/aws_lambda_powertools/event_handler/openapi/params.py @@ -351,7 +351,7 @@ def _get_field_info_and_type_annotation(annotation, value, is_path_param: bool) if annotation is not inspect.Signature.empty: # If the annotation is an Annotated type, we need to extract the type annotation and the FieldInfo if get_origin(annotation) is Annotated: - type_annotation = _get_field_info_annotated_type(annotation, value, is_path_param) + field_info, type_annotation = _get_field_info_annotated_type(annotation, value, is_path_param) # If the annotation is not an Annotated type, we use it as the type annotation else: type_annotation = annotation diff --git a/tests/functional/event_handler/test_openapi_params.py b/tests/functional/event_handler/test_openapi_params.py index 7c170cfe98f..1db549977af 100644 --- a/tests/functional/event_handler/test_openapi_params.py +++ b/tests/functional/event_handler/test_openapi_params.py @@ -1,13 +1,18 @@ from dataclasses import dataclass +from datetime import datetime +from typing import List from pydantic import BaseModel +from typing_extensions import Annotated from aws_lambda_powertools.event_handler.api_gateway import ApiGatewayResolver from aws_lambda_powertools.event_handler.openapi.models import ( + Example, Parameter, ParameterInType, Schema, ) +from aws_lambda_powertools.event_handler.openapi.params import Query JSON_CONTENT_TYPE = "application/json" @@ -33,9 +38,10 @@ def handler(): assert get.summary == "GET /" assert get.operationId == "GetHandler" - assert "200" in get.responses + assert get.responses is not None + assert "200" in get.responses.keys() response = get.responses["200"] - assert response.description == "Success" + assert response.description == "Successful Response" assert JSON_CONTENT_TYPE in response.content json_response = response.content[JSON_CONTENT_TYPE] @@ -85,6 +91,41 @@ def handler(user_id: str, include_extra: bool = False): assert parameter.schema_.title == "Include Extra" +def test_openapi_with_custom_params(): + app = ApiGatewayResolver() + + @app.get("/users", summary="Get Users", operation_id="GetUsers", description="Get paginated users", tags=["Users"]) + def handler( + count: Annotated[ + int, + Query(lt=100, gt=0, examples=[Example(summary="Example 1", value=10)]), + ] = 1, + ): + raise NotImplementedError() + + schema = app.get_openapi_schema() + + get = schema.paths["/users"].get + assert len(get.parameters) == 1 + assert get.summary == "Get Users" + assert get.operationId == "GetUsers" + assert get.description == "Get paginated users" + assert get.tags == ["Users"] + + parameter = get.parameters[0] + assert parameter.required is False + assert parameter.name == "count" + assert parameter.in_ == ParameterInType.query + assert parameter.schema_.type == "integer" + assert parameter.schema_.default == 1 + assert parameter.schema_.title == "Count" + assert parameter.schema_.exclusiveMinimum == 0 + assert parameter.schema_.exclusiveMaximum == 100 + assert len(parameter.schema_.examples) == 1 + assert parameter.schema_.examples[0].summary == "Example 1" + assert parameter.schema_.examples[0].value == 10 + + def test_openapi_with_scalar_returns(): app = ApiGatewayResolver() @@ -130,7 +171,32 @@ def handler() -> User: assert "name" in user_schema.properties -def test_openapi_with_dataclasse_return(): +def test_openapi_with_pydantic_nested_returns(): + app = ApiGatewayResolver() + + class Order(BaseModel): + date: datetime + + class User(BaseModel): + name: str + orders: List[Order] + + @app.get("/") + def handler() -> User: + return User(name="Ruben Fonseca", orders=[Order(date=datetime.now())]) + + schema = app.get_openapi_schema() + assert len(schema.paths.keys()) == 1 + + assert "User" in schema.components.schemas + assert "Order" in schema.components.schemas + + user_schema = schema.components.schemas["User"] + assert "orders" in user_schema.properties + assert user_schema.properties["orders"].type == "array" + + +def test_openapi_with_dataclass_return(): app = ApiGatewayResolver() @dataclass From a80f53bd454d2b02d7e67cb7a0f12c06ae95e705 Mon Sep 17 00:00:00 2001 From: Ruben Fonseca Date: Mon, 2 Oct 2023 20:51:50 +0200 Subject: [PATCH 24/75] fix: tests --- .../event_handler/api_gateway.py | 32 ++++--- .../event_handler/openapi/compat.py | 29 +++++- .../event_handler/openapi/dependant.py | 8 +- .../event_handler/openapi/models.py | 9 +- .../event_handler/openapi/params.py | 91 +++++++++++++------ .../event_handler/test_openapi_params.py | 2 +- 6 files changed, 119 insertions(+), 52 deletions(-) diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index 00b00b3eb54..624255cca8f 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -265,6 +265,9 @@ def __init__( # _middleware_stack_built is used to ensure the middleware stack is only built once. self._middleware_stack_built = False + # _dependant is used to cache the dependant model for the handler function + self._dependant: Optional["Dependant"] = None + def __call__( self, router_middlewares: List[Callable], @@ -355,6 +358,15 @@ def _build_middleware_stack(self, router_middlewares: List[Callable[..., Any]]) self._middleware_stack_built = True + @property + def dependant(self) -> "Dependant": + if self._dependant is None: + from aws_lambda_powertools.event_handler.openapi.dependant import get_dependant + + self._dependant = get_dependant(path=self.path, call=self.func) + + return self._dependant + def _get_openapi_path( self, *, @@ -1170,8 +1182,7 @@ def get_openapi_schema( get_compat_model_name_map, get_definitions, ) - from aws_lambda_powertools.event_handler.openapi.dependant import get_dependant - from aws_lambda_powertools.event_handler.openapi.models import OpenAPI, Server + from aws_lambda_powertools.event_handler.openapi.models import OpenAPI, PathItem, Server from aws_lambda_powertools.event_handler.openapi.types import ( COMPONENT_REF_TEMPLATE, ) @@ -1215,13 +1226,8 @@ def get_openapi_schema( # Add routes to the OpenAPI schema for route in all_routes: - dependant = get_dependant( - path=route.path, - call=route.func, - ) - result = route._get_openapi_path( - dependant=dependant, + dependant=route.dependant, operation_ids=operation_ids, model_name_map=model_name_map, field_mapping=field_mapping, @@ -1240,7 +1246,7 @@ def get_openapi_schema( if tags: output["tags"] = tags - output["paths"] = paths + output["paths"] = {k: PathItem(**v) for k, v in paths.items()} return OpenAPI(**output) @@ -1706,7 +1712,6 @@ def _get_fields_from_routes(routes: Sequence[Route]) -> List["ModelField"]: """ from aws_lambda_powertools.event_handler.openapi.dependant import ( - get_dependant, get_flat_params, ) @@ -1714,12 +1719,11 @@ def _get_fields_from_routes(routes: Sequence[Route]) -> List["ModelField"]: request_fields_from_routes: List["ModelField"] = [] for route in routes: - dependant = get_dependant(path=route.path, call=route.func) - params = get_flat_params(dependant) + params = get_flat_params(route.dependant) request_fields_from_routes.extend(params) - if dependant.return_param: - responses_from_routes.append(dependant.return_param) + if route.dependant.return_param: + responses_from_routes.append(route.dependant.return_param) flat_models = list(responses_from_routes + request_fields_from_routes) return flat_models diff --git a/aws_lambda_powertools/event_handler/openapi/compat.py b/aws_lambda_powertools/event_handler/openapi/compat.py index bd5d8c4c445..33aaf934c27 100644 --- a/aws_lambda_powertools/event_handler/openapi/compat.py +++ b/aws_lambda_powertools/event_handler/openapi/compat.py @@ -1,5 +1,6 @@ # mypy: ignore-errors # flake8: noqa +from copy import copy # MAINTENANCE: remove when deprecating Pydantic v1. Mypy doesn't handle two different code paths that import different # versions of a module, so we need to ignore errors here. @@ -10,19 +11,27 @@ from typing_extensions import Annotated, Literal -from aws_lambda_powertools.event_handler.openapi.types import COMPONENT_REF_PREFIX, PYDANTIC_V2, ModelNameMap +from pydantic import BaseModel +from pydantic.fields import FieldInfo + +from aws_lambda_powertools.event_handler.openapi.types import ( + COMPONENT_REF_PREFIX, + PYDANTIC_V2, + ModelNameMap, +) if PYDANTIC_V2: from pydantic import TypeAdapter from pydantic._internal._typing_extra import eval_type_lenient from pydantic.fields import FieldInfo from pydantic.json_schema import GenerateJsonSchema, JsonSchemaValue - from pydantic_core import PydanticUndefined + from pydantic_core import PydanticUndefined, PydanticUndefinedType from aws_lambda_powertools.event_handler.openapi.types import IncEx Undefined = PydanticUndefined Required = PydanticUndefined + UndefinedType = PydanticUndefinedType evaluate_forwardref = eval_type_lenient @@ -49,7 +58,7 @@ def default(self) -> Any: def type_(self) -> Any: return self.field_info.annotation - def __post__init__(self) -> None: + def __post_init__(self) -> None: self._type_adapter: TypeAdapter[Any] = TypeAdapter( Annotated[self.field_info.annotation, self.field_info], ) @@ -125,9 +134,15 @@ def get_compat_model_name_map(fields: List[ModelField]) -> ModelNameMap: def get_annotation_from_field_info(annotation: Any, field_info: FieldInfo, field_name: str) -> Any: return annotation + def model_rebuild(model: Type[BaseModel]) -> None: + model.model_rebuild() + + def copy_field_info(*, field_info: FieldInfo, annotation: Any) -> FieldInfo: + return type(field_info).from_annotation(annotation) + else: from pydantic import BaseModel - from pydantic.fields import ModelField, Required, Undefined + from pydantic.fields import ModelField, Required, Undefined, UndefinedType from pydantic.schema import ( field_schema, get_annotation_from_field_info, @@ -192,3 +207,9 @@ def get_model_definitions( def get_compat_model_name_map(fields: List[ModelField]) -> ModelNameMap: models = get_flat_models_from_fields(fields, known_models=set()) return get_model_name_map(models) + + def model_rebuild(model: Type[BaseModel]) -> None: + model.update_forward_refs() + + def copy_field_info(*, field_info: FieldInfo, annotation: Any) -> FieldInfo: + return copy(field_info) diff --git a/aws_lambda_powertools/event_handler/openapi/dependant.py b/aws_lambda_powertools/event_handler/openapi/dependant.py index 8cc57944f80..a9e0d5bfa3c 100644 --- a/aws_lambda_powertools/event_handler/openapi/dependant.py +++ b/aws_lambda_powertools/event_handler/openapi/dependant.py @@ -160,11 +160,12 @@ def get_dependant( is_path_param = param_name in path_param_names # Analyze the parameter to get the Pydantic field. - _, param_field = analyze_param( + param_field = analyze_param( param_name=param_name, annotation=param.annotation, value=param.default, is_path_param=is_path_param, + is_response_param=False, ) if param_field is None: raise AssertionError(f"Param field is None for param: {param_name}") @@ -174,11 +175,12 @@ def get_dependant( # If the return annotation is not empty, add it to the dependant model. return_annotation = endpoint_signature.return_annotation if return_annotation is not inspect.Signature.empty: - _, param_field = analyze_param( - param_name="Return", + param_field = analyze_param( + param_name="return", annotation=return_annotation, value=None, is_path_param=False, + is_response_param=True, ) if param_field is None: raise AssertionError("Param field is None for return annotation") diff --git a/aws_lambda_powertools/event_handler/openapi/models.py b/aws_lambda_powertools/event_handler/openapi/models.py index 40d80003f1c..4b5218f9833 100644 --- a/aws_lambda_powertools/event_handler/openapi/models.py +++ b/aws_lambda_powertools/event_handler/openapi/models.py @@ -4,6 +4,7 @@ from pydantic import AnyUrl, BaseModel, Field from typing_extensions import Annotated, Literal +from aws_lambda_powertools.event_handler.openapi.compat import model_rebuild from aws_lambda_powertools.event_handler.openapi.types import PYDANTIC_V2 """ @@ -205,7 +206,7 @@ class Schema(BaseModel): deprecated: Optional[bool] = None readOnly: Optional[bool] = None writeOnly: Optional[bool] = None - examples: Optional[List[Any]] = None + examples: Optional[List["Example"]] = None # Ref: OpenAPI 3.1.0: https://github.com/OAI/OpenAPI-Specification/blob/main/versions/3.1.0.md#schema-object # Schema Object discriminator: Optional[Discriminator] = None @@ -577,6 +578,6 @@ class Config: extra = "allow" -Schema.update_forward_refs() -Operation.update_forward_refs() -Encoding.update_forward_refs() +model_rebuild(Schema) +model_rebuild(Operation) +model_rebuild(Encoding) diff --git a/aws_lambda_powertools/event_handler/openapi/params.py b/aws_lambda_powertools/event_handler/openapi/params.py index 9f404525eaf..fae8fb7fc35 100644 --- a/aws_lambda_powertools/event_handler/openapi/params.py +++ b/aws_lambda_powertools/event_handler/openapi/params.py @@ -1,16 +1,17 @@ import inspect -from copy import copy from enum import Enum -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union from pydantic import BaseConfig from pydantic.fields import FieldInfo -from typing_extensions import Annotated, get_args, get_origin +from typing_extensions import Annotated, Literal, get_args, get_origin from aws_lambda_powertools.event_handler.openapi.compat import ( ModelField, Required, Undefined, + UndefinedType, + copy_field_info, get_annotation_from_field_info, ) from aws_lambda_powertools.event_handler.openapi.types import PYDANTIC_V2, CacheKey @@ -302,7 +303,8 @@ def analyze_param( annotation: Any, value: Any, is_path_param: bool, -) -> Tuple[Any, Optional[ModelField]]: + is_response_param: bool, +) -> Optional[ModelField]: """ Analyze a parameter annotation and value to determine the type and default value of the parameter. @@ -316,10 +318,12 @@ def analyze_param( The value of the parameter is_path_param Whether the parameter is a path parameter + is_response_param + Whether the parameter is the return annotation Returns ------- - Tuple[Any, Optional[ModelField]] + Optional[ModelField] The type annotation and the Pydantic field representing the parameter """ field_info, type_annotation = _get_field_info_and_type_annotation(annotation, value, is_path_param) @@ -336,12 +340,16 @@ def analyze_param( # Check if the parameter is part of the path. Otherwise, defaults to query. if is_path_param: - field_info = Path(annotation=type_annotation, default=default_value) + field_info = Path(annotation=type_annotation) else: field_info = Query(annotation=type_annotation, default=default_value) + # When we have a response field, we need to set the default value to Required + if is_response_param: + field_info.default = Required + field = _create_model_field(field_info, type_annotation, param_name, is_path_param) - return type_annotation, field + return field def _get_field_info_and_type_annotation(annotation, value, is_path_param: bool) -> Tuple[Optional[FieldInfo], Any]: @@ -372,7 +380,10 @@ def _get_field_info_annotated_type(annotation, value, is_path_param: bool) -> Tu if isinstance(powertools_annotation, FieldInfo): # Copy `field_info` because we mutate `field_info.default` later - field_info = copy(powertools_annotation) + field_info = copy_field_info( + field_info=powertools_annotation, + annotation=annotation, + ) if field_info.default not in [Undefined, Required]: raise AssertionError("FieldInfo needs to have a default value of Undefined or Required") @@ -386,6 +397,44 @@ def _get_field_info_annotated_type(annotation, value, is_path_param: bool) -> Tu return field_info, type_annotation +def _create_response_field( + name: str, + type_: Type[Any], + default: Optional[Any] = Undefined, + required: Union[bool, UndefinedType] = Undefined, + model_config: Type[BaseConfig] = BaseConfig, + field_info: Optional[FieldInfo] = None, + alias: Optional[str] = None, + mode: Literal["validation", "serialization"] = "validation", +) -> ModelField: + """ + Create a new response field. Raises if type_ is invalid. + """ + if PYDANTIC_V2: + field_info = field_info or FieldInfo( + annotation=type_, + default=default, + alias=alias, + ) + else: + field_info = field_info or FieldInfo() + kwargs = {"name": name, "field_info": field_info} + if PYDANTIC_V2: + kwargs.update({"mode": mode}) + else: + kwargs.update( + { + "type_": type_, + "class_validators": {}, + "default": default, + "required": required, + "model_config": model_config, + "alias": alias, + }, + ) + return ModelField(**kwargs) # type: ignore[arg-type] + + def _create_model_field( field_info: Optional[FieldInfo], type_annotation: Any, @@ -411,21 +460,11 @@ def _create_model_field( alias = field_info.alias or param_name field_info.alias = alias - # Create the Pydantic field - kwargs = {"name": param_name, "field_info": field_info} - - if PYDANTIC_V2: - kwargs.update({"mode": "validation"}) - else: - kwargs.update( - { - "type_": use_annotation, - "class_validators": {}, - "default": field_info.default, - "required": field_info.default in (Required, Undefined), - "model_config": BaseConfig, - "alias": alias, - }, - ) - - return ModelField(**kwargs) # type: ignore[arg-type] + return _create_response_field( + name=param_name, + type_=use_annotation, + default=field_info.default, + alias=alias, + required=field_info.default in (Required, Undefined), + field_info=field_info, + ) diff --git a/tests/functional/event_handler/test_openapi_params.py b/tests/functional/event_handler/test_openapi_params.py index 1db549977af..a2f444bcacc 100644 --- a/tests/functional/event_handler/test_openapi_params.py +++ b/tests/functional/event_handler/test_openapi_params.py @@ -98,7 +98,7 @@ def test_openapi_with_custom_params(): def handler( count: Annotated[ int, - Query(lt=100, gt=0, examples=[Example(summary="Example 1", value=10)]), + Query(gt=0, lt=100, examples=[Example(summary="Example 1", value=10)]), ] = 1, ): raise NotImplementedError() From 79ea08249267bf05f4b3cb76dff94b8d6be850fe Mon Sep 17 00:00:00 2001 From: Ruben Fonseca Date: Mon, 9 Oct 2023 15:51:36 +0200 Subject: [PATCH 25/75] feat: add a validation middleware --- .../event_handler/api_gateway.py | 111 ++++- .../middlewares/openapi_validation.py | 273 +++++++++++ .../event_handler/openapi/compat.py | 286 +++++++++++- .../event_handler/openapi/dependant.py | 89 ++-- .../event_handler/openapi/encoders.py | 194 ++++++++ .../event_handler/openapi/exceptions.py | 15 + .../event_handler/openapi/params.py | 435 ++++++++++++++++-- .../event_handler/openapi/types.py | 28 ++ .../event_handler/openapi/utils.py | 62 +++ 9 files changed, 1381 insertions(+), 112 deletions(-) create mode 100644 aws_lambda_powertools/event_handler/middlewares/openapi_validation.py create mode 100644 aws_lambda_powertools/event_handler/openapi/encoders.py create mode 100644 aws_lambda_powertools/event_handler/openapi/exceptions.py create mode 100644 aws_lambda_powertools/event_handler/openapi/utils.py diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index 624255cca8f..4dbdd064b61 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -30,6 +30,12 @@ from aws_lambda_powertools.event_handler import content_types from aws_lambda_powertools.event_handler.exceptions import NotFoundError, ServiceError +from aws_lambda_powertools.event_handler.openapi.types import ( + COMPONENT_REF_PREFIX, + METHODS_WITH_BODY, + validation_error_definition, + validation_error_response_definition, +) from aws_lambda_powertools.event_handler.response import Response from aws_lambda_powertools.shared.functions import powertools_dev_is_set from aws_lambda_powertools.shared.json_encoder import Encoder @@ -67,7 +73,9 @@ Tag, ) from aws_lambda_powertools.event_handler.openapi.params import Dependant - from aws_lambda_powertools.event_handler.openapi.types import TypeModelOrEnum + from aws_lambda_powertools.event_handler.openapi.types import ( + TypeModelOrEnum, + ) class ProxyEventType(Enum): @@ -268,6 +276,9 @@ def __init__( # _dependant is used to cache the dependant model for the handler function self._dependant: Optional["Dependant"] = None + # _body_field is used to cache the dependant model for the body field + self._body_field: Optional["ModelField"] = None + def __call__( self, router_middlewares: List[Callable], @@ -367,6 +378,15 @@ def dependant(self) -> "Dependant": return self._dependant + @property + def body_field(self) -> Optional["ModelField"]: + if self._body_field is None: + from aws_lambda_powertools.event_handler.openapi.params import _get_body_field + + self._body_field = _get_body_field(dependant=self.dependant, name=self.operation_id) + + return self._body_field + def _get_openapi_path( self, *, @@ -375,9 +395,7 @@ def _get_openapi_path( model_name_map: Dict["TypeModelOrEnum", str], field_mapping: Dict[Tuple["ModelField", Literal["validation", "serialization"]], "JsonSchemaValue"], ) -> Tuple[Dict[str, Any], Dict[str, Any]]: - from aws_lambda_powertools.event_handler.openapi.dependant import ( - get_flat_params, - ) + from aws_lambda_powertools.event_handler.openapi.dependant import get_flat_params path = {} definitions: Dict[str, Any] = {} @@ -398,6 +416,15 @@ def _get_openapi_path( all_parameters.update(required_parameters) operation["parameters"] = list(all_parameters.values()) + if self.method.upper() in METHODS_WITH_BODY: + request_body_oai = self._openapi_operation_request_body( + body_field=self.body_field, + model_name_map=model_name_map, + field_mapping=field_mapping, + ) + if request_body_oai: + operation["requestBody"] = request_body_oai + responses = operation.setdefault("responses", {}) success_response = responses.setdefault("200", {}) success_response["description"] = self.response_description or _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION @@ -413,6 +440,24 @@ def _get_openapi_path( ), ) + # Validation responses + operation["responses"]["422"] = { + "description": "Validation Error", + "content": { + "application/json": { + "schema": {"$ref": COMPONENT_REF_PREFIX + "HTTPValidationError"}, + }, + }, + } + + if "ValidationError" not in definitions: + definitions.update( + { + "ValidationError": validation_error_definition, + "HTTPValidationError": validation_error_response_definition, + }, + ) + path[self.method.lower()] = operation # Generate the response schema @@ -444,6 +489,38 @@ def _openapi_operation_metadata(self, operation_ids: Set[str]) -> Dict[str, Any] return operation + @staticmethod + def _openapi_operation_request_body( + *, + body_field: Optional["ModelField"], + model_name_map: Dict["TypeModelOrEnum", str], + field_mapping: Dict[Tuple["ModelField", Literal["validation", "serialization"]], "JsonSchemaValue"], + ) -> Optional[Dict[str, Any]]: + from aws_lambda_powertools.event_handler.openapi.compat import ModelField, get_schema_from_model_field + from aws_lambda_powertools.event_handler.openapi.params import Body + + if not body_field: + return None + + if not isinstance(body_field, ModelField): + raise AssertionError(f"Expected ModelField, got {body_field}") + + body_schema = get_schema_from_model_field( + field=body_field, + model_name_map=model_name_map, + field_mapping=field_mapping, + ) + + field_info = cast(Body, body_field.field_info) + request_media_type = field_info.media_type + required = body_field.required + request_body_oai: Dict[str, Any] = {} + if required: + request_body_oai["required"] = required + request_media_content: Dict[str, Any] = {"schema": body_schema} + request_body_oai["content"] = {request_media_type: request_media_content} + return request_body_oai + @staticmethod def _openapi_operation_parameters( *, @@ -1097,6 +1174,7 @@ def __init__( debug: Optional[bool] = None, serializer: Optional[Callable[[Dict], str]] = None, strip_prefixes: Optional[List[Union[str, Pattern]]] = None, + enable_validation: Optional[bool] = False, ): """ Parameters @@ -1114,6 +1192,8 @@ def __init__( optional list of prefixes to be removed from the request path before doing the routing. This is often used with api gateways with multiple custom mappings. Each prefix can be a static string or a compiled regex pattern + enable_validation: Optional[bool] + Enables validation of the request body against the route schema, by default False. """ self._proxy_type = proxy_type self._dynamic_routes: List[Route] = [] @@ -1124,6 +1204,7 @@ def __init__( self._cors_enabled: bool = cors is not None self._cors_methods: Set[str] = {"OPTIONS"} self._debug = self._has_debug(debug) + self._enable_validation = enable_validation self._strip_prefixes = strip_prefixes self.context: Dict = {} # early init as customers might add context before event resolution self.processed_stack_frames = [] @@ -1131,6 +1212,19 @@ def __init__( # Allow for a custom serializer or a concise json serialization self._serializer = serializer or partial(json.dumps, separators=(",", ":"), cls=Encoder) + if self._enable_validation: + from aws_lambda_powertools.event_handler.middlewares.openapi_validation import OpenAPIValidationMiddleware + + self.use([OpenAPIValidationMiddleware()]) + + # When using validation, we need to skip the serializer, as the middleware is doing it automatically + # However, if the user is using a custom serializer, we need to abort + if serializer: + raise ValueError("Cannot use a custom serializer when using validation") + + # Install a dummy serializer + self._serializer = lambda args: args # type: ignore + def get_openapi_schema( self, *, @@ -1711,21 +1805,28 @@ def _get_fields_from_routes(routes: Sequence[Route]) -> List["ModelField"]: Returns a list of fields from the routes """ + from aws_lambda_powertools.event_handler.openapi.compat import ModelField from aws_lambda_powertools.event_handler.openapi.dependant import ( get_flat_params, ) + body_fields_from_routes: List["ModelField"] = [] responses_from_routes: List["ModelField"] = [] request_fields_from_routes: List["ModelField"] = [] for route in routes: + if route.body_field: + if not isinstance(route.body_field, ModelField): + raise AssertionError("A request body myst be a Pydantic Field") + body_fields_from_routes.append(route.body_field) + params = get_flat_params(route.dependant) request_fields_from_routes.extend(params) if route.dependant.return_param: responses_from_routes.append(route.dependant.return_param) - flat_models = list(responses_from_routes + request_fields_from_routes) + flat_models = list(responses_from_routes + request_fields_from_routes + body_fields_from_routes) return flat_models diff --git a/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py b/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py new file mode 100644 index 00000000000..3b2155cee30 --- /dev/null +++ b/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py @@ -0,0 +1,273 @@ +import dataclasses +import json +import logging +from copy import deepcopy +from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple + +from pydantic import BaseModel + +from aws_lambda_powertools.event_handler import Response +from aws_lambda_powertools.event_handler.api_gateway import Route +from aws_lambda_powertools.event_handler.middlewares import BaseMiddlewareHandler, NextMiddleware +from aws_lambda_powertools.event_handler.openapi.compat import ( + ErrorWrapper, + ModelField, + _model_dump, + _normalize_errors, + _regenerate_error_with_loc, + get_missing_field_error, +) +from aws_lambda_powertools.event_handler.openapi.encoders import jsonable_encoder +from aws_lambda_powertools.event_handler.openapi.exceptions import RequestValidationError +from aws_lambda_powertools.event_handler.openapi.params import Param +from aws_lambda_powertools.event_handler.openapi.types import IncEx +from aws_lambda_powertools.event_handler.types import EventHandlerInstance + +logger = logging.getLogger(__name__) + + +class OpenAPIValidationMiddleware(BaseMiddlewareHandler): + def __init__(self): + super().__init__() + + def handler(self, app: EventHandlerInstance, next_middleware: NextMiddleware) -> Response: + logger.debug("OpenAPIValidationMiddleware handler") + + route: Route = app.context["_route"] + + values: Dict[str, Any] = {} + errors: List[Any] = [] + + try: + path_values, path_errors = self._request_params_to_args( + route.dependant.path_params, + app.context["_route_args"], + ) + query_values, query_errors = self._request_params_to_args( + route.dependant.query_params, + app.current_event.query_string_parameters or {}, + ) + + values.update(path_values) + values.update(query_values) + + errors += path_errors + query_errors + + if route.dependant.body_params: + (body_values, body_errors) = self._request_body_to_args( + required_params=route.dependant.body_params, + received_body=self._get_body(app, route), + ) + values.update(body_values) + errors.extend(body_errors) + + if errors: + raise RequestValidationError(_normalize_errors(errors)) + else: + app.context["_route_args"] = values + response = next_middleware(app) + + raw_response = jsonable_encoder(response.body) + return self._serialize_response(field=route.dependant.return_param, response_content=raw_response) + except RequestValidationError as e: + return Response( + status_code=422, + content_type="application/json", + body=json.dumps({"detail": e.errors()}), + ) + + def _serialize_response( + self, + *, + field: Optional[ModelField] = None, + response_content: Any, + include: Optional[IncEx] = None, + exclude: Optional[IncEx] = None, + by_alias: bool = True, + exclude_unset: bool = False, + exclude_defaults: bool = False, + exclude_none: bool = False, + ) -> Any: + if field: + errors = [] + # MAINTENANCE: remove this when we drop pydantic v1 + if not hasattr(field, "serializable"): + response_content = self._prepare_response_content( + response_content, + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + exclude_none=exclude_none, + ) + + value, errors_ = field.validate(response_content, {}, loc=("response",)) + + if isinstance(errors_, list): + errors.extend(errors_) + elif errors_: + errors.append(errors_) + + if errors: + raise RequestValidationError(errors=_normalize_errors(errors), body=response_content) + + if hasattr(field, "serialize"): + return field.serialize( + value, + include=include, + exclude=exclude, + by_alias=by_alias, + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + exclude_none=exclude_none, + ) + + return jsonable_encoder( + value, + include=include, + exclude=exclude, + by_alias=by_alias, + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + exclude_none=exclude_none, + ) + else: + return jsonable_encoder(response_content) + + def _prepare_response_content( + self, + res: Any, + *, + exclude_unset: bool, + exclude_defaults: bool = False, + exclude_none: bool = False, + ) -> Any: + if isinstance(res, BaseModel): + return _model_dump( + res, + by_alias=True, + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + exclude_none=exclude_none, + ) + elif isinstance(res, list): + return [ + self._prepare_response_content(item, exclude_unset=exclude_unset, exclude_defaults=exclude_defaults) + for item in res + ] + elif isinstance(res, dict): + return { + k: self._prepare_response_content(v, exclude_unset=exclude_unset, exclude_defaults=exclude_defaults) + for k, v in res.items() + } + elif dataclasses.is_dataclass(res): + return dataclasses.asdict(res) + return res + + def _get_body(self, app: EventHandlerInstance, route: Route) -> Dict[str, Any]: + content_type_value = app.current_event.get_header_value("content-type") + if not content_type_value or content_type_value.startswith("application/json"): + try: + return app.current_event.json_body + except json.JSONDecodeError as e: + raise RequestValidationError( + [ + { + "type": "json_invalid", + "loc": ("body", e.pos), + "msg": "JSON decode error", + "input": {}, + "ctx": {"error": e.msg}, + }, + ], + body=e.doc, + ) from e + else: + raise NotImplementedError("Only JSON body is supported") + + @staticmethod + def _request_params_to_args( + required_params: Sequence[ModelField], + received_params: Mapping[str, Any], + ) -> Tuple[Dict[str, Any], List[Any]]: + values = {} + errors = [] + + for field in required_params: + value = received_params.get(field.alias) + + field_info = field.field_info + if not isinstance(field_info, Param): + raise AssertionError(f"Expected Param field_info, got {field_info}") + + loc = (field_info.in_.value, field.alias) + if value is None: + if field.required: + errors.append(get_missing_field_error(loc=loc)) + else: + values[field.name] = deepcopy(field.default) + continue + + v_, errors_ = field.validate(value, values, loc=loc) + if isinstance(errors_, ErrorWrapper): + errors.append(errors_) + elif isinstance(errors_, list): + new_errors = _regenerate_error_with_loc(errors=errors_, loc_prefix=()) + errors.extend(new_errors) + else: + values[field.name] = v_ + + return values, errors + + @staticmethod + def _request_body_to_args( + required_params: List[ModelField], + received_body: Optional[Dict[str, Any]], + ) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]: + values: Dict[str, Any] = {} + errors: List[Dict[str, Any]] = [] + + if not required_params: + return values, errors + + field = required_params[0] + field_info = field.field_info + embed = getattr(field_info, "embed", None) + field_alias_omitted = len(required_params) == 1 and not embed + if field_alias_omitted: + received_body = {field.alias: received_body} + + for field in required_params: + loc: Tuple[str, ...] + if field_alias_omitted: + loc = ("body",) + else: + loc = ("body", field.alias) + + value: Optional[Any] = None + + if received_body is not None: + try: + value = received_body.get(field.alias) + except AttributeError: + errors.append(get_missing_field_error(loc)) + continue + + # Determine if the field is required + if value is None: + if field.required: + errors.append(get_missing_field_error(loc)) + else: + values[field.name] = deepcopy(field.default) + continue + + # MAINTENANCE: Handle byte and file fields + + v_, errors_ = field.validate(value, values, loc=loc) + + if isinstance(errors_, list): + errors.extend(errors_) + elif errors_: + errors.append(errors_) + else: + values[field.name] = v_ + + return values, errors diff --git a/aws_lambda_powertools/event_handler/openapi/compat.py b/aws_lambda_powertools/event_handler/openapi/compat.py index 33aaf934c27..78f50f20b7b 100644 --- a/aws_lambda_powertools/event_handler/openapi/compat.py +++ b/aws_lambda_powertools/event_handler/openapi/compat.py @@ -1,17 +1,19 @@ # mypy: ignore-errors # flake8: noqa +from collections import deque from copy import copy # MAINTENANCE: remove when deprecating Pydantic v1. Mypy doesn't handle two different code paths that import different # versions of a module, so we need to ignore errors here. -from dataclasses import dataclass +from dataclasses import dataclass, is_dataclass from enum import Enum -from typing import Any, Dict, List, Set, Tuple, Type, Union +from types import UnionType +from typing import Any, Dict, List, Set, Tuple, Type, Union, FrozenSet, Deque, Sequence, Mapping -from typing_extensions import Annotated, Literal +from typing_extensions import Annotated, Literal, get_origin, get_args -from pydantic import BaseModel +from pydantic import BaseModel, create_model from pydantic.fields import FieldInfo from aws_lambda_powertools.event_handler.openapi.types import ( @@ -20,10 +22,29 @@ ModelNameMap, ) +sequence_annotation_to_type = { + Sequence: list, + List: list, + list: list, + Tuple: tuple, + tuple: tuple, + Set: set, + set: set, + FrozenSet: frozenset, + frozenset: frozenset, + Deque: deque, + deque: deque, +} + +sequence_types = tuple(sequence_annotation_to_type.keys()) + +RequestErrorModel: Type[BaseModel] = create_model("Request") + if PYDANTIC_V2: - from pydantic import TypeAdapter + from pydantic import TypeAdapter, ValidationError from pydantic._internal._typing_extra import eval_type_lenient from pydantic.fields import FieldInfo + from pydantic._internal._utils import lenient_issubclass from pydantic.json_schema import GenerateJsonSchema, JsonSchemaValue from pydantic_core import PydanticUndefined, PydanticUndefinedType @@ -35,6 +56,9 @@ evaluate_forwardref = eval_type_lenient + class ErrorWrapper(Exception): + pass + @dataclass class ModelField: field_info: FieldInfo @@ -91,6 +115,14 @@ def serialize( exclude_none=exclude_none, ) + def validate( + self, value: Any, values: Dict[str, Any] = {}, *, loc: Tuple[Union[int, str], ...] = () + ) -> Tuple[Any, Union[List[Dict[str, Any]], None]]: + try: + return (self._type_adapter.validate_python(value, from_attributes=True), None) + except ValidationError as exc: + return None, _regenerate_error_with_loc(errors=exc.errors(), loc_prefix=loc) + def __hash__(self) -> int: # Each ModelField is unique for our purposes return id(self) @@ -140,9 +172,58 @@ def model_rebuild(model: Type[BaseModel]) -> None: def copy_field_info(*, field_info: FieldInfo, annotation: Any) -> FieldInfo: return type(field_info).from_annotation(annotation) + def get_missing_field_error(loc: Tuple[str, ...]) -> Dict[str, Any]: + error = ValidationError.from_exception_data( + "Field required", [{"type": "missing", "loc": loc, "input": {}}] + ).errors()[0] + error["input"] = None + return error + + def is_scalar_field(field: ModelField) -> bool: + from aws_lambda_powertools.event_handler.openapi.params import Body + + return field_annotation_is_scalar(field.field_info.annotation) and not isinstance(field.field_info, Body) + + def is_scalar_sequence_field(field: ModelField) -> bool: + return field_annotation_is_scalar_sequence(field.field_info.annotation) + + def is_sequence_field(field: ModelField) -> bool: + return field_annotation_is_sequence(field.field_info.annotation) + + def is_bytes_field(field: ModelField) -> bool: + return is_bytes_or_nonable_bytes_annotation(field.type_) + + def is_bytes_sequence_field(field: ModelField) -> bool: + return is_bytes_sequence_annotation(field.type_) + + def serialize_sequence_value(*, field: ModelField, value: Any) -> Sequence[Any]: + origin_type = get_origin(field.field_info.annotation) or field.field_info.annotation + assert issubclass(origin_type, sequence_types) # type: ignore[arg-type] + return sequence_annotation_to_type[origin_type](value) # type: ignore[no-any-return] + + def _model_dump(model: BaseModel, mode: Literal["json", "python"] = "json", **kwargs: Any) -> Any: + return model.dict(**kwargs) + + def create_body_model(*, fields: Sequence[ModelField], model_name: str) -> Type[BaseModel]: + field_params = {f.name: (f.field_info.annotation, f.field_info) for f in fields} + model: Type[BaseModel] = create_model(model_name, **field_params) + return model + else: - from pydantic import BaseModel - from pydantic.fields import ModelField, Required, Undefined, UndefinedType + from pydantic import BaseModel, ValidationError + from pydantic.fields import ( + ModelField, + Required, + Undefined, + UndefinedType, + SHAPE_LIST, + SHAPE_SET, + SHAPE_FROZENSET, + SHAPE_TUPLE, + SHAPE_SEQUENCE, + SHAPE_TUPLE_ELLIPSIS, + SHAPE_SINGLETON, + ) from pydantic.schema import ( field_schema, get_annotation_from_field_info, @@ -150,10 +231,29 @@ def copy_field_info(*, field_info: FieldInfo, annotation: Any) -> FieldInfo: get_model_name_map, model_process_schema, ) + from pydantic.errors import MissingError + from pydantic.error_wrappers import ErrorWrapper + from pydantic.utils import lenient_issubclass from pydantic.typing import evaluate_forwardref JsonSchemaValue = Dict[str, Any] + sequence_shapes = [ + SHAPE_LIST, + SHAPE_SET, + SHAPE_FROZENSET, + SHAPE_TUPLE, + SHAPE_SEQUENCE, + SHAPE_TUPLE_ELLIPSIS, + ] + sequence_shape_to_type = { + SHAPE_LIST: list, + SHAPE_SET: set, + SHAPE_TUPLE: tuple, + SHAPE_SEQUENCE: list, + SHAPE_TUPLE_ELLIPSIS: list, + } + @dataclass class GenerateJsonSchema: ref_template: str @@ -213,3 +313,175 @@ def model_rebuild(model: Type[BaseModel]) -> None: def copy_field_info(*, field_info: FieldInfo, annotation: Any) -> FieldInfo: return copy(field_info) + + def is_pv1_scalar_field(field: ModelField) -> bool: + from aws_lambda_powertools.event_handler.openapi.params import Body + + if not ( + field.shape == SHAPE_SINGLETON + and not lenient_issubclass(field.type_, BaseModel) + and not lenient_issubclass(field.type_, dict) + and not field_annotation_is_sequence(field.type_) + and not is_dataclass(field.type_) + and not isinstance(field.field_info, Body) + ): + return False + + if field.sub_fields: + if not all(is_pv1_scalar_sequence_field(f) for f in field.sub_fields): + return False + + return True + + def is_pv1_scalar_sequence_field(field: ModelField) -> bool: + if (field.shape in sequence_shapes) and not lenient_issubclass(field.type_, BaseModel): + if field.sub_fields is not None: + for sub_field in field.sub_fields: + if not is_pv1_scalar_field(sub_field): + return False + return True + if _annotation_is_sequence(field.type_): + return True + return False + + def is_scalar_field(field: ModelField) -> bool: + return is_pv1_scalar_field(field) + + def is_scalar_sequence_field(field: ModelField) -> bool: + return is_pv1_scalar_sequence_field(field) + + def is_sequence_field(field: ModelField) -> bool: + return field.shape in sequence_shapes or _annotation_is_sequence(field.type_) + + def is_bytes_field(field: ModelField) -> bool: + return lenient_issubclass(field.type_, bytes) + + def is_bytes_sequence_field(field: ModelField) -> bool: + return field.shape in sequence_shapes and lenient_issubclass(field.type_, bytes) # type: ignore[attr-defined] + + def _annotation_is_sequence(annotation: Union[Type[Any], None]) -> bool: + if lenient_issubclass(annotation, (str, bytes)): + return False + return lenient_issubclass(annotation, sequence_types) + + def get_missing_field_error(loc: Tuple[str, ...]) -> Dict[str, Any]: + missing_field_error = ErrorWrapper(MissingError(), loc=loc) + new_error = ValidationError([missing_field_error], RequestErrorModel) + return new_error.errors()[0] + + def _normalize_errors(errors: Sequence[Any]) -> List[Dict[str, Any]]: + use_errors: List[Any] = [] + for error in errors: + if isinstance(error, ErrorWrapper): + new_errors = ValidationError(errors=[error], model=RequestErrorModel).errors() # type: ignore[call-arg] + use_errors.extend(new_errors) + elif isinstance(error, list): + use_errors.extend(_normalize_errors(error)) + else: + use_errors.append(error) + return use_errors + + def create_body_model(*, fields: Sequence[ModelField], model_name: str) -> Type[BaseModel]: + body_model = create_model(model_name) + for f in fields: + body_model.__fields__[f.name] = f # type: ignore[index] + return body_model + + def serialize_sequence_value(*, field: ModelField, value: Any) -> Sequence[Any]: + return sequence_shape_to_type[field.shape](value) + + def _model_dump(model: BaseModel, mode: Literal["json", "python"] = "json", **kwargs: Any) -> Any: + return model.dict(**kwargs) + + +# Common code for both versions + + +def field_annotation_is_complex(annotation: Union[Type[Any], None]) -> bool: + origin = get_origin(annotation) + if origin is Union or origin is UnionType: + return any(field_annotation_is_complex(arg) for arg in get_args(annotation)) + + return ( + _annotation_is_complex(annotation) + or _annotation_is_complex(origin) + or hasattr(origin, "__pydantic_core_schema__") + or hasattr(origin, "__get_pydantic_core_schema__") + ) + + +def field_annotation_is_scalar(annotation: Any) -> bool: + return annotation is Ellipsis or not field_annotation_is_complex(annotation) + + +def field_annotation_is_sequence(annotation: Union[Type[Any], None]) -> bool: + return _annotation_is_sequence(annotation) or _annotation_is_sequence(get_origin(annotation)) + + +def field_annotation_is_scalar_sequence(annotation: Union[Type[Any], None]) -> bool: + origin = get_origin(annotation) + if origin is Union or origin is UnionType: + at_least_one_scalar_sequence = False + for arg in get_args(annotation): + if field_annotation_is_scalar_sequence(arg): + at_least_one_scalar_sequence = True + continue + elif not field_annotation_is_scalar(arg): + return False + return at_least_one_scalar_sequence + return field_annotation_is_sequence(annotation) and all( + field_annotation_is_scalar(sub_annotation) for sub_annotation in get_args(annotation) + ) + + +def is_bytes_or_nonable_bytes_annotation(annotation: Any) -> bool: + if lenient_issubclass(annotation, bytes): + return True + origin = get_origin(annotation) + if origin is Union or origin is UnionType: + for arg in get_args(annotation): + if lenient_issubclass(arg, bytes): + return True + return False + + +def is_bytes_sequence_annotation(annotation: Any) -> bool: + origin = get_origin(annotation) + if origin is Union or origin is UnionType: + at_least_one = False + for arg in get_args(annotation): + if is_bytes_sequence_annotation(arg): + at_least_one = True + break + return at_least_one + return field_annotation_is_sequence(annotation) and all( + is_bytes_or_nonable_bytes_annotation(sub_annotation) for sub_annotation in get_args(annotation) + ) + + +def value_is_sequence(value: Any) -> bool: + return isinstance(value, sequence_types) and not isinstance(value, (str, bytes)) # type: ignore[arg-type] + + +def _annotation_is_complex(annotation: Union[Type[Any], None]) -> bool: + return ( + lenient_issubclass(annotation, (BaseModel, Mapping)) # TODO: UploadFile + or _annotation_is_sequence(annotation) + or is_dataclass(annotation) + ) + + +def _annotation_is_sequence(annotation: Union[Type[Any], None]) -> bool: + if lenient_issubclass(annotation, (str, bytes)): + return False + return lenient_issubclass(annotation, sequence_types) + + +def _regenerate_error_with_loc( + *, errors: Sequence[Any], loc_prefix: Tuple[Union[str, int], ...] +) -> List[Dict[str, Any]]: + updated_loc_errors: List[Any] = [ + {**err, "loc": loc_prefix + err.get("loc", ())} for err in _normalize_errors(errors) + ] + + return updated_loc_errors diff --git a/aws_lambda_powertools/event_handler/openapi/dependant.py b/aws_lambda_powertools/event_handler/openapi/dependant.py index a9e0d5bfa3c..b7770f7bc9d 100644 --- a/aws_lambda_powertools/event_handler/openapi/dependant.py +++ b/aws_lambda_powertools/event_handler/openapi/dependant.py @@ -2,9 +2,22 @@ import re from typing import Any, Callable, Dict, ForwardRef, List, Optional, Set, cast -from aws_lambda_powertools.event_handler.openapi.compat import ModelField, evaluate_forwardref -from aws_lambda_powertools.event_handler.openapi.params import Dependant, Param, ParamTypes, analyze_param -from aws_lambda_powertools.event_handler.openapi.types import CacheKey +from aws_lambda_powertools.event_handler.openapi.compat import ( + ModelField, + evaluate_forwardref, + is_scalar_field, + is_scalar_sequence_field, +) +from aws_lambda_powertools.event_handler.openapi.params import ( + Body, + Dependant, + Header, + Param, + ParamTypes, + Query, + analyze_param, +) +from aws_lambda_powertools.event_handler.openapi.utils import get_flat_dependant """ This turns the opaque function signature into typed, validated models. @@ -170,7 +183,10 @@ def get_dependant( if param_field is None: raise AssertionError(f"Param field is None for param: {param_name}") - add_param_to_fields(field=param_field, dependant=dependant) + if is_body_param(param_field=param_field, is_path_param=is_path_param): + dependant.body_params.append(param_field) + else: + add_param_to_fields(field=param_field, dependant=dependant) # If the return annotation is not empty, add it to the dependant model. return_annotation = endpoint_signature.return_annotation @@ -190,58 +206,19 @@ def get_dependant( return dependant -def get_flat_dependant( - dependant: Dependant, - *, - skip_repeats: bool = False, - visited: Optional[List[CacheKey]] = None, -) -> Dependant: - """ - Flatten a recursive Dependant model structure. - - This function recursively concatenates the parameter fields of a Dependant model and its dependencies into a flat - Dependant structure. This is useful for scenarios like parameter validation where the nested structure is not - relevant. - - Parameters - ---------- - dependant: Dependant - The dependant model to flatten - skip_repeats: bool - If True, child Dependents already visited will be skipped to avoid duplicates - visited: List[CacheKey], optional - Keeps track of visited Dependents to avoid infinite recursion. Defaults to empty list. - - Returns - ------- - Dependant - The flattened Dependant model - """ - if visited is None: - visited = [] - visited.append(dependant.cache_key) - - flat_dependant = Dependant( - path_params=dependant.path_params.copy(), - query_params=dependant.query_params.copy(), - header_params=dependant.header_params.copy(), - cookie_params=dependant.cookie_params.copy(), - body_params=dependant.body_params.copy(), - path=dependant.path, - ) - for sub_dependant in dependant.dependencies: - if skip_repeats and sub_dependant.cache_key in visited: - continue - - flat_sub = get_flat_dependant(sub_dependant, skip_repeats=skip_repeats, visited=visited) - - flat_dependant.path_params.extend(flat_sub.path_params) - flat_dependant.query_params.extend(flat_sub.query_params) - flat_dependant.header_params.extend(flat_sub.header_params) - flat_dependant.cookie_params.extend(flat_sub.cookie_params) - flat_dependant.body_params.extend(flat_sub.body_params) - - return flat_dependant +def is_body_param(*, param_field: ModelField, is_path_param: bool) -> bool: + if is_path_param: + if not is_scalar_field(field=param_field): + raise AssertionError("Path params must be of one of the supported types") + return False + elif is_scalar_field(field=param_field): + return False + elif isinstance(param_field.field_info, (Query, Header)) and is_scalar_sequence_field(param_field): + return False + else: + if not isinstance(param_field.field_info, Body): + raise AssertionError(f"Param: {param_field.name} can only be a request body, using Body()") + return True def get_flat_params(dependant: Dependant) -> List[ModelField]: diff --git a/aws_lambda_powertools/event_handler/openapi/encoders.py b/aws_lambda_powertools/event_handler/openapi/encoders.py new file mode 100644 index 00000000000..d94c7768c62 --- /dev/null +++ b/aws_lambda_powertools/event_handler/openapi/encoders.py @@ -0,0 +1,194 @@ +import dataclasses +import datetime +from collections import defaultdict, deque +from decimal import Decimal +from enum import Enum +from pathlib import Path, PurePath +from re import Pattern +from types import GeneratorType +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union +from uuid import UUID + +from pydantic import BaseModel +from pydantic.color import Color +from pydantic.types import SecretBytes, SecretStr + +from aws_lambda_powertools.event_handler.openapi.compat import _model_dump +from aws_lambda_powertools.event_handler.openapi.types import IncEx + + +def isoformat(o: Union[datetime.date, datetime.time]) -> str: + return o.isoformat() + + +def decimal_encoder(dec_value: Decimal) -> Union[int, float]: + """ + Encodes a Decimal as int of there's no exponent, otherwise float + + This is useful when we use ConstrainedDecimal to represent Numeric(x,0) + where a integer (but not int typed) is used. Encoding this as a float + results in failed round-tripping between encode and parse. + Our Id type is a prime example of this. + + >>> decimal_encoder(Decimal("1.0")) + 1.0 + + >>> decimal_encoder(Decimal("1")) + 1 + """ + if dec_value.as_tuple().exponent >= 0: # type: ignore[operator] + return int(dec_value) + else: + return float(dec_value) + + +ENCODERS_BY_TYPE: Dict[Type[Any], Callable[[Any], Any]] = { + bytes: lambda o: o.decode(), + Color: str, + datetime.date: isoformat, + datetime.datetime: isoformat, + datetime.time: isoformat, + datetime.timedelta: lambda td: td.total_seconds(), + Decimal: decimal_encoder, + Enum: lambda o: o.value, + frozenset: list, + deque: list, + GeneratorType: list, + Path: str, + Pattern: lambda o: o.pattern, + SecretBytes: str, + SecretStr: str, + set: list, + UUID: str, +} + + +def generate_encoders_by_class_tuples( + type_encoder_map: Dict[Any, Callable[[Any], Any]], +) -> Dict[Callable[[Any], Any], Tuple[Any, ...]]: + encoders: Dict[Callable[[Any], Any], Tuple[Any, ...]] = defaultdict(tuple) + for type_, encoder in type_encoder_map.items(): + encoders[encoder] += (type_,) + return encoders + + +encoders_by_class_tuples = generate_encoders_by_class_tuples(ENCODERS_BY_TYPE) + + +def jsonable_encoder( # noqa: C901, PLR0911, PLR0912 + obj: Any, + include: Optional[IncEx] = None, + exclude: Optional[IncEx] = None, + by_alias: bool = True, + exclude_unset: bool = False, + exclude_defaults: bool = False, + exclude_none: bool = False, +) -> Any: + if include is not None and not isinstance(include, (set, dict)): + include = set(include) + if exclude is not None and not isinstance(exclude, (set, dict)): + exclude = set(exclude) + if isinstance(obj, BaseModel): + obj_dict = _model_dump( + obj, + mode="json", + include=include, + exclude=exclude, + by_alias=by_alias, + exclude_unset=exclude_unset, + exclude_none=exclude_none, + exclude_defaults=exclude_defaults, + ) + if "__root__" in obj_dict: + obj_dict = obj_dict["__root__"] + + return jsonable_encoder( + obj_dict, + exclude_none=exclude_none, + exclude_defaults=exclude_defaults, + ) + + if dataclasses.is_dataclass(obj): + obj_dict = dataclasses.asdict(obj) + return jsonable_encoder( + obj_dict, + include=include, + exclude=exclude, + by_alias=by_alias, + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + exclude_none=exclude_none, + ) + if isinstance(obj, Enum): + return obj.value + if isinstance(obj, PurePath): + return str(obj) + if isinstance(obj, (str, int, float, type(None))): + return obj + if isinstance(obj, dict): + encoded_dict = {} + allowed_keys = set(obj.keys()) + if include is not None: + allowed_keys &= set(include) + if exclude is not None: + allowed_keys -= set(exclude) + for key, value in obj.items(): + if ( + (not isinstance(key, str) or not key.startswith("_sa")) + and (value is not None or not exclude_none) + and key in allowed_keys + ): + encoded_key = jsonable_encoder( + key, + by_alias=by_alias, + exclude_unset=exclude_unset, + exclude_none=exclude_none, + ) + encoded_value = jsonable_encoder( + value, + by_alias=by_alias, + exclude_unset=exclude_unset, + exclude_none=exclude_none, + ) + encoded_dict[encoded_key] = encoded_value + return encoded_dict + if isinstance(obj, (list, set, frozenset, GeneratorType, tuple, deque)): + encoded_list = [] + for item in obj: + encoded_list.append( + jsonable_encoder( + item, + include=include, + exclude=exclude, + by_alias=by_alias, + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + exclude_none=exclude_none, + ), + ) + return encoded_list + + if type(obj) in ENCODERS_BY_TYPE: + return ENCODERS_BY_TYPE[type(obj)](obj) + for encoder, classes_tuple in encoders_by_class_tuples.items(): + if isinstance(obj, classes_tuple): + return encoder(obj) + + try: + data = dict(obj) + except Exception as e: + errors: List[Exception] = [e] + try: + data = vars(obj) + except Exception as e: + errors.append(e) + raise ValueError(errors) from e + return jsonable_encoder( + data, + include=include, + exclude=exclude, + by_alias=by_alias, + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + exclude_none=exclude_none, + ) diff --git a/aws_lambda_powertools/event_handler/openapi/exceptions.py b/aws_lambda_powertools/event_handler/openapi/exceptions.py new file mode 100644 index 00000000000..6dbed56876c --- /dev/null +++ b/aws_lambda_powertools/event_handler/openapi/exceptions.py @@ -0,0 +1,15 @@ +from typing import Any, Sequence + + +class ValidationException(Exception): + def __init__(self, errors: Sequence[Any]) -> None: + self._errors = errors + + def errors(self) -> Sequence[Any]: + return self._errors + + +class RequestValidationError(ValidationException): + def __init__(self, errors: Sequence[Any], *, body: Any = None) -> None: + super().__init__(errors) + self.body = body diff --git a/aws_lambda_powertools/event_handler/openapi/params.py b/aws_lambda_powertools/event_handler/openapi/params.py index fae8fb7fc35..31a04920cb4 100644 --- a/aws_lambda_powertools/event_handler/openapi/params.py +++ b/aws_lambda_powertools/event_handler/openapi/params.py @@ -12,59 +12,18 @@ Undefined, UndefinedType, copy_field_info, + create_body_model, + field_annotation_is_scalar, get_annotation_from_field_info, ) from aws_lambda_powertools.event_handler.openapi.types import PYDANTIC_V2, CacheKey +from aws_lambda_powertools.event_handler.openapi.utils import get_flat_dependant """ This turns the low-level function signature into typed, validated Pydantic models for consumption. """ -class Dependant: - """ - A class used internally to represent a dependency between path operation decorators and the path operation function. - """ - - def __init__( - self, - *, - path_params: Optional[List[ModelField]] = None, - query_params: Optional[List[ModelField]] = None, - header_params: Optional[List[ModelField]] = None, - cookie_params: Optional[List[ModelField]] = None, - body_params: Optional[List[ModelField]] = None, - return_param: Optional[ModelField] = None, - dependencies: Optional[List["Dependant"]] = None, - name: Optional[str] = None, - call: Optional[Callable[..., Any]] = None, - request_param_name: Optional[str] = None, - websocket_param_name: Optional[str] = None, - http_connection_param_name: Optional[str] = None, - response_param_name: Optional[str] = None, - background_tasks_param_name: Optional[str] = None, - path: Optional[str] = None, - ) -> None: - self.path_params = path_params or [] - self.query_params = query_params or [] - self.header_params = header_params or [] - self.cookie_params = cookie_params or [] - self.body_params = body_params or [] - self.return_param = return_param or None - self.dependencies = dependencies or [] - self.request_param_name = request_param_name - self.websocket_param_name = websocket_param_name - self.http_connection_param_name = http_connection_param_name - self.response_param_name = response_param_name - self.background_tasks_param_name = background_tasks_param_name - self.name = name - self.call = call - # Store the path to be able to re-generate a dependable from it in overrides - self.path = path - # Save the cache key at creation to optimize performance - self.cache_key: CacheKey = self.call - - class ParamTypes(Enum): query = "query" header = "header" @@ -297,6 +256,296 @@ def __init__( ) +class Header(Param): + in_ = ParamTypes.header + + def __init__( + self, + default: Any = Undefined, + *, + default_factory: Union[Callable[[], Any], None] = _Unset, + annotation: Optional[Any] = None, + alias: Optional[str] = None, + alias_priority: Union[int, None] = _Unset, + # MAINTENANCE: update when deprecating Pydantic v1, import these types + # str | AliasPath | AliasChoices | None + validation_alias: Union[str, None] = None, + serialization_alias: Union[str, None] = None, + convert_underscores: bool = True, + title: Optional[str] = None, + description: Optional[str] = None, + gt: Optional[float] = None, + ge: Optional[float] = None, + lt: Optional[float] = None, + le: Optional[float] = None, + min_length: Optional[int] = None, + max_length: Optional[int] = None, + pattern: Optional[str] = None, + discriminator: Union[str, None] = None, + strict: Union[bool, None] = _Unset, + multiple_of: Union[float, None] = _Unset, + allow_inf_nan: Union[bool, None] = _Unset, + max_digits: Union[int, None] = _Unset, + decimal_places: Union[int, None] = _Unset, + examples: Optional[List[Any]] = None, + deprecated: Optional[bool] = None, + include_in_schema: bool = True, + json_schema_extra: Union[Dict[str, Any], None] = None, + **extra: Any, + ): + self.convert_underscores = convert_underscores + super().__init__( + default=default, + default_factory=default_factory, + annotation=annotation, + alias=alias, + alias_priority=alias_priority, + validation_alias=validation_alias, + serialization_alias=serialization_alias, + title=title, + description=description, + gt=gt, + ge=ge, + lt=lt, + le=le, + min_length=min_length, + max_length=max_length, + pattern=pattern, + discriminator=discriminator, + strict=strict, + multiple_of=multiple_of, + allow_inf_nan=allow_inf_nan, + max_digits=max_digits, + decimal_places=decimal_places, + deprecated=deprecated, + examples=examples, + include_in_schema=include_in_schema, + json_schema_extra=json_schema_extra, + **extra, + ) + + +class Body(FieldInfo): + def __init__( + self, + default: Any = Undefined, + *, + default_factory: Union[Callable[[], Any], None] = _Unset, + annotation: Optional[Any] = None, + embed: bool = False, + media_type: str = "application/json", + alias: Optional[str] = None, + alias_priority: Union[int, None] = _Unset, + # MAINTENANCE: update when deprecating Pydantic v1, import these types + # str | AliasPath | AliasChoices | None + validation_alias: Union[str, None] = None, + serialization_alias: Union[str, None] = None, + title: Optional[str] = None, + description: Optional[str] = None, + gt: Optional[float] = None, + ge: Optional[float] = None, + lt: Optional[float] = None, + le: Optional[float] = None, + min_length: Optional[int] = None, + max_length: Optional[int] = None, + pattern: Optional[str] = None, + discriminator: Union[str, None] = None, + strict: Union[bool, None] = _Unset, + multiple_of: Union[float, None] = _Unset, + allow_inf_nan: Union[bool, None] = _Unset, + max_digits: Union[int, None] = _Unset, + decimal_places: Union[int, None] = _Unset, + examples: Optional[List[Any]] = None, + deprecated: Optional[bool] = None, + include_in_schema: bool = True, + json_schema_extra: Union[Dict[str, Any], None] = None, + **extra: Any, + ): + self.embed = embed + self.media_type = media_type + self.deprecated = deprecated + self.include_in_schema = include_in_schema + kwargs = dict( + default=default, + default_factory=default_factory, + alias=alias, + title=title, + description=description, + gt=gt, + ge=ge, + lt=lt, + le=le, + min_length=min_length, + max_length=max_length, + discriminator=discriminator, + multiple_of=multiple_of, + allow_nan=allow_inf_nan, + max_digits=max_digits, + decimal_places=decimal_places, + **extra, + ) + if examples is not None: + kwargs["examples"] = examples + current_json_schema_extra = json_schema_extra or extra + if PYDANTIC_V2: + kwargs.update( + { + "annotation": annotation, + "alias_priority": alias_priority, + "validation_alias": validation_alias, + "serialization_alias": serialization_alias, + "strict": strict, + "json_schema_extra": current_json_schema_extra, + }, + ) + kwargs["pattern"] = pattern + else: + kwargs["regex"] = pattern + kwargs.update(**current_json_schema_extra) + + use_kwargs = {k: v for k, v in kwargs.items() if v is not _Unset} + + super().__init__(**use_kwargs) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self.default})" + + +class Form(Body): + def __init__( + self, + default: Any = Undefined, + *, + default_factory: Union[Callable[[], Any], None] = _Unset, + annotation: Optional[Any] = None, + media_type: str = "application/x-www-form-urlencoded", + alias: Optional[str] = None, + alias_priority: Union[int, None] = _Unset, + # MAINTENANCE: update when deprecating Pydantic v1, import these types + # str | AliasPath | AliasChoices | None + validation_alias: Union[str, None] = None, + serialization_alias: Union[str, None] = None, + title: Optional[str] = None, + description: Optional[str] = None, + gt: Optional[float] = None, + ge: Optional[float] = None, + lt: Optional[float] = None, + le: Optional[float] = None, + min_length: Optional[int] = None, + max_length: Optional[int] = None, + pattern: Optional[str] = None, + discriminator: Union[str, None] = None, + strict: Union[bool, None] = _Unset, + multiple_of: Union[float, None] = _Unset, + allow_inf_nan: Union[bool, None] = _Unset, + max_digits: Union[int, None] = _Unset, + decimal_places: Union[int, None] = _Unset, + examples: Optional[List[Any]] = None, + deprecated: Optional[bool] = None, + include_in_schema: bool = True, + json_schema_extra: Union[Dict[str, Any], None] = None, + **extra: Any, + ): + super().__init__( + default=default, + default_factory=default_factory, + annotation=annotation, + embed=True, + media_type=media_type, + alias=alias, + alias_priority=alias_priority, + validation_alias=validation_alias, + serialization_alias=serialization_alias, + title=title, + description=description, + gt=gt, + ge=ge, + lt=lt, + le=le, + min_length=min_length, + max_length=max_length, + pattern=pattern, + discriminator=discriminator, + strict=strict, + multiple_of=multiple_of, + allow_inf_nan=allow_inf_nan, + max_digits=max_digits, + decimal_places=decimal_places, + deprecated=deprecated, + examples=examples, + include_in_schema=include_in_schema, + json_schema_extra=json_schema_extra, + **extra, + ) + + +class File(Form): + def __init__( + self, + default: Any = Undefined, + *, + default_factory: Union[Callable[[], Any], None] = _Unset, + annotation: Optional[Any] = None, + media_type: str = "multipart/form-data", + alias: Optional[str] = None, + alias_priority: Union[int, None] = _Unset, + # MAINTENANCE: update when deprecating Pydantic v1, import these types + # str | AliasPath | AliasChoices | None + validation_alias: Union[str, None] = None, + serialization_alias: Union[str, None] = None, + title: Optional[str] = None, + description: Optional[str] = None, + gt: Optional[float] = None, + ge: Optional[float] = None, + lt: Optional[float] = None, + le: Optional[float] = None, + min_length: Optional[int] = None, + max_length: Optional[int] = None, + pattern: Optional[str] = None, + discriminator: Union[str, None] = None, + strict: Union[bool, None] = _Unset, + multiple_of: Union[float, None] = _Unset, + allow_inf_nan: Union[bool, None] = _Unset, + max_digits: Union[int, None] = _Unset, + decimal_places: Union[int, None] = _Unset, + examples: Optional[List[Any]] = None, + deprecated: Optional[bool] = None, + include_in_schema: bool = True, + json_schema_extra: Union[Dict[str, Any], None] = None, + **extra: Any, + ): + super().__init__( + default=default, + default_factory=default_factory, + annotation=annotation, + media_type=media_type, + alias=alias, + alias_priority=alias_priority, + validation_alias=validation_alias, + serialization_alias=serialization_alias, + title=title, + description=description, + gt=gt, + ge=ge, + lt=lt, + le=le, + min_length=min_length, + max_length=max_length, + pattern=pattern, + discriminator=discriminator, + strict=strict, + multiple_of=multiple_of, + allow_inf_nan=allow_inf_nan, + max_digits=max_digits, + decimal_places=decimal_places, + deprecated=deprecated, + examples=examples, + include_in_schema=include_in_schema, + json_schema_extra=json_schema_extra, + **extra, + ) + + def analyze_param( *, param_name: str, @@ -334,6 +583,9 @@ def analyze_param( raise AssertionError("Cannot use a FieldInfo as a parameter annotation and pass a FieldInfo as a value") field_info = value + if PYDANTIC_V2: + field_info.annotation = type_annotation # type: ignore + # If we didn't determine the FieldInfo yet, we create a default one if field_info is None: default_value = value if value is not inspect.Signature.empty else Required @@ -341,6 +593,8 @@ def analyze_param( # Check if the parameter is part of the path. Otherwise, defaults to query. if is_path_param: field_info = Path(annotation=type_annotation) + elif not field_annotation_is_scalar(annotation=type_annotation): + field_info = Body(annotation=type_annotation, default=default_value) else: field_info = Query(annotation=type_annotation, default=default_value) @@ -397,6 +651,99 @@ def _get_field_info_annotated_type(annotation, value, is_path_param: bool) -> Tu return field_info, type_annotation +class Dependant: + """ + A class used internally to represent a dependency between path operation decorators and the path operation function. + """ + + def __init__( + self, + *, + path_params: Optional[List[ModelField]] = None, + query_params: Optional[List[ModelField]] = None, + header_params: Optional[List[ModelField]] = None, + cookie_params: Optional[List[ModelField]] = None, + body_params: Optional[List[ModelField]] = None, + return_param: Optional[ModelField] = None, + dependencies: Optional[List["Dependant"]] = None, + name: Optional[str] = None, + call: Optional[Callable[..., Any]] = None, + request_param_name: Optional[str] = None, + websocket_param_name: Optional[str] = None, + http_connection_param_name: Optional[str] = None, + response_param_name: Optional[str] = None, + background_tasks_param_name: Optional[str] = None, + path: Optional[str] = None, + ) -> None: + self.path_params = path_params or [] + self.query_params = query_params or [] + self.header_params = header_params or [] + self.cookie_params = cookie_params or [] + self.body_params = body_params or [] + self.return_param = return_param or None + self.dependencies = dependencies or [] + self.request_param_name = request_param_name + self.websocket_param_name = websocket_param_name + self.http_connection_param_name = http_connection_param_name + self.response_param_name = response_param_name + self.background_tasks_param_name = background_tasks_param_name + self.name = name + self.call = call + # Store the path to be able to re-generate a dependable from it in overrides + self.path = path + # Save the cache key at creation to optimize performance + self.cache_key: CacheKey = self.call + + +def _get_body_field(*, dependant: Dependant, name: str) -> Optional[ModelField]: + flat_dependant = get_flat_dependant(dependant) + if not flat_dependant.body_params: + return None + + first_param = flat_dependant.body_params[0] + field_info = first_param.field_info + embed = getattr(field_info, "embed", None) + body_param_names_set = {param.name for param in flat_dependant.body_params} + if len(body_param_names_set) == 1 and not embed: + return first_param + + # If one field requires to embed, all have to be embedded + for param in flat_dependant.body_params: + setattr(param.field_info, "embed", True) # noqa: B010 + + model_name = "Body_" + name + body_model = create_body_model(fields=flat_dependant.body_params, model_name=model_name) + + required = any(True for f in flat_dependant.body_params if f.required) + + body_field_info_kwargs: Dict[str, Any] = {"annotation": body_model, "alias": "body"} + + if not required: + body_field_info_kwargs["default"] = None + + if any(isinstance(f.field_info, File) for f in flat_dependant.body_params): + body_field_info: Type[Body] = File + elif any(isinstance(f.field_info, Form) for f in flat_dependant.body_params): + body_field_info = Form + else: + body_field_info = Body + + body_param_media_types = [ + f.field_info.media_type for f in flat_dependant.body_params if isinstance(f.field_info, Body) + ] + if len(set(body_param_media_types)) == 1: + body_field_info_kwargs["media_type"] = body_param_media_types[0] + + final_field = _create_response_field( + name="body", + type_=body_model, + required=required, + alias="body", + field_info=body_field_info(**body_field_info_kwargs), + ) + return final_field + + def _create_response_field( name: str, type_: Type[Any], diff --git a/aws_lambda_powertools/event_handler/openapi/types.py b/aws_lambda_powertools/event_handler/openapi/types.py index bc994e7cfc9..ec2114d12a5 100644 --- a/aws_lambda_powertools/event_handler/openapi/types.py +++ b/aws_lambda_powertools/event_handler/openapi/types.py @@ -14,3 +14,31 @@ PYDANTIC_V2 = PYDANTIC_VERSION.startswith("2.") COMPONENT_REF_PREFIX = "#/components/schemas/" COMPONENT_REF_TEMPLATE = "#/components/schemas/{model}" +METHODS_WITH_BODY = {"GET", "HEAD", "POST", "PUT", "DELETE", "PATCH"} + +validation_error_definition = { + "title": "ValidationError", + "type": "object", + "properties": { + "loc": { + "title": "Location", + "type": "array", + "items": {"anyOf": [{"type": "string"}, {"type": "integer"}]}, + }, + "msg": {"title": "Message", "type": "string"}, + "type": {"title": "Error Type", "type": "string"}, + }, + "required": ["loc", "msg", "type"], +} + +validation_error_response_definition = { + "title": "HTTPValidationError", + "type": "object", + "properties": { + "detail": { + "title": "Detail", + "type": "array", + "items": {"$ref": COMPONENT_REF_PREFIX + "ValidationError"}, + }, + }, +} diff --git a/aws_lambda_powertools/event_handler/openapi/utils.py b/aws_lambda_powertools/event_handler/openapi/utils.py new file mode 100644 index 00000000000..0c50ae6ea1b --- /dev/null +++ b/aws_lambda_powertools/event_handler/openapi/utils.py @@ -0,0 +1,62 @@ +from typing import TYPE_CHECKING, List, Optional + +from aws_lambda_powertools.event_handler.openapi.types import CacheKey + +if TYPE_CHECKING: + from aws_lambda_powertools.event_handler.openapi.dependant import Dependant + + +def get_flat_dependant( + dependant: "Dependant", + *, + skip_repeats: bool = False, + visited: Optional[List[CacheKey]] = None, +) -> "Dependant": + """ + Flatten a recursive Dependant model structure. + + This function recursively concatenates the parameter fields of a Dependant model and its dependencies into a flat + Dependant structure. This is useful for scenarios like parameter validation where the nested structure is not + relevant. + + Parameters + ---------- + dependant: Dependant + The dependant model to flatten + skip_repeats: bool + If True, child Dependents already visited will be skipped to avoid duplicates + visited: List[CacheKey], optional + Keeps track of visited Dependents to avoid infinite recursion. Defaults to empty list. + + Returns + ------- + Dependant + The flattened Dependant model + """ + if visited is None: + visited = [] + visited.append(dependant.cache_key) + + from aws_lambda_powertools.event_handler.openapi.dependant import Dependant + + flat_dependant = Dependant( + path_params=dependant.path_params.copy(), + query_params=dependant.query_params.copy(), + header_params=dependant.header_params.copy(), + cookie_params=dependant.cookie_params.copy(), + body_params=dependant.body_params.copy(), + path=dependant.path, + ) + for sub_dependant in dependant.dependencies: + if skip_repeats and sub_dependant.cache_key in visited: + continue + + flat_sub = get_flat_dependant(sub_dependant, skip_repeats=skip_repeats, visited=visited) + + flat_dependant.path_params.extend(flat_sub.path_params) + flat_dependant.query_params.extend(flat_sub.query_params) + flat_dependant.header_params.extend(flat_sub.header_params) + flat_dependant.cookie_params.extend(flat_sub.cookie_params) + flat_dependant.body_params.extend(flat_sub.body_params) + + return flat_dependant From 9b8ce4a829fd6f89262725778473c8727106d560 Mon Sep 17 00:00:00 2001 From: Ruben Fonseca Date: Mon, 9 Oct 2023 16:15:59 +0200 Subject: [PATCH 26/75] fix: uniontype --- aws_lambda_powertools/event_handler/openapi/compat.py | 2 +- aws_lambda_powertools/event_handler/openapi/types.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/aws_lambda_powertools/event_handler/openapi/compat.py b/aws_lambda_powertools/event_handler/openapi/compat.py index 78f50f20b7b..a76583ab76d 100644 --- a/aws_lambda_powertools/event_handler/openapi/compat.py +++ b/aws_lambda_powertools/event_handler/openapi/compat.py @@ -8,7 +8,6 @@ from dataclasses import dataclass, is_dataclass from enum import Enum -from types import UnionType from typing import Any, Dict, List, Set, Tuple, Type, Union, FrozenSet, Deque, Sequence, Mapping from typing_extensions import Annotated, Literal, get_origin, get_args @@ -20,6 +19,7 @@ COMPONENT_REF_PREFIX, PYDANTIC_V2, ModelNameMap, + UnionType, ) sequence_annotation_to_type = { diff --git a/aws_lambda_powertools/event_handler/openapi/types.py b/aws_lambda_powertools/event_handler/openapi/types.py index ec2114d12a5..ae3fc53ef79 100644 --- a/aws_lambda_powertools/event_handler/openapi/types.py +++ b/aws_lambda_powertools/event_handler/openapi/types.py @@ -1,3 +1,4 @@ +import types from enum import Enum from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Set, Type, Union @@ -10,6 +11,7 @@ IncEx = Union[Set[int], Set[str], Dict[int, Any], Dict[str, Any]] ModelNameMap = Dict[Union[Type["BaseModel"], Type[Enum]], str] TypeModelOrEnum = Union[Type["BaseModel"], Type[Enum]] +UnionType = getattr(types, "UnionType", Union) PYDANTIC_V2 = PYDANTIC_VERSION.startswith("2.") COMPONENT_REF_PREFIX = "#/components/schemas/" From c3f25f8b5685cd928c5350775af8fa9cf044bd97 Mon Sep 17 00:00:00 2001 From: Ruben Fonseca Date: Mon, 9 Oct 2023 16:31:31 +0200 Subject: [PATCH 27/75] fix: types --- aws_lambda_powertools/event_handler/openapi/params.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aws_lambda_powertools/event_handler/openapi/params.py b/aws_lambda_powertools/event_handler/openapi/params.py index 31a04920cb4..a45d3b25e03 100644 --- a/aws_lambda_powertools/event_handler/openapi/params.py +++ b/aws_lambda_powertools/event_handler/openapi/params.py @@ -584,7 +584,7 @@ def analyze_param( field_info = value if PYDANTIC_V2: - field_info.annotation = type_annotation # type: ignore + field_info.annotation = type_annotation # type: ignore[attr-defined] # If we didn't determine the FieldInfo yet, we create a default one if field_info is None: From 13ccd5fb3c76b33266436ce4c270273a420e1002 Mon Sep 17 00:00:00 2001 From: Ruben Fonseca Date: Mon, 9 Oct 2023 17:19:36 +0200 Subject: [PATCH 28/75] fix: ignore unused-ignore --- aws_lambda_powertools/event_handler/openapi/params.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aws_lambda_powertools/event_handler/openapi/params.py b/aws_lambda_powertools/event_handler/openapi/params.py index a45d3b25e03..59457dbac48 100644 --- a/aws_lambda_powertools/event_handler/openapi/params.py +++ b/aws_lambda_powertools/event_handler/openapi/params.py @@ -584,7 +584,7 @@ def analyze_param( field_info = value if PYDANTIC_V2: - field_info.annotation = type_annotation # type: ignore[attr-defined] + field_info.annotation = type_annotation # type: ignore[attr-defined,unused-ignore] # If we didn't determine the FieldInfo yet, we create a default one if field_info is None: From cf1b86656c1c68cbba2547e9e246fc5a84d54e07 Mon Sep 17 00:00:00 2001 From: Ruben Fonseca Date: Mon, 9 Oct 2023 17:57:46 +0200 Subject: [PATCH 29/75] fix: moved things around --- .../event_handler/api_gateway.py | 2 +- .../event_handler/openapi/compat.py | 11 +- .../event_handler/openapi/dependant.py | 57 ++++- .../event_handler/openapi/encoders.py | 17 +- .../event_handler/openapi/exceptions.py | 8 + .../event_handler/openapi/params.py | 230 ++++++++++-------- .../event_handler/openapi/types.py | 12 +- .../event_handler/openapi/utils.py | 62 ----- 8 files changed, 227 insertions(+), 172 deletions(-) delete mode 100644 aws_lambda_powertools/event_handler/openapi/utils.py diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index 4dbdd064b61..bea3126c3a8 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -381,7 +381,7 @@ def dependant(self) -> "Dependant": @property def body_field(self) -> Optional["ModelField"]: if self._body_field is None: - from aws_lambda_powertools.event_handler.openapi.params import _get_body_field + from aws_lambda_powertools.event_handler.openapi.dependant import _get_body_field self._body_field = _get_body_field(dependant=self.dependant, name=self.operation_id) diff --git a/aws_lambda_powertools/event_handler/openapi/compat.py b/aws_lambda_powertools/event_handler/openapi/compat.py index a76583ab76d..f7d558af880 100644 --- a/aws_lambda_powertools/event_handler/openapi/compat.py +++ b/aws_lambda_powertools/event_handler/openapi/compat.py @@ -201,8 +201,8 @@ def serialize_sequence_value(*, field: ModelField, value: Any) -> Sequence[Any]: assert issubclass(origin_type, sequence_types) # type: ignore[arg-type] return sequence_annotation_to_type[origin_type](value) # type: ignore[no-any-return] - def _model_dump(model: BaseModel, mode: Literal["json", "python"] = "json", **kwargs: Any) -> Any: - return model.dict(**kwargs) + def _normalize_errors(errors: Sequence[Any]) -> List[Dict[str, Any]]: + return errors # type: ignore[return-value] def create_body_model(*, fields: Sequence[ModelField], model_name: str) -> Type[BaseModel]: field_params = {f.name: (f.field_info.annotation, f.field_info) for f in fields} @@ -390,9 +390,6 @@ def create_body_model(*, fields: Sequence[ModelField], model_name: str) -> Type[ def serialize_sequence_value(*, field: ModelField, value: Any) -> Sequence[Any]: return sequence_shape_to_type[field.shape](value) - def _model_dump(model: BaseModel, mode: Literal["json", "python"] = "json", **kwargs: Any) -> Any: - return model.dict(**kwargs) - # Common code for both versions @@ -485,3 +482,7 @@ def _regenerate_error_with_loc( ] return updated_loc_errors + + +def _model_dump(model: BaseModel, mode: Literal["json", "python"] = "json", **kwargs: Any) -> Any: + return model.dict(**kwargs) diff --git a/aws_lambda_powertools/event_handler/openapi/dependant.py b/aws_lambda_powertools/event_handler/openapi/dependant.py index b7770f7bc9d..f1172e85b36 100644 --- a/aws_lambda_powertools/event_handler/openapi/dependant.py +++ b/aws_lambda_powertools/event_handler/openapi/dependant.py @@ -1,9 +1,10 @@ import inspect import re -from typing import Any, Callable, Dict, ForwardRef, List, Optional, Set, cast +from typing import Any, Callable, Dict, ForwardRef, List, Optional, Set, Type, cast from aws_lambda_powertools.event_handler.openapi.compat import ( ModelField, + create_body_model, evaluate_forwardref, is_scalar_field, is_scalar_sequence_field, @@ -11,13 +12,16 @@ from aws_lambda_powertools.event_handler.openapi.params import ( Body, Dependant, + File, + Form, Header, Param, ParamTypes, Query, + _create_response_field, analyze_param, + get_flat_dependant, ) -from aws_lambda_powertools.event_handler.openapi.utils import get_flat_dependant """ This turns the opaque function signature into typed, validated models. @@ -243,3 +247,52 @@ def get_flat_params(dependant: Dependant) -> List[ModelField]: + flat_dependant.header_params + flat_dependant.cookie_params ) + + +def _get_body_field(*, dependant: Dependant, name: str) -> Optional[ModelField]: + flat_dependant = get_flat_dependant(dependant) + if not flat_dependant.body_params: + return None + + first_param = flat_dependant.body_params[0] + field_info = first_param.field_info + embed = getattr(field_info, "embed", None) + body_param_names_set = {param.name for param in flat_dependant.body_params} + if len(body_param_names_set) == 1 and not embed: + return first_param + + # If one field requires to embed, all have to be embedded + for param in flat_dependant.body_params: + setattr(param.field_info, "embed", True) # noqa: B010 + + model_name = "Body_" + name + body_model = create_body_model(fields=flat_dependant.body_params, model_name=model_name) + + required = any(True for f in flat_dependant.body_params if f.required) + + body_field_info_kwargs: Dict[str, Any] = {"annotation": body_model, "alias": "body"} + + if not required: + body_field_info_kwargs["default"] = None + + if any(isinstance(f.field_info, File) for f in flat_dependant.body_params): + body_field_info: Type[Body] = File + elif any(isinstance(f.field_info, Form) for f in flat_dependant.body_params): + body_field_info = Form + else: + body_field_info = Body + + body_param_media_types = [ + f.field_info.media_type for f in flat_dependant.body_params if isinstance(f.field_info, Body) + ] + if len(set(body_param_media_types)) == 1: + body_field_info_kwargs["media_type"] = body_param_media_types[0] + + final_field = _create_response_field( + name="body", + type_=body_model, + required=required, + alias="body", + field_info=body_field_info(**body_field_info_kwargs), + ) + return final_field diff --git a/aws_lambda_powertools/event_handler/openapi/encoders.py b/aws_lambda_powertools/event_handler/openapi/encoders.py index d94c7768c62..11a76e29e17 100644 --- a/aws_lambda_powertools/event_handler/openapi/encoders.py +++ b/aws_lambda_powertools/event_handler/openapi/encoders.py @@ -17,7 +17,10 @@ from aws_lambda_powertools.event_handler.openapi.types import IncEx -def isoformat(o: Union[datetime.date, datetime.time]) -> str: +def iso_format(o: Union[datetime.date, datetime.time]) -> str: + """ + ISO format for date and time + """ return o.isoformat() @@ -42,12 +45,13 @@ def decimal_encoder(dec_value: Decimal) -> Union[int, float]: return float(dec_value) +# Encoders for types that are not JSON serializable ENCODERS_BY_TYPE: Dict[Type[Any], Callable[[Any], Any]] = { bytes: lambda o: o.decode(), Color: str, - datetime.date: isoformat, - datetime.datetime: isoformat, - datetime.time: isoformat, + datetime.date: iso_format, + datetime.datetime: iso_format, + datetime.time: iso_format, datetime.timedelta: lambda td: td.total_seconds(), Decimal: decimal_encoder, Enum: lambda o: o.value, @@ -63,6 +67,7 @@ def decimal_encoder(dec_value: Decimal) -> Union[int, float]: } +# Generates a mapping of encoders to a tuple of classes that they can encode def generate_encoders_by_class_tuples( type_encoder_map: Dict[Any, Callable[[Any], Any]], ) -> Dict[Callable[[Any], Any], Tuple[Any, ...]]: @@ -72,6 +77,7 @@ def generate_encoders_by_class_tuples( return encoders +# Mapping of encoders to a tuple of classes that they can encode encoders_by_class_tuples = generate_encoders_by_class_tuples(ENCODERS_BY_TYPE) @@ -84,6 +90,9 @@ def jsonable_encoder( # noqa: C901, PLR0911, PLR0912 exclude_defaults: bool = False, exclude_none: bool = False, ) -> Any: + """ + JSON encodes an arbitrary Python object into JSON serializable data types. + """ if include is not None and not isinstance(include, (set, dict)): include = set(include) if exclude is not None and not isinstance(exclude, (set, dict)): diff --git a/aws_lambda_powertools/event_handler/openapi/exceptions.py b/aws_lambda_powertools/event_handler/openapi/exceptions.py index 6dbed56876c..fdd829ba9b1 100644 --- a/aws_lambda_powertools/event_handler/openapi/exceptions.py +++ b/aws_lambda_powertools/event_handler/openapi/exceptions.py @@ -2,6 +2,10 @@ class ValidationException(Exception): + """ + Base exception for all validation errors + """ + def __init__(self, errors: Sequence[Any]) -> None: self._errors = errors @@ -10,6 +14,10 @@ def errors(self) -> Sequence[Any]: class RequestValidationError(ValidationException): + """ + Raised when the request body does not match the OpenAPI schema + """ + def __init__(self, errors: Sequence[Any], *, body: Any = None) -> None: super().__init__(errors) self.body = body diff --git a/aws_lambda_powertools/event_handler/openapi/params.py b/aws_lambda_powertools/event_handler/openapi/params.py index 59457dbac48..1e506d88820 100644 --- a/aws_lambda_powertools/event_handler/openapi/params.py +++ b/aws_lambda_powertools/event_handler/openapi/params.py @@ -12,12 +12,10 @@ Undefined, UndefinedType, copy_field_info, - create_body_model, field_annotation_is_scalar, get_annotation_from_field_info, ) from aws_lambda_powertools.event_handler.openapi.types import PYDANTIC_V2, CacheKey -from aws_lambda_powertools.event_handler.openapi.utils import get_flat_dependant """ This turns the low-level function signature into typed, validated Pydantic models for consumption. @@ -35,7 +33,55 @@ class ParamTypes(Enum): _Unset: Any = Undefined +class Dependant: + """ + A class used internally to represent a dependency between path operation decorators and the path operation function. + """ + + def __init__( + self, + *, + path_params: Optional[List[ModelField]] = None, + query_params: Optional[List[ModelField]] = None, + header_params: Optional[List[ModelField]] = None, + cookie_params: Optional[List[ModelField]] = None, + body_params: Optional[List[ModelField]] = None, + return_param: Optional[ModelField] = None, + dependencies: Optional[List["Dependant"]] = None, + name: Optional[str] = None, + call: Optional[Callable[..., Any]] = None, + request_param_name: Optional[str] = None, + websocket_param_name: Optional[str] = None, + http_connection_param_name: Optional[str] = None, + response_param_name: Optional[str] = None, + background_tasks_param_name: Optional[str] = None, + path: Optional[str] = None, + ) -> None: + self.path_params = path_params or [] + self.query_params = query_params or [] + self.header_params = header_params or [] + self.cookie_params = cookie_params or [] + self.body_params = body_params or [] + self.return_param = return_param or None + self.dependencies = dependencies or [] + self.request_param_name = request_param_name + self.websocket_param_name = websocket_param_name + self.http_connection_param_name = http_connection_param_name + self.response_param_name = response_param_name + self.background_tasks_param_name = background_tasks_param_name + self.name = name + self.call = call + # Store the path to be able to re-generate a dependable from it in overrides + self.path = path + # Save the cache key at creation to optimize performance + self.cache_key: CacheKey = self.call + + class Param(FieldInfo): + """ + A class used internally to represent a parameter in a path operation. + """ + in_: ParamTypes def __init__( @@ -122,6 +168,10 @@ def __repr__(self) -> str: class Path(Param): + """ + A class used internally to represent a path parameter in a path operation. + """ + in_ = ParamTypes.path def __init__( @@ -192,6 +242,10 @@ def __init__( class Query(Param): + """ + A class used internally to represent a query parameter in a path operation. + """ + in_ = ParamTypes.query def __init__( @@ -257,6 +311,10 @@ def __init__( class Header(Param): + """ + A class used internally to represent a header parameter in a path operation. + """ + in_ = ParamTypes.header def __init__( @@ -326,6 +384,10 @@ def __init__( class Body(FieldInfo): + """ + A class used internally to represent a body parameter in a path operation. + """ + def __init__( self, default: Any = Undefined, @@ -412,6 +474,10 @@ def __repr__(self) -> str: class Form(Body): + """ + A class used internally to represent a form parameter in a path operation. + """ + def __init__( self, default: Any = Undefined, @@ -480,6 +546,10 @@ def __init__( class File(Form): + """ + A class used internally to represent a file parameter in a path operation. + """ + def __init__( self, default: Any = Undefined, @@ -546,6 +616,60 @@ def __init__( ) +def get_flat_dependant( + dependant: Dependant, + *, + skip_repeats: bool = False, + visited: Optional[List[CacheKey]] = None, +) -> Dependant: + """ + Flatten a recursive Dependant model structure. + + This function recursively concatenates the parameter fields of a Dependant model and its dependencies into a flat + Dependant structure. This is useful for scenarios like parameter validation where the nested structure is not + relevant. + + Parameters + ---------- + dependant: Dependant + The dependant model to flatten + skip_repeats: bool + If True, child Dependents already visited will be skipped to avoid duplicates + visited: List[CacheKey], optional + Keeps track of visited Dependents to avoid infinite recursion. Defaults to empty list. + + Returns + ------- + Dependant + The flattened Dependant model + """ + if visited is None: + visited = [] + visited.append(dependant.cache_key) + + flat_dependant = Dependant( + path_params=dependant.path_params.copy(), + query_params=dependant.query_params.copy(), + header_params=dependant.header_params.copy(), + cookie_params=dependant.cookie_params.copy(), + body_params=dependant.body_params.copy(), + path=dependant.path, + ) + for sub_dependant in dependant.dependencies: + if skip_repeats and sub_dependant.cache_key in visited: + continue + + flat_sub = get_flat_dependant(sub_dependant, skip_repeats=skip_repeats, visited=visited) + + flat_dependant.path_params.extend(flat_sub.path_params) + flat_dependant.query_params.extend(flat_sub.query_params) + flat_dependant.header_params.extend(flat_sub.header_params) + flat_dependant.cookie_params.extend(flat_sub.cookie_params) + flat_dependant.body_params.extend(flat_sub.body_params) + + return flat_dependant + + def analyze_param( *, param_name: str, @@ -607,6 +731,9 @@ def analyze_param( def _get_field_info_and_type_annotation(annotation, value, is_path_param: bool) -> Tuple[Optional[FieldInfo], Any]: + """ + Get the FieldInfo and type annotation from an annotation and value. + """ field_info: Optional[FieldInfo] = None type_annotation: Any = Any @@ -622,6 +749,9 @@ def _get_field_info_and_type_annotation(annotation, value, is_path_param: bool) def _get_field_info_annotated_type(annotation, value, is_path_param: bool) -> Tuple[Optional[FieldInfo], Any]: + """ + Get the FieldInfo and type annotation from an Annotated type. + """ field_info: Optional[FieldInfo] = None annotated_args = get_args(annotation) type_annotation = annotated_args[0] @@ -651,99 +781,6 @@ def _get_field_info_annotated_type(annotation, value, is_path_param: bool) -> Tu return field_info, type_annotation -class Dependant: - """ - A class used internally to represent a dependency between path operation decorators and the path operation function. - """ - - def __init__( - self, - *, - path_params: Optional[List[ModelField]] = None, - query_params: Optional[List[ModelField]] = None, - header_params: Optional[List[ModelField]] = None, - cookie_params: Optional[List[ModelField]] = None, - body_params: Optional[List[ModelField]] = None, - return_param: Optional[ModelField] = None, - dependencies: Optional[List["Dependant"]] = None, - name: Optional[str] = None, - call: Optional[Callable[..., Any]] = None, - request_param_name: Optional[str] = None, - websocket_param_name: Optional[str] = None, - http_connection_param_name: Optional[str] = None, - response_param_name: Optional[str] = None, - background_tasks_param_name: Optional[str] = None, - path: Optional[str] = None, - ) -> None: - self.path_params = path_params or [] - self.query_params = query_params or [] - self.header_params = header_params or [] - self.cookie_params = cookie_params or [] - self.body_params = body_params or [] - self.return_param = return_param or None - self.dependencies = dependencies or [] - self.request_param_name = request_param_name - self.websocket_param_name = websocket_param_name - self.http_connection_param_name = http_connection_param_name - self.response_param_name = response_param_name - self.background_tasks_param_name = background_tasks_param_name - self.name = name - self.call = call - # Store the path to be able to re-generate a dependable from it in overrides - self.path = path - # Save the cache key at creation to optimize performance - self.cache_key: CacheKey = self.call - - -def _get_body_field(*, dependant: Dependant, name: str) -> Optional[ModelField]: - flat_dependant = get_flat_dependant(dependant) - if not flat_dependant.body_params: - return None - - first_param = flat_dependant.body_params[0] - field_info = first_param.field_info - embed = getattr(field_info, "embed", None) - body_param_names_set = {param.name for param in flat_dependant.body_params} - if len(body_param_names_set) == 1 and not embed: - return first_param - - # If one field requires to embed, all have to be embedded - for param in flat_dependant.body_params: - setattr(param.field_info, "embed", True) # noqa: B010 - - model_name = "Body_" + name - body_model = create_body_model(fields=flat_dependant.body_params, model_name=model_name) - - required = any(True for f in flat_dependant.body_params if f.required) - - body_field_info_kwargs: Dict[str, Any] = {"annotation": body_model, "alias": "body"} - - if not required: - body_field_info_kwargs["default"] = None - - if any(isinstance(f.field_info, File) for f in flat_dependant.body_params): - body_field_info: Type[Body] = File - elif any(isinstance(f.field_info, Form) for f in flat_dependant.body_params): - body_field_info = Form - else: - body_field_info = Body - - body_param_media_types = [ - f.field_info.media_type for f in flat_dependant.body_params if isinstance(f.field_info, Body) - ] - if len(set(body_param_media_types)) == 1: - body_field_info_kwargs["media_type"] = body_param_media_types[0] - - final_field = _create_response_field( - name="body", - type_=body_model, - required=required, - alias="body", - field_info=body_field_info(**body_field_info_kwargs), - ) - return final_field - - def _create_response_field( name: str, type_: Type[Any], @@ -788,6 +825,9 @@ def _create_model_field( param_name: str, is_path_param: bool, ) -> Optional[ModelField]: + """ + Create a new ModelField from a FieldInfo and type annotation. + """ if field_info is None: return None diff --git a/aws_lambda_powertools/event_handler/openapi/types.py b/aws_lambda_powertools/event_handler/openapi/types.py index ae3fc53ef79..9161d8dc170 100644 --- a/aws_lambda_powertools/event_handler/openapi/types.py +++ b/aws_lambda_powertools/event_handler/openapi/types.py @@ -2,8 +2,6 @@ from enum import Enum from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Set, Type, Union -from pydantic.version import VERSION as PYDANTIC_VERSION - if TYPE_CHECKING: from pydantic import BaseModel # noqa: F401 @@ -13,11 +11,19 @@ TypeModelOrEnum = Union[Type["BaseModel"], Type[Enum]] UnionType = getattr(types, "UnionType", Union) -PYDANTIC_V2 = PYDANTIC_VERSION.startswith("2.") + COMPONENT_REF_PREFIX = "#/components/schemas/" COMPONENT_REF_TEMPLATE = "#/components/schemas/{model}" METHODS_WITH_BODY = {"GET", "HEAD", "POST", "PUT", "DELETE", "PATCH"} +try: + from pydantic.version import VERSION as PYDANTIC_VERSION + + PYDANTIC_V2 = PYDANTIC_VERSION.startswith("2.") +except ImportError: + PYDANTIC_V2 = False + + validation_error_definition = { "title": "ValidationError", "type": "object", diff --git a/aws_lambda_powertools/event_handler/openapi/utils.py b/aws_lambda_powertools/event_handler/openapi/utils.py deleted file mode 100644 index 0c50ae6ea1b..00000000000 --- a/aws_lambda_powertools/event_handler/openapi/utils.py +++ /dev/null @@ -1,62 +0,0 @@ -from typing import TYPE_CHECKING, List, Optional - -from aws_lambda_powertools.event_handler.openapi.types import CacheKey - -if TYPE_CHECKING: - from aws_lambda_powertools.event_handler.openapi.dependant import Dependant - - -def get_flat_dependant( - dependant: "Dependant", - *, - skip_repeats: bool = False, - visited: Optional[List[CacheKey]] = None, -) -> "Dependant": - """ - Flatten a recursive Dependant model structure. - - This function recursively concatenates the parameter fields of a Dependant model and its dependencies into a flat - Dependant structure. This is useful for scenarios like parameter validation where the nested structure is not - relevant. - - Parameters - ---------- - dependant: Dependant - The dependant model to flatten - skip_repeats: bool - If True, child Dependents already visited will be skipped to avoid duplicates - visited: List[CacheKey], optional - Keeps track of visited Dependents to avoid infinite recursion. Defaults to empty list. - - Returns - ------- - Dependant - The flattened Dependant model - """ - if visited is None: - visited = [] - visited.append(dependant.cache_key) - - from aws_lambda_powertools.event_handler.openapi.dependant import Dependant - - flat_dependant = Dependant( - path_params=dependant.path_params.copy(), - query_params=dependant.query_params.copy(), - header_params=dependant.header_params.copy(), - cookie_params=dependant.cookie_params.copy(), - body_params=dependant.body_params.copy(), - path=dependant.path, - ) - for sub_dependant in dependant.dependencies: - if skip_repeats and sub_dependant.cache_key in visited: - continue - - flat_sub = get_flat_dependant(sub_dependant, skip_repeats=skip_repeats, visited=visited) - - flat_dependant.path_params.extend(flat_sub.path_params) - flat_dependant.query_params.extend(flat_sub.query_params) - flat_dependant.header_params.extend(flat_sub.header_params) - flat_dependant.cookie_params.extend(flat_sub.cookie_params) - flat_dependant.body_params.extend(flat_sub.body_params) - - return flat_dependant From f4d944678d1e4fe91f67184d1eea53a875e1ee47 Mon Sep 17 00:00:00 2001 From: Ruben Fonseca Date: Mon, 9 Oct 2023 18:08:25 +0200 Subject: [PATCH 30/75] fix: compatibility with pydantic v2 --- .../event_handler/api_gateway.py | 31 ++++++++++++------- .../event_handler/openapi/compat.py | 16 +++++++--- 2 files changed, 31 insertions(+), 16 deletions(-) diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index bea3126c3a8..ee5238365a2 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -1389,18 +1389,25 @@ def get_openapi_json_schema( str The OpenAPI schema as a JSON serializable dict. """ - return self.get_openapi_schema( - title=title, - version=version, - openapi_version=openapi_version, - summary=summary, - description=description, - tags=tags, - servers=servers, - terms_of_service=terms_of_service, - contact=contact, - license_info=license_info, - ).json(by_alias=True, exclude_none=True, indent=2) + from aws_lambda_powertools.event_handler.openapi.compat import model_json + + return model_json( + self.get_openapi_schema( + title=title, + version=version, + openapi_version=openapi_version, + summary=summary, + description=description, + tags=tags, + servers=servers, + terms_of_service=terms_of_service, + contact=contact, + license_info=license_info, + ), + by_alias=True, + exclude_none=True, + indent=2, + ) def route( self, diff --git a/aws_lambda_powertools/event_handler/openapi/compat.py b/aws_lambda_powertools/event_handler/openapi/compat.py index f7d558af880..146441329fe 100644 --- a/aws_lambda_powertools/event_handler/openapi/compat.py +++ b/aws_lambda_powertools/event_handler/openapi/compat.py @@ -209,6 +209,12 @@ def create_body_model(*, fields: Sequence[ModelField], model_name: str) -> Type[ model: Type[BaseModel] = create_model(model_name, **field_params) return model + def _model_dump(model: BaseModel, mode: Literal["json", "python"] = "json", **kwargs: Any) -> Any: + return model.model_dump(mode=mode, **kwargs) + + def model_json(model: BaseModel, **kwargs: Any) -> Any: + return model.model_dump_json(**kwargs) + else: from pydantic import BaseModel, ValidationError from pydantic.fields import ( @@ -390,6 +396,12 @@ def create_body_model(*, fields: Sequence[ModelField], model_name: str) -> Type[ def serialize_sequence_value(*, field: ModelField, value: Any) -> Sequence[Any]: return sequence_shape_to_type[field.shape](value) + def _model_dump(model: BaseModel, mode: Literal["json", "python"] = "json", **kwargs: Any) -> Any: + return model.dict(**kwargs) + + def model_json(model: BaseModel, **kwargs: Any) -> Any: + return model.json(**kwargs) + # Common code for both versions @@ -482,7 +494,3 @@ def _regenerate_error_with_loc( ] return updated_loc_errors - - -def _model_dump(model: BaseModel, mode: Literal["json", "python"] = "json", **kwargs: Any) -> Any: - return model.dict(**kwargs) From 24a98187ff35680ad297aa806d7bd0e22b424ce0 Mon Sep 17 00:00:00 2001 From: Ruben Fonseca Date: Mon, 9 Oct 2023 20:19:08 +0200 Subject: [PATCH 31/75] chore: add tests on the body request --- .../event_handler/test_openapi_params.py | 58 ++++++++++++++++++- 1 file changed, 56 insertions(+), 2 deletions(-) diff --git a/tests/functional/event_handler/test_openapi_params.py b/tests/functional/event_handler/test_openapi_params.py index a2f444bcacc..d59e7e04de6 100644 --- a/tests/functional/event_handler/test_openapi_params.py +++ b/tests/functional/event_handler/test_openapi_params.py @@ -12,7 +12,7 @@ ParameterInType, Schema, ) -from aws_lambda_powertools.event_handler.openapi.params import Query +from aws_lambda_powertools.event_handler.openapi.params import Body, Query JSON_CONTENT_TYPE = "application/json" @@ -101,6 +101,7 @@ def handler( Query(gt=0, lt=100, examples=[Example(summary="Example 1", value=10)]), ] = 1, ): + print(count) raise NotImplementedError() schema = app.get_openapi_schema() @@ -205,7 +206,7 @@ class User: @app.get("/") def handler() -> User: - return User(name="Ruben Fonseca") + return User(surname="Fonseca") schema = app.get_openapi_schema() assert len(schema.paths.keys()) == 1 @@ -222,3 +223,56 @@ def handler() -> User: assert isinstance(user_schema, Schema) assert user_schema.title == "User" assert "surname" in user_schema.properties + + +def test_openapi_with_body_param(): + app = ApiGatewayResolver() + + class User(BaseModel): + name: str + + @app.post("/users") + def handler(user: User): + print(user) + pass + + schema = app.get_openapi_schema() + assert len(schema.paths.keys()) == 1 + + post = schema.paths["/users"].post + assert post.parameters is None + assert post.requestBody is not None + + request_body = post.requestBody + assert request_body.required is True + assert request_body.content[JSON_CONTENT_TYPE].schema_.ref == "#/components/schemas/User" + + +def test_openapi_with_embed_body_param(): + app = ApiGatewayResolver() + + class User(BaseModel): + name: str + + @app.post("/users") + def handler(user: Annotated[User, Body(embed=True)]): + print(user) + pass + + schema = app.get_openapi_schema() + assert len(schema.paths.keys()) == 1 + + post = schema.paths["/users"].post + assert post.parameters is None + assert post.requestBody is not None + + request_body = post.requestBody + assert request_body.required is True + # Notice here we craft a specific schema for the embedded user + assert request_body.content[JSON_CONTENT_TYPE].schema_.ref == "#/components/schemas/Body_PostHandler" + + # Ensure that the custom body schema actually points to the real user class + components = schema.components + assert "Body_PostHandler" in components.schemas + body_posthandler_schema = components.schemas["Body_PostHandler"] + assert body_posthandler_schema.properties["user"].ref == "#/components/schemas/User" From d17cc646f846e5262eb01d208dd1bd328091fe7e Mon Sep 17 00:00:00 2001 From: Ruben Fonseca Date: Mon, 9 Oct 2023 20:35:35 +0200 Subject: [PATCH 32/75] chore: add tests for validation middleware --- .../test_openapi_validation_middleware.py | 167 ++++++++++++++++++ 1 file changed, 167 insertions(+) create mode 100644 tests/functional/event_handler/test_openapi_validation_middleware.py diff --git a/tests/functional/event_handler/test_openapi_validation_middleware.py b/tests/functional/event_handler/test_openapi_validation_middleware.py new file mode 100644 index 00000000000..064cb6b29db --- /dev/null +++ b/tests/functional/event_handler/test_openapi_validation_middleware.py @@ -0,0 +1,167 @@ +import json +from typing import Annotated + +from pydantic import BaseModel + +from aws_lambda_powertools.event_handler import ApiGatewayResolver +from aws_lambda_powertools.event_handler.openapi.params import Body +from tests.functional.utils import load_event + +LOAD_GW_EVENT = load_event("apiGatewayProxyEvent.json") + + +def test_validate_scalars(): + app = ApiGatewayResolver(enable_validation=True) + + @app.get("/users/") + def handler(user_id: int): + print(user_id) + + # sending a number + LOAD_GW_EVENT["path"] = "/users/123" + + result = app(LOAD_GW_EVENT, {}) + assert result["statusCode"] == 200 + + # sending a string + LOAD_GW_EVENT["path"] = "/users/abc" + + result = app(LOAD_GW_EVENT, {}) + assert result["statusCode"] == 422 + assert "Input should be a valid integer, unable to parse string as an integer" in result["body"] + + +def test_validate_scalars_with_default(): + app = ApiGatewayResolver(enable_validation=True) + + @app.get("/users/") + def handler(user_id: int = 123): + print(user_id) + + # sending a number + LOAD_GW_EVENT["path"] = "/users/123" + + result = app(LOAD_GW_EVENT, {}) + assert result["statusCode"] == 200 + + # sending a string + LOAD_GW_EVENT["path"] = "/users/abc" + + result = app(LOAD_GW_EVENT, {}) + assert result["statusCode"] == 422 + assert "Input should be a valid integer, unable to parse string as an integer" in result["body"] + + +def test_validate_scalars_with_default_and_optional(): + app = ApiGatewayResolver(enable_validation=True) + + @app.get("/users/") + def handler(user_id: int = 123, include_extra: bool = False): + print(user_id) + + # sending a number + LOAD_GW_EVENT["path"] = "/users/123" + + result = app(LOAD_GW_EVENT, {}) + assert result["statusCode"] == 200 + + # sending a string + LOAD_GW_EVENT["path"] = "/users/abc" + + result = app(LOAD_GW_EVENT, {}) + assert result["statusCode"] == 422 + assert "Input should be a valid integer, unable to parse string as an integer" in result["body"] + + +def test_validate_return_type(): + app = ApiGatewayResolver(enable_validation=True) + + @app.get("/") + def handler() -> int: + return 123 + + LOAD_GW_EVENT["path"] = "/" + + result = app(LOAD_GW_EVENT, {}) + assert result["statusCode"] == 200 + assert result["body"] == 123 + + +def test_validate_return_model(): + app = ApiGatewayResolver(enable_validation=True) + + class Model(BaseModel): + name: str + age: int + + @app.get("/") + def handler() -> Model: + return Model(name="John", age=30) + + LOAD_GW_EVENT["path"] = "/" + + result = app(LOAD_GW_EVENT, {}) + assert result["statusCode"] == 200 + assert result["body"] == {"name": "John", "age": 30} + + +def test_validate_invalid_return_model(): + app = ApiGatewayResolver(enable_validation=True) + + class Model(BaseModel): + name: str + age: int + + @app.get("/") + def handler() -> Model: + return {"name": "John"} # type: ignore + + LOAD_GW_EVENT["path"] = "/" + + result = app(LOAD_GW_EVENT, {}) + assert result["statusCode"] == 422 + assert "Field required" in result["body"] + + +def test_validate_body_param(): + app = ApiGatewayResolver(enable_validation=True) + + class Model(BaseModel): + name: str + age: int + + @app.post("/") + def handler(user: Model) -> Model: + return user + + LOAD_GW_EVENT["httpMethod"] = "POST" + LOAD_GW_EVENT["path"] = "/" + LOAD_GW_EVENT["body"] = json.dumps({"name": "John", "age": 30}) + + result = app(LOAD_GW_EVENT, {}) + assert result["statusCode"] == 200 + assert result["body"] == {"name": "John", "age": 30} + + +def test_validate_embed_body_param(): + app = ApiGatewayResolver(enable_validation=True) + + class Model(BaseModel): + name: str + age: int + + @app.post("/") + def handler(user: Annotated[Model, Body(embed=True)]) -> Model: + return user + + LOAD_GW_EVENT["httpMethod"] = "POST" + LOAD_GW_EVENT["path"] = "/" + LOAD_GW_EVENT["body"] = json.dumps({"name": "John", "age": 30}) + + result = app(LOAD_GW_EVENT, {}) + assert result["statusCode"] == 422 + assert "Field required" in result["body"] + + LOAD_GW_EVENT["body"] = json.dumps({"user": {"name": "John", "age": 30}}) + result = app(LOAD_GW_EVENT, {}) + assert result["statusCode"] == 200 From 280abf5e14e2f9a2553ebfe15f6d625ec30e4806 Mon Sep 17 00:00:00 2001 From: Ruben Fonseca Date: Mon, 9 Oct 2023 20:40:48 +0200 Subject: [PATCH 33/75] fix: assorted fixes --- tests/functional/event_handler/test_openapi_params.py | 2 -- .../event_handler/test_openapi_validation_middleware.py | 2 +- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/functional/event_handler/test_openapi_params.py b/tests/functional/event_handler/test_openapi_params.py index d59e7e04de6..fd0c2c2b2c7 100644 --- a/tests/functional/event_handler/test_openapi_params.py +++ b/tests/functional/event_handler/test_openapi_params.py @@ -234,7 +234,6 @@ class User(BaseModel): @app.post("/users") def handler(user: User): print(user) - pass schema = app.get_openapi_schema() assert len(schema.paths.keys()) == 1 @@ -257,7 +256,6 @@ class User(BaseModel): @app.post("/users") def handler(user: Annotated[User, Body(embed=True)]): print(user) - pass schema = app.get_openapi_schema() assert len(schema.paths.keys()) == 1 diff --git a/tests/functional/event_handler/test_openapi_validation_middleware.py b/tests/functional/event_handler/test_openapi_validation_middleware.py index 064cb6b29db..0d686c06698 100644 --- a/tests/functional/event_handler/test_openapi_validation_middleware.py +++ b/tests/functional/event_handler/test_openapi_validation_middleware.py @@ -1,7 +1,7 @@ import json -from typing import Annotated from pydantic import BaseModel +from typing_extensions import Annotated from aws_lambda_powertools.event_handler import ApiGatewayResolver from aws_lambda_powertools.event_handler.openapi.params import Body From 1bb73c67d30e2748968c7634be4b69e1f87f60ce Mon Sep 17 00:00:00 2001 From: Ruben Fonseca Date: Mon, 9 Oct 2023 20:53:09 +0200 Subject: [PATCH 34/75] fix: make tests pass in both pydantic versions --- .../test_openapi_validation_middleware.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/functional/event_handler/test_openapi_validation_middleware.py b/tests/functional/event_handler/test_openapi_validation_middleware.py index 0d686c06698..b34d3f7c9d9 100644 --- a/tests/functional/event_handler/test_openapi_validation_middleware.py +++ b/tests/functional/event_handler/test_openapi_validation_middleware.py @@ -28,7 +28,7 @@ def handler(user_id: int): result = app(LOAD_GW_EVENT, {}) assert result["statusCode"] == 422 - assert "Input should be a valid integer, unable to parse string as an integer" in result["body"] + assert any(text in result["body"] for text in ["type_error.integer", "int_parsing"]) def test_validate_scalars_with_default(): @@ -49,7 +49,7 @@ def handler(user_id: int = 123): result = app(LOAD_GW_EVENT, {}) assert result["statusCode"] == 422 - assert "Input should be a valid integer, unable to parse string as an integer" in result["body"] + assert any(text in result["body"] for text in ["type_error.integer", "int_parsing"]) def test_validate_scalars_with_default_and_optional(): @@ -70,7 +70,7 @@ def handler(user_id: int = 123, include_extra: bool = False): result = app(LOAD_GW_EVENT, {}) assert result["statusCode"] == 422 - assert "Input should be a valid integer, unable to parse string as an integer" in result["body"] + assert any(text in result["body"] for text in ["type_error.integer", "int_parsing"]) def test_validate_return_type(): @@ -120,7 +120,7 @@ def handler() -> Model: result = app(LOAD_GW_EVENT, {}) assert result["statusCode"] == 422 - assert "Field required" in result["body"] + assert "missing" in result["body"] def test_validate_body_param(): @@ -160,7 +160,7 @@ def handler(user: Annotated[Model, Body(embed=True)]) -> Model: result = app(LOAD_GW_EVENT, {}) assert result["statusCode"] == 422 - assert "Field required" in result["body"] + assert "missing" in result["body"] LOAD_GW_EVENT["body"] = json.dumps({"user": {"name": "John", "age": 30}}) result = app(LOAD_GW_EVENT, {}) From a559ed6b7e2087c60de66a847397336107095272 Mon Sep 17 00:00:00 2001 From: Ruben Fonseca Date: Mon, 9 Oct 2023 20:57:37 +0200 Subject: [PATCH 35/75] fix: remove assert --- aws_lambda_powertools/event_handler/openapi/compat.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/aws_lambda_powertools/event_handler/openapi/compat.py b/aws_lambda_powertools/event_handler/openapi/compat.py index 146441329fe..54b78f7e5f6 100644 --- a/aws_lambda_powertools/event_handler/openapi/compat.py +++ b/aws_lambda_powertools/event_handler/openapi/compat.py @@ -198,7 +198,8 @@ def is_bytes_sequence_field(field: ModelField) -> bool: def serialize_sequence_value(*, field: ModelField, value: Any) -> Sequence[Any]: origin_type = get_origin(field.field_info.annotation) or field.field_info.annotation - assert issubclass(origin_type, sequence_types) # type: ignore[arg-type] + if not issubclass(origin_type, sequence_types): # type: ignore[arg-type] + raise AssertionError(f"Expected sequence type, got {origin_type}") return sequence_annotation_to_type[origin_type](value) # type: ignore[no-any-return] def _normalize_errors(errors: Sequence[Any]) -> List[Dict[str, Any]]: From d7317ec87c3f9b14c30a335a162721c17a828f94 Mon Sep 17 00:00:00 2001 From: Ruben Fonseca Date: Tue, 10 Oct 2023 11:11:35 +0200 Subject: [PATCH 36/75] fix: complexity --- .../middlewares/openapi_validation.py | 69 ++++--- .../event_handler/openapi/dependant.py | 43 +++- .../event_handler/openapi/encoders.py | 190 +++++++++++++----- 3 files changed, 216 insertions(+), 86 deletions(-) diff --git a/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py b/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py index 3b2155cee30..fed8047df2f 100644 --- a/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py +++ b/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py @@ -10,7 +10,6 @@ from aws_lambda_powertools.event_handler.api_gateway import Route from aws_lambda_powertools.event_handler.middlewares import BaseMiddlewareHandler, NextMiddleware from aws_lambda_powertools.event_handler.openapi.compat import ( - ErrorWrapper, ModelField, _model_dump, _normalize_errors, @@ -206,14 +205,7 @@ def _request_params_to_args( values[field.name] = deepcopy(field.default) continue - v_, errors_ = field.validate(value, values, loc=loc) - if isinstance(errors_, ErrorWrapper): - errors.append(errors_) - elif isinstance(errors_, list): - new_errors = _regenerate_error_with_loc(errors=errors_, loc_prefix=()) - errors.extend(new_errors) - else: - values[field.name] = v_ + _validate_field(field=field, value=value, loc=loc, existing_values=values, existing_errors=errors) return values, errors @@ -228,19 +220,16 @@ def _request_body_to_args( if not required_params: return values, errors - field = required_params[0] - field_info = field.field_info - embed = getattr(field_info, "embed", None) - field_alias_omitted = len(required_params) == 1 and not embed - if field_alias_omitted: - received_body = {field.alias: received_body} + received_body, field_alias_omitted = _get_embed_body( + field=required_params[0], + required_params=required_params, + received_body=received_body, + ) for field in required_params: - loc: Tuple[str, ...] + loc: Tuple[str, ...] = ("body", field.alias) if field_alias_omitted: loc = ("body",) - else: - loc = ("body", field.alias) value: Optional[Any] = None @@ -261,13 +250,41 @@ def _request_body_to_args( # MAINTENANCE: Handle byte and file fields - v_, errors_ = field.validate(value, values, loc=loc) - - if isinstance(errors_, list): - errors.extend(errors_) - elif errors_: - errors.append(errors_) - else: - values[field.name] = v_ + _validate_field(field=field, value=value, loc=loc, existing_values=values, existing_errors=errors) return values, errors + + +def _validate_field( + *, + field: ModelField, + value: Any, + loc: Tuple[str, ...], + existing_values: Dict[str, Any], + existing_errors: List[Dict[str, Any]], +): + validated_value, errors = field.validate(value, existing_values, loc=loc) + + if isinstance(errors, list): + processed_errors = _regenerate_error_with_loc(errors=errors, loc_prefix=()) + existing_errors.extend(processed_errors) + elif errors: + existing_errors.append(errors) + else: + existing_values[field.name] = validated_value + + +def _get_embed_body( + *, + field: ModelField, + required_params: List[ModelField], + received_body: Optional[Dict[str, Any]], +) -> Tuple[Optional[Dict[str, Any]], bool]: + field_info = field.field_info + embed = getattr(field_info, "embed", None) + + field_alias_omitted = len(required_params) == 1 and not embed + if field_alias_omitted: + received_body = {field.alias: received_body} + + return received_body, field_alias_omitted diff --git a/aws_lambda_powertools/event_handler/openapi/dependant.py b/aws_lambda_powertools/event_handler/openapi/dependant.py index f1172e85b36..48a2d6b4c9e 100644 --- a/aws_lambda_powertools/event_handler/openapi/dependant.py +++ b/aws_lambda_powertools/event_handler/openapi/dependant.py @@ -1,6 +1,8 @@ import inspect import re -from typing import Any, Callable, Dict, ForwardRef, List, Optional, Set, Type, cast +from typing import Any, Callable, Dict, ForwardRef, List, Optional, Set, Tuple, Type, cast + +from pydantic import BaseModel from aws_lambda_powertools.event_handler.openapi.compat import ( ModelField, @@ -250,6 +252,10 @@ def get_flat_params(dependant: Dependant) -> List[ModelField]: def _get_body_field(*, dependant: Dependant, name: str) -> Optional[ModelField]: + """ + Get the Body field for a given Dependant object. + """ + flat_dependant = get_flat_dependant(dependant) if not flat_dependant.body_params: return None @@ -270,6 +276,32 @@ def _get_body_field(*, dependant: Dependant, name: str) -> Optional[ModelField]: required = any(True for f in flat_dependant.body_params if f.required) + body_field_info, body_field_info_kwargs = _get_body_field_info( + body_model=body_model, + flat_dependant=flat_dependant, + required=required, + ) + + final_field = _create_response_field( + name="body", + type_=body_model, + required=required, + alias="body", + field_info=body_field_info(**body_field_info_kwargs), + ) + return final_field + + +def _get_body_field_info( + *, + body_model: Type[BaseModel], + flat_dependant: Dependant, + required: bool, +) -> Tuple[Type[Body], Dict[str, Any]]: + """ + Get the Body field info and kwargs for a given body model. + """ + body_field_info_kwargs: Dict[str, Any] = {"annotation": body_model, "alias": "body"} if not required: @@ -288,11 +320,4 @@ def _get_body_field(*, dependant: Dependant, name: str) -> Optional[ModelField]: if len(set(body_param_media_types)) == 1: body_field_info_kwargs["media_type"] = body_param_media_types[0] - final_field = _create_response_field( - name="body", - type_=body_model, - required=required, - alias="body", - field_info=body_field_info(**body_field_info_kwargs), - ) - return final_field + return body_field_info, body_field_info_kwargs diff --git a/aws_lambda_powertools/event_handler/openapi/encoders.py b/aws_lambda_powertools/event_handler/openapi/encoders.py index 11a76e29e17..56597aac302 100644 --- a/aws_lambda_powertools/event_handler/openapi/encoders.py +++ b/aws_lambda_powertools/event_handler/openapi/encoders.py @@ -97,10 +97,10 @@ def jsonable_encoder( # noqa: C901, PLR0911, PLR0912 include = set(include) if exclude is not None and not isinstance(exclude, (set, dict)): exclude = set(exclude) + if isinstance(obj, BaseModel): - obj_dict = _model_dump( - obj, - mode="json", + return _dump_base_model( + obj=obj, include=include, exclude=exclude, by_alias=by_alias, @@ -108,14 +108,6 @@ def jsonable_encoder( # noqa: C901, PLR0911, PLR0912 exclude_none=exclude_none, exclude_defaults=exclude_defaults, ) - if "__root__" in obj_dict: - obj_dict = obj_dict["__root__"] - - return jsonable_encoder( - obj_dict, - exclude_none=exclude_none, - exclude_defaults=exclude_defaults, - ) if dataclasses.is_dataclass(obj): obj_dict = dataclasses.asdict(obj) @@ -128,6 +120,7 @@ def jsonable_encoder( # noqa: C901, PLR0911, PLR0912 exclude_defaults=exclude_defaults, exclude_none=exclude_none, ) + if isinstance(obj, Enum): return obj.value if isinstance(obj, PurePath): @@ -135,54 +128,149 @@ def jsonable_encoder( # noqa: C901, PLR0911, PLR0912 if isinstance(obj, (str, int, float, type(None))): return obj if isinstance(obj, dict): - encoded_dict = {} - allowed_keys = set(obj.keys()) - if include is not None: - allowed_keys &= set(include) - if exclude is not None: - allowed_keys -= set(exclude) - for key, value in obj.items(): - if ( - (not isinstance(key, str) or not key.startswith("_sa")) - and (value is not None or not exclude_none) - and key in allowed_keys - ): - encoded_key = jsonable_encoder( - key, - by_alias=by_alias, - exclude_unset=exclude_unset, - exclude_none=exclude_none, - ) - encoded_value = jsonable_encoder( - value, - by_alias=by_alias, - exclude_unset=exclude_unset, - exclude_none=exclude_none, - ) - encoded_dict[encoded_key] = encoded_value - return encoded_dict + return _dump_dict( + obj=obj, + include=include, + exclude=exclude, + by_alias=by_alias, + exclude_none=exclude_none, + exclude_unset=exclude_unset, + ) if isinstance(obj, (list, set, frozenset, GeneratorType, tuple, deque)): - encoded_list = [] - for item in obj: - encoded_list.append( - jsonable_encoder( - item, - include=include, - exclude=exclude, - by_alias=by_alias, - exclude_unset=exclude_unset, - exclude_defaults=exclude_defaults, - exclude_none=exclude_none, - ), - ) - return encoded_list + return _dump_sequence( + obj=obj, + include=include, + exclude=exclude, + by_alias=by_alias, + exclude_none=exclude_none, + exclude_defaults=exclude_defaults, + exclude_unset=exclude_unset, + ) if type(obj) in ENCODERS_BY_TYPE: return ENCODERS_BY_TYPE[type(obj)](obj) + for encoder, classes_tuple in encoders_by_class_tuples.items(): if isinstance(obj, classes_tuple): return encoder(obj) + return _dump_other( + obj=obj, + include=include, + exclude=exclude, + by_alias=by_alias, + exclude_none=exclude_none, + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + ) + + +def _dump_base_model( + *, + obj: Any, + include: Optional[IncEx] = None, + exclude: Optional[IncEx] = None, + by_alias: bool = True, + exclude_unset: bool = False, + exclude_none: bool = False, + exclude_defaults: bool = False, +): + """ + Dump a BaseModel object to a dict, using the same parameters as jsonable_encoder + """ + obj_dict = _model_dump( + obj, + mode="json", + include=include, + exclude=exclude, + by_alias=by_alias, + exclude_unset=exclude_unset, + exclude_none=exclude_none, + exclude_defaults=exclude_defaults, + ) + if "__root__" in obj_dict: + obj_dict = obj_dict["__root__"] + + return jsonable_encoder( + obj_dict, + exclude_none=exclude_none, + exclude_defaults=exclude_defaults, + ) + + +def _dump_dict( + *, + obj: Any, + include: Optional[IncEx] = None, + exclude: Optional[IncEx] = None, + by_alias: bool = True, + exclude_unset: bool = False, + exclude_none: bool = False, +) -> Dict[str, Any]: + encoded_dict = {} + allowed_keys = set(obj.keys()) + if include is not None: + allowed_keys &= set(include) + if exclude is not None: + allowed_keys -= set(exclude) + for key, value in obj.items(): + if ( + (not isinstance(key, str) or not key.startswith("_sa")) + and (value is not None or not exclude_none) + and key in allowed_keys + ): + encoded_key = jsonable_encoder( + key, + by_alias=by_alias, + exclude_unset=exclude_unset, + exclude_none=exclude_none, + ) + encoded_value = jsonable_encoder( + value, + by_alias=by_alias, + exclude_unset=exclude_unset, + exclude_none=exclude_none, + ) + encoded_dict[encoded_key] = encoded_value + return encoded_dict + + +def _dump_sequence( + *, + obj: Any, + include: Optional[IncEx] = None, + exclude: Optional[IncEx] = None, + by_alias: bool = True, + exclude_unset: bool = False, + exclude_none: bool = False, + exclude_defaults: bool = False, +) -> List[Any]: + encoded_list = [] + for item in obj: + encoded_list.append( + jsonable_encoder( + item, + include=include, + exclude=exclude, + by_alias=by_alias, + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + exclude_none=exclude_none, + ), + ) + return encoded_list + + +def _dump_other( + *, + obj: Any, + include: Optional[IncEx] = None, + exclude: Optional[IncEx] = None, + by_alias: bool = True, + exclude_unset: bool = False, + exclude_none: bool = False, + exclude_defaults: bool = False, +) -> Any: try: data = dict(obj) except Exception as e: From 6b445750da6e81f8820d271ecb74680e61e65e3d Mon Sep 17 00:00:00 2001 From: Ruben Fonseca Date: Tue, 10 Oct 2023 14:42:38 +0200 Subject: [PATCH 37/75] fix: move Response class back --- .../event_handler/__init__.py | 2 +- .../event_handler/api_gateway.py | 40 +++++++++++++++++- .../event_handler/openapi/encoders.py | 3 +- .../event_handler/response.py | 41 ------------------- 4 files changed, 41 insertions(+), 45 deletions(-) delete mode 100644 aws_lambda_powertools/event_handler/response.py diff --git a/aws_lambda_powertools/event_handler/__init__.py b/aws_lambda_powertools/event_handler/__init__.py index 14372784adb..7bdd9a97f72 100644 --- a/aws_lambda_powertools/event_handler/__init__.py +++ b/aws_lambda_powertools/event_handler/__init__.py @@ -8,12 +8,12 @@ ApiGatewayResolver, APIGatewayRestResolver, CORSConfig, + Response, ) from aws_lambda_powertools.event_handler.appsync import AppSyncResolver from aws_lambda_powertools.event_handler.lambda_function_url import ( LambdaFunctionUrlResolver, ) -from aws_lambda_powertools.event_handler.response import Response from aws_lambda_powertools.event_handler.vpc_lattice import VPCLatticeResolver, VPCLatticeV2Resolver __all__ = [ diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index ee5238365a2..02df889d77e 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -36,7 +36,7 @@ validation_error_definition, validation_error_response_definition, ) -from aws_lambda_powertools.event_handler.response import Response +from aws_lambda_powertools.shared.cookies import Cookie from aws_lambda_powertools.shared.functions import powertools_dev_is_set from aws_lambda_powertools.shared.json_encoder import Encoder from aws_lambda_powertools.utilities.data_classes import ( @@ -200,6 +200,44 @@ def to_dict(self, origin: Optional[str]) -> Dict[str, str]: return headers +class Response: + """Response data class that provides greater control over what is returned from the proxy event""" + + def __init__( + self, + status_code: int, + content_type: Optional[str] = None, + body: Union[str, bytes, None] = None, + headers: Optional[Dict[str, Union[str, List[str]]]] = None, + cookies: Optional[List[Cookie]] = None, + compress: Optional[bool] = None, + ): + """ + + Parameters + ---------- + status_code: int + Http status code, example 200 + content_type: str + Optionally set the Content-Type header, example "application/json". Note this will be merged into any + provided http headers + body: Union[str, bytes, None] + Optionally set the response body. Note: bytes body will be automatically base64 encoded + headers: dict[str, Union[str, List[str]]] + Optionally set specific http headers. Setting "Content-Type" here would override the `content_type` value. + cookies: list[Cookie] + Optionally set cookies. + """ + self.status_code = status_code + self.body = body + self.base64_encoded = False + self.headers: Dict[str, Union[str, List[str]]] = headers if headers else {} + self.cookies = cookies or [] + self.compress = compress + if content_type: + self.headers.setdefault("Content-Type", content_type) + + class Route: """Internally used Route Configuration""" diff --git a/aws_lambda_powertools/event_handler/openapi/encoders.py b/aws_lambda_powertools/event_handler/openapi/encoders.py index 56597aac302..6b3ba8ac65e 100644 --- a/aws_lambda_powertools/event_handler/openapi/encoders.py +++ b/aws_lambda_powertools/event_handler/openapi/encoders.py @@ -29,9 +29,8 @@ def decimal_encoder(dec_value: Decimal) -> Union[int, float]: Encodes a Decimal as int of there's no exponent, otherwise float This is useful when we use ConstrainedDecimal to represent Numeric(x,0) - where a integer (but not int typed) is used. Encoding this as a float + where an integer (but not int typed) is used. Encoding this as a float results in failed round-tripping between encode and parse. - Our Id type is a prime example of this. >>> decimal_encoder(Decimal("1.0")) 1.0 diff --git a/aws_lambda_powertools/event_handler/response.py b/aws_lambda_powertools/event_handler/response.py deleted file mode 100644 index 3c5ffd0152d..00000000000 --- a/aws_lambda_powertools/event_handler/response.py +++ /dev/null @@ -1,41 +0,0 @@ -from typing import Dict, List, Optional, Union - -from aws_lambda_powertools.shared.cookies import Cookie - - -class Response: - """Response data class that provides greater control over what is returned from the proxy event""" - - def __init__( - self, - status_code: int, - content_type: Optional[str] = None, - body: Union[str, bytes, None] = None, - headers: Optional[Dict[str, Union[str, List[str]]]] = None, - cookies: Optional[List[Cookie]] = None, - compress: Optional[bool] = None, - ): - """ - - Parameters - ---------- - status_code: int - Http status code, example 200 - content_type: str - Optionally set the Content-Type header, example "application/json". Note this will be merged into any - provided http headers - body: Union[str, bytes, None] - Optionally set the response body. Note: bytes body will be automatically base64 encoded - headers: dict[str, Union[str, List[str]]] - Optionally set specific http headers. Setting "Content-Type" here would override the `content_type` value. - cookies: list[Cookie] - Optionally set cookies. - """ - self.status_code = status_code - self.body = body - self.base64_encoded = False - self.headers: Dict[str, Union[str, List[str]]] = headers if headers else {} - self.cookies = cookies or [] - self.compress = compress - if content_type: - self.headers.setdefault("Content-Type", content_type) From eb90c5652ef8385b5325245d841a8b44c0357fa7 Mon Sep 17 00:00:00 2001 From: Ruben Fonseca Date: Tue, 10 Oct 2023 14:43:50 +0200 Subject: [PATCH 38/75] fix: more fix --- aws_lambda_powertools/event_handler/middlewares/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aws_lambda_powertools/event_handler/middlewares/base.py b/aws_lambda_powertools/event_handler/middlewares/base.py index a6b1bff6d4a..fb4bf37cc74 100644 --- a/aws_lambda_powertools/event_handler/middlewares/base.py +++ b/aws_lambda_powertools/event_handler/middlewares/base.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod from typing import Generic -from aws_lambda_powertools.event_handler import Response +from aws_lambda_powertools.event_handler.api_gateway import Response from aws_lambda_powertools.event_handler.types import EventHandlerInstance from aws_lambda_powertools.shared.types import Protocol From 31dca104d6ed9ee1d6d38701138943ce4916654a Mon Sep 17 00:00:00 2001 From: Ruben Fonseca Date: Tue, 10 Oct 2023 14:45:03 +0200 Subject: [PATCH 39/75] fix: more fix --- .../event_handler/middlewares/schema_validation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aws_lambda_powertools/event_handler/middlewares/schema_validation.py b/aws_lambda_powertools/event_handler/middlewares/schema_validation.py index a4d3a1c17ab..66be47a48f3 100644 --- a/aws_lambda_powertools/event_handler/middlewares/schema_validation.py +++ b/aws_lambda_powertools/event_handler/middlewares/schema_validation.py @@ -1,7 +1,7 @@ import logging from typing import Dict, Optional -from aws_lambda_powertools.event_handler import Response +from aws_lambda_powertools.event_handler.api_gateway import Response from aws_lambda_powertools.event_handler.exceptions import BadRequestError, InternalServerError from aws_lambda_powertools.event_handler.middlewares import BaseMiddlewareHandler, NextMiddleware from aws_lambda_powertools.event_handler.types import EventHandlerInstance From 550528dea0c2cbcbc3c91d23cf35c24bfb9a026b Mon Sep 17 00:00:00 2001 From: Ruben Fonseca Date: Tue, 10 Oct 2023 14:45:57 +0200 Subject: [PATCH 40/75] fix: one more fix --- examples/event_handler_rest/src/binary_responses.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/examples/event_handler_rest/src/binary_responses.py b/examples/event_handler_rest/src/binary_responses.py index 0c6d15a0e8c..d56eda1afe8 100644 --- a/examples/event_handler_rest/src/binary_responses.py +++ b/examples/event_handler_rest/src/binary_responses.py @@ -2,10 +2,7 @@ from pathlib import Path from aws_lambda_powertools import Logger, Tracer -from aws_lambda_powertools.event_handler import Response -from aws_lambda_powertools.event_handler.api_gateway import ( - APIGatewayRestResolver, -) +from aws_lambda_powertools.event_handler.api_gateway import APIGatewayRestResolver, Response from aws_lambda_powertools.logging import correlation_paths from aws_lambda_powertools.utilities.typing import LambdaContext From cdfbfbf3c8134335983328e54afecca482ed767b Mon Sep 17 00:00:00 2001 From: Ruben Fonseca Date: Tue, 10 Oct 2023 14:56:34 +0200 Subject: [PATCH 41/75] fix: refactor OpenAPI validation middleware --- .../middlewares/openapi_validation.py | 193 +++++++++++------- 1 file changed, 122 insertions(+), 71 deletions(-) diff --git a/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py b/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py index fed8047df2f..586f87a3d5e 100644 --- a/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py +++ b/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py @@ -26,8 +26,33 @@ class OpenAPIValidationMiddleware(BaseMiddlewareHandler): - def __init__(self): - super().__init__() + """ + OpenAPIValidationMiddleware is a middleware that validates the request against the OpenAPI schema defined by the + Lambda handler. It also validates the response against the OpenAPI schema defined by the Lambda handler. It + should not be used directly, but rather through the `enable_validation` parameter of the `APIGatewayProxyHandler`. + + Examples + -------- + + ```python + from typing import List + + from pydantic import BaseModel + + from aws_lambda_powertools.event_handler.api_gateway import ( + APIGatewayProxyHandler, + ) + + class Todo(BaseModel): + name: str + + app = APIGatewayProxyHandler(enable_validation=True) + + @app.get("/todos") + def get_todos(): List[Todo]: + return [Todo(name="hello world")] + ``` + """ def handler(self, app: EventHandlerInstance, next_middleware: NextMiddleware) -> Response: logger.debug("OpenAPIValidationMiddleware handler") @@ -38,35 +63,43 @@ def handler(self, app: EventHandlerInstance, next_middleware: NextMiddleware) -> errors: List[Any] = [] try: - path_values, path_errors = self._request_params_to_args( + # Process path values, which can be found on the route_args + path_values, path_errors = _request_params_to_args( route.dependant.path_params, app.context["_route_args"], ) - query_values, query_errors = self._request_params_to_args( + + # Process query values + query_values, query_errors = _request_params_to_args( route.dependant.query_params, app.current_event.query_string_parameters or {}, ) values.update(path_values) values.update(query_values) - errors += path_errors + query_errors + # Process the request body, if it exists if route.dependant.body_params: - (body_values, body_errors) = self._request_body_to_args( + (body_values, body_errors) = _request_body_to_args( required_params=route.dependant.body_params, - received_body=self._get_body(app, route), + received_body=self._get_body(app), ) values.update(body_values) errors.extend(body_errors) if errors: + # Raise the validation errors raise RequestValidationError(_normalize_errors(errors)) else: + # Re-write the route_args with the validated values, and call the next middleware app.context["_route_args"] = values response = next_middleware(app) + # Process the response body, if it exists raw_response = jsonable_encoder(response.body) + + # Validate and serialize the response return self._serialize_response(field=route.dependant.return_param, response_content=raw_response) except RequestValidationError as e: return Response( @@ -87,6 +120,9 @@ def _serialize_response( exclude_defaults: bool = False, exclude_none: bool = False, ) -> Any: + """ + Serialize the response content according to the field type. + """ if field: errors = [] # MAINTENANCE: remove this when we drop pydantic v1 @@ -139,6 +175,10 @@ def _prepare_response_content( exclude_defaults: bool = False, exclude_none: bool = False, ) -> Any: + """ + Prepares the response content for serialization. + """ + if isinstance(res, BaseModel): return _model_dump( res, @@ -161,7 +201,11 @@ def _prepare_response_content( return dataclasses.asdict(res) return res - def _get_body(self, app: EventHandlerInstance, route: Route) -> Dict[str, Any]: + def _get_body(self, app: EventHandlerInstance) -> Dict[str, Any]: + """ + Get the request body from the event, and parse it as JSON. + """ + content_type_value = app.current_event.get_header_value("content-type") if not content_type_value or content_type_value.startswith("application/json"): try: @@ -182,77 +226,84 @@ def _get_body(self, app: EventHandlerInstance, route: Route) -> Dict[str, Any]: else: raise NotImplementedError("Only JSON body is supported") - @staticmethod - def _request_params_to_args( - required_params: Sequence[ModelField], - received_params: Mapping[str, Any], - ) -> Tuple[Dict[str, Any], List[Any]]: - values = {} - errors = [] - - for field in required_params: - value = received_params.get(field.alias) - - field_info = field.field_info - if not isinstance(field_info, Param): - raise AssertionError(f"Expected Param field_info, got {field_info}") - - loc = (field_info.in_.value, field.alias) - if value is None: - if field.required: - errors.append(get_missing_field_error(loc=loc)) - else: - values[field.name] = deepcopy(field.default) - continue - _validate_field(field=field, value=value, loc=loc, existing_values=values, existing_errors=errors) +def _request_params_to_args( + required_params: Sequence[ModelField], + received_params: Mapping[str, Any], +) -> Tuple[Dict[str, Any], List[Any]]: + """ + Convert the request params to a dictionary of values using validation, and returns a list of errors. + """ + values = {} + errors = [] + + for field in required_params: + value = received_params.get(field.alias) + + field_info = field.field_info + if not isinstance(field_info, Param): + raise AssertionError(f"Expected Param field_info, got {field_info}") + + loc = (field_info.in_.value, field.alias) + if value is None: + if field.required: + errors.append(get_missing_field_error(loc=loc)) + else: + values[field.name] = deepcopy(field.default) + continue + + _validate_field(field=field, value=value, loc=loc, existing_values=values, existing_errors=errors) + + return values, errors + + +def _request_body_to_args( + required_params: List[ModelField], + received_body: Optional[Dict[str, Any]], +) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]: + """ + Convert the request body to a dictionary of values using validation, and returns a list of errors. + """ + + values: Dict[str, Any] = {} + errors: List[Dict[str, Any]] = [] + if not required_params: return values, errors - @staticmethod - def _request_body_to_args( - required_params: List[ModelField], - received_body: Optional[Dict[str, Any]], - ) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]: - values: Dict[str, Any] = {} - errors: List[Dict[str, Any]] = [] - - if not required_params: - return values, errors - - received_body, field_alias_omitted = _get_embed_body( - field=required_params[0], - required_params=required_params, - received_body=received_body, - ) - - for field in required_params: - loc: Tuple[str, ...] = ("body", field.alias) - if field_alias_omitted: - loc = ("body",) - - value: Optional[Any] = None - - if received_body is not None: - try: - value = received_body.get(field.alias) - except AttributeError: - errors.append(get_missing_field_error(loc)) - continue - - # Determine if the field is required - if value is None: - if field.required: - errors.append(get_missing_field_error(loc)) - else: - values[field.name] = deepcopy(field.default) + received_body, field_alias_omitted = _get_embed_body( + field=required_params[0], + required_params=required_params, + received_body=received_body, + ) + + for field in required_params: + loc: Tuple[str, ...] = ("body", field.alias) + if field_alias_omitted: + loc = ("body",) + + value: Optional[Any] = None + + if received_body is not None: + try: + value = received_body.get(field.alias) + except AttributeError: + errors.append(get_missing_field_error(loc)) continue - # MAINTENANCE: Handle byte and file fields + # Determine if the field is required + if value is None: + if field.required: + errors.append(get_missing_field_error(loc)) + else: + values[field.name] = deepcopy(field.default) + continue + + # MAINTENANCE: Handle byte and file fields - _validate_field(field=field, value=value, loc=loc, existing_values=values, existing_errors=errors) + _validate_field(field=field, value=value, loc=loc, existing_values=values, existing_errors=errors) - return values, errors + return values, errors def _validate_field( From 5ff491bcb3f4a7e9942632bc9cbe1bddb995bc57 Mon Sep 17 00:00:00 2001 From: Ruben Fonseca Date: Tue, 10 Oct 2023 15:02:12 +0200 Subject: [PATCH 42/75] fix: refactor dependant.py --- .../event_handler/api_gateway.py | 4 +-- .../event_handler/openapi/dependant.py | 31 +++++++++++++++---- .../event_handler/openapi/params.py | 4 +-- 3 files changed, 29 insertions(+), 10 deletions(-) diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index 02df889d77e..655251f412e 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -419,9 +419,9 @@ def dependant(self) -> "Dependant": @property def body_field(self) -> Optional["ModelField"]: if self._body_field is None: - from aws_lambda_powertools.event_handler.openapi.dependant import _get_body_field + from aws_lambda_powertools.event_handler.openapi.dependant import get_body_field - self._body_field = _get_body_field(dependant=self.dependant, name=self.operation_id) + self._body_field = get_body_field(dependant=self.dependant, name=self.operation_id) return self._body_field diff --git a/aws_lambda_powertools/event_handler/openapi/dependant.py b/aws_lambda_powertools/event_handler/openapi/dependant.py index 48a2d6b4c9e..6f05ed55e6d 100644 --- a/aws_lambda_powertools/event_handler/openapi/dependant.py +++ b/aws_lambda_powertools/event_handler/openapi/dependant.py @@ -20,8 +20,8 @@ Param, ParamTypes, Query, - _create_response_field, analyze_param, + create_response_field, get_flat_dependant, ) @@ -174,6 +174,7 @@ def get_dependant( path=path, ) + # Add each parameter to the dependant model for param_name, param in signature_params.items(): # If the parameter is a path parameter, we need to set the in_ field to "path". is_path_param = param_name in path_param_names @@ -187,7 +188,7 @@ def get_dependant( is_response_param=False, ) if param_field is None: - raise AssertionError(f"Param field is None for param: {param_name}") + raise AssertionError(f"Parameter field is None for param: {param_name}") if is_body_param(param_field=param_field, is_path_param=is_path_param): dependant.body_params.append(param_field) @@ -213,6 +214,21 @@ def get_dependant( def is_body_param(*, param_field: ModelField, is_path_param: bool) -> bool: + """ + Returns whether a parameter is a request body parameter, by checking if it is a scalar field or a body field. + + Parameters + ---------- + param_field: ModelField + The parameter field + is_path_param: bool + Whether the parameter is a path parameter + + Returns + ------- + bool + Whether the parameter is a request body parameter + """ if is_path_param: if not is_scalar_field(field=param_field): raise AssertionError("Path params must be of one of the supported types") @@ -251,7 +267,7 @@ def get_flat_params(dependant: Dependant) -> List[ModelField]: ) -def _get_body_field(*, dependant: Dependant, name: str) -> Optional[ModelField]: +def get_body_field(*, dependant: Dependant, name: str) -> Optional[ModelField]: """ Get the Body field for a given Dependant object. """ @@ -262,6 +278,8 @@ def _get_body_field(*, dependant: Dependant, name: str) -> Optional[ModelField]: first_param = flat_dependant.body_params[0] field_info = first_param.field_info + + # Handle the case where there is only one body parameter and it is embedded embed = getattr(field_info, "embed", None) body_param_names_set = {param.name for param in flat_dependant.body_params} if len(body_param_names_set) == 1 and not embed: @@ -271,18 +289,19 @@ def _get_body_field(*, dependant: Dependant, name: str) -> Optional[ModelField]: for param in flat_dependant.body_params: setattr(param.field_info, "embed", True) # noqa: B010 + # Generate a custom body model for this endpoint model_name = "Body_" + name body_model = create_body_model(fields=flat_dependant.body_params, model_name=model_name) required = any(True for f in flat_dependant.body_params if f.required) - body_field_info, body_field_info_kwargs = _get_body_field_info( + body_field_info, body_field_info_kwargs = get_body_field_info( body_model=body_model, flat_dependant=flat_dependant, required=required, ) - final_field = _create_response_field( + final_field = create_response_field( name="body", type_=body_model, required=required, @@ -292,7 +311,7 @@ def _get_body_field(*, dependant: Dependant, name: str) -> Optional[ModelField]: return final_field -def _get_body_field_info( +def get_body_field_info( *, body_model: Type[BaseModel], flat_dependant: Dependant, diff --git a/aws_lambda_powertools/event_handler/openapi/params.py b/aws_lambda_powertools/event_handler/openapi/params.py index 1e506d88820..21e59126e49 100644 --- a/aws_lambda_powertools/event_handler/openapi/params.py +++ b/aws_lambda_powertools/event_handler/openapi/params.py @@ -781,7 +781,7 @@ def _get_field_info_annotated_type(annotation, value, is_path_param: bool) -> Tu return field_info, type_annotation -def _create_response_field( +def create_response_field( name: str, type_: Type[Any], default: Optional[Any] = Undefined, @@ -847,7 +847,7 @@ def _create_model_field( alias = field_info.alias or param_name field_info.alias = alias - return _create_response_field( + return create_response_field( name=param_name, type_=use_annotation, default=field_info.default, From bbb9c254b78123649398782053bdf87c70f6c4a1 Mon Sep 17 00:00:00 2001 From: Ruben Fonseca Date: Tue, 10 Oct 2023 15:07:45 +0200 Subject: [PATCH 43/75] fix: beautify encoders --- .../event_handler/openapi/encoders.py | 56 +++++++++++++++++++ 1 file changed, 56 insertions(+) diff --git a/aws_lambda_powertools/event_handler/openapi/encoders.py b/aws_lambda_powertools/event_handler/openapi/encoders.py index 6b3ba8ac65e..439cb0f038c 100644 --- a/aws_lambda_powertools/event_handler/openapi/encoders.py +++ b/aws_lambda_powertools/event_handler/openapi/encoders.py @@ -16,6 +16,10 @@ from aws_lambda_powertools.event_handler.openapi.compat import _model_dump from aws_lambda_powertools.event_handler.openapi.types import IncEx +""" +This module contains the encoders used by jsonable_encoder to convert Python objects to JSON serializable data types. +""" + def iso_format(o: Union[datetime.date, datetime.time]) -> str: """ @@ -91,12 +95,43 @@ def jsonable_encoder( # noqa: C901, PLR0911, PLR0912 ) -> Any: """ JSON encodes an arbitrary Python object into JSON serializable data types. + + This is a modified version of fastapi.encoders.jsonable_encoder that supports + encoding of pydantic.BaseModel objects. + + This function is used to encode the response body of a FastAPI endpoint. + + Parameters + ---------- + obj : Any + The object to encode + include : Optional[IncEx], optional + A set or dictionary of strings that specifies which properties should be included, by default None, + meaning everything is included + exclude : Optional[IncEx], optional + A set or dictionary of strings that specifies which properties should be excluded, by default None, + meaning nothing is excluded + by_alias : bool, optional + Whether field aliases should be respected, by default True + exclude_unset : bool, optional + Whether fields that are not set should be excluded, by default False + exclude_defaults : bool, optional + Whether fields that are equal to their default value (as specified in the model) should be excluded, + by default False + exclude_none : bool, optional + Whether fields that are equal to None should be excluded, by default False + + Returns + ------- + Any + The JSON serializable data types """ if include is not None and not isinstance(include, (set, dict)): include = set(include) if exclude is not None and not isinstance(exclude, (set, dict)): exclude = set(exclude) + # Pydantic models if isinstance(obj, BaseModel): return _dump_base_model( obj=obj, @@ -108,6 +143,7 @@ def jsonable_encoder( # noqa: C901, PLR0911, PLR0912 exclude_defaults=exclude_defaults, ) + # Dataclasses if dataclasses.is_dataclass(obj): obj_dict = dataclasses.asdict(obj) return jsonable_encoder( @@ -120,12 +156,19 @@ def jsonable_encoder( # noqa: C901, PLR0911, PLR0912 exclude_none=exclude_none, ) + # Enums if isinstance(obj, Enum): return obj.value + + # Paths if isinstance(obj, PurePath): return str(obj) + + # Scalars if isinstance(obj, (str, int, float, type(None))): return obj + + # Dictionaries if isinstance(obj, dict): return _dump_dict( obj=obj, @@ -135,6 +178,8 @@ def jsonable_encoder( # noqa: C901, PLR0911, PLR0912 exclude_none=exclude_none, exclude_unset=exclude_unset, ) + + # Sequences if isinstance(obj, (list, set, frozenset, GeneratorType, tuple, deque)): return _dump_sequence( obj=obj, @@ -146,6 +191,7 @@ def jsonable_encoder( # noqa: C901, PLR0911, PLR0912 exclude_unset=exclude_unset, ) + # Other types if type(obj) in ENCODERS_BY_TYPE: return ENCODERS_BY_TYPE[type(obj)](obj) @@ -153,6 +199,7 @@ def jsonable_encoder( # noqa: C901, PLR0911, PLR0912 if isinstance(obj, classes_tuple): return encoder(obj) + # Default return _dump_other( obj=obj, include=include, @@ -206,6 +253,9 @@ def _dump_dict( exclude_unset: bool = False, exclude_none: bool = False, ) -> Dict[str, Any]: + """ + Dump a dict to a dict, using the same parameters as jsonable_encoder + """ encoded_dict = {} allowed_keys = set(obj.keys()) if include is not None: @@ -244,6 +294,9 @@ def _dump_sequence( exclude_none: bool = False, exclude_defaults: bool = False, ) -> List[Any]: + """ + Dump a sequence to a list, using the same parameters as jsonable_encoder + """ encoded_list = [] for item in obj: encoded_list.append( @@ -270,6 +323,9 @@ def _dump_other( exclude_none: bool = False, exclude_defaults: bool = False, ) -> Any: + """ + Dump an object to ah hashable object, using the same parameters as jsonable_encoder + """ try: data = dict(obj) except Exception as e: From 5bd4a5039b728ab0388f131e665b88ec4b366dc9 Mon Sep 17 00:00:00 2001 From: Ruben Fonseca Date: Tue, 10 Oct 2023 15:10:10 +0200 Subject: [PATCH 44/75] fix: move things around --- .../event_handler/openapi/encoders.py | 128 +++++++++--------- 1 file changed, 64 insertions(+), 64 deletions(-) diff --git a/aws_lambda_powertools/event_handler/openapi/encoders.py b/aws_lambda_powertools/event_handler/openapi/encoders.py index 439cb0f038c..c31df5c5ed9 100644 --- a/aws_lambda_powertools/event_handler/openapi/encoders.py +++ b/aws_lambda_powertools/event_handler/openapi/encoders.py @@ -21,70 +21,7 @@ """ -def iso_format(o: Union[datetime.date, datetime.time]) -> str: - """ - ISO format for date and time - """ - return o.isoformat() - - -def decimal_encoder(dec_value: Decimal) -> Union[int, float]: - """ - Encodes a Decimal as int of there's no exponent, otherwise float - - This is useful when we use ConstrainedDecimal to represent Numeric(x,0) - where an integer (but not int typed) is used. Encoding this as a float - results in failed round-tripping between encode and parse. - - >>> decimal_encoder(Decimal("1.0")) - 1.0 - - >>> decimal_encoder(Decimal("1")) - 1 - """ - if dec_value.as_tuple().exponent >= 0: # type: ignore[operator] - return int(dec_value) - else: - return float(dec_value) - - -# Encoders for types that are not JSON serializable -ENCODERS_BY_TYPE: Dict[Type[Any], Callable[[Any], Any]] = { - bytes: lambda o: o.decode(), - Color: str, - datetime.date: iso_format, - datetime.datetime: iso_format, - datetime.time: iso_format, - datetime.timedelta: lambda td: td.total_seconds(), - Decimal: decimal_encoder, - Enum: lambda o: o.value, - frozenset: list, - deque: list, - GeneratorType: list, - Path: str, - Pattern: lambda o: o.pattern, - SecretBytes: str, - SecretStr: str, - set: list, - UUID: str, -} - - -# Generates a mapping of encoders to a tuple of classes that they can encode -def generate_encoders_by_class_tuples( - type_encoder_map: Dict[Any, Callable[[Any], Any]], -) -> Dict[Callable[[Any], Any], Tuple[Any, ...]]: - encoders: Dict[Callable[[Any], Any], Tuple[Any, ...]] = defaultdict(tuple) - for type_, encoder in type_encoder_map.items(): - encoders[encoder] += (type_,) - return encoders - - -# Mapping of encoders to a tuple of classes that they can encode -encoders_by_class_tuples = generate_encoders_by_class_tuples(ENCODERS_BY_TYPE) - - -def jsonable_encoder( # noqa: C901, PLR0911, PLR0912 +def jsonable_encoder( # noqa: PLR0911 obj: Any, include: Optional[IncEx] = None, exclude: Optional[IncEx] = None, @@ -344,3 +281,66 @@ def _dump_other( exclude_defaults=exclude_defaults, exclude_none=exclude_none, ) + + +def iso_format(o: Union[datetime.date, datetime.time]) -> str: + """ + ISO format for date and time + """ + return o.isoformat() + + +def decimal_encoder(dec_value: Decimal) -> Union[int, float]: + """ + Encodes a Decimal as int of there's no exponent, otherwise float + + This is useful when we use ConstrainedDecimal to represent Numeric(x,0) + where an integer (but not int typed) is used. Encoding this as a float + results in failed round-tripping between encode and parse. + + >>> decimal_encoder(Decimal("1.0")) + 1.0 + + >>> decimal_encoder(Decimal("1")) + 1 + """ + if dec_value.as_tuple().exponent >= 0: # type: ignore[operator] + return int(dec_value) + else: + return float(dec_value) + + +# Encoders for types that are not JSON serializable +ENCODERS_BY_TYPE: Dict[Type[Any], Callable[[Any], Any]] = { + bytes: lambda o: o.decode(), + Color: str, + datetime.date: iso_format, + datetime.datetime: iso_format, + datetime.time: iso_format, + datetime.timedelta: lambda td: td.total_seconds(), + Decimal: decimal_encoder, + Enum: lambda o: o.value, + frozenset: list, + deque: list, + GeneratorType: list, + Path: str, + Pattern: lambda o: o.pattern, + SecretBytes: str, + SecretStr: str, + set: list, + UUID: str, +} + + +# Generates a mapping of encoders to a tuple of classes that they can encode +def generate_encoders_by_class_tuples( + type_encoder_map: Dict[Any, Callable[[Any], Any]], +) -> Dict[Callable[[Any], Any], Tuple[Any, ...]]: + encoders: Dict[Callable[[Any], Any], Tuple[Any, ...]] = defaultdict(tuple) + for type_, encoder in type_encoder_map.items(): + encoders[encoder] += (type_,) + return encoders + + +# Mapping of encoders to a tuple of classes that they can encode +encoders_by_class_tuples = generate_encoders_by_class_tuples(ENCODERS_BY_TYPE) From a3cef34d20c76bed3d521ed2f7f7574963aefe5c Mon Sep 17 00:00:00 2001 From: Ruben Fonseca Date: Tue, 10 Oct 2023 15:14:13 +0200 Subject: [PATCH 45/75] fix: costmetic changes --- .../event_handler/openapi/params.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/aws_lambda_powertools/event_handler/openapi/params.py b/aws_lambda_powertools/event_handler/openapi/params.py index 21e59126e49..41ecd8130c5 100644 --- a/aws_lambda_powertools/event_handler/openapi/params.py +++ b/aws_lambda_powertools/event_handler/openapi/params.py @@ -152,9 +152,9 @@ def __init__( "serialization_alias": serialization_alias, "strict": strict, "json_schema_extra": current_json_schema_extra, + "pattern": pattern, }, ) - kwargs["pattern"] = pattern else: kwargs["regex"] = pattern kwargs.update(**current_json_schema_extra) @@ -458,9 +458,9 @@ def __init__( "serialization_alias": serialization_alias, "strict": strict, "json_schema_extra": current_json_schema_extra, + "pattern": pattern, }, ) - kwargs["pattern"] = pattern else: kwargs["regex"] = pattern kwargs.update(**current_json_schema_extra) @@ -699,7 +699,7 @@ def analyze_param( Optional[ModelField] The type annotation and the Pydantic field representing the parameter """ - field_info, type_annotation = _get_field_info_and_type_annotation(annotation, value, is_path_param) + field_info, type_annotation = get_field_info_and_type_annotation(annotation, value, is_path_param) # If the value is a FieldInfo, we use it as the FieldInfo for the parameter if isinstance(value, FieldInfo): @@ -730,7 +730,7 @@ def analyze_param( return field -def _get_field_info_and_type_annotation(annotation, value, is_path_param: bool) -> Tuple[Optional[FieldInfo], Any]: +def get_field_info_and_type_annotation(annotation, value, is_path_param: bool) -> Tuple[Optional[FieldInfo], Any]: """ Get the FieldInfo and type annotation from an annotation and value. """ @@ -740,7 +740,7 @@ def _get_field_info_and_type_annotation(annotation, value, is_path_param: bool) if annotation is not inspect.Signature.empty: # If the annotation is an Annotated type, we need to extract the type annotation and the FieldInfo if get_origin(annotation) is Annotated: - field_info, type_annotation = _get_field_info_annotated_type(annotation, value, is_path_param) + field_info, type_annotation = get_field_info_annotated_type(annotation, value, is_path_param) # If the annotation is not an Annotated type, we use it as the type annotation else: type_annotation = annotation @@ -748,7 +748,7 @@ def _get_field_info_and_type_annotation(annotation, value, is_path_param: bool) return field_info, type_annotation -def _get_field_info_annotated_type(annotation, value, is_path_param: bool) -> Tuple[Optional[FieldInfo], Any]: +def get_field_info_annotated_type(annotation, value, is_path_param: bool) -> Tuple[Optional[FieldInfo], Any]: """ Get the FieldInfo and type annotation from an Annotated type. """ @@ -803,6 +803,7 @@ def create_response_field( else: field_info = field_info or FieldInfo() kwargs = {"name": name, "field_info": field_info} + if PYDANTIC_V2: kwargs.update({"mode": mode}) else: From de22a9397b0962cfe70823307be048803b70b526 Mon Sep 17 00:00:00 2001 From: Ruben Fonseca Date: Tue, 10 Oct 2023 15:21:53 +0200 Subject: [PATCH 46/75] fix: add more comments --- .../event_handler/api_gateway.py | 36 +++++++++++++++++-- 1 file changed, 34 insertions(+), 2 deletions(-) diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index 655251f412e..1b74b12bc52 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -438,6 +438,7 @@ def _get_openapi_path( path = {} definitions: Dict[str, Any] = {} + # Gather all the route parameters operation = self._openapi_operation_metadata(operation_ids=operation_ids) parameters: List[Dict[str, Any]] = [] all_route_params = get_flat_params(dependant) @@ -446,14 +447,16 @@ def _get_openapi_path( model_name_map=model_name_map, field_mapping=field_mapping, ) - parameters.extend(operation_params) + + # Add the parameters to the OpenAPI operation if parameters: all_parameters = {(param["in"], param["name"]): param for param in parameters} required_parameters = {(param["in"], param["name"]): param for param in parameters if param.get("required")} all_parameters.update(required_parameters) operation["parameters"] = list(all_parameters.values()) + # Add the request body to the OpenAPI operation, if applicable if self.method.upper() in METHODS_WITH_BODY: request_body_oai = self._openapi_operation_request_body( body_field=self.body_field, @@ -463,12 +466,14 @@ def _get_openapi_path( if request_body_oai: operation["requestBody"] = request_body_oai + # Add the response to the OpenAPI operation responses = operation.setdefault("responses", {}) success_response = responses.setdefault("200", {}) success_response["description"] = self.response_description or _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION success_response["content"] = {"application/json": {"schema": {}}} json_response = success_response["content"].setdefault("application/json", {}) + # Add the response schema to the OpenAPI response json_response.update( self._openapi_operation_return( operation_id=self.operation_id, @@ -478,7 +483,7 @@ def _get_openapi_path( ), ) - # Validation responses + # Add validation responses operation["responses"]["422"] = { "description": "Validation Error", "content": { @@ -488,6 +493,7 @@ def _get_openapi_path( }, } + # We need to add the validation error schema to the definitions once if "ValidationError" not in definitions: definitions.update( { @@ -502,16 +508,27 @@ def _get_openapi_path( return path, definitions def _openapi_operation_summary(self) -> str: + """ + Returns the OpenAPI operation summary. If the user has not provided a summary, we + generate one based on the route path and method. + """ return self.summary or f"{self.method.upper()} {self.path}" def _openapi_operation_metadata(self, operation_ids: Set[str]) -> Dict[str, Any]: + """ + Returns the OpenAPI operation metadata. If the user has not provided a description, we + generate one based on the route path and method. + """ operation: Dict[str, Any] = {} + # Ensure tags is added to the operation if self.tags: operation["tags"] = self.tags + # Ensure summary is added to the operation operation["summary"] = self._openapi_operation_summary() + # Ensure description is added to the operation if self.description: operation["description"] = self.description @@ -522,6 +539,8 @@ def _openapi_operation_metadata(self, operation_ids: Set[str]) -> Dict[str, Any] if file_name: message += f" in {file_name}" warnings.warn(message, stacklevel=1) + + # Adds the operation operation_ids.add(self.operation_id) operation["operationId"] = self.operation_id @@ -534,15 +553,20 @@ def _openapi_operation_request_body( model_name_map: Dict["TypeModelOrEnum", str], field_mapping: Dict[Tuple["ModelField", Literal["validation", "serialization"]], "JsonSchemaValue"], ) -> Optional[Dict[str, Any]]: + """ + Returns the OpenAPI operation request body. + """ from aws_lambda_powertools.event_handler.openapi.compat import ModelField, get_schema_from_model_field from aws_lambda_powertools.event_handler.openapi.params import Body + # Check tat there is a body field and it's a Pydantic's model field if not body_field: return None if not isinstance(body_field, ModelField): raise AssertionError(f"Expected ModelField, got {body_field}") + # Generate the request body schema body_schema = get_schema_from_model_field( field=body_field, model_name_map=model_name_map, @@ -555,6 +579,8 @@ def _openapi_operation_request_body( request_body_oai: Dict[str, Any] = {} if required: request_body_oai["required"] = required + + # Generate the request body media type request_media_content: Dict[str, Any] = {"schema": body_schema} request_body_oai["content"] = {request_media_type: request_media_content} return request_body_oai @@ -569,6 +595,9 @@ def _openapi_operation_parameters( "JsonSchemaValue", ], ) -> List[Dict[str, Any]]: + """ + Returns the OpenAPI operation parameters. + """ from aws_lambda_powertools.event_handler.openapi.compat import ( get_schema_from_model_field, ) @@ -615,6 +644,9 @@ def _openapi_operation_return( "JsonSchemaValue", ], ) -> Dict[str, Any]: + """ + Returns the OpenAPI operation return. + """ if param is None: return {} From e60f7df1832390e1b73045f5eebd8eb1fe86db4b Mon Sep 17 00:00:00 2001 From: Ruben Fonseca Date: Tue, 10 Oct 2023 15:23:09 +0200 Subject: [PATCH 47/75] fix: format --- examples/event_handler_rest/src/binary_responses.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/examples/event_handler_rest/src/binary_responses.py b/examples/event_handler_rest/src/binary_responses.py index d56eda1afe8..f91dc879402 100644 --- a/examples/event_handler_rest/src/binary_responses.py +++ b/examples/event_handler_rest/src/binary_responses.py @@ -2,7 +2,10 @@ from pathlib import Path from aws_lambda_powertools import Logger, Tracer -from aws_lambda_powertools.event_handler.api_gateway import APIGatewayRestResolver, Response +from aws_lambda_powertools.event_handler.api_gateway import ( + APIGatewayRestResolver, + Response, +) from aws_lambda_powertools.logging import correlation_paths from aws_lambda_powertools.utilities.typing import LambdaContext From 0cd690e1e995177374cb05b75fdb8777ac63c503 Mon Sep 17 00:00:00 2001 From: Ruben Fonseca Date: Tue, 10 Oct 2023 15:40:55 +0200 Subject: [PATCH 48/75] fix: cyclomatic --- .../middlewares/openapi_validation.py | 26 +++++-------------- 1 file changed, 7 insertions(+), 19 deletions(-) diff --git a/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py b/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py index 586f87a3d5e..47756cabc8c 100644 --- a/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py +++ b/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py @@ -124,7 +124,7 @@ def _serialize_response( Serialize the response content according to the field type. """ if field: - errors = [] + errors: List[Dict[str, Any]] = [] # MAINTENANCE: remove this when we drop pydantic v1 if not hasattr(field, "serializable"): response_content = self._prepare_response_content( @@ -134,13 +134,7 @@ def _serialize_response( exclude_none=exclude_none, ) - value, errors_ = field.validate(response_content, {}, loc=("response",)) - - if isinstance(errors_, list): - errors.extend(errors_) - elif errors_: - errors.append(errors_) - + value = _validate_field(field=field, value=response_content, loc=("response",), existing_errors=errors) if errors: raise RequestValidationError(errors=_normalize_errors(errors), body=response_content) @@ -178,7 +172,6 @@ def _prepare_response_content( """ Prepares the response content for serialization. """ - if isinstance(res, BaseModel): return _model_dump( res, @@ -252,7 +245,7 @@ def _request_params_to_args( values[field.name] = deepcopy(field.default) continue - _validate_field(field=field, value=value, loc=loc, existing_values=values, existing_errors=errors) + values[field.name] = _validate_field(field=field, value=value, loc=loc, existing_errors=errors) return values, errors @@ -264,13 +257,9 @@ def _request_body_to_args( """ Convert the request body to a dictionary of values using validation, and returns a list of errors. """ - values: Dict[str, Any] = {} errors: List[Dict[str, Any]] = [] - if not required_params: - return values, errors - received_body, field_alias_omitted = _get_embed_body( field=required_params[0], required_params=required_params, @@ -301,7 +290,7 @@ def _request_body_to_args( # MAINTENANCE: Handle byte and file fields - _validate_field(field=field, value=value, loc=loc, existing_values=values, existing_errors=errors) + values[field.name] = _validate_field(field=field, value=value, loc=loc, existing_errors=errors) return values, errors @@ -311,18 +300,17 @@ def _validate_field( field: ModelField, value: Any, loc: Tuple[str, ...], - existing_values: Dict[str, Any], existing_errors: List[Dict[str, Any]], ): - validated_value, errors = field.validate(value, existing_values, loc=loc) + validated_value, errors = field.validate(value, value, loc=loc) if isinstance(errors, list): processed_errors = _regenerate_error_with_loc(errors=errors, loc_prefix=()) existing_errors.extend(processed_errors) elif errors: existing_errors.append(errors) - else: - existing_values[field.name] = validated_value + + return validated_value def _get_embed_body( From eebdc2f6fdcb84d564a8c60a80c4121ee2ee7d55 Mon Sep 17 00:00:00 2001 From: Ruben Fonseca Date: Wed, 11 Oct 2023 12:35:45 +0200 Subject: [PATCH 49/75] fix: change method of generating operation id --- aws_lambda_powertools/event_handler/api_gateway.py | 8 +++++++- tests/functional/event_handler/test_openapi_params.py | 10 +++++----- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index 1b74b12bc52..f9f8b57f56c 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -306,7 +306,7 @@ def __init__( self.response_description = response_description self.tags = tags or [] self.middlewares = middlewares or [] - self.operation_id = operation_id or (self.method.title() + self.func.__name__.title()) + self.operation_id = operation_id or self._generate_operation_id() # _middleware_stack_built is used to ensure the middleware stack is only built once. self._middleware_stack_built = False @@ -662,6 +662,12 @@ def _openapi_operation_return( return {"name": f"Return {operation_id}", "schema": return_schema} + def _generate_operation_id(self) -> str: + operation_id = self.func.__name__ + self.path + operation_id = re.sub(r"\W", "_", operation_id) + operation_id = operation_id + "_" + self.method.lower() + return operation_id + class ResponseBuilder: """Internally used Response builder""" diff --git a/tests/functional/event_handler/test_openapi_params.py b/tests/functional/event_handler/test_openapi_params.py index fd0c2c2b2c7..ce6ee2039d2 100644 --- a/tests/functional/event_handler/test_openapi_params.py +++ b/tests/functional/event_handler/test_openapi_params.py @@ -36,7 +36,7 @@ def handler(): get = path.get assert get.summary == "GET /" - assert get.operationId == "GetHandler" + assert get.operationId == "handler__get" assert get.responses is not None assert "200" in get.responses.keys() @@ -69,7 +69,7 @@ def handler(user_id: str, include_extra: bool = False): get = path.get assert get.summary == "GET /users/" - assert get.operationId == "GetHandler" + assert get.operationId == "handler_users__user_id__get" assert len(get.parameters) == 2 parameter = get.parameters[0] @@ -267,10 +267,10 @@ def handler(user: Annotated[User, Body(embed=True)]): request_body = post.requestBody assert request_body.required is True # Notice here we craft a specific schema for the embedded user - assert request_body.content[JSON_CONTENT_TYPE].schema_.ref == "#/components/schemas/Body_PostHandler" + assert request_body.content[JSON_CONTENT_TYPE].schema_.ref == "#/components/schemas/Body_handler_users_post" # Ensure that the custom body schema actually points to the real user class components = schema.components - assert "Body_PostHandler" in components.schemas - body_posthandler_schema = components.schemas["Body_PostHandler"] + assert "Body_handler_users_post" in components.schemas + body_posthandler_schema = components.schemas["Body_handler_users_post"] assert body_posthandler_schema.properties["user"].ref == "#/components/schemas/User" From b308f63253822b1691ce6d5a185b5f68732d4b53 Mon Sep 17 00:00:00 2001 From: Ruben Fonseca Date: Wed, 11 Oct 2023 13:21:42 +0200 Subject: [PATCH 50/75] fix: allow validation in all resolvers --- .../event_handler/api_gateway.py | 27 +++++++++++++++---- .../event_handler/lambda_function_url.py | 10 ++++++- .../event_handler/vpc_lattice.py | 6 +++-- .../event_handler/test_openapi_params.py | 4 +-- 4 files changed, 37 insertions(+), 10 deletions(-) diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index f9f8b57f56c..a4124f0dd74 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -1375,7 +1375,7 @@ def get_openapi_schema( output["servers"] = servers else: # If the servers property is not provided, or is an empty array, the default value would be a Server Object - # with a url value of /. + # with an url value of /. output["servers"] = [Server(url="/")] components: Dict[str, Dict[str, Any]] = {} @@ -1757,7 +1757,7 @@ def _call_route(self, route: Route, route_arguments: Dict[str, str]) -> Response logger.exception(exc) if self._debug: # If the user has turned on debug mode, - # we'll let the original exception propagate so + # we'll let the original exception propagate, so # they get more information about what went wrong. return ResponseBuilder( Response( @@ -1981,9 +1981,17 @@ def __init__( debug: Optional[bool] = None, serializer: Optional[Callable[[Dict], str]] = None, strip_prefixes: Optional[List[Union[str, Pattern]]] = None, + enable_validation: Optional[bool] = False, ): """Amazon API Gateway REST and HTTP API v1 payload resolver""" - super().__init__(ProxyEventType.APIGatewayProxyEvent, cors, debug, serializer, strip_prefixes) + super().__init__( + ProxyEventType.APIGatewayProxyEvent, + cors, + debug, + serializer, + strip_prefixes, + enable_validation, + ) # override route to ignore trailing "/" in routes for REST API def route( @@ -2032,9 +2040,17 @@ def __init__( debug: Optional[bool] = None, serializer: Optional[Callable[[Dict], str]] = None, strip_prefixes: Optional[List[Union[str, Pattern]]] = None, + enable_validation: Optional[bool] = False, ): """Amazon API Gateway HTTP API v2 payload resolver""" - super().__init__(ProxyEventType.APIGatewayProxyEventV2, cors, debug, serializer, strip_prefixes) + super().__init__( + ProxyEventType.APIGatewayProxyEventV2, + cors, + debug, + serializer, + strip_prefixes, + enable_validation, + ) class ALBResolver(ApiGatewayResolver): @@ -2046,6 +2062,7 @@ def __init__( debug: Optional[bool] = None, serializer: Optional[Callable[[Dict], str]] = None, strip_prefixes: Optional[List[Union[str, Pattern]]] = None, + enable_validation: Optional[bool] = False, ): """Amazon Application Load Balancer (ALB) resolver""" - super().__init__(ProxyEventType.ALBEvent, cors, debug, serializer, strip_prefixes) + super().__init__(ProxyEventType.ALBEvent, cors, debug, serializer, strip_prefixes, enable_validation) diff --git a/aws_lambda_powertools/event_handler/lambda_function_url.py b/aws_lambda_powertools/event_handler/lambda_function_url.py index 433a013ab0b..ff7adeb6412 100644 --- a/aws_lambda_powertools/event_handler/lambda_function_url.py +++ b/aws_lambda_powertools/event_handler/lambda_function_url.py @@ -52,5 +52,13 @@ def __init__( debug: Optional[bool] = None, serializer: Optional[Callable[[Dict], str]] = None, strip_prefixes: Optional[List[Union[str, Pattern]]] = None, + enable_validation: Optional[bool] = False, ): - super().__init__(ProxyEventType.LambdaFunctionUrlEvent, cors, debug, serializer, strip_prefixes) + super().__init__( + ProxyEventType.LambdaFunctionUrlEvent, + cors, + debug, + serializer, + strip_prefixes, + enable_validation, + ) diff --git a/aws_lambda_powertools/event_handler/vpc_lattice.py b/aws_lambda_powertools/event_handler/vpc_lattice.py index bcee046e382..4fa8d061afb 100644 --- a/aws_lambda_powertools/event_handler/vpc_lattice.py +++ b/aws_lambda_powertools/event_handler/vpc_lattice.py @@ -48,9 +48,10 @@ def __init__( debug: Optional[bool] = None, serializer: Optional[Callable[[Dict], str]] = None, strip_prefixes: Optional[List[Union[str, Pattern]]] = None, + enable_validation: bool = False, ): """Amazon VPC Lattice resolver""" - super().__init__(ProxyEventType.VPCLatticeEvent, cors, debug, serializer, strip_prefixes) + super().__init__(ProxyEventType.VPCLatticeEvent, cors, debug, serializer, strip_prefixes, enable_validation) class VPCLatticeV2Resolver(ApiGatewayResolver): @@ -93,6 +94,7 @@ def __init__( debug: Optional[bool] = None, serializer: Optional[Callable[[Dict], str]] = None, strip_prefixes: Optional[List[Union[str, Pattern]]] = None, + enable_validation: bool = False, ): """Amazon VPC Lattice resolver""" - super().__init__(ProxyEventType.VPCLatticeEventV2, cors, debug, serializer, strip_prefixes) + super().__init__(ProxyEventType.VPCLatticeEventV2, cors, debug, serializer, strip_prefixes, enable_validation) diff --git a/tests/functional/event_handler/test_openapi_params.py b/tests/functional/event_handler/test_openapi_params.py index ce6ee2039d2..f658b091338 100644 --- a/tests/functional/event_handler/test_openapi_params.py +++ b/tests/functional/event_handler/test_openapi_params.py @@ -272,5 +272,5 @@ def handler(user: Annotated[User, Body(embed=True)]): # Ensure that the custom body schema actually points to the real user class components = schema.components assert "Body_handler_users_post" in components.schemas - body_posthandler_schema = components.schemas["Body_handler_users_post"] - assert body_posthandler_schema.properties["user"].ref == "#/components/schemas/User" + body_post_handler_schema = components.schemas["Body_handler_users_post"] + assert body_post_handler_schema.properties["user"].ref == "#/components/schemas/User" From 2c7367ee652d46a0873a1cc6237924f1f37ac3fd Mon Sep 17 00:00:00 2001 From: Ruben Fonseca Date: Wed, 11 Oct 2023 13:31:09 +0200 Subject: [PATCH 51/75] fix: use proper resolver in tests --- .../event_handler/api_gateway.py | 2 +- .../event_handler/test_openapi_params.py | 20 +++++++++---------- .../event_handler/test_openapi_servers.py | 6 +++--- .../test_openapi_validation_middleware.py | 18 ++++++++--------- 4 files changed, 23 insertions(+), 23 deletions(-) diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index a4124f0dd74..bf7e083f91c 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -293,7 +293,7 @@ def __init__( The list of route middlewares to be called in order. """ self.method = method.upper() - self.path = path + self.path = "/" if path.strip() == "" else path self.rule = rule self.func = func self._middleware_stack = func diff --git a/tests/functional/event_handler/test_openapi_params.py b/tests/functional/event_handler/test_openapi_params.py index f658b091338..ec1cd4670a9 100644 --- a/tests/functional/event_handler/test_openapi_params.py +++ b/tests/functional/event_handler/test_openapi_params.py @@ -5,7 +5,7 @@ from pydantic import BaseModel from typing_extensions import Annotated -from aws_lambda_powertools.event_handler.api_gateway import ApiGatewayResolver +from aws_lambda_powertools.event_handler.api_gateway import APIGatewayRestResolver from aws_lambda_powertools.event_handler.openapi.models import ( Example, Parameter, @@ -18,7 +18,7 @@ def test_openapi_no_params(): - app = ApiGatewayResolver() + app = APIGatewayRestResolver() @app.get("/") def handler(): @@ -51,7 +51,7 @@ def handler(): def test_openapi_with_scalar_params(): - app = ApiGatewayResolver() + app = APIGatewayRestResolver() @app.get("/users/") def handler(user_id: str, include_extra: bool = False): @@ -92,7 +92,7 @@ def handler(user_id: str, include_extra: bool = False): def test_openapi_with_custom_params(): - app = ApiGatewayResolver() + app = APIGatewayRestResolver() @app.get("/users", summary="Get Users", operation_id="GetUsers", description="Get paginated users", tags=["Users"]) def handler( @@ -128,7 +128,7 @@ def handler( def test_openapi_with_scalar_returns(): - app = ApiGatewayResolver() + app = APIGatewayRestResolver() @app.get("/") def handler() -> str: @@ -146,7 +146,7 @@ def handler() -> str: def test_openapi_with_pydantic_returns(): - app = ApiGatewayResolver() + app = APIGatewayRestResolver() class User(BaseModel): name: str @@ -173,7 +173,7 @@ def handler() -> User: def test_openapi_with_pydantic_nested_returns(): - app = ApiGatewayResolver() + app = APIGatewayRestResolver() class Order(BaseModel): date: datetime @@ -198,7 +198,7 @@ def handler() -> User: def test_openapi_with_dataclass_return(): - app = ApiGatewayResolver() + app = APIGatewayRestResolver() @dataclass class User: @@ -226,7 +226,7 @@ def handler() -> User: def test_openapi_with_body_param(): - app = ApiGatewayResolver() + app = APIGatewayRestResolver() class User(BaseModel): name: str @@ -248,7 +248,7 @@ def handler(user: User): def test_openapi_with_embed_body_param(): - app = ApiGatewayResolver() + app = APIGatewayRestResolver() class User(BaseModel): name: str diff --git a/tests/functional/event_handler/test_openapi_servers.py b/tests/functional/event_handler/test_openapi_servers.py index e348afbd08c..a1ae70a1237 100644 --- a/tests/functional/event_handler/test_openapi_servers.py +++ b/tests/functional/event_handler/test_openapi_servers.py @@ -1,9 +1,9 @@ -from aws_lambda_powertools.event_handler.api_gateway import ApiGatewayResolver +from aws_lambda_powertools.event_handler.api_gateway import APIGatewayRestResolver from aws_lambda_powertools.event_handler.openapi.models import Server def test_openapi_schema_default_server(): - app = ApiGatewayResolver() + app = APIGatewayRestResolver() schema = app.get_openapi_schema(title="Hello API", version="1.0.0") assert schema.servers @@ -12,7 +12,7 @@ def test_openapi_schema_default_server(): def test_openapi_schema_custom_server(): - app = ApiGatewayResolver() + app = APIGatewayRestResolver() schema = app.get_openapi_schema( title="Hello API", diff --git a/tests/functional/event_handler/test_openapi_validation_middleware.py b/tests/functional/event_handler/test_openapi_validation_middleware.py index b34d3f7c9d9..fe1fb2816d6 100644 --- a/tests/functional/event_handler/test_openapi_validation_middleware.py +++ b/tests/functional/event_handler/test_openapi_validation_middleware.py @@ -3,7 +3,7 @@ from pydantic import BaseModel from typing_extensions import Annotated -from aws_lambda_powertools.event_handler import ApiGatewayResolver +from aws_lambda_powertools.event_handler import APIGatewayRestResolver from aws_lambda_powertools.event_handler.openapi.params import Body from tests.functional.utils import load_event @@ -11,7 +11,7 @@ def test_validate_scalars(): - app = ApiGatewayResolver(enable_validation=True) + app = APIGatewayRestResolver(enable_validation=True) @app.get("/users/") def handler(user_id: int): @@ -32,7 +32,7 @@ def handler(user_id: int): def test_validate_scalars_with_default(): - app = ApiGatewayResolver(enable_validation=True) + app = APIGatewayRestResolver(enable_validation=True) @app.get("/users/") def handler(user_id: int = 123): @@ -53,7 +53,7 @@ def handler(user_id: int = 123): def test_validate_scalars_with_default_and_optional(): - app = ApiGatewayResolver(enable_validation=True) + app = APIGatewayRestResolver(enable_validation=True) @app.get("/users/") def handler(user_id: int = 123, include_extra: bool = False): @@ -74,7 +74,7 @@ def handler(user_id: int = 123, include_extra: bool = False): def test_validate_return_type(): - app = ApiGatewayResolver(enable_validation=True) + app = APIGatewayRestResolver(enable_validation=True) @app.get("/") def handler() -> int: @@ -88,7 +88,7 @@ def handler() -> int: def test_validate_return_model(): - app = ApiGatewayResolver(enable_validation=True) + app = APIGatewayRestResolver(enable_validation=True) class Model(BaseModel): name: str @@ -106,7 +106,7 @@ def handler() -> Model: def test_validate_invalid_return_model(): - app = ApiGatewayResolver(enable_validation=True) + app = APIGatewayRestResolver(enable_validation=True) class Model(BaseModel): name: str @@ -124,7 +124,7 @@ def handler() -> Model: def test_validate_body_param(): - app = ApiGatewayResolver(enable_validation=True) + app = APIGatewayRestResolver(enable_validation=True) class Model(BaseModel): name: str @@ -144,7 +144,7 @@ def handler(user: Model) -> Model: def test_validate_embed_body_param(): - app = ApiGatewayResolver(enable_validation=True) + app = APIGatewayRestResolver(enable_validation=True) class Model(BaseModel): name: str From c87e47e5403ed39ee56660c3246cec1fb7b9a2cb Mon Sep 17 00:00:00 2001 From: Ruben Fonseca Date: Wed, 11 Oct 2023 13:41:05 +0200 Subject: [PATCH 52/75] fix: move from flake8 to ruff --- .flake8 | 1 - ruff.toml | 1 + 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/.flake8 b/.flake8 index 0f309f6621a..1db8406d9e4 100644 --- a/.flake8 +++ b/.flake8 @@ -8,7 +8,6 @@ per-file-ignores = tests/e2e/utils/data_builder/__init__.py:F401 tests/e2e/utils/data_fetcher/__init__.py:F401 aws_lambda_powertools/utilities/data_classes/s3_event.py:A003 - aws_lambda_powertools/event_handler/openapi/compat.py:F401 [isort] multi_line_output = 3 diff --git a/ruff.toml b/ruff.toml index a0f8e4fe74f..553a8c47b3d 100644 --- a/ruff.toml +++ b/ruff.toml @@ -87,5 +87,6 @@ split-on-trailing-comma = true "tests/e2e/utils/data_fetcher/__init__.py" = ["F401"] "aws_lambda_powertools/utilities/data_classes/s3_event.py" = ["A003"] "aws_lambda_powertools/utilities/parser/models/__init__.py" = ["E402"] +"aws_lambda_powertools/event_handler/openapi/compat.py" = ["F401"] # Maintenance: we're keeping EphemeralMetrics code in case of Hyrum's law so we can quickly revert it "aws_lambda_powertools/metrics/metrics.py" = ["ERA001"] From 9427ed6a2d3dd4aaead96220e9adb9aa3f16a721 Mon Sep 17 00:00:00 2001 From: Ruben Fonseca Date: Wed, 11 Oct 2023 14:07:54 +0200 Subject: [PATCH 53/75] fix: customizing responses --- .../event_handler/api_gateway.py | 95 ++++++++++--------- .../middlewares/openapi_validation.py | 2 +- .../event_handler/test_openapi_responses.py | 49 ++++++++++ 3 files changed, 100 insertions(+), 46 deletions(-) create mode 100644 tests/functional/event_handler/test_openapi_responses.py diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index bf7e083f91c..3b61c903f75 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -252,7 +252,7 @@ def __init__( cache_control: Optional[str], summary: Optional[str], description: Optional[str], - responses: Optional[Dict[Union[int, str], Dict[str, Any]]], + responses: Optional[Dict[int, Dict[str, Any]]], response_description: Optional[str], tags: Optional[List["Tag"]], operation_id: Optional[str], @@ -281,7 +281,7 @@ def __init__( The OpenAPI summary for this route description: Optional[str] The OpenAPI description for this route - responses: Optional[Dict[Union[int, str], Dict[str, Any]]] + responses: Optional[Dict[int, Dict[str, Any]]] The OpenAPI responses for this route response_description: Optional[str] The OpenAPI response description for this route @@ -467,40 +467,45 @@ def _get_openapi_path( operation["requestBody"] = request_body_oai # Add the response to the OpenAPI operation - responses = operation.setdefault("responses", {}) - success_response = responses.setdefault("200", {}) - success_response["description"] = self.response_description or _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION - success_response["content"] = {"application/json": {"schema": {}}} - json_response = success_response["content"].setdefault("application/json", {}) - - # Add the response schema to the OpenAPI response - json_response.update( - self._openapi_operation_return( - operation_id=self.operation_id, - param=dependant.return_param, - model_name_map=model_name_map, - field_mapping=field_mapping, - ), - ) + if self.responses: + operation["responses"] = self.responses + else: + responses = operation.setdefault("responses", self.responses or {}) + + # Handle the default 200 response + success_response = responses.setdefault(200, {}) + success_response["description"] = self.response_description or _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION + success_response["content"] = {"application/json": {"schema": {}}} + json_response = success_response["content"].setdefault("application/json", {}) + + # Add the response schema to the OpenAPI response + json_response.update( + self._openapi_operation_return( + operation_id=self.operation_id, + param=dependant.return_param, + model_name_map=model_name_map, + field_mapping=field_mapping, + ), + ) - # Add validation responses - operation["responses"]["422"] = { - "description": "Validation Error", - "content": { - "application/json": { - "schema": {"$ref": COMPONENT_REF_PREFIX + "HTTPValidationError"}, + # Add validation responses + operation["responses"][422] = { + "description": "Validation Error", + "content": { + "application/json": { + "schema": {"$ref": COMPONENT_REF_PREFIX + "HTTPValidationError"}, + }, }, - }, - } + } - # We need to add the validation error schema to the definitions once - if "ValidationError" not in definitions: - definitions.update( - { - "ValidationError": validation_error_definition, - "HTTPValidationError": validation_error_response_definition, - }, - ) + # We need to add the validation error schema to the definitions once + if "ValidationError" not in definitions: + definitions.update( + { + "ValidationError": validation_error_definition, + "HTTPValidationError": validation_error_response_definition, + }, + ) path[self.method.lower()] = operation @@ -781,7 +786,7 @@ def route( cache_control: Optional[str] = None, summary: Optional[str] = None, description: Optional[str] = None, - responses: Optional[Dict[Union[int, str], Dict[str, Any]]] = None, + responses: Optional[Dict[int, Dict[str, Any]]] = None, response_description: Optional[str] = _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION, tags: Optional[List["Tag"]] = None, operation_id: Optional[str] = None, @@ -838,7 +843,7 @@ def get( cache_control: Optional[str] = None, summary: Optional[str] = None, description: Optional[str] = None, - responses: Optional[Dict[Union[int, str], Dict[str, Any]]] = None, + responses: Optional[Dict[int, Dict[str, Any]]] = None, response_description: Optional[str] = _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION, tags: Optional[List["Tag"]] = None, operation_id: Optional[str] = None, @@ -889,7 +894,7 @@ def post( cache_control: Optional[str] = None, summary: Optional[str] = None, description: Optional[str] = None, - responses: Optional[Dict[Union[int, str], Dict[str, Any]]] = None, + responses: Optional[Dict[int, Dict[str, Any]]] = None, response_description: Optional[str] = _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION, tags: Optional[List["Tag"]] = None, operation_id: Optional[str] = None, @@ -941,7 +946,7 @@ def put( cache_control: Optional[str] = None, summary: Optional[str] = None, description: Optional[str] = None, - responses: Optional[Dict[Union[int, str], Dict[str, Any]]] = None, + responses: Optional[Dict[int, Dict[str, Any]]] = None, response_description: Optional[str] = _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION, tags: Optional[List["Tag"]] = None, operation_id: Optional[str] = None, @@ -993,7 +998,7 @@ def delete( cache_control: Optional[str] = None, summary: Optional[str] = None, description: Optional[str] = None, - responses: Optional[Dict[Union[int, str], Dict[str, Any]]] = None, + responses: Optional[Dict[int, Dict[str, Any]]] = None, response_description: Optional[str] = _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION, tags: Optional[List["Tag"]] = None, operation_id: Optional[str] = None, @@ -1044,7 +1049,7 @@ def patch( cache_control: Optional[str] = None, summary: Optional[str] = None, description: Optional[str] = None, - responses: Optional[Dict[Union[int, str], Dict[str, Any]]] = None, + responses: Optional[Dict[int, Dict[str, Any]]] = None, response_description: Optional[str] = _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION, tags: Optional[List["Tag"]] = None, operation_id: Optional[str] = None, @@ -1262,7 +1267,7 @@ def __init__( debug: Optional[bool] Enables debug mode, by default False. Can be also be enabled by "POWERTOOLS_DEV" environment variable - serializer : Callable, optional + serializer: Callable, optional function to serialize `obj` to a JSON formatted `str`, by default json.dumps strip_prefixes: List[Union[str, Pattern]], optional optional list of prefixes to be removed from the request path before doing the routing. @@ -1293,8 +1298,8 @@ def __init__( self.use([OpenAPIValidationMiddleware()]) - # When using validation, we need to skip the serializer, as the middleware is doing it automatically - # However, if the user is using a custom serializer, we need to abort + # When using validation, we need to skip the serializer, as the middleware is doing it automatically. + # However, if the user is using a custom serializer, we need to abort. if serializer: raise ValueError("Cannot use a custom serializer when using validation") @@ -1494,7 +1499,7 @@ def route( cache_control: Optional[str] = None, summary: Optional[str] = None, description: Optional[str] = None, - responses: Optional[Dict[Union[int, str], Dict[str, Any]]] = None, + responses: Optional[Dict[int, Dict[str, Any]]] = None, response_description: Optional[str] = _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION, tags: Optional[List["Tag"]] = None, operation_id: Optional[str] = None, @@ -1931,7 +1936,7 @@ def route( cache_control: Optional[str] = None, summary: Optional[str] = None, description: Optional[str] = None, - responses: Optional[Dict[Union[int, str], Dict[str, Any]]] = None, + responses: Optional[Dict[int, Dict[str, Any]]] = None, response_description: Optional[str] = _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION, tags: Optional[List["Tag"]] = None, operation_id: Optional[str] = None, @@ -2003,7 +2008,7 @@ def route( cache_control: Optional[str] = None, summary: Optional[str] = None, description: Optional[str] = None, - responses: Optional[Dict[Union[int, str], Dict[str, Any]]] = None, + responses: Optional[Dict[int, Dict[str, Any]]] = None, response_description: Optional[str] = _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION, tags: Optional[List["Tag"]] = None, operation_id: Optional[str] = None, diff --git a/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py b/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py index 47756cabc8c..4c84c92c908 100644 --- a/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py +++ b/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py @@ -96,7 +96,7 @@ def handler(self, app: EventHandlerInstance, next_middleware: NextMiddleware) -> app.context["_route_args"] = values response = next_middleware(app) - # Process the response body, if it exists + # Process the response body if it exists raw_response = jsonable_encoder(response.body) # Validate and serialize the response diff --git a/tests/functional/event_handler/test_openapi_responses.py b/tests/functional/event_handler/test_openapi_responses.py new file mode 100644 index 00000000000..e566a2f5ca1 --- /dev/null +++ b/tests/functional/event_handler/test_openapi_responses.py @@ -0,0 +1,49 @@ +from aws_lambda_powertools.event_handler import APIGatewayRestResolver + + +def test_openapi_default_response(): + app = APIGatewayRestResolver(enable_validation=True) + + @app.get("/") + def handler(): + pass + + schema = app.get_openapi_schema() + responses = schema.paths["/"].get.responses + assert "200" in responses.keys() + assert responses["200"].description == "Successful Response" + + assert "422" in responses.keys() + assert responses["422"].description == "Validation Error" + + +def test_openapi_200_response_with_description(): + app = APIGatewayRestResolver(enable_validation=True) + + @app.get("/", response_description="Custom response") + def handler(): + return {"message": "hello world"} + + schema = app.get_openapi_schema() + responses = schema.paths["/"].get.responses + assert "200" in responses.keys() + assert responses["200"].description == "Custom response" + + assert "422" in responses.keys() + assert responses["422"].description == "Validation Error" + + +def test_openapi_200_custom_response(): + app = APIGatewayRestResolver(enable_validation=True) + + @app.get("/", responses={202: {"description": "Custom response"}}) + def handler(): + return {"message": "hello world"} + + schema = app.get_openapi_schema() + responses = schema.paths["/"].get.responses + assert "202" in responses.keys() + assert responses["202"].description == "Custom response" + + assert "200" not in responses.keys() + assert "422" not in responses.keys() From 2cb7c6766ca117119514a4bc13d37d594306c41b Mon Sep 17 00:00:00 2001 From: Ruben Fonseca Date: Wed, 11 Oct 2023 14:20:42 +0200 Subject: [PATCH 54/75] fix: add documentation to a method --- aws_lambda_powertools/event_handler/api_gateway.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index 3b61c903f75..6bb2589e99f 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -433,6 +433,9 @@ def _get_openapi_path( model_name_map: Dict["TypeModelOrEnum", str], field_mapping: Dict[Tuple["ModelField", Literal["validation", "serialization"]], "JsonSchemaValue"], ) -> Tuple[Dict[str, Any], Dict[str, Any]]: + """ + Returns the OpenAPI path and definitions for the route. + """ from aws_lambda_powertools.event_handler.openapi.dependant import get_flat_params path = {} From 0a695822dda0f0ba1b25b05d26661e7facbffad1 Mon Sep 17 00:00:00 2001 From: Ruben Fonseca Date: Wed, 11 Oct 2023 14:23:03 +0200 Subject: [PATCH 55/75] fix: more explicit comments --- aws_lambda_powertools/event_handler/api_gateway.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index 6bb2589e99f..941b12d757d 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -471,17 +471,17 @@ def _get_openapi_path( # Add the response to the OpenAPI operation if self.responses: + # If the user supplied responses, we use them and don't set a default 200 response operation["responses"] = self.responses else: + # Set the default 200 response responses = operation.setdefault("responses", self.responses or {}) - - # Handle the default 200 response success_response = responses.setdefault(200, {}) success_response["description"] = self.response_description or _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION success_response["content"] = {"application/json": {"schema": {}}} json_response = success_response["content"].setdefault("application/json", {}) - # Add the response schema to the OpenAPI response + # Add the response schema to the OpenAPI 200 response json_response.update( self._openapi_operation_return( operation_id=self.operation_id, @@ -491,7 +491,7 @@ def _get_openapi_path( ), ) - # Add validation responses + # Add validation failure response (422) operation["responses"][422] = { "description": "Validation Error", "content": { @@ -501,7 +501,7 @@ def _get_openapi_path( }, } - # We need to add the validation error schema to the definitions once + # Add the validation error schema to the definitions, but only if it hasn't been added yet if "ValidationError" not in definitions: definitions.update( { From ab21cb34d9698c7689f1c15415cb4aab3546a5dc Mon Sep 17 00:00:00 2001 From: Ruben Fonseca Date: Wed, 11 Oct 2023 14:24:49 +0200 Subject: [PATCH 56/75] fix: typo --- .../event_handler/middlewares/openapi_validation.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py b/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py index 4c84c92c908..8bc420a080f 100644 --- a/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py +++ b/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py @@ -29,7 +29,7 @@ class OpenAPIValidationMiddleware(BaseMiddlewareHandler): """ OpenAPIValidationMiddleware is a middleware that validates the request against the OpenAPI schema defined by the Lambda handler. It also validates the response against the OpenAPI schema defined by the Lambda handler. It - should not be used directly, but rather through the `enable_validation` parameter of the `APIGatewayProxyHandler`. + should not be used directly, but rather through the `enable_validation` parameter of the `ApiGatewayResolver`. Examples -------- @@ -40,13 +40,13 @@ class OpenAPIValidationMiddleware(BaseMiddlewareHandler): from pydantic import BaseModel from aws_lambda_powertools.event_handler.api_gateway import ( - APIGatewayProxyHandler, + APIGatewayRestResolver, ) class Todo(BaseModel): name: str - app = APIGatewayProxyHandler(enable_validation=True) + app = APIGatewayRestResolver(enable_validation=True) @app.get("/todos") def get_todos(): List[Todo]: From 2fa15a463de2988227ceee73dd05b4112c5ed837 Mon Sep 17 00:00:00 2001 From: Ruben Fonseca Date: Wed, 11 Oct 2023 14:31:29 +0200 Subject: [PATCH 57/75] fix: add extra comment --- .../event_handler/middlewares/openapi_validation.py | 1 + 1 file changed, 1 insertion(+) diff --git a/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py b/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py index 8bc420a080f..3913e5cf973 100644 --- a/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py +++ b/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py @@ -159,6 +159,7 @@ def _serialize_response( exclude_none=exclude_none, ) else: + # Just serialize the response content returned from the handler return jsonable_encoder(response_content) def _prepare_response_content( From efd339c3d8b84ebf4da4509212929f50b54e1e95 Mon Sep 17 00:00:00 2001 From: Ruben Fonseca Date: Wed, 11 Oct 2023 14:35:18 +0200 Subject: [PATCH 58/75] fix: comment --- .../event_handler/middlewares/openapi_validation.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py b/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py index 3913e5cf973..dc9554a65e7 100644 --- a/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py +++ b/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py @@ -239,6 +239,8 @@ def _request_params_to_args( raise AssertionError(f"Expected Param field_info, got {field_info}") loc = (field_info.in_.value, field.alias) + + # If we don't have a value, see if it's required or has a default if value is None: if field.required: errors.append(get_missing_field_error(loc=loc)) @@ -246,6 +248,7 @@ def _request_params_to_args( values[field.name] = deepcopy(field.default) continue + # Finally, validate the value values[field.name] = _validate_field(field=field, value=value, loc=loc, existing_errors=errors) return values, errors From 0c2db13dc8e5e3fe814e4146e3f1ca09451fd0fa Mon Sep 17 00:00:00 2001 From: Ruben Fonseca Date: Wed, 11 Oct 2023 14:37:43 +0200 Subject: [PATCH 59/75] fix: add comments --- .../event_handler/middlewares/openapi_validation.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py b/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py index dc9554a65e7..1ef13dd86bb 100644 --- a/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py +++ b/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py @@ -271,12 +271,16 @@ def _request_body_to_args( ) for field in required_params: + # This sets the location to: + # { "user": { object } } if field.alias == user + # { { object } if field_alias is omitted loc: Tuple[str, ...] = ("body", field.alias) if field_alias_omitted: loc = ("body",) value: Optional[Any] = None + # Now that we know what to look for, try to get the value from the received body if received_body is not None: try: value = received_body.get(field.alias) @@ -294,6 +298,7 @@ def _request_body_to_args( # MAINTENANCE: Handle byte and file fields + # Finally, validate the value values[field.name] = _validate_field(field=field, value=value, loc=loc, existing_errors=errors) return values, errors From c2d7bc32288c1a395247ab14db4de1e03930e949 Mon Sep 17 00:00:00 2001 From: Ruben Fonseca Date: Wed, 11 Oct 2023 14:39:10 +0200 Subject: [PATCH 60/75] fix: comments --- .../event_handler/middlewares/openapi_validation.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py b/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py index 1ef13dd86bb..ea7b303bfa5 100644 --- a/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py +++ b/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py @@ -311,6 +311,9 @@ def _validate_field( loc: Tuple[str, ...], existing_errors: List[Dict[str, Any]], ): + """ + Validate a field, and append any errors to the existing_errors list. + """ validated_value, errors = field.validate(value, value, loc=loc) if isinstance(errors, list): @@ -331,6 +334,7 @@ def _get_embed_body( field_info = field.field_info embed = getattr(field_info, "embed", None) + # If the field is an embed, and the field alias is omitted, we need to wrap the received body in the field alias. field_alias_omitted = len(required_params) == 1 and not embed if field_alias_omitted: received_body = {field.alias: received_body} From a0a9adc052b57d3f85db6c50a32f76a47eba7cf1 Mon Sep 17 00:00:00 2001 From: Ruben Fonseca Date: Wed, 11 Oct 2023 14:40:44 +0200 Subject: [PATCH 61/75] fix: typo --- aws_lambda_powertools/event_handler/api_gateway.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index 941b12d757d..5af7002de0c 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -567,7 +567,7 @@ def _openapi_operation_request_body( from aws_lambda_powertools.event_handler.openapi.compat import ModelField, get_schema_from_model_field from aws_lambda_powertools.event_handler.openapi.params import Body - # Check tat there is a body field and it's a Pydantic's model field + # Check that there is a body field and it's a Pydantic's model field if not body_field: return None From 526d9f7e3864394eaf28f21ff8186ab2a5b8e6b7 Mon Sep 17 00:00:00 2001 From: Ruben Fonseca Date: Wed, 11 Oct 2023 14:45:28 +0200 Subject: [PATCH 62/75] fix: remove leftover comment --- aws_lambda_powertools/event_handler/openapi/encoders.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/aws_lambda_powertools/event_handler/openapi/encoders.py b/aws_lambda_powertools/event_handler/openapi/encoders.py index c31df5c5ed9..94c1cb5d659 100644 --- a/aws_lambda_powertools/event_handler/openapi/encoders.py +++ b/aws_lambda_powertools/event_handler/openapi/encoders.py @@ -36,8 +36,6 @@ def jsonable_encoder( # noqa: PLR0911 This is a modified version of fastapi.encoders.jsonable_encoder that supports encoding of pydantic.BaseModel objects. - This function is used to encode the response body of a FastAPI endpoint. - Parameters ---------- obj : Any From 76f3a32e75e1cd88a67d5529f1f8321d78d960dc Mon Sep 17 00:00:00 2001 From: Ruben Fonseca Date: Tue, 17 Oct 2023 10:53:56 +0200 Subject: [PATCH 63/75] fix: addressing comments --- .../event_handler/api_gateway.py | 3 +-- .../event_handler/openapi/models.py | 2 +- .../event_handler/openapi/params.py | 2 +- aws_lambda_powertools/shared/types.py | 16 +++++++++++----- .../event_handler/test_openapi_params.py | 2 +- .../test_openapi_validation_middleware.py | 2 +- 6 files changed, 16 insertions(+), 11 deletions(-) diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index 5af7002de0c..7bd897e8036 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -26,8 +26,6 @@ cast, ) -from typing_extensions import Literal - from aws_lambda_powertools.event_handler import content_types from aws_lambda_powertools.event_handler.exceptions import NotFoundError, ServiceError from aws_lambda_powertools.event_handler.openapi.types import ( @@ -39,6 +37,7 @@ from aws_lambda_powertools.shared.cookies import Cookie from aws_lambda_powertools.shared.functions import powertools_dev_is_set from aws_lambda_powertools.shared.json_encoder import Encoder +from aws_lambda_powertools.shared.types import Literal from aws_lambda_powertools.utilities.data_classes import ( ALBEvent, APIGatewayProxyEvent, diff --git a/aws_lambda_powertools/event_handler/openapi/models.py b/aws_lambda_powertools/event_handler/openapi/models.py index 4b5218f9833..3a3433cf531 100644 --- a/aws_lambda_powertools/event_handler/openapi/models.py +++ b/aws_lambda_powertools/event_handler/openapi/models.py @@ -2,10 +2,10 @@ from typing import Any, Dict, List, Optional, Set, Union from pydantic import AnyUrl, BaseModel, Field -from typing_extensions import Annotated, Literal from aws_lambda_powertools.event_handler.openapi.compat import model_rebuild from aws_lambda_powertools.event_handler.openapi.types import PYDANTIC_V2 +from aws_lambda_powertools.shared.types import Annotated, Literal """ The code defines Pydantic models for the various OpenAPI objects like OpenAPI, PathItem, Operation, Parameter etc. diff --git a/aws_lambda_powertools/event_handler/openapi/params.py b/aws_lambda_powertools/event_handler/openapi/params.py index 41ecd8130c5..e3f53089495 100644 --- a/aws_lambda_powertools/event_handler/openapi/params.py +++ b/aws_lambda_powertools/event_handler/openapi/params.py @@ -4,7 +4,6 @@ from pydantic import BaseConfig from pydantic.fields import FieldInfo -from typing_extensions import Annotated, Literal, get_args, get_origin from aws_lambda_powertools.event_handler.openapi.compat import ( ModelField, @@ -16,6 +15,7 @@ get_annotation_from_field_info, ) from aws_lambda_powertools.event_handler.openapi.types import PYDANTIC_V2, CacheKey +from aws_lambda_powertools.shared.types import Annotated, Literal, get_args, get_origin """ This turns the low-level function signature into typed, validated Pydantic models for consumption. diff --git a/aws_lambda_powertools/shared/types.py b/aws_lambda_powertools/shared/types.py index 633db46c587..2c5bab1b8b6 100644 --- a/aws_lambda_powertools/shared/types.py +++ b/aws_lambda_powertools/shared/types.py @@ -2,10 +2,14 @@ from typing import Any, Callable, Dict, List, TypeVar, Union if sys.version_info >= (3, 8): - from typing import Literal, Protocol, TypedDict + from typing import Literal, Protocol, TypedDict, get_args else: - from typing_extensions import Literal, Protocol, TypedDict + from typing_extensions import Literal, Protocol, TypedDict, get_args +if sys.version_info >= (3, 9): + from typing import Annotated +else: + from typing_extensions import Annotated if sys.version_info >= (3, 11): from typing import NotRequired @@ -13,13 +17,15 @@ from typing_extensions import NotRequired +# Even though `get_origin` was added in Python 3.8, it only handles Annotated correctly on 3.10. +# So for python < 3.10 we use the backport from typing_extensions. if sys.version_info >= (3, 10): - from typing import TypeAlias + from typing import TypeAlias, get_origin else: - from typing_extensions import TypeAlias + from typing_extensions import TypeAlias, get_origin AnyCallableT = TypeVar("AnyCallableT", bound=Callable[..., Any]) # noqa: VNE001 # JSON primitives only, mypy doesn't support recursive tho JSONType = Union[str, int, float, bool, None, Dict[str, Any], List[Any]] -__all__ = ["Protocol", "TypedDict", "Literal", "NotRequired", "TypeAlias"] +__all__ = ["get_args", "get_origin", "Annotated", "Protocol", "TypedDict", "Literal", "NotRequired", "TypeAlias"] diff --git a/tests/functional/event_handler/test_openapi_params.py b/tests/functional/event_handler/test_openapi_params.py index ec1cd4670a9..655f3bbbd94 100644 --- a/tests/functional/event_handler/test_openapi_params.py +++ b/tests/functional/event_handler/test_openapi_params.py @@ -3,7 +3,6 @@ from typing import List from pydantic import BaseModel -from typing_extensions import Annotated from aws_lambda_powertools.event_handler.api_gateway import APIGatewayRestResolver from aws_lambda_powertools.event_handler.openapi.models import ( @@ -13,6 +12,7 @@ Schema, ) from aws_lambda_powertools.event_handler.openapi.params import Body, Query +from aws_lambda_powertools.shared.types import Annotated JSON_CONTENT_TYPE = "application/json" diff --git a/tests/functional/event_handler/test_openapi_validation_middleware.py b/tests/functional/event_handler/test_openapi_validation_middleware.py index fe1fb2816d6..c08200ca3a1 100644 --- a/tests/functional/event_handler/test_openapi_validation_middleware.py +++ b/tests/functional/event_handler/test_openapi_validation_middleware.py @@ -1,10 +1,10 @@ import json from pydantic import BaseModel -from typing_extensions import Annotated from aws_lambda_powertools.event_handler import APIGatewayRestResolver from aws_lambda_powertools.event_handler.openapi.params import Body +from aws_lambda_powertools.shared.types import Annotated from tests.functional.utils import load_event LOAD_GW_EVENT = load_event("apiGatewayProxyEvent.json") From e243200c6bf6b1749ad808f9c9c2df2fa121c6d6 Mon Sep 17 00:00:00 2001 From: Ruben Fonseca Date: Tue, 17 Oct 2023 11:14:14 +0200 Subject: [PATCH 64/75] fix: pydantic2 models --- .../event_handler/openapi/models.py | 2 +- .../event_handler/test_openapi_params.py | 10 ++++---- .../event_handler/test_openapi_responses.py | 24 +++++++++---------- 3 files changed, 18 insertions(+), 18 deletions(-) diff --git a/aws_lambda_powertools/event_handler/openapi/models.py b/aws_lambda_powertools/event_handler/openapi/models.py index 3a3433cf531..80818315f18 100644 --- a/aws_lambda_powertools/event_handler/openapi/models.py +++ b/aws_lambda_powertools/event_handler/openapi/models.py @@ -373,7 +373,7 @@ class Operation(BaseModel): parameters: Optional[List[Union[Parameter, Reference]]] = None requestBody: Optional[Union[RequestBody, Reference]] = None # Using Any for Specification Extensions - responses: Optional[Dict[str, Union[Response, Any]]] = None + responses: Optional[Dict[int, Union[Response, Any]]] = None callbacks: Optional[Dict[str, Union[Dict[str, "PathItem"], Reference]]] = None deprecated: Optional[bool] = None security: Optional[List[Dict[str, List[str]]]] = None diff --git a/tests/functional/event_handler/test_openapi_params.py b/tests/functional/event_handler/test_openapi_params.py index 655f3bbbd94..41c2aa8e65d 100644 --- a/tests/functional/event_handler/test_openapi_params.py +++ b/tests/functional/event_handler/test_openapi_params.py @@ -39,8 +39,8 @@ def handler(): assert get.operationId == "handler__get" assert get.responses is not None - assert "200" in get.responses.keys() - response = get.responses["200"] + assert 200 in get.responses.keys() + response = get.responses[200] assert response.description == "Successful Response" assert JSON_CONTENT_TYPE in response.content @@ -140,7 +140,7 @@ def handler() -> str: get = schema.paths["/"].get assert get.parameters is None - response = get.responses["200"].content[JSON_CONTENT_TYPE] + response = get.responses[200].content[JSON_CONTENT_TYPE] assert response.schema_.title == "Return" assert response.schema_.type == "string" @@ -161,7 +161,7 @@ def handler() -> User: get = schema.paths["/"].get assert get.parameters is None - response = get.responses["200"].content[JSON_CONTENT_TYPE] + response = get.responses[200].content[JSON_CONTENT_TYPE] reference = response.schema_ assert reference.ref == "#/components/schemas/User" @@ -214,7 +214,7 @@ def handler() -> User: get = schema.paths["/"].get assert get.parameters is None - response = get.responses["200"].content[JSON_CONTENT_TYPE] + response = get.responses[200].content[JSON_CONTENT_TYPE] reference = response.schema_ assert reference.ref == "#/components/schemas/User" diff --git a/tests/functional/event_handler/test_openapi_responses.py b/tests/functional/event_handler/test_openapi_responses.py index e566a2f5ca1..bd470867428 100644 --- a/tests/functional/event_handler/test_openapi_responses.py +++ b/tests/functional/event_handler/test_openapi_responses.py @@ -10,11 +10,11 @@ def handler(): schema = app.get_openapi_schema() responses = schema.paths["/"].get.responses - assert "200" in responses.keys() - assert responses["200"].description == "Successful Response" + assert 200 in responses.keys() + assert responses[200].description == "Successful Response" - assert "422" in responses.keys() - assert responses["422"].description == "Validation Error" + assert 422 in responses.keys() + assert responses[422].description == "Validation Error" def test_openapi_200_response_with_description(): @@ -26,11 +26,11 @@ def handler(): schema = app.get_openapi_schema() responses = schema.paths["/"].get.responses - assert "200" in responses.keys() - assert responses["200"].description == "Custom response" + assert 200 in responses.keys() + assert responses[200].description == "Custom response" - assert "422" in responses.keys() - assert responses["422"].description == "Validation Error" + assert 422 in responses.keys() + assert responses[422].description == "Validation Error" def test_openapi_200_custom_response(): @@ -42,8 +42,8 @@ def handler(): schema = app.get_openapi_schema() responses = schema.paths["/"].get.responses - assert "202" in responses.keys() - assert responses["202"].description == "Custom response" + assert 202 in responses.keys() + assert responses[202].description == "Custom response" - assert "200" not in responses.keys() - assert "422" not in responses.keys() + assert 200 not in responses.keys() + assert 422 not in responses.keys() From 64c6192d6ce30eea0760750b2e2ca6acd2f65d29 Mon Sep 17 00:00:00 2001 From: Ruben Fonseca Date: Tue, 17 Oct 2023 11:48:48 +0200 Subject: [PATCH 65/75] fix: typing extension problems --- .../event_handler/openapi/dependant.py | 2 +- aws_lambda_powertools/shared/types.py | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/aws_lambda_powertools/event_handler/openapi/dependant.py b/aws_lambda_powertools/event_handler/openapi/dependant.py index 6f05ed55e6d..1a4fb5d0102 100644 --- a/aws_lambda_powertools/event_handler/openapi/dependant.py +++ b/aws_lambda_powertools/event_handler/openapi/dependant.py @@ -239,7 +239,7 @@ def is_body_param(*, param_field: ModelField, is_path_param: bool) -> bool: return False else: if not isinstance(param_field.field_info, Body): - raise AssertionError(f"Param: {param_field.name} can only be a request body, using Body()") + raise AssertionError(f"Param: {param_field.name} can only be a request body, use Body()") return True diff --git a/aws_lambda_powertools/shared/types.py b/aws_lambda_powertools/shared/types.py index 2c5bab1b8b6..100005159e4 100644 --- a/aws_lambda_powertools/shared/types.py +++ b/aws_lambda_powertools/shared/types.py @@ -2,9 +2,9 @@ from typing import Any, Callable, Dict, List, TypeVar, Union if sys.version_info >= (3, 8): - from typing import Literal, Protocol, TypedDict, get_args + from typing import Literal, Protocol, TypedDict else: - from typing_extensions import Literal, Protocol, TypedDict, get_args + from typing_extensions import Literal, Protocol, TypedDict if sys.version_info >= (3, 9): from typing import Annotated @@ -17,12 +17,12 @@ from typing_extensions import NotRequired -# Even though `get_origin` was added in Python 3.8, it only handles Annotated correctly on 3.10. +# Even though `get_args` and `get_origin` were added in Python 3.8, they only handle Annotated correctly on 3.10. # So for python < 3.10 we use the backport from typing_extensions. if sys.version_info >= (3, 10): - from typing import TypeAlias, get_origin + from typing import TypeAlias, get_args, get_origin else: - from typing_extensions import TypeAlias, get_origin + from typing_extensions import TypeAlias, get_args, get_origin AnyCallableT = TypeVar("AnyCallableT", bound=Callable[..., Any]) # noqa: VNE001 # JSON primitives only, mypy doesn't support recursive tho From 006f8544cb116b310f8fb8d0e652da48d2c3e557 Mon Sep 17 00:00:00 2001 From: Leandro Damascena Date: Wed, 18 Oct 2023 17:50:43 +0100 Subject: [PATCH 66/75] Adding more tests and fixing small things --- .../event_handler/api_gateway.py | 24 +-- .../event_handler/lambda_function_url.py | 2 +- .../test_openapi_validation_middleware.py | 164 ++++++++++++++++++ 3 files changed, 177 insertions(+), 13 deletions(-) diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index 7bd897e8036..1383b74ada0 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -789,7 +789,7 @@ def route( summary: Optional[str] = None, description: Optional[str] = None, responses: Optional[Dict[int, Dict[str, Any]]] = None, - response_description: Optional[str] = _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION, + response_description: str = _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION, tags: Optional[List["Tag"]] = None, operation_id: Optional[str] = None, middlewares: Optional[List[Callable[..., Any]]] = None, @@ -846,7 +846,7 @@ def get( summary: Optional[str] = None, description: Optional[str] = None, responses: Optional[Dict[int, Dict[str, Any]]] = None, - response_description: Optional[str] = _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION, + response_description: str = _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION, tags: Optional[List["Tag"]] = None, operation_id: Optional[str] = None, middlewares: Optional[List[Callable[..., Any]]] = None, @@ -897,7 +897,7 @@ def post( summary: Optional[str] = None, description: Optional[str] = None, responses: Optional[Dict[int, Dict[str, Any]]] = None, - response_description: Optional[str] = _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION, + response_description: str = _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION, tags: Optional[List["Tag"]] = None, operation_id: Optional[str] = None, middlewares: Optional[List[Callable[..., Any]]] = None, @@ -949,7 +949,7 @@ def put( summary: Optional[str] = None, description: Optional[str] = None, responses: Optional[Dict[int, Dict[str, Any]]] = None, - response_description: Optional[str] = _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION, + response_description: str = _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION, tags: Optional[List["Tag"]] = None, operation_id: Optional[str] = None, middlewares: Optional[List[Callable[..., Any]]] = None, @@ -1001,7 +1001,7 @@ def delete( summary: Optional[str] = None, description: Optional[str] = None, responses: Optional[Dict[int, Dict[str, Any]]] = None, - response_description: Optional[str] = _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION, + response_description: str = _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION, tags: Optional[List["Tag"]] = None, operation_id: Optional[str] = None, middlewares: Optional[List[Callable[..., Any]]] = None, @@ -1052,7 +1052,7 @@ def patch( summary: Optional[str] = None, description: Optional[str] = None, responses: Optional[Dict[int, Dict[str, Any]]] = None, - response_description: Optional[str] = _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION, + response_description: str = _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION, tags: Optional[List["Tag"]] = None, operation_id: Optional[str] = None, middlewares: Optional[List[Callable]] = None, @@ -1257,7 +1257,7 @@ def __init__( debug: Optional[bool] = None, serializer: Optional[Callable[[Dict], str]] = None, strip_prefixes: Optional[List[Union[str, Pattern]]] = None, - enable_validation: Optional[bool] = False, + enable_validation: bool = False, ): """ Parameters @@ -1502,7 +1502,7 @@ def route( summary: Optional[str] = None, description: Optional[str] = None, responses: Optional[Dict[int, Dict[str, Any]]] = None, - response_description: Optional[str] = _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION, + response_description: str = _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION, tags: Optional[List["Tag"]] = None, operation_id: Optional[str] = None, middlewares: Optional[List[Callable[..., Any]]] = None, @@ -1988,7 +1988,7 @@ def __init__( debug: Optional[bool] = None, serializer: Optional[Callable[[Dict], str]] = None, strip_prefixes: Optional[List[Union[str, Pattern]]] = None, - enable_validation: Optional[bool] = False, + enable_validation: bool = False, ): """Amazon API Gateway REST and HTTP API v1 payload resolver""" super().__init__( @@ -2011,7 +2011,7 @@ def route( summary: Optional[str] = None, description: Optional[str] = None, responses: Optional[Dict[int, Dict[str, Any]]] = None, - response_description: Optional[str] = _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION, + response_description: str = _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION, tags: Optional[List["Tag"]] = None, operation_id: Optional[str] = None, middlewares: Optional[List[Callable[..., Any]]] = None, @@ -2047,7 +2047,7 @@ def __init__( debug: Optional[bool] = None, serializer: Optional[Callable[[Dict], str]] = None, strip_prefixes: Optional[List[Union[str, Pattern]]] = None, - enable_validation: Optional[bool] = False, + enable_validation: bool = False, ): """Amazon API Gateway HTTP API v2 payload resolver""" super().__init__( @@ -2069,7 +2069,7 @@ def __init__( debug: Optional[bool] = None, serializer: Optional[Callable[[Dict], str]] = None, strip_prefixes: Optional[List[Union[str, Pattern]]] = None, - enable_validation: Optional[bool] = False, + enable_validation: bool = False, ): """Amazon Application Load Balancer (ALB) resolver""" super().__init__(ProxyEventType.ALBEvent, cors, debug, serializer, strip_prefixes, enable_validation) diff --git a/aws_lambda_powertools/event_handler/lambda_function_url.py b/aws_lambda_powertools/event_handler/lambda_function_url.py index ff7adeb6412..bacdc8549c7 100644 --- a/aws_lambda_powertools/event_handler/lambda_function_url.py +++ b/aws_lambda_powertools/event_handler/lambda_function_url.py @@ -52,7 +52,7 @@ def __init__( debug: Optional[bool] = None, serializer: Optional[Callable[[Dict], str]] = None, strip_prefixes: Optional[List[Union[str, Pattern]]] = None, - enable_validation: Optional[bool] = False, + enable_validation: bool = False, ): super().__init__( ProxyEventType.LambdaFunctionUrlEvent, diff --git a/tests/functional/event_handler/test_openapi_validation_middleware.py b/tests/functional/event_handler/test_openapi_validation_middleware.py index c08200ca3a1..2aebfe5dcab 100644 --- a/tests/functional/event_handler/test_openapi_validation_middleware.py +++ b/tests/functional/event_handler/test_openapi_validation_middleware.py @@ -1,4 +1,9 @@ import json +from dataclasses import dataclass +from decimal import Decimal +from enum import Enum +from pathlib import PurePath +from typing import Tuple from pydantic import BaseModel @@ -11,8 +16,10 @@ def test_validate_scalars(): + # GIVEN an APIGatewayRestResolver with validation enabled app = APIGatewayRestResolver(enable_validation=True) + # WHEN a handler is defined with a scalar parameter @app.get("/users/") def handler(user_id: int): print(user_id) @@ -20,20 +27,24 @@ def handler(user_id: int): # sending a number LOAD_GW_EVENT["path"] = "/users/123" + # THEN the handler should be invoked and return 200 result = app(LOAD_GW_EVENT, {}) assert result["statusCode"] == 200 # sending a string LOAD_GW_EVENT["path"] = "/users/abc" + # THEN the handler should be invoked and return 422 result = app(LOAD_GW_EVENT, {}) assert result["statusCode"] == 422 assert any(text in result["body"] for text in ["type_error.integer", "int_parsing"]) def test_validate_scalars_with_default(): + # GIVEN an APIGatewayRestResolver with validation enabled app = APIGatewayRestResolver(enable_validation=True) + # WHEN a handler is defined with a default scalar parameter @app.get("/users/") def handler(user_id: int = 123): print(user_id) @@ -41,20 +52,24 @@ def handler(user_id: int = 123): # sending a number LOAD_GW_EVENT["path"] = "/users/123" + # THEN the handler should be invoked and return 200 result = app(LOAD_GW_EVENT, {}) assert result["statusCode"] == 200 # sending a string LOAD_GW_EVENT["path"] = "/users/abc" + # THEN the handler should be invoked and return 422 result = app(LOAD_GW_EVENT, {}) assert result["statusCode"] == 422 assert any(text in result["body"] for text in ["type_error.integer", "int_parsing"]) def test_validate_scalars_with_default_and_optional(): + # GIVEN an APIGatewayRestResolver with validation enabled app = APIGatewayRestResolver(enable_validation=True) + # WHEN a handler is defined with a default scalar parameter @app.get("/users/") def handler(user_id: int = 123, include_extra: bool = False): print(user_id) @@ -62,74 +77,215 @@ def handler(user_id: int = 123, include_extra: bool = False): # sending a number LOAD_GW_EVENT["path"] = "/users/123" + # THEN the handler should be invoked and return 200 result = app(LOAD_GW_EVENT, {}) assert result["statusCode"] == 200 # sending a string LOAD_GW_EVENT["path"] = "/users/abc" + # THEN the handler should be invoked and return 422 result = app(LOAD_GW_EVENT, {}) assert result["statusCode"] == 422 assert any(text in result["body"] for text in ["type_error.integer", "int_parsing"]) def test_validate_return_type(): + # GIVEN an APIGatewayRestResolver with validation enabled app = APIGatewayRestResolver(enable_validation=True) + # WHEN a handler is defined with a return type @app.get("/") def handler() -> int: return 123 LOAD_GW_EVENT["path"] = "/" + # THEN the handler should be invoked and return 200 + # THEN the body must be 123 result = app(LOAD_GW_EVENT, {}) assert result["statusCode"] == 200 assert result["body"] == 123 +def test_validate_return_tuple(): + # GIVEN an APIGatewayRestResolver with validation enabled + app = APIGatewayRestResolver(enable_validation=True) + + sample_tuple = (1, 2, 3) + + # WHEN a handler is defined with a return type as Tuple + @app.get("/") + def handler() -> Tuple: + return sample_tuple + + LOAD_GW_EVENT["path"] = "/" + + # THEN the handler should be invoked and return 200 + # THEN the body must be a tuple + result = app(LOAD_GW_EVENT, {}) + assert result["statusCode"] == 200 + assert result["body"] == list(sample_tuple) + + +def test_validate_return_decimal_as_int(): + # GIVEN an APIGatewayRestResolver with validation enabled + app = APIGatewayRestResolver(enable_validation=True) + + sample_decimal = Decimal(10) + + # WHEN a handler is defined with a return type as Decimal + @app.get("/") + def handler() -> Decimal: + return sample_decimal + + LOAD_GW_EVENT["path"] = "/" + + # THEN the handler should be invoked and return 200 + # THEN the body must be a decimal as int + result = app(LOAD_GW_EVENT, {}) + assert result["statusCode"] == 200 + assert result["body"] == sample_decimal + + +def test_validate_return_decimal_as_float(): + # GIVEN an APIGatewayRestResolver with validation enabled + app = APIGatewayRestResolver(enable_validation=True) + + sample_decimal = Decimal(10.20) + + # WHEN a handler is defined with a return type as Decimal + @app.get("/") + def handler() -> Decimal: + return sample_decimal + + LOAD_GW_EVENT["path"] = "/" + + # THEN the handler should be invoked and return 200 + # THEN the body must be a decimal as float + result = app(LOAD_GW_EVENT, {}) + assert result["statusCode"] == 200 + assert result["body"] == sample_decimal + + +def test_validate_return_purepath(): + # GIVEN an APIGatewayRestResolver with validation enabled + app = APIGatewayRestResolver(enable_validation=True) + + sample_path = PurePath(__file__) + + # WHEN a handler is defined with a return type as string + # WHEN return value is a PurePath + @app.get("/") + def handler() -> str: + return sample_path + + LOAD_GW_EVENT["path"] = "/" + + # THEN the handler should be invoked and return 200 + # THEN the body must be a string + result = app(LOAD_GW_EVENT, {}) + assert result["statusCode"] == 200 + assert result["body"] == sample_path.as_posix() + + +def test_validate_return_enum(): + # GIVEN an APIGatewayRestResolver with validation enabled + app = APIGatewayRestResolver(enable_validation=True) + + class Model(Enum): + name = "powertools" + + # WHEN a handler is defined with a return type as Enum + @app.get("/") + def handler() -> Model: + return Model.name.value + + LOAD_GW_EVENT["path"] = "/" + + # THEN the handler should be invoked and return 200 + # THEN the body must be a string + result = app(LOAD_GW_EVENT, {}) + assert result["statusCode"] == 200 + assert result["body"] == "powertools" + + +def test_validate_return_dataclass(): + # GIVEN an APIGatewayRestResolver with validation enabled + app = APIGatewayRestResolver(enable_validation=True) + + @dataclass + class Model: + name: str + age: int + + # WHEN a handler is defined with a return type as dataclass + @app.get("/") + def handler() -> Model: + return Model(name="John", age=30) + + LOAD_GW_EVENT["path"] = "/" + + # THEN the handler should be invoked and return 200 + # THEN the body must be a dict + result = app(LOAD_GW_EVENT, {}) + assert result["statusCode"] == 200 + assert result["body"] == {"name": "John", "age": 30} + + def test_validate_return_model(): + # GIVEN an APIGatewayRestResolver with validation enabled app = APIGatewayRestResolver(enable_validation=True) class Model(BaseModel): name: str age: int + # WHEN a handler is defined with a return type as Pydantic model @app.get("/") def handler() -> Model: return Model(name="John", age=30) LOAD_GW_EVENT["path"] = "/" + # THEN the handler should be invoked and return 200 + # THEN the body must be a dict result = app(LOAD_GW_EVENT, {}) assert result["statusCode"] == 200 assert result["body"] == {"name": "John", "age": 30} def test_validate_invalid_return_model(): + # GIVEN an APIGatewayRestResolver with validation enabled app = APIGatewayRestResolver(enable_validation=True) class Model(BaseModel): name: str age: int + # WHEN a handler is defined with a return type as Pydantic model @app.get("/") def handler() -> Model: return {"name": "John"} # type: ignore LOAD_GW_EVENT["path"] = "/" + # THEN the handler should be invoked and return 422 + # THEN the body must be a dict result = app(LOAD_GW_EVENT, {}) assert result["statusCode"] == 422 assert "missing" in result["body"] def test_validate_body_param(): + # GIVEN an APIGatewayRestResolver with validation enabled app = APIGatewayRestResolver(enable_validation=True) class Model(BaseModel): name: str age: int + # WHEN a handler is defined with a body parameter @app.post("/") def handler(user: Model) -> Model: return user @@ -138,18 +294,22 @@ def handler(user: Model) -> Model: LOAD_GW_EVENT["path"] = "/" LOAD_GW_EVENT["body"] = json.dumps({"name": "John", "age": 30}) + # THEN the handler should be invoked and return 200 + # THEN the body must be a dict result = app(LOAD_GW_EVENT, {}) assert result["statusCode"] == 200 assert result["body"] == {"name": "John", "age": 30} def test_validate_embed_body_param(): + # GIVEN an APIGatewayRestResolver with validation enabled app = APIGatewayRestResolver(enable_validation=True) class Model(BaseModel): name: str age: int + # WHEN a handler is defined with a body parameter @app.post("/") def handler(user: Annotated[Model, Body(embed=True)]) -> Model: return user @@ -158,10 +318,14 @@ def handler(user: Annotated[Model, Body(embed=True)]) -> Model: LOAD_GW_EVENT["path"] = "/" LOAD_GW_EVENT["body"] = json.dumps({"name": "John", "age": 30}) + # THEN the handler should be invoked and return 422 + # THEN the body must be a dict result = app(LOAD_GW_EVENT, {}) assert result["statusCode"] == 422 assert "missing" in result["body"] + # THEN the handler should be invoked and return 200 + # THEN the body must be a dict LOAD_GW_EVENT["body"] = json.dumps({"user": {"name": "John", "age": 30}}) result = app(LOAD_GW_EVENT, {}) assert result["statusCode"] == 200 From 0e79b812c963abeeb06105a6137b1597c61d53aa Mon Sep 17 00:00:00 2001 From: Leandro Damascena Date: Wed, 18 Oct 2023 17:57:47 +0100 Subject: [PATCH 67/75] Adding more tests and fixing small things --- .../event_handler/test_openapi_validation_middleware.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/functional/event_handler/test_openapi_validation_middleware.py b/tests/functional/event_handler/test_openapi_validation_middleware.py index 2aebfe5dcab..b028bfc8bc1 100644 --- a/tests/functional/event_handler/test_openapi_validation_middleware.py +++ b/tests/functional/event_handler/test_openapi_validation_middleware.py @@ -145,14 +145,14 @@ def handler() -> Decimal: # THEN the body must be a decimal as int result = app(LOAD_GW_EVENT, {}) assert result["statusCode"] == 200 - assert result["body"] == sample_decimal + assert result["body"] == "10" def test_validate_return_decimal_as_float(): # GIVEN an APIGatewayRestResolver with validation enabled app = APIGatewayRestResolver(enable_validation=True) - sample_decimal = Decimal(10.20) + sample_decimal = Decimal(10.22) # WHEN a handler is defined with a return type as Decimal @app.get("/") @@ -165,7 +165,7 @@ def handler() -> Decimal: # THEN the body must be a decimal as float result = app(LOAD_GW_EVENT, {}) assert result["statusCode"] == 200 - assert result["body"] == sample_decimal + assert result["body"] == "10.22" def test_validate_return_purepath(): From acf89280a1c6e37656e8dbbf806ad11158a35f3c Mon Sep 17 00:00:00 2001 From: Leandro Damascena Date: Wed, 18 Oct 2023 18:03:44 +0100 Subject: [PATCH 68/75] Adding more tests and fixing small things --- .../event_handler/test_openapi_validation_middleware.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/functional/event_handler/test_openapi_validation_middleware.py b/tests/functional/event_handler/test_openapi_validation_middleware.py index b028bfc8bc1..b18d1ee91df 100644 --- a/tests/functional/event_handler/test_openapi_validation_middleware.py +++ b/tests/functional/event_handler/test_openapi_validation_middleware.py @@ -145,7 +145,7 @@ def handler() -> Decimal: # THEN the body must be a decimal as int result = app(LOAD_GW_EVENT, {}) assert result["statusCode"] == 200 - assert result["body"] == "10" + assert result["body"] == 10 def test_validate_return_decimal_as_float(): @@ -165,7 +165,7 @@ def handler() -> Decimal: # THEN the body must be a decimal as float result = app(LOAD_GW_EVENT, {}) assert result["statusCode"] == 200 - assert result["body"] == "10.22" + assert result["body"] == 10.22 def test_validate_return_purepath(): From 4779d394f0f33fcce997ccd95d58b4abb5d9bf52 Mon Sep 17 00:00:00 2001 From: Leandro Damascena Date: Wed, 18 Oct 2023 18:14:49 +0100 Subject: [PATCH 69/75] Removing flaky tests --- .../test_openapi_validation_middleware.py | 41 ------------------- 1 file changed, 41 deletions(-) diff --git a/tests/functional/event_handler/test_openapi_validation_middleware.py b/tests/functional/event_handler/test_openapi_validation_middleware.py index b18d1ee91df..6b4b94405d8 100644 --- a/tests/functional/event_handler/test_openapi_validation_middleware.py +++ b/tests/functional/event_handler/test_openapi_validation_middleware.py @@ -1,6 +1,5 @@ import json from dataclasses import dataclass -from decimal import Decimal from enum import Enum from pathlib import PurePath from typing import Tuple @@ -128,46 +127,6 @@ def handler() -> Tuple: assert result["body"] == list(sample_tuple) -def test_validate_return_decimal_as_int(): - # GIVEN an APIGatewayRestResolver with validation enabled - app = APIGatewayRestResolver(enable_validation=True) - - sample_decimal = Decimal(10) - - # WHEN a handler is defined with a return type as Decimal - @app.get("/") - def handler() -> Decimal: - return sample_decimal - - LOAD_GW_EVENT["path"] = "/" - - # THEN the handler should be invoked and return 200 - # THEN the body must be a decimal as int - result = app(LOAD_GW_EVENT, {}) - assert result["statusCode"] == 200 - assert result["body"] == 10 - - -def test_validate_return_decimal_as_float(): - # GIVEN an APIGatewayRestResolver with validation enabled - app = APIGatewayRestResolver(enable_validation=True) - - sample_decimal = Decimal(10.22) - - # WHEN a handler is defined with a return type as Decimal - @app.get("/") - def handler() -> Decimal: - return sample_decimal - - LOAD_GW_EVENT["path"] = "/" - - # THEN the handler should be invoked and return 200 - # THEN the body must be a decimal as float - result = app(LOAD_GW_EVENT, {}) - assert result["statusCode"] == 200 - assert result["body"] == 10.22 - - def test_validate_return_purepath(): # GIVEN an APIGatewayRestResolver with validation enabled app = APIGatewayRestResolver(enable_validation=True) From 2e115ddbe123177fba6ca705451b5f32f265b466 Mon Sep 17 00:00:00 2001 From: Ruben Fonseca Date: Mon, 23 Oct 2023 17:26:35 +0200 Subject: [PATCH 70/75] fix: improve coverage of encoders --- .../event_handler/test_openapi_encoders.py | 183 ++++++++++++++++++ 1 file changed, 183 insertions(+) create mode 100644 tests/functional/event_handler/test_openapi_encoders.py diff --git a/tests/functional/event_handler/test_openapi_encoders.py b/tests/functional/event_handler/test_openapi_encoders.py new file mode 100644 index 00000000000..25dcca3a1c0 --- /dev/null +++ b/tests/functional/event_handler/test_openapi_encoders.py @@ -0,0 +1,183 @@ +from dataclasses import dataclass +from typing import List + +from pydantic import BaseModel +from pydantic.color import Color + +from aws_lambda_powertools.event_handler.openapi.encoders import jsonable_encoder + + +def test_openapi_encode_include(): + class User(BaseModel): + name: str + age: int + + result = jsonable_encoder(User(name="John", age=20), include=["name"]) + assert result == {"name": "John"} + + +def test_openapi_encode_exclude(): + class User(BaseModel): + name: str + age: int + + result = jsonable_encoder(User(name="John", age=20), exclude=["age"]) + assert result == {"name": "John"} + + +def test_openapi_encode_pydantic(): + class Order(BaseModel): + quantity: int + + class User(BaseModel): + name: str + order: Order + + result = jsonable_encoder(User(name="John", order=Order(quantity=2))) + assert result == {"name": "John", "order": {"quantity": 2}} + + +def test_openapi_encode_pydantic_root_types(): + class User(BaseModel): + __root__: List[str] + + result = jsonable_encoder(User(__root__=["John", "Jane"])) + assert result == ["John", "Jane"] + + +def test_openapi_encode_dataclass(): + @dataclass + class Order: + quantity: int + + @dataclass + class User: + name: str + order: Order + + result = jsonable_encoder(User(name="John", order=Order(quantity=2))) + assert result == {"name": "John", "order": {"quantity": 2}} + + +def test_openapi_encode_enum(): + from enum import Enum + + class Color(Enum): + RED = "red" + GREEN = "green" + BLUE = "blue" + + result = jsonable_encoder(Color.RED) + assert result == "red" + + +def test_openapi_encode_purepath(): + from pathlib import PurePath + + result = jsonable_encoder(PurePath("/foo/bar")) + assert result == "/foo/bar" + + +def test_openapi_encode_scalars(): + result = jsonable_encoder("foo") + assert result == "foo" + + result = jsonable_encoder(1) + assert result == 1 + + result = jsonable_encoder(1.0) + assert result == 1.0 + + result = jsonable_encoder(True) + assert result is True + + result = jsonable_encoder(None) + assert result is None + + +def test_openapi_encode_dict(): + result = jsonable_encoder({"foo": "bar"}) + assert result == {"foo": "bar"} + + +def test_openapi_encode_dict_with_include(): + result = jsonable_encoder({"foo": "bar", "bar": "foo"}, include=["foo"]) + assert result == {"foo": "bar"} + + +def test_openapi_encode_dict_with_exclude(): + result = jsonable_encoder({"foo": "bar", "bar": "foo"}, exclude=["bar"]) + assert result == {"foo": "bar"} + + +def test_openapi_encode_sequences(): + result = jsonable_encoder(["foo", "bar"]) + assert result == ["foo", "bar"] + + result = jsonable_encoder(("foo", "bar")) + assert result == ["foo", "bar"] + + result = jsonable_encoder({"foo", "bar"}) + assert set(result) == {"foo", "bar"} + + result = jsonable_encoder(frozenset(("foo", "bar"))) + assert set(result) == {"foo", "bar"} + + +def test_openapi_encode_bytes(): + result = jsonable_encoder(b"foo") + assert result == "foo" + + +def test_openapi_encode_timedelta(): + from datetime import timedelta + + result = jsonable_encoder(timedelta(seconds=1)) + assert result == 1 + + +def test_openapi_encode_decimal(): + from decimal import Decimal + + result = jsonable_encoder(Decimal("1.0")) + assert result == 1.0 + + result = jsonable_encoder(Decimal("1")) + assert result == 1 + + +def test_openapi_encode_uuid(): + from uuid import UUID + + result = jsonable_encoder(UUID("123e4567-e89b-12d3-a456-426614174000")) + assert result == "123e4567-e89b-12d3-a456-426614174000" + + +def test_openapi_encode_encodable(): + from datetime import date, datetime, time + + result = jsonable_encoder(date(2021, 1, 1)) + assert result == "2021-01-01" + + result = jsonable_encoder(datetime(2021, 1, 1, 0, 0, 0)) + assert result == "2021-01-01T00:00:00" + + result = jsonable_encoder(time(0, 0, 0)) + assert result == "00:00:00" + + +def test_openapi_encode_subclasses(): + class MyColor(Color): + pass + + result = jsonable_encoder(MyColor("red")) + assert result == "red" + + +def test_openapi_encode_other(): + class User: + def __init__(self, name: str): + self.name = name + + result = jsonable_encoder(User(name="John")) + assert result == {"name": "John"} From bf0aaaebe66b88ddf1f90a5f792d6b01fe8153c3 Mon Sep 17 00:00:00 2001 From: Ruben Fonseca Date: Mon, 23 Oct 2023 17:47:11 +0200 Subject: [PATCH 71/75] fix: mark test as pydantic v1 only --- .../functional/event_handler/test_openapi_encoders.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/tests/functional/event_handler/test_openapi_encoders.py b/tests/functional/event_handler/test_openapi_encoders.py index 25dcca3a1c0..89fa3c0ba60 100644 --- a/tests/functional/event_handler/test_openapi_encoders.py +++ b/tests/functional/event_handler/test_openapi_encoders.py @@ -1,12 +1,22 @@ from dataclasses import dataclass from typing import List +import pytest from pydantic import BaseModel from pydantic.color import Color from aws_lambda_powertools.event_handler.openapi.encoders import jsonable_encoder +@pytest.fixture +def pydanticv1_only(): + from pydantic import __version__ + + version = __version__.split(".") + if version[0] != "1": + pytest.skip("pydanticv1 test only") + + def test_openapi_encode_include(): class User(BaseModel): name: str @@ -37,6 +47,7 @@ class User(BaseModel): assert result == {"name": "John", "order": {"quantity": 2}} +@pytest.mark.usefixtures("pydanticv1_only") def test_openapi_encode_pydantic_root_types(): class User(BaseModel): __root__: List[str] From 80375e4aa7d99f74138bd329f701be49645c1bbd Mon Sep 17 00:00:00 2001 From: Ruben Fonseca Date: Mon, 23 Oct 2023 17:59:29 +0200 Subject: [PATCH 72/75] fix: make sonarcube happy --- tests/functional/event_handler/test_openapi_encoders.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/functional/event_handler/test_openapi_encoders.py b/tests/functional/event_handler/test_openapi_encoders.py index 89fa3c0ba60..4062384b16e 100644 --- a/tests/functional/event_handler/test_openapi_encoders.py +++ b/tests/functional/event_handler/test_openapi_encoders.py @@ -1,3 +1,4 @@ +import math from dataclasses import dataclass from typing import List @@ -97,7 +98,7 @@ def test_openapi_encode_scalars(): assert result == 1 result = jsonable_encoder(1.0) - assert result == 1.0 + assert math.isclose(result, 1.0) result = jsonable_encoder(True) assert result is True @@ -151,7 +152,7 @@ def test_openapi_encode_decimal(): from decimal import Decimal result = jsonable_encoder(Decimal("1.0")) - assert result == 1.0 + assert math.isclose(result, 1.0) result = jsonable_encoder(Decimal("1")) assert result == 1 From 18f441855f70a9c4b40b866ab676deb4f4f3b45c Mon Sep 17 00:00:00 2001 From: Ruben Fonseca Date: Mon, 23 Oct 2023 19:02:24 +0200 Subject: [PATCH 73/75] fix: improve coverage of params.py --- .../event_handler/openapi/dependant.py | 2 +- .../event_handler/openapi/params.py | 19 +-------- .../event_handler/test_openapi_params.py | 41 ++++++++++++++++++- 3 files changed, 42 insertions(+), 20 deletions(-) diff --git a/aws_lambda_powertools/event_handler/openapi/dependant.py b/aws_lambda_powertools/event_handler/openapi/dependant.py index 1a4fb5d0102..8cbb8b942ed 100644 --- a/aws_lambda_powertools/event_handler/openapi/dependant.py +++ b/aws_lambda_powertools/event_handler/openapi/dependant.py @@ -258,7 +258,7 @@ def get_flat_params(dependant: Dependant) -> List[ModelField]: A list of ModelField objects containing the flat parameters from the Dependant object. """ - flat_dependant = get_flat_dependant(dependant, skip_repeats=True) + flat_dependant = get_flat_dependant(dependant) return ( flat_dependant.path_params + flat_dependant.query_params diff --git a/aws_lambda_powertools/event_handler/openapi/params.py b/aws_lambda_powertools/event_handler/openapi/params.py index e3f53089495..797b44f6232 100644 --- a/aws_lambda_powertools/event_handler/openapi/params.py +++ b/aws_lambda_powertools/event_handler/openapi/params.py @@ -47,7 +47,6 @@ def __init__( cookie_params: Optional[List[ModelField]] = None, body_params: Optional[List[ModelField]] = None, return_param: Optional[ModelField] = None, - dependencies: Optional[List["Dependant"]] = None, name: Optional[str] = None, call: Optional[Callable[..., Any]] = None, request_param_name: Optional[str] = None, @@ -63,7 +62,6 @@ def __init__( self.cookie_params = cookie_params or [] self.body_params = body_params or [] self.return_param = return_param or None - self.dependencies = dependencies or [] self.request_param_name = request_param_name self.websocket_param_name = websocket_param_name self.http_connection_param_name = http_connection_param_name @@ -618,8 +616,6 @@ def __init__( def get_flat_dependant( dependant: Dependant, - *, - skip_repeats: bool = False, visited: Optional[List[CacheKey]] = None, ) -> Dependant: """ @@ -647,7 +643,7 @@ def get_flat_dependant( visited = [] visited.append(dependant.cache_key) - flat_dependant = Dependant( + return Dependant( path_params=dependant.path_params.copy(), query_params=dependant.query_params.copy(), header_params=dependant.header_params.copy(), @@ -655,19 +651,6 @@ def get_flat_dependant( body_params=dependant.body_params.copy(), path=dependant.path, ) - for sub_dependant in dependant.dependencies: - if skip_repeats and sub_dependant.cache_key in visited: - continue - - flat_sub = get_flat_dependant(sub_dependant, skip_repeats=skip_repeats, visited=visited) - - flat_dependant.path_params.extend(flat_sub.path_params) - flat_dependant.query_params.extend(flat_sub.query_params) - flat_dependant.header_params.extend(flat_sub.header_params) - flat_dependant.cookie_params.extend(flat_sub.cookie_params) - flat_dependant.body_params.extend(flat_sub.body_params) - - return flat_dependant def analyze_param( diff --git a/tests/functional/event_handler/test_openapi_params.py b/tests/functional/event_handler/test_openapi_params.py index 41c2aa8e65d..b5f4afa9fbe 100644 --- a/tests/functional/event_handler/test_openapi_params.py +++ b/tests/functional/event_handler/test_openapi_params.py @@ -11,7 +11,14 @@ ParameterInType, Schema, ) -from aws_lambda_powertools.event_handler.openapi.params import Body, Query +from aws_lambda_powertools.event_handler.openapi.params import ( + Body, + Header, + Param, + ParamTypes, + Query, + _create_model_field, +) from aws_lambda_powertools.shared.types import Annotated JSON_CONTENT_TYPE = "application/json" @@ -274,3 +281,35 @@ def handler(user: Annotated[User, Body(embed=True)]): assert "Body_handler_users_post" in components.schemas body_post_handler_schema = components.schemas["Body_handler_users_post"] assert body_post_handler_schema.properties["user"].ref == "#/components/schemas/User" + + +def test_create_header(): + header = Header(convert_underscores=True) + assert header.convert_underscores is True + + +def test_create_body(): + body = Body(embed=True, examples=[Example(summary="Example 1", value=10)]) + assert body.embed is True + + +# Tests that when we try to create a model without a field type, we return None +def test_create_empty_model_field(): + result = _create_model_field(None, int, "name", False) + assert result is None + + +# Tests that when we try to crate a param model without a source, we default to "query" +def test_create_model_field_with_empty_in(): + field_info = Param() + + result = _create_model_field(field_info, int, "name", False) + assert result.field_info.in_ == ParamTypes.query + + +# Tests that when we try to create a model field with convert_underscore, we convert the field name +def test_create_model_field_convert_underscore(): + field_info = Header(alias=None, convert_underscores=True) + + result = _create_model_field(field_info, int, "user_id", False) + assert result.alias == "user-id" From fdadd6b88587e07b9494a7a15161f06b41853df1 Mon Sep 17 00:00:00 2001 From: Ruben Fonseca Date: Mon, 23 Oct 2023 20:13:01 +0200 Subject: [PATCH 74/75] fix: add codecov.yml file to ignore compat.py --- codecov.yml | 2 ++ 1 file changed, 2 insertions(+) create mode 100644 codecov.yml diff --git a/codecov.yml b/codecov.yml new file mode 100644 index 00000000000..caff9bcbfab --- /dev/null +++ b/codecov.yml @@ -0,0 +1,2 @@ +ignore: + - "aws_lambda_powertools/event_handler/openapi/compat.py" From 9f4672a2940c1c972ba6c7e0e9416168e5c17bcf Mon Sep 17 00:00:00 2001 From: Cavalcante Damascena Date: Tue, 24 Oct 2023 15:12:13 +0100 Subject: [PATCH 75/75] Increasing coverage --- .../test_openapi_serialization.py | 39 +++++++++++++++++++ 1 file changed, 39 insertions(+) create mode 100644 tests/functional/event_handler/test_openapi_serialization.py diff --git a/tests/functional/event_handler/test_openapi_serialization.py b/tests/functional/event_handler/test_openapi_serialization.py new file mode 100644 index 00000000000..63f1c0e4f9d --- /dev/null +++ b/tests/functional/event_handler/test_openapi_serialization.py @@ -0,0 +1,39 @@ +import json +from typing import Dict + +import pytest + +from aws_lambda_powertools.event_handler import APIGatewayRestResolver + + +def test_openapi_duplicated_serialization(): + # GIVEN APIGatewayRestResolver is initialized with enable_validation=True + app = APIGatewayRestResolver(enable_validation=True) + + # WHEN we have duplicated operations + @app.get("/") + def handler(): + pass + + @app.get("/") + def handler(): # noqa: F811 + pass + + # THEN we should get a warning + with pytest.warns(UserWarning, match="Duplicate Operation*"): + app.get_openapi_schema() + + +def test_openapi_serialize_json(): + # GIVEN APIGatewayRestResolver is initialized with enable_validation=True + app = APIGatewayRestResolver(enable_validation=True) + + @app.get("/") + def handler(): + pass + + # WHEN we serialize as json_schema + schema = json.loads(app.get_openapi_json_schema()) + + # THEN we should get a dictionary + assert isinstance(schema, Dict)