Kohaku-Blueleaf commited on
Commit
3b57e92
·
1 Parent(s): 155fc84
Files changed (1) hide show
  1. app.py +407 -0
app.py ADDED
@@ -0,0 +1,407 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import json
4
+ from pathlib import Path
5
+
6
+ import gradio as gr
7
+ import httpx
8
+ if os.environ.get("IN_SPACES", None) is not None:
9
+ in_spaces = True
10
+ import spaces
11
+ else:
12
+ in_spaces = False
13
+ import numpy as np
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+ from safetensors.torch import load_file
18
+ from PIL import Image
19
+ from tqdm import trange
20
+
21
+ try:
22
+ # pre-import triton can avoid diffusers/transformers make import error
23
+ import triton
24
+ except ImportError:
25
+ print("Triton not found, skip pre import")
26
+
27
+ torch.set_float32_matmul_precision("high")
28
+
29
+ ## HDM model dep
30
+ import xut.env
31
+ xut.env.TORCH_COMPILE = False
32
+ xut.env.USE_LIGER = True
33
+ xut.env.USE_XFORMERS = False
34
+ xut.env.USE_XFORMERS_LAYERS = False
35
+ from xut.xut import XUDiT
36
+ from transformers import Qwen3Model, Qwen2Tokenizer
37
+ from diffusers import AutoencoderKL
38
+
39
+ ## TIPO
40
+ import kgen.models as kgen_models
41
+ import kgen.executor.tipo as tipo
42
+ from kgen.formatter import apply_format, seperate_tags
43
+
44
+
45
+ DEFAULT_FORMAT = """
46
+ <|special|>,
47
+ <|characters|>, <|copyrights|>,
48
+ <|artist|>,
49
+ <|quality|>, <|meta|>, <|rating|>,
50
+
51
+ <|general|>,
52
+
53
+ <|extended|>.
54
+ """.strip()
55
+
56
+
57
+ def GPU(func, duration=None):
58
+ if in_spaces:
59
+ return spaces.GPU(func, duration)
60
+ else:
61
+ return func
62
+
63
+
64
+ def download_model(url: str, filepath: str):
65
+ """Minimal fast download function"""
66
+ if Path(filepath).exists():
67
+ print(f"Model already exists at {filepath}")
68
+ return
69
+
70
+ print(f"Downloading model from {url}...")
71
+ Path(filepath).parent.mkdir(parents=True, exist_ok=True)
72
+
73
+ with httpx.stream("GET", url, follow_redirects=True) as response:
74
+ response.raise_for_status()
75
+ with open(filepath, "wb") as f:
76
+ for chunk in response.iter_bytes(chunk_size=128 * 1024):
77
+ f.write(chunk)
78
+ print(f"Download completed: {filepath}")
79
+
80
+
81
+ def prompt_opt(tags, nl_prompt, aspect_ratio, seed):
82
+ meta, operations, general, nl_prompt = tipo.parse_tipo_request(
83
+ seperate_tags(tags.split(",")),
84
+ nl_prompt,
85
+ tag_length_target="long",
86
+ nl_length_target="short",
87
+ generate_extra_nl_prompt=True,
88
+ )
89
+ meta["aspect_ratio"] = f"{aspect_ratio:.3f}"
90
+ result, timing = tipo.tipo_runner(meta, operations, general, nl_prompt, seed=seed)
91
+ return apply_format(result, DEFAULT_FORMAT).strip().strip(".").strip(",")
92
+
93
+
94
+ # --- User's core functions (copied directly) ---
95
+ def cfg_wrapper(
96
+ prompt: str | list[str],
97
+ neg_prompt: str | list[str],
98
+ unet: nn.Module, # should be k_diffusion wrapper
99
+ te: Qwen3Model,
100
+ tokenizer: Qwen2Tokenizer,
101
+ cfg_scale: float = 3.0,
102
+ ):
103
+ prompt_token = {
104
+ k: v.to(device)
105
+ for k, v in
106
+ tokenizer(
107
+ prompt,
108
+ padding="longest",
109
+ return_tensors="pt",
110
+ ).items()
111
+ }
112
+ neg_prompt_token = {
113
+ k: v.to(device)
114
+ for k, v in
115
+ tokenizer(
116
+ neg_prompt,
117
+ padding="longest",
118
+ return_tensors="pt",
119
+ ).items()
120
+ }
121
+
122
+ emb = te(**prompt_token).last_hidden_state
123
+ neg_emb = te(**neg_prompt_token).last_hidden_state
124
+
125
+ if emb.size(1) > neg_emb.size(1):
126
+ pad_setting = (0, 0, 0, emb.size(1) - neg_emb.size(1))
127
+ neg_emb = F.pad(neg_emb, pad_setting)
128
+ if neg_emb.size(1) > emb.size(1):
129
+ pad_setting = (0, 0, 0, neg_emb.size(1) - emb.size(1))
130
+ emb = F.pad(emb, pad_setting)
131
+ text_ctx_emb = torch.concat([emb, neg_emb])
132
+
133
+ def cfg_fn(x, t, cfg=cfg_scale):
134
+ cond, uncond = unet(
135
+ x.repeat(2, 1, 1, 1),
136
+ t.expand(x.size(0) * 2),
137
+ text_ctx_emb,
138
+ ).chunk(2)
139
+ cond = cond.float()
140
+ uncond = uncond.float()
141
+ return uncond + (cond - uncond) * cfg
142
+
143
+ return cfg_fn
144
+
145
+
146
+
147
+ print("Loading models, please wait...")
148
+ device = torch.device("cuda")
149
+ print("Using device:", torch.cuda.get_device_name(device))
150
+
151
+ model = XUDiT(
152
+ **json.load(open("./config/xut-small-1024-tread.json", "r"))
153
+ ).half().requires_grad_(False).eval().to(device)
154
+ tokenizer = Qwen2Tokenizer.from_pretrained(
155
+ "Qwen/Qwen3-0.6B",
156
+ )
157
+ te = Qwen3Model.from_pretrained(
158
+ "Qwen/Qwen3-0.6B",
159
+ torch_dtype=torch.float16,
160
+ attn_implementation="sdpa"
161
+ ).half().eval().requires_grad_(False).to(device)
162
+ vae = AutoencoderKL.from_pretrained(
163
+ "KBlueLeaf/EQ-SDXL-VAE"
164
+ ).half().eval().requires_grad_(False).to(device)
165
+ vae_mean = torch.tensor(vae.config.latents_mean).view(1, -1, 1, 1).to(device)
166
+ vae_std = torch.tensor(vae.config.latents_std).view(1, -1, 1, 1).to(device)
167
+
168
+
169
+ if not os.path.exists("./model/model.safetensors"):
170
+ model_url = os.environ.get("MODEL_URL")
171
+ download_model(model_url, "./model/model.safetensors")
172
+
173
+ state_dict = load_file("./model/model.safetensors")
174
+ model_sd = {k.replace("unet.", ""): v for k, v in state_dict.items() if k.startswith("unet.")}
175
+ model_sd = {k.replace("model.", ""): v for k, v in model_sd.items()}
176
+ missing, unexpected = model.load_state_dict(model_sd, strict=False)
177
+ if missing:
178
+ print(f"Missing keys: {missing}")
179
+ if unexpected:
180
+ print(f"Unexpected keys: {unexpected}")
181
+
182
+
183
+ tipo_model_name, gguf_list = kgen_models.tipo_model_list[0]
184
+ kgen_models.download_gguf(
185
+ tipo_model_name,
186
+ gguf_list[-1],
187
+ )
188
+ kgen_models.load_model(
189
+ f"{tipo_model_name}_{gguf_list[-1]}", gguf=True, device="cpu"
190
+ )
191
+ print("Models loaded successfully. UI is ready.")
192
+
193
+
194
+ @GPU
195
+ @torch.no_grad()
196
+ def generate(
197
+ nl_prompt: str,
198
+ tag_prompt: str,
199
+ negative_prompt: str,
200
+ num_images: int,
201
+ steps: int,
202
+ cfg_scale: float,
203
+ size: int,
204
+ aspect_ratio: str,
205
+ fixed_short_edge: bool,
206
+ seed: int,
207
+ progress=gr.Progress(),
208
+ ):
209
+ as_w, as_h = aspect_ratio.split(":")
210
+ aspect_ratio = float(as_w) / float(as_h)
211
+ # Set seed for reproducibility
212
+ if seed == -1:
213
+ seed = random.randint(0, 2**32 - 1)
214
+ torch.manual_seed(seed)
215
+
216
+ # TIPO
217
+ tipo.BAN_TAGS = [i.strip() for i in negative_prompt.split(",") if i.strip()]
218
+ final_prompt = prompt_opt(tag_prompt, nl_prompt, aspect_ratio, seed)
219
+ yield None, final_prompt
220
+ all_pil_images = []
221
+
222
+ prompts_to_generate = [final_prompt.replace("\n", " ")] * num_images
223
+ negative_prompts_to_generate = [negative_prompt] * num_images
224
+
225
+ if fixed_short_edge:
226
+ if aspect_ratio > 1:
227
+ h_factor = 1
228
+ w_factor = aspect_ratio
229
+ else:
230
+ h_factor = 1 / aspect_ratio
231
+ w_factor = 1
232
+ else:
233
+ w_factor = aspect_ratio**0.5
234
+ h_factor = 1 / w_factor
235
+
236
+ w = int(size * w_factor / 16) * 2
237
+ h = int(size * h_factor / 16) * 2
238
+
239
+ print("=" * 100)
240
+ print(
241
+ f"Generating {num_images} image(s) with seed: {seed} and resolution {w*8}x{h*8}"
242
+ )
243
+ print("-" * 80)
244
+ print(f"Final prompt: {final_prompt}")
245
+ print("-" * 80)
246
+ print(f"Negative prompt: {negative_prompt}")
247
+ print("-" * 80)
248
+
249
+ prompts_batch = prompts_to_generate
250
+ neg_prompts_batch = negative_prompts_to_generate
251
+
252
+ # Core logic from the original script
253
+ cfg_fn = cfg_wrapper(
254
+ prompts_batch,
255
+ neg_prompts_batch,
256
+ unet=model,
257
+ te=te,
258
+ tokenizer=tokenizer,
259
+ cfg_scale=cfg_scale,
260
+ )
261
+ xt = torch.randn(num_images, 4, h, w).to(device)
262
+
263
+ t = 1.0
264
+ dt = 1.0 / steps
265
+ with trange(steps, desc="Generating Steps", smoothing=0.05) as cli_prog_bar:
266
+ for step in progress.tqdm(list(range(steps)), desc="Generating Steps"):
267
+ with torch.autocast(device.type, dtype=torch.float16):
268
+ model_pred = cfg_fn(xt, torch.tensor(t, device=device))
269
+ xt = xt - dt * model_pred.float()
270
+ t -= dt
271
+ cli_prog_bar.update(1)
272
+
273
+ generated_latents = xt.float()
274
+ image_tensors = torch.concat(
275
+ [
276
+ vae.decode(
277
+ (
278
+ generated_latent[None] * vae_std
279
+ + vae_mean
280
+ ).half()
281
+ ).sample.cpu()
282
+ for generated_latent in generated_latents
283
+ ]
284
+ )
285
+
286
+ # Convert tensors to PIL images
287
+ for image_tensor in image_tensors:
288
+ image = Image.fromarray(
289
+ ((image_tensor * 0.5 + 0.5) * 255)
290
+ .clamp(0, 255)
291
+ .numpy()
292
+ .astype(np.uint8)
293
+ .transpose(1, 2, 0)
294
+ )
295
+ all_pil_images.append(image)
296
+
297
+ yield all_pil_images, final_prompt
298
+
299
+
300
+ # --- Gradio UI Definition ---
301
+ with gr.Blocks(css="footer {display: none !important}") as demo:
302
+ gr.Markdown("# HomeDiffusion Gradio UI")
303
+ gr.Markdown(
304
+ "### Enter a natural language prompt and/or specific tags to generate an image."
305
+ )
306
+
307
+ with gr.Row():
308
+ with gr.Column(scale=2):
309
+ nl_prompt_box = gr.Textbox(
310
+ label="Natural Language Prompt",
311
+ placeholder="e.g., A beautiful anime girl standing in a blooming cherry blossom forest",
312
+ lines=3,
313
+ )
314
+ tag_prompt_box = gr.Textbox(
315
+ label="Tag Prompt (comma-separated)",
316
+ placeholder="e.g., 1girl, solo, long hair, cherry blossoms, school uniform",
317
+ lines=3,
318
+ )
319
+ neg_prompt_box = gr.Textbox(
320
+ label="Negative Prompt",
321
+ value=(
322
+ "low quality, worst quality, "
323
+ "jpeg artifacts, bad anatomy, old, early, "
324
+ "copyright name, watermark"
325
+ ),
326
+ lines=3,
327
+ )
328
+ with gr.Column(scale=1):
329
+ with gr.Row():
330
+ num_images_slider = gr.Slider(
331
+ label="Number of Images", minimum=1, maximum=16, value=1, step=1
332
+ )
333
+ steps_slider = gr.Slider(
334
+ label="Inference Steps", minimum=1, maximum=50, value=32, step=1
335
+ )
336
+
337
+ with gr.Row():
338
+ cfg_slider = gr.Slider(
339
+ label="CFG Scale", minimum=1.0, maximum=10.0, value=3.0, step=0.1
340
+ )
341
+ seed_input = gr.Number(
342
+ label="Seed",
343
+ value=-1,
344
+ precision=0,
345
+ info="Set to -1 for a random seed.",
346
+ )
347
+
348
+ with gr.Row():
349
+ size_slider = gr.Slider(
350
+ label="Base Image Size",
351
+ minimum=384,
352
+ maximum=768,
353
+ value=512,
354
+ step=64,
355
+ )
356
+ with gr.Row():
357
+ aspect_ratio_box = gr.Textbox(
358
+ label="Ratio (W:H)",
359
+ value="1:1",
360
+ )
361
+ fixed_short_edge = gr.Checkbox(
362
+ label="Fixed Edge",
363
+ value=True,
364
+ )
365
+
366
+ generate_button = gr.Button("Generate", variant="primary")
367
+
368
+ with gr.Row():
369
+ with gr.Column(scale=1):
370
+ output_prompt = gr.TextArea(
371
+ label="TIPO Generated Prompt",
372
+ show_label=True,
373
+ interactive=False,
374
+ lines=32,
375
+ max_lines=32,
376
+ )
377
+ with gr.Column(scale=2):
378
+ output_gallery = gr.Gallery(
379
+ label="Generated Images",
380
+ show_label=True,
381
+ elem_id="gallery",
382
+ columns=4,
383
+ rows=3,
384
+ height="800px",
385
+ )
386
+ gr.Markdown("Images are also saved to the `inference_output/` folder.")
387
+
388
+ generate_button.click(
389
+ fn=generate,
390
+ inputs=[
391
+ nl_prompt_box,
392
+ tag_prompt_box,
393
+ neg_prompt_box,
394
+ num_images_slider,
395
+ steps_slider,
396
+ cfg_slider,
397
+ size_slider,
398
+ aspect_ratio_box,
399
+ fixed_short_edge,
400
+ seed_input,
401
+ ],
402
+ outputs=[output_gallery, output_prompt],
403
+ show_progress_on=output_gallery,
404
+ )
405
+
406
+ if __name__ == "__main__":
407
+ demo.launch()