fengyutong commited on
Commit
e4df51f
·
1 Parent(s): bf2440b

first commit

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.jpg filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ __pycache__/
app.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import spaces
3
+ import torch
4
+ import diffusers
5
+ import transformers
6
+ import copy
7
+ import random
8
+ import numpy as np
9
+ import torchvision.transforms as T
10
+ import math
11
+ import os
12
+ import peft
13
+ from peft import LoraConfig
14
+ from safetensors import safe_open
15
+ from omegaconf import OmegaConf
16
+ from omnitry.models.transformer_flux import FluxTransformer2DModel
17
+ from omnitry.pipelines.pipeline_flux_fill import FluxFillPipeline
18
+
19
+
20
+ from huggingface_hub import snapshot_download
21
+ snapshot_download(repo_id="black-forest-labs/FLUX.1-Fill-dev", local_dir="./FLUX.1-Fill-dev")
22
+ snapshot_download(repo_id="Kunbyte/OmniTry", local_dir="./OmniTry")
23
+
24
+ device = torch.device('cuda:0')
25
+ weight_dtype = torch.bfloat16
26
+ args = OmegaConf.load('configs/omnitry_v1_unified.yaml')
27
+
28
+ # init model
29
+ transformer = FluxTransformer2DModel.from_pretrained('./FLUX.1-Fill-dev/transformer').requires_grad_(False).to(device, dtype=weight_dtype)
30
+ vae = diffusers.AutoencoderKL.from_pretrained('./FLUX.1-Fill-dev/vae').requires_grad_(False).to(device, dtype=weight_dtype)
31
+ text_encoder = transformers.CLIPTextModel.from_pretrained('./FLUX.1-Fill-dev/text_encoder').requires_grad_(False).to(device, dtype=weight_dtype)
32
+ text_encoder_2 = transformers.T5EncoderModel.from_pretrained('./FLUX.1-Fill-dev/text_encoder_2').requires_grad_(False).to(device, dtype=weight_dtype)
33
+ scheduler = diffusers.FlowMatchEulerDiscreteScheduler.from_pretrained('./FLUX.1-Fill-dev/scheduler')
34
+ tokenizer = transformers.CLIPTokenizer.from_pretrained('./FLUX.1-Fill-dev/tokenizer')
35
+ tokenizer_2 = transformers.T5TokenizerFast.from_pretrained('./FLUX.1-Fill-dev/tokenizer_2')
36
+
37
+ # insert LoRA
38
+ lora_config = LoraConfig(
39
+ r=args.lora_rank,
40
+ lora_alpha=args.lora_alpha,
41
+ init_lora_weights="gaussian",
42
+ target_modules=[
43
+ 'x_embedder',
44
+ 'attn.to_k', 'attn.to_q', 'attn.to_v', 'attn.to_out.0',
45
+ 'attn.add_k_proj', 'attn.add_q_proj', 'attn.add_v_proj', 'attn.to_add_out',
46
+ 'ff.net.0.proj', 'ff.net.2', 'ff_context.net.0.proj', 'ff_context.net.2',
47
+ 'norm1_context.linear', 'norm1.linear', 'norm.linear', 'proj_mlp', 'proj_out'
48
+ ]
49
+ )
50
+ transformer.add_adapter(lora_config, adapter_name='vtryon_lora')
51
+ transformer.add_adapter(lora_config, adapter_name='garment_lora')
52
+
53
+ with safe_open('OmniTry/omnitry_v1_unified_stage2.safetensors', framework="pt") as f:
54
+ lora_weights = {k: f.get_tensor(k) for k in f.keys()}
55
+ transformer.load_state_dict(lora_weights, strict=False)
56
+
57
+ # hack lora forward
58
+ def create_hacked_forward(module):
59
+
60
+ def lora_forward(self, active_adapter, x, *args, **kwargs):
61
+ result = self.base_layer(x, *args, **kwargs)
62
+ if active_adapter is not None:
63
+ torch_result_dtype = result.dtype
64
+ lora_A = self.lora_A[active_adapter]
65
+ lora_B = self.lora_B[active_adapter]
66
+ dropout = self.lora_dropout[active_adapter]
67
+ scaling = self.scaling[active_adapter]
68
+ x = x.to(lora_A.weight.dtype)
69
+ result = result + lora_B(lora_A(dropout(x))) * scaling
70
+ return result
71
+
72
+ def hacked_lora_forward(self, x, *args, **kwargs):
73
+ return torch.cat((
74
+ lora_forward(self, 'vtryon_lora', x[:1], *args, **kwargs),
75
+ lora_forward(self, 'garment_lora', x[1:], *args, **kwargs),
76
+ ), dim=0)
77
+
78
+ return hacked_lora_forward.__get__(module, type(module))
79
+
80
+ for n, m in transformer.named_modules():
81
+ if isinstance(m, peft.tuners.lora.layer.Linear):
82
+ m.forward = create_hacked_forward(m)
83
+
84
+ # init pipeline
85
+ pipeline = FluxFillPipeline(
86
+ transformer=transformer.eval(),
87
+ scheduler=copy.deepcopy(scheduler),
88
+ vae=vae,
89
+ text_encoder=text_encoder,
90
+ text_encoder_2=text_encoder_2,
91
+ tokenizer=tokenizer,
92
+ tokenizer_2=tokenizer_2,
93
+ )
94
+
95
+
96
+ def seed_everything(seed=0):
97
+ random.seed(seed)
98
+ os.environ['PYTHONHASHSEED'] = str(seed)
99
+ np.random.seed(seed)
100
+ torch.manual_seed(seed)
101
+ torch.cuda.manual_seed(seed)
102
+ torch.cuda.manual_seed_all(seed)
103
+
104
+
105
+
106
+ @spaces.GPU
107
+ def generate(person_image, object_image, object_class, steps, guidance_scale, seed):
108
+ # set seed
109
+ if seed == -1:
110
+ seed = random.randint(0, 2**32 - 1)
111
+ seed_everything(seed)
112
+
113
+ # resize model
114
+ max_area = 1024 * 1024
115
+ oW = person_image.width
116
+ oH = person_image.height
117
+
118
+ ratio = math.sqrt(max_area / (oW * oH))
119
+ ratio = min(1, ratio)
120
+ tW, tH = int(oW * ratio) // 16 * 16, int(oH * ratio) // 16 * 16
121
+ transform = T.Compose([
122
+ T.Resize((tH, tW)),
123
+ T.ToTensor(),
124
+ ])
125
+ person_image = transform(person_image)
126
+
127
+ # resize and padding garment
128
+ ratio = min(tW / object_image.width, tH / object_image.height)
129
+ transform = T.Compose([
130
+ T.Resize((int(object_image.height * ratio), int(object_image.width * ratio))),
131
+ T.ToTensor(),
132
+ ])
133
+ object_image_padded = torch.ones_like(person_image)
134
+ object_image = transform(object_image)
135
+ new_h, new_w = object_image.shape[1], object_image.shape[2]
136
+ min_x = (tW - new_w) // 2
137
+ min_y = (tH - new_h) // 2
138
+ object_image_padded[:, min_y: min_y + new_h, min_x: min_x + new_w] = object_image
139
+
140
+ # prepare prompts & conditions
141
+ prompts = [args.object_map[object_class]] * 2
142
+ img_cond = torch.stack([person_image, object_image_padded]).to(dtype=weight_dtype, device=device)
143
+ mask = torch.zeros_like(img_cond).to(img_cond)
144
+
145
+ with torch.no_grad():
146
+ img = pipeline(
147
+ prompt=prompts,
148
+ height=tH,
149
+ width=tW,
150
+ img_cond=img_cond,
151
+ mask=mask,
152
+ guidance_scale=guidance_scale,
153
+ num_inference_steps=steps,
154
+ generator=torch.Generator(device).manual_seed(seed),
155
+ ).images[0]
156
+
157
+ return img
158
+
159
+
160
+ if __name__ == '__main__':
161
+
162
+ with gr.Blocks() as demo:
163
+ gr.Markdown('# Demo of OmniTry')
164
+ with gr.Row():
165
+ with gr.Column():
166
+ person_image = gr.Image(type="pil", label="Person Image", height=800)
167
+ run_button = gr.Button(value="Submit", variant='primary')
168
+
169
+ with gr.Column():
170
+ object_image = gr.Image(type="pil", label="Object Image", height=800)
171
+ object_class = gr.Dropdown(label='Object Class', choices=args.object_map.keys())
172
+
173
+ with gr.Column():
174
+ image_out = gr.Image(type="pil", label="Output", height=800)
175
+
176
+ with gr.Accordion("Advanced ⚙️", open=False):
177
+ guidance_scale = gr.Slider(label="Guidance scale", minimum=1, maximum=50, value=30, step=0.1)
178
+ steps = gr.Slider(label="Steps", minimum=1, maximum=50, value=20, step=1)
179
+ seed = gr.Number(label="Seed", value=-1, precision=0)
180
+
181
+ with gr.Row():
182
+ gr.Examples(
183
+ examples=[
184
+ [
185
+ './demo_example/person_top_cloth.jpg',
186
+ './demo_example/object_top_cloth.jpg',
187
+ 'top clothes',
188
+ ],
189
+ [
190
+ './demo_example/person_bottom_cloth.jpg',
191
+ './demo_example/object_bottom_cloth.jpg',
192
+ 'bottom clothes',
193
+ ],
194
+ [
195
+ './demo_example/person_dress.jpg',
196
+ './demo_example/object_dress.jpg',
197
+ 'dress',
198
+ ],
199
+ [
200
+ './demo_example/person_shoes.jpg',
201
+ './demo_example/object_shoes.jpg',
202
+ 'shoe',
203
+ ],
204
+ [
205
+ './demo_example/person_earrings.jpg',
206
+ './demo_example/object_earrings.jpg',
207
+ 'earrings',
208
+ ],
209
+ [
210
+ './demo_example/person_bracelet.jpg',
211
+ './demo_example/object_bracelet.jpg',
212
+ 'bracelet',
213
+ ],
214
+ [
215
+ './demo_example/person_necklace.jpg',
216
+ './demo_example/object_necklace.jpg',
217
+ 'necklace',
218
+ ],
219
+ [
220
+ './demo_example/person_ring.jpg',
221
+ './demo_example/object_ring.jpg',
222
+ 'ring',
223
+ ],
224
+ [
225
+ './demo_example/person_sunglasses.jpg',
226
+ './demo_example/object_sunglasses.jpg',
227
+ 'sunglasses',
228
+ ],
229
+ [
230
+ './demo_example/person_glasses.jpg',
231
+ './demo_example/object_glasses.jpg',
232
+ 'glasses',
233
+ ],
234
+ [
235
+ './demo_example/person_belt.jpg',
236
+ './demo_example/object_belt.jpg',
237
+ 'belt',
238
+ ],
239
+ [
240
+ './demo_example/person_bag.jpg',
241
+ './demo_example/object_bag.jpg',
242
+ 'bag',
243
+ ],
244
+ [
245
+ './demo_example/person_hat.jpg',
246
+ './demo_example/object_hat.jpg',
247
+ 'hat',
248
+ ],
249
+ [
250
+ './demo_example/person_tie.jpg',
251
+ './demo_example/object_tie.jpg',
252
+ 'tie',
253
+ ],
254
+ [
255
+ './demo_example/person_bowtie.jpg',
256
+ './demo_example/object_bowtie.jpg',
257
+ 'bow tie',
258
+ ],
259
+ ],
260
+
261
+ inputs=[person_image, object_image, object_class],
262
+ examples_per_page=100
263
+ )
264
+
265
+ run_button.click(generate, inputs=[person_image, object_image, object_class, steps, guidance_scale, seed], outputs=[image_out])
266
+
267
+ demo.launch(server_name="0.0.0.0")
configs/omnitry_v1_unified.yaml ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model_root: checkpoints/FLUX.1-Fill-dev
2
+ lora_path: checkpoints/omnitry_v1_unified_stage2.safetensors
3
+ lora_rank: 16
4
+ lora_alpha: 16
5
+
6
+ object_map: {
7
+ 'top clothes': 'replacing the top cloth',
8
+ 'bottom clothes': 'replacing the bottom cloth',
9
+ 'dress': 'replacing the dress',
10
+ 'shoe': 'replacing the shoe',
11
+
12
+ 'earrings': 'trying on earrings',
13
+ 'bracelet': 'trying on bracelet',
14
+ 'necklace': 'trying on necklace',
15
+ 'ring': 'trying on ring',
16
+
17
+ 'sunglasses': 'trying on sunglasses',
18
+ 'glasses': 'trying on glasses',
19
+ 'belt': 'trying on belt',
20
+ 'bag': 'trying on bag',
21
+ 'hat': 'trying on hat',
22
+ 'tie': 'trying on tie',
23
+ 'bow tie': 'trying on bow tie',
24
+ }
demo_example/object_bag.jpg ADDED

