Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,893 Bytes
08f69f6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 |
import torch
from safetensors.torch import load_file
from torch import nn, Tensor, einsum, IntTensor, FloatTensor, BoolTensor
from torchcubicspline import natural_cubic_spline_coeffs, NaturalCubicSpline
from torch.nn.utils import remove_weight_norm
def load_ckpt_state_dict(ckpt_path, prefix=None):
if ckpt_path.endswith(".safetensors"):
state_dict = load_file(ckpt_path)
else:
state_dict = torch.load(ckpt_path, map_location="cpu")["state_dict"]
# 过滤特定前缀的state_dict
filtered_state_dict = {k.replace(f'{prefix}',''): v for k, v in state_dict.items() if k.startswith(prefix)} if prefix is not None else state_dict
return filtered_state_dict
def remove_weight_norm_from_model(model):
for module in model.modules():
if hasattr(module, "weight"):
print(f"Removing weight norm from {module}")
remove_weight_norm(module)
return model
# Sampling functions copied from https://github.com/facebookresearch/audiocraft/blob/main/audiocraft/utils/utils.py under MIT license
# License can be found in LICENSES/LICENSE_META.txt
def multinomial(input: torch.Tensor, num_samples: int, replacement=False, *, generator=None):
"""torch.multinomial with arbitrary number of dimensions, and number of candidates on the last dimension.
Args:
input (torch.Tensor): The input tensor containing probabilities.
num_samples (int): Number of samples to draw.
replacement (bool): Whether to draw with replacement or not.
Keywords args:
generator (torch.Generator): A pseudorandom number generator for sampling.
Returns:
torch.Tensor: Last dimension contains num_samples indices
sampled from the multinomial probability distribution
located in the last dimension of tensor input.
"""
if num_samples == 1:
q = torch.empty_like(input).exponential_(1, generator=generator)
return torch.argmax(input / q, dim=-1, keepdim=True).to(torch.int64)
input_ = input.reshape(-1, input.shape[-1])
output_ = torch.multinomial(input_, num_samples=num_samples, replacement=replacement, generator=generator)
output = output_.reshape(*list(input.shape[:-1]), -1)
return output
def sample_top_k(probs: torch.Tensor, k: int) -> torch.Tensor:
"""Sample next token from top K values along the last dimension of the input probs tensor.
Args:
probs (torch.Tensor): Input probabilities with token candidates on the last dimension.
k (int): The k in “top-k”.
Returns:
torch.Tensor: Sampled tokens.
"""
top_k_value, _ = torch.topk(probs, k, dim=-1)
min_value_top_k = top_k_value[..., [-1]]
probs *= (probs >= min_value_top_k).float()
probs.div_(probs.sum(dim=-1, keepdim=True))
next_token = multinomial(probs, num_samples=1)
return next_token
def sample_top_p(probs: torch.Tensor, p: float) -> torch.Tensor:
"""Sample next token from top P probabilities along the last dimension of the input probs tensor.
Args:
probs (torch.Tensor): Input probabilities with token candidates on the last dimension.
p (int): The p in “top-p”.
Returns:
torch.Tensor: Sampled tokens.
"""
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
probs_sum = torch.cumsum(probs_sort, dim=-1)
mask = probs_sum - probs_sort > p
probs_sort *= (~mask).float()
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
next_token = multinomial(probs_sort, num_samples=1)
next_token = torch.gather(probs_idx, -1, next_token)
return next_token
def next_power_of_two(n):
return 2 ** (n - 1).bit_length()
def next_multiple_of_64(n):
return ((n + 63) // 64) * 64
# mask construction helpers
def mask_from_start_end_indices(
seq_len: int,
start: Tensor,
end: Tensor
):
assert start.shape == end.shape
device = start.device
seq = torch.arange(seq_len, device = device, dtype = torch.long)
seq = seq.reshape(*((-1,) * start.ndim), seq_len)
seq = seq.expand(*start.shape, seq_len)
mask = seq >= start[..., None].long()
mask &= seq < end[..., None].long()
return mask
def mask_from_frac_lengths(
seq_len: int,
frac_lengths: Tensor
):
device = frac_lengths.device
lengths = (frac_lengths * seq_len).long()
max_start = seq_len - lengths
rand = torch.zeros_like(frac_lengths, device = device).float().uniform_(0, 1)
start = (max_start * rand).clamp(min = 0)
end = start + lengths
return mask_from_start_end_indices(seq_len, start, end)
def _build_spline(video_feat, video_t, target_t):
# 三次样条插值核心实现
coeffs = natural_cubic_spline_coeffs(video_t, video_feat.permute(0,2,1))
spline = NaturalCubicSpline(coeffs)
return spline.evaluate(target_t).permute(0,2,1)
def resample(video_feat, audio_latent):
"""
9s
video_feat: [B, 72, D]
audio_latent: [B, D', 194] or int
"""
B, Tv, D = video_feat.shape
if isinstance(audio_latent, torch.Tensor):
# audio_latent is a tensor
if audio_latent.shape[1] != D:
Ta = audio_latent.shape[1]
else:
Ta = audio_latent.shape[2]
elif isinstance(audio_latent, int):
# audio_latent is an int
Ta = audio_latent
else:
raise TypeError("audio_latent must be either a tensor or an int")
# 构建时间戳 (关键改进点)
video_time = torch.linspace(0, 9, Tv, device=video_feat.device)
audio_time = torch.linspace(0, 9, Ta, device=video_feat.device)
# 三维化处理 (Batch, Feature, Time)
video_feat = video_feat.permute(0, 2, 1) # [B, D, Tv]
# 三次样条插值
aligned_video = _build_spline(video_feat, video_time, audio_time) # [B, D, Ta]
return aligned_video.permute(0, 2, 1) # [B, Ta, D]
|