NestedAttentionEncoder / nested_attention_processor.py
orpatashnik's picture
add code
b197ccc
raw
history blame
12.3 kB
# modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
def my_scaled_dot_product_attention(
query,
key,
value,
attn_mask=None,
dropout_p=0.0,
is_causal=False,
scale=None,
special_token_weight=1.0,
special_token_indices=None,
) -> torch.Tensor:
"""
Computes the scaled dot-product attention with additional control over specific tokens.
This function is a re-implementation of the scaled dot-product attention mechanism,
designed to return both the attention map and the output of the attention operation.
It also provides additional control via a scalar that modifies the attention map
for specific tokens.
"""
L, S = query.size(-2), key.size(-2)
scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
attn_bias = torch.zeros(L, S, dtype=query.dtype).cuda()
if is_causal:
assert attn_mask is None
temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
attn_bias.to(query.dtype)
if attn_mask is not None:
if attn_mask.dtype == torch.bool:
attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
else:
attn_bias += attn_mask
attn_weight = query @ key.transpose(-2, -1) * scale_factor
attn_weight += attn_bias
if special_token_indices is not None and special_token_weight != 1.0:
bs = attn_weight.shape[0]
attn_weight[torch.arange(bs), :, :, special_token_indices] = torch.max(
attn_weight[torch.arange(bs), :, :, special_token_indices],
attn_weight[torch.arange(bs), :, :, special_token_indices]
* special_token_weight,
)
attn_weight = torch.softmax(attn_weight, dim=-1)
attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
return attn_weight @ value, attn_weight
class AttnProcessor(torch.nn.Module):
r"""
Processor for implementing scaled dot-product attention.
"""
def __init__(
self,
hidden_size=None,
cross_attention_dim=None,
):
super().__init__()
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"AttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
)
def __call__(
self,
attn,
hidden_states,
qformer_tokens_out=None,
special_token_indices=None,
inference_mode=None,
encoder_hidden_states=None,
attention_mask=None,
temb=None,
special_token_weight=None,
):
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(
batch_size, channel, height * width
).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape
if encoder_hidden_states is None
else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(
attention_mask, sequence_length, batch_size
)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(
batch_size, attn.heads, -1, attention_mask.shape[-1]
)
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
1, 2
)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(
encoder_hidden_states
)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(
batch_size, -1, attn.heads * head_dim
)
hidden_states = hidden_states.to(query.dtype)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(
batch_size, channel, height, width
)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
class NestedAttnProcessor(torch.nn.Module):
r"""
Nested Attention processor for IP-Adapater for PyTorch 2.0.
"""
def __init__(self, hidden_size, cross_attention_dim=None, normalize_factor=1.0):
super().__init__()
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"NestedAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
)
self.hidden_size = hidden_size
self.cross_attention_dim = cross_attention_dim
self.normalize_factor = normalize_factor
self.nested_to_k = nn.Linear(
cross_attention_dim or hidden_size, hidden_size, bias=False
)
self.nested_to_v = nn.Linear(
cross_attention_dim or hidden_size, hidden_size, bias=False
)
def __call__(
self,
attn,
hidden_states,
qformer_tokens_out,
special_token_indices,
inference_mode=False,
encoder_hidden_states=None,
attention_mask=None,
temb=None,
special_token_weight=1.0,
):
assert (
special_token_indices.shape[0] > 0
), "special_token_indices should not be empty"
# if inference mode is set to True, the code assumes that CFG is used and the first half
# of the batch is used for the null prompt and the second half is used for the prompt
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
bs = hidden_states.shape[0]
if input_ndim == 4:
bs, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(bs, channel, height * width).transpose(
1, 2
)
bs, sequence_length, _ = (
hidden_states.shape
if encoder_hidden_states is None
else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(
attention_mask, sequence_length, bs
)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(
bs, attn.heads, -1, attention_mask.shape[-1]
)
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
1, 2
)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
else:
if attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(
encoder_hidden_states
)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(bs, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(bs, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(bs, -1, attn.heads, head_dim).transpose(1, 2)
# nested attention
nested_key = self.nested_to_k(qformer_tokens_out)
nested_value = self.nested_to_v(qformer_tokens_out)
nested_key = nested_key.view(bs, -1, attn.heads, head_dim).transpose(1, 2)
nested_value = nested_value.view(bs, -1, attn.heads, head_dim).transpose(1, 2)
nested_hidden_states = F.scaled_dot_product_attention(
query,
nested_key,
nested_value,
attn_mask=None,
dropout_p=0.0,
is_causal=False,
)
# normalize V_q
textual_values_norms = torch.norm(
value[torch.arange(bs), :, special_token_indices], dim=-1
)
nested_hidden_states = (
torch.nn.functional.normalize(nested_hidden_states, p=2, dim=-1)
* self.normalize_factor
)
nested_hidden_states = (
textual_values_norms.view(bs, -1, 1, 1) * nested_hidden_states
)
# outer attention
value_without_special_tokens = value.clone()
if inference_mode:
value_without_special_tokens[bs // 2 : bs, :, special_token_indices, :] = (
0.0
)
else:
value_without_special_tokens[
torch.arange(bs), :, special_token_indices, :
] = 0.0
hidden_states_without_special_tokens, attn_weight = (
my_scaled_dot_product_attention(
query,
key,
value_without_special_tokens,
attn_mask=None,
dropout_p=0.0,
is_causal=False,
special_token_weight=special_token_weight,
special_token_indices=special_token_indices,
)
)
# add the special token values
if inference_mode:
special_token_attn_weight = attn_weight[
bs // 2 : bs, :, :, special_token_indices
]
else:
special_token_attn_weight = attn_weight[
torch.arange(bs), :, :, special_token_indices
]
if inference_mode:
special_token_weighted_values = (
special_token_attn_weight * nested_hidden_states[bs // 2 : bs]
)
else:
special_token_weighted_values = (
special_token_attn_weight.unsqueeze(-1) * nested_hidden_states
)
if inference_mode:
hidden_states = hidden_states_without_special_tokens
hidden_states[bs // 2 : bs] += special_token_weighted_values
else:
hidden_states = (
hidden_states_without_special_tokens + special_token_weighted_values
)
# arrange hidden states
hidden_states = hidden_states.transpose(1, 2).reshape(
bs, -1, attn.heads * head_dim
)
hidden_states = hidden_states.to(query.dtype)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(
bs, channel, height, width
)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states