Git LFS Details

  • SHA256: 52e34d4aae41fca9c41896f46a702787dca83f465697ba37b31eca012080d492
  • Pointer size: 130 Bytes
  • Size of remote file: 38.3 kB
demo_example/object_belt.jpg ADDED

Git LFS Details

  • SHA256: aa5eb7a72314f22b1f24e266dfc86dc7a338d2a8a169d0449bca995e39d16ffc
  • Pointer size: 130 Bytes
  • Size of remote file: 60.9 kB
demo_example/object_bottom_cloth.jpg ADDED

Git LFS Details

  • SHA256: c708d3dadf796f60a18128f4e376f9f960a76cdd0a1ce90c50b8c9fed9af3f0d
  • Pointer size: 130 Bytes
  • Size of remote file: 71.8 kB
demo_example/object_bowtie.jpg ADDED

Git LFS Details

  • SHA256: fbb8d0cf2b6fbd16a8a0bec9b9a5613e1a881f0e2135a265d4029e5829dfeff4
  • Pointer size: 131 Bytes
  • Size of remote file: 209 kB
demo_example/object_bracelet.jpg ADDED

Git LFS Details

  • SHA256: f326c938cf43e3b774b7043546e77e55c636ed1489185bf4eb0ea73c505f52ab
  • Pointer size: 131 Bytes
  • Size of remote file: 118 kB
demo_example/object_dress.jpg ADDED

Git LFS Details

  • SHA256: 8f4a487187ba5733312ea39e7c3f304b468b23a130dd9bfd43280996b78adc79
  • Pointer size: 131 Bytes
  • Size of remote file: 108 kB
demo_example/object_earrings.jpg ADDED

Git LFS Details

  • SHA256: 477009532f59fcda35cb4070d2f54404d6b4debaaf3de792299fbe58c88223cb
  • Pointer size: 130 Bytes
  • Size of remote file: 66.7 kB
demo_example/object_glasses.jpg ADDED

Git LFS Details

  • SHA256: 74390b18c222c3cd8d0159206adfa02ad800a2a6e73719e74bea16944910d23e
  • Pointer size: 130 Bytes
  • Size of remote file: 44.6 kB
demo_example/object_hat.jpg ADDED

Git LFS Details

  • SHA256: 9af0d9df985ddaa526c0cff33a132ae650ced96256e0a08ea4d9b05c76a733b3
  • Pointer size: 130 Bytes
  • Size of remote file: 79.9 kB
demo_example/object_necklace.jpg ADDED

Git LFS Details

  • SHA256: 84d1c33da8ccfa4a18cdf4ef4b7539cff4077ac9854097620a6dd633f4ea2090
  • Pointer size: 130 Bytes
  • Size of remote file: 47.4 kB
demo_example/object_ring.jpg ADDED

Git LFS Details

  • SHA256: 9f866c5a1e93336b2f4af8dc20166a4f7885581f3b619604dab37a716387ca13
  • Pointer size: 131 Bytes
  • Size of remote file: 116 kB
demo_example/object_shoes.jpg ADDED

Git LFS Details

  • SHA256: 60b95240eba37ead2a0d509ea55feb087e6b7a54c1a25365117b04a218ccb8e3
  • Pointer size: 131 Bytes
  • Size of remote file: 146 kB
demo_example/object_sunglasses.jpg ADDED

Git LFS Details

  • SHA256: f155d22d913a0a16ad07ce3324c105bc03660f8cb85d50ff72922e742c1a42c4
  • Pointer size: 130 Bytes
  • Size of remote file: 32.8 kB
demo_example/object_tie.jpg ADDED

Git LFS Details

  • SHA256: 455dcc245f790e69fb90848343ef16be3939c7bc45014ba5fdca7a6c797d25b8
  • Pointer size: 130 Bytes
  • Size of remote file: 29.7 kB
demo_example/object_top_cloth.jpg ADDED

Git LFS Details

  • SHA256: c2e1ecb458c8c45099f7aa629cabaff9dd9d427c51c57cffeaa92e1c0dfaf6da
  • Pointer size: 130 Bytes
  • Size of remote file: 37.5 kB
demo_example/person_bag.jpg ADDED

Git LFS Details

  • SHA256: cef2a0d8289d0fd7bc8553bd0aa1e27b322d7994aa71a5cc54d9fa986b787e46
  • Pointer size: 131 Bytes
  • Size of remote file: 136 kB
demo_example/person_belt.jpg ADDED

Git LFS Details

  • SHA256: b2353138cdf23b447124b187ba8c5ba635a84aa3cebdc37e5230f7c308037a6d
  • Pointer size: 130 Bytes
  • Size of remote file: 60 kB
demo_example/person_bottom_cloth.jpg ADDED

Git LFS Details

  • SHA256: d01eeb7c6275ddedd84bc350244d761ce5514d04bbbc194279f6711a01a313d3
  • Pointer size: 131 Bytes
  • Size of remote file: 134 kB
demo_example/person_bowtie.jpg ADDED

Git LFS Details

  • SHA256: 7df336d11b2e242488767ed961b3f75703d9a42384625b4977660e8bb58fea6a
  • Pointer size: 131 Bytes
  • Size of remote file: 125 kB
demo_example/person_bracelet.jpg ADDED

Git LFS Details

  • SHA256: 6cffdb09cb92f53341d5b8756da1c6eab394e65fd9aeac502661f0c18e8caa0d
  • Pointer size: 130 Bytes
  • Size of remote file: 91.6 kB
demo_example/person_dress.jpg ADDED

Git LFS Details

  • SHA256: 270aa0322b21fa639cd1b46a365f4b723b5f6af0f400c1ccb768bf5294919da9
  • Pointer size: 131 Bytes
  • Size of remote file: 127 kB
demo_example/person_earrings.jpg ADDED

Git LFS Details

  • SHA256: dde0486bd9c65dc53edf05dbc699e28e184dd24fa61f92bdac66fb5da0bb6df9
  • Pointer size: 130 Bytes
  • Size of remote file: 89.1 kB
demo_example/person_glasses.jpg ADDED

Git LFS Details

  • SHA256: 77e39f2e57ec236551a64d5fa78b3c6a8e8d542d4f67cbaf417314b8714af31f
  • Pointer size: 130 Bytes
  • Size of remote file: 90.4 kB
demo_example/person_hat.jpg ADDED

Git LFS Details

  • SHA256: c91db3993b94aea35688e5f86c3c849ec45204a534ffed3a66b2b55b8c23b1f5
  • Pointer size: 131 Bytes
  • Size of remote file: 113 kB
demo_example/person_necklace.jpg ADDED

Git LFS Details

  • SHA256: a23eb618152d30caeb5b5ab999c4ae4e4e3bc55775ced40c293c444b7517382f
  • Pointer size: 130 Bytes
  • Size of remote file: 98.5 kB
demo_example/person_ring.jpg ADDED

Git LFS Details

  • SHA256: 4273bd8b7d30d08c8159e46dbf3d6aa9d42ec9698c854a88938fc44b4ac51fc4
  • Pointer size: 131 Bytes
  • Size of remote file: 132 kB
demo_example/person_shoes.jpg ADDED

Git LFS Details

  • SHA256: ab9a4c0df363cc4bd6f7eb4c0ddb0c4823fd6a8fbb41076b9540224e36432a27
  • Pointer size: 130 Bytes
  • Size of remote file: 90.1 kB
demo_example/person_sunglasses.jpg ADDED

Git LFS Details

  • SHA256: 5266d92866b2f164c37fea6f5546758a8845dfa1db1e689cbb47610e199cdb54
  • Pointer size: 130 Bytes
  • Size of remote file: 67.4 kB
demo_example/person_tie.jpg ADDED

Git LFS Details

  • SHA256: 84222f3462c2b75bf4a599ccf1bee095f0cb4309d9079cef8dd23ebaa0f9a267
  • Pointer size: 130 Bytes
  • Size of remote file: 28.7 kB
demo_example/person_top_cloth.jpg ADDED

Git LFS Details

  • SHA256: ff42b165e664aa6ebbfdc765ecea475735495b26199443bbad4d8ea30cf8b2cc
  • Pointer size: 130 Bytes
  • Size of remote file: 90.5 kB
