Spaces:
Running
Running
# https://github.com/XiaoyuShi97/VideoFlow/blob/main/core/Networks/BOFNet/gma.py | |
import torch | |
import math | |
from torch import nn, einsum | |
from einops import rearrange | |
class Attention(nn.Module): | |
def __init__( | |
self, | |
dim: int, | |
heads: int, | |
dim_head: int, | |
): | |
super().__init__() | |
self.heads = heads | |
self.scale = dim_head**-0.5 | |
inner_dim = heads * dim_head | |
self.to_qk = nn.Conv2d(dim, inner_dim * 2, 1, bias=False) | |
def forward(self, fmap): | |
heads, _, _, h, w = self.heads, *fmap.shape | |
q, k = self.to_qk(fmap).chunk(2, dim=1) | |
q, k = map(lambda t: rearrange(t, "b (h d) x y -> b h x y d", h=heads), (q, k)) | |
# Small change based on MemFlow Paper | |
q = self.scale * q * math.log(h * w, 3) | |
sim = einsum("b h x y d, b h u v d -> b h x y u v", q, k) | |
sim = rearrange(sim, "b h x y u v -> b h (x y) (u v)") | |
attn = sim.softmax(dim=-1) | |
return attn | |
class Aggregate(nn.Module): | |
def __init__( | |
self, | |
dim, | |
heads=4, | |
dim_head=128, | |
): | |
super().__init__() | |
self.heads = heads | |
self.scale = dim_head**-0.5 | |
inner_dim = heads * dim_head | |
self.to_v = nn.Conv2d(dim, inner_dim, 1, bias=False) | |
self.gamma = nn.Parameter(torch.zeros(1)) | |
if dim != inner_dim: | |
self.project = nn.Conv2d(inner_dim, dim, 1, bias=False) | |
else: | |
self.project = None | |
def forward(self, attn, fmap): | |
heads, _, _, h, w = self.heads, *fmap.shape | |
v = self.to_v(fmap) | |
v = rearrange(v, "b (h d) x y -> b h (x y) d", h=heads) | |
out = einsum("b h i j, b h j d -> b h i d", attn, v) | |
out = rearrange(out, "b h (x y) d -> b (h d) x y", x=h, y=w) | |
if self.project is not None: | |
out = self.project(out) | |
out = fmap + self.gamma * out | |
return out | |