HDM-demo / app.py
Kohaku-Blueleaf
fixes
aa34787
raw
history blame
12.1 kB
import os
import random
import json
from pathlib import Path
import gradio as gr
import httpx
if os.environ.get("IN_SPACES", None) is not None:
in_spaces = True
import spaces
os.system("pip install git+https://${GIT_USER}:${GIT_TOKEN}@github.com/KohakuBlueleaf/XUT")
else:
in_spaces = False
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from safetensors.torch import load_file
from PIL import Image
from tqdm import trange
try:
# pre-import triton can avoid diffusers/transformers make import error
import triton
except ImportError:
print("Triton not found, skip pre import")
torch.set_float32_matmul_precision("high")
## HDM model dep
import xut.env
xut.env.TORCH_COMPILE = False
xut.env.USE_LIGER = True
xut.env.USE_XFORMERS = False
xut.env.USE_XFORMERS_LAYERS = False
from xut.xut import XUDiT
from transformers import Qwen3Model, Qwen2Tokenizer
from diffusers import AutoencoderKL
## TIPO
import kgen.models as kgen_models
import kgen.executor.tipo as tipo
from kgen.formatter import apply_format, seperate_tags
DEFAULT_FORMAT = """
<|special|>,
<|characters|>, <|copyrights|>,
<|artist|>,
<|quality|>, <|meta|>, <|rating|>,
<|general|>,
<|extended|>.
""".strip()
def GPU(func, duration=None):
if in_spaces:
return spaces.GPU(func, duration=duration)
else:
return func
def download_model(url: str, filepath: str):
"""Minimal fast download function"""
if Path(filepath).exists():
print(f"Model already exists at {filepath}")
return
print(f"Downloading model...")
Path(filepath).parent.mkdir(parents=True, exist_ok=True)
with httpx.stream("GET", url, follow_redirects=True) as response:
response.raise_for_status()
with open(filepath, "wb") as f:
for chunk in response.iter_bytes(chunk_size=128 * 1024):
f.write(chunk)
print(f"Download completed: {filepath}")
def prompt_opt(tags, nl_prompt, aspect_ratio, seed):
meta, operations, general, nl_prompt = tipo.parse_tipo_request(
seperate_tags(tags.split(",")),
nl_prompt,
tag_length_target="long",
nl_length_target="short",
generate_extra_nl_prompt=True,
)
meta["aspect_ratio"] = f"{aspect_ratio:.3f}"
result, timing = tipo.tipo_runner(meta, operations, general, nl_prompt, seed=seed)
return apply_format(result, DEFAULT_FORMAT).strip().strip(".").strip(",")
# --- User's core functions (copied directly) ---
def cfg_wrapper(
prompt: str | list[str],
neg_prompt: str | list[str],
unet: nn.Module, # should be k_diffusion wrapper
te: Qwen3Model,
tokenizer: Qwen2Tokenizer,
cfg_scale: float = 3.0,
):
prompt_token = {
k: v.to(device)
for k, v in
tokenizer(
prompt,
padding="longest",
return_tensors="pt",
).items()
}
neg_prompt_token = {
k: v.to(device)
for k, v in
tokenizer(
neg_prompt,
padding="longest",
return_tensors="pt",
).items()
}
emb = te(**prompt_token).last_hidden_state
neg_emb = te(**neg_prompt_token).last_hidden_state
if emb.size(1) > neg_emb.size(1):
pad_setting = (0, 0, 0, emb.size(1) - neg_emb.size(1))
neg_emb = F.pad(neg_emb, pad_setting)
if neg_emb.size(1) > emb.size(1):
pad_setting = (0, 0, 0, neg_emb.size(1) - emb.size(1))
emb = F.pad(emb, pad_setting)
text_ctx_emb = torch.concat([emb, neg_emb])
def cfg_fn(x, t, cfg=cfg_scale):
cond, uncond = unet(
x.repeat(2, 1, 1, 1),
t.expand(x.size(0) * 2),
text_ctx_emb,
).chunk(2)
cond = cond.float()
uncond = uncond.float()
return uncond + (cond - uncond) * cfg
return cfg_fn
print("Loading models, please wait...")
device = torch.device("cuda")
print("Using device:", torch.cuda.get_device_name(device))
model = XUDiT(
**json.load(open("./config/xut-small-1024-tread.json", "r"))
).half().requires_grad_(False).eval().to(device)
tokenizer = Qwen2Tokenizer.from_pretrained(
"Qwen/Qwen3-0.6B",
)
te = Qwen3Model.from_pretrained(
"Qwen/Qwen3-0.6B",
torch_dtype=torch.float16,
attn_implementation="sdpa"
).half().eval().requires_grad_(False).to(device)
vae = AutoencoderKL.from_pretrained(
"KBlueLeaf/EQ-SDXL-VAE"
).half().eval().requires_grad_(False).to(device)
vae_mean = torch.tensor(vae.config.latents_mean).view(1, -1, 1, 1).to(device)
vae_std = torch.tensor(vae.config.latents_std).view(1, -1, 1, 1).to(device)
if not os.path.exists("./model/model.safetensors"):
model_file = os.environ.get("MODEL_FILE")
os.system(f"hfutils download -t model -r KBlueLeaf/XUT-demo -f {model_file} -o model/model.safetensors")
state_dict = load_file("./model/model.safetensors")
model_sd = {k.replace("unet.", ""): v for k, v in state_dict.items() if k.startswith("unet.")}
model_sd = {k.replace("model.", ""): v for k, v in model_sd.items()}
missing, unexpected = model.load_state_dict(model_sd, strict=False)
if missing:
print(f"Missing keys: {missing}")
if unexpected:
print(f"Unexpected keys: {unexpected}")
tipo_model_name, gguf_list = kgen_models.tipo_model_list[0]
kgen_models.download_gguf(
tipo_model_name,
gguf_list[-1],
)
kgen_models.load_model(
f"{tipo_model_name}_{gguf_list[-1]}", gguf=True, device="cpu"
)
print("Models loaded successfully. UI is ready.")
@GPU
@torch.no_grad()
def generate(
nl_prompt: str,
tag_prompt: str,
negative_prompt: str,
num_images: int,
steps: int,
cfg_scale: float,
size: int,
aspect_ratio: str,
fixed_short_edge: bool,
seed: int,
progress=gr.Progress(),
):
as_w, as_h = aspect_ratio.split(":")
aspect_ratio = float(as_w) / float(as_h)
# Set seed for reproducibility
if seed == -1:
seed = random.randint(0, 2**32 - 1)
torch.manual_seed(seed)
# TIPO
tipo.BAN_TAGS = [i.strip() for i in negative_prompt.split(",") if i.strip()]
final_prompt = prompt_opt(tag_prompt, nl_prompt, aspect_ratio, seed)
yield None, final_prompt
all_pil_images = []
prompts_to_generate = [final_prompt.replace("\n", " ")] * num_images
negative_prompts_to_generate = [negative_prompt] * num_images
if fixed_short_edge:
if aspect_ratio > 1:
h_factor = 1
w_factor = aspect_ratio
else:
h_factor = 1 / aspect_ratio
w_factor = 1
else:
w_factor = aspect_ratio**0.5
h_factor = 1 / w_factor
w = int(size * w_factor / 16) * 2
h = int(size * h_factor / 16) * 2
print("=" * 100)
print(
f"Generating {num_images} image(s) with seed: {seed} and resolution {w*8}x{h*8}"
)
print("-" * 80)
print(f"Final prompt: {final_prompt}")
print("-" * 80)
print(f"Negative prompt: {negative_prompt}")
print("-" * 80)
prompts_batch = prompts_to_generate
neg_prompts_batch = negative_prompts_to_generate
# Core logic from the original script
cfg_fn = cfg_wrapper(
prompts_batch,
neg_prompts_batch,
unet=model,
te=te,
tokenizer=tokenizer,
cfg_scale=cfg_scale,
)
xt = torch.randn(num_images, 4, h, w).to(device)
t = 1.0
dt = 1.0 / steps
with trange(steps, desc="Generating Steps", smoothing=0.05) as cli_prog_bar:
for step in progress.tqdm(list(range(steps)), desc="Generating Steps"):
with torch.autocast(device.type, dtype=torch.float16):
model_pred = cfg_fn(xt, torch.tensor(t, device=device))
xt = xt - dt * model_pred.float()
t -= dt
cli_prog_bar.update(1)
generated_latents = xt.float()
image_tensors = torch.concat(
[
vae.decode(
(
generated_latent[None] * vae_std
+ vae_mean
).half()
).sample.cpu()
for generated_latent in generated_latents
]
)
# Convert tensors to PIL images
for image_tensor in image_tensors:
image = Image.fromarray(
((image_tensor * 0.5 + 0.5) * 255)
.clamp(0, 255)
.numpy()
.astype(np.uint8)
.transpose(1, 2, 0)
)
all_pil_images.append(image)
yield all_pil_images, final_prompt
# --- Gradio UI Definition ---
with gr.Blocks(title="HDM Demo", theme=gr.themes.Soft()) as demo:
gr.Markdown("# HomeDiffusion Gradio UI")
gr.Markdown(
"### Enter a natural language prompt and/or specific tags to generate an image."
)
with gr.Row():
with gr.Column(scale=2):
nl_prompt_box = gr.Textbox(
label="Natural Language Prompt",
placeholder="e.g., A beautiful anime girl standing in a blooming cherry blossom forest",
lines=3,
)
tag_prompt_box = gr.Textbox(
label="Tag Prompt (comma-separated)",
placeholder="e.g., 1girl, solo, long hair, cherry blossoms, school uniform",
lines=3,
)
neg_prompt_box = gr.Textbox(
label="Negative Prompt",
value=(
"low quality, worst quality, "
"jpeg artifacts, bad anatomy, old, early, "
"copyright name, watermark"
),
lines=3,
)
with gr.Column(scale=1):
with gr.Row():
num_images_slider = gr.Slider(
label="Number of Images", minimum=1, maximum=16, value=1, step=1
)
steps_slider = gr.Slider(
label="Inference Steps", minimum=1, maximum=50, value=32, step=1
)
with gr.Row():
cfg_slider = gr.Slider(
label="CFG Scale", minimum=1.0, maximum=10.0, value=3.0, step=0.1
)
seed_input = gr.Number(
label="Seed",
value=-1,
precision=0,
info="Set to -1 for a random seed.",
)
with gr.Row():
size_slider = gr.Slider(
label="Base Image Size",
minimum=384,
maximum=768,
value=512,
step=64,
)
with gr.Row():
aspect_ratio_box = gr.Textbox(
label="Ratio (W:H)",
value="1:1",
)
fixed_short_edge = gr.Checkbox(
label="Fixed Edge",
value=True,
)
generate_button = gr.Button("Generate", variant="primary")
with gr.Row():
with gr.Column(scale=1):
output_prompt = gr.TextArea(
label="TIPO Generated Prompt",
show_label=True,
interactive=False,
lines=32,
max_lines=32,
)
with gr.Column(scale=2):
output_gallery = gr.Gallery(
label="Generated Images",
show_label=True,
elem_id="gallery",
columns=4,
rows=3,
height="800px",
)
gr.Markdown("Images are also saved to the `inference_output/` folder.")
generate_button.click(
fn=generate,
inputs=[
nl_prompt_box,
tag_prompt_box,
neg_prompt_box,
num_images_slider,
steps_slider,
cfg_slider,
size_slider,
aspect_ratio_box,
fixed_short_edge,
seed_input,
],
outputs=[output_gallery, output_prompt],
show_progress_on=output_gallery,
)
if __name__ == "__main__":
demo.launch()