|
1 | 1 | import logging
|
2 |
| -from typing import Any, Callable, Optional, Type, TypeVar |
| 2 | +from typing import Any, Callable, Optional, Type, TypeVar, List, Union |
3 | 3 |
|
4 | 4 | from aws_lambda_powertools.utilities.data_classes import AppSyncResolverEvent
|
5 | 5 | from aws_lambda_powertools.utilities.typing import LambdaContext
|
| 6 | +from itertools import groupby |
| 7 | +from operator import itemgetter |
6 | 8 |
|
7 | 9 | logger = logging.getLogger(__name__)
|
8 | 10 |
|
9 | 11 | AppSyncResolverEventT = TypeVar("AppSyncResolverEventT", bound=AppSyncResolverEvent)
|
10 | 12 |
|
11 | 13 |
|
12 | 14 | class BaseRouter:
|
13 |
| - current_event: AppSyncResolverEventT # type: ignore[valid-type] |
| 15 | + current_event: Union[AppSyncResolverEventT, List[AppSyncResolverEventT]] # type: ignore[valid-type] |
14 | 16 | lambda_context: LambdaContext
|
15 | 17 | context: dict
|
16 | 18 |
|
@@ -152,11 +154,26 @@ def lambda_handler(event, context):
|
152 | 154 | If we could not find a field resolver
|
153 | 155 | """
|
154 | 156 | # Maintenance: revisit generics/overload to fix [attr-defined] in mypy usage
|
155 |
| - BaseRouter.current_event = data_model(event) |
| 157 | + |
| 158 | + # If event is a list it means that AppSync sent batch request |
| 159 | + if isinstance(event, list): |
| 160 | + event_groups = [ |
| 161 | + {"field_name": field_name, "events": list(events)} |
| 162 | + for field_name, events in groupby(event, key=lambda x: x["info"]["fieldName"]) |
| 163 | + ] |
| 164 | + if len(event_groups) > 1: |
| 165 | + ValueError("batch with different field names. It shouldn't happen!") |
| 166 | + |
| 167 | + BaseRouter.current_event = [data_model(event) for event in event_groups[0]["events"]] |
| 168 | + |
| 169 | + resolver = self._get_resolver(BaseRouter.current_event[0].type_name, event_groups[0]["field_name"]) |
| 170 | + response = resolver() |
| 171 | + else: |
| 172 | + BaseRouter.current_event = data_model(event) |
| 173 | + resolver = self._get_resolver(BaseRouter.current_event.type_name, BaseRouter.current_event.field_name) |
| 174 | + response = resolver(**BaseRouter.current_event.arguments) |
156 | 175 | BaseRouter.lambda_context = context
|
157 | 176 |
|
158 |
| - resolver = self._get_resolver(BaseRouter.current_event.type_name, BaseRouter.current_event.field_name) |
159 |
| - response = resolver(**BaseRouter.current_event.arguments) |
160 | 177 | self.clear_context()
|
161 | 178 |
|
162 | 179 | return response
|
|
0 commit comments