1
1
import inspect
2
- from copy import copy
3
2
from enum import Enum
4
- from typing import Any , Callable , Dict , List , Optional , Tuple , Union
3
+ from typing import Any , Callable , Dict , List , Optional , Tuple , Type , Union
5
4
6
5
from pydantic import BaseConfig
7
6
from pydantic .fields import FieldInfo
8
- from typing_extensions import Annotated , get_args , get_origin
7
+ from typing_extensions import Annotated , Literal , get_args , get_origin
9
8
10
9
from aws_lambda_powertools .event_handler .openapi .compat import (
11
10
ModelField ,
12
11
Required ,
13
12
Undefined ,
13
+ UndefinedType ,
14
+ copy_field_info ,
14
15
get_annotation_from_field_info ,
15
16
)
16
17
from aws_lambda_powertools .event_handler .openapi .types import PYDANTIC_V2 , CacheKey
@@ -302,7 +303,8 @@ def analyze_param(
302
303
annotation : Any ,
303
304
value : Any ,
304
305
is_path_param : bool ,
305
- ) -> Tuple [Any , Optional [ModelField ]]:
306
+ is_response_param : bool ,
307
+ ) -> Optional [ModelField ]:
306
308
"""
307
309
Analyze a parameter annotation and value to determine the type and default value of the parameter.
308
310
@@ -316,10 +318,12 @@ def analyze_param(
316
318
The value of the parameter
317
319
is_path_param
318
320
Whether the parameter is a path parameter
321
+ is_response_param
322
+ Whether the parameter is the return annotation
319
323
320
324
Returns
321
325
-------
322
- Tuple[Any, Optional[ModelField] ]
326
+ Optional[ModelField]
323
327
The type annotation and the Pydantic field representing the parameter
324
328
"""
325
329
field_info , type_annotation = _get_field_info_and_type_annotation (annotation , value , is_path_param )
@@ -336,12 +340,16 @@ def analyze_param(
336
340
337
341
# Check if the parameter is part of the path. Otherwise, defaults to query.
338
342
if is_path_param :
339
- field_info = Path (annotation = type_annotation , default = default_value )
343
+ field_info = Path (annotation = type_annotation )
340
344
else :
341
345
field_info = Query (annotation = type_annotation , default = default_value )
342
346
347
+ # When we have a response field, we need to set the default value to Required
348
+ if is_response_param :
349
+ field_info .default = Required
350
+
343
351
field = _create_model_field (field_info , type_annotation , param_name , is_path_param )
344
- return type_annotation , field
352
+ return field
345
353
346
354
347
355
def _get_field_info_and_type_annotation (annotation , value , is_path_param : bool ) -> Tuple [Optional [FieldInfo ], Any ]:
@@ -372,7 +380,10 @@ def _get_field_info_annotated_type(annotation, value, is_path_param: bool) -> Tu
372
380
373
381
if isinstance (powertools_annotation , FieldInfo ):
374
382
# Copy `field_info` because we mutate `field_info.default` later
375
- field_info = copy (powertools_annotation )
383
+ field_info = copy_field_info (
384
+ field_info = powertools_annotation ,
385
+ annotation = annotation ,
386
+ )
376
387
if field_info .default not in [Undefined , Required ]:
377
388
raise AssertionError ("FieldInfo needs to have a default value of Undefined or Required" )
378
389
@@ -386,6 +397,44 @@ def _get_field_info_annotated_type(annotation, value, is_path_param: bool) -> Tu
386
397
return field_info , type_annotation
387
398
388
399
400
+ def _create_response_field (
401
+ name : str ,
402
+ type_ : Type [Any ],
403
+ default : Optional [Any ] = Undefined ,
404
+ required : Union [bool , UndefinedType ] = Undefined ,
405
+ model_config : Type [BaseConfig ] = BaseConfig ,
406
+ field_info : Optional [FieldInfo ] = None ,
407
+ alias : Optional [str ] = None ,
408
+ mode : Literal ["validation" , "serialization" ] = "validation" ,
409
+ ) -> ModelField :
410
+ """
411
+ Create a new response field. Raises if type_ is invalid.
412
+ """
413
+ if PYDANTIC_V2 :
414
+ field_info = field_info or FieldInfo (
415
+ annotation = type_ ,
416
+ default = default ,
417
+ alias = alias ,
418
+ )
419
+ else :
420
+ field_info = field_info or FieldInfo ()
421
+ kwargs = {"name" : name , "field_info" : field_info }
422
+ if PYDANTIC_V2 :
423
+ kwargs .update ({"mode" : mode })
424
+ else :
425
+ kwargs .update (
426
+ {
427
+ "type_" : type_ ,
428
+ "class_validators" : {},
429
+ "default" : default ,
430
+ "required" : required ,
431
+ "model_config" : model_config ,
432
+ "alias" : alias ,
433
+ },
434
+ )
435
+ return ModelField (** kwargs ) # type: ignore[arg-type]
436
+
437
+
389
438
def _create_model_field (
390
439
field_info : Optional [FieldInfo ],
391
440
type_annotation : Any ,
@@ -411,21 +460,11 @@ def _create_model_field(
411
460
alias = field_info .alias or param_name
412
461
field_info .alias = alias
413
462
414
- # Create the Pydantic field
415
- kwargs = {"name" : param_name , "field_info" : field_info }
416
-
417
- if PYDANTIC_V2 :
418
- kwargs .update ({"mode" : "validation" })
419
- else :
420
- kwargs .update (
421
- {
422
- "type_" : use_annotation ,
423
- "class_validators" : {},
424
- "default" : field_info .default ,
425
- "required" : field_info .default in (Required , Undefined ),
426
- "model_config" : BaseConfig ,
427
- "alias" : alias ,
428
- },
429
- )
430
-
431
- return ModelField (** kwargs ) # type: ignore[arg-type]
463
+ return _create_response_field (
464
+ name = param_name ,
465
+ type_ = use_annotation ,
466
+ default = field_info .default ,
467
+ alias = alias ,
468
+ required = field_info .default in (Required , Undefined ),
469
+ field_info = field_info ,
470
+ )
0 commit comments