Skip to content

Commit 2810da1

Browse files
committed
feat: Add function to decode nested messages on SQS events
1 parent 8944f38 commit 2810da1

File tree

5 files changed

+180
-14
lines changed

5 files changed

+180
-14
lines changed

aws_lambda_powertools/utilities/data_classes/event_source.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
def event_source(
1010
handler: Callable[[Any, LambdaContext], Any],
1111
event: Dict[str, Any],
12+
# optional property: original_event_source ??? (what if s3 -> sns -> sqs? should this be recursive?)
1213
context: LambdaContext,
1314
data_class: Type[DictWrapper],
1415
):

aws_lambda_powertools/utilities/data_classes/sns_event.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -20,63 +20,63 @@ class SNSMessage(DictWrapper):
2020
@property
2121
def signature_version(self) -> str:
2222
"""Version of the Amazon SNS signature used."""
23-
return self["Sns"]["SignatureVersion"]
23+
return self["SignatureVersion"]
2424

2525
@property
2626
def timestamp(self) -> str:
2727
"""The time (GMT) when the subscription confirmation was sent."""
28-
return self["Sns"]["Timestamp"]
28+
return self["Timestamp"]
2929

3030
@property
3131
def signature(self) -> str:
3232
"""Base64-encoded "SHA1withRSA" signature of the Message, MessageId, Type, Timestamp, and TopicArn values."""
33-
return self["Sns"]["Signature"]
33+
return self["Signature"]
3434

3535
@property
3636
def signing_cert_url(self) -> str:
3737
"""The URL to the certificate that was used to sign the message."""
38-
return self["Sns"]["SigningCertUrl"]
38+
return self["SigningCertUrl"]
3939

4040
@property
4141
def message_id(self) -> str:
4242
"""A Universally Unique Identifier, unique for each message published.
4343
4444
For a message that Amazon SNS resends during a retry, the message ID of the original message is used."""
45-
return self["Sns"]["MessageId"]
45+
return self["MessageId"]
4646

4747
@property
4848
def message(self) -> str:
4949
"""A string that describes the message."""
50-
return self["Sns"]["Message"]
50+
return self["Message"]
5151

5252
@property
5353
def message_attributes(self) -> Dict[str, SNSMessageAttribute]:
54-
return {k: SNSMessageAttribute(v) for (k, v) in self["Sns"]["MessageAttributes"].items()}
54+
return {k: SNSMessageAttribute(v) for (k, v) in self["MessageAttributes"].items()}
5555

5656
@property
5757
def get_type(self) -> str:
5858
"""The type of message.
5959
6060
For a subscription confirmation, the type is SubscriptionConfirmation."""
6161
# Note: this name conflicts with existing python builtins
62-
return self["Sns"]["Type"]
62+
return self["Type"]
6363

6464
@property
6565
def unsubscribe_url(self) -> str:
6666
"""A URL that you can use to unsubscribe the endpoint from this topic.
6767
6868
If you visit this URL, Amazon SNS unsubscribes the endpoint and stops sending notifications to this endpoint."""
69-
return self["Sns"]["UnsubscribeUrl"]
69+
return self["UnsubscribeUrl"]
7070

7171
@property
7272
def topic_arn(self) -> str:
7373
"""The Amazon Resource Name (ARN) for the topic that this endpoint is subscribed to."""
74-
return self["Sns"]["TopicArn"]
74+
return self["TopicArn"]
7575

7676
@property
7777
def subject(self) -> str:
7878
"""The Subject parameter specified when the notification was published to the topic."""
79-
return self["Sns"]["Subject"]
79+
return self["Subject"]
8080

8181

8282
class SNSEventRecord(DictWrapper):
@@ -96,7 +96,7 @@ def event_source(self) -> str:
9696

9797
@property
9898
def sns(self) -> SNSMessage:
99-
return SNSMessage(self._data)
99+
return SNSMessage(self._data["Sns"])
100100

101101

102102
class SNSEvent(DictWrapper):

aws_lambda_powertools/utilities/data_classes/sqs_event.py

Lines changed: 62 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1-
from typing import Any, Dict, Iterator, Optional
1+
from typing import Any, Dict, Iterator, Optional, Type, TypeVar
22

