|
import sys |
|
import traceback |
|
|
|
from finetrainers import BaseArgs, ControlTrainer, SFTTrainer, TrainingType, get_logger |
|
from finetrainers.config import _get_model_specifiction_cls |
|
from finetrainers.trainer.control_trainer.config import ControlFullRankConfig, ControlLowRankConfig |
|
from finetrainers.trainer.sft_trainer.config import SFTFullRankConfig, SFTLowRankConfig |
|
|
|
|
|
logger = get_logger() |
|
|
|
|
|
def main(): |
|
try: |
|
import multiprocessing |
|
|
|
multiprocessing.set_start_method("fork") |
|
except Exception as e: |
|
logger.error( |
|
f'Failed to set multiprocessing start method to "fork". This can lead to poor performance, high memory usage, or crashes. ' |
|
f"See: https://pytorch.org/docs/stable/notes/multiprocessing.html\n" |
|
f"Error: {e}" |
|
) |
|
|
|
try: |
|
args = BaseArgs() |
|
|
|
argv = [y.strip() for x in sys.argv for y in x.split()] |
|
training_type_index = argv.index("--training_type") |
|
if training_type_index == -1: |
|
raise ValueError("Training type not provided in command line arguments.") |
|
|
|
training_type = argv[training_type_index + 1] |
|
training_cls = None |
|
if training_type == TrainingType.LORA: |
|
training_cls = SFTLowRankConfig |
|
elif training_type == TrainingType.FULL_FINETUNE: |
|
training_cls = SFTFullRankConfig |
|
elif training_type == TrainingType.CONTROL_LORA: |
|
training_cls = ControlLowRankConfig |
|
elif training_type == TrainingType.CONTROL_FULL_FINETUNE: |
|
training_cls = ControlFullRankConfig |
|
else: |
|
raise ValueError(f"Training type {training_type} not supported.") |
|
|
|
args.register_args(training_cls()) |
|
args = args.parse_args() |
|
|
|
model_specification_cls = _get_model_specifiction_cls(args.model_name, args.training_type) |
|
model_specification = model_specification_cls( |
|
pretrained_model_name_or_path=args.pretrained_model_name_or_path, |
|
tokenizer_id=args.tokenizer_id, |
|
tokenizer_2_id=args.tokenizer_2_id, |
|
tokenizer_3_id=args.tokenizer_3_id, |
|
text_encoder_id=args.text_encoder_id, |
|
text_encoder_2_id=args.text_encoder_2_id, |
|
text_encoder_3_id=args.text_encoder_3_id, |
|
transformer_id=args.transformer_id, |
|
vae_id=args.vae_id, |
|
text_encoder_dtype=args.text_encoder_dtype, |
|
text_encoder_2_dtype=args.text_encoder_2_dtype, |
|
text_encoder_3_dtype=args.text_encoder_3_dtype, |
|
transformer_dtype=args.transformer_dtype, |
|
vae_dtype=args.vae_dtype, |
|
revision=args.revision, |
|
cache_dir=args.cache_dir, |
|
) |
|
|
|
if args.training_type in [TrainingType.LORA, TrainingType.FULL_FINETUNE]: |
|
trainer = SFTTrainer(args, model_specification) |
|
elif args.training_type in [TrainingType.CONTROL_LORA, TrainingType.CONTROL_FULL_FINETUNE]: |
|
trainer = ControlTrainer(args, model_specification) |
|
else: |
|
raise ValueError(f"Training type {args.training_type} not supported.") |
|
|
|
trainer.run() |
|
|
|
except KeyboardInterrupt: |
|
logger.info("Received keyboard interrupt. Exiting...") |
|
except Exception as e: |
|
logger.error(f"An error occurred during training: {e}") |
|
logger.error(traceback.format_exc()) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|