Spaces:
Running
Running
File size: 15,543 Bytes
e9917a9 53f71cb e20a6db 53f71cb 8c5287b e9917a9 50ca5e2 ee5b0c8 8c5287b 96bf0f5 8c5287b b723a99 e8a312d 54694b3 ab317be 804d4e4 487dcba ee5b0c8 54694b3 ee5b0c8 e9917a9 9eb8ba7 3e84455 60c0102 3e84455 60c0102 3e84455 60c0102 a882b42 60c0102 a882b42 60c0102 a882b42 60c0102 a882b42 3e84455 d4bcc30 3e84455 d4bcc30 3e84455 a882b42 d4bcc30 a882b42 3e84455 a882b42 3e84455 9eb8ba7 e20a6db ab317be ee5b0c8 54694b3 6ecc165 ee5b0c8 6ecc165 54694b3 6ecc165 cc6323b ee5b0c8 cc6323b ab317be cc6323b 54694b3 ee5b0c8 ab317be 6ecc165 54694b3 ee5b0c8 e20a6db 54694b3 6ecc165 ee5b0c8 1126c53 e20a6db 1126c53 60c0102 1126c53 60c0102 9eb8ba7 1126c53 487dcba 60c0102 487dcba 60c0102 487dcba 60c0102 487dcba 60c0102 1126c53 804d4e4 60c0102 e20a6db 6ecc165 3e84455 1126c53 6ecc165 487dcba 1126c53 6ecc165 487dcba 60c0102 1126c53 9eb8ba7 e20a6db 3e84455 e20a6db 804d4e4 e20a6db 804d4e4 6ecc165 804d4e4 6ecc165 1126c53 9eb8ba7 e20a6db 804d4e4 e8a312d 54694b3 e20a6db ee5b0c8 1126c53 54694b3 3e84455 54694b3 cc6323b 54694b3 cc6323b 54694b3 cc6323b 54694b3 cc6323b 54694b3 e20a6db ee5b0c8 e20a6db 54694b3 3e84455 e20a6db ee5b0c8 e20a6db 804d4e4 54694b3 804d4e4 e20a6db 804d4e4 a882b42 804d4e4 cc6323b 804d4e4 cc6323b 804d4e4 a1ad959 e8a312d cc6323b e8a312d cc6323b e8a312d e20a6db e8a312d e20a6db ee5b0c8 e8a312d a1ad959 e8a312d 9eb8ba7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 |
import os
import gradio as gr
import torch
from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel, AutoencoderKL
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTextConfig
from safetensors.torch import load_file
from collections import OrderedDict
import requests
from urllib.parse import urlparse, unquote
from pathlib import Path
import hashlib
from datetime import datetime
from typing import Dict, List, Optional
from huggingface_hub import login, HfApi, hf_hub_download
from huggingface_hub.utils import validate_repo_id, HFValidationError
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
import subprocess # Import subprocess
from huggingface_hub.utils import HfHubHTTPError
from accelerate import Accelerator
import re # Import the 're' module
# ---------------------- DEPENDENCIES ----------------------
def install_dependencies_gradio():
"""Installs the necessary dependencies."""
try:
subprocess.run(
[
"pip",
"install",
"-U",
"torch",
"diffusers",
"transformers",
"accelerate",
"safetensors",
"huggingface_hub",
"xformers",
],
check=True,
capture_output=True,
text=True
)
print("Dependencies installed successfully.")
except subprocess.CalledProcessError as e:
print(f"Error installing dependencies:\n{e.stderr}")
raise
# ---------------------- UTILITY FUNCTIONS ----------------------
def download_model(model_path_or_url):
"""Downloads a model, handling URLs, HF repos, and local paths."""
try:
# 1. Check if it's a valid Hugging Face repo ID
try:
validate_repo_id(model_path_or_url)
local_path = hf_hub_download(repo_id=model_path_or_url)
return local_path
except HFValidationError:
pass
# 2. Check if it's a URL
if model_path_or_url.startswith("http://") or model_path_or_url.startswith("https://"):
response = requests.get(model_path_or_url, stream=True)
response.raise_for_status()
parsed_url = urlparse(model_path_or_url)
filename = os.path.basename(unquote(parsed_url.path))
if not filename:
filename = hashlib.sha256(model_path_or_url.encode()).hexdigest()
cache_dir = os.path.join(HUGGINGFACE_HUB_CACHE, "downloads")
os.makedirs(cache_dir, exist_ok=True)
local_path = os.path.join(cache_dir, filename)
with open(local_path, "wb") as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
return local_path
# 3. Check if it's a local file
elif os.path.isfile(model_path_or_url):
return model_path_or_url
# 4. Handle Hugging Face repo with a specific file
else:
try:
parts = model_path_or_url.split("/", 1)
if len(parts) == 2:
repo_id, filename = parts
validate_repo_id(repo_id)
local_path = hf_hub_download(repo_id=repo_id, filename=filename)
return local_path
else:
raise ValueError("Invalid input format.")
except HFValidationError:
raise ValueError(f"Invalid model path or URL: {model_path_or_url}")
except Exception as e:
raise ValueError(f"Error downloading or accessing model: {e}")
def create_model_repo(api, user, orgs_name, model_name, make_private=False):
"""Creates a Hugging Face model repository, handling missing inputs and sanitizing the username."""
print("---- create_model_repo Called ----")
print(f" user: {user}")
print(f" orgs_name: {orgs_name}")
print(f" model_name: {model_name}")
if not model_name:
model_name = f"converted-model-{datetime.now().strftime('%Y%m%d%H%M%S')}"
print(f" Using default model_name: {model_name}")
# --- Sanitize model_name and orgs_name ---
if orgs_name:
orgs_name = re.sub(r"[^a-zA-Z0-9._-]", "-", orgs_name)
print(f" Sanitized orgs_name: {orgs_name}")
if model_name:
model_name = re.sub(r"[^a-zA-Z0-9._-]", "-", model_name)
print(f" Sanitized model_name: {model_name}")
if orgs_name:
repo_id = f"{orgs_name}/{model_name.strip()}"
elif user:
sanitized_username = re.sub(r"[^a-zA-Z0-9._-]", "-", user['name'])
print(f" Original Username: {user['name']}")
print(f" Sanitized Username: {sanitized_username}")
repo_id = f"{sanitized_username}/{model_name.strip()}"
else:
raise ValueError(
"Must provide either an organization name or be logged in."
)
print(f" repo_id: {repo_id}")
try:
api.create_repo(repo_id=repo_id, repo_type="model", private=make_private)
print(f"Model repo '{repo_id}' created.")
return repo_id
except Exception as e:
print(f"Error creating repo: {e}")
raise
def load_sdxl_checkpoint(checkpoint_path):
"""Loads checkpoint and extracts state dicts."""
if checkpoint_path.endswith(".safetensors"):
state_dict = load_file(checkpoint_path, device="cpu")
elif checkpoint_path.endswith(".ckpt"):
state_dict = torch.load(checkpoint_path, map_location="cpu")["state_dict"]
else:
raise ValueError("Unsupported checkpoint format. Must be .safetensors or .ckpt")
text_encoder1_state = OrderedDict()
text_encoder2_state = OrderedDict()
vae_state = OrderedDict()
unet_state = OrderedDict()
for key, value in state_dict.items():
if key.startswith("first_stage_model."): # VAE
vae_state[key.replace("first_stage_model.", "")] = value.to(torch.float16)
elif key.startswith("condition_model.model.text_encoder."): # First Text Encoder
text_encoder1_state[key.replace("condition_model.model.text_encoder.", "")] = value.to(torch.float16)
elif key.startswith("condition_model.model.text_encoder_2."): # Second Text Encoder
text_encoder2_state[key.replace("condition_model.model.text_encoder_2.", "")] = value.to(torch.float16)
elif key.startswith("model.diffusion_model."): # UNet
unet_state[key.replace("model.diffusion_model.", "")] = value.to(torch.float16)
return text_encoder1_state, text_encoder2_state, vae_state, unet_state
def build_diffusers_model(text_encoder1_state, text_encoder2_state, vae_state, unet_state, reference_model_path=None):
"""Builds Diffusers components using accelerate for low-memory loading."""
if not reference_model_path:
reference_model_path = "stabilityai/stable-diffusion-xl-base-1.0"
# Initialize the Accelerator
accelerator = Accelerator(mixed_precision="fp16") # Use mixed precision
device = accelerator.device
# Load configurations from the reference model
config_text_encoder1 = CLIPTextConfig.from_pretrained(
reference_model_path, subfolder="text_encoder"
)
config_text_encoder2 = CLIPTextConfig.from_pretrained(
reference_model_path, subfolder="text_encoder_2"
)
# Use from_pretrained with device_map and low_cpu_mem_usage for all components
text_encoder1 = CLIPTextModel.from_pretrained(reference_model_path, subfolder="text_encoder", config=config_text_encoder1, low_cpu_mem_usage=True, torch_dtype=torch.float16).to(device)
text_encoder2 = CLIPTextModelWithProjection.from_pretrained(reference_model_path, subfolder="text_encoder_2", config=config_text_encoder2, low_cpu_mem_usage=True, torch_dtype=torch.float16).to(device)
vae = AutoencoderKL.from_pretrained(reference_model_path, subfolder="vae", low_cpu_mem_usage=True, torch_dtype=torch.float16).to(device)
unet = UNet2DConditionModel.from_pretrained(reference_model_path, subfolder="unet", low_cpu_mem_usage=True, torch_dtype=torch.float16).to(device)
# Load state dicts with strict=False
text_encoder1.load_state_dict(text_encoder1_state, strict=False)
text_encoder2.load_state_dict(text_encoder2_state, strict=False)
vae.load_state_dict(vae_state, strict=False)
unet.load_state_dict(unet_state, strict=False)
return text_encoder1, text_encoder2, vae, unet
def convert_and_save_sdxl_to_diffusers(checkpoint_path_or_url, output_path, reference_model_path):
"""Converts and saves the checkpoint to Diffusers format."""
checkpoint_path = download_model(checkpoint_path_or_url)
text_encoder1_state, text_encoder2_state, vae_state, unet_state = load_sdxl_checkpoint(checkpoint_path)
text_encoder1, text_encoder2, vae, unet = build_diffusers_model(
text_encoder1_state, text_encoder2_state, vae_state, unet_state, reference_model_path
)
# Load tokenizer and scheduler from the reference model
pipeline = StableDiffusionXLPipeline.from_pretrained(
reference_model_path,
text_encoder=text_encoder1,
text_encoder_2=text_encoder2,
vae=vae,
unet=unet,
torch_dtype=torch.float16,
)
pipeline.save_pretrained(output_path)
print(f"Model saved as Diffusers format: {output_path}")
# ---------------------- MAIN FUNCTION (with Debugging Prints) ----------------------
def main(
model_to_load,
reference_model,
output_path,
hf_token,
orgs_name,
model_name,
make_private,
):
"""Main function: SDXL checkpoint to Diffusers, always fp16."""
print("---- Main Function Called ----")
print(f" model_to_load: {model_to_load}")
print(f" reference_model: {reference_model}")
print(f" output_path: {output_path}")
print(f" hf_token: {hf_token}")
print(f" orgs_name: {orgs_name}")
print(f" model_name: {model_name}")
print(f" make_private: {make_private}")
# --- Force Login at the Beginning of main() ---
try:
login(token=hf_token, add_to_git_credential=True)
api = HfApi()
user = api.whoami() # Get logged-in user info
print(f" Logged-in user: {user}")
except Exception as e:
error_message = f"Error during login: {e} Ensure a valid WRITE token is provided."
print(f"---- Main Function Error: {error_message} ----")
return error_message
# --- Strip Whitespace and Sanitize from Inputs ---
model_to_load = model_to_load.strip()
reference_model = reference_model.strip()
output_path = output_path.strip()
hf_token = hf_token.strip() # Even though it's a password field
orgs_name = orgs_name.strip() if orgs_name else ""
model_name = model_name.strip() if model_name else ""
# --- Sanitize model_name and orgs_name ---
if orgs_name:
orgs_name = re.sub(r"[^a-zA-Z0-9._-]", "-", orgs_name)
if model_name:
model_name = re.sub(r"[^a-zA-Z0-9._-]", "-", model_name)
try:
convert_and_save_sdxl_to_diffusers(model_to_load, output_path, reference_model)
# --- Create Repo and Upload (Simplified) ---
if not model_name:
model_name = f"converted-model-{datetime.now().strftime('%Y%m%d%H%M%S')}"
print(f"Using default model_name: {model_name}")
if orgs_name:
repo_id = f"{orgs_name}/{model_name}"
elif user:
# Sanitize username here as well:
sanitized_username = re.sub(r"[^a-zA-Z0-9._-]", "-", user['name'])
print(f" Sanitized Username: {sanitized_username}")
repo_id = f"{sanitized_username}/{model_name}"
else: # Should never happen because of login, but good practice
raise ValueError("Must provide either an organization name or be logged in.")
print(f"repo_id = {repo_id}")
try:
api.create_repo(repo_id=repo_id, repo_type="model", private=make_private)
print(f"Model repo '{repo_id}' created.")
except Exception as e:
print(f"Error in creating model repo: {e}")
raise
api.upload_folder(folder_path=output_path, repo_id=repo_id)
print(f"Model uploaded to: https://huggingface.co/{repo_id}")
result = "Conversion and upload completed successfully!"
print(f"---- Main Function Successful: {result} ----")
return result
except Exception as e:
error_message = f"An error occurred: {e}"
print(f"---- Main Function Error: {error_message} ----")
return error_message
# ---------------------- GRADIO INTERFACE ----------------------
css = """
#main-container {
display: flex;
flex-direction: column;
font-family: 'Arial', sans-serif;
font-size: 16px;
color: #333;
}
#convert-button {
margin-top: 1em;
}
"""
with gr.Blocks(css=css) as demo:
gr.Markdown(
"""
# π¨ SDXL Model Converter
Convert SDXL checkpoints to Diffusers format (FP16, CPU-only).
### π₯ Input Sources Supported:
- Local model files (.safetensors, .ckpt)
- Direct URLs to model files
- Hugging Face model repositories (e.g., 'my-org/my-model' or 'my-org/my-model/file.safetensors')
### βΉοΈ Important Notes:
- This tool runs on **CPU**, conversion might be slower than on GPU.
- For Hugging Face uploads, you need a **WRITE** token (not a read token).
- Get your HF token here: [https://huggingface.co/settings/tokens](https://huggingface.co/settings/tokens)
### πΎ Memory Usage:
- This space is configured for **FP16** precision to reduce memory usage.
- Close other applications during conversion.
- For large models, ensure you have at least 16GB of RAM.
### π» Source Code:
- [GitHub Repository](https://github.com/Ktiseos-Nyx/Gradio-SDXL-Diffusers)
### π Support:
- If you're interested in funding more projects: [Ko-fi](https://ko-fi.com/duskfallcrew)
"""
)
with gr.Row():
with gr.Column():
model_to_load = gr.Textbox(
label="SDXL Checkpoint (Path, URL, or HF Repo)",
placeholder="Path, URL, or Hugging Face Repo ID (e.g., my-org/my-model or my-org/my-model/file.safetensors)",
)
reference_model = gr.Textbox(
label="Reference Diffusers Model (Optional)",
placeholder="e.g., stabilityai/stable-diffusion-xl-base-1.0 (Leave blank for default)",
)
output_path = gr.Textbox(label="Output Path (Diffusers Format)", value="output")
hf_token = gr.Textbox(label="Hugging Face Token", placeholder="Your Hugging Face write token", type="password")
orgs_name = gr.Textbox(label="Organization Name (Optional)", placeholder="Your organization name")
model_name = gr.Textbox(label="Model Name", placeholder="The name of your model on Hugging Face")
make_private = gr.Checkbox(label="Make Repository Private", value=False)
convert_button = gr.Button("Convert and Upload")
with gr.Column(variant="panel"):
output = gr.Markdown(container=True)
convert_button.click(
fn=main,
inputs=[
model_to_load,
reference_model,
output_path,
hf_token,
orgs_name,
model_name,
make_private,
],
outputs=output,
)
demo.launch() |