3+
from aws_lambda_powertools.utilities.data_classes import S3Event
34
from aws_lambda_powertools.utilities.data_classes.common import DictWrapper
5+
from aws_lambda_powertools.utilities.data_classes.sns_event import SNSMessage
46

57

68
class SQSRecordAttributes(DictWrapper):
@@ -83,6 +85,8 @@ def __getitem__(self, key: str) -> Optional[SQSMessageAttribute]: # type: ignor
8385
class SQSRecord(DictWrapper):
8486
"""An Amazon SQS message"""
8587

88+
NestedEvent = TypeVar("NestedEvent", bound=DictWrapper)
89+
8690
@property
8791
def message_id(self) -> str:
8892
"""A unique identifier for the message.
@@ -174,6 +178,63 @@ def queue_url(self) -> str:
174178

175179
return queue_url
176180

181+
@property
182+
def decode_nested_s3_event(self) -> S3Event:
183+
"""Returns the nested `S3Event` object that is sent in the body of a SQS message.
184+
185+
Even though you can typecast the object returned by `record.json_body`
186+
directly, this method is provided as a shortcut for convenience.
187+
188+
Notes
189+
-----
190+
191+
This method does not validate whether the SQS message body is actually a valid S3 event.
192+
193+
Examples
194+
--------
195+
196+
```python
197+
nested_event: S3Event = record.decode_nested_s3_event
198+
```
199+
"""
200+
return self._decode_nested_event(S3Event)
201+
202+
@property
203+
def decode_nested_sns_event(self) -> SNSMessage:
204+
"""Returns the nested `SNSMessage` object that is sent in the body of a SQS message.
205+
206+
Even though you can typecast the object returned by `record.json_body`
207+
directly, this method is provided as a shortcut for convenience.
208+
209+
Notes
210+
-----
211+
212+
This method does not validate whether the SQS message body is actually
213+
a valid SNS message.
214+
215+
Examples
216+
--------
217+
218+
```python
219+
nested_message: SNSMessage = record.decode_nested_sns_event
220+
```
221+
"""
222+
return self._decode_nested_event(SNSMessage)
223+
224+
def _decode_nested_event(self, nested_event_class: Type[NestedEvent]) -> NestedEvent:
225+
"""Returns the nested event source data object.
226+
227+
This is useful for handling events that are sent in the body of a SQS message.
228+
229+
Examples
230+
--------
231+
232+
```python
233+
data: S3Event = self._decode_nested_event(S3Event)
234+
```
235+
"""
236+
return nested_event_class(self.json_body)
237+
177238

178239
class SQSEvent(DictWrapper):
179240
"""SQS Event

tests/functional/test_data_classes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -935,7 +935,7 @@ def test_sns_trigger_event():
935935
assert event.sns_message == "Hello from SNS!"
936936

937937

938-
def test_seq_trigger_event():
938+
def test_sqs_trigger_event():
939939
event = SQSEvent(load_event("sqsEvent.json"))
940940

