Spaces:
Running
on
Zero
Running
on
Zero
| # Authors: Hui Ren (rhfeiyang.github.io) | |
| import torch | |
| from PIL import Image | |
| import argparse | |
| import os, json, random | |
| import matplotlib.pyplot as plt | |
| import glob, re | |
| from tqdm import tqdm | |
| import numpy as np | |
| import sys | |
| import gc | |
| from transformers import CLIPTextModel, CLIPTokenizer, BertModel, BertTokenizer | |
| # import train_util | |
| from utils.train_util import get_noisy_image, encode_prompts | |
| from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel, LMSDiscreteScheduler, DDIMScheduler, PNDMScheduler | |
| from typing import Any, Dict, List, Optional, Tuple, Union | |
| from utils.lora import LoRANetwork, DEFAULT_TARGET_REPLACE, UNET_TARGET_REPLACE_MODULE_CONV | |
| import argparse | |
| # from diffusers.training_utils import EMAModel | |
| import shutil | |
| import yaml | |
| from easydict import EasyDict | |
| from utils.metrics import StyleContentMetric | |
| from torchvision import transforms | |
| from custom_datasets.coco import CustomCocoCaptions | |
| from custom_datasets.imagepair import ImageSet | |
| from custom_datasets import get_dataset | |
| # from stable_diffusion.utils.modules import get_diffusion_modules | |
| # from diffusers import StableDiffusionImg2ImgPipeline | |
| from diffusers.utils.torch_utils import randn_tensor | |
| import pickle | |
| import time | |
| from datetime import datetime | |
| def flush(): | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| def get_train_method(lora_weight): | |
| if lora_weight is None: | |
| return 'None' | |
| if 'full' in lora_weight: | |
| train_method = 'full' | |
| elif "down_1_up_2_attn" in lora_weight: | |
| train_method = 'up_2_attn' | |
| print(f"Using up_2_attn for {lora_weight}") | |
| elif "down_2_up_1_up_2_attn" in lora_weight: | |
| train_method = 'down_2_up_2_attn' | |
| elif "down_2_up_2_attn" in lora_weight: | |
| train_method = 'down_2_up_2_attn' | |
| elif "down_2_attn" in lora_weight: | |
| train_method = 'down_2_attn' | |
| elif 'noxattn' in lora_weight: | |
| train_method = 'noxattn' | |
| elif "xattn" in lora_weight: | |
| train_method = 'xattn' | |
| elif "attn" in lora_weight: | |
| train_method = 'attn' | |
| elif "all_up" in lora_weight: | |
| train_method = 'all_up' | |
| else: | |
| train_method = 'None' | |
| return train_method | |
| def get_validation_dataloader(infer_prompts:list[str]=None, infer_images :list[str]=None,resolution=512, batch_size=10, num_workers=4, val_set="laion_pop500"): | |
| data_transforms = transforms.Compose( | |
| [ | |
| transforms.Resize(resolution), | |
| transforms.CenterCrop(resolution), | |
| ] | |
| ) | |
| def preprocess(example): | |
| ret={} | |
| ret["image"] = data_transforms(example["image"]) if "image" in example else None | |
| if "caption" in example: | |
| if isinstance(example["caption"][0], list): | |
| ret["caption"] = example["caption"][0][0] | |
| else: | |
| ret["caption"] = example["caption"][0] | |
| if "seed" in example: | |
| ret["seed"] = example["seed"] | |
| if "id" in example: | |
| ret["id"] = example["id"] | |
| if "path" in example: | |
| ret["path"] = example["path"] | |
| return ret | |
| def collate_fn(examples): | |
| out = {} | |
| if "image" in examples[0]: | |
| pixel_values = [example["image"] for example in examples] | |
| out["pixel_values"] = pixel_values | |
| # notice: only take the first prompt for each image | |
| if "caption" in examples[0]: | |
| prompts = [example["caption"] for example in examples] | |
| out["prompts"] = prompts | |
| if "seed" in examples[0]: | |
| seeds = [example["seed"] for example in examples] | |
| out["seed"] = seeds | |
| if "path" in examples[0]: | |
| paths = [example["path"] for example in examples] | |
| out["path"] = paths | |
| return out | |
| if infer_prompts is None: | |
| if val_set == "lhq500": | |
| dataset = get_dataset("lhq_sub500", get_val=False)["train"] | |
| elif val_set == "custom_coco100": | |
| dataset = get_dataset("custom_coco100", get_val=False)["train"] | |
| elif val_set == "custom_coco500": | |
| dataset = get_dataset("custom_coco500", get_val=False)["train"] | |
| elif os.path.isdir(val_set): | |
| image_folder = os.path.join(val_set, "paintings") | |
| caption_folder = os.path.join(val_set, "captions") | |
| dataset = ImageSet(folder=image_folder, caption=caption_folder, keep_in_mem=True) | |
| elif "custom_caption" in val_set: | |
| from custom_datasets.custom_caption import Caption_set | |
| name = val_set.replace("custom_caption_", "") | |
| dataset = Caption_set(set_name = name) | |
| elif val_set == "laion_pop500": | |
| dataset = get_dataset("laion_pop500", get_val=False)["train"] | |
| elif val_set == "laion_pop500_first_sentence": | |
| dataset = get_dataset("laion_pop500_first_sentence", get_val=False)["train"] | |
| else: | |
| raise ValueError("Unknown dataset") | |
| dataset.with_transform(preprocess) | |
| elif isinstance(infer_prompts, torch.utils.data.Dataset): | |
| dataset = infer_prompts | |
| try: | |
| dataset.with_transform(preprocess) | |
| except: | |
| pass | |
| else: | |
| class Dataset(torch.utils.data.Dataset): | |
| def __init__(self, prompts, images=None): | |
| self.prompts = prompts | |
| self.images = images | |
| self.get_img = False | |
| if images is not None: | |
| assert len(prompts) == len(images) | |
| self.get_img = True | |
| if isinstance(images[0], str): | |
| self.images = [Image.open(image).convert("RGB") for image in images] | |
| else: | |
| self.images = [None] * len(prompts) | |
| def __len__(self): | |
| return len(self.prompts) | |
| def __getitem__(self, idx): | |
| img = self.images[idx] | |
| if self.get_img and img is not None: | |
| img = data_transforms(img) | |
| return {"caption": self.prompts[idx], "image":img} | |
| dataset = Dataset(infer_prompts, infer_images) | |
| dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn, drop_last=False, | |
| num_workers=num_workers, pin_memory=True) | |
| return dataloader | |
| def get_lora_network(unet , lora_path, train_method="None", rank=1, alpha=1.0, device="cuda", weight_dtype=torch.float32): | |
| if train_method in [None, "None"]: | |
| train_method = get_train_method(lora_path) | |
| print(f"Train method: {train_method}") | |
| network_type = "c3lier" | |
| if train_method == 'xattn': | |
| network_type = 'lierla' | |
| modules = DEFAULT_TARGET_REPLACE | |
| if network_type == "c3lier": | |
| modules += UNET_TARGET_REPLACE_MODULE_CONV | |
| alpha = 1 | |
| if "rank" in lora_path: | |
| rank = int(re.search(r'rank(\d+)', lora_path).group(1)) | |
| if 'alpha1' in lora_path: | |
| alpha = 1.0 | |
| print(f"Rank: {rank}, Alpha: {alpha}") | |
| network = LoRANetwork( | |
| unet, | |
| rank=rank, | |
| multiplier=1.0, | |
| alpha=alpha, | |
| train_method=train_method, | |
| ).to(device, dtype=weight_dtype) | |
| if lora_path not in [None, "None"]: | |
| lora_state_dict = torch.load(lora_path) | |
| miss = network.load_state_dict(lora_state_dict, strict=False) | |
| print(f"Missing: {miss}") | |
| ret = {"network": network, "train_method": train_method} | |
| return ret | |
| def get_model(pretrained_ckpt_path, unet_ckpt=None,revision=None, variant=None, lora_path=None, weight_dtype=torch.float32, | |
| device="cuda"): | |
| modules = {} | |
| pipe = DiffusionPipeline.from_pretrained(pretrained_ckpt_path, revision=revision, variant=variant) | |
| if unet_ckpt is not None: | |
| pipe.unet.from_pretrained(unet_ckpt, subfolder="unet_ema", revision=revision, variant=variant) | |
| unet = pipe.unet | |
| vae = pipe.vae | |
| text_encoder = pipe.text_encoder | |
| tokenizer = pipe.tokenizer | |
| modules["unet"] = unet | |
| modules["vae"] = vae | |
| modules["text_encoder"] = text_encoder | |
| modules["tokenizer"] = tokenizer | |
| # tokenizer = modules["tokenizer"] | |
| unet.enable_xformers_memory_efficient_attention() | |
| unet.to(device, dtype=weight_dtype) | |
| if weight_dtype != torch.bfloat16: | |
| vae.to(device, dtype=torch.float32) | |
| else: | |
| vae.to(device, dtype=weight_dtype) | |
| text_encoder.to(device, dtype=weight_dtype) | |
| if lora_path is not None: | |
| network = get_lora_network(unet, lora_path, device=device, weight_dtype=weight_dtype) | |
| modules["network"] = network | |
| return modules | |
| def inference(network: LoRANetwork, tokenizer: CLIPTokenizer, text_encoder: CLIPTextModel, vae: AutoencoderKL, unet: UNet2DConditionModel, noise_scheduler: LMSDiscreteScheduler, | |
| dataloader, height:int, width:int, scales:list = np.linspace(0,2,5),save_dir:str=None, seed:int = None, | |
| weight_dtype: torch.dtype = torch.float32, device: torch.device="cuda", batch_size:int=1, steps:int=50, guidance_scale:float=7.5, start_noise:int=800, | |
| uncond_prompt:str=None, uncond_embed=None, style_prompt = None, show:bool = False, no_load:bool=False, from_scratch=False): | |
| print(f"current time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") | |
| print(f"save dir: {save_dir}") | |
| if start_noise < 0: | |
| assert from_scratch | |
| network = network.eval() | |
| unet = unet.eval() | |
| vae = vae.eval() | |
| do_convert = not from_scratch | |
| if not do_convert: | |
| try: | |
| dataloader.dataset.get_img = False | |
| except: | |
| pass | |
| scales = list(scales) | |
| else: | |
| scales = ["Real Image"] + list(scales) | |
| if not no_load and os.path.exists(os.path.join(save_dir, "infer_imgs.pickle")): | |
| with open(os.path.join(save_dir, "infer_imgs.pickle"), 'rb') as f: | |
| pred_images = pickle.load(f) | |
| take=True | |
| for key in scales: | |
| if key not in pred_images: | |
| take=False | |
| break | |
| if take: | |
| print(f"Found existing inference results in {save_dir}", flush=True) | |
| return pred_images | |
| max_length = tokenizer.model_max_length | |
| pred_images = {scale :[] for scale in scales} | |
| all_seeds = {scale:[] for scale in scales} | |
| prompts = [] | |
| ori_prompts = [] | |
| if save_dir is not None: | |
| img_output_dir = os.path.join(save_dir, "outputs") | |
| os.makedirs(img_output_dir, exist_ok=True) | |
| if uncond_embed is None: | |
| if uncond_prompt is None: | |
| uncond_input_text = [""] | |
| else: | |
| uncond_input_text = [uncond_prompt] | |
| uncond_embed = encode_prompts(tokenizer = tokenizer, text_encoder = text_encoder, prompts = uncond_input_text) | |
| for batch in dataloader: | |
| ori_prompt = batch["prompts"] | |
| image = batch["pixel_values"] if do_convert else None | |
| if do_convert: | |
| pred_images["Real Image"] += image | |
| if isinstance(ori_prompt, list): | |
| if isinstance(text_encoder, CLIPTextModel): | |
| # trunc prompts for clip encoder | |
| ori_prompt = [p.split(".")[0]+"." for p in ori_prompt] | |
| prompt = [f"{p.strip()[::-1].replace('.', '',1)[::-1]} in the style of {style_prompt}" for p in ori_prompt] if style_prompt is not None else ori_prompt | |
| else: | |
| if isinstance(text_encoder, CLIPTextModel): | |
| ori_prompt = ori_prompt.split(".")[0]+"." | |
| prompt = f"{prompt.strip()[::-1].replace('.', '',1)[::-1]} in the style of {style_prompt}" if style_prompt is not None else ori_prompt | |
| bcz = len(prompt) | |
| single_seed = seed | |
| if dataloader.batch_size == 1 and seed is None: | |
| if "seed" in batch: | |
| single_seed = batch["seed"][0] | |
| print(f"{prompt}, seed={single_seed}") | |
| # text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt").to(device) | |
| # original_embeddings = text_encoder(**text_input)[0] | |
| prompts += prompt | |
| ori_prompts += ori_prompt | |
| # style_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt").to(device) | |
| # # style_embeddings = text_encoder(**style_input)[0] | |
| # style_embeddings = text_encoder(style_input.input_ids, return_dict=False)[0] | |
| style_embeddings = encode_prompts(tokenizer = tokenizer, text_encoder = text_encoder, prompts = prompt) | |
| original_embeddings = encode_prompts(tokenizer = tokenizer, text_encoder = text_encoder, prompts = ori_prompt) | |
| if uncond_embed.shape[0] == 1 and bcz > 1: | |
| uncond_embeddings = uncond_embed.repeat(bcz, 1, 1) | |
| else: | |
| uncond_embeddings = uncond_embed | |
| style_text_embeddings = torch.cat([uncond_embeddings, style_embeddings]).to(weight_dtype) | |
| # original_embeddings = torch.cat([uncond_embeddings, original_embeddings]).to(weight_dtype) | |
| generator = torch.manual_seed(single_seed) if single_seed is not None else None | |
| noise_scheduler.set_timesteps(steps) | |
| if do_convert: | |
| noised_latent, _, _ = get_noisy_image(image, vae, generator, unet, noise_scheduler, total_timesteps=int((1000-start_noise)/1000 *steps)) | |
| else: | |
| latent_shape = (bcz, 4, height//8, width//8) | |
| noised_latent = randn_tensor(latent_shape, generator=generator, device=vae.device) | |
| noised_latent = noised_latent.to(unet.dtype) | |
| noised_latent = noised_latent * noise_scheduler.init_noise_sigma | |
| for scale in scales: | |
| start_time = time.time() | |
| if not isinstance(scale, float) and not isinstance(scale, int): | |
| continue | |
| latents = noised_latent.clone().to(weight_dtype).to(device) | |
| noise_scheduler.set_timesteps(steps) | |
| for t in tqdm(noise_scheduler.timesteps): | |
| if do_convert and t>start_noise: | |
| continue | |
| else: | |
| if t > start_noise and start_noise >= 0: | |
| current_scale = 0 | |
| else: | |
| current_scale = scale | |
| network.set_lora_slider(scale=current_scale) | |
| text_embedding = style_text_embeddings | |
| # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes. | |
| latent_model_input = torch.cat([latents] * 2).to(weight_dtype) | |
| latent_model_input = noise_scheduler.scale_model_input(latent_model_input, timestep=t).to(weight_dtype) | |
| # predict the noise residual | |
| with network: | |
| # print(f"dtype: {latent_model_input.dtype}, {text_embedding.dtype}, t={t}") | |
| noise_pred = unet(latent_model_input, t , encoder_hidden_states=text_embedding).sample | |
| # perform guidance | |
| noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | |
| noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) | |
| # compute the previous noisy sample x_t -> x_t-1 | |
| if isinstance(noise_scheduler, DDPMScheduler): | |
| latents = noise_scheduler.step(noise_pred, t, latents, generator=torch.manual_seed(single_seed+t) if single_seed is not None else None).prev_sample | |
| else: | |
| latents = noise_scheduler.step(noise_pred, t, latents).prev_sample | |
| # scale and decode the image latents with vae | |
| latents = 1 / 0.18215 * latents.to(vae.dtype) | |
| with torch.no_grad(): | |
| image = vae.decode(latents).sample | |
| image = (image / 2 + 0.5).clamp(0, 1) | |
| image = image.detach().cpu().permute(0, 2, 3, 1).to(torch.float32).numpy() | |
| images = (image * 255).round().astype("uint8") | |
| pil_images = [Image.fromarray(image) for image in images] | |
| pred_images[scale]+=pil_images | |
| all_seeds[scale] += [single_seed] * bcz | |
| end_time = time.time() | |
| print(f"Time taken for one batch, Art Adapter scale={scale}: {end_time-start_time}", flush=True) | |
| if save_dir is not None or show: | |
| end_idx = len(list(pred_images.values())[0]) | |
| for i in range(end_idx-bcz, end_idx): | |
| keys = list(pred_images.keys()) | |
| images_list = [pred_images[key][i] for key in keys] | |
| prompt = prompts[i] | |
| if len(scales)==1: | |
| plt.imshow(images_list[0]) | |
| plt.axis('off') | |
| plt.title(f"{prompt}_{single_seed}_start{start_noise}", fontsize=20) | |
| else: | |
| fig, ax = plt.subplots(1, len(images_list), figsize=(len(scales)*5,6), layout="constrained") | |
| for id, a in enumerate(ax): | |
| a.imshow(images_list[id]) | |
| if isinstance(scales[id], float) or isinstance(scales[id], int): | |
| a.set_title(f"Art Adapter scale={scales[id]}", fontsize=20) | |
| else: | |
| a.set_title(f"{keys[id]}", fontsize=20) | |
| a.axis('off') | |
| # plt.suptitle(f"{os.path.basename(lora_weight).replace('.pt','')}", fontsize=20) | |
| # plt.tight_layout() | |
| # if do_convert: | |
| # plt.suptitle(f"{prompt}\nseed{single_seed}_start{start_noise}_guidance{guidance_scale}", fontsize=20) | |
| # else: | |
| # plt.suptitle(f"{prompt}\nseed{single_seed}_from_scratch_guidance{guidance_scale}", fontsize=20) | |
| if save_dir is not None: | |
| plt.savefig(f"{img_output_dir}/{prompt.replace(' ', '_')[:100]}_seed{single_seed}_start{start_noise}.png") | |
| if show: | |
| plt.show() | |
| plt.close() | |
| flush() | |
| if save_dir is not None: | |
| with open(os.path.join(save_dir, "infer_imgs.pickle" ), 'wb') as f: | |
| pickle.dump(pred_images, f) | |
| with open(os.path.join(save_dir, "all_seeds.pickle"), 'wb') as f: | |
| to_save={"all_seeds":all_seeds, "batch_size":batch_size} | |
| pickle.dump(to_save, f) | |
| for scale, images in pred_images.items(): | |
| subfolder = os.path.join(save_dir,"images", f"{scale}") | |
| os.makedirs(subfolder, exist_ok=True) | |
| used_prompt = ori_prompts | |
| if (isinstance(scale, float) or isinstance(scale, int)): #and scale != 0: | |
| used_prompt = prompts | |
| for i, image in enumerate(images): | |
| if scale == "Real Image": | |
| suffix = "" | |
| else: | |
| suffix = f"_seed{all_seeds[scale][i]}" | |
| image.save(os.path.join(subfolder, f"{used_prompt[i].replace(' ', '_')[:100]}{suffix}.jpg")) | |
| with open(os.path.join(save_dir, "infer_prompts.txt"), 'w') as f: | |
| for prompt in prompts: | |
| f.write(f"{prompt}\n") | |
| with open(os.path.join(save_dir, "ori_prompts.txt"), 'w') as f: | |
| for prompt in ori_prompts: | |
| f.write(f"{prompt}\n") | |
| print(f"Saved inference results to {save_dir}", flush=True) | |
| return pred_images, prompts | |
| def infer_metric(ref_image_folder,pred_images, prompts, save_dir, start_noise=""): | |
| prompts = [prompt.split(" in the style of ")[0] for prompt in prompts] | |
| scores = {} | |
| original_images = pred_images["Real Image"] if "Real Image" in pred_images else None | |
| metric = StyleContentMetric(ref_image_folder) | |
| for scale, images in pred_images.items(): | |
| score = metric(images, original_images, prompts) | |
| scores[scale] = score | |
| print(f"Style transfer score at scale {scale}: {score}") | |
| scores["ref_path"] = ref_image_folder | |
| save_name = f"scores_start{start_noise}.json" | |
| os.makedirs(save_dir, exist_ok=True) | |
| with open(os.path.join(save_dir, save_name), 'w') as f: | |
| json.dump(scores, f, indent=2) | |
| return scores | |
| def parse_args(): | |
| parser = argparse.ArgumentParser(description='Inference with LoRA') | |
| parser.add_argument('--lora_weights', type=str, default=["None"], | |
| nargs='+', help='path to your model file') | |
| parser.add_argument('--prompts', type=str, default=[], | |
| nargs='+', help='prompts to try') | |
| parser.add_argument("--prompt_file", type=str, default=None, help="path to the prompt file") | |
| parser.add_argument("--prompt_file_key", type=str, default="prompts", help="key to the prompt file") | |
| parser.add_argument('--resolution', type=int, default=512, help='resolution of the image') | |
| parser.add_argument('--seed', type=int, default=None, help='seed for the random number generator') | |
| parser.add_argument("--start_noise", type=int, default=800, help="start noise") | |
| parser.add_argument("--from_scratch", default=False, action="store_true", help="from scratch") | |
| parser.add_argument("--ref_image_folder", type=str, default=None, help="folder containing reference images") | |
| parser.add_argument("--show", action="store_true", help="show the image") | |
| parser.add_argument("--batch_size", type=int, default=1, help="batch size") | |
| parser.add_argument("--scales", type=float, default=[0.,1.], nargs='+', help="scales to test") | |
| parser.add_argument("--train_method", type=str, default=None, help="train method") | |
| # parser.add_argument("--vae_path", type=str, default="CompVis/stable-diffusion-v1-4", help="Path to the VAE model.") | |
| # parser.add_argument("--text_encoder_path", type=str, default="CompVis/stable-diffusion-v1-4", help="Path to the text encoder model.") | |
| parser.add_argument("--pretrained_model_name_or_path", type=str, default="rhfeiyang/art-free-diffusion-v1", help="Path to the pretrained model.") | |
| parser.add_argument("--unet_ckpt", default=None, type=str, help="Path to the unet checkpoint") | |
| parser.add_argument("--guidance_scale", type=float, default=5.0, help="guidance scale") | |
| parser.add_argument("--infer_mode", default="sks_art", help="inference mode") #, choices=["style", "ori", "artist", "sks_art","Peter"] | |
| parser.add_argument("--save_dir", type=str, default="inference_output", help="save directory") | |
| parser.add_argument("--num_workers", type=int, default=4, help="number of workers") | |
| parser.add_argument("--no_load", action="store_true", help="no load the pre-inferred results") | |
| parser.add_argument("--infer_prompts", type=str, default=None, nargs="+", help="prompts to infer") | |
| parser.add_argument("--infer_images", type=str, default=None, nargs="+", help="images to infer") | |
| parser.add_argument("--rank", type=int, default=1, help="rank of the lora") | |
| parser.add_argument("--val_set", type=str, default="laion_pop500", help="validation set") | |
| parser.add_argument("--folder_name", type=str, default=None, help="folder name") | |
| parser.add_argument("--scheduler_type",type=str, choices=["ddpm", "ddim", "pndm","lms"], default="ddpm", help="scheduler type") | |
| parser.add_argument("--infer_steps", type=int, default=50, help="inference steps") | |
| parser.add_argument("--weight_dtype", type=str, default="fp32", help="weight dtype") | |
| parser.add_argument("--custom_coco_cap", action="store_true", help="use custom coco caption") | |
| args = parser.parse_args() | |
| if args.infer_prompts is not None and len(args.infer_prompts) == 1 and os.path.isfile(args.infer_prompts[0]): | |
| if args.infer_prompts[0].endswith(".txt") and args.custom_coco_cap: | |
| args.infer_prompts = CustomCocoCaptions(custom_file=args.infer_prompts[0]) | |
| elif args.infer_prompts[0].endswith(".txt"): | |
| with open(args.infer_prompts[0], 'r') as f: | |
| args.infer_prompts = f.readlines() | |
| args.infer_prompts = [prompt.strip() for prompt in args.infer_prompts] | |
| elif args.infer_prompts[0].endswith(".csv"): | |
| from custom_datasets.custom_caption import Caption_set | |
| caption_set = Caption_set(args.infer_prompts[0]) | |
| args.infer_prompts = caption_set | |
| if args.infer_mode == "style": | |
| with open(os.path.join(args.ref_image_folder, "style_label.txt"), 'r') as f: | |
| args.style_label = f.readlines()[0].strip() | |
| elif args.infer_mode == "artist": | |
| with open(os.path.join(args.ref_image_folder, "style_label.txt"), 'r') as f: | |
| args.style_label = f.readlines()[0].strip() | |
| args.style_label = args.style_label.split(",")[0].strip() | |
| elif args.infer_mode == "ori": | |
| args.style_label = None | |
| else: | |
| args.style_label = args.infer_mode.replace("_", " ") | |
| if args.ref_image_folder is not None: | |
| args.ref_image_folder = os.path.join(args.ref_image_folder, "paintings") | |
| if args.start_noise < 0: | |
| args.from_scratch = True | |
| print(args.__dict__) | |
| return args | |
| def main(args): | |
| lora_weights = args.lora_weights | |
| if len(lora_weights) == 1 and isinstance(lora_weights[0], str) and os.path.isdir(lora_weights[0]): | |
| lora_weights = glob.glob(os.path.join(lora_weights[0], "*.pt")) | |
| lora_weights=sorted(lora_weights, reverse=True) | |
| width = args.resolution | |
| height = args.resolution | |
| steps = args.infer_steps | |
| revision = None | |
| device = 'cuda' | |
| rank = args.rank | |
| if args.weight_dtype == "fp32": | |
| weight_dtype = torch.float32 | |
| elif args.weight_dtype=="fp16": | |
| weight_dtype = torch.float16 | |
| elif args.weight_dtype=="bf16": | |
| weight_dtype = torch.bfloat16 | |
| modules = get_model(args.pretrained_model_name_or_path, unet_ckpt=args.unet_ckpt, revision=revision, variant=None, lora_path=None, weight_dtype=weight_dtype, device=device, ) | |
| if args.scheduler_type == "pndm": | |
| noise_scheduler = PNDMScheduler.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="scheduler") | |
| elif args.scheduler_type == "ddpm": | |
| noise_scheduler = DDPMScheduler.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="scheduler") | |
| elif args.scheduler_type == "ddim": | |
| noise_scheduler = DDIMScheduler( | |
| beta_start=0.00085, | |
| beta_end=0.012, | |
| beta_schedule="scaled_linear", | |
| num_train_timesteps=1000, | |
| clip_sample=False, | |
| prediction_type="epsilon", | |
| ) | |
| elif args.scheduler_type == "lms": | |
| noise_scheduler = LMSDiscreteScheduler(beta_start=0.00085, | |
| beta_end=0.012, | |
| beta_schedule="scaled_linear", | |
| num_train_timesteps=1000) | |
| else: | |
| raise ValueError("Unknown scheduler type") | |
| cache=EasyDict() | |
| cache.modules = modules | |
| unet = modules["unet"] | |
| vae = modules["vae"] | |
| text_encoder = modules["text_encoder"] | |
| tokenizer = modules["tokenizer"] | |
| unet.requires_grad_(False) | |
| # Move unet, vae and text_encoder to device and cast to weight_dtype | |
| vae.requires_grad_(False) | |
| text_encoder.requires_grad_(False) | |
| ## dataloader | |
| dataloader = get_validation_dataloader(infer_prompts=args.infer_prompts, infer_images=args.infer_images, | |
| resolution=args.resolution, | |
| batch_size=args.batch_size, num_workers=args.num_workers, | |
| val_set=args.val_set) | |
| for lora_weight in lora_weights: | |
| print(f"Testing {lora_weight}") | |
| # for different seeds on same prompt | |
| seed = args.seed | |
| network_ret = get_lora_network(unet, lora_weight, train_method=args.train_method, rank=rank, alpha=1.0, device=device, weight_dtype=weight_dtype) | |
| network = network_ret["network"] | |
| train_method = network_ret["train_method"] | |
| if args.save_dir is not None: | |
| save_dir = args.save_dir | |
| if args.style_label is not None: | |
| save_dir = os.path.join(save_dir, f"{args.style_label.replace(' ', '_')}") | |
| else: | |
| save_dir = os.path.join(save_dir, f"ori/{args.start_noise}") | |
| else: | |
| if args.folder_name is not None: | |
| folder_name = args.folder_name | |
| else: | |
| folder_name = "validation" if args.infer_prompts is None else "validation_prompts" | |
| save_dir = os.path.join(os.path.dirname(lora_weight), f"{folder_name}/{train_method}", os.path.basename(lora_weight).replace('.pt','').split('_')[-1]) | |
| if args.infer_prompts is None: | |
| save_dir = os.path.join(save_dir, f"{args.val_set}") | |
| infer_config = f"{args.scheduler_type}{args.infer_steps}_{args.weight_dtype}_guidance{args.guidance_scale}" | |
| save_dir = os.path.join(save_dir, infer_config) | |
| os.makedirs(save_dir, exist_ok=True) | |
| if args.from_scratch: | |
| save_dir = os.path.join(save_dir, "from_scratch") | |
| else: | |
| save_dir = os.path.join(save_dir, "transfer") | |
| save_dir = os.path.join(save_dir, f"start{args.start_noise}") | |
| os.makedirs(save_dir, exist_ok=True) | |
| with open(os.path.join(save_dir, "infer_args.yaml"), 'w') as f: | |
| yaml.dump(vars(args), f) | |
| # save code | |
| code_dir = os.path.join(save_dir, "code") | |
| os.makedirs(code_dir, exist_ok=True) | |
| current_file = os.path.basename(__file__) | |
| shutil.copy(__file__, os.path.join(code_dir, current_file)) | |
| with torch.no_grad(): | |
| pred_images, prompts = inference(network, tokenizer, text_encoder, vae, unet, noise_scheduler, dataloader, height, width, | |
| args.scales, save_dir, seed, weight_dtype, device, args.batch_size, steps, guidance_scale=args.guidance_scale, | |
| start_noise=args.start_noise, show=args.show, style_prompt=args.style_label, no_load=args.no_load, | |
| from_scratch=args.from_scratch) | |
| if args.ref_image_folder is not None: | |
| flush() | |
| print("Calculating metrics") | |
| infer_metric(args.ref_image_folder, pred_images, save_dir, args.start_noise) | |
| if __name__ == "__main__": | |
| args = parse_args() | |
| main(args) |