egorchistov's picture
Initial release
ac59957
# https://github.com/XiaoyuShi97/VideoFlow/blob/main/core/Networks/BOFNet/gma.py
import torch
import math
from torch import nn, einsum
from einops import rearrange
class Attention(nn.Module):
def __init__(
self,
dim: int,
heads: int,
dim_head: int,
):
super().__init__()
self.heads = heads
self.scale = dim_head**-0.5
inner_dim = heads * dim_head
self.to_qk = nn.Conv2d(dim, inner_dim * 2, 1, bias=False)
def forward(self, fmap):
heads, _, _, h, w = self.heads, *fmap.shape
q, k = self.to_qk(fmap).chunk(2, dim=1)
q, k = map(lambda t: rearrange(t, "b (h d) x y -> b h x y d", h=heads), (q, k))
# Small change based on MemFlow Paper
q = self.scale * q * math.log(h * w, 3)
sim = einsum("b h x y d, b h u v d -> b h x y u v", q, k)
sim = rearrange(sim, "b h x y u v -> b h (x y) (u v)")
attn = sim.softmax(dim=-1)
return attn
class Aggregate(nn.Module):
def __init__(
self,
dim,
heads=4,
dim_head=128,
):
super().__init__()
self.heads = heads
self.scale = dim_head**-0.5
inner_dim = heads * dim_head
self.to_v = nn.Conv2d(dim, inner_dim, 1, bias=False)
self.gamma = nn.Parameter(torch.zeros(1))
if dim != inner_dim:
self.project = nn.Conv2d(inner_dim, dim, 1, bias=False)
else:
self.project = None
def forward(self, attn, fmap):
heads, _, _, h, w = self.heads, *fmap.shape
v = self.to_v(fmap)
v = rearrange(v, "b (h d) x y -> b h (x y) d", h=heads)
out = einsum("b h i j, b h j d -> b h i d", attn, v)
out = rearrange(out, "b h (x y) d -> b (h d) x y", x=h, y=w)
if self.project is not None:
out = self.project(out)
out = fmap + self.gamma * out
return out