| import numpy as np |
| from typing import Callable, Optional, List |
| import torch |
| from ..utils import log |
|
|
| def ordered_halving(val): |
| bin_str = f"{val:064b}" |
| bin_flip = bin_str[::-1] |
| as_int = int(bin_flip, 2) |
|
|
| return as_int / (1 << 64) |
|
|
| def does_window_roll_over(window: list[int], num_frames: int) -> tuple[bool, int]: |
| prev_val = -1 |
| for i, val in enumerate(window): |
| val = val % num_frames |
| if val < prev_val: |
| return True, i |
| prev_val = val |
| return False, -1 |
|
|
| def shift_window_to_start(window: list[int], num_frames: int): |
| start_val = window[0] |
| for i in range(len(window)): |
| |
| |
| window[i] = ((window[i] - start_val) + num_frames) % num_frames |
|
|
| def shift_window_to_end(window: list[int], num_frames: int): |
| |
| shift_window_to_start(window, num_frames) |
| end_val = window[-1] |
| end_delta = num_frames - end_val - 1 |
| for i in range(len(window)): |
| |
| window[i] = window[i] + end_delta |
|
|
| def get_missing_indexes(windows: list[list[int]], num_frames: int) -> list[int]: |
| all_indexes = list(range(num_frames)) |
| for w in windows: |
| for val in w: |
| try: |
| all_indexes.remove(val) |
| except ValueError: |
| pass |
| return all_indexes |
|
|
| def uniform_looped( |
| step: int = ..., |
| num_steps: Optional[int] = None, |
| num_frames: int = ..., |
| context_size: Optional[int] = None, |
| context_stride: int = 3, |
| context_overlap: int = 4, |
| closed_loop: bool = True, |
| ): |
| if num_frames <= context_size: |
| yield list(range(num_frames)) |
| return |
|
|
| context_stride = min(context_stride, int(np.ceil(np.log2(num_frames / context_size))) + 1) |
|
|
| for context_step in 1 << np.arange(context_stride): |
| pad = int(round(num_frames * ordered_halving(step))) |
| for j in range( |
| int(ordered_halving(step) * context_step) + pad, |
| num_frames + pad + (0 if closed_loop else -context_overlap), |
| (context_size * context_step - context_overlap), |
| ): |
| yield [e % num_frames for e in range(j, j + context_size * context_step, context_step)] |
|
|
| |
| def uniform_standard( |
| step: int = ..., |
| num_steps: Optional[int] = None, |
| num_frames: int = ..., |
| context_size: Optional[int] = None, |
| context_stride: int = 3, |
| context_overlap: int = 4, |
| closed_loop: bool = True, |
| ): |
| windows = [] |
| if num_frames <= context_size: |
| windows.append(list(range(num_frames))) |
| return windows |
|
|
| context_stride = min(context_stride, int(np.ceil(np.log2(num_frames / context_size))) + 1) |
|
|
| for context_step in 1 << np.arange(context_stride): |
| pad = int(round(num_frames * ordered_halving(step))) |
| for j in range( |
| int(ordered_halving(step) * context_step) + pad, |
| num_frames + pad + (0 if closed_loop else -context_overlap), |
| (context_size * context_step - context_overlap), |
| ): |
| windows.append([e % num_frames for e in range(j, j + context_size * context_step, context_step)]) |
|
|
| |
| delete_idxs = [] |
| win_i = 0 |
| while win_i < len(windows): |
| |
| is_roll, roll_idx = does_window_roll_over(windows[win_i], num_frames) |
| if is_roll: |
| roll_val = windows[win_i][roll_idx] |
| shift_window_to_end(windows[win_i], num_frames=num_frames) |
| |
| if roll_val not in windows[(win_i+1) % len(windows)]: |
| |
| windows.insert(win_i+1, list(range(roll_val, roll_val + context_size))) |
| |
| for pre_i in range(0, win_i): |
| if windows[win_i] == windows[pre_i]: |
| delete_idxs.append(win_i) |
| break |
| win_i += 1 |
|
|
| |
| delete_idxs.reverse() |
| for i in delete_idxs: |
| windows.pop(i) |
| return windows |
|
|
| def static_standard( |
| step: int = ..., |
| num_steps: Optional[int] = None, |
| num_frames: int = ..., |
| context_size: Optional[int] = None, |
| context_stride: int = 3, |
| context_overlap: int = 4, |
| closed_loop: bool = True, |
| ): |
| windows = [] |
| if num_frames <= context_size: |
| windows.append(list(range(num_frames))) |
| return windows |
| |
| delta = context_size - context_overlap |
| for start_idx in range(0, num_frames, delta): |
| |
| ending = start_idx + context_size |
| if ending >= num_frames: |
| final_delta = ending - num_frames |
| final_start_idx = start_idx - final_delta |
| windows.append(list(range(final_start_idx, final_start_idx + context_size))) |
| break |
| windows.append(list(range(start_idx, start_idx + context_size))) |
| return windows |
|
|
| def get_context_scheduler(name: str) -> Callable: |
| if name == "uniform_looped": |
| return uniform_looped |
| elif name == "uniform_standard": |
| return uniform_standard |
| elif name == "static_standard": |
| return static_standard |
| else: |
| raise ValueError(f"Unknown context_overlap policy {name}") |
|
|
|
|
| def get_total_steps( |
| scheduler, |
| timesteps: List[int], |
| num_steps: Optional[int] = None, |
| num_frames: int = ..., |
| context_size: Optional[int] = None, |
| context_stride: int = 3, |
| context_overlap: int = 4, |
| closed_loop: bool = True, |
| ): |
| return sum( |
| len( |
| list( |
| scheduler( |
| i, |
| num_steps, |
| num_frames, |
| context_size, |
| context_stride, |
| context_overlap, |
| ) |
| ) |
| ) |
| for i in range(len(timesteps)) |
| ) |
|
|
| def create_window_mask(noise_pred_context, c, latent_video_length, context_overlap, looped=False, window_type="linear"): |
| window_mask = torch.ones_like(noise_pred_context) |
| |
| if window_type == "pyramid": |
| |
| length = noise_pred_context.shape[1] |
| if length % 2 == 0: |
| max_weight = length // 2 |
| weight_sequence = list(range(1, max_weight + 1, 1)) + list(range(max_weight, 0, -1)) |
| else: |
| max_weight = (length + 1) // 2 |
| weight_sequence = list(range(1, max_weight, 1)) + [max_weight] + list(range(max_weight - 1, 0, -1)) |
| |
| |
| max_val = max(weight_sequence) |
| weight_sequence = [w / max_val for w in weight_sequence] |
| |
| |
| weights_tensor = torch.tensor(weight_sequence, device=noise_pred_context.device) |
| weights_tensor = weights_tensor.view(1, -1, 1, 1) |
| window_mask = weights_tensor.expand_as(window_mask).clone() |
| |
| |
| if not looped: |
| if min(c) == 0: |
| left_ramp = torch.linspace(0, 1, context_overlap, device=noise_pred_context.device).view(1, -1, 1, 1) |
| |
| left_section = window_mask[:, :context_overlap].clone() |
| window_mask[:, :context_overlap] = torch.maximum(left_section, left_ramp) |
| |
| if max(c) == latent_video_length - 1: |
| right_ramp = torch.linspace(1, 0, context_overlap, device=noise_pred_context.device).view(1, -1, 1, 1) |
| |
| right_section = window_mask[:, -context_overlap:].clone() |
| window_mask[:, -context_overlap:] = torch.maximum(right_section, right_ramp) |
| else: |
| |
| if min(c) > 0 or (looped and max(c) == latent_video_length - 1): |
| ramp_up = torch.linspace(0, 1, context_overlap, device=noise_pred_context.device) |
| ramp_up = ramp_up.view(1, -1, 1, 1) |
| window_mask[:, :context_overlap] = ramp_up |
| |
| |
| if max(c) < latent_video_length - 1 or (looped and min(c) == 0): |
| ramp_down = torch.linspace(1, 0, context_overlap, device=noise_pred_context.device) |
| ramp_down = ramp_down.view(1, -1, 1, 1) |
| window_mask[:, -context_overlap:] = ramp_down |
| |
| return window_mask |
|
|
| class WindowTracker: |
| def __init__(self, verbose=False): |
| self.window_map = {} |
| self.next_id = 0 |
| self.cache_states = {} |
| self.verbose = verbose |
| |
| def get_window_id(self, frames): |
| key = tuple(sorted(frames)) |
| if key not in self.window_map: |
| self.window_map[key] = self.next_id |
| if self.verbose: |
| log.info(f"New window pattern {key} -> ID {self.next_id}") |
| self.next_id += 1 |
| return self.window_map[key] |
| |
| def get_teacache(self, window_id, base_state): |
| if window_id not in self.cache_states: |
| if self.verbose: |
| log.info(f"Initializing persistent teacache for window {window_id}") |
| self.cache_states[window_id] = base_state.copy() |
| return self.cache_states[window_id] |
|
|