Spaces:
Running
on
Zero
Running
on
Zero
import logging | |
from abc import ABC, abstractmethod | |
from typing import Dict, List, Literal, Optional, Tuple, Union | |
import torch | |
import torch.nn.functional as F | |
from einops import rearrange, repeat | |
from ...util import append_dims, default | |
logpy = logging.getLogger(__name__) | |
class Guider(ABC): | |
def __call__(self, x: torch.Tensor, sigma: float) -> torch.Tensor: | |
pass | |
def prepare_inputs( | |
self, x: torch.Tensor, s: float, c: Dict, uc: Dict | |
) -> Tuple[torch.Tensor, float, Dict]: | |
pass | |
class VanillaCFG(Guider): | |
def __init__( | |
self, scale: float, low_sigma: float = 0.0, high_sigma: float = float("inf") | |
): | |
self.scale = scale | |
self.low_sigma = low_sigma | |
self.high_sigma = high_sigma | |
def set_scale(self, scale: float): | |
self.scale = scale | |
def __call__(self, x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor: | |
x_u, x_c = x.chunk(2) | |
x_pred = x_u + self.scale * (x_c - x_u) | |
return x_pred | |
def prepare_inputs(self, x, s, c, uc): | |
c_out = dict() | |
for k in c: | |
if k in [ | |
"vector", | |
"crossattn", | |
"concat", | |
"audio_emb", | |
"image_embeds", | |
"landmarks", | |
"masks", | |
"gt", | |
"valence", | |
"arousal", | |
]: | |
c_out[k] = torch.cat((uc[k], c[k]), 0) | |
elif k == "reference": | |
c_out["reference"] = [] | |
for i in range(len(c[k])): | |
c_out["reference"].append(torch.cat((uc[k][i], c[k][i]), 0)) | |
else: | |
assert c[k] == uc[k] | |
c_out[k] = c[k] | |
return torch.cat([x] * 2), torch.cat([s] * 2), c_out | |
class VanillaSTG(Guider): | |
def __init__( | |
self, | |
scale_spatial: float, | |
scale_temporal: float, | |
low_sigma: float = 0.0, | |
high_sigma: float = float("inf"), | |
layer_skip: int = 8, | |
): | |
self.scale_spatial = scale_spatial | |
self.scale_temporal = scale_temporal | |
self.low_sigma = low_sigma | |
self.high_sigma = high_sigma | |
self.layer_skip = layer_skip | |
def set_scale(self, scale_spatial: float, scale_temporal: float): | |
self.scale_spatial = scale_spatial | |
self.scale_temporal = scale_temporal | |
def __call__(self, x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor: | |
x_c, x_spatial, x_temporal = x.chunk(3) | |
x_pred = ( | |
x_c | |
+ self.scale_spatial * (x_c - x_spatial) | |
+ self.scale_temporal * (x_c - x_temporal) | |
) | |
return x_pred | |
def prepare_inputs(self, x, s, c, uc): | |
c_out = dict() | |
for k in c: | |
if k in [ | |
"vector", | |
"crossattn", | |
"concat", | |
"audio_emb", | |
"image_embeds", | |
"landmarks", | |
"masks", | |
"gt", | |
"valence", | |
"arousal", | |
]: | |
c_out[k] = torch.cat((c[k], c[k], c[k]), 0) | |
elif k == "reference": | |
c_out["reference"] = [] | |
for i in range(len(c[k])): | |
c_out["reference"].append(torch.cat((c[k][i], c[k][i], c[k][i]), 0)) | |
else: | |
assert c[k] == uc[k] | |
c_out[k] = c[k] | |
c_out["skip_spatial_attention_at"] = [None, self.layer_skip, None] | |
c_out["skip_temporal_attention_at"] = [None, None, self.layer_skip] | |
return torch.cat([x] * 3), torch.cat([s] * 3), c_out | |
class MomentumBuffer: | |
def __init__(self, momentum: float): | |
self.momentum = momentum | |
self.running_average = 0 | |
def update(self, update_value: torch.Tensor): | |
new_average = self.momentum * self.running_average | |
self.running_average = update_value + new_average | |
def project(v0: torch.Tensor, v1: torch.Tensor): | |
dtype = v0.dtype | |
v0, v1 = v0.double(), v1.double() | |
v1 = F.normalize(v1, dim=[-1, -2, -3]) | |
v0_parallel = (v0 * v1).sum(dim=[-1, -2, -3], keepdim=True) * v1 | |
v0_orthogonal = v0 - v0_parallel | |
return v0_parallel.to(dtype), v0_orthogonal.to(dtype) | |
def normalized_guidance( | |
pred_cond: torch.Tensor, | |
pred_uncond: torch.Tensor, | |
guidance_scale: float, | |
momentum_buffer: MomentumBuffer = None, | |
eta: float = 1.0, | |
norm_threshold: float = 0.0, | |
): | |
diff = pred_cond - pred_uncond | |
if momentum_buffer is not None: | |
momentum_buffer.update(diff) | |
diff = momentum_buffer.running_average | |
if norm_threshold > 0: | |
ones = torch.ones_like(diff) | |
diff_norm = diff.norm(p=2, dim=[-1, -2, -3], keepdim=True) | |
scale_factor = torch.minimum(ones, norm_threshold / diff_norm) | |
diff = diff * scale_factor | |
diff_parallel, diff_orthogonal = project(diff, pred_cond) | |
normalized_update = diff_orthogonal + eta * diff_parallel | |
pred_guided = pred_cond + (guidance_scale - 1) * normalized_update | |
return pred_guided | |
class APGGuider(VanillaCFG): | |
def __init__( | |
self, | |
scale: float, | |
momentum: float = -0.75, | |
eta: float = 0.0, | |
norm_threshold: float = 2.5, | |
): | |
super().__init__(scale) | |
self.momentum_buffer = MomentumBuffer(momentum) | |
self.eta = eta | |
self.norm_threshold = norm_threshold | |
def __call__(self, x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor: | |
x_u, x_c = x.chunk(2) | |
return normalized_guidance( | |
x_c, x_u, self.scale, self.momentum_buffer, self.eta, self.norm_threshold | |
) | |
class VanillaCFGplusplus(VanillaCFG): | |
def __call__(self, x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor: | |
x_u, x_c = x.chunk(2) | |
x_pred = x_u + self.scale * (x_c - x_u) | |
return x_pred, x_u | |
class KarrasGuider(VanillaCFG): | |
def prepare_inputs(self, x, s, c, uc): | |
c_out = dict() | |
for k in c: | |
if k in [ | |
"vector", | |
"crossattn", | |
"concat", | |
"audio_emb", | |
"image_embeds", | |
"landmarks", | |
"valence", | |
"arousal", | |
]: | |
c_out[k] = torch.cat((c[k], c[k]), 0) | |
elif k == "reference": | |
c_out["reference"] = [] | |
for i in range(len(c[k])): | |
c_out["reference"].append(torch.cat((c[k][i], c[k][i]), 0)) | |
else: | |
assert c[k] == uc[k] | |
c_out[k] = c[k] | |
return torch.cat([x] * 2), torch.cat([s] * 2), c_out | |
class MultipleCondVanilla(Guider): | |
def __init__(self, scales, condition_names) -> None: | |
assert len(scales) == len(condition_names) | |
self.scales = scales | |
# self.condition_names = condition_names | |
self.n_conditions = len(scales) | |
self.map_cond_name = { | |
"audio_emb": "audio_emb", | |
"cond_frames_without_noise": "crossattn", | |
"cond_frames": "concat", | |
} | |
self.condition_names = [ | |
self.map_cond_name.get(cond_name, cond_name) | |
for cond_name in condition_names | |
] | |
print("Condition names: ", self.condition_names) | |
def __call__(self, x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor: | |
outs = x.chunk(self.n_conditions + 1) | |
x_full_cond = outs[0] | |
x_pred = (1 + sum(self.scales)) * x_full_cond | |
for i, scale in enumerate(self.scales): | |
x_pred -= scale * outs[i + 1] | |
return x_pred | |
def prepare_inputs(self, x, s, c, uc): | |
c_out = dict() | |
# The first element is the full condition | |
for k in c: | |
if k in [ | |
"vector", | |
"crossattn", | |
"concat", | |
"audio_emb", | |
"image_embeds", | |
"landmarks", | |
"masks", | |
"gt", | |
]: | |
c_out[k] = c[k] | |
else: | |
assert c[k] == uc[k] | |
c_out[k] = c[k] | |
# The rest are the conditions removed from the full condition | |
for cond_name in self.condition_names: | |
if not isinstance(cond_name, list): | |
cond_name = [cond_name] | |
for k in c: | |
if k in [ | |
"vector", | |
"crossattn", | |
"concat", | |
"audio_emb", | |
"image_embeds", | |
"landmarks", | |
"masks", | |
"gt", | |
]: | |
c_out[k] = torch.cat( | |
(c_out[k], uc[k] if k in cond_name else c[k]), 0 | |
) | |
return ( | |
torch.cat([x] * (self.n_conditions + 1)), | |
torch.cat([s] * (self.n_conditions + 1)), | |
c_out, | |
) | |
class AudioRefMultiCondGuider(MultipleCondVanilla): | |
def __init__( | |
self, | |
audio_ratio: float = 5.0, | |
ref_ratio: float = 3.0, | |
use_normalized: bool = False, | |
momentum: float = -0.75, | |
eta: float = 0.0, | |
norm_threshold: float = 2.5, | |
): | |
super().__init__( | |
scales=[audio_ratio, ref_ratio], condition_names=["audio_emb", "concat"] | |
) | |
self.audio_ratio = audio_ratio | |
self.ref_ratio = ref_ratio | |
self.use_normalized = use_normalized | |
print(f"Use normalized: {self.use_normalized}") | |
self.momentum_buffer = MomentumBuffer(momentum) | |
self.eta = eta | |
self.norm_threshold = norm_threshold | |
self.momentum_buffer_audio = MomentumBuffer(momentum) | |
self.momentum_buffer_ref = MomentumBuffer(momentum) | |
def __call__(self, x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor: | |
e_uc, e_ref, c_audio_ref = x.chunk(3) | |
if self.use_normalized: | |
# Normalized guidance version | |
# Compute diff for audio guidance | |
diff_audio = c_audio_ref - e_uc | |
if self.momentum_buffer_audio is not None: | |
self.momentum_buffer_audio.update(diff_audio) | |
diff_audio = self.momentum_buffer_audio.running_average | |
if self.norm_threshold > 0: | |
ones = torch.ones_like(diff_audio) | |
diff_norm = diff_audio.norm(p=2, dim=[-1, -2, -3], keepdim=True) | |
scale_factor = torch.minimum(ones, self.norm_threshold / diff_norm) | |
diff_audio = diff_audio * scale_factor | |
diff_audio_parallel, diff_audio_orthogonal = project( | |
diff_audio, c_audio_ref | |
) | |
normalized_update_audio = ( | |
diff_audio_orthogonal + self.eta * diff_audio_parallel | |
) | |
guidance_audio = (self.audio_ratio - 1) * normalized_update_audio | |
# Compute diff for ref guidance | |
diff_ref = e_ref - e_uc | |
if self.momentum_buffer_ref is not None: | |
self.momentum_buffer_ref.update(diff_ref) | |
diff_ref = self.momentum_buffer_ref.running_average | |
if self.norm_threshold > 0: | |
ones = torch.ones_like(diff_ref) | |
diff_norm = diff_ref.norm(p=2, dim=[-1, -2, -3], keepdim=True) | |
scale_factor = torch.minimum(ones, self.norm_threshold / diff_norm) | |
diff_ref = diff_ref * scale_factor | |
diff_ref_parallel, diff_ref_orthogonal = project(diff_ref, e_ref) | |
normalized_update_ref = diff_ref_orthogonal + self.eta * diff_ref_parallel | |
guidance_ref = (self.ref_ratio - 1) * normalized_update_ref | |
e_final = e_uc + guidance_audio + guidance_ref | |
else: | |
# Original version | |
e_final = ( | |
self.audio_ratio * (c_audio_ref - e_ref) | |
+ self.ref_ratio * (e_ref - e_uc) | |
+ e_uc | |
) | |
return e_final | |
def set_scale(self, scale: torch.Tensor): | |
self.audio_ratio = float(scale[0]) | |
self.ref_ratio = float(scale[1]) | |
print(f"Audio ratio: {self.audio_ratio}, Ref ratio: {self.ref_ratio}") | |
def prepare_inputs(self, x, s, c, uc): | |
c_out = dict() | |
# Prepare inputs for e_base (no audio, no ref concat) | |
c_base = {k: v for k, v in c.items()} | |
c_base["crossattn"] = uc["crossattn"] | |
c_base["concat"] = uc["concat"] # Remove ref concat | |
# Prepare inputs for e_ref (no audio, with ref concat) | |
c_audio_ref = {k: v for k, v in c.items()} | |
# c_ref["concat"] = uc["concat"] # Remove ref concat | |
# Prepare inputs for e_audio (all conditions) | |
c_ref = {k: v for k, v in c.items()} | |
c_ref["crossattn"] = uc["crossattn"] | |
# Combine all conditions | |
for k in c: | |
if k in [ | |
"vector", | |
"crossattn", | |
"concat", | |
"audio_emb", | |
"image_embeds", | |
"landmarks", | |
"masks", | |
"gt", | |
]: | |
c_out[k] = torch.cat((c_base[k], c_ref[k], c_audio_ref[k]), 0) | |
else: | |
c_out[k] = c[k] | |
return torch.cat([x] * 3), torch.cat([s] * 3), c_out | |
class IdentityGuider(Guider): | |
def __init__(self, *args, **kwargs): | |
# self.num_frames = num_frames | |
pass | |
def __call__(self, x: torch.Tensor, sigma: float) -> torch.Tensor: | |
return x | |
def prepare_inputs( | |
self, x: torch.Tensor, s: float, c: Dict, uc: Dict | |
) -> Tuple[torch.Tensor, float, Dict]: | |
c_out = dict() | |
for k in c: | |
c_out[k] = c[k] | |
return x, s, c_out | |
class LinearPredictionGuider(Guider): | |
def __init__( | |
self, | |
max_scale: float, | |
num_frames: int, | |
min_scale: float = 1.0, | |
additional_cond_keys: Optional[Union[List[str], str]] = None, | |
only_first=False, | |
): | |
self.min_scale = min_scale | |
self.max_scale = max_scale | |
self.num_frames = num_frames | |
self.scale = torch.linspace(min_scale, max_scale, num_frames).unsqueeze(0) | |
self.only_first = only_first | |
if only_first: | |
self.scale = torch.ones_like(self.scale) * max_scale | |
self.scale[:, 0] = min_scale | |
additional_cond_keys = default(additional_cond_keys, []) | |
if isinstance(additional_cond_keys, str): | |
additional_cond_keys = [additional_cond_keys] | |
self.additional_cond_keys = additional_cond_keys | |
def set_scale(self, scale: torch.Tensor): | |
self.min_scale = scale | |
self.scale = torch.linspace( | |
self.min_scale, self.max_scale, self.num_frames | |
).unsqueeze(0) | |
if self.only_first: | |
self.scale = torch.ones_like(self.scale) * self.max_scale | |
self.scale[:, 0] = self.min_scale | |
print(self.scale) | |
def __call__(self, x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor: | |
x_u, x_c = x.chunk(2) | |
x_u = rearrange(x_u, "(b t) ... -> b t ...", t=self.num_frames) | |
x_c = rearrange(x_c, "(b t) ... -> b t ...", t=self.num_frames) | |
scale = repeat(self.scale, "1 t -> b t", b=x_u.shape[0]) | |
scale = append_dims(scale, x_u.ndim).to(x_u.device) | |
return rearrange(x_u + scale * (x_c - x_u), "b t ... -> (b t) ...") | |
def prepare_inputs( | |
self, x: torch.Tensor, s: torch.Tensor, c: dict, uc: dict | |
) -> Tuple[torch.Tensor, torch.Tensor, dict]: | |
c_out = dict() | |
for k in c: | |
if ( | |
k | |
in ["vector", "crossattn", "concat", "audio_emb", "masks", "gt"] | |
+ self.additional_cond_keys | |
): | |
c_out[k] = torch.cat((uc[k], c[k]), 0) | |
else: | |
assert c[k] == uc[k] | |
c_out[k] = c[k] | |
return torch.cat([x] * 2), torch.cat([s] * 2), c_out | |
class LinearPredictionGuiderPlus(LinearPredictionGuider): | |
def __init__( | |
self, | |
max_scale: float, | |
num_frames: int, | |
min_scale: float = 1.0, | |
additional_cond_keys: Optional[Union[List[str], str]] = None, | |
): | |
super().__init__(max_scale, num_frames, min_scale, additional_cond_keys) | |
def __call__(self, x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor: | |
x_u, x_c = x.chunk(2) | |
x_u = rearrange(x_u, "(b t) ... -> b t ...", t=self.num_frames) | |
x_c = rearrange(x_c, "(b t) ... -> b t ...", t=self.num_frames) | |
scale = repeat(self.scale, "1 t -> b t", b=x_u.shape[0]) | |
scale = append_dims(scale, x_u.ndim).to(x_u.device) | |
return rearrange(x_u + scale * (x_c - x_u), "b t ... -> (b t) ..."), x_u | |
class TrianglePredictionGuider(LinearPredictionGuider): | |
def __init__( | |
self, | |
max_scale: float, | |
num_frames: int, | |
min_scale: float = 1.0, | |
period: float | List[float] = 1.0, | |
period_fusing: Literal["mean", "multiply", "max"] = "max", | |
additional_cond_keys: Optional[Union[List[str], str]] = None, | |
): | |
super().__init__(max_scale, num_frames, min_scale, additional_cond_keys) | |
values = torch.linspace(0, 1, num_frames) | |
# Constructs a triangle wave | |
if isinstance(period, float): | |
period = [period] | |
scales = [] | |
for p in period: | |
scales.append(self.triangle_wave(values, p)) | |
if period_fusing == "mean": | |
scale = sum(scales) / len(period) | |
elif period_fusing == "multiply": | |
scale = torch.prod(torch.stack(scales), dim=0) | |
elif period_fusing == "max": | |
scale = torch.max(torch.stack(scales), dim=0).values | |
self.scale = (scale * (max_scale - min_scale) + min_scale).unsqueeze(0) | |
def triangle_wave(self, values: torch.Tensor, period) -> torch.Tensor: | |
return 2 * (values / period - torch.floor(values / period + 0.5)).abs() | |