Spaces:
Runtime error
Runtime error
| import torch | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| torch.backends.cudnn.allow_tf32 = True | |
| import torch.distributed as dist | |
| from torch.utils.data import Dataset, DataLoader | |
| from torch.utils.data.distributed import DistributedSampler | |
| from torchvision.datasets import ImageFolder | |
| from torchvision import transforms | |
| from tqdm import tqdm | |
| import os | |
| import itertools | |
| from PIL import Image | |
| import numpy as np | |
| import argparse | |
| import random | |
| from skimage.metrics import peak_signal_noise_ratio as psnr_loss | |
| from skimage.metrics import structural_similarity as ssim_loss | |
| from diffusers.models import AutoencoderKL | |
| class SingleFolderDataset(Dataset): | |
| def __init__(self, directory, transform=None): | |
| super().__init__() | |
| self.directory = directory | |
| self.transform = transform | |
| self.image_paths = [os.path.join(directory, file_name) for file_name in os.listdir(directory) | |
| if os.path.isfile(os.path.join(directory, file_name))] | |
| def __len__(self): | |
| return len(self.image_paths) | |
| def __getitem__(self, idx): | |
| image_path = self.image_paths[idx] | |
| image = Image.open(image_path).convert('RGB') | |
| if self.transform: | |
| image = self.transform(image) | |
| return image, torch.tensor(0) | |
| def create_npz_from_sample_folder(sample_dir, num=50_000): | |
| """ | |
| Builds a single .npz file from a folder of .png samples. | |
| """ | |
| samples = [] | |
| for i in tqdm(range(num), desc="Building .npz file from samples"): | |
| sample_pil = Image.open(f"{sample_dir}/{i:06d}.png") | |
| sample_np = np.asarray(sample_pil).astype(np.uint8) | |
| samples.append(sample_np) | |
| random.shuffle(samples) # This is very important for IS(Inception Score) !!! | |
| samples = np.stack(samples) | |
| assert samples.shape == (num, samples.shape[1], samples.shape[2], 3) | |
| npz_path = f"{sample_dir}.npz" | |
| np.savez(npz_path, arr_0=samples) | |
| print(f"Saved .npz file to {npz_path} [shape={samples.shape}].") | |
| return npz_path | |
| def center_crop_arr(pil_image, image_size): | |
| """ | |
| Center cropping implementation from ADM. | |
| https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126 | |
| """ | |
| while min(*pil_image.size) >= 2 * image_size: | |
| pil_image = pil_image.resize( | |
| tuple(x // 2 for x in pil_image.size), resample=Image.BOX | |
| ) | |
| scale = image_size / min(*pil_image.size) | |
| pil_image = pil_image.resize( | |
| tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC | |
| ) | |
| arr = np.array(pil_image) | |
| crop_y = (arr.shape[0] - image_size) // 2 | |
| crop_x = (arr.shape[1] - image_size) // 2 | |
| return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size]) | |
| def main(args): | |
| # Setup PyTorch: | |
| assert torch.cuda.is_available(), "Sampling with DDP requires at least one GPU. sample.py supports CPU-only usage" | |
| torch.set_grad_enabled(False) | |
| # Setup DDP: | |
| dist.init_process_group("nccl") | |
| rank = dist.get_rank() | |
| device = rank % torch.cuda.device_count() | |
| seed = args.global_seed * dist.get_world_size() + rank | |
| torch.manual_seed(seed) | |
| torch.cuda.set_device(device) | |
| print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.") | |
| # load vae | |
| vae = AutoencoderKL.from_pretrained(f"stabilityai/{args.vae}").to(device) | |
| # Create folder to save samples: | |
| folder_name = f"stabilityai-{args.vae}-{args.dataset}-size-{args.image_size}-seed-{args.global_seed}" | |
| sample_folder_dir = f"{args.sample_dir}/{folder_name}" | |
| if rank == 0: | |
| os.makedirs(sample_folder_dir, exist_ok=True) | |
| print(f"Saving .png samples at {sample_folder_dir}") | |
| dist.barrier() | |
| # Setup data: | |
| transform = transforms.Compose([ | |
| transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, args.image_size)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True) | |
| ]) | |
| if args.dataset == 'imagenet': | |
| dataset = ImageFolder(args.data_path, transform=transform) | |
| num_fid_samples = 50000 | |
| elif args.dataset == 'coco': | |
| dataset = SingleFolderDataset(args.data_path, transform=transform) | |
| num_fid_samples = 5000 | |
| else: | |
| raise Exception("please check dataset") | |
| sampler = DistributedSampler( | |
| dataset, | |
| num_replicas=dist.get_world_size(), | |
| rank=rank, | |
| shuffle=False, | |
| seed=args.global_seed | |
| ) | |
| loader = DataLoader( | |
| dataset, | |
| batch_size=args.per_proc_batch_size, | |
| shuffle=False, | |
| sampler=sampler, | |
| num_workers=args.num_workers, | |
| pin_memory=True, | |
| drop_last=False | |
| ) | |
| # Figure out how many samples we need to generate on each GPU and how many iterations we need to run: | |
| n = args.per_proc_batch_size | |
| global_batch_size = n * dist.get_world_size() | |
| psnr_val_rgb = [] | |
| ssim_val_rgb = [] | |
| loader = tqdm(loader) if rank == 0 else loader | |
| total = 0 | |
| for x, _ in loader: | |
| rgb_gts = x | |
| rgb_gts = (rgb_gts.permute(0, 2, 3, 1).to("cpu").numpy() + 1.0) / 2.0 # rgb_gt value is between [0, 1] | |
| x = x.to(device) | |
| with torch.no_grad(): | |
| # Map input images to latent space + normalize latents: | |
| latent = vae.encode(x).latent_dist.sample().mul_(0.18215) | |
| # reconstruct: | |
| samples = vae.decode(latent / 0.18215).sample # output value is between [-1, 1] | |
| samples = torch.clamp(127.5 * samples + 128.0, 0, 255).permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy() | |
| # Save samples to disk as individual .png files | |
| for i, (sample, rgb_gt) in enumerate(zip(samples, rgb_gts)): | |
| index = i * dist.get_world_size() + rank + total | |
| Image.fromarray(sample).save(f"{sample_folder_dir}/{index:06d}.png") | |
| # metric | |
| rgb_restored = sample.astype(np.float32) / 255. # rgb_restored value is between [0, 1] | |
| psnr = psnr_loss(rgb_restored, rgb_gt) | |
| ssim = ssim_loss(rgb_restored, rgb_gt, multichannel=True, data_range=2.0, channel_axis=-1) | |
| psnr_val_rgb.append(psnr) | |
| ssim_val_rgb.append(ssim) | |
| total += global_batch_size | |
| # ------------------------------------ | |
| # Summary | |
| # ------------------------------------ | |
| # Make sure all processes have finished saving their samples | |
| dist.barrier() | |
| world_size = dist.get_world_size() | |
| gather_psnr_val = [None for _ in range(world_size)] | |
| gather_ssim_val = [None for _ in range(world_size)] | |
| dist.all_gather_object(gather_psnr_val, psnr_val_rgb) | |
| dist.all_gather_object(gather_ssim_val, ssim_val_rgb) | |
| if rank == 0: | |
| gather_psnr_val = list(itertools.chain(*gather_psnr_val)) | |
| gather_ssim_val = list(itertools.chain(*gather_ssim_val)) | |
| psnr_val_rgb = sum(gather_psnr_val) / len(gather_psnr_val) | |
| ssim_val_rgb = sum(gather_ssim_val) / len(gather_ssim_val) | |
| print("PSNR: %f, SSIM: %f " % (psnr_val_rgb, ssim_val_rgb)) | |
| result_file = f"{sample_folder_dir}_results.txt" | |
| print("writing results to {}".format(result_file)) | |
| with open(result_file, 'w') as f: | |
| print("PSNR: %f, SSIM: %f " % (psnr_val_rgb, ssim_val_rgb), file=f) | |
| create_npz_from_sample_folder(sample_folder_dir, num_fid_samples) | |
| print("Done.") | |
| dist.barrier() | |
| dist.destroy_process_group() | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--data-path", type=str, required=True) | |
| parser.add_argument("--dataset", type=str, choices=['imagenet', 'coco'], default='imagenet') | |
| parser.add_argument("--vae", type=str, choices=["sdxl-vae", "sd-vae-ft-mse"], default="sd-vae-ft-mse") | |
| parser.add_argument("--image-size", type=int, choices=[256, 512], default=256) | |
| parser.add_argument("--sample-dir", type=str, default="reconstructions") | |
| parser.add_argument("--per-proc-batch-size", type=int, default=32) | |
| parser.add_argument("--global-seed", type=int, default=0) | |
| parser.add_argument("--num-workers", type=int, default=4) | |
| args = parser.parse_args() | |
| main(args) |