Kohaku-Blueleaf
commited on
Commit
·
3b57e92
1
Parent(s):
155fc84
app
Browse files
app.py
ADDED
@@ -0,0 +1,407 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import random
|
3 |
+
import json
|
4 |
+
from pathlib import Path
|
5 |
+
|
6 |
+
import gradio as gr
|
7 |
+
import httpx
|
8 |
+
if os.environ.get("IN_SPACES", None) is not None:
|
9 |
+
in_spaces = True
|
10 |
+
import spaces
|
11 |
+
else:
|
12 |
+
in_spaces = False
|
13 |
+
import numpy as np
|
14 |
+
import torch
|
15 |
+
import torch.nn as nn
|
16 |
+
import torch.nn.functional as F
|
17 |
+
from safetensors.torch import load_file
|
18 |
+
from PIL import Image
|
19 |
+
from tqdm import trange
|
20 |
+
|
21 |
+
try:
|
22 |
+
# pre-import triton can avoid diffusers/transformers make import error
|
23 |
+
import triton
|
24 |
+
except ImportError:
|
25 |
+
print("Triton not found, skip pre import")
|
26 |
+
|
27 |
+
torch.set_float32_matmul_precision("high")
|
28 |
+
|
29 |
+
## HDM model dep
|
30 |
+
import xut.env
|
31 |
+
xut.env.TORCH_COMPILE = False
|
32 |
+
xut.env.USE_LIGER = True
|
33 |
+
xut.env.USE_XFORMERS = False
|
34 |
+
xut.env.USE_XFORMERS_LAYERS = False
|
35 |
+
from xut.xut import XUDiT
|
36 |
+
from transformers import Qwen3Model, Qwen2Tokenizer
|
37 |
+
from diffusers import AutoencoderKL
|
38 |
+
|
39 |
+
## TIPO
|
40 |
+
import kgen.models as kgen_models
|
41 |
+
import kgen.executor.tipo as tipo
|
42 |
+
from kgen.formatter import apply_format, seperate_tags
|
43 |
+
|
44 |
+
|
45 |
+
DEFAULT_FORMAT = """
|
46 |
+
<|special|>,
|
47 |
+
<|characters|>, <|copyrights|>,
|
48 |
+
<|artist|>,
|
49 |
+
<|quality|>, <|meta|>, <|rating|>,
|
50 |
+
|
51 |
+
<|general|>,
|
52 |
+
|
53 |
+
<|extended|>.
|
54 |
+
""".strip()
|
55 |
+
|
56 |
+
|
57 |
+
def GPU(func, duration=None):
|
58 |
+
if in_spaces:
|
59 |
+
return spaces.GPU(func, duration)
|
60 |
+
else:
|
61 |
+
return func
|
62 |
+
|
63 |
+
|
64 |
+
def download_model(url: str, filepath: str):
|
65 |
+
"""Minimal fast download function"""
|
66 |
+
if Path(filepath).exists():
|
67 |
+
print(f"Model already exists at {filepath}")
|
68 |
+
return
|
69 |
+
|
70 |
+
print(f"Downloading model from {url}...")
|
71 |
+
Path(filepath).parent.mkdir(parents=True, exist_ok=True)
|
72 |
+
|
73 |
+
with httpx.stream("GET", url, follow_redirects=True) as response:
|
74 |
+
response.raise_for_status()
|
75 |
+
with open(filepath, "wb") as f:
|
76 |
+
for chunk in response.iter_bytes(chunk_size=128 * 1024):
|
77 |
+
f.write(chunk)
|
78 |
+
print(f"Download completed: {filepath}")
|
79 |
+
|
80 |
+
|
81 |
+
def prompt_opt(tags, nl_prompt, aspect_ratio, seed):
|
82 |
+
meta, operations, general, nl_prompt = tipo.parse_tipo_request(
|
83 |
+
seperate_tags(tags.split(",")),
|
84 |
+
nl_prompt,
|
85 |
+
tag_length_target="long",
|
86 |
+
nl_length_target="short",
|
87 |
+
generate_extra_nl_prompt=True,
|
88 |
+
)
|
89 |
+
meta["aspect_ratio"] = f"{aspect_ratio:.3f}"
|
90 |
+
result, timing = tipo.tipo_runner(meta, operations, general, nl_prompt, seed=seed)
|
91 |
+
return apply_format(result, DEFAULT_FORMAT).strip().strip(".").strip(",")
|
92 |
+
|
93 |
+
|
94 |
+
# --- User's core functions (copied directly) ---
|
95 |
+
def cfg_wrapper(
|
96 |
+
prompt: str | list[str],
|
97 |
+
neg_prompt: str | list[str],
|
98 |
+
unet: nn.Module, # should be k_diffusion wrapper
|
99 |
+
te: Qwen3Model,
|
100 |
+
tokenizer: Qwen2Tokenizer,
|
101 |
+
cfg_scale: float = 3.0,
|
102 |
+
):
|
103 |
+
prompt_token = {
|
104 |
+
k: v.to(device)
|
105 |
+
for k, v in
|
106 |
+
tokenizer(
|
107 |
+
prompt,
|
108 |
+
padding="longest",
|
109 |
+
return_tensors="pt",
|
110 |
+
).items()
|
111 |
+
}
|
112 |
+
neg_prompt_token = {
|
113 |
+
k: v.to(device)
|
114 |
+
for k, v in
|
115 |
+
tokenizer(
|
116 |
+
neg_prompt,
|
117 |
+
padding="longest",
|
118 |
+
return_tensors="pt",
|
119 |
+
).items()
|
120 |
+
}
|
121 |
+
|
122 |
+
emb = te(**prompt_token).last_hidden_state
|
123 |
+
neg_emb = te(**neg_prompt_token).last_hidden_state
|
124 |
+
|
125 |
+
if emb.size(1) > neg_emb.size(1):
|
126 |
+
pad_setting = (0, 0, 0, emb.size(1) - neg_emb.size(1))
|
127 |
+
neg_emb = F.pad(neg_emb, pad_setting)
|
128 |
+
if neg_emb.size(1) > emb.size(1):
|
129 |
+
pad_setting = (0, 0, 0, neg_emb.size(1) - emb.size(1))
|
130 |
+
emb = F.pad(emb, pad_setting)
|
131 |
+
text_ctx_emb = torch.concat([emb, neg_emb])
|
132 |
+
|
133 |
+
def cfg_fn(x, t, cfg=cfg_scale):
|
134 |
+
cond, uncond = unet(
|
135 |
+
x.repeat(2, 1, 1, 1),
|
136 |
+
t.expand(x.size(0) * 2),
|
137 |
+
text_ctx_emb,
|
138 |
+
).chunk(2)
|
139 |
+
cond = cond.float()
|
140 |
+
uncond = uncond.float()
|
141 |
+
return uncond + (cond - uncond) * cfg
|
142 |
+
|
143 |
+
return cfg_fn
|
144 |
+
|
145 |
+
|
146 |
+
|
147 |
+
print("Loading models, please wait...")
|
148 |
+
device = torch.device("cuda")
|
149 |
+
print("Using device:", torch.cuda.get_device_name(device))
|
150 |
+
|
151 |
+
model = XUDiT(
|
152 |
+
**json.load(open("./config/xut-small-1024-tread.json", "r"))
|
153 |
+
).half().requires_grad_(False).eval().to(device)
|
154 |
+
tokenizer = Qwen2Tokenizer.from_pretrained(
|
155 |
+
"Qwen/Qwen3-0.6B",
|
156 |
+
)
|
157 |
+
te = Qwen3Model.from_pretrained(
|
158 |
+
"Qwen/Qwen3-0.6B",
|
159 |
+
torch_dtype=torch.float16,
|
160 |
+
attn_implementation="sdpa"
|
161 |
+
).half().eval().requires_grad_(False).to(device)
|
162 |
+
vae = AutoencoderKL.from_pretrained(
|
163 |
+
"KBlueLeaf/EQ-SDXL-VAE"
|
164 |
+
).half().eval().requires_grad_(False).to(device)
|
165 |
+
vae_mean = torch.tensor(vae.config.latents_mean).view(1, -1, 1, 1).to(device)
|
166 |
+
vae_std = torch.tensor(vae.config.latents_std).view(1, -1, 1, 1).to(device)
|
167 |
+
|
168 |
+
|
169 |
+
if not os.path.exists("./model/model.safetensors"):
|
170 |
+
model_url = os.environ.get("MODEL_URL")
|
171 |
+
download_model(model_url, "./model/model.safetensors")
|
172 |
+
|
173 |
+
state_dict = load_file("./model/model.safetensors")
|
174 |
+
model_sd = {k.replace("unet.", ""): v for k, v in state_dict.items() if k.startswith("unet.")}
|
175 |
+
model_sd = {k.replace("model.", ""): v for k, v in model_sd.items()}
|
176 |
+
missing, unexpected = model.load_state_dict(model_sd, strict=False)
|
177 |
+
if missing:
|
178 |
+
print(f"Missing keys: {missing}")
|
179 |
+
if unexpected:
|
180 |
+
print(f"Unexpected keys: {unexpected}")
|
181 |
+
|
182 |
+
|
183 |
+
tipo_model_name, gguf_list = kgen_models.tipo_model_list[0]
|
184 |
+
kgen_models.download_gguf(
|
185 |
+
tipo_model_name,
|
186 |
+
gguf_list[-1],
|
187 |
+
)
|
188 |
+
kgen_models.load_model(
|
189 |
+
f"{tipo_model_name}_{gguf_list[-1]}", gguf=True, device="cpu"
|
190 |
+
)
|
191 |
+
print("Models loaded successfully. UI is ready.")
|
192 |
+
|
193 |
+
|
194 |
+
@GPU
|
195 |
+
@torch.no_grad()
|
196 |
+
def generate(
|
197 |
+
nl_prompt: str,
|
198 |
+
tag_prompt: str,
|
199 |
+
negative_prompt: str,
|
200 |
+
num_images: int,
|
201 |
+
steps: int,
|
202 |
+
cfg_scale: float,
|
203 |
+
size: int,
|
204 |
+
aspect_ratio: str,
|
205 |
+
fixed_short_edge: bool,
|
206 |
+
seed: int,
|
207 |
+
progress=gr.Progress(),
|
208 |
+
):
|
209 |
+
as_w, as_h = aspect_ratio.split(":")
|
210 |
+
aspect_ratio = float(as_w) / float(as_h)
|
211 |
+
# Set seed for reproducibility
|
212 |
+
if seed == -1:
|
213 |
+
seed = random.randint(0, 2**32 - 1)
|
214 |
+
torch.manual_seed(seed)
|
215 |
+
|
216 |
+
# TIPO
|
217 |
+
tipo.BAN_TAGS = [i.strip() for i in negative_prompt.split(",") if i.strip()]
|
218 |
+
final_prompt = prompt_opt(tag_prompt, nl_prompt, aspect_ratio, seed)
|
219 |
+
yield None, final_prompt
|
220 |
+
all_pil_images = []
|
221 |
+
|
222 |
+
prompts_to_generate = [final_prompt.replace("\n", " ")] * num_images
|
223 |
+
negative_prompts_to_generate = [negative_prompt] * num_images
|
224 |
+
|
225 |
+
if fixed_short_edge:
|
226 |
+
if aspect_ratio > 1:
|
227 |
+
h_factor = 1
|
228 |
+
w_factor = aspect_ratio
|
229 |
+
else:
|
230 |
+
h_factor = 1 / aspect_ratio
|
231 |
+
w_factor = 1
|
232 |
+
else:
|
233 |
+
w_factor = aspect_ratio**0.5
|
234 |
+
h_factor = 1 / w_factor
|
235 |
+
|
236 |
+
w = int(size * w_factor / 16) * 2
|
237 |
+
h = int(size * h_factor / 16) * 2
|
238 |
+
|
239 |
+
print("=" * 100)
|
240 |
+
print(
|
241 |
+
f"Generating {num_images} image(s) with seed: {seed} and resolution {w*8}x{h*8}"
|
242 |
+
)
|
243 |
+
print("-" * 80)
|
244 |
+
print(f"Final prompt: {final_prompt}")
|
245 |
+
print("-" * 80)
|
246 |
+
print(f"Negative prompt: {negative_prompt}")
|
247 |
+
print("-" * 80)
|
248 |
+
|
249 |
+
prompts_batch = prompts_to_generate
|
250 |
+
neg_prompts_batch = negative_prompts_to_generate
|
251 |
+
|
252 |
+
# Core logic from the original script
|
253 |
+
cfg_fn = cfg_wrapper(
|
254 |
+
prompts_batch,
|
255 |
+
neg_prompts_batch,
|
256 |
+
unet=model,
|
257 |
+
te=te,
|
258 |
+
tokenizer=tokenizer,
|
259 |
+
cfg_scale=cfg_scale,
|
260 |
+
)
|
261 |
+
xt = torch.randn(num_images, 4, h, w).to(device)
|
262 |
+
|
263 |
+
t = 1.0
|
264 |
+
dt = 1.0 / steps
|
265 |
+
with trange(steps, desc="Generating Steps", smoothing=0.05) as cli_prog_bar:
|
266 |
+
for step in progress.tqdm(list(range(steps)), desc="Generating Steps"):
|
267 |
+
with torch.autocast(device.type, dtype=torch.float16):
|
268 |
+
model_pred = cfg_fn(xt, torch.tensor(t, device=device))
|
269 |
+
xt = xt - dt * model_pred.float()
|
270 |
+
t -= dt
|
271 |
+
cli_prog_bar.update(1)
|
272 |
+
|
273 |
+
generated_latents = xt.float()
|
274 |
+
image_tensors = torch.concat(
|
275 |
+
[
|
276 |
+
vae.decode(
|
277 |
+
(
|
278 |
+
generated_latent[None] * vae_std
|
279 |
+
+ vae_mean
|
280 |
+
).half()
|
281 |
+
).sample.cpu()
|
282 |
+
for generated_latent in generated_latents
|
283 |
+
]
|
284 |
+
)
|
285 |
+
|
286 |
+
# Convert tensors to PIL images
|
287 |
+
for image_tensor in image_tensors:
|
288 |
+
image = Image.fromarray(
|
289 |
+
((image_tensor * 0.5 + 0.5) * 255)
|
290 |
+
.clamp(0, 255)
|
291 |
+
.numpy()
|
292 |
+
.astype(np.uint8)
|
293 |
+
.transpose(1, 2, 0)
|
294 |
+
)
|
295 |
+
all_pil_images.append(image)
|
296 |
+
|
297 |
+
yield all_pil_images, final_prompt
|
298 |
+
|
299 |
+
|
300 |
+
# --- Gradio UI Definition ---
|
301 |
+
with gr.Blocks(css="footer {display: none !important}") as demo:
|
302 |
+
gr.Markdown("# HomeDiffusion Gradio UI")
|
303 |
+
gr.Markdown(
|
304 |
+
"### Enter a natural language prompt and/or specific tags to generate an image."
|
305 |
+
)
|
306 |
+
|
307 |
+
with gr.Row():
|
308 |
+
with gr.Column(scale=2):
|
309 |
+
nl_prompt_box = gr.Textbox(
|
310 |
+
label="Natural Language Prompt",
|
311 |
+
placeholder="e.g., A beautiful anime girl standing in a blooming cherry blossom forest",
|
312 |
+
lines=3,
|
313 |
+
)
|
314 |
+
tag_prompt_box = gr.Textbox(
|
315 |
+
label="Tag Prompt (comma-separated)",
|
316 |
+
placeholder="e.g., 1girl, solo, long hair, cherry blossoms, school uniform",
|
317 |
+
lines=3,
|
318 |
+
)
|
319 |
+
neg_prompt_box = gr.Textbox(
|
320 |
+
label="Negative Prompt",
|
321 |
+
value=(
|
322 |
+
"low quality, worst quality, "
|
323 |
+
"jpeg artifacts, bad anatomy, old, early, "
|
324 |
+
"copyright name, watermark"
|
325 |
+
),
|
326 |
+
lines=3,
|
327 |
+
)
|
328 |
+
with gr.Column(scale=1):
|
329 |
+
with gr.Row():
|
330 |
+
num_images_slider = gr.Slider(
|
331 |
+
label="Number of Images", minimum=1, maximum=16, value=1, step=1
|
332 |
+
)
|
333 |
+
steps_slider = gr.Slider(
|
334 |
+
label="Inference Steps", minimum=1, maximum=50, value=32, step=1
|
335 |
+
)
|
336 |
+
|
337 |
+
with gr.Row():
|
338 |
+
cfg_slider = gr.Slider(
|
339 |
+
label="CFG Scale", minimum=1.0, maximum=10.0, value=3.0, step=0.1
|
340 |
+
)
|
341 |
+
seed_input = gr.Number(
|
342 |
+
label="Seed",
|
343 |
+
value=-1,
|
344 |
+
precision=0,
|
345 |
+
info="Set to -1 for a random seed.",
|
346 |
+
)
|
347 |
+
|
348 |
+
with gr.Row():
|
349 |
+
size_slider = gr.Slider(
|
350 |
+
label="Base Image Size",
|
351 |
+
minimum=384,
|
352 |
+
maximum=768,
|
353 |
+
value=512,
|
354 |
+
step=64,
|
355 |
+
)
|
356 |
+
with gr.Row():
|
357 |
+
aspect_ratio_box = gr.Textbox(
|
358 |
+
label="Ratio (W:H)",
|
359 |
+
value="1:1",
|
360 |
+
)
|
361 |
+
fixed_short_edge = gr.Checkbox(
|
362 |
+
label="Fixed Edge",
|
363 |
+
value=True,
|
364 |
+
)
|
365 |
+
|
366 |
+
generate_button = gr.Button("Generate", variant="primary")
|
367 |
+
|
368 |
+
with gr.Row():
|
369 |
+
with gr.Column(scale=1):
|
370 |
+
output_prompt = gr.TextArea(
|
371 |
+
label="TIPO Generated Prompt",
|
372 |
+
show_label=True,
|
373 |
+
interactive=False,
|
374 |
+
lines=32,
|
375 |
+
max_lines=32,
|
376 |
+
)
|
377 |
+
with gr.Column(scale=2):
|
378 |
+
output_gallery = gr.Gallery(
|
379 |
+
label="Generated Images",
|
380 |
+
show_label=True,
|
381 |
+
elem_id="gallery",
|
382 |
+
columns=4,
|
383 |
+
rows=3,
|
384 |
+
height="800px",
|
385 |
+
)
|
386 |
+
gr.Markdown("Images are also saved to the `inference_output/` folder.")
|
387 |
+
|
388 |
+
generate_button.click(
|
389 |
+
fn=generate,
|
390 |
+
inputs=[
|
391 |
+
nl_prompt_box,
|
392 |
+
tag_prompt_box,
|
393 |
+
neg_prompt_box,
|
394 |
+
num_images_slider,
|
395 |
+
steps_slider,
|
396 |
+
cfg_slider,
|
397 |
+
size_slider,
|
398 |
+
aspect_ratio_box,
|
399 |
+
fixed_short_edge,
|
400 |
+
seed_input,
|
401 |
+
],
|
402 |
+
outputs=[output_gallery, output_prompt],
|
403 |
+
show_progress_on=output_gallery,
|
404 |
+
)
|
405 |
+
|
406 |
+
if __name__ == "__main__":
|
407 |
+
demo.launch()
|