Spaces:
Sleeping
Sleeping
import torch | |
import numpy as np | |
from torchvision.transforms.functional import to_tensor | |
from PIL import Image | |
def blue_loss(images): | |
""" | |
Custom loss function to penalize or encourage the presence of blue hues in the images. | |
""" | |
# Convert images to tensors | |
images_tensor = torch.tensor(images).float() / 255.0 | |
# Extract the blue channel (last channel in RGB) | |
blue_channel = images_tensor[:, :, :, 2] | |
# Calculate variance of the blue channel | |
variance = torch.var(blue_channel) | |
# Return negative variance as the loss (penalize less blue) | |
return -variance | |
def generate_with_prompt_style_guidance(prompt, style, seed=42): | |
prompt = prompt + ' in style of s' | |
embed = torch.load(style) | |
height = 512 | |
width = 512 | |
num_inference_steps = 10 | |
guidance_scale = 8 | |
generator = torch.manual_seed(seed) | |
batch_size = 1 | |
contrast_loss_scale = 200 | |
blue_loss_scale = 100 # Scale for blue loss | |
# Prep text | |
text_input = tokenizer([prompt], padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt") | |
with torch.no_grad(): | |
text_embeddings = text_encoder(text_input.input_ids.to(torch_device))[0] | |
input_ids = text_input.input_ids.to(torch_device) | |
# Get token embeddings | |
token_embeddings = token_emb_layer(input_ids) | |
# The new embedding - our special birb word | |
replacement_token_embedding = embed[list(embed.keys())[0]].to(torch_device) | |
# Insert this into the token embeddings | |
token_embeddings[0, torch.where(input_ids[0] == 338)] = replacement_token_embedding.to(torch_device) | |
# Combine with pos embs | |
input_embeddings = token_embeddings + position_embeddings | |
# Feed through to get final output embs | |
modified_output_embeddings = get_output_embeds(input_embeddings) | |
# And the uncond. input as before: | |
max_length = text_input.input_ids.shape[-1] | |
uncond_input = tokenizer( | |
[""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt" | |
) | |
with torch.no_grad(): | |
uncond_embeddings = text_encoder(uncond_input.input_ids.to(torch_device))[0] | |
text_embeddings = torch.cat([uncond_embeddings, modified_output_embeddings]) | |
# Prep Scheduler | |
scheduler.set_timesteps(num_inference_steps) | |
# Prep latents | |
latents = torch.randn( | |
(batch_size, unet.config.in_channels, height // 8, width // 8), | |
generator=generator, | |
) | |
latents = latents.to(torch_device) | |
latents = latents * scheduler.init_noise_sigma | |
# Loop | |
for i, t in tqdm(enumerate(scheduler.timesteps), total=len(scheduler.timesteps)): | |
# Expand the latents if we are doing classifier-free guidance to avoid doing two forward passes. | |
latent_model_input = torch.cat([latents] * 2) | |
sigma = scheduler.sigmas[i] | |
latent_model_input = scheduler.scale_model_input(latent_model_input, t) | |
# Predict the noise residual | |
with torch.no_grad(): | |
noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"] | |
# Perform CFG | |
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | |
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) | |
# Additional Guidance | |
if i % 5 == 0: | |
# Requires grad on the latents | |
latents = latents.detach().requires_grad_() | |
# Get the predicted x0 | |
latents_x0 = latents - sigma * noise_pred | |
# Decode to image space | |
denoised_images = vae.decode((1 / 0.18215) * latents_x0).sample / 2 + 0.5 | |
# Calculate losses | |
contrast_loss_val = contrast_loss(denoised_images) * contrast_loss_scale | |
blue_loss_val = blue_loss(denoised_images) * blue_loss_scale | |
# Combine losses | |
total_loss = contrast_loss_val + blue_loss_val | |
# Get gradient | |
cond_grad = torch.autograd.grad(total_loss, latents)[0] | |
# Modify the latents based on this gradient | |
latents = latents.detach() - cond_grad * sigma**2 | |
# Now step with scheduler | |
latents = scheduler.step(noise_pred, t, latents).prev_sample | |
return latents_to_pil(latents)[0] | |
import gradio as gr | |
dict_styles = { | |
'Dr Strange': 'styles/learned_embeds_dr_strange.bin', | |
'GTA-5':'styles/learned_embeds_gta5.bin', | |
'Manga':'styles/learned_embeds_manga.bin', | |
'Pokemon':'styles/learned_embeds_pokemon.bin', | |
} | |
def inference(prompt, style): | |
if prompt is not None and style is not None: | |
style = dict_styles[style] | |
result = generate_with_prompt_style_guidance(prompt, style) | |
return np.array(result) | |
else: | |
return None | |
title = "Stable Diffusion and Textual Inversion" | |
description = "A simple Gradio interface to stylize Stable Diffusion outputs" | |
examples = [['A man sipping wine wearing a spacesuit on the moon']] | |
demo = gr.Interface(inference, | |
inputs=[gr.Textbox(label='Prompt'), | |
gr.Dropdown(['Dr Strange', 'GTA-5', 'Manga', 'Pokemon'], label='Style')], | |
outputs=[gr.Image(label="Stable Diffusion Output")], | |
title=title, | |
description=description, | |
examples=examples) | |
demo.launch() |