941941
records = list(event.records)
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
import json
2+
from typing import Dict
3+
4+
import pytest
5+
6+
from aws_lambda_powertools.utilities.data_classes import S3Event, SQSEvent
7+
from aws_lambda_powertools.utilities.data_classes.sns_event import SNSMessage
8+
from tests.functional.utils import load_event
9+
10+
11+
@pytest.mark.parametrize(
12+
"raw_event",
13+
[
14+
pytest.param(load_event("s3SqsEvent.json")),
15+
],
16+
ids=["s3_sqs"],
17+
)
18+
def test_decode_nested_s3_event(raw_event: Dict):
19+
event = SQSEvent(raw_event)
20+
21+
records = list(event.records)
22+
record = records[0]
23+
attributes = record.attributes
24+
25+
assert len(records) == 1
26+
assert record.message_id == "ca3e7a89-c358-40e5-8aa0-5da01403c267"
27+
assert attributes.aws_trace_header is None
28+
assert attributes.approximate_receive_count == "1"
29+
assert attributes.sent_timestamp == "1681332219270"
30+
assert attributes.sender_id == "AIDAJHIPRHEMV73VRJEBU"
31+
assert attributes.approximate_first_receive_timestamp == "1681332239270"
32+
assert attributes.sequence_number is None
33+
assert attributes.message_group_id is None
34+
assert attributes.message_deduplication_id is None
35+
assert record.md5_of_body == "16f4460f4477d8d693a5abe94fdbbd73"
36+
assert record.event_source == "aws:sqs"
37+
assert record.event_source_arn == "arn:aws:sqs:us-east-1:123456789012:SQS"
38+
assert record.aws_region == "us-east-1"
39+
40+
s3_event: S3Event = record.decode_nested_s3_event
41+
s3_record = s3_event.record
42+
43+
assert s3_event.bucket_name == "xxx"
44+
assert s3_event.object_key == "test.pdf"
45+
assert s3_record.aws_region == "us-east-1"
46+
assert s3_record.event_name == "ObjectCreated:Put"
47+
assert s3_record.event_source == "aws:s3"
48+
assert s3_record.event_time == "2023-04-12T20:43:38.021Z"
49+
assert s3_record.event_version == "2.1"
50+
assert s3_record.glacier_event_data is None
51+
assert s3_record.request_parameters.source_ip_address == "93.108.161.96"
52+
assert s3_record.response_elements["x-amz-request-id"] == "YMSSR8BZJ2Y99K6P"
53+
assert s3_record.s3.s3_schema_version == "1.0"
54+
assert s3_record.s3.bucket.arn == "arn:aws:s3:::xxx"
55+
assert s3_record.s3.bucket.name == "xxx"
56+
assert s3_record.s3.bucket.owner_identity.principal_id == "A1YQ72UWCM96UF"
57+
assert s3_record.s3.configuration_id == "SNS"
58+
assert s3_record.s3.get_object.etag == "2e3ad1e983318bbd8e73b080e2997980"
59+
assert s3_record.s3.get_object.key == "test.pdf"
60+
assert s3_record.s3.get_object.sequencer == "00643717F9F8B85354"
61+
assert s3_record.s3.get_object.size == 104681
62+
assert s3_record.s3.get_object.version_id == "yd3d4HaWOT2zguDLvIQLU6ptDTwKBnQV"
63+
assert s3_record.user_identity.principal_id == "A1YQ72UWCM96UF"
64+
65+
66+
@pytest.mark.parametrize(
67+
"raw_event",
68+
[
69+
pytest.param(load_event("snsSqsEvent.json")),
70+
],
71+
ids=["sns_sqs"],
72+
)
73+
def test_decode_nested_sns_event(raw_event: Dict):
74+
event = SQSEvent(raw_event)
75+
76+
records = list(event.records)
77+
record = records[0]
78+
attributes = record.attributes
79+
80+
assert len(records) == 1
81+
assert record.message_id == "79406a00-bf15-46ca-978c-22c3613fcb30"
82+
assert attributes.aws_trace_header is None
83+
assert attributes.approximate_receive_count == "1"
84+
assert attributes.sent_timestamp == "1611050827340"
85+
assert attributes.sender_id == "AIDAISMY7JYY5F7RTT6AO"
86+
assert attributes.approximate_first_receive_timestamp == "1611050827344"
87+
assert attributes.sequence_number is None
88+
assert attributes.message_group_id is None
89+
assert attributes.message_deduplication_id is None
90+
assert record.md5_of_body == "8910bdaaf9a30a607f7891037d4af0b0"
91+
assert record.event_source == "aws:sqs"
92+
assert record.event_source_arn == "arn:aws:sqs:eu-west-1:231436140809:powertools265"
93+
assert record.aws_region == "eu-west-1"
94+
95+
sns_message: SNSMessage = record.decode_nested_sns_event
96+
message = json.loads(sns_message.message)
97+
98+
assert sns_message.get_type == "Notification"
99+
assert sns_message.message_id == "d88d4479-6ec0-54fe-b63f-1cf9df4bb16e"
100+
assert sns_message.topic_arn == "arn:aws:sns:eu-west-1:231436140809:powertools265"
101+
assert sns_message.timestamp == "2021-01-19T10:07:07.287Z"
102+
assert sns_message.signature_version == "1"
103+
assert message["message"] == "hello world"
104+
assert message["username"] == "lessa"

0 commit comments

Comments
 (0)