Skip to content

Commit 4a4f56b

Browse files
committed
model server might have already done a serialization. honor that by not decoding the request again if it is not already bytes or bytestream
1 parent 2102bb7 commit 4a4f56b

File tree

3 files changed

+36
-18
lines changed

3 files changed

+36
-18
lines changed

src/sagemaker/serve/model_server/multi_model_server/inference.py

+12-6
Original file line numberDiff line numberDiff line change
@@ -46,18 +46,24 @@ def input_fn(input_data, content_type, context=None):
4646
if hasattr(schema_builder, "custom_input_translator"):
4747
deserialized_data = schema_builder.custom_input_translator.deserialize(
4848
(
49-
io.BytesIO(input_data)
50-
if type(input_data) == bytes
51-
else io.BytesIO(input_data.encode("utf-8"))
49+
input_data
50+
if not any([
51+
isinstance(input_data, bytes),
52+
isinstance(input_data, bytearray),
53+
])
54+
else io.BytesIO(input_data)
5255
),
5356
content_type,
5457
)
5558
else:
5659
deserialized_data = schema_builder.input_deserializer.deserialize(
5760
(
58-
io.BytesIO(input_data)
59-
if type(input_data) == bytes
60-
else io.BytesIO(input_data.encode("utf-8"))
61+
input_data
62+
if not any([
63+
isinstance(input_data, bytes),
64+
isinstance(input_data, bytearray),
65+
])
66+
else io.BytesIO(input_data)
6167
),
6268
content_type[0],
6369
)

src/sagemaker/serve/model_server/torchserve/inference.py

+12-6
Original file line numberDiff line numberDiff line change
@@ -68,18 +68,24 @@ def input_fn(input_data, content_type):
6868
if hasattr(schema_builder, "custom_input_translator"):
6969
deserialized_data = schema_builder.custom_input_translator.deserialize(
7070
(
71-
io.BytesIO(input_data)
72-
if type(input_data) == bytes
73-
else io.BytesIO(input_data.encode("utf-8"))
71+
input_data
72+
if not any([
73+
isinstance(input_data, bytes),
74+
isinstance(input_data, bytearray),
75+
])
76+
else io.BytesIO(input_data)
7477
),
7578
content_type,
7679
)
7780
else:
7881
deserialized_data = schema_builder.input_deserializer.deserialize(
7982
(
80-
io.BytesIO(input_data)
81-
if type(input_data) == bytes
82-
else io.BytesIO(input_data.encode("utf-8"))
83+
input_data
84+
if not any([
85+
isinstance(input_data, bytes),
86+
isinstance(input_data, bytearray),
87+
])
88+
else io.BytesIO(input_data)
8389
),
8490
content_type[0],
8591
)

src/sagemaker/serve/model_server/torchserve/xgboost_inference.py

+12-6
Original file line numberDiff line numberDiff line change
@@ -71,18 +71,24 @@ def input_fn(input_data, content_type):
7171
if hasattr(schema_builder, "custom_input_translator"):
7272
return schema_builder.custom_input_translator.deserialize(
7373
(
74-
io.BytesIO(input_data)
75-
if type(input_data) == bytes
76-
else io.BytesIO(input_data.encode("utf-8"))
74+
input_data
75+
if not any([
76+
isinstance(input_data, bytes),
77+
isinstance(input_data, bytearray),
78+
])
79+
else io.BytesIO(input_data)
7780
),
7881
content_type,
7982
)
8083
else:
8184
return schema_builder.input_deserializer.deserialize(
8285
(
83-
io.BytesIO(input_data)
84-
if type(input_data) == bytes
85-
else io.BytesIO(input_data.encode("utf-8"))
86+
input_data
87+
if not any([
88+
isinstance(input_data, bytes),
89+
isinstance(input_data, bytearray),
90+
])
91+
else io.BytesIO(input_data)
8692
),
8793
content_type[0],
8894
)

0 commit comments

Comments
 (0)