FLAIR / src /flair /utils.py
juliuse's picture
Initial commit: track binaries with LFS
90a9dd3
import random
from pathlib import Path
import os
import torch
from torchmetrics.functional.image import (
peak_signal_noise_ratio,
learned_perceptual_image_patch_similarity,
)
from PIL import Image
from skimage.color import rgb2lab, lab2rgb
import numpy as np
import cv2
from torchvision import transforms
RESAMPLE_MODE = Image.BICUBIC
def skip_iterator(iterator, skip):
for i, item in enumerate(iterator):
if i % skip == 0:
yield item
def generate_output_structure(output_dir, subfolders=[]):
"""
Generate a directory structure for the output. and return the paths to the subfolders. as template
"""
output_dir = Path(output_dir)
output_dir.mkdir(exist_ok=True, parents=True)
output_paths = []
for subfolder in subfolders:
output_paths.append(Path(os.path.join(output_dir, subfolder)))
(output_paths[-1]).mkdir(exist_ok=True, parents=True)
output_paths[-1] = os.path.join(output_paths[-1], "{}.png")
return output_paths
def find_files(path, ext="png"):
if os.path.isdir(path):
path = Path(path)
sorted_files = sorted(list(path.glob(f"*.{ext}")))
return sorted_files
else:
return [path]
def load_guidance_image(path, size=None):
"""
Load an image and convert it to a tensor.
Args: path to the image
returns: tensor of the image of shape (1, 3, H, W)
"""
img = Image.open(path)
img = img.convert("RGB")
tf = transforms.Compose([
transforms.Resize(size),
transforms.CenterCrop(size),
transforms.ToTensor()
])
img = tf(img) * 2 - 1
return img.unsqueeze(0)
def yield_images(path, ext="png", size=None):
files = find_files(path, ext)
for file in files:
yield load_guidance_image(file, size)
def yield_videos(paths, ext="png", H=None, W=None, n_frames=61):
for path in paths:
yield read_video(path, H, W, n_frames)
def read_video(path, H=None, W=None, n_frames=61) -> list[Image]:
path = Path(path)
frames = []
if Path(path).is_dir():
files = sorted(list(path.glob("*.png")))
for file in files[:n_frames]:
image = Image.open(file)
image.load()
if H is not None and W is not None:
image = image.resize((W, H), resample=Image.BICUBIC)
# to tensor
image = (
torch.tensor(np.array(image), dtype=torch.float32).permute(2, 0, 1)
/ 255.0
* 2
- 1
)
frames.append(image)
H, W = frames[0].size()[-2:]
frames = torch.stack(frames).unsqueeze(0)
return frames, (10, H, W)
capture = cv2.VideoCapture(str(path))
fps = capture.get(cv2.CAP_PROP_FPS)
while True:
success, frame = capture.read()
if not success:
break
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frame = Image.fromarray(frame)
if H is not None and W is not None:
frame = frame.resize((W, H), resample=Image.BICUBIC)
# to tensor
frame = (
torch.tensor(np.array(frame), dtype=torch.float32).permute(2, 0, 1)
/ 255.0
* 2
- 1
)
frames.append(frame)
capture.release()
# to torch
frames = torch.stack(frames).unsqueeze(0)
return frames, (fps, W, H)
def resize_video(
video: list[Image], width, height, resample_mode=RESAMPLE_MODE
) -> list[Image]:
frames_lr = []
for frame in video:
frame_lr = frame.resize((width, height), resample_mode)
frames_lr.append(frame_lr)
return frames_lr
def export_to_video(
video_frames,
output_video_path=None,
fps=8,
put_numbers=False,
annotations=None,
fourcc="mp4v",
):
fourcc = cv2.VideoWriter_fourcc(*fourcc) # codec
writer = cv2.VideoWriter(output_video_path, fourcc, fps, video_frames[0].size)
for i, frame in enumerate(video_frames):
frame = cv2.cvtColor(np.array(frame), cv2.COLOR_RGB2BGR)
if put_numbers:
text_position = (frame.shape[1] - 60, 30)
font = cv2.FONT_HERSHEY_SIMPLEX
font_scale = 1
font_color = (255, 255, 255)
line_type = 2
cv2.putText(
frame,
f"{i + 1}",
text_position,
font,
font_scale,
font_color,
line_type,
)
if annotations:
annotation = annotations[i]
frame = draw_bodypose(
frame, annotation["candidates"], annotation["subsets"]
)
writer.write(frame)
writer.release()
def export_images(frames: list[Image], dir_name):
dir_name = Path(dir_name)
dir_name.mkdir(exist_ok=True, parents=True)
for i, frame in enumerate(frames):
frame.save(dir_name / f"{i:05d}.png")
def vid2tensor(images: list[Image]) -> torch.Tensor:
# PIL to numpy
if not isinstance(images, list):
raise ValueError()
images = [np.array(image).astype(np.float32) / 255.0 for image in images]
images = np.stack(images, axis=0)
if images.ndim == 3:
# L mode, add luminance channel
images = np.expand_dims(images, -1)
# numpy to torch
images = torch.from_numpy(images.transpose(0, 3, 1, 2))
return images
def compute_metrics(
source: list[Image],
output: list[Image],
output_lq: list[Image],
target: list[Image],
) -> dict:
psnr_ab = torch.tensor(
np.mean(compute_color_metrics(output, target)["psnr_ab"])
).float()
source = vid2tensor(source)
output = vid2tensor(output)
output_lq = vid2tensor(output_lq)
target = vid2tensor(target)
mse = ((output - target) ** 2).mean()
psnr = peak_signal_noise_ratio(output, target, data_range=1.0, dim=(1, 2, 3))
# lpips = learned_perceptual_image_patch_similarity(output, target)
mse_source = ((output_lq - source) ** 2).mean()
psnr_source = peak_signal_noise_ratio(
output_lq, source, data_range=1.0, dim=(1, 2, 3)
)
return {
"mse": mse.detach().cpu().item(),
"psnr": psnr.detach().cpu().item(),
"psnr_ab": psnr_ab.detach().cpu().item(),
"mse_source": mse_source.detach().cpu().item(),
"psnr_source": psnr_source.detach().cpu().item(),
}
def compute_psnr_ab(x, y_gt, pp_max=202.33542248):
"""Computes the PSNR of the ab color channels.
Note that the CIE-Lab space is asymmetric.
The maximum size for the 2 channels of the ab subspace is approximately 202.3354...
pp_max: Approximated maximum swing for the ab channels of the CIE-Lab color space
max_{x \in CIE-Lab} {x_a x_b} - min_{x \in CIE-Lab} {x_a x_b}
"""
assert (
len(x.shape) == 3
), f"Expecting data of the size HW2 but found {x.shape}; This should be a,b channels of CIE-Lab Space"
assert (
len(y_gt.shape) == 3
), f"Expecting data of the size HW2 but found {y_gt.shape}; This should be a,b channels of CIE-Lab Space"
assert (
x.shape == y_gt.shape
), f"Expecting data to have identical shape but found {y_gt.shape} != {x.shape}"
H, W, C = x.shape
assert (
C == 2
), f"This function assumes that both x & y are both the ab channels of the CIE-Lab Space"
MSE = np.sum((x - y_gt) ** 2) / (H * W * C) # C=2, two channels
MSE = np.clip(MSE, 1e-12, np.inf)
PSNR_ab = 10 * np.log10(pp_max**2) - 10 * np.log10(MSE)
return PSNR_ab
def compute_color_metrics(out: list[Image], target: list[Image]):
if len(out) != len(target):
raise ValueError("Videos do not have same length")
metrics = {"psnr_ab": []}
for out_frame, target_frame in zip(out, target):
out_frame, target_frame = np.asarray(out_frame), np.asarray(target_frame)
out_frame_lab, target_frame_lab = rgb2lab(out_frame), rgb2lab(target_frame)
psnr_ab = compute_psnr_ab(out_frame_lab[..., 1:3], target_frame_lab[..., 1:3])
metrics["psnr_ab"].append(psnr_ab.item())
return metrics
def to_device(sample, device):
result = {}
for key, val in sample.items():
if isinstance(val, torch.Tensor):
result[key] = val.to(device)
elif isinstance(val, list):
new_val = []
for e in val:
if isinstance(e, torch.Tensor):
new_val.append(e.to(device))
else:
new_val.append(val)
result[key] = new_val
else:
result[key] = val
return result
def seed_all(seed=42):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
def adapt_unet(unet, lora_rank=None, in_conv_mode="zeros"):
# adapt conv_in
kernel = unet.conv_in.weight.data
if in_conv_mode == "zeros":
new_kernel = torch.zeros(320, 4, 3, 3, dtype=kernel.dtype, device=kernel.device)
elif in_conv_mode == "reflect":
new_kernel = kernel[:, 4:].clone()
else:
raise NotImplementedError
unet.conv_in.weight.data = torch.cat([kernel, new_kernel], dim=1)
if in_conv_mode == "reflect":
unet.conv_in.weight.data *= 2.0 / 3.0
unet.conv_in.in_channels = 12
if lora_rank is not None:
from peft import LoraConfig
types = (torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Linear, torch.nn.Embedding)
target_modules = [
(n, m) for n, m in unet.named_modules() if isinstance(m, types)
]
# identify parameters (not modules) that will not be lora'd
for _, m in target_modules:
m.requires_grad_(False)
not_adapted = [p for p in unet.parameters() if p.requires_grad]
unet_lora_config = LoraConfig(
r=lora_rank,
lora_alpha=lora_rank,
init_lora_weights="gaussian",
# target_modules=["to_k", "to_q", "to_v", "to_out.0"],
target_modules=[n for n, _ in target_modules],
)
# the following line sets all parameters except the loras to non-trainable
unet.add_adapter(unet_lora_config)
unet.conv_in.requires_grad_()
for p in not_adapted:
p.requires_grad_()
def repeat_infinite(iterable):
def repeated():
while True:
yield from iterable
return repeated
class CPUAdam:
def __init__(self, params, lr=0.001, betas=(0.9, 0.999), eps=1e-8, weight_decay=0):
self.params = list(params)
self.lr = lr
self.betas = betas
self.eps = eps
self.weight_decay = weight_decay
# keep this in main memory to save VRAM
self.state = {
param: {
"step": 0,
"exp_avg": torch.zeros_like(param, device="cpu"),
"exp_avg_sq": torch.zeros_like(param, device="cpu"),
}
for param in self.params
}
def step(self):
for param in self.params:
if param.grad is None:
continue
grad = param.grad.data.cpu()
if self.weight_decay != 0:
grad.add_(param.data, alpha=self.weight_decay)
state = self.state[param]
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
beta1, beta2 = self.betas
state["step"] += 1
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
denom = exp_avg_sq.sqrt().add_(self.eps)
step_size = (
self.lr
* (1 - beta2 ** state["step"]) ** 0.5
/ (1 - beta1 ** state["step"])
)
# param.data.add_((-step_size * (exp_avg / denom)).cuda())
param.data.addcdiv_(exp_avg.cuda(), denom.cuda(), value=-step_size)
def zero_grad(self):
for param in self.params:
param.grad = None
def state_dict(self):
return self.state
def set_lr(self, lr):
self.lr = lr