alexnasa commited on
Commit
8381600
·
verified ·
1 Parent(s): 2910919

Update src/flux/generate.py

Browse files
Files changed (1) hide show
  1. src/flux/generate.py +798 -798
src/flux/generate.py CHANGED
@@ -21,818 +21,818 @@ from typing import List, Union, Optional, Dict, Any, Callable
21
  from src.flux.transformer import tranformer_forward
22
  from src.flux.condition import Condition
23
 
24
- # from diffusers.pipelines.flux.pipeline_flux import (
25
- # FluxPipelineOutput,
26
- # calculate_shift,
27
- # retrieve_timesteps,
28
- # np,
 
 
 
29
  # )
30
- from src.flux.pipeline_tools import (
31
- encode_prompt_with_clip_t5, tokenize_t5_prompt, clear_attn_maps, encode_vae_images
32
- )
33
-
34
- from src.flux.pipeline_tools import CustomFluxPipeline, load_modulation_adapter, decode_vae_images, \
35
- save_attention_maps, gather_attn_maps, clear_attn_maps, load_dit_lora, quantization
36
-
37
- from src.utils.data_utils import pad_to_square, pad_to_target, pil2tensor, get_closest_ratio, get_aspect_ratios
38
- from src.utils.modulation_utils import get_word_index, unpad_input_ids
39
-
40
- def get_config(config_path: str = None):
41
- config_path = config_path or os.environ.get("XFL_CONFIG")
42
- if not config_path:
43
- return {}
44
- with open(config_path, "r") as f:
45
- config = yaml.safe_load(f)
46
- return config
47
-
48
-
49
- def prepare_params(
50
- prompt: Union[str, List[str]] = None,
51
- prompt_2: Optional[Union[str, List[str]]] = None,
52
- height: Optional[int] = 512,
53
- width: Optional[int] = 512,
54
- num_inference_steps: int = 28,
55
- timesteps: List[int] = None,
56
- guidance_scale: float = 3.5,
57
- num_images_per_prompt: Optional[int] = 1,
58
- generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
59
- latents: Optional[torch.FloatTensor] = None,
60
- prompt_embeds: Optional[torch.FloatTensor] = None,
61
- pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
62
- output_type: Optional[str] = "pil",
63
- return_dict: bool = True,
64
- joint_attention_kwargs: Optional[Dict[str, Any]] = None,
65
- callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
66
- callback_on_step_end_tensor_inputs: List[str] = ["latents"],
67
- max_sequence_length: int = 512,
68
- verbose: bool = False,
69
- **kwargs: dict,
70
- ):
71
- return (
72
- prompt,
73
- prompt_2,
74
- height,
75
- width,
76
- num_inference_steps,
77
- timesteps,
78
- guidance_scale,
79
- num_images_per_prompt,
80
- generator,
81
- latents,
82
- prompt_embeds,
83
- pooled_prompt_embeds,
84
- output_type,
85
- return_dict,
86
- joint_attention_kwargs,
87
- callback_on_step_end,
88
- callback_on_step_end_tensor_inputs,
89
- max_sequence_length,
90
- verbose,
91
- )
92
-
93
-
94
- def seed_everything(seed: int = 42):
95
- torch.backends.cudnn.deterministic = True
96
- torch.manual_seed(seed)
97
- np.random.seed(seed)
98
-
99
-
100
- @torch.no_grad()
101
- def generate(
102
- pipeline: FluxPipeline,
103
- vae_conditions: List[Condition] = None,
104
- config_path: str = None,
105
- model_config: Optional[Dict[str, Any]] = {},
106
- vae_condition_scale: float = 1.0,
107
- default_lora: bool = False,
108
- condition_pad_to: str = "square",
109
- condition_size: int = 512,
110
- text_cond_mask: Optional[torch.FloatTensor] = None,
111
- delta_emb: Optional[torch.FloatTensor] = None,
112
- delta_emb_pblock: Optional[torch.FloatTensor] = None,
113
- delta_emb_mask: Optional[torch.FloatTensor] = None,
114
- delta_start_ends = None,
115
- condition_latents = None,
116
- condition_ids = None,
117
- mod_adapter = None,
118
- store_attn_map: bool = False,
119
- vae_skip_iter: str = None,
120
- control_weight_lambda: str = None,
121
- double_attention: bool = False,
122
- single_attention: bool = False,
123
- ip_scale: str = None,
124
- use_latent_sblora_control: bool = False,
125
- latent_sblora_scale: str = None,
126
- use_condition_sblora_control: bool = False,
127
- condition_sblora_scale: str = None,
128
- idips = None,
129
- **params: dict,
130
- ):
131
- model_config = model_config or get_config(config_path).get("model", {})
132
-
133
- vae_skip_iter = model_config.get("vae_skip_iter", vae_skip_iter)
134
- double_attention = model_config.get("double_attention", double_attention)
135
- single_attention = model_config.get("single_attention", single_attention)
136
- control_weight_lambda = model_config.get("control_weight_lambda", control_weight_lambda)
137
- ip_scale = model_config.get("ip_scale", ip_scale)
138
- use_latent_sblora_control = model_config.get("use_latent_sblora_control", use_latent_sblora_control)
139
- use_condition_sblora_control = model_config.get("use_condition_sblora_control", use_condition_sblora_control)
140
-
141
- latent_sblora_scale = model_config.get("latent_sblora_scale", latent_sblora_scale)
142
- condition_sblora_scale = model_config.get("condition_sblora_scale", condition_sblora_scale)
143
-
144
- model_config["use_attention_double"] = False
145
- model_config["use_attention_single"] = False
146
- use_attention = False
147
 
148
- if idips is not None:
149
- if control_weight_lambda != "no":
150
- parts = control_weight_lambda.split(',')
151
- new_parts = []
152
- for part in parts:
153
- if ':' in part:
154
- left, right = part.split(':')
155
- values = right.split('/')
156
- # 保存整体值
157
- global_value = values[0]
158
- id_value = values[1]
159
- ip_value = values[2]
160
- new_values = [global_value]
161
- for is_id in idips:
162
- if is_id:
163
- new_values.append(id_value)
164
- else:
165
- new_values.append(ip_value)
166
- new_part = f"{left}:{('/'.join(new_values))}"
167
- new_parts.append(new_part)
168
- else:
169
- new_parts.append(part)
170
- control_weight_lambda = ','.join(new_parts)
171
-
172
- if vae_condition_scale != 1:
173
- for name, module in pipeline.transformer.named_modules():
174
- if not name.endswith(".attn"):
175
- continue
176
- module.c_factor = torch.ones(1, 1) * vae_condition_scale
177
-
178
- self = pipeline
179
- (
180
- prompt,
181
- prompt_2,
182
- height,
183
- width,
184
- num_inference_steps,
185
- timesteps,
186
- guidance_scale,
187
- num_images_per_prompt,
188
- generator,
189
- latents,
190
- prompt_embeds,
191
- pooled_prompt_embeds,
192
- output_type,
193
- return_dict,
194
- joint_attention_kwargs,
195
- callback_on_step_end,
196
- callback_on_step_end_tensor_inputs,
197
- max_sequence_length,
198
- verbose,
199
- ) = prepare_params(**params)
200
-
201
- height = height or self.default_sample_size * self.vae_scale_factor
202
- width = width or self.default_sample_size * self.vae_scale_factor
203
-
204
- # 1. Check inputs. Raise error if not correct
205
- self.check_inputs(
206
- prompt,
207
- prompt_2,
208
- height,
209
- width,
210
- prompt_embeds=prompt_embeds,
211
- pooled_prompt_embeds=pooled_prompt_embeds,
212
- callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
213
- max_sequence_length=max_sequence_length,
214
- )
215
-
216
- self._guidance_scale = guidance_scale
217
- self._joint_attention_kwargs = joint_attention_kwargs
218
- self._interrupt = False
219
-
220
- # 2. Define call parameters
221
- if prompt is not None and isinstance(prompt, str):
222
- batch_size = 1
223
- elif prompt is not None and isinstance(prompt, list):
224
- batch_size = len(prompt)
225
- else:
226
- batch_size = prompt_embeds.shape[0]
227
-
228
- device = self._execution_device
229
-
230
- lora_scale = (
231
- self.joint_attention_kwargs.get("scale", None)
232
- if self.joint_attention_kwargs is not None
233
- else None
234
- )
235
- (
236
- t5_prompt_embeds,
237
- pooled_prompt_embeds,
238
- text_ids,
239
- ) = encode_prompt_with_clip_t5(
240
- self=self,
241
- prompt="" if self.text_encoder_2 is None else prompt,
242
- prompt_2=None,
243
- prompt_embeds=prompt_embeds,
244
- pooled_prompt_embeds=pooled_prompt_embeds,
245
- device=device,
246
- num_images_per_prompt=num_images_per_prompt,
247
- max_sequence_length=max_sequence_length,
248
- lora_scale=lora_scale,
249
- )
250
-
251
- # 4. Prepare latent variables
252
- num_channels_latents = self.transformer.config.in_channels // 4
253
- latents, latent_image_ids = self.prepare_latents(
254
- batch_size * num_images_per_prompt,
255
- num_channels_latents,
256
- height,
257
- width,
258
- pooled_prompt_embeds.dtype,
259
- device,
260
- generator,
261
- latents,
262
- )
263
-
264
- latent_height = height // 16
265
-
266
- # 5. Prepare timesteps
267
- sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
268
- image_seq_len = latents.shape[1]
269
- mu = calculate_shift(
270
- image_seq_len,
271
- self.scheduler.config.base_image_seq_len,
272
- self.scheduler.config.max_image_seq_len,
273
- self.scheduler.config.base_shift,
274
- self.scheduler.config.max_shift,
275
- )
276
- timesteps, num_inference_steps = retrieve_timesteps(
277
- self.scheduler,
278
- num_inference_steps,
279
- device,
280
- timesteps,
281
- sigmas,
282
- mu=mu,
283
- )
284
- num_warmup_steps = max(
285
- len(timesteps) - num_inference_steps * self.scheduler.order, 0
286
- )
287
- self._num_timesteps = len(timesteps)
288
-
289
- attn_map = None
290
-
291
- # 6. Denoising loop
292
- with self.progress_bar(total=num_inference_steps) as progress_bar:
293
- totalsteps = timesteps[0]
294
- if control_weight_lambda is not None:
295
- print("control_weight_lambda", control_weight_lambda)
296
- control_weight_lambda_schedule = []
297
- for scale_str in control_weight_lambda.split(','):
298
- time_region, scale = scale_str.split(':')
299
- start, end = time_region.split('-')
300
- scales = [float(s) for s in scale.split('/')]
301
- control_weight_lambda_schedule.append([(1-float(start))*totalsteps, (1-float(end))*totalsteps, scales])
302
-
303
- if ip_scale is not None:
304
- print("ip_scale", ip_scale)
305
- ip_scale_schedule = []
306
- for scale_str in ip_scale.split(','):
307
- time_region, scale = scale_str.split(':')
308
- start, end = time_region.split('-')
309
- ip_scale_schedule.append([(1-float(start))*totalsteps, (1-float(end))*totalsteps, float(scale)])
310
-
311
- if use_latent_sblora_control:
312
- if latent_sblora_scale is not None:
313
- print("latent_sblora_scale", latent_sblora_scale)
314
- latent_sblora_scale_schedule = []
315
- for scale_str in latent_sblora_scale.split(','):
316
- time_region, scale = scale_str.split(':')
317
- start, end = time_region.split('-')
318
- latent_sblora_scale_schedule.append([(1-float(start))*totalsteps, (1-float(end))*totalsteps, float(scale)])
319
 
