Shivdutta's picture
Update app.py
e84a799 verified
raw
history blame
5.39 kB
import torch
import numpy as np
from torchvision.transforms.functional import to_tensor
from PIL import Image
def blue_loss(images):
"""
Custom loss function to penalize or encourage the presence of blue hues in the images.
"""
# Convert images to tensors
images_tensor = torch.tensor(images).float() / 255.0
# Extract the blue channel (last channel in RGB)
blue_channel = images_tensor[:, :, :, 2]
# Calculate variance of the blue channel
variance = torch.var(blue_channel)
# Return negative variance as the loss (penalize less blue)
return -variance
def generate_with_prompt_style_guidance(prompt, style, seed=42):
prompt = prompt + ' in style of s'
embed = torch.load(style)
height = 512
width = 512
num_inference_steps = 10
guidance_scale = 8
generator = torch.manual_seed(seed)
batch_size = 1
contrast_loss_scale = 200
blue_loss_scale = 100 # Scale for blue loss
# Prep text
text_input = tokenizer([prompt], padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
with torch.no_grad():
text_embeddings = text_encoder(text_input.input_ids.to(torch_device))[0]
input_ids = text_input.input_ids.to(torch_device)
# Get token embeddings
token_embeddings = token_emb_layer(input_ids)
# The new embedding - our special birb word
replacement_token_embedding = embed[list(embed.keys())[0]].to(torch_device)
# Insert this into the token embeddings
token_embeddings[0, torch.where(input_ids[0] == 338)] = replacement_token_embedding.to(torch_device)
# Combine with pos embs
input_embeddings = token_embeddings + position_embeddings
# Feed through to get final output embs
modified_output_embeddings = get_output_embeds(input_embeddings)
# And the uncond. input as before:
max_length = text_input.input_ids.shape[-1]
uncond_input = tokenizer(
[""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
)
with torch.no_grad():
uncond_embeddings = text_encoder(uncond_input.input_ids.to(torch_device))[0]
text_embeddings = torch.cat([uncond_embeddings, modified_output_embeddings])
# Prep Scheduler
scheduler.set_timesteps(num_inference_steps)
# Prep latents
latents = torch.randn(
(batch_size, unet.config.in_channels, height // 8, width // 8),
generator=generator,
)
latents = latents.to(torch_device)
latents = latents * scheduler.init_noise_sigma
# Loop
for i, t in tqdm(enumerate(scheduler.timesteps), total=len(scheduler.timesteps)):
# Expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
latent_model_input = torch.cat([latents] * 2)
sigma = scheduler.sigmas[i]
latent_model_input = scheduler.scale_model_input(latent_model_input, t)
# Predict the noise residual
with torch.no_grad():
noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"]
# Perform CFG
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# Additional Guidance
if i % 5 == 0:
# Requires grad on the latents
latents = latents.detach().requires_grad_()
# Get the predicted x0
latents_x0 = latents - sigma * noise_pred
# Decode to image space
denoised_images = vae.decode((1 / 0.18215) * latents_x0).sample / 2 + 0.5
# Calculate losses
contrast_loss_val = contrast_loss(denoised_images) * contrast_loss_scale
blue_loss_val = blue_loss(denoised_images) * blue_loss_scale
# Combine losses
total_loss = contrast_loss_val + blue_loss_val
# Get gradient
cond_grad = torch.autograd.grad(total_loss, latents)[0]
# Modify the latents based on this gradient
latents = latents.detach() - cond_grad * sigma**2
# Now step with scheduler
latents = scheduler.step(noise_pred, t, latents).prev_sample
return latents_to_pil(latents)[0]
import gradio as gr
dict_styles = {
'Dr Strange': 'styles/learned_embeds_dr_strange.bin',
'GTA-5':'styles/learned_embeds_gta5.bin',
'Manga':'styles/learned_embeds_manga.bin',
'Pokemon':'styles/learned_embeds_pokemon.bin',
}
def inference(prompt, style):
if prompt is not None and style is not None:
style = dict_styles[style]
result = generate_with_prompt_style_guidance(prompt, style)
return np.array(result)
else:
return None
title = "Stable Diffusion and Textual Inversion"
description = "A simple Gradio interface to stylize Stable Diffusion outputs"
examples = [['A man sipping wine wearing a spacesuit on the moon']]
demo = gr.Interface(inference,
inputs=[gr.Textbox(label='Prompt'),
gr.Dropdown(['Dr Strange', 'GTA-5', 'Manga', 'Pokemon'], label='Style')],
outputs=[gr.Image(label="Stable Diffusion Output")],
title=title,
description=description,
examples=examples)
demo.launch()