ivangabriele's picture
feat: initialize project
2f5127c verified
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
python examples/scripts/ddpo.py \
--num_epochs=200 \
--train_gradient_accumulation_steps=1 \
--sample_num_steps=50 \
--sample_batch_size=6 \
--train_batch_size=3 \
--sample_num_batches_per_epoch=4 \
--per_prompt_stat_tracking=True \
--per_prompt_stat_tracking_buffer_size=32 \
--tracker_project_name="stable_diffusion_training" \
--log_with="wandb"
"""
import os
from dataclasses import dataclass, field
import numpy as np
import torch
import torch.nn as nn
from huggingface_hub import hf_hub_download
from huggingface_hub.utils import EntryNotFoundError
from transformers import CLIPModel, CLIPProcessor, HfArgumentParser, is_torch_npu_available, is_torch_xpu_available
from trl import DDPOConfig, DDPOTrainer, DefaultDDPOStableDiffusionPipeline
@dataclass
class ScriptArguments:
r"""
Arguments for the script.
Args:
pretrained_model (`str`, *optional*, defaults to `"runwayml/stable-diffusion-v1-5"`):
Pretrained model to use.
pretrained_revision (`str`, *optional*, defaults to `"main"`):
Pretrained model revision to use.
hf_hub_model_id (`str`, *optional*, defaults to `"ddpo-finetuned-stable-diffusion"`):
HuggingFace repo to save model weights to.
hf_hub_aesthetic_model_id (`str`, *optional*, defaults to `"trl-lib/ddpo-aesthetic-predictor"`):
Hugging Face model ID for aesthetic scorer model weights.
hf_hub_aesthetic_model_filename (`str`, *optional*, defaults to `"aesthetic-model.pth"`):
Hugging Face model filename for aesthetic scorer model weights.
use_lora (`bool`, *optional*, defaults to `True`):
Whether to use LoRA.
"""
pretrained_model: str = field(
default="runwayml/stable-diffusion-v1-5", metadata={"help": "Pretrained model to use."}
)
pretrained_revision: str = field(default="main", metadata={"help": "Pretrained model revision to use."})
hf_hub_model_id: str = field(
default="ddpo-finetuned-stable-diffusion", metadata={"help": "HuggingFace repo to save model weights to."}
)
hf_hub_aesthetic_model_id: str = field(
default="trl-lib/ddpo-aesthetic-predictor",
metadata={"help": "Hugging Face model ID for aesthetic scorer model weights."},
)
hf_hub_aesthetic_model_filename: str = field(
default="aesthetic-model.pth",
metadata={"help": "Hugging Face model filename for aesthetic scorer model weights."},
)
use_lora: bool = field(default=True, metadata={"help": "Whether to use LoRA."})
class MLP(nn.Module):
def __init__(self):
super().__init__()
self.layers = nn.Sequential(
nn.Linear(768, 1024),
nn.Dropout(0.2),
nn.Linear(1024, 128),
nn.Dropout(0.2),
nn.Linear(128, 64),
nn.Dropout(0.1),
nn.Linear(64, 16),
nn.Linear(16, 1),
)
@torch.no_grad()
def forward(self, embed):
return self.layers(embed)
class AestheticScorer(torch.nn.Module):
"""
This model attempts to predict the aesthetic score of an image. The aesthetic score
is a numerical approximation of how much a specific image is liked by humans on average.
This is from https://github.com/christophschuhmann/improved-aesthetic-predictor
"""
def __init__(self, *, dtype, model_id, model_filename):
super().__init__()
self.clip = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
self.mlp = MLP()
try:
cached_path = hf_hub_download(model_id, model_filename)
except EntryNotFoundError:
cached_path = os.path.join(model_id, model_filename)
state_dict = torch.load(cached_path, map_location=torch.device("cpu"), weights_only=True)
self.mlp.load_state_dict(state_dict)
self.dtype = dtype
self.eval()
@torch.no_grad()
def __call__(self, images):
device = next(self.parameters()).device
inputs = self.processor(images=images, return_tensors="pt")
inputs = {k: v.to(self.dtype).to(device) for k, v in inputs.items()}
embed = self.clip.get_image_features(**inputs)
# normalize embedding
embed = embed / torch.linalg.vector_norm(embed, dim=-1, keepdim=True)
return self.mlp(embed).squeeze(1)
def aesthetic_scorer(hub_model_id, model_filename):
scorer = AestheticScorer(
model_id=hub_model_id,
model_filename=model_filename,
dtype=torch.float32,
)
if is_torch_npu_available():
scorer = scorer.npu()
elif is_torch_xpu_available():
scorer = scorer.xpu()
else:
scorer = scorer.cuda()
def _fn(images, prompts, metadata):
images = (images * 255).round().clamp(0, 255).to(torch.uint8)
scores = scorer(images)
return scores, {}
return _fn
# list of example prompts to feed stable diffusion
animals = [
"cat",
"dog",
"horse",
"monkey",
"rabbit",
"zebra",
"spider",
"bird",
"sheep",
"deer",
"cow",
"goat",
"lion",
"frog",
"chicken",
"duck",
"goose",
"bee",
"pig",
"turkey",
"fly",
"llama",
"camel",
"bat",
"gorilla",
"hedgehog",
"kangaroo",
]
def prompt_fn():
return np.random.choice(animals), {}
def image_outputs_logger(image_data, global_step, accelerate_logger):
# For the sake of this example, we will only log the last batch of images
# and associated data
result = {}
images, prompts, _, rewards, _ = image_data[-1]
for i, image in enumerate(images):
prompt = prompts[i]
reward = rewards[i].item()
result[f"{prompt:.25} | {reward:.2f}"] = image.unsqueeze(0).float()
accelerate_logger.log_images(
result,
step=global_step,
)
if __name__ == "__main__":
parser = HfArgumentParser((ScriptArguments, DDPOConfig))
script_args, training_args = parser.parse_args_into_dataclasses()
training_args.project_kwargs = {
"logging_dir": "./logs",
"automatic_checkpoint_naming": True,
"total_limit": 5,
"project_dir": "./save",
}
pipeline = DefaultDDPOStableDiffusionPipeline(
script_args.pretrained_model,
pretrained_model_revision=script_args.pretrained_revision,
use_lora=script_args.use_lora,
)
trainer = DDPOTrainer(
training_args,
aesthetic_scorer(script_args.hf_hub_aesthetic_model_id, script_args.hf_hub_aesthetic_model_filename),
prompt_fn,
pipeline,
image_samples_hook=image_outputs_logger,
)
trainer.train()
# Save and push to hub
trainer.save_model(training_args.output_dir)
if training_args.push_to_hub:
trainer.push_to_hub(dataset_name=script_args.dataset_name)