diff --git a/src/sagemaker/serve/model_server/multi_model_server/inference.py b/src/sagemaker/serve/model_server/multi_model_server/inference.py index 595b9d9c39..3cece40c5e 100644 --- a/src/sagemaker/serve/model_server/multi_model_server/inference.py +++ b/src/sagemaker/serve/model_server/multi_model_server/inference.py @@ -45,11 +45,11 @@ def input_fn(input_data, content_type): try: if hasattr(schema_builder, "custom_input_translator"): deserialized_data = schema_builder.custom_input_translator.deserialize( - io.BytesIO(input_data), content_type + io.BytesIO(input_data) if type(input_data)== bytes else io.BytesIO(input_data.encode('utf-8')), content_type ) else: deserialized_data = schema_builder.input_deserializer.deserialize( - io.BytesIO(input_data), content_type[0] + io.BytesIO(input_data) if type(input_data)== bytes else io.BytesIO(input_data.encode('utf-8')), content_type[0] ) # Check if preprocess method is defined and call it diff --git a/src/sagemaker/serve/model_server/torchserve/inference.py b/src/sagemaker/serve/model_server/torchserve/inference.py index cad94cc817..294c032ccc 100644 --- a/src/sagemaker/serve/model_server/torchserve/inference.py +++ b/src/sagemaker/serve/model_server/torchserve/inference.py @@ -67,11 +67,11 @@ def input_fn(input_data, content_type): try: if hasattr(schema_builder, "custom_input_translator"): deserialized_data = schema_builder.custom_input_translator.deserialize( - io.BytesIO(input_data), content_type + io.BytesIO(input_data) if type(input_data)== bytes else io.BytesIO(input_data.encode('utf-8')), content_type ) else: deserialized_data = schema_builder.input_deserializer.deserialize( - io.BytesIO(input_data), content_type[0] + io.BytesIO(input_data) if type(input_data)== bytes else io.BytesIO(input_data.encode('utf-8')), content_type[0] ) # Check if preprocess method is defined and call it diff --git a/src/sagemaker/serve/model_server/torchserve/xgboost_inference.py b/src/sagemaker/serve/model_server/torchserve/xgboost_inference.py index 4e82ec66b2..6dab9bc6c6 100644 --- a/src/sagemaker/serve/model_server/torchserve/xgboost_inference.py +++ b/src/sagemaker/serve/model_server/torchserve/xgboost_inference.py @@ -70,11 +70,11 @@ def input_fn(input_data, content_type): try: if hasattr(schema_builder, "custom_input_translator"): return schema_builder.custom_input_translator.deserialize( - io.BytesIO(input_data), content_type + io.BytesIO(input_data) if type(input_data)== bytes else io.BytesIO(input_data.encode('utf-8')), content_type ) else: return schema_builder.input_deserializer.deserialize( - io.BytesIO(input_data), content_type[0] + io.BytesIO(input_data) if type(input_data)== bytes else io.BytesIO(input_data.encode('utf-8')), content_type[0] ) except Exception as e: raise Exception("Encountered error in deserialize_request.") from e