# torchrun --nnodes=1 --nproc_per_node=1 -m pytest -s tests/trainer/test_sft_trainer.py import json import os import pathlib import tempfile import time import unittest import pytest import torch from diffusers.utils import export_to_video from parameterized import parameterized from PIL import Image from finetrainers import BaseArgs, SFTTrainer, TrainingType, get_logger os.environ["WANDB_MODE"] = "disabled" os.environ["FINETRAINERS_LOG_LEVEL"] = "INFO" from ..models.cogvideox.base_specification import DummyCogVideoXModelSpecification # noqa from ..models.cogview4.base_specification import DummyCogView4ModelSpecification # noqa from ..models.flux.base_specification import DummyFluxModelSpecification # noqa from ..models.hunyuan_video.base_specification import DummyHunyuanVideoModelSpecification # noqa from ..models.ltx_video.base_specification import DummyLTXVideoModelSpecification # noqa from ..models.wan.base_specification import DummyWanModelSpecification # noqa logger = get_logger() @pytest.fixture(autouse=True) def slow_down_tests(): yield # Sleep between each test so that process groups are cleaned and resources are released. # Not doing so seems to randomly trigger some test failures, which wouldn't fail if run individually. # !!!Look into this in future!!! time.sleep(5) class SFTTrainerFastTestsMixin: model_specification_cls = None num_data_files = 4 num_frames = 4 height = 64 width = 64 def setUp(self): self.tmpdir = tempfile.TemporaryDirectory() self.data_files = [] for i in range(self.num_data_files): data_file = pathlib.Path(self.tmpdir.name) / f"{i}.mp4" export_to_video( [Image.new("RGB", (self.width, self.height))] * self.num_frames, data_file.as_posix(), fps=2 ) self.data_files.append(data_file.as_posix()) csv_filename = pathlib.Path(self.tmpdir.name) / "metadata.csv" with open(csv_filename.as_posix(), "w") as f: f.write("file_name,caption\n") for i in range(self.num_data_files): prompt = f"A cat ruling the world - {i}" f.write(f'{i}.mp4,"{prompt}"\n') dataset_config = { "datasets": [ { "data_root": self.tmpdir.name, "dataset_type": "video", "id_token": "TEST", "video_resolution_buckets": [[self.num_frames, self.height, self.width]], "reshape_mode": "bicubic", } ] } self.dataset_config_filename = pathlib.Path(self.tmpdir.name) / "dataset_config.json" with open(self.dataset_config_filename.as_posix(), "w") as f: json.dump(dataset_config, f) def tearDown(self): self.tmpdir.cleanup() # For some reason, if the process group is not destroyed, the tests that follow will fail. Just manually # make sure to destroy it here. if torch.distributed.is_initialized(): torch.distributed.destroy_process_group() time.sleep(3) def get_base_args(self) -> BaseArgs: args = BaseArgs() args.dataset_config = self.dataset_config_filename.as_posix() args.train_steps = 10 args.max_data_samples = 25 args.batch_size = 1 args.gradient_checkpointing = True args.output_dir = self.tmpdir.name args.checkpointing_steps = 6 args.enable_precomputation = False args.precomputation_items = self.num_data_files args.precomputation_dir = os.path.join(self.tmpdir.name, "precomputed") args.compile_scopes = "regional" # This will only be in effect when `compile_modules` is set # args.attn_provider_training = ["transformer:_native_cudnn"] # args.attn_provider_inference = ["transformer:_native_cudnn"] return args def get_args(self) -> BaseArgs: raise NotImplementedError("`get_args` must be implemented in the subclass.") def _test_training(self, args: BaseArgs): model_specification = self.model_specification_cls() trainer = SFTTrainer(args, model_specification) trainer.run() # =============== =============== class SFTTrainerLoRATestsMixin___Accelerate(SFTTrainerFastTestsMixin): def get_args(self) -> BaseArgs: args = self.get_base_args() args.parallel_backend = "accelerate" args.training_type = TrainingType.LORA args.rank = 4 args.lora_alpha = 4 args.target_modules = ["to_q", "to_k", "to_v", "to_out.0"] return args @parameterized.expand([(False,), (True,)]) def test___dp_degree_1___batch_size_1(self, enable_precomputation: bool): args = self.get_args() args.dp_degree = 1 args.batch_size = 1 args.enable_precomputation = enable_precomputation self._test_training(args) @parameterized.expand([(True,)]) def test___layerwise_upcasting___dp_degree_1___batch_size_1(self, enable_precomputation: bool): args = self.get_args() args.dp_degree = 1 args.batch_size = 1 args.enable_precomputation = enable_precomputation args.layerwise_upcasting_modules = ["transformer"] self._test_training(args) @parameterized.expand([(False,), (True,)]) def test___dp_degree_2___batch_size_1(self, enable_precomputation: bool): args = self.get_args() args.dp_degree = 2 args.batch_size = 1 args.enable_precomputation = enable_precomputation self._test_training(args) class SFTTrainerFullFinetuneTestsMixin___Accelerate(SFTTrainerFastTestsMixin): def get_args(self) -> BaseArgs: args = self.get_base_args() args.parallel_backend = "accelerate" args.training_type = TrainingType.FULL_FINETUNE return args @parameterized.expand([(False,), (True,)]) def test___dp_degree_1___batch_size_1(self, enable_precomputation: bool): args = self.get_args() args.dp_degree = 1 args.batch_size = 1 args.enable_precomputation = enable_precomputation self._test_training(args) @parameterized.expand([(False,), (True,)]) def test___dp_degree_2___batch_size_1(self, enable_precomputation: bool): args = self.get_args() args.dp_degree = 2 args.batch_size = 1 args.enable_precomputation = enable_precomputation self._test_training(args) class SFTTrainerCogVideoXLoRATests___Accelerate(SFTTrainerLoRATestsMixin___Accelerate, unittest.TestCase): model_specification_cls = DummyCogVideoXModelSpecification class SFTTrainerCogVideoXFullFinetuneTests___Accelerate( SFTTrainerFullFinetuneTestsMixin___Accelerate, unittest.TestCase ): model_specification_cls = DummyCogVideoXModelSpecification class SFTTrainerCogView4LoRATests___Accelerate(SFTTrainerLoRATestsMixin___Accelerate, unittest.TestCase): model_specification_cls = DummyCogView4ModelSpecification class SFTTrainerCogView4FullFinetuneTests___Accelerate( SFTTrainerFullFinetuneTestsMixin___Accelerate, unittest.TestCase ): model_specification_cls = DummyCogView4ModelSpecification class SFTTrainerFluxLoRATests___Accelerate(SFTTrainerLoRATestsMixin___Accelerate, unittest.TestCase): model_specification_cls = DummyFluxModelSpecification class SFTTrainerFluxFullFinetuneTests___Accelerate(SFTTrainerFullFinetuneTestsMixin___Accelerate, unittest.TestCase): model_specification_cls = DummyFluxModelSpecification class SFTTrainerHunyuanVideoLoRATests___Accelerate(SFTTrainerLoRATestsMixin___Accelerate, unittest.TestCase): model_specification_cls = DummyHunyuanVideoModelSpecification class SFTTrainerHunyuanVideoFullFinetuneTests___Accelerate( SFTTrainerFullFinetuneTestsMixin___Accelerate, unittest.TestCase ): model_specification_cls = DummyHunyuanVideoModelSpecification class SFTTrainerLTXVideoLoRATests___Accelerate(SFTTrainerLoRATestsMixin___Accelerate, unittest.TestCase): model_specification_cls = DummyLTXVideoModelSpecification class SFTTrainerLTXVideoFullFinetuneTests___Accelerate( SFTTrainerFullFinetuneTestsMixin___Accelerate, unittest.TestCase ): model_specification_cls = DummyLTXVideoModelSpecification class SFTTrainerWanLoRATests___Accelerate(SFTTrainerLoRATestsMixin___Accelerate, unittest.TestCase): model_specification_cls = DummyWanModelSpecification class SFTTrainerWanFullFinetuneTests___Accelerate(SFTTrainerFullFinetuneTestsMixin___Accelerate, unittest.TestCase): model_specification_cls = DummyWanModelSpecification # =============== =============== # =============== =============== class SFTTrainerLoRATestsMixin___PTD(SFTTrainerFastTestsMixin): def get_args(self) -> BaseArgs: args = self.get_base_args() args.parallel_backend = "ptd" args.training_type = TrainingType.LORA args.rank = 4 args.lora_alpha = 4 args.target_modules = ["to_q", "to_k", "to_v", "to_out.0"] return args @parameterized.expand([(False,), (True,)]) def test___dp_degree_1___batch_size_1(self, enable_precomputation: bool): args = self.get_args() args.dp_degree = 1 args.batch_size = 1 args.enable_precomputation = enable_precomputation self._test_training(args) @parameterized.expand([(True,)]) def test___layerwise_upcasting___dp_degree_1___batch_size_1(self, enable_precomputation: bool): args = self.get_args() args.dp_degree = 1 args.batch_size = 1 args.enable_precomputation = enable_precomputation args.layerwise_upcasting_modules = ["transformer"] self._test_training(args) @parameterized.expand([(True,)]) def test___compile___dp_degree_1___batch_size_1(self, enable_precomputation: bool): args = self.get_args() args.dp_degree = 1 args.batch_size = 1 args.enable_precomputation = enable_precomputation args.compile_modules = ["transformer"] self._test_training(args) @parameterized.expand([(False,), (True,)]) def test___dp_degree_1___batch_size_2(self, enable_precomputation: bool): args = self.get_args() args.dp_degree = 1 args.batch_size = 2 args.enable_precomputation = enable_precomputation self._test_training(args) @parameterized.expand([(False,), (True,)]) def test___dp_degree_2___batch_size_1(self, enable_precomputation: bool): args = self.get_args() args.dp_degree = 2 args.batch_size = 1 args.enable_precomputation = enable_precomputation self._test_training(args) @parameterized.expand([(True,)]) def test___layerwise_upcasting___dp_degree_2___batch_size_1(self, enable_precomputation: bool): args = self.get_args() args.dp_degree = 2 args.batch_size = 1 args.enable_precomputation = enable_precomputation args.layerwise_upcasting_modules = ["transformer"] self._test_training(args) @parameterized.expand([(True,)]) def test___compile___dp_degree_2___batch_size_1(self, enable_precomputation: bool): args = self.get_args() args.dp_degree = 2 args.batch_size = 1 args.enable_precomputation = enable_precomputation args.compile_modules = ["transformer"] self._test_training(args) @parameterized.expand([(True,)]) def test___dp_degree_2___batch_size_2(self, enable_precomputation: bool): args = self.get_args() args.dp_degree = 2 args.batch_size = 2 args.enable_precomputation = enable_precomputation self._test_training(args) @parameterized.expand([(False,), (True,)]) def test___dp_shards_2___batch_size_1(self, enable_precomputation: bool): args = self.get_args() args.dp_shards = 2 args.batch_size = 1 args.enable_precomputation = enable_precomputation self._test_training(args) @parameterized.expand([(True,)]) def test___dp_shards_2___batch_size_2(self, enable_precomputation: bool): args = self.get_args() args.dp_shards = 2 args.batch_size = 1 args.enable_precomputation = enable_precomputation self._test_training(args) @parameterized.expand([(False,), (True,)]) def test___dp_degree_2___dp_shards_2___batch_size_1(self, enable_precomputation: bool): args = self.get_args() args.dp_degree = 2 args.dp_shards = 2 args.batch_size = 1 args.enable_precomputation = enable_precomputation self._test_training(args) @parameterized.expand([(False,), (True,)]) def test___tp_degree_2___batch_size_2(self, enable_precomputation: bool): args = self.get_args() args.tp_degree = 2 args.batch_size = 2 args.enable_precomputation = enable_precomputation self._test_training(args) @unittest.skip( "TODO: The model specifications for CP with cudnn/flash/efficient backend require the attention head dim to be a multiple with 8. Land math backend first for fast tests and then enable this test." ) @parameterized.expand([(True,)]) def test___cp_degree_2___batch_size_1(self, enable_precomputation: bool): args = self.get_args() args.cp_degree = 2 args.batch_size = 1 args.enable_precomputation = enable_precomputation self._test_training(args) @unittest.skip( "TODO: The model specifications for CP with cudnn/flash/efficient backend require the attention head dim to be a multiple with 8. Land math backend first for fast tests and then enable this test." ) @parameterized.expand([(True,)]) def test___dp_degree_2___cp_degree_2___batch_size_1(self, enable_precomputation: bool): args = self.get_args() args.dp_degree = 2 args.cp_degree = 2 args.batch_size = 1 args.enable_precomputation = enable_precomputation self._test_training(args) class SFTTrainerFullFinetuneTestsMixin___PTD(SFTTrainerFastTestsMixin): def get_args(self) -> BaseArgs: args = self.get_base_args() args.parallel_backend = "ptd" args.training_type = TrainingType.FULL_FINETUNE return args @parameterized.expand([(False,), (True,)]) def test___dp_degree_1___batch_size_1(self, enable_precomputation: bool): args = self.get_args() args.dp_degree = 1 args.batch_size = 1 args.enable_precomputation = enable_precomputation self._test_training(args) @parameterized.expand([(True,)]) def test___compile___dp_degree_1___batch_size_1(self, enable_precomputation: bool): args = self.get_args() args.dp_degree = 1 args.batch_size = 1 args.enable_precomputation = enable_precomputation args.compile_modules = ["transformer"] self._test_training(args) @parameterized.expand([(True,)]) def test___dp_degree_1___batch_size_2(self, enable_precomputation: bool): args = self.get_args() args.dp_degree = 1 args.batch_size = 2 args.enable_precomputation = enable_precomputation self._test_training(args) @parameterized.expand([(False,), (True,)]) def test___dp_degree_2___batch_size_1(self, enable_precomputation: bool): args = self.get_args() args.dp_degree = 2 args.batch_size = 1 args.enable_precomputation = enable_precomputation self._test_training(args) @parameterized.expand([(True,)]) def test___compile___dp_degree_2___batch_size_1(self, enable_precomputation: bool): args = self.get_args() args.dp_degree = 2 args.batch_size = 1 args.enable_precomputation = enable_precomputation args.compile_modules = ["transformer"] self._test_training(args) @parameterized.expand([(True,)]) def test___dp_degree_2___batch_size_2(self, enable_precomputation: bool): args = self.get_args() args.dp_degree = 2 args.batch_size = 2 args.enable_precomputation = enable_precomputation self._test_training(args) @parameterized.expand([(False,), (True,)]) def test___dp_shards_2___batch_size_1(self, enable_precomputation: bool): args = self.get_args() args.dp_shards = 2 args.batch_size = 1 args.enable_precomputation = enable_precomputation self._test_training(args) @parameterized.expand([(True,)]) def test___dp_shards_2___batch_size_2(self, enable_precomputation: bool): args = self.get_args() args.dp_shards = 2 args.batch_size = 1 args.enable_precomputation = enable_precomputation self._test_training(args) @parameterized.expand([(False,), (True,)]) def test___dp_degree_2___dp_shards_2___batch_size_1(self, enable_precomputation: bool): args = self.get_args() args.dp_degree = 2 args.dp_shards = 2 args.batch_size = 1 args.enable_precomputation = enable_precomputation self._test_training(args) @parameterized.expand([(False,), (True,)]) def test___tp_degree_2___batch_size_2(self, enable_precomputation: bool): args = self.get_args() args.tp_degree = 2 args.batch_size = 1 args.enable_precomputation = enable_precomputation self._test_training(args) @unittest.skip( "TODO: The model specifications for CP with cudnn/flash/efficient backend require the attention head dim to be a multiple with 8. Land math backend first for fast tests and then enable this test." ) @parameterized.expand([(True,)]) def test___cp_degree_2___batch_size_1(self, enable_precomputation: bool): args = self.get_args() args.cp_degree = 2 args.batch_size = 1 args.enable_precomputation = enable_precomputation self._test_training(args) @unittest.skip( "TODO: The model specifications for CP with cudnn/flash/efficient backend require the attention head dim to be a multiple with 8. Land math backend first for fast tests and then enable this test." ) @parameterized.expand([(True,)]) def test___dp_degree_2___cp_degree_2___batch_size_1(self, enable_precomputation: bool): args = self.get_args() args.dp_degree = 2 args.cp_degree = 2 args.batch_size = 1 args.enable_precomputation = enable_precomputation self._test_training(args) class SFTTrainerCogVideoXLoRATests___PTD(SFTTrainerLoRATestsMixin___PTD, unittest.TestCase): model_specification_cls = DummyCogVideoXModelSpecification class SFTTrainerCogVideoXFullFinetuneTests___PTD(SFTTrainerFullFinetuneTestsMixin___PTD, unittest.TestCase): model_specification_cls = DummyCogVideoXModelSpecification class SFTTrainerCogView4LoRATests___PTD(SFTTrainerLoRATestsMixin___PTD, unittest.TestCase): model_specification_cls = DummyCogView4ModelSpecification class SFTTrainerCogView4FullFinetuneTests___PTD(SFTTrainerFullFinetuneTestsMixin___PTD, unittest.TestCase): model_specification_cls = DummyCogView4ModelSpecification class SFTTrainerFluxLoRATests___PTD(SFTTrainerLoRATestsMixin___PTD, unittest.TestCase): model_specification_cls = DummyFluxModelSpecification class SFTTrainerFluxFullFinetuneTests___PTD(SFTTrainerFullFinetuneTestsMixin___PTD, unittest.TestCase): model_specification_cls = DummyFluxModelSpecification class SFTTrainerHunyuanVideoLoRATests___PTD(SFTTrainerLoRATestsMixin___PTD, unittest.TestCase): model_specification_cls = DummyHunyuanVideoModelSpecification class SFTTrainerHunyuanVideoFullFinetuneTests___PTD(SFTTrainerFullFinetuneTestsMixin___PTD, unittest.TestCase): model_specification_cls = DummyHunyuanVideoModelSpecification class SFTTrainerLTXVideoLoRATests___PTD(SFTTrainerLoRATestsMixin___PTD, unittest.TestCase): model_specification_cls = DummyLTXVideoModelSpecification class SFTTrainerLTXVideoFullFinetuneTests___PTD(SFTTrainerFullFinetuneTestsMixin___PTD, unittest.TestCase): model_specification_cls = DummyLTXVideoModelSpecification class SFTTrainerWanLoRATests___PTD(SFTTrainerLoRATestsMixin___PTD, unittest.TestCase): model_specification_cls = DummyWanModelSpecification class SFTTrainerWanFullFinetuneTests___PTD(SFTTrainerFullFinetuneTestsMixin___PTD, unittest.TestCase): model_specification_cls = DummyWanModelSpecification # =============== ===============