Skip to content

Commit 33b8303

Browse files
committed
feat(event_handler): add support for multiple headers with same key
1 parent 452ed3f commit 33b8303

File tree

5 files changed

+53
-36
lines changed

5 files changed

+53
-36
lines changed

aws_lambda_powertools/event_handler/api_gateway.py

+17-16
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import warnings
88
import zlib
99
from abc import ABC, abstractmethod
10+
from collections import defaultdict
1011
from enum import Enum
1112
from functools import partial
1213
from http import HTTPStatus
@@ -122,18 +123,18 @@ def __init__(
122123
self.max_age = max_age
123124
self.allow_credentials = allow_credentials
124125

125-
def to_dict(self) -> Dict[str, str]:
126+
def to_dict(self) -> Dict[str, List[str]]:
126127
"""Builds the configured Access-Control http headers"""
127-
headers = {
128-
"Access-Control-Allow-Origin": self.allow_origin,
129-
"Access-Control-Allow-Headers": ",".join(sorted(self.allow_headers)),
130-
}
128+
headers: Dict[str, List[str]] = defaultdict(list)
129+
headers["Access-Control-Allow-Origin"].append(self.allow_origin)
130+
headers["Access-Control-Allow-Headers"].append(",".join(sorted(self.allow_headers)))
131+
131132
if self.expose_headers:
132-
headers["Access-Control-Expose-Headers"] = ",".join(self.expose_headers)
133+
headers["Access-Control-Expose-Headers"].append(",".join(self.expose_headers))
133134
if self.max_age is not None:
134-
headers["Access-Control-Max-Age"] = str(self.max_age)
135+
headers["Access-Control-Max-Age"].append(str(self.max_age))
135136
if self.allow_credentials is True:
136-
headers["Access-Control-Allow-Credentials"] = "true"
137+
headers["Access-Control-Allow-Credentials"].append("true")
137138
return headers
138139

139140

@@ -145,7 +146,7 @@ def __init__(
145146
status_code: int,
146147
content_type: Optional[str],
147148
body: Union[str, bytes, None],
148-
headers: Optional[Dict[str, str]] = None,
149+
headers: Optional[Dict[str, List[str]]] = None,
149150
cookies: Optional[List[str]] = None,
150151
):
151152
"""
@@ -159,18 +160,18 @@ def __init__(
159160
provided http headers
160161
body: Union[str, bytes, None]
161162
Optionally set the response body. Note: bytes body will be automatically base64 encoded
162-
headers: dict[str, str]
163+
headers: dict[str, List[str]]
163164
Optionally set specific http headers. Setting "Content-Type" here would override the `content_type` value.
164165
cookies: list[str]
165166
Optionally set cookies.
166167
"""
167168
self.status_code = status_code
168169
self.body = body
169170
self.base64_encoded = False
170-
self.headers: Dict[str, str] = headers or {}
171+
self.headers: Dict[str, List[str]] = defaultdict(list, **headers) if headers else defaultdict(list)
171172
self.cookies = cookies or []
172173
if content_type:
173-
self.headers.setdefault("Content-Type", content_type)
174+
self.headers.setdefault("Content-Type", [content_type])
174175

175176

176177
class Route:
@@ -200,11 +201,11 @@ def _add_cors(self, cors: CORSConfig):
200201

201202
def _add_cache_control(self, cache_control: str):
202203
"""Set the specified cache control headers for 200 http responses. For non-200 `no-cache` is used."""
203-
self.response.headers["Cache-Control"] = cache_control if self.response.status_code == 200 else "no-cache"
204+
self.response.headers["Cache-Control"].append(cache_control if self.response.status_code == 200 else "no-cache")
204205

205206
def _compress(self):
206207
"""Compress the response body, but only if `Accept-Encoding` headers includes gzip."""
207-
self.response.headers["Content-Encoding"] = "gzip"
208+
self.response.headers["Content-Encoding"].append("gzip")
208209
if isinstance(self.response.body, str):
209210
logger.debug("Converting string response to bytes before compressing it")
210211
self.response.body = bytes(self.response.body, "utf-8")
@@ -602,14 +603,14 @@ def _path_starts_with(path: str, prefix: str):
602603

603604
def _not_found(self, method: str) -> ResponseBuilder:
604605
"""Called when no matching route was found and includes support for the cors preflight response"""
605-
headers = {}
606+
headers: Dict[str, List[str]] = defaultdict(list)
606607
if self._cors:
607608
logger.debug("CORS is enabled, updating headers.")
608609
headers.update(self._cors.to_dict())
609610

610611
if method == "OPTIONS":
611612
logger.debug("Pre-flight request detected. Returning CORS with null response")
612-
headers["Access-Control-Allow-Methods"] = ",".join(sorted(self._cors_methods))
613+
headers["Access-Control-Allow-Methods"].append(",".join(sorted(self._cors_methods)))
613614
return ResponseBuilder(Response(status_code=204, content_type=None, headers=headers, body=""))
614615

615616
handler = self._lookup_exception_handler(NotFoundError)

aws_lambda_powertools/shared/headers_serializer.py

+27-11
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import warnings
22
from abc import ABC
3+
from collections import defaultdict
34
from typing import Any, Dict, List
45

56

@@ -8,14 +9,14 @@ class BaseHeadersSerializer(ABC):
89
Helper class to correctly serialize headers and cookies on the response payload.
910
"""
1011

11-
def serialize(self, headers: Dict[str, str], cookies: List[str]) -> Dict[str, Any]:
12+
def serialize(self, headers: Dict[str, List[str]], cookies: List[str]) -> Dict[str, Any]:
1213
"""
1314
Serializes headers and cookies according to the request type.
1415
Returns a dict that can be merged with the response payload.
1516
1617
Parameters
1718
----------
18-
headers: Dict[str, str]
19+
headers: Dict[str, List[str]]
1920
A dictionary of headers to set in the response
2021
cookies: List[str]
2122
A list of cookies to set in the response
@@ -24,7 +25,7 @@ def serialize(self, headers: Dict[str, str], cookies: List[str]) -> Dict[str, An
2425

2526

2627
class HttpApiSerializer(BaseHeadersSerializer):
27-
def serialize(self, headers: Dict[str, str], cookies: List[str]) -> Dict[str, Any]:
28+
def serialize(self, headers: Dict[str, List[str]], cookies: List[str]) -> Dict[str, Any]:
2829
"""
2930
When using HTTP APIs or LambdaFunctionURLs, everything is taken care automatically for us.
3031
We can directly assign a list of cookies and a dict of headers to the response payload, and the
@@ -33,11 +34,18 @@ def serialize(self, headers: Dict[str, str], cookies: List[str]) -> Dict[str, An
3334
https://docs.aws.amazon.com/apigateway/latest/developerguide/http-api-develop-integrations-lambda.html#http-api-develop-integrations-lambda.proxy-format
3435
https://docs.aws.amazon.com/apigateway/latest/developerguide/http-api-develop-integrations-lambda.html#http-api-develop-integrations-lambda.response
3536
"""
36-
return {"headers": headers, "cookies": cookies}
37+
38+
# Format 2.0 doesn't have multiValueHeaders or multiValueQueryStringParameters fields.
39+
# Duplicate headers are combined with commas and included in the headers field.
40+
combined_headers: Dict[str, str] = {}
41+
for key, values in headers.items():
42+
combined_headers[key] = ",".join(values)
43+
44+
return {"headers": combined_headers, "cookies": cookies}
3745

3846

3947
class MultiValueHeadersSerializer(BaseHeadersSerializer):
40-
def serialize(self, headers: Dict[str, str], cookies: List[str]) -> Dict[str, Any]:
48+
def serialize(self, headers: Dict[str, List[str]], cookies: List[str]) -> Dict[str, Any]:
4149
"""
4250
When using REST APIs, headers can be encoded using the `multiValueHeaders` key on the response.
4351
This is also the case when using an ALB integration with the `multiValueHeaders` option enabled.
@@ -46,10 +54,11 @@ def serialize(self, headers: Dict[str, str], cookies: List[str]) -> Dict[str, An
4654
https://docs.aws.amazon.com/apigateway/latest/developerguide/set-up-lambda-proxy-integrations.html#api-gateway-simple-proxy-for-lambda-output-format
4755
https://docs.aws.amazon.com/elasticloadbalancing/latest/application/lambda-functions.html#multi-value-headers-response
4856
"""
49-
payload: Dict[str, List[str]] = {}
57+
payload: Dict[str, List[str]] = defaultdict(list)
5058

51-
for key, value in headers.items():
52-
payload[key] = [value]
59+
for key, values in headers.items():
60+
for value in values:
61+
payload[key].append(value)
5362

5463
if cookies:
5564
payload.setdefault("Set-Cookie", [])
@@ -60,7 +69,7 @@ def serialize(self, headers: Dict[str, str], cookies: List[str]) -> Dict[str, An
6069

6170

6271
class SingleValueHeadersSerializer(BaseHeadersSerializer):
63-
def serialize(self, headers: Dict[str, str], cookies: List[str]) -> Dict[str, Any]:
72+
def serialize(self, headers: Dict[str, List[str]], cookies: List[str]) -> Dict[str, Any]:
6473
"""
6574
The ALB integration has `multiValueHeaders` disabled by default.
6675
If we try to set multiple headers with the same key, or more than one cookie, print a warning.
@@ -80,7 +89,14 @@ def serialize(self, headers: Dict[str, str], cookies: List[str]) -> Dict[str, An
8089
# We can only send one cookie, send the last one
8190
payload["headers"]["Set-Cookie"] = cookies[-1]
8291

83-
for key, value in headers.items():
84-
payload["headers"][key] = value
92+
for key, values in headers.items():
93+
if len(values) > 1:
94+
warnings.warn(
95+
"Can't encode more than one header value for the same key in the response. "
96+
"Did you enable multiValueHeaders on the ALB Target Group?"
97+
)
98+
99+
# We can only set one header per key, send the last one
100+
payload["headers"][key] = values[-1]
85101

86102
return payload

examples/event_handler_rest/src/fine_grained_responses.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def get_todos():
1919
todos: requests.Response = requests.get("https://jsonplaceholder.typicode.com/todos")
2020
todos.raise_for_status()
2121

22-
custom_headers = {"X-Transaction-Id": f"{uuid4()}"}
22+
custom_headers = {"X-Transaction-Id": [f"{uuid4()}"]}
2323

2424
return Response(
2525
status_code=HTTPStatus.OK.value, # 200

tests/functional/event_handler/test_api_gateway.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -441,7 +441,7 @@ def rest_func() -> Response:
441441
status_code=404,
442442
content_type="used-if-not-set-in-header",
443443
body="Not found",
444-
headers={"Content-Type": "header-content-type-wins", "custom": "value"},
444+
headers={"Content-Type": ["header-content-type-wins"], "custom": ["value"]},
445445
)
446446

447447
# WHEN calling the event handler
@@ -573,7 +573,7 @@ def custom_preflight():
573573
status_code=200,
574574
content_type=content_types.TEXT_HTML,
575575
body="Foo",
576-
headers={"Access-Control-Allow-Methods": "CUSTOM"},
576+
headers={"Access-Control-Allow-Methods": ["CUSTOM"]},
577577
)
578578

579579
@app.route(method="CUSTOM", rule="/some-call", cors=True)

tests/functional/event_handler/test_headers_serializer.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,13 @@ def test_headers_serializer_http_api():
1313
payload = serializer.serialize(cookies=[], headers={})
1414
assert payload == {"cookies": [], "headers": {}}
1515

16-
payload = serializer.serialize(cookies=[], headers={"Content-Type": "text/html"})
16+
payload = serializer.serialize(cookies=[], headers={"Content-Type": ["text/html"]})
1717
assert payload == {"cookies": [], "headers": {"Content-Type": "text/html"}}
1818

1919
payload = serializer.serialize(cookies=["UUID=12345"], headers={})
2020
assert payload == {"cookies": ["UUID=12345"], "headers": {}}
2121

22-
payload = serializer.serialize(cookies=["UUID=12345", "SSID=0xdeadbeef"], headers={"Foo": "bar,zbr"})
22+
payload = serializer.serialize(cookies=["UUID=12345", "SSID=0xdeadbeef"], headers={"Foo": ["bar,zbr"]})
2323
assert payload == {"cookies": ["UUID=12345", "SSID=0xdeadbeef"], "headers": {"Foo": "bar,zbr"}}
2424

2525

@@ -29,13 +29,13 @@ def test_headers_serializer_multi_value_headers():
2929
payload = serializer.serialize(cookies=[], headers={})
3030
assert payload == {"multiValueHeaders": {}}
3131

32-
payload = serializer.serialize(cookies=[], headers={"Content-Type": "text/html"})
32+
payload = serializer.serialize(cookies=[], headers={"Content-Type": ["text/html"]})
3333
assert payload == {"multiValueHeaders": {"Content-Type": ["text/html"]}}
3434

3535
payload = serializer.serialize(cookies=["UUID=12345"], headers={})
3636
assert payload == {"multiValueHeaders": {"Set-Cookie": ["UUID=12345"]}}
3737

38-
payload = serializer.serialize(cookies=["UUID=12345", "SSID=0xdeadbeef"], headers={"Foo": "bar,zbr"})
38+
payload = serializer.serialize(cookies=["UUID=12345", "SSID=0xdeadbeef"], headers={"Foo": ["bar,zbr"]})
3939
assert payload == {"multiValueHeaders": {"Set-Cookie": ["UUID=12345", "SSID=0xdeadbeef"], "Foo": ["bar,zbr"]}}
4040

4141

@@ -45,7 +45,7 @@ def test_headers_serializer_single_value_headers():
4545
payload = serializer.serialize(cookies=[], headers={})
4646
assert payload == {"headers": {}}
4747

48-
payload = serializer.serialize(cookies=[], headers={"Content-Type": "text/html"})
48+
payload = serializer.serialize(cookies=[], headers={"Content-Type": ["text/html"]})
4949
assert payload == {"headers": {"Content-Type": "text/html"}}
5050

5151
payload = serializer.serialize(cookies=["UUID=12345"], headers={})
@@ -54,7 +54,7 @@ def test_headers_serializer_single_value_headers():
5454
with warnings.catch_warnings(record=True) as w:
5555
warnings.simplefilter("default")
5656

57-
payload = serializer.serialize(cookies=["UUID=12345", "SSID=0xdeadbeef"], headers={"Foo": "bar,zbr"})
57+
payload = serializer.serialize(cookies=["UUID=12345", "SSID=0xdeadbeef"], headers={"Foo": ["bar,zbr"]})
5858
assert payload == {"headers": {"Set-Cookie": "SSID=0xdeadbeef", "Foo": "bar,zbr"}}
5959

6060
assert len(w) == 1

0 commit comments

Comments
 (0)