Spaces:
Running
Running
# torchrun --nnodes=1 --nproc_per_node=1 -m pytest -s tests/trainer/test_sft_trainer.py | |
import json | |
import os | |
import pathlib | |
import tempfile | |
import time | |
import unittest | |
import pytest | |
import torch | |
from diffusers.utils import export_to_video | |
from parameterized import parameterized | |
from PIL import Image | |
from finetrainers import BaseArgs, SFTTrainer, TrainingType, get_logger | |
os.environ["WANDB_MODE"] = "disabled" | |
os.environ["FINETRAINERS_LOG_LEVEL"] = "INFO" | |
from ..models.cogvideox.base_specification import DummyCogVideoXModelSpecification # noqa | |
from ..models.cogview4.base_specification import DummyCogView4ModelSpecification # noqa | |
from ..models.flux.base_specification import DummyFluxModelSpecification # noqa | |
from ..models.hunyuan_video.base_specification import DummyHunyuanVideoModelSpecification # noqa | |
from ..models.ltx_video.base_specification import DummyLTXVideoModelSpecification # noqa | |
from ..models.wan.base_specification import DummyWanModelSpecification # noqa | |
logger = get_logger() | |
def slow_down_tests(): | |
yield | |
# Sleep between each test so that process groups are cleaned and resources are released. | |
# Not doing so seems to randomly trigger some test failures, which wouldn't fail if run individually. | |
# !!!Look into this in future!!! | |
time.sleep(5) | |
class SFTTrainerFastTestsMixin: | |
model_specification_cls = None | |
num_data_files = 4 | |
num_frames = 4 | |
height = 64 | |
width = 64 | |
def setUp(self): | |
self.tmpdir = tempfile.TemporaryDirectory() | |
self.data_files = [] | |
for i in range(self.num_data_files): | |
data_file = pathlib.Path(self.tmpdir.name) / f"{i}.mp4" | |
export_to_video( | |
[Image.new("RGB", (self.width, self.height))] * self.num_frames, data_file.as_posix(), fps=2 | |
) | |
self.data_files.append(data_file.as_posix()) | |
csv_filename = pathlib.Path(self.tmpdir.name) / "metadata.csv" | |
with open(csv_filename.as_posix(), "w") as f: | |
f.write("file_name,caption\n") | |
for i in range(self.num_data_files): | |
prompt = f"A cat ruling the world - {i}" | |
f.write(f'{i}.mp4,"{prompt}"\n') | |
dataset_config = { | |
"datasets": [ | |
{ | |
"data_root": self.tmpdir.name, | |
"dataset_type": "video", | |
"id_token": "TEST", | |
"video_resolution_buckets": [[self.num_frames, self.height, self.width]], | |
"reshape_mode": "bicubic", | |
} | |
] | |
} | |
self.dataset_config_filename = pathlib.Path(self.tmpdir.name) / "dataset_config.json" | |
with open(self.dataset_config_filename.as_posix(), "w") as f: | |
json.dump(dataset_config, f) | |
def tearDown(self): | |
self.tmpdir.cleanup() | |
# For some reason, if the process group is not destroyed, the tests that follow will fail. Just manually | |
# make sure to destroy it here. | |
if torch.distributed.is_initialized(): | |
torch.distributed.destroy_process_group() | |
time.sleep(3) | |
def get_base_args(self) -> BaseArgs: | |
args = BaseArgs() | |
args.dataset_config = self.dataset_config_filename.as_posix() | |
args.train_steps = 10 | |
args.max_data_samples = 25 | |
args.batch_size = 1 | |
args.gradient_checkpointing = True | |
args.output_dir = self.tmpdir.name | |
args.checkpointing_steps = 6 | |
args.enable_precomputation = False | |
args.precomputation_items = self.num_data_files | |
args.precomputation_dir = os.path.join(self.tmpdir.name, "precomputed") | |
args.compile_scopes = "regional" # This will only be in effect when `compile_modules` is set | |
# args.attn_provider_training = ["transformer:_native_cudnn"] | |
# args.attn_provider_inference = ["transformer:_native_cudnn"] | |
return args | |
def get_args(self) -> BaseArgs: | |
raise NotImplementedError("`get_args` must be implemented in the subclass.") | |
def _test_training(self, args: BaseArgs): | |
model_specification = self.model_specification_cls() | |
trainer = SFTTrainer(args, model_specification) | |
trainer.run() | |
# =============== <ACCELERATE> =============== | |
class SFTTrainerLoRATestsMixin___Accelerate(SFTTrainerFastTestsMixin): | |
def get_args(self) -> BaseArgs: | |
args = self.get_base_args() | |
args.parallel_backend = "accelerate" | |
args.training_type = TrainingType.LORA | |
args.rank = 4 | |
args.lora_alpha = 4 | |
args.target_modules = ["to_q", "to_k", "to_v", "to_out.0"] | |
return args | |
def test___dp_degree_1___batch_size_1(self, enable_precomputation: bool): | |
args = self.get_args() | |
args.dp_degree = 1 | |
args.batch_size = 1 | |
args.enable_precomputation = enable_precomputation | |
self._test_training(args) | |
def test___layerwise_upcasting___dp_degree_1___batch_size_1(self, enable_precomputation: bool): | |
args = self.get_args() | |
args.dp_degree = 1 | |
args.batch_size = 1 | |
args.enable_precomputation = enable_precomputation | |
args.layerwise_upcasting_modules = ["transformer"] | |
self._test_training(args) | |
def test___dp_degree_2___batch_size_1(self, enable_precomputation: bool): | |
args = self.get_args() | |
args.dp_degree = 2 | |
args.batch_size = 1 | |
args.enable_precomputation = enable_precomputation | |
self._test_training(args) | |
class SFTTrainerFullFinetuneTestsMixin___Accelerate(SFTTrainerFastTestsMixin): | |
def get_args(self) -> BaseArgs: | |
args = self.get_base_args() | |
args.parallel_backend = "accelerate" | |
args.training_type = TrainingType.FULL_FINETUNE | |
return args | |
def test___dp_degree_1___batch_size_1(self, enable_precomputation: bool): | |
args = self.get_args() | |
args.dp_degree = 1 | |
args.batch_size = 1 | |
args.enable_precomputation = enable_precomputation | |
self._test_training(args) | |
def test___dp_degree_2___batch_size_1(self, enable_precomputation: bool): | |
args = self.get_args() | |
args.dp_degree = 2 | |
args.batch_size = 1 | |
args.enable_precomputation = enable_precomputation | |
self._test_training(args) | |
class SFTTrainerCogVideoXLoRATests___Accelerate(SFTTrainerLoRATestsMixin___Accelerate, unittest.TestCase): | |
model_specification_cls = DummyCogVideoXModelSpecification | |
class SFTTrainerCogVideoXFullFinetuneTests___Accelerate( | |
SFTTrainerFullFinetuneTestsMixin___Accelerate, unittest.TestCase | |
): | |
model_specification_cls = DummyCogVideoXModelSpecification | |
class SFTTrainerCogView4LoRATests___Accelerate(SFTTrainerLoRATestsMixin___Accelerate, unittest.TestCase): | |
model_specification_cls = DummyCogView4ModelSpecification | |
class SFTTrainerCogView4FullFinetuneTests___Accelerate( | |
SFTTrainerFullFinetuneTestsMixin___Accelerate, unittest.TestCase | |
): | |
model_specification_cls = DummyCogView4ModelSpecification | |
class SFTTrainerFluxLoRATests___Accelerate(SFTTrainerLoRATestsMixin___Accelerate, unittest.TestCase): | |
model_specification_cls = DummyFluxModelSpecification | |
class SFTTrainerFluxFullFinetuneTests___Accelerate(SFTTrainerFullFinetuneTestsMixin___Accelerate, unittest.TestCase): | |
model_specification_cls = DummyFluxModelSpecification | |
class SFTTrainerHunyuanVideoLoRATests___Accelerate(SFTTrainerLoRATestsMixin___Accelerate, unittest.TestCase): | |
model_specification_cls = DummyHunyuanVideoModelSpecification | |
class SFTTrainerHunyuanVideoFullFinetuneTests___Accelerate( | |
SFTTrainerFullFinetuneTestsMixin___Accelerate, unittest.TestCase | |
): | |
model_specification_cls = DummyHunyuanVideoModelSpecification | |
class SFTTrainerLTXVideoLoRATests___Accelerate(SFTTrainerLoRATestsMixin___Accelerate, unittest.TestCase): | |
model_specification_cls = DummyLTXVideoModelSpecification | |
class SFTTrainerLTXVideoFullFinetuneTests___Accelerate( | |
SFTTrainerFullFinetuneTestsMixin___Accelerate, unittest.TestCase | |
): | |
model_specification_cls = DummyLTXVideoModelSpecification | |
class SFTTrainerWanLoRATests___Accelerate(SFTTrainerLoRATestsMixin___Accelerate, unittest.TestCase): | |
model_specification_cls = DummyWanModelSpecification | |
class SFTTrainerWanFullFinetuneTests___Accelerate(SFTTrainerFullFinetuneTestsMixin___Accelerate, unittest.TestCase): | |
model_specification_cls = DummyWanModelSpecification | |
# =============== </ACCELERATE> =============== | |
# =============== <PTD> =============== | |
class SFTTrainerLoRATestsMixin___PTD(SFTTrainerFastTestsMixin): | |
def get_args(self) -> BaseArgs: | |
args = self.get_base_args() | |
args.parallel_backend = "ptd" | |
args.training_type = TrainingType.LORA | |
args.rank = 4 | |
args.lora_alpha = 4 | |
args.target_modules = ["to_q", "to_k", "to_v", "to_out.0"] | |
return args | |
def test___dp_degree_1___batch_size_1(self, enable_precomputation: bool): | |
args = self.get_args() | |
args.dp_degree = 1 | |
args.batch_size = 1 | |
args.enable_precomputation = enable_precomputation | |
self._test_training(args) | |
def test___layerwise_upcasting___dp_degree_1___batch_size_1(self, enable_precomputation: bool): | |
args = self.get_args() | |
args.dp_degree = 1 | |
args.batch_size = 1 | |
args.enable_precomputation = enable_precomputation | |
args.layerwise_upcasting_modules = ["transformer"] | |
self._test_training(args) | |
def test___compile___dp_degree_1___batch_size_1(self, enable_precomputation: bool): | |
args = self.get_args() | |
args.dp_degree = 1 | |
args.batch_size = 1 | |
args.enable_precomputation = enable_precomputation | |
args.compile_modules = ["transformer"] | |
self._test_training(args) | |
def test___dp_degree_1___batch_size_2(self, enable_precomputation: bool): | |
args = self.get_args() | |
args.dp_degree = 1 | |
args.batch_size = 2 | |
args.enable_precomputation = enable_precomputation | |
self._test_training(args) | |
def test___dp_degree_2___batch_size_1(self, enable_precomputation: bool): | |
args = self.get_args() | |
args.dp_degree = 2 | |
args.batch_size = 1 | |
args.enable_precomputation = enable_precomputation | |
self._test_training(args) | |
def test___layerwise_upcasting___dp_degree_2___batch_size_1(self, enable_precomputation: bool): | |
args = self.get_args() | |
args.dp_degree = 2 | |
args.batch_size = 1 | |
args.enable_precomputation = enable_precomputation | |
args.layerwise_upcasting_modules = ["transformer"] | |
self._test_training(args) | |
def test___compile___dp_degree_2___batch_size_1(self, enable_precomputation: bool): | |
args = self.get_args() | |
args.dp_degree = 2 | |
args.batch_size = 1 | |
args.enable_precomputation = enable_precomputation | |
args.compile_modules = ["transformer"] | |
self._test_training(args) | |
def test___dp_degree_2___batch_size_2(self, enable_precomputation: bool): | |
args = self.get_args() | |
args.dp_degree = 2 | |
args.batch_size = 2 | |
args.enable_precomputation = enable_precomputation | |
self._test_training(args) | |
def test___dp_shards_2___batch_size_1(self, enable_precomputation: bool): | |
args = self.get_args() | |
args.dp_shards = 2 | |
args.batch_size = 1 | |
args.enable_precomputation = enable_precomputation | |
self._test_training(args) | |
def test___dp_shards_2___batch_size_2(self, enable_precomputation: bool): | |
args = self.get_args() | |
args.dp_shards = 2 | |
args.batch_size = 1 | |
args.enable_precomputation = enable_precomputation | |
self._test_training(args) | |
def test___dp_degree_2___dp_shards_2___batch_size_1(self, enable_precomputation: bool): | |
args = self.get_args() | |
args.dp_degree = 2 | |
args.dp_shards = 2 | |
args.batch_size = 1 | |
args.enable_precomputation = enable_precomputation | |
self._test_training(args) | |
def test___tp_degree_2___batch_size_2(self, enable_precomputation: bool): | |
args = self.get_args() | |
args.tp_degree = 2 | |
args.batch_size = 2 | |
args.enable_precomputation = enable_precomputation | |
self._test_training(args) | |
def test___cp_degree_2___batch_size_1(self, enable_precomputation: bool): | |
args = self.get_args() | |
args.cp_degree = 2 | |
args.batch_size = 1 | |
args.enable_precomputation = enable_precomputation | |
self._test_training(args) | |
def test___dp_degree_2___cp_degree_2___batch_size_1(self, enable_precomputation: bool): | |
args = self.get_args() | |
args.dp_degree = 2 | |
args.cp_degree = 2 | |
args.batch_size = 1 | |
args.enable_precomputation = enable_precomputation | |
self._test_training(args) | |
class SFTTrainerFullFinetuneTestsMixin___PTD(SFTTrainerFastTestsMixin): | |
def get_args(self) -> BaseArgs: | |
args = self.get_base_args() | |
args.parallel_backend = "ptd" | |
args.training_type = TrainingType.FULL_FINETUNE | |
return args | |
def test___dp_degree_1___batch_size_1(self, enable_precomputation: bool): | |
args = self.get_args() | |
args.dp_degree = 1 | |
args.batch_size = 1 | |
args.enable_precomputation = enable_precomputation | |
self._test_training(args) | |
def test___compile___dp_degree_1___batch_size_1(self, enable_precomputation: bool): | |
args = self.get_args() | |
args.dp_degree = 1 | |
args.batch_size = 1 | |
args.enable_precomputation = enable_precomputation | |
args.compile_modules = ["transformer"] | |
self._test_training(args) | |
def test___dp_degree_1___batch_size_2(self, enable_precomputation: bool): | |
args = self.get_args() | |
args.dp_degree = 1 | |
args.batch_size = 2 | |
args.enable_precomputation = enable_precomputation | |
self._test_training(args) | |
def test___dp_degree_2___batch_size_1(self, enable_precomputation: bool): | |
args = self.get_args() | |
args.dp_degree = 2 | |
args.batch_size = 1 | |
args.enable_precomputation = enable_precomputation | |
self._test_training(args) | |
def test___compile___dp_degree_2___batch_size_1(self, enable_precomputation: bool): | |
args = self.get_args() | |
args.dp_degree = 2 | |
args.batch_size = 1 | |
args.enable_precomputation = enable_precomputation | |
args.compile_modules = ["transformer"] | |
self._test_training(args) | |
def test___dp_degree_2___batch_size_2(self, enable_precomputation: bool): | |
args = self.get_args() | |
args.dp_degree = 2 | |
args.batch_size = 2 | |
args.enable_precomputation = enable_precomputation | |
self._test_training(args) | |
def test___dp_shards_2___batch_size_1(self, enable_precomputation: bool): | |
args = self.get_args() | |
args.dp_shards = 2 | |
args.batch_size = 1 | |
args.enable_precomputation = enable_precomputation | |
self._test_training(args) | |
def test___dp_shards_2___batch_size_2(self, enable_precomputation: bool): | |
args = self.get_args() | |
args.dp_shards = 2 | |
args.batch_size = 1 | |
args.enable_precomputation = enable_precomputation | |
self._test_training(args) | |
def test___dp_degree_2___dp_shards_2___batch_size_1(self, enable_precomputation: bool): | |
args = self.get_args() | |
args.dp_degree = 2 | |
args.dp_shards = 2 | |
args.batch_size = 1 | |
args.enable_precomputation = enable_precomputation | |
self._test_training(args) | |
def test___tp_degree_2___batch_size_2(self, enable_precomputation: bool): | |
args = self.get_args() | |
args.tp_degree = 2 | |
args.batch_size = 1 | |
args.enable_precomputation = enable_precomputation | |
self._test_training(args) | |
def test___cp_degree_2___batch_size_1(self, enable_precomputation: bool): | |
args = self.get_args() | |
args.cp_degree = 2 | |
args.batch_size = 1 | |
args.enable_precomputation = enable_precomputation | |
self._test_training(args) | |
def test___dp_degree_2___cp_degree_2___batch_size_1(self, enable_precomputation: bool): | |
args = self.get_args() | |
args.dp_degree = 2 | |
args.cp_degree = 2 | |
args.batch_size = 1 | |
args.enable_precomputation = enable_precomputation | |
self._test_training(args) | |
class SFTTrainerCogVideoXLoRATests___PTD(SFTTrainerLoRATestsMixin___PTD, unittest.TestCase): | |
model_specification_cls = DummyCogVideoXModelSpecification | |
class SFTTrainerCogVideoXFullFinetuneTests___PTD(SFTTrainerFullFinetuneTestsMixin___PTD, unittest.TestCase): | |
model_specification_cls = DummyCogVideoXModelSpecification | |
class SFTTrainerCogView4LoRATests___PTD(SFTTrainerLoRATestsMixin___PTD, unittest.TestCase): | |
model_specification_cls = DummyCogView4ModelSpecification | |
class SFTTrainerCogView4FullFinetuneTests___PTD(SFTTrainerFullFinetuneTestsMixin___PTD, unittest.TestCase): | |
model_specification_cls = DummyCogView4ModelSpecification | |
class SFTTrainerFluxLoRATests___PTD(SFTTrainerLoRATestsMixin___PTD, unittest.TestCase): | |
model_specification_cls = DummyFluxModelSpecification | |
class SFTTrainerFluxFullFinetuneTests___PTD(SFTTrainerFullFinetuneTestsMixin___PTD, unittest.TestCase): | |
model_specification_cls = DummyFluxModelSpecification | |
class SFTTrainerHunyuanVideoLoRATests___PTD(SFTTrainerLoRATestsMixin___PTD, unittest.TestCase): | |
model_specification_cls = DummyHunyuanVideoModelSpecification | |
class SFTTrainerHunyuanVideoFullFinetuneTests___PTD(SFTTrainerFullFinetuneTestsMixin___PTD, unittest.TestCase): | |
model_specification_cls = DummyHunyuanVideoModelSpecification | |
class SFTTrainerLTXVideoLoRATests___PTD(SFTTrainerLoRATestsMixin___PTD, unittest.TestCase): | |
model_specification_cls = DummyLTXVideoModelSpecification | |
class SFTTrainerLTXVideoFullFinetuneTests___PTD(SFTTrainerFullFinetuneTestsMixin___PTD, unittest.TestCase): | |
model_specification_cls = DummyLTXVideoModelSpecification | |
class SFTTrainerWanLoRATests___PTD(SFTTrainerLoRATestsMixin___PTD, unittest.TestCase): | |
model_specification_cls = DummyWanModelSpecification | |
class SFTTrainerWanFullFinetuneTests___PTD(SFTTrainerFullFinetuneTestsMixin___PTD, unittest.TestCase): | |
model_specification_cls = DummyWanModelSpecification | |
# =============== </PTD> =============== | |