Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| import torchvision | |
| import os | |
| import gc | |
| import tqdm | |
| import matplotlib.pyplot as plt | |
| import torchvision.transforms as transforms | |
| from lora_w2w import LoRAw2w | |
| from transformers import AutoTokenizer, PretrainedConfig | |
| from PIL import Image | |
| import warnings | |
| warnings.filterwarnings("ignore") | |
| ######## Editing Utilities | |
| def get_direction(df, label, pinverse, return_dim, device): | |
| ### get labels | |
| labels = [] | |
| for folder in list(df.index): | |
| labels.append(df.loc[folder][label]) | |
| labels = torch.Tensor(labels).to(device).bfloat16() | |
| ### solve least squares | |
| direction = (pinverse@labels).unsqueeze(0) | |
| if return_dim == 1000: | |
| return direction | |
| else: | |
| direction = torch.cat((direction, torch.zeros((1, return_dim-1000)).to(device)), dim=1) | |
| return direction | |
| def debias(direction, label, df, pinverse, device): | |
| ### get labels | |
| labels = [] | |
| for folder in list(df.index): | |
| labels.append(df.loc[folder][label]) | |
| labels = torch.Tensor(labels).to(device).bfloat16() | |
| ### solve least squares | |
| d = (pinverse@labels).unsqueeze(0) | |
| ###align dimensionalities of the two vectors | |
| if direction.shape[1] == 1000: | |
| pass | |
| else: | |
| d = torch.cat((d, torch.zeros((1, direction.shape[1]-1000)).to(device)), dim=1) | |
| #remove this component from the direction | |
| direction = direction - (([email protected])/(torch.norm(d)**2))*d | |
| return direction | |
| def edit_inference(network, edited_weights, unet, vae, text_encoder, tokenizer, prompt, negative_prompt, guidance_scale, noise_scheduler, ddim_steps, start_noise, seed, generator, device): | |
| original_weights = network.proj.clone() | |
| generator = generator.manual_seed(seed) | |
| latents = torch.randn( | |
| (1, unet.in_channels, 512 // 8, 512 // 8), | |
| generator = generator, | |
| device = device | |
| ).bfloat16() | |
| text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt") | |
| text_embeddings = text_encoder(text_input.input_ids.to(device))[0] | |
| max_length = text_input.input_ids.shape[-1] | |
| uncond_input = tokenizer( | |
| [negative_prompt], padding="max_length", max_length=max_length, return_tensors="pt" | |
| ) | |
| uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0] | |
| text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) | |
| noise_scheduler.set_timesteps(ddim_steps) | |
| latents = latents * noise_scheduler.init_noise_sigma | |
| for i,t in enumerate(tqdm.tqdm(noise_scheduler.timesteps)): | |
| latent_model_input = torch.cat([latents] * 2) | |
| latent_model_input = noise_scheduler.scale_model_input(latent_model_input, timestep=t) | |
| if t>start_noise: | |
| pass | |
| elif t<=start_noise: | |
| network.proj = torch.nn.Parameter(edited_weights) | |
| network.reset() | |
| with network: | |
| noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings, timestep_cond= None).sample | |
| #guidance | |
| noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | |
| noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) | |
| latents = noise_scheduler.step(noise_pred, t, latents).prev_sample | |
| latents = 1 / 0.18215 * latents | |
| image = vae.decode(latents).sample | |
| image = (image / 2 + 0.5).clamp(0, 1) | |
| #reset weights back to original | |
| network.proj = torch.nn.Parameter(original_weights) | |
| network.reset() | |
| return image | |