Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import torch.nn.functional as F | |
import ctypes | |
import numpy as np | |
from einops import rearrange, repeat | |
from scipy.optimize import linear_sum_assignment | |
from typing import Optional, Union, Tuple, List, Callable, Dict | |
from model.modules.dift_utils import gen_nn_map | |
class AttentionBase: | |
def __init__(self): | |
self.cur_step = 0 | |
self.num_att_layers = -1 | |
self.cur_att_layer = 0 | |
def after_step(self): | |
pass | |
def __call__( | |
self, | |
q: torch.Tensor, | |
k: torch.Tensor, | |
v: torch.Tensor, | |
sim: torch.Tensor, | |
attn: torch.Tensor, | |
is_cross: bool, | |
place_in_unet: str, | |
num_heads: int, | |
**kwargs | |
) -> torch.Tensor: | |
out = self.forward(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs) | |
self.cur_att_layer += 1 | |
if self.cur_att_layer == self.num_att_layers: | |
self.cur_att_layer = 0 | |
self.cur_step += 1 | |
self.after_step() | |
return out | |
def forward( | |
self, | |
q: torch.Tensor, | |
k: torch.Tensor, | |
v: torch.Tensor, | |
sim: torch.Tensor, | |
attn: torch.Tensor, | |
is_cross: bool, | |
place_in_unet: str, | |
num_heads: int, | |
**kwargs | |
) -> torch.Tensor: | |
out = torch.einsum('b i j, b j d -> b i d', attn, v) | |
out = rearrange(out, '(b h) n d -> b n (h d)', h=num_heads) | |
return out | |
def reset(self): | |
self.cur_step = 0 | |
self.cur_att_layer = 0 | |
class DirectionalAttentionControl(AttentionBase): | |
MODEL_TYPE = {"SD": 16, "SDXL": 70} | |
def __init__( | |
self, | |
start_step: int = 4, | |
start_layer: int = 10, | |
layer_idx: Optional[List[int]] = None, | |
step_idx: Optional[List[int]] = None, | |
total_steps: int = 50, | |
model_type: str = "SD", | |
**kwargs | |
): | |
super().__init__() | |
self.total_steps = total_steps | |
self.total_layers = self.MODEL_TYPE.get(model_type, 16) | |
self.start_step = start_step | |
self.start_layer = start_layer | |
self.layer_idx = layer_idx if layer_idx is not None else list(range(start_layer, self.total_layers)) | |
self.step_idx = step_idx if step_idx is not None else list(range(start_step, total_steps)) | |
self.w = 1.0 | |
self.structural_alignment = kwargs.get("structural_alignment", False) | |
self.style_transfer_only = kwargs.get("style_transfer_only", False) | |
self.alpha = kwargs.get("alpha", 0.5) | |
self.beta = kwargs.get("beta", 0.5) | |
self.newness = kwargs.get("support_new_object", True) | |
self.mode = kwargs.get("mode", "normal") | |
def forward( | |
self, | |
q: torch.Tensor, | |
k: torch.Tensor, | |
v: torch.Tensor, | |
sim: torch.Tensor, | |
attn: torch.Tensor, | |
is_cross: bool, | |
place_in_unet: str, | |
num_heads: int, | |
**kwargs | |
) -> torch.Tensor: | |
if is_cross or self.cur_step not in self.step_idx or self.cur_att_layer // 2 not in self.layer_idx: | |
return super().forward(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs) | |
q_s, q_middle, q_t = q.chunk(3) | |
k_s, k_middle, k_t = k.chunk(3) | |
v_s, v_middle, v_t = v.chunk(3) | |
attn_s, attn_middle, attn_t = attn.chunk(3) | |
out_s = self.attn_batch(q_s, k_s, v_s, sim, attn_s, is_cross, place_in_unet, num_heads, **kwargs) | |
out_middle = self.attn_batch(q_middle, k_middle, v_middle, sim, attn_middle, is_cross, place_in_unet, num_heads, **kwargs) | |
if self.cur_step <= 0 and self.beta > 0 and \ | |
self.structural_alignment: | |
q_t = self.align_queries_via_matching(q_s, q_t, beta=self.beta) | |
out_t = self.apply_mode(q_t, k_s, k_t, v_s, v_t, attn_t, sim, is_cross, place_in_unet, num_heads, **kwargs) | |
out = torch.cat([out_s, out_middle, out_t], dim=0) | |
return out | |
def apply_mode( | |
self, | |
q_t: torch.Tensor, | |
k_s: torch.Tensor, | |
k_t: torch.Tensor, | |
v_s: torch.Tensor, | |
v_t: torch.Tensor, | |
attn_t: torch.Tensor, | |
sim: torch.Tensor, | |
is_cross: bool, | |
place_in_unet: str, | |
num_heads: int, | |
**kwargs | |
) -> torch.Tensor: | |
mode = self.mode | |
if 'dift' in mode and self.cur_step <= 0: | |
mode = 'normal' | |
if mode == "concat": | |
out_t = self.attn_batch( | |
q_t, torch.cat([k_s, 0.85 * k_t]), torch.cat([v_s, v_t]), | |
sim, attn_t, is_cross, place_in_unet, num_heads, **kwargs | |
) | |
elif mode == "concat_dift": | |
updated_k_s, updated_v_s, _ = self.process_dift_features(kwargs.get("dift_features"), k_s, k_t, v_s, v_t) | |
out_t = self.attn_batch( | |
q_t, torch.cat([updated_k_s, k_t]), torch.cat([updated_v_s, v_t]), | |
sim, attn_t, is_cross, place_in_unet, num_heads, **kwargs | |
) | |
elif mode == "masa": | |
out_t = self.attn_batch(q_t, k_s, v_s, sim, attn_t, is_cross, place_in_unet, num_heads, **kwargs) | |
elif mode == "normal": | |
out_t = self.attn_batch(q_t, k_t, v_t, sim, attn_t, is_cross, place_in_unet, num_heads, **kwargs) | |
elif mode == "lerp": | |
time = self.alpha | |
k_lerp = k_s + time * (k_t - k_s) | |
v_lerp = v_s + time * (v_t - v_s) | |
out_t = self.attn_batch(q_t, k_lerp, v_lerp, sim, attn_t, is_cross, place_in_unet, num_heads, **kwargs) | |
elif mode == "lerp_dift": | |
updated_k_s, updated_v_s, newness = self.process_dift_features( | |
kwargs.get("dift_features"), k_s, k_t, v_s, v_t, return_newness=self.newness | |
) | |
out_t = self.apply_lerp_dift(q_t, k_s, k_t, v_s, v_t, updated_k_s, updated_v_s, newness, sim, attn_t, is_cross, place_in_unet, num_heads, **kwargs) | |
elif mode in ("slerp", "log_slerp"): | |
time = self.alpha | |
k_slerp = self.slerp_fixed_length_batch(k_s, k_t, t=time) | |
v_slerp = self.slerp_batch(v_s, v_t, t=time, log_slerp="log" in mode) | |
out_t = self.attn_batch(q_t, k_slerp, v_slerp, sim, attn_t, is_cross, place_in_unet, num_heads, **kwargs) | |
elif mode in ("slerp_dift", "log_slerp_dift"): | |
out_t = self.apply_slerp_dift(q_t, k_s, k_t, v_s, v_t, sim, attn_t, is_cross, place_in_unet, num_heads, **kwargs) | |
else: | |
out_t = self.attn_batch(q_t, k_t, v_t, sim, attn_t, is_cross, place_in_unet, num_heads, **kwargs) | |
return out_t | |
def attn_batch( | |
self, | |
q: torch.Tensor, | |
k: torch.Tensor, | |
v: torch.Tensor, | |
sim: torch.Tensor, | |
attn: torch.Tensor, | |
is_cross: bool, | |
place_in_unet: str, | |
num_heads: int, | |
**kwargs | |
) -> torch.Tensor: | |
b = q.shape[0] // num_heads | |
q = rearrange(q, "(b h) n d -> h (b n) d", h=num_heads) | |
k = rearrange(k, "(b h) n d -> h (b n) d", h=num_heads) | |
v = rearrange(v, "(b h) n d -> h (b n) d", h=num_heads) | |
scale = kwargs.get("scale", 1.0) | |
sim_batched = torch.einsum("h i d, h j d -> h i j", q, k) * scale | |
attn_batched = sim_batched.softmax(-1) | |
out = torch.einsum("h i j, h j d -> h i d", attn_batched, v) | |
out = rearrange(out, "h (b n) d -> b n (h d)", b=b) | |
return out | |
def slerp(self, x: torch.Tensor, y: torch.Tensor, t: float = 0.5) -> torch.Tensor: | |
x_norm = x.norm(p=2) | |
y_norm = y.norm(p=2) | |
if y_norm < 1e-12: | |
return x | |
y_normalized = y / y_norm | |
y_same_length = y_normalized * x_norm | |
dot_xy = (x * y_same_length).sum() | |
cos_theta = torch.clamp(dot_xy / (x_norm * x_norm), -1.0, 1.0) | |
theta = torch.acos(cos_theta) | |
if torch.isclose(theta, torch.tensor(0.0)): | |
return x | |
sin_theta = torch.sin(theta) | |
s1 = torch.sin((1.0 - t) * theta) / sin_theta | |
s2 = torch.sin(t * theta) / sin_theta | |
return s1 * x + s2 * y_same_length | |
def slerp_batch( | |
self, | |
x: torch.Tensor, | |
y: torch.Tensor, | |
t: float = 0.5, | |
eps: float = 1e-12, | |
log_slerp: bool = False | |
) -> torch.Tensor: | |
""" | |
Variation of SLERP for batches that allows for linear or logarithmic interpolation of magnitudes. | |
""" | |
x_norm = x.norm(p=2, dim=-1, keepdim=True) | |
y_norm = y.norm(p=2, dim=-1, keepdim=True) | |
y_zero_mask = (y_norm < eps) | |
x_unit = x / (x_norm + eps) | |
y_unit = y / (y_norm + eps) | |
dot_xy = (x_unit * y_unit).sum(dim=-1, keepdim=True) | |
cos_theta = torch.clamp(dot_xy, -1.0, 1.0) | |
theta = torch.acos(cos_theta) | |
sin_theta = torch.sin(theta) | |
theta_zero_mask = (theta.abs() < 1e-7) | |
sin_theta_safe = torch.where(sin_theta.abs() < eps, torch.ones_like(sin_theta), sin_theta) | |
s1 = torch.sin((1.0 - t) * theta) / sin_theta_safe | |
s2 = torch.sin(t * theta) / sin_theta_safe | |
dir_interp = s1 * x_unit + s2 * y_unit | |
if not log_slerp: | |
mag_interp = (1.0 - t) * x_norm + t * y_norm | |
else: | |
mag_interp = (x_norm ** (1.0 - t)) * (y_norm ** t) | |
out = mag_interp * dir_interp | |
out = torch.where(y_zero_mask | theta_zero_mask, x, out) | |
return out | |
def slerp_fixed_length_batch( | |
self, | |
x: torch.Tensor, | |
y: torch.Tensor, | |
t: float = 0.5, | |
eps: float = 1e-12 | |
) -> torch.Tensor: | |
""" | |
performing SLERP while preserving the norm of the source tensor x | |
""" | |
x_norm = x.norm(p=2, dim=-1, keepdim=True) | |
y_norm = y.norm(p=2, dim=-1, keepdim=True) | |
y_zero_mask = (y_norm < eps) | |
y_normalized = y / (y_norm + eps) | |
y_same_length = y_normalized * x_norm | |
dot_xy = (x * y_same_length).sum(dim=-1, keepdim=True) | |
cos_theta = torch.clamp(dot_xy / (x_norm * x_norm + eps), -1.0, 1.0) | |
theta = torch.acos(cos_theta) | |
sin_theta = torch.sin(theta) | |
sin_theta_safe = torch.where(sin_theta.abs() < eps, torch.ones_like(sin_theta), sin_theta) | |
s1 = torch.sin((1.0 - t) * theta) / sin_theta_safe | |
s2 = torch.sin(t * theta) / sin_theta_safe | |
out = s1 * x + s2 * y_same_length | |
theta_zero_mask = (theta.abs() < 1e-7) | |
out = torch.where(y_zero_mask | theta_zero_mask, x, out) | |
return out | |
def apply_lerp_dift( | |
self, | |
q_t: torch.Tensor, | |
k_s: torch.Tensor, | |
k_t: torch.Tensor, | |
v_s: torch.Tensor, | |
v_t: torch.Tensor, | |
updated_k_s: torch.Tensor, | |
updated_v_s: torch.Tensor, | |
newness: torch.Tensor, | |
sim: torch.Tensor, | |
attn_t: torch.Tensor, | |
is_cross: bool, | |
place_in_unet: str, | |
num_heads: int, | |
**kwargs | |
) -> torch.Tensor: | |
alpha = self.alpha | |
k_lerp = k_s + alpha * (k_t - k_s) | |
v_lerp = v_s + alpha * (v_t - v_s) | |
if alpha > 0: | |
k_t_new = newness * k_t + (1 - newness) * k_lerp | |
v_t_new = newness * v_t + (1 - newness) * v_lerp | |
else: | |
k_t_new = k_s | |
v_t_new = v_s | |
out_t = self.attn_batch(q_t, k_t_new, v_t_new, sim, attn_t, is_cross, place_in_unet, num_heads, **kwargs) | |
return out_t | |
def apply_slerp_dift( | |
self, | |
q_t: torch.Tensor, | |
k_s: torch.Tensor, | |
k_t: torch.Tensor, | |
v_s: torch.Tensor, | |
v_t: torch.Tensor, | |
sim: torch.Tensor, | |
attn_t: torch.Tensor, | |
is_cross: bool, | |
place_in_unet: str, | |
num_heads: int, | |
**kwargs | |
) -> torch.Tensor: | |
updated_k_s, updated_v_s, newness = self.process_dift_features( | |
kwargs.get("dift_features"), k_s, k_t, v_s, v_t, return_newness=self.newness | |
) | |
alpha = self.alpha | |
log_slerp = "log" in self.mode | |
# Interpolate from k_t->updated_k_s so that if alpha=0, we get k_t | |
k_slerp = self.slerp_fixed_length_batch(k_t, updated_k_s, t=1-alpha) | |
v_slerp = self.slerp_batch(v_t, updated_v_s, t=1-alpha, log_slerp=log_slerp) | |
if alpha > 0: | |
k_t_new = newness * k_t + (1 - newness) * k_slerp | |
v_t_new = newness * v_t + (1 - newness) * v_slerp | |
else: | |
k_t_new = k_s | |
v_t_new = v_s | |
out_t = self.attn_batch(q_t, k_t_new, v_t_new, sim, attn_t, is_cross, place_in_unet, num_heads, **kwargs) | |
return out_t | |
def process_dift_features( | |
self, | |
dift_features: torch.Tensor, | |
k_s: torch.Tensor, | |
k_t: torch.Tensor, | |
v_s: torch.Tensor, | |
v_t: torch.Tensor, | |
return_newness: bool = True | |
): | |
dift_s, _, dift_t = dift_features.chunk(3) | |
k_s1 = k_s.permute(0, 2, 1).reshape(k_s.shape[0], k_s.shape[2], int(k_s.shape[1]**0.5), -1) | |
v_s1 = v_s.permute(0, 2, 1).reshape(v_s.shape[0], v_s.shape[2], int(v_s.shape[1]**0.5), -1) | |
k_s1 = k_s1.reshape(-1, k_s1.shape[-2], k_s1.shape[-1]) | |
v_s1 = v_s1.reshape(-1, v_s1.shape[-2], v_s1.shape[-1]) | |
################# uncomment only for visualization ################# | |
# result = gen_nn_map( | |
# [dift_s[0], dift_s[0]], | |
# dift_s[0], | |
# dift_t[0], | |
# kernel_size=1, | |
# stride=1, | |
# device=k_s.device, | |
# timestep=self.cur_step, | |
# visualize=True, | |
# return_newness=return_newness | |
# ) | |
##################################################################### | |
resized_src = F.interpolate(dift_s[0].unsqueeze(0), size=k_s1.shape[-1], mode='bilinear', align_corners=False).squeeze(0) | |
resized_tgt = F.interpolate(dift_t[0].unsqueeze(0), size=k_s1.shape[-1], mode='bilinear', align_corners=False).squeeze(0) | |
result = gen_nn_map( | |
[k_s1, v_s1], | |
resized_src, | |
resized_tgt, | |
kernel_size=1, | |
stride=1, | |
device=k_s.device, | |
timestep=self.cur_step, | |
return_newness=return_newness | |
) | |
if return_newness: | |
updated_k_s, updated_v_s, newness = result | |
else: | |
updated_k_s, updated_v_s = result | |
newness = torch.zeros_like(updated_k_s[:1]).to(k_s.device) | |
newness = newness.view(-1).unsqueeze(0).unsqueeze(-1) | |
updated_k_s = updated_k_s.reshape(k_s.shape[0], k_s.shape[2], -1).permute(0, 2, 1) | |
updated_v_s = updated_v_s.reshape(v_s.shape[0], v_s.shape[2], -1).permute(0, 2, 1) | |
return updated_k_s, updated_v_s, newness | |
def sinkhorn(self, cost_matrix, max_iter=50, epsilon=1e-8): | |
n, m = cost_matrix.shape | |
K = torch.exp(-cost_matrix / cost_matrix.std()) # Kernelized cost matrix | |
u = torch.ones(n, device=cost_matrix.device) / n | |
v = torch.ones(m, device=cost_matrix.device) / m | |
for _ in range(max_iter): | |
u_prev = u.clone() | |
u = 1.0 / (K @ v) | |
v = 1.0 / (K.T @ u) | |
if torch.max(torch.abs(u - u_prev)) < epsilon: | |
break | |
P = torch.diag(u) @ K @ torch.diag(v) | |
return P | |
def align_queries_via_matching(self, q_s: torch.Tensor, q_t: torch.Tensor, beta: float = 0.5, device: str = "cuda"): | |
q_s = q_s.to(device) | |
q_t = q_t.to(device) | |
B, _, _ = q_s.shape | |
q_t_updated = torch.zeros_like(q_t, device=device) | |
for b in range(B): | |
########################### L2 ############################## | |
# cost_matrix1 = (q_s[b].unsqueeze(1) - q_t[b].unsqueeze(0)).pow(2).sum(dim=-1) | |
######################### cosine ############################ | |
cost_matrix1 = - F.cosine_similarity( | |
q_s[b].unsqueeze(1), q_t[b].unsqueeze(0), dim=-1) | |
############################################################# | |
# cost_matrix2 = (q_t[b].unsqueeze(1) - q_t[b].unsqueeze(0)).pow(2).sum(dim=-1) | |
cost_matrix2 = torch.abs(torch.arange(q_t[b].shape[0], device=device).unsqueeze(0) - | |
torch.arange(q_t[b].shape[0], device=device).unsqueeze(1)).float() | |
cost_matrix2 = cost_matrix2 ** 0.5 | |
# cost_matrix2 = torch.where(cost_matrix2 > 0, 1.0, 0.0) | |
mean1 = cost_matrix1.mean() | |
std1 = cost_matrix1.std() | |
mean2 = cost_matrix2.mean() | |
std2 = cost_matrix2.std() | |
cost_func_1_std = (cost_matrix1 - mean1) / (std1 + 1e-8) | |
cost_func_2_std = (cost_matrix2 - mean2) / (std2 + 1e-8) | |
cost_matrix = beta * cost_func_1_std + (1.0 - beta) * cost_func_2_std | |
cost_np = cost_matrix.detach().cpu().numpy() | |
row_ind, col_ind = linear_sum_assignment(cost_np) | |
q_t_updated[b] = q_t[b][col_ind] | |
# P = self.sinkhorn(cost_matrix) | |
# col_ind = P.argmax(dim=1) | |
# idea 1 | |
# q_t_updated[b] = q_t[b][col_ind] | |
# idea 2 | |
# q_t_updated[b] = P @ q_t[b] | |
return q_t_updated | |