omnitry/models/attn_processors.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from torch.nn.utils.rnn import pad_sequence
4
+
5
+ try:
6
+ from flash_attn import flash_attn_varlen_func
7
+ FLASH_ATTN_AVALIABLE = True
8
+ except:
9
+ FLASH_ATTN_AVALIABLE = False
10
+
11
+
12
+ def apply_rotary_emb(
13
+ x: torch.Tensor,
14
+ freqs_cis,
15
+ use_real = True,
16
+ use_real_unbind_dim = -1,
17
+ ):
18
+ """
19
+ Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
20
+ to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
21
+ reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
22
+ tensors contain rotary embeddings and are returned as real tensors.
23
+
24
+ Args:
25
+ x (`torch.Tensor`):
26
+ Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
27
+ freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([B, S, D], [B, S, D],)
28
+
29
+ Returns:
30
+ Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
31
+ """
32
+ if use_real:
33
+ B, H, S, D = x.size()
34
+ cos, sin = freqs_cis[..., 0], freqs_cis[..., 1]
35
+ cos = cos.unsqueeze(1)
36
+ sin = sin.unsqueeze(1)
37
+ cos, sin = cos.to(x.device), sin.to(x.device)
38
+
39
+ if use_real_unbind_dim == -1:
40
+ # Used for flux, cogvideox, hunyuan-dit
41
+ x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
42
+ x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
43
+ elif use_real_unbind_dim == -2:
44
+ # Used for Stable Audio
45
+ x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2]
46
+ x_rotated = torch.cat([-x_imag, x_real], dim=-1)
47
+ else:
48
+ raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.")
49
+
50
+ out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
51
+
52
+ return out
53
+ else:
54
+ # used for lumina
55
+ x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
56
+ freqs_cis = freqs_cis.unsqueeze(2)
57
+ x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
58
+
59
+ return x_out.type_as(x)
60
+
61
+
62
+ class FluxAttnProcessor2_0:
63
+ """Attention processor used typically in processing the SD3-like self-attention projections."""
64
+
65
+ def __init__(self):
66
+ if not hasattr(F, "scaled_dot_product_attention"):
67
+ raise ImportError("FluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
68
+
69
+ def __call__(
70
+ self,
71
+ attn,
72
+ hidden_states,
73
+ encoder_hidden_states=None,
74
+ attention_mask=None,
75
+ image_rotary_emb=None,
76
+ lens=None,
77
+ ) -> torch.FloatTensor:
78
+ batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
79
+
80
+ # `sample` projections.
81
+ query = attn.to_q(hidden_states)
82
+ key = attn.to_k(hidden_states)
83
+ value = attn.to_v(hidden_states)
84
+
85
+ inner_dim = key.shape[-1]
86
+ head_dim = inner_dim // attn.heads
87
+
88
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
89
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
90
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
91
+
92
+ if attn.norm_q is not None:
93
+ query = attn.norm_q(query)
94
+ if attn.norm_k is not None:
95
+ key = attn.norm_k(key)
96
+
97
+ # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
98
+ if encoder_hidden_states is not None:
99
+ # `context` projections.
100
+ encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
101
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
102
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
103
+
104
+ encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
105
+ batch_size, -1, attn.heads, head_dim
106
+ ).transpose(1, 2)
107
+ encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
108
+ batch_size, -1, attn.heads, head_dim
109
+ ).transpose(1, 2)
110
+ encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
111
+ batch_size, -1, attn.heads, head_dim
112
+ ).transpose(1, 2)
113
+
114
+ if attn.norm_added_q is not None:
115
+ encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
116
+ if attn.norm_added_k is not None:
117
+ encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
118
+
119
+ # attention
120
+ query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
121
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
122
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
123
+
124
+ if image_rotary_emb is not None:
125
+ query = apply_rotary_emb(query, image_rotary_emb)
126
+ key = apply_rotary_emb(key, image_rotary_emb)
127
+
128
+ # supporting sequence length
129
+ q_lens = lens.clone() if lens is not None else torch.LongTensor([query.shape[2]] * batch_size).to(query.device)
130
+ k_lens = lens.clone() if lens is not None else torch.LongTensor([key.shape[2]] * batch_size).to(key.device)
131
+
132
+ # hacked: shared attention
133
+ txt_len = 512
134
+ context_key = [
135
+ torch.cat([key[0], key[1, :, txt_len:]], dim=1).permute(1, 0, 2),
136
+ key[1].permute(1, 0, 2)
137
+ ]
138
+ context_value = [
139
+ torch.cat([value[0], value[1, :, txt_len:]], dim=1).permute(1, 0, 2),
140
+ value[1].permute(1, 0, 2)
141
+ ]
142
+ k_lens = torch.LongTensor([k.size(0) for k in context_key]).to(query.device)
143
+ key = pad_sequence(context_key, batch_first=True).permute(0, 2, 1, 3)
144
+ value = pad_sequence(context_value, batch_first=True).permute(0, 2, 1, 3)
145
+
146
+ # core attention
147
+ if FLASH_ATTN_AVALIABLE:
148
+ query = query.permute(0, 2, 1, 3) # batch, sequence, num_head, head_dim
149
+ key = key.permute(0, 2, 1, 3)
150
+ value = value.permute(0, 2, 1, 3)
151
+
152
+ query = torch.cat([u[:l] for u, l in zip(query, q_lens)], dim=0)
153
+ key = torch.cat([u[:l] for u, l in zip(key, k_lens)], dim=0)
154
+ value = torch.cat([u[:l] for u, l in zip(value, k_lens)], dim=0)
155
+ cu_seqlens_q = F.pad(q_lens.cumsum(dim=0), (1, 0)).to(torch.int32)
156
+ cu_seqlens_k = F.pad(k_lens.cumsum(dim=0), (1, 0)).to(torch.int32)
157
+ max_seqlen_q = torch.max(q_lens).item()
158
+ max_seqlen_k = torch.max(k_lens).item()
159
+
160
+ hidden_states = flash_attn_varlen_func(query, key, value, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k)
161
+ hidden_states = pad_sequence([
162
+ hidden_states[start: end]
163
+ for start, end in zip(cu_seqlens_q[:-1], cu_seqlens_q[1:])
164
+ ], batch_first=True)
165
+ hidden_states = hidden_states.reshape(batch_size, -1, attn.heads * head_dim)
166
+
167
+ else:
168
+ attn_mask = torch.zeros((query.size(0), 1, query.size(2), key.size(2)), dtype=torch.bool).to(query)
169
+ for i, (q_len, k_len) in enumerate(zip(q_lens, k_lens)):
170
+ attn_mask[i, :, :q_len, :k_len] = True
171
+
172
+ hidden_states = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
173
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
174
+
175
+ hidden_states = hidden_states.to(query.dtype)
176
+
177
+ if encoder_hidden_states is not None:
178
+ encoder_hidden_states, hidden_states = (
179
+ hidden_states[:, : encoder_hidden_states.shape[1]],
180
+ hidden_states[:, encoder_hidden_states.shape[1] :],
181
+ )
182
+
183
+ # linear proj
184
+ hidden_states = attn.to_out[0](hidden_states)
185
+ # dropout
186
+ hidden_states = attn.to_out[1](hidden_states)
187
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
188
+
189
+ return hidden_states, encoder_hidden_states
190
+ else:
191
+ return hidden_states
omnitry/models/transformer_flux.py ADDED
@@ -0,0 +1,620 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from typing import Any, Dict, Optional, Tuple, Union, List
17
+
18
+ import numpy as np
19
+ import torch
20
+ import torch.nn as nn
21
+ import torch.nn.functional as F
22
+ import copy
23
+
24
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
25
+ from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
26
+ from diffusers.models.attention import FeedForward
27
+ from diffusers.models.attention_processor import (
28
+ Attention,
29
+ AttentionProcessor,
30
+ FusedFluxAttnProcessor2_0,
31
+ )
32
+ from diffusers.models.modeling_utils import ModelMixin
33
+ from diffusers.models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
34
+ from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
35
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
36
+ from diffusers.models.embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, get_1d_rotary_pos_embed
37
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
38
+
39
+ from .attn_processors import FluxAttnProcessor2_0
40
+
41
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
42
+
43
+
44
+ def zero_module(module):
45
+ # Zero out the parameters of a module and return it.
46
+ for p in module.parameters():
47
+ p.detach().zero_()
48
+ return module
49
+
50
+
51
+ class FluxPosEmbed(nn.Module):
52
+ # modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11
53
+ def __init__(self, theta: int, axes_dim: List[int]):
54
+ super().__init__()
55
+ self.theta = theta
56
+ self.axes_dim = axes_dim
57
+
58
+ def forward(self, ids: torch.Tensor) -> torch.Tensor:
59
+ # input: ids (S, N)
60
+ # return: [cos, sin] (S, D, 2)
61
+ n_axes = ids.shape[-1]
62
+ cos_out = []
63
+ sin_out = []
64
+ pos = ids.float()
65
+ is_mps = ids.device.type == "mps"
66
+ freqs_dtype = torch.float32 if is_mps else torch.float64
67
+ for i in range(n_axes):
68
+ cos, sin = get_1d_rotary_pos_embed(
69
+ self.axes_dim[i], pos[:, i], repeat_interleave_real=True, use_real=True, freqs_dtype=freqs_dtype
70
+ )
71
+ cos_out.append(cos)
72
+ sin_out.append(sin)
73
+ freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device)
74
+ freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device)
75
+
76
+ return torch.cat([freqs_cos.unsqueeze(2), freqs_sin.unsqueeze(2)], dim=2)
77
+
78
+
79
+ @maybe_allow_in_graph
80
+ class FluxSingleTransformerBlock(nn.Module):
81
+ r"""
82
+ A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
83
+
84
+ Reference: https://arxiv.org/abs/2403.03206
85
+
86
+ Parameters:
87
+ dim (`int`): The number of channels in the input and output.
88
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
89
+ attention_head_dim (`int`): The number of channels in each head.
90
+ context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
91
+ processing of `context` conditions.
92
+ """
93
+
94
+ def __init__(self, dim, num_attention_heads, attention_head_dim, mlp_ratio=4.0):
95
+ super().__init__()
96
+ self.mlp_hidden_dim = int(dim * mlp_ratio)
97
+
98
+ self.norm = AdaLayerNormZeroSingle(dim)
99
+ self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim)
100
+ self.act_mlp = nn.GELU(approximate="tanh")
101
+ self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
102
+ self.dim = dim
103
+
104
+ processor = FluxAttnProcessor2_0()
105
+ self.attn = Attention(
106
+ query_dim=dim,
107
+ cross_attention_dim=None,
108
+ dim_head=attention_head_dim,
109
+ heads=num_attention_heads,
110
+ out_dim=dim,
111
+ bias=True,
112
+ processor=processor,
113
+ qk_norm="rms_norm",
114
+ eps=1e-6,
115
+ pre_only=True,
116
+ )
117
+
118
+ def init_intra_group_adapter(self):
119
+ self.igadapter_attn = copy.deepcopy(self.attn)
120
+ self.igadapter_proj_out = nn.Linear(self.dim, self.dim)
121
+ zero_module(self.igadapter_proj_out)
122
+
123
+ def forward(
124
+ self,
125
+ hidden_states: torch.FloatTensor,
126
+ temb: torch.FloatTensor,
127
+ image_rotary_emb=None,
128
+ lens=None,
129
+ joint_attention_kwargs=None,
130
+ ):
131
+ residual = hidden_states
132
+ norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
133
+ mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
134
+ joint_attention_kwargs = joint_attention_kwargs or {}
135
+ attn_output = self.attn(
136
+ hidden_states=norm_hidden_states,
137
+ image_rotary_emb=image_rotary_emb,
138
+ lens=lens,
139
+ **joint_attention_kwargs,
140
+ )
141
+
142
+ hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
143
+ gate = gate.unsqueeze(1)
144
+ hidden_states = gate * self.proj_out(hidden_states)
145
+ hidden_states = residual + hidden_states
146
+ if hidden_states.dtype == torch.float16:
147
+ hidden_states = hidden_states.clip(-65504, 65504)
148
+
149
+ return hidden_states
150
+
151
+
152
+ @maybe_allow_in_graph
153
+ class FluxTransformerBlock(nn.Module):
154
+ r"""
155
+ A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
156
+
157
+ Reference: https://arxiv.org/abs/2403.03206
158
+
159
+ Parameters:
160
+ dim (`int`): The number of channels in the input and output.
161
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
162
+ attention_head_dim (`int`): The number of channels in each head.
163
+ context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
164
+ processing of `context` conditions.
165
+ """
166
+
167
+ def __init__(self, dim, num_attention_heads, attention_head_dim, qk_norm="rms_norm", eps=1e-6):
168
+ super().__init__()
169
+
170
+ self.dim = dim
171
+ self.norm1 = AdaLayerNormZero(dim)
172
+
173
+ self.norm1_context = AdaLayerNormZero(dim)
174
+
175
+ processor = FluxAttnProcessor2_0()
176
+
177
+ self.attn = Attention(
178
+ query_dim=dim,
179
+ cross_attention_dim=None,
180
+ added_kv_proj_dim=dim,
181
+ dim_head=attention_head_dim,
182
+ heads=num_attention_heads,
183
+ out_dim=dim,
184
+ context_pre_only=False,
185
+ bias=True,
186
+ processor=processor,
187
+ qk_norm=qk_norm,
188
+ eps=eps,
189
+ )
190
+
191
+ self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
192
+ self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
193
+
194
+ self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
195
+ self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
196
+
197
+ # let chunk size default to None
198
+ self._chunk_size = None
199
+ self._chunk_dim = 0
200
+
201
+ def init_intra_group_adapter(self):
202
+ self.igadapter_attn = copy.deepcopy(self.attn)
203
+ self.igadapter_proj_out = nn.Linear(self.dim, self.dim)
204
+ zero_module(self.igadapter_proj_out)
205
+
206
+ def forward(
207
+ self,
208
+ hidden_states: torch.FloatTensor,
209
+ encoder_hidden_states: torch.FloatTensor,
210
+ temb: torch.FloatTensor,
211
+ image_rotary_emb=None,
212
+ lens=None,
213
+ joint_attention_kwargs=None,
214
+ ):
215
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
216
+
217
+ norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
218
+ encoder_hidden_states, emb=temb
219
+ )
220
+ joint_attention_kwargs = joint_attention_kwargs or {}
221
+ # Attention.
222
+ attn_output, context_attn_output = self.attn(
223
+ hidden_states=norm_hidden_states,
224
+ encoder_hidden_states=norm_encoder_hidden_states,
225
+ image_rotary_emb=image_rotary_emb,
226
+ lens=lens,
227
+ **joint_attention_kwargs,
228
+ )
229
+
230
+ # Process attention outputs for the `hidden_states`.
231
+ attn_output = gate_msa.unsqueeze(1) * attn_output
232
+ hidden_states = hidden_states + attn_output
233
+
234
+ norm_hidden_states = self.norm2(hidden_states)
235
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
236
+
237
+ ff_output = self.ff(norm_hidden_states)
238
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
239
+
240
+ hidden_states = hidden_states + ff_output
241
+
242
+ # Process attention outputs for the `encoder_hidden_states`.
243
+
244
+ context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
245
+ encoder_hidden_states = encoder_hidden_states + context_attn_output
246
+
247
+ norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
248
+ norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
249
+
250
+ context_ff_output = self.ff_context(norm_encoder_hidden_states)
251
+ encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
252
+ if encoder_hidden_states.dtype == torch.float16:
253
+ encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
254
+
255
+ return encoder_hidden_states, hidden_states
256
+
257
+
258
+ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
259
+ """
260
+ The Transformer model introduced in Flux.
261
+
262
+ Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
263
+
264
+ Parameters:
265
+ patch_size (`int`): Patch size to turn the input data into small patches.
266
+ in_channels (`int`, *optional*, defaults to 16): The number of channels in the input.
267
+ num_layers (`int`, *optional*, defaults to 18): The number of layers of MMDiT blocks to use.
268
+ num_single_layers (`int`, *optional*, defaults to 18): The number of layers of single DiT blocks to use.
269
+ attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
270
+ num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention.
271
+ joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
272
+ pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`.
273
+ guidance_embeds (`bool`, defaults to False): Whether to use guidance embeddings.
274
+ """
275
+
276
+ _supports_gradient_checkpointing = True
277
+ _no_split_modules = ["FluxTransformerBlock", "FluxSingleTransformerBlock"]
278
+
279
+ @register_to_config
280
+ def __init__(
281
+ self,
282
+ patch_size: int = 1,
283
+ in_channels: int = 64,
284
+ out_channels: int = 64,
285
+ num_layers: int = 19,
286
+ num_single_layers: int = 38,
287
+ attention_head_dim: int = 128,
288
+ num_attention_heads: int = 24,
289
+ joint_attention_dim: int = 4096,
290
+ pooled_projection_dim: int = 768,
291
+ guidance_embeds: bool = False,
292
+ axes_dims_rope: Tuple[int] = (16, 56, 56),
293
+ ):
294
+ super().__init__()
295
+ self.in_channels = in_channels
296
+ self.out_channels = out_channels
297
+ self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
298
+
299
+ self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
300
+
301
+ text_time_guidance_cls = (
302
+ CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings
303
+ )
304
+ self.time_text_embed = text_time_guidance_cls(
305
+ embedding_dim=self.inner_dim, pooled_projection_dim=self.config.pooled_projection_dim
306
+ )
307
+
308
+ self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.inner_dim)
309
+ self.x_embedder = torch.nn.Linear(self.config.in_channels, self.inner_dim)
310
+
311
+ self.transformer_blocks = nn.ModuleList(
312
+ [
313
+ FluxTransformerBlock(
314
+ dim=self.inner_dim,
315
+ num_attention_heads=self.config.num_attention_heads,
316
+ attention_head_dim=self.config.attention_head_dim,
317
+ )
318
+ for i in range(self.config.num_layers)
319
+ ]
320
+ )
321
+
322
+ self.single_transformer_blocks = nn.ModuleList(
323
+ [
324
+ FluxSingleTransformerBlock(
325
+ dim=self.inner_dim,
326
+ num_attention_heads=self.config.num_attention_heads,
327
+ attention_head_dim=self.config.attention_head_dim,
328
+ )
329
+ for i in range(self.config.num_single_layers)
330
+ ]
331
+ )
332
+
333
+ self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
334
+ self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
335
+
336
+ self.gradient_checkpointing = False
337
+
338
+ @property
339
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
340
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
341
+ r"""
342
+ Returns:
343
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
344
+ indexed by its weight name.
345
+ """
346
+ # set recursively
347
+ processors = {}
348
+
349
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
350
+ if hasattr(module, "get_processor"):
351
+ processors[f"{name}.processor"] = module.get_processor()
352
+
353
+ for sub_name, child in module.named_children():
354
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
355
+
356
+ return processors
357
+
358
+ for name, module in self.named_children():
359
+ fn_recursive_add_processors(name, module, processors)
360
+
361
+ return processors
362
+
363
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
364
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
365
+ r"""
366
+ Sets the attention processor to use to compute attention.
367
+
368
+ Parameters:
369
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
370
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
371
+ for **all** `Attention` layers.
372
+
373
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
374
+ processor. This is strongly recommended when setting trainable attention processors.
375
+
376
+ """
377
+ count = len(self.attn_processors.keys())
378
+
379
+ if isinstance(processor, dict) and len(processor) != count:
380
+ raise ValueError(
381
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
382
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
383
+ )
384
+
385
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
386
+ if hasattr(module, "set_processor"):
387
+ if not isinstance(processor, dict):
388
+ module.set_processor(processor)
389
+ else:
390
+ module.set_processor(processor.pop(f"{name}.processor"))
391
+
392
+ for sub_name, child in module.named_children():
393
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
394
+
395
+ for name, module in self.named_children():
396
+ fn_recursive_attn_processor(name, module, processor)
397
+
398
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedFluxAttnProcessor2_0
399
+ def fuse_qkv_projections(self):
400
+ """
401
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
402
+ are fused. For cross-attention modules, key and value projection matrices are fused.
403
+
404
+ <Tip warning={true}>
405
+
406
+ This API is 🧪 experimental.
407
+
408
+ </Tip>
409
+ """
410
+ self.original_attn_processors = None
411
+
412
+ for _, attn_processor in self.attn_processors.items():
413
+ if "Added" in str(attn_processor.__class__.__name__):
414
+ raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
415
+
416
+ self.original_attn_processors = self.attn_processors
417
+
418
+ for module in self.modules():
419
+ if isinstance(module, Attention):
420
+ module.fuse_projections(fuse=True)
421
+
422
+ self.set_attn_processor(FusedFluxAttnProcessor2_0())
423
+
424
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
425
+ def unfuse_qkv_projections(self):
426
+ """Disables the fused QKV projection if enabled.
427
+
428
+ <Tip warning={true}>
429
+
430
+ This API is 🧪 experimental.
431
+
432
+ </Tip>
433
+
434
+ """
435
+ if self.original_attn_processors is not None:
436
+ self.set_attn_processor(self.original_attn_processors)
437
+
438
+ def _set_gradient_checkpointing(self, module, value=False):
439
+ if hasattr(module, "gradient_checkpointing"):
440
+ module.gradient_checkpointing = value
441
+
442
+ def forward(
443
+ self,
444
+ hidden_states: torch.Tensor,
445
+ encoder_hidden_states: torch.Tensor = None,
446
+ pooled_projections: torch.Tensor = None,
447
+ timestep: torch.LongTensor = None,
448
+ img_ids: torch.Tensor = None,
449
+ txt_ids: torch.Tensor = None,
450
+ guidance: torch.Tensor = None,
451
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
452
+ img_shapes: list = None,
453
+ img_lens: list = None,
454
+ return_dict: bool = True,
455
+ ) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
456
+ """
457
+ The [`FluxTransformer2DModel`] forward method.
458
+
459
+ Args:
460
+ hidden_states (`torch.FloatTensor` of shape `(batch size, sequence, channel)`):
461
+ Input `hidden_states`.
462
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
463
+ Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
464
+ pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
465
+ from the embeddings of input conditions.
466
+ timestep ( `torch.LongTensor`):
467
+ Used to indicate denoising step.
468
+ block_controlnet_hidden_states: (`list` of `torch.Tensor`):
469
+ A list of tensors that if specified are added to the residuals of transformer blocks.
470
+ joint_attention_kwargs (`dict`, *optional*):
471
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
472
+ `self.processor` in
473
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
474
+ return_dict (`bool`, *optional*, defaults to `True`):
475
+ Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
476
+ tuple.
477
+
478
+ Returns:
479
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
480
+ `tuple` where the first element is the sample tensor.
481
+ """
482
+ if joint_attention_kwargs is not None:
483
+ joint_attention_kwargs = joint_attention_kwargs.copy()
484
+ lora_scale = joint_attention_kwargs.pop("scale", 1.0)
485
+ else:
486
+ lora_scale = 1.0
487
+
488
+ if USE_PEFT_BACKEND:
489
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
490
+ scale_lora_layers(self, lora_scale)
491
+ else:
492
+ if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
493
+ logger.warning(
494
+ "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
495
+ )
496
+
497
+ # patchify
498
+ hidden_states = self.x_embedder(hidden_states)
499
+
500
+ # conditions (time, guidance, text)
501
+ bsz = hidden_states.size(0)
502
+ timestep = timestep.to(hidden_states.dtype) * 1000
503
+ if guidance is not None:
504
+ guidance = guidance.to(hidden_states.dtype) * 1000
505
+ else:
506
+ guidance = None
507
+ temb = (
508
+ self.time_text_embed(timestep, pooled_projections)
509
+ if guidance is None
510
+ else self.time_text_embed(timestep, guidance, pooled_projections)
511
+ )
512
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states)
513
+
514
+ if txt_ids.ndim == 2:
515
+ txt_ids = txt_ids[None].repeat(bsz, 1, 1)
516
+ if img_ids.ndim == 2:
517
+ img_ids = img_ids[None].repeat(bsz, 1, 1)
518
+
519
+ # shift pos id
520
+ max_w = img_ids[:, :, 2].max().item()
521
+ for b in range(bsz):
522
+ img_ids[b, :, 0] = b
523
+ txt_ids[b, :, 0] = b
524
+ img_ids[b, :, 2] += b * max_w
525
+
526
+ # prepare rope embedding
527
+ image_rotary_emb = torch.stack([
528
+ self.pos_embed(torch.cat([t_id, i_id], dim=0))
529
+ for t_id, i_id in zip(txt_ids, img_ids)
530
+ ])
531
+
532
+ # sequence length, TODO: varied txt length
533
+ if img_lens is not None:
534
+ lens = img_lens + encoder_hidden_states.size(1)
535
+ else:
536
+ lens = None
537
+
538
+ # transformer blocks
539
+ for block in self.transformer_blocks:
540
+
541
+ if self.training and self.gradient_checkpointing:
542
+
543
+ def create_custom_forward(module, return_dict=None):
544
+ def custom_forward(*inputs):
545
+ if return_dict is not None:
546
+ return module(*inputs, return_dict=return_dict)
547
+ else:
548
+ return module(*inputs)
549
+
550
+ return custom_forward
551
+
552
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
553
+ encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
554
+ create_custom_forward(block),
555
+ hidden_states,
556
+ encoder_hidden_states,
557
+ temb,
558
+ image_rotary_emb,
559
+ lens,
560
+ joint_attention_kwargs,
561
+ **ckpt_kwargs,
562
+ )
563
+
564
+ else:
565
+ encoder_hidden_states, hidden_states = block(
566
+ hidden_states=hidden_states,
567
+ encoder_hidden_states=encoder_hidden_states,
568
+ temb=temb,
569
+ image_rotary_emb=image_rotary_emb,
570
+ lens=lens,
571
+ joint_attention_kwargs=joint_attention_kwargs,
572
+ )
573
+
574
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
575
+
576
+ for block in self.single_transformer_blocks:
577
+ if self.training and self.gradient_checkpointing:
578
+
579
+ def create_custom_forward(module, return_dict=None):
580
+ def custom_forward(*inputs):
581
+ if return_dict is not None:
582
+ return module(*inputs, return_dict=return_dict)
583
+ else:
584
+ return module(*inputs)
585
+
586
+ return custom_forward
587
+
588
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
589
+ hidden_states = torch.utils.checkpoint.checkpoint(
590
+ create_custom_forward(block),
591
+ hidden_states,
592
+ temb,
593
+ image_rotary_emb,
594
+ lens,
595
+ joint_attention_kwargs,
596
+ **ckpt_kwargs,
597
+ )
598
+
599
+ else:
600
+ hidden_states = block(
601
+ hidden_states=hidden_states,
602
+ temb=temb,
603
+ image_rotary_emb=image_rotary_emb,
604
+ lens=lens,
605
+ joint_attention_kwargs=joint_attention_kwargs,
606
+ )
607
+
608
+ hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
609
+
610
+ hidden_states = self.norm_out(hidden_states, temb)
611
+ output = self.proj_out(hidden_states)
612
+
613
+ if USE_PEFT_BACKEND:
614
+ # remove `lora_scale` from each PEFT layer
615
+ unscale_lora_layers(self, lora_scale)
616
+
617
+ if not return_dict:
618
+ return (output,)
619
+
620
+ return Transformer2DModelOutput(sample=output)
omnitry/pipelines/pipeline_flux.py ADDED
@@ -0,0 +1,799 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ from typing import Any, Callable, Dict, List, Optional, Union
17
+
18
+ import numpy as np
19
+ import torch
20
+ from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
21
+
22
+ from diffusers.image_processor import VaeImageProcessor
23
+ from diffusers.loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
24
+ from diffusers.models.autoencoders import AutoencoderKL
25
+ from diffusers.models.transformers import FluxTransformer2DModel
26
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
27
+ from diffusers.utils import (
28
+ USE_PEFT_BACKEND,
29
+ is_torch_xla_available,
30
+ logging,
31
+ replace_example_docstring,
32
+ scale_lora_layers,
33
+ unscale_lora_layers,
34
+ )
35
+ from diffusers.utils.torch_utils import randn_tensor
36
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
37
+ from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
38
+
39
+
40
+ if is_torch_xla_available():
41
+ import torch_xla.core.xla_model as xm
42
+
43
+ XLA_AVAILABLE = True
44
+ else:
45
+ XLA_AVAILABLE = False
46
+
47
+
48
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
49
+
50
+ EXAMPLE_DOC_STRING = """
51
+ Examples:
52
+ ```py
53
+ >>> import torch
54
+ >>> from diffusers import FluxPipeline
55
+
56
+ >>> pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
57
+ >>> pipe.to("cuda")
58
+ >>> prompt = "A cat holding a sign that says hello world"
59
+ >>> # Depending on the variant being used, the pipeline call will slightly vary.
60
+ >>> # Refer to the pipeline documentation for more details.
61
+ >>> image = pipe(prompt, num_inference_steps=4, guidance_scale=0.0).images[0]
62
+ >>> image.save("flux.png")
63
+ ```
64
+ """
65
+
66
+
67
+ def calculate_shift(
68
+ image_seq_len,
69
+ base_seq_len: int = 256,
70
+ max_seq_len: int = 4096,
71
+ base_shift: float = 0.5,
72
+ max_shift: float = 1.15,
73
+ ):
74
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
75
+ b = base_shift - m * base_seq_len
76
+ mu = image_seq_len * m + b
77
+ return mu
78
+
79
+
80
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
81
+ def retrieve_timesteps(
82
+ scheduler,
83
+ num_inference_steps: Optional[int] = None,
84
+ device: Optional[Union[str, torch.device]] = None,
85
+ timesteps: Optional[List[int]] = None,
86
+ sigmas: Optional[List[float]] = None,
87
+ **kwargs,
88
+ ):
89
+ r"""
90
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
91
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
92
+
93
+ Args:
94
+ scheduler (`SchedulerMixin`):
95
+ The scheduler to get timesteps from.
96
+ num_inference_steps (`int`):
97
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
98
+ must be `None`.
99
+ device (`str` or `torch.device`, *optional*):
100
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
101
+ timesteps (`List[int]`, *optional*):
102
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
103
+ `num_inference_steps` and `sigmas` must be `None`.
104
+ sigmas (`List[float]`, *optional*):
105
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
106
+ `num_inference_steps` and `timesteps` must be `None`.
107
+
108
+ Returns:
109
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
110
+ second element is the number of inference steps.
111
+ """
112
+ if timesteps is not None and sigmas is not None:
113
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
114
+ if timesteps is not None:
115
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
116
+ if not accepts_timesteps:
117
+ raise ValueError(
118
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
119
+ f" timestep schedules. Please check whether you are using the correct scheduler."
120
+ )
121
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
122
+ timesteps = scheduler.timesteps
123
+ num_inference_steps = len(timesteps)
124
+ elif sigmas is not None:
125
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
126
+ if not accept_sigmas:
127
+ raise ValueError(
128
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
129
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
130
+ )
131
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
132
+ timesteps = scheduler.timesteps
133
+ num_inference_steps = len(timesteps)
134
+ else:
135
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
136
+ timesteps = scheduler.timesteps
137
+ return timesteps, num_inference_steps
138
+
139
+
140
+ class FluxPipeline(
141
+ DiffusionPipeline,
142
+ FluxLoraLoaderMixin,
143
+ FromSingleFileMixin,
144
+ TextualInversionLoaderMixin,
145
+ ):
146
+ r"""
147
+ The Flux pipeline for text-to-image generation.
148
+
149
+ Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
150
+
151
+ Args:
152
+ transformer ([`FluxTransformer2DModel`]):
153
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
154
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
155
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
156
+ vae ([`AutoencoderKL`]):
157
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
158
+ text_encoder ([`CLIPTextModel`]):
159
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
160
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
161
+ text_encoder_2 ([`T5EncoderModel`]):
162
+ [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
163
+ the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
164
+ tokenizer (`CLIPTokenizer`):
165
+ Tokenizer of class
166
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
167
+ tokenizer_2 (`T5TokenizerFast`):
168
+ Second Tokenizer of class
169
+ [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
170
+ """
171
+
172
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
173
+ _optional_components = []
174
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
175
+
176
+ def __init__(
177
+ self,
178
+ scheduler: FlowMatchEulerDiscreteScheduler,
179
+ vae: AutoencoderKL,
180
+ text_encoder: CLIPTextModel,
181
+ tokenizer: CLIPTokenizer,
182
+ text_encoder_2: T5EncoderModel,
183
+ tokenizer_2: T5TokenizerFast,
184
+ transformer: FluxTransformer2DModel,
185
+ ):
186
+ super().__init__()
187
+
188
+ self.register_modules(
189
+ vae=vae,
190
+ text_encoder=text_encoder,
191
+ text_encoder_2=text_encoder_2,
192
+ tokenizer=tokenizer,
193
+ tokenizer_2=tokenizer_2,
194
+ transformer=transformer,
195
+ scheduler=scheduler,
196
+ )
197
+ self.vae_scale_factor = (
198
+ 2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16
199
+ )
200
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
201
+ self.tokenizer_max_length = (
202
+ self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
203
+ )
204
+ self.default_sample_size = 64
205
+
206
+ def _get_t5_prompt_embeds(
207
+ self,
208
+ prompt: Union[str, List[str]] = None,
209
+ num_images_per_prompt: int = 1,
210
+ max_sequence_length: int = 512,
211
+ device: Optional[torch.device] = None,
212
+ dtype: Optional[torch.dtype] = None,
213
+ ):
214
+ device = device or self._execution_device
215
+ dtype = dtype or self.text_encoder.dtype
216
+
217
+ prompt = [prompt] if isinstance(prompt, str) else prompt
218
+ batch_size = len(prompt)
219
+
220
+ if isinstance(self, TextualInversionLoaderMixin):
221
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2)
222
+
223
+ text_inputs = self.tokenizer_2(
224
+ prompt,
225
+ padding="max_length",
226
+ max_length=max_sequence_length,
227
+ truncation=True,
228
+ return_length=False,
229
+ return_overflowing_tokens=False,
230
+ return_tensors="pt",
231
+ )
232
+ text_input_ids = text_inputs.input_ids
233
+ untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
234
+
235
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
236
+ removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
237
+ # logger.warning(
238
+ # "The following part of your input was truncated because `max_sequence_length` is set to "
239
+ # f" {max_sequence_length} tokens: {removed_text}"
240
+ # )
241
+
242
+ prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
243
+
244
+ dtype = self.text_encoder_2.dtype
245
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
246
+
247
+ _, seq_len, _ = prompt_embeds.shape
248
+
249
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
250
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
251
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
252
+
253
+ return prompt_embeds
254
+
255
+ def _get_clip_prompt_embeds(
256
+ self,
257
+ prompt: Union[str, List[str]],
258
+ num_images_per_prompt: int = 1,
259
+ device: Optional[torch.device] = None,
260
+ ):
261
+ device = device or self._execution_device
262
+
263
+ prompt = [prompt] if isinstance(prompt, str) else prompt
264
+ batch_size = len(prompt)
265
+
266
+ if isinstance(self, TextualInversionLoaderMixin):
267
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
268
+
269
+ text_inputs = self.tokenizer(
270
+ prompt,
271
+ padding="max_length",
272
+ max_length=self.tokenizer_max_length,
273
+ truncation=True,
274
+ return_overflowing_tokens=False,
275
+ return_length=False,
276
+ return_tensors="pt",
277
+ )
278
+
279
+ text_input_ids = text_inputs.input_ids
280
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
281
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
282
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
283
+ # logger.warning(
284
+ # "The following part of your input was truncated because CLIP can only handle sequences up to"
285
+ # f" {self.tokenizer_max_length} tokens: {removed_text}"
286
+ # )
287
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)
288
+
289
+ # Use pooled output of CLIPTextModel
290
+ prompt_embeds = prompt_embeds.pooler_output
291
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
292
+
293
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
294
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
295
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
296
+
297
+ return prompt_embeds
298
+
299
+ def encode_prompt(
300
+ self,
301
+ prompt: Union[str, List[str]],
302
+ prompt_2: Union[str, List[str]],
303
+ device: Optional[torch.device] = None,
304
+ num_images_per_prompt: int = 1,
305
+ prompt_embeds: Optional[torch.FloatTensor] = None,
306
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
307
+ max_sequence_length: int = 512,
308
+ lora_scale: Optional[float] = None,
309
+ ):
310
+ r"""
311
+
312
+ Args:
313
+ prompt (`str` or `List[str]`, *optional*):
314
+ prompt to be encoded
315
+ prompt_2 (`str` or `List[str]`, *optional*):
316
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
317
+ used in all text-encoders
318
+ device: (`torch.device`):
319
+ torch device
320
+ num_images_per_prompt (`int`):
321
+ number of images that should be generated per prompt
322
+ prompt_embeds (`torch.FloatTensor`, *optional*):
323
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
324
+ provided, text embeddings will be generated from `prompt` input argument.
325
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
326
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
327
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
328
+ lora_scale (`float`, *optional*):
329
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
330
+ """
331
+ device = device or self._execution_device
332
+
333
+ # set lora scale so that monkey patched LoRA
334
+ # function of text encoder can correctly access it
335
+ if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
336
+ self._lora_scale = lora_scale
337
+
338
+ # dynamically adjust the LoRA scale
339
+ if self.text_encoder is not None and USE_PEFT_BACKEND:
340
+ scale_lora_layers(self.text_encoder, lora_scale)
341
+ if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
342
+ scale_lora_layers(self.text_encoder_2, lora_scale)
343
+
344
+ prompt = [prompt] if isinstance(prompt, str) else prompt
345
+
346
+ if prompt_embeds is None:
347
+ prompt_2 = prompt_2 or prompt
348
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
349
+
350
+ # We only use the pooled prompt output from the CLIPTextModel
351
+ pooled_prompt_embeds = self._get_clip_prompt_embeds(
352
+ prompt=prompt,
353
+ device=device,
354
+ num_images_per_prompt=num_images_per_prompt,
355
+ )
356
+ prompt_embeds = self._get_t5_prompt_embeds(
357
+ prompt=prompt_2,
358
+ num_images_per_prompt=num_images_per_prompt,
359
+ max_sequence_length=max_sequence_length,
360
+ device=device,
361
+ )
362
+
363
+ if self.text_encoder is not None:
364
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
365
+ # Retrieve the original scale by scaling back the LoRA layers
366
+ unscale_lora_layers(self.text_encoder, lora_scale)
367
+
368
+ if self.text_encoder_2 is not None:
369
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
370
+ # Retrieve the original scale by scaling back the LoRA layers
371
+ unscale_lora_layers(self.text_encoder_2, lora_scale)
372
+
373
+ dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
374
+ text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
375
+
376
+ return prompt_embeds, pooled_prompt_embeds, text_ids
377
+
378
+ def check_inputs(
379
+ self,
380
+ prompt,
381
+ prompt_2,
382
+ height,
383
+ width,
384
+ prompt_embeds=None,
385
+ pooled_prompt_embeds=None,
386
+ callback_on_step_end_tensor_inputs=None,
387
+ max_sequence_length=None,
388
+ ):
389
+ if height % 8 != 0 or width % 8 != 0:
390
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
391
+
392
+ if callback_on_step_end_tensor_inputs is not None and not all(
393
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
394
+ ):
395
+ raise ValueError(
396
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
397
+ )
398
+
399
+ if prompt is not None and prompt_embeds is not None:
400
+ raise ValueError(
401
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
402
+ " only forward one of the two."
403
+ )
404
+ elif prompt_2 is not None and prompt_embeds is not None:
405
+ raise ValueError(
406
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
407
+ " only forward one of the two."
408
+ )
409
+ elif prompt is None and prompt_embeds is None:
410
+ raise ValueError(
411
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
412
+ )
413
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
414
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
415
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
416
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
417
+
418
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
419
+ raise ValueError(
420
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
421
+ )
422
+
423
+ if max_sequence_length is not None and max_sequence_length > 512:
424
+ raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
425
+
426
+ @staticmethod
427
+ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
428
+ latent_image_ids = torch.zeros(height // 2, width // 2, 3)
429
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
430
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
431
+
432
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
433
+
434
+ latent_image_ids = latent_image_ids.reshape(
435
+ latent_image_id_height * latent_image_id_width, latent_image_id_channels
436
+ )
437
+
438
+ return latent_image_ids.to(device=device, dtype=dtype)
439
+
440
+ @staticmethod
441
+ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
442
+ latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
443
+ latents = latents.permute(0, 2, 4, 1, 3, 5)
444
+ latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
445
+
446
+ return latents
447
+
448
+ @staticmethod
449
+ def _unpack_latents(latents, height, width, vae_scale_factor):
450
+ batch_size, num_patches, channels = latents.shape
451
+
452
+ height = height // vae_scale_factor
453
+ width = width // vae_scale_factor
454
+
455
+ latents = latents.view(batch_size, height, width, channels // 4, 2, 2)
456
+ latents = latents.permute(0, 3, 1, 4, 2, 5)
457
+
458
+ latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2)
459
+
460
+ return latents
461
+
462
+ def enable_vae_slicing(self):
463
+ r"""
464
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
465
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
466
+ """
467
+ self.vae.enable_slicing()
468
+
469
+ def disable_vae_slicing(self):
470
+ r"""
471
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
472
+ computing decoding in one step.
473
+ """
474
+ self.vae.disable_slicing()
475
+
476
+ def enable_vae_tiling(self):
477
+ r"""
478
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
479
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
480
+ processing larger images.
481
+ """
482
+ self.vae.enable_tiling()
483
+
484
+ def disable_vae_tiling(self):
485
+ r"""
486
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
487
+ computing decoding in one step.
488
+ """
489
+ self.vae.disable_tiling()
490
+
491
+ def prepare_latents(
492
+ self,
493
+ batch_size,
494
+ num_channels_latents,
495
+ height,
496
+ width,
497
+ dtype,
498
+ device,
499
+ generator,
500
+ latents=None,
501
+ ):
502
+ height = 2 * (int(height) // self.vae_scale_factor)
503
+ width = 2 * (int(width) // self.vae_scale_factor)
504
+
505
+ shape = (batch_size, num_channels_latents, height, width)
506
+
507
+ if latents is not None:
508
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
509
+ return latents.to(device=device, dtype=dtype), latent_image_ids
510
+
511
+ if isinstance(generator, list) and len(generator) != batch_size:
512
+ raise ValueError(
513
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
514
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
515
+ )
516
+
517
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
518
+ latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
519
+
520
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
521
+
522
+ return latents, latent_image_ids
523
+
524
+ @property
525
+ def guidance_scale(self):
526
+ return self._guidance_scale
527
+
528
+ @property
529
+ def joint_attention_kwargs(self):
530
+ return self._joint_attention_kwargs
531
+
532
+ @property
533
+ def num_timesteps(self):
534
+ return self._num_timesteps
535
+
536
+ @property
537
+ def interrupt(self):
538
+ return self._interrupt
539
+
540
+ @torch.no_grad()
541
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
542
+ def __call__(
543
+ self,
544
+ prompt: Union[str, List[str]] = None,
545
+ prompt_2: Optional[Union[str, List[str]]] = None,
546
+ height: Optional[int] = None,
547
+ width: Optional[int] = None,
548
+ num_inference_steps: int = 28,
549
+ timesteps: List[int] = None,
550
+ guidance_scale: float = 3.5,
551
+ num_images_per_prompt: Optional[int] = 1,
552
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
553
+ latents: Optional[torch.FloatTensor] = None,
554
+ prompt_embeds: Optional[torch.FloatTensor] = None,
555
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
556
+ output_type: Optional[str] = "pil",
557
+ return_dict: bool = True,
558
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
559
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
560
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
561
+ max_sequence_length: int = 512,
562
+ condition_latents=None,
563
+ condition_latents_indices=None,
564
+ condition_diffuse_ratio=1.0,
565
+ ):
566
+ r"""
567
+ Function invoked when calling the pipeline for generation.
568
+
569
+ Args:
570
+ prompt (`str` or `List[str]`, *optional*):
571
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
572
+ instead.
573
+ prompt_2 (`str` or `List[str]`, *optional*):
574
+ The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
575
+ will be used instead
576
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
577
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
578
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
579
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
580
+ num_inference_steps (`int`, *optional*, defaults to 50):
581
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
582
+ expense of slower inference.
583
+ timesteps (`List[int]`, *optional*):
584
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
585
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
586
+ passed will be used. Must be in descending order.
587
+ guidance_scale (`float`, *optional*, defaults to 7.0):
588
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
589
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
590
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
591
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
592
+ usually at the expense of lower image quality.
593
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
594
+ The number of images to generate per prompt.
595
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
596
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
597
+ to make generation deterministic.
598
+ latents (`torch.FloatTensor`, *optional*):
599
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
600
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
601
+ tensor will ge generated by sampling using the supplied random `generator`.
602
+ prompt_embeds (`torch.FloatTensor`, *optional*):
603
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
604
+ provided, text embeddings will be generated from `prompt` input argument.
605
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
606
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
607
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
608
+ output_type (`str`, *optional*, defaults to `"pil"`):
609
+ The output format of the generate image. Choose between
610
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
611
+ return_dict (`bool`, *optional*, defaults to `True`):
612
+ Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
613
+ joint_attention_kwargs (`dict`, *optional*):
614
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
615
+ `self.processor` in
616
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
617
+ callback_on_step_end (`Callable`, *optional*):
618
+ A function that calls at the end of each denoising steps during the inference. The function is called
619
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
620
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
621
+ `callback_on_step_end_tensor_inputs`.
622
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
623
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
624
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
625
+ `._callback_tensor_inputs` attribute of your pipeline class.
626
+ max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
627
+
628
+ Examples:
629
+
630
+ Returns:
631
+ [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
632
+ is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
633
+ images.
634
+ """
635
+
636
+ height = height or self.default_sample_size * self.vae_scale_factor
637
+ width = width or self.default_sample_size * self.vae_scale_factor
638
+
639
+ # 1. Check inputs. Raise error if not correct
640
+ self.check_inputs(
641
+ prompt,
642
+ prompt_2,
643
+ height,
644
+ width,
645
+ prompt_embeds=prompt_embeds,
646
+ pooled_prompt_embeds=pooled_prompt_embeds,
647
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
648
+ max_sequence_length=max_sequence_length,
649
+ )
650
+
651
+ self._guidance_scale = guidance_scale
652
+ self._joint_attention_kwargs = joint_attention_kwargs
653
+ self._interrupt = False
654
+
655
+ # 2. Define call parameters
656
+ if prompt is not None and isinstance(prompt, str):
657
+ batch_size = 1
658
+ elif prompt is not None and isinstance(prompt, list):
659
+ batch_size = len(prompt)
660
+ else:
661
+ batch_size = prompt_embeds.shape[0]
662
+
663
+ device = self._execution_device
664
+
665
+ lora_scale = (
666
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
667
+ )
668
+ (
669
+ prompt_embeds,
670
+ pooled_prompt_embeds,
671
+ text_ids,
672
+ ) = self.encode_prompt(
673
+ prompt=prompt,
674
+ prompt_2=prompt_2,
675
+ prompt_embeds=prompt_embeds,
676
+ pooled_prompt_embeds=pooled_prompt_embeds,
677
+ device=device,
678
+ num_images_per_prompt=num_images_per_prompt,
679
+ max_sequence_length=max_sequence_length,
680
+ lora_scale=lora_scale,
681
+ )
682
+
683
+ # 4. Prepare latent variables
684
+ num_channels_latents = self.transformer.config.in_channels // 4
685
+ latents, latent_image_ids = self.prepare_latents(
686
+ batch_size * num_images_per_prompt,
687
+ num_channels_latents,
688
+ height,
689
+ width,
690
+ prompt_embeds.dtype,
691
+ device,
692
+ generator,
693
+ latents,
694
+ )
695
+
696
+ # 5. Prepare timesteps
697
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
698
+ image_seq_len = latents.shape[1]
699
+ mu = calculate_shift(
700
+ image_seq_len,
701
+ self.scheduler.config.base_image_seq_len,
702
+ self.scheduler.config.max_image_seq_len,
703
+ self.scheduler.config.base_shift,
704
+ self.scheduler.config.max_shift,
705
+ )
706
+ timesteps, num_inference_steps = retrieve_timesteps(
707
+ self.scheduler,
708
+ num_inference_steps,
709
+ device,
710
+ timesteps,
711
+ sigmas,
712
+ mu=mu,
713
+ )
714
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
715
+ self._num_timesteps = len(timesteps)
716
+
717
+ # handle guidance
718
+ if self.transformer.config.guidance_embeds:
719
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
720
+ guidance = guidance.expand(latents.shape[0])
721
+ else:
722
+ guidance = None
723
+
724
+ if condition_latents is not None and condition_latents_indices is not None:
725
+ condition_noises = [torch.randn_like(z) for z in condition_latents]
726
+
727
+ # 6. Denoising loop
728
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
729
+ for i, t in enumerate(timesteps):
730
+ if self.interrupt:
731
+ continue
732
+
733
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
734
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
735
+
736
+ # image conditioning
737
+ if condition_latents is not None and condition_latents_indices is not None:
738
+ for z, idx, noise in zip(condition_latents, condition_latents_indices, condition_noises):
739
+ condition_t = timestep[idx] / 1000 * condition_diffuse_ratio
740
+ timestep[idx] = int(timestep[idx] * condition_diffuse_ratio)
741
+ latents[idx] = (1 - condition_t) * z + condition_t * noise
742
+
743
+ noise_pred = self.transformer(
744
+ hidden_states=latents,
745
+ timestep=timestep / 1000,
746
+ guidance=guidance,
747
+ pooled_projections=pooled_prompt_embeds,
748
+ encoder_hidden_states=prompt_embeds,
749
+ txt_ids=text_ids,
750
+ img_ids=latent_image_ids,
751
+ joint_attention_kwargs=self.joint_attention_kwargs,
752
+ return_dict=False,
753
+ )[0]
754
+
755
+ # compute the previous noisy sample x_t -> x_t-1
756
+ latents_dtype = latents.dtype
757
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
758
+
759
+ if latents.dtype != latents_dtype:
760
+ if torch.backends.mps.is_available():
761
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
762
+ latents = latents.to(latents_dtype)
763
+
764
+ if callback_on_step_end is not None:
765
+ callback_kwargs = {}
766
+ for k in callback_on_step_end_tensor_inputs:
767
+ callback_kwargs[k] = locals()[k]
768
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
769
+
770
+ latents = callback_outputs.pop("latents", latents)
771
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
772
+
773
+ # call the callback, if provided
774
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
775
+ progress_bar.update()
776
+
777
+ if XLA_AVAILABLE:
778
+ xm.mark_step()
779
+
780
+ if condition_latents is not None and condition_latents_indices is not None:
781
+ for z, idx in zip(condition_latents, condition_latents_indices):
782
+ latents[idx] = z
783
+
784
+ if output_type == "latent":
785
+ image = latents
786
+
787
+ else:
788
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
789
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
790
+ image = self.vae.decode(latents, return_dict=False)[0]
791
+ image = self.image_processor.postprocess(image, output_type=output_type)
792
+
793
+ # Offload all models
794
+ self.maybe_free_model_hooks()
795
+
796
+ if not return_dict:
797
+ return (image,)
798
+
799
+ return FluxPipelineOutput(images=image)
omnitry/pipelines/pipeline_flux_fill.py ADDED
@@ -0,0 +1,510 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ from typing import Any, Callable, Dict, List, Optional, Union
17
+
18
+ import numpy as np
19
+ import torch
20
+ from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
21
+
22
+ from diffusers.image_processor import VaeImageProcessor, PipelineImageInput
23
+ from diffusers.models.autoencoders import AutoencoderKL
24
+ from diffusers.models.transformers import FluxTransformer2DModel
25
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
26
+ from diffusers.utils import (
27
+ is_torch_xla_available,
28
+ logging,
29
+ replace_example_docstring,
30
+ )
31
+ from diffusers.utils.torch_utils import randn_tensor
32
+ from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
33
+
34
+ from .pipeline_flux import FluxPipeline, calculate_shift, retrieve_timesteps
35
+
36
+ if is_torch_xla_available():
37
+ import torch_xla.core.xla_model as xm
38
+
39
+ XLA_AVAILABLE = True
40
+ else:
41
+ XLA_AVAILABLE = False
42
+
43
+
44
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
45
+
46
+ EXAMPLE_DOC_STRING = """
47
+ Examples:
48
+ ```py
49
+ >>> import torch
50
+ >>> from diffusers import FluxPipeline
51
+
52
+ >>> pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
53
+ >>> pipe.to("cuda")
54
+ >>> prompt = "A cat holding a sign that says hello world"
55
+ >>> # Depending on the variant being used, the pipeline call will slightly vary.
56
+ >>> # Refer to the pipeline documentation for more details.
57
+ >>> image = pipe(prompt, num_inference_steps=4, guidance_scale=0.0).images[0]
58
+ >>> image.save("flux.png")
59
+ ```
60
+ """
61
+
62
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
63
+ def retrieve_latents(
64
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
65
+ ):
66
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
67
+ return encoder_output.latent_dist.sample(generator)
68
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
69
+ return encoder_output.latent_dist.mode()
70
+ elif hasattr(encoder_output, "latents"):
71
+ return encoder_output.latents
72
+ else:
73
+ raise AttributeError("Could not access latents of provided encoder_output")
74
+
75
+
76
+ class FluxFillPipeline(FluxPipeline):
77
+ r"""
78
+ The Flux pipeline for text-to-image generation.
79
+
80
+ Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
81
+
82
+ Args:
83
+ transformer ([`FluxTransformer2DModel`]):
84
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
85
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
86
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
87
+ vae ([`AutoencoderKL`]):
88
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
89
+ text_encoder ([`CLIPTextModel`]):
90
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
91
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
92
+ text_encoder_2 ([`T5EncoderModel`]):
93
+ [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
94
+ the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
95
+ tokenizer (`CLIPTokenizer`):
96
+ Tokenizer of class
97
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
98
+ tokenizer_2 (`T5TokenizerFast`):
99
+ Second Tokenizer of class
100
+ [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
101
+ """
102
+
103
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
104
+ _optional_components = []
105
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
106
+
107
+ def __init__(
108
+ self,
109
+ scheduler: FlowMatchEulerDiscreteScheduler,
110
+ vae: AutoencoderKL,
111
+ text_encoder: CLIPTextModel,
112
+ tokenizer: CLIPTokenizer,
113
+ text_encoder_2: T5EncoderModel,
114
+ tokenizer_2: T5TokenizerFast,
115
+ transformer: FluxTransformer2DModel,
116
+ ):
117
+ super().__init__(
118
+ scheduler=scheduler,
119
+ vae=vae,
120
+ text_encoder=text_encoder,
121
+ tokenizer=tokenizer,
122
+ text_encoder_2=text_encoder_2,
123
+ tokenizer_2=tokenizer_2,
124
+ transformer=transformer
125
+ )
126
+ self.mask_processor = VaeImageProcessor(
127
+ vae_scale_factor=self.vae_scale_factor,
128
+ vae_latent_channels=self.vae.config.latent_channels,
129
+ do_normalize=False,
130
+ do_binarize=True,
131
+ do_convert_grayscale=True,
132
+ )
133
+
134
+ def prepare_mask_latents(
135
+ self,
136
+ mask,
137
+ masked_image,
138
+ batch_size,
139
+ num_channels_latents,
140
+ num_images_per_prompt,
141
+ height,
142
+ width,
143
+ dtype,
144
+ device,
145
+ generator,
146
+ ):
147
+ # 1. calculate the height and width of the latents
148
+ # VAE applies 8x compression on images but we must also account for packing which requires
149
+ # latent height and width to be divisible by 2.
150
+ height = 2 * (int(height) // self.vae_scale_factor)
151
+ width = 2 * (int(width) // self.vae_scale_factor)
152
+
153
+ # 2. encode the masked image
154
+ if masked_image.shape[1] == num_channels_latents:
155
+ masked_image_latents = masked_image
156
+ else:
157
+ masked_image_latents = retrieve_latents(self.vae.encode(masked_image), generator=generator)
158
+
159
+ masked_image_latents = (masked_image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
160
+ masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
161
+
162
+ # 3. duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
163
+ batch_size = batch_size * num_images_per_prompt
164
+ if mask.shape[0] < batch_size:
165
+ if not batch_size % mask.shape[0] == 0:
166
+ raise ValueError(
167
+ "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
168
+ f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
169
+ " of masks that you pass is divisible by the total requested batch size."
170
+ )
171
+ mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
172
+ if masked_image_latents.shape[0] < batch_size:
173
+ if not batch_size % masked_image_latents.shape[0] == 0:
174
+ raise ValueError(
175
+ "The passed images and the required batch size don't match. Images are supposed to be duplicated"
176
+ f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
177
+ " Make sure the number of images that you pass is divisible by the total requested batch size."
178
+ )
179
+ masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1)
180
+
181
+ # 4. pack the masked_image_latents
182
+ # batch_size, num_channels_latents, height, width -> batch_size, height//2 * width//2 , num_channels_latents*4
183
+ masked_image_latents = self._pack_latents(
184
+ masked_image_latents,
185
+ batch_size,
186
+ num_channels_latents,
187
+ height,
188
+ width,
189
+ )
190
+
191
+ # 5.resize mask to latents shape we we concatenate the mask to the latents
192
+ mask = mask[:, 0, :, :] # batch_size, 8 * height, 8 * width (mask has not been 8x compressed)
193
+ mask = mask.view(
194
+ batch_size, height, self.vae_scale_factor // 2, width, self.vae_scale_factor // 2
195
+ ) # batch_size, height, 8, width, 8
196
+ mask = mask.permute(0, 2, 4, 1, 3) # batch_size, 8, 8, height, width
197
+ mask = mask.reshape(
198
+ batch_size, (self.vae_scale_factor // 2) * (self.vae_scale_factor // 2), height, width
199
+ ) # batch_size, 8*8, height, width
200
+
201
+ # 6. pack the mask:
202
+ # batch_size, 64, height, width -> batch_size, height//2 * width//2 , 64*2*2
203
+ mask = self._pack_latents(
204
+ mask,
205
+ batch_size,
206
+ (self.vae_scale_factor // 2) * (self.vae_scale_factor // 2),
207
+ height,
208
+ width,
209
+ )
210
+ mask = mask.to(device=device, dtype=dtype)
211
+
212
+ return mask, masked_image_latents
213
+
214
+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps
215
+ def get_timesteps(self, num_inference_steps, strength, device):
216
+ # get the original timestep using init_timestep
217
+ init_timestep = min(num_inference_steps * strength, num_inference_steps)
218
+
219
+ t_start = int(max(num_inference_steps - init_timestep, 0))
220
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
221
+ if hasattr(self.scheduler, "set_begin_index"):
222
+ self.scheduler.set_begin_index(t_start * self.scheduler.order)
223
+
224
+ return timesteps, num_inference_steps - t_start
225
+
226
+ def get_latents_with_image(self, image, latent_timestep, batch_size, num_channels_latents, height, width, generator, device, dtype):
227
+ image = image.to(device=device, dtype=dtype)
228
+ image_latents = self.vae.encode(image).latent_dist.sample(generator)
229
+ image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
230
+ batch_size, num_channels_latents, height, width = image_latents.size()
231
+ noise = randn_tensor(image_latents.size(), generator=generator, device=device, dtype=dtype)
232
+ latents = self.scheduler.scale_noise(image_latents, latent_timestep, noise)
233
+ latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
234
+
235
+ return latents
236
+
237
+ @torch.no_grad()
238
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
239
+ def __call__(
240
+ self,
241
+ prompt: Union[str, List[str]] = None,
242
+ prompt_2: Optional[Union[str, List[str]]] = None,
243
+ img_cond: torch.FloatTensor = None,
244
+ mask: torch.FloatTensor = None,
245
+ height: Optional[int] = None,
246
+ width: Optional[int] = None,
247
+ strength: float = 1.0,
248
+ image: PipelineImageInput = None,
249
+ num_inference_steps: int = 28,
250
+ timesteps: List[int] = None,
251
+ guidance_scale: float = 3.5,
252
+ num_images_per_prompt: Optional[int] = 1,
253
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
254
+ latents: Optional[torch.FloatTensor] = None,
255
+ prompt_embeds: Optional[torch.FloatTensor] = None,
256
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
257
+ output_type: Optional[str] = "pil",
258
+ return_dict: bool = True,
259
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
260
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
261
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
262
+ max_sequence_length: int = 512,
263
+ ):
264
+ r"""
265
+ Function invoked when calling the pipeline for generation.
266
+
267
+ Args:
268
+ prompt (`str` or `List[str]`, *optional*):
269
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
270
+ instead.
271
+ prompt_2 (`str` or `List[str]`, *optional*):
272
+ The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
273
+ will be used instead
274
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
275
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
276
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
277
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
278
+ num_inference_steps (`int`, *optional*, defaults to 50):
279
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
280
+ expense of slower inference.
281
+ timesteps (`List[int]`, *optional*):
282
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
283
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
284
+ passed will be used. Must be in descending order.
285
+ guidance_scale (`float`, *optional*, defaults to 7.0):
286
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
287
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
288
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
289
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
290
+ usually at the expense of lower image quality.
291
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
292
+ The number of images to generate per prompt.
293
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
294
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
295
+ to make generation deterministic.
296
+ latents (`torch.FloatTensor`, *optional*):
297
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
298
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
299
+ tensor will ge generated by sampling using the supplied random `generator`.
300
+ prompt_embeds (`torch.FloatTensor`, *optional*):
301
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
302
+ provided, text embeddings will be generated from `prompt` input argument.
303
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
304
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
305
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
306
+ output_type (`str`, *optional*, defaults to `"pil"`):
307
+ The output format of the generate image. Choose between
308
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
309
+ return_dict (`bool`, *optional*, defaults to `True`):
310
+ Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
311
+ joint_attention_kwargs (`dict`, *optional*):
312
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
313
+ `self.processor` in
314
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
315
+ callback_on_step_end (`Callable`, *optional*):
316
+ A function that calls at the end of each denoising steps during the inference. The function is called
317
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
318
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
319
+ `callback_on_step_end_tensor_inputs`.
320
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
321
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
322
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
323
+ `._callback_tensor_inputs` attribute of your pipeline class.
324
+ max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
325
+
326
+ Examples:
327
+
328
+ Returns:
329
+ [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
330
+ is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
331
+ images.
332
+ """
333
+
334
+ height = height or self.default_sample_size * self.vae_scale_factor
335
+ width = width or self.default_sample_size * self.vae_scale_factor
336
+
337
+ # 1. Check inputs. Raise error if not correct
338
+ self.check_inputs(
339
+ prompt,
340
+ prompt_2,
341
+ height,
342
+ width,
343
+ prompt_embeds=prompt_embeds,
344
+ pooled_prompt_embeds=pooled_prompt_embeds,
345
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
346
+ max_sequence_length=max_sequence_length,
347
+ )
348
+
349
+ self._guidance_scale = guidance_scale
350
+ self._joint_attention_kwargs = joint_attention_kwargs
351
+ self._interrupt = False
352
+
353
+ # 2. Define call parameters
354
+ if prompt is not None and isinstance(prompt, str):
355
+ batch_size = 1
356
+ elif prompt is not None and isinstance(prompt, list):
357
+ batch_size = len(prompt)
358
+ else:
359
+ batch_size = prompt_embeds.shape[0]
360
+
361
+ device = self._execution_device
362
+
363
+ lora_scale = (
364
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
365
+ )
366
+ (
367
+ prompt_embeds,
368
+ pooled_prompt_embeds,
369
+ text_ids,
370
+ ) = self.encode_prompt(
371
+ prompt=prompt,
372
+ prompt_2=prompt_2,
373
+ prompt_embeds=prompt_embeds,
374
+ pooled_prompt_embeds=pooled_prompt_embeds,
375
+ device=device,
376
+ num_images_per_prompt=num_images_per_prompt,
377
+ max_sequence_length=max_sequence_length,
378
+ lora_scale=lora_scale,
379
+ )
380
+
381
+ # 4. Prepare latent variables
382
+ num_channels_latents = self.vae.config.latent_channels
383
+ latents, latent_image_ids = self.prepare_latents(
384
+ batch_size * num_images_per_prompt,
385
+ num_channels_latents,
386
+ height,
387
+ width,
388
+ prompt_embeds.dtype,
389
+ device,
390
+ generator,
391
+ latents,
392
+ )
393
+
394
+ # 4.5 Prepare masked image latents
395
+ img_cond = self.image_processor.preprocess(img_cond, height=height, width=width)
396
+ mask = self.mask_processor.preprocess(mask, height=height, width=width)
397
+ masked_image = img_cond * (1 - mask)
398
+ masked_image = masked_image.to(device=device, dtype=prompt_embeds.dtype)
399
+
400
+ height, width = masked_image.shape[-2:]
401
+ mask, masked_image_latents = self.prepare_mask_latents(
402
+ mask,
403
+ masked_image,
404
+ batch_size,
405
+ num_channels_latents,
406
+ num_images_per_prompt,
407
+ height,
408
+ width,
409
+ prompt_embeds.dtype,
410
+ device,
411
+ generator,
412
+ )
413
+ masked_image_latents = torch.cat((masked_image_latents, mask), dim=-1)
414
+
415
+ # 5. Prepare timesteps
416
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
417
+ image_seq_len = latents.shape[1]
418
+ mu = calculate_shift(
419
+ image_seq_len,
420
+ self.scheduler.config.base_image_seq_len,
421
+ self.scheduler.config.max_image_seq_len,
422
+ self.scheduler.config.base_shift,
423
+ self.scheduler.config.max_shift,
424
+ )
425
+ timesteps, num_inference_steps = retrieve_timesteps(
426
+ self.scheduler,
427
+ num_inference_steps,
428
+ device,
429
+ timesteps,
430
+ sigmas,
431
+ mu=mu,
432
+ )
433
+
434
+ if strength != 1.0 and image is not None:
435
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
436
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
437
+ latents = self.get_latents_with_image(image, latent_timestep, batch_size, num_channels_latents, height, width, generator, device, latents.dtype)
438
+
439
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
440
+ self._num_timesteps = len(timesteps)
441
+
442
+ # handle guidance
443
+ if self.transformer.config.guidance_embeds:
444
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
445
+ guidance = guidance.expand(latents.shape[0])
446
+ else:
447
+ guidance = None
448
+
449
+ # 6. Denoising loop
450
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
451
+ for i, t in enumerate(timesteps):
452
+ if self.interrupt:
453
+ continue
454
+
455
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
456
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
457
+ noise_pred = self.transformer(
458
+ hidden_states=torch.cat((latents, masked_image_latents), dim=-1),
459
+ timestep=timestep / 1000,
460
+ guidance=guidance,
461
+ pooled_projections=pooled_prompt_embeds,
462
+ encoder_hidden_states=prompt_embeds,
463
+ txt_ids=text_ids,
464
+ img_ids=latent_image_ids,
465
+ joint_attention_kwargs=self.joint_attention_kwargs,
466
+ return_dict=False,
467
+ )[0]
468
+
469
+ # compute the previous noisy sample x_t -> x_t-1
470
+ latents_dtype = latents.dtype
471
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
472
+
473
+ if latents.dtype != latents_dtype:
474
+ if torch.backends.mps.is_available():
475
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
476
+ latents = latents.to(latents_dtype)
477
+
478
+ if callback_on_step_end is not None:
479
+ callback_kwargs = {}
480
+ for k in callback_on_step_end_tensor_inputs:
481
+ callback_kwargs[k] = locals()[k]
482
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
483
+
484
+ latents = callback_outputs.pop("latents", latents)
485
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
486
+
487
+ # call the callback, if provided
488
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
489
+ progress_bar.update()
490
+
491
+ if XLA_AVAILABLE:
492
+ xm.mark_step()
493
+
494
+ if output_type == "latent":
495
+ image = latents
496
+
497
+ else:
498
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
499
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
500
+ image = self.vae.decode(latents, return_dict=False)[0]
501
+ image = self.image_processor.postprocess(image, output_type=output_type)
502
+
503
+ # Offload all models
504
+ self.maybe_free_model_hooks()
505
+
506
+ if not return_dict:
507
+ return (image,)
508
+
509
+ return FluxPipelineOutput(images=image)
510
+
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio==5.6.0
2
+ transformers==4.45.0
3
+ diffusers==0.33.1
4
+ sentencepiece==0.2.0
5
+ peft==0.13.2
6
+ einops
7
+ omegaconf
8
+ safetensors
9
+ torch==2.7.0
10
+ torchvision==0.22.0