Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.fft as fft | |
| import math | |
| def get_longpath(BOX_SIZE_H=0.3, BOX_SIZE_W=0.3, input_mode=4): | |
| if input_mode == 1: | |
| # mode 1 | |
| inputs = [[0, 0, 0 + BOX_SIZE_H, 0, 0 + BOX_SIZE_W], | |
| [7, 1-BOX_SIZE_H, 1, (1-BOX_SIZE_W) / 15 * 7, (1-BOX_SIZE_W) / 15 * 7 + BOX_SIZE_W], | |
| [8, 1-BOX_SIZE_H, 1, (1-BOX_SIZE_W) / 15 * 8, (1-BOX_SIZE_W) / 15 * 8 + BOX_SIZE_W], | |
| [15, 0, 0 + BOX_SIZE_H, 1-BOX_SIZE_W, 1], | |
| [16, 0.1, 0.1 + BOX_SIZE_H, 0.9-BOX_SIZE_W, 0.9], | |
| [25, 0.1, 0.1 + BOX_SIZE_H, 0.1, 0.1 + BOX_SIZE_W], | |
| [31, 0.9-BOX_SIZE_H, 0.9, 0.1, 0.1 + BOX_SIZE_W], | |
| [32, 1-BOX_SIZE_H, 1, 0, 0 + BOX_SIZE_W], | |
| [39, 0, 0 + BOX_SIZE_H, (1-BOX_SIZE_W) / 15 * 7, (1-BOX_SIZE_W) / 15 * 7 + BOX_SIZE_W], | |
| [40, 0, 0 + BOX_SIZE_H, (1-BOX_SIZE_W) / 15 * 8, (1-BOX_SIZE_W) / 15 * 8 + BOX_SIZE_W], | |
| [47, 1-BOX_SIZE_H, 1, 1-BOX_SIZE_W, 1], | |
| [48, 0.9-BOX_SIZE_H, 0.9, 0.9-BOX_SIZE_W, 0.9], | |
| [57, 0.9-BOX_SIZE_H, 0.9, 0.1, 0.1 + BOX_SIZE_W], | |
| [63, 0.1, 0.1 + BOX_SIZE_H, 0.1, 0.1 + BOX_SIZE_W]] | |
| elif input_mode == 2: | |
| # mode 2 | |
| inputs = [[0, 0.1, 0.1 + BOX_SIZE_H, 0.1, 0.1 + BOX_SIZE_W], | |
| [6, 0.9-BOX_SIZE_H, 0.9, 0.1, 0.1 + BOX_SIZE_W], | |
| [15, 0.9-BOX_SIZE_H, 0.9, 0.9-BOX_SIZE_W, 0.9], | |
| [16, 0.9-BOX_SIZE_H, 0.9, 0.9-BOX_SIZE_W, 0.9], | |
| [22, 0.1, 0.1 + BOX_SIZE_H, 0.9-BOX_SIZE_W, 0.9], | |
| [31, 0.1, 0.1 + BOX_SIZE_H, 0.1, 0.1 + BOX_SIZE_W], | |
| [32, 0.1, 0.1 + BOX_SIZE_H, 0.1, 0.1 + BOX_SIZE_W], | |
| [41, 0.1, 0.1 + BOX_SIZE_H, 0.9-BOX_SIZE_W, 0.9], | |
| [47, 0.9-BOX_SIZE_H, 0.9, 0.9-BOX_SIZE_W, 0.9], | |
| [48, 0.9-BOX_SIZE_H, 0.9, 0.9-BOX_SIZE_W, 0.9], | |
| [57, 0.9-BOX_SIZE_H, 0.9, 0.1, 0.1 + BOX_SIZE_W], | |
| [63, 0.1, 0.1 + BOX_SIZE_H, 0.1, 0.1 + BOX_SIZE_W]] | |
| elif input_mode == 3: | |
| # mode 3 |||| | |
| inputs = [[0, 0, 0 + BOX_SIZE_H, 0, 0 + BOX_SIZE_W], | |
| [9, 1-BOX_SIZE_H, 1, (1-BOX_SIZE_W) / 7 * 1, (1-BOX_SIZE_W) / 7 * 1 + BOX_SIZE_W], | |
| [18, 0, 0 + BOX_SIZE_H, (1-BOX_SIZE_W) / 7 * 2, (1-BOX_SIZE_W) / 7 * 2 + BOX_SIZE_W], | |
| [27, 1-BOX_SIZE_H, 1, (1-BOX_SIZE_W) / 7 * 3, (1-BOX_SIZE_W) / 7 * 3 + BOX_SIZE_W], | |
| [36, 0, 0 + BOX_SIZE_H, (1-BOX_SIZE_W) / 7 * 4, (1-BOX_SIZE_W) / 7 * 4 + BOX_SIZE_W], | |
| [45, 1-BOX_SIZE_H, 1, (1-BOX_SIZE_W) / 7 * 5, (1-BOX_SIZE_W) / 7 * 5 + BOX_SIZE_W], | |
| [54, 0, 0 + BOX_SIZE_H, (1-BOX_SIZE_W) / 7 * 6, (1-BOX_SIZE_W) / 7 * 6 + BOX_SIZE_W], | |
| [63, 1-BOX_SIZE_H, 1, 1-BOX_SIZE_W, 1]] | |
| elif input_mode == 4: | |
| # mode 4 ---- | |
| inputs = [[0, 0, 0 + BOX_SIZE_H, 0, 0 + BOX_SIZE_W], | |
| [9, (1-BOX_SIZE_H) / 7 * 1, (1-BOX_SIZE_H) / 7 * 1 + BOX_SIZE_H, 1-BOX_SIZE_W, 1], | |
| [18, (1-BOX_SIZE_H) / 7 * 2, (1-BOX_SIZE_H) / 7 * 2 + BOX_SIZE_H, 0, 0 + BOX_SIZE_W], | |
| [27, (1-BOX_SIZE_H) / 7 * 3, (1-BOX_SIZE_H) / 7 * 3 + BOX_SIZE_H, 1-BOX_SIZE_W, 1], | |
| [36, (1-BOX_SIZE_H) / 7 * 4, (1-BOX_SIZE_H) / 7 * 4 + BOX_SIZE_H, 0, 0 + BOX_SIZE_W], | |
| [45, (1-BOX_SIZE_H) / 7 * 5, (1-BOX_SIZE_H) / 7 * 5 + BOX_SIZE_H, 1-BOX_SIZE_W, 1], | |
| [54, (1-BOX_SIZE_H) / 7 * 6, (1-BOX_SIZE_H) / 7 * 6 + BOX_SIZE_H, 0, 0 + BOX_SIZE_W], | |
| [63, 1-BOX_SIZE_H, 1, 1-BOX_SIZE_W, 1]] | |
| else: | |
| print('error') | |
| exit() | |
| outputs = plan_path(inputs) | |
| # print(outputs) | |
| return outputs | |
| def get_path(BOX_SIZE_H=0.3, BOX_SIZE_W=0.3, input_mode=0): | |
| if input_mode == 0: | |
| # \ d | |
| inputs = [[0, 0, 0 + BOX_SIZE_H, 0, 0 + BOX_SIZE_W], [15, 1-BOX_SIZE_H, 1, 1-BOX_SIZE_W, 1]] | |
| elif input_mode == 1: | |
| # / re d | |
| inputs = [[0, 0, 0 + BOX_SIZE_H, 1-BOX_SIZE_W, 1], [15, 1-BOX_SIZE_H, 1, 0, 0 + BOX_SIZE_W]] | |
| elif input_mode == 2: | |
| # L | |
| inputs = [[0, 0.1, 0.1 + BOX_SIZE_H, 0.1, 0.1 + BOX_SIZE_W], [6, 0.9-BOX_SIZE_H, 0.9, 0.1, 0.1 + BOX_SIZE_W], [15, 0.9-BOX_SIZE_H, 0.9, 0.9-BOX_SIZE_W, 0.9]] | |
| elif input_mode == 3: | |
| # re L | |
| inputs = [[0, 0.9-BOX_SIZE_H, 0.9, 0.9-BOX_SIZE_W, 0.9], [6, 0.1, 0.1 + BOX_SIZE_H, 0.9-BOX_SIZE_W, 0.9], [15, 0.1, 0.1 + BOX_SIZE_H, 0.1, 0.1 + BOX_SIZE_W]] | |
| elif input_mode == 4: | |
| # V | |
| inputs = [[0, 0, 0 + BOX_SIZE_H, 0, 0 + BOX_SIZE_W], [7, 1-BOX_SIZE_H, 1, (1-BOX_SIZE_W) / 15 * 7, (1-BOX_SIZE_W) / 15 * 7 + BOX_SIZE_W], [8, 1-BOX_SIZE_H, 1, (1-BOX_SIZE_W) / 15 * 8, (1-BOX_SIZE_W) / 15 * 8 + BOX_SIZE_W], [15, 0, 0 + BOX_SIZE_H, 1-BOX_SIZE_W, 1]] | |
| elif input_mode == 5: | |
| # re V | |
| inputs = [[0, 1-BOX_SIZE_H, 1, 1-BOX_SIZE_W, 1], [7, 0, 0 + BOX_SIZE_H, (1-BOX_SIZE_W) / 15 * 8, (1-BOX_SIZE_W) / 15 * 8 + BOX_SIZE_W], [8, 0, 0 + BOX_SIZE_H, (1-BOX_SIZE_W) / 15 * 7, (1-BOX_SIZE_W) / 15 * 7 + BOX_SIZE_W], [15, 1-BOX_SIZE_H, 1, 0, 0 + BOX_SIZE_W]] | |
| elif input_mode == 6: | |
| # -- goback | |
| inputs = [[0, 0.35, 0.35 + BOX_SIZE_H, 0.1, 0.1 + BOX_SIZE_W], [7, 0.35, 0.35 + BOX_SIZE_H, 0.9-BOX_SIZE_W, 0.9], [8, 0.35, 0.35 + BOX_SIZE_H, 0.9-BOX_SIZE_W, 0.9], [15, 0.35, 0.35 + BOX_SIZE_H, 0.1, 0.1 + BOX_SIZE_W]] | |
| elif input_mode == 7: | |
| # tri | |
| inputs = [[0, 0.1, 0.1 + BOX_SIZE_H, 0.35, 0.35 + BOX_SIZE_W], [5, 0.9-BOX_SIZE_H, 0.9, 0.9-BOX_SIZE_W, 0.9], [10, 0.9-BOX_SIZE_H, 0.9, 0.1, 0.1 + BOX_SIZE_W], [15, 0.1, 0.1 + BOX_SIZE_H, 0.35, 0.35 + BOX_SIZE_W]] | |
| outputs = plan_path(inputs) | |
| return outputs | |
| # input: List([frame, h_start, h_end, w_start, w_end], ...) | |
| # return: List([h_start, h_end, w_start, w_end], ...) | |
| def plan_path(input, video_length = 16): | |
| len_input = len(input) | |
| path = [input[0][1:]] | |
| for i in range(1, len_input): | |
| start = input[i-1] | |
| end = input[i] | |
| start_frame = start[0] | |
| end_frame = end[0] | |
| h_start_change = (end[1] - start[1]) / (end_frame - start_frame) | |
| h_end_change = (end[2] - start[2]) / (end_frame - start_frame) | |
| w_start_change = (end[3] - start[3]) / (end_frame - start_frame) | |
| w_end_change = (end[4] - start[4]) / (end_frame - start_frame) | |
| for j in range(start_frame+1, end_frame + 1): | |
| increase_frame = j - start_frame | |
| path += [[increase_frame * h_start_change + start[1], increase_frame * h_end_change + start[2], increase_frame * w_start_change + start[3], increase_frame * w_end_change + start[4]]] | |
| if input[0][0] > 0: | |
| h_change = path[1][0] - path[0][0] | |
| w_change = path[1][2] - path[0][2] | |
| for i in range(input[0][0]): | |
| path = [path[0][0] - h_change, path[0][1] - h_change, path[0][2] - w_change, path[0][3] - w_change] + path | |
| if input[-1][0] < video_length - 1: | |
| h_change = path[-1][0] - path[-2][0] | |
| w_change = path[-1][2] - path[-2][2] | |
| for i in range(video_length - 1 - input[-1][0]): | |
| path = path + [path[-1][0] + h_change, path[-1][1] + h_change, path[-1][2] + w_change, path[-1][3] + w_change] | |
| return path | |
| def gaussian_2d(x=0, y=0, mx=0, my=0, sx=1, sy=1): | |
| """ 2d Gaussian weight function | |
| """ | |
| gaussian_map = ( | |
| 1 | |
| / (2 * math.pi * sx * sy) | |
| * torch.exp(-((x - mx) ** 2 / (2 * sx**2) + (y - my) ** 2 / (2 * sy**2))) | |
| ) | |
| gaussian_map.div_(gaussian_map.max()) | |
| return gaussian_map | |
| def gaussian_weight(height=32, width=32, KERNEL_DIVISION=3.0): | |
| x = torch.linspace(0, height, height) | |
| y = torch.linspace(0, width, width) | |
| x, y = torch.meshgrid(x, y, indexing="ij") | |
| noise_patch = ( | |
| gaussian_2d( | |
| x, | |
| y, | |
| mx=int(height / 2), | |
| my=int(width / 2), | |
| sx=float(height / KERNEL_DIVISION), | |
| sy=float(width / KERNEL_DIVISION), | |
| ) | |
| ).half() | |
| return noise_patch | |
| def freq_mix_3d(x, noise, LPF): | |
| """ | |
| Noise reinitialization. | |
| Args: | |
| x: diffused latent | |
| noise: randomly sampled noise | |
| LPF: low pass filter | |
| """ | |
| # FFT | |
| x_freq = fft.fftn(x, dim=(-3, -2, -1)) | |
| x_freq = fft.fftshift(x_freq, dim=(-3, -2, -1)) | |
| noise_freq = fft.fftn(noise, dim=(-3, -2, -1)) | |
| noise_freq = fft.fftshift(noise_freq, dim=(-3, -2, -1)) | |
| # frequency mix | |
| HPF = 1 - LPF | |
| x_freq_low = x_freq * LPF | |
| noise_freq_high = noise_freq * HPF | |
| x_freq_mixed = x_freq_low + noise_freq_high # mix in freq domain | |
| # IFFT | |
| x_freq_mixed = fft.ifftshift(x_freq_mixed, dim=(-3, -2, -1)) | |
| x_mixed = fft.ifftn(x_freq_mixed, dim=(-3, -2, -1)).real | |
| return x_mixed | |
| def get_freq_filter(shape, device, filter_type, n, d_s, d_t): | |
| """ | |
| Form the frequency filter for noise reinitialization. | |
| Args: | |
| shape: shape of latent (B, C, T, H, W) | |
| filter_type: type of the freq filter | |
| n: (only for butterworth) order of the filter, larger n ~ ideal, smaller n ~ gaussian | |
| d_s: normalized stop frequency for spatial dimensions (0.0-1.0) | |
| d_t: normalized stop frequency for temporal dimension (0.0-1.0) | |
| """ | |
| if filter_type == "gaussian": | |
| return gaussian_low_pass_filter(shape=shape, d_s=d_s, d_t=d_t).to(device) | |
| elif filter_type == "ideal": | |
| return ideal_low_pass_filter(shape=shape, d_s=d_s, d_t=d_t).to(device) | |
| elif filter_type == "box": | |
| return box_low_pass_filter(shape=shape, d_s=d_s, d_t=d_t).to(device) | |
| elif filter_type == "butterworth": | |
| return butterworth_low_pass_filter(shape=shape, n=n, d_s=d_s, d_t=d_t).to(device) | |
| else: | |
| raise NotImplementedError | |
| def gaussian_low_pass_filter(shape, d_s=0.25, d_t=0.25): | |
| """ | |
| Compute the gaussian low pass filter mask. | |
| Args: | |
| shape: shape of the filter (volume) | |
| d_s: normalized stop frequency for spatial dimensions (0.0-1.0) | |
| d_t: normalized stop frequency for temporal dimension (0.0-1.0) | |
| """ | |
| T, H, W = shape[-3], shape[-2], shape[-1] | |
| mask = torch.zeros(shape) | |
| if d_s==0 or d_t==0: | |
| return mask | |
| for t in range(T): | |
| for h in range(H): | |
| for w in range(W): | |
| d_square = (((d_s/d_t)*(2*t/T-1))**2 + (2*h/H-1)**2 + (2*w/W-1)**2) | |
| mask[..., t,h,w] = math.exp(-1/(2*d_s**2) * d_square) | |
| return mask | |
| def butterworth_low_pass_filter(shape, n=4, d_s=0.25, d_t=0.25): | |
| """ | |
| Compute the butterworth low pass filter mask. | |
| Args: | |
| shape: shape of the filter (volume) | |
| n: order of the filter, larger n ~ ideal, smaller n ~ gaussian | |
| d_s: normalized stop frequency for spatial dimensions (0.0-1.0) | |
| d_t: normalized stop frequency for temporal dimension (0.0-1.0) | |
| """ | |
| T, H, W = shape[-3], shape[-2], shape[-1] | |
| mask = torch.zeros(shape) | |
| if d_s==0 or d_t==0: | |
| return mask | |
| for t in range(T): | |
| for h in range(H): | |
| for w in range(W): | |
| d_square = (((d_s/d_t)*(2*t/T-1))**2 + (2*h/H-1)**2 + (2*w/W-1)**2) | |
| mask[..., t,h,w] = 1 / (1 + (d_square / d_s**2)**n) | |
| return mask | |
| def ideal_low_pass_filter(shape, d_s=0.25, d_t=0.25): | |
| """ | |
| Compute the ideal low pass filter mask. | |
| Args: | |
| shape: shape of the filter (volume) | |
| d_s: normalized stop frequency for spatial dimensions (0.0-1.0) | |
| d_t: normalized stop frequency for temporal dimension (0.0-1.0) | |
| """ | |
| T, H, W = shape[-3], shape[-2], shape[-1] | |
| mask = torch.zeros(shape) | |
| if d_s==0 or d_t==0: | |
| return mask | |
| for t in range(T): | |
| for h in range(H): | |
| for w in range(W): | |
| d_square = (((d_s/d_t)*(2*t/T-1))**2 + (2*h/H-1)**2 + (2*w/W-1)**2) | |
| mask[..., t,h,w] = 1 if d_square <= d_s*2 else 0 | |
| return mask | |
| def box_low_pass_filter(shape, d_s=0.25, d_t=0.25): | |
| """ | |
| Compute the ideal low pass filter mask (approximated version). | |
| Args: | |
| shape: shape of the filter (volume) | |
| d_s: normalized stop frequency for spatial dimensions (0.0-1.0) | |
| d_t: normalized stop frequency for temporal dimension (0.0-1.0) | |
| """ | |
| T, H, W = shape[-3], shape[-2], shape[-1] | |
| mask = torch.zeros(shape) | |
| if d_s==0 or d_t==0: | |
| return mask | |
| threshold_s = round(int(H // 2) * d_s) | |
| threshold_t = round(T // 2 * d_t) | |
| cframe, crow, ccol = T // 2, H // 2, W //2 | |
| mask[..., cframe - threshold_t:cframe + threshold_t, crow - threshold_s:crow + threshold_s, ccol - threshold_s:ccol + threshold_s] = 1.0 | |
| return mask | |