import os import random import json from pathlib import Path import gradio as gr import httpx if os.environ.get("IN_SPACES", None) is not None: in_spaces = True import spaces os.system("pip install git+https://${GIT_USER}:${GIT_TOKEN}@github.com/KohakuBlueleaf/XUT") else: in_spaces = False import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from safetensors.torch import load_file from PIL import Image from tqdm import trange try: # pre-import triton can avoid diffusers/transformers make import error import triton except ImportError: print("Triton not found, skip pre import") torch.set_float32_matmul_precision("high") ## HDM model dep import xut.env xut.env.TORCH_COMPILE = False xut.env.USE_LIGER = True xut.env.USE_XFORMERS = False xut.env.USE_XFORMERS_LAYERS = False from xut.xut import XUDiT from transformers import Qwen3Model, Qwen2Tokenizer from diffusers import AutoencoderKL ## TIPO import kgen.models as kgen_models import kgen.executor.tipo as tipo from kgen.formatter import apply_format, seperate_tags DEFAULT_FORMAT = """ <|special|>, <|characters|>, <|copyrights|>, <|artist|>, <|quality|>, <|meta|>, <|rating|>, <|general|>, <|extended|>. """.strip() def GPU(func, duration=None): if in_spaces: return spaces.GPU(func, duration=duration) else: return func def download_model(url: str, filepath: str): """Minimal fast download function""" if Path(filepath).exists(): print(f"Model already exists at {filepath}") return print(f"Downloading model...") Path(filepath).parent.mkdir(parents=True, exist_ok=True) with httpx.stream("GET", url, follow_redirects=True) as response: response.raise_for_status() with open(filepath, "wb") as f: for chunk in response.iter_bytes(chunk_size=128 * 1024): f.write(chunk) print(f"Download completed: {filepath}") def prompt_opt(tags, nl_prompt, aspect_ratio, seed): meta, operations, general, nl_prompt = tipo.parse_tipo_request( seperate_tags(tags.split(",")), nl_prompt, tag_length_target="long", nl_length_target="short", generate_extra_nl_prompt=True, ) meta["aspect_ratio"] = f"{aspect_ratio:.3f}" result, timing = tipo.tipo_runner(meta, operations, general, nl_prompt, seed=seed) return apply_format(result, DEFAULT_FORMAT).strip().strip(".").strip(",") # --- User's core functions (copied directly) --- def cfg_wrapper( prompt: str | list[str], neg_prompt: str | list[str], unet: nn.Module, # should be k_diffusion wrapper te: Qwen3Model, tokenizer: Qwen2Tokenizer, cfg_scale: float = 3.0, ): prompt_token = { k: v.to(device) for k, v in tokenizer( prompt, padding="longest", return_tensors="pt", ).items() } neg_prompt_token = { k: v.to(device) for k, v in tokenizer( neg_prompt, padding="longest", return_tensors="pt", ).items() } emb = te(**prompt_token).last_hidden_state neg_emb = te(**neg_prompt_token).last_hidden_state if emb.size(1) > neg_emb.size(1): pad_setting = (0, 0, 0, emb.size(1) - neg_emb.size(1)) neg_emb = F.pad(neg_emb, pad_setting) if neg_emb.size(1) > emb.size(1): pad_setting = (0, 0, 0, neg_emb.size(1) - emb.size(1)) emb = F.pad(emb, pad_setting) text_ctx_emb = torch.concat([emb, neg_emb]) def cfg_fn(x, t, cfg=cfg_scale): cond, uncond = unet( x.repeat(2, 1, 1, 1), t.expand(x.size(0) * 2), text_ctx_emb, ).chunk(2) cond = cond.float() uncond = uncond.float() return uncond + (cond - uncond) * cfg return cfg_fn print("Loading models, please wait...") device = torch.device("cuda") print("Using device:", torch.cuda.get_device_name(device)) model = XUDiT( **json.load(open("./config/xut-small-1024-tread.json", "r")) ).half().requires_grad_(False).eval().to(device) tokenizer = Qwen2Tokenizer.from_pretrained( "Qwen/Qwen3-0.6B", ) te = Qwen3Model.from_pretrained( "Qwen/Qwen3-0.6B", torch_dtype=torch.float16, attn_implementation="sdpa" ).half().eval().requires_grad_(False).to(device) vae = AutoencoderKL.from_pretrained( "KBlueLeaf/EQ-SDXL-VAE" ).half().eval().requires_grad_(False).to(device) vae_mean = torch.tensor(vae.config.latents_mean).view(1, -1, 1, 1).to(device) vae_std = torch.tensor(vae.config.latents_std).view(1, -1, 1, 1).to(device) if not os.path.exists("./model/model.safetensors"): model_file = os.environ.get("MODEL_FILE") os.system(f"hfutils download -t model -r KBlueLeaf/XUT-demo -f {model_file} -o model/model.safetensors") state_dict = load_file("./model/model.safetensors") model_sd = {k.replace("unet.", ""): v for k, v in state_dict.items() if k.startswith("unet.")} model_sd = {k.replace("model.", ""): v for k, v in model_sd.items()} missing, unexpected = model.load_state_dict(model_sd, strict=False) if missing: print(f"Missing keys: {missing}") if unexpected: print(f"Unexpected keys: {unexpected}") tipo_model_name, gguf_list = kgen_models.tipo_model_list[0] kgen_models.download_gguf( tipo_model_name, gguf_list[-1], ) kgen_models.load_model( f"{tipo_model_name}_{gguf_list[-1]}", gguf=True, device="cpu" ) print("Models loaded successfully. UI is ready.") @GPU @torch.no_grad() def generate( nl_prompt: str, tag_prompt: str, negative_prompt: str, num_images: int, steps: int, cfg_scale: float, size: int, aspect_ratio: str, fixed_short_edge: bool, seed: int, progress=gr.Progress(), ): as_w, as_h = aspect_ratio.split(":") aspect_ratio = float(as_w) / float(as_h) # Set seed for reproducibility if seed == -1: seed = random.randint(0, 2**32 - 1) torch.manual_seed(seed) # TIPO tipo.BAN_TAGS = [i.strip() for i in negative_prompt.split(",") if i.strip()] final_prompt = prompt_opt(tag_prompt, nl_prompt, aspect_ratio, seed) yield None, final_prompt all_pil_images = [] prompts_to_generate = [final_prompt.replace("\n", " ")] * num_images negative_prompts_to_generate = [negative_prompt] * num_images if fixed_short_edge: if aspect_ratio > 1: h_factor = 1 w_factor = aspect_ratio else: h_factor = 1 / aspect_ratio w_factor = 1 else: w_factor = aspect_ratio**0.5 h_factor = 1 / w_factor w = int(size * w_factor / 16) * 2 h = int(size * h_factor / 16) * 2 print("=" * 100) print( f"Generating {num_images} image(s) with seed: {seed} and resolution {w*8}x{h*8}" ) print("-" * 80) print(f"Final prompt: {final_prompt}") print("-" * 80) print(f"Negative prompt: {negative_prompt}") print("-" * 80) prompts_batch = prompts_to_generate neg_prompts_batch = negative_prompts_to_generate # Core logic from the original script cfg_fn = cfg_wrapper( prompts_batch, neg_prompts_batch, unet=model, te=te, tokenizer=tokenizer, cfg_scale=cfg_scale, ) xt = torch.randn(num_images, 4, h, w).to(device) t = 1.0 dt = 1.0 / steps with trange(steps, desc="Generating Steps", smoothing=0.05) as cli_prog_bar: for step in progress.tqdm(list(range(steps)), desc="Generating Steps"): with torch.autocast(device.type, dtype=torch.float16): model_pred = cfg_fn(xt, torch.tensor(t, device=device)) xt = xt - dt * model_pred.float() t -= dt cli_prog_bar.update(1) generated_latents = xt.float() image_tensors = torch.concat( [ vae.decode( ( generated_latent[None] * vae_std + vae_mean ).half() ).sample.cpu() for generated_latent in generated_latents ] ) # Convert tensors to PIL images for image_tensor in image_tensors: image = Image.fromarray( ((image_tensor * 0.5 + 0.5) * 255) .clamp(0, 255) .numpy() .astype(np.uint8) .transpose(1, 2, 0) ) all_pil_images.append(image) yield all_pil_images, final_prompt # --- Gradio UI Definition --- with gr.Blocks(title="HDM Demo", theme=gr.themes.Soft()) as demo: gr.Markdown("# HomeDiffusion Gradio UI") gr.Markdown( "### Enter a natural language prompt and/or specific tags to generate an image." ) with gr.Row(): with gr.Column(scale=2): nl_prompt_box = gr.Textbox( label="Natural Language Prompt", placeholder="e.g., A beautiful anime girl standing in a blooming cherry blossom forest", lines=3, ) tag_prompt_box = gr.Textbox( label="Tag Prompt (comma-separated)", placeholder="e.g., 1girl, solo, long hair, cherry blossoms, school uniform", lines=3, ) neg_prompt_box = gr.Textbox( label="Negative Prompt", value=( "low quality, worst quality, " "jpeg artifacts, bad anatomy, old, early, " "copyright name, watermark" ), lines=3, ) with gr.Column(scale=1): with gr.Row(): num_images_slider = gr.Slider( label="Number of Images", minimum=1, maximum=16, value=1, step=1 ) steps_slider = gr.Slider( label="Inference Steps", minimum=1, maximum=50, value=32, step=1 ) with gr.Row(): cfg_slider = gr.Slider( label="CFG Scale", minimum=1.0, maximum=10.0, value=3.0, step=0.1 ) seed_input = gr.Number( label="Seed", value=-1, precision=0, info="Set to -1 for a random seed.", ) with gr.Row(): size_slider = gr.Slider( label="Base Image Size", minimum=384, maximum=768, value=512, step=64, ) with gr.Row(): aspect_ratio_box = gr.Textbox( label="Ratio (W:H)", value="1:1", ) fixed_short_edge = gr.Checkbox( label="Fixed Edge", value=True, ) generate_button = gr.Button("Generate", variant="primary") with gr.Row(): with gr.Column(scale=1): output_prompt = gr.TextArea( label="TIPO Generated Prompt", show_label=True, interactive=False, lines=32, max_lines=32, ) with gr.Column(scale=2): output_gallery = gr.Gallery( label="Generated Images", show_label=True, elem_id="gallery", columns=4, rows=3, height="800px", ) gr.Markdown("Images are also saved to the `inference_output/` folder.") generate_button.click( fn=generate, inputs=[ nl_prompt_box, tag_prompt_box, neg_prompt_box, num_images_slider, steps_slider, cfg_slider, size_slider, aspect_ratio_box, fixed_short_edge, seed_input, ], outputs=[output_gallery, output_prompt], show_progress_on=output_gallery, ) if __name__ == "__main__": demo.launch()