Skip to content

Commit 05dd82b

Browse files
fix(event_handler): serialize pydantic/dataclasses in exception handler (#3455)
Co-authored-by: Leandro Damascena <[email protected]>
1 parent 4a49071 commit 05dd82b

File tree

6 files changed

+159
-10
lines changed

6 files changed

+159
-10
lines changed

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ complexity-baseline:
8484
$(info Maintenability index)
8585
poetry run radon mi aws_lambda_powertools
8686
$(info Cyclomatic complexity index)
87-
poetry run xenon --max-absolute C --max-modules A --max-average A aws_lambda_powertools
87+
poetry run xenon --max-absolute C --max-modules A --max-average A aws_lambda_powertools --exclude aws_lambda_powertools/shared/json_encoder.py
8888

8989
#
9090
# Use `poetry version <major>/<minor></patch>` for version bump

aws_lambda_powertools/event_handler/api_gateway.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -709,7 +709,7 @@ class ResponseBuilder(Generic[ResponseEventT]):
709709
def __init__(
710710
self,
711711
response: Response,
712-
serializer: Callable[[Any], str] = json.dumps,
712+
serializer: Callable[[Any], str] = partial(json.dumps, separators=(",", ":"), cls=Encoder),
713713
route: Optional[Route] = None,
714714
):
715715
self.response = response

aws_lambda_powertools/shared/functions.py

Lines changed: 82 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from __future__ import annotations
22

33
import base64
4-
import dataclasses
54
import itertools
65
import logging
76
import os
@@ -168,8 +167,86 @@ def extract_event_from_common_models(data: Any) -> Dict | Any:
168167
return data.raw_event
169168

170169
# Is it a Pydantic Model?
171-
if callable(getattr(data, "dict", None)):
172-
return data.dict()
170+
if is_pydantic(data):
171+
return pydantic_to_dict(data)
173172

174-
# Is it a Dataclass? If not return as is
175-
return dataclasses.asdict(data) if dataclasses.is_dataclass(data) else data
173+
# Is it a Dataclass?
174+
if is_dataclass(data):
175+
return dataclass_to_dict(data)
176+
177+
# Return as is
178+
return data
179+
180+
181+
def is_pydantic(data) -> bool:
182+
"""Whether data is a Pydantic model by checking common field available in v1/v2
183+
184+
Parameters
185+
----------
186+
data: BaseModel
187+
Pydantic model
188+
189+
Returns
190+
-------
191+
bool
192+
Whether it's a Pydantic model
193+
"""
194+
return getattr(data, "json", False)
195+
196+
197+
def is_dataclass(data) -> bool:
198+
"""Whether data is a dataclass
199+
200+
Parameters
201+
----------
202+
data: dataclass
203+
Dataclass obj
204+
205+
Returns
206+
-------
207+
bool
208+
Whether it's a Dataclass
209+
"""
210+
return getattr(data, "__dataclass_fields__", False)
211+
212+
213+
def pydantic_to_dict(data) -> dict:
214+
"""Dump Pydantic model v1 and v2 as dict.
215+
216+
Note we use lazy import since Pydantic is an optional dependency.
217+
218+
Parameters
219+
----------
220+
data: BaseModel
221+
Pydantic model
222+
223+
Returns
224+
-------
225+
226+
dict:
227+
Pydantic model serialized to dict
228+
"""
229+
from aws_lambda_powertools.event_handler.openapi.compat import _model_dump
230+
231+
return _model_dump(data)
232+
233+
234+
def dataclass_to_dict(data) -> dict:
235+
"""Dump standard dataclass as dict.
236+
237+
Note we use lazy import to prevent bloating other code parts.
238+
239+
Parameters
240+
----------
241+
data: dataclass
242+
Dataclass
243+
244+
Returns
245+
-------
246+
247+
dict:
248+
Pydantic model serialized to dict
249+
"""
250+
import dataclasses
251+
252+
return dataclasses.asdict(data)

aws_lambda_powertools/shared/json_encoder.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,25 @@
22
import json
33
import math
44

5+
from aws_lambda_powertools.shared.functions import dataclass_to_dict, is_dataclass, is_pydantic, pydantic_to_dict
6+
57

68
class Encoder(json.JSONEncoder):
7-
"""
8-
Custom JSON encoder to allow for serialization of Decimals, similar to the serializer used by Lambda internally.
9+
"""Custom JSON encoder to allow for serialization of Decimals, Pydantic and Dataclasses.
10+
11+
It's similar to the serializer used by Lambda internally.
912
"""
1013

1114
def default(self, obj):
1215
if isinstance(obj, decimal.Decimal):
1316
if obj.is_nan():
1417
return math.nan
1518
return str(obj)
19+
20+
if is_pydantic(obj):
21+
return pydantic_to_dict(obj)
22+
23+
if is_dataclass(obj):
24+
return dataclass_to_dict(obj)
25+
1626
return super().default(obj)

tests/functional/event_handler/test_api_gateway.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from typing import Dict
1111

1212
import pytest
13+
from pydantic import BaseModel
1314

1415
from aws_lambda_powertools.event_handler import content_types
1516
from aws_lambda_powertools.event_handler.api_gateway import (
@@ -1465,7 +1466,6 @@ def test_exception_handler_with_data_validation():
14651466

14661467
@app.exception_handler(RequestValidationError)
14671468
def handle_validation_error(ex: RequestValidationError):
1468-
print(f"request path is '{app.current_event.path}'")
14691469
return Response(
14701470
status_code=422,
14711471
content_type=content_types.TEXT_PLAIN,
@@ -1486,6 +1486,34 @@ def get_lambda(param: int):
14861486
assert result["body"] == "Invalid data. Number of errors: 1"
14871487

14881488

1489+
def test_exception_handler_with_data_validation_pydantic_response():
1490+
# GIVEN a resolver with an exception handler defined for RequestValidationError
1491+
app = ApiGatewayResolver(enable_validation=True)
1492+
1493+
class Err(BaseModel):
1494+
msg: str
1495+
1496+
@app.exception_handler(RequestValidationError)
1497+
def handle_validation_error(ex: RequestValidationError):
1498+
return Response(
1499+
status_code=422,
1500+
content_type=content_types.APPLICATION_JSON,
1501+
body=Err(msg=f"Invalid data. Number of errors: {len(ex.errors())}"),
1502+
)
1503+
1504+
@app.get("/my/path")
1505+
def get_lambda(param: int):
1506+
...
1507+
1508+
# WHEN calling the event handler
1509+
# AND a RequestValidationError is raised
1510+
result = app(LOAD_GW_EVENT, {})
1511+
1512+
# THEN exception handler's pydantic response should be serialized correctly
1513+
assert result["statusCode"] == 422
1514+
assert result["body"] == '{"msg":"Invalid data. Number of errors: 1"}'
1515+
1516+
14891517
def test_data_validation_error():
14901518
# GIVEN a resolver without an exception handler
14911519
app = ApiGatewayResolver(enable_validation=True)

tests/unit/test_json_encoder.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import decimal
22
import json
3+
from dataclasses import dataclass
34

45
import pytest
6+
from pydantic import BaseModel
57

68
from aws_lambda_powertools.shared.json_encoder import Encoder
79

@@ -22,3 +24,35 @@ class CustomClass:
2224

2325
with pytest.raises(TypeError):
2426
json.dumps({"val": CustomClass()}, cls=Encoder)
27+
28+
29+
def test_json_encode_pydantic():
30+
# GIVEN a Pydantic model
31+
class Model(BaseModel):
32+
data: dict
33+
34+
data = {"msg": "hello"}
35+
model = Model(data=data)
36+
37+
# WHEN json.dumps use our custom Encoder
38+
result = json.dumps(model, cls=Encoder)
39+
40+
# THEN we should serialize successfully; not raise a TypeError
41+
assert result == json.dumps({"data": data}, cls=Encoder)
42+
43+
44+
def test_json_encode_dataclasses():
45+
# GIVEN a standard dataclass
46+
47+
@dataclass
48+
class Model:
49+
data: dict
50+
51+
data = {"msg": "hello"}
52+
model = Model(data=data)
53+
54+
# WHEN json.dumps use our custom Encoder
55+
result = json.dumps(model, cls=Encoder)
56+
57+
# THEN we should serialize successfully; not raise a TypeError
58+
assert result == json.dumps({"data": data}, cls=Encoder)

0 commit comments

Comments
 (0)