320
- if use_condition_sblora_control:
321
- if condition_sblora_scale is not None:
322
- print("condition_sblora_scale", condition_sblora_scale)
323
- condition_sblora_scale_schedule = []
324
- for scale_str in condition_sblora_scale.split(','):
325
- time_region, scale = scale_str.split(':')
326
- start, end = time_region.split('-')
327
- condition_sblora_scale_schedule.append([(1-float(start))*totalsteps, (1-float(end))*totalsteps, float(scale)])
328
-
329
-
330
- if vae_skip_iter is not None:
331
- print("vae_skip_iter", vae_skip_iter)
332
- vae_skip_iter_schedule = []
333
- for scale_str in vae_skip_iter.split(','):
334
- time_region, scale = scale_str.split(':')
335
- start, end = time_region.split('-')
336
- vae_skip_iter_schedule.append([(1-float(start))*totalsteps, (1-float(end))*totalsteps, float(scale)])
337
-
338
- if control_weight_lambda is not None and attn_map is None:
339
- batch_size = latents.shape[0]
340
- latent_width = latents.shape[1]//latent_height
341
- attn_map = torch.ones(batch_size, latent_height, latent_width, 128, device=latents.device, dtype=torch.bfloat16)
342
- print("contol_weight_only", attn_map.shape)
343
-
344
- self.scheduler.set_begin_index(0)
345
- self.scheduler._init_step_index(0)
346
- for i, t in enumerate(timesteps):
347
 
348
- if control_weight_lambda is not None:
349
- cur_control_weight_lambda = []
350
- for start, end, scale in control_weight_lambda_schedule:
351
- if t <= start and t >= end:
352
- cur_control_weight_lambda = scale
353
- break
354
- print(f"timestep:{t}, cur_control_weight_lambda:{cur_control_weight_lambda}")
355
 
356
- if cur_control_weight_lambda:
357
- model_config["use_attention_single"] = True
358
- use_attention = True
359
- model_config["use_atten_lambda"] = cur_control_weight_lambda
360
- else:
361
- model_config["use_attention_single"] = False
362
- use_attention = False
363
 
364
- if self.interrupt:
365
- continue
366
-
367
- if isinstance(delta_emb, list):
368
- cur_delta_emb = delta_emb[i]
369
- cur_delta_emb_pblock = delta_emb_pblock[i]
370
- cur_delta_emb_mask = delta_emb_mask[i]
371
- else:
372
- cur_delta_emb = delta_emb
373
- cur_delta_emb_pblock = delta_emb_pblock
374
- cur_delta_emb_mask = delta_emb_mask
375
-
376
-
377
- # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
378
- timestep = t.expand(latents.shape[0]).to(latents.dtype) / 1000
379
- prompt_embeds = t5_prompt_embeds
380
- text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=prompt_embeds.dtype)
381
-
382
- # handle guidance
383
- if self.transformer.config.guidance_embeds:
384
- guidance = torch.tensor([guidance_scale], device=device)
385
- guidance = guidance.expand(latents.shape[0])
386
- else:
387
- guidance = None
388
- self.transformer.enable_lora()
389
 
390
- lora_weight = 1
391
- if ip_scale is not None:
392
- lora_weight = 0
393
- for start, end, scale in ip_scale_schedule:
394
- if t <= start and t >= end:
395
- lora_weight = scale
396
- break
397
- if lora_weight != 1: print(f"timestep:{t}, lora_weights:{lora_weight}")
398
 
399
- latent_sblora_weight = None
400
- if use_latent_sblora_control:
401
- if latent_sblora_scale is not None:
402
- latent_sblora_weight = 0
403
- for start, end, scale in latent_sblora_scale_schedule:
404
- if t <= start and t >= end:
405
- latent_sblora_weight = scale
406
- break
407
- if latent_sblora_weight != 1: print(f"timestep:{t}, latent_sblora_weight:{latent_sblora_weight}")
408
 
