@@ -180,36 +180,36 @@ def test_get_args_from_recipe_compute(
180
180
assert mock_trainium_args .call_count == 0
181
181
assert args is None
182
182
183
- @ pytest . mark . parametrize (
184
- "test_case" ,
185
- [
186
- {
187
- "model_type " : "llama_v3" ,
188
- "script" : "llama_pretrain.py" ,
189
- "model_base_name " : "llama_v3" ,
190
- } ,
191
- {
192
- "model_type" : "mistral" ,
193
- "script" : "mistral_pretrain.py" ,
194
- "model_base_name " : "mistral" ,
195
- } ,
196
- {
197
- "model_type" : "deepseek_llamav3" ,
198
- "script" : "deepseek_pretrain.py" ,
199
- "model_base_name " : "deepseek " ,
200
- } ,
201
- {
202
- "model_type" : "deepseek_qwenv2" ,
203
- "script" : "deepseek_pretrain.py" ,
204
- "model_base_name " : "deepseek " ,
205
- } ,
206
- ] ,
207
- )
208
- def test_get_trainining_recipe_gpu_model_name_and_script ( test_case ):
209
- model_type = test_case [ "model_type" ]
210
- script = test_case [ "script" ]
211
- model_base_name , script = _get_trainining_recipe_gpu_model_name_and_script (
212
- model_type , script
213
- )
214
- assert model_base_name == test_case ["model_base_name" ]
215
- assert script == test_case ["script" ]
183
+
184
+ @ pytest . mark . parametrize (
185
+ "test_case" ,
186
+ [
187
+ { "model_type" : "llama_v4" , "script" : "llama_pretrain.py" , "model_base_name " : "llama" } ,
188
+ {
189
+ "model_type " : "llama_v3" ,
190
+ "script" : "llama_pretrain.py" ,
191
+ "model_base_name" : "llama" ,
192
+ } ,
193
+ {
194
+ "model_type " : "mistral" ,
195
+ "script" : "mistral_pretrain.py" ,
196
+ "model_base_name" : "mistral" ,
197
+ } ,
198
+ {
199
+ "model_type " : "deepseek_llamav3 " ,
200
+ "script" : "deepseek_pretrain.py" ,
201
+ "model_base_name" : "deepseek" ,
202
+ } ,
203
+ {
204
+ "model_type " : "deepseek_qwenv2 " ,
205
+ "script" : "deepseek_pretrain.py" ,
206
+ "model_base_name" : "deepseek" ,
207
+ },
208
+ ],
209
+ )
210
+ def test_get_trainining_recipe_gpu_model_name_and_script ( test_case ):
211
+ model_type = test_case [ "model_type" ]
212
+ script = test_case [ " script" ]
213
+ model_base_name , script = _get_trainining_recipe_gpu_model_name_and_script ( model_type )
214
+ assert model_base_name == test_case ["model_base_name" ]
215
+ assert script == test_case ["script" ]
0 commit comments