import os import torch import gradio as gr from PIL import Image from diffusers import StableDiffusionPipeline, DiffusionPipeline from diffusers import AutoencoderKL, LMSDiscreteScheduler, UNet2DConditionModel from tqdm.auto import tqdm import torchvision.transforms as T import torch.nn.functional as F import gc # Configure constants HEIGHT, WIDTH = 512, 512 GUIDANCE_SCALE = 8 LOSS_SCALE = 200 NUM_INFERENCE_STEPS = 50 BATCH_SIZE = 1 DEFAULT_PROMPT = "A deadly witcher slinging a sword with a lion medallion in his neck, casting a fire spell from his hand in a snowy forest" # Define the device TORCH_DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" # Initialize the elastic transformer elastic_transformer = T.ElasticTransform(alpha=550.0, sigma=5.0) # Load the model def load_model(): pipe = DiffusionPipeline.from_pretrained( "CompVis/stable-diffusion-v1-4", torch_dtype=torch.float16 if TORCH_DEVICE == "cuda" else torch.float32 ).to(TORCH_DEVICE) # Load textual inversion concepts try: pipe.load_textual_inversion("sd-concepts-library/rimworld-art-style", mean_resizing=False) pipe.load_textual_inversion("sd-concepts-library/hk-goldenlantern", mean_resizing=False) pipe.load_textual_inversion("sd-concepts-library/phoenix-01", mean_resizing=False) pipe.load_textual_inversion("sd-concepts-library/fractal-flame", mean_resizing=False) pipe.load_textual_inversion("sd-concepts-library/scarlet-witch", mean_resizing=False) except Exception as e: print(f"Warning: Could not load all textual inversion concepts: {e}") return pipe # Helper functions def image_grid(imgs, rows, cols): assert len(imgs) == rows*cols w, h = imgs[0].size grid = Image.new('RGB', size=(cols*w, rows*h)) for i, img in enumerate(imgs): grid.paste(img, box=(i%cols*w, i//cols*h)) return grid def image_loss(images, loss_type): if loss_type == 'blue': # blue loss error = torch.abs(images[:,2] - 0.9).mean() elif loss_type == 'elastic': # elastic loss transformed_imgs = elastic_transformer(images) error = torch.abs(transformed_imgs - images).mean() elif loss_type == 'symmetry': flipped_image = torch.flip(images, [3]) error = F.mse_loss(images, flipped_image) elif loss_type == 'saturation': # saturation loss transformed_imgs = T.functional.adjust_saturation(images, saturation_factor=10) error = torch.abs(transformed_imgs - images).mean() else: print("Error. Loss not defined") error = torch.tensor(0.0) return error def latents_to_pil(latents, pipe): # batch of latents -> list of images latents = (1 / 0.18215) * latents with torch.no_grad(): image = pipe.vae.decode(latents).sample image = (image / 2 + 0.5).clamp(0, 1) image = image.detach().cpu().permute(0, 2, 3, 1).numpy() images = (image * 255).round().astype("uint8") pil_images = [Image.fromarray(image) for image in images] return pil_images def generate_image(pipe, seed_no, prompts, loss_type, loss_apply=False, progress=gr.Progress()): # Initialization and Setup generator = torch.manual_seed(seed_no) scheduler = LMSDiscreteScheduler( beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000 ) scheduler.set_timesteps(NUM_INFERENCE_STEPS) scheduler.timesteps = scheduler.timesteps.to(torch.float32) # Text Processing text_input = pipe.tokenizer( prompts, padding='max_length', max_length=pipe.tokenizer.model_max_length, truncation=True, return_tensors="pt" ) input_ids = text_input.input_ids.to(TORCH_DEVICE) # Convert text inputs to embeddings with torch.no_grad(): text_embeddings = pipe.text_encoder(input_ids)[0] # Handle padding and truncation of text inputs max_length = text_input.input_ids.shape[-1] uncond_input = pipe.tokenizer( [""] * BATCH_SIZE, padding="max_length", max_length=max_length, return_tensors="pt" ) with torch.no_grad(): uncond_embeddings = pipe.text_encoder(uncond_input.input_ids.to(TORCH_DEVICE))[0] # Concatenate unconditioned and text embeddings text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) # Create random initial latents latents = torch.randn( (BATCH_SIZE, pipe.unet.config.in_channels, HEIGHT // 8, WIDTH // 8), generator=generator, ) # Move latents to device and apply noise scaling if TORCH_DEVICE == "cuda": latents = latents.to(torch.float16) latents = latents.to(TORCH_DEVICE) latents = latents * scheduler.init_noise_sigma # Diffusion Process for i, t in progress.tqdm(enumerate(scheduler.timesteps), total=len(scheduler.timesteps)): # Process the latent model input latent_model_input = torch.cat([latents] * 2) sigma = scheduler.sigmas[i] latent_model_input = scheduler.scale_model_input(latent_model_input, t) with torch.no_grad(): noise_pred = pipe.unet( latent_model_input, t, encoder_hidden_states=text_embeddings )["sample"] # Apply noise prediction noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + GUIDANCE_SCALE * (noise_pred_text - noise_pred_uncond) # Apply loss if requested if loss_apply and i % 5 == 0: latents = latents.detach().requires_grad_() latents_x0 = latents - sigma * noise_pred # Use VAE to decode the image denoised_images = pipe.vae.decode((1 / 0.18215) * latents_x0).sample / 2 + 0.5 # Apply loss loss = image_loss(denoised_images, loss_type) * LOSS_SCALE print(f"Step {i}, Loss: {loss.item()}") # Compute gradients for optimization cond_grad = torch.autograd.grad(loss, latents)[0] latents = latents.detach() - cond_grad * sigma**2 # Update latents using the scheduler latents = scheduler.step(noise_pred, t, latents).prev_sample return latents def generate_images(prompt, loss_type, apply_loss, seeds, pipe): latents_collect = [] # Convert comma-separated string to list and clean seeds = [int(seed.strip()) for seed in seeds.split(',') if seed.strip()] if not seeds: seeds = [1000] # Default seed if none provided # List of SD concepts (can be empty if not used) sdconcepts = [''] * len(seeds) # Generate images for each seed for seed_no, sd in zip(seeds, sdconcepts): # Clear CUDA cache if TORCH_DEVICE == "cuda": torch.cuda.empty_cache() gc.collect() torch.cuda.empty_cache() # Generate image prompts = [f'{prompt} {sd}'] latents = generate_image(pipe, seed_no, prompts, loss_type, loss_apply=apply_loss) latents_collect.append(latents) # Stack latents and convert to images latents_collect = torch.vstack(latents_collect) images = latents_to_pil(latents_collect, pipe) # Create image grid if len(images) > 1: result = image_grid(images, 1, len(images)) return result else: return images[0] # Gradio Interface def create_interface(): pipe = load_model() with gr.Blocks(title="Stable Diffusion Text Inversion with Loss Functions") as app: gr.Markdown(""" # Stable Diffusion Text Inversion with Loss Functions Generate images using Stable Diffusion with various loss functions to guide the diffusion process. """) with gr.Row(): with gr.Column(): prompt = gr.Textbox( label="Prompt", value=DEFAULT_PROMPT, lines=3 ) loss_type = gr.Radio( label="Loss Type", choices=["N/A", "blue", "elastic", "symmetry", "saturation"], value="N/A" ) apply_loss = gr.Checkbox( label="Apply Loss Function", value=False ) seeds = gr.Textbox( label="Seeds (comma-separated)", value="3000,2000,1000", lines=1 ) generate_btn = gr.Button("Generate Images") with gr.Column(): output_image = gr.Image(label="Generated Image") generate_btn.click( fn=lambda p, lt, al, s: generate_images(p, lt, al, s, pipe), inputs=[prompt, loss_type, apply_loss, seeds], outputs=output_image ) gr.Markdown(""" ## About the Loss Functions - **Blue**: Encourages more blue tones in the image - **Elastic**: Creates distortion effects by minimizing differences with elastically transformed versions - **Symmetry**: Encourages symmetrical images by minimizing differences with horizontally flipped versions - **Saturation**: Increases color saturation in the image Set "N/A" and uncheck "Apply Loss Function" for normal image generation. """) return app if __name__ == "__main__": # Create and launch the interface app = create_interface() app.launch()