Skip to content

Commit b2c4915

Browse files
author
Michael Brewer
committed
feat(event-handler): allow for a custom serializer
1 parent 1135314 commit b2c4915

File tree

2 files changed

+49
-5
lines changed

2 files changed

+49
-5
lines changed

aws_lambda_powertools/event_handler/api_gateway.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import traceback
77
import zlib
88
from enum import Enum
9+
from functools import partial
910
from http import HTTPStatus
1011
from typing import Any, Callable, Dict, List, Optional, Set, Union
1112

@@ -263,6 +264,7 @@ def __init__(
263264
proxy_type: Enum = ProxyEventType.APIGatewayProxyEvent,
264265
cors: Optional[CORSConfig] = None,
265266
debug: Optional[bool] = None,
267+
serializer: Optional[Callable[[Dict], str]] = None,
266268
):
267269
"""
268270
Parameters
@@ -284,6 +286,14 @@ def __init__(
284286
env=os.getenv(constants.EVENT_HANDLER_DEBUG_ENV, "false"), choice=debug
285287
)
286288

289+
if serializer:
290+
self._serializer = serializer
291+
elif self._debug:
292+
"""Does a concise json serialization or pretty print when in debug mode"""
293+
self._serializer = partial(json.dumps, indent=4, cls=Encoder)
294+
else:
295+
self._serializer = partial(json.dumps, separators=(",", ":"), cls=Encoder)
296+
287297
def get(self, rule: str, cors: Optional[bool] = None, compress: bool = False, cache_control: Optional[str] = None):
288298
"""Get route decorator with GET `method`
289299
@@ -592,8 +602,4 @@ def _to_response(self, result: Union[Dict, Response]) -> Response:
592602
)
593603

594604
def _json_dump(self, obj: Any) -> str:
595-
"""Does a concise json serialization or pretty print when in debug mode"""
596-
if self._debug:
597-
return json.dumps(obj, indent=4, cls=Encoder)
598-
else:
599-
return json.dumps(obj, separators=(",", ":"), cls=Encoder)
605+
return self._serializer(obj)

tests/functional/event_handler/test_api_gateway.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
import zlib
44
from copy import deepcopy
55
from decimal import Decimal
6+
from enum import Enum
7+
from json import JSONEncoder
68
from pathlib import Path
79
from typing import Dict
810

@@ -728,3 +730,39 @@ def get_account(account_id: str):
728730

729731
ret = app.resolve(event, None)
730732
assert ret["statusCode"] == 200
733+
734+
735+
def test_custom_serializer():
736+
class Color(Enum):
737+
RED = 1
738+
BLUE = 2
739+
740+
class CustomEncoder(JSONEncoder):
741+
def default(self, data):
742+
if isinstance(data, Enum):
743+
return data.value
744+
try:
745+
iterable = iter(data)
746+
except TypeError:
747+
pass
748+
else:
749+
return list(iterable)
750+
return JSONEncoder.default(self, data)
751+
752+
def custom_serializer(data) -> str:
753+
return json.dumps(data, cls=CustomEncoder)
754+
755+
app = ApiGatewayResolver(serializer=custom_serializer)
756+
757+
@app.get("/colors")
758+
def get_color() -> Dict:
759+
return {
760+
"color": Color.RED,
761+
"variations": {"light", "dark"},
762+
}
763+
764+
response = app({"httpMethod": "GET", "path": "/colors"}, None)
765+
766+
body = response["body"]
767+
expected = '{"color": 1, "variations": ["light", "dark"]}'
768+
assert expected == body

0 commit comments

Comments
 (0)