FluxFoundry / diffusers_lora_finetune.py
stillerman's picture
generating images
3cac2a1
# ---
# deploy: true
# ---
from dataclasses import dataclass
from pathlib import Path
from typing import Optional
import modal
app = modal.App(name="dreambooth-lora-flux")
image = modal.Image.debian_slim(python_version="3.10").pip_install(
"accelerate==0.31.0",
"datasets==3.6.0",
"pillow",
"fastapi[standard]==0.115.4",
"ftfy~=6.1.0",
"gradio~=5.5.0",
"huggingface-hub==0.32.4",
"hf_transfer==0.1.8",
"numpy<2",
"peft==0.11.1",
"pydantic==2.9.2",
"sentencepiece>=0.1.91,!=0.1.92",
"smart_open~=6.4.0",
"starlette==0.41.2",
"transformers~=4.41.2",
"torch~=2.2.0",
"torchvision~=0.16",
"triton~=2.2.0",
"wandb==0.17.6",
)
GIT_SHA = "e649678bf55aeaa4b60bd1f68b1ee726278c0304" # specify the commit to fetch
image = (
image.apt_install("git")
# Perform a shallow fetch of just the target `diffusers` commit, checking out
# the commit in the container's home directory, /root. Then install `diffusers`
.run_commands(
"cd /root && git init .",
"cd /root && git remote add origin https://github.com/huggingface/diffusers",
f"cd /root && git fetch --depth=1 origin {GIT_SHA} && git checkout {GIT_SHA}",
"cd /root && pip install -e .",
)
)
# ### Configuration with `dataclass`es
# Machine learning apps often have a lot of configuration information.
# We collect up all of our configuration into dataclasses to avoid scattering special/magic values throughout code.
@dataclass
class SharedConfig:
"""Configuration information shared across project components."""
# The instance name is the "proper noun" we're teaching the model
instance_name: str = "Qwerty"
# That proper noun is usually a member of some class (person, bird),
# and sharing that information with the model helps it generalize better.
class_name: str = "Golden Retriever"
# identifier for pretrained models on Hugging Face
model_name: str = "black-forest-labs/FLUX.1-dev"
# ### Storing data created by our app with `modal.Volume`
# The tools we've used so far work well for fetching external information,
# which defines the environment our app runs in,
# but what about data that we create or modify during the app's execution?
# A persisted [`modal.Volume`](https://modal.com/docs/guide/volumes) can store and share data across Modal Apps and Functions.
# We'll use one to store both the original and fine-tuned weights we create during training
# and then load them back in for inference.
image = image.env(
{"HF_HUB_ENABLE_HF_TRANSFER": "1"} # turn on faster downloads from HF
)
def load_images_from_hf_dataset(dataset_id: str, hf_token: str) -> Path:
"""Load images from a HuggingFace dataset."""
import PIL.Image
from datasets import load_dataset
img_path = Path("/img")
img_path.mkdir(parents=True, exist_ok=True)
# Load dataset from HuggingFace
dataset = load_dataset(dataset_id, token=hf_token, split="train")
for ii, example in enumerate(dataset):
# Assume the dataset has an 'image' column
if 'image' in example:
image = example['image']
if isinstance(image, PIL.Image.Image):
image.save(img_path / f"{ii}.png")
else:
# Handle other image formats
pil_image = PIL.Image.open(image)
pil_image.save(img_path / f"{ii}.png")
else:
print(f"Warning: No 'image' field found in dataset example {ii}")
print(f"{len(dataset)} images loaded from HuggingFace dataset")
return img_path
# ## Stateless API Training Function
@dataclass
class APITrainConfig:
"""Configuration for the API training function."""
# Basic model info
model_name: str = "black-forest-labs/FLUX.1-dev"
# Training prompt components
instance_name: str = "subject"
class_name: str = "person"
prefix: str = "a photo of"
postfix: str = ""
# Training hyperparameters
resolution: int = 512
train_batch_size: int = 3
rank: int = 16 # lora rank
gradient_accumulation_steps: int = 1
learning_rate: float = 4e-4
lr_scheduler: str = "constant"
lr_warmup_steps: int = 0
max_train_steps: int = 500
checkpointing_steps: int = 1000
seed: int = 117
@app.function(
image=image,
gpu="A100-80GB", # fine-tuning is VRAM-heavy and requires a high-VRAM GPU
timeout=3600, # 60 minutes
)
def train_lora_stateless(
dataset_id: str,
hf_token: str,
output_repo: str,
instance_name: Optional[str] = None,
class_name: Optional[str] = None,
max_train_steps: int = 500,
):
"""
Stateless LoRA training function that reads from HF dataset and uploads to HF repo.
Args:
dataset_id: HuggingFace dataset ID (e.g., "username/dataset-name")
hf_token: HuggingFace API token
output_repo: HuggingFace repository to upload the trained LoRA to
instance_name: Name of the subject (optional, defaults to "subject")
class_name: Class of the subject (optional, defaults to "person")
max_train_steps: Number of training steps
"""
import subprocess
import tempfile
from pathlib import Path
import torch
from accelerate.utils import write_basic_config
from diffusers import DiffusionPipeline
from huggingface_hub import snapshot_download, upload_folder, login, create_repo
# Login to HuggingFace
login(token=hf_token)
# Create temporary directories
with tempfile.TemporaryDirectory() as temp_dir:
temp_path = Path(temp_dir)
model_dir = temp_path / "model"
output_dir = temp_path / "output"
# Download base model
print("📥 Downloading base model...")
snapshot_download(
"black-forest-labs/FLUX.1-dev",
local_dir=str(model_dir),
ignore_patterns=["*.pt", "*.bin"], # using safetensors
token=hf_token
)
# Load and validate model
DiffusionPipeline.from_pretrained(str(model_dir), torch_dtype=torch.bfloat16)
print("✅ Base model loaded successfully")
# Load training images from HF dataset
print(f"📥 Loading images from dataset: {dataset_id}")
img_path = load_images_from_hf_dataset(dataset_id, hf_token)
# Set up training configuration
config = APITrainConfig(
instance_name=instance_name or "subject",
class_name=class_name or "person",
max_train_steps=max_train_steps
)
# Set up hugging face accelerate library for fast training
write_basic_config(mixed_precision="bf16")
# Define the training prompt
instance_phrase = f"{config.instance_name} the {config.class_name}"
prompt = f"{config.prefix} {instance_phrase} {config.postfix}".strip()
print(f"🎯 Training prompt: {prompt}")
print(f"🚀 Starting training for {max_train_steps} steps...")
# Execute training subprocess
def _exec_subprocess(cmd: list[str]):
"""Executes subprocess and prints log to terminal while subprocess is running."""
process = subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
)
with process.stdout as pipe:
for line in iter(pipe.readline, b""):
line_str = line.decode()
print(f"{line_str}", end="")
if exitcode := process.wait() != 0:
raise subprocess.CalledProcessError(exitcode, "\n".join(cmd))
# Run training
_exec_subprocess([
"accelerate",
"launch",
"examples/dreambooth/train_dreambooth_lora_flux.py",
"--mixed_precision=bf16",
f"--pretrained_model_name_or_path={model_dir}",
f"--instance_data_dir={img_path}",
f"--output_dir={output_dir}",
f"--instance_prompt={prompt}",
f"--resolution={config.resolution}",
f"--train_batch_size={config.train_batch_size}",
f"--gradient_accumulation_steps={config.gradient_accumulation_steps}",
f"--learning_rate={config.learning_rate}",
f"--lr_scheduler={config.lr_scheduler}",
f"--lr_warmup_steps={config.lr_warmup_steps}",
f"--max_train_steps={config.max_train_steps}",
f"--checkpointing_steps={config.checkpointing_steps}",
f"--seed={config.seed}",
])
print("✅ Training completed!")
# Upload trained LoRA to HuggingFace repository
print(f"📤 Uploading LoRA to repository: {output_repo}")
# Create repository if it doesn't exist
create_repo(
repo_id=output_repo,
repo_type="model",
token=hf_token,
exist_ok=True
)
# print contents of output_dir
print(f"Contents of {output_dir}:")
for file in output_dir.iterdir():
print(file)
upload_folder(
folder_path=str(output_dir),
repo_id=output_repo,
repo_type="model",
token=hf_token,
commit_message=f"Add LoRA trained on {dataset_id}",
)
print(f"🎉 Successfully uploaded LoRA to {output_repo}")
return {
"status": "success",
"message": f"LoRA training completed and uploaded to {output_repo}",
"dataset_used": dataset_id,
"training_steps": max_train_steps,
"training_prompt": prompt
}
# ## API Endpoints with Job ID System
@app.function(
image=image,
keep_warm=1, # Keep one container warm for faster response
)
@modal.fastapi_endpoint(method="POST")
def api_start_training(item: dict):
"""
Start LoRA training and return a job ID.
Expected JSON payload:
{
"dataset_id": "username/dataset-name",
"hf_token": "hf_...",
"output_repo": "username/output-repo",
"instance_name": "optional_subject_name",
"class_name": "optional_class_name",
"max_train_steps": 500
}
"""
try:
# Extract required parameters
dataset_id = item["dataset_id"]
hf_token = item["hf_token"]
output_repo = item["output_repo"]
# Extract optional parameters
instance_name = item.get("instance_name")
class_name = item.get("class_name")
max_train_steps = item.get("max_train_steps", 500)
# Start training (non-blocking)
call_handle = train_lora_stateless.spawn(
dataset_id=dataset_id,
hf_token=hf_token,
output_repo=output_repo,
instance_name=instance_name,
class_name=class_name,
max_train_steps=max_train_steps
)
job_id = call_handle.object_id
return {
"status": "started",
"job_id": job_id,
"message": "Training job started successfully",
"dataset_id": dataset_id,
"output_repo": output_repo,
"max_train_steps": max_train_steps
}
except KeyError as e:
return {
"status": "error",
"message": f"Missing required parameter: {e}"
}
except Exception as e:
return {
"status": "error",
"message": f"Failed to start training: {str(e)}"
}
@app.function(
image=image,
keep_warm=1,
)
@modal.fastapi_endpoint(method="GET")
def api_job_status(job_id: str):
"""
Check the status of a training job.
Pass job_id as a query parameter: /job_status?job_id=xyz
"""
try:
from modal.functions import FunctionCall
# Get the function call handle
call_handle = FunctionCall.from_id(job_id)
if call_handle is None:
return {
"status": "error",
"message": "Job not found"
}
# Check if the job is finished
try:
result = call_handle.get(timeout=0) # Non-blocking check
return {
"status": "completed",
"result": result
}
except TimeoutError:
return {
"status": "running",
"message": "Job is still running"
}
except Exception as e:
return {
"status": "failed",
"message": f"Job failed: {str(e)}"
}
except Exception as e:
return {
"status": "error",
"message": f"Error checking job status: {str(e)}"
}
@dataclass
class InferenceConfig:
"""Configuration for inference."""
num_inference_steps: int = 20
guidance_scale: float = 7.5
width: int = 512
height: int = 512
@app.function(
image=image,
gpu="A100", # Inference requires GPU
timeout=1800, # 30 minutes
)
def generate_images_stateless(
hf_token: str,
lora_repo: str,
prompts: list[str],
num_inference_steps: int = 20,
guidance_scale: float = 7.5,
width: int = 512,
height: int = 512,
):
"""
Stateless function to generate images using a LoRA from HuggingFace.
Args:
hf_token: HuggingFace API token
lora_repo: HuggingFace repository containing the LoRA (e.g., "username/my-lora")
prompts: List of text prompts to generate images for
num_inference_steps: Number of denoising steps
guidance_scale: Classifier-free guidance scale
width: Image width
height: Image height
Returns:
Dictionary with status and list of generated images (as base64 strings)
"""
import base64
import io
import tempfile
from pathlib import Path
import torch
from diffusers import DiffusionPipeline
from huggingface_hub import snapshot_download, login
try:
# Login to HuggingFace
login(token=hf_token)
# Create temporary directory for model
with tempfile.TemporaryDirectory() as temp_dir:
temp_path = Path(temp_dir)
model_dir = temp_path / "model"
lora_dir = temp_path / "lora"
print("📥 Downloading base model...")
# Download base model
snapshot_download(
"black-forest-labs/FLUX.1-dev",
local_dir=str(model_dir),
ignore_patterns=["*.pt", "*.bin"], # using safetensors
token=hf_token
)
print(f"📥 Downloading LoRA from {lora_repo}...")
# Download LoRA
snapshot_download(
lora_repo,
local_dir=str(lora_dir),
token=hf_token
)
print("🔄 Loading pipeline...")
# Load the diffusion pipeline
pipe = DiffusionPipeline.from_pretrained(
str(model_dir),
torch_dtype=torch.bfloat16,
).to("cuda")
# Load LoRA weights
pipe.load_lora_weights(str(lora_dir))
print(f"🎨 Generating {len(prompts)} images...")
generated_images = []
# Generate images for each prompt
for i, prompt in enumerate(prompts):
print(f" Generating image {i+1}/{len(prompts)}: {prompt[:50]}...")
image = pipe(
prompt,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
width=width,
height=height,
).images[0]
# Convert PIL Image to base64 string
img_buffer = io.BytesIO()
image.save(img_buffer, format='PNG')
img_base64 = base64.b64encode(img_buffer.getvalue()).decode('utf-8')
generated_images.append({
"prompt": prompt,
"image": img_base64
})
print("✅ All images generated successfully!")
return {
"status": "success",
"message": f"Generated {len(prompts)} images successfully",
"lora_repo": lora_repo,
"images": generated_images
}
except Exception as e:
return {
"status": "error",
"message": f"Failed to generate images: {str(e)}"
}
@app.function(
image=image,
keep_warm=1,
)
@modal.fastapi_endpoint(method="POST")
def api_generate_images(item: dict):
"""
Generate images using a LoRA model.
Expected JSON payload:
{
"hf_token": "hf_...",
"lora_repo": "username/my-lora",
"prompts": ["prompt1", "prompt2", ...],
"num_inference_steps": 20, // optional
"guidance_scale": 7.5, // optional
"width": 512, // optional
"height": 512 // optional
}
"""
try:
# Extract required parameters
hf_token = item["hf_token"]
lora_repo = item["lora_repo"]
prompts = item["prompts"]
if not isinstance(prompts, list) or len(prompts) == 0:
return {
"status": "error",
"message": "prompts must be a non-empty list"
}
# Extract optional parameters
num_inference_steps = item.get("num_inference_steps", 20)
guidance_scale = item.get("guidance_scale", 7.5)
width = item.get("width", 512)
height = item.get("height", 512)
# Start generation (non-blocking)
call_handle = generate_images_stateless.spawn(
hf_token=hf_token,
lora_repo=lora_repo,
prompts=prompts,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
width=width,
height=height
)
job_id = call_handle.object_id
return {
"status": "started",
"job_id": job_id,
"message": "Image generation job started successfully",
"lora_repo": lora_repo,
"num_prompts": len(prompts)
}
except KeyError as e:
return {
"status": "error",
"message": f"Missing required parameter: {e}"
}
except Exception as e:
return {
"status": "error",
"message": f"Failed to start image generation: {str(e)}"
}