Spaces:
Runtime error
Runtime error
| # Diffusion Model Alignment Using Direct Preference Optimization | |
| This directory provides LoRA implementations of Diffusion DPO proposed in [DiffusionModel Alignment Using Direct Preference Optimization](https://arxiv.org/abs/2311.12908) by Bram Wallace, Meihua Dang, Rafael Rafailov, Linqi Zhou, Aaron Lou, Senthil Purushwalkam, Stefano Ermon, Caiming Xiong, Shafiq Joty, and Nikhil Naik. | |
| We provide implementations for both Stable Diffusion (SD) and Stable Diffusion XL (SDXL). The original checkpoints are available at the URLs below: | |
| * [mhdang/dpo-sd1.5-text2image-v1](https://huggingface.co/mhdang/dpo-sd1.5-text2image-v1) | |
| * [mhdang/dpo-sdxl-text2image-v1](https://huggingface.co/mhdang/dpo-sdxl-text2image-v1) | |
| > 💡 Note: The scripts are highly experimental and were only tested on low-data regimes. Proceed with caution. Feel free to let us know about your findings via GitHub issues. | |
| ## SD training command | |
| ```bash | |
| accelerate launch train_diffusion_dpo.py \ | |
| --pretrained_model_name_or_path=stable-diffusion-v1-5/stable-diffusion-v1-5 \ | |
| --output_dir="diffusion-dpo" \ | |
| --mixed_precision="fp16" \ | |
| --dataset_name=kashif/pickascore \ | |
| --resolution=512 \ | |
| --train_batch_size=16 \ | |
| --gradient_accumulation_steps=2 \ | |
| --gradient_checkpointing \ | |
| --use_8bit_adam \ | |
| --rank=8 \ | |
| --learning_rate=1e-5 \ | |
| --report_to="wandb" \ | |
| --lr_scheduler="constant" \ | |
| --lr_warmup_steps=0 \ | |
| --max_train_steps=10000 \ | |
| --checkpointing_steps=2000 \ | |
| --run_validation --validation_steps=200 \ | |
| --seed="0" \ | |
| --report_to="wandb" \ | |
| --push_to_hub | |
| ``` | |
| ## SDXL training command | |
| ```bash | |
| accelerate launch train_diffusion_dpo_sdxl.py \ | |
| --pretrained_model_name_or_path=stabilityai/stable-diffusion-xl-base-1.0 \ | |
| --pretrained_vae_model_name_or_path=madebyollin/sdxl-vae-fp16-fix \ | |
| --output_dir="diffusion-sdxl-dpo" \ | |
| --mixed_precision="fp16" \ | |
| --dataset_name=kashif/pickascore \ | |
| --train_batch_size=8 \ | |
| --gradient_accumulation_steps=2 \ | |
| --gradient_checkpointing \ | |
| --use_8bit_adam \ | |
| --rank=8 \ | |
| --learning_rate=1e-5 \ | |
| --report_to="wandb" \ | |
| --lr_scheduler="constant" \ | |
| --lr_warmup_steps=0 \ | |
| --max_train_steps=2000 \ | |
| --checkpointing_steps=500 \ | |
| --run_validation --validation_steps=50 \ | |
| --seed="0" \ | |
| --report_to="wandb" \ | |
| --push_to_hub | |
| ``` | |
| ## SDXL Turbo training command | |
| ```bash | |
| accelerate launch train_diffusion_dpo_sdxl.py \ | |
| --pretrained_model_name_or_path=stabilityai/sdxl-turbo \ | |
| --pretrained_vae_model_name_or_path=madebyollin/sdxl-vae-fp16-fix \ | |
| --output_dir="diffusion-sdxl-turbo-dpo" \ | |
| --mixed_precision="fp16" \ | |
| --dataset_name=kashif/pickascore \ | |
| --train_batch_size=8 \ | |
| --gradient_accumulation_steps=2 \ | |
| --gradient_checkpointing \ | |
| --use_8bit_adam \ | |
| --rank=8 \ | |
| --learning_rate=1e-5 \ | |
| --report_to="wandb" \ | |
| --lr_scheduler="constant" \ | |
| --lr_warmup_steps=0 \ | |
| --max_train_steps=2000 \ | |
| --checkpointing_steps=500 \ | |
| --run_validation --validation_steps=50 \ | |
| --seed="0" \ | |
| --report_to="wandb" \ | |
| --is_turbo --resolution 512 \ | |
| --push_to_hub | |
| ``` | |
| ## Acknowledgements | |
| This is based on the amazing work done by [Bram](https://github.com/bram-w) here for Diffusion DPO: https://github.com/bram-w/trl/blob/dpo/. | |