import random from pathlib import Path import os import torch from torchmetrics.functional.image import ( peak_signal_noise_ratio, learned_perceptual_image_patch_similarity, ) from PIL import Image from skimage.color import rgb2lab, lab2rgb import numpy as np import cv2 from torchvision import transforms RESAMPLE_MODE = Image.BICUBIC def skip_iterator(iterator, skip): for i, item in enumerate(iterator): if i % skip == 0: yield item def generate_output_structure(output_dir, subfolders=[]): """ Generate a directory structure for the output. and return the paths to the subfolders. as template """ output_dir = Path(output_dir) output_dir.mkdir(exist_ok=True, parents=True) output_paths = [] for subfolder in subfolders: output_paths.append(Path(os.path.join(output_dir, subfolder))) (output_paths[-1]).mkdir(exist_ok=True, parents=True) output_paths[-1] = os.path.join(output_paths[-1], "{}.png") return output_paths def find_files(path, ext="png"): if os.path.isdir(path): path = Path(path) sorted_files = sorted(list(path.glob(f"*.{ext}"))) return sorted_files else: return [path] def load_guidance_image(path, size=None): """ Load an image and convert it to a tensor. Args: path to the image returns: tensor of the image of shape (1, 3, H, W) """ img = Image.open(path) img = img.convert("RGB") tf = transforms.Compose([ transforms.Resize(size), transforms.CenterCrop(size), transforms.ToTensor() ]) img = tf(img) * 2 - 1 return img.unsqueeze(0) def yield_images(path, ext="png", size=None): files = find_files(path, ext) for file in files: yield load_guidance_image(file, size) def yield_videos(paths, ext="png", H=None, W=None, n_frames=61): for path in paths: yield read_video(path, H, W, n_frames) def read_video(path, H=None, W=None, n_frames=61) -> list[Image]: path = Path(path) frames = [] if Path(path).is_dir(): files = sorted(list(path.glob("*.png"))) for file in files[:n_frames]: image = Image.open(file) image.load() if H is not None and W is not None: image = image.resize((W, H), resample=Image.BICUBIC) # to tensor image = ( torch.tensor(np.array(image), dtype=torch.float32).permute(2, 0, 1) / 255.0 * 2 - 1 ) frames.append(image) H, W = frames[0].size()[-2:] frames = torch.stack(frames).unsqueeze(0) return frames, (10, H, W) capture = cv2.VideoCapture(str(path)) fps = capture.get(cv2.CAP_PROP_FPS) while True: success, frame = capture.read() if not success: break frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) frame = Image.fromarray(frame) if H is not None and W is not None: frame = frame.resize((W, H), resample=Image.BICUBIC) # to tensor frame = ( torch.tensor(np.array(frame), dtype=torch.float32).permute(2, 0, 1) / 255.0 * 2 - 1 ) frames.append(frame) capture.release() # to torch frames = torch.stack(frames).unsqueeze(0) return frames, (fps, W, H) def resize_video( video: list[Image], width, height, resample_mode=RESAMPLE_MODE ) -> list[Image]: frames_lr = [] for frame in video: frame_lr = frame.resize((width, height), resample_mode) frames_lr.append(frame_lr) return frames_lr def export_to_video( video_frames, output_video_path=None, fps=8, put_numbers=False, annotations=None, fourcc="mp4v", ): fourcc = cv2.VideoWriter_fourcc(*fourcc) # codec writer = cv2.VideoWriter(output_video_path, fourcc, fps, video_frames[0].size) for i, frame in enumerate(video_frames): frame = cv2.cvtColor(np.array(frame), cv2.COLOR_RGB2BGR) if put_numbers: text_position = (frame.shape[1] - 60, 30) font = cv2.FONT_HERSHEY_SIMPLEX font_scale = 1 font_color = (255, 255, 255) line_type = 2 cv2.putText( frame, f"{i + 1}", text_position, font, font_scale, font_color, line_type, ) if annotations: annotation = annotations[i] frame = draw_bodypose( frame, annotation["candidates"], annotation["subsets"] ) writer.write(frame) writer.release() def export_images(frames: list[Image], dir_name): dir_name = Path(dir_name) dir_name.mkdir(exist_ok=True, parents=True) for i, frame in enumerate(frames): frame.save(dir_name / f"{i:05d}.png") def vid2tensor(images: list[Image]) -> torch.Tensor: # PIL to numpy if not isinstance(images, list): raise ValueError() images = [np.array(image).astype(np.float32) / 255.0 for image in images] images = np.stack(images, axis=0) if images.ndim == 3: # L mode, add luminance channel images = np.expand_dims(images, -1) # numpy to torch images = torch.from_numpy(images.transpose(0, 3, 1, 2)) return images def compute_metrics( source: list[Image], output: list[Image], output_lq: list[Image], target: list[Image], ) -> dict: psnr_ab = torch.tensor( np.mean(compute_color_metrics(output, target)["psnr_ab"]) ).float() source = vid2tensor(source) output = vid2tensor(output) output_lq = vid2tensor(output_lq) target = vid2tensor(target) mse = ((output - target) ** 2).mean() psnr = peak_signal_noise_ratio(output, target, data_range=1.0, dim=(1, 2, 3)) # lpips = learned_perceptual_image_patch_similarity(output, target) mse_source = ((output_lq - source) ** 2).mean() psnr_source = peak_signal_noise_ratio( output_lq, source, data_range=1.0, dim=(1, 2, 3) ) return { "mse": mse.detach().cpu().item(), "psnr": psnr.detach().cpu().item(), "psnr_ab": psnr_ab.detach().cpu().item(), "mse_source": mse_source.detach().cpu().item(), "psnr_source": psnr_source.detach().cpu().item(), } def compute_psnr_ab(x, y_gt, pp_max=202.33542248): """Computes the PSNR of the ab color channels. Note that the CIE-Lab space is asymmetric. The maximum size for the 2 channels of the ab subspace is approximately 202.3354... pp_max: Approximated maximum swing for the ab channels of the CIE-Lab color space max_{x \in CIE-Lab} {x_a x_b} - min_{x \in CIE-Lab} {x_a x_b} """ assert ( len(x.shape) == 3 ), f"Expecting data of the size HW2 but found {x.shape}; This should be a,b channels of CIE-Lab Space" assert ( len(y_gt.shape) == 3 ), f"Expecting data of the size HW2 but found {y_gt.shape}; This should be a,b channels of CIE-Lab Space" assert ( x.shape == y_gt.shape ), f"Expecting data to have identical shape but found {y_gt.shape} != {x.shape}" H, W, C = x.shape assert ( C == 2 ), f"This function assumes that both x & y are both the ab channels of the CIE-Lab Space" MSE = np.sum((x - y_gt) ** 2) / (H * W * C) # C=2, two channels MSE = np.clip(MSE, 1e-12, np.inf) PSNR_ab = 10 * np.log10(pp_max**2) - 10 * np.log10(MSE) return PSNR_ab def compute_color_metrics(out: list[Image], target: list[Image]): if len(out) != len(target): raise ValueError("Videos do not have same length") metrics = {"psnr_ab": []} for out_frame, target_frame in zip(out, target): out_frame, target_frame = np.asarray(out_frame), np.asarray(target_frame) out_frame_lab, target_frame_lab = rgb2lab(out_frame), rgb2lab(target_frame) psnr_ab = compute_psnr_ab(out_frame_lab[..., 1:3], target_frame_lab[..., 1:3]) metrics["psnr_ab"].append(psnr_ab.item()) return metrics def to_device(sample, device): result = {} for key, val in sample.items(): if isinstance(val, torch.Tensor): result[key] = val.to(device) elif isinstance(val, list): new_val = [] for e in val: if isinstance(e, torch.Tensor): new_val.append(e.to(device)) else: new_val.append(val) result[key] = new_val else: result[key] = val return result def seed_all(seed=42): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) def adapt_unet(unet, lora_rank=None, in_conv_mode="zeros"): # adapt conv_in kernel = unet.conv_in.weight.data if in_conv_mode == "zeros": new_kernel = torch.zeros(320, 4, 3, 3, dtype=kernel.dtype, device=kernel.device) elif in_conv_mode == "reflect": new_kernel = kernel[:, 4:].clone() else: raise NotImplementedError unet.conv_in.weight.data = torch.cat([kernel, new_kernel], dim=1) if in_conv_mode == "reflect": unet.conv_in.weight.data *= 2.0 / 3.0 unet.conv_in.in_channels = 12 if lora_rank is not None: from peft import LoraConfig types = (torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Linear, torch.nn.Embedding) target_modules = [ (n, m) for n, m in unet.named_modules() if isinstance(m, types) ] # identify parameters (not modules) that will not be lora'd for _, m in target_modules: m.requires_grad_(False) not_adapted = [p for p in unet.parameters() if p.requires_grad] unet_lora_config = LoraConfig( r=lora_rank, lora_alpha=lora_rank, init_lora_weights="gaussian", # target_modules=["to_k", "to_q", "to_v", "to_out.0"], target_modules=[n for n, _ in target_modules], ) # the following line sets all parameters except the loras to non-trainable unet.add_adapter(unet_lora_config) unet.conv_in.requires_grad_() for p in not_adapted: p.requires_grad_() def repeat_infinite(iterable): def repeated(): while True: yield from iterable return repeated class CPUAdam: def __init__(self, params, lr=0.001, betas=(0.9, 0.999), eps=1e-8, weight_decay=0): self.params = list(params) self.lr = lr self.betas = betas self.eps = eps self.weight_decay = weight_decay # keep this in main memory to save VRAM self.state = { param: { "step": 0, "exp_avg": torch.zeros_like(param, device="cpu"), "exp_avg_sq": torch.zeros_like(param, device="cpu"), } for param in self.params } def step(self): for param in self.params: if param.grad is None: continue grad = param.grad.data.cpu() if self.weight_decay != 0: grad.add_(param.data, alpha=self.weight_decay) state = self.state[param] exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] beta1, beta2 = self.betas state["step"] += 1 exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) denom = exp_avg_sq.sqrt().add_(self.eps) step_size = ( self.lr * (1 - beta2 ** state["step"]) ** 0.5 / (1 - beta1 ** state["step"]) ) # param.data.add_((-step_size * (exp_avg / denom)).cuda()) param.data.addcdiv_(exp_avg.cuda(), denom.cuda(), value=-step_size) def zero_grad(self): for param in self.params: param.grad = None def state_dict(self): return self.state def set_lr(self, lr): self.lr = lr