Antoni Bigata
first commit
b5ce381
raw
history blame
18 kB
import logging
from abc import ABC, abstractmethod
from typing import Dict, List, Literal, Optional, Tuple, Union
import torch
import torch.nn.functional as F
from einops import rearrange, repeat
from ...util import append_dims, default
logpy = logging.getLogger(__name__)
class Guider(ABC):
@abstractmethod
def __call__(self, x: torch.Tensor, sigma: float) -> torch.Tensor:
pass
def prepare_inputs(
self, x: torch.Tensor, s: float, c: Dict, uc: Dict
) -> Tuple[torch.Tensor, float, Dict]:
pass
class VanillaCFG(Guider):
def __init__(
self, scale: float, low_sigma: float = 0.0, high_sigma: float = float("inf")
):
self.scale = scale
self.low_sigma = low_sigma
self.high_sigma = high_sigma
def set_scale(self, scale: float):
self.scale = scale
def __call__(self, x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor:
x_u, x_c = x.chunk(2)
x_pred = x_u + self.scale * (x_c - x_u)
return x_pred
def prepare_inputs(self, x, s, c, uc):
c_out = dict()
for k in c:
if k in [
"vector",
"crossattn",
"concat",
"audio_emb",
"image_embeds",
"landmarks",
"masks",
"gt",
"valence",
"arousal",
]:
c_out[k] = torch.cat((uc[k], c[k]), 0)
elif k == "reference":
c_out["reference"] = []
for i in range(len(c[k])):
c_out["reference"].append(torch.cat((uc[k][i], c[k][i]), 0))
else:
assert c[k] == uc[k]
c_out[k] = c[k]
return torch.cat([x] * 2), torch.cat([s] * 2), c_out
class VanillaSTG(Guider):
def __init__(
self,
scale_spatial: float,
scale_temporal: float,
low_sigma: float = 0.0,
high_sigma: float = float("inf"),
layer_skip: int = 8,
):
self.scale_spatial = scale_spatial
self.scale_temporal = scale_temporal
self.low_sigma = low_sigma
self.high_sigma = high_sigma
self.layer_skip = layer_skip
def set_scale(self, scale_spatial: float, scale_temporal: float):
self.scale_spatial = scale_spatial
self.scale_temporal = scale_temporal
def __call__(self, x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor:
x_c, x_spatial, x_temporal = x.chunk(3)
x_pred = (
x_c
+ self.scale_spatial * (x_c - x_spatial)
+ self.scale_temporal * (x_c - x_temporal)
)
return x_pred
def prepare_inputs(self, x, s, c, uc):
c_out = dict()
for k in c:
if k in [
"vector",
"crossattn",
"concat",
"audio_emb",
"image_embeds",
"landmarks",
"masks",
"gt",
"valence",
"arousal",
]:
c_out[k] = torch.cat((c[k], c[k], c[k]), 0)
elif k == "reference":
c_out["reference"] = []
for i in range(len(c[k])):
c_out["reference"].append(torch.cat((c[k][i], c[k][i], c[k][i]), 0))
else:
assert c[k] == uc[k]
c_out[k] = c[k]
c_out["skip_spatial_attention_at"] = [None, self.layer_skip, None]
c_out["skip_temporal_attention_at"] = [None, None, self.layer_skip]
return torch.cat([x] * 3), torch.cat([s] * 3), c_out
class MomentumBuffer:
def __init__(self, momentum: float):
self.momentum = momentum
self.running_average = 0
def update(self, update_value: torch.Tensor):
new_average = self.momentum * self.running_average
self.running_average = update_value + new_average
def project(v0: torch.Tensor, v1: torch.Tensor):
dtype = v0.dtype
v0, v1 = v0.double(), v1.double()
v1 = F.normalize(v1, dim=[-1, -2, -3])
v0_parallel = (v0 * v1).sum(dim=[-1, -2, -3], keepdim=True) * v1
v0_orthogonal = v0 - v0_parallel
return v0_parallel.to(dtype), v0_orthogonal.to(dtype)
def normalized_guidance(
pred_cond: torch.Tensor,
pred_uncond: torch.Tensor,
guidance_scale: float,
momentum_buffer: MomentumBuffer = None,
eta: float = 1.0,
norm_threshold: float = 0.0,
):
diff = pred_cond - pred_uncond
if momentum_buffer is not None:
momentum_buffer.update(diff)
diff = momentum_buffer.running_average
if norm_threshold > 0:
ones = torch.ones_like(diff)
diff_norm = diff.norm(p=2, dim=[-1, -2, -3], keepdim=True)
scale_factor = torch.minimum(ones, norm_threshold / diff_norm)
diff = diff * scale_factor
diff_parallel, diff_orthogonal = project(diff, pred_cond)
normalized_update = diff_orthogonal + eta * diff_parallel
pred_guided = pred_cond + (guidance_scale - 1) * normalized_update
return pred_guided
class APGGuider(VanillaCFG):
def __init__(
self,
scale: float,
momentum: float = -0.75,
eta: float = 0.0,
norm_threshold: float = 2.5,
):
super().__init__(scale)
self.momentum_buffer = MomentumBuffer(momentum)
self.eta = eta
self.norm_threshold = norm_threshold
def __call__(self, x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor:
x_u, x_c = x.chunk(2)
return normalized_guidance(
x_c, x_u, self.scale, self.momentum_buffer, self.eta, self.norm_threshold
)
class VanillaCFGplusplus(VanillaCFG):
def __call__(self, x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor:
x_u, x_c = x.chunk(2)
x_pred = x_u + self.scale * (x_c - x_u)
return x_pred, x_u
class KarrasGuider(VanillaCFG):
def prepare_inputs(self, x, s, c, uc):
c_out = dict()
for k in c:
if k in [
"vector",
"crossattn",
"concat",
"audio_emb",
"image_embeds",
"landmarks",
"valence",
"arousal",
]:
c_out[k] = torch.cat((c[k], c[k]), 0)
elif k == "reference":
c_out["reference"] = []
for i in range(len(c[k])):
c_out["reference"].append(torch.cat((c[k][i], c[k][i]), 0))
else:
assert c[k] == uc[k]
c_out[k] = c[k]
return torch.cat([x] * 2), torch.cat([s] * 2), c_out
class MultipleCondVanilla(Guider):
def __init__(self, scales, condition_names) -> None:
assert len(scales) == len(condition_names)
self.scales = scales
# self.condition_names = condition_names
self.n_conditions = len(scales)
self.map_cond_name = {
"audio_emb": "audio_emb",
"cond_frames_without_noise": "crossattn",
"cond_frames": "concat",
}
self.condition_names = [
self.map_cond_name.get(cond_name, cond_name)
for cond_name in condition_names
]
print("Condition names: ", self.condition_names)
def __call__(self, x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor:
outs = x.chunk(self.n_conditions + 1)
x_full_cond = outs[0]
x_pred = (1 + sum(self.scales)) * x_full_cond
for i, scale in enumerate(self.scales):
x_pred -= scale * outs[i + 1]
return x_pred
def prepare_inputs(self, x, s, c, uc):
c_out = dict()
# The first element is the full condition
for k in c:
if k in [
"vector",
"crossattn",
"concat",
"audio_emb",
"image_embeds",
"landmarks",
"masks",
"gt",
]:
c_out[k] = c[k]
else:
assert c[k] == uc[k]
c_out[k] = c[k]
# The rest are the conditions removed from the full condition
for cond_name in self.condition_names:
if not isinstance(cond_name, list):
cond_name = [cond_name]
for k in c:
if k in [
"vector",
"crossattn",
"concat",
"audio_emb",
"image_embeds",
"landmarks",
"masks",
"gt",
]:
c_out[k] = torch.cat(
(c_out[k], uc[k] if k in cond_name else c[k]), 0
)
return (
torch.cat([x] * (self.n_conditions + 1)),
torch.cat([s] * (self.n_conditions + 1)),
c_out,
)
class AudioRefMultiCondGuider(MultipleCondVanilla):
def __init__(
self,
audio_ratio: float = 5.0,
ref_ratio: float = 3.0,
use_normalized: bool = False,
momentum: float = -0.75,
eta: float = 0.0,
norm_threshold: float = 2.5,
):
super().__init__(
scales=[audio_ratio, ref_ratio], condition_names=["audio_emb", "concat"]
)
self.audio_ratio = audio_ratio
self.ref_ratio = ref_ratio
self.use_normalized = use_normalized
print(f"Use normalized: {self.use_normalized}")
self.momentum_buffer = MomentumBuffer(momentum)
self.eta = eta
self.norm_threshold = norm_threshold
self.momentum_buffer_audio = MomentumBuffer(momentum)
self.momentum_buffer_ref = MomentumBuffer(momentum)
def __call__(self, x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor:
e_uc, e_ref, c_audio_ref = x.chunk(3)
if self.use_normalized:
# Normalized guidance version
# Compute diff for audio guidance
diff_audio = c_audio_ref - e_uc
if self.momentum_buffer_audio is not None:
self.momentum_buffer_audio.update(diff_audio)
diff_audio = self.momentum_buffer_audio.running_average
if self.norm_threshold > 0:
ones = torch.ones_like(diff_audio)
diff_norm = diff_audio.norm(p=2, dim=[-1, -2, -3], keepdim=True)
scale_factor = torch.minimum(ones, self.norm_threshold / diff_norm)
diff_audio = diff_audio * scale_factor
diff_audio_parallel, diff_audio_orthogonal = project(
diff_audio, c_audio_ref
)
normalized_update_audio = (
diff_audio_orthogonal + self.eta * diff_audio_parallel
)
guidance_audio = (self.audio_ratio - 1) * normalized_update_audio
# Compute diff for ref guidance
diff_ref = e_ref - e_uc
if self.momentum_buffer_ref is not None:
self.momentum_buffer_ref.update(diff_ref)
diff_ref = self.momentum_buffer_ref.running_average
if self.norm_threshold > 0:
ones = torch.ones_like(diff_ref)
diff_norm = diff_ref.norm(p=2, dim=[-1, -2, -3], keepdim=True)
scale_factor = torch.minimum(ones, self.norm_threshold / diff_norm)
diff_ref = diff_ref * scale_factor
diff_ref_parallel, diff_ref_orthogonal = project(diff_ref, e_ref)
normalized_update_ref = diff_ref_orthogonal + self.eta * diff_ref_parallel
guidance_ref = (self.ref_ratio - 1) * normalized_update_ref
e_final = e_uc + guidance_audio + guidance_ref
else:
# Original version
e_final = (
self.audio_ratio * (c_audio_ref - e_ref)
+ self.ref_ratio * (e_ref - e_uc)
+ e_uc
)
return e_final
def set_scale(self, scale: torch.Tensor):
self.audio_ratio = float(scale[0])
self.ref_ratio = float(scale[1])
print(f"Audio ratio: {self.audio_ratio}, Ref ratio: {self.ref_ratio}")
def prepare_inputs(self, x, s, c, uc):
c_out = dict()
# Prepare inputs for e_base (no audio, no ref concat)
c_base = {k: v for k, v in c.items()}
c_base["crossattn"] = uc["crossattn"]
c_base["concat"] = uc["concat"] # Remove ref concat
# Prepare inputs for e_ref (no audio, with ref concat)
c_audio_ref = {k: v for k, v in c.items()}
# c_ref["concat"] = uc["concat"] # Remove ref concat
# Prepare inputs for e_audio (all conditions)
c_ref = {k: v for k, v in c.items()}
c_ref["crossattn"] = uc["crossattn"]
# Combine all conditions
for k in c:
if k in [
"vector",
"crossattn",
"concat",
"audio_emb",
"image_embeds",
"landmarks",
"masks",
"gt",
]:
c_out[k] = torch.cat((c_base[k], c_ref[k], c_audio_ref[k]), 0)
else:
c_out[k] = c[k]
return torch.cat([x] * 3), torch.cat([s] * 3), c_out
class IdentityGuider(Guider):
def __init__(self, *args, **kwargs):
# self.num_frames = num_frames
pass
def __call__(self, x: torch.Tensor, sigma: float) -> torch.Tensor:
return x
def prepare_inputs(
self, x: torch.Tensor, s: float, c: Dict, uc: Dict
) -> Tuple[torch.Tensor, float, Dict]:
c_out = dict()
for k in c:
c_out[k] = c[k]
return x, s, c_out
class LinearPredictionGuider(Guider):
def __init__(
self,
max_scale: float,
num_frames: int,
min_scale: float = 1.0,
additional_cond_keys: Optional[Union[List[str], str]] = None,
only_first=False,
):
self.min_scale = min_scale
self.max_scale = max_scale
self.num_frames = num_frames
self.scale = torch.linspace(min_scale, max_scale, num_frames).unsqueeze(0)
self.only_first = only_first
if only_first:
self.scale = torch.ones_like(self.scale) * max_scale
self.scale[:, 0] = min_scale
additional_cond_keys = default(additional_cond_keys, [])
if isinstance(additional_cond_keys, str):
additional_cond_keys = [additional_cond_keys]
self.additional_cond_keys = additional_cond_keys
def set_scale(self, scale: torch.Tensor):
self.min_scale = scale
self.scale = torch.linspace(
self.min_scale, self.max_scale, self.num_frames
).unsqueeze(0)
if self.only_first:
self.scale = torch.ones_like(self.scale) * self.max_scale
self.scale[:, 0] = self.min_scale
print(self.scale)
def __call__(self, x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor:
x_u, x_c = x.chunk(2)
x_u = rearrange(x_u, "(b t) ... -> b t ...", t=self.num_frames)
x_c = rearrange(x_c, "(b t) ... -> b t ...", t=self.num_frames)
scale = repeat(self.scale, "1 t -> b t", b=x_u.shape[0])
scale = append_dims(scale, x_u.ndim).to(x_u.device)
return rearrange(x_u + scale * (x_c - x_u), "b t ... -> (b t) ...")
def prepare_inputs(
self, x: torch.Tensor, s: torch.Tensor, c: dict, uc: dict
) -> Tuple[torch.Tensor, torch.Tensor, dict]:
c_out = dict()
for k in c:
if (
k
in ["vector", "crossattn", "concat", "audio_emb", "masks", "gt"]
+ self.additional_cond_keys
):
c_out[k] = torch.cat((uc[k], c[k]), 0)
else:
assert c[k] == uc[k]
c_out[k] = c[k]
return torch.cat([x] * 2), torch.cat([s] * 2), c_out
class LinearPredictionGuiderPlus(LinearPredictionGuider):
def __init__(
self,
max_scale: float,
num_frames: int,
min_scale: float = 1.0,
additional_cond_keys: Optional[Union[List[str], str]] = None,
):
super().__init__(max_scale, num_frames, min_scale, additional_cond_keys)
def __call__(self, x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor:
x_u, x_c = x.chunk(2)
x_u = rearrange(x_u, "(b t) ... -> b t ...", t=self.num_frames)
x_c = rearrange(x_c, "(b t) ... -> b t ...", t=self.num_frames)
scale = repeat(self.scale, "1 t -> b t", b=x_u.shape[0])
scale = append_dims(scale, x_u.ndim).to(x_u.device)
return rearrange(x_u + scale * (x_c - x_u), "b t ... -> (b t) ..."), x_u
class TrianglePredictionGuider(LinearPredictionGuider):
def __init__(
self,
max_scale: float,
num_frames: int,
min_scale: float = 1.0,
period: float | List[float] = 1.0,
period_fusing: Literal["mean", "multiply", "max"] = "max",
additional_cond_keys: Optional[Union[List[str], str]] = None,
):
super().__init__(max_scale, num_frames, min_scale, additional_cond_keys)
values = torch.linspace(0, 1, num_frames)
# Constructs a triangle wave
if isinstance(period, float):
period = [period]
scales = []
for p in period:
scales.append(self.triangle_wave(values, p))
if period_fusing == "mean":
scale = sum(scales) / len(period)
elif period_fusing == "multiply":
scale = torch.prod(torch.stack(scales), dim=0)
elif period_fusing == "max":
scale = torch.max(torch.stack(scales), dim=0).values
self.scale = (scale * (max_scale - min_scale) + min_scale).unsqueeze(0)
def triangle_wave(self, values: torch.Tensor, period) -> torch.Tensor:
return 2 * (values / period - torch.floor(values / period + 0.5)).abs()