Spaces:
Runtime error
Runtime error
| import torch | |
| import os | |
| import math | |
| import torch.nn.functional as F | |
| # https://github.com/universome/fvd-comparison | |
| def load_i3d_pretrained(device=torch.device('cpu')): | |
| i3D_WEIGHTS_URL = "https://www.dropbox.com/s/ge9e5ujwgetktms/i3d_torchscript.pt" | |
| filepath = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'i3d_torchscript.pt') | |
| print(filepath) | |
| if not os.path.exists(filepath): | |
| print(f"preparing for download {i3D_WEIGHTS_URL}, you can download it by yourself.") | |
| os.system(f"wget {i3D_WEIGHTS_URL} -O {filepath}") | |
| i3d = torch.jit.load(filepath).eval().to(device) | |
| i3d = torch.nn.DataParallel(i3d) | |
| return i3d | |
| def get_feats(videos, detector, device, bs=10): | |
| # videos : torch.tensor BCTHW [0, 1] | |
| detector_kwargs = dict(rescale=False, resize=False, return_features=True) # Return raw features before the softmax layer. | |
| feats = np.empty((0, 400)) | |
| with torch.no_grad(): | |
| for i in range((len(videos)-1)//bs + 1): | |
| feats = np.vstack([feats, detector(torch.stack([preprocess_single(video) for video in videos[i*bs:(i+1)*bs]]).to(device), **detector_kwargs).detach().cpu().numpy()]) | |
| return feats | |
| def get_fvd_feats(videos, i3d, device, bs=10): | |
| # videos in [0, 1] as torch tensor BCTHW | |
| # videos = [preprocess_single(video) for video in videos] | |
| embeddings = get_feats(videos, i3d, device, bs) | |
| return embeddings | |
| def preprocess_single(video, resolution=224, sequence_length=None): | |
| # video: CTHW, [0, 1] | |
| c, t, h, w = video.shape | |
| # temporal crop | |
| if sequence_length is not None: | |
| assert sequence_length <= t | |
| video = video[:, :sequence_length] | |
| # scale shorter side to resolution | |
| scale = resolution / min(h, w) | |
| if h < w: | |
| target_size = (resolution, math.ceil(w * scale)) | |
| else: | |
| target_size = (math.ceil(h * scale), resolution) | |
| video = F.interpolate(video, size=target_size, mode='bilinear', align_corners=False) | |
| # center crop | |
| c, t, h, w = video.shape | |
| w_start = (w - resolution) // 2 | |
| h_start = (h - resolution) // 2 | |
| video = video[:, :, h_start:h_start + resolution, w_start:w_start + resolution] | |
| # [0, 1] -> [-1, 1] | |
| video = (video - 0.5) * 2 | |
| return video.contiguous() | |
| """ | |
| Copy-pasted from https://github.com/cvpr2022-stylegan-v/stylegan-v/blob/main/src/metrics/frechet_video_distance.py | |
| """ | |
| from typing import Tuple | |
| from scipy.linalg import sqrtm | |
| import numpy as np | |
| def compute_stats(feats: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: | |
| mu = feats.mean(axis=0) # [d] | |
| sigma = np.cov(feats, rowvar=False) # [d, d] | |
| return mu, sigma | |
| def frechet_distance(feats_fake: np.ndarray, feats_real: np.ndarray) -> float: | |
| mu_gen, sigma_gen = compute_stats(feats_fake) | |
| mu_real, sigma_real = compute_stats(feats_real) | |
| m = np.square(mu_gen - mu_real).sum() | |
| if feats_fake.shape[0]>1: | |
| s, _ = sqrtm(np.dot(sigma_gen, sigma_real), disp=False) # pylint: disable=no-member | |
| fid = np.real(m + np.trace(sigma_gen + sigma_real - s * 2)) | |
| else: | |
| fid = np.real(m) | |
| return float(fid) |