kaupane commited on
Commit
c1b39e1
·
verified ·
1 Parent(s): 1af442b

Upload 5 files

Browse files

Initial commit.

Files changed (5) hide show
  1. app.py +245 -0
  2. ckpts/DiT_S_final.pth +3 -0
  3. mapping.py +235 -0
  4. models/DiT.py +184 -0
  5. requirements.txt +6 -0
app.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gradio as gr
3
+ from pyngrok import ngrok
4
+ import numpy as np
5
+ import os
6
+ import random
7
+ from mapping import reduced_genre_mapping, reduced_style_mapping, reverse_reduced_genre_mapping, reverse_reduced_style_mapping
8
+ from diffusers import AutoencoderKL
9
+ from models.DiT import DiT
10
+
11
+ # Global settings
12
+ num_timesteps = 1000
13
+ beta_start = 1e-4
14
+ beta_end = 0.02
15
+ latent_scale_factor = 0.18215 # Same as in DiTTrainer
16
+
17
+ # For tracking progress in UI
18
+ global_progress = 0
19
+
20
+ def load_dit_model(dit_size):
21
+ """Load DiT model of specified size"""
22
+ ckpt_path = f"./ckpts/DiT_{dit_size}_final.pth"
23
+ if not os.path.exists(ckpt_path):
24
+ raise FileNotFoundError(f"Checkpoint not found at {ckpt_path}")
25
+
26
+ # Configure model based on size
27
+ if dit_size == "S":
28
+ model = DiT(num_blocks=8, hidden_size=384, num_heads=6)
29
+ elif dit_size == "B":
30
+ model = DiT(num_blocks=12, hidden_size=640, num_heads=10)
31
+ elif dit_size == "L":
32
+ model = DiT(num_blocks=16, hidden_size=896, num_heads=14)
33
+ else:
34
+ raise ValueError(f"Invalid DiT size: {dit_size}")
35
+
36
+ # Load checkpoint
37
+ checkpoint = torch.load(ckpt_path, map_location="cpu")
38
+ model.load_state_dict(checkpoint["model_state_dict"])
39
+
40
+ return model
41
+
42
+ class DiffusionSampler:
43
+ def __init__(self, device="cuda" if torch.cuda.is_available() else "cpu"):
44
+ self.device = device
45
+ self.vae = None
46
+
47
+ # Pre-compute diffusion parameters
48
+ self.betas = torch.linspace(beta_start, beta_end, num_timesteps)
49
+ self.alphas = 1.0 - self.betas
50
+ self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
51
+ self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
52
+ self.sqrt_one_minus_alpha_cumprod = torch.sqrt(1.0 - self.alphas_cumprod)
53
+ self.sqrt_recip_alphas = torch.sqrt(1.0 / self.alphas)
54
+ self.alphas_cumprod_prev = torch.cat([torch.tensor([1.0]), self.alphas_cumprod[:-1]])
55
+ self.posterior_variance = self.betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
56
+
57
+ # Move to device
58
+ self.sqrt_alphas_cumprod = self.sqrt_alphas_cumprod.to(self.device)
59
+ self.sqrt_one_minus_alpha_cumprod = self.sqrt_one_minus_alpha_cumprod.to(self.device)
60
+ self.sqrt_recip_alphas = self.sqrt_recip_alphas.to(self.device)
61
+ self.betas = self.betas.to(self.device)
62
+ self.posterior_variance = self.posterior_variance.to(self.device)
63
+
64
+ def load_vae(self):
65
+ """Load VAE model (done lazily to save memory until needed)"""
66
+ if self.vae is None:
67
+ self.vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-ema").to(self.device)
68
+ self.vae.eval()
69
+
70
+ def generate_images(self, model, num_samples, genre, style, seed, progress=gr.Progress()):
71
+ """Generate images with the DiT model"""
72
+ global global_progress
73
+ global_progress = 0
74
+
75
+ # Set random seed for reproducibility
76
+ if seed is not None:
77
+ torch.manual_seed(seed)
78
+ np.random.seed(seed)
79
+ random.seed(seed)
80
+ # Also set CUDA seed if using GPU
81
+ if torch.cuda.is_available():
82
+ torch.cuda.manual_seed(seed)
83
+ torch.cuda.manual_seed_all(seed)
84
+
85
+ model.to(self.device)
86
+ model.eval()
87
+
88
+ # Convert genre and style to tensors
89
+ g_cond = torch.tensor([genre] * num_samples, device=self.device, dtype=torch.long)
90
+ s_cond = torch.tensor([style] * num_samples, device=self.device, dtype=torch.long)
91
+ g_null = torch.tensor([model.num_genres] * num_samples, device=self.device, dtype=torch.long)
92
+ s_null = torch.tensor([model.num_styles] * num_samples, device=self.device, dtype=torch.long)
93
+
94
+ # Start with random latents
95
+ latents = torch.randn((num_samples, 4, 32, 32), device=self.device)
96
+
97
+ # Use classifier-free guidance for better quality
98
+ cfg_scale = 2.5
99
+
100
+ # Go through the reverse diffusion process
101
+ timesteps = torch.arange(num_timesteps - 1, -1, -1, device=self.device)
102
+ total_steps = len(timesteps)
103
+
104
+ with torch.no_grad():
105
+ for i, t_val in enumerate(timesteps):
106
+ # Update progress
107
+ global_progress = int(100 * i / total_steps)
108
+ progress(global_progress / 100, desc="Generating images...")
109
+
110
+ t = torch.full((num_samples,), t_val, device=self.device, dtype=torch.long)
111
+
112
+ sqrt_recip_alphas_t = self.sqrt_recip_alphas[t].view(-1, 1, 1, 1)
113
+ sqrt_one_minus_alphas_cumprod_t = self.sqrt_one_minus_alpha_cumprod[t].view(-1, 1, 1, 1)
114
+ beta_t = self.betas[t].view(-1, 1, 1, 1)
115
+ posterior_variance_t = self.posterior_variance[t].view(-1, 1, 1, 1)
116
+
117
+ # Get noise prediction with classifier-free guidance
118
+ eps_theta_cond = model(latents, t, g_cond, s_cond)
119
+ eps_theta_uncond = model(latents, t, g_null, s_null)
120
+ eps_theta = eps_theta_uncond + cfg_scale * (eps_theta_cond - eps_theta_uncond)
121
+
122
+ # Update latents
123
+ mean = sqrt_recip_alphas_t * (latents - (beta_t / sqrt_one_minus_alphas_cumprod_t) * eps_theta)
124
+ noise = torch.randn_like(latents)
125
+ if t_val == 0:
126
+ latents = mean
127
+ else:
128
+ latents = mean + torch.sqrt(posterior_variance_t) * noise
129
+
130
+ # Decode latents to images
131
+ self.load_vae()
132
+ latents = latents / self.vae.config.scaling_factor
133
+ latents = latents.to(self.device)
134
+
135
+ progress(0.95, desc="Decoding images...")
136
+ with torch.no_grad():
137
+ images = self.vae.decode(latents).sample
138
+ images = (images / 2 + 0.5).clamp(0, 1)
139
+ images = images.permute(0, 2, 3, 1).cpu().numpy()
140
+
141
+ progress(1.0, desc="Done!")
142
+ global_progress = 100
143
+
144
+ # Create image gallery with labels
145
+ gallery_images = []
146
+ for i in range(num_samples):
147
+ # Convert numpy array to PIL Image
148
+ img = (images[i] * 255).astype(np.uint8)
149
+ caption = f"Genre: {reverse_reduced_genre_mapping[genre]}, Style: {reverse_reduced_style_mapping[style]}"
150
+ if seed is not None:
151
+ caption += f" (Seed: {seed})"
152
+ gallery_images.append((img, caption))
153
+
154
+ return gallery_images
155
+
156
+ # Initialize sampler globally
157
+ sampler = DiffusionSampler()
158
+
159
+ def generate_random_seed():
160
+ """Generate a random seed between 0 and 2^32 - 1"""
161
+ return random.randint(0, 2**32 - 1)
162
+
163
+ def generate_samples(num_samples, dit_size, genre_name, style_name, seed, progress=gr.Progress()):
164
+ """Main function for Gradio interface"""
165
+ if num_samples < 1 or num_samples > 16:
166
+ return None, gr.update(value="Number of samples must be between 1 and 16", visible=True)
167
+
168
+ # Get genre and style IDs from mappings
169
+ genre_id = reduced_genre_mapping.get(genre_name)
170
+ style_id = reduced_style_mapping.get(style_name)
171
+
172
+ if genre_id is None:
173
+ return None, gr.update(value=f"Unknown genre: {genre_name}", visible=True)
174
+ if style_id is None:
175
+ return None, gr.update(value=f"Unknown style: {style_name}", visible=True)
176
+
177
+ try:
178
+ # Load model
179
+ progress(0.05, desc="Loading DiT model...")
180
+ model = load_dit_model(dit_size)
181
+
182
+ # Generate images
183
+ gallery_images = sampler.generate_images(model, num_samples, genre_id, style_id, seed, progress)
184
+
185
+ return gallery_images, gr.update(value="", visible=False)
186
+ except Exception as e:
187
+ error_msg = f"Error: {str(e)}"
188
+ return None, gr.update(value=error_msg, visible=True)
189
+
190
+ def clear_gallery():
191
+ """Clear the gallery display"""
192
+ return None, gr.update(value="", visible=False)
193
+
194
+ # Create the Gradio interface
195
+ with gr.Blocks(title="DiT Diffusion Model Generator", theme=gr.themes.Soft()) as app:
196
+ gr.Markdown("# DiT Diffusion Model Generator")
197
+ gr.Markdown("Generate art images using a Diffusion Transformer (DiT) model")
198
+
199
+ with gr.Row():
200
+ with gr.Column(scale=1):
201
+ num_samples = gr.Slider(minimum=1, maximum=16, value=4, step=1, label="Number of Samples", info="How many images to generate (1-16)")
202
+ dit_size = gr.Radio(choices=["S", "B", "L"], value="S", label="DiT Model Size", info="Larger models produce better quality but take longer")
203
+
204
+ genre_names = list(reduced_genre_mapping.keys())
205
+ style_names = list(reduced_style_mapping.keys())
206
+
207
+ # Sort alphabetically, ensuring 'None' is at top
208
+ genre_names.sort()
209
+
210
+ style_names.sort()
211
+
212
+ genre = gr.Dropdown(choices=genre_names, value="landscape", label="Art Genre")
213
+ style = gr.Dropdown(choices=style_names, value="impressionism", label="Art Style")
214
+
215
+ with gr.Row():
216
+ seed = gr.Number(label="Seed", value=generate_random_seed(), precision=0, info="Set for reproducible results")
217
+ reset_seed_btn = gr.Button("🎲 New Seed")
218
+
219
+ with gr.Row():
220
+ generate_btn = gr.Button("Generate Images", variant="primary")
221
+ clear_btn = gr.Button("🗑️ Clear Gallery")
222
+
223
+ progress_bar = gr.Progress(track_tqdm=True)
224
+
225
+ with gr.Column(scale=2):
226
+ output_gallery = gr.Gallery(label="Generated Images", columns=4, rows=4, object_fit="contain", height=600)
227
+ error_message = gr.Textbox(label="Error", visible=False, max_lines=3, container=True, elem_id="error_box")
228
+
229
+ # Seed reset button functionality
230
+ reset_seed_btn.click(generate_random_seed, inputs=[], outputs=[seed])
231
+
232
+ # Clear gallery button functionality
233
+ clear_btn.click(clear_gallery, inputs=[], outputs=[output_gallery, error_message])
234
+
235
+ # Connect components
236
+ generate_btn.click(
237
+ fn=generate_samples,
238
+ inputs=[num_samples, dit_size, genre, style, seed],
239
+ outputs=[output_gallery, error_message],
240
+ )
241
+
242
+
243
+
244
+ if __name__ == "__main__":
245
+ app.launch()
ckpts/DiT_S_final.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6f20bce4f40e4112f73fcd89dd2d5d9b7e4a5560f265ca0741af134d1a7ab355
3
+ size 264237994
mapping.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ genre_mapping = {
3
+ 'None': 0,
4
+ 'abstract': 1,
5
+ 'advertisement': 2,
6
+ 'allegorical painting': 3,
7
+ 'animal painting': 4,
8
+ 'battle painting': 5,
9
+ 'bijinga': 6,
10
+ 'bird-and-flower painting': 7,
11
+ 'calligraphy': 8,
12
+ 'capriccio': 9,
13
+ 'caricature': 10,
14
+ 'cityscape': 11,
15
+ 'cloudscape': 12,
16
+ 'design': 13,
17
+ 'figurative': 14,
18
+ 'flower painting': 15,
19
+ 'genre painting': 16,
20
+ 'history painting': 17,
21
+ 'illustration': 18,
22
+ 'interior': 19,
23
+ 'landscape': 20,
24
+ 'literary painting': 21,
25
+ 'marina': 22,
26
+ 'miniature': 23,
27
+ 'mythological painting': 24,
28
+ 'nude painting (nu)': 25,
29
+ 'panorama': 26,
30
+ 'pastorale': 27,
31
+ 'portrait': 28,
32
+ 'poster': 29,
33
+ 'quadratura': 30,
34
+ 'religious painting': 31,
35
+ 'self-portrait': 32,
36
+ 'shan shui': 33,
37
+ 'sketch and study': 34,
38
+ 'still life': 35,
39
+ 'symbolic painting': 36,
40
+ 'tessellation': 37,
41
+ 'urushi-e': 38,
42
+ 'vanitas': 39,
43
+ 'veduta': 40,
44
+ 'wildlife painting': 41,
45
+ 'yakusha-e': 42
46
+ }
47
+
48
+ style_mapping = {
49
+ 'Abstract Art': 0,
50
+ 'Abstract Expressionism': 1,
51
+ 'Academicism': 2,
52
+ 'Action painting': 3,
53
+ 'American Realism': 4,
54
+ 'Analytical Cubism': 5,
55
+ 'Analytical\xa0Realism': 6,
56
+ 'Art Brut': 7,
57
+ 'Art Deco': 8,
58
+ 'Art Informel': 9,
59
+ 'Art Nouveau (Modern)': 10,
60
+ 'Automatic Painting': 11,
61
+ 'Baroque': 12,
62
+ 'Biedermeier': 13,
63
+ 'Byzantine': 14,
64
+ 'Cartographic Art': 15,
65
+ 'Classicism': 16,
66
+ 'Cloisonism': 17,
67
+ 'Color Field Painting': 18,
68
+ 'Conceptual Art': 19,
69
+ 'Concretism': 20,
70
+ 'Constructivism': 21,
71
+ 'Contemporary Realism': 22,
72
+ 'Costumbrismo': 23,
73
+ 'Cubism': 24,
74
+ 'Cubo-Expressionism': 25,
75
+ 'Cubo-Futurism': 26,
76
+ 'Dada': 27,
77
+ 'Divisionism': 28,
78
+ 'Early Renaissance': 29,
79
+ 'Environmental (Land) Art': 30,
80
+ 'Existential Art': 31,
81
+ 'Expressionism': 32,
82
+ 'Fantastic Realism': 33,
83
+ 'Fauvism': 34,
84
+ 'Feminist Art': 35,
85
+ 'Figurative Expressionism': 36,
86
+ 'Futurism': 37,
87
+ 'Gongbi': 38,
88
+ 'Gothic': 39,
89
+ 'Hard Edge Painting': 40,
90
+ 'High Renaissance': 41,
91
+ 'Hyper-Realism': 42,
92
+ 'Ilkhanid': 43,
93
+ 'Impressionism': 44,
94
+ 'Indian Space painting': 45,
95
+ 'Ink and wash painting': 46,
96
+ 'International Gothic': 47,
97
+ 'Intimism': 48,
98
+ 'Japonism': 49,
99
+ 'Joseon Dynasty': 50,
100
+ 'Kinetic Art': 51,
101
+ 'Kitsch': 52,
102
+ 'Lettrism': 53,
103
+ 'Light and Space': 54,
104
+ 'Luminism': 55,
105
+ 'Lyrical Abstraction': 56,
106
+ 'Magic Realism': 57,
107
+ 'Mail Art': 58,
108
+ 'Mannerism (Late Renaissance)': 59,
109
+ 'Mechanistic Cubism': 60,
110
+ 'Metaphysical art': 61,
111
+ 'Minimalism': 62,
112
+ 'Miserabilism': 63,
113
+ 'Modernismo': 64,
114
+ 'Mosan art': 65,
115
+ 'Muralism': 66,
116
+ 'Nanga (Bunjinga)': 67,
117
+ 'Nats-Taliq': 68,
118
+ 'Native Art': 69,
119
+ 'Naturalism': 70,
120
+ 'Naïve Art (Primitivism)': 71,
121
+ 'Neo-Byzantine': 72,
122
+ 'Neo-Concretism': 73,
123
+ 'Neo-Dada': 74,
124
+ 'Neo-Expressionism': 75,
125
+ 'Neo-Figurative Art': 76,
126
+ 'Neo-Rococo': 77,
127
+ 'Neo-Romanticism': 78,
128
+ 'Neo-baroque': 79,
129
+ 'Neoclassicism': 80,
130
+ 'Neoplasticism': 81,
131
+ 'New Casualism': 82,
132
+ 'New European Painting': 83,
133
+ 'New Realism': 84,
134
+ 'Nihonga': 85,
135
+ 'None': 86,
136
+ 'Northern Renaissance': 87,
137
+ 'Nouveau Réalisme': 88,
138
+ 'Op Art': 89,
139
+ 'Orientalism': 90,
140
+ 'Orphism': 91,
141
+ 'Ottoman Period': 92,
142
+ 'Outsider art': 93,
143
+ 'Perceptism': 94,
144
+ 'Photorealism': 95,
145
+ 'Pointillism': 96,
146
+ 'Pop Art': 97,
147
+ 'Post-Impressionism': 98,
148
+ 'Post-Minimalism': 99,
149
+ 'Post-Painterly Abstraction': 100,
150
+ 'Poster Art Realism': 101,
151
+ 'Precisionism': 102,
152
+ 'Primitivism': 103,
153
+ 'Proto Renaissance': 104,
154
+ 'Purism': 105,
155
+ 'Rayonism': 106,
156
+ 'Realism': 107,
157
+ 'Regionalism': 108,
158
+ 'Renaissance': 109,
159
+ 'Rococo': 110,
160
+ 'Romanesque': 111,
161
+ 'Romanticism': 112,
162
+ 'Safavid Period': 113,
163
+ 'Shin-hanga': 114,
164
+ 'Social Realism': 115,
165
+ 'Socialist Realism': 116,
166
+ 'Spatialism': 117,
167
+ 'Spectralism': 118,
168
+ 'Street art': 119,
169
+ 'Suprematism': 120,
170
+ 'Surrealism': 121,
171
+ 'Symbolism': 122,
172
+ 'Synchromism': 123,
173
+ 'Synthetic Cubism': 124,
174
+ 'Synthetism': 125,
175
+ 'Sōsaku hanga': 126,
176
+ 'Tachisme': 127,
177
+ 'Tenebrism': 128,
178
+ 'Timurid Period': 129,
179
+ 'Tonalism': 130,
180
+ 'Transautomatism': 131,
181
+ 'Tubism': 132,
182
+ 'Ukiyo-e': 133,
183
+ 'Verism': 134,
184
+ 'Yamato-e': 135,
185
+ 'Zen': 136
186
+ }
187
+
188
+ reverse_genre_mapping = {v: k for k, v in genre_mapping.items()}
189
+ reverse_style_mapping = {v: k for k, v in style_mapping.items()}
190
+
191
+ reduced_genre_mapping = {
192
+ 'abstract': 1,
193
+ 'capriccio': 9,
194
+ 'cityscape': 11,
195
+ 'cloudscape': 12,
196
+ 'flower painting': 15,
197
+ 'genre painting': 16,
198
+ 'interior': 19,
199
+ 'landscape': 20,
200
+ 'marina': 22,
201
+ 'panorama': 26,
202
+ 'pastorale': 27,
203
+ 'quadratura': 30,
204
+ 'shan shui': 33,
205
+ 'sketch and study': 34,
206
+ 'still life': 35,
207
+ 'symbolic painting': 36,
208
+ 'tesselation': 37,
209
+ 'veduta': 40
210
+ }
211
+
212
+ reduced_style_mapping = {
213
+ 'abstract expressionism': 1,
214
+ 'art deco': 8,
215
+ 'art nouveau': 10,
216
+ 'baroque': 12,
217
+ 'conceptual art': 19,
218
+ 'cubism': 24,
219
+ 'expressionism': 32,
220
+ 'gothic': 39,
221
+ 'impressionism': 44,
222
+ 'minimalism': 62,
223
+ 'modernism': 64,
224
+ 'neoclassicism': 80,
225
+ 'pop-art': 97,
226
+ 'post-impressionism': 98,
227
+ 'renaissance': 109,
228
+ 'rococo': 110,
229
+ 'romanticism': 112,
230
+ 'surrealism': 121
231
+ }
232
+
233
+ reverse_reduced_genre_mapping = {v: k for k, v in reduced_genre_mapping.items()}
234
+ reverse_reduced_style_mapping = {v: k for k, v in reduced_style_mapping.items()}
235
+
models/DiT.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import math
5
+ from timm.models.vision_transformer import PatchEmbed
6
+
7
+ class TimestepEmbedder(nn.Module):
8
+ """Module to create timestep's embedding."""
9
+ def __init__(self,hidden_size,frequency_embedding_size=256):
10
+ super().__init__()
11
+ self.mlp = nn.Sequential(
12
+ nn.Linear(frequency_embedding_size,hidden_size),
13
+ nn.SiLU(),
14
+ nn.Linear(hidden_size,hidden_size)
15
+ )
16
+ self.frequency_embedding_size = frequency_embedding_size
17
+
18
+ def forward(self, t):
19
+ half = self.frequency_embedding_size // 2
20
+ freqs = torch.exp(
21
+ -math.log(10000) * torch.arange(start=0,end=half) / half
22
+ ).to(device=t.device)
23
+ args = torch.einsum('i,j->ij', t, freqs.to(t.device))
24
+ freqs = torch.cat([torch.cos(args),torch.sin(args)],dim=-1)
25
+ return self.mlp(freqs)
26
+
27
+ class ViTAttn(nn.Module):
28
+ def __init__(self,hidden_size,num_heads):
29
+ super().__init__()
30
+ self.attn = nn.MultiheadAttention(hidden_size,num_heads,bias=True,add_bias_kv=True,batch_first=True)
31
+
32
+ def forward(self,x):
33
+ attn, _ = self.attn(x,x,x)
34
+ return attn
35
+
36
+ class DiTBlock(nn.Module):
37
+ """
38
+ DiT Block with adaptive layer norm zero (adaLN-Zero) conditioning.
39
+ Using post-norm
40
+ """
41
+ def __init__(self,hidden_size,num_heads):
42
+ super().__init__()
43
+ self.norm1 = nn.LayerNorm(hidden_size,elementwise_affine=False,eps=1e-6)
44
+ self.attn = ViTAttn(hidden_size,num_heads)
45
+ self.norm2 = nn.LayerNorm(hidden_size,elementwise_affine=False,eps=1e-6)
46
+ self.mlp = nn.Sequential(
47
+ nn.Linear(hidden_size,4*hidden_size),
48
+ nn.GELU(approximate="tanh"),
49
+ nn.Linear(4*hidden_size,hidden_size)
50
+ )
51
+ self.adaLN = nn.Sequential(
52
+ nn.SiLU(),
53
+ nn.Linear(hidden_size,6*hidden_size)
54
+ )
55
+
56
+ def forward(self,x,c):
57
+ gamma_1,beta_1,alpha_1,gamma_2,beta_2,alpha_2 = self.adaLN(c).chunk(6,dim=1)
58
+ x = self.norm1(x + alpha_1.unsqueeze(1) * self.attn(x))
59
+ x = x * (1+gamma_1.unsqueeze(1)) + beta_1.unsqueeze(1)
60
+ x = self.norm2(x + alpha_2.unsqueeze(1) * self.mlp(x))
61
+ x = x * (1+gamma_2.unsqueeze(1)) + beta_2.unsqueeze(1)
62
+ return x
63
+
64
+ class DiT(nn.Module):
65
+ def __init__(self,
66
+ num_blocks=10,
67
+ hidden_size=640,
68
+ num_heads=10,
69
+ patch_size=2,
70
+ num_channels=4,
71
+ img_size=32,
72
+ num_genres=42,
73
+ num_styles=137):
74
+ super().__init__()
75
+ self.hidden_size = hidden_size
76
+ self.patch_size = patch_size
77
+ self.num_channels = num_channels
78
+ self.seq_len = (img_size // patch_size)**2
79
+ self.img_size = img_size
80
+ self.blocks = nn.ModuleList(
81
+ DiTBlock(hidden_size,num_heads) for _ in range(num_blocks)
82
+ )
83
+ self.timestep_embed = TimestepEmbedder(hidden_size)
84
+
85
+ self.num_genres = num_genres
86
+ self.num_styles = num_styles
87
+ self.genre_condition = nn.Embedding(num_genres+1,hidden_size) # +1 for null condition
88
+ self.style_condition = nn.Embedding(num_styles+1,hidden_size)
89
+
90
+ self.pos_embed = nn.Parameter(torch.zeros(1, self.seq_len, hidden_size))
91
+
92
+ patch_dim = num_channels * patch_size * patch_size
93
+ self.proj_in = nn.Linear(patch_dim,hidden_size)
94
+ self.proj_out = nn.Linear(hidden_size,patch_dim)
95
+
96
+ self.norm_out = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
97
+ self.adaLN_final = nn.Sequential(
98
+ nn.SiLU(),
99
+ nn.Linear(hidden_size, 2*hidden_size)
100
+ )
101
+
102
+ self.initialize_weights()
103
+
104
+ def initialize_weights(self):
105
+ nn.init.normal_(self.pos_embed, std=0.02)
106
+ nn.init.normal_(self.proj_out.weight, std=0.02)
107
+ nn.init.zeros_(self.proj_out.bias)
108
+ nn.init.normal_(self.proj_in.weight, std=0.02)
109
+ nn.init.zeros_(self.proj_in.bias)
110
+
111
+ nn.init.normal_(self.timestep_embed.mlp[0].weight, std=0.02)
112
+ nn.init.zeros_(self.timestep_embed.mlp[0].bias)
113
+ nn.init.normal_(self.timestep_embed.mlp[2].weight, std=0.02)
114
+ nn.init.zeros_(self.timestep_embed.mlp[2].bias)
115
+
116
+ for block in self.blocks:
117
+ nn.init.zeros_(block.adaLN[-1].weight)
118
+ nn.init.zeros_(block.adaLN[-1].bias)
119
+
120
+ nn.init.zeros_(self.adaLN_final[-1].weight)
121
+ nn.init.zeros_(self.adaLN_final[-1].bias)
122
+
123
+ nn.init.normal_(self.genre_condition.weight, std=0.02)
124
+ nn.init.normal_(self.style_condition.weight, std=0.02)
125
+
126
+ def patchify(self,z):
127
+ """
128
+ from (batch_size,6,32,32) -> (batch_size,256,24) -> (batch_size,256,hidden_size)
129
+ """
130
+ b,_,_,_ = z.shape
131
+ c = self.num_channels
132
+ p = self.patch_size
133
+ z = z.unfold(2,p,p).unfold(3,p,p) # (b,c,h//p,p,w//p,p)
134
+ z = z.contiguous().view(b,c,-1,p,p) # (b,c,hw//p**2,p,p)
135
+ z = torch.einsum('bcapq->bacpq',z).contiguous().view(b,-1,c*p**2) # (b,hw//p**2,c*p**2)
136
+ return self.proj_in(z) # (b,hw//p**2,hidden_size)
137
+
138
+ def unpatchify(self,z):
139
+ """
140
+ from (batch_size,256,hidden_size) -> (batch_size,256,24) -> (batch_size,6,32,32)
141
+ """
142
+ b,_,_ = z.shape
143
+ c = self.num_channels
144
+ p = self.patch_size
145
+ s = int(self.seq_len ** 0.5)
146
+ i = self.img_size
147
+ z = self.proj_out(z) # (b,hw//p**2,c*p**2)
148
+ z = z.view(b,s,s,c,p,p) # (b,h/p,w/p,c,p,p)
149
+ z = torch.einsum('befcpq->bcepfq',z) # (b,c,h/p,p,w/p,p)
150
+ z = z.contiguous().view(b,c,i,i)
151
+ return z
152
+
153
+
154
+ def forward(self,z,t,g,s):
155
+ t_embed = self.timestep_embed(t) # t_embed: (batch_size, hidden_size)
156
+ g_embed = self.genre_condition(g)
157
+ s_embed = self.style_condition(s)
158
+
159
+ c = t_embed + g_embed + s_embed
160
+
161
+ z = self.patchify(z)
162
+ z = z + self.pos_embed
163
+
164
+ for block in self.blocks:
165
+ z = block(z,c)
166
+
167
+ gamma, beta = self.adaLN_final(c).chunk(2,dim=-1)
168
+ z = self.norm_out(z)
169
+ z = z * (1+gamma.unsqueeze(1)) + beta.unsqueeze(1)
170
+
171
+ return self.unpatchify(z)
172
+
173
+
174
+
175
+
176
+ if __name__ == "__main__":
177
+ model = DiT(1,768,12,2,6,32)
178
+ z = torch.randn(2,6,32,32)
179
+ c = torch.randn(2,768)
180
+ t = torch.randint(0,1000,(2,))
181
+ output = model(z,c,t)
182
+ print(z.shape,c.shape,t.shape,output.shape)
183
+ output_cfg = model.forward_cfg(z,t)
184
+ print(output_cfg.shape)
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ torch
2
+ diffusers
3
+ gradio
4
+ numpy
5
+ tqdm
6
+ matplotlib