import os import json import inspect from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler from peft import LoraConfig, get_peft_model import torch from huggingface_hub import snapshot_download # ─── 1. Read hyperparameters & mode ─────────────────────────────────────────── model_id = os.environ.get("BASE_MODEL", "HiDream-ai/HiDream-I1-Dev") trigger_word = os.environ.get("TRIGGER_WORD", "default-style") num_steps = int(os.environ.get("NUM_STEPS", 100)) lora_r = int(os.environ.get("LORA_R", 16)) lora_alpha = int(os.environ.get("LORA_ALPHA", 16)) LOCAL = os.environ.get("LOCAL_TRAIN", "").lower() in ("1", "true") # ─── 2. Set up directories ──────────────────────────────────────────────────── if LOCAL: DATA_DIR = os.path.join(os.getcwd(), "data") OUTPUT_DIR = os.path.join(os.getcwd(), "lora-trained") LOCAL_MODEL = os.path.join(os.getcwd(), "hidream-model") os.makedirs(DATA_DIR, exist_ok=True) os.makedirs(OUTPUT_DIR, exist_ok=True) else: DATA_DIR = "/tmp/data" OUTPUT_DIR = "/tmp/lora-trained" CACHE_DIR = "/tmp/hidream-model" os.makedirs(DATA_DIR, exist_ok=True) os.makedirs(OUTPUT_DIR, exist_ok=True) os.makedirs(CACHE_DIR, exist_ok=True) print(f"📂 Dataset directory: {DATA_DIR}", flush=True) print(f"📥 Preparing base model: {model_id}", flush=True) # ─── 3. Resolve model path ──────────────────────────────────────────────────── def get_model_path(): # If local and predownloaded model exists, use it if LOCAL and os.path.isdir(LOCAL_MODEL) and os.path.isfile(os.path.join(LOCAL_MODEL, "config.json")): print(f"✅ Using local model at: {LOCAL_MODEL}", flush=True) return LOCAL_MODEL # Otherwise download (to ~/.cache on local, or /tmp on Spaces) download_kwargs = {} if LOCAL else {"local_dir": CACHE_DIR} path = snapshot_download(model_id, **download_kwargs) print(f"✅ Downloaded model to: {path}", flush=True) return path model_path = get_model_path() # ─── 4. Patch model_index.json to remove unsupported scheduler ──────────────── mi_file = os.path.join(model_path, "model_index.json") if os.path.isfile(mi_file): with open(mi_file, "r") as f: mi = json.load(f) if "pipeline" in mi and "scheduler" in mi["pipeline"]: print("🔧 Removing 'scheduler' entry from model_index.json", flush=True) mi["pipeline"].pop("scheduler", None) with open(mi_file, "w") as f: json.dump(mi, f, indent=2) # ─── 5. Load & filter scheduler_config.json ────────────────────────────────── sched_cfg_path = os.path.join(model_path, "scheduler", "scheduler_config.json") filtered_cfg = {} if os.path.isfile(sched_cfg_path): with open(sched_cfg_path, "r") as f: raw_cfg = json.load(f) sig = inspect.signature(DPMSolverMultistepScheduler.__init__) valid_keys = set(sig.parameters.keys()) - {"self", "args", "kwargs"} filtered_cfg = {k: v for k, v in raw_cfg.items() if k in valid_keys} dropped = set(raw_cfg) - set(filtered_cfg) if dropped: print(f"⚠️ Dropped unsupported scheduler keys: {dropped}", flush=True) try: scheduler = DPMSolverMultistepScheduler(**filtered_cfg) print("✅ Instantiated DPMSolverMultistepScheduler from config", flush=True) except Exception as e: print(f"❌ Failed to init scheduler from config ({e}), using defaults", flush=True) scheduler = DPMSolverMultistepScheduler() else: print("⚠️ No scheduler_config.json found; using default DPMSolverMultistepScheduler", flush=True) scheduler = DPMSolverMultistepScheduler() # ─── 6. Load the Stable Diffusion pipeline ──────────────────────────────────── print(f"🔧 Loading pipeline from: {model_path}", flush=True) pipe = StableDiffusionPipeline.from_pretrained( model_path, torch_dtype=torch.float16, scheduler=scheduler ).to("cuda") # ─── 7. Apply LoRA adapters ─────────────────────────────────────────────────── print(f"🧠 Applying LoRA config (r={lora_r}, α={lora_alpha})", flush=True) lora_config = LoraConfig( r=lora_r, lora_alpha=lora_alpha, bias="none", task_type="CAUSAL_LM" ) pipe.unet = get_peft_model(pipe.unet, lora_config) # ─── 8. Training loop stub ───────────────────────────────────────────────────── print(f"🚀 Starting fine‑tuning for {num_steps} steps (trigger: {trigger_word})", flush=True) for step in range(num_steps): # TODO: replace this stub with your actual training code: # • Load batches from DATA_DIR # • Forward/backward pass, optimizer.step(), etc. print(f"🌀 Step {step+1}/{num_steps}", flush=True) # ─── 9. Save the fine‑tuned model ───────────────────────────────────────────── print(f"💾 Saving fine‑tuned model to: {OUTPUT_DIR}", flush=True) pipe.save_pretrained(OUTPUT_DIR) print("✅ Training complete!", flush=True)