Spaces:
Sleeping
Sleeping
| import torch | |
| import spaces | |
| import gradio as gr | |
| import numpy as np | |
| import os | |
| import random | |
| from mapping import reduced_genre_mapping, reduced_style_mapping, reverse_reduced_genre_mapping, reverse_reduced_style_mapping | |
| from diffusers import AutoencoderKL | |
| from huggingface_hub import hf_hub_download | |
| from models.DiT import DiT | |
| # Global settings | |
| num_timesteps = 1000 | |
| beta_start = 1e-4 | |
| beta_end = 0.02 | |
| latent_scale_factor = 0.18215 # Same as in DiTTrainer | |
| # For tracking progress in UI | |
| global_progress = 0 | |
| # Enable half precision inference | |
| USE_HALF_PRECISION = True | |
| def load_dit_model(dit_size): | |
| """Load DiT model of specified size""" | |
| #ckpt_path = f"./ckpts/DiT_{dit_size}_final.pth" | |
| ckpt_path = hf_hub_download( | |
| repo_id = "kaupane/DiT-Wikiart", | |
| filename = f"DiT_{dit_size}_final.pth" | |
| ) | |
| if not os.path.exists(ckpt_path): | |
| raise FileNotFoundError(f"Checkpoint not found at {ckpt_path}") | |
| # Configure model based on size | |
| if dit_size == "S": | |
| model = DiT(num_blocks=8, hidden_size=384, num_heads=6) | |
| elif dit_size == "B": | |
| model = DiT(num_blocks=12, hidden_size=640, num_heads=10) | |
| elif dit_size == "L": | |
| model = DiT(num_blocks=16, hidden_size=896, num_heads=14) | |
| else: | |
| raise ValueError(f"Invalid DiT size: {dit_size}") | |
| # Load checkpoint | |
| checkpoint = torch.load(ckpt_path, map_location="cpu") | |
| model.load_state_dict(checkpoint["model_state_dict"]) | |
| return model | |
| class DiffusionSampler: | |
| def __init__(self, device="cuda" if torch.cuda.is_available() else "cpu", use_half = USE_HALF_PRECISION): | |
| self.device = device | |
| self.use_half = use_half | |
| self.vae = None | |
| # Pre-compute diffusion parameters | |
| self.betas = torch.linspace(beta_start, beta_end, num_timesteps) | |
| self.alphas = 1.0 - self.betas | |
| self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) | |
| self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod) | |
| self.sqrt_one_minus_alpha_cumprod = torch.sqrt(1.0 - self.alphas_cumprod) | |
| self.sqrt_recip_alphas = torch.sqrt(1.0 / self.alphas) | |
| self.alphas_cumprod_prev = torch.cat([torch.tensor([1.0]), self.alphas_cumprod[:-1]]) | |
| self.posterior_variance = self.betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) | |
| # Move to device | |
| self.sqrt_alphas_cumprod = self.sqrt_alphas_cumprod.to(self.device) | |
| self.sqrt_one_minus_alpha_cumprod = self.sqrt_one_minus_alpha_cumprod.to(self.device) | |
| self.sqrt_recip_alphas = self.sqrt_recip_alphas.to(self.device) | |
| self.betas = self.betas.to(self.device) | |
| self.posterior_variance = self.posterior_variance.to(self.device) | |
| # Convert to half precision if needed | |
| if self.use_half: | |
| self.sqrt_alphas_cumprod = self.sqrt_alphas_cumprod.half() | |
| self.sqrt_one_minus_alpha_cumprod = self.sqrt_one_minus_alpha_cumprod.half() | |
| self.sqrt_recip_alphas = self.sqrt_recip_alphas.half() | |
| self.betas = self.betas.half() | |
| self.posterior_variance = self.posterior_variance.half() | |
| def load_vae(self): | |
| """Load VAE model (done lazily to save memory until needed)""" | |
| if self.vae is None: | |
| self.vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-ema").to(self.device) | |
| self.vae.eval() | |
| def generate_images(self, model, num_samples, genre, style, seed, progress=gr.Progress()): | |
| """Generate images with the DiT model""" | |
| global global_progress | |
| global_progress = 0 | |
| # Set random seed for reproducibility | |
| if seed is not None: | |
| torch.manual_seed(seed) | |
| np.random.seed(seed) | |
| random.seed(seed) | |
| # Also set CUDA seed if using GPU | |
| if torch.cuda.is_available(): | |
| torch.cuda.manual_seed(seed) | |
| torch.cuda.manual_seed_all(seed) | |
| model.to(self.device) | |
| model.eval() | |
| # Convert genre and style to tensors | |
| g_cond = torch.tensor([genre] * num_samples, device=self.device, dtype=torch.long) | |
| s_cond = torch.tensor([style] * num_samples, device=self.device, dtype=torch.long) | |
| g_null = torch.tensor([model.num_genres] * num_samples, device=self.device, dtype=torch.long) | |
| s_null = torch.tensor([model.num_styles] * num_samples, device=self.device, dtype=torch.long) | |
| # Start with random latents | |
| latents = torch.randn((num_samples, 4, 32, 32), device=self.device) | |
| if self.use_half: | |
| latents = latents.half() | |
| # Use classifier-free guidance for better quality | |
| cfg_scale = 2.5 | |
| # Go through the reverse diffusion process | |
| timesteps = torch.arange(num_timesteps - 1, -1, -1, device=self.device) | |
| total_steps = len(timesteps) | |
| with torch.no_grad(): | |
| for i, t_val in enumerate(timesteps): | |
| # Update progress | |
| global_progress = int(100 * i / total_steps) | |
| progress(global_progress / 100, desc="Generating images...") | |
| t = torch.full((num_samples,), t_val, device=self.device, dtype=torch.long) | |
| sqrt_recip_alphas_t = self.sqrt_recip_alphas[t].view(-1, 1, 1, 1) | |
| sqrt_one_minus_alphas_cumprod_t = self.sqrt_one_minus_alpha_cumprod[t].view(-1, 1, 1, 1) | |
| beta_t = self.betas[t].view(-1, 1, 1, 1) | |
| posterior_variance_t = self.posterior_variance[t].view(-1, 1, 1, 1) | |
| # Get noise prediction with classifier-free guidance | |
| eps_theta_cond = model(latents, t, g_cond, s_cond) | |
| eps_theta_uncond = model(latents, t, g_null, s_null) | |
| eps_theta = eps_theta_uncond + cfg_scale * (eps_theta_cond - eps_theta_uncond) | |
| # Update latents | |
| mean = sqrt_recip_alphas_t * (latents - (beta_t / sqrt_one_minus_alphas_cumprod_t) * eps_theta) | |
| noise = torch.randn_like(latents) | |
| if t_val == 0: | |
| latents = mean | |
| else: | |
| latents = mean + torch.sqrt(posterior_variance_t) * noise | |
| # Decode latents to images | |
| self.load_vae() | |
| # Convert back to float | |
| if self.use_half: | |
| latents = latents.float() | |
| latents = latents / self.vae.config.scaling_factor | |
| latents = latents.to(self.device) | |
| progress(0.95, desc="Decoding images...") | |
| with torch.no_grad(): | |
| images = self.vae.decode(latents).sample | |
| images = (images / 2 + 0.5).clamp(0, 1) | |
| images = images.permute(0, 2, 3, 1).cpu().numpy() | |
| progress(1.0, desc="Done!") | |
| global_progress = 100 | |
| # Create image gallery with labels | |
| gallery_images = [] | |
| for i in range(num_samples): | |
| # Convert numpy array to PIL Image | |
| img = (images[i] * 255).astype(np.uint8) | |
| caption = f"Genre: {reverse_reduced_genre_mapping[genre]}, Style: {reverse_reduced_style_mapping[style]}" | |
| if seed is not None: | |
| caption += f" (Seed: {seed})" | |
| gallery_images.append((img, caption)) | |
| return gallery_images | |
| # Initialize sampler globally | |
| sampler = DiffusionSampler() | |
| def generate_random_seed(): | |
| """Generate a random seed between 0 and 2^32 - 1""" | |
| return random.randint(0, 2**32 - 1) | |
| MODEL_SAMPLE_LIMITS = { | |
| "S": {"min":1, "max": 18, "default": 4}, | |
| "B": {"min":1, "max": 9, "default": 4}, | |
| "L": {"min":1, "max": 3, "default": 1} | |
| } | |
| def update_sample_slider(dit_size): | |
| limits = MODEL_SAMPLE_LIMITS[dit_size] | |
| return gr.update( | |
| minimum=limits["min"], | |
| maximum=limits["max"], | |
| value=limits["default"], | |
| info=f"How many images to generate ({limits['min']}-{limits['max']})" | |
| ) | |
| def generate_samples(num_samples, dit_size, genre_name, style_name, seed, progress=gr.Progress()): | |
| """Main function for Gradio interface""" | |
| limits = MODEL_SAMPLE_LIMITS[dit_size] | |
| if num_samples < limits["min"] or num_samples > limits["max"]: | |
| return None, gr.update(value=f"Number of samples for {dit_size} model must be between {limits['min']} and {limits['max']}", visible=True) | |
| # Get genre and style IDs from mappings | |
| genre_id = reduced_genre_mapping.get(genre_name) | |
| style_id = reduced_style_mapping.get(style_name) | |
| if genre_id is None: | |
| return None, gr.update(value=f"Unknown genre: {genre_name}", visible=True) | |
| if style_id is None: | |
| return None, gr.update(value=f"Unknown style: {style_name}", visible=True) | |
| try: | |
| # Load model | |
| progress(0.05, desc="Loading DiT model...") | |
| model = load_dit_model(dit_size) | |
| # Generate images | |
| gallery_images = sampler.generate_images(model, num_samples, genre_id, style_id, seed, progress) | |
| return gallery_images, gr.update(value="", visible=False) | |
| except Exception as e: | |
| error_msg = f"Error: {str(e)}" | |
| return None, gr.update(value=error_msg, visible=True) | |
| def clear_gallery(): | |
| """Clear the gallery display""" | |
| return None, gr.update(value="", visible=False) | |
| # Create the Gradio interface | |
| with gr.Blocks(title="DiT Diffusion Model Generator", theme=gr.themes.Soft()) as app: | |
| gr.Markdown("# DiT Diffusion Model Generator") | |
| gr.Markdown("Generate art images using a Diffusion Transformer (DiT) model") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| dit_size = gr.Radio( | |
| choices=["S", "B", "L"], | |
| value="B", | |
| label="DiT Model Size", | |
| info="S: Small (fastest), B: Base (balanced), L: Large (best quality but slowest)" | |
| ) | |
| num_samples = gr.Slider( | |
| minimum=MODEL_SAMPLE_LIMITS["B"]["min"], | |
| maximum=MODEL_SAMPLE_LIMITS["B"]["max"], | |
| value=MODEL_SAMPLE_LIMITS["B"]["default"], | |
| step=1, | |
| label="Number of Samples", | |
| info=f"How many images to generate ({MODEL_SAMPLE_LIMITS['B']['min']}-{MODEL_SAMPLE_LIMITS['B']['max']})" | |
| ) | |
| genre_names = list(reduced_genre_mapping.keys()) | |
| style_names = list(reduced_style_mapping.keys()) | |
| # Sort alphabetically, ensuring 'None' is at top | |
| genre_names.sort() | |
| style_names.sort() | |
| genre = gr.Dropdown(choices=genre_names, value="landscape", label="Art Genre") | |
| style = gr.Dropdown(choices=style_names, value="impressionism", label="Art Style") | |
| with gr.Row(): | |
| seed = gr.Number(label="Seed", value=generate_random_seed(), precision=0, info="Set for reproducible results") | |
| reset_seed_btn = gr.Button("🎲 New Seed") | |
| with gr.Row(): | |
| generate_btn = gr.Button("Generate Images", variant="primary") | |
| clear_btn = gr.Button("🗑️ Clear Gallery") | |
| progress_bar = gr.Progress(track_tqdm=True) | |
| with gr.Column(scale=2): | |
| output_gallery = gr.Gallery(label="Generated Images", columns=4, rows=4, object_fit="contain", height=600) | |
| error_message = gr.Textbox(label="Error", visible=False, max_lines=3, container=True, elem_id="error_box") | |
| dit_size.change(update_sample_slider, inputs=[dit_size],outputs=[num_samples]) | |
| # Seed reset button functionality | |
| reset_seed_btn.click(generate_random_seed, inputs=[], outputs=[seed]) | |
| # Clear gallery button functionality | |
| clear_btn.click(clear_gallery, inputs=[], outputs=[output_gallery, error_message]) | |
| # Connect components | |
| generate_btn.click( | |
| fn=generate_samples, | |
| inputs=[num_samples, dit_size, genre, style, seed], | |
| outputs=[output_gallery, error_message], | |
| ) | |
| if __name__ == "__main__": | |
| app.launch() | |