409
- condition_sblora_weight = None
410
- if use_condition_sblora_control:
411
- if condition_sblora_scale is not None:
412
- condition_sblora_weight = 0
413
- for start, end, scale in condition_sblora_scale_schedule:
414
- if t <= start and t >= end:
415
- condition_sblora_weight = scale
416
- break
417
- if condition_sblora_weight !=1: print(f"timestep:{t}, condition_sblora_weight:{condition_sblora_weight}")
418
-
419
- vae_skip_iter_t = False
420
- if vae_skip_iter is not None:
421
- for start, end, scale in vae_skip_iter_schedule:
422
- if t <= start and t >= end:
423
- vae_skip_iter_t = bool(scale)
424
- break
425
- if vae_skip_iter_t:
426
- print(f"timestep:{t}, skip vae:{vae_skip_iter_t}")
427
-
428
- noise_pred = tranformer_forward(
429
- self.transformer,
430
- model_config=model_config,
431
- # Inputs of the condition (new feature)
432
- text_cond_mask=text_cond_mask,
433
- delta_emb=cur_delta_emb,
434
- delta_emb_pblock=cur_delta_emb_pblock,
435
- delta_emb_mask=cur_delta_emb_mask,
436
- delta_start_ends=delta_start_ends,
437
- condition_latents=None if vae_skip_iter_t else condition_latents,
438
- condition_ids=None if vae_skip_iter_t else condition_ids,
439
- condition_type_ids=None,
440
- # Inputs to the original transformer
441
- hidden_states=latents,
442
- # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
443
- timestep=timestep,
444
- guidance=guidance,
445
- pooled_projections=pooled_prompt_embeds,
446
- encoder_hidden_states=prompt_embeds,
447
- txt_ids=text_ids,
448
- img_ids=latent_image_ids,
449
- joint_attention_kwargs={'scale': lora_weight, "latent_sblora_weight": latent_sblora_weight, "condition_sblora_weight": condition_sblora_weight},
450
- store_attn_map=use_attention,
451
- last_attn_map=attn_map if cur_control_weight_lambda else None,
452
- use_text_mod=model_config["modulation"]["use_text_mod"],
453
- use_img_mod=model_config["modulation"]["use_img_mod"],
454
- mod_adapter=mod_adapter,
455
- latent_height=latent_height,
456
- return_dict=False,
457
- )[0]
458
-
459
- if use_attention:
460
- attn_maps, _ = gather_attn_maps(self.transformer, clear=True)
461
-
462
- # compute the previous noisy sample x_t -> x_t-1
463
- latents_dtype = latents.dtype
464
- latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
465
-
466
- if latents.dtype != latents_dtype:
467
- if torch.backends.mps.is_available():
468
- # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
469
- latents = latents.to(latents_dtype)
470
-
471
- if callback_on_step_end is not None:
472
- callback_kwargs = {}
473
- for k in callback_on_step_end_tensor_inputs:
474
- callback_kwargs[k] = locals()[k]
475
- callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
476
-
477
- latents = callback_outputs.pop("latents", latents)
478
- prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
479
-
480
- # call the callback, if provided
481
- if i == len(timesteps) - 1 or (
482
- (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
483
- ):
484
- progress_bar.update()
485
-
486
- if output_type == "latent":
487
- image = latents
488
-
489
- else:
490
- latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
491
- latents = (
492
- latents / self.vae.config.scaling_factor
493
- ) + self.vae.config.shift_factor
494
- image = self.vae.decode(latents, return_dict=False)[0]
495
- image = self.image_processor.postprocess(image, output_type=output_type)
496
-
497
- # Offload all models
498
- self.maybe_free_model_hooks()
499
-
500
- self.transformer.enable_lora()
501
-
502
- if vae_condition_scale != 1:
503
- for name, module in pipeline.transformer.named_modules():
504
- if not name.endswith(".attn"):
505
- continue
506
- del module.c_factor
507
-
508
- if not return_dict:
509
- return (image,)
510
-
511
- return FluxPipelineOutput(images=image)
512
-
513
-
514
- @torch.no_grad()
515
- def generate_from_test_sample(
516
- test_sample, pipe, config,
517
- num_images=1,
518
- vae_skip_iter: str = None,
519
- target_height: int = None,
520
- target_width: int = None,
521
- seed: int = 42,
522
- control_weight_lambda: str = None,
523
- double_attention: bool = False,
524
- single_attention: bool = False,
525
- ip_scale: str = None,
526
- use_latent_sblora_control: bool = False,
527
- latent_sblora_scale: str = None,
528
- use_condition_sblora_control: bool = False,
529
- condition_sblora_scale: str = None,
530
- use_idip = False,
531
- **kargs
532
- ):
533
- target_size = config["train"]["dataset"]["val_target_size"]
534
- condition_size = config["train"]["dataset"].get("val_condition_size", target_size//2)
535
- condition_pad_to = config["train"]["dataset"]["condition_pad_to"]
536
- pos_offset_type = config["model"].get("pos_offset_type", "width")
537
- seed = config["model"].get("seed", seed)
538
-
539
- device = pipe._execution_device
540
-
541
- condition_imgs = test_sample['input_images']
542
- position_delta = test_sample['position_delta']
543
- prompt = test_sample['prompt']
544
- original_image = test_sample.get('original_image', None)
545
- condition_type = test_sample.get('condition_type', "subject")
546
- modulation_input = test_sample.get('modulation', None)
547
-
548
- delta_start_ends = None
549
- condition_latents = condition_ids = None
550
- text_cond_mask = None
551
 
552
- delta_embs = None
553
- delta_embs_pblock = None
554
- delta_embs_mask = None
555
-
556
- try:
557
- max_length = config["model"]["modulation"]["max_text_len"]
558
- except Exception as e:
559
- print(e)
560
- max_length = 512
561
-
562
- if modulation_input is None or len(modulation_input) == 0:
563
- delta_emb = delta_emb_pblock = delta_emb_mask = None
564
- else:
565
- dtype = torch.bfloat16
566
- batch_size = 1
567
- N = config["model"]["modulation"].get("per_block_adapter_single_blocks", 0) + 19
568
- guidance = torch.tensor([3.5]).to(device).expand(batch_size)
569
- out_dim = config["model"]["modulation"]["out_dim"]
570
-
571
- tar_text_inputs = tokenize_t5_prompt(pipe, prompt, max_length)
572
- tar_padding_mask = tar_text_inputs.attention_mask.to(device).bool()
573
- tar_tokens = tar_text_inputs.input_ids.to(device)
574
- if config["model"]["modulation"]["eos_exclude"]:
575
- tar_padding_mask[tar_tokens == 1] = False
576
-
577
- def get_start_end_by_pompt_matching(src_prompts, tar_prompts):
578
- text_cond_mask = torch.zeros(batch_size, max_length, device=device, dtype=torch.bool)
579
- tar_prompt_input_ids = tokenize_t5_prompt(pipe, tar_prompts, max_length).input_ids
580
- src_prompt_count = 1
581
- start_ends = []
582
- for i, (src_prompt, tar_prompt, tar_prompt_tokens) in enumerate(zip(src_prompts, tar_prompts, tar_prompt_input_ids)):
583
- try:
584
- tar_start, tar_end = get_word_index(pipe, tar_prompt, tar_prompt_tokens, src_prompt, src_prompt_count, max_length, verbose=False)
585
- start_ends.append([tar_start, tar_end])
586
- text_cond_mask[i, tar_start:tar_end] = True
587
- except Exception as e:
588
- print(e)
589
- return start_ends, text_cond_mask
590
-
591
- def encode_mod_image(pil_images):
592
- if config["model"]["modulation"]["use_dit"]:
593
- raise NotImplementedError()
594
- else:
595
- pil_images = [pad_to_square(img).resize((224, 224)) for img in pil_images]
596
- if config["model"]["modulation"]["use_vae"]:
597
- raise NotImplementedError()
598
- else:
599
- clip_pixel_values = pipe.clip_processor(
600
- text=None, images=pil_images, do_resize=False, do_center_crop=False, return_tensors="pt",
601
- ).pixel_values.to(dtype=dtype, device=device)
602
- clip_outputs = pipe.clip_model(clip_pixel_values, output_hidden_states=True, interpolate_pos_encoding=True, return_dict=True)
603
- return clip_outputs
604
-
605
- def rgba_to_white_background(input_path, background=(255,255,255)):
606
- with Image.open(input_path).convert("RGBA") as img:
607
- img_np = np.array(img)
608
- alpha = img_np[:, :, 3] / 255.0 # 归一化Alpha通道[3](@ref)
609
- rgb = img_np[:, :, :3].astype(float) # 提取RGB通道
610
 
611
- background_np = np.full_like(rgb, background, dtype=float) # 根据参数生成背景[7](@ref)
612
 
613
- # 混合计算:前景色*alpha + 背景色*(1-alpha)
614
- result_np = rgb * alpha[..., np.newaxis] + \
615
- background_np * (1 - alpha[..., np.newaxis])
616
 
617
- result = Image.fromarray(result_np.astype(np.uint8), "RGB")
618
- return result
619
- def get_mod_emb(modulation_input, timestep):
620
- delta_emb = torch.zeros((batch_size, max_length, out_dim), dtype=dtype, device=device)
621
- delta_emb_pblock = torch.zeros((batch_size, max_length, N, out_dim), dtype=dtype, device=device)
622
- delta_emb_mask = torch.zeros((batch_size, max_length), dtype=torch.bool, device=device)
623
- delta_start_ends = None
624
- condition_latents = condition_ids = None
625
- text_cond_mask = None
626
-
627
- if modulation_input[0]["type"] == "adapter":
628
- num_inputs = len(modulation_input[0]["src_inputs"])
629
- src_prompts = [x["caption"] for x in modulation_input[0]["src_inputs"]]
630
- src_text_inputs = tokenize_t5_prompt(pipe, src_prompts, max_length)
631
- src_input_ids = unpad_input_ids(src_text_inputs.input_ids, src_text_inputs.attention_mask)
632
- tar_input_ids = unpad_input_ids(tar_text_inputs.input_ids, tar_text_inputs.attention_mask)
633
- src_prompt_embeds = pipe._get_t5_prompt_embeds(prompt=src_prompts, max_sequence_length=max_length, device=device) # (M, 512, 4096)
634
 
635
- pil_images = [rgba_to_white_background(x["image_path"]) for x in modulation_input[0]["src_inputs"]]
636
-
637
- src_ds_scales = [x.get("downsample_scale", 1.0) for x in modulation_input[0]["src_inputs"]]
638
- resized_pil_images = []
639
- for img, ds_scale in zip(pil_images, src_ds_scales):
640
- img = pad_to_square(img)
641
- if ds_scale < 1.0:
642
- assert ds_scale > 0
643
- img = img.resize((int(224 * ds_scale), int(224 * ds_scale))).resize((224, 224))
644
- resized_pil_images.append(img)
645
- pil_images = resized_pil_images
646
 
647
- img_encoded = encode_mod_image(pil_images)
648
- delta_start_ends = []
649
- text_cond_mask = torch.zeros(num_inputs, max_length, device=device, dtype=torch.bool)
650
- if config["model"]["modulation"]["pass_vae"]:
651
- pil_images = [pad_to_square(img).resize((condition_size, condition_size)) for img in pil_images]
652
- with torch.no_grad():
653
- batch_tensor = torch.stack([pil2tensor(x) for x in pil_images])
654
- x_0, img_ids = encode_vae_images(pipe, batch_tensor) # (N, 256, 64)
655
-
656
- condition_latents = x_0.clone().detach().reshape(1, -1, 64) # (1, N256, 64)
657
- condition_ids = img_ids.clone().detach()
658
- condition_ids = condition_ids.unsqueeze(0).repeat_interleave(num_inputs, dim=0) # (N, 256, 3)
659
- for i in range(num_inputs):
660
- condition_ids[i, :, 1] += 0 if pos_offset_type == "width" else -(batch_tensor.shape[-1]//16) * (i + 1)
661
- condition_ids[i, :, 2] += -(batch_tensor.shape[-1]//16) * (i + 1)
662
- condition_ids = condition_ids.reshape(-1, 3) # (N256, 3)
663
-
664
- if config["model"]["modulation"]["use_dit"]:
665
- raise NotImplementedError()
666
- else:
667
- src_delta_embs = [] # [(512, 3072)]
668
- src_delta_emb_pblock = []
669
- for i in range(num_inputs):
670
- if isinstance(img_encoded, dict):
671
- _src_clip_outputs = {}
672
- for key in img_encoded:
673
- if torch.is_tensor(img_encoded[key]):
674
- _src_clip_outputs[key] = img_encoded[key][i:i+1]
675
- else:
676
- _src_clip_outputs[key] = [x[i:i+1] for x in img_encoded[key]]
677
- _img_encoded = _src_clip_outputs
678
- else:
679
- _img_encoded = img_encoded[i:i+1]
680
 
681
- x1, x2 = pipe.modulation_adapters[0](timestep, src_prompt_embeds[i:i+1], _img_encoded)
682
- src_delta_embs.append(x1[0]) # (512, 3072)
683
- src_delta_emb_pblock.append(x2[0]) # (512, N, 3072)
684
-
685
- for input_args in modulation_input[0]["use_words"]:
686
- src_word_count = 1
687
- if len(input_args) == 3:
688
- src_input_index, src_word, tar_word = input_args
689
- tar_word_count = 1
690
- else:
691
- src_input_index, src_word, tar_word, tar_word_count = input_args[:4]
692
- src_prompt = src_prompts[src_input_index]
693
- tar_prompt = prompt
694
-
695
- src_start, src_end = get_word_index(pipe, src_prompt, src_input_ids[src_input_index], src_word, src_word_count, max_length, verbose=False)
696
- tar_start, tar_end = get_word_index(pipe, tar_prompt, tar_input_ids[0], tar_word, tar_word_count, max_length, verbose=False)
697
- if delta_emb is not None:
698
- delta_emb[:, tar_start:tar_end] = src_delta_embs[src_input_index][src_start:src_end] # (B, 512, 3072)
699
- if delta_emb_pblock is not None:
700
- delta_emb_pblock[:, tar_start:tar_end] = src_delta_emb_pblock[src_input_index][src_start:src_end] # (B, 512, N, 3072)
701
- delta_emb_mask[:, tar_start:tar_end] = True
702
- text_cond_mask[src_input_index, tar_start:tar_end] = True
703
- delta_start_ends.append([0, src_input_index, src_start, src_end, tar_start, tar_end])
704
- text_cond_mask = text_cond_mask.transpose(0, 1).unsqueeze(0)
705
-
706
- else:
707
- raise NotImplementedError()
708
- return delta_emb, delta_emb_pblock, delta_emb_mask, \
709
- text_cond_mask, delta_start_ends, condition_latents, condition_ids
710
 
711
- num_inference_steps = 28 # FIXME: harcoded here
712
- num_channels_latents = pipe.transformer.config.in_channels // 4
713
-
714
- # set timesteps
715
- sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
716
- mu = calculate_shift(
717
- num_channels_latents,
718
- pipe.scheduler.config.base_image_seq_len,
719
- pipe.scheduler.config.max_image_seq_len,
720
- pipe.scheduler.config.base_shift,
721
- pipe.scheduler.config.max_shift,
722
- )
723
- timesteps, num_inference_steps = retrieve_timesteps(
724
- pipe.scheduler,
725
- num_inference_steps,
726
- device,
727
- None,
728
- sigmas,
729
- mu=mu,
730
- )
731
-
732
- if modulation_input is not None:
733
- delta_embs = []
734
- delta_embs_pblock = []
735
- delta_embs_mask = []
736
- for i, t in enumerate(timesteps):
737
- t = t.expand(1).to(torch.bfloat16) / 1000
738
- (
739
- delta_emb, delta_emb_pblock, delta_emb_mask,
740
- text_cond_mask, delta_start_ends,
741
- condition_latents, condition_ids
742
- ) = get_mod_emb(modulation_input, t)
743
- delta_embs.append(delta_emb)
744
- delta_embs_pblock.append(delta_emb_pblock)
745
- delta_embs_mask.append(delta_emb_mask)
746
-
747
- if original_image is not None:
748
- raise NotImplementedError()
749
- (target_height, target_width), closest_ratio = get_closest_ratio(original_image.height, original_image.width, train_aspect_ratios)
750
- elif modulation_input is None or len(modulation_input) == 0:
751
- delta_emb = delta_emb_pblock = delta_emb_mask = None
752
- else:
753
- for i, t in enumerate(timesteps):
754
- t = t.expand(1).to(torch.bfloat16) / 1000
755
- (
756
- delta_emb, delta_emb_pblock, delta_emb_mask,
757
- text_cond_mask, delta_start_ends,
758
- condition_latents, condition_ids
759
- ) = get_mod_emb(modulation_input, t)
760
- delta_embs.append(delta_emb)
761
- delta_embs_pblock.append(delta_emb_pblock)
762
- delta_embs_mask.append(delta_emb_mask)
763
-
764
- if target_height is None or target_width is None:
765
- target_height = target_width = target_size
766
-
767
- if condition_pad_to == "square":
768
- condition_imgs = [pad_to_square(x) for x in condition_imgs]
769
- elif condition_pad_to == "target":
770
- condition_imgs = [pad_to_target(x, (target_size, target_size)) for x in condition_imgs]
771
- condition_imgs = [x.resize((condition_size, condition_size)).convert("RGB") for x in condition_imgs]
772
- # TODO: fix position_delta
773
- conditions = [
774
- Condition(
775
- condition_type=condition_type,
776
- condition=x,
777
- position_delta=position_delta,
778
- ) for x in condition_imgs
779
- ]
780
- # vlm_images = condition_imgs if config["model"]["use_vlm"] else []
781
-
782
- use_perblock_adapter = False
783
- try:
784
- if config["model"]["modulation"]["use_perblock_adapter"]:
785
- use_perblock_adapter = True
786
- except Exception as e:
787
- pass
788
-
789
- results = []
790
- for i in range(num_images):
791
- clear_attn_maps(pipe.transformer)
792
- generator = torch.Generator(device=device)
793
- generator.manual_seed(seed + i)
794
- if modulation_input is None or len(modulation_input) == 0:
795
- idips = None
796
- else:
797
- idips = ["human" in p["image_path"] for p in modulation_input[0]["src_inputs"]]
798
- if len(modulation_input[0]["use_words"][0])==5:
799
- print("use idips in use_words")
800
- idips = [x[-1] for x in modulation_input[0]["use_words"]]
801
- result_img = generate(
802
- pipe,
803
- prompt=prompt,
804
- max_sequence_length=max_length,
805
- vae_conditions=conditions,
806
- generator=generator,
807
- model_config=config["model"],
808
- height=target_height,
809
- width=target_width,
810
- condition_pad_to=condition_pad_to,
811
- condition_size=condition_size,
812
- text_cond_mask=text_cond_mask,
813
- delta_emb=delta_embs,
814
- delta_emb_pblock=delta_embs_pblock if use_perblock_adapter else None,
815
- delta_emb_mask=delta_embs_mask,
816
- delta_start_ends=delta_start_ends,
817
- condition_latents=condition_latents,
818
- condition_ids=condition_ids,
819
- mod_adapter=pipe.modulation_adapters[0] if config["model"]["modulation"]["use_dit"] else None,
820
- vae_skip_iter=vae_skip_iter,
821
- control_weight_lambda=control_weight_lambda,
822
- double_attention=double_attention,
823
- single_attention=single_attention,
824
- ip_scale=ip_scale,
825
- use_latent_sblora_control=use_latent_sblora_control,
826
- latent_sblora_scale=latent_sblora_scale,
827
- use_condition_sblora_control=use_condition_sblora_control,
828
- condition_sblora_scale=condition_sblora_scale,
829
- idips=idips if use_idip else None,
830
- **kargs,
831
- ).images[0]
832
-
833
- final_image = result_img
834
- results.append(final_image)
835
-
836
- if num_images == 1:
837
- return results[0]
838
- return results
 
21
  from src.flux.transformer import tranformer_forward
22
  from src.flux.condition import Condition
23
 
24
+ # # from diffusers.pipelines.flux.pipeline_flux import (
25
+ # # FluxPipelineOutput,
26
+ # # calculate_shift,
27
+ # # retrieve_timesteps,
28
+ # # np,
29
+ # # )
30
+ # from src.flux.pipeline_tools import (
31
+ # encode_prompt_with_clip_t5, tokenize_t5_prompt, clear_attn_maps, encode_vae_images
32
  # )
33
+
34
+ # from src.flux.pipeline_tools import CustomFluxPipeline, load_modulation_adapter, decode_vae_images, \
35
+ # save_attention_maps, gather_attn_maps, clear_attn_maps, load_dit_lora, quantization
36
+
37
+ # from src.utils.data_utils import pad_to_square, pad_to_target, pil2tensor, get_closest_ratio, get_aspect_ratios
38
+ # from src.utils.modulation_utils import get_word_index, unpad_input_ids
39
+
40
+ # def get_config(config_path: str = None):
41
+ # config_path = config_path or os.environ.get("XFL_CONFIG")
42
+ # if not config_path:
43
+ # return {}
44
+ # with open(config_path, "r") as f:
45
+ # config = yaml.safe_load(f)
46
+ # return config
47
+
48
+
49
+ # def prepare_params(
50
+ # prompt: Union[str, List[str]] = None,
51
+ # prompt_2: Optional[Union[str, List[str]]] = None,
52
+ # height: Optional[int] = 512,
53
+ # width: Optional[int] = 512,
54
+ # num_inference_steps: int = 28,
55
+ # timesteps: List[int] = None,
56
+ # guidance_scale: float = 3.5,
57
+ # num_images_per_prompt: Optional[int] = 1,
58
+ # generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
59
+ # latents: Optional[torch.FloatTensor] = None,
60
+ # prompt_embeds: Optional[torch.FloatTensor] = None,
61
+ # pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
62
+ # output_type: Optional[str] = "pil",
63
+ # return_dict: bool = True,
64
+ # joint_attention_kwargs: Optional[Dict[str, Any]] = None,
65
+ # callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
66
+ # callback_on_step_end_tensor_inputs: List[str] = ["latents"],
67
+ # max_sequence_length: int = 512,
68
+ # verbose: bool = False,
69
+ # **kwargs: dict,
70
+ # ):
71
+ # return (
72
+ # prompt,
73
+ # prompt_2,
74
+ # height,
75
+ # width,
76
+ # num_inference_steps,
77
+ # timesteps,
78
+ # guidance_scale,
79
+ # num_images_per_prompt,
80
+ # generator,
81
+ # latents,
82
+ # prompt_embeds,
83
+ # pooled_prompt_embeds,
84
+ # output_type,
85
+ # return_dict,
86
+ # joint_attention_kwargs,
87
+ # callback_on_step_end,
88
+ # callback_on_step_end_tensor_inputs,
89
+ # max_sequence_length,
90
+ # verbose,
91
+ # )
92
+
93
+
94
+ # def seed_everything(seed: int = 42):
95
+ # torch.backends.cudnn.deterministic = True
96
+ # torch.manual_seed(seed)
97
+ # np.random.seed(seed)
98
+
99
+
100
+ # @torch.no_grad()
101
+ # def generate(
102
+ # pipeline: FluxPipeline,
103
+ # vae_conditions: List[Condition] = None,
104
+ # config_path: str = None,
105
+ # model_config: Optional[Dict[str, Any]] = {},
106
+ # vae_condition_scale: float = 1.0,
107
+ # default_lora: bool = False,
108
+ # condition_pad_to: str = "square",
109
+ # condition_size: int = 512,
110
+ # text_cond_mask: Optional[torch.FloatTensor] = None,
111
+ # delta_emb: Optional[torch.FloatTensor] = None,
112
+ # delta_emb_pblock: Optional[torch.FloatTensor] = None,
113
+ # delta_emb_mask: Optional[torch.FloatTensor] = None,
114
+ # delta_start_ends = None,
115
+ # condition_latents = None,
116
+ # condition_ids = None,
117
+ # mod_adapter = None,
118
+ # store_attn_map: bool = False,
119
+ # vae_skip_iter: str = None,
120
+ # control_weight_lambda: str = None,
121
+ # double_attention: bool = False,
122
+ # single_attention: bool = False,
123
+ # ip_scale: str = None,
124
+ # use_latent_sblora_control: bool = False,
125
+ # latent_sblora_scale: str = None,
126
+ # use_condition_sblora_control: bool = False,
127
+ # condition_sblora_scale: str = None,
128
+ # idips = None,
129
+ # **params: dict,
130
+ # ):
131
+ # model_config = model_config or get_config(config_path).get("model", {})
132
+
133
+ # vae_skip_iter = model_config.get("vae_skip_iter", vae_skip_iter)
134
+ # double_attention = model_config.get("double_attention", double_attention)
135
+ # single_attention = model_config.get("single_attention", single_attention)
136
+ # control_weight_lambda = model_config.get("control_weight_lambda", control_weight_lambda)
137
+ # ip_scale = model_config.get("ip_scale", ip_scale)
138
+ # use_latent_sblora_control = model_config.get("use_latent_sblora_control", use_latent_sblora_control)
139
+ # use_condition_sblora_control = model_config.get("use_condition_sblora_control", use_condition_sblora_control)
140
+
141
+ # latent_sblora_scale = model_config.get("latent_sblora_scale", latent_sblora_scale)
142
+ # condition_sblora_scale = model_config.get("condition_sblora_scale", condition_sblora_scale)
143
+
144
+ # model_config["use_attention_double"] = False
145
+ # model_config["use_attention_single"] = False
146
+ # use_attention = False
 
 
 
147
 
148
+ # if idips is not None:
149
+ # if control_weight_lambda != "no":
150
+ # parts = control_weight_lambda.split(',')
151
+ # new_parts = []
152
+ # for part in parts:
153
+ # if ':' in part:
154
+ # left, right = part.split(':')
155
+ # values = right.split('/')
156
+ # # 保存整体值
157
+ # global_value = values[0]
158
+ # id_value = values[1]
159
+ # ip_value = values[2]
160
+ # new_values = [global_value]
161
+ # for is_id in idips:
162
+ # if is_id:
163
+ # new_values.append(id_value)
164
+ # else:
165
+ # new_values.append(ip_value)
166
+ # new_part = f"{left}:{('/'.join(new_values))}"
167
+ # new_parts.append(new_part)
168
+ # else:
169
+ # new_parts.append(part)
170
+ # control_weight_lambda = ','.join(new_parts)
171
+
172
+ # if vae_condition_scale != 1:
173
+ # for name, module in pipeline.transformer.named_modules():
174
+ # if not name.endswith(".attn"):
175
+ # continue
176
+ # module.c_factor = torch.ones(1, 1) * vae_condition_scale
177
+
178
+ # self = pipeline
179
+ # (
180
+ # prompt,
181
+ # prompt_2,
182
+ # height,
183
+ # width,
184
+ # num_inference_steps,
185
+ # timesteps,
186
+ # guidance_scale,
187
+ # num_images_per_prompt,
188
+ # generator,
189
+ # latents,
190
+ # prompt_embeds,
191
+ # pooled_prompt_embeds,
192
+ # output_type,
193
+ # return_dict,
194
+ # joint_attention_kwargs,
195
+ # callback_on_step_end,
196
+ # callback_on_step_end_tensor_inputs,
197
+ # max_sequence_length,
198
+ # verbose,
199
+ # ) = prepare_params(**params)
200
+
201
+ # height = height or self.default_sample_size * self.vae_scale_factor
202
+ # width = width or self.default_sample_size * self.vae_scale_factor
203
+
204
+ # # 1. Check inputs. Raise error if not correct
205
+ # self.check_inputs(
206
+ # prompt,
207
+ # prompt_2,
208
+ # height,
209
+ # width,
210
+ # prompt_embeds=prompt_embeds,
211
+ # pooled_prompt_embeds=pooled_prompt_embeds,
212
+ # callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
213
+ # max_sequence_length=max_sequence_length,
214
+ # )
215
+
216
+ # self._guidance_scale = guidance_scale
217
+ # self._joint_attention_kwargs = joint_attention_kwargs
218
+ # self._interrupt = False
219
+
220
+ # # 2. Define call parameters
221
+ # if prompt is not None and isinstance(prompt, str):
222
+ # batch_size = 1
223
+ # elif prompt is not None and isinstance(prompt, list):
224
+ # batch_size = len(prompt)
225
+ # else:
226
+ # batch_size = prompt_embeds.shape[0]
227
+
228
+ # device = self._execution_device
229
+
230
+ # lora_scale = (
231
+ # self.joint_attention_kwargs.get("scale", None)
232
+ # if self.joint_attention_kwargs is not None
233
+ # else None
234
+ # )
235
+ # (
236
+ # t5_prompt_embeds,
237
+ # pooled_prompt_embeds,
238
+ # text_ids,
239
+ # ) = encode_prompt_with_clip_t5(
240
+ # self=self,
241
+ # prompt="" if self.text_encoder_2 is None else prompt,
242
+ # prompt_2=None,
243
+ # prompt_embeds=prompt_embeds,
244
+ # pooled_prompt_embeds=pooled_prompt_embeds,
245
+ # device=device,
246
+ # num_images_per_prompt=num_images_per_prompt,
247
+ # max_sequence_length=max_sequence_length,
248
+ # lora_scale=lora_scale,
249
+ # )
250
+
251
+ # # 4. Prepare latent variables
252
+ # num_channels_latents = self.transformer.config.in_channels // 4
253
+ # latents, latent_image_ids = self.prepare_latents(
254
+ # batch_size * num_images_per_prompt,
255
+ # num_channels_latents,
256
+ # height,
257
+ # width,
258
+ # pooled_prompt_embeds.dtype,
259
+ # device,
260
+ # generator,
261
+ # latents,
262
+ # )
263
+
264
+ # latent_height = height // 16
265
+
266
+ # # 5. Prepare timesteps
267
+ # sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
268
+ # image_seq_len = latents.shape[1]
269
+ # mu = calculate_shift(
270
+ # image_seq_len,
271
+ # self.scheduler.config.base_image_seq_len,
272
+ # self.scheduler.config.max_image_seq_len,
273
+ # self.scheduler.config.base_shift,
274
+ # self.scheduler.config.max_shift,
275
+ # )
276
+ # timesteps, num_inference_steps = retrieve_timesteps(
277
+ # self.scheduler,
278
+ # num_inference_steps,
279
+ # device,
280
+ # timesteps,
281
+ # sigmas,
282
+ # mu=mu,
283
+ # )
284
+ # num_warmup_steps = max(
285
+ # len(timesteps) - num_inference_steps * self.scheduler.order, 0
286
+ # )
287
+ # self._num_timesteps = len(timesteps)
288
+
289
+ # attn_map = None
290
+
291
+ # # 6. Denoising loop
292
+ # with self.progress_bar(total=num_inference_steps) as progress_bar:
293
+ # totalsteps = timesteps[0]
294
+ # if control_weight_lambda is not None:
295
+ # print("control_weight_lambda", control_weight_lambda)
296
+ # control_weight_lambda_schedule = []
297
+ # for scale_str in control_weight_lambda.split(','):
298
+ # time_region, scale = scale_str.split(':')
299
+ # start, end = time_region.split('-')
300
+ # scales = [float(s) for s in scale.split('/')]
301
+ # control_weight_lambda_schedule.append([(1-float(start))*totalsteps, (1-float(end))*totalsteps, scales])
302
+
303
+ # if ip_scale is not None:
304
+ # print("ip_scale", ip_scale)
305
+ # ip_scale_schedule = []
306
+ # for scale_str in ip_scale.split(','):
307
+ # time_region, scale = scale_str.split(':')
308
+ # start, end = time_region.split('-')
309
+ # ip_scale_schedule.append([(1-float(start))*totalsteps, (1-float(end))*totalsteps, float(scale)])
310
+
311
+ # if use_latent_sblora_control:
312
+ # if latent_sblora_scale is not None:
313
+ # print("latent_sblora_scale", latent_sblora_scale)
314
+ # latent_sblora_scale_schedule = []
315
+ # for scale_str in latent_sblora_scale.split(','):
316
+ # time_region, scale = scale_str.split(':')
317
+ # start, end = time_region.split('-')
318
+ # latent_sblora_scale_schedule.append([(1-float(start))*totalsteps, (1-float(end))*totalsteps, float(scale)])
319
 
320
+ # if use_condition_sblora_control:
321
+ # if condition_sblora_scale is not None:
322
+ # print("condition_sblora_scale", condition_sblora_scale)
323
+ # condition_sblora_scale_schedule = []
324
+ # for scale_str in condition_sblora_scale.split(','):
325
+ # time_region, scale = scale_str.split(':')
326
+ # start, end = time_region.split('-')
327
+ # condition_sblora_scale_schedule.append([(1-float(start))*totalsteps, (1-float(end))*totalsteps, float(scale)])
328
+
329
+
330
+ # if vae_skip_iter is not None:
331
+ # print("vae_skip_iter", vae_skip_iter)
332
+ # vae_skip_iter_schedule = []
333
+ # for scale_str in vae_skip_iter.split(','):
334
+ # time_region, scale = scale_str.split(':')
335
+ # start, end = time_region.split('-')
336
+ # vae_skip_iter_schedule.append([(1-float(start))*totalsteps, (1-float(end))*totalsteps, float(scale)])
337
+
338
+ # if control_weight_lambda is not None and attn_map is None:
339
+ # batch_size = latents.shape[0]
340
+ # latent_width = latents.shape[1]//latent_height
341
+ # attn_map = torch.ones(batch_size, latent_height, latent_width, 128, device=latents.device, dtype=torch.bfloat16)
342
+ # print("contol_weight_only", attn_map.shape)
343
+
344
+ # self.scheduler.set_begin_index(0)
345
+ # self.scheduler._init_step_index(0)
346
+ # for i, t in enumerate(timesteps):
347
 
348
+ # if control_weight_lambda is not None:
349
+ # cur_control_weight_lambda = []
350
+ # for start, end, scale in control_weight_lambda_schedule:
351
+ # if t <= start and t >= end:
352
+ # cur_control_weight_lambda = scale
353
+ # break
354
+ # print(f"timestep:{t}, cur_control_weight_lambda:{cur_control_weight_lambda}")
355
 
356
+ # if cur_control_weight_lambda:
357
+ # model_config["use_attention_single"] = True
358
+ # use_attention = True
359
+ # model_config["use_atten_lambda"] = cur_control_weight_lambda
360
+ # else:
361
+ # model_config["use_attention_single"] = False
362
+ # use_attention = False
363
 
364
+ # if self.interrupt:
365
+ # continue
366
+
367
+ # if isinstance(delta_emb, list):
368
+ # cur_delta_emb = delta_emb[i]
369
+ # cur_delta_emb_pblock = delta_emb_pblock[i]
370
+ # cur_delta_emb_mask = delta_emb_mask[i]
371
+ # else:
372
+ # cur_delta_emb = delta_emb
373
+ # cur_delta_emb_pblock = delta_emb_pblock
374
+ # cur_delta_emb_mask = delta_emb_mask
375
+
376
+
377
+ # # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
378
+ # timestep = t.expand(latents.shape[0]).to(latents.dtype) / 1000
379
+ # prompt_embeds = t5_prompt_embeds
380
+ # text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=prompt_embeds.dtype)
381
+
382
+ # # handle guidance
383
+ # if self.transformer.config.guidance_embeds:
384
+ # guidance = torch.tensor([guidance_scale], device=device)
385
+ # guidance = guidance.expand(latents.shape[0])
386
+ # else:
387
+ # guidance = None
388
+ # self.transformer.enable_lora()
389
 
390
+ # lora_weight = 1
391
+ # if ip_scale is not None:
392
+ # lora_weight = 0
393
+ # for start, end, scale in ip_scale_schedule:
394
+ # if t <= start and t >= end:
395
+ # lora_weight = scale
396
+ # break
397
+ # if lora_weight != 1: print(f"timestep:{t}, lora_weights:{lora_weight}")
398
 
399
+ # latent_sblora_weight = None
400
+ # if use_latent_sblora_control:
401
+ # if latent_sblora_scale is not None:
402
+ # latent_sblora_weight = 0
403
+ # for start, end, scale in latent_sblora_scale_schedule:
404
+ # if t <= start and t >= end:
405
+ # latent_sblora_weight = scale
406
+ # break
407
+ # if latent_sblora_weight != 1: print(f"timestep:{t}, latent_sblora_weight:{latent_sblora_weight}")
408
 
409
+ # condition_sblora_weight = None
410
+ # if use_condition_sblora_control:
411
+ # if condition_sblora_scale is not None:
412
+ # condition_sblora_weight = 0
413
+ # for start, end, scale in condition_sblora_scale_schedule:
414
+ # if t <= start and t >= end:
415
+ # condition_sblora_weight = scale
416
+ # break
417
+ # if condition_sblora_weight !=1: print(f"timestep:{t}, condition_sblora_weight:{condition_sblora_weight}")
418
+
419
+ # vae_skip_iter_t = False
420
+ # if vae_skip_iter is not None:
421
+ # for start, end, scale in vae_skip_iter_schedule:
422
+ # if t <= start and t >= end:
423
+ # vae_skip_iter_t = bool(scale)
424
+ # break
425
+ # if vae_skip_iter_t:
426
+ # print(f"timestep:{t}, skip vae:{vae_skip_iter_t}")
427
+
428
+ # noise_pred = tranformer_forward(
429
+ # self.transformer,
430
+ # model_config=model_config,
431
+ # # Inputs of the condition (new feature)
432
+ # text_cond_mask=text_cond_mask,
433
+ # delta_emb=cur_delta_emb,
434
+ # delta_emb_pblock=cur_delta_emb_pblock,
435
+ # delta_emb_mask=cur_delta_emb_mask,
436
+ # delta_start_ends=delta_start_ends,
437
+ # condition_latents=None if vae_skip_iter_t else condition_latents,
438
+ # condition_ids=None if vae_skip_iter_t else condition_ids,
439
+ # condition_type_ids=None,
440
+ # # Inputs to the original transformer
441
+ # hidden_states=latents,
442
+ # # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
443
+ # timestep=timestep,
444
+ # guidance=guidance,
445
+ # pooled_projections=pooled_prompt_embeds,
446
+ # encoder_hidden_states=prompt_embeds,
447
+ # txt_ids=text_ids,
448
+ # img_ids=latent_image_ids,
449
+ # joint_attention_kwargs={'scale': lora_weight, "latent_sblora_weight": latent_sblora_weight, "condition_sblora_weight": condition_sblora_weight},
450
+ # store_attn_map=use_attention,
451
+ # last_attn_map=attn_map if cur_control_weight_lambda else None,
452
+ # use_text_mod=model_config["modulation"]["use_text_mod"],
453
+ # use_img_mod=model_config["modulation"]["use_img_mod"],
454
+ # mod_adapter=mod_adapter,
455
+ # latent_height=latent_height,
456
+ # return_dict=False,
457
+ # )[0]
458
+
459
+ # if use_attention:
460
+ # attn_maps, _ = gather_attn_maps(self.transformer, clear=True)
461
+
462
+ # # compute the previous noisy sample x_t -> x_t-1
463
+ # latents_dtype = latents.dtype
464
+ # latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
465
+
466
+ # if latents.dtype != latents_dtype:
467
+ # if torch.backends.mps.is_available():
468
+ # # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
469
+ # latents = latents.to(latents_dtype)
470
+
471
+ # if callback_on_step_end is not None:
472
+ # callback_kwargs = {}
473
+ # for k in callback_on_step_end_tensor_inputs:
474
+ # callback_kwargs[k] = locals()[k]
475
+ # callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
476
+
477
+ # latents = callback_outputs.pop("latents", latents)
478
+ # prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
479
+
480
+ # # call the callback, if provided
481
+ # if i == len(timesteps) - 1 or (
482
+ # (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
483
+ # ):
484
+ # progress_bar.update()
485
+
486
+ # if output_type == "latent":
487
+ # image = latents
488
+
489
+ # else:
490
+ # latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
491
+ # latents = (
492
+ # latents / self.vae.config.scaling_factor
493
+ # ) + self.vae.config.shift_factor
494
+ # image = self.vae.decode(latents, return_dict=False)[0]
495
+ # image = self.image_processor.postprocess(image, output_type=output_type)
496
+
497
+ # # Offload all models
498
+ # self.maybe_free_model_hooks()
499
+
500
+ # self.transformer.enable_lora()
501
+
502
+ # if vae_condition_scale != 1:
503
+ # for name, module in pipeline.transformer.named_modules():
504
+ # if not name.endswith(".attn"):
505
+ # continue
506
+ # del module.c_factor
507
+
508
+ # if not return_dict:
509
+ # return (image,)
510
+
511
+ # return FluxPipelineOutput(images=image)
512
+
513
+
514
+ # @torch.no_grad()
515
+ # def generate_from_test_sample(
516
+ # test_sample, pipe, config,
517
+ # num_images=1,
518
+ # vae_skip_iter: str = None,
519
+ # target_height: int = None,
520
+ # target_width: int = None,
521
+ # seed: int = 42,
522
+ # control_weight_lambda: str = None,
523
+ # double_attention: bool = False,
524
+ # single_attention: bool = False,
525
+ # ip_scale: str = None,
526
+ # use_latent_sblora_control: bool = False,
527
+ # latent_sblora_scale: str = None,
528
+ # use_condition_sblora_control: bool = False,
529
+ # condition_sblora_scale: str = None,
530
+ # use_idip = False,
531
+ # **kargs
532
+ # ):
533
+ # target_size = config["train"]["dataset"]["val_target_size"]
534
+ # condition_size = config["train"]["dataset"].get("val_condition_size", target_size//2)
535
+ # condition_pad_to = config["train"]["dataset"]["condition_pad_to"]
536
+ # pos_offset_type = config["model"].get("pos_offset_type", "width")
537
+ # seed = config["model"].get("seed", seed)
538
+
539
+ # device = pipe._execution_device
540
+
541
+ # condition_imgs = test_sample['input_images']
542
+ # position_delta = test_sample['position_delta']
543
+ # prompt = test_sample['prompt']
544
+ # original_image = test_sample.get('original_image', None)
545
+ # condition_type = test_sample.get('condition_type', "subject")
546
+ # modulation_input = test_sample.get('modulation', None)
547
+
548
+ # delta_start_ends = None
549
+ # condition_latents = condition_ids = None
550
+ # text_cond_mask = None
551
 
552
+ # delta_embs = None
553
+ # delta_embs_pblock = None
554
+ # delta_embs_mask = None
555
+
556
+ # try:
557
+ # max_length = config["model"]["modulation"]["max_text_len"]
558
+ # except Exception as e:
559
+ # print(e)
560
+ # max_length = 512
561
+
562
+ # if modulation_input is None or len(modulation_input) == 0:
563
+ # delta_emb = delta_emb_pblock = delta_emb_mask = None
564
+ # else:
565
+ # dtype = torch.bfloat16
566
+ # batch_size = 1
567
+ # N = config["model"]["modulation"].get("per_block_adapter_single_blocks", 0) + 19
568
+ # guidance = torch.tensor([3.5]).to(device).expand(batch_size)
569
+ # out_dim = config["model"]["modulation"]["out_dim"]
570
+
571
+ # tar_text_inputs = tokenize_t5_prompt(pipe, prompt, max_length)
572
+ # tar_padding_mask = tar_text_inputs.attention_mask.to(device).bool()
573
+ # tar_tokens = tar_text_inputs.input_ids.to(device)
574
+ # if config["model"]["modulation"]["eos_exclude"]:
575
+ # tar_padding_mask[tar_tokens == 1] = False
576
+
577
+ # def get_start_end_by_pompt_matching(src_prompts, tar_prompts):
578
+ # text_cond_mask = torch.zeros(batch_size, max_length, device=device, dtype=torch.bool)
579
+ # tar_prompt_input_ids = tokenize_t5_prompt(pipe, tar_prompts, max_length).input_ids
580
+ # src_prompt_count = 1
581
+ # start_ends = []
582
+ # for i, (src_prompt, tar_prompt, tar_prompt_tokens) in enumerate(zip(src_prompts, tar_prompts, tar_prompt_input_ids)):
583
+ # try:
584
+ # tar_start, tar_end = get_word_index(pipe, tar_prompt, tar_prompt_tokens, src_prompt, src_prompt_count, max_length, verbose=False)
585
+ # start_ends.append([tar_start, tar_end])
586
+ # text_cond_mask[i, tar_start:tar_end] = True
587
+ # except Exception as e:
588
+ # print(e)
589
+ # return start_ends, text_cond_mask
590
+
591
+ # def encode_mod_image(pil_images):
592
+ # if config["model"]["modulation"]["use_dit"]:
593
+ # raise NotImplementedError()
594
+ # else:
595
+ # pil_images = [pad_to_square(img).resize((224, 224)) for img in pil_images]
596
+ # if config["model"]["modulation"]["use_vae"]:
597
+ # raise NotImplementedError()
598
+ # else:
599
+ # clip_pixel_values = pipe.clip_processor(
600
+ # text=None, images=pil_images, do_resize=False, do_center_crop=False, return_tensors="pt",
601
+ # ).pixel_values.to(dtype=dtype, device=device)
602
+ # clip_outputs = pipe.clip_model(clip_pixel_values, output_hidden_states=True, interpolate_pos_encoding=True, return_dict=True)
603
+ # return clip_outputs
604
+
605
+ # def rgba_to_white_background(input_path, background=(255,255,255)):
606
+ # with Image.open(input_path).convert("RGBA") as img:
607
+ # img_np = np.array(img)
608
+ # alpha = img_np[:, :, 3] / 255.0 # 归一化Alpha通道[3](@ref)
609
+ # rgb = img_np[:, :, :3].astype(float) # 提取RGB通道
610
 
611
+ # background_np = np.full_like(rgb, background, dtype=float) # 根据参数生成背景[7](@ref)
612
 
613
+ # # 混合计算:前景色*alpha + 背景色*(1-alpha)
614
+ # result_np = rgb * alpha[..., np.newaxis] + \
615
+ # background_np * (1 - alpha[..., np.newaxis])
616
 
617
+ # result = Image.fromarray(result_np.astype(np.uint8), "RGB")
618
+ # return result
619
+ # def get_mod_emb(modulation_input, timestep):
620
+ # delta_emb = torch.zeros((batch_size, max_length, out_dim), dtype=dtype, device=device)
621
+ # delta_emb_pblock = torch.zeros((batch_size, max_length, N, out_dim), dtype=dtype, device=device)
622
+ # delta_emb_mask = torch.zeros((batch_size, max_length), dtype=torch.bool, device=device)
623
+ # delta_start_ends = None
624
+ # condition_latents = condition_ids = None
625
+ # text_cond_mask = None
626
+
627
+ # if modulation_input[0]["type"] == "adapter":
628
+ # num_inputs = len(modulation_input[0]["src_inputs"])
629
+ # src_prompts = [x["caption"] for x in modulation_input[0]["src_inputs"]]
630
+ # src_text_inputs = tokenize_t5_prompt(pipe, src_prompts, max_length)
631
+ # src_input_ids = unpad_input_ids(src_text_inputs.input_ids, src_text_inputs.attention_mask)
632
+ # tar_input_ids = unpad_input_ids(tar_text_inputs.input_ids, tar_text_inputs.attention_mask)
633
+ # src_prompt_embeds = pipe._get_t5_prompt_embeds(prompt=src_prompts, max_sequence_length=max_length, device=device) # (M, 512, 4096)
634
 
635
+ # pil_images = [rgba_to_white_background(x["image_path"]) for x in modulation_input[0]["src_inputs"]]
636
+
637
+ # src_ds_scales = [x.get("downsample_scale", 1.0) for x in modulation_input[0]["src_inputs"]]
638
+ # resized_pil_images = []
639
+ # for img, ds_scale in zip(pil_images, src_ds_scales):
640
+ # img = pad_to_square(img)
641
+ # if ds_scale < 1.0:
642
+ # assert ds_scale > 0
643
+ # img = img.resize((int(224 * ds_scale), int(224 * ds_scale))).resize((224, 224))
644
+ # resized_pil_images.append(img)
645
+ # pil_images = resized_pil_images
646
 
647
+ # img_encoded = encode_mod_image(pil_images)
648
+ # delta_start_ends = []
649
+ # text_cond_mask = torch.zeros(num_inputs, max_length, device=device, dtype=torch.bool)
650
+ # if config["model"]["modulation"]["pass_vae"]:
651
+ # pil_images = [pad_to_square(img).resize((condition_size, condition_size)) for img in pil_images]
652
+ # with torch.no_grad():
653
+ # batch_tensor = torch.stack([pil2tensor(x) for x in pil_images])
654
+ # x_0, img_ids = encode_vae_images(pipe, batch_tensor) # (N, 256, 64)
655
+
656
+ # condition_latents = x_0.clone().detach().reshape(1, -1, 64) # (1, N256, 64)
657
+ # condition_ids = img_ids.clone().detach()
658
+ # condition_ids = condition_ids.unsqueeze(0).repeat_interleave(num_inputs, dim=0) # (N, 256, 3)
659
+ # for i in range(num_inputs):
660
+ # condition_ids[i, :, 1] += 0 if pos_offset_type == "width" else -(batch_tensor.shape[-1]//16) * (i + 1)
661
+ # condition_ids[i, :, 2] += -(batch_tensor.shape[-1]//16) * (i + 1)
662
+ # condition_ids = condition_ids.reshape(-1, 3) # (N256, 3)
663
+
664
+ # if config["model"]["modulation"]["use_dit"]:
665
+ # raise NotImplementedError()
666
+ # else:
667
+ # src_delta_embs = [] # [(512, 3072)]
668
+ # src_delta_emb_pblock = []
669
+ # for i in range(num_inputs):
670
+ # if isinstance(img_encoded, dict):
671
+ # _src_clip_outputs = {}
672
+ # for key in img_encoded:
673
+ # if torch.is_tensor(img_encoded[key]):
674
+ # _src_clip_outputs[key] = img_encoded[key][i:i+1]
675
+ # else:
676
+ # _src_clip_outputs[key] = [x[i:i+1] for x in img_encoded[key]]
677
+ # _img_encoded = _src_clip_outputs
678
+ # else:
679
+ # _img_encoded = img_encoded[i:i+1]
680
 
681
+ # x1, x2 = pipe.modulation_adapters[0](timestep, src_prompt_embeds[i:i+1], _img_encoded)
682
+ # src_delta_embs.append(x1[0]) # (512, 3072)
683
+ # src_delta_emb_pblock.append(x2[0]) # (512, N, 3072)
684
+
685
+ # for input_args in modulation_input[0]["use_words"]:
686
+ # src_word_count = 1
687
+ # if len(input_args) == 3:
688
+ # src_input_index, src_word, tar_word = input_args
689
+ # tar_word_count = 1
690
+ # else:
691
+ # src_input_index, src_word, tar_word, tar_word_count = input_args[:4]
692
+ # src_prompt = src_prompts[src_input_index]
693
+ # tar_prompt = prompt
694
+
695
+ # src_start, src_end = get_word_index(pipe, src_prompt, src_input_ids[src_input_index], src_word, src_word_count, max_length, verbose=False)
696
+ # tar_start, tar_end = get_word_index(pipe, tar_prompt, tar_input_ids[0], tar_word, tar_word_count, max_length, verbose=False)
697
+ # if delta_emb is not None:
698
+ # delta_emb[:, tar_start:tar_end] = src_delta_embs[src_input_index][src_start:src_end] # (B, 512, 3072)
699
+ # if delta_emb_pblock is not None:
700
+ # delta_emb_pblock[:, tar_start:tar_end] = src_delta_emb_pblock[src_input_index][src_start:src_end] # (B, 512, N, 3072)
701
+ # delta_emb_mask[:, tar_start:tar_end] = True
702
+ # text_cond_mask[src_input_index, tar_start:tar_end] = True
703
+ # delta_start_ends.append([0, src_input_index, src_start, src_end, tar_start, tar_end])
704
+ # text_cond_mask = text_cond_mask.transpose(0, 1).unsqueeze(0)
705
+
706
+ # else:
707
+ # raise NotImplementedError()
708
+ # return delta_emb, delta_emb_pblock, delta_emb_mask, \
709
+ # text_cond_mask, delta_start_ends, condition_latents, condition_ids
710
 
711
+ # num_inference_steps = 28 # FIXME: harcoded here
712
+ # num_channels_latents = pipe.transformer.config.in_channels // 4
713
+
714
+ # # set timesteps
715
+ # sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
716
+ # mu = calculate_shift(
717
+ # num_channels_latents,
718
+ # pipe.scheduler.config.base_image_seq_len,
719
+ # pipe.scheduler.config.max_image_seq_len,
720
+ # pipe.scheduler.config.base_shift,
721
+ # pipe.scheduler.config.max_shift,
722
+ # )
723
+ # timesteps, num_inference_steps = retrieve_timesteps(
724
+ # pipe.scheduler,
725
+ # num_inference_steps,
726
+ # device,
727
+ # None,
728
+ # sigmas,
729
+ # mu=mu,
730
+ # )
731
+
732
+ # if modulation_input is not None:
733
+ # delta_embs = []
734
+ # delta_embs_pblock = []
735
+ # delta_embs_mask = []
736
+ # for i, t in enumerate(timesteps):
737
+ # t = t.expand(1).to(torch.bfloat16) / 1000
738
+ # (
739
+ # delta_emb, delta_emb_pblock, delta_emb_mask,
740
+ # text_cond_mask, delta_start_ends,
741
+ # condition_latents, condition_ids
742
+ # ) = get_mod_emb(modulation_input, t)
743
+ # delta_embs.append(delta_emb)
744
+ # delta_embs_pblock.append(delta_emb_pblock)
745
+ # delta_embs_mask.append(delta_emb_mask)
746
+
747
+ # if original_image is not None:
748
+ # raise NotImplementedError()
749
+ # (target_height, target_width), closest_ratio = get_closest_ratio(original_image.height, original_image.width, train_aspect_ratios)
750
+ # elif modulation_input is None or len(modulation_input) == 0:
751
+ # delta_emb = delta_emb_pblock = delta_emb_mask = None
752
+ # else:
753
+ # for i, t in enumerate(timesteps):
754
+ # t = t.expand(1).to(torch.bfloat16) / 1000
755
+ # (
756
+ # delta_emb, delta_emb_pblock, delta_emb_mask,
757
+ # text_cond_mask, delta_start_ends,
758
+ # condition_latents, condition_ids
759
+ # ) = get_mod_emb(modulation_input, t)
760
+ # delta_embs.append(delta_emb)
761
+ # delta_embs_pblock.append(delta_emb_pblock)
762
+ # delta_embs_mask.append(delta_emb_mask)
763
+
764
+ # if target_height is None or target_width is None:
765
+ # target_height = target_width = target_size
766
+
767
+ # if condition_pad_to == "square":
768
+ # condition_imgs = [pad_to_square(x) for x in condition_imgs]
769
+ # elif condition_pad_to == "target":
770
+ # condition_imgs = [pad_to_target(x, (target_size, target_size)) for x in condition_imgs]
771
+ # condition_imgs = [x.resize((condition_size, condition_size)).convert("RGB") for x in condition_imgs]
772
+ # # TODO: fix position_delta
773
+ # conditions = [
774
+ # Condition(
775
+ # condition_type=condition_type,
776
+ # condition=x,
777
+ # position_delta=position_delta,
778
+ # ) for x in condition_imgs
779
+ # ]
780
+ # # vlm_images = condition_imgs if config["model"]["use_vlm"] else []
781
+
782
+ # use_perblock_adapter = False
783
+ # try:
784
+ # if config["model"]["modulation"]["use_perblock_adapter"]:
785
+ # use_perblock_adapter = True
786
+ # except Exception as e:
787
+ # pass
788
+
789
+ # results = []
790
+ # for i in range(num_images):
791
+ # clear_attn_maps(pipe.transformer)
792
+ # generator = torch.Generator(device=device)
793
+ # generator.manual_seed(seed + i)
794
+ # if modulation_input is None or len(modulation_input) == 0:
795
+ # idips = None
796
+ # else:
797
+ # idips = ["human" in p["image_path"] for p in modulation_input[0]["src_inputs"]]
798
+ # if len(modulation_input[0]["use_words"][0])==5:
799
+ # print("use idips in use_words")
800
+ # idips = [x[-1] for x in modulation_input[0]["use_words"]]
801
+ # result_img = generate(
802
+ # pipe,
803
+ # prompt=prompt,
804
+ # max_sequence_length=max_length,
805
+ # vae_conditions=conditions,
806
+ # generator=generator,
807
+ # model_config=config["model"],
808
+ # height=target_height,
809
+ # width=target_width,
810
+ # condition_pad_to=condition_pad_to,
811
+ # condition_size=condition_size,
812
+ # text_cond_mask=text_cond_mask,
813
+ # delta_emb=delta_embs,
814
+ # delta_emb_pblock=delta_embs_pblock if use_perblock_adapter else None,
815
+ # delta_emb_mask=delta_embs_mask,
816
+ # delta_start_ends=delta_start_ends,
817
+ # condition_latents=condition_latents,
818
+ # condition_ids=condition_ids,
819
+ # mod_adapter=pipe.modulation_adapters[0] if config["model"]["modulation"]["use_dit"] else None,
820
+ # vae_skip_iter=vae_skip_iter,
821
+ # control_weight_lambda=control_weight_lambda,
822
+ # double_attention=double_attention,
823
+ # single_attention=single_attention,
824
+ # ip_scale=ip_scale,
825
+ # use_latent_sblora_control=use_latent_sblora_control,
826
+ # latent_sblora_scale=latent_sblora_scale,
827
+ # use_condition_sblora_control=use_condition_sblora_control,
828
+ # condition_sblora_scale=condition_sblora_scale,
829
+ # idips=idips if use_idip else None,
830
+ # **kargs,
831
+ # ).images[0]
832
+
833
+ # final_image = result_img
834
+ # results.append(final_image)
835
+
836
+ # if num_images == 1:
837
+ # return results[0]
838
+ # return results