AAGCN / aagcn.py
fossbk's picture
Upload 8 files
a6976f4 verified
import math
import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Variable
from graph import Graph
import pytorch_lightning as pl
from torchmetrics.classification import MulticlassAccuracy, BinaryAccuracy
import torch.optim as optim
def import_class(name):
components = name.split('.')
mod = __import__(components[0])
for comp in components[1:]:
mod = getattr(mod, comp)
return mod
def conv_branch_init(conv, branches):
weight = conv.weight
n = weight.size(0)
k1 = weight.size(1)
k2 = weight.size(2)
nn.init.normal_(weight, 0, math.sqrt(2. / (n * k1 * k2 * branches)))
nn.init.constant_(conv.bias, 0)
def conv_init(conv):
nn.init.kaiming_normal_(conv.weight, mode='fan_out')
nn.init.constant_(conv.bias, 0)
def bn_init(bn, scale):
nn.init.constant_(bn.weight, scale)
nn.init.constant_(bn.bias, 0)
class unit_tcn(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=9, stride=1):
super(unit_tcn, self).__init__()
pad = int((kernel_size - 1) / 2)
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=(kernel_size, 1), padding=(pad, 0),
stride=(stride, 1))
self.bn = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)
conv_init(self.conv)
bn_init(self.bn, 1)
def forward(self, x):
x = self.bn(self.conv(x))
return x
class unit_gcn(nn.Module):
def __init__(self, in_channels, out_channels, A, coff_embedding=4, num_subset=3, adaptive=True, attention=True):
super(unit_gcn, self).__init__()
inter_channels = out_channels // coff_embedding
self.inter_c = inter_channels
self.out_c = out_channels
self.in_c = in_channels
self.num_subset = num_subset
num_jpts = A.shape[-1]
self.conv_d = nn.ModuleList()
for i in range(self.num_subset):
self.conv_d.append(nn.Conv2d(in_channels, out_channels, 1))
if adaptive:
self.PA = nn.Parameter(torch.from_numpy(A.astype(np.float32)))
self.alpha = nn.Parameter(torch.zeros(1))
# self.beta = nn.Parameter(torch.ones(1))
# nn.init.constant_(self.PA, 1e-6)
# self.A = Variable(torch.from_numpy(A.astype(np.float32)), requires_grad=False)
# self.A = self.PA
self.conv_a = nn.ModuleList()
self.conv_b = nn.ModuleList()
for i in range(self.num_subset):
self.conv_a.append(nn.Conv2d(in_channels, inter_channels, 1))
self.conv_b.append(nn.Conv2d(in_channels, inter_channels, 1))
else:
self.A = Variable(torch.from_numpy(A.astype(np.float32)), requires_grad=False)
self.adaptive = adaptive
if attention:
# self.beta = nn.Parameter(torch.zeros(1))
# self.gamma = nn.Parameter(torch.zeros(1))
# unified attention
# self.Attention = nn.Parameter(torch.ones(num_jpts))
# temporal attention
self.conv_ta = nn.Conv1d(out_channels, 1, 9, padding=4)
nn.init.constant_(self.conv_ta.weight, 0)
nn.init.constant_(self.conv_ta.bias, 0)
# s attention
ker_jpt = num_jpts - 1 if not num_jpts % 2 else num_jpts
pad = (ker_jpt - 1) // 2
self.conv_sa = nn.Conv1d(out_channels, 1, ker_jpt, padding=pad)
nn.init.xavier_normal_(self.conv_sa.weight)
nn.init.constant_(self.conv_sa.bias, 0)
# channel attention
rr = 2
self.fc1c = nn.Linear(out_channels, out_channels // rr)
self.fc2c = nn.Linear(out_channels // rr, out_channels)
nn.init.kaiming_normal_(self.fc1c.weight)
nn.init.constant_(self.fc1c.bias, 0)
nn.init.constant_(self.fc2c.weight, 0)
nn.init.constant_(self.fc2c.bias, 0)
# self.bn = nn.BatchNorm2d(out_channels)
# bn_init(self.bn, 1)
self.attention = attention
if in_channels != out_channels:
self.down = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 1),
nn.BatchNorm2d(out_channels)
)
else:
self.down = lambda x: x
self.bn = nn.BatchNorm2d(out_channels)
self.soft = nn.Softmax(-2)
self.tan = nn.Tanh()
self.sigmoid = nn.Sigmoid()
self.relu = nn.ReLU(inplace=True)
for m in self.modules():
if isinstance(m, nn.Conv2d):
conv_init(m)
elif isinstance(m, nn.BatchNorm2d):
bn_init(m, 1)
bn_init(self.bn, 1e-6)
for i in range(self.num_subset):
conv_branch_init(self.conv_d[i], self.num_subset)
def forward(self, x):
N, C, T, V = x.size()
y = None
if self.adaptive:
A = self.PA
# A = A + self.PA
for i in range(self.num_subset):
A1 = self.conv_a[i](x).permute(0, 3, 1, 2).contiguous().view(N, V, self.inter_c * T)
A2 = self.conv_b[i](x).view(N, self.inter_c * T, V)
A1 = self.tan(torch.matmul(A1, A2) / A1.size(-1)) # N V V
A1 = A[i] + A1 * self.alpha
A2 = x.view(N, C * T, V)
z = self.conv_d[i](torch.matmul(A2, A1).view(N, C, T, V))
y = z + y if y is not None else z
else:
A = self.A.cuda(x.get_device()) * self.mask
for i in range(self.num_subset):
A1 = A[i]
A2 = x.view(N, C * T, V)
z = self.conv_d[i](torch.matmul(A2, A1).view(N, C, T, V))
y = z + y if y is not None else z
y = self.bn(y)
y += self.down(x)
y = self.relu(y)
if self.attention:
# spatial attention
se = y.mean(-2) # N C V
se1 = self.sigmoid(self.conv_sa(se))
y = y * se1.unsqueeze(-2) + y
# a1 = se1.unsqueeze(-2)
# temporal attention
se = y.mean(-1)
se1 = self.sigmoid(self.conv_ta(se))
y = y * se1.unsqueeze(-1) + y
# a2 = se1.unsqueeze(-1)
# channel attention
se = y.mean(-1).mean(-1)
se1 = self.relu(self.fc1c(se))
se2 = self.sigmoid(self.fc2c(se1))
y = y * se2.unsqueeze(-1).unsqueeze(-1) + y
# a3 = se2.unsqueeze(-1).unsqueeze(-1)
# unified attention
# y = y * self.Attention + y
# y = y + y * ((a2 + a3) / 2)
# y = self.bn(y)
return y
class TCN_GCN_unit(nn.Module):
def __init__(self, in_channels, out_channels, A, stride=1, residual=True, adaptive=True, attention=True):
super(TCN_GCN_unit, self).__init__()
self.gcn1 = unit_gcn(in_channels, out_channels, A, adaptive=adaptive, attention=attention)
self.tcn1 = unit_tcn(out_channels, out_channels, stride=stride)
self.relu = nn.ReLU(inplace=True)
# if attention:
# self.alpha = nn.Parameter(torch.zeros(1))
# self.beta = nn.Parameter(torch.ones(1))
# temporal attention
# self.conv_ta1 = nn.Conv1d(out_channels, out_channels//rt, 9, padding=4)
# self.bn = nn.BatchNorm2d(out_channels)
# bn_init(self.bn, 1)
# self.conv_ta2 = nn.Conv1d(out_channels, 1, 9, padding=4)
# nn.init.kaiming_normal_(self.conv_ta1.weight)
# nn.init.constant_(self.conv_ta1.bias, 0)
# nn.init.constant_(self.conv_ta2.weight, 0)
# nn.init.constant_(self.conv_ta2.bias, 0)
# rt = 4
# self.inter_c = out_channels // rt
# self.conv_ta1 = nn.Conv2d(out_channels, out_channels // rt, 1)
# self.conv_ta2 = nn.Conv2d(out_channels, out_channels // rt, 1)
# nn.init.constant_(self.conv_ta1.weight, 0)
# nn.init.constant_(self.conv_ta1.bias, 0)
# nn.init.constant_(self.conv_ta2.weight, 0)
# nn.init.constant_(self.conv_ta2.bias, 0)
# s attention
# num_jpts = A.shape[-1]
# ker_jpt = num_jpts - 1 if not num_jpts % 2 else num_jpts
# pad = (ker_jpt - 1) // 2
# self.conv_sa = nn.Conv1d(out_channels, 1, ker_jpt, padding=pad)
# nn.init.constant_(self.conv_sa.weight, 0)
# nn.init.constant_(self.conv_sa.bias, 0)
# channel attention
# rr = 16
# self.fc1c = nn.Linear(out_channels, out_channels // rr)
# self.fc2c = nn.Linear(out_channels // rr, out_channels)
# nn.init.kaiming_normal_(self.fc1c.weight)
# nn.init.constant_(self.fc1c.bias, 0)
# nn.init.constant_(self.fc2c.weight, 0)
# nn.init.constant_(self.fc2c.bias, 0)
#
# self.softmax = nn.Softmax(-2)
# self.sigmoid = nn.Sigmoid()
self.attention = attention
if not residual:
self.residual = lambda x: 0
elif (in_channels == out_channels) and (stride == 1):
self.residual = lambda x: x
else:
self.residual = unit_tcn(in_channels, out_channels, kernel_size=1, stride=stride)
def forward(self, x):
if self.attention:
y = self.relu(self.tcn1(self.gcn1(x)) + self.residual(x))
# spatial attention
# se = y.mean(-2) # N C V
# se1 = self.sigmoid(self.conv_sa(se))
# y = y * se1.unsqueeze(-2) + y
# a1 = se1.unsqueeze(-2)
# temporal attention
# se = y.mean(-1) # N C T
# # se1 = self.relu(self.bn(self.conv_ta1(se)))
# se2 = self.sigmoid(self.conv_ta2(se))
# # y = y * se1.unsqueeze(-1) + y
# a2 = se2.unsqueeze(-1)
# se = y # NCTV
# N, C, T, V = y.shape
# se1 = self.conv_ta1(se).permute(0, 2, 1, 3).contiguous().view(N, T, self.inter_c * V) # NTCV
# se2 = self.conv_ta2(se).permute(0, 1, 3, 2).contiguous().view(N, self.inter_c * V, T) # NCVT
# a2 = self.softmax(torch.matmul(se1, se2) / np.sqrt(se1.size(-1))) # N T T
# y = torch.matmul(y.permute(0, 1, 3, 2).contiguous().view(N, C * V, T), a2) \
# .view(N, C, V, T).permute(0, 1, 3, 2) * self.alpha + y
# channel attention
# se = y.mean(-1).mean(-1)
# se1 = self.relu(self.fc1c(se))
# se2 = self.sigmoid(self.fc2c(se1))
# # y = y * se2.unsqueeze(-1).unsqueeze(-1) + y
# a3 = se2.unsqueeze(-1).unsqueeze(-1)
#
# y = y * ((a2 + a3) / 2) + y
# y = self.bn(y)
else:
y = self.relu(self.tcn1(self.gcn1(x)) + self.residual(x))
return y
class Model(pl.LightningModule):
def __init__(self, num_class=60, num_point=25, num_person=2, graph=None, graph_args=dict(), in_channels=3,
drop_out=0, adaptive=True, attention=True, learning_rate=1e-4, weight_decay=1e-4):
super(Model, self).__init__()
# if graph is None:
# raise ValueError()
# else:
# Graph = import_class(graph)
self.graph = Graph(**graph_args)
A = self.graph.A
self.num_class = num_class
self.data_bn = nn.BatchNorm1d(num_person * in_channels * num_point)
self.l1 = TCN_GCN_unit(in_channels, 64, A, residual=False, adaptive=adaptive, attention=attention)
self.l2 = TCN_GCN_unit(64, 64, A, adaptive=adaptive, attention=attention)
self.l3 = TCN_GCN_unit(64, 64, A, adaptive=adaptive, attention=attention)
self.l4 = TCN_GCN_unit(64, 64, A, adaptive=adaptive, attention=attention)
self.l5 = TCN_GCN_unit(64, 128, A, stride=2, adaptive=adaptive, attention=attention)
self.l6 = TCN_GCN_unit(128, 128, A, adaptive=adaptive, attention=attention)
self.l7 = TCN_GCN_unit(128, 128, A, adaptive=adaptive, attention=attention)
self.l8 = TCN_GCN_unit(128, 256, A, stride=2, adaptive=adaptive, attention=attention)
self.l9 = TCN_GCN_unit(256, 256, A, adaptive=adaptive, attention=attention)
self.l10 = TCN_GCN_unit(256, 256, A, adaptive=adaptive, attention=attention)
# self.l11 = TCN_GCN_unit(256, 512, A, stride=2, adaptive=adaptive, attention=attention)
# self.l12 = TCN_GCN_unit(512, 512, A, adaptive=adaptive, attention=attention)
# self.l13 = TCN_GCN_unit(512, 512, A, adaptive=adaptive, attention=attention)
self.fc = nn.Linear(256, num_class)
nn.init.normal_(self.fc.weight, 0, math.sqrt(2. / num_class))
bn_init(self.data_bn, 1)
if drop_out:
self.drop_out = nn.Dropout(drop_out)
else:
self.drop_out = lambda x: x
self.loss = nn.CrossEntropyLoss()
self.metric = MulticlassAccuracy(num_class)
# self.metric = BinaryAccuracy()
self.learning_rate = learning_rate
self.weight_decay = weight_decay
self.validation_step_loss_outputs = []
self.validation_step_acc_outputs = []
self.save_hyperparameters()
def forward(self, x):
N, C, T, V, M = x.size()
x = x.permute(0, 4, 3, 1, 2).contiguous().view(N, M * V * C, T)
x = self.data_bn(x.float())
x = x.view(N, M, V, C, T).permute(0, 1, 3, 4, 2).contiguous().view(N * M, C, T, V)
x = self.l1(x)
x = self.l2(x)
x = self.l3(x)
x = self.l4(x)
x = self.l5(x)
x = self.l6(x)
x = self.l7(x)
x = self.l8(x)
x = self.l9(x)
x = self.l10(x)
# x = self.l11(x)
# x = self.l12(x)
# x = self.l13(x)
# N*M,C,T,V
c_new = x.size(1)
x = x.view(N, M, c_new, -1)
x = x.mean(3).mean(1)
x = self.drop_out(x)
return self.fc(x)
def training_step(self, batch, batch_idx):
inputs, targets = batch
outputs = self(inputs)
y_pred_class = torch.argmax(torch.softmax(outputs, dim=1), dim=1)
# print("Targets : ", targets)
# print("Preds : ", y_pred_class)
train_accuracy = self.metric(y_pred_class, targets)
loss = self.loss(outputs, targets)
self.log('train_accuracy', train_accuracy, prog_bar=True, on_epoch=True)
self.log('train_loss', loss, prog_bar=True, on_epoch=True)
# return {"loss": loss, "train_accuracy" : train_accuracy}
return loss
def validation_step(self, batch, batch_idx):
inputs, targets = batch
outputs = self.forward(inputs)
y_pred_class = torch.argmax(torch.softmax(outputs, dim=1), dim=1)
valid_accuracy = self.metric(y_pred_class, targets)
loss = self.loss(outputs, targets)
self.log('valid_accuracy', valid_accuracy, prog_bar=True, on_epoch=True)
self.log('valid_loss', loss, prog_bar=True, on_epoch=True)
self.validation_step_loss_outputs.append(loss)
self.validation_step_acc_outputs.append(valid_accuracy)
return {"valid_loss" : loss, "valid_accuracy" : valid_accuracy}
def on_validation_epoch_end(self):
# avg_loss = torch.stack(
# [x["valid_loss"] for x in outputs]).mean()
# avg_acc = torch.stack(
# [x["valid_accuracy"] for x in outputs]).mean()
avg_loss = torch.stack(self.validation_step_loss_outputs).mean()
avg_acc = torch.stack(self.validation_step_acc_outputs).mean()
self.log("ptl/val_loss", avg_loss)
self.log("ptl/val_accuracy", avg_acc)
self.validation_step_loss_outputs.clear()
self.validation_step_acc_outputs.clear()
def test_step(self, batch, batch_idx):
inputs, targets = batch
outputs = self.forward(inputs)
y_pred_class = torch.argmax(torch.softmax(outputs, dim=1), dim=1)
print("Targets : ", targets)
print("Preds : ", y_pred_class)
test_accuracy = self.metric(y_pred_class, targets)
loss = self.loss(outputs, targets)
self.log('test_accuracy', test_accuracy, prog_bar=True, on_epoch=True)
self.log('test_loss', loss, prog_bar=True, on_epoch=True)
return {"test_loss" : loss, "test_accuracy" : test_accuracy}
def configure_optimizers(self):
params = self.parameters()
optimizer = optim.Adam(params=params, lr = self.learning_rate, weight_decay = self.weight_decay)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max')
return {"optimizer": optimizer,
"lr_scheduler": {"scheduler": scheduler, "monitor": "valid_accuracy"}
}
# return optimizer
def predict_step(self, batch, batch_idx):
return self(batch)
if __name__ == "__main__":
import os
from torchinfo import summary
print(os.getcwd())
device = "cuda" if torch.cuda.is_available() else "cpu"
model = Model(num_class=20, num_point=18, num_person=1,
graph_args={}, in_channels=2).to(device)
# print(model.device)
# N, C, T, V, M
summary(model)
x = torch.randn((1, 2, 80, 18, 1)).to(device)
y = model(x)
print(y.shape)