diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index 4fa9eb3eb97..a20154b4bbf 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -817,11 +817,7 @@ def _has_compression_enabled( bool True if compression is enabled and the "gzip" encoding is accepted, False otherwise. """ - encoding: str = event.get_header_value( - name="accept-encoding", - default_value="", - case_sensitive=False, - ) # noqa: E501 + encoding = event.headers.get("accept-encoding", "") if "gzip" in encoding: if response_compression is not None: return response_compression # e.g., Response(compress=False/True)) diff --git a/aws_lambda_powertools/event_handler/appsync.py b/aws_lambda_powertools/event_handler/appsync.py index fba5681ef6a..99e9225b504 100644 --- a/aws_lambda_powertools/event_handler/appsync.py +++ b/aws_lambda_powertools/event_handler/appsync.py @@ -127,7 +127,7 @@ def handler(event, context: LambdaContext): class MyCustomModel(AppSyncResolverEvent): @property def country_viewer(self) -> str: - return self.request_headers.get("cloudfront-viewer-country") + return self.request_headers.get("cloudfront-viewer-country", "") @app.resolver(field_name="listLocations") diff --git a/aws_lambda_powertools/event_handler/middlewares/base.py b/aws_lambda_powertools/event_handler/middlewares/base.py index fb4bf37cc74..342b033ec1f 100644 --- a/aws_lambda_powertools/event_handler/middlewares/base.py +++ b/aws_lambda_powertools/event_handler/middlewares/base.py @@ -47,10 +47,7 @@ def __init__(self, header: str): def handler(self, app: APIGatewayRestResolver, next_middleware: NextMiddleware) -> Response: # BEFORE logic request_id = app.current_event.request_context.request_id - correlation_id = app.current_event.get_header_value( - name=self.header, - default_value=request_id, - ) + correlation_id = app.current_event.headers.get(self.header, request_id) # Call next middleware or route handler ('/todos') response = next_middleware(app) diff --git a/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py b/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py index 2eafb0d67bb..12b70987f8a 100644 --- a/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py +++ b/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py @@ -2,7 +2,7 @@ import json import logging from copy import deepcopy -from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple +from typing import Any, Callable, Dict, List, Mapping, MutableMapping, Optional, Sequence, Tuple from pydantic import BaseModel @@ -237,8 +237,8 @@ def _get_body(self, app: EventHandlerInstance) -> Dict[str, Any]: Get the request body from the event, and parse it as JSON. """ - content_type_value = app.current_event.get_header_value("content-type") - if not content_type_value or content_type_value.strip().startswith("application/json"): + content_type = app.current_event.headers.get("content-type") + if not content_type or content_type.strip().startswith("application/json"): try: return app.current_event.json_body except json.JSONDecodeError as e: @@ -410,7 +410,7 @@ def _normalize_multi_query_string_with_param( return resolved_query_string -def _normalize_multi_header_values_with_param(headers: Dict[str, Any], params: Sequence[ModelField]): +def _normalize_multi_header_values_with_param(headers: MutableMapping[str, Any], params: Sequence[ModelField]): """ Extract and normalize resolved_headers_field diff --git a/aws_lambda_powertools/event_handler/util.py b/aws_lambda_powertools/event_handler/util.py index 6f2caf10858..9981e392f82 100644 --- a/aws_lambda_powertools/event_handler/util.py +++ b/aws_lambda_powertools/event_handler/util.py @@ -1,6 +1,4 @@ -from typing import Any, Dict - -from aws_lambda_powertools.utilities.data_classes.shared_functions import get_header_value +from typing import Any, Mapping, Optional class _FrozenDict(dict): @@ -18,25 +16,19 @@ def __hash__(self): return hash(frozenset(self.keys())) -def extract_origin_header(resolver_headers: Dict[str, Any]): +def extract_origin_header(resolved_headers: Mapping[str, Any]) -> Optional[str]: """ Extracts the 'origin' or 'Origin' header from the provided resolver headers. The 'origin' or 'Origin' header can be either a single header or a multi-header. Args: - resolver_headers (Dict): A dictionary containing the headers. + resolved_headers (Mapping): A dictionary containing the headers. Returns: Optional[str]: The value(s) of the origin header or None. """ - resolved_header = get_header_value( - headers=resolver_headers, - name="origin", - default_value=None, - case_sensitive=False, - ) + resolved_header = resolved_headers.get("origin") if isinstance(resolved_header, list): return resolved_header[0] - return resolved_header diff --git a/aws_lambda_powertools/utilities/data_classes/alb_event.py b/aws_lambda_powertools/utilities/data_classes/alb_event.py index a3fbb24f270..d5aa076c36a 100644 --- a/aws_lambda_powertools/utilities/data_classes/alb_event.py +++ b/aws_lambda_powertools/utilities/data_classes/alb_event.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List from aws_lambda_powertools.shared.headers_serializer import ( BaseHeadersSerializer, @@ -7,6 +7,7 @@ ) from aws_lambda_powertools.utilities.data_classes.common import ( BaseProxyEvent, + CaseInsensitiveDict, DictWrapper, ) @@ -37,25 +38,15 @@ def multi_value_query_string_parameters(self) -> Dict[str, List[str]]: @property def resolved_query_string_parameters(self) -> Dict[str, List[str]]: - if self.multi_value_query_string_parameters: - return self.multi_value_query_string_parameters - - return super().resolved_query_string_parameters + return self.multi_value_query_string_parameters or super().resolved_query_string_parameters @property - def resolved_headers_field(self) -> Dict[str, Any]: - headers: Dict[str, Any] = {} - - if self.multi_value_headers: - headers = self.multi_value_headers - else: - headers = self.headers - - return {key.lower(): value for key, value in headers.items()} + def multi_value_headers(self) -> Dict[str, List[str]]: + return CaseInsensitiveDict(self.get("multiValueHeaders")) @property - def multi_value_headers(self) -> Optional[Dict[str, List[str]]]: - return self.get("multiValueHeaders") + def resolved_headers_field(self) -> Dict[str, Any]: + return self.multi_value_headers or self.headers def header_serializer(self) -> BaseHeadersSerializer: # When using the ALB integration, the `multiValueHeaders` feature can be disabled (default) or enabled. diff --git a/aws_lambda_powertools/utilities/data_classes/api_gateway_authorizer_event.py b/aws_lambda_powertools/utilities/data_classes/api_gateway_authorizer_event.py index b87c8ddaf20..a9c0bf90e23 100644 --- a/aws_lambda_powertools/utilities/data_classes/api_gateway_authorizer_event.py +++ b/aws_lambda_powertools/utilities/data_classes/api_gateway_authorizer_event.py @@ -1,15 +1,13 @@ import enum import re -from typing import Any, Dict, List, Optional, overload +from typing import Any, Dict, List, Optional from aws_lambda_powertools.utilities.data_classes.common import ( BaseRequestContext, BaseRequestContextV2, + CaseInsensitiveDict, DictWrapper, ) -from aws_lambda_powertools.utilities.data_classes.shared_functions import ( - get_header_value, -) class APIGatewayRouteArn: @@ -144,7 +142,7 @@ def http_method(self) -> str: @property def headers(self) -> Dict[str, str]: - return self["headers"] + return CaseInsensitiveDict(self["headers"]) @property def query_string_parameters(self) -> Dict[str, str]: @@ -162,45 +160,6 @@ def stage_variables(self) -> Dict[str, str]: def request_context(self) -> BaseRequestContext: return BaseRequestContext(self._data) - @overload - def get_header_value( - self, - name: str, - default_value: str, - case_sensitive: bool = False, - ) -> str: ... - - @overload - def get_header_value( - self, - name: str, - default_value: Optional[str] = None, - case_sensitive: bool = False, - ) -> Optional[str]: ... - - def get_header_value( - self, - name: str, - default_value: Optional[str] = None, - case_sensitive: bool = False, - ) -> Optional[str]: - """Get header value by name - - Parameters - ---------- - name: str - Header name - default_value: str, optional - Default value if no value was found by name - case_sensitive: bool - Whether to use a case-sensitive look up - Returns - ------- - str, optional - Header value - """ - return get_header_value(self.headers, name, default_value, case_sensitive) - class APIGatewayAuthorizerEventV2(DictWrapper): """API Gateway Authorizer Event Format 2.0 @@ -234,14 +193,14 @@ def parsed_arn(self) -> APIGatewayRouteArn: return parse_api_gateway_arn(self.route_arn) @property - def identity_source(self) -> Optional[List[str]]: + def identity_source(self) -> List[str]: """The identity source for which authorization is requested. For a REQUEST authorizer, this is optional. The value is a set of one or more mapping expressions of the specified request parameters. The identity source can be headers, query string parameters, stage variables, and context parameters. """ - return self.get("identitySource") + return self.get("identitySource") or [] @property def route_key(self) -> str: @@ -265,7 +224,7 @@ def cookies(self) -> List[str]: @property def headers(self) -> Dict[str, str]: """Http headers""" - return self["headers"] + return CaseInsensitiveDict(self["headers"]) @property def query_string_parameters(self) -> Dict[str, str]: @@ -276,46 +235,12 @@ def request_context(self) -> BaseRequestContextV2: return BaseRequestContextV2(self._data) @property - def path_parameters(self) -> Optional[Dict[str, str]]: - return self.get("pathParameters") + def path_parameters(self) -> Dict[str, str]: + return self.get("pathParameters") or {} @property - def stage_variables(self) -> Optional[Dict[str, str]]: - return self.get("stageVariables") - - @overload - def get_header_value(self, name: str, default_value: str, case_sensitive: bool = False) -> str: ... - - @overload - def get_header_value( - self, - name: str, - default_value: Optional[str] = None, - case_sensitive: bool = False, - ) -> Optional[str]: ... - - def get_header_value( - self, - name: str, - default_value: Optional[str] = None, - case_sensitive: bool = False, - ) -> Optional[str]: - """Get header value by name - - Parameters - ---------- - name: str - Header name - default_value: str, optional - Default value if no value was found by name - case_sensitive: bool - Whether to use a case-sensitive look up - Returns - ------- - str, optional - Header value - """ - return get_header_value(self.headers, name, default_value, case_sensitive) + def stage_variables(self) -> Dict[str, str]: + return self.get("stageVariables") or {} class APIGatewayAuthorizerResponseV2: diff --git a/aws_lambda_powertools/utilities/data_classes/api_gateway_proxy_event.py b/aws_lambda_powertools/utilities/data_classes/api_gateway_proxy_event.py index 48d3c96c84c..f010dad80c3 100644 --- a/aws_lambda_powertools/utilities/data_classes/api_gateway_proxy_event.py +++ b/aws_lambda_powertools/utilities/data_classes/api_gateway_proxy_event.py @@ -1,3 +1,4 @@ +from functools import cached_property from typing import Any, Dict, List, Optional from aws_lambda_powertools.shared.headers_serializer import ( @@ -9,6 +10,7 @@ BaseProxyEvent, BaseRequestContext, BaseRequestContextV2, + CaseInsensitiveDict, DictWrapper, ) @@ -113,7 +115,7 @@ def resource(self) -> str: @property def multi_value_headers(self) -> Dict[str, List[str]]: - return self.get("multiValueHeaders") or {} # key might exist but can be `null` + return CaseInsensitiveDict(self.get("multiValueHeaders")) @property def multi_value_query_string_parameters(self) -> Dict[str, List[str]]: @@ -128,26 +130,19 @@ def resolved_query_string_parameters(self) -> Dict[str, List[str]]: @property def resolved_headers_field(self) -> Dict[str, Any]: - headers: Dict[str, Any] = {} - - if self.multi_value_headers: - headers = self.multi_value_headers - else: - headers = self.headers - - return {key.lower(): value for key, value in headers.items()} + return self.multi_value_headers or self.headers @property def request_context(self) -> APIGatewayEventRequestContext: return APIGatewayEventRequestContext(self._data) @property - def path_parameters(self) -> Optional[Dict[str, str]]: - return self.get("pathParameters") + def path_parameters(self) -> Dict[str, str]: + return self.get("pathParameters") or {} @property - def stage_variables(self) -> Optional[Dict[str, str]]: - return self.get("stageVariables") + def stage_variables(self) -> Dict[str, str]: + return self.get("stageVariables") or {} def header_serializer(self) -> BaseHeadersSerializer: return MultiValueHeadersSerializer() @@ -289,20 +284,20 @@ def raw_query_string(self) -> str: return self["rawQueryString"] @property - def cookies(self) -> Optional[List[str]]: - return self.get("cookies") + def cookies(self) -> List[str]: + return self.get("cookies") or [] @property def request_context(self) -> RequestContextV2: return RequestContextV2(self._data) @property - def path_parameters(self) -> Optional[Dict[str, str]]: - return self.get("pathParameters") + def path_parameters(self) -> Dict[str, str]: + return self.get("pathParameters") or {} @property - def stage_variables(self) -> Optional[Dict[str, str]]: - return self.get("stageVariables") + def stage_variables(self) -> Dict[str, str]: + return self.get("stageVariables") or {} @property def path(self) -> str: @@ -319,10 +314,6 @@ def http_method(self) -> str: def header_serializer(self): return HttpApiHeadersSerializer() - @property + @cached_property def resolved_headers_field(self) -> Dict[str, Any]: - if self.headers is not None: - headers = {key.lower(): value.split(",") if "," in value else value for key, value in self.headers.items()} - return headers - - return {} + return CaseInsensitiveDict((k, v.split(",") if "," in v else v) for k, v in self.headers.items()) diff --git a/aws_lambda_powertools/utilities/data_classes/appsync_resolver_event.py b/aws_lambda_powertools/utilities/data_classes/appsync_resolver_event.py index f58308377ff..4a02177b62e 100644 --- a/aws_lambda_powertools/utilities/data_classes/appsync_resolver_event.py +++ b/aws_lambda_powertools/utilities/data_classes/appsync_resolver_event.py @@ -1,9 +1,6 @@ -from typing import Any, Dict, List, Optional, Union, overload +from typing import Any, Dict, List, Optional, Union -from aws_lambda_powertools.utilities.data_classes.common import DictWrapper -from aws_lambda_powertools.utilities.data_classes.shared_functions import ( - get_header_value, -) +from aws_lambda_powertools.utilities.data_classes.common import CaseInsensitiveDict, DictWrapper def get_identity_object(identity: Optional[dict]) -> Any: @@ -118,15 +115,15 @@ def parent_type_name(self) -> str: return self["parentTypeName"] @property - def variables(self) -> Optional[Dict[str, str]]: + def variables(self) -> Dict[str, str]: """A map which holds all variables that are passed into the GraphQL request.""" - return self.get("variables") + return self.get("variables") or {} @property - def selection_set_list(self) -> Optional[List[str]]: + def selection_set_list(self) -> List[str]: """A list representation of the fields in the GraphQL selection set. Fields that are aliased will only be referenced by the alias name, not the field name.""" - return self.get("selectionSetList") + return self.get("selectionSetList") or [] @property def selection_set_graphql(self) -> Optional[str]: @@ -184,14 +181,14 @@ def identity(self) -> Union[None, AppSyncIdentityIAM, AppSyncIdentityCognito]: return get_identity_object(self.get("identity")) @property - def source(self) -> Optional[Dict[str, Any]]: + def source(self) -> Dict[str, Any]: """A map that contains the resolution of the parent field.""" - return self.get("source") + return self.get("source") or {} @property def request_headers(self) -> Dict[str, str]: """Request headers""" - return self["request"]["headers"] + return CaseInsensitiveDict(self["request"]["headers"]) @property def prev_result(self) -> Optional[Dict[str, Any]]: @@ -207,48 +204,9 @@ def info(self) -> AppSyncResolverEventInfo: return self._info @property - def stash(self) -> Optional[dict]: + def stash(self) -> dict: """The stash is a map that is made available inside each resolver and function mapping template. The same stash instance lives through a single resolver execution. This means that you can use the stash to pass arbitrary data across request and response mapping templates, and across functions in a pipeline resolver.""" - return self.get("stash") - - @overload - def get_header_value( - self, - name: str, - default_value: str, - case_sensitive: bool = False, - ) -> str: ... - - @overload - def get_header_value( - self, - name: str, - default_value: Optional[str] = None, - case_sensitive: bool = False, - ) -> Optional[str]: ... - - def get_header_value( - self, - name: str, - default_value: Optional[str] = None, - case_sensitive: bool = False, - ) -> Optional[str]: - """Get header value by name - - Parameters - ---------- - name: str - Header name - default_value: str, optional - Default value if no value was found by name - case_sensitive: bool - Whether to use a case-sensitive look up - Returns - ------- - str, optional - Header value - """ - return get_header_value(self.request_headers, name, default_value, case_sensitive) + return self.get("stash") or {} diff --git a/aws_lambda_powertools/utilities/data_classes/bedrock_agent_event.py b/aws_lambda_powertools/utilities/data_classes/bedrock_agent_event.py index 4c404c73111..71c6f44aa1b 100644 --- a/aws_lambda_powertools/utilities/data_classes/bedrock_agent_event.py +++ b/aws_lambda_powertools/utilities/data_classes/bedrock_agent_event.py @@ -80,8 +80,9 @@ def http_method(self) -> str: return self["httpMethod"] @property - def parameters(self) -> Optional[List[BedrockAgentProperty]]: - return [BedrockAgentProperty(x) for x in self["parameters"]] if self.get("parameters") else None + def parameters(self) -> List[BedrockAgentProperty]: + parameters = self.get("parameters") or [] + return [BedrockAgentProperty(x) for x in parameters] @property def request_body(self) -> Optional[BedrockAgentRequestBody]: @@ -104,11 +105,12 @@ def prompt_session_attributes(self) -> Dict[str, str]: def path(self) -> str: return self["apiPath"] - @property - def query_string_parameters(self) -> Optional[Dict[str, str]]: + @cached_property + def query_string_parameters(self) -> Dict[str, str]: # In Bedrock Agent events, query string parameters are passed as undifferentiated parameters, # together with the other parameters. So we just return all parameters here. - return {x["name"]: x["value"] for x in self["parameters"]} if self.get("parameters") else None + parameters = self.get("parameters") or [] + return {x["name"]: x["value"] for x in parameters} @property def resolved_headers_field(self) -> Dict[str, Any]: diff --git a/aws_lambda_powertools/utilities/data_classes/cloud_watch_alarm_event.py b/aws_lambda_powertools/utilities/data_classes/cloud_watch_alarm_event.py index d085228cb37..78106b576e0 100644 --- a/aws_lambda_powertools/utilities/data_classes/cloud_watch_alarm_event.py +++ b/aws_lambda_powertools/utilities/data_classes/cloud_watch_alarm_event.py @@ -1,7 +1,7 @@ from __future__ import annotations from functools import cached_property -from typing import Any, Dict, List, Literal, Optional +from typing import Any, List, Literal, Optional from aws_lambda_powertools.utilities.data_classes.common import DictWrapper @@ -117,11 +117,11 @@ def unit(self) -> Optional[str]: return self.get("unit", None) @property - def metric(self) -> Optional[Dict]: + def metric(self) -> dict: """ Metric details """ - return self.get("metric", {}) + return self.get("metric") or {} class CloudWatchAlarmData(DictWrapper): @@ -191,12 +191,12 @@ def alarm_actions_suppressor_extension_period(self) -> Optional[str]: return self.get("actionsSuppressorExtensionPeriod", None) @property - def metrics(self) -> Optional[List[CloudWatchAlarmMetric]]: + def metrics(self) -> List[CloudWatchAlarmMetric]: """ The metrics evaluated for the Alarm. """ - metrics = self.get("metrics") - return [CloudWatchAlarmMetric(i) for i in metrics] if metrics else None + metrics = self.get("metrics") or [] + return [CloudWatchAlarmMetric(i) for i in metrics] class CloudWatchAlarmEvent(DictWrapper): diff --git a/aws_lambda_powertools/utilities/data_classes/cloud_watch_logs_event.py b/aws_lambda_powertools/utilities/data_classes/cloud_watch_logs_event.py index 7775dd67333..7a5fe7cec76 100644 --- a/aws_lambda_powertools/utilities/data_classes/cloud_watch_logs_event.py +++ b/aws_lambda_powertools/utilities/data_classes/cloud_watch_logs_event.py @@ -23,9 +23,9 @@ def message(self) -> str: return self["message"] @property - def extracted_fields(self) -> Optional[Dict[str, str]]: + def extracted_fields(self) -> Dict[str, str]: """Get the `extractedFields` property""" - return self.get("extractedFields") + return self.get("extractedFields") or {} class CloudWatchLogsDecodedData(DictWrapper): diff --git a/aws_lambda_powertools/utilities/data_classes/code_pipeline_job_event.py b/aws_lambda_powertools/utilities/data_classes/code_pipeline_job_event.py index cc7a75cc05e..1cc409c6988 100644 --- a/aws_lambda_powertools/utilities/data_classes/code_pipeline_job_event.py +++ b/aws_lambda_powertools/utilities/data_classes/code_pipeline_job_event.py @@ -19,12 +19,11 @@ def user_parameters(self) -> Optional[str]: return self.get("UserParameters", None) @cached_property - def decoded_user_parameters(self) -> Optional[Dict[str, Any]]: + def decoded_user_parameters(self) -> Dict[str, Any]: """Json Decoded user parameters""" if self.user_parameters is not None: return self._json_deserializer(self.user_parameters) - - return None + return {} class CodePipelineActionConfiguration(DictWrapper): @@ -177,7 +176,7 @@ def user_parameters(self) -> Optional[str]: return self.data.action_configuration.configuration.user_parameters @property - def decoded_user_parameters(self) -> Optional[Dict[str, Any]]: + def decoded_user_parameters(self) -> Dict[str, Any]: """Json Decoded action configuration user parameters""" return self.data.action_configuration.configuration.decoded_user_parameters diff --git a/aws_lambda_powertools/utilities/data_classes/cognito_user_pool_event.py b/aws_lambda_powertools/utilities/data_classes/cognito_user_pool_event.py index a97bf26a16f..86cf3b0601d 100644 --- a/aws_lambda_powertools/utilities/data_classes/cognito_user_pool_event.py +++ b/aws_lambda_powertools/utilities/data_classes/cognito_user_pool_event.py @@ -61,15 +61,15 @@ def user_attributes(self) -> Dict[str, str]: return self["request"]["userAttributes"] @property - def validation_data(self) -> Optional[Dict[str, str]]: + def validation_data(self) -> Dict[str, str]: """One or more name-value pairs containing the validation data in the request to register a user.""" - return self["request"].get("validationData") + return self["request"].get("validationData") or {} @property - def client_metadata(self) -> Optional[Dict[str, str]]: + def client_metadata(self) -> Dict[str, str]: """One or more key-value pairs that you can provide as custom input to the Lambda function that you specify for the pre sign-up trigger.""" - return self["request"].get("clientMetadata") + return self["request"].get("clientMetadata") or {} class PreSignUpTriggerEventResponse(DictWrapper): @@ -133,10 +133,10 @@ def user_attributes(self) -> Dict[str, str]: return self["request"]["userAttributes"] @property - def client_metadata(self) -> Optional[Dict[str, str]]: + def client_metadata(self) -> Dict[str, str]: """One or more key-value pairs that you can provide as custom input to the Lambda function that you specify for the post confirmation trigger.""" - return self["request"].get("clientMetadata") + return self["request"].get("clientMetadata") or {} class PostConfirmationTriggerEvent(BaseTriggerEvent): @@ -165,15 +165,15 @@ def password(self) -> str: return self["request"]["password"] @property - def validation_data(self) -> Optional[Dict[str, str]]: + def validation_data(self) -> Dict[str, str]: """One or more name-value pairs containing the validation data in the request to register a user.""" - return self["request"].get("validationData") + return self["request"].get("validationData") or {} @property - def client_metadata(self) -> Optional[Dict[str, str]]: + def client_metadata(self) -> Dict[str, str]: """One or more key-value pairs that you can provide as custom input to the Lambda function that you specify for the pre sign-up trigger.""" - return self["request"].get("clientMetadata") + return self["request"].get("clientMetadata") or {} class UserMigrationTriggerEventResponse(DictWrapper): @@ -213,8 +213,8 @@ def message_action(self, value: str): self["response"]["messageAction"] = value @property - def desired_delivery_mediums(self) -> Optional[List[str]]: - return self["response"].get("desiredDeliveryMediums") + def desired_delivery_mediums(self) -> List[str]: + return self["response"].get("desiredDeliveryMediums") or [] @desired_delivery_mediums.setter def desired_delivery_mediums(self, value: List[str]): @@ -281,10 +281,10 @@ def user_attributes(self) -> Dict[str, str]: return self["request"]["userAttributes"] @property - def client_metadata(self) -> Optional[Dict[str, str]]: + def client_metadata(self) -> Dict[str, str]: """One or more key-value pairs that you can provide as custom input to the Lambda function that you specify for the pre sign-up trigger.""" - return self["request"].get("clientMetadata") + return self["request"].get("clientMetadata") or {} class CustomMessageTriggerEventResponse(DictWrapper): @@ -361,9 +361,9 @@ def user_attributes(self) -> Dict[str, str]: return self["request"]["userAttributes"] @property - def validation_data(self) -> Optional[Dict[str, str]]: + def validation_data(self) -> Dict[str, str]: """One or more key-value pairs containing the validation data in the user's sign-in request.""" - return self["request"].get("validationData") + return self["request"].get("validationData") or {} class PreAuthenticationTriggerEvent(BaseTriggerEvent): @@ -402,10 +402,10 @@ def user_attributes(self) -> Dict[str, str]: return self["request"]["userAttributes"] @property - def client_metadata(self) -> Optional[Dict[str, str]]: + def client_metadata(self) -> Dict[str, str]: """One or more key-value pairs that you can provide as custom input to the Lambda function that you specify for the post authentication trigger.""" - return self["request"].get("clientMetadata") + return self["request"].get("clientMetadata") or {} class PostAuthenticationTriggerEvent(BaseTriggerEvent): @@ -433,14 +433,14 @@ def request(self) -> PostAuthenticationTriggerEventRequest: class GroupOverrideDetails(DictWrapper): @property - def groups_to_override(self) -> Optional[List[str]]: + def groups_to_override(self) -> List[str]: """A list of the group names that are associated with the user that the identity token is issued for.""" - return self.get("groupsToOverride") + return self.get("groupsToOverride") or [] @property - def iam_roles_to_override(self) -> Optional[List[str]]: + def iam_roles_to_override(self) -> List[str]: """A list of the current IAM roles associated with these groups.""" - return self.get("iamRolesToOverride") + return self.get("iamRolesToOverride") or [] @property def preferred_role(self) -> Optional[str]: @@ -460,16 +460,16 @@ def user_attributes(self) -> Dict[str, str]: return self["request"]["userAttributes"] @property - def client_metadata(self) -> Optional[Dict[str, str]]: + def client_metadata(self) -> Dict[str, str]: """One or more key-value pairs that you can provide as custom input to the Lambda function that you specify for the pre token generation trigger.""" - return self["request"].get("clientMetadata") + return self["request"].get("clientMetadata") or {} class ClaimsOverrideDetails(DictWrapper): @property - def claims_to_add_or_override(self) -> Optional[Dict[str, str]]: - return self.get("claimsToAddOrOverride") + def claims_to_add_or_override(self) -> Dict[str, str]: + return self.get("claimsToAddOrOverride") or {} @claims_to_add_or_override.setter def claims_to_add_or_override(self, value: Dict[str, str]): @@ -478,8 +478,8 @@ def claims_to_add_or_override(self, value: Dict[str, str]): self._data["claimsToAddOrOverride"] = value @property - def claims_to_suppress(self) -> Optional[List[str]]: - return self.get("claimsToSuppress") + def claims_to_suppress(self) -> List[str]: + return self.get("claimsToSuppress") or [] @claims_to_suppress.setter def claims_to_suppress(self, value: List[str]): @@ -599,10 +599,10 @@ def session(self) -> List[ChallengeResult]: return [ChallengeResult(result) for result in self["request"]["session"]] @property - def client_metadata(self) -> Optional[Dict[str, str]]: + def client_metadata(self) -> Dict[str, str]: """One or more key-value pairs that you can provide as custom input to the Lambda function that you specify for the defined auth challenge trigger.""" - return self["request"].get("clientMetadata") + return self["request"].get("clientMetadata") or {} class DefineAuthChallengeTriggerEventResponse(DictWrapper): @@ -685,10 +685,10 @@ def session(self) -> List[ChallengeResult]: return [ChallengeResult(result) for result in self["request"]["session"]] @property - def client_metadata(self) -> Optional[Dict[str, str]]: + def client_metadata(self) -> Dict[str, str]: """One or more key-value pairs that you can provide as custom input to the Lambda function that you specify for the creation auth challenge trigger.""" - return self["request"].get("clientMetadata") + return self["request"].get("clientMetadata") or {} class CreateAuthChallengeTriggerEventResponse(DictWrapper): @@ -773,10 +773,10 @@ def challenge_answer(self) -> Any: return self["request"]["challengeAnswer"] @property - def client_metadata(self) -> Optional[Dict[str, str]]: + def client_metadata(self) -> Dict[str, str]: """One or more key-value pairs that you can provide as custom input to the Lambda function that you specify for the "Verify Auth Challenge" trigger.""" - return self["request"].get("clientMetadata") + return self["request"].get("clientMetadata") or {} @property def user_not_found(self) -> Optional[bool]: diff --git a/aws_lambda_powertools/utilities/data_classes/common.py b/aws_lambda_powertools/utilities/data_classes/common.py index 76726ca5129..7e9ed2471d2 100644 --- a/aws_lambda_powertools/utilities/data_classes/common.py +++ b/aws_lambda_powertools/utilities/data_classes/common.py @@ -1,15 +1,52 @@ import base64 import json -from collections.abc import Mapping from functools import cached_property -from typing import Any, Callable, Dict, Iterator, List, Optional, overload +from typing import Any, Callable, Dict, Iterator, List, Mapping, Optional from aws_lambda_powertools.shared.headers_serializer import BaseHeadersSerializer -from aws_lambda_powertools.utilities.data_classes.shared_functions import ( - get_header_value, - get_multi_value_query_string_values, - get_query_string_value, -) + + +class CaseInsensitiveDict(dict): + """Case insensitive dict implementation. Assumes string keys only.""" + + def __init__(self, data=None, **kwargs): + super().__init__() + self.update(data, **kwargs) + + def get(self, k, default=None): + return super().get(k.lower(), default) + + def pop(self, k): + return super().pop(k.lower()) + + def setdefault(self, k, default=None): + return super().setdefault(k.lower(), default) + + def update(self, data=None, **kwargs): + if data is not None: + if isinstance(data, Mapping): + data = data.items() + super().update((k.lower(), v) for k, v in data) + super().update((k.lower(), v) for k, v in kwargs) + + def __contains__(self, k): + return super().__contains__(k.lower()) + + def __delitem__(self, k): + super().__delitem__(k.lower()) + + def __eq__(self, other): + if not isinstance(other, Mapping): + return False + if not isinstance(other, CaseInsensitiveDict): + other = CaseInsensitiveDict(other) + return super().__eq__(other) + + def __getitem__(self, k): + return super().__getitem__(k.lower()) + + def __setitem__(self, k, v): + super().__setitem__(k.lower(), v) class DictWrapper(Mapping): @@ -98,17 +135,17 @@ def raw_event(self) -> Dict[str, Any]: class BaseProxyEvent(DictWrapper): @property def headers(self) -> Dict[str, str]: - return self.get("headers") or {} + return CaseInsensitiveDict(self.get("headers")) @property - def query_string_parameters(self) -> Optional[Dict[str, str]]: - return self.get("queryStringParameters") + def query_string_parameters(self) -> Dict[str, str]: + return self.get("queryStringParameters") or {} @property def multi_value_query_string_parameters(self) -> Dict[str, List[str]]: return self.get("multiValueQueryStringParameters") or {} - @property + @cached_property def resolved_query_string_parameters(self) -> Dict[str, List[str]]: """ This property determines the appropriate query string parameter to be used @@ -117,14 +154,10 @@ def resolved_query_string_parameters(self) -> Dict[str, List[str]]: This is necessary because different resolvers use different formats to encode multi query string parameters. """ - if self.query_string_parameters is not None: - query_string = {key: value.split(",") for key, value in self.query_string_parameters.items()} - return query_string - - return {} + return {k: v.split(",") for k, v in self.query_string_parameters.items()} @property - def resolved_headers_field(self) -> Dict[str, Any]: + def resolved_headers_field(self) -> Dict[str, str]: """ This property determines the appropriate header to be used as a trusted source for validating OpenAPI. @@ -172,101 +205,6 @@ def http_method(self) -> str: """The HTTP method used. Valid values include: DELETE, GET, HEAD, OPTIONS, PATCH, POST, and PUT.""" return self["httpMethod"] - @overload - def get_query_string_value(self, name: str, default_value: str) -> str: ... - - @overload - def get_query_string_value(self, name: str, default_value: Optional[str] = None) -> Optional[str]: ... - - def get_query_string_value(self, name: str, default_value: Optional[str] = None) -> Optional[str]: - """Get query string value by name - - Parameters - ---------- - name: str - Query string parameter name - default_value: str, optional - Default value if no value was found by name - Returns - ------- - str, optional - Query string parameter value - """ - return get_query_string_value( - query_string_parameters=self.query_string_parameters, - name=name, - default_value=default_value, - ) - - def get_multi_value_query_string_values( - self, - name: str, - default_values: Optional[List[str]] = None, - ) -> List[str]: - """Get multi-value query string parameter values by name - - Parameters - ---------- - name: str - Multi-Value query string parameter name - default_values: List[str], optional - Default values is no values are found by name - Returns - ------- - List[str], optional - List of query string values - - """ - return get_multi_value_query_string_values( - multi_value_query_string_parameters=self.multi_value_query_string_parameters, - name=name, - default_values=default_values, - ) - - @overload - def get_header_value( - self, - name: str, - default_value: str, - case_sensitive: bool = False, - ) -> str: ... - - @overload - def get_header_value( - self, - name: str, - default_value: Optional[str] = None, - case_sensitive: bool = False, - ) -> Optional[str]: ... - - def get_header_value( - self, - name: str, - default_value: Optional[str] = None, - case_sensitive: bool = False, - ) -> Optional[str]: - """Get header value by name - - Parameters - ---------- - name: str - Header name - default_value: str, optional - Default value if no value was found by name - case_sensitive: bool - Whether to use a case-sensitive look up. By default we make a case-insensitive lookup. - Returns - ------- - str, optional - Header value - """ - return get_header_value( - headers=self.headers, - name=name, - default_value=default_value, - case_sensitive=case_sensitive, - ) - def header_serializer(self) -> BaseHeadersSerializer: raise NotImplementedError() diff --git a/aws_lambda_powertools/utilities/data_classes/dynamo_db_stream_event.py b/aws_lambda_powertools/utilities/data_classes/dynamo_db_stream_event.py index d0d1bd7ab41..139a70e9065 100644 --- a/aws_lambda_powertools/utilities/data_classes/dynamo_db_stream_event.py +++ b/aws_lambda_powertools/utilities/data_classes/dynamo_db_stream_event.py @@ -1,4 +1,5 @@ from enum import Enum +from functools import cached_property from typing import Any, Dict, Iterator, Optional from aws_lambda_powertools.shared.dynamodb_deserializer import TypeDeserializer @@ -27,7 +28,7 @@ def __init__(self, data: Dict[str, Any]): super().__init__(data) self._deserializer = TypeDeserializer() - def _deserialize_dynamodb_dict(self, key: str) -> Optional[Dict[str, Any]]: + def _deserialize_dynamodb_dict(self, key: str) -> Dict[str, Any]: """Deserialize DynamoDB records available in `Keys`, `NewImage`, and `OldImage` Parameters @@ -37,13 +38,10 @@ def _deserialize_dynamodb_dict(self, key: str) -> Optional[Dict[str, Any]]: Returns ------- - Optional[Dict[str, Any]] + Dict[str, Any] Deserialized records in Python native types """ - dynamodb_dict = self._data.get(key) - if dynamodb_dict is None: - return None - + dynamodb_dict = self._data.get(key) or {} return {k: self._deserializer.deserialize(v) for k, v in dynamodb_dict.items()} @property @@ -52,18 +50,18 @@ def approximate_creation_date_time(self) -> Optional[int]: item = self.get("ApproximateCreationDateTime") return None if item is None else int(item) - @property - def keys(self) -> Optional[Dict[str, Any]]: # type: ignore[override] + @cached_property + def keys(self) -> Dict[str, Any]: # type: ignore[override] """The primary key attribute(s) for the DynamoDB item that was modified.""" return self._deserialize_dynamodb_dict("Keys") - @property - def new_image(self) -> Optional[Dict[str, Any]]: + @cached_property + def new_image(self) -> Dict[str, Any]: """The item in the DynamoDB table as it appeared after it was modified.""" return self._deserialize_dynamodb_dict("NewImage") - @property - def old_image(self) -> Optional[Dict[str, Any]]: + @cached_property + def old_image(self) -> Dict[str, Any]: """The item in the DynamoDB table as it appeared before it was modified.""" return self._deserialize_dynamodb_dict("OldImage") @@ -132,9 +130,9 @@ def event_version(self) -> Optional[str]: return self.get("eventVersion") @property - def user_identity(self) -> Optional[dict]: + def user_identity(self) -> dict: """Contains details about the type of identity that made the request""" - return self.get("userIdentity") + return self.get("userIdentity") or {} class DynamoDBStreamEvent(DictWrapper): diff --git a/aws_lambda_powertools/utilities/data_classes/kafka_event.py b/aws_lambda_powertools/utilities/data_classes/kafka_event.py index f20c5254730..f73802ba699 100644 --- a/aws_lambda_powertools/utilities/data_classes/kafka_event.py +++ b/aws_lambda_powertools/utilities/data_classes/kafka_event.py @@ -1,11 +1,8 @@ import base64 from functools import cached_property -from typing import Any, Dict, Iterator, List, Optional, overload +from typing import Any, Dict, Iterator, List, Optional -from aws_lambda_powertools.utilities.data_classes.common import DictWrapper -from aws_lambda_powertools.utilities.data_classes.shared_functions import ( - get_header_value, -) +from aws_lambda_powertools.utilities.data_classes.common import CaseInsensitiveDict, DictWrapper class KafkaEventRecord(DictWrapper): @@ -64,40 +61,10 @@ def headers(self) -> List[Dict[str, List[int]]]: """The raw Kafka record headers.""" return self["headers"] - @property + @cached_property def decoded_headers(self) -> Dict[str, bytes]: """Decodes the headers as a single dictionary.""" - return {k: bytes(v) for chunk in self.headers for k, v in chunk.items()} - - @overload - def get_header_value( - self, - name: str, - default_value: str, - case_sensitive: bool = True, - ) -> str: ... - - @overload - def get_header_value( - self, - name: str, - default_value: Optional[str] = None, - case_sensitive: bool = True, - ) -> Optional[str]: ... - - def get_header_value( - self, - name: str, - default_value: Optional[str] = None, - case_sensitive: bool = True, - ) -> Optional[str]: - """Get a decoded header value by name.""" - return get_header_value( - headers=self.decoded_headers, - name=name, - default_value=default_value, - case_sensitive=case_sensitive, - ) + return CaseInsensitiveDict((k, bytes(v)) for chunk in self.headers for k, v in chunk.items()) class KafkaEvent(DictWrapper): diff --git a/aws_lambda_powertools/utilities/data_classes/s3_batch_operation_event.py b/aws_lambda_powertools/utilities/data_classes/s3_batch_operation_event.py index 9c742e0c553..5419f6f8088 100644 --- a/aws_lambda_powertools/utilities/data_classes/s3_batch_operation_event.py +++ b/aws_lambda_powertools/utilities/data_classes/s3_batch_operation_event.py @@ -147,9 +147,9 @@ def get_id(self) -> str: return self["id"] @property - def user_arguments(self) -> Optional[Dict[str, str]]: + def user_arguments(self) -> Dict[str, str]: """Get user arguments provided for this job (only for invocation schema 2.0)""" - return self.get("userArguments") + return self.get("userArguments") or {} class S3BatchOperationTask(DictWrapper): diff --git a/aws_lambda_powertools/utilities/data_classes/s3_object_event.py b/aws_lambda_powertools/utilities/data_classes/s3_object_event.py index dc79b72766f..728773a717d 100644 --- a/aws_lambda_powertools/utilities/data_classes/s3_object_event.py +++ b/aws_lambda_powertools/utilities/data_classes/s3_object_event.py @@ -1,9 +1,6 @@ -from typing import Dict, Optional, overload +from typing import Dict, Optional -from aws_lambda_powertools.utilities.data_classes.common import DictWrapper -from aws_lambda_powertools.utilities.data_classes.shared_functions import ( - get_header_value, -) +from aws_lambda_powertools.utilities.data_classes.common import CaseInsensitiveDict, DictWrapper class S3ObjectContext(DictWrapper): @@ -71,46 +68,7 @@ def headers(self) -> Dict[str, str]: If the same header appears multiple times, their values are combined into a comma-delimited list. The case of the original headers is retained in this map.""" - return self["headers"] - - @overload - def get_header_value( - self, - name: str, - default_value: str, - case_sensitive: bool = False, - ) -> str: ... - - @overload - def get_header_value( - self, - name: str, - default_value: Optional[str] = None, - case_sensitive: bool = False, - ) -> Optional[str]: ... - - def get_header_value( - self, - name: str, - default_value: Optional[str] = None, - case_sensitive: bool = False, - ) -> Optional[str]: - """Get header value by name - - Parameters - ---------- - name: str - Header name - default_value: str, optional - Default value if no value was found by name - case_sensitive: bool - Whether to use a case-sensitive look up - Returns - ------- - str, optional - Header value - """ - return get_header_value(self.headers, name, default_value, case_sensitive) + return CaseInsensitiveDict(self["headers"]) class S3ObjectSessionIssuer(DictWrapper): diff --git a/aws_lambda_powertools/utilities/data_classes/ses_event.py b/aws_lambda_powertools/utilities/data_classes/ses_event.py index 2ebc02e22a0..5adcf7149ee 100644 --- a/aws_lambda_powertools/utilities/data_classes/ses_event.py +++ b/aws_lambda_powertools/utilities/data_classes/ses_event.py @@ -46,24 +46,24 @@ def subject(self) -> str: return str(self["subject"]) @property - def cc(self) -> Optional[List[str]]: + def cc(self) -> List[str]: """The values in the CC header of the email.""" - return self.get("cc") + return self.get("cc") or [] @property - def bcc(self) -> Optional[List[str]]: + def bcc(self) -> List[str]: """The values in the BCC header of the email.""" - return self.get("bcc") + return self.get("bcc") or [] @property - def sender(self) -> Optional[List[str]]: + def sender(self) -> List[str]: """The values in the Sender header of the email.""" - return self.get("sender") + return self.get("sender") or [] @property - def reply_to(self) -> Optional[List[str]]: + def reply_to(self) -> List[str]: """The values in the replyTo header of the email.""" - return self.get("replyTo") + return self.get("replyTo") or [] class SESMail(DictWrapper): diff --git a/aws_lambda_powertools/utilities/data_classes/shared_functions.py b/aws_lambda_powertools/utilities/data_classes/shared_functions.py index 0e88a5dac93..4f3451714a1 100644 --- a/aws_lambda_powertools/utilities/data_classes/shared_functions.py +++ b/aws_lambda_powertools/utilities/data_classes/shared_functions.py @@ -1,7 +1,4 @@ -from __future__ import annotations - import base64 -from typing import Any, Dict, overload def base64_decode(value: str) -> str: @@ -19,129 +16,3 @@ def base64_decode(value: str) -> str: The decoded string value. """ return base64.b64decode(value).decode("UTF-8") - - -@overload -def get_header_value( - headers: dict[str, Any], - name: str, - default_value: str, - case_sensitive: bool = False, -) -> str: ... - - -@overload -def get_header_value( - headers: dict[str, Any], - name: str, - default_value: str | None = None, - case_sensitive: bool = False, -) -> str | None: ... - - -def get_header_value( - headers: dict[str, Any], - name: str, - default_value: str | None = None, - case_sensitive: bool = False, -) -> str | None: - """ - Get the value of a header by its name. - - Parameters - ---------- - headers: Dict[str, str] - The dictionary of headers. - name: str - The name of the header to retrieve. - default_value: str, optional - The default value to return if the header is not found. Default is None. - case_sensitive: bool, optional - Indicates whether the header name should be case-sensitive. Default is False. - - Returns - ------- - str, optional - The value of the header if found, otherwise the default value or None. - """ - # If headers is NoneType, return default value - if not headers: - return default_value - - if case_sensitive: - return headers.get(name, default_value) - name_lower = name.lower() - - return next( - # Iterate over the dict and do a case-insensitive key comparison - (value for key, value in headers.items() if key.lower() == name_lower), - # Default value is returned if no matches was found - default_value, - ) - - -@overload -def get_query_string_value( - query_string_parameters: Dict[str, str] | None, - name: str, - default_value: str, -) -> str: ... - - -@overload -def get_query_string_value( - query_string_parameters: Dict[str, str] | None, - name: str, - default_value: str | None = None, -) -> str | None: ... - - -def get_query_string_value( - query_string_parameters: Dict[str, str] | None, - name: str, - default_value: str | None = None, -) -> str | None: - """ - Retrieves the value of a query string parameter specified by the given name. - - Parameters - ---------- - name: str - The name of the query string parameter to retrieve. - default_value: str, optional - The default value to return if the parameter is not found. Defaults to None. - - Returns - ------- - str. optional - The value of the query string parameter if found, or the default value if not found. - """ - params = query_string_parameters - return default_value if params is None else params.get(name, default_value) - - -def get_multi_value_query_string_values( - multi_value_query_string_parameters: Dict[str, list[str]] | None, - name: str, - default_values: list[str] | None = None, -) -> list[str]: - """ - Retrieves the values of a multi-value string parameters specified by the given name. - - Parameters - ---------- - name: str - The name of the query string parameter to retrieve. - default_value: list[str], optional - The default value to return if the parameter is not found. Defaults to None. - - Returns - ------- - List[str]. optional - The values of the query string parameter if found, or the default values if not found. - """ - - default = default_values or [] - params = multi_value_query_string_parameters or {} - - return params.get(name) or default diff --git a/aws_lambda_powertools/utilities/data_classes/vpc_lattice.py b/aws_lambda_powertools/utilities/data_classes/vpc_lattice.py index c28977c56ba..f04c58dc5f0 100644 --- a/aws_lambda_powertools/utilities/data_classes/vpc_lattice.py +++ b/aws_lambda_powertools/utilities/data_classes/vpc_lattice.py @@ -1,16 +1,16 @@ from functools import cached_property -from typing import Any, Dict, Optional, overload +from typing import Any, Dict, Optional from aws_lambda_powertools.shared.headers_serializer import ( BaseHeadersSerializer, HttpApiHeadersSerializer, ) -from aws_lambda_powertools.utilities.data_classes.common import BaseProxyEvent, DictWrapper -from aws_lambda_powertools.utilities.data_classes.shared_functions import ( - base64_decode, - get_header_value, - get_query_string_value, +from aws_lambda_powertools.utilities.data_classes.common import ( + BaseProxyEvent, + CaseInsensitiveDict, + DictWrapper, ) +from aws_lambda_powertools.utilities.data_classes.shared_functions import base64_decode class VPCLatticeEventBase(BaseProxyEvent): @@ -27,7 +27,7 @@ def json_body(self) -> Any: @property def headers(self) -> Dict[str, str]: """The VPC Lattice event headers.""" - return self["headers"] + return CaseInsensitiveDict(self["headers"]) @property def decoded_body(self) -> str: @@ -47,76 +47,6 @@ def http_method(self) -> str: """The HTTP method used. Valid values include: DELETE, GET, HEAD, OPTIONS, PATCH, POST, and PUT.""" return self["method"] - @overload - def get_query_string_value(self, name: str, default_value: str) -> str: ... - - @overload - def get_query_string_value(self, name: str, default_value: Optional[str] = None) -> Optional[str]: ... - - def get_query_string_value(self, name: str, default_value: Optional[str] = None) -> Optional[str]: - """Get query string value by name - - Parameters - ---------- - name: str - Query string parameter name - default_value: str, optional - Default value if no value was found by name - Returns - ------- - str, optional - Query string parameter value - """ - return get_query_string_value( - query_string_parameters=self.query_string_parameters, - name=name, - default_value=default_value, - ) - - @overload - def get_header_value( - self, - name: str, - default_value: str, - case_sensitive: bool = False, - ) -> str: ... - - @overload - def get_header_value( - self, - name: str, - default_value: Optional[str] = None, - case_sensitive: bool = False, - ) -> Optional[str]: ... - - def get_header_value( - self, - name: str, - default_value: Optional[str] = None, - case_sensitive: bool = False, - ) -> Optional[str]: - """Get header value by name - - Parameters - ---------- - name: str - Header name - default_value: str, optional - Default value if no value was found by name - case_sensitive: bool - Whether to use a case-sensitive look up - Returns - ------- - str, optional - Header value - """ - return get_header_value( - headers=self.headers, - name=name, - default_value=default_value, - case_sensitive=case_sensitive, - ) - def header_serializer(self) -> BaseHeadersSerializer: # When using the VPC Lattice integration, we have multiple HTTP Headers. return HttpApiHeadersSerializer() @@ -144,13 +74,9 @@ def query_string_parameters(self) -> Dict[str, str]: """The request query string parameters.""" return self["query_string_parameters"] - @property + @cached_property def resolved_headers_field(self) -> Dict[str, Any]: - if self.headers is not None: - headers = {key.lower(): value.split(",") if "," in value else value for key, value in self.headers.items()} - return headers - - return {} + return CaseInsensitiveDict((k, v.split(",") if "," in v else v) for k, v in self.headers.items()) class vpcLatticeEventV2Identity(DictWrapper): @@ -258,22 +184,12 @@ def request_context(self) -> vpcLatticeEventV2RequestContext: """The VPC Lattice v2 Event request context.""" return vpcLatticeEventV2RequestContext(self["requestContext"]) - @property - def query_string_parameters(self) -> Optional[Dict[str, str]]: + @cached_property + def query_string_parameters(self) -> Dict[str, str]: """The request query string parameters. For VPC Lattice V2, the queryStringParameters will contain a Dict[str, List[str]] so to keep compatibility with existing utilities, we merge all the values with a comma. """ - params = self.get("queryStringParameters") - if params: - return {key: ",".join(value) for key, value in params.items()} - else: - return None - - @property - def resolved_headers_field(self) -> Dict[str, str]: - if self.headers is not None: - return {key.lower(): value for key, value in self.headers.items()} - - return {} + params = self.get("queryStringParameters") or {} + return {k: ",".join(v) for k, v in params.items()} diff --git a/docs/core/event_handler/api_gateway.md b/docs/core/event_handler/api_gateway.md index aa667f5f169..3c182b30e4e 100644 --- a/docs/core/event_handler/api_gateway.md +++ b/docs/core/event_handler/api_gateway.md @@ -466,7 +466,7 @@ That is why you see `app.resolve(event, context)` in every example. This allows #### Query strings and payload -Within `app.current_event` property, you can access all available query strings as a dictionary via `query_string_parameters`, or a specific one via `get_query_string_value` method. +Within `app.current_event` property, you can access all available query strings as a dictionary via `query_string_parameters`. You can access the raw payload via `body` property, or if it's a JSON string you can quickly deserialize it via `json_body` property - like the earlier example in the [HTTP Methods](#http-methods) section. @@ -476,7 +476,7 @@ You can access the raw payload via `body` property, or if it's a JSON string you #### Headers -Similarly to [Query strings](#query-strings-and-payload), you can access headers as dictionary via `app.current_event.headers`, or by name via `get_header_value`. If you prefer a case-insensitive lookup of the header value, the `app.current_event.get_header_value` function automatically handles it. +Similarly to [Query strings](#query-strings-and-payload), you can access headers as dictionary via `app.current_event.headers`. Specifically for headers, it's a case-insensitive dictionary, so all lookups are case-insensitive. ```python hl_lines="19" title="Accessing HTTP Headers" --8<-- "examples/event_handler_rest/src/accessing_request_details_headers.py" diff --git a/docs/utilities/data_classes.md b/docs/utilities/data_classes.md index 0b43f36933e..b481fe7b3a7 100644 --- a/docs/utilities/data_classes.md +++ b/docs/utilities/data_classes.md @@ -175,7 +175,7 @@ Use **`APIGatewayAuthorizerRequestEvent`** for type `REQUEST` and **`APIGatewayA @event_source(data_class=APIGatewayAuthorizerRequestEvent) def handler(event: APIGatewayAuthorizerRequestEvent, context): - user = get_user_by_token(event.get_header_value("Authorization")) + user = get_user_by_token(event.headers["Authorization"]) if user is None: # No user was found @@ -263,7 +263,7 @@ See also [this blog post](https://aws.amazon.com/blogs/compute/introducing-iam-a @event_source(data_class=APIGatewayAuthorizerEventV2) def handler(event: APIGatewayAuthorizerEventV2, context): - user = get_user_by_token(event.get_header_value("x-token")) + user = get_user_by_token(event.headers["x-token"]) if user is None: # No user was found, so we return not authorized @@ -397,7 +397,7 @@ In this example, we also use the new Logger `correlation_id` and built-in `corre event: AppSyncResolverEvent = AppSyncResolverEvent(event) # Case insensitive look up of request headers - x_forwarded_for = event.get_header_value("x-forwarded-for") + x_forwarded_for = event.headers.get("x-forwarded-for") # Support for AppSyncIdentityCognito or AppSyncIdentityIAM identity types assert isinstance(event.identity, AppSyncIdentityCognito) diff --git a/examples/event_handler_graphql/src/custom_models.py b/examples/event_handler_graphql/src/custom_models.py index 61e03318d14..21f5f07af00 100644 --- a/examples/event_handler_graphql/src/custom_models.py +++ b/examples/event_handler_graphql/src/custom_models.py @@ -26,11 +26,11 @@ class Location(TypedDict, total=False): class MyCustomModel(AppSyncResolverEvent): @property def country_viewer(self) -> str: - return self.get_header_value(name="cloudfront-viewer-country", default_value="", case_sensitive=False) + return self.request_headers.get("cloudfront-viewer-country", "") @property def api_key(self) -> str: - return self.get_header_value(name="x-api-key", default_value="", case_sensitive=False) + return self.request_headers.get("x-api-key", "") @app.resolver(type_name="Query", field_name="listLocations") diff --git a/examples/event_handler_rest/src/accessing_request_details.py b/examples/event_handler_rest/src/accessing_request_details.py index 037b76daa66..e9a5d924017 100644 --- a/examples/event_handler_rest/src/accessing_request_details.py +++ b/examples/event_handler_rest/src/accessing_request_details.py @@ -16,12 +16,12 @@ @app.get("/todos") @tracer.capture_method def get_todos(): - todo_id: str = app.current_event.get_query_string_value(name="id", default_value="") + todo_id: str = app.current_event.query_string_parameters["id"] # alternatively _: Optional[str] = app.current_event.query_string_parameters.get("id") # or multi-value query string parameters; ?category="red"&?category="blue" - _: List[str] = app.current_event.get_multi_value_query_string_values(name="category") + _: List[str] = app.current_event.multi_value_query_string_parameters["category"] # Payload _: Optional[str] = app.current_event.body # raw str | None diff --git a/examples/event_handler_rest/src/accessing_request_details_headers.py b/examples/event_handler_rest/src/accessing_request_details_headers.py index f6bfb88c869..de5df2fed0b 100644 --- a/examples/event_handler_rest/src/accessing_request_details_headers.py +++ b/examples/event_handler_rest/src/accessing_request_details_headers.py @@ -16,7 +16,7 @@ def get_todos(): endpoint = "https://jsonplaceholder.typicode.com/todos" - api_key: str = app.current_event.get_header_value(name="X-Api-Key", case_sensitive=True, default_value="") + api_key = app.current_event.headers["X-Api-Key"] todos: Response = requests.get(endpoint, headers={"X-Api-Key": api_key}) todos.raise_for_status() diff --git a/examples/event_handler_rest/src/exception_handling.py b/examples/event_handler_rest/src/exception_handling.py index ea325bd6dc1..24c14bb868d 100644 --- a/examples/event_handler_rest/src/exception_handling.py +++ b/examples/event_handler_rest/src/exception_handling.py @@ -31,7 +31,7 @@ def handle_invalid_limit_qs(ex: ValueError): # receives exception raised def get_todos(): # educational purpose only: we should receive a `ValueError` # if a query string value for `limit` cannot be coerced to int - max_results: int = int(app.current_event.get_query_string_value(name="limit", default_value=0)) + max_results = int(app.current_event.query_string_parameters.get("limit", 0)) todos: requests.Response = requests.get(f"https://jsonplaceholder.typicode.com/todos?limit={max_results}") todos.raise_for_status() diff --git a/examples/event_handler_rest/src/middleware_extending_middlewares.py b/examples/event_handler_rest/src/middleware_extending_middlewares.py index e492caacf47..ad448c03d30 100644 --- a/examples/event_handler_rest/src/middleware_extending_middlewares.py +++ b/examples/event_handler_rest/src/middleware_extending_middlewares.py @@ -22,10 +22,7 @@ def __init__(self, header: str): # (1)! def handler(self, app: APIGatewayRestResolver, next_middleware: NextMiddleware) -> Response: # (2)! request_id = app.current_event.request_context.request_id - correlation_id = app.current_event.get_header_value( - name=self.header, - default_value=request_id, - ) + correlation_id = app.current_event.headers.get(self.header, request_id) response = next_middleware(app) # (3)! response.headers[self.header] = correlation_id diff --git a/examples/event_handler_rest/src/middleware_global_middlewares_module.py b/examples/event_handler_rest/src/middleware_global_middlewares_module.py index 2b06bc31c71..96745a28448 100644 --- a/examples/event_handler_rest/src/middleware_global_middlewares_module.py +++ b/examples/event_handler_rest/src/middleware_global_middlewares_module.py @@ -34,7 +34,7 @@ def inject_correlation_id(app: APIGatewayRestResolver, next_middleware: NextMidd def enforce_correlation_id(app: APIGatewayRestResolver, next_middleware: NextMiddleware) -> Response: # If missing mandatory header raise an error - if not app.current_event.get_header_value("x-correlation-id", case_sensitive=False): + if not app.current_event.headers.get("x-correlation-id"): return Response(status_code=400, body="Correlation ID header is now mandatory.") # (1)! # Get the response from the next middleware and return it diff --git a/examples/event_handler_rest/src/split_route_module.py b/examples/event_handler_rest/src/split_route_module.py index b6a91b3fb3b..b67d5d0568b 100644 --- a/examples/event_handler_rest/src/split_route_module.py +++ b/examples/event_handler_rest/src/split_route_module.py @@ -13,7 +13,7 @@ @router.get("/todos") @tracer.capture_method def get_todos(): - api_key: str = router.current_event.get_header_value(name="X-Api-Key", case_sensitive=True, default_value="") + api_key = router.current_event.headers["X-Api-Key"] todos: Response = requests.get(endpoint, headers={"X-Api-Key": api_key}) todos.raise_for_status() @@ -25,11 +25,7 @@ def get_todos(): @router.get("/todos/") @tracer.capture_method def get_todo_by_id(todo_id: str): # value come as str - api_key: str = router.current_event.get_header_value( - name="X-Api-Key", - case_sensitive=True, - default_value="", - ) # noqa: E501 + api_key = router.current_event.headers["X-Api-Key"] todos: Response = requests.get(f"{endpoint}/{todo_id}", headers={"X-Api-Key": api_key}) todos.raise_for_status() diff --git a/examples/event_handler_rest/src/split_route_prefix_module.py b/examples/event_handler_rest/src/split_route_prefix_module.py index aa17e0cd347..c112a772c6e 100644 --- a/examples/event_handler_rest/src/split_route_prefix_module.py +++ b/examples/event_handler_rest/src/split_route_prefix_module.py @@ -13,7 +13,7 @@ @router.get("/") @tracer.capture_method def get_todos(): - api_key: str = router.current_event.get_header_value(name="X-Api-Key", case_sensitive=True, default_value="") + api_key = router.current_event.headers["X-Api-Key"] todos: Response = requests.get(endpoint, headers={"X-Api-Key": api_key}) todos.raise_for_status() @@ -25,11 +25,7 @@ def get_todos(): @router.get("/") @tracer.capture_method def get_todo_by_id(todo_id: str): # value come as str - api_key: str = router.current_event.get_header_value( - name="X-Api-Key", - case_sensitive=True, - default_value="", - ) # sentinel typing # noqa: E501 + api_key = router.current_event.headers["X-Api-Key"] todos: Response = requests.get(f"{endpoint}/{todo_id}", headers={"X-Api-Key": api_key}) todos.raise_for_status() diff --git a/tests/functional/event_handler/test_api_middlewares.py b/tests/functional/event_handler/test_api_middlewares.py index 58bec259072..ed5c3ecb21b 100644 --- a/tests/functional/event_handler/test_api_middlewares.py +++ b/tests/functional/event_handler/test_api_middlewares.py @@ -484,10 +484,7 @@ def __init__(self, header: str): def handler(self, app: ApiGatewayResolver, get_response: NextMiddleware, **kwargs) -> Response: request_id = app.current_event.request_context.request_id # type: ignore[attr-defined] # using REST event in a base Resolver # noqa: E501 - correlation_id = app.current_event.get_header_value( - name=self.header, - default_value=request_id, - ) # noqa: E501 + correlation_id = app.current_event.headers.get(self.header, request_id) response = get_response(app, **kwargs) response.headers[self.header] = correlation_id diff --git a/tests/functional/event_handler/test_appsync.py b/tests/functional/event_handler/test_appsync.py index 5699e560065..47fd583031b 100644 --- a/tests/functional/event_handler/test_appsync.py +++ b/tests/functional/event_handler/test_appsync.py @@ -145,8 +145,8 @@ def test_resolve_custom_data_model(): class MyCustomModel(AppSyncResolverEvent): @property - def country_viewer(self): - return self.request_headers.get("cloudfront-viewer-country") + def country_viewer(self) -> str: + return self.request_headers.get("cloudfront-viewer-country", "") app = AppSyncResolver() diff --git a/tests/unit/data_classes/test_alb_event.py b/tests/unit/data_classes/test_alb_event.py index 47048ab9407..6945dc67c36 100644 --- a/tests/unit/data_classes/test_alb_event.py +++ b/tests/unit/data_classes/test_alb_event.py @@ -14,6 +14,6 @@ def test_alb_event(): assert parsed_event.multi_value_query_string_parameters == raw_event.get("multiValueQueryStringParameters", {}) - assert parsed_event.multi_value_headers == raw_event.get("multiValueHeaders") + assert parsed_event.multi_value_headers == (raw_event.get("multiValueHeaders") or {}) assert parsed_event.body == raw_event["body"] assert parsed_event.is_base64_encoded == raw_event["isBase64Encoded"] diff --git a/tests/unit/data_classes/test_api_gateway_authorizer_event.py b/tests/unit/data_classes/test_api_gateway_authorizer_event.py index 2c5f170d924..4ae44643474 100644 --- a/tests/unit/data_classes/test_api_gateway_authorizer_event.py +++ b/tests/unit/data_classes/test_api_gateway_authorizer_event.py @@ -52,16 +52,16 @@ def test_api_gateway_authorizer_v2(): assert parsed_event.path_parameters == raw_event["pathParameters"] assert parsed_event.stage_variables == raw_event["stageVariables"] - assert parsed_event.get_header_value("Authorization") == "value" - assert parsed_event.get_header_value("authorization") == "value" - assert parsed_event.get_header_value("missing") is None + assert parsed_event.headers["Authorization"] == "value" + assert parsed_event.headers["authorization"] == "value" + assert parsed_event.headers.get("missing") is None # Check for optionals event_optionals = APIGatewayAuthorizerEventV2({"requestContext": {}}) - assert event_optionals.identity_source is None + assert event_optionals.identity_source == [] assert event_optionals.request_context.authentication is None - assert event_optionals.path_parameters is None - assert event_optionals.stage_variables is None + assert event_optionals.path_parameters == {} + assert event_optionals.stage_variables == {} def test_api_gateway_authorizer_token_event(): @@ -90,7 +90,7 @@ def test_api_gateway_authorizer_request_event(): assert parsed_event.path == raw_event["path"] assert parsed_event.http_method == raw_event["httpMethod"] assert parsed_event.headers == raw_event["headers"] - assert parsed_event.get_header_value("accept") == "*/*" + assert parsed_event.headers["accept"] == "*/*" assert parsed_event.query_string_parameters == raw_event["queryStringParameters"] assert parsed_event.path_parameters == raw_event["pathParameters"] assert parsed_event.stage_variables == raw_event["stageVariables"] diff --git a/tests/unit/data_classes/test_api_gateway_proxy_event.py b/tests/unit/data_classes/test_api_gateway_proxy_event.py index d86e4b5e19b..42925ee9c9f 100644 --- a/tests/unit/data_classes/test_api_gateway_proxy_event.py +++ b/tests/unit/data_classes/test_api_gateway_proxy_event.py @@ -54,8 +54,8 @@ def test_default_api_gateway_proxy_event(): assert identity.user_arn == identity_raw["userArn"] assert identity.client_cert.subject_dn == "www.example.com" - assert parsed_event.path_parameters == raw_event["pathParameters"] - assert parsed_event.stage_variables == raw_event["stageVariables"] + assert parsed_event.path_parameters == (raw_event["pathParameters"] or {}) + assert parsed_event.stage_variables == (raw_event["stageVariables"] or {}) assert parsed_event.body == raw_event["body"] assert parsed_event.is_base64_encoded == raw_event["isBase64Encoded"] @@ -121,8 +121,8 @@ def test_api_gateway_proxy_event(): assert identity.user_arn == identity_raw["userArn"] assert identity.client_cert.subject_dn == "www.example.com" - assert parsed_event.path_parameters == raw_event["pathParameters"] - assert parsed_event.stage_variables == raw_event["stageVariables"] + assert parsed_event.path_parameters == (raw_event["pathParameters"] or {}) + assert parsed_event.stage_variables == (raw_event["stageVariables"] or {}) assert parsed_event.body == raw_event["body"] assert parsed_event.is_base64_encoded == raw_event["isBase64Encoded"] diff --git a/tests/unit/data_classes/test_appsync_resolver_event.py b/tests/unit/data_classes/test_appsync_resolver_event.py index a1a010c251a..da607d05379 100644 --- a/tests/unit/data_classes/test_appsync_resolver_event.py +++ b/tests/unit/data_classes/test_appsync_resolver_event.py @@ -17,19 +17,19 @@ def test_appsync_resolver_event(): assert parsed_event.arguments.get("name") == raw_event["arguments"]["name"] assert parsed_event.identity.claims.get("token_use") == raw_event["identity"]["claims"]["token_use"] assert parsed_event.source.get("name") == raw_event["source"]["name"] - assert parsed_event.get_header_value("X-amzn-trace-id") == "Root=1-60488877-0b0c4e6727ab2a1c545babd0" - assert parsed_event.get_header_value("X-amzn-trace-id", case_sensitive=True) is None - assert parsed_event.get_header_value("missing", default_value="Foo") == "Foo" + assert parsed_event.request_headers["X-amzn-trace-id"] == "Root=1-60488877-0b0c4e6727ab2a1c545babd0" + assert parsed_event.request_headers["x-amzn-trace-id"] == "Root=1-60488877-0b0c4e6727ab2a1c545babd0" + assert parsed_event.request_headers.get("missing", "Foo") == "Foo" assert parsed_event.prev_result == {} - assert parsed_event.stash is None + assert parsed_event.stash == {} info = parsed_event.info assert info is not None assert isinstance(info, AppSyncResolverEventInfo) assert info.field_name == raw_event["fieldName"] assert info.parent_type_name == raw_event["typeName"] - assert info.variables is None - assert info.selection_set_list is None + assert info.variables == {} + assert info.selection_set_list == [] assert info.selection_set_graphql is None assert isinstance(parsed_event.identity, AppSyncIdentityCognito) @@ -80,7 +80,7 @@ def test_appsync_resolver_direct(): raw_event = load_event("appSyncDirectResolver.json") parsed_event = AppSyncResolverEvent(raw_event) - assert parsed_event.source is None + assert parsed_event.source == {} assert parsed_event.arguments.get("id") == raw_event["arguments"]["id"] assert parsed_event.stash == {} assert parsed_event.prev_result is None @@ -90,7 +90,6 @@ def test_appsync_resolver_direct(): info_raw = raw_event["info"] assert info is not None assert isinstance(info, AppSyncResolverEventInfo) - assert info.selection_set_list is not None assert info.selection_set_list == info["selectionSetList"] assert info.selection_set_graphql == info_raw["selectionSetGraphQL"] assert info.parent_type_name == info_raw["parentTypeName"] @@ -112,7 +111,7 @@ def test_appsync_resolver_event_info(): event = AppSyncResolverEvent(event) - assert event.source is None + assert event.source == {} assert event.identity is None assert event.info is not None assert isinstance(event.info, AppSyncResolverEventInfo) diff --git a/tests/unit/data_classes/test_cloud_watch_alarm_event.py b/tests/unit/data_classes/test_cloud_watch_alarm_event.py index 56933a1505d..df72a7ff1e1 100644 --- a/tests/unit/data_classes/test_cloud_watch_alarm_event.py +++ b/tests/unit/data_classes/test_cloud_watch_alarm_event.py @@ -102,3 +102,4 @@ def test_cloud_watch_alarm_event_composite_metric(): parsed_event.alarm_data.configuration.alarm_actions_suppressor == raw_event["alarmData"]["configuration"]["actionsSuppressor"] ) + assert isinstance(parsed_event.alarm_data.configuration.metrics, List) diff --git a/tests/unit/data_classes/test_cloud_watch_logs_event.py b/tests/unit/data_classes/test_cloud_watch_logs_event.py index c65c55d6334..10a3a499dd0 100644 --- a/tests/unit/data_classes/test_cloud_watch_logs_event.py +++ b/tests/unit/data_classes/test_cloud_watch_logs_event.py @@ -24,7 +24,7 @@ def test_cloud_watch_trigger_event(): assert log_event.get_id == "eventId1" assert log_event.timestamp == 1440442987000 assert log_event.message == "[ERROR] First test message" - assert log_event.extracted_fields is None + assert log_event.extracted_fields == {} event2 = CloudWatchLogsEvent(load_event("cloudWatchLogEvent.json")) assert parsed_event.raw_event == event2.raw_event @@ -52,7 +52,7 @@ def test_cloud_watch_trigger_event_with_policy_level(): assert log_event.get_id == "eventId1" assert log_event.timestamp == 1440442987000 assert log_event.message == "[ERROR] First test message" - assert log_event.extracted_fields is None + assert log_event.extracted_fields == {} event2 = CloudWatchLogsEvent(load_event("cloudWatchLogEventWithPolicyLevel.json")) assert parsed_event.raw_event == event2.raw_event diff --git a/tests/unit/data_classes/test_code_pipeline_job_event.py b/tests/unit/data_classes/test_code_pipeline_job_event.py index a1689ede2f1..75e68b44396 100644 --- a/tests/unit/data_classes/test_code_pipeline_job_event.py +++ b/tests/unit/data_classes/test_code_pipeline_job_event.py @@ -93,8 +93,8 @@ def test_code_pipeline_event_missing_user_parameters(): configuration = parsed_event.data.action_configuration.configuration decoded_params = configuration.decoded_user_parameters assert decoded_params == parsed_event.decoded_user_parameters - assert decoded_params is None - assert configuration.decoded_user_parameters is None + assert decoded_params == {} + assert configuration.decoded_user_parameters == {} def test_code_pipeline_event_non_json_user_parameters(): diff --git a/tests/unit/data_classes/test_cognito_user_pool_event.py b/tests/unit/data_classes/test_cognito_user_pool_event.py index 2321f23c16e..9c4285fd18a 100644 --- a/tests/unit/data_classes/test_cognito_user_pool_event.py +++ b/tests/unit/data_classes/test_cognito_user_pool_event.py @@ -32,8 +32,8 @@ def test_cognito_pre_signup_trigger_event(): # Verify properties user_attributes = parsed_event.request.user_attributes assert user_attributes.get("email") == raw_event["request"]["userAttributes"]["email"] - assert parsed_event.request.validation_data is None - assert parsed_event.request.client_metadata is None + assert parsed_event.request.validation_data == {} + assert parsed_event.request.client_metadata == {} # Verify setters parsed_event.response.auto_confirm_user = True @@ -53,7 +53,7 @@ def test_cognito_post_confirmation_trigger_event(): user_attributes = parsed_event.request.user_attributes assert user_attributes.get("email") == raw_event["request"]["userAttributes"]["email"] - assert parsed_event.request.client_metadata is None + assert parsed_event.request.client_metadata == {} def test_cognito_user_migration_trigger_event(): @@ -63,8 +63,8 @@ def test_cognito_user_migration_trigger_event(): assert parsed_event.trigger_source == raw_event["triggerSource"] assert compare_digest(parsed_event.request.password, raw_event["request"]["password"]) - assert parsed_event.request.validation_data is None - assert parsed_event.request.client_metadata is None + assert parsed_event.request.validation_data == {} + assert parsed_event.request.client_metadata == {} parsed_event.response.user_attributes = {"username": "username"} assert parsed_event.response.user_attributes == raw_event["response"]["userAttributes"] @@ -72,7 +72,7 @@ def test_cognito_user_migration_trigger_event(): assert parsed_event.response.final_user_status is None assert parsed_event.response.message_action is None assert parsed_event.response.force_alias_creation is None - assert parsed_event.response.desired_delivery_mediums is None + assert parsed_event.response.desired_delivery_mediums == [] parsed_event.response.final_user_status = "CONFIRMED" assert parsed_event.response.final_user_status == "CONFIRMED" @@ -93,7 +93,7 @@ def test_cognito_custom_message_trigger_event(): assert parsed_event.request.code_parameter == raw_event["request"]["codeParameter"] assert parsed_event.request.username_parameter == raw_event["request"]["usernameParameter"] assert parsed_event.request.user_attributes.get("phone_number_verified") is False - assert parsed_event.request.client_metadata is None + assert parsed_event.request.client_metadata == {} parsed_event.response.sms_message = "sms" assert parsed_event.response.sms_message == parsed_event["response"]["smsMessage"] @@ -113,7 +113,7 @@ def test_cognito_pre_authentication_trigger_event(): parsed_event["request"]["userNotFound"] = True assert parsed_event.request.user_not_found is True assert parsed_event.request.user_attributes.get("email") == raw_event["request"]["userAttributes"]["email"] - assert parsed_event.request.validation_data is None + assert parsed_event.request.validation_data == {} def test_cognito_post_authentication_trigger_event(): @@ -124,7 +124,7 @@ def test_cognito_post_authentication_trigger_event(): assert parsed_event.request.new_device_used is True assert parsed_event.request.user_attributes.get("email") == raw_event["request"]["userAttributes"]["email"] - assert parsed_event.request.client_metadata is None + assert parsed_event.request.client_metadata == {} def test_cognito_pre_token_generation_trigger_event(): @@ -138,7 +138,7 @@ def test_cognito_pre_token_generation_trigger_event(): assert group_configuration.iam_roles_to_override == [] assert group_configuration.preferred_role is None assert parsed_event.request.user_attributes.get("email") == raw_event["request"]["userAttributes"]["email"] - assert parsed_event.request.client_metadata is None + assert parsed_event.request.client_metadata == {} parsed_event["request"]["groupConfiguration"]["preferredRole"] = "temp" group_configuration = parsed_event.request.group_configuration @@ -148,8 +148,8 @@ def test_cognito_pre_token_generation_trigger_event(): claims_override_details = parsed_event.response.claims_override_details assert parsed_event["response"]["claimsOverrideDetails"] == {} - assert claims_override_details.claims_to_add_or_override is None - assert claims_override_details.claims_to_suppress is None + assert claims_override_details.claims_to_add_or_override == {} + assert claims_override_details.claims_to_suppress == [] assert claims_override_details.group_configuration is None claims_override_details.group_configuration = {} @@ -208,7 +208,7 @@ def test_cognito_define_auth_challenge_trigger_event(): assert session[0].challenge_result is True assert session[0].challenge_metadata is None assert session[1].challenge_metadata == raw_event["request"]["session"][1]["challengeMetadata"] - assert parsed_event.request.client_metadata is None + assert parsed_event.request.client_metadata == {} # Verify setters parsed_event.response.challenge_name = "CUSTOM_CHALLENGE" @@ -236,7 +236,7 @@ def test_create_auth_challenge_trigger_event(): assert len(session) == 1 assert session[0].challenge_name == raw_event["request"]["session"][0]["challengeName"] assert session[0].challenge_metadata == raw_event["request"]["session"][0]["challengeMetadata"] - assert parsed_event.request.client_metadata is None + assert parsed_event.request.client_metadata == {} # Verify setters parsed_event.response.public_challenge_parameters = {"test": "value"} @@ -263,7 +263,6 @@ def test_verify_auth_challenge_response_trigger_event(): == raw_event["request"]["privateChallengeParameters"]["answer"] ) assert parsed_event.request.challenge_answer == raw_event["request"]["challengeAnswer"] - assert parsed_event.request.client_metadata is not None assert parsed_event.request.client_metadata.get("foo") == raw_event["request"]["clientMetadata"]["foo"] assert parsed_event.request.user_not_found is True diff --git a/tests/unit/data_classes/test_dynamo_db_stream_event.py b/tests/unit/data_classes/test_dynamo_db_stream_event.py index f7672abd69b..9632563423a 100644 --- a/tests/unit/data_classes/test_dynamo_db_stream_event.py +++ b/tests/unit/data_classes/test_dynamo_db_stream_event.py @@ -30,7 +30,7 @@ def test_dynamodb_stream_trigger_event(): assert record.event_source == record_raw["eventSource"] assert record.event_source_arn == record_raw["eventSourceARN"] assert record.event_version == record_raw["eventVersion"] - assert record.user_identity is None + assert record.user_identity == {} dynamodb = record.dynamodb assert dynamodb is not None assert dynamodb.approximate_creation_date_time == record_raw["dynamodb"]["ApproximateCreationDateTime"] @@ -38,7 +38,7 @@ def test_dynamodb_stream_trigger_event(): assert keys is not None assert keys["Id"] == decimal_context.create_decimal(101) assert dynamodb.new_image.get("Message") == record_raw["dynamodb"]["NewImage"]["Message"]["S"] - assert dynamodb.old_image is None + assert dynamodb.old_image == {} assert dynamodb.sequence_number == record_raw["dynamodb"]["SequenceNumber"] assert dynamodb.size_bytes == record_raw["dynamodb"]["SizeBytes"] assert dynamodb.stream_view_type == StreamViewType.NEW_AND_OLD_IMAGES @@ -94,7 +94,7 @@ def test_dynamodb_stream_record_deserialization(): def test_dynamodb_stream_record_keys_with_no_keys(): record = StreamRecord({}) - assert record.keys is None + assert record.keys == {} def test_dynamodb_stream_record_keys_overrides_dict_wrapper_keys(): diff --git a/tests/unit/data_classes/test_kafka_event.py b/tests/unit/data_classes/test_kafka_event.py index f97fa8e0a0e..fc36171da77 100644 --- a/tests/unit/data_classes/test_kafka_event.py +++ b/tests/unit/data_classes/test_kafka_event.py @@ -31,7 +31,7 @@ def test_kafka_msk_event(): assert record.value == raw_record["value"] assert record.json_value == {"key": "value"} assert record.decoded_headers == {"headerKey": b"headerValue"} - assert record.get_header_value("HeaderKey", case_sensitive=False) == b"headerValue" + assert record.decoded_headers["HeaderKey"] == b"headerValue" assert parsed_event.record == records[0] @@ -62,7 +62,7 @@ def test_kafka_self_managed_event(): assert record.value == raw_record["value"] assert record.json_value == {"key": "value"} assert record.decoded_headers == {"headerKey": b"headerValue"} - assert record.get_header_value("HeaderKey", case_sensitive=False) == b"headerValue" + assert record.decoded_headers["HeaderKey"] == b"headerValue" assert parsed_event.record == records[0] diff --git a/tests/unit/data_classes/test_lambda_function_url.py b/tests/unit/data_classes/test_lambda_function_url.py index f8ce71b1543..ca8e3d78c59 100644 --- a/tests/unit/data_classes/test_lambda_function_url.py +++ b/tests/unit/data_classes/test_lambda_function_url.py @@ -13,17 +13,17 @@ def test_lambda_function_url_event(): assert parsed_event.path == raw_event["rawPath"] assert parsed_event.raw_query_string == raw_event["rawQueryString"] - assert parsed_event.cookies is None + assert parsed_event.cookies == [] headers = parsed_event.headers assert len(headers) == 20 - assert parsed_event.query_string_parameters is None + assert parsed_event.query_string_parameters == {} assert parsed_event.is_base64_encoded is False assert parsed_event.body is None - assert parsed_event.path_parameters is None - assert parsed_event.stage_variables is None + assert parsed_event.path_parameters == {} + assert parsed_event.stage_variables == {} assert parsed_event.http_method == raw_event["requestContext"]["http"]["method"] request_context = parsed_event.request_context @@ -75,8 +75,8 @@ def test_lambda_function_url_event_iam(): assert parsed_event.is_base64_encoded is False assert parsed_event.body == raw_event["body"] assert parsed_event.decoded_body == raw_event["body"] - assert parsed_event.path_parameters is None - assert parsed_event.stage_variables is None + assert parsed_event.path_parameters == {} + assert parsed_event.stage_variables == {} assert parsed_event.http_method == raw_event["requestContext"]["http"]["method"] request_context = parsed_event.request_context diff --git a/tests/unit/data_classes/test_s3_batch_operation_event.py b/tests/unit/data_classes/test_s3_batch_operation_event.py index ca0d4ae635c..44dc65df07d 100644 --- a/tests/unit/data_classes/test_s3_batch_operation_event.py +++ b/tests/unit/data_classes/test_s3_batch_operation_event.py @@ -19,7 +19,7 @@ def test_s3_batch_operation_schema_v1(): job = parsed_event.job assert job.get_id == raw_event["job"]["id"] - assert job.user_arguments is None + assert job.user_arguments == {} assert parsed_event.invocation_schema_version == raw_event["invocationSchemaVersion"] assert parsed_event.invocation_id == raw_event["invocationId"] diff --git a/tests/unit/data_classes/test_s3_object_event.py b/tests/unit/data_classes/test_s3_object_event.py index 47583d9e544..09d0f14e5f6 100644 --- a/tests/unit/data_classes/test_s3_object_event.py +++ b/tests/unit/data_classes/test_s3_object_event.py @@ -23,7 +23,7 @@ def test_s3_object_event_iam(): user_request = parsed_event.user_request assert user_request.url == raw_event["userRequest"]["url"] assert user_request.headers == raw_event["userRequest"]["headers"] - assert user_request.get_header_value("Accept-Encoding") == "identity" + assert user_request.headers["Accept-Encoding"] == "identity" assert parsed_event.user_identity is not None user_identity = parsed_event.user_identity assert user_identity.get_type == raw_event["userIdentity"]["type"] diff --git a/tests/unit/data_classes/test_ses_event.py b/tests/unit/data_classes/test_ses_event.py index 636cf4cccac..e81c546fb1e 100644 --- a/tests/unit/data_classes/test_ses_event.py +++ b/tests/unit/data_classes/test_ses_event.py @@ -29,10 +29,10 @@ def test_ses_trigger_event(): assert common_headers.to == [expected_address] assert common_headers.message_id == common_headers_raw["messageId"] assert common_headers.subject == common_headers_raw["subject"] - assert common_headers.cc is None - assert common_headers.bcc is None - assert common_headers.sender is None - assert common_headers.reply_to is None + assert common_headers.cc == [] + assert common_headers.bcc == [] + assert common_headers.sender == [] + assert common_headers.reply_to == [] receipt = record.ses.receipt raw_receipt = raw_event["Records"][0]["ses"]["receipt"] assert receipt.timestamp == raw_receipt["timestamp"] diff --git a/tests/unit/data_classes/test_vpc_lattice_event.py b/tests/unit/data_classes/test_vpc_lattice_event.py index ab00c51521f..9f5ad742557 100644 --- a/tests/unit/data_classes/test_vpc_lattice_event.py +++ b/tests/unit/data_classes/test_vpc_lattice_event.py @@ -7,8 +7,8 @@ def test_vpc_lattice_event(): parsed_event = VPCLatticeEvent(raw_event) assert parsed_event.raw_path == raw_event["raw_path"] - assert parsed_event.get_query_string_value("order-id") == "1" - assert parsed_event.get_header_value("user_agent") == "curl/7.64.1" + assert parsed_event.query_string_parameters["order-id"] == "1" + assert parsed_event.headers["user_agent"] == "curl/7.64.1" assert parsed_event.decoded_body == '{"test": "event"}' assert parsed_event.json_body == {"test": "event"} assert parsed_event.method == raw_event["method"] diff --git a/tests/unit/data_classes/test_vpc_lattice_eventv2.py b/tests/unit/data_classes/test_vpc_lattice_eventv2.py index 3726831445f..87a9a69be38 100644 --- a/tests/unit/data_classes/test_vpc_lattice_eventv2.py +++ b/tests/unit/data_classes/test_vpc_lattice_eventv2.py @@ -7,8 +7,8 @@ def test_vpc_lattice_v2_event(): parsed_event = VPCLatticeEventV2(raw_event) assert parsed_event.path == raw_event["path"] - assert parsed_event.get_query_string_value("order-id") == "1" - assert parsed_event.get_header_value("user_agent") == "curl/7.64.1" + assert parsed_event.query_string_parameters["order-id"] == "1" + assert parsed_event.headers["user_agent"] == "curl/7.64.1" assert parsed_event.decoded_body == '{"message": "Hello from Lambda!"}' assert parsed_event.json_body == {"message": "Hello from Lambda!"} assert parsed_event.method == raw_event["method"] diff --git a/tests/unit/test_data_classes.py b/tests/unit/test_data_classes.py index 393bcdf250e..63947eade11 100644 --- a/tests/unit/test_data_classes.py +++ b/tests/unit/test_data_classes.py @@ -240,90 +240,6 @@ def data_property(self) -> str: assert str(event_source) == "{'data_property': '[SENSITIVE]', 'raw_event': '[SENSITIVE]'}" -def test_base_proxy_event_get_query_string_value(): - default_value = "default" - set_value = "value" - - event = BaseProxyEvent({}) - value = event.get_query_string_value("test", default_value) - assert value == default_value - - event._data["queryStringParameters"] = {"test": set_value} - value = event.get_query_string_value("test", default_value) - assert value == set_value - - value = event.get_query_string_value("unknown", default_value) - assert value == default_value - - value = event.get_query_string_value("unknown") - assert value is None - - -def test_base_proxy_event_get_multi_value_query_string_values(): - default_values = ["default_1", "default_2"] - set_values = ["value_1", "value_2"] - - event = BaseProxyEvent({}) - values = event.get_multi_value_query_string_values("test", default_values) - assert values == default_values - - event._data["multiValueQueryStringParameters"] = {"test": set_values} - values = event.get_multi_value_query_string_values("test", default_values) - assert values == set_values - - values = event.get_multi_value_query_string_values("unknown", default_values) - assert values == default_values - - values = event.get_multi_value_query_string_values("unknown") - assert values == [] - - -def test_base_proxy_event_get_header_value(): - default_value = "default" - set_value = "value" - - event = BaseProxyEvent({"headers": {}}) - value = event.get_header_value("test", default_value) - assert value == default_value - - event._data["headers"] = {"test": set_value} - value = event.get_header_value("test", default_value) - assert value == set_value - - # Verify that the default look is case insensitive - value = event.get_header_value("Test") - assert value == set_value - - value = event.get_header_value("unknown", default_value) - assert value == default_value - - value = event.get_header_value("unknown") - assert value is None - - -def test_base_proxy_event_get_header_value_case_insensitive(): - default_value = "default" - set_value = "value" - - event = BaseProxyEvent({"headers": {}}) - - event._data["headers"] = {"Test": set_value} - value = event.get_header_value("test", case_sensitive=True) - assert value is None - - value = event.get_header_value("test", default_value=default_value, case_sensitive=True) - assert value == default_value - - value = event.get_header_value("Test", case_sensitive=True) - assert value == set_value - - value = event.get_header_value("unknown", default_value, case_sensitive=True) - assert value == default_value - - value = event.get_header_value("unknown", case_sensitive=True) - assert value is None - - def test_base_proxy_event_json_body(): data = {"message": "Foo"} event = BaseProxyEvent({"body": json.dumps(data)}) @@ -408,7 +324,7 @@ def test_reflected_types(): def lambda_handler(event: APIGatewayProxyEventV2, _): # THEN we except the event to be of the pass in data class type assert isinstance(event, APIGatewayProxyEventV2) - assert event.get_header_value("x-foo") == "Foo" + assert event.headers["x-foo"] == "Foo" # WHEN calling the lambda handler lambda_handler({"headers": {"X-Foo": "Foo"}}, None)