Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| import torchvision | |
| import os | |
| import gc | |
| import tqdm | |
| import matplotlib.pyplot as plt | |
| import torchvision.transforms as transforms | |
| from transformers import CLIPTextModel | |
| from peft import PeftModel, LoraConfig | |
| from lora_w2w import LoRAw2w | |
| from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel, LMSDiscreteScheduler | |
| from peft.utils.save_and_load import load_peft_weights, set_peft_model_state_dict | |
| from transformers import AutoTokenizer, PretrainedConfig | |
| from PIL import Image | |
| import warnings | |
| warnings.filterwarnings("ignore") | |
| from diffusers import ( | |
| AutoencoderKL, | |
| DDPMScheduler, | |
| DiffusionPipeline, | |
| DPMSolverMultistepScheduler, | |
| UNet2DConditionModel, | |
| PNDMScheduler, | |
| StableDiffusionPipeline | |
| ) | |
| ######## Sampling utilities | |
| def sample_weights(unet, proj, mean, std, v, device, factor = 1.0): | |
| # get mean and standard deviation for each principal component | |
| m = torch.mean(proj, 0) | |
| standev = torch.std(proj, 0) | |
| del proj | |
| torch.cuda.empty_cache() | |
| # sample | |
| sample = torch.zeros([1, 1000]).to(device) | |
| for i in range(1000): | |
| sample[0, i] = torch.normal(m[i], factor*standev[i], (1,1)) | |
| # load weights into network | |
| network = LoRAw2w( sample, mean, std, v, | |
| unet, | |
| rank=1, | |
| multiplier=1.0, | |
| alpha=27.0, | |
| train_method="xattn-strict" | |
| ).to(device, torch.bfloat16) | |
| return network | |