import torch import torch.nn.functional as F import ctypes import numpy as np from einops import rearrange, repeat from scipy.optimize import linear_sum_assignment from typing import Optional, Union, Tuple, List, Callable, Dict from model.modules.dift_utils import gen_nn_map class AttentionBase: def __init__(self): self.cur_step = 0 self.num_att_layers = -1 self.cur_att_layer = 0 def after_step(self): pass def __call__( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, sim: torch.Tensor, attn: torch.Tensor, is_cross: bool, place_in_unet: str, num_heads: int, **kwargs ) -> torch.Tensor: out = self.forward(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs) self.cur_att_layer += 1 if self.cur_att_layer == self.num_att_layers: self.cur_att_layer = 0 self.cur_step += 1 self.after_step() return out def forward( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, sim: torch.Tensor, attn: torch.Tensor, is_cross: bool, place_in_unet: str, num_heads: int, **kwargs ) -> torch.Tensor: out = torch.einsum('b i j, b j d -> b i d', attn, v) out = rearrange(out, '(b h) n d -> b n (h d)', h=num_heads) return out def reset(self): self.cur_step = 0 self.cur_att_layer = 0 class DirectionalAttentionControl(AttentionBase): MODEL_TYPE = {"SD": 16, "SDXL": 70} def __init__( self, start_step: int = 4, start_layer: int = 10, layer_idx: Optional[List[int]] = None, step_idx: Optional[List[int]] = None, total_steps: int = 50, model_type: str = "SD", **kwargs ): super().__init__() self.total_steps = total_steps self.total_layers = self.MODEL_TYPE.get(model_type, 16) self.start_step = start_step self.start_layer = start_layer self.layer_idx = layer_idx if layer_idx is not None else list(range(start_layer, self.total_layers)) self.step_idx = step_idx if step_idx is not None else list(range(start_step, total_steps)) self.w = 1.0 self.structural_alignment = kwargs.get("structural_alignment", False) self.style_transfer_only = kwargs.get("style_transfer_only", False) self.alpha = kwargs.get("alpha", 0.5) self.beta = kwargs.get("beta", 0.5) self.newness = kwargs.get("support_new_object", True) self.mode = kwargs.get("mode", "normal") def forward( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, sim: torch.Tensor, attn: torch.Tensor, is_cross: bool, place_in_unet: str, num_heads: int, **kwargs ) -> torch.Tensor: if is_cross or self.cur_step not in self.step_idx or self.cur_att_layer // 2 not in self.layer_idx: return super().forward(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs) q_s, q_middle, q_t = q.chunk(3) k_s, k_middle, k_t = k.chunk(3) v_s, v_middle, v_t = v.chunk(3) attn_s, attn_middle, attn_t = attn.chunk(3) out_s = self.attn_batch(q_s, k_s, v_s, sim, attn_s, is_cross, place_in_unet, num_heads, **kwargs) out_middle = self.attn_batch(q_middle, k_middle, v_middle, sim, attn_middle, is_cross, place_in_unet, num_heads, **kwargs) if self.cur_step <= 0 and self.beta > 0 and \ self.structural_alignment: q_t = self.align_queries_via_matching(q_s, q_t, beta=self.beta) out_t = self.apply_mode(q_t, k_s, k_t, v_s, v_t, attn_t, sim, is_cross, place_in_unet, num_heads, **kwargs) out = torch.cat([out_s, out_middle, out_t], dim=0) return out def apply_mode( self, q_t: torch.Tensor, k_s: torch.Tensor, k_t: torch.Tensor, v_s: torch.Tensor, v_t: torch.Tensor, attn_t: torch.Tensor, sim: torch.Tensor, is_cross: bool, place_in_unet: str, num_heads: int, **kwargs ) -> torch.Tensor: mode = self.mode if 'dift' in mode and self.cur_step <= 0: mode = 'normal' if mode == "concat": out_t = self.attn_batch( q_t, torch.cat([k_s, 0.85 * k_t]), torch.cat([v_s, v_t]), sim, attn_t, is_cross, place_in_unet, num_heads, **kwargs ) elif mode == "concat_dift": updated_k_s, updated_v_s, _ = self.process_dift_features(kwargs.get("dift_features"), k_s, k_t, v_s, v_t) out_t = self.attn_batch( q_t, torch.cat([updated_k_s, k_t]), torch.cat([updated_v_s, v_t]), sim, attn_t, is_cross, place_in_unet, num_heads, **kwargs ) elif mode == "masa": out_t = self.attn_batch(q_t, k_s, v_s, sim, attn_t, is_cross, place_in_unet, num_heads, **kwargs) elif mode == "normal": out_t = self.attn_batch(q_t, k_t, v_t, sim, attn_t, is_cross, place_in_unet, num_heads, **kwargs) elif mode == "lerp": time = self.alpha k_lerp = k_s + time * (k_t - k_s) v_lerp = v_s + time * (v_t - v_s) out_t = self.attn_batch(q_t, k_lerp, v_lerp, sim, attn_t, is_cross, place_in_unet, num_heads, **kwargs) elif mode == "lerp_dift": updated_k_s, updated_v_s, newness = self.process_dift_features( kwargs.get("dift_features"), k_s, k_t, v_s, v_t, return_newness=self.newness ) out_t = self.apply_lerp_dift(q_t, k_s, k_t, v_s, v_t, updated_k_s, updated_v_s, newness, sim, attn_t, is_cross, place_in_unet, num_heads, **kwargs) elif mode in ("slerp", "log_slerp"): time = self.alpha k_slerp = self.slerp_fixed_length_batch(k_s, k_t, t=time) v_slerp = self.slerp_batch(v_s, v_t, t=time, log_slerp="log" in mode) out_t = self.attn_batch(q_t, k_slerp, v_slerp, sim, attn_t, is_cross, place_in_unet, num_heads, **kwargs) elif mode in ("slerp_dift", "log_slerp_dift"): out_t = self.apply_slerp_dift(q_t, k_s, k_t, v_s, v_t, sim, attn_t, is_cross, place_in_unet, num_heads, **kwargs) else: out_t = self.attn_batch(q_t, k_t, v_t, sim, attn_t, is_cross, place_in_unet, num_heads, **kwargs) return out_t def attn_batch( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, sim: torch.Tensor, attn: torch.Tensor, is_cross: bool, place_in_unet: str, num_heads: int, **kwargs ) -> torch.Tensor: b = q.shape[0] // num_heads q = rearrange(q, "(b h) n d -> h (b n) d", h=num_heads) k = rearrange(k, "(b h) n d -> h (b n) d", h=num_heads) v = rearrange(v, "(b h) n d -> h (b n) d", h=num_heads) scale = kwargs.get("scale", 1.0) sim_batched = torch.einsum("h i d, h j d -> h i j", q, k) * scale attn_batched = sim_batched.softmax(-1) out = torch.einsum("h i j, h j d -> h i d", attn_batched, v) out = rearrange(out, "h (b n) d -> b n (h d)", b=b) return out def slerp(self, x: torch.Tensor, y: torch.Tensor, t: float = 0.5) -> torch.Tensor: x_norm = x.norm(p=2) y_norm = y.norm(p=2) if y_norm < 1e-12: return x y_normalized = y / y_norm y_same_length = y_normalized * x_norm dot_xy = (x * y_same_length).sum() cos_theta = torch.clamp(dot_xy / (x_norm * x_norm), -1.0, 1.0) theta = torch.acos(cos_theta) if torch.isclose(theta, torch.tensor(0.0)): return x sin_theta = torch.sin(theta) s1 = torch.sin((1.0 - t) * theta) / sin_theta s2 = torch.sin(t * theta) / sin_theta return s1 * x + s2 * y_same_length def slerp_batch( self, x: torch.Tensor, y: torch.Tensor, t: float = 0.5, eps: float = 1e-12, log_slerp: bool = False ) -> torch.Tensor: """ Variation of SLERP for batches that allows for linear or logarithmic interpolation of magnitudes. """ x_norm = x.norm(p=2, dim=-1, keepdim=True) y_norm = y.norm(p=2, dim=-1, keepdim=True) y_zero_mask = (y_norm < eps) x_unit = x / (x_norm + eps) y_unit = y / (y_norm + eps) dot_xy = (x_unit * y_unit).sum(dim=-1, keepdim=True) cos_theta = torch.clamp(dot_xy, -1.0, 1.0) theta = torch.acos(cos_theta) sin_theta = torch.sin(theta) theta_zero_mask = (theta.abs() < 1e-7) sin_theta_safe = torch.where(sin_theta.abs() < eps, torch.ones_like(sin_theta), sin_theta) s1 = torch.sin((1.0 - t) * theta) / sin_theta_safe s2 = torch.sin(t * theta) / sin_theta_safe dir_interp = s1 * x_unit + s2 * y_unit if not log_slerp: mag_interp = (1.0 - t) * x_norm + t * y_norm else: mag_interp = (x_norm ** (1.0 - t)) * (y_norm ** t) out = mag_interp * dir_interp out = torch.where(y_zero_mask | theta_zero_mask, x, out) return out def slerp_fixed_length_batch( self, x: torch.Tensor, y: torch.Tensor, t: float = 0.5, eps: float = 1e-12 ) -> torch.Tensor: """ performing SLERP while preserving the norm of the source tensor x """ x_norm = x.norm(p=2, dim=-1, keepdim=True) y_norm = y.norm(p=2, dim=-1, keepdim=True) y_zero_mask = (y_norm < eps) y_normalized = y / (y_norm + eps) y_same_length = y_normalized * x_norm dot_xy = (x * y_same_length).sum(dim=-1, keepdim=True) cos_theta = torch.clamp(dot_xy / (x_norm * x_norm + eps), -1.0, 1.0) theta = torch.acos(cos_theta) sin_theta = torch.sin(theta) sin_theta_safe = torch.where(sin_theta.abs() < eps, torch.ones_like(sin_theta), sin_theta) s1 = torch.sin((1.0 - t) * theta) / sin_theta_safe s2 = torch.sin(t * theta) / sin_theta_safe out = s1 * x + s2 * y_same_length theta_zero_mask = (theta.abs() < 1e-7) out = torch.where(y_zero_mask | theta_zero_mask, x, out) return out def apply_lerp_dift( self, q_t: torch.Tensor, k_s: torch.Tensor, k_t: torch.Tensor, v_s: torch.Tensor, v_t: torch.Tensor, updated_k_s: torch.Tensor, updated_v_s: torch.Tensor, newness: torch.Tensor, sim: torch.Tensor, attn_t: torch.Tensor, is_cross: bool, place_in_unet: str, num_heads: int, **kwargs ) -> torch.Tensor: alpha = self.alpha k_lerp = k_s + alpha * (k_t - k_s) v_lerp = v_s + alpha * (v_t - v_s) if alpha > 0: k_t_new = newness * k_t + (1 - newness) * k_lerp v_t_new = newness * v_t + (1 - newness) * v_lerp else: k_t_new = k_s v_t_new = v_s out_t = self.attn_batch(q_t, k_t_new, v_t_new, sim, attn_t, is_cross, place_in_unet, num_heads, **kwargs) return out_t def apply_slerp_dift( self, q_t: torch.Tensor, k_s: torch.Tensor, k_t: torch.Tensor, v_s: torch.Tensor, v_t: torch.Tensor, sim: torch.Tensor, attn_t: torch.Tensor, is_cross: bool, place_in_unet: str, num_heads: int, **kwargs ) -> torch.Tensor: updated_k_s, updated_v_s, newness = self.process_dift_features( kwargs.get("dift_features"), k_s, k_t, v_s, v_t, return_newness=self.newness ) alpha = self.alpha log_slerp = "log" in self.mode # Interpolate from k_t->updated_k_s so that if alpha=0, we get k_t k_slerp = self.slerp_fixed_length_batch(k_t, updated_k_s, t=1-alpha) v_slerp = self.slerp_batch(v_t, updated_v_s, t=1-alpha, log_slerp=log_slerp) if alpha > 0: k_t_new = newness * k_t + (1 - newness) * k_slerp v_t_new = newness * v_t + (1 - newness) * v_slerp else: k_t_new = k_s v_t_new = v_s out_t = self.attn_batch(q_t, k_t_new, v_t_new, sim, attn_t, is_cross, place_in_unet, num_heads, **kwargs) return out_t def process_dift_features( self, dift_features: torch.Tensor, k_s: torch.Tensor, k_t: torch.Tensor, v_s: torch.Tensor, v_t: torch.Tensor, return_newness: bool = True ): dift_s, _, dift_t = dift_features.chunk(3) k_s1 = k_s.permute(0, 2, 1).reshape(k_s.shape[0], k_s.shape[2], int(k_s.shape[1]**0.5), -1) v_s1 = v_s.permute(0, 2, 1).reshape(v_s.shape[0], v_s.shape[2], int(v_s.shape[1]**0.5), -1) k_s1 = k_s1.reshape(-1, k_s1.shape[-2], k_s1.shape[-1]) v_s1 = v_s1.reshape(-1, v_s1.shape[-2], v_s1.shape[-1]) ################# uncomment only for visualization ################# # result = gen_nn_map( # [dift_s[0], dift_s[0]], # dift_s[0], # dift_t[0], # kernel_size=1, # stride=1, # device=k_s.device, # timestep=self.cur_step, # visualize=True, # return_newness=return_newness # ) ##################################################################### resized_src = F.interpolate(dift_s[0].unsqueeze(0), size=k_s1.shape[-1], mode='bilinear', align_corners=False).squeeze(0) resized_tgt = F.interpolate(dift_t[0].unsqueeze(0), size=k_s1.shape[-1], mode='bilinear', align_corners=False).squeeze(0) result = gen_nn_map( [k_s1, v_s1], resized_src, resized_tgt, kernel_size=1, stride=1, device=k_s.device, timestep=self.cur_step, return_newness=return_newness ) if return_newness: updated_k_s, updated_v_s, newness = result else: updated_k_s, updated_v_s = result newness = torch.zeros_like(updated_k_s[:1]).to(k_s.device) newness = newness.view(-1).unsqueeze(0).unsqueeze(-1) updated_k_s = updated_k_s.reshape(k_s.shape[0], k_s.shape[2], -1).permute(0, 2, 1) updated_v_s = updated_v_s.reshape(v_s.shape[0], v_s.shape[2], -1).permute(0, 2, 1) return updated_k_s, updated_v_s, newness def sinkhorn(self, cost_matrix, max_iter=50, epsilon=1e-8): n, m = cost_matrix.shape K = torch.exp(-cost_matrix / cost_matrix.std()) # Kernelized cost matrix u = torch.ones(n, device=cost_matrix.device) / n v = torch.ones(m, device=cost_matrix.device) / m for _ in range(max_iter): u_prev = u.clone() u = 1.0 / (K @ v) v = 1.0 / (K.T @ u) if torch.max(torch.abs(u - u_prev)) < epsilon: break P = torch.diag(u) @ K @ torch.diag(v) return P def align_queries_via_matching(self, q_s: torch.Tensor, q_t: torch.Tensor, beta: float = 0.5, device: str = "cuda"): q_s = q_s.to(device) q_t = q_t.to(device) B, _, _ = q_s.shape q_t_updated = torch.zeros_like(q_t, device=device) for b in range(B): ########################### L2 ############################## # cost_matrix1 = (q_s[b].unsqueeze(1) - q_t[b].unsqueeze(0)).pow(2).sum(dim=-1) ######################### cosine ############################ cost_matrix1 = - F.cosine_similarity( q_s[b].unsqueeze(1), q_t[b].unsqueeze(0), dim=-1) ############################################################# # cost_matrix2 = (q_t[b].unsqueeze(1) - q_t[b].unsqueeze(0)).pow(2).sum(dim=-1) cost_matrix2 = torch.abs(torch.arange(q_t[b].shape[0], device=device).unsqueeze(0) - torch.arange(q_t[b].shape[0], device=device).unsqueeze(1)).float() cost_matrix2 = cost_matrix2 ** 0.5 # cost_matrix2 = torch.where(cost_matrix2 > 0, 1.0, 0.0) mean1 = cost_matrix1.mean() std1 = cost_matrix1.std() mean2 = cost_matrix2.mean() std2 = cost_matrix2.std() cost_func_1_std = (cost_matrix1 - mean1) / (std1 + 1e-8) cost_func_2_std = (cost_matrix2 - mean2) / (std2 + 1e-8) cost_matrix = beta * cost_func_1_std + (1.0 - beta) * cost_func_2_std cost_np = cost_matrix.detach().cpu().numpy() row_ind, col_ind = linear_sum_assignment(cost_np) q_t_updated[b] = q_t[b][col_ind] # P = self.sinkhorn(cost_matrix) # col_ind = P.argmax(dim=1) # idea 1 # q_t_updated[b] = q_t[b][col_ind] # idea 2 # q_t_updated[b] = P @ q_t[b] return q_t_updated