Spaces:
mashroo
/
Running on Zero

CRM / imagedream /ldm /modules /attention.py
YoussefAnso's picture
Refactor attention module to improve xformers integration. Renamed availability flag to HAS_XFORMERS and added safe_memory_efficient_attention function for better handling of attention operations across devices. Updated related assertions and calls to ensure compatibility with systems lacking GPU support.
1d3fed2
from inspect import isfunction
import math
import torch
import torch.nn.functional as F
from torch import nn, einsum
from einops import rearrange, repeat
from typing import Optional, Any
from .diffusionmodules.util import checkpoint
try:
import xformers
import xformers.ops
HAS_XFORMERS = True
except ImportError:
HAS_XFORMERS = False
# CrossAttn precision handling
import os
_ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32")
def exists(val):
return val is not None
def uniq(arr):
return {el: True for el in arr}.keys()
def default(val, d):
if exists(val):
return val
return d() if isfunction(d) else d
def max_neg_value(t):
return -torch.finfo(t.dtype).max
def init_(tensor):
dim = tensor.shape[-1]
std = 1 / math.sqrt(dim)
tensor.uniform_(-std, std)
return tensor
# feedforward
class GEGLU(nn.Module):
def __init__(self, dim_in, dim_out):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out * 2)
def forward(self, x):
x, gate = self.proj(x).chunk(2, dim=-1)
return x * F.gelu(gate)
class FeedForward(nn.Module):
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
super().__init__()
inner_dim = int(dim * mult)
dim_out = default(dim_out, dim)
project_in = (
nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
if not glu
else GEGLU(dim, inner_dim)
)
self.net = nn.Sequential(
project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
)
def forward(self, x):
return self.net(x)
def zero_module(module):
"""
Zero out the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().zero_()
return module
def Normalize(in_channels):
return torch.nn.GroupNorm(
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
)
class SpatialSelfAttention(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.in_channels = in_channels
self.norm = Normalize(in_channels)
self.q = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
self.k = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
self.v = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
self.proj_out = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
def forward(self, x):
h_ = x
h_ = self.norm(h_)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
# compute attention
b, c, h, w = q.shape
q = rearrange(q, "b c h w -> b (h w) c")
k = rearrange(k, "b c h w -> b c (h w)")
w_ = torch.einsum("bij,bjk->bik", q, k)
w_ = w_ * (int(c) ** (-0.5))
w_ = torch.nn.functional.softmax(w_, dim=2)
# attend to values
v = rearrange(v, "b c h w -> b c (h w)")
w_ = rearrange(w_, "b i j -> b j i")
h_ = torch.einsum("bij,bjk->bik", v, w_)
h_ = rearrange(h_, "b c (h w) -> b c h w", h=h)
h_ = self.proj_out(h_)
return x + h_
def safe_memory_efficient_attention(q, k, v, attn_bias=None, op=None, p=0.0):
if q.device.type == "cuda" and HAS_XFORMERS:
return xformers.ops.memory_efficient_attention(q, k, v, attn_bias=attn_bias, op=op, p=p)
else:
# Standard attention for CPU
scale = 1.0 / (q.shape[-1] ** 0.5)
attn = torch.matmul(q * scale, k.transpose(-2, -1))
if attn_bias is not None:
attn = attn + attn_bias
attn = torch.softmax(attn, dim=-1)
attn = torch.nn.functional.dropout(attn, p=p)
return torch.matmul(attn, v)
class MemoryEfficientCrossAttention(nn.Module):
# https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, **kwargs):
super().__init__()
print(
f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
f"{heads} heads."
)
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)
self.heads = heads
self.dim_head = dim_head
self.with_ip = kwargs.get("with_ip", False)
if self.with_ip and (context_dim is not None):
self.to_k_ip = nn.Linear(context_dim, inner_dim, bias=False)
self.to_v_ip = nn.Linear(context_dim, inner_dim, bias=False)
self.ip_dim= kwargs.get("ip_dim", 16)
self.ip_weight = kwargs.get("ip_weight", 1.0)
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
)
self.attention_op: Optional[Any] = None
def forward(self, x, context=None, mask=None):
q = self.to_q(x)
has_ip = self.with_ip and (context is not None)
if has_ip:
# context dim [(b frame_num), (77 + img_token), 1024]
token_len = context.shape[1]
context_ip = context[:, -self.ip_dim:, :]
k_ip = self.to_k_ip(context_ip)
v_ip = self.to_v_ip(context_ip)
context = context[:, :(token_len - self.ip_dim), :]
context = default(context, x)
k = self.to_k(context)
v = self.to_v(context)
b, _, _ = q.shape
q, k, v = map(
lambda t: t.unsqueeze(3)
.reshape(b, t.shape[1], self.heads, self.dim_head)
.permute(0, 2, 1, 3)
.reshape(b * self.heads, t.shape[1], self.dim_head)
.contiguous(),
(q, k, v),
)
# actually compute the attention, what we cannot get enough of
out = safe_memory_efficient_attention(
q, k, v, attn_bias=None, op=self.attention_op
)
if has_ip:
k_ip, v_ip = map(
lambda t: t.unsqueeze(3)
.reshape(b, t.shape[1], self.heads, self.dim_head)
.permute(0, 2, 1, 3)
.reshape(b * self.heads, t.shape[1], self.dim_head)
.contiguous(),
(k_ip, v_ip),
)
# actually compute the attention, what we cannot get enough of
out_ip = safe_memory_efficient_attention(
q, k_ip, v_ip, attn_bias=None, op=self.attention_op
)
out = out + self.ip_weight * out_ip
if exists(mask):
raise NotImplementedError
out = (
out.unsqueeze(0)
.reshape(b, self.heads, out.shape[1], self.dim_head)
.permute(0, 2, 1, 3)
.reshape(b, out.shape[1], self.heads * self.dim_head)
)
return self.to_out(out)
class BasicTransformerBlock(nn.Module):
def __init__(
self,
dim,
n_heads,
d_head,
dropout=0.0,
context_dim=None,
gated_ff=True,
checkpoint=True,
disable_self_attn=False,
**kwargs
):
super().__init__()
assert HAS_XFORMERS, "xformers is not available"
attn_cls = MemoryEfficientCrossAttention
self.disable_self_attn = disable_self_attn
self.attn1 = attn_cls(
query_dim=dim,
heads=n_heads,
dim_head=d_head,
dropout=dropout,
context_dim=context_dim if self.disable_self_attn else None,
) # is a self-attention if not self.disable_self_attn
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
self.attn2 = attn_cls(
query_dim=dim,
context_dim=context_dim,
heads=n_heads,
dim_head=d_head,
dropout=dropout,
**kwargs
) # is self-attn if context is none
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.norm3 = nn.LayerNorm(dim)
self.checkpoint = checkpoint
def forward(self, x, context=None):
return checkpoint(
self._forward, (x, context), self.parameters(), self.checkpoint
)
def _forward(self, x, context=None):
x = (
self.attn1(
self.norm1(x), context=context if self.disable_self_attn else None
)
+ x
)
x = self.attn2(self.norm2(x), context=context) + x
x = self.ff(self.norm3(x)) + x
return x
class SpatialTransformer(nn.Module):
"""
Transformer block for image-like data.
First, project the input (aka embedding)
and reshape to b, t, d.
Then apply standard transformer action.
Finally, reshape to image
NEW: use_linear for more efficiency instead of the 1x1 convs
"""
def __init__(
self,
in_channels,
n_heads,
d_head,
depth=1,
dropout=0.0,
context_dim=None,
disable_self_attn=False,
use_linear=False,
use_checkpoint=True,
**kwargs
):
super().__init__()
if exists(context_dim) and not isinstance(context_dim, list):
context_dim = [context_dim]
self.in_channels = in_channels
inner_dim = n_heads * d_head
self.norm = Normalize(in_channels)
if not use_linear:
self.proj_in = nn.Conv2d(
in_channels, inner_dim, kernel_size=1, stride=1, padding=0
)
else:
self.proj_in = nn.Linear(in_channels, inner_dim)
self.transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(
inner_dim,
n_heads,
d_head,
dropout=dropout,
context_dim=context_dim[d],
disable_self_attn=disable_self_attn,
checkpoint=use_checkpoint,
**kwargs
)
for d in range(depth)
]
)
if not use_linear:
self.proj_out = zero_module(
nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
)
else:
self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
self.use_linear = use_linear
def forward(self, x, context=None):
# note: if no context is given, cross-attention defaults to self-attention
if not isinstance(context, list):
context = [context]
b, c, h, w = x.shape
x_in = x
x = self.norm(x)
if not self.use_linear:
x = self.proj_in(x)
x = rearrange(x, "b c h w -> b (h w) c").contiguous()
if self.use_linear:
x = self.proj_in(x)
for i, block in enumerate(self.transformer_blocks):
x = block(x, context=context[i])
if self.use_linear:
x = self.proj_out(x)
x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous()
if not self.use_linear:
x = self.proj_out(x)
return x + x_in
class BasicTransformerBlock3D(BasicTransformerBlock):
def forward(self, x, context=None, num_frames=1):
return checkpoint(
self._forward, (x, context, num_frames), self.parameters(), self.checkpoint
)
def _forward(self, x, context=None, num_frames=1):
x = rearrange(x, "(b f) l c -> b (f l) c", f=num_frames).contiguous()
x = (
self.attn1(
self.norm1(x),
context=context if self.disable_self_attn else None
)
+ x
)
x = rearrange(x, "b (f l) c -> (b f) l c", f=num_frames).contiguous()
x = self.attn2(self.norm2(x), context=context) + x
x = self.ff(self.norm3(x)) + x
return x
class SpatialTransformer3D(nn.Module):
"""3D self-attention"""
def __init__(
self,
in_channels,
n_heads,
d_head,
depth=1,
dropout=0.0,
context_dim=None,
disable_self_attn=False,
use_linear=False,
use_checkpoint=True,
**kwargs
):
super().__init__()
if exists(context_dim) and not isinstance(context_dim, list):
context_dim = [context_dim]
self.in_channels = in_channels
inner_dim = n_heads * d_head
self.norm = Normalize(in_channels)
if not use_linear:
self.proj_in = nn.Conv2d(
in_channels, inner_dim, kernel_size=1, stride=1, padding=0
)
else:
self.proj_in = nn.Linear(in_channels, inner_dim)
self.transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock3D(
inner_dim,
n_heads,
d_head,
dropout=dropout,
context_dim=context_dim[d],
disable_self_attn=disable_self_attn,
checkpoint=use_checkpoint,
**kwargs
)
for d in range(depth)
]
)
if not use_linear:
self.proj_out = zero_module(
nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
)
else:
self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
self.use_linear = use_linear
def forward(self, x, context=None, num_frames=1):
# note: if no context is given, cross-attention defaults to self-attention
if not isinstance(context, list):
context = [context]
b, c, h, w = x.shape
x_in = x
x = self.norm(x)
if not self.use_linear:
x = self.proj_in(x)
x = rearrange(x, "b c h w -> b (h w) c").contiguous()
if self.use_linear:
x = self.proj_in(x)
for i, block in enumerate(self.transformer_blocks):
x = block(x, context=context[i], num_frames=num_frames)
if self.use_linear:
x = self.proj_out(x)
x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous()
if not self.use_linear:
x = self.proj_out(x)
return x + x_in