Skip to content

Commit 67a3e5a

Browse files
authored
fix: Map llama models to correct script (#5159)
1 parent b50b6fc commit 67a3e5a

File tree

2 files changed

+34
-34
lines changed

2 files changed

+34
-34
lines changed

src/sagemaker/modules/train/sm_recipes/utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ def _get_trainining_recipe_gpu_model_name_and_script(model_type: str):
129129
"""Get the model base name and script for the training recipe."""
130130

131131
model_type_to_script = {
132-
"llama_v3": ("llama", "llama_pretrain.py"),
132+
"llama": ("llama", "llama_pretrain.py"),
133133
"mistral": ("mistral", "mistral_pretrain.py"),
134134
"mixtral": ("mixtral", "mixtral_pretrain.py"),
135135
"deepseek": ("deepseek", "deepseek_pretrain.py"),

tests/unit/sagemaker/modules/train/sm_recipes/test_utils.py

+33-33
Original file line numberDiff line numberDiff line change
@@ -180,36 +180,36 @@ def test_get_args_from_recipe_compute(
180180
assert mock_trainium_args.call_count == 0
181181
assert args is None
182182

183-
@pytest.mark.parametrize(
184-
"test_case",
185-
[
186-
{
187-
"model_type": "llama_v3",
188-
"script": "llama_pretrain.py",
189-
"model_base_name": "llama_v3",
190-
},
191-
{
192-
"model_type": "mistral",
193-
"script": "mistral_pretrain.py",
194-
"model_base_name": "mistral",
195-
},
196-
{
197-
"model_type": "deepseek_llamav3",
198-
"script": "deepseek_pretrain.py",
199-
"model_base_name": "deepseek",
200-
},
201-
{
202-
"model_type": "deepseek_qwenv2",
203-
"script": "deepseek_pretrain.py",
204-
"model_base_name": "deepseek",
205-
},
206-
],
207-
)
208-
def test_get_trainining_recipe_gpu_model_name_and_script(test_case):
209-
model_type = test_case["model_type"]
210-
script = test_case["script"]
211-
model_base_name, script = _get_trainining_recipe_gpu_model_name_and_script(
212-
model_type, script
213-
)
214-
assert model_base_name == test_case["model_base_name"]
215-
assert script == test_case["script"]
183+
184+
@pytest.mark.parametrize(
185+
"test_case",
186+
[
187+
{"model_type": "llama_v4", "script": "llama_pretrain.py", "model_base_name": "llama"},
188+
{
189+
"model_type": "llama_v3",
190+
"script": "llama_pretrain.py",
191+
"model_base_name": "llama",
192+
},
193+
{
194+
"model_type": "mistral",
195+
"script": "mistral_pretrain.py",
196+
"model_base_name": "mistral",
197+
},
198+
{
199+
"model_type": "deepseek_llamav3",
200+
"script": "deepseek_pretrain.py",
201+
"model_base_name": "deepseek",
202+
},
203+
{
204+
"model_type": "deepseek_qwenv2",
205+
"script": "deepseek_pretrain.py",
206+
"model_base_name": "deepseek",
207+
},
208+
],
209+
)
210+
def test_get_trainining_recipe_gpu_model_name_and_script(test_case):
211+
model_type = test_case["model_type"]
212+
script = test_case["script"]
213+
model_base_name, script = _get_trainining_recipe_gpu_model_name_and_script(model_type)
214+
assert model_base_name == test_case["model_base_name"]
215+
assert script == test_case["script"]

0 commit comments

Comments
 (0)