|
import os |
|
import random |
|
import json |
|
from pathlib import Path |
|
from functools import partial |
|
|
|
if os.environ.get("IN_SPACES", None) is not None: |
|
in_spaces = True |
|
import spaces |
|
|
|
os.system( |
|
"pip install git+https://${GIT_USER}:${GIT_TOKEN}@github.com/KohakuBlueleaf/XUT" |
|
) |
|
else: |
|
in_spaces = False |
|
|
|
import gradio as gr |
|
import httpx |
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from safetensors.torch import load_file |
|
from PIL import Image |
|
from tqdm import trange |
|
|
|
try: |
|
|
|
import triton |
|
except ImportError: |
|
print("Triton not found, skip pre import") |
|
|
|
torch.set_float32_matmul_precision("high") |
|
|
|
|
|
import xut.env |
|
|
|
xut.env.TORCH_COMPILE = False |
|
xut.env.USE_LIGER = True |
|
xut.env.USE_XFORMERS = False |
|
xut.env.USE_XFORMERS_LAYERS = False |
|
from xut.xut import XUDiT |
|
from transformers import Qwen3Model, Qwen2Tokenizer |
|
from diffusers import AutoencoderKL |
|
|
|
|
|
import kgen.models as kgen_models |
|
import kgen.executor.tipo as tipo |
|
from kgen.formatter import apply_format, seperate_tags |
|
|
|
|
|
DEFAULT_FORMAT = """ |
|
<|special|>, |
|
<|characters|>, <|copyrights|>, |
|
<|artist|>, |
|
<|quality|>, <|meta|>, <|rating|>, |
|
|
|
<|general|>, |
|
|
|
<|extended|>. |
|
""".strip() |
|
|
|
|
|
def GPU(func=None, duration=None): |
|
if func is None: |
|
return partial(GPU, duration=duration) |
|
if in_spaces: |
|
if duration: |
|
return spaces.GPU(func, duration=duration) |
|
else: |
|
return spaces.GPU(func) |
|
else: |
|
return func |
|
|
|
|
|
def download_model(url: str, filepath: str): |
|
"""Minimal fast download function""" |
|
if Path(filepath).exists(): |
|
print(f"Model already exists at {filepath}") |
|
return |
|
|
|
print(f"Downloading model...") |
|
Path(filepath).parent.mkdir(parents=True, exist_ok=True) |
|
|
|
with httpx.stream("GET", url, follow_redirects=True) as response: |
|
response.raise_for_status() |
|
with open(filepath, "wb") as f: |
|
for chunk in response.iter_bytes(chunk_size=128 * 1024): |
|
f.write(chunk) |
|
print(f"Download completed: {filepath}") |
|
|
|
|
|
def prompt_opt(tags, nl_prompt, aspect_ratio, seed): |
|
meta, operations, general, nl_prompt = tipo.parse_tipo_request( |
|
seperate_tags(tags.split(",")), |
|
nl_prompt, |
|
tag_length_target="long", |
|
nl_length_target="short", |
|
generate_extra_nl_prompt=True, |
|
) |
|
meta["aspect_ratio"] = f"{aspect_ratio:.3f}" |
|
result, timing = tipo.tipo_runner(meta, operations, general, nl_prompt, seed=seed) |
|
return apply_format(result, DEFAULT_FORMAT).strip().strip(".").strip(",") |
|
|
|
|
|
|
|
def cfg_wrapper( |
|
prompt: str | list[str], |
|
neg_prompt: str | list[str], |
|
unet: nn.Module, |
|
te: Qwen3Model, |
|
tokenizer: Qwen2Tokenizer, |
|
cfg_scale: float = 3.0, |
|
): |
|
prompt_token = { |
|
k: v.to(device) |
|
for k, v in tokenizer( |
|
prompt, |
|
padding="longest", |
|
return_tensors="pt", |
|
).items() |
|
} |
|
neg_prompt_token = { |
|
k: v.to(device) |
|
for k, v in tokenizer( |
|
neg_prompt, |
|
padding="longest", |
|
return_tensors="pt", |
|
).items() |
|
} |
|
|
|
emb = te(**prompt_token).last_hidden_state |
|
neg_emb = te(**neg_prompt_token).last_hidden_state |
|
|
|
def cfg_fn(x, t, cfg=cfg_scale): |
|
cond = unet(x, t.expand(x.size(0)), emb).float() |
|
uncond = unet(x, t.expand(x.size(0)), neg_emb).float() |
|
return uncond + (cond - uncond) * cfg |
|
|
|
return cfg_fn |
|
|
|
|
|
print("Loading models, please wait...") |
|
device = torch.device("cuda") |
|
|
|
model = ( |
|
XUDiT(**json.load(open("./config/xut-small-1024-tread.json", "r"))) |
|
.half() |
|
.requires_grad_(False) |
|
.eval() |
|
.to(device) |
|
) |
|
tokenizer = Qwen2Tokenizer.from_pretrained( |
|
"Qwen/Qwen3-0.6B", |
|
) |
|
te = ( |
|
Qwen3Model.from_pretrained( |
|
"Qwen/Qwen3-0.6B", torch_dtype=torch.float16, attn_implementation="sdpa" |
|
) |
|
.half() |
|
.eval() |
|
.requires_grad_(False) |
|
.to(device) |
|
) |
|
vae = ( |
|
AutoencoderKL.from_pretrained("KBlueLeaf/EQ-SDXL-VAE") |
|
.half() |
|
.eval() |
|
.requires_grad_(False) |
|
.to(device) |
|
) |
|
vae_mean = torch.tensor(vae.config.latents_mean).view(1, -1, 1, 1).to(device) |
|
vae_std = torch.tensor(vae.config.latents_std).view(1, -1, 1, 1).to(device) |
|
|
|
|
|
if not os.path.exists("./model/model.safetensors"): |
|
model_file = os.environ.get("MODEL_FILE") |
|
os.system( |
|
f"hfutils download -t model -r KBlueLeaf/XUT-demo -f {model_file} -o model/model.safetensors" |
|
) |
|
|
|
state_dict = load_file("./model/model.safetensors") |
|
model_sd = { |
|
k.replace("unet.", ""): v for k, v in state_dict.items() if k.startswith("unet.") |
|
} |
|
model_sd = {k.replace("model.", ""): v for k, v in model_sd.items()} |
|
missing, unexpected = model.load_state_dict(model_sd, strict=False) |
|
if missing: |
|
print(f"Missing keys: {missing}") |
|
if unexpected: |
|
print(f"Unexpected keys: {unexpected}") |
|
|
|
|
|
tipo_model_name, gguf_list = kgen_models.tipo_model_list[0] |
|
kgen_models.load_model(tipo_model_name, device="cuda") |
|
print("Models loaded successfully. UI is ready.") |
|
|
|
|
|
@GPU(duration=5) |
|
@torch.no_grad() |
|
def generate( |
|
nl_prompt: str, |
|
tag_prompt: str, |
|
negative_prompt: str, |
|
tipo_enable: bool, |
|
format_enable: bool, |
|
num_images: int, |
|
steps: int, |
|
cfg_scale: float, |
|
size: int, |
|
aspect_ratio: str, |
|
fixed_short_edge: bool, |
|
seed: int, |
|
progress=gr.Progress(), |
|
): |
|
as_w, as_h = aspect_ratio.split(":") |
|
aspect_ratio = float(as_w) / float(as_h) |
|
|
|
if seed == -1: |
|
seed = random.randint(0, 2**32 - 1) |
|
torch.manual_seed(seed) |
|
|
|
|
|
if tipo_enable: |
|
tipo.BAN_TAGS = [i.strip() for i in negative_prompt.split(",") if i.strip()] |
|
final_prompt = prompt_opt(tag_prompt, nl_prompt, aspect_ratio, seed) |
|
elif format_enable: |
|
final_prompt = apply_format(nl_prompt, DEFAULT_FORMAT) |
|
else: |
|
final_prompt = tag_prompt + "\n" + nl_prompt |
|
|
|
yield None, final_prompt |
|
all_pil_images = [] |
|
|
|
prompts_to_generate = [final_prompt.replace("\n", " ")] * num_images |
|
negative_prompts_to_generate = [negative_prompt] * num_images |
|
|
|
if fixed_short_edge: |
|
if aspect_ratio > 1: |
|
h_factor = 1 |
|
w_factor = aspect_ratio |
|
else: |
|
h_factor = 1 / aspect_ratio |
|
w_factor = 1 |
|
else: |
|
w_factor = aspect_ratio**0.5 |
|
h_factor = 1 / w_factor |
|
|
|
w = int(size * w_factor / 16) * 2 |
|
h = int(size * h_factor / 16) * 2 |
|
|
|
print("=" * 100) |
|
print( |
|
f"Generating {num_images} image(s) with seed: {seed} and resolution {w*8}x{h*8}" |
|
) |
|
print("-" * 80) |
|
print(f"Final prompt: {final_prompt}") |
|
print("-" * 80) |
|
print(f"Negative prompt: {negative_prompt}") |
|
print("-" * 80) |
|
|
|
prompts_batch = prompts_to_generate |
|
neg_prompts_batch = negative_prompts_to_generate |
|
|
|
|
|
cfg_fn = cfg_wrapper( |
|
prompts_batch, |
|
neg_prompts_batch, |
|
unet=model, |
|
te=te, |
|
tokenizer=tokenizer, |
|
cfg_scale=cfg_scale, |
|
) |
|
xt = torch.randn(num_images, 4, h, w).to(device) |
|
|
|
t = 1.0 |
|
dt = 1.0 / steps |
|
with trange(steps, desc="Generating Steps", smoothing=0.05) as cli_prog_bar: |
|
for step in progress.tqdm(list(range(steps)), desc="Generating Steps"): |
|
with torch.autocast(device.type, dtype=torch.float16): |
|
model_pred = cfg_fn(xt, torch.tensor(t, device=device)) |
|
xt = xt - dt * model_pred.float() |
|
t -= dt |
|
cli_prog_bar.update(1) |
|
|
|
generated_latents = xt.float() |
|
image_tensors = torch.concat( |
|
[ |
|
vae.decode( |
|
(generated_latent[None] * vae_std + vae_mean).half() |
|
).sample.cpu() |
|
for generated_latent in generated_latents |
|
] |
|
) |
|
|
|
|
|
for image_tensor in image_tensors: |
|
image = Image.fromarray( |
|
((image_tensor * 0.5 + 0.5) * 255) |
|
.clamp(0, 255) |
|
.numpy() |
|
.astype(np.uint8) |
|
.transpose(1, 2, 0) |
|
) |
|
all_pil_images.append(image) |
|
|
|
yield all_pil_images, final_prompt |
|
|
|
|
|
|
|
with gr.Blocks(title="HDM Demo", theme=gr.themes.Soft()) as demo: |
|
gr.Markdown("# HDM Early Demo") |
|
gr.Markdown( |
|
"### Enter a natural language prompt and/or specific tags to generate an image." |
|
) |
|
with gr.Accordion("Introduction", open=False): |
|
gr.Markdown(""" |
|
# HDM: HomeDiffusion Model Project |
|
HDM is a project to implement a series of generative model that can be pretrained at home. |
|
|
|
## About this Demo |
|
This DEMO used a checkpoint during training to demostrate the functionality of HDM. |
|
Not final model yet. |
|
|
|
## Usage |
|
This early model used a model trained on anime image set only, |
|
so you should expect to see anime style images only in this demo. |
|
|
|
For prompting, enter danbooru tag prompt to the box "Tag Prompt" with comma seperated and remove the underscore. |
|
enter natural language prompt to the box "Natural Language Prompt" and enter negative prompt to the box "Negative Prompt". |
|
|
|
If you don't want to spent so much effort on prompting, try to keep "Enable TIPO" selected. |
|
|
|
If you don't want to apply any pre-defined format, unselect "Enable TIPO" and "Enable Format". |
|
|
|
## Model Spec |
|
- Backbone: 342M custom DiT(UViT modified) arch |
|
- Text Encoder: Qwen3 0.6B (596M) |
|
- VAE: EQ-SDXL-VAE, an EQ-VAE finetuned sdxl vae. |
|
|
|
## Pretraining Dataset |
|
- Danbooru 2023 (latest id around 8M) |
|
- Pixiv famous artist set |
|
- some pvc figure photos |
|
""") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=2): |
|
nl_prompt_box = gr.Textbox( |
|
label="Natural Language Prompt", |
|
placeholder="e.g., A beautiful anime girl standing in a blooming cherry blossom forest", |
|
lines=3, |
|
) |
|
tag_prompt_box = gr.Textbox( |
|
label="Tag Prompt (comma-separated)", |
|
placeholder="e.g., 1girl, solo, long hair, cherry blossoms, school uniform", |
|
lines=3, |
|
) |
|
neg_prompt_box = gr.Textbox( |
|
label="Negative Prompt", |
|
value=( |
|
"low quality, worst quality, " |
|
"jpeg artifacts, bad anatomy, old, early, " |
|
"copyright name, watermark" |
|
), |
|
lines=3, |
|
) |
|
with gr.Row(): |
|
tipo_enable = gr.Checkbox( |
|
label="Enable TIPO", |
|
value=True, |
|
) |
|
format_enable = gr.Checkbox( |
|
label="Enable Format", |
|
value=True, |
|
) |
|
with gr.Column(scale=1): |
|
with gr.Row(): |
|
num_images_slider = gr.Slider( |
|
label="Number of Images", minimum=1, maximum=16, value=1, step=1 |
|
) |
|
steps_slider = gr.Slider( |
|
label="Inference Steps", minimum=1, maximum=64, value=32, step=1 |
|
) |
|
|
|
with gr.Row(): |
|
cfg_slider = gr.Slider( |
|
label="CFG Scale", minimum=1.0, maximum=5.0, value=3.0, step=0.1 |
|
) |
|
seed_input = gr.Number( |
|
label="Seed", |
|
value=-1, |
|
precision=0, |
|
info="Set to -1 for a random seed.", |
|
) |
|
|
|
with gr.Row(): |
|
size_slider = gr.Slider( |
|
label="Base Image Size", |
|
minimum=384, |
|
maximum=768, |
|
value=512, |
|
step=64, |
|
) |
|
with gr.Row(): |
|
aspect_ratio_box = gr.Textbox( |
|
label="Ratio (W:H)", |
|
value="1:1", |
|
) |
|
fixed_short_edge = gr.Checkbox( |
|
label="Fixed Edge", |
|
value=True, |
|
) |
|
|
|
generate_button = gr.Button("Generate", variant="primary") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
output_prompt = gr.TextArea( |
|
label="Final Prompt", |
|
show_label=True, |
|
interactive=False, |
|
lines=32, |
|
max_lines=32, |
|
) |
|
with gr.Column(scale=2): |
|
output_gallery = gr.Gallery( |
|
label="Generated Images", |
|
show_label=True, |
|
elem_id="gallery", |
|
columns=4, |
|
rows=3, |
|
height="800px", |
|
) |
|
|
|
generate_button.click( |
|
fn=generate, |
|
inputs=[ |
|
nl_prompt_box, |
|
tag_prompt_box, |
|
neg_prompt_box, |
|
tipo_enable, |
|
format_enable, |
|
num_images_slider, |
|
steps_slider, |
|
cfg_slider, |
|
size_slider, |
|
aspect_ratio_box, |
|
fixed_short_edge, |
|
seed_input, |
|
], |
|
outputs=[output_gallery, output_prompt], |
|
show_progress_on=output_gallery, |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|