From b21a75364b350bb84dd61600c5f7c272420c419e Mon Sep 17 00:00:00 2001 From: Evan Han Date: Wed, 7 May 2025 13:03:55 +0900 Subject: [PATCH 1/4] Update test_models_transformer_ltx.py --- .../test_models_transformer_ltx.py | 25 ++++++++++++++++++- 1 file changed, 24 insertions(+), 1 deletion(-) diff --git a/tests/models/transformers/test_models_transformer_ltx.py b/tests/models/transformers/test_models_transformer_ltx.py index 128bf04155e7..a44b98c0e681 100644 --- a/tests/models/transformers/test_models_transformer_ltx.py +++ b/tests/models/transformers/test_models_transformer_ltx.py @@ -18,7 +18,14 @@ import torch from diffusers import LTXVideoTransformer3DModel -from diffusers.utils.testing_utils import enable_full_determinism, torch_device +from diffusers.utils.testing_utils import ( + enable_full_determinism, + is_torch_compile, + require_torch_2, + require_torch_gpu, + slow, + torch_device, +) from ..test_modeling_common import ModelTesterMixin @@ -81,3 +88,19 @@ def prepare_init_args_and_inputs_for_common(self): def test_gradient_checkpointing_is_applied(self): expected_set = {"LTXVideoTransformer3DModel"} super().test_gradient_checkpointing_is_applied(expected_set=expected_set) + + @require_torch_gpu + @require_torch_2 + @is_torch_compile + @slow + def test_torch_compile_recompilation_and_graph_break(self): + torch._dynamo.reset() + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + model = self.model_class(**init_dict).to(torch_device) + model.eval() + model = torch.compile(model, fullgraph=True) + + with torch._dynamo.config.patch(error_on_recompile=True), torch.no_grad(): + _ = model(**inputs_dict) + _ = model(**inputs_dict) From 46fb70d578c5c5260fd877b59d515f9e86d73249 Mon Sep 17 00:00:00 2001 From: Evan Han Date: Wed, 7 May 2025 15:27:52 +0900 Subject: [PATCH 2/4] Update test_models_transformer_ltx.py --- tests/models/transformers/test_models_transformer_ltx.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/models/transformers/test_models_transformer_ltx.py b/tests/models/transformers/test_models_transformer_ltx.py index a44b98c0e681..e86ede254f12 100644 --- a/tests/models/transformers/test_models_transformer_ltx.py +++ b/tests/models/transformers/test_models_transformer_ltx.py @@ -27,13 +27,13 @@ torch_device, ) -from ..test_modeling_common import ModelTesterMixin +from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin enable_full_determinism() -class LTXTransformerTests(ModelTesterMixin, unittest.TestCase): +class LTXTransformerTests(ModelTesterMixin, TorchCompileTesterMixin, unittest.TestCase): model_class = LTXVideoTransformer3DModel main_input_name = "hidden_states" uses_custom_attn_processor = True From f12dc97bd816da626006d95d61d10bbdb49ffe72 Mon Sep 17 00:00:00 2001 From: Evan Han Date: Wed, 7 May 2025 17:13:44 +0900 Subject: [PATCH 3/4] Update test_models_transformer_ltx.py --- .../test_models_transformer_ltx.py | 22 +------------------ 1 file changed, 1 insertion(+), 21 deletions(-) diff --git a/tests/models/transformers/test_models_transformer_ltx.py b/tests/models/transformers/test_models_transformer_ltx.py index e86ede254f12..3ab4b5ed9543 100644 --- a/tests/models/transformers/test_models_transformer_ltx.py +++ b/tests/models/transformers/test_models_transformer_ltx.py @@ -19,13 +19,7 @@ from diffusers import LTXVideoTransformer3DModel from diffusers.utils.testing_utils import ( - enable_full_determinism, - is_torch_compile, - require_torch_2, - require_torch_gpu, - slow, - torch_device, -) +from diffusers.utils.testing_utils import enable_full_determinism, torch_device from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin @@ -87,20 +81,6 @@ def prepare_init_args_and_inputs_for_common(self): def test_gradient_checkpointing_is_applied(self): expected_set = {"LTXVideoTransformer3DModel"} - super().test_gradient_checkpointing_is_applied(expected_set=expected_set) - - @require_torch_gpu - @require_torch_2 - @is_torch_compile - @slow - def test_torch_compile_recompilation_and_graph_break(self): torch._dynamo.reset() - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - - model = self.model_class(**init_dict).to(torch_device) - model.eval() - model = torch.compile(model, fullgraph=True) - - with torch._dynamo.config.patch(error_on_recompile=True), torch.no_grad(): _ = model(**inputs_dict) _ = model(**inputs_dict) From ada170d5b923c5851ebeff83cede29e87f13be8d Mon Sep 17 00:00:00 2001 From: Evan Han Date: Wed, 7 May 2025 17:16:35 +0900 Subject: [PATCH 4/4] Update test_models_transformer_ltx.py --- tests/models/transformers/test_models_transformer_ltx.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/tests/models/transformers/test_models_transformer_ltx.py b/tests/models/transformers/test_models_transformer_ltx.py index 3ab4b5ed9543..8649ce97a52e 100644 --- a/tests/models/transformers/test_models_transformer_ltx.py +++ b/tests/models/transformers/test_models_transformer_ltx.py @@ -18,7 +18,6 @@ import torch from diffusers import LTXVideoTransformer3DModel -from diffusers.utils.testing_utils import ( from diffusers.utils.testing_utils import enable_full_determinism, torch_device from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin @@ -81,6 +80,4 @@ def prepare_init_args_and_inputs_for_common(self): def test_gradient_checkpointing_is_applied(self): expected_set = {"LTXVideoTransformer3DModel"} - torch._dynamo.reset() - _ = model(**inputs_dict) - _ = model(**inputs_dict) + super().test_gradient_checkpointing_is_applied(expected_set=expected_set)