svjack commited on
Commit
eaa504b
ยท
verified ยท
1 Parent(s): 32681ea

Upload zip_app.py

Browse files
Files changed (1) hide show
  1. zip_app.py +529 -0
zip_app.py ADDED
@@ -0,0 +1,529 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import zipfile
3
+ import tempfile
4
+ from datasets import load_dataset
5
+ import uuid
6
+ import math
7
+ import gradio as gr
8
+ import numpy as np
9
+ import torch
10
+ import safetensors.torch as sf
11
+ import db_examples
12
+ from PIL import Image
13
+ from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline
14
+ from diffusers import AutoencoderKL, UNet2DConditionModel, DDIMScheduler, EulerAncestralDiscreteScheduler, DPMSolverMultistepScheduler
15
+ from diffusers.models.attention_processor import AttnProcessor2_0
16
+ from transformers import CLIPTextModel, CLIPTokenizer
17
+ from briarmbg import BriaRMBG
18
+ from enum import Enum
19
+
20
+ # ๅฎšไน‰ไฟๅญ˜่ทฏๅพ„
21
+ save_path = "./examples/xiangxiang_man"
22
+
23
+ # ๆธ…็ฉบ็›ฎๆ ‡่ทฏๅพ„๏ผˆๅฆ‚ๆžœๅญ˜ๅœจ๏ผ‰
24
+ if os.path.exists(save_path):
25
+ for file_name in os.listdir(save_path):
26
+ file_path = os.path.join(save_path, file_name)
27
+ if os.path.isfile(file_path):
28
+ os.remove(file_path)
29
+ print(f"Cleared existing files in {save_path}")
30
+ else:
31
+ os.makedirs(save_path, exist_ok=True)
32
+ print(f"Created directory: {save_path}")
33
+
34
+ # ๅŠ ่ฝฝๆ•ฐๆฎ้›†
35
+ dataset = load_dataset("svjack/Prince_Xiang_iclight_v2")
36
+
37
+ # ้ๅކๆ•ฐๆฎ้›†ๅนถไฟๅญ˜ๅ›พ็‰‡
38
+ for example in dataset["train"]:
39
+ # ่Žทๅ–ๅ›พ็‰‡ๆ•ฐๆฎ
40
+ image = example["image"]
41
+
42
+ # ็”Ÿๆˆๅ”ฏไธ€็š„ๆ–‡ไปถๅ๏ผˆไฝฟ็”จ uuid๏ผ‰
43
+ file_name = f"{uuid.uuid4()}.png"
44
+ file_path = os.path.join(save_path, file_name)
45
+
46
+ # ไฟๅญ˜ๅ›พ็‰‡
47
+ image.save(file_path)
48
+ print(f"Saved {file_path}")
49
+
50
+ print("All images have been saved.")
51
+
52
+ # ... [็œ็•ฅไธญ้—ดไปฃ็ ๏ผŒไฟๆŒไธๅ˜] ...
53
+
54
+ def create_zip_from_gallery(images, prompt):
55
+ # ๅˆ›ๅปบไธ€ไธชไธดๆ—ถ็›ฎๅฝ•ๆฅๅญ˜ๅ‚จๅ›พ็‰‡
56
+ with tempfile.TemporaryDirectory() as temp_dir:
57
+ # ๅฐ†ๅ›พ็‰‡ไฟๅญ˜ๅˆฐไธดๆ—ถ็›ฎๅฝ•
58
+ print("images :")
59
+ print(images)
60
+ print("- " * 100)
61
+
62
+ for i, img in enumerate(images):
63
+ img = img[0]
64
+ img_path = os.path.join(temp_dir, f"image_{i}.png")
65
+ #Image.fromarray(img).save(img_path)
66
+ Image.open(img).save(img_path)
67
+
68
+ # ๅˆ›ๅปบๅŽ‹็ผฉๆ–‡ไปถ
69
+ zip_filename = f"{prompt.replace(' ', '_')}.zip"
70
+ #zip_path = os.path.join(temp_dir, zip_filename)
71
+ zip_path = zip_filename
72
+ with zipfile.ZipFile(zip_path, 'w') as zipf:
73
+ for img_file in os.listdir(temp_dir):
74
+ if img_file.endswith('.png'):
75
+ zipf.write(os.path.join(temp_dir, img_file), img_file)
76
+
77
+ # ่ฟ”ๅ›žๅŽ‹็ผฉๆ–‡ไปถ็š„่ทฏๅพ„
78
+ return zip_path
79
+
80
+ # ... [็œ็•ฅไธญ้—ดไปฃ็ ๏ผŒไฟๆŒไธๅ˜] ...
81
+
82
+ #import spaces
83
+ import math
84
+ import gradio as gr
85
+ import numpy as np
86
+ import torch
87
+ import safetensors.torch as sf
88
+ import db_examples
89
+
90
+ from PIL import Image
91
+ from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline
92
+ from diffusers import AutoencoderKL, UNet2DConditionModel, DDIMScheduler, EulerAncestralDiscreteScheduler, DPMSolverMultistepScheduler
93
+ from diffusers.models.attention_processor import AttnProcessor2_0
94
+ from transformers import CLIPTextModel, CLIPTokenizer
95
+ from briarmbg import BriaRMBG
96
+ from enum import Enum
97
+ # from torch.hub import download_url_to_file
98
+
99
+
100
+ # 'stablediffusionapi/realistic-vision-v51'
101
+ # 'runwayml/stable-diffusion-v1-5'
102
+ sd15_name = 'stablediffusionapi/realistic-vision-v51'
103
+ tokenizer = CLIPTokenizer.from_pretrained(sd15_name, subfolder="tokenizer")
104
+ text_encoder = CLIPTextModel.from_pretrained(sd15_name, subfolder="text_encoder")
105
+ vae = AutoencoderKL.from_pretrained(sd15_name, subfolder="vae")
106
+ unet = UNet2DConditionModel.from_pretrained(sd15_name, subfolder="unet")
107
+ rmbg = BriaRMBG.from_pretrained("briaai/RMBG-1.4")
108
+
109
+ # Change UNet
110
+
111
+ with torch.no_grad():
112
+ new_conv_in = torch.nn.Conv2d(8, unet.conv_in.out_channels, unet.conv_in.kernel_size, unet.conv_in.stride, unet.conv_in.padding)
113
+ new_conv_in.weight.zero_()
114
+ new_conv_in.weight[:, :4, :, :].copy_(unet.conv_in.weight)
115
+ new_conv_in.bias = unet.conv_in.bias
116
+ unet.conv_in = new_conv_in
117
+
118
+ unet_original_forward = unet.forward
119
+
120
+
121
+ def hooked_unet_forward(sample, timestep, encoder_hidden_states, **kwargs):
122
+ c_concat = kwargs['cross_attention_kwargs']['concat_conds'].to(sample)
123
+ c_concat = torch.cat([c_concat] * (sample.shape[0] // c_concat.shape[0]), dim=0)
124
+ new_sample = torch.cat([sample, c_concat], dim=1)
125
+ kwargs['cross_attention_kwargs'] = {}
126
+ return unet_original_forward(new_sample, timestep, encoder_hidden_states, **kwargs)
127
+
128
+
129
+ unet.forward = hooked_unet_forward
130
+
131
+ # Load
132
+
133
+ model_path = './models/iclight_sd15_fc.safetensors'
134
+ # download_url_to_file(url='https://huggingface.co/lllyasviel/ic-light/resolve/main/iclight_sd15_fc.safetensors', dst=model_path)
135
+ sd_offset = sf.load_file(model_path)
136
+ sd_origin = unet.state_dict()
137
+ keys = sd_origin.keys()
138
+ sd_merged = {k: sd_origin[k] + sd_offset[k] for k in sd_origin.keys()}
139
+ unet.load_state_dict(sd_merged, strict=True)
140
+ del sd_offset, sd_origin, sd_merged, keys
141
+
142
+ # Device
143
+
144
+ device = torch.device('cuda')
145
+ text_encoder = text_encoder.to(device=device, dtype=torch.float16)
146
+ vae = vae.to(device=device, dtype=torch.bfloat16)
147
+ unet = unet.to(device=device, dtype=torch.float16)
148
+ rmbg = rmbg.to(device=device, dtype=torch.float32)
149
+
150
+ # SDP
151
+
152
+ unet.set_attn_processor(AttnProcessor2_0())
153
+ vae.set_attn_processor(AttnProcessor2_0())
154
+
155
+ # Samplers
156
+
157
+ ddim_scheduler = DDIMScheduler(
158
+ num_train_timesteps=1000,
159
+ beta_start=0.00085,
160
+ beta_end=0.012,
161
+ beta_schedule="scaled_linear",
162
+ clip_sample=False,
163
+ set_alpha_to_one=False,
164
+ steps_offset=1,
165
+ )
166
+
167
+ euler_a_scheduler = EulerAncestralDiscreteScheduler(
168
+ num_train_timesteps=1000,
169
+ beta_start=0.00085,
170
+ beta_end=0.012,
171
+ steps_offset=1
172
+ )
173
+
174
+ dpmpp_2m_sde_karras_scheduler = DPMSolverMultistepScheduler(
175
+ num_train_timesteps=1000,
176
+ beta_start=0.00085,
177
+ beta_end=0.012,
178
+ algorithm_type="sde-dpmsolver++",
179
+ use_karras_sigmas=True,
180
+ steps_offset=1
181
+ )
182
+
183
+ # Pipelines
184
+
185
+ t2i_pipe = StableDiffusionPipeline(
186
+ vae=vae,
187
+ text_encoder=text_encoder,
188
+ tokenizer=tokenizer,
189
+ unet=unet,
190
+ scheduler=dpmpp_2m_sde_karras_scheduler,
191
+ safety_checker=None,
192
+ requires_safety_checker=False,
193
+ feature_extractor=None,
194
+ image_encoder=None
195
+ )
196
+
197
+ i2i_pipe = StableDiffusionImg2ImgPipeline(
198
+ vae=vae,
199
+ text_encoder=text_encoder,
200
+ tokenizer=tokenizer,
201
+ unet=unet,
202
+ scheduler=dpmpp_2m_sde_karras_scheduler,
203
+ safety_checker=None,
204
+ requires_safety_checker=False,
205
+ feature_extractor=None,
206
+ image_encoder=None
207
+ )
208
+
209
+
210
+ @torch.inference_mode()
211
+ def encode_prompt_inner(txt: str):
212
+ max_length = tokenizer.model_max_length
213
+ chunk_length = tokenizer.model_max_length - 2
214
+ id_start = tokenizer.bos_token_id
215
+ id_end = tokenizer.eos_token_id
216
+ id_pad = id_end
217
+
218
+ def pad(x, p, i):
219
+ return x[:i] if len(x) >= i else x + [p] * (i - len(x))
220
+
221
+ tokens = tokenizer(txt, truncation=False, add_special_tokens=False)["input_ids"]
222
+ chunks = [[id_start] + tokens[i: i + chunk_length] + [id_end] for i in range(0, len(tokens), chunk_length)]
223
+ chunks = [pad(ck, id_pad, max_length) for ck in chunks]
224
+
225
+ token_ids = torch.tensor(chunks).to(device=device, dtype=torch.int64)
226
+ conds = text_encoder(token_ids).last_hidden_state
227
+
228
+ return conds
229
+
230
+
231
+ @torch.inference_mode()
232
+ def encode_prompt_pair(positive_prompt, negative_prompt):
233
+ c = encode_prompt_inner(positive_prompt)
234
+ uc = encode_prompt_inner(negative_prompt)
235
+
236
+ c_len = float(len(c))
237
+ uc_len = float(len(uc))
238
+ max_count = max(c_len, uc_len)
239
+ c_repeat = int(math.ceil(max_count / c_len))
240
+ uc_repeat = int(math.ceil(max_count / uc_len))
241
+ max_chunk = max(len(c), len(uc))
242
+
243
+ c = torch.cat([c] * c_repeat, dim=0)[:max_chunk]
244
+ uc = torch.cat([uc] * uc_repeat, dim=0)[:max_chunk]
245
+
246
+ c = torch.cat([p[None, ...] for p in c], dim=1)
247
+ uc = torch.cat([p[None, ...] for p in uc], dim=1)
248
+
249
+ return c, uc
250
+
251
+
252
+ @torch.inference_mode()
253
+ def pytorch2numpy(imgs, quant=True):
254
+ results = []
255
+ for x in imgs:
256
+ y = x.movedim(0, -1)
257
+
258
+ if quant:
259
+ y = y * 127.5 + 127.5
260
+ y = y.detach().float().cpu().numpy().clip(0, 255).astype(np.uint8)
261
+ else:
262
+ y = y * 0.5 + 0.5
263
+ y = y.detach().float().cpu().numpy().clip(0, 1).astype(np.float32)
264
+
265
+ results.append(y)
266
+ return results
267
+
268
+
269
+ @torch.inference_mode()
270
+ def numpy2pytorch(imgs):
271
+ h = torch.from_numpy(np.stack(imgs, axis=0)).float() / 127.0 - 1.0 # so that 127 must be strictly 0.0
272
+ h = h.movedim(-1, 1)
273
+ return h
274
+
275
+
276
+ def resize_and_center_crop(image, target_width, target_height):
277
+ pil_image = Image.fromarray(image)
278
+ original_width, original_height = pil_image.size
279
+ scale_factor = max(target_width / original_width, target_height / original_height)
280
+ resized_width = int(round(original_width * scale_factor))
281
+ resized_height = int(round(original_height * scale_factor))
282
+ resized_image = pil_image.resize((resized_width, resized_height), Image.LANCZOS)
283
+ left = (resized_width - target_width) / 2
284
+ top = (resized_height - target_height) / 2
285
+ right = (resized_width + target_width) / 2
286
+ bottom = (resized_height + target_height) / 2
287
+ cropped_image = resized_image.crop((left, top, right, bottom))
288
+ return np.array(cropped_image)
289
+
290
+
291
+ def resize_without_crop(image, target_width, target_height):
292
+ pil_image = Image.fromarray(image)
293
+ resized_image = pil_image.resize((target_width, target_height), Image.LANCZOS)
294
+ return np.array(resized_image)
295
+
296
+
297
+ @torch.inference_mode()
298
+ def run_rmbg(img, sigma=0.0):
299
+ H, W, C = img.shape
300
+ assert C == 3
301
+ k = (256.0 / float(H * W)) ** 0.5
302
+ feed = resize_without_crop(img, int(64 * round(W * k)), int(64 * round(H * k)))
303
+ feed = numpy2pytorch([feed]).to(device=device, dtype=torch.float32)
304
+ alpha = rmbg(feed)[0][0]
305
+ alpha = torch.nn.functional.interpolate(alpha, size=(H, W), mode="bilinear")
306
+ alpha = alpha.movedim(1, -1)[0]
307
+ alpha = alpha.detach().float().cpu().numpy().clip(0, 1)
308
+ result = 127 + (img.astype(np.float32) - 127 + sigma) * alpha
309
+ return result.clip(0, 255).astype(np.uint8), alpha
310
+
311
+
312
+ @torch.inference_mode()
313
+ def process(input_fg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, lowres_denoise, bg_source):
314
+ bg_source = BGSource(bg_source)
315
+ input_bg = None
316
+
317
+ if bg_source == BGSource.NONE:
318
+ pass
319
+ elif bg_source == BGSource.LEFT:
320
+ gradient = np.linspace(255, 0, image_width)
321
+ image = np.tile(gradient, (image_height, 1))
322
+ input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
323
+ elif bg_source == BGSource.RIGHT:
324
+ gradient = np.linspace(0, 255, image_width)
325
+ image = np.tile(gradient, (image_height, 1))
326
+ input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
327
+ elif bg_source == BGSource.TOP:
328
+ gradient = np.linspace(255, 0, image_height)[:, None]
329
+ image = np.tile(gradient, (1, image_width))
330
+ input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
331
+ elif bg_source == BGSource.BOTTOM:
332
+ gradient = np.linspace(0, 255, image_height)[:, None]
333
+ image = np.tile(gradient, (1, image_width))
334
+ input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
335
+ else:
336
+ raise 'Wrong initial latent!'
337
+
338
+ rng = torch.Generator(device=device).manual_seed(int(seed))
339
+
340
+ fg = resize_and_center_crop(input_fg, image_width, image_height)
341
+
342
+ concat_conds = numpy2pytorch([fg]).to(device=vae.device, dtype=vae.dtype)
343
+ concat_conds = vae.encode(concat_conds).latent_dist.mode() * vae.config.scaling_factor
344
+
345
+ conds, unconds = encode_prompt_pair(positive_prompt=prompt + ', ' + a_prompt, negative_prompt=n_prompt)
346
+
347
+ if input_bg is None:
348
+ latents = t2i_pipe(
349
+ prompt_embeds=conds,
350
+ negative_prompt_embeds=unconds,
351
+ width=image_width,
352
+ height=image_height,
353
+ num_inference_steps=steps,
354
+ num_images_per_prompt=num_samples,
355
+ generator=rng,
356
+ output_type='latent',
357
+ guidance_scale=cfg,
358
+ cross_attention_kwargs={'concat_conds': concat_conds},
359
+ ).images.to(vae.dtype) / vae.config.scaling_factor
360
+ else:
361
+ bg = resize_and_center_crop(input_bg, image_width, image_height)
362
+ bg_latent = numpy2pytorch([bg]).to(device=vae.device, dtype=vae.dtype)
363
+ bg_latent = vae.encode(bg_latent).latent_dist.mode() * vae.config.scaling_factor
364
+ latents = i2i_pipe(
365
+ image=bg_latent,
366
+ strength=lowres_denoise,
367
+ prompt_embeds=conds,
368
+ negative_prompt_embeds=unconds,
369
+ width=image_width,
370
+ height=image_height,
371
+ num_inference_steps=int(round(steps / lowres_denoise)),
372
+ num_images_per_prompt=num_samples,
373
+ generator=rng,
374
+ output_type='latent',
375
+ guidance_scale=cfg,
376
+ cross_attention_kwargs={'concat_conds': concat_conds},
377
+ ).images.to(vae.dtype) / vae.config.scaling_factor
378
+
379
+ pixels = vae.decode(latents).sample
380
+ pixels = pytorch2numpy(pixels)
381
+ pixels = [resize_without_crop(
382
+ image=p,
383
+ target_width=int(round(image_width * highres_scale / 64.0) * 64),
384
+ target_height=int(round(image_height * highres_scale / 64.0) * 64))
385
+ for p in pixels]
386
+
387
+ pixels = numpy2pytorch(pixels).to(device=vae.device, dtype=vae.dtype)
388
+ latents = vae.encode(pixels).latent_dist.mode() * vae.config.scaling_factor
389
+ latents = latents.to(device=unet.device, dtype=unet.dtype)
390
+
391
+ image_height, image_width = latents.shape[2] * 8, latents.shape[3] * 8
392
+
393
+ fg = resize_and_center_crop(input_fg, image_width, image_height)
394
+ concat_conds = numpy2pytorch([fg]).to(device=vae.device, dtype=vae.dtype)
395
+ concat_conds = vae.encode(concat_conds).latent_dist.mode() * vae.config.scaling_factor
396
+
397
+ latents = i2i_pipe(
398
+ image=latents,
399
+ strength=highres_denoise,
400
+ prompt_embeds=conds,
401
+ negative_prompt_embeds=unconds,
402
+ width=image_width,
403
+ height=image_height,
404
+ num_inference_steps=int(round(steps / highres_denoise)),
405
+ num_images_per_prompt=num_samples,
406
+ generator=rng,
407
+ output_type='latent',
408
+ guidance_scale=cfg,
409
+ cross_attention_kwargs={'concat_conds': concat_conds},
410
+ ).images.to(vae.dtype) / vae.config.scaling_factor
411
+
412
+ pixels = vae.decode(latents).sample
413
+
414
+ return pytorch2numpy(pixels)
415
+
416
+
417
+ #@spaces.GPU
418
+ @torch.inference_mode()
419
+ def process_relight(input_fg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, lowres_denoise, bg_source):
420
+ input_fg, matting = run_rmbg(input_fg)
421
+ results = process(input_fg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, lowres_denoise, bg_source)
422
+ return input_fg, results
423
+
424
+
425
+ quick_prompts = [
426
+ 'sunshine from window',
427
+ 'neon light, city',
428
+ 'sunset over sea',
429
+ 'golden time',
430
+ 'sci-fi RGB glowing, cyberpunk',
431
+ 'natural lighting',
432
+ 'warm atmosphere, at home, bedroom',
433
+ 'magic lit',
434
+ 'evil, gothic, Yharnam',
435
+ 'light and shadow',
436
+ 'shadow from window',
437
+ 'soft studio lighting',
438
+ 'home atmosphere, cozy bedroom illumination',
439
+ 'neon, Wong Kar-wai, warm'
440
+ ]
441
+ quick_prompts = [[x] for x in quick_prompts]
442
+
443
+
444
+ quick_subjects = [
445
+ 'beautiful woman, detailed face',
446
+ 'handsome man, detailed face',
447
+ ]
448
+ quick_subjects = [[x] for x in quick_subjects]
449
+
450
+
451
+ class BGSource(Enum):
452
+ NONE = "None"
453
+ LEFT = "Left Light"
454
+ RIGHT = "Right Light"
455
+ TOP = "Top Light"
456
+ BOTTOM = "Bottom Light"
457
+
458
+
459
+ block = gr.Blocks().queue()
460
+ with block:
461
+ with gr.Row():
462
+ gr.Markdown("## IC-Light (Relighting with Foreground Condition)")
463
+ with gr.Row():
464
+ gr.Markdown("See also https://github.com/lllyasviel/IC-Light for background-conditioned model and normal estimation")
465
+ with gr.Row():
466
+ with gr.Column():
467
+ with gr.Row():
468
+ input_fg = gr.Image(sources='upload', type="numpy", label="Image", height=480)
469
+ output_bg = gr.Image(type="numpy", label="Preprocessed Foreground", height=480)
470
+ prompt = gr.Textbox(label="Prompt")
471
+ bg_source = gr.Radio(choices=[e.value for e in BGSource],
472
+ value=BGSource.NONE.value,
473
+ label="Lighting Preference (Initial Latent)", type='value')
474
+ example_quick_subjects = gr.Dataset(samples=quick_subjects, label='Subject Quick List', samples_per_page=1000, components=[prompt])
475
+ example_quick_prompts = gr.Dataset(samples=quick_prompts, label='Lighting Quick List', samples_per_page=1000, components=[prompt])
476
+ relight_button = gr.Button(value="Relight")
477
+
478
+ with gr.Group():
479
+ with gr.Row():
480
+ num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1)
481
+ seed = gr.Number(label="Seed", value=12345, precision=0)
482
+
483
+ with gr.Row():
484
+ image_width = gr.Slider(label="Image Width", minimum=256, maximum=1024, value=512, step=64)
485
+ image_height = gr.Slider(label="Image Height", minimum=256, maximum=1024, value=640, step=64)
486
+
487
+ with gr.Accordion("Advanced options", open=False):
488
+ steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=25, step=1)
489
+ cfg = gr.Slider(label="CFG Scale", minimum=1.0, maximum=32.0, value=2, step=0.01)
490
+ lowres_denoise = gr.Slider(label="Lowres Denoise (for initial latent)", minimum=0.1, maximum=1.0, value=0.9, step=0.01)
491
+ highres_scale = gr.Slider(label="Highres Scale", minimum=1.0, maximum=3.0, value=1.5, step=0.01)
492
+ highres_denoise = gr.Slider(label="Highres Denoise", minimum=0.1, maximum=1.0, value=0.5, step=0.01)
493
+ a_prompt = gr.Textbox(label="Added Prompt", value='best quality')
494
+ n_prompt = gr.Textbox(label="Negative Prompt", value='lowres, bad anatomy, bad hands, cropped, worst quality')
495
+ with gr.Column():
496
+ result_gallery = gr.Gallery(height=832, object_fit='contain', label='Outputs')
497
+ download_button = gr.Button(value="Download Gallery as ZIP")
498
+ download_link = gr.File(label="Download Link", visible=False)
499
+ with gr.Row():
500
+ dummy_image_for_outputs = gr.Image(visible=False, label='Result')
501
+ gr.Examples(
502
+ fn=lambda *args: [[args[-1]], "imgs/dummy.png"],
503
+ examples=db_examples.foreground_conditioned_examples,
504
+ inputs=[
505
+ input_fg, prompt, bg_source, image_width, image_height, seed, dummy_image_for_outputs
506
+ ],
507
+ outputs=[result_gallery, output_bg],
508
+ run_on_click=True, examples_per_page=1024
509
+ )
510
+ ips = [input_fg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, lowres_denoise, bg_source]
511
+ relight_button.click(fn=process_relight, inputs=ips, outputs=[output_bg, result_gallery])
512
+ example_quick_prompts.click(lambda x, y: ', '.join(y.split(', ')[:2] + [x[0]]), inputs=[example_quick_prompts, prompt], outputs=prompt, show_progress=False, queue=False)
513
+ example_quick_subjects.click(lambda x: x[0], inputs=example_quick_subjects, outputs=prompt, show_progress=False, queue=False)
514
+
515
+ # ๆทปๅŠ ไธ‹่ฝฝๆŒ‰้’ฎ็š„ๅŠŸ่ƒฝ
516
+ download_button.click(
517
+ fn=create_zip_from_gallery,
518
+ inputs=[result_gallery, prompt],
519
+ outputs=download_link,
520
+ show_progress=True
521
+ )
522
+
523
+ import pathlib
524
+ im_l = list(map(str ,pathlib.Path("./examples/xiangxiang_man").rglob("*.png")))
525
+ gr.Examples(
526
+ im_l, input_fg
527
+ )
528
+
529
+ block.launch(share=True)