Antoni Bigata
first commit
b5ce381
import logging
from abc import ABC, abstractmethod
from typing import Dict, List, Literal, Optional, Tuple, Union
import torch
import torch.nn.functional as F
from einops import rearrange, repeat
from ...util import append_dims, default
logpy = logging.getLogger(__name__)
class Guider(ABC):
@abstractmethod
def __call__(self, x: torch.Tensor, sigma: float) -> torch.Tensor:
pass
def prepare_inputs(
self, x: torch.Tensor, s: float, c: Dict, uc: Dict
) -> Tuple[torch.Tensor, float, Dict]:
pass
class VanillaCFG(Guider):
def __init__(
self, scale: float, low_sigma: float = 0.0, high_sigma: float = float("inf")
):
self.scale = scale
self.low_sigma = low_sigma
self.high_sigma = high_sigma
def set_scale(self, scale: float):
self.scale = scale
def __call__(self, x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor:
x_u, x_c = x.chunk(2)
x_pred = x_u + self.scale * (x_c - x_u)
return x_pred
def prepare_inputs(self, x, s, c, uc):
c_out = dict()
for k in c:
if k in [
"vector",
"crossattn",
"concat",
"audio_emb",
"image_embeds",
"landmarks",
"masks",
"gt",
"valence",
"arousal",
]:
c_out[k] = torch.cat((uc[k], c[k]), 0)
elif k == "reference":
c_out["reference"] = []
for i in range(len(c[k])):
c_out["reference"].append(torch.cat((uc[k][i], c[k][i]), 0))
else:
assert c[k] == uc[k]
c_out[k] = c[k]
return torch.cat([x] * 2), torch.cat([s] * 2), c_out
class VanillaSTG(Guider):
def __init__(
self,
scale_spatial: float,
scale_temporal: float,
low_sigma: float = 0.0,
high_sigma: float = float("inf"),
layer_skip: int = 8,
):
self.scale_spatial = scale_spatial
self.scale_temporal = scale_temporal
self.low_sigma = low_sigma
self.high_sigma = high_sigma
self.layer_skip = layer_skip
def set_scale(self, scale_spatial: float, scale_temporal: float):
self.scale_spatial = scale_spatial
self.scale_temporal = scale_temporal
def __call__(self, x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor:
x_c, x_spatial, x_temporal = x.chunk(3)
x_pred = (
x_c
+ self.scale_spatial * (x_c - x_spatial)
+ self.scale_temporal * (x_c - x_temporal)
)
return x_pred
def prepare_inputs(self, x, s, c, uc):
c_out = dict()
for k in c:
if k in [
"vector",
"crossattn",
"concat",
"audio_emb",
"image_embeds",
"landmarks",
"masks",
"gt",
"valence",
"arousal",
]:
c_out[k] = torch.cat((c[k], c[k], c[k]), 0)
elif k == "reference":
c_out["reference"] = []
for i in range(len(c[k])):
c_out["reference"].append(torch.cat((c[k][i], c[k][i], c[k][i]), 0))
else:
assert c[k] == uc[k]
c_out[k] = c[k]
c_out["skip_spatial_attention_at"] = [None, self.layer_skip, None]
c_out["skip_temporal_attention_at"] = [None, None, self.layer_skip]
return torch.cat([x] * 3), torch.cat([s] * 3), c_out
class MomentumBuffer:
def __init__(self, momentum: float):
self.momentum = momentum
self.running_average = 0
def update(self, update_value: torch.Tensor):
new_average = self.momentum * self.running_average
self.running_average = update_value + new_average
def project(v0: torch.Tensor, v1: torch.Tensor):
dtype = v0.dtype
v0, v1 = v0.double(), v1.double()
v1 = F.normalize(v1, dim=[-1, -2, -3])
v0_parallel = (v0 * v1).sum(dim=[-1, -2, -3], keepdim=True) * v1
v0_orthogonal = v0 - v0_parallel
return v0_parallel.to(dtype), v0_orthogonal.to(dtype)
def normalized_guidance(
pred_cond: torch.Tensor,
pred_uncond: torch.Tensor,
guidance_scale: float,
momentum_buffer: MomentumBuffer = None,
eta: float = 1.0,
norm_threshold: float = 0.0,
):
diff = pred_cond - pred_uncond
if momentum_buffer is not None:
momentum_buffer.update(diff)
diff = momentum_buffer.running_average
if norm_threshold > 0:
ones = torch.ones_like(diff)
diff_norm = diff.norm(p=2, dim=[-1, -2, -3], keepdim=True)
scale_factor = torch.minimum(ones, norm_threshold / diff_norm)
diff = diff * scale_factor
diff_parallel, diff_orthogonal = project(diff, pred_cond)
normalized_update = diff_orthogonal + eta * diff_parallel
pred_guided = pred_cond + (guidance_scale - 1) * normalized_update
return pred_guided
class APGGuider(VanillaCFG):
def __init__(
self,
scale: float,
momentum: float = -0.75,
eta: float = 0.0,
norm_threshold: float = 2.5,
):
super().__init__(scale)
self.momentum_buffer = MomentumBuffer(momentum)
self.eta = eta
self.norm_threshold = norm_threshold
def __call__(self, x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor:
x_u, x_c = x.chunk(2)
return normalized_guidance(
x_c, x_u, self.scale, self.momentum_buffer, self.eta, self.norm_threshold
)
class VanillaCFGplusplus(VanillaCFG):
def __call__(self, x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor:
x_u, x_c = x.chunk(2)
x_pred = x_u + self.scale * (x_c - x_u)
return x_pred, x_u
class KarrasGuider(VanillaCFG):
def prepare_inputs(self, x, s, c, uc):
c_out = dict()
for k in c:
if k in [
"vector",
"crossattn",
"concat",
"audio_emb",
"image_embeds",
"landmarks",
"valence",
"arousal",
]:
c_out[k] = torch.cat((c[k], c[k]), 0)
elif k == "reference":
c_out["reference"] = []
for i in range(len(c[k])):
c_out["reference"].append(torch.cat((c[k][i], c[k][i]), 0))
else:
assert c[k] == uc[k]
c_out[k] = c[k]
return torch.cat([x] * 2), torch.cat([s] * 2), c_out
class MultipleCondVanilla(Guider):
def __init__(self, scales, condition_names) -> None:
assert len(scales) == len(condition_names)
self.scales = scales
# self.condition_names = condition_names
self.n_conditions = len(scales)
self.map_cond_name = {
"audio_emb": "audio_emb",
"cond_frames_without_noise": "crossattn",
"cond_frames": "concat",
}
self.condition_names = [
self.map_cond_name.get(cond_name, cond_name)
for cond_name in condition_names
]
print("Condition names: ", self.condition_names)
def __call__(self, x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor:
outs = x.chunk(self.n_conditions + 1)
x_full_cond = outs[0]
x_pred = (1 + sum(self.scales)) * x_full_cond
for i, scale in enumerate(self.scales):
x_pred -= scale * outs[i + 1]
return x_pred
def prepare_inputs(self, x, s, c, uc):
c_out = dict()
# The first element is the full condition
for k in c:
if k in [
"vector",
"crossattn",
"concat",
"audio_emb",
"image_embeds",
"landmarks",
"masks",
"gt",
]:
c_out[k] = c[k]
else:
assert c[k] == uc[k]
c_out[k] = c[k]
# The rest are the conditions removed from the full condition
for cond_name in self.condition_names:
if not isinstance(cond_name, list):
cond_name = [cond_name]
for k in c:
if k in [
"vector",
"crossattn",
"concat",
"audio_emb",
"image_embeds",
"landmarks",
"masks",
"gt",
]:
c_out[k] = torch.cat(
(c_out[k], uc[k] if k in cond_name else c[k]), 0
)
return (
torch.cat([x] * (self.n_conditions + 1)),
torch.cat([s] * (self.n_conditions + 1)),
c_out,
)
class AudioRefMultiCondGuider(MultipleCondVanilla):
def __init__(
self,
audio_ratio: float = 5.0,
ref_ratio: float = 3.0,
use_normalized: bool = False,
momentum: float = -0.75,
eta: float = 0.0,
norm_threshold: float = 2.5,
):
super().__init__(
scales=[audio_ratio, ref_ratio], condition_names=["audio_emb", "concat"]
)
self.audio_ratio = audio_ratio
self.ref_ratio = ref_ratio
self.use_normalized = use_normalized
print(f"Use normalized: {self.use_normalized}")
self.momentum_buffer = MomentumBuffer(momentum)
self.eta = eta
self.norm_threshold = norm_threshold
self.momentum_buffer_audio = MomentumBuffer(momentum)
self.momentum_buffer_ref = MomentumBuffer(momentum)
def __call__(self, x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor:
e_uc, e_ref, c_audio_ref = x.chunk(3)
if self.use_normalized:
# Normalized guidance version
# Compute diff for audio guidance
diff_audio = c_audio_ref - e_uc
if self.momentum_buffer_audio is not None:
self.momentum_buffer_audio.update(diff_audio)
diff_audio = self.momentum_buffer_audio.running_average
if self.norm_threshold > 0:
ones = torch.ones_like(diff_audio)
diff_norm = diff_audio.norm(p=2, dim=[-1, -2, -3], keepdim=True)
scale_factor = torch.minimum(ones, self.norm_threshold / diff_norm)
diff_audio = diff_audio * scale_factor
diff_audio_parallel, diff_audio_orthogonal = project(
diff_audio, c_audio_ref
)
normalized_update_audio = (
diff_audio_orthogonal + self.eta * diff_audio_parallel
)
guidance_audio = (self.audio_ratio - 1) * normalized_update_audio
# Compute diff for ref guidance
diff_ref = e_ref - e_uc
if self.momentum_buffer_ref is not None:
self.momentum_buffer_ref.update(diff_ref)
diff_ref = self.momentum_buffer_ref.running_average
if self.norm_threshold > 0:
ones = torch.ones_like(diff_ref)
diff_norm = diff_ref.norm(p=2, dim=[-1, -2, -3], keepdim=True)
scale_factor = torch.minimum(ones, self.norm_threshold / diff_norm)
diff_ref = diff_ref * scale_factor
diff_ref_parallel, diff_ref_orthogonal = project(diff_ref, e_ref)
normalized_update_ref = diff_ref_orthogonal + self.eta * diff_ref_parallel
guidance_ref = (self.ref_ratio - 1) * normalized_update_ref
e_final = e_uc + guidance_audio + guidance_ref
else:
# Original version
e_final = (
self.audio_ratio * (c_audio_ref - e_ref)
+ self.ref_ratio * (e_ref - e_uc)
+ e_uc
)
return e_final
def set_scale(self, scale: torch.Tensor):
self.audio_ratio = float(scale[0])
self.ref_ratio = float(scale[1])
print(f"Audio ratio: {self.audio_ratio}, Ref ratio: {self.ref_ratio}")
def prepare_inputs(self, x, s, c, uc):
c_out = dict()
# Prepare inputs for e_base (no audio, no ref concat)
c_base = {k: v for k, v in c.items()}
c_base["crossattn"] = uc["crossattn"]
c_base["concat"] = uc["concat"] # Remove ref concat
# Prepare inputs for e_ref (no audio, with ref concat)
c_audio_ref = {k: v for k, v in c.items()}
# c_ref["concat"] = uc["concat"] # Remove ref concat
# Prepare inputs for e_audio (all conditions)
c_ref = {k: v for k, v in c.items()}
c_ref["crossattn"] = uc["crossattn"]
# Combine all conditions
for k in c:
if k in [
"vector",
"crossattn",
"concat",
"audio_emb",
"image_embeds",
"landmarks",
"masks",
"gt",
]:
c_out[k] = torch.cat((c_base[k], c_ref[k], c_audio_ref[k]), 0)
else:
c_out[k] = c[k]
return torch.cat([x] * 3), torch.cat([s] * 3), c_out
class IdentityGuider(Guider):
def __init__(self, *args, **kwargs):
# self.num_frames = num_frames
pass
def __call__(self, x: torch.Tensor, sigma: float) -> torch.Tensor:
return x
def prepare_inputs(
self, x: torch.Tensor, s: float, c: Dict, uc: Dict
) -> Tuple[torch.Tensor, float, Dict]:
c_out = dict()
for k in c:
c_out[k] = c[k]
return x, s, c_out
class LinearPredictionGuider(Guider):
def __init__(
self,
max_scale: float,
num_frames: int,
min_scale: float = 1.0,
additional_cond_keys: Optional[Union[List[str], str]] = None,
only_first=False,
):
self.min_scale = min_scale
self.max_scale = max_scale
self.num_frames = num_frames
self.scale = torch.linspace(min_scale, max_scale, num_frames).unsqueeze(0)
self.only_first = only_first
if only_first:
self.scale = torch.ones_like(self.scale) * max_scale
self.scale[:, 0] = min_scale
additional_cond_keys = default(additional_cond_keys, [])
if isinstance(additional_cond_keys, str):
additional_cond_keys = [additional_cond_keys]
self.additional_cond_keys = additional_cond_keys
def set_scale(self, scale: torch.Tensor):
self.min_scale = scale
self.scale = torch.linspace(
self.min_scale, self.max_scale, self.num_frames
).unsqueeze(0)
if self.only_first:
self.scale = torch.ones_like(self.scale) * self.max_scale
self.scale[:, 0] = self.min_scale
print(self.scale)
def __call__(self, x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor:
x_u, x_c = x.chunk(2)
x_u = rearrange(x_u, "(b t) ... -> b t ...", t=self.num_frames)
x_c = rearrange(x_c, "(b t) ... -> b t ...", t=self.num_frames)
scale = repeat(self.scale, "1 t -> b t", b=x_u.shape[0])
scale = append_dims(scale, x_u.ndim).to(x_u.device)
return rearrange(x_u + scale * (x_c - x_u), "b t ... -> (b t) ...")
def prepare_inputs(
self, x: torch.Tensor, s: torch.Tensor, c: dict, uc: dict
) -> Tuple[torch.Tensor, torch.Tensor, dict]:
c_out = dict()
for k in c:
if (
k
in ["vector", "crossattn", "concat", "audio_emb", "masks", "gt"]
+ self.additional_cond_keys
):
c_out[k] = torch.cat((uc[k], c[k]), 0)
else:
assert c[k] == uc[k]
c_out[k] = c[k]
return torch.cat([x] * 2), torch.cat([s] * 2), c_out
class LinearPredictionGuiderPlus(LinearPredictionGuider):
def __init__(
self,
max_scale: float,
num_frames: int,
min_scale: float = 1.0,
additional_cond_keys: Optional[Union[List[str], str]] = None,
):
super().__init__(max_scale, num_frames, min_scale, additional_cond_keys)
def __call__(self, x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor:
x_u, x_c = x.chunk(2)
x_u = rearrange(x_u, "(b t) ... -> b t ...", t=self.num_frames)
x_c = rearrange(x_c, "(b t) ... -> b t ...", t=self.num_frames)
scale = repeat(self.scale, "1 t -> b t", b=x_u.shape[0])
scale = append_dims(scale, x_u.ndim).to(x_u.device)
return rearrange(x_u + scale * (x_c - x_u), "b t ... -> (b t) ..."), x_u
class TrianglePredictionGuider(LinearPredictionGuider):
def __init__(
self,
max_scale: float,
num_frames: int,
min_scale: float = 1.0,
period: float | List[float] = 1.0,
period_fusing: Literal["mean", "multiply", "max"] = "max",
additional_cond_keys: Optional[Union[List[str], str]] = None,
):
super().__init__(max_scale, num_frames, min_scale, additional_cond_keys)
values = torch.linspace(0, 1, num_frames)
# Constructs a triangle wave
if isinstance(period, float):
period = [period]
scales = []
for p in period:
scales.append(self.triangle_wave(values, p))
if period_fusing == "mean":
scale = sum(scales) / len(period)
elif period_fusing == "multiply":
scale = torch.prod(torch.stack(scales), dim=0)
elif period_fusing == "max":
scale = torch.max(torch.stack(scales), dim=0).values
self.scale = (scale * (max_scale - min_scale) + min_scale).unsqueeze(0)
def triangle_wave(self, values: torch.Tensor, period) -> torch.Tensor:
return 2 * (values / period - torch.floor(values / period + 0.5)).abs()