import os import torch import gradio as gr from PIL import Image from diffusers import StableDiffusionPipeline, DiffusionPipeline from diffusers import AutoencoderKL, LMSDiscreteScheduler, UNet2DConditionModel from tqdm.auto import tqdm import torchvision.transforms as T import torch.nn.functional as F import gc import signal import time import traceback # Configure constants - optimized for CPU HEIGHT, WIDTH = 384, 384 # Smaller images use less memory GUIDANCE_SCALE = 7.5 LOSS_SCALE = 200 NUM_INFERENCE_STEPS = 30 # Reduced from 50 BATCH_SIZE = 1 DEFAULT_PROMPT = "A deadly witcher slinging a sword with a lion medallion in his neck, casting a fire spell from his hand in a snowy forest" # Define the device TORCH_DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" print(f"Using device: {TORCH_DEVICE}") # Initialize the elastic transformer elastic_transformer = T.ElasticTransform(alpha=550.0, sigma=5.0) # Timeout handler for CPU processing def timeout_handler(signum, frame): raise TimeoutError("Image generation took too long") # Load the model def load_model(): try: # Initialize signal handler only on Unix-like systems if TORCH_DEVICE == "cpu" and hasattr(signal, 'SIGALRM'): signal.signal(signal.SIGALRM, timeout_handler) signal.alarm(2100) # 15 minutes timeout for model loading pipe = DiffusionPipeline.from_pretrained( "CompVis/stable-diffusion-v1-4", torch_dtype=torch.float16 if TORCH_DEVICE == "cuda" else torch.float32, safety_checker=None, # Disable safety checker for memory low_cpu_mem_usage=True # Enable memory optimization ).to(TORCH_DEVICE) # Load textual inversion for all devices including CPU try: # Load one at a time with memory cleanup between each concepts = [ "sd-concepts-library/rimworld-art-style", "sd-concepts-library/hk-goldenlantern", "sd-concepts-library/phoenix-01", "sd-concepts-library/fractal-flame", "sd-concepts-library/scarlet-witch" ] for concept in concepts: try: print(f"Loading textual inversion concept: {concept}") pipe.load_textual_inversion(concept, mean_resizing=False) # Clear memory after loading each concept if TORCH_DEVICE == "cpu": gc.collect() except Exception as e: print(f"Warning: Could not load textual inversion concept {concept}: {e}") except Exception as e: print(f"Warning: Could not load textual inversion concepts: {e}") # Clear the alarm if set if TORCH_DEVICE == "cpu" and hasattr(signal, 'SIGALRM'): signal.alarm(0) return pipe except Exception as e: # Clear the alarm if set if TORCH_DEVICE == "cpu" and hasattr(signal, 'SIGALRM'): signal.alarm(0) print(f"Error loading model: {e}") traceback.print_exc() raise # Helper functions def image_grid(imgs, rows, cols): assert len(imgs) == rows*cols w, h = imgs[0].size grid = Image.new('RGB', size=(cols*w, rows*h)) for i, img in enumerate(imgs): grid.paste(img, box=(i%cols*w, i//cols*h)) return grid def image_loss(images, loss_type): if loss_type == 'blue': # blue loss error = torch.abs(images[:,2] - 0.9).mean() elif loss_type == 'elastic': # elastic loss transformed_imgs = elastic_transformer(images) error = torch.abs(transformed_imgs - images).mean() elif loss_type == 'symmetry': flipped_image = torch.flip(images, [3]) error = F.mse_loss(images, flipped_image) elif loss_type == 'saturation': # saturation loss transformed_imgs = T.functional.adjust_saturation(images, saturation_factor=10) error = torch.abs(transformed_imgs - images).mean() else: print("Error. Loss not defined") error = torch.tensor(0.0) return error def latents_to_pil(latents, pipe): # batch of latents -> list of images latents = (1 / 0.18215) * latents with torch.no_grad(): image = pipe.vae.decode(latents).sample image = (image / 2 + 0.5).clamp(0, 1) image = image.detach().cpu().permute(0, 2, 3, 1).numpy() images = (image * 255).round().astype("uint8") pil_images = [Image.fromarray(image) for image in images] return pil_images def generate_image(pipe, seed_no, prompts, loss_type, loss_apply=False, progress=gr.Progress()): try: # Set timeout for CPU if TORCH_DEVICE == "cpu": signal.signal(signal.SIGALRM, timeout_handler) signal.alarm(600) # 10 minute timeout # Initialization and Setup generator = torch.manual_seed(seed_no) scheduler = LMSDiscreteScheduler( beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000 ) scheduler.set_timesteps(NUM_INFERENCE_STEPS) scheduler.timesteps = scheduler.timesteps.to(torch.float32) # Text Processing text_input = pipe.tokenizer( prompts, padding='max_length', max_length=pipe.tokenizer.model_max_length, truncation=True, return_tensors="pt" ) input_ids = text_input.input_ids.to(TORCH_DEVICE) # Convert text inputs to embeddings with torch.no_grad(): text_embeddings = pipe.text_encoder(input_ids)[0] # Handle padding and truncation of text inputs max_length = text_input.input_ids.shape[-1] uncond_input = pipe.tokenizer( [""] * BATCH_SIZE, padding="max_length", max_length=max_length, return_tensors="pt" ) with torch.no_grad(): uncond_embeddings = pipe.text_encoder(uncond_input.input_ids.to(TORCH_DEVICE))[0] # Concatenate unconditioned and text embeddings text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) # Create random initial latents latents = torch.randn( (BATCH_SIZE, pipe.unet.config.in_channels, HEIGHT // 8, WIDTH // 8), generator=generator, ) # Move latents to device and apply noise scaling if TORCH_DEVICE == "cuda": latents = latents.to(torch.float16) latents = latents.to(TORCH_DEVICE) latents = latents * scheduler.init_noise_sigma # Diffusion Process timesteps = scheduler.timesteps progress(0, desc="Generating") # Fixed loop - separate the progress tracking from the enumeration for i in range(len(timesteps)): progress((i + 1) / len(timesteps), desc=f"Diffusion step {i+1}/{len(timesteps)}") t = timesteps[i] # Process the latent model input latent_model_input = torch.cat([latents] * 2) sigma = scheduler.sigmas[i] latent_model_input = scheduler.scale_model_input(latent_model_input, t) with torch.no_grad(): noise_pred = pipe.unet( latent_model_input, t, encoder_hidden_states=text_embeddings )["sample"] # Apply noise prediction noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + GUIDANCE_SCALE * (noise_pred_text - noise_pred_uncond) # Apply loss if requested if loss_apply and i % 5 == 0 and loss_type != "N/A": latents = latents.detach().requires_grad_() latents_x0 = latents - sigma * noise_pred # Use VAE to decode the image denoised_images = pipe.vae.decode((1 / 0.18215) * latents_x0).sample / 2 + 0.5 # Apply loss loss = image_loss(denoised_images, loss_type) * LOSS_SCALE print(f"Step {i}, Loss: {loss.item()}") # Compute gradients for optimization cond_grad = torch.autograd.grad(loss, latents)[0] latents = latents.detach() - cond_grad * sigma**2 # Update latents using the scheduler latents = scheduler.step(noise_pred, t, latents).prev_sample # Garbage collect every 5 steps if on CPU if TORCH_DEVICE == "cpu" and i % 5 == 0: gc.collect() # Clear the alarm if set if TORCH_DEVICE == "cpu": signal.alarm(0) return latents except Exception as e: print(f"Error in generate_image: {e}") traceback.print_exc() # Return empty latents as fallback return torch.zeros( (BATCH_SIZE, pipe.unet.config.in_channels, HEIGHT // 8, WIDTH // 8), device=TORCH_DEVICE ) def generate_images(prompt, loss_type, apply_loss, seeds, pipe, progress=gr.Progress()): try: images_list = [] # Convert comma-separated string to list and clean seeds = [int(seed.strip()) for seed in seeds.split(',') if seed.strip()] if not seeds: seeds = [1000] # Default seed if none provided # Process one seed at a time to save memory for i, seed_no in enumerate(seeds): progress((i / len(seeds)) * 0.1, desc=f"Starting seed {seed_no}") # Clear memory if TORCH_DEVICE == "cuda": torch.cuda.empty_cache() gc.collect() try: # Generate image prompts = [prompt] latents = generate_image(pipe, seed_no, prompts, loss_type, loss_apply=apply_loss, progress=progress) pil_images = latents_to_pil(latents, pipe) images_list.extend(pil_images) except Exception as e: print(f"Error generating image with seed {seed_no}: {e}") # Create an error image error_img = Image.new('RGB', (HEIGHT, WIDTH), color=(255, 0, 0)) images_list.append(error_img) # Force garbage collection gc.collect() # Create image grid if len(images_list) > 1: result = image_grid(images_list, 1, len(images_list)) return result else: return images_list[0] except Exception as e: print(f"Error in generate_images: {e}") traceback.print_exc() # Create an error image error_img = Image.new('RGB', (WIDTH, HEIGHT), color=(255, 0, 0)) return error_img # Gradio Interface def create_interface(): with gr.Blocks(title="Stable Diffusion Text Inversion with Loss Functions") as app: gr.Markdown(""" # Stable Diffusion Text Inversion with Loss Functions Generate images using Stable Diffusion with various loss functions to guide the diffusion process. """) if TORCH_DEVICE == "cpu": gr.Markdown(""" ⚠️ **Running on CPU**: Generation will be slow and memory-intensive. Each image may take several minutes to generate. """) pipe = None # Initialize to None to avoid loading during interface creation with gr.Row(): with gr.Column(): prompt = gr.Textbox( label="Prompt", value=DEFAULT_PROMPT, lines=3 ) loss_type = gr.Radio( label="Loss Type", choices=["N/A", "blue", "elastic", "symmetry", "saturation"], value="N/A" ) apply_loss = gr.Checkbox( label="Apply Loss Function", value=False ) if TORCH_DEVICE == "cpu": seeds = gr.Textbox( label="Seeds (comma-separated) - Use fewer seeds for CPU", value="1000", lines=1 ) else: seeds = gr.Textbox( label="Seeds (comma-separated)", value="3000,2000,1000", lines=1 ) # Load model button load_model_btn = gr.Button("Load Model") model_status = gr.Textbox(label="Model Status", value="Model not loaded", interactive=False) generate_btn = gr.Button("Generate Images", interactive=False) with gr.Column(): output_image = gr.Image(label="Generated Image") def load_model_fn(): nonlocal pipe try: pipe = load_model() return "Model loaded successfully", True except Exception as e: return f"Error loading model: {str(e)}", False load_model_btn.click( fn=load_model_fn, inputs=[], outputs=[model_status, generate_btn] ) generate_btn.click( fn=lambda p, lt, al, s, prog: generate_images(p, lt, al, s, pipe, prog), inputs=[prompt, loss_type, apply_loss, seeds], outputs=output_image ) gr.Markdown(""" ## About the Loss Functions - **Blue**: Encourages more blue tones in the image - **Elastic**: Creates distortion effects by minimizing differences with elastically transformed versions - **Symmetry**: Encourages symmetrical images by minimizing differences with horizontally flipped versions - **Saturation**: Increases color saturation in the image Set "N/A" and uncheck "Apply Loss Function" for normal image generation. """) if TORCH_DEVICE == "cpu": gr.Markdown(""" ## CPU Mode Tips - Use smaller prompts - Process one seed at a time - Be patient, generation can take 5-10 minutes per image - If you encounter memory errors, try restarting the app and using even smaller dimensions """) return app if __name__ == "__main__": # Create and launch the interface app = create_interface() app.launch()