diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index dce520c147d..d950bdc9c52 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -10,7 +10,7 @@ from enum import Enum from functools import partial from http import HTTPStatus -from typing import Any, Callable, Dict, List, Optional, Set, Union +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union from aws_lambda_powertools.event_handler import content_types from aws_lambda_powertools.event_handler.exceptions import ServiceError @@ -453,7 +453,7 @@ def __init__( def route( self, rule: str, - method: str, + method: Union[str, Union[List[str], Tuple[str]]], cors: Optional[bool] = None, compress: bool = False, cache_control: Optional[str] = None, @@ -461,19 +461,22 @@ def route( """Route decorator includes parameter `method`""" def register_resolver(func: Callable): - logger.debug(f"Adding route using rule {rule} and method {method.upper()}") + methods = (method,) if isinstance(method, str) else method + logger.debug(f"Adding route using rule {rule} and methods: {','.join((m.upper() for m in methods))}") if cors is None: cors_enabled = self._cors_enabled else: cors_enabled = cors - self._routes.append(Route(method, self._compile_regex(rule), func, cors_enabled, compress, cache_control)) - route_key = method + rule - if route_key in self._route_keys: - warnings.warn(f"A route like this was already registered. method: '{method}' rule: '{rule}'") - self._route_keys.append(route_key) - if cors_enabled: - logger.debug(f"Registering method {method.upper()} to Allow Methods in CORS") - self._cors_methods.add(method.upper()) + + for item in methods: + self._routes.append(Route(item, self._compile_regex(rule), func, cors_enabled, compress, cache_control)) + route_key = item + rule + if route_key in self._route_keys: + warnings.warn(f"A route like this was already registered. method: '{item}' rule: '{rule}'") + self._route_keys.append(route_key) + if cors_enabled: + logger.debug(f"Registering method {item.upper()} to Allow Methods in CORS") + self._cors_methods.add(item.upper()) return func return register_resolver @@ -679,14 +682,14 @@ def __init__(self): def route( self, rule: str, - method: Union[str, List[str]], + method: Union[str, Union[List[str], Tuple[str]]], cors: Optional[bool] = None, compress: bool = False, cache_control: Optional[str] = None, ): def register_route(func: Callable): - methods = method if isinstance(method, list) else [method] - for item in methods: - self._routes[(rule, item, cors, compress, cache_control)] = func + # Convert methods to tuple. It needs to be hashable as its part of the self._routes dict key + methods = (method,) if isinstance(method, str) else tuple(method) + self._routes[(rule, methods, cors, compress, cache_control)] = func return register_route diff --git a/docs/core/event_handler/api_gateway.md b/docs/core/event_handler/api_gateway.md index f9482edaacf..8c0d5e6621e 100644 --- a/docs/core/event_handler/api_gateway.md +++ b/docs/core/event_handler/api_gateway.md @@ -42,45 +42,27 @@ This is the sample infrastructure for API Gateway we are using for the examples Timeout: 5 Runtime: python3.8 Tracing: Active - Environment: + Environment: Variables: LOG_LEVEL: INFO POWERTOOLS_LOGGER_SAMPLE_RATE: 0.1 POWERTOOLS_LOGGER_LOG_EVENT: true POWERTOOLS_METRICS_NAMESPACE: MyServerlessApplication - POWERTOOLS_SERVICE_NAME: hello + POWERTOOLS_SERVICE_NAME: my_api-service Resources: - HelloWorldFunction: + ApiFunction: Type: AWS::Serverless::Function Properties: Handler: app.lambda_handler - CodeUri: hello_world - Description: Hello World function + CodeUri: api_handler/ + Description: API handler function Events: - HelloUniverse: - Type: Api - Properties: - Path: /hello - Method: GET - HelloYou: - Type: Api - Properties: - Path: /hello/{name} # see Dynamic routes section - Method: GET - CustomMessage: - Type: Api - Properties: - Path: /{message}/{name} # see Dynamic routes section - Method: GET - - Outputs: - HelloWorldApigwURL: - Description: "API Gateway endpoint URL for Prod environment for Hello World Function" - Value: !Sub "https://${ServerlessRestApi}.execute-api.${AWS::Region}.amazonaws.com/Prod/hello" - HelloWorldFunction: - Description: "Hello World Lambda Function ARN" - Value: !GetAtt HelloWorldFunction.Arn + ApiEvent: + Type: Api + Properties: + Path: /{proxy+} # Send requests on any path to the lambda function + Method: ANY # Send requests using any http method to the lambda function ``` ### API Gateway decorator @@ -360,6 +342,87 @@ You can also combine nested paths with greedy regex to catch in between routes. ... } ``` +### HTTP Methods +You can use named decorators to specify the HTTP method that should be handled in your functions. As well as the +`get` method already shown above, you can use `post`, `put`, `patch`, `delete`, and `patch`. + +=== "app.py" + + ```python hl_lines="9-10" + from aws_lambda_powertools import Logger, Tracer + from aws_lambda_powertools.logging import correlation_paths + from aws_lambda_powertools.event_handler.api_gateway import ApiGatewayResolver + + tracer = Tracer() + logger = Logger() + app = ApiGatewayResolver() + + # Only POST HTTP requests to the path /hello will route to this function + @app.post("/hello") + @tracer.capture_method + def get_hello_you(): + name = app.current_event.json_body.get("name") + return {"message": f"hello {name}"} + + # You can continue to use other utilities just as before + @logger.inject_lambda_context(correlation_id_path=correlation_paths.API_GATEWAY_REST) + @tracer.capture_lambda_handler + def lambda_handler(event, context): + return app.resolve(event, context) + ``` + +=== "sample_request.json" + + ```json + { + "resource": "/hello/{name}", + "path": "/hello/lessa", + "httpMethod": "GET", + ... + } + ``` + +If you need to accept multiple HTTP methods in a single function, you can use the `route` method and pass a list of +HTTP methods. + +=== "app.py" + + ```python hl_lines="9-10" + from aws_lambda_powertools import Logger, Tracer + from aws_lambda_powertools.logging import correlation_paths + from aws_lambda_powertools.event_handler.api_gateway import ApiGatewayResolver + + tracer = Tracer() + logger = Logger() + app = ApiGatewayResolver() + + # PUT and POST HTTP requests to the path /hello will route to this function + @app.route("/hello", method=["PUT", "POST"]) + @tracer.capture_method + def get_hello_you(): + name = app.current_event.json_body.get("name") + return {"message": f"hello {name}"} + + # You can continue to use other utilities just as before + @logger.inject_lambda_context(correlation_id_path=correlation_paths.API_GATEWAY_REST) + @tracer.capture_lambda_handler + def lambda_handler(event, context): + return app.resolve(event, context) + ``` + +=== "sample_request.json" + + ```json + { + "resource": "/hello/{name}", + "path": "/hello/lessa", + "httpMethod": "GET", + ... + } + ``` + +!!! note "It is usually better to have separate functions for each HTTP method, as the functionality tends to differ +depending on which method is used." ### Accessing request details diff --git a/tests/functional/event_handler/test_api_gateway.py b/tests/functional/event_handler/test_api_gateway.py index f4543fa300c..09594789ac3 100644 --- a/tests/functional/event_handler/test_api_gateway.py +++ b/tests/functional/event_handler/test_api_gateway.py @@ -1021,3 +1021,39 @@ def get_func_another_duplicate(): # THEN only execute the first registered route # AND print warnings assert result["statusCode"] == 200 + + +def test_route_multiple_methods(): + # GIVEN a function with http methods passed as a list + app = ApiGatewayResolver() + req = "foo" + get_event = deepcopy(LOAD_GW_EVENT) + get_event["resource"] = "/accounts/{account_id}" + get_event["path"] = f"/accounts/{req}" + + post_event = deepcopy(get_event) + post_event["httpMethod"] = "POST" + + put_event = deepcopy(get_event) + put_event["httpMethod"] = "PUT" + + lambda_context = {} + + @app.route(rule="/accounts/", method=["GET", "POST"]) + def foo(account_id): + assert app.lambda_context == lambda_context + assert account_id == f"{req}" + return {} + + # WHEN calling the event handler with the supplied methods + get_result = app(get_event, lambda_context) + post_result = app(post_event, lambda_context) + put_result = app(put_event, lambda_context) + + # THEN events are processed correctly + assert get_result["statusCode"] == 200 + assert get_result["headers"]["Content-Type"] == content_types.APPLICATION_JSON + assert post_result["statusCode"] == 200 + assert post_result["headers"]["Content-Type"] == content_types.APPLICATION_JSON + assert put_result["statusCode"] == 404 + assert put_result["headers"]["Content-Type"] == content_types.APPLICATION_JSON