Spaces:
Sleeping
Sleeping
File size: 5,391 Bytes
9091147 e84a799 9091147 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 |
import torch
import numpy as np
from torchvision.transforms.functional import to_tensor
from PIL import Image
def blue_loss(images):
"""
Custom loss function to penalize or encourage the presence of blue hues in the images.
"""
# Convert images to tensors
images_tensor = torch.tensor(images).float() / 255.0
# Extract the blue channel (last channel in RGB)
blue_channel = images_tensor[:, :, :, 2]
# Calculate variance of the blue channel
variance = torch.var(blue_channel)
# Return negative variance as the loss (penalize less blue)
return -variance
def generate_with_prompt_style_guidance(prompt, style, seed=42):
prompt = prompt + ' in style of s'
embed = torch.load(style)
height = 512
width = 512
num_inference_steps = 10
guidance_scale = 8
generator = torch.manual_seed(seed)
batch_size = 1
contrast_loss_scale = 200
blue_loss_scale = 100 # Scale for blue loss
# Prep text
text_input = tokenizer([prompt], padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
with torch.no_grad():
text_embeddings = text_encoder(text_input.input_ids.to(torch_device))[0]
input_ids = text_input.input_ids.to(torch_device)
# Get token embeddings
token_embeddings = token_emb_layer(input_ids)
# The new embedding - our special birb word
replacement_token_embedding = embed[list(embed.keys())[0]].to(torch_device)
# Insert this into the token embeddings
token_embeddings[0, torch.where(input_ids[0] == 338)] = replacement_token_embedding.to(torch_device)
# Combine with pos embs
input_embeddings = token_embeddings + position_embeddings
# Feed through to get final output embs
modified_output_embeddings = get_output_embeds(input_embeddings)
# And the uncond. input as before:
max_length = text_input.input_ids.shape[-1]
uncond_input = tokenizer(
[""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
)
with torch.no_grad():
uncond_embeddings = text_encoder(uncond_input.input_ids.to(torch_device))[0]
text_embeddings = torch.cat([uncond_embeddings, modified_output_embeddings])
# Prep Scheduler
scheduler.set_timesteps(num_inference_steps)
# Prep latents
latents = torch.randn(
(batch_size, unet.config.in_channels, height // 8, width // 8),
generator=generator,
)
latents = latents.to(torch_device)
latents = latents * scheduler.init_noise_sigma
# Loop
for i, t in tqdm(enumerate(scheduler.timesteps), total=len(scheduler.timesteps)):
# Expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
latent_model_input = torch.cat([latents] * 2)
sigma = scheduler.sigmas[i]
latent_model_input = scheduler.scale_model_input(latent_model_input, t)
# Predict the noise residual
with torch.no_grad():
noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"]
# Perform CFG
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# Additional Guidance
if i % 5 == 0:
# Requires grad on the latents
latents = latents.detach().requires_grad_()
# Get the predicted x0
latents_x0 = latents - sigma * noise_pred
# Decode to image space
denoised_images = vae.decode((1 / 0.18215) * latents_x0).sample / 2 + 0.5
# Calculate losses
contrast_loss_val = contrast_loss(denoised_images) * contrast_loss_scale
blue_loss_val = blue_loss(denoised_images) * blue_loss_scale
# Combine losses
total_loss = contrast_loss_val + blue_loss_val
# Get gradient
cond_grad = torch.autograd.grad(total_loss, latents)[0]
# Modify the latents based on this gradient
latents = latents.detach() - cond_grad * sigma**2
# Now step with scheduler
latents = scheduler.step(noise_pred, t, latents).prev_sample
return latents_to_pil(latents)[0]
import gradio as gr
dict_styles = {
'Dr Strange': 'styles/learned_embeds_dr_strange.bin',
'GTA-5':'styles/learned_embeds_gta5.bin',
'Manga':'styles/learned_embeds_manga.bin',
'Pokemon':'styles/learned_embeds_pokemon.bin',
}
def inference(prompt, style):
if prompt is not None and style is not None:
style = dict_styles[style]
result = generate_with_prompt_style_guidance(prompt, style)
return np.array(result)
else:
return None
title = "Stable Diffusion and Textual Inversion"
description = "A simple Gradio interface to stylize Stable Diffusion outputs"
examples = [['A man sipping wine wearing a spacesuit on the moon']]
demo = gr.Interface(inference,
inputs=[gr.Textbox(label='Prompt'),
gr.Dropdown(['Dr Strange', 'GTA-5', 'Manga', 'Pokemon'], label='Style')],
outputs=[gr.Image(label="Stable Diffusion Output")],
title=title,
description=description,
examples=examples)
demo.launch() |