# modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py import math import torch import torch.nn as nn import torch.nn.functional as F def my_scaled_dot_product_attention( query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None, special_token_weight=1.0, special_token_indices=None, ) -> torch.Tensor: """ Computes the scaled dot-product attention with additional control over specific tokens. This function is a re-implementation of the scaled dot-product attention mechanism, designed to return both the attention map and the output of the attention operation. It also provides additional control via a scalar that modifies the attention map for specific tokens. """ L, S = query.size(-2), key.size(-2) scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale attn_bias = torch.zeros(L, S, dtype=query.dtype).cuda() if is_causal: assert attn_mask is None temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0) attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) attn_bias.to(query.dtype) if attn_mask is not None: if attn_mask.dtype == torch.bool: attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf")) else: attn_bias += attn_mask attn_weight = query @ key.transpose(-2, -1) * scale_factor attn_weight += attn_bias if special_token_indices is not None and special_token_weight != 1.0: bs = attn_weight.shape[0] attn_weight[torch.arange(bs), :, :, special_token_indices] = torch.max( attn_weight[torch.arange(bs), :, :, special_token_indices], attn_weight[torch.arange(bs), :, :, special_token_indices] * special_token_weight, ) attn_weight = torch.softmax(attn_weight, dim=-1) attn_weight = torch.dropout(attn_weight, dropout_p, train=True) return attn_weight @ value, attn_weight class AttnProcessor(torch.nn.Module): r""" Processor for implementing scaled dot-product attention. """ def __init__( self, hidden_size=None, cross_attention_dim=None, ): super().__init__() if not hasattr(F, "scaled_dot_product_attention"): raise ImportError( "AttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." ) def __call__( self, attn, hidden_states, qformer_tokens_out=None, special_token_indices=None, inference_mode=None, encoder_hidden_states=None, attention_mask=None, temb=None, special_token_weight=None, ): residual = hidden_states if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view( batch_size, channel, height * width ).transpose(1, 2) batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) if attention_mask is not None: attention_mask = attn.prepare_attention_mask( attention_mask, sequence_length, batch_size ) # scaled_dot_product_attention expects attention_mask shape to be # (batch, heads, source_length, target_length) attention_mask = attention_mask.view( batch_size, attn.heads, -1, attention_mask.shape[-1] ) if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose( 1, 2 ) query = attn.to_q(hidden_states) if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states( encoder_hidden_states ) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) inner_dim = key.shape[-1] head_dim = inner_dim // attn.heads query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) # the output of sdp = (batch, num_heads, seq_len, head_dim) hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) hidden_states = hidden_states.transpose(1, 2).reshape( batch_size, -1, attn.heads * head_dim ) hidden_states = hidden_states.to(query.dtype) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape( batch_size, channel, height, width ) if attn.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor return hidden_states class NestedAttnProcessor(torch.nn.Module): r""" Nested Attention processor for IP-Adapater for PyTorch 2.0. """ def __init__(self, hidden_size, cross_attention_dim=None, normalize_factor=1.0): super().__init__() if not hasattr(F, "scaled_dot_product_attention"): raise ImportError( "NestedAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." ) self.hidden_size = hidden_size self.cross_attention_dim = cross_attention_dim self.normalize_factor = normalize_factor self.nested_to_k = nn.Linear( cross_attention_dim or hidden_size, hidden_size, bias=False ) self.nested_to_v = nn.Linear( cross_attention_dim or hidden_size, hidden_size, bias=False ) def __call__( self, attn, hidden_states, qformer_tokens_out, special_token_indices, inference_mode=False, encoder_hidden_states=None, attention_mask=None, temb=None, special_token_weight=1.0, ): assert ( special_token_indices.shape[0] > 0 ), "special_token_indices should not be empty" # if inference mode is set to True, the code assumes that CFG is used and the first half # of the batch is used for the null prompt and the second half is used for the prompt residual = hidden_states if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) input_ndim = hidden_states.ndim bs = hidden_states.shape[0] if input_ndim == 4: bs, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(bs, channel, height * width).transpose( 1, 2 ) bs, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) if attention_mask is not None: attention_mask = attn.prepare_attention_mask( attention_mask, sequence_length, bs ) # scaled_dot_product_attention expects attention_mask shape to be # (batch, heads, source_length, target_length) attention_mask = attention_mask.view( bs, attn.heads, -1, attention_mask.shape[-1] ) if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose( 1, 2 ) query = attn.to_q(hidden_states) if encoder_hidden_states is None: encoder_hidden_states = hidden_states else: if attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states( encoder_hidden_states ) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) inner_dim = key.shape[-1] head_dim = inner_dim // attn.heads query = query.view(bs, -1, attn.heads, head_dim).transpose(1, 2) key = key.view(bs, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(bs, -1, attn.heads, head_dim).transpose(1, 2) # nested attention nested_key = self.nested_to_k(qformer_tokens_out) nested_value = self.nested_to_v(qformer_tokens_out) nested_key = nested_key.view(bs, -1, attn.heads, head_dim).transpose(1, 2) nested_value = nested_value.view(bs, -1, attn.heads, head_dim).transpose(1, 2) nested_hidden_states = F.scaled_dot_product_attention( query, nested_key, nested_value, attn_mask=None, dropout_p=0.0, is_causal=False, ) # normalize V_q textual_values_norms = torch.norm( value[torch.arange(bs), :, special_token_indices], dim=-1 ) nested_hidden_states = ( torch.nn.functional.normalize(nested_hidden_states, p=2, dim=-1) * self.normalize_factor ) nested_hidden_states = ( textual_values_norms.view(bs, -1, 1, 1) * nested_hidden_states ) # outer attention value_without_special_tokens = value.clone() if inference_mode: value_without_special_tokens[bs // 2 : bs, :, special_token_indices, :] = ( 0.0 ) else: value_without_special_tokens[ torch.arange(bs), :, special_token_indices, : ] = 0.0 hidden_states_without_special_tokens, attn_weight = ( my_scaled_dot_product_attention( query, key, value_without_special_tokens, attn_mask=None, dropout_p=0.0, is_causal=False, special_token_weight=special_token_weight, special_token_indices=special_token_indices, ) ) # add the special token values if inference_mode: special_token_attn_weight = attn_weight[ bs // 2 : bs, :, :, special_token_indices ] else: special_token_attn_weight = attn_weight[ torch.arange(bs), :, :, special_token_indices ] if inference_mode: special_token_weighted_values = ( special_token_attn_weight * nested_hidden_states[bs // 2 : bs] ) else: special_token_weighted_values = ( special_token_attn_weight.unsqueeze(-1) * nested_hidden_states ) if inference_mode: hidden_states = hidden_states_without_special_tokens hidden_states[bs // 2 : bs] += special_token_weighted_values else: hidden_states = ( hidden_states_without_special_tokens + special_token_weighted_values ) # arrange hidden states hidden_states = hidden_states.transpose(1, 2).reshape( bs, -1, attn.heads * head_dim ) hidden_states = hidden_states.to(query.dtype) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape( bs, channel, height, width ) if attn.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor return hidden_states