import torch import numpy as np from torchvision.transforms.functional import to_tensor from PIL import Image def blue_loss(images): """ Custom loss function to penalize or encourage the presence of blue hues in the images. """ # Convert images to tensors images_tensor = torch.tensor(images).float() / 255.0 # Extract the blue channel (last channel in RGB) blue_channel = images_tensor[:, :, :, 2] # Calculate variance of the blue channel variance = torch.var(blue_channel) # Return negative variance as the loss (penalize less blue) return -variance def generate_with_prompt_style_guidance(prompt, style, seed=42): prompt = prompt + ' in style of s' embed = torch.load(style) height = 512 width = 512 num_inference_steps = 10 guidance_scale = 8 generator = torch.manual_seed(seed) batch_size = 1 contrast_loss_scale = 200 blue_loss_scale = 100 # Scale for blue loss # Prep text text_input = tokenizer([prompt], padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt") with torch.no_grad(): text_embeddings = text_encoder(text_input.input_ids.to(torch_device))[0] input_ids = text_input.input_ids.to(torch_device) # Get token embeddings token_embeddings = token_emb_layer(input_ids) # The new embedding - our special birb word replacement_token_embedding = embed[list(embed.keys())[0]].to(torch_device) # Insert this into the token embeddings token_embeddings[0, torch.where(input_ids[0] == 338)] = replacement_token_embedding.to(torch_device) # Combine with pos embs input_embeddings = token_embeddings + position_embeddings # Feed through to get final output embs modified_output_embeddings = get_output_embeds(input_embeddings) # And the uncond. input as before: max_length = text_input.input_ids.shape[-1] uncond_input = tokenizer( [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt" ) with torch.no_grad(): uncond_embeddings = text_encoder(uncond_input.input_ids.to(torch_device))[0] text_embeddings = torch.cat([uncond_embeddings, modified_output_embeddings]) # Prep Scheduler scheduler.set_timesteps(num_inference_steps) # Prep latents latents = torch.randn( (batch_size, unet.config.in_channels, height // 8, width // 8), generator=generator, ) latents = latents.to(torch_device) latents = latents * scheduler.init_noise_sigma # Loop for i, t in tqdm(enumerate(scheduler.timesteps), total=len(scheduler.timesteps)): # Expand the latents if we are doing classifier-free guidance to avoid doing two forward passes. latent_model_input = torch.cat([latents] * 2) sigma = scheduler.sigmas[i] latent_model_input = scheduler.scale_model_input(latent_model_input, t) # Predict the noise residual with torch.no_grad(): noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"] # Perform CFG noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # Additional Guidance if i % 5 == 0: # Requires grad on the latents latents = latents.detach().requires_grad_() # Get the predicted x0 latents_x0 = latents - sigma * noise_pred # Decode to image space denoised_images = vae.decode((1 / 0.18215) * latents_x0).sample / 2 + 0.5 # Calculate losses contrast_loss_val = contrast_loss(denoised_images) * contrast_loss_scale blue_loss_val = blue_loss(denoised_images) * blue_loss_scale # Combine losses total_loss = contrast_loss_val + blue_loss_val # Get gradient cond_grad = torch.autograd.grad(total_loss, latents)[0] # Modify the latents based on this gradient latents = latents.detach() - cond_grad * sigma**2 # Now step with scheduler latents = scheduler.step(noise_pred, t, latents).prev_sample return latents_to_pil(latents)[0] import gradio as gr dict_styles = { 'Dr Strange': 'styles/learned_embeds_dr_strange.bin', 'GTA-5':'styles/learned_embeds_gta5.bin', 'Manga':'styles/learned_embeds_manga.bin', 'Pokemon':'styles/learned_embeds_pokemon.bin', } def inference(prompt, style): if prompt is not None and style is not None: style = dict_styles[style] result = generate_with_prompt_style_guidance(prompt, style) return np.array(result) else: return None title = "Stable Diffusion and Textual Inversion" description = "A simple Gradio interface to stylize Stable Diffusion outputs" examples = [['A man sipping wine wearing a spacesuit on the moon']] demo = gr.Interface(inference, inputs=[gr.Textbox(label='Prompt'), gr.Dropdown(['Dr Strange', 'GTA-5', 'Manga', 'Pokemon'], label='Style')], outputs=[gr.Image(label="Stable Diffusion Output")], title=title, description=description, examples=examples) demo.launch()