QingyuShi commited on
Commit
7c8069d
·
verified ·
1 Parent(s): 7f353c0

Upload folder using huggingface_hub

Browse files
README.md CHANGED
@@ -1,13 +1,100 @@
1
  ---
2
- title: Muddit
3
- emoji: 🌖
4
- colorFrom: red
5
  colorTo: purple
6
  sdk: gradio
7
- sdk_version: 5.38.2
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: Muddit Interface
3
+ emoji: 🎨
4
+ colorFrom: blue
5
  colorTo: purple
6
  sdk: gradio
7
+ sdk_version: 4.0.0
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
11
  ---
12
 
13
+ # 🎨 Muddit Interface
14
+
15
+ A unified model interface for **Text-to-Image generation** and **Visual Question Answering (VQA)** powered by advanced transformer architectures.
16
+
17
+ ## ✨ Features
18
+
19
+ ### 🖼️ Text-to-Image Generation
20
+ - Generate high-quality images from detailed text descriptions
21
+ - Customizable parameters (resolution, inference steps, CFG scale, seed)
22
+ - Support for negative prompts to avoid unwanted elements
23
+ - Real-time generation with progress tracking
24
+
25
+ ### ❓ Visual Question Answering
26
+ - Upload images and ask natural language questions
27
+ - Get detailed descriptions and answers about image content
28
+ - Support for various question types (counting, description, identification)
29
+ - Advanced visual understanding capabilities
30
+
31
+ ## 🚀 How to Use
32
+
33
+ ### Text-to-Image
34
+ 1. Go to the **"🖼️ Text-to-Image"** tab
35
+ 2. Enter your text description in the **Prompt** field
36
+ 3. Optionally add a **Negative Prompt** to exclude unwanted elements
37
+ 4. Adjust parameters as needed:
38
+ - **Width/Height**: Image resolution (256-1024px)
39
+ - **Inference Steps**: Quality vs speed (1-100)
40
+ - **CFG Scale**: Prompt adherence (1.0-20.0)
41
+ - **Seed**: For reproducible results
42
+ 5. Click **"🎨 Generate Image"**
43
+
44
+ ### Visual Question Answering
45
+ 1. Go to the **"❓ Visual Question Answering"** tab
46
+ 2. **Upload an image** using the image input
47
+ 3. **Ask a question** about the image
48
+ 4. Adjust processing parameters if needed
49
+ 5. Click **"🤔 Ask Question"** to get an answer
50
+
51
+ ## 📝 Example Prompts
52
+
53
+ ### Text-to-Image Examples:
54
+ - "A majestic night sky awash with billowing clouds, sparkling with a million twinkling stars"
55
+ - "A hyper realistic image of a chimpanzee with a glass-enclosed brain on his head, standing amidst lush, bioluminescent foliage"
56
+ - "A samurai in a stylized cyberpunk outfit adorned with intricate steampunk gear and floral accents"
57
+
58
+ ### VQA Examples:
59
+ - "What objects do you see in this image?"
60
+ - "How many people are in the picture?"
61
+ - "What is the main subject of this image?"
62
+ - "Describe the scene in detail"
63
+ - "What colors dominate this image?"
64
+
65
+ ## 🛠️ Technical Details
66
+
67
+ - **Architecture**: Unified transformer-based model
68
+ - **Text Encoder**: CLIP for text understanding
69
+ - **Vision Encoder**: VQ-VAE for image processing
70
+ - **Generation**: Advanced diffusion-based synthesis
71
+ - **VQA**: Multimodal understanding with attention mechanisms
72
+
73
+ ## ⚙️ Parameters Guide
74
+
75
+ | Parameter | Description | Recommended Range |
76
+ |-----------|-------------|-------------------|
77
+ | **Inference Steps** | More steps = higher quality, slower generation | 20-64 |
78
+ | **CFG Scale** | How closely to follow the prompt | 7.0-12.0 |
79
+ | **Resolution** | Output image size | 512x512 to 1024x1024 |
80
+ | **Seed** | For reproducible results | Any integer or -1 for random |
81
+
82
+ ## 🎯 Use Cases
83
+
84
+ - **Creative Content**: Generate artwork, illustrations, concepts
85
+ - **Visual Analysis**: Analyze and understand image content
86
+ - **Education**: Learn about visual AI and multimodal models
87
+ - **Research**: Explore capabilities of unified vision-language models
88
+ - **Accessibility**: Describe images for visually impaired users
89
+
90
+ ## 📄 License
91
+
92
+ This project is licensed under the Apache 2.0 License.
93
+
94
+ ## 🤝 Contributing
95
+
96
+ Feedback and contributions are welcome! Please feel free to submit issues or pull requests.
97
+
98
+ ---
99
+
100
+ *Powered by Gradio and Hugging Face Spaces* 🤗
app.py ADDED
@@ -0,0 +1,322 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ import warnings
4
+ import tempfile
5
+
6
+ import gradio as gr
7
+ import torch
8
+ from PIL import Image
9
+ import numpy as np
10
+
11
+ from transformers import (
12
+ CLIPTextModelWithProjection,
13
+ CLIPTokenizer,
14
+ )
15
+ from diffusers.models.autoencoders.vq_model import VQModel
16
+
17
+ from src.transformer import SymmetricTransformer2DModel
18
+ from src.pipeline import UnifiedPipeline
19
+ from src.scheduler import Scheduler
20
+ from train.trainer_utils import load_images_to_tensor
21
+
22
+ # Suppress FutureWarnings to reduce clutter
23
+ warnings.filterwarnings("ignore", category=FutureWarning)
24
+
25
+ # Set Gradio temp directory to a writable location
26
+ def setup_gradio_temp_dir():
27
+ """Setup a writable temp directory for Gradio with fallback options"""
28
+ possible_dirs = [
29
+ os.path.join(os.getcwd(), "gradio_tmp"), # Project directory
30
+ os.path.join(os.path.expanduser("~"), ".gradio_tmp"), # Home directory
31
+ tempfile.mkdtemp(prefix="gradio_") # System temp with unique name
32
+ ]
33
+
34
+ for temp_dir in possible_dirs:
35
+ try:
36
+ os.makedirs(temp_dir, exist_ok=True)
37
+ # Test write permission
38
+ test_file = os.path.join(temp_dir, "test_write.tmp")
39
+ with open(test_file, "w") as f:
40
+ f.write("test")
41
+ os.remove(test_file)
42
+
43
+ os.environ["GRADIO_TEMP_DIR"] = temp_dir
44
+ print(f"✅ Gradio temp directory set to: {temp_dir}")
45
+ return temp_dir
46
+ except (PermissionError, OSError) as e:
47
+ print(f"⚠️ Cannot use {temp_dir}: {e}")
48
+ continue
49
+
50
+ raise RuntimeError("Could not find a writable directory for Gradio temp files")
51
+
52
+ setup_gradio_temp_dir()
53
+
54
+
55
+
56
+ class MudditInterface:
57
+ def __init__(self, model_path="MeissonFlow/Meissonic", transformer_path="QingyuShi/Muddit"):
58
+ if torch.cuda.is_available():
59
+ device = "cuda"
60
+ else:
61
+ device = "cpu"
62
+ self.device = device
63
+ self.model_path = model_path
64
+ self.transformer_path = transformer_path or model_path
65
+
66
+ print("Loading models...")
67
+ self.load_models()
68
+ print("Models loaded successfully!")
69
+
70
+ def load_models(self):
71
+ """Load all required models"""
72
+ try:
73
+ print("📥 Loading transformer model...")
74
+ self.model = SymmetricTransformer2DModel.from_pretrained(
75
+ self.transformer_path,
76
+ subfolder="transformer",
77
+ )
78
+ print("📥 Loading VQ model...")
79
+ self.vq_model = VQModel.from_pretrained(
80
+ self.model_path,
81
+ subfolder="vqvae"
82
+ )
83
+ print("📥 Loading text encoder...")
84
+ self.text_encoder = CLIPTextModelWithProjection.from_pretrained(
85
+ self.model_path,
86
+ subfolder="text_encoder"
87
+ )
88
+ print("📥 Loading tokenizer...")
89
+ self.tokenizer = CLIPTokenizer.from_pretrained(
90
+ self.model_path,
91
+ subfolder="tokenizer"
92
+ )
93
+ print("📥 Loading scheduler...")
94
+ self.scheduler = Scheduler.from_pretrained(
95
+ self.model_path,
96
+ subfolder="scheduler"
97
+ )
98
+
99
+ print("🔧 Assembling pipeline...")
100
+ self.pipe = UnifiedPipeline(
101
+ vqvae=self.vq_model,
102
+ tokenizer=self.tokenizer,
103
+ text_encoder=self.text_encoder,
104
+ transformer=self.model,
105
+ scheduler=self.scheduler,
106
+ )
107
+ print(f"🚀 Moving models to {self.device}...")
108
+ self.pipe.to(self.device)
109
+ except Exception as e:
110
+ print(f"❌ Error loading models: {str(e)}")
111
+ raise
112
+
113
+ def text_to_image(self, prompt, negative_prompt, height, width, steps, cfg_scale, seed):
114
+ """Generate image from text prompt"""
115
+ try:
116
+ if seed == -1:
117
+ generator = None
118
+ else:
119
+ generator = torch.manual_seed(seed)
120
+
121
+ if not negative_prompt:
122
+ negative_prompt = "worst quality, low quality, low res, blurry, distortion, watermark, logo, signature, text, jpeg artifacts, signature, sketch, duplicate, ugly, identifying mark"
123
+
124
+ output = self.pipe(
125
+ prompt=[prompt],
126
+ negative_prompt=negative_prompt,
127
+ height=height,
128
+ width=width,
129
+ guidance_scale=cfg_scale,
130
+ num_inference_steps=steps,
131
+ mask_token_embedding=None,
132
+ generator=generator
133
+ )
134
+
135
+ if hasattr(output, 'images') and len(output.images) > 0:
136
+ return output.images[0]
137
+ else:
138
+ return None
139
+
140
+ except Exception as e:
141
+ gr.Error(f"Error generating image: {str(e)}")
142
+ return None
143
+
144
+ def image_to_text(self, image, question, height, width, steps, cfg_scale):
145
+ """Answer question about the image"""
146
+ try:
147
+ if image is None:
148
+ return "Please upload an image."
149
+
150
+ # Convert PIL image to tensor
151
+ if isinstance(image, np.ndarray):
152
+ image = Image.fromarray(image)
153
+
154
+ # Save image temporarily and load using the existing function
155
+ temp_path = "temp_image.jpg"
156
+ image.save(temp_path)
157
+
158
+ try:
159
+ images = load_images_to_tensor(temp_path, target_size=(height, width))
160
+ finally:
161
+ if os.path.exists(temp_path):
162
+ os.remove(temp_path)
163
+
164
+ if images is None:
165
+ return "Failed to process the image."
166
+
167
+ questions = [question] * len(images)
168
+
169
+ output = self.pipe(
170
+ prompt=questions,
171
+ image=images,
172
+ height=height,
173
+ width=width,
174
+ guidance_scale=cfg_scale,
175
+ num_inference_steps=steps,
176
+ mask_token_embedding=None,
177
+ )
178
+
179
+ if hasattr(output, 'prompts') and len(output.prompts) > 0:
180
+ return output.prompts[0]
181
+ else:
182
+ return "No response generated."
183
+
184
+ except Exception as e:
185
+ return f"Error processing image: {str(e)}"
186
+
187
+
188
+ def create_muddit_interface():
189
+ # Initialize the model interface
190
+ interface = MudditInterface()
191
+
192
+ with gr.Blocks(title="Muddit Interface", theme=gr.themes.Soft()) as demo:
193
+ gr.Markdown("# 🎨 Muddit Interface")
194
+ gr.Markdown("Generate images from text or ask questions about images using Muddit.")
195
+
196
+ with gr.Tabs():
197
+ # Text-to-Image Tab
198
+ with gr.TabItem("🖼️ Text-to-Image"):
199
+ gr.Markdown("### Generate images from text descriptions")
200
+
201
+ with gr.Row():
202
+ with gr.Column(scale=1):
203
+ t2i_prompt = gr.Textbox(
204
+ label="Prompt",
205
+ placeholder="A majestic night sky awash with billowing clouds, sparkling with a million twinkling stars",
206
+ lines=3
207
+ )
208
+ t2i_negative = gr.Textbox(
209
+ label="Negative Prompt (optional)",
210
+ placeholder="worst quality, low quality, blurry...",
211
+ lines=2
212
+ )
213
+
214
+ with gr.Row():
215
+ t2i_width = gr.Slider(
216
+ minimum=256, maximum=1024, value=1024, step=64,
217
+ label="Width"
218
+ )
219
+ t2i_height = gr.Slider(
220
+ minimum=256, maximum=1024, value=1024, step=64,
221
+ label="Height"
222
+ )
223
+
224
+ with gr.Row():
225
+ t2i_steps = gr.Slider(
226
+ minimum=1, maximum=100, value=64, step=1,
227
+ label="Inference Steps"
228
+ )
229
+ t2i_cfg = gr.Slider(
230
+ minimum=1.0, maximum=20.0, value=9.0, step=0.5,
231
+ label="CFG Scale"
232
+ )
233
+
234
+ t2i_seed = gr.Number(
235
+ label="Seed (-1 for random)",
236
+ value=42,
237
+ precision=0
238
+ )
239
+
240
+ t2i_generate = gr.Button("🎨 Generate Image", variant="primary")
241
+
242
+ with gr.Column(scale=1):
243
+ t2i_output = gr.Image(label="Generated Image", type="pil")
244
+
245
+ t2i_generate.click(
246
+ fn=interface.text_to_image,
247
+ inputs=[t2i_prompt, t2i_negative, t2i_height, t2i_width, t2i_steps, t2i_cfg, t2i_seed],
248
+ outputs=[t2i_output]
249
+ )
250
+
251
+ # Visual Question Answering Tab
252
+ with gr.TabItem("❓ Visual Question Answering"):
253
+ gr.Markdown("### Ask questions about images")
254
+
255
+ with gr.Row():
256
+ with gr.Column(scale=1):
257
+ vqa_image = gr.Image(
258
+ label="Upload Image",
259
+ type="pil"
260
+ )
261
+ vqa_question = gr.Textbox(
262
+ label="Question",
263
+ placeholder="What do you see in this image?",
264
+ lines=2
265
+ )
266
+
267
+ with gr.Row():
268
+ vqa_width = gr.Slider(
269
+ minimum=256, maximum=1024, value=1024, step=64,
270
+ label="Width"
271
+ )
272
+ vqa_height = gr.Slider(
273
+ minimum=256, maximum=1024, value=1024, step=64,
274
+ label="Height"
275
+ )
276
+
277
+ with gr.Row():
278
+ vqa_steps = gr.Slider(
279
+ minimum=1, maximum=100, value=64, step=1,
280
+ label="Inference Steps"
281
+ )
282
+ vqa_cfg = gr.Slider(
283
+ minimum=1.0, maximum=20.0, value=9.0, step=0.5,
284
+ label="CFG Scale"
285
+ )
286
+
287
+ vqa_submit = gr.Button("🤔 Ask Question", variant="primary")
288
+
289
+ with gr.Column(scale=1):
290
+ vqa_output = gr.Textbox(
291
+ label="Answer",
292
+ lines=5,
293
+ interactive=False
294
+ )
295
+
296
+ vqa_submit.click(
297
+ fn=interface.image_to_text,
298
+ inputs=[vqa_image, vqa_question, vqa_height, vqa_width, vqa_steps, vqa_cfg],
299
+ outputs=[vqa_output]
300
+ )
301
+
302
+ # Example section
303
+ with gr.Accordion("📝 Examples", open=False):
304
+ gr.Markdown("""
305
+ ### Text-to-Image Examples:
306
+ - "A majestic night sky awash with billowing clouds, sparkling with a million twinkling stars"
307
+ - "A hyper realistic image of a chimpanzee with a glass-enclosed brain on his head"
308
+ - "A samurai in a stylized cyberpunk outfit adorned with intricate steampunk gear"
309
+
310
+ ### VQA Examples:
311
+ - "What objects do you see in this image?"
312
+ - "How many people are in the picture?"
313
+ - "What is the main subject of this image?"
314
+ - "Describe the scene in detail"
315
+ """)
316
+
317
+ return demo
318
+
319
+
320
+ if __name__ == "__main__":
321
+ demo = create_muddit_interface()
322
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ gradio>=4.0.0
2
+ torch>=2.0.0
3
+ torchvision>=0.15.0
4
+ transformers>=4.40.0
5
+ diffusers>=0.30.0
6
+ pillow>=9.0.0
7
+ numpy>=1.21.0
8
+ accelerate>=0.20.0
9
+ safetensors>=0.3.0
src/__pycache__/pipeline.cpython-310.pyc ADDED
Binary file (27.8 kB). View file
 
src/__pycache__/pipeline.cpython-39.pyc ADDED
Binary file (27.5 kB). View file
 
src/__pycache__/scheduler.cpython-310.pyc ADDED
Binary file (5.12 kB). View file
 
src/__pycache__/scheduler.cpython-39.pyc ADDED
Binary file (5.09 kB). View file
 
src/__pycache__/transformer.cpython-310.pyc ADDED
Binary file (37.5 kB). View file
 
src/__pycache__/transformer.cpython-39.pyc ADDED
Binary file (37.4 kB). View file
 
src/pipeline.py ADDED
@@ -0,0 +1,1231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team and The MeissonFlow Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import os
15
+ import sys
16
+ from dataclasses import dataclass
17
+
18
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
19
+ import PIL.Image
20
+ import torch
21
+ import PIL
22
+ import numpy as np
23
+
24
+ from transformers import (
25
+ CLIPTextModelWithProjection,
26
+ CLIPTokenizer,
27
+ CLIPImageProcessor,
28
+ CLIPVisionModelWithProjection,
29
+ )
30
+ from transformers.models.gemma2.modeling_gemma2 import Gemma2Model
31
+ from transformers.models.gemma.tokenization_gemma_fast import GemmaTokenizerFast
32
+
33
+ from diffusers.image_processor import VaeImageProcessor
34
+ from diffusers.models import VQModel
35
+ from diffusers.utils import replace_example_docstring
36
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
37
+ from diffusers.utils import BaseOutput
38
+
39
+ from src.scheduler import Scheduler
40
+ from src.transformer import SymmetricTransformer2DModel
41
+
42
+
43
+ EXAMPLE_DOC_STRING = """
44
+ Examples:
45
+ ```py
46
+ >>> image = pipe(prompt).images[0]
47
+ ```
48
+ """
49
+
50
+
51
+ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
52
+ latent_image_ids = torch.zeros(height // 2, width // 2, 3)
53
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
54
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
55
+
56
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
57
+
58
+ latent_image_ids = latent_image_ids.reshape(
59
+ latent_image_id_height * latent_image_id_width, latent_image_id_channels
60
+ )
61
+
62
+ return latent_image_ids.to(device=device, dtype=dtype)
63
+
64
+ def dedup_consecutive_words(text: str) -> str:
65
+ """
66
+ >>> dedup_consecutive_words("hello hello world world world")
67
+ 'hello world'
68
+ """
69
+ words = text.split()
70
+ if not words:
71
+ return text
72
+
73
+ out = [words[0]]
74
+ for w in words[1:]:
75
+ if w != out[-1]:
76
+ out.append(w)
77
+ return " ".join(out)
78
+
79
+ def keep_upto_last_period(text: str) -> str:
80
+ """
81
+ Return the substring up to (and including) the last period-mark.
82
+
83
+ The function searches first for the Chinese full stop “。”;
84
+ if none is found, it falls back to the ASCII dot “.”.
85
+
86
+ Parameters
87
+ ----------
88
+ text : str
89
+ Input string.
90
+
91
+ Returns
92
+ -------
93
+ str
94
+ Substring ending at the final period-mark. If no period is present,
95
+ the original string is returned unchanged.
96
+ """
97
+ # Weired problem
98
+ text = text.replace("is such is", "").replace("such is", "")
99
+ # Fallback to the ASCII period
100
+ idx = -1
101
+ if idx == -1:
102
+ idx = text.rfind(".")
103
+ # If still not found, return original text
104
+ if idx == -1:
105
+ return text
106
+ # Keep everything up to (and including) the last period
107
+ return text[:idx + 1]
108
+
109
+ @dataclass
110
+ class UnifiedPipelineOutput(BaseOutput):
111
+ """
112
+ Output class for image pipelines.
113
+
114
+ Args:
115
+ images (`List[PIL.Image.Image]` or `np.ndarray`)
116
+ List of denoised PIL images of length `batch_size` or NumPy array of shape `(batch_size, height, width,
117
+ num_channels)`.
118
+ """
119
+
120
+ images: Union[List[PIL.Image.Image], np.ndarray]
121
+ prompts: List[str]
122
+
123
+
124
+ class UnifiedPipeline(DiffusionPipeline):
125
+ image_processor: VaeImageProcessor
126
+ vqvae: VQModel
127
+ tokenizer: CLIPTokenizer
128
+ tokenizer_2: GemmaTokenizerFast
129
+ text_encoder: CLIPTextModelWithProjection
130
+ text_encoder_2: Gemma2Model
131
+ transformer: SymmetricTransformer2DModel
132
+ scheduler: Scheduler
133
+ model_cpu_offload_seq = "text_encoder->transformer->vqvae"
134
+
135
+ def __init__(
136
+ self,
137
+ vqvae: VQModel,
138
+ tokenizer: CLIPTokenizer,
139
+ text_encoder: CLIPTextModelWithProjection,
140
+ transformer: SymmetricTransformer2DModel,
141
+ scheduler: Scheduler,
142
+ tokenizer_2: GemmaTokenizerFast = None,
143
+ text_encoder_2: Gemma2Model = None,
144
+ ):
145
+ super().__init__()
146
+
147
+ self.register_modules(
148
+ vqvae=vqvae,
149
+ tokenizer=tokenizer,
150
+ tokenizer_2=tokenizer_2,
151
+ text_encoder=text_encoder,
152
+ text_encoder_2=text_encoder_2,
153
+ transformer=transformer,
154
+ scheduler=scheduler,
155
+ )
156
+ self.vae_scale_factor = 2 ** (len(self.vqvae.config.block_out_channels) - 1)
157
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_normalize=False)
158
+
159
+ @torch.no_grad()
160
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
161
+ def __call__(
162
+ self,
163
+ prompt: Optional[Union[List[str], str]] = None,
164
+ height: Optional[int] = 1024,
165
+ width: Optional[int] = 1024,
166
+ image: Optional[Union[torch.Tensor, PIL.Image.Image]] = None,
167
+ num_inference_steps: int = 48,
168
+ guidance_scale: float = 9.0,
169
+ negative_prompt: Optional[Union[str, List[str]]] = None,
170
+ num_images_per_prompt: Optional[int] = 1,
171
+ generator: Optional[torch.Generator] = None,
172
+ latents: Optional[torch.IntTensor] = None,
173
+ prompt_embeds: Optional[torch.Tensor] = None,
174
+ encoder_hidden_states: Optional[torch.Tensor] = None,
175
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
176
+ negative_encoder_hidden_states: Optional[torch.Tensor] = None,
177
+ output_type = "pil",
178
+ return_dict: bool = True,
179
+ callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
180
+ callback_steps: int = 1,
181
+ micro_conditioning_aesthetic_score: int = 6,
182
+ micro_conditioning_crop_coord: Tuple[int, int] = (0, 0),
183
+ temperature: Union[int, Tuple[int, int], List[int]] = (2, 0),
184
+ mask_token_embedding: Optional[str] = None,
185
+ ):
186
+ """
187
+ The call function to the pipeline for generation.
188
+
189
+ Args:
190
+ prompt (`str` or `List[str]`, *optional*):
191
+ The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
192
+ height (`int`, *optional*, defaults to `self.transformer.config.sample_size * self.vae_scale_factor`):
193
+ The height in pixels of the generated image.
194
+ width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
195
+ The width in pixels of the generated image.
196
+ num_inference_steps (`int`, *optional*, defaults to 16):
197
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
198
+ expense of slower inference.
199
+ guidance_scale (`float`, *optional*, defaults to 10.0):
200
+ A higher guidance scale value encourages the model to generate images closely linked to the text
201
+ `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
202
+ negative_prompt (`str` or `List[str]`, *optional*):
203
+ The prompt or prompts to guide what to not include in image generation. If not defined, you need to
204
+ pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
205
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
206
+ The number of images to generate per prompt.
207
+ generator (`torch.Generator`, *optional*):
208
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
209
+ generation deterministic.
210
+ latents (`torch.IntTensor`, *optional*):
211
+ Pre-generated tokens representing latent vectors in `self.vqvae`, to be used as inputs for image
212
+ gneration. If not provided, the starting latents will be completely masked.
213
+ prompt_embeds (`torch.Tensor`, *optional*):
214
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
215
+ provided, text embeddings are generated from the `prompt` input argument. A single vector from the
216
+ pooled and projected final hidden states.
217
+ encoder_hidden_states (`torch.Tensor`, *optional*):
218
+ Pre-generated penultimate hidden states from the text encoder providing additional text conditioning.
219
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
220
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
221
+ not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
222
+ negative_encoder_hidden_states (`torch.Tensor`, *optional*):
223
+ Analogous to `encoder_hidden_states` for the positive prompt.
224
+ output_type (`str`, *optional*, defaults to `"pil"`):
225
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
226
+ return_dict (`bool`, *optional*, defaults to `True`):
227
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
228
+ plain tuple.
229
+ callback (`Callable`, *optional*):
230
+ A function that calls every `callback_steps` steps during inference. The function is called with the
231
+ following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.
232
+ callback_steps (`int`, *optional*, defaults to 1):
233
+ The frequency at which the `callback` function is called. If not specified, the callback is called at
234
+ every step.
235
+ cross_attention_kwargs (`dict`, *optional*):
236
+ A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
237
+ [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
238
+ micro_conditioning_aesthetic_score (`int`, *optional*, defaults to 6):
239
+ The targeted aesthetic score according to the laion aesthetic classifier. See
240
+ https://laion.ai/blog/laion-aesthetics/ and the micro-conditioning section of
241
+ https://arxiv.org/abs/2307.01952.
242
+ micro_conditioning_crop_coord (`Tuple[int]`, *optional*, defaults to (0, 0)):
243
+ The targeted height, width crop coordinates. See the micro-conditioning section of
244
+ https://arxiv.org/abs/2307.01952.
245
+ temperature (`Union[int, Tuple[int, int], List[int]]`, *optional*, defaults to (2, 0)):
246
+ Configures the temperature scheduler on `self.scheduler` see `Scheduler#set_timesteps`.
247
+
248
+ Examples:
249
+
250
+ Returns:
251
+ [`~pipelines.pipeline_utils.ImagePipelineOutput`] or `tuple`:
252
+ If `return_dict` is `True`, [`~pipelines.pipeline_utils.ImagePipelineOutput`] is returned, otherwise a
253
+ `tuple` is returned where the first element is a list with the generated images.
254
+ """
255
+ if (prompt_embeds is not None and encoder_hidden_states is None) or (
256
+ prompt_embeds is None and encoder_hidden_states is not None
257
+ ):
258
+ raise ValueError("pass either both `prompt_embeds` and `encoder_hidden_states` or neither")
259
+
260
+ if (negative_prompt_embeds is not None and negative_encoder_hidden_states is None) or (
261
+ negative_prompt_embeds is None and negative_encoder_hidden_states is not None
262
+ ):
263
+ raise ValueError(
264
+ "pass either both `negatve_prompt_embeds` and `negative_encoder_hidden_states` or neither"
265
+ )
266
+
267
+ if self.text_encoder_2 is not None:
268
+ self.text_encoder_2.to(self._execution_device)
269
+
270
+ text2image = image is None
271
+ image2text = image is not None
272
+
273
+ if image2text:
274
+ if self.text_encoder_2 is not None:
275
+ prompt = "<extra_id_0>" * 256
276
+ prompt = [prompt] * len(image)
277
+
278
+ text_encoder_2_mask_id = self.tokenizer_2.convert_tokens_to_ids("<extra_id_0>")
279
+ self.scheduler.config.mask_token_id = text_encoder_2_mask_id
280
+ else:
281
+ mask_token = "<mask>"
282
+ self.tokenizer.add_tokens(mask_token, special_tokens=False)
283
+ clip_mask_id = self.tokenizer.convert_tokens_to_ids(mask_token)
284
+ self.text_encoder.resize_token_embeddings(len(self.tokenizer))
285
+
286
+ if mask_token_embedding is not None:
287
+ try:
288
+ if mask_token_embedding.endswith(".pth"):
289
+ mask_token_embedding = torch.load(mask_token_embedding)
290
+ else:
291
+ mask_token_embedding_path = os.path.join(mask_token_embedding, "mask_token_embedding.pth")
292
+ assert os.path.exists(mask_token_embedding_path), f"{mask_token_embedding_path} doesn't exists!"
293
+ mask_token_embedding = torch.load(mask_token_embedding_path)
294
+
295
+ mask_token_embedding = mask_token_embedding.to(self._execution_device, dtype=self.text_encoder.dtype)
296
+ self.text_encoder.get_input_embeddings().weight.data[clip_mask_id].copy_(mask_token_embedding)
297
+
298
+ except Exception as e:
299
+ print(f"Error loading mask token embedding: {e}")
300
+ print("Using random initialized mask token embedding")
301
+ mask_token_embedding = None
302
+
303
+ self.scheduler.config.mask_token_id = clip_mask_id
304
+
305
+ input_ids = torch.ones(
306
+ size=(len(image), self.tokenizer.model_max_length),
307
+ dtype=torch.int64,
308
+ device=self._execution_device
309
+ )
310
+ input_ids = input_ids * clip_mask_id
311
+
312
+ question_len = []
313
+ if prompt is None:
314
+ question_len = [0] * len(image)
315
+ elif isinstance(prompt, str):
316
+ question_ids = torch.LongTensor([self.tokenizer.encode(prompt)])
317
+ question_ids = question_ids.repeat(len(image), 1)
318
+
319
+ q_len = len(question_ids[0]) - 1 # remove <eos> token
320
+ question_len = [q_len] * len(image)
321
+
322
+ input_ids[:, :q_len] = question_ids[:, :-1]
323
+ else:
324
+ assert isinstance(prompt, list), f"prompt must be None or str or list!"
325
+ assert len(prompt) == len(image), f"VQA require equal num of images and prompts!"
326
+ for i, p in enumerate(prompt):
327
+ question_ids = torch.LongTensor([self.tokenizer.encode(p)])
328
+
329
+ q_len = len(question_ids[0]) - 1
330
+ question_len.append(q_len)
331
+
332
+ input_ids[i, :q_len] = question_ids[0, :-1]
333
+ else:
334
+ self.scheduler.config.mask_token_id = self.transformer.config.vocab_size - 1
335
+
336
+ if isinstance(prompt, str):
337
+ prompt = [prompt]
338
+
339
+ if image is not None:
340
+ batch_size = len(image)
341
+ else:
342
+ batch_size = len(prompt)
343
+
344
+ if height is None:
345
+ height = self.transformer.config.sample_size * self.vae_scale_factor
346
+
347
+ if width is None:
348
+ width = self.transformer.config.sample_size * self.vae_scale_factor
349
+
350
+ if isinstance(self.text_encoder, CLIPTextModelWithProjection):
351
+ text_encoder_type = "open_clip"
352
+ if isinstance(self.text_encoder_2, Gemma2Model):
353
+ text_encoder_type = "gemma"
354
+
355
+ if prompt_embeds is None:
356
+ if text_encoder_type == "t5_clip":
357
+ if text2image:
358
+ input_ids_clip = self.tokenizer(
359
+ prompt,
360
+ return_tensors="pt",
361
+ padding="max_length",
362
+ truncation=True,
363
+ add_special_tokens=True,
364
+ max_length=77,
365
+ ).input_ids.to(self._execution_device)
366
+ outputs = self.text_encoder(input_ids_clip, return_dict=True, output_hidden_states=True)
367
+ prompt_embeds = outputs.text_embeds
368
+
369
+ input_ids_t5 = self.tokenizer_2(
370
+ prompt,
371
+ return_tensors="pt",
372
+ padding="max_length",
373
+ truncation=True,
374
+ add_special_tokens=True,
375
+ max_length=256,
376
+ ).input_ids.to(self._execution_device)
377
+
378
+ outputs_2 = self.text_encoder_2(input_ids_t5, return_dict=True, output_hidden_states=True)
379
+ encoder_hidden_states = outputs_2.last_hidden_state
380
+ elif text_encoder_type == "open_clip":
381
+ if text2image:
382
+ input_ids = self.tokenizer(
383
+ prompt,
384
+ return_tensors="pt",
385
+ padding="max_length",
386
+ truncation=True,
387
+ add_special_tokens=True,
388
+ max_length=77,
389
+ ).input_ids.to(self._execution_device)
390
+
391
+ outputs = self.text_encoder(input_ids, return_dict=True, output_hidden_states=True)
392
+ prompt_embeds = outputs.text_embeds
393
+ encoder_hidden_states = outputs.hidden_states[-2]
394
+ elif text_encoder_type == "gemma":
395
+ if text2image:
396
+ input_ids_clip = self.tokenizer(
397
+ prompt,
398
+ return_tensors="pt",
399
+ padding="max_length",
400
+ truncation=True,
401
+ add_special_tokens=True,
402
+ max_length=77,
403
+ ).input_ids.to(self._execution_device)
404
+ outputs = self.text_encoder(input_ids_clip, return_dict=True, output_hidden_states=True)
405
+ prompt_embeds = outputs.text_embeds
406
+
407
+ input_ids_2 = self.tokenizer_2(
408
+ prompt,
409
+ truncation=True,
410
+ padding="max_length",
411
+ max_length=256,
412
+ return_tensors="pt",
413
+ ).input_ids.to(self._execution_device)
414
+
415
+ outputs_2 = self.text_encoder_2(input_ids_2, return_dict=True, output_hidden_states=True)
416
+ encoder_hidden_states = outputs_2.last_hidden_state
417
+
418
+ prompt_embeds = prompt_embeds.repeat(num_images_per_prompt, 1)
419
+ encoder_hidden_states = encoder_hidden_states.repeat(num_images_per_prompt, 1, 1)
420
+
421
+ if guidance_scale > 1.0 and text2image:
422
+ if negative_prompt_embeds is None:
423
+ if negative_prompt is None:
424
+ negative_prompt = [""] * len(prompt)
425
+
426
+ if isinstance(negative_prompt, str):
427
+ negative_prompt = [negative_prompt] * len(prompt)
428
+
429
+ if text_encoder_type == "t5_clip":
430
+ input_ids = self.tokenizer(
431
+ negative_prompt,
432
+ return_tensors="pt",
433
+ padding="max_length",
434
+ truncation=True,
435
+ add_special_tokens=True,
436
+ max_length=77,
437
+ ).input_ids.to(self._execution_device)
438
+ outputs = self.text_encoder(input_ids, return_dict=True, output_hidden_states=True)
439
+ negative_prompt_embeds = outputs.text_embeds
440
+
441
+ input_ids_2 = self.tokenizer_2(
442
+ negative_prompt,
443
+ return_tensors="pt",
444
+ padding="max_length",
445
+ truncation=True,
446
+ add_special_tokens=True,
447
+ max_length=256,
448
+ ).input_ids.to(self._execution_device)
449
+ outputs_2 = self.text_encoder_2(input_ids_2, return_dict=True, output_hidden_states=True)
450
+ negative_encoder_hidden_states = outputs_2.last_hidden_state
451
+
452
+ elif text_encoder_type == "open_clip":
453
+ input_ids = self.tokenizer(
454
+ negative_prompt,
455
+ return_tensors="pt",
456
+ padding="max_length",
457
+ truncation=True,
458
+ add_special_tokens=True,
459
+ max_length=77,
460
+ ).input_ids.to(self._execution_device)
461
+
462
+ outputs = self.text_encoder(input_ids, return_dict=True, output_hidden_states=True)
463
+
464
+ negative_prompt_embeds = outputs.text_embeds
465
+ negative_encoder_hidden_states = outputs.hidden_states[-2]
466
+
467
+ elif text_encoder_type == "gemma":
468
+ input_ids = self.tokenizer(
469
+ negative_prompt,
470
+ return_tensors="pt",
471
+ padding="max_length",
472
+ truncation=True,
473
+ add_special_tokens=True,
474
+ max_length=77,
475
+ ).input_ids.to(self._execution_device)
476
+ outputs = self.text_encoder(input_ids, return_dict=True, output_hidden_states=True)
477
+ negative_prompt_embeds = outputs.text_embeds
478
+
479
+ input_ids_2 = self.tokenizer_2(
480
+ negative_prompt,
481
+ truncation=True,
482
+ padding="max_length",
483
+ max_length=256,
484
+ return_tensors="pt",
485
+ ).input_ids.to(self._execution_device)
486
+ outputs_2 = self.text_encoder_2(input_ids_2, return_dict=True, output_hidden_states=True)
487
+ negative_encoder_hidden_states = outputs_2.last_hidden_state
488
+
489
+ negative_prompt_embeds = negative_prompt_embeds.repeat(num_images_per_prompt, 1)
490
+ negative_encoder_hidden_states = negative_encoder_hidden_states.repeat(num_images_per_prompt, 1, 1)
491
+
492
+ prompt_embeds = torch.concat([negative_prompt_embeds, prompt_embeds])
493
+ encoder_hidden_states = torch.concat([negative_encoder_hidden_states, encoder_hidden_states])
494
+
495
+ # Note that the micro conditionings _do_ flip the order of width, height for the original size
496
+ # and the crop coordinates. This is how it was done in the original code base
497
+ micro_conds = torch.tensor(
498
+ [
499
+ width,
500
+ height,
501
+ micro_conditioning_crop_coord[0],
502
+ micro_conditioning_crop_coord[1],
503
+ micro_conditioning_aesthetic_score,
504
+ ],
505
+ device=self._execution_device,
506
+ dtype=encoder_hidden_states.dtype,
507
+ )
508
+ micro_conds = micro_conds.unsqueeze(0)
509
+ micro_conds = micro_conds.expand(2 * batch_size if guidance_scale > 1.0 and text2image else batch_size, -1)
510
+
511
+ shape = (batch_size, height // self.vae_scale_factor, width // self.vae_scale_factor)
512
+
513
+ if latents is None and text2image:
514
+ latents = torch.full(
515
+ shape, self.scheduler.config.mask_token_id, dtype=torch.long, device=self._execution_device
516
+ )
517
+ elif image2text:
518
+ if text_encoder_type in ("t5_clip", "gemma"):
519
+ latents = input_ids_2 # [b, l]
520
+ else:
521
+ latents = input_ids
522
+
523
+ model_input = None
524
+
525
+ step_by_step = []
526
+
527
+ self.scheduler.set_timesteps(num_inference_steps, temperature, self._execution_device)
528
+ num_warmup_steps = len(self.scheduler.timesteps) - num_inference_steps * self.scheduler.order
529
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
530
+ for i, timestep in enumerate(self.scheduler.timesteps):
531
+ if guidance_scale > 1.0 and text2image:
532
+ model_input = torch.cat([latents] * 2)
533
+ encoder_hidden_states = encoder_hidden_states
534
+ elif image2text:
535
+ if model_input is None:
536
+ model_input = self.vqvae.quantize(
537
+ self.vqvae.encode(image.to(self._execution_device, dtype=self.vqvae.dtype)).latents
538
+ )[2][2].reshape(batch_size, height // self.vae_scale_factor, width // self.vae_scale_factor)
539
+
540
+ if text_encoder_type in ("t5_clip", "gemma"):
541
+ outputs_t5 = self.text_encoder_2(latents, return_dict=True)
542
+ encoder_hidden_states = outputs_t5.last_hidden_state
543
+
544
+ batch_prompt = []
545
+ for i in range(latents.size(0)):
546
+ masked_prompt_input_id = latents[i].tolist()
547
+ prompt = self.tokenizer_2.decode(masked_prompt_input_id, skip_special_tokens=True)
548
+ batch_prompt.append(prompt)
549
+
550
+ masked_prompt_input_ids_clip = self.tokenizer(
551
+ batch_prompt,
552
+ truncation=True,
553
+ padding="max_length",
554
+ max_length=77,
555
+ return_tensors="pt"
556
+ ).input_ids
557
+ masked_prompt_input_ids_clip = masked_prompt_input_ids_clip.to(self._execution_device)
558
+ outputs_clip = self.text_encoder(input_ids=masked_prompt_input_ids_clip, return_dict=True)
559
+ prompt_embeds = outputs_clip.text_embeds
560
+
561
+ else:
562
+ outputs = self.text_encoder(latents, return_dict=True, output_hidden_states=True)
563
+ prompt_embeds = outputs.text_embeds
564
+ encoder_hidden_states = outputs.hidden_states[-2]
565
+ else:
566
+ model_input = latents
567
+ encoder_hidden_states = encoder_hidden_states
568
+
569
+ if height == 1024: #args.resolution == 1024:
570
+ img_ids = _prepare_latent_image_ids(
571
+ model_input.shape[0],
572
+ model_input.shape[-2],
573
+ model_input.shape[-1],
574
+ model_input.device,
575
+ model_input.dtype
576
+ )
577
+ else:
578
+ img_ids = _prepare_latent_image_ids(
579
+ model_input.shape[0],
580
+ model_input.shape[-2],
581
+ model_input.shape[-1],
582
+ model_input.device,
583
+ model_input.dtype
584
+ )
585
+ txt_ids = torch.zeros(encoder_hidden_states.shape[1], 3).to(
586
+ device=encoder_hidden_states.device,
587
+ dtype=encoder_hidden_states.dtype
588
+ )
589
+
590
+ # timestep_ = int(timestep / num_inference_steps * 1000)
591
+ model_output, encoder_hidden_states_tmp = self.transformer(
592
+ hidden_states=model_input,
593
+ micro_conds=micro_conds,
594
+ pooled_projections=prompt_embeds,
595
+ encoder_hidden_states=encoder_hidden_states,
596
+ img_ids=img_ids,
597
+ txt_ids=txt_ids,
598
+ timestep=torch.tensor([timestep / num_inference_steps], device=model_input.device),
599
+ )
600
+
601
+ if image2text:
602
+ encoder_hidden_states = encoder_hidden_states_tmp.clone()
603
+
604
+ if guidance_scale > 1.0 and text2image:
605
+ uncond_logits, cond_logits = model_output.chunk(2)
606
+ to_scheduler = uncond_logits + guidance_scale * (cond_logits - uncond_logits)
607
+ elif image2text:
608
+ to_scheduler = encoder_hidden_states
609
+ else:
610
+ to_scheduler = model_output
611
+
612
+ latents = self.scheduler.step(
613
+ model_output=to_scheduler,
614
+ timestep=timestep,
615
+ sample=latents,
616
+ generator=generator,
617
+ ).prev_sample
618
+
619
+ # this line will print the intermediate results of the image-to-text generation
620
+ # step_by_step.append(self.tokenizer.decode(latents[0].tolist(), skip_special_tokens=True))
621
+
622
+ # this line will print the intermediate results of the text-to-image generation
623
+ # output = self.vqvae.decode(
624
+ # latents,
625
+ # force_not_quantize=True,
626
+ # shape=(
627
+ # batch_size,
628
+ # height // self.vae_scale_factor,
629
+ # width // self.vae_scale_factor,
630
+ # self.vqvae.config.latent_channels,
631
+ # ),
632
+ # ).sample.clip(0, 1)
633
+ # output = self.image_processor.postprocess(output, output_type) # output is a list of PIL.Image, you need to save it.
634
+
635
+ if i == len(self.scheduler.timesteps) - 1 or (
636
+ (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
637
+ ):
638
+ progress_bar.update()
639
+ if callback is not None and i % callback_steps == 0:
640
+ step_idx = i // getattr(self.scheduler, "order", 1)
641
+ callback(step_idx, timestep, latents)
642
+
643
+ # with open("step_by_step.txt", "w") as file:
644
+ # for prompt in step_by_step:
645
+ # file.write(prompt + "\n")
646
+
647
+ if guidance_scale > 1.0 and text2image:
648
+ decoded_input_ids = encoder_hidden_states[encoder_hidden_states.shape[0] // 2:].argmax(-1)
649
+ else:
650
+ decoded_input_ids = encoder_hidden_states.argmax(-1)
651
+
652
+ prompts = []
653
+ for i, prompt in enumerate(decoded_input_ids):
654
+ if image2text:
655
+ q_len = question_len[i]
656
+ prompt = self.tokenizer.decode(prompt.tolist()[q_len:], skip_special_tokens=True)
657
+ prompts.append(keep_upto_last_period(dedup_consecutive_words(prompt)))
658
+ else:
659
+ prompts.append("Placeholder")
660
+
661
+ if output_type == "latent":
662
+ output = latents
663
+ else:
664
+ needs_upcasting = self.vqvae.dtype == torch.float16 and self.vqvae.config.force_upcast
665
+
666
+ if needs_upcasting:
667
+ self.vqvae.float()
668
+
669
+ if text2image:
670
+ to_vqvae = latents
671
+ else:
672
+ to_vqvae = model_input
673
+
674
+ output = self.vqvae.decode(
675
+ to_vqvae,
676
+ force_not_quantize=True,
677
+ shape=(
678
+ batch_size,
679
+ height // self.vae_scale_factor,
680
+ width // self.vae_scale_factor,
681
+ self.vqvae.config.latent_channels,
682
+ ),
683
+ ).sample.clip(0, 1)
684
+ output = self.image_processor.postprocess(output, output_type)
685
+
686
+ if needs_upcasting:
687
+ self.vqvae.half()
688
+
689
+ self.maybe_free_model_hooks()
690
+
691
+ if not return_dict:
692
+ return (output,)
693
+
694
+ return UnifiedPipelineOutput(images=output, prompts=prompts)
695
+
696
+
697
+ class UnifiedPipeline_new(DiffusionPipeline):
698
+ image_processor: VaeImageProcessor
699
+ vqvae: VQModel
700
+ tokenizer: CLIPTokenizer
701
+ tokenizer_2: GemmaTokenizerFast
702
+ text_encoder: CLIPTextModelWithProjection
703
+ text_encoder_2: Gemma2Model
704
+ image_encoder: CLIPVisionModelWithProjection
705
+ clip_image_processor: CLIPImageProcessor
706
+ transformer: SymmetricTransformer2DModel
707
+ scheduler: Scheduler
708
+
709
+ def __init__(
710
+ self,
711
+ vqvae: VQModel,
712
+ tokenizer: CLIPTokenizer,
713
+ text_encoder: CLIPTextModelWithProjection,
714
+ transformer: SymmetricTransformer2DModel,
715
+ scheduler: Scheduler,
716
+ tokenizer_2: Optional[GemmaTokenizerFast]=None,
717
+ text_encoder_2: Optional[Gemma2Model]=None,
718
+ image_encoder: Optional[CLIPVisionModelWithProjection]=None,
719
+ clip_image_processor: Optional[CLIPImageProcessor]=None,
720
+ ):
721
+ super().__init__()
722
+
723
+ self.register_modules(
724
+ vqvae=vqvae,
725
+ tokenizer=tokenizer,
726
+ tokenizer_2=tokenizer_2,
727
+ text_encoder=text_encoder,
728
+ text_encoder_2=text_encoder_2,
729
+ image_encoder=image_encoder,
730
+ clip_image_processor=clip_image_processor,
731
+ transformer=transformer,
732
+ scheduler=scheduler,
733
+ )
734
+ self.vae_scale_factor = 2 ** (len(self.vqvae.config.block_out_channels) - 1)
735
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_normalize=False)
736
+
737
+ @torch.no_grad()
738
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
739
+ def __call__(
740
+ self,
741
+ prompt: Optional[Union[List[str], str]] = None,
742
+ height: Optional[int] = 1024,
743
+ width: Optional[int] = 1024,
744
+ image: Optional[torch.Tensor] = None,
745
+ num_inference_steps: int = 48,
746
+ guidance_scale: float = 9.0,
747
+ negative_prompt: Optional[Union[str, List[str]]] = None,
748
+ num_images_per_prompt: Optional[int] = 1,
749
+ generator: Optional[torch.Generator] = None,
750
+ latents: Optional[torch.IntTensor] = None,
751
+ prompt_embeds: Optional[torch.Tensor] = None,
752
+ encoder_hidden_states: Optional[torch.Tensor] = None,
753
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
754
+ negative_encoder_hidden_states: Optional[torch.Tensor] = None,
755
+ output_type = "pil",
756
+ return_dict: bool = True,
757
+ callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
758
+ callback_steps: int = 1,
759
+ micro_conditioning_aesthetic_score: int = 6,
760
+ micro_conditioning_crop_coord: Tuple[int, int] = (0, 0),
761
+ temperature: Union[int, Tuple[int, int], List[int]] = (2, 0),
762
+ mask_token_embedding: Optional[str] = None,
763
+ ):
764
+ """
765
+ The call function to the pipeline for generation.
766
+
767
+ Args:
768
+ prompt (`str` or `List[str]`, *optional*):
769
+ The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
770
+ height (`int`, *optional*, defaults to `self.transformer.config.sample_size * self.vae_scale_factor`):
771
+ The height in pixels of the generated image.
772
+ width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
773
+ The width in pixels of the generated image.
774
+ num_inference_steps (`int`, *optional*, defaults to 16):
775
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
776
+ expense of slower inference.
777
+ guidance_scale (`float`, *optional*, defaults to 10.0):
778
+ A higher guidance scale value encourages the model to generate images closely linked to the text
779
+ `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
780
+ negative_prompt (`str` or `List[str]`, *optional*):
781
+ The prompt or prompts to guide what to not include in image generation. If not defined, you need to
782
+ pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
783
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
784
+ The number of images to generate per prompt.
785
+ generator (`torch.Generator`, *optional*):
786
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
787
+ generation deterministic.
788
+ latents (`torch.IntTensor`, *optional*):
789
+ Pre-generated tokens representing latent vectors in `self.vqvae`, to be used as inputs for image
790
+ gneration. If not provided, the starting latents will be completely masked.
791
+ prompt_embeds (`torch.Tensor`, *optional*):
792
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
793
+ provided, text embeddings are generated from the `prompt` input argument. A single vector from the
794
+ pooled and projected final hidden states.
795
+ encoder_hidden_states (`torch.Tensor`, *optional*):
796
+ Pre-generated penultimate hidden states from the text encoder providing additional text conditioning.
797
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
798
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
799
+ not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
800
+ negative_encoder_hidden_states (`torch.Tensor`, *optional*):
801
+ Analogous to `encoder_hidden_states` for the positive prompt.
802
+ output_type (`str`, *optional*, defaults to `"pil"`):
803
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
804
+ return_dict (`bool`, *optional*, defaults to `True`):
805
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
806
+ plain tuple.
807
+ callback (`Callable`, *optional*):
808
+ A function that calls every `callback_steps` steps during inference. The function is called with the
809
+ following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.
810
+ callback_steps (`int`, *optional*, defaults to 1):
811
+ The frequency at which the `callback` function is called. If not specified, the callback is called at
812
+ every step.
813
+ cross_attention_kwargs (`dict`, *optional*):
814
+ A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
815
+ [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
816
+ micro_conditioning_aesthetic_score (`int`, *optional*, defaults to 6):
817
+ The targeted aesthetic score according to the laion aesthetic classifier. See
818
+ https://laion.ai/blog/laion-aesthetics/ and the micro-conditioning section of
819
+ https://arxiv.org/abs/2307.01952.
820
+ micro_conditioning_crop_coord (`Tuple[int]`, *optional*, defaults to (0, 0)):
821
+ The targeted height, width crop coordinates. See the micro-conditioning section of
822
+ https://arxiv.org/abs/2307.01952.
823
+ temperature (`Union[int, Tuple[int, int], List[int]]`, *optional*, defaults to (2, 0)):
824
+ Configures the temperature scheduler on `self.scheduler` see `Scheduler#set_timesteps`.
825
+
826
+ Examples:
827
+
828
+ Returns:
829
+ [`~pipelines.pipeline_utils.ImagePipelineOutput`] or `tuple`:
830
+ If `return_dict` is `True`, [`~pipelines.pipeline_utils.ImagePipelineOutput`] is returned, otherwise a
831
+ `tuple` is returned where the first element is a list with the generated images.
832
+ """
833
+ if (prompt_embeds is not None and encoder_hidden_states is None) or (
834
+ prompt_embeds is None and encoder_hidden_states is not None
835
+ ):
836
+ raise ValueError("pass either both `prompt_embeds` and `encoder_hidden_states` or neither")
837
+
838
+ if (negative_prompt_embeds is not None and negative_encoder_hidden_states is None) or (
839
+ negative_prompt_embeds is None and negative_encoder_hidden_states is not None
840
+ ):
841
+ raise ValueError(
842
+ "pass either both `negatve_prompt_embeds` and `negative_encoder_hidden_states` or neither"
843
+ )
844
+
845
+ if self.text_encoder_2 is not None:
846
+ self.text_encoder_2.to(self._execution_device)
847
+
848
+ text2image = image is None
849
+ image2text = image is not None
850
+
851
+ if image2text:
852
+ if self.text_encoder_2 is not None:
853
+ prompt = "<mask>" * 256
854
+ prompt = [prompt] * image.shape[0]
855
+
856
+ text_encoder_2_mask_id = self.tokenizer_2.convert_tokens_to_ids("<mask>")
857
+ self.scheduler.config.mask_token_id = text_encoder_2_mask_id
858
+ else:
859
+ mask_token = "<mask>"
860
+ self.tokenizer.add_tokens(mask_token, special_tokens=False)
861
+ clip_mask_id = self.tokenizer.convert_tokens_to_ids(mask_token)
862
+ self.text_encoder.resize_token_embeddings(len(self.tokenizer))
863
+
864
+ if mask_token_embedding is not None:
865
+ try:
866
+ if mask_token_embedding.endswith(".pth"):
867
+ mask_token_embedding = torch.load(mask_token_embedding)
868
+ else:
869
+ mask_token_embedding = os.path.dirname(mask_token_embedding)
870
+ mask_token_embedding_path = os.path.join(mask_token_embedding, "mask_token_embedding.pth")
871
+ assert os.path.exists(mask_token_embedding_path), f"{mask_token_embedding_path} doesn't exists!"
872
+ mask_token_embedding = torch.load(mask_token_embedding_path)
873
+
874
+ mask_token_embedding = mask_token_embedding.to(self._execution_device, dtype=self.text_encoder.dtype)
875
+ self.text_encoder.get_input_embeddings().weight.data[clip_mask_id].copy_(mask_token_embedding)
876
+
877
+ except Exception as e:
878
+ print(f"Error loading mask token embedding: {e}")
879
+ print("Using random initialized mask token embedding")
880
+ mask_token_embedding = None
881
+
882
+ self.scheduler.config.mask_token_id = clip_mask_id
883
+
884
+ input_ids = torch.ones(
885
+ size=(image.shape[0], self.tokenizer.model_max_length),
886
+ dtype=torch.int64,
887
+ device=self._execution_device
888
+ )
889
+ input_ids = input_ids * clip_mask_id
890
+
891
+ question_len = []
892
+ if prompt is None:
893
+ question_len = [0] * image.shape[0]
894
+ elif isinstance(prompt, str):
895
+ question_ids = torch.LongTensor([self.tokenizer.encode(prompt)])
896
+ question_ids = question_ids.repeat(image.shape[0], 1)
897
+
898
+ q_len = len(question_ids[0]) - 1 # remove <eos> token
899
+ question_len = [q_len] * image.shape[0]
900
+
901
+ input_ids[:, :q_len] = question_ids[:, :-1]
902
+ else:
903
+ assert isinstance(prompt, list), f"prompt must be None or str or list!"
904
+ assert len(prompt) == image.shape[0], f"VQA require equal num of images and prompts!"
905
+ for i, p in enumerate(prompt):
906
+ question_ids = torch.LongTensor([self.tokenizer.encode(p)])
907
+
908
+ q_len = len(question_ids[0]) - 1
909
+ question_len.append(q_len)
910
+
911
+ input_ids[i, :q_len] = question_ids[0, :-1]
912
+ else:
913
+ self.scheduler.config.mask_token_id = self.transformer.config.vocab_size - 1
914
+
915
+ if image is not None:
916
+ batch_size = image.shape[0]
917
+ elif isinstance(prompt, list):
918
+ batch_size = len(prompt)
919
+ elif isinstance(prompt, str):
920
+ batch_size = 1
921
+ prompt = [prompt]
922
+ else:
923
+ raise ValueError("prompt must be None or str or list!")
924
+
925
+ if height is None:
926
+ height = self.transformer.config.sample_size * self.vae_scale_factor
927
+
928
+ if width is None:
929
+ width = self.transformer.config.sample_size * self.vae_scale_factor
930
+
931
+ if isinstance(self.text_encoder, CLIPTextModelWithProjection):
932
+ text_encoder_type = "open_clip"
933
+ if isinstance(self.text_encoder_2, Gemma2Model):
934
+ text_encoder_type = "gemma"
935
+
936
+ if prompt_embeds is None and text2image:
937
+ if text_encoder_type == "open_clip":
938
+ input_ids = self.tokenizer(
939
+ prompt,
940
+ return_tensors="pt",
941
+ padding="max_length",
942
+ truncation=True,
943
+ add_special_tokens=True,
944
+ max_length=77,
945
+ ).input_ids.to(self._execution_device)
946
+
947
+ outputs = self.text_encoder(input_ids, return_dict=True, output_hidden_states=True)
948
+ prompt_embeds = outputs.text_embeds
949
+ encoder_hidden_states = outputs.hidden_states[-2]
950
+ elif text_encoder_type == "gemma":
951
+ input_ids_clip = self.tokenizer(
952
+ prompt,
953
+ return_tensors="pt",
954
+ padding="max_length",
955
+ truncation=True,
956
+ add_special_tokens=True,
957
+ max_length=77,
958
+ ).input_ids.to(self._execution_device)
959
+ outputs = self.text_encoder(input_ids_clip, return_dict=True, output_hidden_states=True)
960
+ prompt_embeds = outputs.text_embeds
961
+
962
+ input_ids_2 = self.tokenizer_2(
963
+ prompt,
964
+ truncation=True,
965
+ padding="max_length",
966
+ max_length=256,
967
+ return_tensors="pt",
968
+ ).input_ids.to(self._execution_device)
969
+
970
+ outputs_2 = self.text_encoder_2(input_ids_2, return_dict=True, output_hidden_states=True)
971
+ encoder_hidden_states = outputs_2.last_hidden_state
972
+
973
+ prompt_embeds = prompt_embeds.repeat(num_images_per_prompt, 1)
974
+ encoder_hidden_states = encoder_hidden_states.repeat(num_images_per_prompt, 1, 1)
975
+
976
+ if guidance_scale > 1.0 and text2image:
977
+ if negative_prompt_embeds is None:
978
+ if negative_prompt is None:
979
+ negative_prompt = [""] * len(prompt)
980
+
981
+ if isinstance(negative_prompt, str):
982
+ negative_prompt = [negative_prompt] * len(prompt)
983
+
984
+ if text_encoder_type == "t5_clip":
985
+ input_ids = self.tokenizer(
986
+ negative_prompt,
987
+ return_tensors="pt",
988
+ padding="max_length",
989
+ truncation=True,
990
+ add_special_tokens=True,
991
+ max_length=77,
992
+ ).input_ids.to(self._execution_device)
993
+ outputs = self.text_encoder(input_ids, return_dict=True, output_hidden_states=True)
994
+ negative_prompt_embeds = outputs.text_embeds
995
+
996
+ input_ids_2 = self.tokenizer_2(
997
+ negative_prompt,
998
+ return_tensors="pt",
999
+ padding="max_length",
1000
+ truncation=True,
1001
+ add_special_tokens=True,
1002
+ max_length=256,
1003
+ ).input_ids.to(self._execution_device)
1004
+ outputs_2 = self.text_encoder_2(input_ids_2, return_dict=True, output_hidden_states=True)
1005
+ negative_encoder_hidden_states = outputs_2.last_hidden_state
1006
+
1007
+ elif text_encoder_type == "open_clip":
1008
+ input_ids = self.tokenizer(
1009
+ negative_prompt,
1010
+ return_tensors="pt",
1011
+ padding="max_length",
1012
+ truncation=True,
1013
+ add_special_tokens=True,
1014
+ max_length=77,
1015
+ ).input_ids.to(self._execution_device)
1016
+
1017
+ outputs = self.text_encoder(input_ids, return_dict=True, output_hidden_states=True)
1018
+
1019
+ negative_prompt_embeds = outputs.text_embeds
1020
+ negative_encoder_hidden_states = outputs.hidden_states[-2]
1021
+
1022
+ elif text_encoder_type == "gemma":
1023
+ input_ids = self.tokenizer(
1024
+ negative_prompt,
1025
+ return_tensors="pt",
1026
+ padding="max_length",
1027
+ truncation=True,
1028
+ add_special_tokens=True,
1029
+ max_length=77,
1030
+ ).input_ids.to(self._execution_device)
1031
+ outputs = self.text_encoder(input_ids, return_dict=True, output_hidden_states=True)
1032
+ negative_prompt_embeds = outputs.text_embeds
1033
+
1034
+ input_ids_2 = self.tokenizer_2(
1035
+ negative_prompt,
1036
+ truncation=True,
1037
+ padding="max_length",
1038
+ max_length=256,
1039
+ return_tensors="pt",
1040
+ ).input_ids.to(self._execution_device)
1041
+ outputs_2 = self.text_encoder_2(input_ids_2, return_dict=True, output_hidden_states=True)
1042
+ negative_encoder_hidden_states = outputs_2.last_hidden_state
1043
+
1044
+ negative_prompt_embeds = negative_prompt_embeds.repeat(num_images_per_prompt, 1)
1045
+ negative_encoder_hidden_states = negative_encoder_hidden_states.repeat(num_images_per_prompt, 1, 1)
1046
+
1047
+ prompt_embeds = torch.concat([negative_prompt_embeds, prompt_embeds])
1048
+ encoder_hidden_states = torch.concat([negative_encoder_hidden_states, encoder_hidden_states])
1049
+
1050
+ # Note that the micro conditionings _do_ flip the order of width, height for the original size
1051
+ # and the crop coordinates. This is how it was done in the original code base
1052
+ micro_conds = torch.tensor(
1053
+ [
1054
+ width,
1055
+ height,
1056
+ micro_conditioning_crop_coord[0],
1057
+ micro_conditioning_crop_coord[1],
1058
+ micro_conditioning_aesthetic_score,
1059
+ ],
1060
+ device=self._execution_device,
1061
+ dtype=self.transformer.dtype,
1062
+ )
1063
+ micro_conds = micro_conds.unsqueeze(0)
1064
+ micro_conds = micro_conds.expand(2 * batch_size if guidance_scale > 1.0 and text2image else batch_size, -1)
1065
+
1066
+ shape = (batch_size, height // self.vae_scale_factor, width // self.vae_scale_factor)
1067
+
1068
+ if latents is None and text2image:
1069
+ latents = torch.full(
1070
+ shape,
1071
+ self.scheduler.config.mask_token_id,
1072
+ dtype=torch.long,
1073
+ device=self._execution_device
1074
+ )
1075
+ elif image2text:
1076
+ if text_encoder_type in ("t5_clip", "gemma"):
1077
+ latents = input_ids_2 # [b, l]
1078
+ else:
1079
+ latents = input_ids
1080
+
1081
+ model_input = None
1082
+ step_by_step = []
1083
+
1084
+ self.scheduler.set_timesteps(num_inference_steps, temperature, self._execution_device)
1085
+ num_warmup_steps = len(self.scheduler.timesteps) - num_inference_steps * self.scheduler.order
1086
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
1087
+ for i, timestep in enumerate(self.scheduler.timesteps):
1088
+ if guidance_scale > 1.0 and text2image:
1089
+ model_input = torch.cat([latents] * 2)
1090
+ elif image2text:
1091
+ if model_input is None:
1092
+ model_input = self.vqvae.quantize(
1093
+ self.vqvae.encode(image.to(self._execution_device, dtype=self.vqvae.dtype)).latents
1094
+ )[2][2].reshape(batch_size, height // self.vae_scale_factor, width // self.vae_scale_factor)
1095
+
1096
+ prompt_embeds = self.image_encoder(
1097
+ self.clip_image_processor(
1098
+ image,
1099
+ return_tensors="pt",
1100
+ do_rescale=False,
1101
+ do_resize=True,
1102
+ do_normalize=True,
1103
+ ).to(self._execution_device, dtype=self.image_encoder.dtype).pixel_values
1104
+ ).image_embeds # [b, 1024]
1105
+
1106
+ if text_encoder_type in ("t5_clip", "gemma"):
1107
+ outputs = self.text_encoder_2(latents, return_dict=True)
1108
+ encoder_hidden_states = outputs.last_hidden_state
1109
+ else:
1110
+ outputs = self.text_encoder(latents, return_dict=True, output_hidden_states=True)
1111
+ encoder_hidden_states = outputs.hidden_states[-2]
1112
+ else:
1113
+ model_input = latents
1114
+
1115
+ img_ids = _prepare_latent_image_ids(
1116
+ model_input.shape[0],
1117
+ model_input.shape[-2],
1118
+ model_input.shape[-1],
1119
+ self._execution_device,
1120
+ self.transformer.dtype
1121
+ )
1122
+ txt_ids = torch.zeros(encoder_hidden_states.shape[1], 3).to(
1123
+ device=self._execution_device,
1124
+ dtype=self.transformer.dtype
1125
+ )
1126
+
1127
+ # timestep_ = int(timestep / num_inference_steps * 1000)
1128
+ model_output, encoder_hidden_states_tmp = self.transformer(
1129
+ hidden_states=model_input,
1130
+ micro_conds=micro_conds,
1131
+ pooled_projections=prompt_embeds,
1132
+ encoder_hidden_states=encoder_hidden_states,
1133
+ img_ids=img_ids,
1134
+ txt_ids=txt_ids,
1135
+ timestep=torch.tensor([timestep], device=self._execution_device),
1136
+ )
1137
+
1138
+ if image2text:
1139
+ encoder_hidden_states = encoder_hidden_states_tmp.clone()
1140
+
1141
+ if guidance_scale > 1.0 and text2image:
1142
+ uncond_logits, cond_logits = model_output.chunk(2)
1143
+ to_scheduler = uncond_logits + guidance_scale * (cond_logits - uncond_logits)
1144
+ elif image2text:
1145
+ to_scheduler = encoder_hidden_states
1146
+ else:
1147
+ to_scheduler = model_output
1148
+
1149
+ latents = self.scheduler.step(
1150
+ model_output=to_scheduler,
1151
+ timestep=timestep,
1152
+ sample=latents,
1153
+ generator=generator,
1154
+ ).prev_sample
1155
+
1156
+ # this line will print the intermediate results of the image-to-text generation
1157
+ # step_by_step.append(self.tokenizer.decode(latents[0].tolist(), skip_special_tokens=True))
1158
+
1159
+ # this line will print the intermediate results of the text-to-image generation
1160
+ # output = self.vqvae.decode(
1161
+ # latents,
1162
+ # force_not_quantize=True,
1163
+ # shape=(
1164
+ # batch_size,
1165
+ # height // self.vae_scale_factor,
1166
+ # width // self.vae_scale_factor,
1167
+ # self.vqvae.config.latent_channels,
1168
+ # ),
1169
+ # ).sample.clip(0, 1)
1170
+ # output = self.image_processor.postprocess(output, output_type) # output is a list of PIL.Image, you need to save it.
1171
+
1172
+ if i == len(self.scheduler.timesteps) - 1 or (
1173
+ (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
1174
+ ):
1175
+ progress_bar.update()
1176
+ if callback is not None and i % callback_steps == 0:
1177
+ step_idx = i // getattr(self.scheduler, "order", 1)
1178
+ callback(step_idx, timestep, latents)
1179
+
1180
+ # with open("step_by_step.txt", "w") as file:
1181
+ # for prompt in step_by_step:
1182
+ # file.write(prompt + "\n")
1183
+
1184
+ if guidance_scale > 1.0 and text2image:
1185
+ decoded_input_ids = encoder_hidden_states[encoder_hidden_states.shape[0] // 2:].argmax(-1)
1186
+ else:
1187
+ decoded_input_ids = encoder_hidden_states.argmax(-1)
1188
+
1189
+ prompts = []
1190
+ for i, prompt in enumerate(decoded_input_ids):
1191
+ if image2text:
1192
+ q_len = question_len[i]
1193
+ prompt = self.tokenizer.decode(prompt.tolist()[q_len:], skip_special_tokens=True)
1194
+ prompts.append(keep_upto_last_period(dedup_consecutive_words(prompt)))
1195
+ else:
1196
+ prompts.append("Placeholder")
1197
+
1198
+ if output_type == "latent":
1199
+ output = latents
1200
+ else:
1201
+ needs_upcasting = self.vqvae.dtype == torch.float16 and self.vqvae.config.force_upcast
1202
+
1203
+ if needs_upcasting:
1204
+ self.vqvae.float()
1205
+
1206
+ if text2image:
1207
+ to_vqvae = latents
1208
+ else:
1209
+ to_vqvae = model_input
1210
+
1211
+ output = self.vqvae.decode(
1212
+ to_vqvae,
1213
+ force_not_quantize=True,
1214
+ shape=(
1215
+ batch_size,
1216
+ height // self.vae_scale_factor,
1217
+ width // self.vae_scale_factor,
1218
+ self.vqvae.config.latent_channels,
1219
+ ),
1220
+ ).sample.clip(0, 1)
1221
+ output = self.image_processor.postprocess(output, output_type)
1222
+
1223
+ if needs_upcasting:
1224
+ self.vqvae.half()
1225
+
1226
+ self.maybe_free_model_hooks()
1227
+
1228
+ if not return_dict:
1229
+ return (output,)
1230
+
1231
+ return UnifiedPipelineOutput(images=output, prompts=prompts)
src/scheduler.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team and The MeissonFlow Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import math
15
+ from dataclasses import dataclass
16
+ from typing import List, Optional, Tuple, Union
17
+
18
+ import torch
19
+
20
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
21
+ from diffusers.utils import BaseOutput
22
+ from diffusers.schedulers.scheduling_utils import SchedulerMixin
23
+
24
+
25
+ def gumbel_noise(t, generator=None):
26
+ device = generator.device if generator is not None else t.device
27
+ noise = torch.zeros_like(t, device=device).uniform_(0, 1, generator=generator).to(t.device)
28
+ return -torch.log((-torch.log(noise.clamp(1e-20))).clamp(1e-20))
29
+
30
+
31
+ def mask_by_random_topk(mask_len, probs, temperature=1.0, generator=None):
32
+ confidence = torch.log(probs.clamp(1e-20)) + temperature * gumbel_noise(probs, generator=generator)
33
+ sorted_confidence = torch.sort(confidence, dim=-1).values
34
+ cut_off = torch.gather(sorted_confidence, 1, mask_len.long())
35
+ masking = confidence < cut_off
36
+ return masking
37
+
38
+
39
+ @dataclass
40
+ class SchedulerOutput(BaseOutput):
41
+ """
42
+ Output class for the scheduler's `step` function output.
43
+
44
+ Args:
45
+ prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
46
+ Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
47
+ denoising loop.
48
+ pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
49
+ The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
50
+ `pred_original_sample` can be used to preview progress or for guidance.
51
+ """
52
+
53
+ prev_sample: torch.Tensor
54
+ pred_original_sample: torch.Tensor = None
55
+
56
+
57
+ class Scheduler(SchedulerMixin, ConfigMixin):
58
+ order = 1
59
+
60
+ temperatures: torch.Tensor
61
+
62
+ @register_to_config
63
+ def __init__(
64
+ self,
65
+ mask_token_id: int,
66
+ masking_schedule: str = "cosine",
67
+ ):
68
+ self.temperatures = None
69
+ self.timesteps = None
70
+
71
+ def set_timesteps(
72
+ self,
73
+ num_inference_steps: int,
74
+ temperature: Union[int, Tuple[int, int], List[int]] = (2, 0),
75
+ device: Union[str, torch.device] = None,
76
+ ):
77
+ self.timesteps = torch.arange(num_inference_steps, device=device).flip(0)
78
+
79
+ if isinstance(temperature, (tuple, list)):
80
+ self.temperatures = torch.linspace(temperature[0], temperature[1], num_inference_steps, device=device)
81
+ else:
82
+ self.temperatures = torch.linspace(temperature, 0.01, num_inference_steps, device=device)
83
+
84
+ def step(
85
+ self,
86
+ model_output: torch.Tensor,
87
+ timestep: torch.long,
88
+ sample: torch.LongTensor,
89
+ starting_mask_ratio: int = 1,
90
+ generator: Optional[torch.Generator] = None,
91
+ return_dict: bool = True,
92
+ ) -> Union[SchedulerOutput, Tuple]:
93
+ two_dim_input = sample.ndim == 3 and model_output.ndim == 4
94
+
95
+ if two_dim_input:
96
+ batch_size, codebook_size, height, width = model_output.shape
97
+ sample = sample.reshape(batch_size, height * width)
98
+ model_output = model_output.reshape(batch_size, codebook_size, height * width).permute(0, 2, 1)
99
+
100
+ unknown_map = sample == self.config.mask_token_id
101
+
102
+ probs = model_output.softmax(dim=-1)
103
+
104
+ device = probs.device
105
+ probs_ = probs.to(generator.device) if generator is not None else probs # handles when generator is on CPU
106
+ if probs_.device.type == "cpu" and probs_.dtype != torch.float32:
107
+ probs_ = probs_.float() # multinomial is not implemented for cpu half precision
108
+ probs_ = probs_.reshape(-1, probs.size(-1))
109
+ pred_original_sample = torch.multinomial(probs_, 1, generator=generator).to(device=device)
110
+ pred_original_sample = pred_original_sample[:, 0].view(*probs.shape[:-1])
111
+ pred_original_sample = torch.where(unknown_map, pred_original_sample, sample)
112
+
113
+ if timestep == 0:
114
+ prev_sample = pred_original_sample
115
+ else:
116
+ seq_len = sample.shape[1]
117
+ step_idx = (self.timesteps == timestep).nonzero()
118
+ ratio = (step_idx + 1) / len(self.timesteps)
119
+
120
+ if self.config.masking_schedule == "cosine":
121
+ mask_ratio = torch.cos(ratio * math.pi / 2)
122
+ elif self.config.masking_schedule == "linear":
123
+ mask_ratio = 1 - ratio
124
+ else:
125
+ raise ValueError(f"unknown masking schedule {self.config.masking_schedule}")
126
+
127
+ mask_ratio = starting_mask_ratio * mask_ratio
128
+
129
+ mask_len = (seq_len * mask_ratio).floor()
130
+ # do not mask more than amount previously masked
131
+ mask_len = torch.min(unknown_map.sum(dim=-1, keepdim=True) - 1, mask_len)
132
+ # mask at least one
133
+ mask_len = torch.max(torch.tensor([1], device=model_output.device), mask_len)
134
+
135
+ selected_probs = torch.gather(probs, -1, pred_original_sample[:, :, None])[:, :, 0]
136
+ # Ignores the tokens given in the input by overwriting their confidence.
137
+ selected_probs = torch.where(unknown_map, selected_probs, torch.finfo(selected_probs.dtype).max)
138
+
139
+ masking = mask_by_random_topk(mask_len, selected_probs, self.temperatures[step_idx], generator)
140
+
141
+ # Masks tokens with lower confidence.
142
+ prev_sample = torch.where(masking, self.config.mask_token_id, pred_original_sample)
143
+
144
+ if two_dim_input:
145
+ prev_sample = prev_sample.reshape(batch_size, height, width)
146
+ pred_original_sample = pred_original_sample.reshape(batch_size, height, width)
147
+
148
+ if not return_dict:
149
+ return (prev_sample, pred_original_sample)
150
+
151
+ return SchedulerOutput(prev_sample, pred_original_sample)
152
+
153
+ def add_noise(self, sample, timesteps, generator=None):
154
+ step_idx = (self.timesteps == timesteps).nonzero()
155
+ ratio = (step_idx + 1) / len(self.timesteps)
156
+
157
+ if self.config.masking_schedule == "cosine":
158
+ mask_ratio = torch.cos(ratio * math.pi / 2)
159
+ elif self.config.masking_schedule == "linear":
160
+ mask_ratio = 1 - ratio
161
+ else:
162
+ raise ValueError(f"unknown masking schedule {self.config.masking_schedule}")
163
+
164
+ mask_indices = (
165
+ torch.rand(
166
+ sample.shape, device=generator.device if generator is not None else sample.device, generator=generator
167
+ ).to(sample.device)
168
+ < mask_ratio
169
+ )
170
+
171
+ masked_sample = sample.clone()
172
+
173
+ masked_sample[mask_indices] = self.config.mask_token_id
174
+
175
+ return masked_sample
src/transformer.py ADDED
@@ -0,0 +1,1459 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Black Forest Labs, The HuggingFace Team, The InstantX Team and The MeissonFlow Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from typing import Any, Dict, Optional, Tuple, Union, List
17
+
18
+ import numpy as np
19
+ import torch
20
+ import torch.nn as nn
21
+ import torch.nn.functional as F
22
+
23
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
24
+ from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
25
+ from diffusers.models.attention import FeedForward, BasicTransformerBlock, SkipFFTransformerBlock
26
+ from diffusers.models.attention_processor import (
27
+ Attention,
28
+ AttentionProcessor,
29
+ FluxAttnProcessor2_0,
30
+ )
31
+ from diffusers.models.modeling_utils import ModelMixin
32
+ from diffusers.models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle, GlobalResponseNorm, RMSNorm
33
+ from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
34
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
35
+ from diffusers.models.embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings,TimestepEmbedding, get_timestep_embedding #,FluxPosEmbed
36
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
37
+ from diffusers.models.resnet import Downsample2D, Upsample2D
38
+
39
+
40
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
41
+
42
+
43
+
44
+ def get_3d_rotary_pos_embed(
45
+ embed_dim, crops_coords, grid_size, temporal_size, theta: int = 10000, use_real: bool = True
46
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
47
+ """
48
+ RoPE for video tokens with 3D structure.
49
+
50
+ Args:
51
+ embed_dim: (`int`):
52
+ The embedding dimension size, corresponding to hidden_size_head.
53
+ crops_coords (`Tuple[int]`):
54
+ The top-left and bottom-right coordinates of the crop.
55
+ grid_size (`Tuple[int]`):
56
+ The grid size of the spatial positional embedding (height, width).
57
+ temporal_size (`int`):
58
+ The size of the temporal dimension.
59
+ theta (`float`):
60
+ Scaling factor for frequency computation.
61
+ use_real (`bool`):
62
+ If True, return real part and imaginary part separately. Otherwise, return complex numbers.
63
+
64
+ Returns:
65
+ `torch.Tensor`: positional embedding with shape `(temporal_size * grid_size[0] * grid_size[1], embed_dim/2)`.
66
+ """
67
+ start, stop = crops_coords
68
+ grid_h = np.linspace(start[0], stop[0], grid_size[0], endpoint=False, dtype=np.float32)
69
+ grid_w = np.linspace(start[1], stop[1], grid_size[1], endpoint=False, dtype=np.float32)
70
+ grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32)
71
+
72
+ # Compute dimensions for each axis
73
+ dim_t = embed_dim // 4
74
+ dim_h = embed_dim // 8 * 3
75
+ dim_w = embed_dim // 8 * 3
76
+
77
+ # Temporal frequencies
78
+ freqs_t = 1.0 / (theta ** (torch.arange(0, dim_t, 2).float() / dim_t))
79
+ grid_t = torch.from_numpy(grid_t).float()
80
+ freqs_t = torch.einsum("n , f -> n f", grid_t, freqs_t)
81
+ freqs_t = freqs_t.repeat_interleave(2, dim=-1)
82
+
83
+ # Spatial frequencies for height and width
84
+ freqs_h = 1.0 / (theta ** (torch.arange(0, dim_h, 2).float() / dim_h))
85
+ freqs_w = 1.0 / (theta ** (torch.arange(0, dim_w, 2).float() / dim_w))
86
+ grid_h = torch.from_numpy(grid_h).float()
87
+ grid_w = torch.from_numpy(grid_w).float()
88
+ freqs_h = torch.einsum("n , f -> n f", grid_h, freqs_h)
89
+ freqs_w = torch.einsum("n , f -> n f", grid_w, freqs_w)
90
+ freqs_h = freqs_h.repeat_interleave(2, dim=-1)
91
+ freqs_w = freqs_w.repeat_interleave(2, dim=-1)
92
+
93
+ # Broadcast and concatenate tensors along specified dimension
94
+ def broadcast(tensors, dim=-1):
95
+ num_tensors = len(tensors)
96
+ shape_lens = {len(t.shape) for t in tensors}
97
+ assert len(shape_lens) == 1, "tensors must all have the same number of dimensions"
98
+ shape_len = list(shape_lens)[0]
99
+ dim = (dim + shape_len) if dim < 0 else dim
100
+ dims = list(zip(*(list(t.shape) for t in tensors)))
101
+ expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
102
+ assert all(
103
+ [*(len(set(t[1])) <= 2 for t in expandable_dims)]
104
+ ), "invalid dimensions for broadcastable concatenation"
105
+ max_dims = [(t[0], max(t[1])) for t in expandable_dims]
106
+ expanded_dims = [(t[0], (t[1],) * num_tensors) for t in max_dims]
107
+ expanded_dims.insert(dim, (dim, dims[dim]))
108
+ expandable_shapes = list(zip(*(t[1] for t in expanded_dims)))
109
+ tensors = [t[0].expand(*t[1]) for t in zip(tensors, expandable_shapes)]
110
+ return torch.cat(tensors, dim=dim)
111
+
112
+ freqs = broadcast((freqs_t[:, None, None, :], freqs_h[None, :, None, :], freqs_w[None, None, :, :]), dim=-1)
113
+
114
+ t, h, w, d = freqs.shape
115
+ freqs = freqs.view(t * h * w, d)
116
+
117
+ # Generate sine and cosine components
118
+ sin = freqs.sin()
119
+ cos = freqs.cos()
120
+
121
+ if use_real:
122
+ return cos, sin
123
+ else:
124
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
125
+ return freqs_cis
126
+
127
+
128
+ def get_2d_rotary_pos_embed(embed_dim, crops_coords, grid_size, use_real=True):
129
+ """
130
+ RoPE for image tokens with 2d structure.
131
+
132
+ Args:
133
+ embed_dim: (`int`):
134
+ The embedding dimension size
135
+ crops_coords (`Tuple[int]`)
136
+ The top-left and bottom-right coordinates of the crop.
137
+ grid_size (`Tuple[int]`):
138
+ The grid size of the positional embedding.
139
+ use_real (`bool`):
140
+ If True, return real part and imaginary part separately. Otherwise, return complex numbers.
141
+
142
+ Returns:
143
+ `torch.Tensor`: positional embedding with shape `( grid_size * grid_size, embed_dim/2)`.
144
+ """
145
+ start, stop = crops_coords
146
+ grid_h = np.linspace(start[0], stop[0], grid_size[0], endpoint=False, dtype=np.float32)
147
+ grid_w = np.linspace(start[1], stop[1], grid_size[1], endpoint=False, dtype=np.float32)
148
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
149
+ grid = np.stack(grid, axis=0) # [2, W, H]
150
+
151
+ grid = grid.reshape([2, 1, *grid.shape[1:]])
152
+ pos_embed = get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=use_real)
153
+ return pos_embed
154
+
155
+
156
+ def get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=False):
157
+ assert embed_dim % 4 == 0
158
+
159
+ # use half of dimensions to encode grid_h
160
+ emb_h = get_1d_rotary_pos_embed(
161
+ embed_dim // 2, grid[0].reshape(-1), use_real=use_real
162
+ ) # (H*W, D/2) if use_real else (H*W, D/4)
163
+ emb_w = get_1d_rotary_pos_embed(
164
+ embed_dim // 2, grid[1].reshape(-1), use_real=use_real
165
+ ) # (H*W, D/2) if use_real else (H*W, D/4)
166
+
167
+ if use_real:
168
+ cos = torch.cat([emb_h[0], emb_w[0]], dim=1) # (H*W, D)
169
+ sin = torch.cat([emb_h[1], emb_w[1]], dim=1) # (H*W, D)
170
+ return cos, sin
171
+ else:
172
+ emb = torch.cat([emb_h, emb_w], dim=1) # (H*W, D/2)
173
+ return emb
174
+
175
+
176
+ def get_2d_rotary_pos_embed_lumina(embed_dim, len_h, len_w, linear_factor=1.0, ntk_factor=1.0):
177
+ assert embed_dim % 4 == 0
178
+
179
+ emb_h = get_1d_rotary_pos_embed(
180
+ embed_dim // 2, len_h, linear_factor=linear_factor, ntk_factor=ntk_factor
181
+ ) # (H, D/4)
182
+ emb_w = get_1d_rotary_pos_embed(
183
+ embed_dim // 2, len_w, linear_factor=linear_factor, ntk_factor=ntk_factor
184
+ ) # (W, D/4)
185
+ emb_h = emb_h.view(len_h, 1, embed_dim // 4, 1).repeat(1, len_w, 1, 1) # (H, W, D/4, 1)
186
+ emb_w = emb_w.view(1, len_w, embed_dim // 4, 1).repeat(len_h, 1, 1, 1) # (H, W, D/4, 1)
187
+
188
+ emb = torch.cat([emb_h, emb_w], dim=-1).flatten(2) # (H, W, D/2)
189
+ return emb
190
+
191
+
192
+ def get_1d_rotary_pos_embed(
193
+ dim: int,
194
+ pos: Union[np.ndarray, int],
195
+ theta: float = 10000.0,
196
+ use_real=False,
197
+ linear_factor=1.0,
198
+ ntk_factor=1.0,
199
+ repeat_interleave_real=True,
200
+ freqs_dtype=torch.float32, # torch.float32 (hunyuan, stable audio), torch.float64 (flux)
201
+ ):
202
+ """
203
+ Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
204
+
205
+ This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' and the end
206
+ index 'end'. The 'theta' parameter scales the frequencies. The returned tensor contains complex values in complex64
207
+ data type.
208
+
209
+ Args:
210
+ dim (`int`): Dimension of the frequency tensor.
211
+ pos (`np.ndarray` or `int`): Position indices for the frequency tensor. [S] or scalar
212
+ theta (`float`, *optional*, defaults to 10000.0):
213
+ Scaling factor for frequency computation. Defaults to 10000.0.
214
+ use_real (`bool`, *optional*):
215
+ If True, return real part and imaginary part separately. Otherwise, return complex numbers.
216
+ linear_factor (`float`, *optional*, defaults to 1.0):
217
+ Scaling factor for the context extrapolation. Defaults to 1.0.
218
+ ntk_factor (`float`, *optional*, defaults to 1.0):
219
+ Scaling factor for the NTK-Aware RoPE. Defaults to 1.0.
220
+ repeat_interleave_real (`bool`, *optional*, defaults to `True`):
221
+ If `True` and `use_real`, real part and imaginary part are each interleaved with themselves to reach `dim`.
222
+ Otherwise, they are concateanted with themselves.
223
+ freqs_dtype (`torch.float32` or `torch.float64`, *optional*, defaults to `torch.float32`):
224
+ the dtype of the frequency tensor.
225
+ Returns:
226
+ `torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2]
227
+ """
228
+ assert dim % 2 == 0
229
+
230
+ if isinstance(pos, int):
231
+ pos = np.arange(pos)
232
+ theta = theta * ntk_factor
233
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype)[: (dim // 2)] / dim)) / linear_factor # [D/2]
234
+ t = torch.from_numpy(pos).to(freqs.device) # type: ignore # [S]
235
+ freqs = torch.outer(t, freqs) # type: ignore # [S, D/2]
236
+ if use_real and repeat_interleave_real:
237
+ freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D]
238
+ freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D]
239
+ return freqs_cos, freqs_sin
240
+ elif use_real:
241
+ freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float() # [S, D]
242
+ freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float() # [S, D]
243
+ return freqs_cos, freqs_sin
244
+ else:
245
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs).float() # complex64 # [S, D/2]
246
+ return freqs_cis
247
+
248
+
249
+ class FluxPosEmbed(nn.Module):
250
+ # modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11
251
+ def __init__(self, theta: int, axes_dim: Tuple[int]):
252
+ super().__init__()
253
+ self.theta = theta
254
+ self.axes_dim = axes_dim
255
+
256
+ def forward(self, ids: torch.Tensor) -> torch.Tensor:
257
+ n_axes = ids.shape[-1]
258
+ cos_out = []
259
+ sin_out = []
260
+ pos = ids.squeeze().float().cpu().numpy()
261
+ is_mps = ids.device.type == "mps"
262
+ freqs_dtype = torch.float32 if is_mps else torch.float64
263
+ for i in range(n_axes):
264
+ cos, sin = get_1d_rotary_pos_embed(
265
+ self.axes_dim[i], pos[:, i], repeat_interleave_real=True, use_real=True, freqs_dtype=freqs_dtype
266
+ )
267
+ cos_out.append(cos)
268
+ sin_out.append(sin)
269
+ freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device)
270
+ freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device)
271
+ return freqs_cos, freqs_sin
272
+
273
+
274
+
275
+ class FusedFluxAttnProcessor2_0:
276
+ """Attention processor used typically in processing the SD3-like self-attention projections."""
277
+
278
+ def __init__(self):
279
+ if not hasattr(F, "scaled_dot_product_attention"):
280
+ raise ImportError(
281
+ "FusedFluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
282
+ )
283
+
284
+ def __call__(
285
+ self,
286
+ attn: Attention,
287
+ hidden_states: torch.FloatTensor,
288
+ encoder_hidden_states: torch.FloatTensor = None,
289
+ attention_mask: Optional[torch.FloatTensor] = None,
290
+ image_rotary_emb: Optional[torch.Tensor] = None,
291
+ ) -> torch.FloatTensor:
292
+ batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
293
+
294
+ # `sample` projections.
295
+ qkv = attn.to_qkv(hidden_states)
296
+ split_size = qkv.shape[-1] // 3
297
+ query, key, value = torch.split(qkv, split_size, dim=-1)
298
+
299
+ inner_dim = key.shape[-1]
300
+ head_dim = inner_dim // attn.heads
301
+
302
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
303
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
304
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
305
+
306
+ if attn.norm_q is not None:
307
+ query = attn.norm_q(query)
308
+ if attn.norm_k is not None:
309
+ key = attn.norm_k(key)
310
+
311
+ # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
312
+ # `context` projections.
313
+ if encoder_hidden_states is not None:
314
+ encoder_qkv = attn.to_added_qkv(encoder_hidden_states)
315
+ split_size = encoder_qkv.shape[-1] // 3
316
+ (
317
+ encoder_hidden_states_query_proj,
318
+ encoder_hidden_states_key_proj,
319
+ encoder_hidden_states_value_proj,
320
+ ) = torch.split(encoder_qkv, split_size, dim=-1)
321
+
322
+ encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
323
+ batch_size, -1, attn.heads, head_dim
324
+ ).transpose(1, 2)
325
+ encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
326
+ batch_size, -1, attn.heads, head_dim
327
+ ).transpose(1, 2)
328
+ encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
329
+ batch_size, -1, attn.heads, head_dim
330
+ ).transpose(1, 2)
331
+
332
+ if attn.norm_added_q is not None:
333
+ encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
334
+ if attn.norm_added_k is not None:
335
+ encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
336
+
337
+ # attention
338
+ query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
339
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
340
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
341
+
342
+ if image_rotary_emb is not None:
343
+ from diffusers.models.embeddings import apply_rotary_emb
344
+
345
+ query = apply_rotary_emb(query, image_rotary_emb)
346
+ key = apply_rotary_emb(key, image_rotary_emb)
347
+
348
+ hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
349
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
350
+ hidden_states = hidden_states.to(query.dtype)
351
+
352
+ if encoder_hidden_states is not None:
353
+ encoder_hidden_states, hidden_states = (
354
+ hidden_states[:, : encoder_hidden_states.shape[1]],
355
+ hidden_states[:, encoder_hidden_states.shape[1] :],
356
+ )
357
+
358
+ # linear proj
359
+ hidden_states = attn.to_out[0](hidden_states)
360
+ # dropout
361
+ hidden_states = attn.to_out[1](hidden_states)
362
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
363
+
364
+ return hidden_states, encoder_hidden_states
365
+ else:
366
+ return hidden_states
367
+
368
+
369
+
370
+ @maybe_allow_in_graph
371
+ class SingleTransformerBlock(nn.Module):
372
+ r"""
373
+ A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
374
+
375
+ Reference: https://arxiv.org/abs/2403.03206
376
+
377
+ Parameters:
378
+ dim (`int`): The number of channels in the input and output.
379
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
380
+ attention_head_dim (`int`): The number of channels in each head.
381
+ context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
382
+ processing of `context` conditions.
383
+ """
384
+
385
+ def __init__(self, dim, num_attention_heads, attention_head_dim, mlp_ratio=4.0):
386
+ super().__init__()
387
+ self.mlp_hidden_dim = int(dim * mlp_ratio)
388
+
389
+ self.norm = AdaLayerNormZeroSingle(dim)
390
+ self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim)
391
+ self.act_mlp = nn.GELU(approximate="tanh")
392
+ self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
393
+
394
+ processor = FluxAttnProcessor2_0()
395
+ self.attn = Attention(
396
+ query_dim=dim,
397
+ cross_attention_dim=None,
398
+ dim_head=attention_head_dim,
399
+ heads=num_attention_heads,
400
+ out_dim=dim,
401
+ bias=True,
402
+ processor=processor,
403
+ qk_norm="rms_norm",
404
+ eps=1e-6,
405
+ pre_only=True,
406
+ )
407
+
408
+ def forward(
409
+ self,
410
+ hidden_states: torch.FloatTensor,
411
+ temb: torch.FloatTensor,
412
+ image_rotary_emb=None,
413
+ ):
414
+ residual = hidden_states
415
+ norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
416
+ mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
417
+
418
+ attn_output = self.attn(
419
+ hidden_states=norm_hidden_states,
420
+ image_rotary_emb=image_rotary_emb,
421
+ )
422
+
423
+ hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
424
+ gate = gate.unsqueeze(1)
425
+ hidden_states = gate * self.proj_out(hidden_states)
426
+ hidden_states = residual + hidden_states
427
+ if hidden_states.dtype == torch.float16:
428
+ hidden_states = hidden_states.clip(-65504, 65504)
429
+
430
+ return hidden_states
431
+
432
+ @maybe_allow_in_graph
433
+ class TransformerBlock(nn.Module):
434
+ r"""
435
+ A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
436
+
437
+ Reference: https://arxiv.org/abs/2403.03206
438
+
439
+ Parameters:
440
+ dim (`int`): The number of channels in the input and output.
441
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
442
+ attention_head_dim (`int`): The number of channels in each head.
443
+ context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
444
+ processing of `context` conditions.
445
+ """
446
+
447
+ def __init__(self, dim, num_attention_heads, attention_head_dim, qk_norm="rms_norm", eps=1e-6):
448
+ super().__init__()
449
+
450
+ self.norm1 = AdaLayerNormZero(dim)
451
+
452
+ self.norm1_context = AdaLayerNormZero(dim)
453
+
454
+ if hasattr(F, "scaled_dot_product_attention"):
455
+ processor = FluxAttnProcessor2_0()
456
+ else:
457
+ raise ValueError(
458
+ "The current PyTorch version does not support the `scaled_dot_product_attention` function."
459
+ )
460
+ self.attn = Attention(
461
+ query_dim=dim,
462
+ cross_attention_dim=None,
463
+ added_kv_proj_dim=dim,
464
+ dim_head=attention_head_dim,
465
+ heads=num_attention_heads,
466
+ out_dim=dim,
467
+ context_pre_only=False,
468
+ bias=True,
469
+ processor=processor,
470
+ qk_norm=qk_norm,
471
+ eps=eps,
472
+ )
473
+
474
+ self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
475
+ self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
476
+ # self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="swiglu")
477
+
478
+ self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
479
+ self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
480
+ # self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="swiglu")
481
+
482
+ # let chunk size default to None
483
+ self._chunk_size = None
484
+ self._chunk_dim = 0
485
+
486
+ def forward(
487
+ self,
488
+ hidden_states: torch.FloatTensor,
489
+ encoder_hidden_states: torch.FloatTensor,
490
+ temb: torch.FloatTensor,
491
+ image_rotary_emb=None,
492
+ ):
493
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
494
+
495
+ norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
496
+ encoder_hidden_states, emb=temb
497
+ )
498
+ # Attention.
499
+ attn_output, context_attn_output = self.attn(
500
+ hidden_states=norm_hidden_states,
501
+ encoder_hidden_states=norm_encoder_hidden_states,
502
+ image_rotary_emb=image_rotary_emb,
503
+ )
504
+
505
+ # Process attention outputs for the `hidden_states`.
506
+ attn_output = gate_msa.unsqueeze(1) * attn_output
507
+ hidden_states = hidden_states + attn_output
508
+
509
+ norm_hidden_states = self.norm2(hidden_states)
510
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
511
+
512
+ ff_output = self.ff(norm_hidden_states)
513
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
514
+
515
+ hidden_states = hidden_states + ff_output
516
+
517
+ # Process attention outputs for the `encoder_hidden_states`.
518
+
519
+ context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
520
+ encoder_hidden_states = encoder_hidden_states + context_attn_output
521
+
522
+ norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
523
+ norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
524
+
525
+ context_ff_output = self.ff_context(norm_encoder_hidden_states)
526
+ encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
527
+ if encoder_hidden_states.dtype == torch.float16:
528
+ encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
529
+
530
+ return encoder_hidden_states, hidden_states
531
+
532
+
533
+ class UVit2DConvEmbed(nn.Module):
534
+ def __init__(self, in_channels, block_out_channels, vocab_size, elementwise_affine, eps, bias):
535
+ super().__init__()
536
+ self.embeddings = nn.Embedding(vocab_size, in_channels)
537
+ self.layer_norm = RMSNorm(in_channels, eps, elementwise_affine)
538
+ self.conv = nn.Conv2d(in_channels, block_out_channels, kernel_size=1, bias=bias)
539
+
540
+ def forward(self, input_ids):
541
+ embeddings = self.embeddings(input_ids)
542
+ embeddings = self.layer_norm(embeddings)
543
+ embeddings = embeddings.permute(0, 3, 1, 2)
544
+ embeddings = self.conv(embeddings)
545
+ return embeddings
546
+
547
+ class ConvMlmLayer(nn.Module):
548
+ def __init__(
549
+ self,
550
+ block_out_channels: int,
551
+ in_channels: int,
552
+ use_bias: bool,
553
+ ln_elementwise_affine: bool,
554
+ layer_norm_eps: float,
555
+ codebook_size: int,
556
+ ):
557
+ super().__init__()
558
+ self.conv1 = nn.Conv2d(block_out_channels, in_channels, kernel_size=1, bias=use_bias)
559
+ self.layer_norm = RMSNorm(in_channels, layer_norm_eps, ln_elementwise_affine)
560
+ self.conv2 = nn.Conv2d(in_channels, codebook_size, kernel_size=1, bias=use_bias)
561
+
562
+ def forward(self, hidden_states):
563
+ hidden_states = self.conv1(hidden_states)
564
+ hidden_states = self.layer_norm(hidden_states.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
565
+ logits = self.conv2(hidden_states)
566
+ return logits
567
+
568
+ class SwiGLU(nn.Module):
569
+ r"""
570
+ A [variant](https://arxiv.org/abs/2002.05202) of the gated linear unit activation function. It's similar to `GEGLU`
571
+ but uses SiLU / Swish instead of GeLU.
572
+
573
+ Parameters:
574
+ dim_in (`int`): The number of channels in the input.
575
+ dim_out (`int`): The number of channels in the output.
576
+ bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
577
+ """
578
+
579
+ def __init__(self, dim_in: int, dim_out: int, bias: bool = True):
580
+ super().__init__()
581
+ self.proj = nn.Linear(dim_in, dim_out * 2, bias=bias)
582
+ self.activation = nn.SiLU()
583
+
584
+ def forward(self, hidden_states):
585
+ hidden_states = self.proj(hidden_states)
586
+ hidden_states, gate = hidden_states.chunk(2, dim=-1)
587
+ return hidden_states * self.activation(gate)
588
+
589
+ class ConvNextBlock(nn.Module):
590
+ def __init__(
591
+ self, channels, layer_norm_eps, ln_elementwise_affine, use_bias, hidden_dropout, hidden_size, res_ffn_factor=4
592
+ ):
593
+ super().__init__()
594
+ self.depthwise = nn.Conv2d(
595
+ channels,
596
+ channels,
597
+ kernel_size=3,
598
+ padding=1,
599
+ groups=channels,
600
+ bias=use_bias,
601
+ )
602
+ self.norm = RMSNorm(channels, layer_norm_eps, ln_elementwise_affine)
603
+ self.channelwise_linear_1 = nn.Linear(channels, int(channels * res_ffn_factor), bias=use_bias)
604
+ self.channelwise_act = nn.GELU()
605
+ self.channelwise_norm = GlobalResponseNorm(int(channels * res_ffn_factor))
606
+ self.channelwise_linear_2 = nn.Linear(int(channels * res_ffn_factor), channels, bias=use_bias)
607
+ self.channelwise_dropout = nn.Dropout(hidden_dropout)
608
+ self.cond_embeds_mapper = nn.Linear(hidden_size, channels * 2, use_bias)
609
+
610
+ def forward(self, x, cond_embeds):
611
+ x_res = x
612
+
613
+ x = self.depthwise(x)
614
+
615
+ x = x.permute(0, 2, 3, 1)
616
+ x = self.norm(x)
617
+
618
+ x = self.channelwise_linear_1(x)
619
+ x = self.channelwise_act(x)
620
+ x = self.channelwise_norm(x)
621
+ x = self.channelwise_linear_2(x)
622
+ x = self.channelwise_dropout(x)
623
+
624
+ x = x.permute(0, 3, 1, 2)
625
+
626
+ x = x + x_res
627
+
628
+ scale, shift = self.cond_embeds_mapper(F.silu(cond_embeds)).chunk(2, dim=1)
629
+ x = x * (1 + scale[:, :, None, None]) + shift[:, :, None, None]
630
+
631
+ return x
632
+
633
+ class Simple_UVitBlock(nn.Module):
634
+ def __init__(
635
+ self,
636
+ channels,
637
+ ln_elementwise_affine,
638
+ layer_norm_eps,
639
+ use_bias,
640
+ downsample: bool,
641
+ upsample: bool,
642
+ ):
643
+ super().__init__()
644
+
645
+ if downsample:
646
+ self.downsample = Downsample2D(
647
+ channels,
648
+ use_conv=True,
649
+ padding=0,
650
+ name="Conv2d_0",
651
+ kernel_size=2,
652
+ norm_type="rms_norm",
653
+ eps=layer_norm_eps,
654
+ elementwise_affine=ln_elementwise_affine,
655
+ bias=use_bias,
656
+ )
657
+ else:
658
+ self.downsample = None
659
+
660
+ if upsample:
661
+ self.upsample = Upsample2D(
662
+ channels,
663
+ use_conv_transpose=True,
664
+ kernel_size=2,
665
+ padding=0,
666
+ name="conv",
667
+ norm_type="rms_norm",
668
+ eps=layer_norm_eps,
669
+ elementwise_affine=ln_elementwise_affine,
670
+ bias=use_bias,
671
+ interpolate=False,
672
+ )
673
+ else:
674
+ self.upsample = None
675
+
676
+ def forward(self, x):
677
+ # print("before,", x.shape)
678
+ if self.downsample is not None:
679
+ # print('downsample')
680
+ x = self.downsample(x)
681
+
682
+ if self.upsample is not None:
683
+ # print('upsample')
684
+ x = self.upsample(x)
685
+ # print("after,", x.shape)
686
+ return x
687
+
688
+ class Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
689
+ """
690
+ The Transformer model introduced in Flux.
691
+
692
+ Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
693
+
694
+ Parameters:
695
+ patch_size (`int`): Patch size to turn the input data into small patches.
696
+ in_channels (`int`, *optional*, defaults to 16): The number of channels in the input.
697
+ num_layers (`int`, *optional*, defaults to 18): The number of layers of MMDiT blocks to use.
698
+ num_single_layers (`int`, *optional*, defaults to 18): The number of layers of single DiT blocks to use.
699
+ attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
700
+ num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention.
701
+ joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
702
+ pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`.
703
+ guidance_embeds (`bool`, defaults to False): Whether to use guidance embeddings.
704
+ """
705
+
706
+ _supports_gradient_checkpointing = False #True
707
+ # Due to NotImplementedError: DDPOptimizer backend: Found a higher order op in the graph. This is not supported. Please turn off DDP optimizer using torch._dynamo.config.optimize_ddp=False. Note that this can cause performance degradation because there will be one bucket for the entire Dynamo graph.
708
+ # Please refer to this issue - https://github.com/pytorch/pytorch/issues/104674.
709
+ _no_split_modules = ["TransformerBlock", "SingleTransformerBlock"]
710
+
711
+ @register_to_config
712
+ def __init__(
713
+ self,
714
+ patch_size: int = 1,
715
+ in_channels: int = 64,
716
+ num_layers: int = 19,
717
+ num_single_layers: int = 38,
718
+ attention_head_dim: int = 128,
719
+ num_attention_heads: int = 24,
720
+ joint_attention_dim: int = 4096,
721
+ pooled_projection_dim: int = 768,
722
+ guidance_embeds: bool = False, # unused in our implementation
723
+ axes_dims_rope: Tuple[int] = (16, 56, 56),
724
+ vocab_size: int = 8256,
725
+ codebook_size: int = 8192,
726
+ downsample: bool = False,
727
+ upsample: bool = False,
728
+ ):
729
+ super().__init__()
730
+ self.out_channels = in_channels
731
+ self.inner_dim = self.num_attention_heads * self.attention_head_dim
732
+
733
+ self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
734
+ text_time_guidance_cls = (
735
+ CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings
736
+ )
737
+ self.time_text_embed = text_time_guidance_cls(
738
+ embedding_dim=self.inner_dim, pooled_projection_dim=self.pooled_projection_dim
739
+ )
740
+
741
+ self.context_embedder = nn.Linear(self.joint_attention_dim, self.inner_dim)
742
+
743
+ self.transformer_blocks = nn.ModuleList(
744
+ [
745
+ TransformerBlock(
746
+ dim=self.inner_dim,
747
+ num_attention_heads=self.num_attention_heads,
748
+ attention_head_dim=self.attention_head_dim,
749
+ )
750
+ for i in range(self.num_layers)
751
+ ]
752
+ )
753
+
754
+ self.single_transformer_blocks = nn.ModuleList(
755
+ [
756
+ SingleTransformerBlock(
757
+ dim=self.inner_dim,
758
+ num_attention_heads=self.num_attention_heads,
759
+ attention_head_dim=self.attention_head_dim,
760
+ )
761
+ for i in range(self.num_single_layers)
762
+ ]
763
+ )
764
+
765
+
766
+ self.gradient_checkpointing = False
767
+
768
+ in_channels_embed = self.inner_dim
769
+ ln_elementwise_affine = True
770
+ layer_norm_eps = 1e-06
771
+ use_bias = False
772
+ micro_cond_embed_dim = 1280
773
+ self.embed = UVit2DConvEmbed(
774
+ in_channels_embed, self.inner_dim, self.vocab_size, ln_elementwise_affine, layer_norm_eps, use_bias
775
+ )
776
+ self.mlm_layer = ConvMlmLayer(
777
+ self.inner_dim, in_channels_embed, use_bias, ln_elementwise_affine, layer_norm_eps, self.codebook_size
778
+ )
779
+ self.cond_embed = TimestepEmbedding(
780
+ micro_cond_embed_dim + self.pooled_projection_dim, self.inner_dim, sample_proj_bias=use_bias
781
+ )
782
+ self.encoder_proj_layer_norm = RMSNorm(self.inner_dim, layer_norm_eps, ln_elementwise_affine)
783
+ self.project_to_hidden_norm = RMSNorm(in_channels_embed, layer_norm_eps, ln_elementwise_affine)
784
+ self.project_to_hidden = nn.Linear(in_channels_embed, self.inner_dim, bias=use_bias)
785
+ self.project_from_hidden_norm = RMSNorm(self.inner_dim, layer_norm_eps, ln_elementwise_affine)
786
+ self.project_from_hidden = nn.Linear(self.inner_dim, in_channels_embed, bias=use_bias)
787
+
788
+ self.down_block = Simple_UVitBlock(
789
+ self.inner_dim,
790
+ ln_elementwise_affine,
791
+ layer_norm_eps,
792
+ use_bias,
793
+ downsample,
794
+ False,
795
+ )
796
+ self.up_block = Simple_UVitBlock(
797
+ self.inner_dim, #block_out_channels,
798
+ ln_elementwise_affine,
799
+ layer_norm_eps,
800
+ use_bias,
801
+ False,
802
+ upsample=upsample,
803
+ )
804
+
805
+ # self.fuse_qkv_projections()
806
+
807
+ @property
808
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
809
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
810
+ r"""
811
+ Returns:
812
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
813
+ indexed by its weight name.
814
+ """
815
+ # set recursively
816
+ processors = {}
817
+
818
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
819
+ if hasattr(module, "get_processor"):
820
+ processors[f"{name}.processor"] = module.get_processor()
821
+
822
+ for sub_name, child in module.named_children():
823
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
824
+
825
+ return processors
826
+
827
+ for name, module in self.named_children():
828
+ fn_recursive_add_processors(name, module, processors)
829
+
830
+ return processors
831
+
832
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
833
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
834
+ r"""
835
+ Sets the attention processor to use to compute attention.
836
+
837
+ Parameters:
838
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
839
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
840
+ for **all** `Attention` layers.
841
+
842
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
843
+ processor. This is strongly recommended when setting trainable attention processors.
844
+
845
+ """
846
+ count = len(self.attn_processors.keys())
847
+
848
+ if isinstance(processor, dict) and len(processor) != count:
849
+ raise ValueError(
850
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
851
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
852
+ )
853
+
854
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
855
+ if hasattr(module, "set_processor"):
856
+ if not isinstance(processor, dict):
857
+ module.set_processor(processor)
858
+ else:
859
+ module.set_processor(processor.pop(f"{name}.processor"))
860
+
861
+ for sub_name, child in module.named_children():
862
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
863
+
864
+ for name, module in self.named_children():
865
+ fn_recursive_attn_processor(name, module, processor)
866
+
867
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedFluxAttnProcessor2_0
868
+ def fuse_qkv_projections(self):
869
+ """
870
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
871
+ are fused. For cross-attention modules, key and value projection matrices are fused.
872
+
873
+ <Tip warning={true}>
874
+
875
+ This API is 🧪 experimental.
876
+
877
+ </Tip>
878
+ """
879
+ self.original_attn_processors = None
880
+
881
+ for _, attn_processor in self.attn_processors.items():
882
+ if "Added" in str(attn_processor.__class__.__name__):
883
+ raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
884
+
885
+ self.original_attn_processors = self.attn_processors
886
+
887
+ for module in self.modules():
888
+ if isinstance(module, Attention):
889
+ module.fuse_projections(fuse=True)
890
+
891
+ self.set_attn_processor(FusedFluxAttnProcessor2_0())
892
+
893
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
894
+ def unfuse_qkv_projections(self):
895
+ """Disables the fused QKV projection if enabled.
896
+
897
+ <Tip warning={true}>
898
+
899
+ This API is 🧪 experimental.
900
+
901
+ </Tip>
902
+
903
+ """
904
+ if self.original_attn_processors is not None:
905
+ self.set_attn_processor(self.original_attn_processors)
906
+
907
+ def _set_gradient_checkpointing(self, module, value=False):
908
+ if hasattr(module, "gradient_checkpointing"):
909
+ module.gradient_checkpointing = value
910
+
911
+ def forward(
912
+ self,
913
+ hidden_states: torch.Tensor,
914
+ encoder_hidden_states: torch.Tensor = None,
915
+ pooled_projections: torch.Tensor = None,
916
+ timestep: torch.LongTensor = None,
917
+ img_ids: torch.Tensor = None,
918
+ txt_ids: torch.Tensor = None,
919
+ guidance: torch.Tensor = None,
920
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
921
+ controlnet_block_samples= None,
922
+ controlnet_single_block_samples=None,
923
+ return_dict: bool = True,
924
+ micro_conds: torch.Tensor = None,
925
+ ) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
926
+ """
927
+ The [`FluxTransformer2DModel`] forward method.
928
+
929
+ Args:
930
+ hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
931
+ Input `hidden_states`.
932
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
933
+ Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
934
+ pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
935
+ from the embeddings of input conditions.
936
+ timestep ( `torch.LongTensor`):
937
+ Used to indicate denoising step.
938
+ block_controlnet_hidden_states: (`list` of `torch.Tensor`):
939
+ A list of tensors that if specified are added to the residuals of transformer blocks.
940
+ joint_attention_kwargs (`dict`, *optional*):
941
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
942
+ `self.processor` in
943
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
944
+ return_dict (`bool`, *optional*, defaults to `True`):
945
+ Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
946
+ tuple.
947
+
948
+ Returns:
949
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
950
+ `tuple` where the first element is the sample tensor.
951
+ """
952
+ micro_cond_encode_dim = 256 # same as self.micro_cond_encode_dim = 256 from amused
953
+ micro_cond_embeds = get_timestep_embedding(
954
+ micro_conds.flatten(), micro_cond_encode_dim, flip_sin_to_cos=True, downscale_freq_shift=0
955
+ )
956
+ micro_cond_embeds = micro_cond_embeds.reshape((hidden_states.shape[0], -1))
957
+
958
+ pooled_projections = torch.cat([pooled_projections, micro_cond_embeds], dim=1)
959
+ pooled_projections = pooled_projections.to(dtype=self.dtype)
960
+ pooled_projections = self.cond_embed(pooled_projections).to(encoder_hidden_states.dtype)
961
+
962
+
963
+ hidden_states = self.embed(hidden_states)
964
+
965
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states)
966
+ encoder_hidden_states = self.encoder_proj_layer_norm(encoder_hidden_states)
967
+ hidden_states = self.down_block(hidden_states)
968
+
969
+ batch_size, channels, height, width = hidden_states.shape
970
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch_size, height * width, channels)
971
+ hidden_states = self.project_to_hidden_norm(hidden_states)
972
+ hidden_states = self.project_to_hidden(hidden_states)
973
+
974
+
975
+ if joint_attention_kwargs is not None:
976
+ joint_attention_kwargs = joint_attention_kwargs.copy()
977
+ lora_scale = joint_attention_kwargs.pop("scale", 1.0)
978
+ else:
979
+ lora_scale = 1.0
980
+
981
+ if USE_PEFT_BACKEND:
982
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
983
+ scale_lora_layers(self, lora_scale)
984
+ else:
985
+ if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
986
+ logger.warning(
987
+ "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
988
+ )
989
+
990
+ timestep = timestep.to(hidden_states.dtype) * 1000
991
+ if guidance is not None:
992
+ guidance = guidance.to(hidden_states.dtype) * 1000
993
+ else:
994
+ guidance = None
995
+ temb = (
996
+ self.time_text_embed(timestep, pooled_projections)
997
+ if guidance is None
998
+ else self.time_text_embed(timestep, guidance, pooled_projections)
999
+ )
1000
+
1001
+ if txt_ids.ndim == 3:
1002
+ logger.warning(
1003
+ "Passing `txt_ids` 3d torch.Tensor is deprecated."
1004
+ "Please remove the batch dimension and pass it as a 2d torch Tensor"
1005
+ )
1006
+ txt_ids = txt_ids[0]
1007
+ if img_ids.ndim == 3:
1008
+ logger.warning(
1009
+ "Passing `img_ids` 3d torch.Tensor is deprecated."
1010
+ "Please remove the batch dimension and pass it as a 2d torch Tensor"
1011
+ )
1012
+ img_ids = img_ids[0]
1013
+ ids = torch.cat((txt_ids, img_ids), dim=0)
1014
+
1015
+ image_rotary_emb = self.pos_embed(ids)
1016
+
1017
+ for index_block, block in enumerate(self.transformer_blocks):
1018
+ if self.training and self.gradient_checkpointing:
1019
+
1020
+ def create_custom_forward(module, return_dict=None):
1021
+ def custom_forward(*inputs):
1022
+ if return_dict is not None:
1023
+ return module(*inputs, return_dict=return_dict)
1024
+ else:
1025
+ return module(*inputs)
1026
+
1027
+ return custom_forward
1028
+
1029
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
1030
+ encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
1031
+ create_custom_forward(block),
1032
+ hidden_states,
1033
+ encoder_hidden_states,
1034
+ temb,
1035
+ image_rotary_emb,
1036
+ **ckpt_kwargs,
1037
+ )
1038
+
1039
+ else:
1040
+ encoder_hidden_states, hidden_states = block(
1041
+ hidden_states=hidden_states,
1042
+ encoder_hidden_states=encoder_hidden_states,
1043
+ temb=temb,
1044
+ image_rotary_emb=image_rotary_emb,
1045
+ )
1046
+
1047
+
1048
+ # controlnet residual
1049
+ if controlnet_block_samples is not None:
1050
+ interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
1051
+ interval_control = int(np.ceil(interval_control))
1052
+ hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
1053
+
1054
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
1055
+
1056
+ for index_block, block in enumerate(self.single_transformer_blocks):
1057
+ if self.training and self.gradient_checkpointing:
1058
+
1059
+ def create_custom_forward(module, return_dict=None):
1060
+ def custom_forward(*inputs):
1061
+ if return_dict is not None:
1062
+ return module(*inputs, return_dict=return_dict)
1063
+ else:
1064
+ return module(*inputs)
1065
+
1066
+ return custom_forward
1067
+
1068
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
1069
+ hidden_states = torch.utils.checkpoint.checkpoint(
1070
+ create_custom_forward(block),
1071
+ hidden_states,
1072
+ temb,
1073
+ image_rotary_emb,
1074
+ **ckpt_kwargs,
1075
+ )
1076
+
1077
+ else:
1078
+ hidden_states = block(
1079
+ hidden_states=hidden_states,
1080
+ temb=temb,
1081
+ image_rotary_emb=image_rotary_emb,
1082
+ )
1083
+
1084
+ # controlnet residual
1085
+ if controlnet_single_block_samples is not None:
1086
+ interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)
1087
+ interval_control = int(np.ceil(interval_control))
1088
+ hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
1089
+ hidden_states[:, encoder_hidden_states.shape[1] :, ...]
1090
+ + controlnet_single_block_samples[index_block // interval_control]
1091
+ )
1092
+
1093
+ hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
1094
+
1095
+
1096
+ hidden_states = self.project_from_hidden_norm(hidden_states)
1097
+ hidden_states = self.project_from_hidden(hidden_states)
1098
+
1099
+
1100
+ hidden_states = hidden_states.reshape(batch_size, height, width, channels).permute(0, 3, 1, 2)
1101
+
1102
+ hidden_states = self.up_block(hidden_states)
1103
+
1104
+ if USE_PEFT_BACKEND:
1105
+ # remove `lora_scale` from each PEFT layer
1106
+ unscale_lora_layers(self, lora_scale)
1107
+
1108
+ output = self.mlm_layer(hidden_states)
1109
+ # self.unfuse_qkv_projections()
1110
+ if not return_dict:
1111
+ return (output,)
1112
+
1113
+
1114
+ return output
1115
+
1116
+
1117
+ class SymmetricTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
1118
+ """
1119
+ The Transformer model introduced in Flux.
1120
+
1121
+ Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
1122
+
1123
+ Parameters:
1124
+ patch_size (`int`): Patch size to turn the input data into small patches.
1125
+ in_channels (`int`, *optional*, defaults to 16): The number of channels in the input.
1126
+ num_layers (`int`, *optional*, defaults to 18): The number of layers of MMDiT blocks to use.
1127
+ num_single_layers (`int`, *optional*, defaults to 18): The number of layers of single DiT blocks to use.
1128
+ attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
1129
+ num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention.
1130
+ joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
1131
+ pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`.
1132
+ guidance_embeds (`bool`, defaults to False): Whether to use guidance embeddings.
1133
+ """
1134
+
1135
+ _supports_gradient_checkpointing = False #True
1136
+ # Due to NotImplementedError: DDPOptimizer backend: Found a higher order op in the graph. This is not supported. Please turn off DDP optimizer using torch._dynamo.config.optimize_ddp=False. Note that this can cause performance degradation because there will be one bucket for the entire Dynamo graph.
1137
+ # Please refer to this issue - https://github.com/pytorch/pytorch/issues/104674.
1138
+ _no_split_modules = ["TransformerBlock", "SingleTransformerBlock"]
1139
+
1140
+ @register_to_config
1141
+ def __init__(
1142
+ self,
1143
+ patch_size: int = 1,
1144
+ in_channels: int = 64,
1145
+ num_layers: int = 19,
1146
+ num_single_layers: int = 38,
1147
+ attention_head_dim: int = 128,
1148
+ num_attention_heads: int = 24,
1149
+ joint_attention_dim: int = 4096,
1150
+ pooled_projection_dim: int = 768,
1151
+ guidance_embeds: bool = False, # unused in our implementation
1152
+ axes_dims_rope: Tuple[int] = (16, 56, 56),
1153
+ vocab_size: int = 8256,
1154
+ codebook_size: int = 8192,
1155
+ tokenizer_vocab_size: Optional[int] = None,
1156
+ t5_dim: Optional[int] = None,
1157
+ downsample: bool = False,
1158
+ upsample: bool = False,
1159
+ ):
1160
+ super().__init__()
1161
+ self.out_channels = in_channels
1162
+ self.inner_dim = self.num_attention_heads * self.attention_head_dim
1163
+
1164
+ self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
1165
+ text_time_guidance_cls = (
1166
+ CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings
1167
+ )
1168
+ self.time_text_embed = text_time_guidance_cls(
1169
+ embedding_dim=self.inner_dim, pooled_projection_dim=self.inner_dim
1170
+ )
1171
+
1172
+ if t5_dim is not None:
1173
+ self.adapter = nn.Sequential(
1174
+ nn.LayerNorm(t5_dim, elementwise_affine=False, eps=1e-6),
1175
+ nn.Linear(t5_dim, self.joint_attention_dim, bias=False)
1176
+ )
1177
+ else:
1178
+ self.adapter = None
1179
+
1180
+ self.context_embedder = nn.Linear(self.joint_attention_dim, self.inner_dim)
1181
+
1182
+ self.transformer_blocks = nn.ModuleList(
1183
+ [
1184
+ TransformerBlock(
1185
+ dim=self.inner_dim,
1186
+ num_attention_heads=self.num_attention_heads,
1187
+ attention_head_dim=self.attention_head_dim,
1188
+ )
1189
+ for i in range(self.num_layers)
1190
+ ]
1191
+ )
1192
+
1193
+ self.single_transformer_blocks = nn.ModuleList(
1194
+ [
1195
+ SingleTransformerBlock(
1196
+ dim=self.inner_dim,
1197
+ num_attention_heads=self.num_attention_heads,
1198
+ attention_head_dim=self.attention_head_dim,
1199
+ )
1200
+ for i in range(self.num_single_layers)
1201
+ ]
1202
+ )
1203
+
1204
+ self.gradient_checkpointing = False
1205
+
1206
+ in_channels_embed = self.inner_dim
1207
+ ln_elementwise_affine = True
1208
+ layer_norm_eps = 1e-06
1209
+ use_bias = False
1210
+ micro_cond_embed_dim = 1280
1211
+ self.embed = UVit2DConvEmbed(
1212
+ in_channels_embed, self.inner_dim, self.vocab_size, ln_elementwise_affine, layer_norm_eps, use_bias
1213
+ )
1214
+ self.mlm_layer = ConvMlmLayer(
1215
+ self.inner_dim, in_channels_embed, use_bias, ln_elementwise_affine, layer_norm_eps, self.codebook_size
1216
+ )
1217
+ self.cond_embed = TimestepEmbedding(
1218
+ micro_cond_embed_dim + self.pooled_projection_dim, self.inner_dim, sample_proj_bias=use_bias
1219
+ )
1220
+ self.encoder_proj_layer_norm = RMSNorm(self.inner_dim, layer_norm_eps, ln_elementwise_affine)
1221
+ self.project_to_hidden_norm = RMSNorm(in_channels_embed, layer_norm_eps, ln_elementwise_affine)
1222
+ self.project_to_hidden = nn.Linear(in_channels_embed, self.inner_dim, bias=use_bias)
1223
+ self.project_from_hidden_norm = RMSNorm(self.inner_dim, layer_norm_eps, ln_elementwise_affine)
1224
+ self.project_from_hidden = nn.Linear(self.inner_dim, in_channels_embed, bias=use_bias)
1225
+
1226
+ self.down_block = Simple_UVitBlock(
1227
+ self.inner_dim,
1228
+ ln_elementwise_affine,
1229
+ layer_norm_eps,
1230
+ use_bias,
1231
+ downsample,
1232
+ False,
1233
+ )
1234
+ self.up_block = Simple_UVitBlock(
1235
+ self.inner_dim,
1236
+ ln_elementwise_affine,
1237
+ layer_norm_eps,
1238
+ use_bias,
1239
+ False,
1240
+ upsample=upsample,
1241
+ )
1242
+
1243
+ if tokenizer_vocab_size is not None:
1244
+ self.text_decoder = nn.Sequential(
1245
+ nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6),
1246
+ nn.Linear(self.inner_dim, tokenizer_vocab_size, bias=use_bias)
1247
+ )
1248
+ else:
1249
+ self.text_decoder = None
1250
+
1251
+
1252
+ def forward(
1253
+ self,
1254
+ hidden_states: torch.Tensor,
1255
+ encoder_hidden_states: torch.Tensor = None,
1256
+ pooled_projections: torch.Tensor = None,
1257
+ timestep: torch.LongTensor = None,
1258
+ img_ids: torch.Tensor = None,
1259
+ txt_ids: torch.Tensor = None,
1260
+ guidance: torch.Tensor = None,
1261
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
1262
+ controlnet_block_samples= None,
1263
+ controlnet_single_block_samples=None,
1264
+ return_dict: bool = True,
1265
+ micro_conds: torch.Tensor = None,
1266
+ ) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
1267
+ """
1268
+ The [`FluxTransformer2DModel`] forward method.
1269
+
1270
+ Args:
1271
+ hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
1272
+ Input `hidden_states`.
1273
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
1274
+ Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
1275
+ pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
1276
+ from the embeddings of input conditions.
1277
+ timestep ( `torch.LongTensor`):
1278
+ Used to indicate denoising step.
1279
+ block_controlnet_hidden_states: (`list` of `torch.Tensor`):
1280
+ A list of tensors that if specified are added to the residuals of transformer blocks.
1281
+ joint_attention_kwargs (`dict`, *optional*):
1282
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
1283
+ `self.processor` in
1284
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
1285
+ return_dict (`bool`, *optional*, defaults to `True`):
1286
+ Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
1287
+ tuple.
1288
+
1289
+ Returns:
1290
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
1291
+ `tuple` where the first element is the sample tensor.
1292
+ """
1293
+ micro_cond_encode_dim = 256 # same as self.micro_cond_encode_dim = 256 from amused
1294
+ micro_cond_embeds = get_timestep_embedding(
1295
+ micro_conds.flatten(), micro_cond_encode_dim, flip_sin_to_cos=True, downscale_freq_shift=0
1296
+ )
1297
+ micro_cond_embeds = micro_cond_embeds.reshape((hidden_states.shape[0], -1))
1298
+
1299
+ if self.adapter is not None:
1300
+ encoder_hidden_states = self.adapter(encoder_hidden_states)
1301
+
1302
+ pooled_projections = torch.cat([pooled_projections, micro_cond_embeds], dim=1)
1303
+ pooled_projections = pooled_projections.to(dtype=self.dtype)
1304
+ pooled_projections = self.cond_embed(pooled_projections).to(encoder_hidden_states.dtype)
1305
+
1306
+ hidden_states = self.embed(hidden_states)
1307
+
1308
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states)
1309
+ encoder_hidden_states = self.encoder_proj_layer_norm(encoder_hidden_states)
1310
+ hidden_states = self.down_block(hidden_states)
1311
+
1312
+ batch_size, channels, height, width = hidden_states.shape
1313
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch_size, height * width, channels)
1314
+ hidden_states = self.project_to_hidden_norm(hidden_states)
1315
+ hidden_states = self.project_to_hidden(hidden_states)
1316
+
1317
+
1318
+ if joint_attention_kwargs is not None:
1319
+ joint_attention_kwargs = joint_attention_kwargs.copy()
1320
+ lora_scale = joint_attention_kwargs.pop("scale", 1.0)
1321
+ else:
1322
+ lora_scale = 1.0
1323
+
1324
+ if USE_PEFT_BACKEND:
1325
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
1326
+ scale_lora_layers(self, lora_scale)
1327
+ else:
1328
+ if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
1329
+ logger.warning(
1330
+ "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
1331
+ )
1332
+
1333
+ timestep = timestep.to(hidden_states.dtype) * 1000
1334
+ if guidance is not None:
1335
+ guidance = guidance.to(hidden_states.dtype) * 1000
1336
+ else:
1337
+ guidance = None
1338
+ temb = (
1339
+ self.time_text_embed(timestep, pooled_projections)
1340
+ if guidance is None
1341
+ else self.time_text_embed(timestep, guidance, pooled_projections)
1342
+ )
1343
+
1344
+ if txt_ids.ndim == 3:
1345
+ logger.warning(
1346
+ "Passing `txt_ids` 3d torch.Tensor is deprecated."
1347
+ "Please remove the batch dimension and pass it as a 2d torch Tensor"
1348
+ )
1349
+ txt_ids = txt_ids[0]
1350
+ if img_ids.ndim == 3:
1351
+ logger.warning(
1352
+ "Passing `img_ids` 3d torch.Tensor is deprecated."
1353
+ "Please remove the batch dimension and pass it as a 2d torch Tensor"
1354
+ )
1355
+ img_ids = img_ids[0]
1356
+ ids = torch.cat((txt_ids, img_ids), dim=0)
1357
+
1358
+ image_rotary_emb = self.pos_embed(ids)
1359
+
1360
+ for index_block, block in enumerate(self.transformer_blocks):
1361
+ if self.training and self.gradient_checkpointing:
1362
+
1363
+ def create_custom_forward(module, return_dict=None):
1364
+ def custom_forward(*inputs):
1365
+ if return_dict is not None:
1366
+ return module(*inputs, return_dict=return_dict)
1367
+ else:
1368
+ return module(*inputs)
1369
+
1370
+ return custom_forward
1371
+
1372
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
1373
+ encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
1374
+ create_custom_forward(block),
1375
+ hidden_states,
1376
+ encoder_hidden_states,
1377
+ temb,
1378
+ image_rotary_emb,
1379
+ **ckpt_kwargs,
1380
+ )
1381
+
1382
+ else:
1383
+ encoder_hidden_states, hidden_states = block(
1384
+ hidden_states=hidden_states,
1385
+ encoder_hidden_states=encoder_hidden_states,
1386
+ temb=temb,
1387
+ image_rotary_emb=image_rotary_emb,
1388
+ )
1389
+
1390
+
1391
+ # controlnet residual
1392
+ if controlnet_block_samples is not None:
1393
+ interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
1394
+ interval_control = int(np.ceil(interval_control))
1395
+ hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
1396
+
1397
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
1398
+
1399
+ for index_block, block in enumerate(self.single_transformer_blocks):
1400
+ if self.training and self.gradient_checkpointing:
1401
+
1402
+ def create_custom_forward(module, return_dict=None):
1403
+ def custom_forward(*inputs):
1404
+ if return_dict is not None:
1405
+ return module(*inputs, return_dict=return_dict)
1406
+ else:
1407
+ return module(*inputs)
1408
+
1409
+ return custom_forward
1410
+
1411
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
1412
+ hidden_states = torch.utils.checkpoint.checkpoint(
1413
+ create_custom_forward(block),
1414
+ hidden_states,
1415
+ temb,
1416
+ image_rotary_emb,
1417
+ **ckpt_kwargs,
1418
+ )
1419
+
1420
+ else:
1421
+ hidden_states = block(
1422
+ hidden_states=hidden_states,
1423
+ temb=temb,
1424
+ image_rotary_emb=image_rotary_emb,
1425
+ )
1426
+
1427
+ # controlnet residual
1428
+ if controlnet_single_block_samples is not None:
1429
+ interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)
1430
+ interval_control = int(np.ceil(interval_control))
1431
+ hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
1432
+ hidden_states[:, encoder_hidden_states.shape[1] :, ...]
1433
+ + controlnet_single_block_samples[index_block // interval_control]
1434
+ )
1435
+
1436
+ encoder_hidden_states = hidden_states[:, :encoder_hidden_states.shape[1], ...]
1437
+ hidden_states = hidden_states[:, encoder_hidden_states.shape[1]:, ...]
1438
+
1439
+ if self.text_decoder is not None:
1440
+ encoder_hidden_states = self.text_decoder(encoder_hidden_states)
1441
+
1442
+ hidden_states = self.project_from_hidden_norm(hidden_states)
1443
+ hidden_states = self.project_from_hidden(hidden_states)
1444
+
1445
+ hidden_states = hidden_states.reshape(batch_size, height, width, channels).permute(0, 3, 1, 2)
1446
+
1447
+ hidden_states = self.up_block(hidden_states)
1448
+
1449
+ if USE_PEFT_BACKEND:
1450
+ # remove `lora_scale` from each PEFT layer
1451
+ unscale_lora_layers(self, lora_scale)
1452
+
1453
+ output = self.mlm_layer(hidden_states)
1454
+ # self.unfuse_qkv_projections()
1455
+ if not return_dict:
1456
+ return (output, encoder_hidden_states)
1457
+
1458
+
1459
+ return output, encoder_hidden_states # [b, l, tokenizer_vocab_size]
train/__pycache__/dataset_utils.cpython-39.pyc ADDED
Binary file (12.5 kB). View file
 
train/__pycache__/trainer_utils.cpython-310.pyc ADDED
Binary file (2.8 kB). View file
 
train/__pycache__/trainer_utils.cpython-39.pyc ADDED
Binary file (2.81 kB). View file
 
train/dataset_utils.py ADDED
@@ -0,0 +1,472 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team and The MeissonFlow Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ import torch
17
+ from torch.utils.data import Dataset
18
+ from torchvision import transforms
19
+ from PIL.ImageOps import exif_transpose
20
+ from PIL import Image
21
+ import io
22
+ import json
23
+ import numpy as np
24
+ import pyarrow.parquet as pq
25
+ import random
26
+ import bisect
27
+ import pyarrow.fs as fs
28
+
29
+
30
+ @torch.no_grad()
31
+ def tokenize_prompt(
32
+ tokenizer,
33
+ prompt,
34
+ text_encoder_architecture='open_clip',
35
+ padding='max_length',
36
+ max_length=77,
37
+ max_length_t5=256,
38
+ ):
39
+ if text_encoder_architecture == 'CLIP' or text_encoder_architecture == 'open_clip':
40
+ input_ids = tokenizer(
41
+ prompt,
42
+ truncation=True,
43
+ padding=padding,
44
+ max_length=max_length,
45
+ return_tensors="pt",
46
+ ).input_ids
47
+ return input_ids
48
+ elif text_encoder_architecture == 't5_clip': # we have two tokenizers, 1st for CLIP, 2nd for T5
49
+ input_ids = []
50
+ input_ids.append(tokenizer[0](
51
+ prompt,
52
+ truncation=True,
53
+ padding=padding,
54
+ max_length=max_length,
55
+ return_tensors="pt",
56
+ ).input_ids)
57
+ input_ids.append(tokenizer[1](
58
+ prompt,
59
+ truncation=True,
60
+ padding=padding,
61
+ max_length=max_length_t5,
62
+ return_tensors="pt",
63
+ ).input_ids)
64
+ return input_ids
65
+ elif text_encoder_architecture == "gemma":
66
+ input_ids = []
67
+ input_ids.append(tokenizer[0](
68
+ prompt,
69
+ truncation=True,
70
+ padding=padding,
71
+ max_length=max_length,
72
+ return_tensors="pt",
73
+ ).input_ids)
74
+ input_ids.append(tokenizer[1](
75
+ prompt,
76
+ truncation=True,
77
+ padding=padding,
78
+ max_length=max_length_t5,
79
+ return_tensors="pt",
80
+ ).input_ids)
81
+ return input_ids
82
+ else:
83
+ raise ValueError(f"Unknown text_encoder_architecture: {text_encoder_architecture}")
84
+
85
+
86
+ def encode_prompt(
87
+ text_encoder,
88
+ input_ids,
89
+ text_encoder_architecture='open_clip'
90
+ ):
91
+ if text_encoder_architecture == 'CLIP' or text_encoder_architecture == 'open_clip':
92
+ outputs = text_encoder(input_ids=input_ids, return_dict=True, output_hidden_states=True)
93
+ encoder_hidden_states = outputs.hidden_states[-2]
94
+ cond_embeds = outputs[0]
95
+ return encoder_hidden_states, cond_embeds
96
+ elif text_encoder_architecture == 't5_clip':
97
+ outputs_clip = text_encoder[0](
98
+ input_ids=input_ids[0],
99
+ return_dict=True,
100
+ output_hidden_states=True
101
+ )
102
+ outputs_t5 = text_encoder[1](
103
+ input_ids=input_ids[1],
104
+ return_dict=True,
105
+ output_hidden_states=True
106
+ )
107
+ encoder_hidden_states = outputs_t5.last_hidden_state
108
+ cond_embeds = outputs_clip.text_embeds
109
+ return encoder_hidden_states, cond_embeds
110
+ elif text_encoder_architecture == "gemma":
111
+ outputs_clip = text_encoder[0](
112
+ input_ids=input_ids[0],
113
+ return_dict=True,
114
+ output_hidden_states=True
115
+ )
116
+ outputs_gemma = text_encoder[1](
117
+ input_ids=input_ids[1],
118
+ return_dict=True,
119
+ output_hidden_states=True
120
+ )
121
+ encoder_hidden_states = outputs_gemma.last_hidden_state
122
+ cond_embeds = outputs_clip.text_embeds
123
+ return encoder_hidden_states, cond_embeds
124
+ else:
125
+ raise ValueError(f"Unknown text_encoder_architecture: {text_encoder_architecture}")
126
+
127
+
128
+ def process_image(image, size, Norm=False, hps_score=6.0):
129
+ image = exif_transpose(image)
130
+
131
+ if not image.mode == "RGB":
132
+ image = image.convert("RGB")
133
+
134
+ orig_height = image.height
135
+ orig_width = image.width
136
+
137
+ image = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR)(image)
138
+
139
+ c_top, c_left, _, _ = transforms.RandomCrop.get_params(image, output_size=(size, size))
140
+ image = transforms.functional.crop(image, c_top, c_left, size, size)
141
+ image = transforms.ToTensor()(image)
142
+
143
+ if Norm:
144
+ image = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)(image)
145
+
146
+ micro_conds = torch.tensor(
147
+ [orig_width, orig_height, c_top, c_left, hps_score],
148
+ )
149
+
150
+ return {"image": image, "micro_conds": micro_conds}
151
+
152
+
153
+ class ImageCaptionLargeDataset(Dataset):
154
+ def __init__(
155
+ self,
156
+ root_dir,
157
+ tokenizer,
158
+ size,
159
+ text_encoder_architecture="CLIP",
160
+ norm=False
161
+ ):
162
+ self.root_dir = root_dir
163
+ self.tokenizer = tokenizer
164
+ self.size = size
165
+ self.text_encoder_architecture = text_encoder_architecture
166
+ self.norm = norm
167
+
168
+ self.data_list = []
169
+ for root, dirnames, filenames in os.walk(root_dir):
170
+ for filename in filenames:
171
+ if filename.endswith(".jpg") or filename.endswith(".png"):
172
+ base_name = os.path.splitext(filename)[0]
173
+ txt_file = os.path.join(root, base_name + ".txt")
174
+ if os.path.exists(txt_file):
175
+ self.data_list.append((root, base_name + ".txt", filename))
176
+
177
+ def __len__(self):
178
+ return len(self.data_list)
179
+
180
+ def __getitem__(self, idx):
181
+ try:
182
+ sub_dir, txtfilename, imgfilename = self.data_list[idx]
183
+ img_path = os.path.join(sub_dir, imgfilename)
184
+ caption_path = os.path.join(sub_dir, txtfilename)
185
+
186
+ image = Image.open(img_path).convert("RGB")
187
+ ret = process_image(image, self.size, self.norm)
188
+
189
+ with open(caption_path, "r", encoding="utf-8") as f:
190
+ caption = f.read().strip()
191
+
192
+ ret["prompt_input_ids"] = tokenize_prompt(
193
+ self.tokenizer, caption, self.text_encoder_architecture
194
+ )
195
+
196
+ return ret
197
+
198
+ except Exception as e:
199
+ print("===========================================")
200
+ print(f"[Warning] Error at index {idx}: {img_path}")
201
+ print("===========================================")
202
+ if idx + 1 < len(self.data_list):
203
+ return self.__getitem__(idx + 1)
204
+ else:
205
+ return self.__getitem__(len(self.data_list) - 1)
206
+
207
+
208
+ class MultiSourceVLDataset(Dataset):
209
+ """
210
+ A unified dataloader for
211
+ • LLaVA-Instruct-150K
212
+ • MMMU (multiple-choice QA)
213
+ • VQAv2
214
+ • Local caption files under `pdd3/`
215
+ """
216
+
217
+ def __init__(
218
+ self,
219
+ tokenizer,
220
+ size: int,
221
+ text_encoder_architecture: str = "CLIP",
222
+ norm: bool = False,
223
+ # ----- paths -----
224
+ llava_json: str = None, llava_img_root: str = None,
225
+ mmmu_json: str = None, mmmu_img_root: str = None,
226
+ vqa_ann_json: str = None, vqa_img_root: str = None,
227
+ gqa_json: str = None, gqa_img_root: str = None,
228
+ coco_json: str = None, coco_img_root: str = None,
229
+ coco_qa_json: str = None,
230
+ mg_llava_json: str = None, mg_llava_root: str = None,
231
+ pdd3_dir: str = None, caption_dir: str = None,
232
+ ):
233
+ self.tokenizer = tokenizer
234
+ self.size = size
235
+ self.arch = text_encoder_architecture
236
+ self.norm = norm
237
+
238
+ self.gen_samples = [] # [(img_path, prompt), ...]
239
+ self.mmu_samples = [] # [(img_path, question, answer), ...]
240
+
241
+ if llava_json:
242
+ self._load_llava(llava_json, llava_img_root)
243
+ if mmmu_json:
244
+ self._load_mmmu(mmmu_json, mmmu_img_root)
245
+ if vqa_ann_json:
246
+ self._load_vqav2(vqa_ann_json, vqa_img_root)
247
+ if coco_json:
248
+ self._load_coco2014_captions(coco_json, coco_img_root)
249
+ if coco_qa_json:
250
+ self._load_coco2014_qa(coco_qa_json, coco_img_root)
251
+ if gqa_json:
252
+ self._load_gqa(gqa_json, gqa_img_root)
253
+ if mg_llava_json:
254
+ self._load_mg_llava(mg_llava_json, mg_llava_root)
255
+ if caption_dir:
256
+ self._load_caption(caption_dir)
257
+ if pdd3_dir:
258
+ self._load_pdd3(pdd3_dir)
259
+
260
+ self.len_mmu = len(self.mmu_samples)
261
+ self.len_gen = len(self.gen_samples)
262
+
263
+ # ------------------------------------------------------------------ #
264
+ # dataset parsers #
265
+ # ------------------------------------------------------------------ #
266
+ def _load_llava(self, json_path, img_root):
267
+ with open(json_path, "r", encoding="utf-8") as f:
268
+ data = json.load(f)
269
+
270
+ for ex in data:
271
+ img_file = os.path.join(img_root, ex["image"])
272
+
273
+ human_msg = next(m["value"] for m in ex["conversations"] if m["from"] == "human")
274
+ gpt_msg = next(m["value"] for m in ex["conversations"] if m["from"] == "gpt")
275
+
276
+ self.mmu_samples.append((img_file, human_msg.strip(), gpt_msg.strip()))
277
+
278
+ def _load_mmmu(self, json_path, img_root):
279
+ with open(json_path, "r", encoding="utf-8") as f:
280
+ data = json.load(f)
281
+
282
+ for ex in data:
283
+ img_file = os.path.join(img_root, ex["image"])
284
+ choices = "\n".join([f"{chr(65+i)}. {c}" for i, c in enumerate(ex["choices"])])
285
+
286
+ question = f"{ex['question'].strip()}\n{choices}"
287
+ answer = f"{ex['answer']}"
288
+
289
+ self.mmu_samples.append((img_file, question, answer))
290
+
291
+ def _load_coco2014_qa(self, ann_jsonl, img_root):
292
+ with open(ann_jsonl, "r", encoding="utf-8") as file:
293
+ data = [json.loads(line) for line in file if line.strip()]
294
+
295
+ for ann in data:
296
+ image = ann["image"]
297
+ question = ann["question"]
298
+ answer = ann["label"]
299
+
300
+ image_path = os.path.join(img_root, image)
301
+ self.mmu_samples.append((image_path, question, answer))
302
+
303
+
304
+ def _load_coco2014_captions(self, ann_json, img_root):
305
+ """
306
+ Load COCO 2014 image-caption pairs from caption annotation file.
307
+
308
+ Args:
309
+ ann_json (str): Path to COCO-style captions JSON (e.g., captions_train2014.json)
310
+ img_root (str): Directory containing COCO images (should include 'train2014/' and 'val2014/' subdirs)
311
+ """
312
+ with open(ann_json, "r") as f:
313
+ data = json.load(f)
314
+
315
+ is_train = "train" in os.path.basename(ann_json).lower()
316
+ img_subdir = "train2014" if is_train else "val2014"
317
+ prefix = "COCO_train2014_" if is_train else "COCO_val2014_"
318
+
319
+ for ann in data["annotations"]:
320
+ image_id = ann["image_id"]
321
+ caption = ann["caption"]
322
+
323
+ image_filename = f"{prefix}{image_id:012d}.jpg"
324
+ image_path = os.path.join(img_root, img_subdir, image_filename)
325
+
326
+ question = "Please describe this image concisely."
327
+ self.mmu_samples.append((image_path, question, caption))
328
+
329
+ def _load_vqav2(self, ann_json, img_root):
330
+ with open(ann_json, "r") as file:
331
+ annos = json.load(file)
332
+
333
+ for ann in annos:
334
+ q = ann["question"]
335
+ answer = ann["answer"]
336
+ img_path = ann["image"]
337
+ img_file = os.path.join(
338
+ img_root,
339
+ img_path # if val, modify to val2014
340
+ )
341
+
342
+ self.mmu_samples.append((img_file, q, answer))
343
+
344
+ def _load_gqa(self, ann_json_root, img_root):
345
+ annos = {}
346
+
347
+ for jsonfile in os.listdir(ann_json_root):
348
+ jsonpath = os.path.join(ann_json_root, jsonfile)
349
+ with open(jsonpath, "r") as file:
350
+ anno = json.load(file)
351
+ annos.update(anno)
352
+
353
+ for ann in annos.values():
354
+ q = ann["question"]
355
+ answer = ann["fullAnswer"]
356
+ img_name = ann["imageId"] + ".jpg"
357
+ img_path = os.path.join(
358
+ img_root,
359
+ img_name
360
+ )
361
+
362
+ self.mmu_samples.append((img_path, q, answer))
363
+
364
+ def _load_mg_llava(self, json_path, img_root):
365
+ with open(json_path, "r", encoding="utf-8") as f:
366
+ data = json.load(f)
367
+
368
+ for ex in data:
369
+ image = ex.get("image", None)
370
+ if image is not None:
371
+ img_file = os.path.join(img_root, ex["image"])
372
+ if os.path.exists(img_file):
373
+ human_msg = next(m["value"] for m in ex["conversations"] if m["from"] == "human")
374
+ gpt_msg = next(m["value"] for m in ex["conversations"] if m["from"] == "gpt")
375
+
376
+ self.mmu_samples.append((img_file, human_msg.strip(), gpt_msg.strip()))
377
+
378
+ def _load_caption(self, root_dir):
379
+ for root, _, files in os.walk(root_dir):
380
+ for f in files:
381
+ if f.lower().endswith((".jpg", ".png")):
382
+ base = os.path.splitext(f)[0]
383
+ txt_path = os.path.join(root, base + ".txt")
384
+ if os.path.exists(txt_path):
385
+ with open(txt_path, "r") as file:
386
+ caption = file.read().strip()
387
+ q = "Please describe this image."
388
+ self.mmu_samples.append((os.path.join(root, f), q, caption))
389
+
390
+ def _load_pdd3(self, root_dir):
391
+ for root, _, files in os.walk(root_dir):
392
+ for f in files:
393
+ if f.lower().endswith((".jpg", ".png")):
394
+ base = os.path.splitext(f)[0]
395
+ txt_path = os.path.join(root, base + ".txt")
396
+ if os.path.exists(txt_path):
397
+ with open(txt_path, "r") as file:
398
+ caption = file.read().strip()
399
+ self.gen_samples.append((os.path.join(root, f), caption))
400
+
401
+ # ------------------------------------------------------------------ #
402
+ # PyTorch Dataset API #
403
+ # ------------------------------------------------------------------ #
404
+ def __len__(self):
405
+ return max(self.len_gen, self.len_mmu)
406
+
407
+ def __getitem__(self, idx):
408
+ get_mmu_data = False
409
+ get_gen_data = False
410
+
411
+ while not get_mmu_data:
412
+ try:
413
+ mmu_img_path, question, answer = self.mmu_samples[idx]
414
+ get_mmu_data = True
415
+ except:
416
+ idx = random.randint(0, self.len_mmu - 1)
417
+
418
+ while not get_gen_data:
419
+ try:
420
+ gen_img_path, prompt = self.gen_samples[idx]
421
+ get_gen_data = True
422
+ except:
423
+ idx = random.randint(0, self.len_gen - 1)
424
+
425
+ try:
426
+ # ---- image ----
427
+ mmu_image = Image.open(mmu_img_path).convert("RGB")
428
+ mmu_ret = process_image(mmu_image, self.size, self.norm)
429
+
430
+ gen_image = Image.open(gen_img_path).convert("RGB")
431
+ gen_ret = process_image(gen_image, self.size, self.norm)
432
+
433
+ ret = dict(
434
+ gen_image=gen_ret["image"],
435
+ gen_micro_conds=gen_ret["micro_conds"],
436
+ mmu_image=mmu_ret["image"],
437
+ mmu_micro_conds=mmu_ret["micro_conds"]
438
+ )
439
+
440
+ # ---- text ----
441
+ question = question.replace("<image>", "").replace("\n", "")
442
+ question_ids = tokenize_prompt(
443
+ self.tokenizer,
444
+ question,
445
+ self.arch,
446
+ padding=False,
447
+ )
448
+ question_ids = question_ids[:, :-1]
449
+ q_len = len(question_ids[0])
450
+ if answer:
451
+ full_prompt = question + " " + answer
452
+ else:
453
+ full_prompt = question
454
+ mmu_input_ids = tokenize_prompt(self.tokenizer, full_prompt, self.arch)
455
+
456
+ gen_input_ids = tokenize_prompt(self.tokenizer, prompt, self.arch)
457
+
458
+ ret.update({
459
+ "gen_input_ids": gen_input_ids,
460
+ "mmu_input_ids": mmu_input_ids,
461
+ "question_len": torch.LongTensor([q_len])
462
+ })
463
+ return ret
464
+ except:
465
+ print("================================================================")
466
+ print(f"There is something wrong with {mmu_img_path} or {gen_img_path}.")
467
+ print("================================================================")
468
+ if idx < self.len_gen - 1 or idx < self.len_mmu - 1:
469
+ return self.__getitem__(idx + 1)
470
+ else:
471
+ idx = random.randint(0, self.len_gen - 1)
472
+ return self.__getitem__(idx)
train/instruction_tuning.py ADDED
@@ -0,0 +1,1246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team and The MeissonFlow Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import argparse
16
+ import copy
17
+ import logging
18
+ import math
19
+ import os
20
+ from pathlib import Path
21
+ import sys
22
+ sys.path.append(os.getcwd())
23
+ import json
24
+ import gc
25
+
26
+ import torch
27
+ import torch.nn.functional as F
28
+ from torch import nn
29
+
30
+ from accelerate import Accelerator
31
+ from accelerate.logging import get_logger
32
+ from accelerate.utils import ProjectConfiguration, set_seed
33
+ from peft import LoraConfig
34
+ from peft.utils import get_peft_model_state_dict
35
+ from torch.utils.data import DataLoader
36
+ from torchvision import transforms
37
+
38
+ from transformers import (
39
+ CLIPTextModelWithProjection,
40
+ CLIPTokenizer,
41
+ T5EncoderModel,
42
+ T5Tokenizer,
43
+ )
44
+
45
+ import diffusers.optimization
46
+ from diffusers import VQModel
47
+
48
+ from src.scheduler import Scheduler
49
+ from diffusers.loaders import LoraLoaderMixin
50
+ from diffusers.utils import is_wandb_available
51
+ from src.pipeline import UnifiedPipeline
52
+ from torchvision.utils import save_image, make_grid
53
+ from train.trainer_utils import save_checkpoint
54
+ from train.dataset_utils import ImageCaptionLargeDataset, MultiSourceVLDataset
55
+ from train.dataset_utils import tokenize_prompt, encode_prompt
56
+ from src.transformer import SymmetricTransformer2DModel
57
+ from train.trainer_utils import load_images_to_tensor
58
+
59
+ if is_wandb_available():
60
+ import wandb
61
+ # wandb.login(key="")
62
+
63
+ logger = get_logger(__name__, log_level="INFO")
64
+
65
+ import torch._dynamo
66
+ torch._dynamo.config.verbose = True
67
+
68
+ # Optionally suppress errors to fall back to eager execution
69
+ torch._dynamo.config.suppress_errors = True
70
+
71
+ def parse_args():
72
+ parser = argparse.ArgumentParser()
73
+ parser.add_argument(
74
+ "--pretrained_model_name_or_path",
75
+ type=str,
76
+ default=None,
77
+ required=True,
78
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
79
+ )
80
+ parser.add_argument(
81
+ "--pretrained_transformer_path",
82
+ type=str,
83
+ default=None,
84
+ required=True,
85
+ help="Path to pretrained transformer.",
86
+ )
87
+ parser.add_argument(
88
+ "--text_encoder_architecture",
89
+ type=str,
90
+ default="open_clip",
91
+ required=False,
92
+ help="The architecture of the text encoder. One of ['CLIP', 'open_clip', 'flan-t5-base','Qwen2-0.5B','gemini-2b',long_t5_clip','t5_clip']",
93
+ )
94
+ parser.add_argument(
95
+ "--dataset_type",
96
+ type=str,
97
+ default=None,
98
+ required=False,
99
+ help="The type of the dataset.",
100
+ )
101
+ parser.add_argument(
102
+ "--instance_data_dir",
103
+ type=str,
104
+ default=None,
105
+ required=False,
106
+ help="A folder containing the training data of instance images.",
107
+ )
108
+ parser.add_argument(
109
+ "--caption_dir",
110
+ type=str,
111
+ default=None,
112
+ required=False,
113
+ help="A folder containing the training data of instance images.",
114
+ )
115
+ parser.add_argument(
116
+ "--llava_json_path",
117
+ type=str,
118
+ default=None,
119
+ required=False,
120
+ help="A folder containing the training data of instance images.",
121
+ )
122
+ parser.add_argument(
123
+ "--llava_image_root",
124
+ type=str,
125
+ default=None,
126
+ required=False,
127
+ help="A folder containing the training data of instance images.",
128
+ )
129
+ parser.add_argument(
130
+ "--mmmu_json_path",
131
+ type=str,
132
+ default=None,
133
+ required=False,
134
+ help="A folder containing the training data of instance images.",
135
+ )
136
+ parser.add_argument(
137
+ "--mmmu_image_root",
138
+ type=str,
139
+ default=None,
140
+ required=False,
141
+ help="A folder containing the training data of instance images.",
142
+ )
143
+ parser.add_argument(
144
+ "--vqa_ann_json_path",
145
+ type=str,
146
+ default=None,
147
+ required=False,
148
+ help="A folder containing the training data of instance images.",
149
+ )
150
+ parser.add_argument(
151
+ "--vqa_image_root",
152
+ type=str,
153
+ default=None,
154
+ required=False,
155
+ help="A folder containing the training data of instance images.",
156
+ )
157
+ parser.add_argument(
158
+ "--coco_json",
159
+ type=str,
160
+ default=None,
161
+ required=False,
162
+ help="A folder containing the training data of instance images.",
163
+ )
164
+ parser.add_argument(
165
+ "--coco_qa_json",
166
+ type=str,
167
+ default=None,
168
+ required=False,
169
+ help="A folder containing the training data of instance images.",
170
+ )
171
+ parser.add_argument(
172
+ "--coco_img_root",
173
+ type=str,
174
+ default=None,
175
+ required=False,
176
+ help="A folder containing the training data of instance images.",
177
+ )
178
+ parser.add_argument(
179
+ "--gqa_json_root",
180
+ type=str,
181
+ default=None,
182
+ required=False,
183
+ help="A folder containing the training data of instance images.",
184
+ )
185
+ parser.add_argument(
186
+ "--gqa_image_root",
187
+ type=str,
188
+ default=None,
189
+ required=False,
190
+ help="A folder containing the training data of instance images.",
191
+ )
192
+ parser.add_argument(
193
+ "--mg_llava_json",
194
+ type=str,
195
+ default=None,
196
+ required=False,
197
+ help="A folder containing the training data of instance images.",
198
+ )
199
+ parser.add_argument(
200
+ "--mg_llava_root",
201
+ type=str,
202
+ default=None,
203
+ required=False,
204
+ help="A folder containing the training data of instance images.",
205
+ )
206
+ parser.add_argument(
207
+ "--training_from_scratch",
208
+ type=bool,
209
+ default=False,
210
+ required=False
211
+ )
212
+ parser.add_argument(
213
+ "--revision",
214
+ type=str,
215
+ default=None,
216
+ required=False,
217
+ help="Revision of pretrained model identifier from huggingface.co/models.",
218
+ )
219
+ parser.add_argument(
220
+ "--variant",
221
+ type=str,
222
+ default=None,
223
+ help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
224
+ )
225
+ parser.add_argument(
226
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
227
+ )
228
+ parser.add_argument(
229
+ "--dataloader_num_workers",
230
+ type=int,
231
+ default=0,
232
+ help=(
233
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
234
+ ),
235
+ )
236
+ parser.add_argument(
237
+ "--allow_tf32",
238
+ action="store_true",
239
+ help=(
240
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
241
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
242
+ ),
243
+ )
244
+ parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.")
245
+ parser.add_argument("--ema_decay", type=float, default=0.9999)
246
+ parser.add_argument("--ema_update_after_step", type=int, default=0)
247
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
248
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
249
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
250
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
251
+ parser.add_argument(
252
+ "--output_dir",
253
+ type=str,
254
+ default="muse_training",
255
+ help="The output directory where the model predictions and checkpoints will be written.",
256
+ )
257
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
258
+ parser.add_argument(
259
+ "--logging_dir",
260
+ type=str,
261
+ default="logs",
262
+ help=(
263
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
264
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
265
+ ),
266
+ )
267
+ parser.add_argument(
268
+ "--max_train_steps",
269
+ type=int,
270
+ default=None,
271
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
272
+ )
273
+ parser.add_argument(
274
+ "--checkpointing_steps",
275
+ type=int,
276
+ default=500,
277
+ help=(
278
+ "Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. "
279
+ "In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference."
280
+ "Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components."
281
+ "See https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step"
282
+ "instructions."
283
+ ),
284
+ )
285
+ parser.add_argument(
286
+ "--logging_steps",
287
+ type=int,
288
+ default=50,
289
+ )
290
+ parser.add_argument(
291
+ "--checkpoints_total_limit",
292
+ type=int,
293
+ default=None,
294
+ help=(
295
+ "Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`."
296
+ " See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state"
297
+ " for more details"
298
+ ),
299
+ )
300
+ parser.add_argument(
301
+ "--resume_from_checkpoint",
302
+ type=str,
303
+ default=None,
304
+ help=(
305
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
306
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
307
+ ),
308
+ )
309
+ parser.add_argument(
310
+ "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
311
+ )
312
+ parser.add_argument(
313
+ "--gradient_accumulation_steps",
314
+ type=int,
315
+ default=1,
316
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
317
+ )
318
+ parser.add_argument(
319
+ "--text_loss_weight",
320
+ type=float,
321
+ default=0.2,
322
+ )
323
+ parser.add_argument(
324
+ "--learning_rate",
325
+ type=float,
326
+ default=0.0003,
327
+ help="Initial learning rate (after the potential warmup period) to use.",
328
+ )
329
+ parser.add_argument(
330
+ "--scale_lr",
331
+ action="store_true",
332
+ default=False,
333
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
334
+ )
335
+ parser.add_argument(
336
+ "--lr_scheduler",
337
+ type=str,
338
+ default="constant",
339
+ help=(
340
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
341
+ ' "constant", "constant_with_warmup"]'
342
+ ),
343
+ )
344
+ parser.add_argument(
345
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
346
+ )
347
+ parser.add_argument(
348
+ "--validation_steps",
349
+ type=int,
350
+ default=100,
351
+ help=(
352
+ "Run validation every X steps. Validation consists of running the prompt"
353
+ " `args.validation_prompt` multiple times: `args.num_validation_images`"
354
+ " and logging the images."
355
+ ),
356
+ )
357
+ parser.add_argument(
358
+ "--mixed_precision",
359
+ type=str,
360
+ default=None,
361
+ choices=["no", "fp16", "bf16"],
362
+ help=(
363
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
364
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
365
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
366
+ ),
367
+ )
368
+ parser.add_argument(
369
+ "--report_to",
370
+ type=str,
371
+ default="wandb",
372
+ help=(
373
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
374
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
375
+ ),
376
+ )
377
+ parser.add_argument("--validation_prompts", type=str, nargs="*")
378
+ parser.add_argument("--validation_vqa_prompts", type=str, default=None)
379
+ parser.add_argument("--validation_images", type=str, default=None)
380
+ parser.add_argument(
381
+ "--resolution",
382
+ type=int,
383
+ default=512,
384
+ help=(
385
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
386
+ " resolution"
387
+ ),
388
+ )
389
+ parser.add_argument("--split_vae_encode", type=int, required=False, default=None)
390
+ parser.add_argument("--min_masking_rate", type=float, default=0.0)
391
+ parser.add_argument("--cond_dropout_prob", type=float, default=0.0)
392
+ parser.add_argument("--max_grad_norm", default=50.0, type=float, help="Max gradient norm.", required=False)
393
+ parser.add_argument("--use_lora", action="store_true", help="Fine tune the model using LoRa")
394
+ parser.add_argument("--text_encoder_use_lora", action="store_true", help="Fine tune the model using LoRa")
395
+ parser.add_argument("--lora_r", default=16, type=int)
396
+ parser.add_argument("--lora_alpha", default=32, type=int)
397
+ parser.add_argument("--lora_target_modules", default=["to_q", "to_k", "to_v"], type=str, nargs="+")
398
+ parser.add_argument("--text_encoder_lora_r", default=16, type=int)
399
+ parser.add_argument("--text_encoder_lora_alpha", default=32, type=int)
400
+ parser.add_argument("--text_encoder_lora_target_modules", default=["to_q", "to_k", "to_v"], type=str, nargs="+")
401
+ parser.add_argument("--train_text_encoder", action="store_true")
402
+ parser.add_argument("--image_to_text_only", action="store_true")
403
+ parser.add_argument("--image_key", type=str, required=False)
404
+ parser.add_argument("--prompt_key", type=str, required=False)
405
+ parser.add_argument(
406
+ "--gradient_checkpointing",
407
+ action="store_true",
408
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
409
+ )
410
+ parser.add_argument("--prompt_prefix", type=str, required=False, default=None)
411
+
412
+ args = parser.parse_args()
413
+
414
+ if args.report_to == "wandb":
415
+ if not is_wandb_available():
416
+ raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
417
+
418
+ if args.instance_data_dir is not None:
419
+ if not os.path.exists(args.instance_data_dir):
420
+ raise ValueError(f"Does not exist: `--args.instance_data_dir` {args.instance_data_dir}")
421
+
422
+ return args
423
+
424
+ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
425
+ latent_image_ids = torch.zeros(height // 2, width // 2, 3)
426
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
427
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
428
+
429
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
430
+
431
+ latent_image_ids = latent_image_ids.reshape(
432
+ latent_image_id_height * latent_image_id_width, latent_image_id_channels
433
+ )
434
+
435
+ return latent_image_ids.to(device=device, dtype=dtype)
436
+
437
+ def main(args):
438
+ if args.allow_tf32:
439
+ torch.backends.cuda.matmul.allow_tf32 = True
440
+
441
+ logging_dir = Path(args.output_dir, args.logging_dir)
442
+
443
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
444
+
445
+ accelerator = Accelerator(
446
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
447
+ mixed_precision=args.mixed_precision,
448
+ log_with=args.report_to,
449
+ project_config=accelerator_project_config,
450
+ )
451
+
452
+ if accelerator.is_main_process:
453
+ os.makedirs(args.output_dir, exist_ok=True)
454
+
455
+ # Make one log on every process with the configuration for debugging.
456
+ logging.basicConfig(
457
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
458
+ datefmt="%m/%d/%Y %H:%M:%S",
459
+ level=logging.INFO,
460
+ )
461
+ logger.info(accelerator.state, main_process_only=False)
462
+
463
+ if accelerator.is_main_process:
464
+ accelerator.init_trackers("meissonic", config=vars(copy.deepcopy(args)))
465
+
466
+ if args.seed is not None:
467
+ set_seed(args.seed)
468
+
469
+ if args.text_encoder_architecture == "open_clip":
470
+ text_encoder = CLIPTextModelWithProjection.from_pretrained(
471
+ args.pretrained_model_name_or_path, subfolder="text_encoder", variant=args.variant
472
+ )
473
+ tokenizer = CLIPTokenizer.from_pretrained(
474
+ args.pretrained_model_name_or_path, subfolder="tokenizer", variant=args.variant
475
+ )
476
+ tokenizer_2 = None
477
+ text_encoder_2 = None
478
+
479
+ text_encoder.requires_grad_(False)
480
+ elif args.text_encoder_architecture == "t5_clip":
481
+ tokenizer = CLIPTokenizer.from_pretrained(
482
+ args.pretrained_model_name_or_path, subfolder="tokenizer", variant=args.variant
483
+ )
484
+ text_encoder = CLIPTextModelWithProjection.from_pretrained(
485
+ args.pretrained_model_name_or_path, subfolder="text_encoder", variant=args.variant
486
+ )
487
+
488
+ tokenizer_2 = T5Tokenizer.from_pretrained(
489
+ args.pretrained_model_name_or_path, subfolder="tokenizer_2", variant=args.variant,
490
+ )
491
+ text_encoder_2 = T5EncoderModel.from_pretrained(
492
+ args.pretrained_model_name_or_path, subfolder="text_encoder_2", variant=args.variant,
493
+ )
494
+
495
+ text_encoder.requires_grad_(False)
496
+ text_encoder_2.requires_grad_(False)
497
+ else:
498
+ raise ValueError(f"Unknown text encoder architecture: {args.text_encoder_architecture}")
499
+
500
+ vq_model = VQModel.from_pretrained(
501
+ args.pretrained_model_name_or_path, subfolder="vqvae", revision=args.revision, variant=args.variant
502
+ )
503
+ vq_model.requires_grad_(False)
504
+
505
+ model = SymmetricTransformer2DModel.from_pretrained(
506
+ args.pretrained_transformer_path,
507
+ subfolder="transformer",
508
+ low_cpu_mem_usage=False,
509
+ device_map=None
510
+ )
511
+
512
+ if model.config.tokenizer_vocab_size is None:
513
+ if args.text_encoder_architecture == "open_clip":
514
+ model.register_to_config(tokenizer_vocab_size=len(tokenizer))
515
+ # model.config.tokenizer_vocab_size = len(tokenizer) # We exclude the mask token in the predicted logits
516
+ elif args.text_encoder_architecture == "t5_clip":
517
+ model.register_to_config(tokenizer_vocab_size=len(tokenizer_2))
518
+ # model.config.tokenizer_vocab_size = len(tokenizer_2) # We don't need to add new token
519
+ if model.adapter is None:
520
+ raise ValueError(f"The MMDiT must has adapter if you want to use t5_clip mode!!!")
521
+ else:
522
+ raise ValueError(f"Unknown text encoder architecture!")
523
+
524
+ print(f"model's tokenizer vocab size is {model.config.tokenizer_vocab_size}")
525
+ model.text_decoder = nn.Sequential(
526
+ nn.LayerNorm(model.inner_dim, elementwise_affine=False, eps=1e-6),
527
+ nn.Linear(model.inner_dim, model.config.tokenizer_vocab_size, bias=False)
528
+ )
529
+
530
+ model = torch.compile(model)
531
+
532
+ if args.use_lora:
533
+ lora_config = LoraConfig(
534
+ r=args.lora_r,
535
+ lora_alpha=args.lora_alpha,
536
+ target_modules=args.lora_target_modules,
537
+ )
538
+ model.add_adapter(lora_config)
539
+
540
+ model.train()
541
+
542
+ if args.image_to_text_only:
543
+ frozen_keys = ["project_from_hidden", "up_block", "mlm_layer"]
544
+ for n, p in model.named_parameters():
545
+ if any([frozen_key in n for frozen_key in frozen_keys]):
546
+ p.requires_grad_(False)
547
+ else:
548
+ p.requires_grad_(True)
549
+ else:
550
+ model.requires_grad_(True)
551
+
552
+ if args.gradient_checkpointing:
553
+ model.enable_gradient_checkpointing()
554
+
555
+ def save_model_hook(models, weights, output_dir):
556
+ if accelerator.is_main_process:
557
+ transformer_lora_layers_to_save = None
558
+ text_encoder_lora_layers_to_save = None
559
+
560
+ for model_ in models:
561
+ if isinstance(model_, type(accelerator.unwrap_model(model))):
562
+ if args.use_lora:
563
+ transformer_lora_layers_to_save = get_peft_model_state_dict(model_)
564
+ else:
565
+ model_.save_pretrained(os.path.join(output_dir, "transformer"))
566
+ elif isinstance(model_, type(accelerator.unwrap_model(text_encoder))):
567
+ if args.text_encoder_use_lora:
568
+ text_encoder_lora_layers_to_save = get_peft_model_state_dict(model_)
569
+ else:
570
+ model_.save_pretrained(os.path.join(output_dir, "text_encoder"))
571
+ else:
572
+ raise ValueError(f"unexpected save model: {model_.__class__}")
573
+
574
+ # make sure to pop weight so that corresponding model is not saved again
575
+ weights.pop()
576
+
577
+ if transformer_lora_layers_to_save is not None or text_encoder_lora_layers_to_save is not None:
578
+ LoraLoaderMixin.save_lora_weights(
579
+ output_dir,
580
+ unet_lora_layers=transformer_lora_layers_to_save,
581
+ text_encoder_lora_layers=text_encoder_lora_layers_to_save,
582
+ )
583
+
584
+
585
+ def load_model_hook(models, input_dir):
586
+ transformer = None
587
+ text_encoder_ = None
588
+
589
+ # this part is added for keep consistency when add model.compile() in the model
590
+ def adap_compile(ori_dict):#add '_orig_mod.' to each key
591
+ new_dict = {}
592
+ for k,v in ori_dict.items():
593
+ new_dict['_orig_mod.' + k] = v
594
+ return new_dict
595
+
596
+ while len(models) > 0:
597
+ model_ = models.pop()
598
+
599
+ if isinstance(model_, type(accelerator.unwrap_model(model))):
600
+ if args.use_lora:
601
+ transformer = model_
602
+ else:
603
+ load_model = SymmetricTransformer2DModel.from_pretrained(os.path.join(input_dir, "transformer"), low_cpu_mem_usage=False, device_map=None)
604
+ model_.load_state_dict(adap_compile(load_model.state_dict()))
605
+ del load_model
606
+ elif isinstance(model_, type(accelerator.unwrap_model(text_encoder))):
607
+ if args.text_encoder_use_lora:
608
+ text_encoder_ = model_
609
+ else:
610
+ try:
611
+ load_model = CLIPTextModelWithProjection.from_pretrained(os.path.join(input_dir, "text_encoder"))
612
+ model_.load_state_dict(load_model.state_dict())
613
+ # print('finished loading text encoder!')
614
+ except:
615
+ print('Not found text-encoder model in current folder. So we download one text encoder from Internet.')
616
+ load_model = CLIPTextModelWithProjection.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K")
617
+ model_.load_state_dict(load_model.state_dict())
618
+ del load_model
619
+ else:
620
+ raise ValueError(f"unexpected save model: {model.__class__}")
621
+
622
+ if transformer is not None or text_encoder_ is not None:
623
+ lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir)
624
+ LoraLoaderMixin.load_lora_into_text_encoder(
625
+ lora_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_
626
+ )
627
+ LoraLoaderMixin.load_lora_into_transformer(
628
+ lora_state_dict, network_alphas=network_alphas, transformer=transformer
629
+ )
630
+
631
+ accelerator.register_load_state_pre_hook(load_model_hook)
632
+ accelerator.register_save_state_pre_hook(save_model_hook)
633
+
634
+ if args.scale_lr:
635
+ args.learning_rate = (
636
+ args.learning_rate * args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
637
+ )
638
+
639
+ if args.use_8bit_adam:
640
+ try:
641
+ import bitsandbytes as bnb
642
+ except ImportError:
643
+ raise ImportError(
644
+ "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
645
+ )
646
+
647
+ optimizer_cls = bnb.optim.AdamW8bit
648
+ else:
649
+ optimizer_cls = torch.optim.AdamW
650
+
651
+ optimizer_grouped_parameters = [
652
+ {
653
+ "params": [p for p in model.parameters() if p.requires_grad],
654
+ "weight_decay": args.adam_weight_decay,
655
+ }
656
+ ]
657
+ optimizer = optimizer_cls(
658
+ optimizer_grouped_parameters,
659
+ lr=args.learning_rate,
660
+ betas=(args.adam_beta1, args.adam_beta2),
661
+ weight_decay=args.adam_weight_decay,
662
+ eps=args.adam_epsilon,
663
+ )
664
+
665
+ logger.info("Creating dataloaders and lr_scheduler")
666
+
667
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
668
+
669
+ if args.text_encoder_architecture == "t5_clip":
670
+ tokenizer_for_dataset = [tokenizer, tokenizer_2]
671
+ else:
672
+ tokenizer_for_dataset = tokenizer
673
+
674
+ if args.dataset_type == "ImageCaptionLargeDataset":
675
+ dataset = ImageCaptionLargeDataset(
676
+ root_dir=args.instance_data_dir,
677
+ tokenizer=tokenizer_for_dataset,
678
+ size=args.resolution,
679
+ text_encoder_architecture=args.text_encoder_architecture
680
+ )
681
+ elif args.dataset_type == "MultiSourceVLDataset":
682
+ dataset = MultiSourceVLDataset(
683
+ tokenizer=tokenizer_for_dataset,
684
+ size=args.resolution,
685
+ text_encoder_architecture=args.text_encoder_architecture,
686
+ norm=False,
687
+ llava_json=args.llava_json_path,
688
+ llava_img_root=args.llava_image_root,
689
+ mmmu_json=args.mmmu_json_path,
690
+ mmmu_img_root=args.mmmu_image_root,
691
+ vqa_ann_json=args.vqa_ann_json_path,
692
+ vqa_img_root=args.vqa_image_root,
693
+ coco_json=args.coco_json,
694
+ coco_qa_json=args.coco_qa_json,
695
+ coco_img_root=args.coco_img_root,
696
+ gqa_json=args.gqa_json_root,
697
+ gqa_img_root=args.gqa_image_root,
698
+ mg_llava_json=args.mg_llava_json,
699
+ mg_llava_root=args.mg_llava_root,
700
+ caption_dir=args.caption_dir,
701
+ pdd3_dir=args.instance_data_dir,
702
+ )
703
+ elif args.dataset_type == "DATA_TYPE":
704
+ raise NotImplementedError("DATA_TYPE is not yet supported")
705
+ else:
706
+ assert False
707
+
708
+ def collate_fn(samples):
709
+ gen_images = [sample["gen_image"] for sample in samples]
710
+ mmu_images = [sample["mmu_image"] for sample in samples]
711
+
712
+ gen_micro_conds = [sample["gen_micro_conds"] for sample in samples]
713
+ mmu_micro_conds = [sample["mmu_micro_conds"] for sample in samples]
714
+
715
+ gen_images = torch.stack(gen_images, dim=0)
716
+ mmu_images = torch.stack(mmu_images, dim=0)
717
+
718
+ gen_micro_conds = torch.stack(gen_micro_conds, dim=0)
719
+ mmu_micro_conds = torch.stack(mmu_micro_conds, dim=0)
720
+
721
+ if isinstance(samples[0]["gen_input_ids"], list):
722
+ gen_input_ids = [sample["gen_input_ids"][0] for sample in samples]
723
+ gen_input_ids_2 = [sample["gen_input_ids"][1] for sample in samples]
724
+
725
+ gen_input_ids = torch.cat(gen_input_ids, dim=0)
726
+ gen_input_ids_2 = torch.cat(gen_input_ids_2, dim=0)
727
+ gen_input_ids = [gen_input_ids, gen_input_ids_2]
728
+ else:
729
+ gen_input_ids = [sample["gen_input_ids"] for sample in samples]
730
+ mmu_input_ids = [sample["mmu_input_ids"] for sample in samples]
731
+
732
+ gen_input_ids = torch.cat(gen_input_ids, dim=0)
733
+ mmu_input_ids = torch.cat(mmu_input_ids, dim=0)
734
+
735
+ if samples[0].get("question_len", None) is not None:
736
+ question_len = [sample["question_len"] for sample in samples]
737
+
738
+ question_len = torch.cat(question_len, dim=0) # [B, ]
739
+ else:
740
+ question_len = None
741
+
742
+ ret = dict(
743
+ gen_images=gen_images,
744
+ mmu_images=mmu_images,
745
+ gen_micro_conds=gen_micro_conds,
746
+ mmu_micro_conds=mmu_micro_conds,
747
+ gen_input_ids=gen_input_ids,
748
+ mmu_input_ids=mmu_input_ids,
749
+ question_len=question_len
750
+ )
751
+
752
+ return ret
753
+
754
+ train_dataloader = DataLoader(
755
+ dataset,
756
+ batch_size=args.train_batch_size,
757
+ shuffle=True,
758
+ num_workers=args.dataloader_num_workers,
759
+ collate_fn=collate_fn,
760
+ pin_memory=True,
761
+ )
762
+ train_dataloader.num_batches = len(train_dataloader)
763
+
764
+ lr_scheduler = diffusers.optimization.get_scheduler(
765
+ args.lr_scheduler,
766
+ optimizer=optimizer,
767
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
768
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
769
+ )
770
+
771
+ logger.info("Preparing model, optimizer and dataloaders")
772
+
773
+ model, optimizer, lr_scheduler, train_dataloader = accelerator.prepare(
774
+ model, optimizer, lr_scheduler, train_dataloader
775
+ )
776
+
777
+ train_dataloader.num_batches = len(train_dataloader)
778
+
779
+ weight_dtype = torch.float32
780
+ if accelerator.mixed_precision == "fp16":
781
+ weight_dtype = torch.float16
782
+ elif accelerator.mixed_precision == "bf16":
783
+ weight_dtype = torch.bfloat16
784
+
785
+ if args.text_encoder_architecture == "t5_clip":
786
+ text_encoder.to(device=accelerator.device, dtype=weight_dtype)
787
+ text_encoder_2.to(device=accelerator.device, dtype=weight_dtype)
788
+ else:
789
+ text_encoder.to(device=accelerator.device, dtype=weight_dtype)
790
+
791
+ vq_model.to(device=accelerator.device)
792
+
793
+ with torch.no_grad():
794
+ if args.text_encoder_architecture == "t5_clip":
795
+ _input_ids_tmp_ = tokenize_prompt([tokenizer, tokenizer_2], "", args.text_encoder_architecture)
796
+ _input_ids_tmp_[0] = _input_ids_tmp_[0].to(accelerator.device)
797
+ _input_ids_tmp_[1] = _input_ids_tmp_[1].to(accelerator.device)
798
+ empty_embeds, empty_clip_embeds = encode_prompt(
799
+ [text_encoder, text_encoder_2],
800
+ _input_ids_tmp_,
801
+ args.text_encoder_architecture
802
+ )
803
+ else:
804
+ _input_ids_tmp_ = tokenize_prompt(tokenizer, "", args.text_encoder_architecture)
805
+ _input_ids_tmp_ = _input_ids_tmp_.to(accelerator.device)
806
+ empty_embeds, empty_clip_embeds = encode_prompt(
807
+ text_encoder,
808
+ _input_ids_tmp_,
809
+ args.text_encoder_architecture
810
+ )
811
+
812
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
813
+ num_update_steps_per_epoch = math.ceil(train_dataloader.num_batches / args.gradient_accumulation_steps)
814
+ # Afterwards we recalculate our number of training epochs.
815
+ # Note: We are not doing epoch based training here, but just using this for book keeping and being able to
816
+ # reuse the same training loop with other datasets/loaders.
817
+ num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
818
+
819
+ # Train!
820
+ logger.info("***** Running training *****")
821
+ logger.info(f" Num training steps = {args.max_train_steps}")
822
+ logger.info(f" Instantaneous batch size per device = { args.train_batch_size}")
823
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
824
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
825
+
826
+ resume_from_checkpoint = args.resume_from_checkpoint
827
+ if resume_from_checkpoint:
828
+ if resume_from_checkpoint == "latest":
829
+ # Get the most recent checkpoint
830
+ dirs = os.listdir(args.output_dir)
831
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
832
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
833
+ if len(dirs) > 0:
834
+ resume_from_checkpoint = os.path.join(args.output_dir, dirs[-1])
835
+ else:
836
+ resume_from_checkpoint = None
837
+
838
+ if resume_from_checkpoint is None:
839
+ accelerator.print(
840
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
841
+ )
842
+ else:
843
+ accelerator.print(f"Resuming from checkpoint {resume_from_checkpoint}")
844
+
845
+ if resume_from_checkpoint is None:
846
+ global_step = 0
847
+ first_epoch = 0
848
+ else:
849
+ accelerator.load_state(resume_from_checkpoint)
850
+ global_step = int(os.path.basename(resume_from_checkpoint).split("-")[1])
851
+ first_epoch = global_step // num_update_steps_per_epoch
852
+
853
+ # As stated above, we are not doing epoch based training here, but just using this for book keeping and being able to
854
+ # reuse the same training loop with other datasets/loaders.
855
+ for epoch in range(first_epoch, num_train_epochs):
856
+ for batch in train_dataloader:
857
+ torch.cuda.empty_cache()
858
+ with torch.no_grad():
859
+ gen_pixel_values = batch["gen_images"].to(accelerator.device, non_blocking=True)
860
+ mmu_pixel_values = batch["mmu_images"].to(accelerator.device, non_blocking=True)
861
+
862
+ gen_micro_conds = batch["gen_micro_conds"].to(accelerator.device, non_blocking=True)
863
+ mmu_micro_conds = batch["mmu_micro_conds"].to(accelerator.device, non_blocking=True)
864
+
865
+ # ====================== tokenize images ======================
866
+ pixel_values = torch.cat([gen_pixel_values, mmu_pixel_values], dim=0)
867
+ batch_size = pixel_values.shape[0]
868
+
869
+ split_batch_size = args.split_vae_encode if args.split_vae_encode is not None else batch_size
870
+ num_splits = math.ceil(batch_size / split_batch_size)
871
+ image_tokens = []
872
+ for i in range(num_splits):
873
+ start_idx = i * split_batch_size
874
+ end_idx = min((i + 1) * split_batch_size, batch_size)
875
+ image_tokens.append(
876
+ vq_model.quantize(
877
+ vq_model.encode(pixel_values[start_idx:end_idx]).latents
878
+ )[2][2].reshape(split_batch_size, -1)
879
+ )
880
+ image_tokens = torch.cat(image_tokens, dim=0)
881
+ gen_image_tokens, mmu_image_tokens = image_tokens.chunk(2, dim=0)
882
+ # ====================== tokenize images ======================
883
+
884
+
885
+ # ====================== encode clean text prompts ======================
886
+ if args.text_encoder_architecture == "t5_clip":
887
+ gen_input_ids_clip = batch["gen_input_ids"][0].to(accelerator.device, non_blocking=True)
888
+ gen_input_ids_t5 = batch["gen_input_ids"][1].to(accelerator.device, non_blocking=True)
889
+ encoder_hidden_states, cond_embeds = encode_prompt(
890
+ [text_encoder, text_encoder_2],
891
+ [gen_input_ids_clip, gen_input_ids_t5],
892
+ args.text_encoder_architecture
893
+ )
894
+ else:
895
+ gen_input_ids = batch["gen_input_ids"].to(accelerator.device, non_blocking=True)
896
+ gen_encoder_hidden_states, gen_cond_embeds = encode_prompt(
897
+ text_encoder,
898
+ gen_input_ids,
899
+ args.text_encoder_architecture
900
+ )
901
+ gen_encoder_hidden_states = gen_encoder_hidden_states.to(accelerator.device, dtype=accelerator.unwrap_model(model).dtype)
902
+ gen_cond_embeds = gen_cond_embeds.to(accelerator.device, dtype=accelerator.unwrap_model(model).dtype)
903
+ # ====================== encode clean text prompts ======================
904
+
905
+
906
+ # ====================== image perturbation ======================
907
+ half_batch_size, seq_len = gen_image_tokens.shape
908
+ sigma = torch.rand(half_batch_size, device=gen_image_tokens.device)
909
+ image_mask_prob = torch.cos(sigma * math.pi * 0.5)
910
+ image_mask_prob = image_mask_prob.clip(args.min_masking_rate)
911
+
912
+ num_token_masked = (seq_len * image_mask_prob).round().clamp(min=1)
913
+ batch_randperm = torch.rand(half_batch_size, seq_len, device=gen_image_tokens.device).argsort(dim=-1)
914
+ mask = batch_randperm < num_token_masked.unsqueeze(-1)
915
+
916
+ mask_id = accelerator.unwrap_model(model).config.vocab_size - 1
917
+ gen_image_ids = torch.where(mask, mask_id, gen_image_tokens)
918
+ image_labels = torch.where(mask, gen_image_tokens, -100)
919
+ # ====================== image perturbation ======================
920
+
921
+
922
+ # ====================== text perturbation ======================
923
+ if args.text_encoder_architecture == "t5_clip":
924
+ mmu_input_ids_clip = batch["mmu_input_ids"][0].to(accelerator.device, non_blocking=True)
925
+ mmu_input_ids_t5 = batch["mmu_input_ids"][1].to(accelerator.device, non_blocking=True)
926
+ half_batch_size, seq_len = mmu_input_ids_t5.shape
927
+ sigma = torch.rand(half_batch_size, device=mmu_image_tokens.device)
928
+ text_mask_prob = torch.cos(sigma * math.pi * 0.5)
929
+ text_mask_prob = text_mask_prob.clip(args.min_masking_rate)
930
+ text_timestep = text_mask_prob.clone().clamp(min=1e-3)
931
+
932
+ num_token_masked = (seq_len * text_mask_prob).round().clamp(min=1)
933
+ batch_randperm = torch.rand(half_batch_size, seq_len, device=mmu_image_tokens.device).argsort(dim=-1)
934
+ mask = batch_randperm < num_token_masked.unsqueeze(-1)
935
+
936
+ extra_id_0_token = "<extra_id_0>"
937
+ t5_mask_id = tokenizer_2.convert_tokens_to_ids(extra_id_0_token)
938
+ masked_prompt_input_ids_t5 = torch.where(mask, t5_mask_id, mmu_input_ids_t5)
939
+ text_labels = torch.where(mask, mmu_input_ids_t5, -100)
940
+
941
+ # prepare input_ids for clip model
942
+ batch_prompt_2 = []
943
+ for i in range(masked_prompt_input_ids_t5.size(0)):
944
+ masked_prompt_input_id = masked_prompt_input_ids_t5[i].tolist()
945
+ prompt_2 = tokenizer_2.decode(masked_prompt_input_id, skip_special_tokens=True)
946
+ batch_prompt_2.append(prompt_2)
947
+
948
+ masked_prompt_input_ids_clip = tokenizer(
949
+ batch_prompt_2,
950
+ truncation=True,
951
+ padding="max_length",
952
+ max_length=77,
953
+ return_tensors="pt"
954
+ ).input_ids
955
+ masked_prompt_input_ids_clip = masked_prompt_input_ids_clip.to(accelerator.device)
956
+ else:
957
+ extra_id_0_token = "<extra_id_0>"
958
+ num_new_tokens = tokenizer.add_tokens(extra_id_0_token)
959
+ clip_mask_id = tokenizer.convert_tokens_to_ids(extra_id_0_token)
960
+ if num_new_tokens > 0:
961
+ text_encoder.resize_token_embeddings(len(tokenizer))
962
+ mask_token_embedding = text_encoder.get_input_embeddings().weight[clip_mask_id]
963
+ mask_token_embedding = mask_token_embedding.clone().detach().cpu().float()
964
+ if accelerator.is_main_process:
965
+ print("Saving masked token embedding...")
966
+ torch.save(mask_token_embedding, os.path.join(args.output_dir, "mask_token_embedding.pth"))
967
+
968
+
969
+ mmu_input_ids = batch["mmu_input_ids"].to(accelerator.device, non_blocking=True) # [B, L]
970
+ question_len = batch["question_len"] # [B, ]
971
+ if question_len is not None:
972
+ question_len = question_len.to(accelerator.device, non_blocking=True)
973
+
974
+ half_batch_size, seq_len = mmu_input_ids.shape
975
+ answer_len = seq_len - question_len
976
+
977
+ sigma = torch.rand(half_batch_size, device=mmu_image_tokens.device)
978
+ text_mask_prob = torch.cos(sigma * math.pi * 0.5)
979
+ text_mask_prob = text_mask_prob.clip(args.min_masking_rate)
980
+ text_timestep = text_mask_prob.clone().clamp(min=1e-3)
981
+
982
+ num_token_masked = ((seq_len - question_len) * text_mask_prob).round().clamp(min=1) # [B, ]
983
+ num_token_masked = torch.minimum(num_token_masked, answer_len)
984
+
985
+ seq_idx = torch.arange(seq_len, device=mmu_image_tokens.device).unsqueeze(0).repeat(half_batch_size, 1)
986
+ answer_region = seq_idx >= question_len.unsqueeze(1)
987
+
988
+ rand_value = torch.rand(half_batch_size, seq_len, device=mmu_image_tokens.device)
989
+ rand_value = rand_value.masked_fill(~answer_region, float("inf"))
990
+
991
+ order = rand_value.argsort(dim=-1)
992
+ order = order.argsort(dim=-1)
993
+ mask = order < num_token_masked.unsqueeze(-1)
994
+
995
+ # mask = torch.zeros_like(mmu_input_ids)
996
+ # for b in range(half_batch_size):
997
+ # ans_len = seq_len - question_len[b]
998
+ # batch_randperm = torch.rand(1, ans_len, device=mmu_image_tokens.device).argsort(dim=-1)
999
+ # mask[b, question_len[b]:] = batch_randperm < num_token_masked[b].unsqueeze(-1)
1000
+
1001
+ mmu_input_ids_clip = torch.where(mask, clip_mask_id, mmu_input_ids)
1002
+ text_labels = torch.where(mask, mmu_input_ids, -100)
1003
+ # ====================== text perturbation ======================
1004
+
1005
+
1006
+ # ====================== encode masked text prompts ======================
1007
+ if args.text_encoder_architecture == "t5_clip":
1008
+ masked_encoder_hidden_states, masked_cond_embeds = encode_prompt(
1009
+ [text_encoder, text_encoder_2],
1010
+ [masked_prompt_input_ids_clip, masked_prompt_input_ids_t5],
1011
+ args.text_encoder_architecture
1012
+ )
1013
+ else:
1014
+ mmu_encoder_hidden_states, mmu_cond_embeds = encode_prompt(
1015
+ text_encoder,
1016
+ mmu_input_ids_clip,
1017
+ args.text_encoder_architecture
1018
+ )
1019
+ mmu_encoder_hidden_states = mmu_encoder_hidden_states.to(accelerator.device, dtype=accelerator.unwrap_model(model).dtype)
1020
+ mmu_cond_embeds = mmu_cond_embeds.to(accelerator.device, dtype=accelerator.unwrap_model(model).dtype)
1021
+ # ====================== encode masked text prompts ======================
1022
+
1023
+
1024
+ # for CFG
1025
+ if args.cond_dropout_prob > 0.0:
1026
+ assert encoder_hidden_states is not None
1027
+
1028
+ batch_size = encoder_hidden_states.shape[0]
1029
+
1030
+ mask = (
1031
+ torch.zeros((batch_size, 1, 1), device=encoder_hidden_states.device).float().uniform_(0, 1)
1032
+ < args.cond_dropout_prob
1033
+ )
1034
+
1035
+ empty_embeds_ = empty_embeds.expand(batch_size, -1, -1)
1036
+ encoder_hidden_states = torch.where(
1037
+ (encoder_hidden_states * mask).bool(), encoder_hidden_states, empty_embeds_
1038
+ )
1039
+
1040
+ empty_clip_embeds_ = empty_clip_embeds.expand(batch_size, -1)
1041
+ cond_embeds = torch.where((cond_embeds * mask.squeeze(-1)).bool(), cond_embeds, empty_clip_embeds_)
1042
+
1043
+ vae_scale_factor = 2 ** (len(vq_model.config.block_out_channels) - 1)
1044
+ resolution = args.resolution // vae_scale_factor
1045
+ gen_image_ids = gen_image_ids.reshape(half_batch_size, resolution, resolution)
1046
+ mmu_image_ids = mmu_image_tokens.reshape(half_batch_size, resolution, resolution)
1047
+
1048
+
1049
+ # Train Step
1050
+ with accelerator.accumulate(model):
1051
+ codebook_size = accelerator.unwrap_model(model).config.codebook_size
1052
+ if args.resolution == 1024: # only stage 3 and stage 4 do not apply 2*
1053
+ img_ids = _prepare_latent_image_ids(
1054
+ gen_image_ids.shape[0],
1055
+ gen_image_ids.shape[-2],
1056
+ gen_image_ids.shape[-1],
1057
+ gen_image_ids.device,
1058
+ gen_image_ids.dtype
1059
+ )
1060
+ else:
1061
+ img_ids = _prepare_latent_image_ids(
1062
+ gen_image_ids.shape[0],
1063
+ gen_image_ids.shape[-2],
1064
+ gen_image_ids.shape[-1],
1065
+ gen_image_ids.device,
1066
+ gen_image_ids.dtype
1067
+ )
1068
+
1069
+ txt_ids = torch.zeros(gen_encoder_hidden_states.shape[1], 3).to(device=gen_image_ids.device, dtype=gen_image_ids.dtype)
1070
+
1071
+ image_logits = model(
1072
+ hidden_states=gen_image_ids, # should be (batch size, channel, height, width)
1073
+ encoder_hidden_states=gen_encoder_hidden_states, # should be (batch size, sequence_len, embed_dims)
1074
+ micro_conds=gen_micro_conds,
1075
+ pooled_projections=gen_cond_embeds, # should be (batch_size, projection_dim)
1076
+ img_ids=img_ids,
1077
+ txt_ids=txt_ids,
1078
+ timestep=image_mask_prob * 1000,
1079
+ )[0]
1080
+ image_logits = image_logits.reshape(half_batch_size, codebook_size, -1)
1081
+ image_logits = image_logits.permute(0, 2, 1)
1082
+ image_logits = image_logits.reshape(-1, codebook_size)
1083
+
1084
+ image_loss = F.cross_entropy(
1085
+ image_logits,
1086
+ image_labels.view(-1),
1087
+ ignore_index=-100,
1088
+ reduction="mean",
1089
+ )
1090
+
1091
+ text_logits = model(
1092
+ hidden_states=mmu_image_ids, # should be (batch size, channel, height, width)
1093
+ encoder_hidden_states=mmu_encoder_hidden_states, # should be (batch size, sequence_len, embed_dims)
1094
+ micro_conds=mmu_micro_conds,
1095
+ pooled_projections=mmu_cond_embeds, # should be (batch_size, projection_dim)
1096
+ img_ids=img_ids,
1097
+ txt_ids=txt_ids,
1098
+ timestep=text_timestep * 1000,
1099
+ )[1]
1100
+ text_logits = text_logits.reshape(-1, accelerator.unwrap_model(model).config.tokenizer_vocab_size)
1101
+
1102
+ text_loss = F.cross_entropy(
1103
+ text_logits,
1104
+ text_labels.view(-1),
1105
+ ignore_index=-100,
1106
+ reduction="none",
1107
+ )
1108
+ text_loss = text_loss.reshape(half_batch_size, -1).mean(-1)
1109
+ text_loss = text_loss / text_timestep
1110
+ text_loss = text_loss.mean()
1111
+
1112
+ loss = image_loss + args.text_loss_weight * text_loss
1113
+
1114
+ # Gather the losses across all processes for logging (if we use distributed training).
1115
+ avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
1116
+ avg_masking_rate = accelerator.gather(text_mask_prob.repeat(args.train_batch_size)).mean()
1117
+
1118
+ accelerator.backward(loss)
1119
+
1120
+ # Temporarily add this to identify unused parameters
1121
+ # for name, param in accelerator.unwrap_model(model).named_parameters():
1122
+ # if param.grad is None:
1123
+ # print(f"Unused parameter: {name}")
1124
+
1125
+ if args.max_grad_norm is not None and accelerator.sync_gradients:
1126
+ accelerator.clip_grad_norm_(model.parameters(), args.max_grad_norm)
1127
+
1128
+ optimizer.step()
1129
+ lr_scheduler.step()
1130
+
1131
+ optimizer.zero_grad(set_to_none=True)
1132
+
1133
+ # Checks if the accelerator has performed an optimization step behind the scenes
1134
+ if accelerator.sync_gradients:
1135
+ if (global_step + 1) % args.logging_steps == 0:
1136
+ logs = {
1137
+ "step_loss": avg_loss.item(),
1138
+ "lr": lr_scheduler.get_last_lr()[0],
1139
+ "avg_masking_rate": avg_masking_rate.item(),
1140
+ }
1141
+ accelerator.log(logs, step=global_step + 1)
1142
+
1143
+ logger.info(
1144
+ f"Step: {global_step + 1} "
1145
+ f"Loss: {avg_loss.item():0.4f} "
1146
+ f"LR: {lr_scheduler.get_last_lr()[0]:0.6f}"
1147
+ )
1148
+
1149
+ if (global_step + 1) % args.checkpointing_steps == 0:
1150
+ save_checkpoint(args, accelerator, global_step + 1, logger)
1151
+
1152
+ if (global_step + 1) % args.validation_steps == 0 and accelerator.is_main_process:
1153
+
1154
+ with torch.no_grad():
1155
+ logger.info("Generating images...")
1156
+
1157
+ model.eval()
1158
+
1159
+ scheduler = Scheduler.from_pretrained(
1160
+ args.pretrained_model_name_or_path,
1161
+ subfolder="scheduler",
1162
+ revision=args.revision,
1163
+ variant=args.variant,
1164
+ )
1165
+
1166
+ pipe = UnifiedPipeline(
1167
+ transformer=accelerator.unwrap_model(model),
1168
+ tokenizer=tokenizer,
1169
+ text_encoder=text_encoder,
1170
+ vqvae=vq_model,
1171
+ scheduler=scheduler,
1172
+ tokenizer_2=tokenizer_2,
1173
+ text_encoder_2=text_encoder_2,
1174
+ )
1175
+
1176
+ if not args.image_to_text_only:
1177
+ output = pipe(
1178
+ prompt=args.validation_prompts,
1179
+ height=args.resolution,
1180
+ width=args.resolution,
1181
+ guidance_scale=9,
1182
+ num_inference_steps=64,
1183
+ )
1184
+ pil_images = output.images
1185
+
1186
+ result=[]
1187
+ for img in pil_images:
1188
+ if not isinstance(img, torch.Tensor):
1189
+ img = transforms.ToTensor()(img)
1190
+ result.append(img.unsqueeze(0))
1191
+ result = torch.cat(result,dim=0)
1192
+ result = make_grid(result, nrow=3)
1193
+ save_image(result,os.path.join(args.output_dir, str(global_step)+'_text2image_1024_CFG-9.png'))
1194
+
1195
+ output_data = {
1196
+ "step": global_step,
1197
+ "prompts": args.validation_prompts,
1198
+ "images": [f"{global_step}_text2image_1024_CFG-9_{i}.png" for i in range(len(pil_images))]
1199
+ }
1200
+
1201
+ with open(os.path.join(args.output_dir, f"text2image_{global_step}.json"), "w") as f:
1202
+ json.dump(output_data, f, indent=2)
1203
+
1204
+ image = load_images_to_tensor(args.validation_images, target_size=(args.resolution, args.resolution))
1205
+ output = pipe(
1206
+ prompt=args.validation_vqa_prompts,
1207
+ height=args.resolution,
1208
+ width=args.resolution,
1209
+ guidance_scale=9,
1210
+ image=image,
1211
+ num_inference_steps=64
1212
+ )
1213
+ prompts = output.prompts
1214
+
1215
+ output_data = {
1216
+ "step": global_step,
1217
+ "prompts": prompts,
1218
+ }
1219
+
1220
+ with open(os.path.join(args.output_dir, f"image2text_{global_step}.json"), "w") as f:
1221
+ json.dump(output_data, f, indent=2)
1222
+
1223
+ model.train()
1224
+
1225
+ global_step += 1
1226
+
1227
+ # Stop training if max steps is reached
1228
+ if global_step >= args.max_train_steps:
1229
+ break
1230
+ # End for
1231
+
1232
+ accelerator.wait_for_everyone()
1233
+
1234
+ # Evaluate and save checkpoint at the end of training
1235
+ save_checkpoint(args, accelerator, global_step, logger)
1236
+
1237
+ # Save the final trained checkpoint
1238
+ if accelerator.is_main_process:
1239
+ model = accelerator.unwrap_model(model)
1240
+ model.save_pretrained(args.output_dir)
1241
+
1242
+ accelerator.end_training()
1243
+
1244
+
1245
+ if __name__ == "__main__":
1246
+ main(parse_args())
train/instruction_tuning.sh ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # bash it in root path
2
+ PYTHON_PATH='./' accelerate launch --multi_gpu --gpu_ids '0,1,2,3,4,5,6,7,8' --main_process_port 25000 --num_processes 8 train/instruction_tuning.py \
3
+ --output_dir "/path/to/output/dir" \
4
+ --train_batch_size 8 \
5
+ --gradient_accumulation_steps 8 \
6
+ --learning_rate 1e-4 \
7
+ --text_loss_weight 0.2 \
8
+ --max_grad_norm 10 \
9
+ --pretrained_model_name_or_path "MeissonFlow/Meissonic" \
10
+ --pretrained_transformer_path "MeissonFlow/Meissonic" \
11
+ --text_encoder_architecture 'open_clip' \
12
+ --dataset_type 'MultiSourceVLDataset' \
13
+ --instance_data_dir '/path/to/data' \
14
+ --llava_json_path '/path/to/llava_instruct_150k.json' \
15
+ --llava_image_root '/path/to/coco/train2017' \
16
+ --resolution 512 \
17
+ --mixed_precision fp16 \
18
+ --lr_scheduler constant \
19
+ --use_8bit_adam \
20
+ --dataloader_num_workers 4 \
21
+ --validation_prompts \
22
+ 'a boy' \
23
+ 'A serene mountain landscape with towering snow-capped peaks, a crystal-clear blue lake reflecting the mountains, dense pine forests, and a vibrant orange sunrise illuminating the sky.' \
24
+ 'A playful golden retriever puppy with a shiny coat, bounding through a meadow filled with colorful wildflowers, under a bright, clear blue sky.' \
25
+ 'A bustling city street at night, illuminated by vibrant neon signs in various colors, with busy pedestrians, street vendors, and a light rain creating reflective puddles on the pavement.' \
26
+ 'A majestic, medieval castle perched on a rugged cliffside, overlooking a vast, calm ocean at sunset, with the sky painted in hues of pink, orange, and purple.' \
27
+ 'An elegant ballerina in a white tutu, dancing gracefully on a grand stage with ornate, gold-trimmed curtains, under a spotlight that casts a soft glow.' \
28
+ 'A cozy, rustic log cabin nestled in a snow-covered forest, with smoke rising from the stone chimney, warm lights glowing from the windows, and a path of footprints leading to the front door.'\
29
+ 'A Cute Cat' \
30
+ 'A Snow Mountain'\
31
+ --validation_images '/path/to/validation/images/dir' \
32
+ --validation_vqa_prompts 'Please describe this image.' \
33
+ --max_train_steps 100000 \
34
+ --checkpointing_steps 100 \
35
+ --validation_steps 100 \
36
+ --report_to 'wandb' \
37
+ --logging_steps 10
train/train_text_decoder.py ADDED
@@ -0,0 +1,1017 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team and The MeissonFlow Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import argparse
16
+ import copy
17
+ import logging
18
+ import math
19
+ import os
20
+ import json
21
+ from contextlib import nullcontext
22
+ from pathlib import Path
23
+ import sys
24
+ sys.path.append(os.getcwd())
25
+ import gc
26
+ import torch
27
+ import torch.nn.functional as F
28
+ from accelerate import Accelerator
29
+ from accelerate.logging import get_logger
30
+ from accelerate.utils import ProjectConfiguration, set_seed
31
+ from peft import LoraConfig
32
+ from peft.utils import get_peft_model_state_dict
33
+ from torch.utils.data import DataLoader, default_collate
34
+ from torchvision import transforms
35
+ from transformers import (
36
+ CLIPTextModelWithProjection,
37
+ CLIPTokenizer,
38
+ )
39
+ import diffusers.optimization
40
+ from diffusers import EMAModel, VQModel
41
+ from src.scheduler import Scheduler
42
+ from diffusers.loaders import LoraLoaderMixin
43
+ from diffusers.utils import is_wandb_available
44
+ from src.pipeline import UnifiedPipeline
45
+ from torchvision.utils import save_image,make_grid
46
+ from datasets import load_dataset
47
+ from train.trainer_utils import save_checkpoint
48
+ from train.dataset_utils import ImageCaptionDataset, HuggingFaceDataset
49
+ from train.dataset_utils import tokenize_prompt, encode_prompt
50
+ from src.transformer import SymmetricTransformer2DModel, Transformer2DModel
51
+
52
+ if is_wandb_available():
53
+ import wandb
54
+ # wandb.login(key="")
55
+
56
+ logger = get_logger(__name__, log_level="INFO")
57
+
58
+ import torch._dynamo
59
+ torch._dynamo.config.verbose = True
60
+
61
+ # Optionally suppress errors to fall back to eager execution
62
+ torch._dynamo.config.suppress_errors = True
63
+
64
+ def parse_args():
65
+ parser = argparse.ArgumentParser()
66
+ parser.add_argument(
67
+ "--pretrained_model_architecture",
68
+ type=str,
69
+ default="Meissonic",
70
+ required=False
71
+ )
72
+ parser.add_argument(
73
+ "--text_encoder_architecture",
74
+ type=str,
75
+ default="open_clip",
76
+ required=False,
77
+ help="The architecture of the text encoder. One of ['CLIP', 'open_clip', 'flan-t5-base','Qwen2-0.5B','gemini-2b',long_CLIP_T5_base','CLIP_T5_base']",
78
+ )
79
+ parser.add_argument(
80
+ "--instance_dataset",
81
+ type=str,
82
+ default=None,
83
+ required=False,
84
+ help="The dataset to use for training. One of ['MSCOCO600K', 'PickaPicV2']",
85
+ )
86
+ parser.add_argument(
87
+ "--training_from_scratch",
88
+ type=bool,
89
+ default=False,
90
+ required=False
91
+ )
92
+ parser.add_argument(
93
+ "--pretrained_model_name_or_path",
94
+ type=str,
95
+ default=None,
96
+ required=True,
97
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
98
+ )
99
+ parser.add_argument(
100
+ "--revision",
101
+ type=str,
102
+ default=None,
103
+ required=False,
104
+ help="Revision of pretrained model identifier from huggingface.co/models.",
105
+ )
106
+ parser.add_argument(
107
+ "--variant",
108
+ type=str,
109
+ default=None,
110
+ help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
111
+ )
112
+ parser.add_argument(
113
+ "--instance_data_dataset",
114
+ type=str,
115
+ default=None,
116
+ required=False,
117
+ help="A Hugging Face dataset containing the training images",
118
+ )
119
+ parser.add_argument(
120
+ "--instance_data_dir",
121
+ type=str,
122
+ default=None,
123
+ required=False,
124
+ help="A folder containing the training data of instance images.",
125
+ )
126
+ parser.add_argument(
127
+ "--instance_data_image", type=str, default=None, required=False, help="A single training image"
128
+ )
129
+ parser.add_argument(
130
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
131
+ )
132
+ parser.add_argument(
133
+ "--dataloader_num_workers",
134
+ type=int,
135
+ default=0,
136
+ help=(
137
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
138
+ ),
139
+ )
140
+ parser.add_argument(
141
+ "--allow_tf32",
142
+ action="store_true",
143
+ help=(
144
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
145
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
146
+ ),
147
+ )
148
+ parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.")
149
+ parser.add_argument("--ema_decay", type=float, default=0.9999)
150
+ parser.add_argument("--ema_update_after_step", type=int, default=0)
151
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
152
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
153
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
154
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
155
+ parser.add_argument(
156
+ "--output_dir",
157
+ type=str,
158
+ default="muse_training",
159
+ help="The output directory where the model predictions and checkpoints will be written.",
160
+ )
161
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
162
+ parser.add_argument(
163
+ "--logging_dir",
164
+ type=str,
165
+ default="logs",
166
+ help=(
167
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
168
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
169
+ ),
170
+ )
171
+ parser.add_argument(
172
+ "--max_train_steps",
173
+ type=int,
174
+ default=None,
175
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
176
+ )
177
+ parser.add_argument(
178
+ "--checkpointing_steps",
179
+ type=int,
180
+ default=500,
181
+ help=(
182
+ "Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. "
183
+ "In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference."
184
+ "Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components."
185
+ "See https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step"
186
+ "instructions."
187
+ ),
188
+ )
189
+ parser.add_argument(
190
+ "--logging_steps",
191
+ type=int,
192
+ default=50,
193
+ )
194
+ parser.add_argument(
195
+ "--checkpoints_total_limit",
196
+ type=int,
197
+ default=None,
198
+ help=(
199
+ "Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`."
200
+ " See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state"
201
+ " for more details"
202
+ ),
203
+ )
204
+ parser.add_argument(
205
+ "--resume_from_checkpoint",
206
+ type=str,
207
+ default=None,
208
+ help=(
209
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
210
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
211
+ ),
212
+ )
213
+ parser.add_argument(
214
+ "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
215
+ )
216
+ parser.add_argument(
217
+ "--gradient_accumulation_steps",
218
+ type=int,
219
+ default=1,
220
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
221
+ )
222
+ parser.add_argument(
223
+ "--learning_rate",
224
+ type=float,
225
+ default=0.0003,
226
+ help="Initial learning rate (after the potential warmup period) to use.",
227
+ )
228
+ parser.add_argument(
229
+ "--scale_lr",
230
+ action="store_true",
231
+ default=False,
232
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
233
+ )
234
+ parser.add_argument(
235
+ "--lr_scheduler",
236
+ type=str,
237
+ default="constant",
238
+ help=(
239
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
240
+ ' "constant", "constant_with_warmup"]'
241
+ ),
242
+ )
243
+ parser.add_argument(
244
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
245
+ )
246
+ parser.add_argument(
247
+ "--validation_steps",
248
+ type=int,
249
+ default=100,
250
+ help=(
251
+ "Run validation every X steps. Validation consists of running the prompt"
252
+ " `args.validation_prompt` multiple times: `args.num_validation_images`"
253
+ " and logging the images."
254
+ ),
255
+ )
256
+ parser.add_argument(
257
+ "--mixed_precision",
258
+ type=str,
259
+ default=None,
260
+ choices=["no", "fp16", "bf16"],
261
+ help=(
262
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
263
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
264
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
265
+ ),
266
+ )
267
+ parser.add_argument(
268
+ "--report_to",
269
+ type=str,
270
+ default="wandb",
271
+ help=(
272
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
273
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
274
+ ),
275
+ )
276
+ parser.add_argument("--validation_prompts", type=str, nargs="*")
277
+ parser.add_argument(
278
+ "--resolution",
279
+ type=int,
280
+ default=512,
281
+ help=(
282
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
283
+ " resolution"
284
+ ),
285
+ )
286
+ parser.add_argument("--split_vae_encode", type=int, required=False, default=None)
287
+ parser.add_argument("--min_masking_rate", type=float, default=0.0)
288
+ parser.add_argument("--cond_dropout_prob", type=float, default=0.0)
289
+ parser.add_argument("--max_grad_norm", default=50.0, type=float, help="Max gradient norm.", required=False)
290
+ parser.add_argument("--use_lora", action="store_true", help="Fine tune the model using LoRa")
291
+ parser.add_argument("--text_encoder_use_lora", action="store_true", help="Fine tune the model using LoRa")
292
+ parser.add_argument("--lora_r", default=16, type=int)
293
+ parser.add_argument("--lora_alpha", default=32, type=int)
294
+ parser.add_argument("--lora_target_modules", default=["to_q", "to_k", "to_v"], type=str, nargs="+")
295
+ parser.add_argument("--text_encoder_lora_r", default=16, type=int)
296
+ parser.add_argument("--text_encoder_lora_alpha", default=32, type=int)
297
+ parser.add_argument("--text_encoder_lora_target_modules", default=["to_q", "to_k", "to_v"], type=str, nargs="+")
298
+ parser.add_argument("--train_text_encoder", action="store_true")
299
+ parser.add_argument("--image_key", type=str, required=False)
300
+ parser.add_argument("--prompt_key", type=str, required=False)
301
+ parser.add_argument(
302
+ "--gradient_checkpointing",
303
+ action="store_true",
304
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
305
+ )
306
+ parser.add_argument("--prompt_prefix", type=str, required=False, default=None)
307
+
308
+ args = parser.parse_args()
309
+
310
+ if args.report_to == "wandb":
311
+ if not is_wandb_available():
312
+ raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
313
+
314
+ num_datasources = sum(
315
+ [x is not None for x in [args.instance_data_dir, args.instance_data_image, args.instance_data_dataset]]
316
+ )
317
+
318
+ if num_datasources != 1:
319
+ raise ValueError(
320
+ "provide one and only one of `--instance_data_dir`, `--instance_data_image`, or `--instance_data_dataset`"
321
+ )
322
+
323
+ if args.instance_data_dir is not None:
324
+ if not os.path.exists(args.instance_data_dir):
325
+ raise ValueError(f"Does not exist: `--args.instance_data_dir` {args.instance_data_dir}")
326
+
327
+ if args.instance_data_image is not None:
328
+ if not os.path.exists(args.instance_data_image):
329
+ raise ValueError(f"Does not exist: `--args.instance_data_image` {args.instance_data_image}")
330
+
331
+ if args.instance_data_dataset is not None and (args.image_key is None or args.prompt_key is None):
332
+ raise ValueError("`--instance_data_dataset` requires setting `--image_key` and `--prompt_key`")
333
+
334
+ return args
335
+
336
+ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
337
+ latent_image_ids = torch.zeros(height // 2, width // 2, 3)
338
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
339
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
340
+
341
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
342
+
343
+ latent_image_ids = latent_image_ids.reshape(
344
+ latent_image_id_height * latent_image_id_width, latent_image_id_channels
345
+ )
346
+ # latent_image_ids = latent_image_ids.unsqueeze(0).repeat(batch_size, 1, 1)
347
+
348
+ return latent_image_ids.to(device=device, dtype=dtype)
349
+
350
+ def main(args):
351
+ if args.allow_tf32:
352
+ torch.backends.cuda.matmul.allow_tf32 = True
353
+
354
+ # if args.pretrained_model_architecture == "Meissonic":
355
+ # from src.pipeline import Pipeline
356
+ # else:
357
+ # raise ValueError(f"Unknown model architecture: {args.pretrained_model_architecture}")
358
+
359
+
360
+ logging_dir = Path(args.output_dir, args.logging_dir)
361
+
362
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
363
+
364
+ accelerator = Accelerator(
365
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
366
+ mixed_precision=args.mixed_precision,
367
+ log_with=args.report_to,
368
+ project_config=accelerator_project_config,
369
+ )
370
+
371
+ if accelerator.is_main_process:
372
+ os.makedirs(args.output_dir, exist_ok=True)
373
+
374
+ # Make one log on every process with the configuration for debugging.
375
+ logging.basicConfig(
376
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
377
+ datefmt="%m/%d/%Y %H:%M:%S",
378
+ level=logging.INFO,
379
+ )
380
+ logger.info(accelerator.state, main_process_only=False)
381
+
382
+ if accelerator.is_main_process:
383
+ accelerator.init_trackers("meissonic", config=vars(copy.deepcopy(args)))
384
+
385
+ if args.seed is not None:
386
+ set_seed(args.seed)
387
+
388
+ if args.text_encoder_architecture == "open_clip":
389
+ if args.resume_from_checkpoint:
390
+ text_encoder = CLIPTextModelWithProjection.from_pretrained(
391
+ args.resume_from_checkpoint, subfolder="text_encoder", variant=args.variant
392
+ )
393
+ tokenizer = CLIPTokenizer.from_pretrained(
394
+ args.resume_from_checkpoint, subfolder="tokenizer", variant=args.variant
395
+ )
396
+ else:
397
+ text_encoder = CLIPTextModelWithProjection.from_pretrained(
398
+ args.pretrained_model_name_or_path, subfolder="text_encoder", variant=args.variant
399
+ )
400
+ tokenizer = CLIPTokenizer.from_pretrained(
401
+ args.pretrained_model_name_or_path, subfolder="tokenizer", variant=args.variant
402
+ )
403
+ else:
404
+ raise ValueError(f"Unknown text encoder architecture: {args.text_encoder_architecture}")
405
+
406
+ vq_model = VQModel.from_pretrained(
407
+ args.pretrained_model_name_or_path, subfolder="vqvae", revision=args.revision, variant=args.variant
408
+ )
409
+
410
+ if args.train_text_encoder:
411
+ if args.text_encoder_use_lora:
412
+ lora_config = LoraConfig(
413
+ r=args.text_encoder_lora_r,
414
+ lora_alpha=args.text_encoder_lora_alpha,
415
+ target_modules=args.text_encoder_lora_target_modules,
416
+ )
417
+ text_encoder.add_adapter(lora_config)
418
+ if args.text_encoder_architecture == "CLIP_T5_base": # Not support yet. Only support open_clip
419
+ text_encoder[0].train()
420
+ text_encoder[0].requires_grad_(True)
421
+ text_encoder[1].train()
422
+ text_encoder[1].requires_grad_(True)
423
+ else:
424
+ text_encoder.train()
425
+ text_encoder.requires_grad_(True)
426
+ else:
427
+ if args.text_encoder_architecture == "CLIP_T5_base": # Not support yet. Only support open_clip
428
+ text_encoder[0].eval()
429
+ text_encoder[0].requires_grad_(False)
430
+ text_encoder[1].eval()
431
+ text_encoder[1].requires_grad_(False)
432
+ else:
433
+ text_encoder.eval()
434
+ text_encoder.requires_grad_(False)
435
+
436
+ vq_model.requires_grad_(False)
437
+
438
+ if args.pretrained_model_architecture == "Meissonic":
439
+ if args.training_from_scratch:
440
+ model = SymmetricTransformer2DModel(
441
+ patch_size = 1,
442
+ in_channels = 64,
443
+ num_layers = 14,
444
+ num_single_layers = 28,
445
+ attention_head_dim = 128,
446
+ num_attention_heads = 8,
447
+ joint_attention_dim = 1024,
448
+ pooled_projection_dim = 1024,
449
+ guidance_embeds = False,
450
+ axes_dims_rope = (16, 56, 56),
451
+ downsample= True,
452
+ upsample= True,
453
+ )
454
+ else:
455
+ orig_model = Transformer2DModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="transformer", low_cpu_mem_usage=False, device_map=None)
456
+ orig_config = orig_model.config
457
+ config = {}
458
+ for k, v in orig_config.items():
459
+ if k.startswith("_"):
460
+ continue
461
+ config[k] = v
462
+ config["tokenizer_vocab_size"] = tokenizer.vocab_size
463
+
464
+ model = SymmetricTransformer2DModel(**config)
465
+ model.load_state_dict(orig_model.state_dict(), strict=False)
466
+
467
+ del orig_model
468
+ gc.collect()
469
+ torch.cuda.empty_cache()
470
+ else:
471
+ raise ValueError(f"Unknown model architecture: {args.pretrained_model_architecture}")
472
+
473
+ model = torch.compile(model)
474
+
475
+ if args.use_lora:
476
+ lora_config = LoraConfig(
477
+ r=args.lora_r,
478
+ lora_alpha=args.lora_alpha,
479
+ target_modules=args.lora_target_modules,
480
+ )
481
+ model.add_adapter(lora_config)
482
+
483
+ model.train()
484
+ for n, p in model.named_parameters():
485
+ if "text_decoder" in n:
486
+ p.requires_grad_(True)
487
+ else:
488
+ p.requires_grad_(False)
489
+
490
+ if args.gradient_checkpointing:
491
+ model.enable_gradient_checkpointing()
492
+ if args.train_text_encoder:
493
+ if args.text_encoder_architecture == "CLIP_T5_base": # Not support yet. Only support open_clip
494
+ text_encoder[0].gradient_checkpointing_enable()
495
+ text_encoder[1].gradient_checkpointing_enable()
496
+ else:
497
+ text_encoder.gradient_checkpointing_enable()
498
+
499
+ if args.use_ema: # Not verify the robostness of this part
500
+ ema = EMAModel(
501
+ model.parameters(),
502
+ decay=args.ema_decay,
503
+ update_after_step=args.ema_update_after_step,
504
+ model_cls= Transformer2DModel,
505
+ model_config=model.config,
506
+ )
507
+
508
+ def save_model_hook(models, weights, output_dir):
509
+ if accelerator.is_main_process:
510
+ transformer_lora_layers_to_save = None
511
+ text_encoder_lora_layers_to_save = None
512
+
513
+ for model_ in models:
514
+ if isinstance(model_, type(accelerator.unwrap_model(model))):
515
+ if args.use_lora:
516
+ transformer_lora_layers_to_save = get_peft_model_state_dict(model_)
517
+ else:
518
+ model_.save_pretrained(os.path.join(output_dir, "transformer"))
519
+ elif isinstance(model_, type(accelerator.unwrap_model(text_encoder))):
520
+ if args.text_encoder_use_lora:
521
+ text_encoder_lora_layers_to_save = get_peft_model_state_dict(model_)
522
+ else:
523
+ model_.save_pretrained(os.path.join(output_dir, "text_encoder"))
524
+ else:
525
+ raise ValueError(f"unexpected save model: {model_.__class__}")
526
+
527
+ # make sure to pop weight so that corresponding model is not saved again
528
+ weights.pop()
529
+
530
+ if transformer_lora_layers_to_save is not None or text_encoder_lora_layers_to_save is not None:
531
+ LoraLoaderMixin.save_lora_weights(
532
+ output_dir,
533
+ unet_lora_layers=transformer_lora_layers_to_save,
534
+ text_encoder_lora_layers=text_encoder_lora_layers_to_save,
535
+ )
536
+
537
+ if args.use_ema:
538
+ ema.save_pretrained(os.path.join(output_dir, "ema_model"))
539
+
540
+ def load_model_hook(models, input_dir):
541
+ transformer = None
542
+ text_encoder_ = None
543
+
544
+ # this part is added for keep consistency when add model.compile() in the model
545
+ def adap_compile(ori_dict):#add '_orig_mod.' to each key
546
+ new_dict = {}
547
+ for k,v in ori_dict.items():
548
+ new_dict['_orig_mod.'+k] = v
549
+ return new_dict
550
+
551
+ while len(models) > 0:
552
+ model_ = models.pop()
553
+
554
+ if isinstance(model_, type(accelerator.unwrap_model(model))):
555
+ if args.use_lora:
556
+ transformer = model_
557
+ else:
558
+ if args.pretrained_model_architecture == "Meissonic":
559
+ load_model = SymmetricTransformer2DModel.from_pretrained(os.path.join(input_dir, "transformer"), low_cpu_mem_usage=False, device_map=None)
560
+ else:
561
+ raise ValueError(f"Unknown model architecture: {args.pretrained_model_architecture}")
562
+ model_.load_state_dict(adap_compile(load_model.state_dict()))
563
+ del load_model
564
+ elif isinstance(model_, type(accelerator.unwrap_model(text_encoder))):
565
+ if args.text_encoder_use_lora:
566
+ text_encoder_ = model_
567
+ else:
568
+ try:
569
+ load_model = CLIPTextModelWithProjection.from_pretrained(os.path.join(input_dir, "text_encoder"))
570
+ model_.load_state_dict(load_model.state_dict())
571
+ # print('finished loading text encoder!')
572
+ except:
573
+ print('Not found text-encoder model in current folder. So we download one text encoder from Internet.')
574
+ load_model = CLIPTextModelWithProjection.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K")
575
+ model_.load_state_dict(load_model.state_dict())
576
+ del load_model
577
+ else:
578
+ raise ValueError(f"unexpected save model: {model.__class__}")
579
+
580
+ if transformer is not None or text_encoder_ is not None:
581
+ lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir)
582
+ LoraLoaderMixin.load_lora_into_text_encoder(
583
+ lora_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_
584
+ )
585
+ LoraLoaderMixin.load_lora_into_transformer(
586
+ lora_state_dict, network_alphas=network_alphas, transformer=transformer
587
+ )
588
+
589
+ if args.use_ema:
590
+ load_from = EMAModel.from_pretrained(os.path.join(input_dir, "ema_model"), model_cls=Transformer2DModel)
591
+ ema.load_state_dict(adap_compile(load_from.state_dict()))
592
+ del load_from
593
+
594
+ accelerator.register_load_state_pre_hook(load_model_hook)
595
+ accelerator.register_save_state_pre_hook(save_model_hook)
596
+
597
+ if args.scale_lr:
598
+ args.learning_rate = (
599
+ args.learning_rate * args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
600
+ )
601
+
602
+ if args.use_8bit_adam:
603
+ try:
604
+ import bitsandbytes as bnb
605
+ except ImportError:
606
+ raise ImportError(
607
+ "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
608
+ )
609
+
610
+ optimizer_cls = bnb.optim.AdamW8bit
611
+ else:
612
+ optimizer_cls = torch.optim.AdamW
613
+
614
+
615
+ optimizer_parameters = [p for p in model.parameters() if p.requires_grad]
616
+
617
+ optimizer = optimizer_cls(
618
+ optimizer_parameters,
619
+ lr=args.learning_rate,
620
+ betas=(args.adam_beta1, args.adam_beta2),
621
+ weight_decay=args.adam_weight_decay,
622
+ eps=args.adam_epsilon,
623
+ )
624
+
625
+ logger.info("Creating dataloaders and lr_scheduler")
626
+
627
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
628
+
629
+ if args.instance_dataset == "MyParquetDataset":
630
+ dataset = ImageCaptionDataset(
631
+ root_dir=args.instance_data_dir, # something like '../parquets_father_dir/'
632
+ tokenizer=tokenizer,
633
+ size=args.resolution,
634
+ text_encoder_architecture=args.text_encoder_architecture
635
+ )
636
+ elif args.instance_dataset == 'HuggingFaceDataset': # you can try this first, just download dataset from huggingface
637
+ dataset = HuggingFaceDataset(
638
+ hf_dataset=load_dataset(args.instance_data_dir, split="train"), # something like './parquets_father_dir/'
639
+ tokenizer=tokenizer,
640
+ image_key='image',
641
+ prompt_key='caption',
642
+ prompt_prefix=args.prompt_prefix,
643
+ size=args.resolution,
644
+ text_encoder_architecture=args.text_encoder_architecture
645
+ )
646
+ elif args.instance_dataset == "DATA_TYPE":
647
+ raise NotImplementedError("DATA_TYPE is not yet supported")
648
+ else:
649
+ assert False
650
+
651
+ train_dataloader = DataLoader(
652
+ dataset,
653
+ batch_size=args.train_batch_size,
654
+ shuffle=True,
655
+ num_workers=args.dataloader_num_workers,
656
+ collate_fn=default_collate,
657
+ pin_memory=True,
658
+ )
659
+ train_dataloader.num_batches = len(train_dataloader)
660
+
661
+ lr_scheduler = diffusers.optimization.get_scheduler(
662
+ args.lr_scheduler,
663
+ optimizer=optimizer,
664
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
665
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
666
+ )
667
+
668
+ logger.info("Preparing model, optimizer and dataloaders")
669
+
670
+ if args.train_text_encoder:
671
+ if args.text_encoder_architecture == "CLIP_T5_base": # Not support yet. Only support open_clip
672
+ model, optimizer, lr_scheduler, train_dataloader, text_encoder[0], text_encoder[1] = accelerator.prepare(
673
+ model, optimizer, lr_scheduler, train_dataloader, text_encoder[0], text_encoder[1]
674
+ )
675
+ else:
676
+ model, optimizer, lr_scheduler, train_dataloader, text_encoder = accelerator.prepare(
677
+ model, optimizer, lr_scheduler, train_dataloader, text_encoder
678
+ )
679
+ else:
680
+ model, optimizer, lr_scheduler, train_dataloader = accelerator.prepare(
681
+ model, optimizer, lr_scheduler, train_dataloader
682
+ )
683
+
684
+ train_dataloader.num_batches = len(train_dataloader)
685
+
686
+ weight_dtype = torch.float32
687
+ if accelerator.mixed_precision == "fp16":
688
+ weight_dtype = torch.float16
689
+ elif accelerator.mixed_precision == "bf16":
690
+ weight_dtype = torch.bfloat16
691
+
692
+ if not args.train_text_encoder:
693
+ if args.text_encoder_architecture == "CLIP_T5_base": # Not support yet. Only support open_clip
694
+ text_encoder[0].to(device=accelerator.device, dtype=weight_dtype)
695
+ text_encoder[1].to(device=accelerator.device, dtype=weight_dtype)
696
+ else:
697
+ text_encoder.to(device=accelerator.device, dtype=weight_dtype)
698
+
699
+ vq_model.to(device=accelerator.device)
700
+
701
+ if args.use_ema:
702
+ ema.to(accelerator.device)
703
+
704
+ with nullcontext() if args.train_text_encoder else torch.no_grad():
705
+ if args.text_encoder_architecture == "CLIP_T5_base": # Not support yet. Only support open_clip
706
+ _input_ids_tmp_ = tokenize_prompt(tokenizer, "", args.text_encoder_architecture)
707
+ _input_ids_tmp_[0] = _input_ids_tmp_[0].to(accelerator.device, non_blocking=True)
708
+ _input_ids_tmp_[1] = _input_ids_tmp_[1].to(accelerator.device, non_blocking=True)
709
+ empty_embeds, empty_clip_embeds = encode_prompt(
710
+ text_encoder, _input_ids_tmp_, args.text_encoder_architecture
711
+ )
712
+ else:
713
+ empty_embeds, empty_clip_embeds = encode_prompt(
714
+ text_encoder, tokenize_prompt(tokenizer, "", args.text_encoder_architecture).to(accelerator.device, non_blocking=True), args.text_encoder_architecture
715
+ )
716
+
717
+ # There is a single image, we can just pre-encode the single prompt
718
+ if args.instance_data_image is not None:
719
+ prompt = os.path.splitext(os.path.basename(args.instance_data_image))[0]
720
+ if args.text_encoder_architecture == "CLIP_T5_base": # Not support yet. Only support open_clip
721
+ _input_ids_tmp_ = tokenize_prompt(tokenizer, prompt, args.text_encoder_architecture)
722
+ _input_ids_tmp_[0] = _input_ids_tmp_[0].to(accelerator.device, non_blocking=True)
723
+ _input_ids_tmp_[1] = _input_ids_tmp_[1].to(accelerator.device, non_blocking=True)
724
+ empty_embeds, empty_clip_embeds = encode_prompt(
725
+ text_encoder, _input_ids_tmp_, args.text_encoder_architecture
726
+ )
727
+ else:
728
+ encoder_hidden_states, cond_embeds = encode_prompt(
729
+ text_encoder, tokenize_prompt(tokenizer, prompt, args.text_encoder_architecture).to(accelerator.device, non_blocking=True), args.text_encoder_architecture
730
+ )
731
+ encoder_hidden_states = encoder_hidden_states.repeat(args.train_batch_size, 1, 1)
732
+ cond_embeds = cond_embeds.repeat(args.train_batch_size, 1)
733
+
734
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
735
+ num_update_steps_per_epoch = math.ceil(train_dataloader.num_batches / args.gradient_accumulation_steps)
736
+ # Afterwards we recalculate our number of training epochs.
737
+ # Note: We are not doing epoch based training here, but just using this for book keeping and being able to
738
+ # reuse the same training loop with other datasets/loaders.
739
+ num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
740
+
741
+ # Train!
742
+ logger.info("***** Running training *****")
743
+ logger.info(f" Num training steps = {args.max_train_steps}")
744
+ logger.info(f" Instantaneous batch size per device = { args.train_batch_size}")
745
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
746
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
747
+
748
+ resume_from_checkpoint = args.resume_from_checkpoint
749
+ if resume_from_checkpoint:
750
+ if resume_from_checkpoint == "latest":
751
+ # Get the most recent checkpoint
752
+ dirs = os.listdir(args.output_dir)
753
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
754
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
755
+ if len(dirs) > 0:
756
+ resume_from_checkpoint = os.path.join(args.output_dir, dirs[-1])
757
+ else:
758
+ resume_from_checkpoint = None
759
+
760
+ if resume_from_checkpoint is None:
761
+ accelerator.print(
762
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
763
+ )
764
+ else:
765
+ accelerator.print(f"Resuming from checkpoint {resume_from_checkpoint}")
766
+
767
+ if resume_from_checkpoint is None:
768
+ global_step = 0
769
+ first_epoch = 0
770
+ else:
771
+ accelerator.load_state(resume_from_checkpoint)
772
+ global_step = int(os.path.basename(resume_from_checkpoint).split("-")[1])
773
+ first_epoch = global_step // num_update_steps_per_epoch
774
+
775
+ # This is to solve the inconsistent tensor device issue
776
+ if args.use_ema:
777
+ ema.shadow_params = [p.to(accelerator.device) for p in ema.shadow_params]
778
+
779
+ # As stated above, we are not doing epoch based training here, but just using this for book keeping and being able to
780
+ # reuse the same training loop with other datasets/loaders.
781
+ for epoch in range(first_epoch, num_train_epochs):
782
+ for batch in train_dataloader:
783
+ torch.cuda.empty_cache()
784
+ with torch.no_grad():
785
+ micro_conds = batch["micro_conds"].to(accelerator.device, non_blocking=True)
786
+ pixel_values = batch["image"].to(accelerator.device, non_blocking=True)
787
+
788
+ batch_size = pixel_values.shape[0]
789
+
790
+ split_batch_size = args.split_vae_encode if args.split_vae_encode is not None else batch_size
791
+ num_splits = math.ceil(batch_size / split_batch_size)
792
+ image_tokens = []
793
+ for i in range(num_splits):
794
+ start_idx = i * split_batch_size
795
+ end_idx = min((i + 1) * split_batch_size, batch_size)
796
+ bs = pixel_values.shape[0]
797
+ image_tokens.append(
798
+ vq_model.quantize(
799
+ vq_model.encode(
800
+ pixel_values[start_idx: end_idx]
801
+ ).latents
802
+ )[2][2].reshape(split_batch_size, -1)
803
+ )
804
+ image_tokens = torch.cat(image_tokens, dim=0)
805
+
806
+ batch_size, seq_len = image_tokens.shape
807
+
808
+ timesteps = torch.ones(batch_size, device=image_tokens.device)
809
+ mask_prob = torch.cos(timesteps * math.pi * 0.5)
810
+ mask_prob = mask_prob.clip(args.min_masking_rate)
811
+
812
+ num_token_masked = (seq_len * mask_prob).round().clamp(min=1)
813
+ batch_randperm = torch.rand(batch_size, seq_len, device=image_tokens.device).argsort(dim=-1)
814
+ mask = batch_randperm < num_token_masked.unsqueeze(-1)
815
+
816
+ mask_id = accelerator.unwrap_model(model).config.vocab_size - 1
817
+ input_ids = torch.where(mask, mask_id, image_tokens)
818
+ # labels = torch.where(mask, image_tokens, -100)
819
+
820
+ if "prompt_input_ids" in batch:
821
+ prompt_input_ids = batch["prompt_input_ids"]
822
+ labels = prompt_input_ids
823
+ with nullcontext() if args.train_text_encoder else torch.no_grad():
824
+ if args.text_encoder_architecture == "CLIP_T5_base": # Not support yet. Only support open_clip
825
+ batch["prompt_input_ids"][0] = batch["prompt_input_ids"][0].to(accelerator.device, non_blocking=True)
826
+ batch["prompt_input_ids"][1] = batch["prompt_input_ids"][1].to(accelerator.device, non_blocking=True)
827
+ encoder_hidden_states, cond_embeds = encode_prompt(
828
+ text_encoder, batch["prompt_input_ids"], args.text_encoder_architecture
829
+ )
830
+ else:
831
+ encoder_hidden_states, cond_embeds = encode_prompt(
832
+ text_encoder, batch["prompt_input_ids"].to(accelerator.device, non_blocking=True), args.text_encoder_architecture
833
+ )
834
+
835
+ if args.cond_dropout_prob > 0.0:
836
+ assert encoder_hidden_states is not None
837
+
838
+ batch_size = encoder_hidden_states.shape[0]
839
+
840
+ mask = (
841
+ torch.zeros((batch_size, 1, 1), device=encoder_hidden_states.device).float().uniform_(0, 1)
842
+ < args.cond_dropout_prob
843
+ )
844
+
845
+ empty_embeds_ = empty_embeds.expand(batch_size, -1, -1)
846
+ encoder_hidden_states = torch.where(
847
+ (encoder_hidden_states * mask).bool(), encoder_hidden_states, empty_embeds_
848
+ )
849
+
850
+ empty_clip_embeds_ = empty_clip_embeds.expand(batch_size, -1)
851
+ cond_embeds = torch.where((cond_embeds * mask.squeeze(-1)).bool(), cond_embeds, empty_clip_embeds_)
852
+
853
+ bs = input_ids.shape[0]
854
+ vae_scale_factor = 2 ** (len(vq_model.config.block_out_channels) - 1)
855
+ resolution = args.resolution // vae_scale_factor
856
+ input_ids = input_ids.reshape(bs, resolution, resolution)
857
+
858
+ # Train Step
859
+ with accelerator.accumulate(model):
860
+ if args.pretrained_model_architecture == 'Meissonic':
861
+
862
+ if args.resolution == 1024: # only stage 3 and stage 4 do not apply 2*
863
+ img_ids = _prepare_latent_image_ids(input_ids.shape[0], input_ids.shape[-2], input_ids.shape[-1], input_ids.device, input_ids.dtype)
864
+ else:
865
+ img_ids = _prepare_latent_image_ids(input_ids.shape[0], 2 * input_ids.shape[-2], 2 * input_ids.shape[-1], input_ids.device, input_ids.dtype)
866
+
867
+ txt_ids = torch.zeros(encoder_hidden_states.shape[1], 3).to(device = input_ids.device, dtype = input_ids.dtype)
868
+
869
+ logits = model(
870
+ hidden_states=input_ids, # should be (batch size, channel, height, width)
871
+ encoder_hidden_states=encoder_hidden_states, # should be (batch size, sequence_len, embed_dims)
872
+ micro_conds=micro_conds, #
873
+ pooled_projections=cond_embeds, # should be (batch_size, projection_dim)
874
+ img_ids=img_ids,
875
+ txt_ids=txt_ids,
876
+ timestep=mask_prob * 1000,
877
+ )[1]
878
+ # print(logits.shape)
879
+ logits = logits.reshape(-1, tokenizer.vocab_size)
880
+
881
+ else:
882
+ raise ValueError(f"Unknown model architecture: {args.pretrained_model_architecture}")
883
+
884
+ loss = F.cross_entropy(
885
+ logits,
886
+ labels.view(-1),
887
+ ignore_index=-100,
888
+ reduction="mean",
889
+ )
890
+
891
+ # Gather the losses across all processes for logging (if we use distributed training).
892
+ avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
893
+ avg_masking_rate = accelerator.gather(mask_prob.repeat(args.train_batch_size)).mean()
894
+
895
+ accelerator.backward(loss)
896
+
897
+ if args.max_grad_norm is not None and accelerator.sync_gradients:
898
+ accelerator.clip_grad_norm_(model.parameters(), args.max_grad_norm)
899
+
900
+ optimizer.step()
901
+ lr_scheduler.step()
902
+
903
+ optimizer.zero_grad(set_to_none=True)
904
+
905
+ # Checks if the accelerator has performed an optimization step behind the scenes
906
+ if accelerator.sync_gradients:
907
+ if args.use_ema:
908
+ ema.step(model.parameters())
909
+
910
+ if (global_step + 1) % args.logging_steps == 0:
911
+ logs = {
912
+ "step_loss": avg_loss.item(),
913
+ "lr": lr_scheduler.get_last_lr()[0],
914
+ "avg_masking_rate": avg_masking_rate.item(),
915
+ }
916
+ accelerator.log(logs, step=global_step + 1)
917
+
918
+ logger.info(
919
+ f"Step: {global_step + 1} "
920
+ f"Loss: {avg_loss.item():0.4f} "
921
+ f"LR: {lr_scheduler.get_last_lr()[0]:0.6f}"
922
+ )
923
+
924
+ if (global_step + 1) % args.checkpointing_steps == 0:
925
+ save_checkpoint(args, accelerator, global_step + 1, logger)
926
+
927
+ if (global_step + 1) % args.validation_steps == 0 and accelerator.is_main_process:
928
+ if args.use_ema:
929
+ ema.store(model.parameters())
930
+ ema.copy_to(model.parameters())
931
+
932
+ with torch.no_grad():
933
+ logger.info("Generating images...")
934
+
935
+ model.eval()
936
+
937
+ scheduler = Scheduler.from_pretrained(
938
+ args.pretrained_model_name_or_path,
939
+ subfolder="scheduler",
940
+ revision=args.revision,
941
+ variant=args.variant,
942
+ )
943
+
944
+ pipe = UnifiedPipeline(
945
+ transformer=accelerator.unwrap_model(model),
946
+ tokenizer=tokenizer,
947
+ text_encoder=text_encoder,
948
+ vqvae=vq_model,
949
+ scheduler=scheduler,
950
+ )
951
+
952
+ output = pipe(
953
+ prompt=args.validation_prompts,
954
+ height=args.resolution,
955
+ width=args.resolution,
956
+ guidance_scale=9,
957
+ num_inference_steps=64
958
+ )
959
+ pil_images = output.images
960
+ prompts = output.prompts
961
+ print(prompts)
962
+
963
+ result=[]
964
+ for img in pil_images:
965
+ if not isinstance(img, torch.Tensor):
966
+ img = transforms.ToTensor()(img)
967
+ result.append(img.unsqueeze(0))
968
+ result = torch.cat(result,dim=0)
969
+ result = make_grid(result, nrow=3)
970
+ save_image(result,os.path.join(args.output_dir, str(global_step)+'_text2image_1024_CFG-9.png'))
971
+
972
+ # 保存为JSON
973
+ output_data = {
974
+ "step": global_step,
975
+ "prompts": prompts,
976
+ "images": [f"{global_step}_text2image_1024_CFG-9_{i}.png" for i in range(len(pil_images))]
977
+ }
978
+
979
+ with open(os.path.join(args.output_dir, f"prompts_{global_step}.json"), "w") as f:
980
+ json.dump(output_data, f, indent=2)
981
+
982
+ model.train()
983
+
984
+ if args.train_text_encoder:
985
+ if args.text_encoder_architecture == "CLIP_T5_base": # Not support yet. Only support open_clip
986
+ text_encoder[0].train()
987
+ text_encoder[1].trian()
988
+ else:
989
+ text_encoder.train()
990
+
991
+ if args.use_ema:
992
+ ema.restore(model.parameters())
993
+
994
+ global_step += 1
995
+
996
+ # Stop training if max steps is reached
997
+ if global_step >= args.max_train_steps:
998
+ break
999
+ # End for
1000
+
1001
+ accelerator.wait_for_everyone()
1002
+
1003
+ # Evaluate and save checkpoint at the end of training
1004
+ save_checkpoint(args, accelerator, global_step, logger)
1005
+
1006
+ # Save the final trained checkpoint
1007
+ if accelerator.is_main_process:
1008
+ model = accelerator.unwrap_model(model)
1009
+ if args.use_ema:
1010
+ ema.copy_to(model.parameters())
1011
+ model.save_pretrained(args.output_dir)
1012
+
1013
+ accelerator.end_training()
1014
+
1015
+
1016
+ if __name__ == "__main__":
1017
+ main(parse_args())
train/train_text_decoder.sh ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # bash it in root path
2
+ PYTHON_PATH='./' accelerate launch --multi_gpu --gpu_ids '2,3' --main_process_port 25001 --num_processes 2 train/train_text_decoder.py \
3
+ --output_dir "./outputs/football/" \
4
+ --train_batch_size 16 \
5
+ --gradient_accumulation_steps 2 \
6
+ --learning_rate 1e-5 \
7
+ --max_grad_norm 10 \
8
+ --pretrained_model_name_or_path "meissonflow/meissonic" \
9
+ --text_encoder_architecture 'open_clip' \
10
+ --pretrained_model_architecture 'Meissonic' \
11
+ --instance_dataset 'MyParquetDataset' \
12
+ --instance_data_dir '/data/sqy/0000/' \
13
+ --resolution 1024 \
14
+ --mixed_precision fp16 \
15
+ --lr_scheduler constant \
16
+ --use_8bit_adam \
17
+ --dataloader_num_workers 0 \
18
+ --validation_prompts \
19
+ 'a boy' \
20
+ 'A serene mountain landscape with towering snow-capped peaks, a crystal-clear blue lake reflecting the mountains, dense pine forests, and a vibrant orange sunrise illuminating the sky.' \
21
+ 'A playful golden retriever puppy with a shiny coat, bounding through a meadow filled with colorful wildflowers, under a bright, clear blue sky.' \
22
+ 'A bustling city street at night, illuminated by vibrant neon signs in various colors, with busy pedestrians, street vendors, and a light rain creating reflective puddles on the pavement.' \
23
+ 'A majestic, medieval castle perched on a rugged cliffside, overlooking a vast, calm ocean at sunset, with the sky painted in hues of pink, orange, and purple.' \
24
+ 'An elegant ballerina in a white tutu, dancing gracefully on a grand stage with ornate, gold-trimmed curtains, under a spotlight that casts a soft glow.' \
25
+ 'A cozy, rustic log cabin nestled in a snow-covered forest, with smoke rising from the stone chimney, warm lights glowing from the windows, and a path of footprints leading to the front door.'\
26
+ 'A Cute Cat' \
27
+ 'A Snow Mountain'\
28
+ --max_train_steps 30000 \
29
+ --checkpointing_steps 1000 \
30
+ --validation_steps 100 \
31
+ --logging_steps 10
train/train_text_encoder.py ADDED
@@ -0,0 +1,903 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import copy
3
+ import logging
4
+ import math
5
+ import os
6
+ import json
7
+ from contextlib import nullcontext
8
+ from pathlib import Path
9
+ import sys
10
+ sys.path.append(os.getcwd())
11
+ import gc
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+
16
+ from accelerate import Accelerator
17
+ from accelerate.logging import get_logger
18
+ from accelerate.utils import ProjectConfiguration, set_seed
19
+ from peft import LoraConfig
20
+ from peft.utils import get_peft_model_state_dict
21
+
22
+ from torch.utils.data import DataLoader, default_collate
23
+ from torchvision import transforms
24
+ from torchvision.utils import save_image,make_grid
25
+
26
+ from transformers.models.gemma2.modeling_gemma2 import Gemma2Model
27
+ from transformers.models.gemma.tokenization_gemma_fast import GemmaTokenizerFast
28
+ from transformers import (
29
+ CLIPTextModelWithProjection,
30
+ CLIPTokenizer,
31
+ )
32
+
33
+ import diffusers.optimization
34
+ from diffusers import EMAModel, VQModel
35
+ from diffusers.loaders import LoraLoaderMixin
36
+ from diffusers.utils import is_wandb_available
37
+
38
+ from src.scheduler import Scheduler
39
+ from src.pipeline import UnifiedPipeline
40
+
41
+ from train.trainer_utils import save_checkpoint
42
+ from train.dataset_utils import ImageCaptionLargeDataset
43
+ from train.dataset_utils import tokenize_prompt, encode_prompt
44
+ from src.transformer import SymmetricTransformer2DModel
45
+
46
+ if is_wandb_available():
47
+ import wandb
48
+ # wandb.login(key="")
49
+
50
+ logger = get_logger(__name__, log_level="INFO")
51
+
52
+ import torch._dynamo
53
+ torch._dynamo.config.verbose = True
54
+
55
+ # Optionally suppress errors to fall back to eager execution
56
+ torch._dynamo.config.suppress_errors = True
57
+
58
+ def parse_args():
59
+ parser = argparse.ArgumentParser()
60
+ parser.add_argument(
61
+ "--pretrained_model_name_or_path",
62
+ type=str,
63
+ default=None,
64
+ required=True,
65
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
66
+ )
67
+ parser.add_argument(
68
+ "--pretrained_transformer_path",
69
+ type=str,
70
+ default=None,
71
+ required=True,
72
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
73
+ )
74
+ parser.add_argument(
75
+ "--text_encoder_architecture",
76
+ type=str,
77
+ default="open_clip",
78
+ required=False,
79
+ help="The architecture of the text encoder. One of ['CLIP', 'open_clip', 'flan-t5-base','Qwen2-0.5B','gemini-2b', 'CLIP_T5_base']",
80
+ )
81
+ parser.add_argument(
82
+ "--text_encoder_name_or_path",
83
+ type=str,
84
+ default=None,
85
+ required=True,
86
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
87
+ )
88
+ parser.add_argument(
89
+ "--instance_dataset",
90
+ type=str,
91
+ default=None,
92
+ required=False,
93
+ help="The dataset to use for training. One of ['MSCOCO600K', 'PickaPicV2']",
94
+ )
95
+ parser.add_argument(
96
+ "--instance_data_dir",
97
+ type=str,
98
+ default=None,
99
+ required=False,
100
+ help="A folder containing the training data of instance images.",
101
+ )
102
+ parser.add_argument(
103
+ "--training_from_scratch",
104
+ type=bool,
105
+ default=False,
106
+ required=False
107
+ )
108
+ parser.add_argument(
109
+ "--revision",
110
+ type=str,
111
+ default=None,
112
+ required=False,
113
+ help="Revision of pretrained model identifier from huggingface.co/models.",
114
+ )
115
+ parser.add_argument(
116
+ "--variant",
117
+ type=str,
118
+ default=None,
119
+ help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
120
+ )
121
+ parser.add_argument(
122
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
123
+ )
124
+ parser.add_argument(
125
+ "--dataloader_num_workers",
126
+ type=int,
127
+ default=0,
128
+ help=(
129
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
130
+ ),
131
+ )
132
+ parser.add_argument(
133
+ "--allow_tf32",
134
+ action="store_true",
135
+ help=(
136
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
137
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
138
+ ),
139
+ )
140
+ parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.")
141
+ parser.add_argument("--ema_decay", type=float, default=0.9999)
142
+ parser.add_argument("--ema_update_after_step", type=int, default=0)
143
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
144
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
145
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
146
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
147
+ parser.add_argument(
148
+ "--output_dir",
149
+ type=str,
150
+ default="muse_training",
151
+ help="The output directory where the model predictions and checkpoints will be written.",
152
+ )
153
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
154
+ parser.add_argument(
155
+ "--logging_dir",
156
+ type=str,
157
+ default="logs",
158
+ help=(
159
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
160
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
161
+ ),
162
+ )
163
+ parser.add_argument(
164
+ "--max_train_steps",
165
+ type=int,
166
+ default=None,
167
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
168
+ )
169
+ parser.add_argument(
170
+ "--checkpointing_steps",
171
+ type=int,
172
+ default=500,
173
+ help=(
174
+ "Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. "
175
+ "In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference."
176
+ "Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components."
177
+ "See https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step"
178
+ "instructions."
179
+ ),
180
+ )
181
+ parser.add_argument(
182
+ "--logging_steps",
183
+ type=int,
184
+ default=50,
185
+ )
186
+ parser.add_argument(
187
+ "--checkpoints_total_limit",
188
+ type=int,
189
+ default=None,
190
+ help=(
191
+ "Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`."
192
+ " See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state"
193
+ " for more details"
194
+ ),
195
+ )
196
+ parser.add_argument(
197
+ "--resume_from_checkpoint",
198
+ type=str,
199
+ default=None,
200
+ help=(
201
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
202
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
203
+ ),
204
+ )
205
+ parser.add_argument(
206
+ "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
207
+ )
208
+ parser.add_argument(
209
+ "--gradient_accumulation_steps",
210
+ type=int,
211
+ default=1,
212
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
213
+ )
214
+ parser.add_argument(
215
+ "--learning_rate",
216
+ type=float,
217
+ default=0.0003,
218
+ help="Initial learning rate (after the potential warmup period) to use.",
219
+ )
220
+ parser.add_argument(
221
+ "--scale_lr",
222
+ action="store_true",
223
+ default=False,
224
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
225
+ )
226
+ parser.add_argument(
227
+ "--lr_scheduler",
228
+ type=str,
229
+ default="constant",
230
+ help=(
231
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
232
+ ' "constant", "constant_with_warmup"]'
233
+ ),
234
+ )
235
+ parser.add_argument(
236
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
237
+ )
238
+ parser.add_argument(
239
+ "--validation_steps",
240
+ type=int,
241
+ default=100,
242
+ help=(
243
+ "Run validation every X steps. Validation consists of running the prompt"
244
+ " `args.validation_prompt` multiple times: `args.num_validation_images`"
245
+ " and logging the images."
246
+ ),
247
+ )
248
+ parser.add_argument(
249
+ "--mixed_precision",
250
+ type=str,
251
+ default=None,
252
+ choices=["no", "fp16", "bf16"],
253
+ help=(
254
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
255
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
256
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
257
+ ),
258
+ )
259
+ parser.add_argument(
260
+ "--report_to",
261
+ type=str,
262
+ default="wandb",
263
+ help=(
264
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
265
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
266
+ ),
267
+ )
268
+ parser.add_argument("--validation_prompts", type=str, nargs="*")
269
+ parser.add_argument(
270
+ "--resolution",
271
+ type=int,
272
+ default=512,
273
+ help=(
274
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
275
+ " resolution"
276
+ ),
277
+ )
278
+ parser.add_argument("--split_vae_encode", type=int, required=False, default=None)
279
+ parser.add_argument("--min_masking_rate", type=float, default=0.0)
280
+ parser.add_argument("--cond_dropout_prob", type=float, default=0.0)
281
+ parser.add_argument("--max_grad_norm", default=50.0, type=float, help="Max gradient norm.", required=False)
282
+ parser.add_argument("--use_lora", action="store_true", help="Fine tune the model using LoRa")
283
+ parser.add_argument("--text_encoder_use_lora", action="store_true", help="Fine tune the model using LoRa")
284
+ parser.add_argument("--lora_r", default=16, type=int)
285
+ parser.add_argument("--lora_alpha", default=32, type=int)
286
+ parser.add_argument("--lora_target_modules", default=["to_q", "to_k", "to_v"], type=str, nargs="+")
287
+ parser.add_argument("--text_encoder_lora_r", default=16, type=int)
288
+ parser.add_argument("--text_encoder_lora_alpha", default=32, type=int)
289
+ parser.add_argument("--text_encoder_lora_target_modules", default=["to_q", "to_k", "to_v"], type=str, nargs="+")
290
+ parser.add_argument("--train_text_encoder", action="store_true")
291
+ parser.add_argument("--image_key", type=str, required=False)
292
+ parser.add_argument("--prompt_key", type=str, required=False)
293
+ parser.add_argument(
294
+ "--gradient_checkpointing",
295
+ action="store_true",
296
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
297
+ )
298
+ parser.add_argument("--prompt_prefix", type=str, required=False, default=None)
299
+
300
+ args = parser.parse_args()
301
+
302
+ if args.report_to == "wandb":
303
+ if not is_wandb_available():
304
+ raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
305
+
306
+ if args.instance_data_dir is not None:
307
+ if not os.path.exists(args.instance_data_dir):
308
+ raise ValueError(f"Does not exist: `--args.instance_data_dir` {args.instance_data_dir}")
309
+
310
+ return args
311
+
312
+ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
313
+ latent_image_ids = torch.zeros(height // 2, width // 2, 3)
314
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
315
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
316
+
317
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
318
+
319
+ latent_image_ids = latent_image_ids.reshape(
320
+ latent_image_id_height * latent_image_id_width, latent_image_id_channels
321
+ )
322
+ # latent_image_ids = latent_image_ids.unsqueeze(0).repeat(batch_size, 1, 1)
323
+
324
+ return latent_image_ids.to(device=device, dtype=dtype)
325
+
326
+ def main(args):
327
+ if args.allow_tf32:
328
+ torch.backends.cuda.matmul.allow_tf32 = True
329
+
330
+
331
+ logging_dir = Path(args.output_dir, args.logging_dir)
332
+
333
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
334
+
335
+ accelerator = Accelerator(
336
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
337
+ mixed_precision=args.mixed_precision,
338
+ log_with=args.report_to,
339
+ project_config=accelerator_project_config,
340
+ )
341
+
342
+ if accelerator.is_main_process:
343
+ os.makedirs(args.output_dir, exist_ok=True)
344
+
345
+ # Make one log on every process with the configuration for debugging.
346
+ logging.basicConfig(
347
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
348
+ datefmt="%m/%d/%Y %H:%M:%S",
349
+ level=logging.INFO,
350
+ )
351
+ logger.info(accelerator.state, main_process_only=False)
352
+
353
+ if accelerator.is_main_process:
354
+ accelerator.init_trackers("muddit", config=vars(copy.deepcopy(args)))
355
+
356
+ if args.seed is not None:
357
+ set_seed(args.seed)
358
+
359
+ if args.text_encoder_architecture == "gemma":
360
+ text_encoder_one = CLIPTextModelWithProjection.from_pretrained(
361
+ args.pretrained_model_name_or_path, subfolder="text_encoder", variant=args.variant
362
+ )
363
+ tokenizer_one = CLIPTokenizer.from_pretrained(
364
+ args.pretrained_model_name_or_path, subfolder="tokenizer", variant=args.variant
365
+ )
366
+
367
+ text_encoder_two = Gemma2Model.from_pretrained(
368
+ args.text_encoder_name_or_path, variant=args.variant
369
+ )
370
+ tokenizer_two = GemmaTokenizerFast.from_pretrained(
371
+ args.text_encoder_name_or_path, variant=args.variant
372
+ )
373
+ t5_dim = text_encoder_two.config.hidden_size
374
+
375
+ text_encoder = [text_encoder_one, text_encoder_two]
376
+ tokenizer = [tokenizer_one, tokenizer_two]
377
+
378
+ text_encoder_one.requires_grad_(False)
379
+ text_encoder_two.requires_grad_(False)
380
+ else:
381
+ raise ValueError(f"Unknown text encoder architecture: {args.text_encoder_architecture}")
382
+
383
+ vq_model = VQModel.from_pretrained(
384
+ args.pretrained_model_name_or_path, subfolder="vqvae", revision=args.revision, variant=args.variant
385
+ )
386
+ vq_model.requires_grad_(False)
387
+
388
+ model = SymmetricTransformer2DModel.from_pretrained(
389
+ args.pretrained_model_name_or_path if args.pretrained_transformer_path is None else args.pretrained_transformer_path,
390
+ subfolder="transformer",
391
+ low_cpu_mem_usage=False,
392
+ device_map=None,
393
+ )
394
+
395
+ if args.pretrained_transformer_path is None and model.adapter is None:
396
+ model.register_to_config(t5_dim=t5_dim)
397
+ model.adapter = nn.Sequential(
398
+ nn.LayerNorm(t5_dim, elementwise_affine=False, eps=1e-6),
399
+ nn.Linear(t5_dim, model.config.joint_attention_dim, bias=False)
400
+ )
401
+
402
+ model.requires_grad_(True)
403
+ model.train()
404
+ model = torch.compile(model)
405
+
406
+ if args.gradient_checkpointing:
407
+ model.enable_gradient_checkpointing()
408
+
409
+ if args.use_ema: # Not verify the robostness of this part
410
+ ema = EMAModel(
411
+ model.parameters(),
412
+ decay=args.ema_decay,
413
+ update_after_step=args.ema_update_after_step,
414
+ model_cls=SymmetricTransformer2DModel,
415
+ model_config=model.config,
416
+ )
417
+
418
+ def save_model_hook(models, weights, output_dir):
419
+ if accelerator.is_main_process:
420
+ transformer_lora_layers_to_save = None
421
+ text_encoder_lora_layers_to_save = None
422
+
423
+ for model_ in models:
424
+ if isinstance(model_, type(accelerator.unwrap_model(model))):
425
+ if args.use_lora:
426
+ transformer_lora_layers_to_save = get_peft_model_state_dict(model_)
427
+ else:
428
+ model_.save_pretrained(os.path.join(output_dir, "transformer"))
429
+ elif isinstance(model_, type(accelerator.unwrap_model(text_encoder))):
430
+ if args.text_encoder_use_lora:
431
+ text_encoder_lora_layers_to_save = get_peft_model_state_dict(model_)
432
+ else:
433
+ model_.save_pretrained(os.path.join(output_dir, "text_encoder"))
434
+ else:
435
+ raise ValueError(f"unexpected save model: {model_.__class__}")
436
+
437
+ # make sure to pop weight so that corresponding model is not saved again
438
+ weights.pop()
439
+
440
+ if transformer_lora_layers_to_save is not None or text_encoder_lora_layers_to_save is not None:
441
+ LoraLoaderMixin.save_lora_weights(
442
+ output_dir,
443
+ unet_lora_layers=transformer_lora_layers_to_save,
444
+ text_encoder_lora_layers=text_encoder_lora_layers_to_save,
445
+ )
446
+
447
+ if args.use_ema:
448
+ ema.save_pretrained(os.path.join(output_dir, "ema_model"))
449
+
450
+ def load_model_hook(models, input_dir):
451
+ transformer = None
452
+ text_encoder_ = None
453
+
454
+ # this part is added for keep consistency when add model.compile() in the model
455
+ def adap_compile(ori_dict):#add '_orig_mod.' to each key
456
+ new_dict = {}
457
+ for k,v in ori_dict.items():
458
+ new_dict['_orig_mod.'+k] = v
459
+ return new_dict
460
+
461
+ while len(models) > 0:
462
+ model_ = models.pop()
463
+
464
+ if isinstance(model_, type(accelerator.unwrap_model(model))):
465
+ if args.use_lora:
466
+ transformer = model_
467
+ else:
468
+ load_model = SymmetricTransformer2DModel.from_pretrained(os.path.join(input_dir, "transformer"), low_cpu_mem_usage=False, device_map=None)
469
+ model_.load_state_dict(adap_compile(load_model.state_dict()))
470
+ del load_model
471
+ else:
472
+ raise ValueError(f"unexpected save model: {model.__class__}")
473
+
474
+ if transformer is not None or text_encoder_ is not None:
475
+ lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir)
476
+ LoraLoaderMixin.load_lora_into_text_encoder(
477
+ lora_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_
478
+ )
479
+ LoraLoaderMixin.load_lora_into_transformer(
480
+ lora_state_dict, network_alphas=network_alphas, transformer=transformer
481
+ )
482
+
483
+ if args.use_ema:
484
+ load_from = EMAModel.from_pretrained(os.path.join(input_dir, "ema_model"), model_cls=SymmetricTransformer2DModel)
485
+ ema.load_state_dict(adap_compile(load_from.state_dict()))
486
+ del load_from
487
+
488
+ accelerator.register_load_state_pre_hook(load_model_hook)
489
+ accelerator.register_save_state_pre_hook(save_model_hook)
490
+
491
+ if args.scale_lr:
492
+ args.learning_rate = (
493
+ args.learning_rate * args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
494
+ )
495
+
496
+ if args.use_8bit_adam:
497
+ try:
498
+ import bitsandbytes as bnb
499
+ except ImportError:
500
+ raise ImportError(
501
+ "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
502
+ )
503
+
504
+ optimizer_cls = bnb.optim.AdamW8bit
505
+ else:
506
+ optimizer_cls = torch.optim.AdamW
507
+
508
+ optimizer_parameters = [p for p in model.parameters() if p.requires_grad]
509
+ optimizer = optimizer_cls(
510
+ optimizer_parameters,
511
+ lr=args.learning_rate,
512
+ betas=(args.adam_beta1, args.adam_beta2),
513
+ weight_decay=args.adam_weight_decay,
514
+ eps=args.adam_epsilon,
515
+ )
516
+
517
+ logger.info("Creating dataloaders and lr_scheduler")
518
+
519
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
520
+
521
+ if args.instance_dataset == "ImageCaptionLargeDataset":
522
+ dataset = ImageCaptionLargeDataset(
523
+ root_dir=args.instance_data_dir,
524
+ tokenizer=tokenizer,
525
+ size=args.resolution,
526
+ text_encoder_architecture=args.text_encoder_architecture
527
+ )
528
+ elif args.instance_dataset == "DATA_TYPE":
529
+ raise NotImplementedError("DATA_TYPE is not yet supported")
530
+ else:
531
+ assert False
532
+
533
+ def collate_fn(samples):
534
+ images = [sample["image"] for sample in samples]
535
+ micro_conds = [sample["micro_conds"] for sample in samples]
536
+
537
+ images = torch.stack(images, dim=0)
538
+ micro_conds = torch.stack(micro_conds, dim=0)
539
+
540
+ if isinstance(samples[0]["prompt_input_ids"], list):
541
+ input_ids = [sample["prompt_input_ids"][0] for sample in samples]
542
+ input_ids_2 = [sample["prompt_input_ids"][1] for sample in samples]
543
+
544
+ input_ids = torch.cat(input_ids, dim=0)
545
+ input_ids_2 = torch.cat(input_ids_2, dim=0)
546
+ prompt_input_ids = [input_ids, input_ids_2]
547
+
548
+ elif isinstance(samples[0]["prompt_input_ids"], torch.Tensor):
549
+ input_ids = [sample["prompt_input_ids"] for sample in samples]
550
+
551
+ input_ids = torch.cat(input_ids, dim=0)
552
+ prompt_input_ids = input_ids
553
+
554
+ ret = dict(
555
+ images=images,
556
+ micro_conds=micro_conds,
557
+ prompt_input_ids=prompt_input_ids,
558
+ )
559
+
560
+ return ret
561
+
562
+ train_dataloader = DataLoader(
563
+ dataset,
564
+ batch_size=args.train_batch_size,
565
+ shuffle=True,
566
+ num_workers=args.dataloader_num_workers,
567
+ collate_fn=collate_fn,
568
+ pin_memory=True,
569
+ )
570
+ train_dataloader.num_batches = len(train_dataloader)
571
+
572
+ lr_scheduler = diffusers.optimization.get_scheduler(
573
+ args.lr_scheduler,
574
+ optimizer=optimizer,
575
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
576
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
577
+ )
578
+
579
+ logger.info("Preparing model, optimizer and dataloaders")
580
+
581
+ if args.train_text_encoder:
582
+ if args.text_encoder_architecture == "CLIP_T5_base": # Not support yet. Only support open_clip
583
+ model, optimizer, lr_scheduler, train_dataloader, text_encoder[0], text_encoder[1] = accelerator.prepare(
584
+ model, optimizer, lr_scheduler, train_dataloader, text_encoder[0], text_encoder[1]
585
+ )
586
+ else:
587
+ model, optimizer, lr_scheduler, train_dataloader, text_encoder = accelerator.prepare(
588
+ model, optimizer, lr_scheduler, train_dataloader, text_encoder
589
+ )
590
+ else:
591
+ model, optimizer, lr_scheduler, train_dataloader = accelerator.prepare(
592
+ model, optimizer, lr_scheduler, train_dataloader
593
+ )
594
+
595
+ train_dataloader.num_batches = len(train_dataloader)
596
+
597
+ weight_dtype = torch.float32
598
+ if accelerator.mixed_precision == "fp16":
599
+ weight_dtype = torch.float16
600
+ elif accelerator.mixed_precision == "bf16":
601
+ weight_dtype = torch.bfloat16
602
+
603
+ if not args.train_text_encoder:
604
+ if args.text_encoder_architecture in ("t5_clip", "gemma"): # Not support yet. Only support open_clip
605
+ text_encoder[0].to(device=accelerator.device, dtype=weight_dtype)
606
+ text_encoder[1].to(device=accelerator.device, dtype=weight_dtype)
607
+ else:
608
+ text_encoder.to(device=accelerator.device, dtype=weight_dtype)
609
+
610
+ vq_model.to(device=accelerator.device)
611
+
612
+ if args.use_ema:
613
+ ema.to(accelerator.device)
614
+
615
+ with nullcontext() if args.train_text_encoder else torch.no_grad():
616
+ if args.text_encoder_architecture in ("t5_clip", "gemma"): # Not support yet. Only support open_clip
617
+ _input_ids_tmp_ = tokenize_prompt(tokenizer, "", args.text_encoder_architecture)
618
+ _input_ids_tmp_[0] = _input_ids_tmp_[0].to(accelerator.device, non_blocking=True)
619
+ _input_ids_tmp_[1] = _input_ids_tmp_[1].to(accelerator.device, non_blocking=True)
620
+ empty_embeds, empty_clip_embeds = encode_prompt(
621
+ text_encoder,
622
+ _input_ids_tmp_,
623
+ args.text_encoder_architecture
624
+ )
625
+ else:
626
+ empty_embeds, empty_clip_embeds = encode_prompt(
627
+ text_encoder,
628
+ tokenize_prompt(tokenizer, "", args.text_encoder_architecture).to(accelerator.device, non_blocking=True),
629
+ args.text_encoder_architecture
630
+ )
631
+
632
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
633
+ num_update_steps_per_epoch = math.ceil(train_dataloader.num_batches / args.gradient_accumulation_steps)
634
+ # Afterwards we recalculate our number of training epochs.
635
+ # Note: We are not doing epoch based training here, but just using this for book keeping and being able to
636
+ # reuse the same training loop with other datasets/loaders.
637
+ num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
638
+
639
+ # Train!
640
+ logger.info("***** Running training *****")
641
+ logger.info(f" Num training steps = {args.max_train_steps}")
642
+ logger.info(f" Instantaneous batch size per device = { args.train_batch_size}")
643
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
644
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
645
+
646
+ resume_from_checkpoint = args.resume_from_checkpoint
647
+ if resume_from_checkpoint:
648
+ if resume_from_checkpoint == "latest":
649
+ # Get the most recent checkpoint
650
+ dirs = os.listdir(args.output_dir)
651
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
652
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
653
+ if len(dirs) > 0:
654
+ resume_from_checkpoint = os.path.join(args.output_dir, dirs[-1])
655
+ else:
656
+ resume_from_checkpoint = None
657
+
658
+ if resume_from_checkpoint is None:
659
+ accelerator.print(
660
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
661
+ )
662
+ else:
663
+ accelerator.print(f"Resuming from checkpoint {resume_from_checkpoint}")
664
+
665
+ if resume_from_checkpoint is None:
666
+ global_step = 0
667
+ first_epoch = 0
668
+ else:
669
+ accelerator.load_state(resume_from_checkpoint)
670
+ global_step = int(os.path.basename(resume_from_checkpoint).split("-")[1])
671
+ first_epoch = global_step // num_update_steps_per_epoch
672
+
673
+ # This is to solve the inconsistent tensor device issue
674
+ if args.use_ema:
675
+ ema.shadow_params = [p.to(accelerator.device) for p in ema.shadow_params]
676
+
677
+ # As stated above, we are not doing epoch based training here, but just using this for book keeping and being able to
678
+ # reuse the same training loop with other datasets/loaders.
679
+ for epoch in range(first_epoch, num_train_epochs):
680
+ for batch in train_dataloader:
681
+ torch.cuda.empty_cache()
682
+ with torch.no_grad():
683
+ micro_conds = batch["micro_conds"].to(accelerator.device, non_blocking=True)
684
+ pixel_values = batch["images"].to(accelerator.device, non_blocking=True)
685
+
686
+ batch_size = pixel_values.shape[0]
687
+
688
+ # ====================== tokenize images ======================
689
+ split_batch_size = args.split_vae_encode if args.split_vae_encode is not None else batch_size
690
+ num_splits = math.ceil(batch_size / split_batch_size)
691
+ image_tokens = []
692
+ for i in range(num_splits):
693
+ start_idx = i * split_batch_size
694
+ end_idx = min((i + 1) * split_batch_size, batch_size)
695
+ bs = pixel_values.shape[0]
696
+ image_tokens.append(
697
+ vq_model.quantize(
698
+ vq_model.encode(
699
+ pixel_values[start_idx: end_idx]
700
+ ).latents
701
+ )[2][2].reshape(split_batch_size, -1)
702
+ )
703
+ image_tokens = torch.cat(image_tokens, dim=0)
704
+ # ====================== tokenize images ======================
705
+
706
+ batch_size, seq_len = image_tokens.shape
707
+ timesteps = torch.rand(batch_size, device=image_tokens.device)
708
+ mask_prob = torch.cos(timesteps * math.pi * 0.5)
709
+ mask_prob = mask_prob.clip(args.min_masking_rate)
710
+
711
+ num_token_masked = (seq_len * mask_prob).round().clamp(min=1)
712
+ batch_randperm = torch.rand(batch_size, seq_len, device=image_tokens.device).argsort(dim=-1)
713
+ mask = batch_randperm < num_token_masked.unsqueeze(-1)
714
+
715
+ mask_id = accelerator.unwrap_model(model).config.vocab_size - 1
716
+ input_ids = torch.where(mask, mask_id, image_tokens)
717
+ labels = torch.where(mask, image_tokens, -100)
718
+
719
+ if "prompt_input_ids" in batch:
720
+ with nullcontext() if args.train_text_encoder else torch.no_grad():
721
+ if args.text_encoder_architecture in ("t5_clip", "gemma"): # Not support yet. Only support open_clip
722
+ batch["prompt_input_ids"][0] = batch["prompt_input_ids"][0].to(accelerator.device, non_blocking=True)
723
+ batch["prompt_input_ids"][1] = batch["prompt_input_ids"][1].to(accelerator.device, non_blocking=True)
724
+ encoder_hidden_states, cond_embeds = encode_prompt(
725
+ text_encoder,
726
+ batch["prompt_input_ids"],
727
+ args.text_encoder_architecture
728
+ )
729
+ else:
730
+ encoder_hidden_states, cond_embeds = encode_prompt(
731
+ text_encoder,
732
+ batch["prompt_input_ids"].to(accelerator.device, non_blocking=True),
733
+ args.text_encoder_architecture
734
+ )
735
+
736
+ if args.cond_dropout_prob > 0.0:
737
+ assert encoder_hidden_states is not None
738
+
739
+ batch_size = encoder_hidden_states.shape[0]
740
+
741
+ mask = (
742
+ torch.zeros((batch_size, 1, 1), device=encoder_hidden_states.device).float().uniform_(0, 1)
743
+ < args.cond_dropout_prob
744
+ )
745
+
746
+ empty_embeds_ = empty_embeds.expand(batch_size, -1, -1)
747
+ encoder_hidden_states = torch.where(
748
+ (encoder_hidden_states * mask).bool(), encoder_hidden_states, empty_embeds_
749
+ )
750
+
751
+ empty_clip_embeds_ = empty_clip_embeds.expand(batch_size, -1)
752
+ cond_embeds = torch.where((cond_embeds * mask.squeeze(-1)).bool(), cond_embeds, empty_clip_embeds_)
753
+
754
+ bs = input_ids.shape[0]
755
+ vae_scale_factor = 2 ** (len(vq_model.config.block_out_channels) - 1)
756
+ resolution = args.resolution // vae_scale_factor
757
+ input_ids = input_ids.reshape(bs, resolution, resolution)
758
+
759
+ # Train Step
760
+ with accelerator.accumulate(model):
761
+ codebook_size = accelerator.unwrap_model(model).config.codebook_size
762
+ if args.resolution == 1024: # only stage 3 and stage 4 do not apply 2*
763
+ img_ids = _prepare_latent_image_ids(input_ids.shape[0], input_ids.shape[-2], input_ids.shape[-1], input_ids.device, input_ids.dtype)
764
+ else:
765
+ img_ids = _prepare_latent_image_ids(input_ids.shape[0], 2 * input_ids.shape[-2], 2 * input_ids.shape[-1], input_ids.device, input_ids.dtype)
766
+
767
+ txt_ids = torch.zeros(encoder_hidden_states.shape[1], 3).to(device = input_ids.device, dtype = input_ids.dtype)
768
+
769
+ logits = model(
770
+ hidden_states=input_ids, # should be (batch size, channel, height, width)
771
+ encoder_hidden_states=encoder_hidden_states, # should be (batch size, sequence_len, embed_dims)
772
+ micro_conds=micro_conds, #
773
+ pooled_projections=cond_embeds, # should be (batch_size, projection_dim)
774
+ img_ids=img_ids,
775
+ txt_ids=txt_ids,
776
+ timestep=mask_prob,
777
+ )[0]
778
+ logits = logits.reshape(batch_size, codebook_size, -1).permute(0, 2, 1)
779
+ logits = logits.reshape(-1, codebook_size)
780
+ loss = F.cross_entropy(
781
+ logits,
782
+ labels.view(-1),
783
+ ignore_index=-100,
784
+ reduction="mean",
785
+ )
786
+
787
+ # Gather the losses across all processes for logging (if we use distributed training).
788
+ avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
789
+ avg_masking_rate = accelerator.gather(mask_prob.repeat(args.train_batch_size)).mean()
790
+
791
+ accelerator.backward(loss)
792
+
793
+ if args.max_grad_norm is not None and accelerator.sync_gradients:
794
+ accelerator.clip_grad_norm_(model.parameters(), args.max_grad_norm)
795
+
796
+ optimizer.step()
797
+ lr_scheduler.step()
798
+
799
+ optimizer.zero_grad(set_to_none=True)
800
+
801
+ # Checks if the accelerator has performed an optimization step behind the scenes
802
+ if accelerator.sync_gradients:
803
+ if args.use_ema:
804
+ ema.step(model.parameters())
805
+
806
+ if (global_step + 1) % args.logging_steps == 0:
807
+ logs = {
808
+ "step_loss": avg_loss.item(),
809
+ "lr": lr_scheduler.get_last_lr()[0],
810
+ "avg_masking_rate": avg_masking_rate.item(),
811
+ }
812
+ accelerator.log(logs, step=global_step + 1)
813
+
814
+ logger.info(
815
+ f"Step: {global_step + 1} "
816
+ f"Loss: {avg_loss.item():0.4f} "
817
+ f"LR: {lr_scheduler.get_last_lr()[0]:0.6f}"
818
+ )
819
+
820
+ if (global_step + 1) % args.checkpointing_steps == 0:
821
+ save_checkpoint(args, accelerator, global_step + 1, logger)
822
+
823
+ if (global_step + 1) % args.validation_steps == 0 and accelerator.is_main_process:
824
+ if args.use_ema:
825
+ ema.store(model.parameters())
826
+ ema.copy_to(model.parameters())
827
+
828
+ with torch.no_grad():
829
+ logger.info("Generating images...")
830
+
831
+ model.eval()
832
+
833
+ scheduler = Scheduler.from_pretrained(
834
+ args.pretrained_model_name_or_path,
835
+ subfolder="scheduler",
836
+ revision=args.revision,
837
+ variant=args.variant,
838
+ )
839
+
840
+ pipe = UnifiedPipeline(
841
+ transformer=accelerator.unwrap_model(model),
842
+ tokenizer=tokenizer_one,
843
+ tokenizer_2=tokenizer_two,
844
+ text_encoder=text_encoder_one,
845
+ text_encoder_2=text_encoder_two,
846
+ vqvae=vq_model,
847
+ scheduler=scheduler,
848
+ )
849
+
850
+ output = pipe(
851
+ prompt=args.validation_prompts,
852
+ height=args.resolution,
853
+ width=args.resolution,
854
+ guidance_scale=9,
855
+ num_inference_steps=64
856
+ )
857
+ pil_images = output.images
858
+
859
+ wandb_images = [
860
+ wandb.Image(image, caption=args.validation_prompts[i])
861
+ for i, image in enumerate(pil_images)
862
+ ]
863
+
864
+ wandb.log({"generated_images": wandb_images}, step=global_step + 1)
865
+
866
+ result=[]
867
+ for img in pil_images:
868
+ if not isinstance(img, torch.Tensor):
869
+ img = transforms.ToTensor()(img)
870
+ result.append(img.unsqueeze(0))
871
+ result = torch.cat(result,dim=0)
872
+ result = make_grid(result, nrow=3)
873
+ save_image(result,os.path.join(args.output_dir, str(global_step)+'_text2image_1024_CFG-9.png'))
874
+
875
+ model.train()
876
+
877
+ if args.use_ema:
878
+ ema.restore(model.parameters())
879
+
880
+ global_step += 1
881
+
882
+ # Stop training if max steps is reached
883
+ if global_step >= args.max_train_steps:
884
+ break
885
+ # End for
886
+
887
+ accelerator.wait_for_everyone()
888
+
889
+ # Evaluate and save checkpoint at the end of training
890
+ save_checkpoint(args, accelerator, global_step, logger)
891
+
892
+ # Save the final trained checkpoint
893
+ if accelerator.is_main_process:
894
+ model = accelerator.unwrap_model(model)
895
+ if args.use_ema:
896
+ ema.copy_to(model.parameters())
897
+ model.save_pretrained(args.output_dir)
898
+
899
+ accelerator.end_training()
900
+
901
+
902
+ if __name__ == "__main__":
903
+ main(parse_args())
train/train_text_encoder.sh ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ export TOKENIZERS_PARALLELISM=false
2
+
3
+ # bash it in root path
4
+ PYTHON_PATH='./' accelerate launch --multi_gpu --gpu_ids '2,3' --main_process_port 25000 --num_processes 2 train/train_text_encoder.py \
5
+ --output_dir "./outputs/debug/" \
6
+ --train_batch_size 8 \
7
+ --gradient_accumulation_steps 2 \
8
+ --learning_rate 1e-4 \
9
+ --max_grad_norm 10 \
10
+ --pretrained_model_name_or_path "MeissonFlow/Meissonic" \
11
+ --text_encoder_architecture 'gemma' \
12
+ --text_encoder_name_or_path "google/gemma-2-2b-it" \
13
+ --instance_dataset 'ImageCaptionLargeDataset' \
14
+ --instance_data_dir '/data/sqy/0000/' \
15
+ --resolution 1024 \
16
+ --mixed_precision fp16 \
17
+ --lr_scheduler constant \
18
+ --use_8bit_adam \
19
+ --dataloader_num_workers 4 \
20
+ --validation_prompts \
21
+ 'a boy' \
22
+ 'A serene mountain landscape with towering snow-capped peaks, a crystal-clear blue lake reflecting the mountains, dense pine forests, and a vibrant orange sunrise illuminating the sky.' \
23
+ 'A playful golden retriever puppy with a shiny coat, bounding through a meadow filled with colorful wildflowers, under a bright, clear blue sky.' \
24
+ 'A bustling city street at night, illuminated by vibrant neon signs in various colors, with busy pedestrians, street vendors, and a light rain creating reflective puddles on the pavement.' \
25
+ 'A majestic, medieval castle perched on a rugged cliffside, overlooking a vast, calm ocean at sunset, with the sky painted in hues of pink, orange, and purple.' \
26
+ 'An elegant ballerina in a white tutu, dancing gracefully on a grand stage with ornate, gold-trimmed curtains, under a spotlight that casts a soft glow.' \
27
+ 'A cozy, rustic log cabin nestled in a snow-covered forest, with smoke rising from the stone chimney, warm lights glowing from the windows, and a path of footprints leading to the front door.'\
28
+ 'A Cute Cat' \
29
+ 'A Snow Mountain'\
30
+ --max_train_steps 30000 \
31
+ --checkpointing_steps 1000 \
32
+ --validation_steps 100 \
33
+ --logging_steps 10 \
34
+ --report_to "wandb"
train/train_unified.py ADDED
@@ -0,0 +1,1141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team and The MeissonFlow Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import argparse
16
+ import copy
17
+ import logging
18
+ import math
19
+ import os
20
+ from pathlib import Path
21
+ import sys
22
+ sys.path.append(os.getcwd())
23
+ import json
24
+ import gc
25
+
26
+ import torch
27
+ import torch.nn.functional as F
28
+ from torch import nn
29
+
30
+ from accelerate import Accelerator
31
+ from accelerate.logging import get_logger
32
+ from accelerate.utils import ProjectConfiguration, set_seed
33
+ from peft import LoraConfig
34
+ from peft.utils import get_peft_model_state_dict
35
+ from torch.utils.data import DataLoader
36
+ from torchvision import transforms
37
+
38
+ from transformers.models.gemma2.modeling_gemma2 import Gemma2Model
39
+ from transformers.models.gemma.tokenization_gemma_fast import GemmaTokenizerFast
40
+ from transformers import (
41
+ CLIPTextModelWithProjection,
42
+ CLIPTokenizer,
43
+ T5EncoderModel,
44
+ T5Tokenizer,
45
+ )
46
+
47
+ import diffusers.optimization
48
+ from diffusers import VQModel
49
+
50
+ from src.scheduler import Scheduler
51
+ from diffusers.loaders import LoraLoaderMixin
52
+ from diffusers.utils import is_wandb_available
53
+ from src.pipeline import UnifiedPipeline
54
+ from torchvision.utils import save_image, make_grid
55
+ from train.trainer_utils import save_checkpoint
56
+ from train.dataset_utils import ImageCaptionLargeDataset
57
+ from train.dataset_utils import tokenize_prompt, encode_prompt
58
+ from src.transformer import Transformer2DModel, SymmetricTransformer2DModel
59
+ from train.trainer_utils import load_images_to_tensor
60
+
61
+ if is_wandb_available():
62
+ import wandb
63
+ # wandb.login(key="")
64
+
65
+ logger = get_logger(__name__, log_level="INFO")
66
+
67
+ import torch._dynamo
68
+ torch._dynamo.config.verbose = True
69
+
70
+ # Optionally suppress errors to fall back to eager execution
71
+ torch._dynamo.config.suppress_errors = True
72
+
73
+ def parse_args():
74
+ parser = argparse.ArgumentParser()
75
+ parser.add_argument(
76
+ "--pretrained_model_name_or_path",
77
+ type=str,
78
+ default=None,
79
+ required=True,
80
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
81
+ )
82
+ parser.add_argument(
83
+ "--pretrained_transformer_path",
84
+ type=str,
85
+ default=None,
86
+ required=True,
87
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
88
+ )
89
+ parser.add_argument(
90
+ "--text_encoder_architecture",
91
+ type=str,
92
+ default="open_clip",
93
+ required=False,
94
+ help="The architecture of the text encoder. One of ['CLIP', 'open_clip', 'flan-t5-base','Qwen2-0.5B','gemini-2b',long_t5_clip','t5_clip']",
95
+ )
96
+ parser.add_argument(
97
+ "--text_encoder_name_or_path",
98
+ type=str,
99
+ default=None,
100
+ required=True,
101
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
102
+ )
103
+ parser.add_argument(
104
+ "--remove_pooled_embeddings",
105
+ type=bool,
106
+ default=False,
107
+ required=False,
108
+ help="Whether to remove the pooled embeddings from the text encoder.",
109
+ )
110
+ parser.add_argument(
111
+ "--instance_dataset",
112
+ type=str,
113
+ default=None,
114
+ required=False,
115
+ help="The dataset to use for training. One of ['MSCOCO600K', 'PickaPicV2']",
116
+ )
117
+ parser.add_argument(
118
+ "--instance_data_dir",
119
+ type=str,
120
+ default=None,
121
+ required=False,
122
+ help="A folder containing the training data of instance images.",
123
+ )
124
+ parser.add_argument(
125
+ "--training_from_scratch",
126
+ type=bool,
127
+ default=False,
128
+ required=False
129
+ )
130
+ parser.add_argument(
131
+ "--revision",
132
+ type=str,
133
+ default=None,
134
+ required=False,
135
+ help="Revision of pretrained model identifier from huggingface.co/models.",
136
+ )
137
+ parser.add_argument(
138
+ "--variant",
139
+ type=str,
140
+ default=None,
141
+ help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
142
+ )
143
+ parser.add_argument(
144
+ "--instance_data_image", type=str, default=None, required=False, help="A single training image"
145
+ )
146
+ parser.add_argument(
147
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
148
+ )
149
+ parser.add_argument(
150
+ "--dataloader_num_workers",
151
+ type=int,
152
+ default=0,
153
+ help=(
154
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
155
+ ),
156
+ )
157
+ parser.add_argument(
158
+ "--allow_tf32",
159
+ action="store_true",
160
+ help=(
161
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
162
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
163
+ ),
164
+ )
165
+ parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.")
166
+ parser.add_argument("--ema_decay", type=float, default=0.9999)
167
+ parser.add_argument("--ema_update_after_step", type=int, default=0)
168
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
169
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
170
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
171
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
172
+ parser.add_argument(
173
+ "--output_dir",
174
+ type=str,
175
+ default="muse_training",
176
+ help="The output directory where the model predictions and checkpoints will be written.",
177
+ )
178
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
179
+ parser.add_argument(
180
+ "--logging_dir",
181
+ type=str,
182
+ default="logs",
183
+ help=(
184
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
185
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
186
+ ),
187
+ )
188
+ parser.add_argument(
189
+ "--max_train_steps",
190
+ type=int,
191
+ default=None,
192
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
193
+ )
194
+ parser.add_argument(
195
+ "--checkpointing_steps",
196
+ type=int,
197
+ default=500,
198
+ help=(
199
+ "Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. "
200
+ "In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference."
201
+ "Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components."
202
+ "See https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step"
203
+ "instructions."
204
+ ),
205
+ )
206
+ parser.add_argument(
207
+ "--logging_steps",
208
+ type=int,
209
+ default=50,
210
+ )
211
+ parser.add_argument(
212
+ "--checkpoints_total_limit",
213
+ type=int,
214
+ default=None,
215
+ help=(
216
+ "Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`."
217
+ " See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state"
218
+ " for more details"
219
+ ),
220
+ )
221
+ parser.add_argument(
222
+ "--resume_from_checkpoint",
223
+ type=str,
224
+ default=None,
225
+ help=(
226
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
227
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
228
+ ),
229
+ )
230
+ parser.add_argument(
231
+ "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
232
+ )
233
+ parser.add_argument(
234
+ "--gradient_accumulation_steps",
235
+ type=int,
236
+ default=1,
237
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
238
+ )
239
+ parser.add_argument(
240
+ "--text_loss_weight",
241
+ type=float,
242
+ default=0.2,
243
+ )
244
+ parser.add_argument(
245
+ "--learning_rate",
246
+ type=float,
247
+ default=0.0003,
248
+ help="Initial learning rate (after the potential warmup period) to use.",
249
+ )
250
+ parser.add_argument(
251
+ "--scale_lr",
252
+ action="store_true",
253
+ default=False,
254
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
255
+ )
256
+ parser.add_argument(
257
+ "--lr_scheduler",
258
+ type=str,
259
+ default="constant",
260
+ help=(
261
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
262
+ ' "constant", "constant_with_warmup"]'
263
+ ),
264
+ )
265
+ parser.add_argument(
266
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
267
+ )
268
+ parser.add_argument(
269
+ "--validation_steps",
270
+ type=int,
271
+ default=100,
272
+ help=(
273
+ "Run validation every X steps. Validation consists of running the prompt"
274
+ " `args.validation_prompt` multiple times: `args.num_validation_images`"
275
+ " and logging the images."
276
+ ),
277
+ )
278
+ parser.add_argument(
279
+ "--mixed_precision",
280
+ type=str,
281
+ default=None,
282
+ choices=["no", "fp16", "bf16"],
283
+ help=(
284
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
285
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
286
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
287
+ ),
288
+ )
289
+ parser.add_argument(
290
+ "--report_to",
291
+ type=str,
292
+ default="wandb",
293
+ help=(
294
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
295
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
296
+ ),
297
+ )
298
+ parser.add_argument("--validation_prompts", type=str, nargs="*")
299
+ parser.add_argument("--validation_images", type=str, default="./validation_images")
300
+ parser.add_argument(
301
+ "--resolution",
302
+ type=int,
303
+ default=512,
304
+ help=(
305
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
306
+ " resolution"
307
+ ),
308
+ )
309
+ parser.add_argument("--split_vae_encode", type=int, required=False, default=None)
310
+ parser.add_argument("--min_masking_rate", type=float, default=0.0)
311
+ parser.add_argument("--cond_dropout_prob", type=float, default=0.0)
312
+ parser.add_argument("--max_grad_norm", default=50.0, type=float, help="Max gradient norm.", required=False)
313
+ parser.add_argument("--use_lora", action="store_true", help="Fine tune the model using LoRa")
314
+ parser.add_argument("--text_encoder_use_lora", action="store_true", help="Fine tune the model using LoRa")
315
+ parser.add_argument("--lora_r", default=16, type=int)
316
+ parser.add_argument("--lora_alpha", default=32, type=int)
317
+ parser.add_argument("--lora_target_modules", default=["to_q", "to_k", "to_v"], type=str, nargs="+")
318
+ parser.add_argument("--text_encoder_lora_r", default=16, type=int)
319
+ parser.add_argument("--text_encoder_lora_alpha", default=32, type=int)
320
+ parser.add_argument("--text_encoder_lora_target_modules", default=["to_q", "to_k", "to_v"], type=str, nargs="+")
321
+ parser.add_argument("--train_text_encoder", action="store_true")
322
+ parser.add_argument("--image_to_text_only", action="store_true")
323
+ parser.add_argument("--image_key", type=str, required=False)
324
+ parser.add_argument("--prompt_key", type=str, required=False)
325
+ parser.add_argument(
326
+ "--gradient_checkpointing",
327
+ action="store_true",
328
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
329
+ )
330
+ parser.add_argument("--prompt_prefix", type=str, required=False, default=None)
331
+
332
+ args = parser.parse_args()
333
+
334
+ if args.report_to == "wandb":
335
+ if not is_wandb_available():
336
+ raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
337
+
338
+ num_datasources = sum(
339
+ [x is not None for x in [args.instance_data_dir, args.instance_data_image]]
340
+ )
341
+
342
+ if num_datasources != 1:
343
+ raise ValueError(
344
+ "provide one and only one of `--instance_data_dir`, `--instance_data_image`, or `--instance_data_dataset`"
345
+ )
346
+
347
+ if args.instance_data_dir is not None:
348
+ if not os.path.exists(args.instance_data_dir):
349
+ raise ValueError(f"Does not exist: `--args.instance_data_dir` {args.instance_data_dir}")
350
+
351
+ if args.instance_data_image is not None:
352
+ if not os.path.exists(args.instance_data_image):
353
+ raise ValueError(f"Does not exist: `--args.instance_data_image` {args.instance_data_image}")
354
+
355
+ return args
356
+
357
+ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
358
+ latent_image_ids = torch.zeros(height // 2, width // 2, 3)
359
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
360
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
361
+
362
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
363
+
364
+ latent_image_ids = latent_image_ids.reshape(
365
+ latent_image_id_height * latent_image_id_width, latent_image_id_channels
366
+ )
367
+
368
+ return latent_image_ids.to(device=device, dtype=dtype)
369
+
370
+ def main(args):
371
+ if args.allow_tf32:
372
+ torch.backends.cuda.matmul.allow_tf32 = True
373
+
374
+ logging_dir = Path(args.output_dir, args.logging_dir)
375
+
376
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
377
+
378
+ accelerator = Accelerator(
379
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
380
+ mixed_precision=args.mixed_precision,
381
+ log_with=args.report_to,
382
+ project_config=accelerator_project_config,
383
+ )
384
+
385
+ if accelerator.is_main_process:
386
+ os.makedirs(args.output_dir, exist_ok=True)
387
+
388
+ # Make one log on every process with the configuration for debugging.
389
+ logging.basicConfig(
390
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
391
+ datefmt="%m/%d/%Y %H:%M:%S",
392
+ level=logging.INFO,
393
+ )
394
+ logger.info(accelerator.state, main_process_only=False)
395
+
396
+ if accelerator.is_main_process:
397
+ accelerator.init_trackers("meissonic", config=vars(copy.deepcopy(args)))
398
+
399
+ if args.seed is not None:
400
+ set_seed(args.seed)
401
+
402
+ if args.text_encoder_architecture == "open_clip":
403
+ text_encoder = CLIPTextModelWithProjection.from_pretrained(
404
+ args.pretrained_model_name_or_path, subfolder="text_encoder", variant=args.variant
405
+ )
406
+ tokenizer = CLIPTokenizer.from_pretrained(
407
+ args.pretrained_model_name_or_path, subfolder="tokenizer", variant=args.variant
408
+ )
409
+ tokenizer_2 = None
410
+ text_encoder_2 = None
411
+
412
+ extra_id_0_token = "<extra_id_0>"
413
+ num_new_tokens = tokenizer.add_tokens(extra_id_0_token)
414
+ mask_id_1 = tokenizer.convert_tokens_to_ids(extra_id_0_token)
415
+ if num_new_tokens > 0:
416
+ text_encoder.resize_token_embeddings(len(tokenizer))
417
+ mask_token_embedding = text_encoder.get_input_embeddings().weight[mask_id_1]
418
+ mask_token_embedding = mask_token_embedding.clone().detach().cpu().float()
419
+ if accelerator.is_main_process:
420
+ print("Saving masked token embedding...")
421
+ torch.save(mask_token_embedding, os.path.join(args.output_dir, "mask_token_embedding.pth"))
422
+
423
+ text_encoder.requires_grad_(False)
424
+ elif args.text_encoder_architecture == "t5_clip":
425
+ tokenizer = CLIPTokenizer.from_pretrained(
426
+ args.pretrained_model_name_or_path, subfolder="tokenizer", variant=args.variant
427
+ )
428
+ text_encoder = CLIPTextModelWithProjection.from_pretrained(
429
+ args.pretrained_model_name_or_path, subfolder="text_encoder", variant=args.variant
430
+ )
431
+
432
+ tokenizer_2 = T5Tokenizer.from_pretrained(
433
+ args.pretrained_model_name_or_path, subfolder="tokenizer_2", variant=args.variant,
434
+ )
435
+ text_encoder_2 = T5EncoderModel.from_pretrained(
436
+ args.pretrained_model_name_or_path, subfolder="text_encoder_2", variant=args.variant,
437
+ )
438
+
439
+ text_encoder.requires_grad_(False)
440
+ text_encoder_2.requires_grad_(False)
441
+ elif args.text_encoder_architecture == "gemma":
442
+ tokenizer = CLIPTokenizer.from_pretrained(
443
+ args.pretrained_model_name_or_path, subfolder="tokenizer", variant=args.variant
444
+ )
445
+ text_encoder = CLIPTextModelWithProjection.from_pretrained(
446
+ args.pretrained_model_name_or_path, subfolder="text_encoder", variant=args.variant
447
+ )
448
+
449
+ tokenizer_2 = GemmaTokenizerFast.from_pretrained(
450
+ args.text_encoder_name_or_path, variant=args.variant,
451
+ )
452
+ text_encoder_2 = Gemma2Model.from_pretrained(
453
+ args.text_encoder_name_or_path, variant=args.variant,
454
+ )
455
+
456
+ extra_id_0_token = "<extra_id_0>"
457
+
458
+ tokenizer.add_tokens(extra_id_0_token)
459
+ tokenizer_2.add_tokens(extra_id_0_token)
460
+ mask_id_1 = tokenizer.convert_tokens_to_ids(extra_id_0_token)
461
+ mask_id_2 = tokenizer_2.convert_tokens_to_ids(extra_id_0_token)
462
+
463
+ text_encoder.resize_token_embeddings(len(tokenizer))
464
+ text_encoder_2.resize_token_embeddings(len(tokenizer_2))
465
+
466
+ text_encoder.requires_grad_(False)
467
+ text_encoder_2.requires_grad_(False)
468
+ else:
469
+ raise ValueError(f"Unknown text encoder architecture: {args.text_encoder_architecture}")
470
+
471
+ vq_model = VQModel.from_pretrained(
472
+ args.pretrained_model_name_or_path, subfolder="vqvae", revision=args.revision, variant=args.variant
473
+ )
474
+ vq_model.requires_grad_(False)
475
+
476
+ model = SymmetricTransformer2DModel.from_pretrained(
477
+ args.pretrained_transformer_path,
478
+ subfolder="transformer",
479
+ low_cpu_mem_usage=False,
480
+ device_map=None
481
+ )
482
+
483
+ if model.config.tokenizer_vocab_size is None:
484
+ if args.text_encoder_architecture == "open_clip":
485
+ model.register_to_config(tokenizer_vocab_size=len(tokenizer))
486
+ elif args.text_encoder_architecture in ("t5_clip", "gemma"):
487
+ model.register_to_config(tokenizer_vocab_size=len(tokenizer_2))
488
+ if model.adapter is None:
489
+ raise ValueError(f"The MMDiT must has adapter if you want to use t5_clip mode!!!")
490
+ else:
491
+ raise ValueError(f"Unknown text encoder architecture!")
492
+
493
+ if accelerator.is_main_process:
494
+ print(f"model's tokenizer vocab size is {model.config.tokenizer_vocab_size}")
495
+
496
+ model.text_decoder = nn.Sequential(
497
+ nn.LayerNorm(model.inner_dim, elementwise_affine=False, eps=1e-6),
498
+ nn.Linear(model.inner_dim, model.config.tokenizer_vocab_size, bias=False)
499
+ )
500
+
501
+ model = torch.compile(model)
502
+
503
+ if args.use_lora:
504
+ lora_config = LoraConfig(
505
+ r=args.lora_r,
506
+ lora_alpha=args.lora_alpha,
507
+ target_modules=args.lora_target_modules,
508
+ )
509
+ model.add_adapter(lora_config)
510
+
511
+ model.train()
512
+
513
+ if args.image_to_text_only:
514
+ frozen_keys = ["project_from_hidden", "up_block", "mlm_layer"]
515
+ for n, p in model.named_parameters():
516
+ if any([frozen_key in n for frozen_key in frozen_keys]):
517
+ p.requires_grad_(False)
518
+ else:
519
+ p.requires_grad_(True)
520
+ else:
521
+ model.requires_grad_(True)
522
+
523
+ if args.gradient_checkpointing:
524
+ model.enable_gradient_checkpointing()
525
+
526
+ def save_model_hook(models, weights, output_dir):
527
+ if accelerator.is_main_process:
528
+ transformer_lora_layers_to_save = None
529
+ text_encoder_lora_layers_to_save = None
530
+
531
+ for model_ in models:
532
+ if isinstance(model_, type(accelerator.unwrap_model(model))):
533
+ if args.use_lora:
534
+ transformer_lora_layers_to_save = get_peft_model_state_dict(model_)
535
+ else:
536
+ model_.save_pretrained(os.path.join(output_dir, "transformer"))
537
+ elif isinstance(model_, type(accelerator.unwrap_model(text_encoder))):
538
+ if args.text_encoder_use_lora:
539
+ text_encoder_lora_layers_to_save = get_peft_model_state_dict(model_)
540
+ else:
541
+ model_.save_pretrained(os.path.join(output_dir, "text_encoder"))
542
+ else:
543
+ raise ValueError(f"unexpected save model: {model_.__class__}")
544
+
545
+ # make sure to pop weight so that corresponding model is not saved again
546
+ weights.pop()
547
+
548
+ if transformer_lora_layers_to_save is not None or text_encoder_lora_layers_to_save is not None:
549
+ LoraLoaderMixin.save_lora_weights(
550
+ output_dir,
551
+ unet_lora_layers=transformer_lora_layers_to_save,
552
+ text_encoder_lora_layers=text_encoder_lora_layers_to_save,
553
+ )
554
+
555
+
556
+ def load_model_hook(models, input_dir):
557
+ transformer = None
558
+ text_encoder_ = None
559
+
560
+ # this part is added for keep consistency when add model.compile() in the model
561
+ def adap_compile(ori_dict):#add '_orig_mod.' to each key
562
+ new_dict = {}
563
+ for k,v in ori_dict.items():
564
+ new_dict['_orig_mod.' + k] = v
565
+ return new_dict
566
+
567
+ while len(models) > 0:
568
+ model_ = models.pop()
569
+
570
+ if isinstance(model_, type(accelerator.unwrap_model(model))):
571
+ if args.use_lora:
572
+ transformer = model_
573
+ else:
574
+ load_model = SymmetricTransformer2DModel.from_pretrained(os.path.join(input_dir, "transformer"), low_cpu_mem_usage=False, device_map=None)
575
+ model_.load_state_dict(adap_compile(load_model.state_dict()))
576
+ del load_model
577
+ elif isinstance(model_, type(accelerator.unwrap_model(text_encoder))):
578
+ if args.text_encoder_use_lora:
579
+ text_encoder_ = model_
580
+ else:
581
+ try:
582
+ load_model = CLIPTextModelWithProjection.from_pretrained(os.path.join(input_dir, "text_encoder"))
583
+ model_.load_state_dict(load_model.state_dict())
584
+ # print('finished loading text encoder!')
585
+ except:
586
+ print('Not found text-encoder model in current folder. So we download one text encoder from Internet.')
587
+ load_model = CLIPTextModelWithProjection.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K")
588
+ model_.load_state_dict(load_model.state_dict())
589
+ del load_model
590
+ else:
591
+ raise ValueError(f"unexpected save model: {model.__class__}")
592
+
593
+ if transformer is not None or text_encoder_ is not None:
594
+ lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir)
595
+ LoraLoaderMixin.load_lora_into_text_encoder(
596
+ lora_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_
597
+ )
598
+ LoraLoaderMixin.load_lora_into_transformer(
599
+ lora_state_dict, network_alphas=network_alphas, transformer=transformer
600
+ )
601
+
602
+ accelerator.register_load_state_pre_hook(load_model_hook)
603
+ accelerator.register_save_state_pre_hook(save_model_hook)
604
+
605
+ if args.scale_lr:
606
+ args.learning_rate = (
607
+ args.learning_rate * args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
608
+ )
609
+
610
+ if args.use_8bit_adam:
611
+ try:
612
+ import bitsandbytes as bnb
613
+ except ImportError:
614
+ raise ImportError(
615
+ "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
616
+ )
617
+
618
+ optimizer_cls = bnb.optim.AdamW8bit
619
+ else:
620
+ optimizer_cls = torch.optim.AdamW
621
+
622
+ optimizer_grouped_parameters = [
623
+ {
624
+ "params": [p for p in model.parameters() if p.requires_grad],
625
+ "weight_decay": args.adam_weight_decay,
626
+ }
627
+ ]
628
+ optimizer = optimizer_cls(
629
+ optimizer_grouped_parameters,
630
+ lr=args.learning_rate,
631
+ betas=(args.adam_beta1, args.adam_beta2),
632
+ weight_decay=args.adam_weight_decay,
633
+ eps=args.adam_epsilon,
634
+ )
635
+
636
+ logger.info("Creating dataloaders and lr_scheduler")
637
+
638
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
639
+
640
+ if args.text_encoder_architecture in ("t5_clip", "gemma"):
641
+ tokenizer_for_dataset = [tokenizer, tokenizer_2]
642
+ else:
643
+ tokenizer_for_dataset = tokenizer
644
+
645
+ if args.instance_dataset == "ImageCaptionLargeDataset":
646
+ dataset = ImageCaptionLargeDataset(
647
+ root_dir=args.instance_data_dir,
648
+ tokenizer=tokenizer_for_dataset,
649
+ size=args.resolution,
650
+ text_encoder_architecture=args.text_encoder_architecture
651
+ )
652
+ elif args.instance_dataset == "DATA_TYPE":
653
+ raise NotImplementedError("DATA_TYPE is not yet supported")
654
+ else:
655
+ assert False
656
+
657
+ def collate_fn(samples):
658
+ images = [sample["image"] for sample in samples]
659
+ micro_conds = [sample["micro_conds"] for sample in samples]
660
+
661
+ images = torch.stack(images, dim=0)
662
+ micro_conds = torch.stack(micro_conds, dim=0)
663
+
664
+ if isinstance(samples[0]["prompt_input_ids"], list):
665
+ input_ids = [sample["prompt_input_ids"][0] for sample in samples]
666
+ input_ids_2 = [sample["prompt_input_ids"][1] for sample in samples]
667
+
668
+ input_ids = torch.cat(input_ids, dim=0)
669
+ input_ids_2 = torch.cat(input_ids_2, dim=0)
670
+ prompt_input_ids = [input_ids, input_ids_2]
671
+ else:
672
+ input_ids = [sample["prompt_input_ids"] for sample in samples]
673
+
674
+ input_ids = torch.cat(input_ids, dim=0)
675
+ prompt_input_ids = input_ids
676
+
677
+ ret = dict(
678
+ images=images,
679
+ micro_conds=micro_conds,
680
+ prompt_input_ids=prompt_input_ids,
681
+ )
682
+
683
+ return ret
684
+
685
+ train_dataloader = DataLoader(
686
+ dataset,
687
+ batch_size=args.train_batch_size,
688
+ shuffle=True,
689
+ num_workers=args.dataloader_num_workers,
690
+ collate_fn=collate_fn,
691
+ pin_memory=True,
692
+ )
693
+ train_dataloader.num_batches = len(train_dataloader)
694
+
695
+ lr_scheduler = diffusers.optimization.get_scheduler(
696
+ args.lr_scheduler,
697
+ optimizer=optimizer,
698
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
699
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
700
+ )
701
+
702
+ logger.info("Preparing model, optimizer and dataloaders")
703
+
704
+ model, optimizer, lr_scheduler, train_dataloader = accelerator.prepare(
705
+ model, optimizer, lr_scheduler, train_dataloader
706
+ )
707
+
708
+ train_dataloader.num_batches = len(train_dataloader)
709
+
710
+ weight_dtype = torch.float32
711
+ if accelerator.mixed_precision == "fp16":
712
+ weight_dtype = torch.float16
713
+ elif accelerator.mixed_precision == "bf16":
714
+ weight_dtype = torch.bfloat16
715
+
716
+ if args.text_encoder_architecture in ("t5_clip", "gemma"):
717
+ text_encoder.to(device=accelerator.device, dtype=weight_dtype)
718
+ text_encoder_2.to(device=accelerator.device, dtype=weight_dtype)
719
+ else:
720
+ text_encoder.to(device=accelerator.device, dtype=weight_dtype)
721
+
722
+ vq_model.to(device=accelerator.device)
723
+
724
+ with torch.no_grad():
725
+ if args.text_encoder_architecture in ("t5_clip", "gemma"):
726
+ _input_ids_tmp_ = tokenize_prompt([tokenizer, tokenizer_2], "", args.text_encoder_architecture)
727
+ _input_ids_tmp_[0] = _input_ids_tmp_[0].to(accelerator.device)
728
+ _input_ids_tmp_[1] = _input_ids_tmp_[1].to(accelerator.device)
729
+ empty_embeds, empty_clip_embeds = encode_prompt(
730
+ [text_encoder, text_encoder_2],
731
+ _input_ids_tmp_,
732
+ args.text_encoder_architecture
733
+ )
734
+ else:
735
+ _input_ids_tmp_ = tokenize_prompt(tokenizer, "", args.text_encoder_architecture)
736
+ _input_ids_tmp_ = _input_ids_tmp_.to(accelerator.device)
737
+ empty_embeds, empty_clip_embeds = encode_prompt(
738
+ text_encoder,
739
+ _input_ids_tmp_,
740
+ args.text_encoder_architecture
741
+ )
742
+
743
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
744
+ num_update_steps_per_epoch = math.ceil(train_dataloader.num_batches / args.gradient_accumulation_steps)
745
+ # Afterwards we recalculate our number of training epochs.
746
+ # Note: We are not doing epoch based training here, but just using this for book keeping and being able to
747
+ # reuse the same training loop with other datasets/loaders.
748
+ num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
749
+
750
+ # Train!
751
+ logger.info("***** Running training *****")
752
+ logger.info(f" Num training steps = {args.max_train_steps}")
753
+ logger.info(f" Instantaneous batch size per device = { args.train_batch_size}")
754
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
755
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
756
+
757
+ resume_from_checkpoint = args.resume_from_checkpoint
758
+ if resume_from_checkpoint:
759
+ if resume_from_checkpoint == "latest":
760
+ # Get the most recent checkpoint
761
+ dirs = os.listdir(args.output_dir)
762
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
763
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
764
+ if len(dirs) > 0:
765
+ resume_from_checkpoint = os.path.join(args.output_dir, dirs[-1])
766
+ else:
767
+ resume_from_checkpoint = None
768
+
769
+ if resume_from_checkpoint is None:
770
+ accelerator.print(
771
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
772
+ )
773
+ else:
774
+ accelerator.print(f"Resuming from checkpoint {resume_from_checkpoint}")
775
+
776
+ if resume_from_checkpoint is None:
777
+ global_step = 0
778
+ first_epoch = 0
779
+ else:
780
+ accelerator.load_state(resume_from_checkpoint)
781
+ global_step = int(os.path.basename(resume_from_checkpoint).split("-")[1])
782
+ first_epoch = global_step // num_update_steps_per_epoch
783
+
784
+ # As stated above, we are not doing epoch based training here, but just using this for book keeping and being able to
785
+ # reuse the same training loop with other datasets/loaders.
786
+ for epoch in range(first_epoch, num_train_epochs):
787
+ for batch in train_dataloader:
788
+ torch.cuda.empty_cache()
789
+ with torch.no_grad():
790
+ micro_conds = batch["micro_conds"].to(accelerator.device, non_blocking=True)
791
+ image_micro_conds, text_micro_conds = micro_conds.chunk(2, dim=0)
792
+
793
+ pixel_values = batch["images"].to(accelerator.device, non_blocking=True)
794
+ batch_size = pixel_values.shape[0]
795
+
796
+ # ====================== tokenize images ======================
797
+ split_batch_size = args.split_vae_encode if args.split_vae_encode is not None else batch_size
798
+ num_splits = math.ceil(batch_size / split_batch_size)
799
+ image_tokens = []
800
+ for i in range(num_splits):
801
+ start_idx = i * split_batch_size
802
+ end_idx = min((i + 1) * split_batch_size, batch_size)
803
+ image_tokens.append(
804
+ vq_model.quantize(
805
+ vq_model.encode(pixel_values[start_idx:end_idx]).latents
806
+ )[2][2].reshape(split_batch_size, -1)
807
+ )
808
+ image_tokens = torch.cat(image_tokens, dim=0)
809
+ # ====================== tokenize images ======================
810
+
811
+
812
+ # ====================== tokenize text prompts ======================
813
+ if args.text_encoder_architecture in ("t5_clip", "gemma"):
814
+ prompt_input_ids_clip = batch["prompt_input_ids"][0].to(accelerator.device, non_blocking=True)
815
+ prompt_input_ids_t5 = batch["prompt_input_ids"][1].to(accelerator.device, non_blocking=True)
816
+ prompt_input_ids_clip_1, prompt_input_ids_clip_2 = prompt_input_ids_clip.chunk(2, dim=0)
817
+ prompt_input_ids_t5_1, prompt_input_ids_t5_2 = prompt_input_ids_t5.chunk(2, dim=0)
818
+ encoder_hidden_states, cond_embeds = encode_prompt(
819
+ [text_encoder, text_encoder_2],
820
+ [prompt_input_ids_clip_1, prompt_input_ids_t5_1],
821
+ args.text_encoder_architecture
822
+ )
823
+ else:
824
+ prompt_input_ids = batch["prompt_input_ids"].to(accelerator.device, non_blocking=True)
825
+ prompt_input_ids_clip_1, prompt_input_ids_clip_2 = prompt_input_ids.chunk(2, dim=0)
826
+ encoder_hidden_states, cond_embeds = encode_prompt(
827
+ text_encoder,
828
+ prompt_input_ids_clip_1,
829
+ args.text_encoder_architecture
830
+ )
831
+ encoder_hidden_states = encoder_hidden_states.to(accelerator.device, dtype=accelerator.unwrap_model(model).dtype)
832
+ cond_embeds = cond_embeds.to(accelerator.device, dtype=accelerator.unwrap_model(model).dtype)
833
+ # ====================== tokenize text prompts ======================
834
+
835
+
836
+ # ====================== image perturbation ======================
837
+ image_tokens_1, image_tokens_2 = image_tokens.chunk(2, dim=0) # (b // 2, seq_len)
838
+ half_batch_size, seq_len = image_tokens_1.shape
839
+ sigma = torch.rand(half_batch_size, device=image_tokens_1.device)
840
+ image_mask_prob = torch.cos(sigma * math.pi * 0.5)
841
+ image_mask_prob = image_mask_prob.clip(args.min_masking_rate)
842
+
843
+ num_token_masked = (seq_len * image_mask_prob).round().clamp(min=1)
844
+ batch_randperm = torch.rand(half_batch_size, seq_len, device=image_tokens_1.device).argsort(dim=-1)
845
+ mask = batch_randperm < num_token_masked.unsqueeze(-1)
846
+
847
+ mask_id = accelerator.unwrap_model(model).config.vocab_size - 1
848
+ masked_image_ids = torch.where(mask, mask_id, image_tokens_1)
849
+ image_labels = torch.where(mask, image_tokens_1, -100)
850
+ # ====================== image perturbation ======================
851
+
852
+
853
+ # ====================== text perturbation ======================
854
+ if args.text_encoder_architecture in ("t5_clip", "gemma"):
855
+ half_batch_size, seq_len = prompt_input_ids_t5_2.shape
856
+ sigma = torch.rand(half_batch_size, device=image_tokens_1.device)
857
+ text_mask_prob = torch.cos(sigma * math.pi * 0.5)
858
+ text_mask_prob = text_mask_prob.clip(args.min_masking_rate)
859
+ text_timestep = text_mask_prob.clone().clamp(min=1e-3)
860
+
861
+ num_token_masked = (seq_len * text_mask_prob).round().clamp(min=1)
862
+ batch_randperm = torch.rand(half_batch_size, seq_len, device=image_tokens_1.device).argsort(dim=-1)
863
+ mask = batch_randperm < num_token_masked.unsqueeze(-1)
864
+
865
+ masked_prompt_input_ids_t5 = torch.where(mask, mask_id_2, prompt_input_ids_t5_2)
866
+ text_labels = torch.where(mask, prompt_input_ids_t5_2, -100)
867
+
868
+ # prepare input_ids for clip model
869
+ batch_prompt_2 = []
870
+ for i in range(masked_prompt_input_ids_t5.size(0)):
871
+ masked_prompt_input_id = masked_prompt_input_ids_t5[i].tolist()
872
+ prompt_2 = tokenizer_2.decode(masked_prompt_input_id, skip_special_tokens=True)
873
+ batch_prompt_2.append(prompt_2)
874
+
875
+ masked_prompt_input_ids_clip = tokenizer(
876
+ batch_prompt_2,
877
+ truncation=True,
878
+ padding="max_length",
879
+ max_length=77,
880
+ return_tensors="pt"
881
+ ).input_ids
882
+ masked_prompt_input_ids_clip = masked_prompt_input_ids_clip.to(accelerator.device)
883
+ else:
884
+ half_batch_size, seq_len = prompt_input_ids_clip_2.shape
885
+ sigma = torch.rand(half_batch_size, device=image_tokens_1.device)
886
+ text_mask_prob = torch.cos(sigma * math.pi * 0.5)
887
+ text_mask_prob = text_mask_prob.clip(args.min_masking_rate)
888
+ text_timestep = text_mask_prob.clone().clamp(min=1e-3)
889
+
890
+ num_token_masked = (seq_len * text_mask_prob).round().clamp(min=1)
891
+ batch_randperm = torch.rand(half_batch_size, seq_len, device=image_tokens_1.device).argsort(dim=-1)
892
+ mask = batch_randperm < num_token_masked.unsqueeze(-1)
893
+
894
+ masked_prompt_input_ids_clip = torch.where(mask, mask_id_1, prompt_input_ids_clip_2)
895
+ text_labels = torch.where(mask, prompt_input_ids_clip_2, -100)
896
+ # ====================== text perturbation ======================
897
+
898
+
899
+ # ====================== encode masked text prompts ======================
900
+ if args.text_encoder_architecture in ("t5_clip", "gemma"):
901
+ masked_encoder_hidden_states, masked_cond_embeds = encode_prompt(
902
+ [text_encoder, text_encoder_2],
903
+ [masked_prompt_input_ids_clip, masked_prompt_input_ids_t5],
904
+ args.text_encoder_architecture
905
+ )
906
+ else:
907
+ masked_encoder_hidden_states, masked_cond_embeds = encode_prompt(
908
+ text_encoder,
909
+ masked_prompt_input_ids_clip,
910
+ args.text_encoder_architecture
911
+ )
912
+ masked_encoder_hidden_states = masked_encoder_hidden_states.to(accelerator.device, dtype=accelerator.unwrap_model(model).dtype)
913
+ masked_cond_embeds = masked_cond_embeds.to(accelerator.device, dtype=accelerator.unwrap_model(model).dtype)
914
+ # ====================== encode masked text prompts ======================
915
+
916
+
917
+ # for CFG
918
+ if args.cond_dropout_prob > 0.0:
919
+ assert encoder_hidden_states is not None
920
+
921
+ batch_size = encoder_hidden_states.shape[0]
922
+
923
+ mask = (
924
+ torch.zeros((batch_size, 1, 1), device=encoder_hidden_states.device).float().uniform_(0, 1)
925
+ < args.cond_dropout_prob
926
+ )
927
+
928
+ empty_embeds_ = empty_embeds.expand(batch_size, -1, -1)
929
+ encoder_hidden_states = torch.where(
930
+ (encoder_hidden_states * mask).bool(), encoder_hidden_states, empty_embeds_
931
+ )
932
+
933
+ empty_clip_embeds_ = empty_clip_embeds.expand(batch_size, -1)
934
+ cond_embeds = torch.where((cond_embeds * mask.squeeze(-1)).bool(), cond_embeds, empty_clip_embeds_)
935
+
936
+ vae_scale_factor = 2 ** (len(vq_model.config.block_out_channels) - 1)
937
+ resolution = args.resolution // vae_scale_factor
938
+ masked_image_ids = masked_image_ids.reshape(half_batch_size, resolution, resolution)
939
+ image_ids = image_tokens_2.reshape(half_batch_size, resolution, resolution)
940
+
941
+
942
+ # Train Step
943
+ with accelerator.accumulate(model):
944
+ codebook_size = accelerator.unwrap_model(model).config.codebook_size
945
+ if args.resolution == 1024: # only stage 3 and stage 4 do not apply 2*
946
+ img_ids = _prepare_latent_image_ids(
947
+ masked_image_ids.shape[0],
948
+ masked_image_ids.shape[-2],
949
+ masked_image_ids.shape[-1],
950
+ masked_image_ids.device,
951
+ masked_image_ids.dtype
952
+ )
953
+ else:
954
+ img_ids = _prepare_latent_image_ids(
955
+ masked_image_ids.shape[0],
956
+ masked_image_ids.shape[-2],
957
+ masked_image_ids.shape[-1],
958
+ masked_image_ids.device,
959
+ masked_image_ids.dtype
960
+ )
961
+
962
+ txt_ids = torch.zeros(encoder_hidden_states.shape[1], 3).to(device=masked_image_ids.device, dtype=masked_image_ids.dtype)
963
+
964
+ image_logits = (
965
+ model(
966
+ hidden_states=masked_image_ids, # should be (batch size, channel, height, width)
967
+ encoder_hidden_states=encoder_hidden_states, # should be (batch size, sequence_len, embed_dims)
968
+ micro_conds=image_micro_conds,
969
+ pooled_projections=cond_embeds, # should be (batch_size, projection_dim)
970
+ img_ids=img_ids,
971
+ txt_ids=txt_ids,
972
+ timestep=image_mask_prob,
973
+ )[0]
974
+ .reshape(half_batch_size, codebook_size, -1)
975
+ .permute(0, 2, 1)
976
+ .reshape(-1, codebook_size)
977
+ )
978
+
979
+ image_loss = F.cross_entropy(
980
+ image_logits,
981
+ image_labels.view(-1),
982
+ ignore_index=-100,
983
+ reduction="mean",
984
+ )
985
+
986
+ text_logits = model(
987
+ hidden_states=image_ids, # should be (batch size, channel, height, width)
988
+ encoder_hidden_states=masked_encoder_hidden_states, # should be (batch size, sequence_len, embed_dims)
989
+ micro_conds=text_micro_conds,
990
+ pooled_projections=masked_cond_embeds, # should be (batch_size, projection_dim)
991
+ img_ids=img_ids,
992
+ txt_ids=txt_ids,
993
+ timestep=text_timestep,
994
+ )[1]
995
+ text_logits = text_logits.reshape(-1, accelerator.unwrap_model(model).config.tokenizer_vocab_size)
996
+
997
+ text_loss = F.cross_entropy(
998
+ text_logits,
999
+ text_labels.view(-1),
1000
+ ignore_index=-100,
1001
+ reduction="none",
1002
+ )
1003
+ text_loss = text_loss.reshape(half_batch_size, -1).mean(-1)
1004
+ text_loss = text_loss / text_timestep
1005
+ text_loss = text_loss.mean()
1006
+
1007
+ loss = image_loss + args.text_loss_weight * text_loss
1008
+
1009
+ # Gather the losses across all processes for logging (if we use distributed training).
1010
+ avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
1011
+ avg_masking_rate = accelerator.gather(text_mask_prob.repeat(args.train_batch_size)).mean()
1012
+
1013
+ accelerator.backward(loss)
1014
+
1015
+ # Temporarily add this to identify unused parameters
1016
+ # for name, param in accelerator.unwrap_model(model).named_parameters():
1017
+ # if param.grad is None:
1018
+ # print(f"Unused parameter: {name}")
1019
+
1020
+ if args.max_grad_norm is not None and accelerator.sync_gradients:
1021
+ accelerator.clip_grad_norm_(model.parameters(), args.max_grad_norm)
1022
+
1023
+ optimizer.step()
1024
+ lr_scheduler.step()
1025
+
1026
+ optimizer.zero_grad(set_to_none=True)
1027
+
1028
+ # Checks if the accelerator has performed an optimization step behind the scenes
1029
+ if accelerator.sync_gradients:
1030
+ if (global_step + 1) % args.logging_steps == 0:
1031
+ logs = {
1032
+ "step_loss": avg_loss.item(),
1033
+ "lr": lr_scheduler.get_last_lr()[0],
1034
+ "avg_masking_rate": avg_masking_rate.item(),
1035
+ }
1036
+ accelerator.log(logs, step=global_step + 1)
1037
+
1038
+ logger.info(
1039
+ f"Step: {global_step + 1} "
1040
+ f"Loss: {avg_loss.item():0.4f} "
1041
+ f"LR: {lr_scheduler.get_last_lr()[0]:0.6f}"
1042
+ )
1043
+
1044
+ if (global_step + 1) % args.checkpointing_steps == 0:
1045
+ save_checkpoint(args, accelerator, global_step + 1, logger)
1046
+
1047
+ if (global_step + 1) % args.validation_steps == 0 and accelerator.is_main_process:
1048
+
1049
+ with torch.no_grad():
1050
+ logger.info("Generating images...")
1051
+
1052
+ model.eval()
1053
+
1054
+ scheduler = Scheduler.from_pretrained(
1055
+ args.pretrained_model_name_or_path,
1056
+ subfolder="scheduler",
1057
+ revision=args.revision,
1058
+ variant=args.variant,
1059
+ )
1060
+
1061
+ pipe = UnifiedPipeline(
1062
+ transformer=accelerator.unwrap_model(model),
1063
+ tokenizer=tokenizer,
1064
+ text_encoder=text_encoder,
1065
+ vqvae=vq_model,
1066
+ scheduler=scheduler,
1067
+ tokenizer_2=tokenizer_2,
1068
+ text_encoder_2=text_encoder_2,
1069
+ )
1070
+
1071
+ if not args.image_to_text_only:
1072
+ output = pipe(
1073
+ prompt=args.validation_prompts,
1074
+ height=args.resolution,
1075
+ width=args.resolution,
1076
+ guidance_scale=9,
1077
+ num_inference_steps=64,
1078
+ )
1079
+ pil_images = output.images
1080
+
1081
+ result=[]
1082
+ for img in pil_images:
1083
+ if not isinstance(img, torch.Tensor):
1084
+ img = transforms.ToTensor()(img)
1085
+ result.append(img.unsqueeze(0))
1086
+ result = torch.cat(result,dim=0)
1087
+ result = make_grid(result, nrow=3)
1088
+ save_image(result,os.path.join(args.output_dir, str(global_step)+'_text2image_1024_CFG-9.png'))
1089
+
1090
+ output_data = {
1091
+ "step": global_step,
1092
+ "prompts": args.validation_prompts,
1093
+ "images": [f"{global_step}_text2image_1024_CFG-9_{i}.png" for i in range(len(pil_images))]
1094
+ }
1095
+
1096
+ with open(os.path.join(args.output_dir, f"text2image_{global_step}.json"), "w") as f:
1097
+ json.dump(output_data, f, indent=2)
1098
+
1099
+ image = load_images_to_tensor(args.validation_images, target_size=(args.resolution, args.resolution))
1100
+ output = pipe(
1101
+ prompt=args.validation_prompts,
1102
+ height=args.resolution,
1103
+ width=args.resolution,
1104
+ guidance_scale=9,
1105
+ image=image,
1106
+ num_inference_steps=64
1107
+ )
1108
+ prompts = output.prompts
1109
+
1110
+ output_data = {
1111
+ "step": global_step,
1112
+ "prompts": prompts,
1113
+ }
1114
+
1115
+ with open(os.path.join(args.output_dir, f"image2text_{global_step}.json"), "w") as f:
1116
+ json.dump(output_data, f, indent=2)
1117
+
1118
+ model.train()
1119
+
1120
+ global_step += 1
1121
+
1122
+ # Stop training if max steps is reached
1123
+ if global_step >= args.max_train_steps:
1124
+ break
1125
+ # End for
1126
+
1127
+ accelerator.wait_for_everyone()
1128
+
1129
+ # Evaluate and save checkpoint at the end of training
1130
+ save_checkpoint(args, accelerator, global_step, logger)
1131
+
1132
+ # Save the final trained checkpoint
1133
+ if accelerator.is_main_process:
1134
+ model = accelerator.unwrap_model(model)
1135
+ model.save_pretrained(args.output_dir)
1136
+
1137
+ accelerator.end_training()
1138
+
1139
+
1140
+ if __name__ == "__main__":
1141
+ main(parse_args())
train/train_unified.sh ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # bash it in root path
2
+ PYTHON_PATH='./' accelerate launch --multi_gpu --gpu_ids '0,1,2,3,4,5,6,7' --main_process_port 25000 --num_processes 8 train/train_unified.py \
3
+ --output_dir "/path/to/output/dir" \
4
+ --train_batch_size 8 \
5
+ --gradient_accumulation_steps 8 \
6
+ --learning_rate 1e-4 \
7
+ --text_loss_weight 0 \
8
+ --max_grad_norm 10 \
9
+ --pretrained_model_name_or_path "MeissonFlow/Meissonic" \
10
+ --pretrained_transformer_path "MeissonFlow/Meissonic" \
11
+ --text_encoder_architecture 'open_clip' \
12
+ --instance_dataset 'ImageCaptionLargeDataset' \
13
+ --instance_data_dir '/path/to/data/' \
14
+ --resolution 512 \
15
+ --mixed_precision fp16 \
16
+ --lr_scheduler constant \
17
+ --use_8bit_adam \
18
+ --dataloader_num_workers 4 \
19
+ --validation_prompts \
20
+ 'a boy' \
21
+ 'A serene mountain landscape with towering snow-capped peaks, a crystal-clear blue lake reflecting the mountains, dense pine forests, and a vibrant orange sunrise illuminating the sky.' \
22
+ 'A playful golden retriever puppy with a shiny coat, bounding through a meadow filled with colorful wildflowers, under a bright, clear blue sky.' \
23
+ 'A bustling city street at night, illuminated by vibrant neon signs in various colors, with busy pedestrians, street vendors, and a light rain creating reflective puddles on the pavement.' \
24
+ 'A majestic, medieval castle perched on a rugged cliffside, overlooking a vast, calm ocean at sunset, with the sky painted in hues of pink, orange, and purple.' \
25
+ 'An elegant ballerina in a white tutu, dancing gracefully on a grand stage with ornate, gold-trimmed curtains, under a spotlight that casts a soft glow.' \
26
+ 'A cozy, rustic log cabin nestled in a snow-covered forest, with smoke rising from the stone chimney, warm lights glowing from the windows, and a path of footprints leading to the front door.'\
27
+ 'A Cute Cat' \
28
+ 'A Snow Mountain'\
29
+ --max_train_steps 100000 \
30
+ --checkpointing_steps 1000 \
31
+ --validation_steps 100 \
32
+ --report_to 'wandb' \
33
+ --logging_steps 10
train/train_unified_new.py ADDED
@@ -0,0 +1,1078 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team and The MeissonFlow Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import argparse
16
+ import copy
17
+ import logging
18
+ import math
19
+ import os
20
+ from pathlib import Path
21
+ import sys
22
+ sys.path.append(os.getcwd())
23
+ import json
24
+ import gc
25
+
26
+ import torch
27
+ import torch.nn.functional as F
28
+ from torch import nn
29
+
30
+ from accelerate import Accelerator
31
+ from accelerate.logging import get_logger
32
+ from accelerate.utils import ProjectConfiguration, set_seed
33
+ from peft import LoraConfig
34
+ from peft.utils import get_peft_model_state_dict
35
+ from torch.utils.data import DataLoader
36
+ from torchvision import transforms
37
+
38
+ from transformers.models.gemma2.modeling_gemma2 import Gemma2Model
39
+ from transformers.models.gemma.tokenization_gemma_fast import GemmaTokenizerFast
40
+ from transformers import (
41
+ CLIPTextModelWithProjection,
42
+ CLIPTokenizer,
43
+ CLIPImageProcessor,
44
+ CLIPVisionModelWithProjection,
45
+ )
46
+
47
+ import diffusers.optimization
48
+ from diffusers import VQModel
49
+
50
+ from src.scheduler import Scheduler
51
+ from diffusers.loaders import LoraLoaderMixin
52
+ from diffusers.utils import is_wandb_available
53
+ from src.pipeline import UnifiedPipeline_new
54
+ from torchvision.utils import save_image, make_grid
55
+ from train.trainer_utils import save_checkpoint
56
+ from train.dataset_utils import ImageCaptionLargeDataset
57
+ from train.dataset_utils import encode_prompt
58
+ from src.transformer import SymmetricTransformer2DModel
59
+ from train.trainer_utils import load_images_to_tensor
60
+
61
+ if is_wandb_available():
62
+ import wandb
63
+ # wandb.login(key="")
64
+
65
+ logger = get_logger(__name__, log_level="INFO")
66
+
67
+ import torch._dynamo
68
+ torch._dynamo.config.verbose = True
69
+
70
+ # Optionally suppress errors to fall back to eager execution
71
+ torch._dynamo.config.suppress_errors = True
72
+
73
+ def parse_args():
74
+ parser = argparse.ArgumentParser()
75
+ parser.add_argument(
76
+ "--pretrained_model_name_or_path",
77
+ type=str,
78
+ default=None,
79
+ required=True,
80
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
81
+ )
82
+ parser.add_argument(
83
+ "--pretrained_transformer_path",
84
+ type=str,
85
+ default=None,
86
+ required=True,
87
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
88
+ )
89
+ parser.add_argument(
90
+ "--text_encoder_architecture",
91
+ type=str,
92
+ default="open_clip",
93
+ required=False,
94
+ help="The architecture of the text encoder. One of ['open_clip', 'gemma']",
95
+ )
96
+ parser.add_argument(
97
+ "--clip_model_name_or_path",
98
+ type=str,
99
+ default=None,
100
+ required=True,
101
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
102
+ )
103
+ parser.add_argument(
104
+ "--text_encoder_2_name_or_path",
105
+ type=str,
106
+ default=None,
107
+ required=False,
108
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
109
+ )
110
+ parser.add_argument(
111
+ "--instance_dataset",
112
+ type=str,
113
+ default=None,
114
+ required=False,
115
+ help="The dataset to use for training. One of ['MSCOCO600K', 'PickaPicV2']",
116
+ )
117
+ parser.add_argument(
118
+ "--instance_data_dir",
119
+ type=str,
120
+ default=None,
121
+ required=False,
122
+ help="A folder containing the training data of instance images.",
123
+ )
124
+ parser.add_argument(
125
+ "--training_from_scratch",
126
+ type=bool,
127
+ default=False,
128
+ required=False
129
+ )
130
+ parser.add_argument(
131
+ "--revision",
132
+ type=str,
133
+ default=None,
134
+ required=False,
135
+ help="Revision of pretrained model identifier from huggingface.co/models.",
136
+ )
137
+ parser.add_argument(
138
+ "--variant",
139
+ type=str,
140
+ default=None,
141
+ help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
142
+ )
143
+ parser.add_argument(
144
+ "--instance_data_image", type=str, default=None, required=False, help="A single training image"
145
+ )
146
+ parser.add_argument(
147
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
148
+ )
149
+ parser.add_argument(
150
+ "--dataloader_num_workers",
151
+ type=int,
152
+ default=0,
153
+ help=(
154
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
155
+ ),
156
+ )
157
+ parser.add_argument(
158
+ "--allow_tf32",
159
+ action="store_true",
160
+ help=(
161
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
162
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
163
+ ),
164
+ )
165
+ parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.")
166
+ parser.add_argument("--ema_decay", type=float, default=0.9999)
167
+ parser.add_argument("--ema_update_after_step", type=int, default=0)
168
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
169
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
170
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
171
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
172
+ parser.add_argument(
173
+ "--output_dir",
174
+ type=str,
175
+ default="muse_training",
176
+ help="The output directory where the model predictions and checkpoints will be written.",
177
+ )
178
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
179
+ parser.add_argument(
180
+ "--logging_dir",
181
+ type=str,
182
+ default="logs",
183
+ help=(
184
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
185
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
186
+ ),
187
+ )
188
+ parser.add_argument(
189
+ "--max_train_steps",
190
+ type=int,
191
+ default=None,
192
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
193
+ )
194
+ parser.add_argument(
195
+ "--checkpointing_steps",
196
+ type=int,
197
+ default=500,
198
+ help=(
199
+ "Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. "
200
+ "In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference."
201
+ "Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components."
202
+ "See https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step"
203
+ "instructions."
204
+ ),
205
+ )
206
+ parser.add_argument(
207
+ "--logging_steps",
208
+ type=int,
209
+ default=50,
210
+ )
211
+ parser.add_argument(
212
+ "--checkpoints_total_limit",
213
+ type=int,
214
+ default=None,
215
+ help=(
216
+ "Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`."
217
+ " See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state"
218
+ " for more details"
219
+ ),
220
+ )
221
+ parser.add_argument(
222
+ "--resume_from_checkpoint",
223
+ type=str,
224
+ default=None,
225
+ help=(
226
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
227
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
228
+ ),
229
+ )
230
+ parser.add_argument(
231
+ "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
232
+ )
233
+ parser.add_argument(
234
+ "--gradient_accumulation_steps",
235
+ type=int,
236
+ default=1,
237
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
238
+ )
239
+ parser.add_argument(
240
+ "--text_loss_reduction",
241
+ type=str,
242
+ default="mean",
243
+ help="The reduction method for the text loss. One of ['mean', 'reweighted']",
244
+ )
245
+ parser.add_argument(
246
+ "--text_loss_weight",
247
+ type=float,
248
+ default=0.2,
249
+ )
250
+ parser.add_argument(
251
+ "--learning_rate",
252
+ type=float,
253
+ default=0.0003,
254
+ help="Initial learning rate (after the potential warmup period) to use.",
255
+ )
256
+ parser.add_argument(
257
+ "--scale_lr",
258
+ action="store_true",
259
+ default=False,
260
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
261
+ )
262
+ parser.add_argument(
263
+ "--lr_scheduler",
264
+ type=str,
265
+ default="constant",
266
+ help=(
267
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
268
+ ' "constant", "constant_with_warmup"]'
269
+ ),
270
+ )
271
+ parser.add_argument(
272
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
273
+ )
274
+ parser.add_argument(
275
+ "--validation_steps",
276
+ type=int,
277
+ default=100,
278
+ help=(
279
+ "Run validation every X steps. Validation consists of running the prompt"
280
+ " `args.validation_prompt` multiple times: `args.num_validation_images`"
281
+ " and logging the images."
282
+ ),
283
+ )
284
+ parser.add_argument(
285
+ "--mixed_precision",
286
+ type=str,
287
+ default=None,
288
+ choices=["no", "fp16", "bf16"],
289
+ help=(
290
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
291
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
292
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
293
+ ),
294
+ )
295
+ parser.add_argument(
296
+ "--report_to",
297
+ type=str,
298
+ default="wandb",
299
+ help=(
300
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
301
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
302
+ ),
303
+ )
304
+ parser.add_argument("--validation_prompts", type=str, nargs="*")
305
+ parser.add_argument("--validation_images", type=str, default="./assets")
306
+ parser.add_argument(
307
+ "--resolution",
308
+ type=int,
309
+ default=512,
310
+ help=(
311
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
312
+ " resolution"
313
+ ),
314
+ )
315
+ parser.add_argument("--split_vae_encode", type=int, required=False, default=None)
316
+ parser.add_argument("--min_masking_rate", type=float, default=0.0)
317
+ parser.add_argument("--cond_dropout_prob", type=float, default=0.0)
318
+ parser.add_argument("--max_grad_norm", default=50.0, type=float, help="Max gradient norm.", required=False)
319
+ parser.add_argument("--use_lora", action="store_true", help="Fine tune the model using LoRa")
320
+ parser.add_argument("--text_encoder_use_lora", action="store_true", help="Fine tune the model using LoRa")
321
+ parser.add_argument("--lora_r", default=16, type=int)
322
+ parser.add_argument("--lora_alpha", default=32, type=int)
323
+ parser.add_argument("--lora_target_modules", default=["to_q", "to_k", "to_v"], type=str, nargs="+")
324
+ parser.add_argument("--text_encoder_lora_r", default=16, type=int)
325
+ parser.add_argument("--text_encoder_lora_alpha", default=32, type=int)
326
+ parser.add_argument("--text_encoder_lora_target_modules", default=["to_q", "to_k", "to_v"], type=str, nargs="+")
327
+ parser.add_argument("--train_text_encoder", action="store_true")
328
+ parser.add_argument("--image_to_text_only", action="store_true")
329
+ parser.add_argument("--image_key", type=str, required=False)
330
+ parser.add_argument("--prompt_key", type=str, required=False)
331
+ parser.add_argument(
332
+ "--gradient_checkpointing",
333
+ action="store_true",
334
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
335
+ )
336
+ parser.add_argument("--prompt_prefix", type=str, required=False, default=None)
337
+
338
+ args = parser.parse_args()
339
+
340
+ if args.report_to == "wandb":
341
+ if not is_wandb_available():
342
+ raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
343
+
344
+ num_datasources = sum(
345
+ [x is not None for x in [args.instance_data_dir, args.instance_data_image]]
346
+ )
347
+
348
+ if num_datasources != 1:
349
+ raise ValueError(
350
+ "provide one and only one of `--instance_data_dir`, `--instance_data_image`, or `--instance_data_dataset`"
351
+ )
352
+
353
+ if args.instance_data_dir is not None:
354
+ if not os.path.exists(args.instance_data_dir):
355
+ raise ValueError(f"Does not exist: `--args.instance_data_dir` {args.instance_data_dir}")
356
+
357
+ if args.instance_data_image is not None:
358
+ if not os.path.exists(args.instance_data_image):
359
+ raise ValueError(f"Does not exist: `--args.instance_data_image` {args.instance_data_image}")
360
+
361
+ return args
362
+
363
+ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
364
+ latent_image_ids = torch.zeros(height // 2, width // 2, 3)
365
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
366
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
367
+
368
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
369
+
370
+ latent_image_ids = latent_image_ids.reshape(
371
+ latent_image_id_height * latent_image_id_width, latent_image_id_channels
372
+ )
373
+
374
+ return latent_image_ids.to(device=device, dtype=dtype)
375
+
376
+ def main(args):
377
+ if args.allow_tf32:
378
+ torch.backends.cuda.matmul.allow_tf32 = True
379
+
380
+ logging_dir = Path(args.output_dir, args.logging_dir)
381
+
382
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
383
+
384
+ accelerator = Accelerator(
385
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
386
+ mixed_precision=args.mixed_precision,
387
+ log_with=args.report_to,
388
+ project_config=accelerator_project_config,
389
+ )
390
+
391
+ if accelerator.is_main_process:
392
+ os.makedirs(args.output_dir, exist_ok=True)
393
+
394
+ # Make one log on every process with the configuration for debugging.
395
+ logging.basicConfig(
396
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
397
+ datefmt="%m/%d/%Y %H:%M:%S",
398
+ level=logging.INFO,
399
+ )
400
+ logger.info(accelerator.state, main_process_only=False)
401
+
402
+ # if accelerator.is_main_process:
403
+ # accelerator.init_trackers("meissonic", config=vars(copy.deepcopy(args)))
404
+
405
+ if args.seed is not None:
406
+ set_seed(args.seed)
407
+
408
+ # Initialize image processor and image encoder (CLIP Vision Model with Projection)
409
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(
410
+ args.clip_model_name_or_path
411
+ )
412
+ image_processor = CLIPImageProcessor.from_pretrained(
413
+ args.clip_model_name_or_path
414
+ )
415
+
416
+ # Freeze image encoder parameters
417
+ image_encoder.requires_grad_(False)
418
+
419
+ text_encoder = CLIPTextModelWithProjection.from_pretrained(
420
+ args.clip_model_name_or_path
421
+ )
422
+ tokenizer = CLIPTokenizer.from_pretrained(
423
+ args.clip_model_name_or_path
424
+ )
425
+
426
+ text_encoder.requires_grad_(False)
427
+
428
+ if args.text_encoder_architecture == "open_clip":
429
+ tokenizer_2 = None
430
+ text_encoder_2 = None
431
+
432
+ mask_token = "<mask>"
433
+ num_new_tokens = tokenizer.add_tokens(mask_token)
434
+ mask_id_1 = tokenizer.convert_tokens_to_ids(mask_token)
435
+ if num_new_tokens > 0:
436
+ text_encoder.resize_token_embeddings(len(tokenizer))
437
+ mask_token_embedding = text_encoder.get_input_embeddings().weight[mask_id_1]
438
+ mask_token_embedding = mask_token_embedding.clone().detach().cpu().float()
439
+ if accelerator.is_main_process:
440
+ print("Saving masked token embedding...")
441
+ torch.save(mask_token_embedding, os.path.join(args.output_dir, "mask_token_embedding.pth"))
442
+
443
+ elif args.text_encoder_architecture == "gemma":
444
+ tokenizer_2 = GemmaTokenizerFast.from_pretrained(
445
+ args.text_encoder_name_or_path,
446
+ )
447
+ text_encoder_2 = Gemma2Model.from_pretrained(
448
+ args.text_encoder_name_or_path,
449
+ )
450
+
451
+ mask_token = "<mask>"
452
+
453
+ num_new_tokens = tokenizer_2.add_tokens(mask_token)
454
+ mask_id_2 = tokenizer_2.convert_tokens_to_ids(mask_token)
455
+
456
+ if num_new_tokens > 0:
457
+ text_encoder_2.resize_token_embeddings(len(tokenizer_2))
458
+ mask_token_embedding = text_encoder_2.get_input_embeddings().weight[mask_id_2]
459
+ mask_token_embedding = mask_token_embedding.clone().detach().cpu().float()
460
+ if accelerator.is_main_process:
461
+ print("Saving masked token embedding...")
462
+ torch.save(mask_token_embedding, os.path.join(args.output_dir, "mask_token_embedding.pth"))
463
+
464
+ text_encoder_2.requires_grad_(False)
465
+ else:
466
+ raise ValueError(f"Unknown text encoder architecture: {args.text_encoder_architecture}")
467
+
468
+ vq_model = VQModel.from_pretrained(
469
+ args.pretrained_model_name_or_path, subfolder="vqvae", revision=args.revision, variant=args.variant
470
+ )
471
+ vq_model.requires_grad_(False)
472
+
473
+ model = SymmetricTransformer2DModel.from_pretrained(
474
+ args.pretrained_transformer_path,
475
+ subfolder="transformer",
476
+ low_cpu_mem_usage=False,
477
+ device_map=None
478
+ )
479
+
480
+ if model.config.tokenizer_vocab_size is None:
481
+ if args.text_encoder_architecture == "open_clip":
482
+ model.register_to_config(tokenizer_vocab_size=len(tokenizer))
483
+ elif args.text_encoder_architecture == "gemma":
484
+ model.register_to_config(tokenizer_vocab_size=len(tokenizer_2))
485
+ else:
486
+ raise ValueError(f"Unknown text encoder architecture!")
487
+
488
+ if accelerator.is_main_process:
489
+ print(f"model's tokenizer vocab size is {model.config.tokenizer_vocab_size}")
490
+
491
+ model.text_decoder = nn.Sequential(
492
+ nn.LayerNorm(model.inner_dim, elementwise_affine=False, eps=1e-6),
493
+ nn.Linear(model.inner_dim, model.config.tokenizer_vocab_size, bias=False)
494
+ )
495
+
496
+ model = torch.compile(model)
497
+
498
+ if args.use_lora:
499
+ lora_config = LoraConfig(
500
+ r=args.lora_r,
501
+ lora_alpha=args.lora_alpha,
502
+ target_modules=args.lora_target_modules,
503
+ )
504
+ model.add_adapter(lora_config)
505
+
506
+ model.train()
507
+
508
+ if args.image_to_text_only:
509
+ frozen_keys = ["project_from_hidden", "up_block", "mlm_layer"]
510
+ for n, p in model.named_parameters():
511
+ if any([frozen_key in n for frozen_key in frozen_keys]):
512
+ p.requires_grad_(False)
513
+ else:
514
+ p.requires_grad_(True)
515
+ else:
516
+ model.requires_grad_(True)
517
+
518
+ if args.gradient_checkpointing:
519
+ model.enable_gradient_checkpointing()
520
+
521
+ def save_model_hook(models, weights, output_dir):
522
+ if accelerator.is_main_process:
523
+ transformer_lora_layers_to_save = None
524
+ text_encoder_lora_layers_to_save = None
525
+
526
+ for model_ in models:
527
+ if isinstance(model_, type(accelerator.unwrap_model(model))):
528
+ if args.use_lora:
529
+ transformer_lora_layers_to_save = get_peft_model_state_dict(model_)
530
+ else:
531
+ model_.save_pretrained(os.path.join(output_dir, "transformer"))
532
+ elif isinstance(model_, type(accelerator.unwrap_model(text_encoder))):
533
+ if args.text_encoder_use_lora:
534
+ text_encoder_lora_layers_to_save = get_peft_model_state_dict(model_)
535
+ else:
536
+ model_.save_pretrained(os.path.join(output_dir, "text_encoder"))
537
+ else:
538
+ raise ValueError(f"unexpected save model: {model_.__class__}")
539
+
540
+ # make sure to pop weight so that corresponding model is not saved again
541
+ weights.pop()
542
+
543
+ if transformer_lora_layers_to_save is not None or text_encoder_lora_layers_to_save is not None:
544
+ LoraLoaderMixin.save_lora_weights(
545
+ output_dir,
546
+ unet_lora_layers=transformer_lora_layers_to_save,
547
+ text_encoder_lora_layers=text_encoder_lora_layers_to_save,
548
+ )
549
+
550
+
551
+ def load_model_hook(models, input_dir):
552
+ transformer = None
553
+ text_encoder_ = None
554
+
555
+ # this part is added for keep consistency when add model.compile() in the model
556
+ def adap_compile(ori_dict):#add '_orig_mod.' to each key
557
+ new_dict = {}
558
+ for k,v in ori_dict.items():
559
+ new_dict['_orig_mod.' + k] = v
560
+ return new_dict
561
+
562
+ while len(models) > 0:
563
+ model_ = models.pop()
564
+
565
+ if isinstance(model_, type(accelerator.unwrap_model(model))):
566
+ if args.use_lora:
567
+ transformer = model_
568
+ else:
569
+ load_model = SymmetricTransformer2DModel.from_pretrained(os.path.join(input_dir, "transformer"), low_cpu_mem_usage=False, device_map=None)
570
+ model_.load_state_dict(adap_compile(load_model.state_dict()))
571
+ del load_model
572
+ elif isinstance(model_, type(accelerator.unwrap_model(text_encoder))):
573
+ if args.text_encoder_use_lora:
574
+ text_encoder_ = model_
575
+ else:
576
+ try:
577
+ load_model = CLIPTextModelWithProjection.from_pretrained(os.path.join(input_dir, "text_encoder"))
578
+ model_.load_state_dict(load_model.state_dict())
579
+ # print('finished loading text encoder!')
580
+ except:
581
+ print('Not found text-encoder model in current folder. So we download one text encoder from Internet.')
582
+ load_model = CLIPTextModelWithProjection.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K")
583
+ model_.load_state_dict(load_model.state_dict())
584
+ del load_model
585
+ else:
586
+ raise ValueError(f"unexpected save model: {model.__class__}")
587
+
588
+ if transformer is not None or text_encoder_ is not None:
589
+ lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir)
590
+ LoraLoaderMixin.load_lora_into_text_encoder(
591
+ lora_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_
592
+ )
593
+ LoraLoaderMixin.load_lora_into_transformer(
594
+ lora_state_dict, network_alphas=network_alphas, transformer=transformer
595
+ )
596
+
597
+ accelerator.register_load_state_pre_hook(load_model_hook)
598
+ accelerator.register_save_state_pre_hook(save_model_hook)
599
+
600
+ if args.scale_lr:
601
+ args.learning_rate = (
602
+ args.learning_rate * args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
603
+ )
604
+
605
+ if args.use_8bit_adam:
606
+ try:
607
+ import bitsandbytes as bnb
608
+ except ImportError:
609
+ raise ImportError(
610
+ "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
611
+ )
612
+
613
+ optimizer_cls = bnb.optim.AdamW8bit
614
+ else:
615
+ optimizer_cls = torch.optim.AdamW
616
+
617
+ optimizer_grouped_parameters = [
618
+ {
619
+ "params": [p for p in model.parameters() if p.requires_grad],
620
+ "weight_decay": args.adam_weight_decay,
621
+ }
622
+ ]
623
+ optimizer = optimizer_cls(
624
+ optimizer_grouped_parameters,
625
+ lr=args.learning_rate,
626
+ betas=(args.adam_beta1, args.adam_beta2),
627
+ weight_decay=args.adam_weight_decay,
628
+ eps=args.adam_epsilon,
629
+ )
630
+
631
+ logger.info("Creating dataloaders and lr_scheduler")
632
+
633
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
634
+
635
+ if args.text_encoder_architecture in ("t5_clip", "gemma"):
636
+ tokenizer_for_dataset = [tokenizer, tokenizer_2]
637
+ else:
638
+ tokenizer_for_dataset = tokenizer
639
+
640
+ if args.instance_dataset == "ImageCaptionLargeDataset":
641
+ dataset = ImageCaptionLargeDataset(
642
+ root_dir=args.instance_data_dir,
643
+ tokenizer=tokenizer_for_dataset,
644
+ size=args.resolution,
645
+ text_encoder_architecture=args.text_encoder_architecture
646
+ )
647
+ elif args.instance_dataset == "DATA_TYPE":
648
+ raise NotImplementedError("DATA_TYPE is not yet supported")
649
+ else:
650
+ assert False
651
+
652
+ def collate_fn(samples):
653
+ images = [sample["image"] for sample in samples]
654
+ micro_conds = [sample["micro_conds"] for sample in samples]
655
+
656
+ images = torch.stack(images, dim=0)
657
+ micro_conds = torch.stack(micro_conds, dim=0)
658
+
659
+ if isinstance(samples[0]["prompt_input_ids"], list):
660
+ input_ids = [sample["prompt_input_ids"][0] for sample in samples]
661
+ input_ids_2 = [sample["prompt_input_ids"][1] for sample in samples]
662
+
663
+ input_ids = torch.cat(input_ids, dim=0)
664
+ input_ids_2 = torch.cat(input_ids_2, dim=0)
665
+ prompt_input_ids = [input_ids, input_ids_2]
666
+ else:
667
+ input_ids = [sample["prompt_input_ids"] for sample in samples]
668
+
669
+ input_ids = torch.cat(input_ids, dim=0)
670
+ prompt_input_ids = input_ids
671
+
672
+ ret = dict(
673
+ images=images,
674
+ micro_conds=micro_conds,
675
+ prompt_input_ids=prompt_input_ids,
676
+ )
677
+
678
+ return ret
679
+
680
+ train_dataloader = DataLoader(
681
+ dataset,
682
+ batch_size=args.train_batch_size,
683
+ shuffle=True,
684
+ num_workers=args.dataloader_num_workers,
685
+ collate_fn=collate_fn,
686
+ pin_memory=True,
687
+ )
688
+ train_dataloader.num_batches = len(train_dataloader)
689
+
690
+ lr_scheduler = diffusers.optimization.get_scheduler(
691
+ args.lr_scheduler,
692
+ optimizer=optimizer,
693
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
694
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
695
+ )
696
+
697
+ logger.info("Preparing model, optimizer and dataloaders")
698
+
699
+ model, optimizer, lr_scheduler, train_dataloader = accelerator.prepare(
700
+ model, optimizer, lr_scheduler, train_dataloader
701
+ )
702
+
703
+ train_dataloader.num_batches = len(train_dataloader)
704
+
705
+ weight_dtype = torch.float32
706
+ if accelerator.mixed_precision == "fp16":
707
+ weight_dtype = torch.float16
708
+ elif accelerator.mixed_precision == "bf16":
709
+ weight_dtype = torch.bfloat16
710
+
711
+ image_encoder.to(device=accelerator.device, dtype=weight_dtype)
712
+
713
+ if args.text_encoder_architecture in ("t5_clip", "gemma"):
714
+ text_encoder.to(device=accelerator.device, dtype=weight_dtype)
715
+ text_encoder_2.to(device=accelerator.device, dtype=weight_dtype)
716
+ else:
717
+ text_encoder.to(device=accelerator.device, dtype=weight_dtype)
718
+
719
+ vq_model.to(device=accelerator.device)
720
+
721
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
722
+ num_update_steps_per_epoch = math.ceil(train_dataloader.num_batches / args.gradient_accumulation_steps)
723
+ # Afterwards we recalculate our number of training epochs.
724
+ # Note: We are not doing epoch based training here, but just using this for book keeping and being able to
725
+ # reuse the same training loop with other datasets/loaders.
726
+ num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
727
+
728
+ # Train!
729
+ logger.info("***** Running training *****")
730
+ logger.info(f" Num training steps = {args.max_train_steps}")
731
+ logger.info(f" Instantaneous batch size per device = { args.train_batch_size}")
732
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
733
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
734
+
735
+ resume_from_checkpoint = args.resume_from_checkpoint
736
+ if resume_from_checkpoint:
737
+ if resume_from_checkpoint == "latest":
738
+ # Get the most recent checkpoint
739
+ dirs = os.listdir(args.output_dir)
740
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
741
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
742
+ if len(dirs) > 0:
743
+ resume_from_checkpoint = os.path.join(args.output_dir, dirs[-1])
744
+ else:
745
+ resume_from_checkpoint = None
746
+
747
+ if resume_from_checkpoint is None:
748
+ accelerator.print(
749
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
750
+ )
751
+ else:
752
+ accelerator.print(f"Resuming from checkpoint {resume_from_checkpoint}")
753
+
754
+ if resume_from_checkpoint is None:
755
+ global_step = 0
756
+ first_epoch = 0
757
+ else:
758
+ accelerator.load_state(resume_from_checkpoint)
759
+ global_step = int(os.path.basename(resume_from_checkpoint).split("-")[1])
760
+ first_epoch = global_step // num_update_steps_per_epoch
761
+
762
+ # As stated above, we are not doing epoch based training here, but just using this for book keeping and being able to
763
+ # reuse the same training loop with other datasets/loaders.
764
+ for epoch in range(first_epoch, num_train_epochs):
765
+ for batch in train_dataloader:
766
+ torch.cuda.empty_cache()
767
+ with torch.no_grad():
768
+ micro_conds = batch["micro_conds"].to(accelerator.device, non_blocking=True)
769
+ gen_micro_conds, und_micro_conds = micro_conds.chunk(2, dim=0)
770
+
771
+ pixel_values = batch["images"].to(accelerator.device, non_blocking=True) # [b, 3, res, res]
772
+ batch_size = pixel_values.shape[0]
773
+ half_batch_size = batch_size // 2
774
+
775
+ # ====================== tokenize images ======================
776
+ image_tokens = vq_model.quantize(
777
+ vq_model.encode(pixel_values).latents
778
+ )[2][2].reshape(batch_size, -1) # [b, seq_len]
779
+ # ====================== tokenize images ======================
780
+
781
+
782
+ # ====================== tokenize text prompts ======================
783
+ if args.text_encoder_architecture == "gemma":
784
+ prompt_input_ids_1 = batch["prompt_input_ids"][0].to(accelerator.device, non_blocking=True)
785
+ prompt_input_ids_2 = batch["prompt_input_ids"][1].to(accelerator.device, non_blocking=True)
786
+
787
+ prompt_input_ids_gen_1, prompt_input_ids_und_1 = prompt_input_ids_1.chunk(2, dim=0)
788
+ prompt_input_ids_gen_2, prompt_input_ids_und_2 = prompt_input_ids_2.chunk(2, dim=0)
789
+ else:
790
+ prompt_input_ids = batch["prompt_input_ids"].to(accelerator.device, non_blocking=True)
791
+ prompt_input_ids_gen, prompt_input_ids_und = prompt_input_ids.chunk(2, dim=0)
792
+ # ====================== tokenize text prompts ======================
793
+
794
+
795
+ # ====================== image perturbation ======================
796
+ image_tokens_gen, image_tokens_und = image_tokens.chunk(2, dim=0) # (b // 2, seq_len)
797
+ _, seq_len = image_tokens_gen.shape
798
+ sigma = torch.rand(half_batch_size, device=image_tokens_gen.device)
799
+ gen_mask_prob = torch.cos(sigma * math.pi * 0.5)
800
+ gen_mask_prob = gen_mask_prob.clip(args.min_masking_rate)
801
+
802
+ num_token_masked = (seq_len * gen_mask_prob).round().clamp(min=1)
803
+ batch_randperm = torch.rand(half_batch_size, seq_len, device=image_tokens_gen.device).argsort(dim=-1)
804
+ mask = batch_randperm < num_token_masked.unsqueeze(-1)
805
+
806
+ mask_id = accelerator.unwrap_model(model).config.vocab_size - 1
807
+ masked_image_ids = torch.where(mask, mask_id, image_tokens_gen)
808
+ image_labels = torch.where(mask, image_tokens_gen, -100)
809
+ # ====================== image perturbation ======================
810
+
811
+
812
+ # ====================== text perturbation ======================
813
+ if args.text_encoder_architecture in ("t5_clip", "gemma"):
814
+ half_batch_size, seq_len = prompt_input_ids_und_2.shape
815
+ sigma = torch.rand(half_batch_size, device=image_tokens_gen.device)
816
+ text_mask_prob = torch.cos(sigma * math.pi * 0.5)
817
+ text_mask_prob = text_mask_prob.clip(args.min_masking_rate)
818
+ text_timestep = text_mask_prob.clone()
819
+
820
+ num_token_masked = (seq_len * text_mask_prob).round().clamp(min=1)
821
+ batch_randperm = torch.rand(half_batch_size, seq_len, device=image_tokens_gen.device).argsort(dim=-1)
822
+ mask = batch_randperm < num_token_masked.unsqueeze(-1)
823
+
824
+ masked_prompt_input_ids_und = torch.where(mask, mask_id_2, prompt_input_ids_und_2)
825
+ text_labels = torch.where(mask, prompt_input_ids_und_2, -100)
826
+
827
+ else:
828
+ half_batch_size, seq_len = prompt_input_ids_und.shape
829
+ sigma = torch.rand(half_batch_size, device=image_tokens_gen.device)
830
+ text_mask_prob = torch.cos(sigma * math.pi * 0.5)
831
+ text_mask_prob = text_mask_prob.clip(args.min_masking_rate)
832
+ text_timestep = text_mask_prob.clone().clamp(min=1e-3)
833
+
834
+ num_token_masked = (seq_len * text_mask_prob).round().clamp(min=1)
835
+ batch_randperm = torch.rand(half_batch_size, seq_len, device=image_tokens_gen.device).argsort(dim=-1)
836
+ mask = batch_randperm < num_token_masked.unsqueeze(-1)
837
+
838
+ masked_prompt_input_ids_und = torch.where(mask, mask_id_1, prompt_input_ids_und)
839
+ text_labels = torch.where(mask, prompt_input_ids_und, -100)
840
+ # ====================== text perturbation ======================
841
+
842
+
843
+ # ====================== encode text prompts ======================
844
+ if args.text_encoder_architecture == "gemma":
845
+ masked_encoder_hidden_states, _ = encode_prompt(
846
+ [text_encoder, text_encoder_2],
847
+ [prompt_input_ids_und_1, masked_prompt_input_ids_und],
848
+ args.text_encoder_architecture
849
+ )
850
+ encoder_hidden_states, text_pooled_embeds = encode_prompt(
851
+ [text_encoder, text_encoder_2],
852
+ [prompt_input_ids_gen_1, prompt_input_ids_gen_2],
853
+ args.text_encoder_architecture
854
+ )
855
+ else:
856
+ masked_encoder_hidden_states, _ = encode_prompt(
857
+ text_encoder,
858
+ masked_prompt_input_ids_und,
859
+ args.text_encoder_architecture
860
+ )
861
+ encoder_hidden_states, text_pooled_embeds = encode_prompt(
862
+ text_encoder,
863
+ prompt_input_ids_gen,
864
+ args.text_encoder_architecture
865
+ )
866
+ # obtain the cond_embeds through send pixel_values[half_batch_size:] to image_encoder, we use the clip pooled embedding as cond_embeds
867
+ # pixel_values need image process, the value of pixel values arange from 0 to 1
868
+ with torch.no_grad():
869
+ processed_pixel_values = image_processor(
870
+ pixel_values[half_batch_size:],
871
+ do_rescale=False,
872
+ do_resize=True,
873
+ do_normalize=True,
874
+ return_tensors="pt"
875
+ )["pixel_values"].to(image_encoder.device, dtype=image_encoder.dtype)
876
+ image_pooled_embeds = image_encoder(processed_pixel_values).image_embeds
877
+
878
+ # for text-to-image
879
+ encoder_hidden_states = encoder_hidden_states.to(accelerator.device, dtype=accelerator.unwrap_model(model).dtype)
880
+ text_pooled_embeds = text_pooled_embeds.to(accelerator.device, dtype=accelerator.unwrap_model(model).dtype)
881
+ # for image-to-text
882
+ masked_encoder_hidden_states = masked_encoder_hidden_states.to(accelerator.device, dtype=accelerator.unwrap_model(model).dtype)
883
+ image_pooled_embeds = image_pooled_embeds.to(accelerator.device, dtype=accelerator.unwrap_model(model).dtype)
884
+ # ====================== encode text prompts ======================
885
+
886
+ vae_scale_factor = 2 ** (len(vq_model.config.block_out_channels) - 1)
887
+ resolution = args.resolution // vae_scale_factor
888
+ masked_image_ids = masked_image_ids.reshape(half_batch_size, resolution, resolution)
889
+ image_ids = image_tokens_und.reshape(half_batch_size, resolution, resolution)
890
+
891
+ # Train Step
892
+ with accelerator.accumulate(model):
893
+ codebook_size = accelerator.unwrap_model(model).config.codebook_size
894
+ img_ids = _prepare_latent_image_ids(
895
+ masked_image_ids.shape[0],
896
+ masked_image_ids.shape[-2],
897
+ masked_image_ids.shape[-1],
898
+ masked_image_ids.device,
899
+ masked_image_ids.dtype
900
+ )
901
+
902
+ txt_ids = torch.zeros(encoder_hidden_states.shape[1], 3).to(device=masked_image_ids.device, dtype=masked_image_ids.dtype)
903
+
904
+ image_logits = (
905
+ model(
906
+ hidden_states=masked_image_ids, # should be (batch size, channel, height, width)
907
+ encoder_hidden_states=encoder_hidden_states, # should be (batch size, sequence_len, embed_dims)
908
+ micro_conds=gen_micro_conds,
909
+ pooled_projections=text_pooled_embeds, # should be (batch_size, projection_dim)
910
+ img_ids=img_ids,
911
+ txt_ids=txt_ids,
912
+ timestep=gen_mask_prob,
913
+ )[0]
914
+ .reshape(half_batch_size, codebook_size, -1)
915
+ .permute(0, 2, 1)
916
+ .reshape(-1, codebook_size)
917
+ )
918
+
919
+ image_loss = F.cross_entropy(
920
+ image_logits,
921
+ image_labels.view(-1),
922
+ ignore_index=-100,
923
+ reduction="mean",
924
+ )
925
+
926
+ text_logits = model(
927
+ hidden_states=image_ids, # should be (batch size, channel, height, width)
928
+ encoder_hidden_states=masked_encoder_hidden_states, # should be (batch size, sequence_len, embed_dims)
929
+ micro_conds=und_micro_conds,
930
+ pooled_projections=image_pooled_embeds, # should be (batch_size, projection_dim)
931
+ img_ids=img_ids,
932
+ txt_ids=txt_ids,
933
+ timestep=text_mask_prob,
934
+ )[1]
935
+ text_logits = text_logits.reshape(-1, accelerator.unwrap_model(model).config.tokenizer_vocab_size)
936
+
937
+ if args.text_loss_reduction == "mean":
938
+ text_loss = F.cross_entropy(
939
+ text_logits,
940
+ text_labels.view(-1),
941
+ ignore_index=-100,
942
+ reduction="mean",
943
+ )
944
+ elif args.text_loss_reduction == "reweighted":
945
+ text_loss = F.cross_entropy(
946
+ text_logits,
947
+ text_labels.view(-1),
948
+ ignore_index=-100,
949
+ reduction="none",
950
+ )
951
+ text_loss = text_loss.reshape(half_batch_size, -1).mean(-1)
952
+ text_loss = text_loss / text_timestep
953
+ text_loss = text_loss.mean()
954
+ else:
955
+ raise ValueError(f"Unknown text_loss_reduction: {args.text_loss_reduction}")
956
+
957
+ loss = image_loss + args.text_loss_weight * text_loss
958
+
959
+ # Gather the losses across all processes for logging (if we use distributed training).
960
+ avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
961
+ avg_masking_rate = accelerator.gather(gen_mask_prob.repeat(args.train_batch_size)).mean()
962
+
963
+ accelerator.backward(loss)
964
+
965
+ if args.max_grad_norm is not None and accelerator.sync_gradients:
966
+ accelerator.clip_grad_norm_(model.parameters(), args.max_grad_norm)
967
+
968
+ optimizer.step()
969
+ lr_scheduler.step()
970
+
971
+ optimizer.zero_grad(set_to_none=True)
972
+
973
+ # Checks if the accelerator has performed an optimization step behind the scenes
974
+ if accelerator.sync_gradients:
975
+ if (global_step + 1) % args.logging_steps == 0:
976
+ logs = {
977
+ "step_loss": avg_loss.item(),
978
+ "lr": lr_scheduler.get_last_lr()[0],
979
+ "avg_masking_rate": avg_masking_rate.item(),
980
+ }
981
+ accelerator.log(logs, step=global_step + 1)
982
+
983
+ logger.info(
984
+ f"Step: {global_step + 1} "
985
+ f"Loss: {avg_loss.item():0.4f} "
986
+ f"LR: {lr_scheduler.get_last_lr()[0]:0.6f}"
987
+ )
988
+
989
+ if (global_step + 1) % args.checkpointing_steps == 0:
990
+ save_checkpoint(args, accelerator, global_step + 1, logger)
991
+
992
+ if (global_step + 1) % args.validation_steps == 0 and accelerator.is_main_process:
993
+
994
+ with torch.no_grad():
995
+ logger.info("Evaluating...")
996
+
997
+ model.eval()
998
+
999
+ scheduler = Scheduler.from_pretrained(
1000
+ args.pretrained_model_name_or_path,
1001
+ subfolder="scheduler",
1002
+ revision=args.revision,
1003
+ variant=args.variant,
1004
+ )
1005
+
1006
+ pipe = UnifiedPipeline_new(
1007
+ transformer=accelerator.unwrap_model(model),
1008
+ tokenizer=tokenizer,
1009
+ text_encoder=text_encoder,
1010
+ vqvae=vq_model,
1011
+ scheduler=scheduler,
1012
+ tokenizer_2=tokenizer_2,
1013
+ text_encoder_2=text_encoder_2,
1014
+ clip_image_processor=image_processor,
1015
+ image_encoder=image_encoder,
1016
+ )
1017
+
1018
+ if not args.image_to_text_only:
1019
+ output = pipe(
1020
+ prompt=args.validation_prompts,
1021
+ height=args.resolution,
1022
+ width=args.resolution,
1023
+ guidance_scale=9,
1024
+ num_inference_steps=64,
1025
+ )
1026
+ pil_images = output.images
1027
+
1028
+ result=[]
1029
+ for img in pil_images:
1030
+ if not isinstance(img, torch.Tensor):
1031
+ img = transforms.ToTensor()(img)
1032
+ result.append(img.unsqueeze(0))
1033
+ result = torch.cat(result,dim=0)
1034
+ result = make_grid(result, nrow=3)
1035
+ save_image(result,os.path.join(args.output_dir, str(global_step)+'_text2image_1024_CFG-9.png'))
1036
+
1037
+ image = load_images_to_tensor(args.validation_images, target_size=(args.resolution, args.resolution))
1038
+ output = pipe(
1039
+ height=args.resolution,
1040
+ width=args.resolution,
1041
+ guidance_scale=9,
1042
+ image=image,
1043
+ num_inference_steps=64
1044
+ )
1045
+ prompts = output.prompts
1046
+
1047
+ output_data = {
1048
+ "step": global_step,
1049
+ "prompts": prompts,
1050
+ }
1051
+
1052
+ with open(os.path.join(args.output_dir, f"image2text_{global_step}.json"), "w") as f:
1053
+ json.dump(output_data, f, indent=2)
1054
+
1055
+ model.train()
1056
+
1057
+ global_step += 1
1058
+
1059
+ # Stop training if max steps is reached
1060
+ if global_step >= args.max_train_steps:
1061
+ break
1062
+ # End for
1063
+
1064
+ accelerator.wait_for_everyone()
1065
+
1066
+ # Evaluate and save checkpoint at the end of training
1067
+ save_checkpoint(args, accelerator, global_step, logger)
1068
+
1069
+ # Save the final trained checkpoint
1070
+ if accelerator.is_main_process:
1071
+ model = accelerator.unwrap_model(model)
1072
+ model.save_pretrained(args.output_dir)
1073
+
1074
+ accelerator.end_training()
1075
+
1076
+
1077
+ if __name__ == "__main__":
1078
+ main(parse_args())
train/trainer_utils.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team and The MeissonFlow Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ import shutil
17
+ from pathlib import Path, PosixPath
18
+
19
+ import torch
20
+ from PIL import Image
21
+ from torchvision import transforms
22
+
23
+
24
+ def save_checkpoint(args, accelerator, global_step, logger):
25
+ output_dir = args.output_dir
26
+
27
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
28
+ if accelerator.is_main_process and args.checkpoints_total_limit is not None:
29
+ checkpoints = os.listdir(output_dir)
30
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
31
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
32
+
33
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
34
+ if len(checkpoints) >= args.checkpoints_total_limit:
35
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
36
+ removing_checkpoints = checkpoints[0:num_to_remove]
37
+
38
+ logger.info(
39
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
40
+ )
41
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
42
+
43
+ for removing_checkpoint in removing_checkpoints:
44
+ removing_checkpoint = os.path.join(output_dir, removing_checkpoint)
45
+ shutil.rmtree(removing_checkpoint)
46
+
47
+ save_path = Path(output_dir) / f"checkpoint-{global_step}"
48
+ accelerator.save_state(save_path)
49
+ logger.info(f"Saved state to {save_path}")
50
+
51
+
52
+ def load_images_to_tensor(path, target_size=(1024, 1024)):
53
+ """
54
+ Args:
55
+ folder_path
56
+ target_size: (height, width)
57
+
58
+ Return:
59
+ torch.Tensor: [B, 3, H, W] in [0, 1]
60
+ """
61
+ valid_extensions = ('.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.webp')
62
+
63
+ if isinstance(path, list):
64
+ image_files = path
65
+ elif isinstance(path, str) and os.path.isdir(path):
66
+ image_files = [f for f in os.listdir(path) if f.lower().endswith(valid_extensions)]
67
+ elif isinstance(path, str):
68
+ image_files = [path]
69
+ else:
70
+ raise ValueError(f"Unsupported folder_path type: {type(path)}")
71
+
72
+ if not image_files:
73
+ raise ValueError(f"No valid images found in {path}")
74
+
75
+ transform = transforms.Compose([
76
+ transforms.Resize(target_size),
77
+ transforms.ToTensor(),
78
+ ])
79
+
80
+ tensors = []
81
+ for img_file in image_files:
82
+ try:
83
+ if isinstance(path, str) and os.path.isdir(path):
84
+ img_path = os.path.join(path, img_file)
85
+ else:
86
+ img_path = img_file
87
+ img = Image.open(img_path).convert('RGB')
88
+ tensor = transform(img)
89
+ tensors.append(tensor)
90
+ except Exception as e:
91
+ print(f"Error processing {img_file}: {e}")
92
+
93
+ if not tensors:
94
+ raise ValueError("No images could be loaded")
95
+
96
+ batch_tensor = torch.stack(tensors)
97
+
98
+ assert batch_tensor.shape[1:] == (3, *target_size), \
99
+ f"Output shape is {batch_tensor.shape}, expected (B, 3, {target_size[0]}, {target_size[1]})"
100
+
101
+ return batch_tensor