Spaces:
Sleeping
Sleeping
Upload 5 files
Browse filesInitial commit.
- app.py +245 -0
- ckpts/DiT_S_final.pth +3 -0
- mapping.py +235 -0
- models/DiT.py +184 -0
- requirements.txt +6 -0
app.py
ADDED
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import gradio as gr
|
3 |
+
from pyngrok import ngrok
|
4 |
+
import numpy as np
|
5 |
+
import os
|
6 |
+
import random
|
7 |
+
from mapping import reduced_genre_mapping, reduced_style_mapping, reverse_reduced_genre_mapping, reverse_reduced_style_mapping
|
8 |
+
from diffusers import AutoencoderKL
|
9 |
+
from models.DiT import DiT
|
10 |
+
|
11 |
+
# Global settings
|
12 |
+
num_timesteps = 1000
|
13 |
+
beta_start = 1e-4
|
14 |
+
beta_end = 0.02
|
15 |
+
latent_scale_factor = 0.18215 # Same as in DiTTrainer
|
16 |
+
|
17 |
+
# For tracking progress in UI
|
18 |
+
global_progress = 0
|
19 |
+
|
20 |
+
def load_dit_model(dit_size):
|
21 |
+
"""Load DiT model of specified size"""
|
22 |
+
ckpt_path = f"./ckpts/DiT_{dit_size}_final.pth"
|
23 |
+
if not os.path.exists(ckpt_path):
|
24 |
+
raise FileNotFoundError(f"Checkpoint not found at {ckpt_path}")
|
25 |
+
|
26 |
+
# Configure model based on size
|
27 |
+
if dit_size == "S":
|
28 |
+
model = DiT(num_blocks=8, hidden_size=384, num_heads=6)
|
29 |
+
elif dit_size == "B":
|
30 |
+
model = DiT(num_blocks=12, hidden_size=640, num_heads=10)
|
31 |
+
elif dit_size == "L":
|
32 |
+
model = DiT(num_blocks=16, hidden_size=896, num_heads=14)
|
33 |
+
else:
|
34 |
+
raise ValueError(f"Invalid DiT size: {dit_size}")
|
35 |
+
|
36 |
+
# Load checkpoint
|
37 |
+
checkpoint = torch.load(ckpt_path, map_location="cpu")
|
38 |
+
model.load_state_dict(checkpoint["model_state_dict"])
|
39 |
+
|
40 |
+
return model
|
41 |
+
|
42 |
+
class DiffusionSampler:
|
43 |
+
def __init__(self, device="cuda" if torch.cuda.is_available() else "cpu"):
|
44 |
+
self.device = device
|
45 |
+
self.vae = None
|
46 |
+
|
47 |
+
# Pre-compute diffusion parameters
|
48 |
+
self.betas = torch.linspace(beta_start, beta_end, num_timesteps)
|
49 |
+
self.alphas = 1.0 - self.betas
|
50 |
+
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
|
51 |
+
self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
|
52 |
+
self.sqrt_one_minus_alpha_cumprod = torch.sqrt(1.0 - self.alphas_cumprod)
|
53 |
+
self.sqrt_recip_alphas = torch.sqrt(1.0 / self.alphas)
|
54 |
+
self.alphas_cumprod_prev = torch.cat([torch.tensor([1.0]), self.alphas_cumprod[:-1]])
|
55 |
+
self.posterior_variance = self.betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
|
56 |
+
|
57 |
+
# Move to device
|
58 |
+
self.sqrt_alphas_cumprod = self.sqrt_alphas_cumprod.to(self.device)
|
59 |
+
self.sqrt_one_minus_alpha_cumprod = self.sqrt_one_minus_alpha_cumprod.to(self.device)
|
60 |
+
self.sqrt_recip_alphas = self.sqrt_recip_alphas.to(self.device)
|
61 |
+
self.betas = self.betas.to(self.device)
|
62 |
+
self.posterior_variance = self.posterior_variance.to(self.device)
|
63 |
+
|
64 |
+
def load_vae(self):
|
65 |
+
"""Load VAE model (done lazily to save memory until needed)"""
|
66 |
+
if self.vae is None:
|
67 |
+
self.vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-ema").to(self.device)
|
68 |
+
self.vae.eval()
|
69 |
+
|
70 |
+
def generate_images(self, model, num_samples, genre, style, seed, progress=gr.Progress()):
|
71 |
+
"""Generate images with the DiT model"""
|
72 |
+
global global_progress
|
73 |
+
global_progress = 0
|
74 |
+
|
75 |
+
# Set random seed for reproducibility
|
76 |
+
if seed is not None:
|
77 |
+
torch.manual_seed(seed)
|
78 |
+
np.random.seed(seed)
|
79 |
+
random.seed(seed)
|
80 |
+
# Also set CUDA seed if using GPU
|
81 |
+
if torch.cuda.is_available():
|
82 |
+
torch.cuda.manual_seed(seed)
|
83 |
+
torch.cuda.manual_seed_all(seed)
|
84 |
+
|
85 |
+
model.to(self.device)
|
86 |
+
model.eval()
|
87 |
+
|
88 |
+
# Convert genre and style to tensors
|
89 |
+
g_cond = torch.tensor([genre] * num_samples, device=self.device, dtype=torch.long)
|
90 |
+
s_cond = torch.tensor([style] * num_samples, device=self.device, dtype=torch.long)
|
91 |
+
g_null = torch.tensor([model.num_genres] * num_samples, device=self.device, dtype=torch.long)
|
92 |
+
s_null = torch.tensor([model.num_styles] * num_samples, device=self.device, dtype=torch.long)
|
93 |
+
|
94 |
+
# Start with random latents
|
95 |
+
latents = torch.randn((num_samples, 4, 32, 32), device=self.device)
|
96 |
+
|
97 |
+
# Use classifier-free guidance for better quality
|
98 |
+
cfg_scale = 2.5
|
99 |
+
|
100 |
+
# Go through the reverse diffusion process
|
101 |
+
timesteps = torch.arange(num_timesteps - 1, -1, -1, device=self.device)
|
102 |
+
total_steps = len(timesteps)
|
103 |
+
|
104 |
+
with torch.no_grad():
|
105 |
+
for i, t_val in enumerate(timesteps):
|
106 |
+
# Update progress
|
107 |
+
global_progress = int(100 * i / total_steps)
|
108 |
+
progress(global_progress / 100, desc="Generating images...")
|
109 |
+
|
110 |
+
t = torch.full((num_samples,), t_val, device=self.device, dtype=torch.long)
|
111 |
+
|
112 |
+
sqrt_recip_alphas_t = self.sqrt_recip_alphas[t].view(-1, 1, 1, 1)
|
113 |
+
sqrt_one_minus_alphas_cumprod_t = self.sqrt_one_minus_alpha_cumprod[t].view(-1, 1, 1, 1)
|
114 |
+
beta_t = self.betas[t].view(-1, 1, 1, 1)
|
115 |
+
posterior_variance_t = self.posterior_variance[t].view(-1, 1, 1, 1)
|
116 |
+
|
117 |
+
# Get noise prediction with classifier-free guidance
|
118 |
+
eps_theta_cond = model(latents, t, g_cond, s_cond)
|
119 |
+
eps_theta_uncond = model(latents, t, g_null, s_null)
|
120 |
+
eps_theta = eps_theta_uncond + cfg_scale * (eps_theta_cond - eps_theta_uncond)
|
121 |
+
|
122 |
+
# Update latents
|
123 |
+
mean = sqrt_recip_alphas_t * (latents - (beta_t / sqrt_one_minus_alphas_cumprod_t) * eps_theta)
|
124 |
+
noise = torch.randn_like(latents)
|
125 |
+
if t_val == 0:
|
126 |
+
latents = mean
|
127 |
+
else:
|
128 |
+
latents = mean + torch.sqrt(posterior_variance_t) * noise
|
129 |
+
|
130 |
+
# Decode latents to images
|
131 |
+
self.load_vae()
|
132 |
+
latents = latents / self.vae.config.scaling_factor
|
133 |
+
latents = latents.to(self.device)
|
134 |
+
|
135 |
+
progress(0.95, desc="Decoding images...")
|
136 |
+
with torch.no_grad():
|
137 |
+
images = self.vae.decode(latents).sample
|
138 |
+
images = (images / 2 + 0.5).clamp(0, 1)
|
139 |
+
images = images.permute(0, 2, 3, 1).cpu().numpy()
|
140 |
+
|
141 |
+
progress(1.0, desc="Done!")
|
142 |
+
global_progress = 100
|
143 |
+
|
144 |
+
# Create image gallery with labels
|
145 |
+
gallery_images = []
|
146 |
+
for i in range(num_samples):
|
147 |
+
# Convert numpy array to PIL Image
|
148 |
+
img = (images[i] * 255).astype(np.uint8)
|
149 |
+
caption = f"Genre: {reverse_reduced_genre_mapping[genre]}, Style: {reverse_reduced_style_mapping[style]}"
|
150 |
+
if seed is not None:
|
151 |
+
caption += f" (Seed: {seed})"
|
152 |
+
gallery_images.append((img, caption))
|
153 |
+
|
154 |
+
return gallery_images
|
155 |
+
|
156 |
+
# Initialize sampler globally
|
157 |
+
sampler = DiffusionSampler()
|
158 |
+
|
159 |
+
def generate_random_seed():
|
160 |
+
"""Generate a random seed between 0 and 2^32 - 1"""
|
161 |
+
return random.randint(0, 2**32 - 1)
|
162 |
+
|
163 |
+
def generate_samples(num_samples, dit_size, genre_name, style_name, seed, progress=gr.Progress()):
|
164 |
+
"""Main function for Gradio interface"""
|
165 |
+
if num_samples < 1 or num_samples > 16:
|
166 |
+
return None, gr.update(value="Number of samples must be between 1 and 16", visible=True)
|
167 |
+
|
168 |
+
# Get genre and style IDs from mappings
|
169 |
+
genre_id = reduced_genre_mapping.get(genre_name)
|
170 |
+
style_id = reduced_style_mapping.get(style_name)
|
171 |
+
|
172 |
+
if genre_id is None:
|
173 |
+
return None, gr.update(value=f"Unknown genre: {genre_name}", visible=True)
|
174 |
+
if style_id is None:
|
175 |
+
return None, gr.update(value=f"Unknown style: {style_name}", visible=True)
|
176 |
+
|
177 |
+
try:
|
178 |
+
# Load model
|
179 |
+
progress(0.05, desc="Loading DiT model...")
|
180 |
+
model = load_dit_model(dit_size)
|
181 |
+
|
182 |
+
# Generate images
|
183 |
+
gallery_images = sampler.generate_images(model, num_samples, genre_id, style_id, seed, progress)
|
184 |
+
|
185 |
+
return gallery_images, gr.update(value="", visible=False)
|
186 |
+
except Exception as e:
|
187 |
+
error_msg = f"Error: {str(e)}"
|
188 |
+
return None, gr.update(value=error_msg, visible=True)
|
189 |
+
|
190 |
+
def clear_gallery():
|
191 |
+
"""Clear the gallery display"""
|
192 |
+
return None, gr.update(value="", visible=False)
|
193 |
+
|
194 |
+
# Create the Gradio interface
|
195 |
+
with gr.Blocks(title="DiT Diffusion Model Generator", theme=gr.themes.Soft()) as app:
|
196 |
+
gr.Markdown("# DiT Diffusion Model Generator")
|
197 |
+
gr.Markdown("Generate art images using a Diffusion Transformer (DiT) model")
|
198 |
+
|
199 |
+
with gr.Row():
|
200 |
+
with gr.Column(scale=1):
|
201 |
+
num_samples = gr.Slider(minimum=1, maximum=16, value=4, step=1, label="Number of Samples", info="How many images to generate (1-16)")
|
202 |
+
dit_size = gr.Radio(choices=["S", "B", "L"], value="S", label="DiT Model Size", info="Larger models produce better quality but take longer")
|
203 |
+
|
204 |
+
genre_names = list(reduced_genre_mapping.keys())
|
205 |
+
style_names = list(reduced_style_mapping.keys())
|
206 |
+
|
207 |
+
# Sort alphabetically, ensuring 'None' is at top
|
208 |
+
genre_names.sort()
|
209 |
+
|
210 |
+
style_names.sort()
|
211 |
+
|
212 |
+
genre = gr.Dropdown(choices=genre_names, value="landscape", label="Art Genre")
|
213 |
+
style = gr.Dropdown(choices=style_names, value="impressionism", label="Art Style")
|
214 |
+
|
215 |
+
with gr.Row():
|
216 |
+
seed = gr.Number(label="Seed", value=generate_random_seed(), precision=0, info="Set for reproducible results")
|
217 |
+
reset_seed_btn = gr.Button("🎲 New Seed")
|
218 |
+
|
219 |
+
with gr.Row():
|
220 |
+
generate_btn = gr.Button("Generate Images", variant="primary")
|
221 |
+
clear_btn = gr.Button("🗑️ Clear Gallery")
|
222 |
+
|
223 |
+
progress_bar = gr.Progress(track_tqdm=True)
|
224 |
+
|
225 |
+
with gr.Column(scale=2):
|
226 |
+
output_gallery = gr.Gallery(label="Generated Images", columns=4, rows=4, object_fit="contain", height=600)
|
227 |
+
error_message = gr.Textbox(label="Error", visible=False, max_lines=3, container=True, elem_id="error_box")
|
228 |
+
|
229 |
+
# Seed reset button functionality
|
230 |
+
reset_seed_btn.click(generate_random_seed, inputs=[], outputs=[seed])
|
231 |
+
|
232 |
+
# Clear gallery button functionality
|
233 |
+
clear_btn.click(clear_gallery, inputs=[], outputs=[output_gallery, error_message])
|
234 |
+
|
235 |
+
# Connect components
|
236 |
+
generate_btn.click(
|
237 |
+
fn=generate_samples,
|
238 |
+
inputs=[num_samples, dit_size, genre, style, seed],
|
239 |
+
outputs=[output_gallery, error_message],
|
240 |
+
)
|
241 |
+
|
242 |
+
|
243 |
+
|
244 |
+
if __name__ == "__main__":
|
245 |
+
app.launch()
|
ckpts/DiT_S_final.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6f20bce4f40e4112f73fcd89dd2d5d9b7e4a5560f265ca0741af134d1a7ab355
|
3 |
+
size 264237994
|
mapping.py
ADDED
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
genre_mapping = {
|
3 |
+
'None': 0,
|
4 |
+
'abstract': 1,
|
5 |
+
'advertisement': 2,
|
6 |
+
'allegorical painting': 3,
|
7 |
+
'animal painting': 4,
|
8 |
+
'battle painting': 5,
|
9 |
+
'bijinga': 6,
|
10 |
+
'bird-and-flower painting': 7,
|
11 |
+
'calligraphy': 8,
|
12 |
+
'capriccio': 9,
|
13 |
+
'caricature': 10,
|
14 |
+
'cityscape': 11,
|
15 |
+
'cloudscape': 12,
|
16 |
+
'design': 13,
|
17 |
+
'figurative': 14,
|
18 |
+
'flower painting': 15,
|
19 |
+
'genre painting': 16,
|
20 |
+
'history painting': 17,
|
21 |
+
'illustration': 18,
|
22 |
+
'interior': 19,
|
23 |
+
'landscape': 20,
|
24 |
+
'literary painting': 21,
|
25 |
+
'marina': 22,
|
26 |
+
'miniature': 23,
|
27 |
+
'mythological painting': 24,
|
28 |
+
'nude painting (nu)': 25,
|
29 |
+
'panorama': 26,
|
30 |
+
'pastorale': 27,
|
31 |
+
'portrait': 28,
|
32 |
+
'poster': 29,
|
33 |
+
'quadratura': 30,
|
34 |
+
'religious painting': 31,
|
35 |
+
'self-portrait': 32,
|
36 |
+
'shan shui': 33,
|
37 |
+
'sketch and study': 34,
|
38 |
+
'still life': 35,
|
39 |
+
'symbolic painting': 36,
|
40 |
+
'tessellation': 37,
|
41 |
+
'urushi-e': 38,
|
42 |
+
'vanitas': 39,
|
43 |
+
'veduta': 40,
|
44 |
+
'wildlife painting': 41,
|
45 |
+
'yakusha-e': 42
|
46 |
+
}
|
47 |
+
|
48 |
+
style_mapping = {
|
49 |
+
'Abstract Art': 0,
|
50 |
+
'Abstract Expressionism': 1,
|
51 |
+
'Academicism': 2,
|
52 |
+
'Action painting': 3,
|
53 |
+
'American Realism': 4,
|
54 |
+
'Analytical Cubism': 5,
|
55 |
+
'Analytical\xa0Realism': 6,
|
56 |
+
'Art Brut': 7,
|
57 |
+
'Art Deco': 8,
|
58 |
+
'Art Informel': 9,
|
59 |
+
'Art Nouveau (Modern)': 10,
|
60 |
+
'Automatic Painting': 11,
|
61 |
+
'Baroque': 12,
|
62 |
+
'Biedermeier': 13,
|
63 |
+
'Byzantine': 14,
|
64 |
+
'Cartographic Art': 15,
|
65 |
+
'Classicism': 16,
|
66 |
+
'Cloisonism': 17,
|
67 |
+
'Color Field Painting': 18,
|
68 |
+
'Conceptual Art': 19,
|
69 |
+
'Concretism': 20,
|
70 |
+
'Constructivism': 21,
|
71 |
+
'Contemporary Realism': 22,
|
72 |
+
'Costumbrismo': 23,
|
73 |
+
'Cubism': 24,
|
74 |
+
'Cubo-Expressionism': 25,
|
75 |
+
'Cubo-Futurism': 26,
|
76 |
+
'Dada': 27,
|
77 |
+
'Divisionism': 28,
|
78 |
+
'Early Renaissance': 29,
|
79 |
+
'Environmental (Land) Art': 30,
|
80 |
+
'Existential Art': 31,
|
81 |
+
'Expressionism': 32,
|
82 |
+
'Fantastic Realism': 33,
|
83 |
+
'Fauvism': 34,
|
84 |
+
'Feminist Art': 35,
|
85 |
+
'Figurative Expressionism': 36,
|
86 |
+
'Futurism': 37,
|
87 |
+
'Gongbi': 38,
|
88 |
+
'Gothic': 39,
|
89 |
+
'Hard Edge Painting': 40,
|
90 |
+
'High Renaissance': 41,
|
91 |
+
'Hyper-Realism': 42,
|
92 |
+
'Ilkhanid': 43,
|
93 |
+
'Impressionism': 44,
|
94 |
+
'Indian Space painting': 45,
|
95 |
+
'Ink and wash painting': 46,
|
96 |
+
'International Gothic': 47,
|
97 |
+
'Intimism': 48,
|
98 |
+
'Japonism': 49,
|
99 |
+
'Joseon Dynasty': 50,
|
100 |
+
'Kinetic Art': 51,
|
101 |
+
'Kitsch': 52,
|
102 |
+
'Lettrism': 53,
|
103 |
+
'Light and Space': 54,
|
104 |
+
'Luminism': 55,
|
105 |
+
'Lyrical Abstraction': 56,
|
106 |
+
'Magic Realism': 57,
|
107 |
+
'Mail Art': 58,
|
108 |
+
'Mannerism (Late Renaissance)': 59,
|
109 |
+
'Mechanistic Cubism': 60,
|
110 |
+
'Metaphysical art': 61,
|
111 |
+
'Minimalism': 62,
|
112 |
+
'Miserabilism': 63,
|
113 |
+
'Modernismo': 64,
|
114 |
+
'Mosan art': 65,
|
115 |
+
'Muralism': 66,
|
116 |
+
'Nanga (Bunjinga)': 67,
|
117 |
+
'Nats-Taliq': 68,
|
118 |
+
'Native Art': 69,
|
119 |
+
'Naturalism': 70,
|
120 |
+
'Naïve Art (Primitivism)': 71,
|
121 |
+
'Neo-Byzantine': 72,
|
122 |
+
'Neo-Concretism': 73,
|
123 |
+
'Neo-Dada': 74,
|
124 |
+
'Neo-Expressionism': 75,
|
125 |
+
'Neo-Figurative Art': 76,
|
126 |
+
'Neo-Rococo': 77,
|
127 |
+
'Neo-Romanticism': 78,
|
128 |
+
'Neo-baroque': 79,
|
129 |
+
'Neoclassicism': 80,
|
130 |
+
'Neoplasticism': 81,
|
131 |
+
'New Casualism': 82,
|
132 |
+
'New European Painting': 83,
|
133 |
+
'New Realism': 84,
|
134 |
+
'Nihonga': 85,
|
135 |
+
'None': 86,
|
136 |
+
'Northern Renaissance': 87,
|
137 |
+
'Nouveau Réalisme': 88,
|
138 |
+
'Op Art': 89,
|
139 |
+
'Orientalism': 90,
|
140 |
+
'Orphism': 91,
|
141 |
+
'Ottoman Period': 92,
|
142 |
+
'Outsider art': 93,
|
143 |
+
'Perceptism': 94,
|
144 |
+
'Photorealism': 95,
|
145 |
+
'Pointillism': 96,
|
146 |
+
'Pop Art': 97,
|
147 |
+
'Post-Impressionism': 98,
|
148 |
+
'Post-Minimalism': 99,
|
149 |
+
'Post-Painterly Abstraction': 100,
|
150 |
+
'Poster Art Realism': 101,
|
151 |
+
'Precisionism': 102,
|
152 |
+
'Primitivism': 103,
|
153 |
+
'Proto Renaissance': 104,
|
154 |
+
'Purism': 105,
|
155 |
+
'Rayonism': 106,
|
156 |
+
'Realism': 107,
|
157 |
+
'Regionalism': 108,
|
158 |
+
'Renaissance': 109,
|
159 |
+
'Rococo': 110,
|
160 |
+
'Romanesque': 111,
|
161 |
+
'Romanticism': 112,
|
162 |
+
'Safavid Period': 113,
|
163 |
+
'Shin-hanga': 114,
|
164 |
+
'Social Realism': 115,
|
165 |
+
'Socialist Realism': 116,
|
166 |
+
'Spatialism': 117,
|
167 |
+
'Spectralism': 118,
|
168 |
+
'Street art': 119,
|
169 |
+
'Suprematism': 120,
|
170 |
+
'Surrealism': 121,
|
171 |
+
'Symbolism': 122,
|
172 |
+
'Synchromism': 123,
|
173 |
+
'Synthetic Cubism': 124,
|
174 |
+
'Synthetism': 125,
|
175 |
+
'Sōsaku hanga': 126,
|
176 |
+
'Tachisme': 127,
|
177 |
+
'Tenebrism': 128,
|
178 |
+
'Timurid Period': 129,
|
179 |
+
'Tonalism': 130,
|
180 |
+
'Transautomatism': 131,
|
181 |
+
'Tubism': 132,
|
182 |
+
'Ukiyo-e': 133,
|
183 |
+
'Verism': 134,
|
184 |
+
'Yamato-e': 135,
|
185 |
+
'Zen': 136
|
186 |
+
}
|
187 |
+
|
188 |
+
reverse_genre_mapping = {v: k for k, v in genre_mapping.items()}
|
189 |
+
reverse_style_mapping = {v: k for k, v in style_mapping.items()}
|
190 |
+
|
191 |
+
reduced_genre_mapping = {
|
192 |
+
'abstract': 1,
|
193 |
+
'capriccio': 9,
|
194 |
+
'cityscape': 11,
|
195 |
+
'cloudscape': 12,
|
196 |
+
'flower painting': 15,
|
197 |
+
'genre painting': 16,
|
198 |
+
'interior': 19,
|
199 |
+
'landscape': 20,
|
200 |
+
'marina': 22,
|
201 |
+
'panorama': 26,
|
202 |
+
'pastorale': 27,
|
203 |
+
'quadratura': 30,
|
204 |
+
'shan shui': 33,
|
205 |
+
'sketch and study': 34,
|
206 |
+
'still life': 35,
|
207 |
+
'symbolic painting': 36,
|
208 |
+
'tesselation': 37,
|
209 |
+
'veduta': 40
|
210 |
+
}
|
211 |
+
|
212 |
+
reduced_style_mapping = {
|
213 |
+
'abstract expressionism': 1,
|
214 |
+
'art deco': 8,
|
215 |
+
'art nouveau': 10,
|
216 |
+
'baroque': 12,
|
217 |
+
'conceptual art': 19,
|
218 |
+
'cubism': 24,
|
219 |
+
'expressionism': 32,
|
220 |
+
'gothic': 39,
|
221 |
+
'impressionism': 44,
|
222 |
+
'minimalism': 62,
|
223 |
+
'modernism': 64,
|
224 |
+
'neoclassicism': 80,
|
225 |
+
'pop-art': 97,
|
226 |
+
'post-impressionism': 98,
|
227 |
+
'renaissance': 109,
|
228 |
+
'rococo': 110,
|
229 |
+
'romanticism': 112,
|
230 |
+
'surrealism': 121
|
231 |
+
}
|
232 |
+
|
233 |
+
reverse_reduced_genre_mapping = {v: k for k, v in reduced_genre_mapping.items()}
|
234 |
+
reverse_reduced_style_mapping = {v: k for k, v in reduced_style_mapping.items()}
|
235 |
+
|
models/DiT.py
ADDED
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import math
|
5 |
+
from timm.models.vision_transformer import PatchEmbed
|
6 |
+
|
7 |
+
class TimestepEmbedder(nn.Module):
|
8 |
+
"""Module to create timestep's embedding."""
|
9 |
+
def __init__(self,hidden_size,frequency_embedding_size=256):
|
10 |
+
super().__init__()
|
11 |
+
self.mlp = nn.Sequential(
|
12 |
+
nn.Linear(frequency_embedding_size,hidden_size),
|
13 |
+
nn.SiLU(),
|
14 |
+
nn.Linear(hidden_size,hidden_size)
|
15 |
+
)
|
16 |
+
self.frequency_embedding_size = frequency_embedding_size
|
17 |
+
|
18 |
+
def forward(self, t):
|
19 |
+
half = self.frequency_embedding_size // 2
|
20 |
+
freqs = torch.exp(
|
21 |
+
-math.log(10000) * torch.arange(start=0,end=half) / half
|
22 |
+
).to(device=t.device)
|
23 |
+
args = torch.einsum('i,j->ij', t, freqs.to(t.device))
|
24 |
+
freqs = torch.cat([torch.cos(args),torch.sin(args)],dim=-1)
|
25 |
+
return self.mlp(freqs)
|
26 |
+
|
27 |
+
class ViTAttn(nn.Module):
|
28 |
+
def __init__(self,hidden_size,num_heads):
|
29 |
+
super().__init__()
|
30 |
+
self.attn = nn.MultiheadAttention(hidden_size,num_heads,bias=True,add_bias_kv=True,batch_first=True)
|
31 |
+
|
32 |
+
def forward(self,x):
|
33 |
+
attn, _ = self.attn(x,x,x)
|
34 |
+
return attn
|
35 |
+
|
36 |
+
class DiTBlock(nn.Module):
|
37 |
+
"""
|
38 |
+
DiT Block with adaptive layer norm zero (adaLN-Zero) conditioning.
|
39 |
+
Using post-norm
|
40 |
+
"""
|
41 |
+
def __init__(self,hidden_size,num_heads):
|
42 |
+
super().__init__()
|
43 |
+
self.norm1 = nn.LayerNorm(hidden_size,elementwise_affine=False,eps=1e-6)
|
44 |
+
self.attn = ViTAttn(hidden_size,num_heads)
|
45 |
+
self.norm2 = nn.LayerNorm(hidden_size,elementwise_affine=False,eps=1e-6)
|
46 |
+
self.mlp = nn.Sequential(
|
47 |
+
nn.Linear(hidden_size,4*hidden_size),
|
48 |
+
nn.GELU(approximate="tanh"),
|
49 |
+
nn.Linear(4*hidden_size,hidden_size)
|
50 |
+
)
|
51 |
+
self.adaLN = nn.Sequential(
|
52 |
+
nn.SiLU(),
|
53 |
+
nn.Linear(hidden_size,6*hidden_size)
|
54 |
+
)
|
55 |
+
|
56 |
+
def forward(self,x,c):
|
57 |
+
gamma_1,beta_1,alpha_1,gamma_2,beta_2,alpha_2 = self.adaLN(c).chunk(6,dim=1)
|
58 |
+
x = self.norm1(x + alpha_1.unsqueeze(1) * self.attn(x))
|
59 |
+
x = x * (1+gamma_1.unsqueeze(1)) + beta_1.unsqueeze(1)
|
60 |
+
x = self.norm2(x + alpha_2.unsqueeze(1) * self.mlp(x))
|
61 |
+
x = x * (1+gamma_2.unsqueeze(1)) + beta_2.unsqueeze(1)
|
62 |
+
return x
|
63 |
+
|
64 |
+
class DiT(nn.Module):
|
65 |
+
def __init__(self,
|
66 |
+
num_blocks=10,
|
67 |
+
hidden_size=640,
|
68 |
+
num_heads=10,
|
69 |
+
patch_size=2,
|
70 |
+
num_channels=4,
|
71 |
+
img_size=32,
|
72 |
+
num_genres=42,
|
73 |
+
num_styles=137):
|
74 |
+
super().__init__()
|
75 |
+
self.hidden_size = hidden_size
|
76 |
+
self.patch_size = patch_size
|
77 |
+
self.num_channels = num_channels
|
78 |
+
self.seq_len = (img_size // patch_size)**2
|
79 |
+
self.img_size = img_size
|
80 |
+
self.blocks = nn.ModuleList(
|
81 |
+
DiTBlock(hidden_size,num_heads) for _ in range(num_blocks)
|
82 |
+
)
|
83 |
+
self.timestep_embed = TimestepEmbedder(hidden_size)
|
84 |
+
|
85 |
+
self.num_genres = num_genres
|
86 |
+
self.num_styles = num_styles
|
87 |
+
self.genre_condition = nn.Embedding(num_genres+1,hidden_size) # +1 for null condition
|
88 |
+
self.style_condition = nn.Embedding(num_styles+1,hidden_size)
|
89 |
+
|
90 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, self.seq_len, hidden_size))
|
91 |
+
|
92 |
+
patch_dim = num_channels * patch_size * patch_size
|
93 |
+
self.proj_in = nn.Linear(patch_dim,hidden_size)
|
94 |
+
self.proj_out = nn.Linear(hidden_size,patch_dim)
|
95 |
+
|
96 |
+
self.norm_out = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
97 |
+
self.adaLN_final = nn.Sequential(
|
98 |
+
nn.SiLU(),
|
99 |
+
nn.Linear(hidden_size, 2*hidden_size)
|
100 |
+
)
|
101 |
+
|
102 |
+
self.initialize_weights()
|
103 |
+
|
104 |
+
def initialize_weights(self):
|
105 |
+
nn.init.normal_(self.pos_embed, std=0.02)
|
106 |
+
nn.init.normal_(self.proj_out.weight, std=0.02)
|
107 |
+
nn.init.zeros_(self.proj_out.bias)
|
108 |
+
nn.init.normal_(self.proj_in.weight, std=0.02)
|
109 |
+
nn.init.zeros_(self.proj_in.bias)
|
110 |
+
|
111 |
+
nn.init.normal_(self.timestep_embed.mlp[0].weight, std=0.02)
|
112 |
+
nn.init.zeros_(self.timestep_embed.mlp[0].bias)
|
113 |
+
nn.init.normal_(self.timestep_embed.mlp[2].weight, std=0.02)
|
114 |
+
nn.init.zeros_(self.timestep_embed.mlp[2].bias)
|
115 |
+
|
116 |
+
for block in self.blocks:
|
117 |
+
nn.init.zeros_(block.adaLN[-1].weight)
|
118 |
+
nn.init.zeros_(block.adaLN[-1].bias)
|
119 |
+
|
120 |
+
nn.init.zeros_(self.adaLN_final[-1].weight)
|
121 |
+
nn.init.zeros_(self.adaLN_final[-1].bias)
|
122 |
+
|
123 |
+
nn.init.normal_(self.genre_condition.weight, std=0.02)
|
124 |
+
nn.init.normal_(self.style_condition.weight, std=0.02)
|
125 |
+
|
126 |
+
def patchify(self,z):
|
127 |
+
"""
|
128 |
+
from (batch_size,6,32,32) -> (batch_size,256,24) -> (batch_size,256,hidden_size)
|
129 |
+
"""
|
130 |
+
b,_,_,_ = z.shape
|
131 |
+
c = self.num_channels
|
132 |
+
p = self.patch_size
|
133 |
+
z = z.unfold(2,p,p).unfold(3,p,p) # (b,c,h//p,p,w//p,p)
|
134 |
+
z = z.contiguous().view(b,c,-1,p,p) # (b,c,hw//p**2,p,p)
|
135 |
+
z = torch.einsum('bcapq->bacpq',z).contiguous().view(b,-1,c*p**2) # (b,hw//p**2,c*p**2)
|
136 |
+
return self.proj_in(z) # (b,hw//p**2,hidden_size)
|
137 |
+
|
138 |
+
def unpatchify(self,z):
|
139 |
+
"""
|
140 |
+
from (batch_size,256,hidden_size) -> (batch_size,256,24) -> (batch_size,6,32,32)
|
141 |
+
"""
|
142 |
+
b,_,_ = z.shape
|
143 |
+
c = self.num_channels
|
144 |
+
p = self.patch_size
|
145 |
+
s = int(self.seq_len ** 0.5)
|
146 |
+
i = self.img_size
|
147 |
+
z = self.proj_out(z) # (b,hw//p**2,c*p**2)
|
148 |
+
z = z.view(b,s,s,c,p,p) # (b,h/p,w/p,c,p,p)
|
149 |
+
z = torch.einsum('befcpq->bcepfq',z) # (b,c,h/p,p,w/p,p)
|
150 |
+
z = z.contiguous().view(b,c,i,i)
|
151 |
+
return z
|
152 |
+
|
153 |
+
|
154 |
+
def forward(self,z,t,g,s):
|
155 |
+
t_embed = self.timestep_embed(t) # t_embed: (batch_size, hidden_size)
|
156 |
+
g_embed = self.genre_condition(g)
|
157 |
+
s_embed = self.style_condition(s)
|
158 |
+
|
159 |
+
c = t_embed + g_embed + s_embed
|
160 |
+
|
161 |
+
z = self.patchify(z)
|
162 |
+
z = z + self.pos_embed
|
163 |
+
|
164 |
+
for block in self.blocks:
|
165 |
+
z = block(z,c)
|
166 |
+
|
167 |
+
gamma, beta = self.adaLN_final(c).chunk(2,dim=-1)
|
168 |
+
z = self.norm_out(z)
|
169 |
+
z = z * (1+gamma.unsqueeze(1)) + beta.unsqueeze(1)
|
170 |
+
|
171 |
+
return self.unpatchify(z)
|
172 |
+
|
173 |
+
|
174 |
+
|
175 |
+
|
176 |
+
if __name__ == "__main__":
|
177 |
+
model = DiT(1,768,12,2,6,32)
|
178 |
+
z = torch.randn(2,6,32,32)
|
179 |
+
c = torch.randn(2,768)
|
180 |
+
t = torch.randint(0,1000,(2,))
|
181 |
+
output = model(z,c,t)
|
182 |
+
print(z.shape,c.shape,t.shape,output.shape)
|
183 |
+
output_cfg = model.forward_cfg(z,t)
|
184 |
+
print(output_cfg.shape)
|
requirements.txt
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
diffusers
|
3 |
+
gradio
|
4 |
+
numpy
|
5 |
+
tqdm
|
6 |
+
matplotlib
|