Upload 8 files
Browse files# AAGCN-VSL-Detection
# Introction
In this project, we modify a model to better classify Vietnamese sign language (VSL). We use Mediapipe to extract keypoints. Then apply a bilinear preprocessing technique in order to reconstruct the missing keypoints. The model is built based on the Adaptive Graph Convolutional Networks by combining with Attention Mechanisms, including spatial, temporal, and channel attention.
The model is test on our self-collected VSL dataset, which was collected from the school for hearing-impaired children in Hanoi, Vietnam. The data consists of 5,572 videos of 28 actors, with 199 classes each, representing the most frequently used spoken Vietnamese. Before training on this dataset, we pre-train the model with the Ankara University Turkish Sign Language Dataset (AUTSL) first to obtain the weights, learning rate for further training.
# Setup
Download all the files and run the jupyter file Total (VSL). This file will do almost the entire training process, ranging from extracting the keypoints to performing interpolation, training the model. All the other files are the supplement for the model part.
Change the path of the input dataset. The code will read all videos in the folder.
In this coding file (Total (VSL)), there are 2 options:
1. Train new.
2. Train continuously from the model that was trained on the AUTSL dataset.
Uncomment the part of the code that you want to use.
# Evaluation
We use k-fold in this project with k is defined as 10. The 10 trained models will be stored in the "checkpoints" folder. The final output will print out the best accuracy.
# Acknoledgement
This project is built by iBME lab at Hanoi University of Science and Technology, Vietnam. It is funded by Hanoi University of Science and Technology (HUST) under project number T2023-PC-028.
- 1_1000_label.csv +0 -0
- Total (VSL).ipynb +883 -0
- aagcn.py +440 -0
- agcn.py +275 -0
- augumentation.py +73 -0
- epoch=55-valid_loss=0.41-valid_accuracy=0.85-autsl-aagcn.ckpt +3 -0
- feeder.py +72 -0
- graph.py +287 -0
The diff for this file is too large to render.
See raw diff
|
|
@@ -0,0 +1,883 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"metadata": {},
|
6 |
+
"source": [
|
7 |
+
"pre-train with AUTSL dataset"
|
8 |
+
]
|
9 |
+
},
|
10 |
+
{
|
11 |
+
"cell_type": "markdown",
|
12 |
+
"metadata": {},
|
13 |
+
"source": [
|
14 |
+
"This is the code to train the model with the AUTSL dataset"
|
15 |
+
]
|
16 |
+
},
|
17 |
+
{
|
18 |
+
"cell_type": "code",
|
19 |
+
"execution_count": 1,
|
20 |
+
"metadata": {},
|
21 |
+
"outputs": [
|
22 |
+
{
|
23 |
+
"data": {
|
24 |
+
"text/plain": [
|
25 |
+
"'\\nfrom torch.utils.data import DataLoader\\nimport torch\\nfrom torchinfo import summary\\nfrom feeder import FeederINCLUDE\\nfrom aagcn import Model\\nimport pytorch_lightning as pl\\nfrom pytorch_lightning.loggers import WandbLogger # Importing here\\nfrom pytorch_lightning.callbacks import ModelCheckpoint\\nimport wandb\\nfrom augumentation import Rotate, Compose\\nfrom torch.utils.data import random_split\\n\\n\\nif __name__ == \\'__main__\\':\\n\\n # Hyper parameter tuning : batch_size, learning_rate, weight_decay\\n config = {\\'batch_size\\': 150, \\'learning_rate\\': 0.0137296, \\'weight_decay\\': 0.000150403}\\n \\n # Load device\\n device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\\n\\n # Initialize wandb\\n wandb.finish()\\n wandb.init(project=\"GCN_VSL\", config=config) \\n wandb_config = wandb.config # Access the config parameters\\n\\n try:\\n # Your training or evaluation code here\\n print(\"WandB initialized successfully.\")\\n\\n finally:\\n # Fnish the WandB run\\n wandb.finish()\\n\\n # Load model\\n model = Model(num_class=226, num_point=46, num_person=1, in_channels=2,\\n graph_args={\"layout\": \"mediapipe_two_hand\", \"strategy\": \"spatial\"},\\n learning_rate=wandb_config.learning_rate, weight_decay=wandb_config.weight_decay)\\n\\n # Callback PL\\n callbacks = [\\n ModelCheckpoint(\\n dirpath=\"checkpoints\",\\n monitor=\"valid_loss\",\\n mode=\"min\",\\n every_n_epochs=2,\\n filename=\\'{epoch}-{valid_loss:.2f}-{valid_accuracy:.2f}-autsl-aagcn\\'\\n ),\\n ]\\n\\n # Augmentation\\n transforms = Compose([\\n Rotate(15, 80, 25, (0.5, 0.5))\\n ])\\n\\n %cd /home/ibmelab/Documents/GG/VSLRecognition/AUTSL/AAGCN\\n # Dataset class\\n train_dataset = FeederINCLUDE(data_path=f\"autsl_train_data_preprocess.npy\", label_path=f\"train_label_preprocess.npy\",\\n transform=transforms)\\n test_dataset = FeederINCLUDE(data_path=f\"autsl_test_data_preprocess.npy\", label_path=f\"test_label_preprocess.npy\")\\n valid_dataset = FeederINCLUDE(data_path=f\"autsl_valid_data_preprocess.npy\", label_path=f\"valid_label_preprocess.npy\")\\n\\n # DataLoader\\n train_dataloader = DataLoader(train_dataset, batch_size=wandb_config.batch_size, shuffle=True)\\n test_dataloader = DataLoader(test_dataset, batch_size=wandb_config.batch_size, shuffle=False)\\n val_dataloader = DataLoader(valid_dataset, batch_size=wandb_config.batch_size, shuffle=False)\\n\\n # Wandb Logger\\n wandb_logger = WandbLogger(log_model=\\'all\\')\\n\\n %cd /media/ibmelab/ibme21/Test\\n # Trainer PL\\n trainer = pl.Trainer(max_epochs=120, accelerator=\"auto\", check_val_every_n_epoch=1,\\n devices=1, callbacks=callbacks, logger=wandb_logger) # Added logger here\\n\\n trainer.fit(model, train_dataloader, val_dataloader)\\n\\n # Optional: Uncomment this when you want to test\\n # trainer.test(model, test_dataloader, ckpt_path=\"checkpoints/your_checkpoint.ckpt\", verbose=True)\\n'"
|
26 |
+
]
|
27 |
+
},
|
28 |
+
"execution_count": 1,
|
29 |
+
"metadata": {},
|
30 |
+
"output_type": "execute_result"
|
31 |
+
}
|
32 |
+
],
|
33 |
+
"source": [
|
34 |
+
"'''\n",
|
35 |
+
"from torch.utils.data import DataLoader\n",
|
36 |
+
"import torch\n",
|
37 |
+
"from torchinfo import summary\n",
|
38 |
+
"from feeder import FeederINCLUDE\n",
|
39 |
+
"from aagcn import Model\n",
|
40 |
+
"import pytorch_lightning as pl\n",
|
41 |
+
"from pytorch_lightning.loggers import WandbLogger # Importing here\n",
|
42 |
+
"from pytorch_lightning.callbacks import ModelCheckpoint\n",
|
43 |
+
"import wandb\n",
|
44 |
+
"from augumentation import Rotate, Compose\n",
|
45 |
+
"from torch.utils.data import random_split\n",
|
46 |
+
"\n",
|
47 |
+
"\n",
|
48 |
+
"if __name__ == '__main__':\n",
|
49 |
+
"\n",
|
50 |
+
" # Hyper parameter tuning : batch_size, learning_rate, weight_decay\n",
|
51 |
+
" config = {'batch_size': 150, 'learning_rate': 0.0137296, 'weight_decay': 0.000150403}\n",
|
52 |
+
" \n",
|
53 |
+
" # Load device\n",
|
54 |
+
" device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
|
55 |
+
"\n",
|
56 |
+
" # Initialize wandb\n",
|
57 |
+
" wandb.finish()\n",
|
58 |
+
" wandb.init(project=\"GCN_VSL\", config=config) \n",
|
59 |
+
" wandb_config = wandb.config # Access the config parameters\n",
|
60 |
+
"\n",
|
61 |
+
" try:\n",
|
62 |
+
" # Your training or evaluation code here\n",
|
63 |
+
" print(\"WandB initialized successfully.\")\n",
|
64 |
+
"\n",
|
65 |
+
" finally:\n",
|
66 |
+
" # Fnish the WandB run\n",
|
67 |
+
" wandb.finish()\n",
|
68 |
+
"\n",
|
69 |
+
" # Load model\n",
|
70 |
+
" model = Model(num_class=226, num_point=46, num_person=1, in_channels=2,\n",
|
71 |
+
" graph_args={\"layout\": \"mediapipe_two_hand\", \"strategy\": \"spatial\"},\n",
|
72 |
+
" learning_rate=wandb_config.learning_rate, weight_decay=wandb_config.weight_decay)\n",
|
73 |
+
"\n",
|
74 |
+
" # Callback PL\n",
|
75 |
+
" callbacks = [\n",
|
76 |
+
" ModelCheckpoint(\n",
|
77 |
+
" dirpath=\"checkpoints\",\n",
|
78 |
+
" monitor=\"valid_loss\",\n",
|
79 |
+
" mode=\"min\",\n",
|
80 |
+
" every_n_epochs=2,\n",
|
81 |
+
" filename='{epoch}-{valid_loss:.2f}-{valid_accuracy:.2f}-autsl-aagcn'\n",
|
82 |
+
" ),\n",
|
83 |
+
" ]\n",
|
84 |
+
"\n",
|
85 |
+
" # Augmentation\n",
|
86 |
+
" transforms = Compose([\n",
|
87 |
+
" Rotate(15, 80, 25, (0.5, 0.5))\n",
|
88 |
+
" ])\n",
|
89 |
+
"\n",
|
90 |
+
" %cd /home/ibmelab/Documents/GG/VSLRecognition/AUTSL/AAGCN\n",
|
91 |
+
" # Dataset class\n",
|
92 |
+
" train_dataset = FeederINCLUDE(data_path=f\"autsl_train_data_preprocess.npy\", label_path=f\"train_label_preprocess.npy\",\n",
|
93 |
+
" transform=transforms)\n",
|
94 |
+
" test_dataset = FeederINCLUDE(data_path=f\"autsl_test_data_preprocess.npy\", label_path=f\"test_label_preprocess.npy\")\n",
|
95 |
+
" valid_dataset = FeederINCLUDE(data_path=f\"autsl_valid_data_preprocess.npy\", label_path=f\"valid_label_preprocess.npy\")\n",
|
96 |
+
"\n",
|
97 |
+
" # DataLoader\n",
|
98 |
+
" train_dataloader = DataLoader(train_dataset, batch_size=wandb_config.batch_size, shuffle=True)\n",
|
99 |
+
" test_dataloader = DataLoader(test_dataset, batch_size=wandb_config.batch_size, shuffle=False)\n",
|
100 |
+
" val_dataloader = DataLoader(valid_dataset, batch_size=wandb_config.batch_size, shuffle=False)\n",
|
101 |
+
"\n",
|
102 |
+
" # Wandb Logger\n",
|
103 |
+
" wandb_logger = WandbLogger(log_model='all')\n",
|
104 |
+
"\n",
|
105 |
+
" %cd /media/ibmelab/ibme21/Test\n",
|
106 |
+
" # Trainer PL\n",
|
107 |
+
" trainer = pl.Trainer(max_epochs=120, accelerator=\"auto\", check_val_every_n_epoch=1,\n",
|
108 |
+
" devices=1, callbacks=callbacks, logger=wandb_logger) # Added logger here\n",
|
109 |
+
"\n",
|
110 |
+
" trainer.fit(model, train_dataloader, val_dataloader)\n",
|
111 |
+
"\n",
|
112 |
+
" # Optional: Uncomment this when you want to test\n",
|
113 |
+
" # trainer.test(model, test_dataloader, ckpt_path=\"checkpoints/your_checkpoint.ckpt\", verbose=True)\n",
|
114 |
+
"'''\n",
|
115 |
+
" \n"
|
116 |
+
]
|
117 |
+
},
|
118 |
+
{
|
119 |
+
"cell_type": "code",
|
120 |
+
"execution_count": 2,
|
121 |
+
"metadata": {},
|
122 |
+
"outputs": [],
|
123 |
+
"source": [
|
124 |
+
"import pandas as pd\n",
|
125 |
+
"import mediapipe as mp\n",
|
126 |
+
"import cv2\n",
|
127 |
+
"from collections import defaultdict\n",
|
128 |
+
"from joblib import Parallel, delayed\n",
|
129 |
+
"from tqdm import tqdm\n",
|
130 |
+
"import ast\n",
|
131 |
+
"import os\n",
|
132 |
+
"import csv\n",
|
133 |
+
"import re\n",
|
134 |
+
"from sklearn.model_selection import KFold\n",
|
135 |
+
"import numpy as np"
|
136 |
+
]
|
137 |
+
},
|
138 |
+
{
|
139 |
+
"cell_type": "markdown",
|
140 |
+
"metadata": {},
|
141 |
+
"source": [
|
142 |
+
"Load the videos to videos_list.csv (columns: file (path), label, gloss, video name, actor)"
|
143 |
+
]
|
144 |
+
},
|
145 |
+
{
|
146 |
+
"cell_type": "code",
|
147 |
+
"execution_count": null,
|
148 |
+
"metadata": {},
|
149 |
+
"outputs": [
|
150 |
+
{
|
151 |
+
"name": "stdout",
|
152 |
+
"output_type": "stream",
|
153 |
+
"text": [
|
154 |
+
"Video names have been written to videos_list.csv\n",
|
155 |
+
"Minimum label: 20\n",
|
156 |
+
"Labels have been updated and saved.\n"
|
157 |
+
]
|
158 |
+
}
|
159 |
+
],
|
160 |
+
"source": [
|
161 |
+
"folder_path = r'path_to_dataset_folder'\n",
|
162 |
+
"csv_file_path = 'videos_list.csv'\n",
|
163 |
+
"labels_file_path = '1_1000_label.csv'\n",
|
164 |
+
"final_file_path = 'temp_videos_list.csv'\n",
|
165 |
+
"\n",
|
166 |
+
"label_to_gloss = {}\n",
|
167 |
+
"with open(labels_file_path, mode='r', encoding='utf-8') as labels_file:\n",
|
168 |
+
" csv_reader = csv.DictReader(labels_file)\n",
|
169 |
+
" for row in csv_reader:\n",
|
170 |
+
" label = int(row['id_label_in_documents'])\n",
|
171 |
+
" gloss = row['name']\n",
|
172 |
+
" label_to_gloss[label] = gloss\n",
|
173 |
+
"\n",
|
174 |
+
"with open(csv_file_path, mode='w', newline='', encoding='utf-8') as csv_file:\n",
|
175 |
+
" csv_writer = csv.writer(csv_file)\n",
|
176 |
+
" csv_writer.writerow(['file', 'label', 'gloss', 'video_name', 'actor'])\n",
|
177 |
+
"\n",
|
178 |
+
" for filename in os.listdir(folder_path):\n",
|
179 |
+
" if filename.lower().endswith(('.mp4', '.mkv', '.avi', '.mov', '.flv', '.wmv')):\n",
|
180 |
+
" actor = filename.split('_')[0]\n",
|
181 |
+
" \n",
|
182 |
+
" match = re.search(r'_(\\d+)\\.', filename)\n",
|
183 |
+
" if match:\n",
|
184 |
+
" label = int(match.group(1))\n",
|
185 |
+
" gloss = label_to_gloss.get(label, 'Unknown')\n",
|
186 |
+
" else:\n",
|
187 |
+
" label = 'N/A'\n",
|
188 |
+
" gloss = 'Unknown'\n",
|
189 |
+
"\n",
|
190 |
+
" if label != 200:\n",
|
191 |
+
" full_filename = os.path.join(folder_path, filename)\n",
|
192 |
+
" csv_writer.writerow([full_filename, label, gloss, filename, actor]) \n",
|
193 |
+
"\n",
|
194 |
+
"print(f'Video names have been written to {csv_file_path}')\n",
|
195 |
+
"\n",
|
196 |
+
"# Find min label\n",
|
197 |
+
"with open(csv_file_path, mode='r', newline='', encoding='utf-8') as csv_file:\n",
|
198 |
+
" csv_reader = csv.DictReader(csv_file)\n",
|
199 |
+
" labels = [int(row[\"label\"]) for row in csv_reader if row[\"label\"].isdigit()] \n",
|
200 |
+
" min_label = min(labels) if labels else None\n",
|
201 |
+
"\n",
|
202 |
+
"print(\"Minimum label:\", min_label)\n",
|
203 |
+
"\n",
|
204 |
+
"# Normalize labels\n",
|
205 |
+
"with open(csv_file_path, mode='r', newline='', encoding='utf-8') as csv_file, \\\n",
|
206 |
+
" open(final_file_path, mode='w', newline='', encoding='utf-8') as final_file:\n",
|
207 |
+
" \n",
|
208 |
+
" csv_reader = csv.DictReader(csv_file)\n",
|
209 |
+
" fieldnames = csv_reader.fieldnames\n",
|
210 |
+
" \n",
|
211 |
+
" csv_writer = csv.DictWriter(final_file, fieldnames=fieldnames)\n",
|
212 |
+
" csv_writer.writeheader()\n",
|
213 |
+
" \n",
|
214 |
+
" for row in csv_reader:\n",
|
215 |
+
" if row['label'].isdigit(): # Check if label is a digit before converting\n",
|
216 |
+
" row['label'] = str(int(row['label']) - min_label) \n",
|
217 |
+
" csv_writer.writerow(row)\n",
|
218 |
+
"\n",
|
219 |
+
"# Replace the original file with the updated file\n",
|
220 |
+
"os.replace(final_file_path, csv_file_path)\n",
|
221 |
+
"\n",
|
222 |
+
"print(\"Labels have been updated and saved.\")\n"
|
223 |
+
]
|
224 |
+
},
|
225 |
+
{
|
226 |
+
"cell_type": "markdown",
|
227 |
+
"metadata": {},
|
228 |
+
"source": [
|
229 |
+
"Number of labels in the dataset"
|
230 |
+
]
|
231 |
+
},
|
232 |
+
{
|
233 |
+
"cell_type": "code",
|
234 |
+
"execution_count": 4,
|
235 |
+
"metadata": {},
|
236 |
+
"outputs": [
|
237 |
+
{
|
238 |
+
"name": "stdout",
|
239 |
+
"output_type": "stream",
|
240 |
+
"text": [
|
241 |
+
"10\n",
|
242 |
+
"{20, 21, 22, 23, 24, 25, 26, 27, 28, 29}\n"
|
243 |
+
]
|
244 |
+
}
|
245 |
+
],
|
246 |
+
"source": [
|
247 |
+
"num_labels = len(set(labels))\n",
|
248 |
+
"print(num_labels)\n",
|
249 |
+
"print(set(labels))"
|
250 |
+
]
|
251 |
+
},
|
252 |
+
{
|
253 |
+
"cell_type": "markdown",
|
254 |
+
"metadata": {},
|
255 |
+
"source": [
|
256 |
+
"Extract keypoints"
|
257 |
+
]
|
258 |
+
},
|
259 |
+
{
|
260 |
+
"cell_type": "code",
|
261 |
+
"execution_count": null,
|
262 |
+
"metadata": {},
|
263 |
+
"outputs": [
|
264 |
+
{
|
265 |
+
"name": "stderr",
|
266 |
+
"output_type": "stream",
|
267 |
+
"text": [
|
268 |
+
" \r"
|
269 |
+
]
|
270 |
+
}
|
271 |
+
],
|
272 |
+
"source": [
|
273 |
+
"import pandas as pd\n",
|
274 |
+
"import mediapipe as mp\n",
|
275 |
+
"import cv2\n",
|
276 |
+
"import os\n",
|
277 |
+
"from collections import defaultdict\n",
|
278 |
+
"from joblib import Parallel, delayed\n",
|
279 |
+
"from tqdm import tqdm\n",
|
280 |
+
"\n",
|
281 |
+
"mp_holistic = mp.solutions.holistic\n",
|
282 |
+
"mp_drawing = mp.solutions.drawing_utils\n",
|
283 |
+
"\n",
|
284 |
+
"hand_landmarks = ['INDEX_FINGER_DIP', 'INDEX_FINGER_MCP', 'INDEX_FINGER_PIP', 'INDEX_FINGER_TIP', \n",
|
285 |
+
" 'MIDDLE_FINGER_DIP', 'MIDDLE_FINGER_MCP', 'MIDDLE_FINGER_PIP', 'MIDDLE_FINGER_TIP', \n",
|
286 |
+
" 'PINKY_DIP', 'PINKY_MCP', 'PINKY_PIP', 'PINKY_TIP', 'RING_FINGER_DIP', 'RING_FINGER_MCP', \n",
|
287 |
+
" 'RING_FINGER_PIP', 'RING_FINGER_TIP', 'THUMB_CMC', 'THUMB_IP', 'THUMB_MCP', 'THUMB_TIP', 'WRIST']\n",
|
288 |
+
"pose_landmarks = ['LEFT_ANKLE', 'LEFT_EAR', 'LEFT_ELBOW', 'LEFT_EYE', 'LEFT_EYE_INNER', 'LEFT_EYE_OUTER', \n",
|
289 |
+
" 'LEFT_FOOT_INDEX', 'LEFT_HEEL', 'LEFT_HIP', 'LEFT_INDEX', 'LEFT_KNEE', 'LEFT_PINKY', \n",
|
290 |
+
" 'LEFT_SHOULDER', 'LEFT_THUMB', 'LEFT_WRIST', 'MOUTH_LEFT', 'MOUTH_RIGHT', 'NOSE', \n",
|
291 |
+
" 'RIGHT_ANKLE', 'RIGHT_EAR', 'RIGHT_ELBOW', 'RIGHT_EYE', 'RIGHT_EYE_INNER', 'RIGHT_EYE_OUTER', \n",
|
292 |
+
" 'RIGHT_FOOT_INDEX', 'RIGHT_HEEL', 'RIGHT_HIP', 'RIGHT_INDEX', 'RIGHT_KNEE', 'RIGHT_PINKY', \n",
|
293 |
+
" 'RIGHT_SHOULDER', 'RIGHT_THUMB', 'RIGHT_WRIST']\n",
|
294 |
+
"\n",
|
295 |
+
"def extract_keypoint(video_path, label, actor):\n",
|
296 |
+
" cap = cv2.VideoCapture(video_path)\n",
|
297 |
+
" \n",
|
298 |
+
" keypoint_dict = defaultdict(list)\n",
|
299 |
+
" count = 0\n",
|
300 |
+
"\n",
|
301 |
+
" with mp_holistic.Holistic(min_detection_confidence=0.5, min_tracking_confidence=0.5) as holistic:\n",
|
302 |
+
" while True:\n",
|
303 |
+
" ret, frame = cap.read()\n",
|
304 |
+
" if not ret:\n",
|
305 |
+
" break\n",
|
306 |
+
" \n",
|
307 |
+
" count += 1\n",
|
308 |
+
" image = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)\n",
|
309 |
+
" results = holistic.process(image)\n",
|
310 |
+
"\n",
|
311 |
+
" if results.right_hand_landmarks:\n",
|
312 |
+
" for idx, landmark in enumerate(results.right_hand_landmarks.landmark): \n",
|
313 |
+
" keypoint_dict[f\"{hand_landmarks[idx]}_right_x\"].append(landmark.x)\n",
|
314 |
+
" keypoint_dict[f\"{hand_landmarks[idx]}_right_y\"].append(landmark.y)\n",
|
315 |
+
" keypoint_dict[f\"{hand_landmarks[idx]}_right_z\"].append(landmark.z)\n",
|
316 |
+
" else:\n",
|
317 |
+
" for idx in range(len(hand_landmarks)):\n",
|
318 |
+
" keypoint_dict[f\"{hand_landmarks[idx]}_right_x\"].append(0)\n",
|
319 |
+
" keypoint_dict[f\"{hand_landmarks[idx]}_right_y\"].append(0)\n",
|
320 |
+
" keypoint_dict[f\"{hand_landmarks[idx]}_right_z\"].append(0)\n",
|
321 |
+
"\n",
|
322 |
+
" if results.left_hand_landmarks:\n",
|
323 |
+
" for idx, landmark in enumerate(results.left_hand_landmarks.landmark): \n",
|
324 |
+
" keypoint_dict[f\"{hand_landmarks[idx]}_left_x\"].append(landmark.x)\n",
|
325 |
+
" keypoint_dict[f\"{hand_landmarks[idx]}_left_y\"].append(landmark.y)\n",
|
326 |
+
" keypoint_dict[f\"{hand_landmarks[idx]}_left_z\"].append(landmark.z)\n",
|
327 |
+
" else:\n",
|
328 |
+
" for idx in range(len(hand_landmarks)):\n",
|
329 |
+
" keypoint_dict[f\"{hand_landmarks[idx]}_left_x\"].append(0)\n",
|
330 |
+
" keypoint_dict[f\"{hand_landmarks[idx]}_left_y\"].append(0)\n",
|
331 |
+
" keypoint_dict[f\"{hand_landmarks[idx]}_left_z\"].append(0)\n",
|
332 |
+
"\n",
|
333 |
+
" if results.pose_landmarks:\n",
|
334 |
+
" for idx, landmark in enumerate(results.pose_landmarks.landmark): \n",
|
335 |
+
" keypoint_dict[f\"{pose_landmarks[idx]}_x\"].append(landmark.x)\n",
|
336 |
+
" keypoint_dict[f\"{pose_landmarks[idx]}_y\"].append(landmark.y)\n",
|
337 |
+
" keypoint_dict[f\"{pose_landmarks[idx]}_z\"].append(landmark.z)\n",
|
338 |
+
" else:\n",
|
339 |
+
" for idx in range(len(pose_landmarks)):\n",
|
340 |
+
" keypoint_dict[f\"{pose_landmarks[idx]}_x\"].append(0)\n",
|
341 |
+
" keypoint_dict[f\"{pose_landmarks[idx]}_y\"].append(0)\n",
|
342 |
+
" keypoint_dict[f\"{pose_landmarks[idx]}_z\"].append(0)\n",
|
343 |
+
"\n",
|
344 |
+
" keypoint_dict[\"frame\"] = count\n",
|
345 |
+
" keypoint_dict[\"video_path\"] = video_path\n",
|
346 |
+
" keypoint_dict[\"label\"] = label\n",
|
347 |
+
" keypoint_dict[\"actor\"] = actor\n",
|
348 |
+
"\n",
|
349 |
+
" return keypoint_dict\n",
|
350 |
+
"\n",
|
351 |
+
"def process_videos():\n",
|
352 |
+
" csv_file = f\"videos_list.csv\"\n",
|
353 |
+
" data = pd.read_csv(csv_file)\n",
|
354 |
+
"\n",
|
355 |
+
" keypoints_list = Parallel(n_jobs=-1)( \n",
|
356 |
+
" delayed(extract_keypoint)(row['file'], row['label'], row['actor']) for index, row in tqdm(data.iterrows(), total=len(data), desc=\"Processing videos\", leave=False)\n",
|
357 |
+
" )\n",
|
358 |
+
"\n",
|
359 |
+
" keypoints_df = pd.DataFrame(keypoints_list)\n",
|
360 |
+
" keypoints_df.to_csv(f\"vsl{num_labels}_keypoints.csv\", index=False)\n",
|
361 |
+
"\n",
|
362 |
+
"if __name__ == '__main__':\n",
|
363 |
+
" process_videos()\n"
|
364 |
+
]
|
365 |
+
},
|
366 |
+
{
|
367 |
+
"cell_type": "markdown",
|
368 |
+
"metadata": {},
|
369 |
+
"source": [
|
370 |
+
"Interpolation"
|
371 |
+
]
|
372 |
+
},
|
373 |
+
{
|
374 |
+
"cell_type": "code",
|
375 |
+
"execution_count": null,
|
376 |
+
"metadata": {},
|
377 |
+
"outputs": [
|
378 |
+
{
|
379 |
+
"name": "stderr",
|
380 |
+
"output_type": "stream",
|
381 |
+
"text": [
|
382 |
+
"100%|██████████| 280/280 [00:04<00:00, 67.80it/s]\n"
|
383 |
+
]
|
384 |
+
},
|
385 |
+
{
|
386 |
+
"name": "stdout",
|
387 |
+
"output_type": "stream",
|
388 |
+
"text": [
|
389 |
+
"Interpolated keypoints saved to vsl10_interpolated_keypoints.csv\n"
|
390 |
+
]
|
391 |
+
},
|
392 |
+
{
|
393 |
+
"name": "stderr",
|
394 |
+
"output_type": "stream",
|
395 |
+
"text": [
|
396 |
+
"100%|██████████| 280/280 [00:03<00:00, 91.60it/s] "
|
397 |
+
]
|
398 |
+
},
|
399 |
+
{
|
400 |
+
"name": "stdout",
|
401 |
+
"output_type": "stream",
|
402 |
+
"text": [
|
403 |
+
"Data processing and saving completed.\n"
|
404 |
+
]
|
405 |
+
},
|
406 |
+
{
|
407 |
+
"name": "stderr",
|
408 |
+
"output_type": "stream",
|
409 |
+
"text": [
|
410 |
+
"\n"
|
411 |
+
]
|
412 |
+
}
|
413 |
+
],
|
414 |
+
"source": [
|
415 |
+
"import pandas as pd\n",
|
416 |
+
"import numpy as np\n",
|
417 |
+
"import ast\n",
|
418 |
+
"from tqdm import tqdm\n",
|
419 |
+
"\n",
|
420 |
+
"def find_index(array):\n",
|
421 |
+
" for i, num in enumerate(array):\n",
|
422 |
+
" if num != 0:\n",
|
423 |
+
" return i\n",
|
424 |
+
"\n",
|
425 |
+
"def curl_skeleton(array):\n",
|
426 |
+
" if sum(array) == 0:\n",
|
427 |
+
" return array\n",
|
428 |
+
" for i, location in enumerate(array):\n",
|
429 |
+
" if location != 0:\n",
|
430 |
+
" continue\n",
|
431 |
+
" else:\n",
|
432 |
+
" if i == 0 or i == len(array) - 1:\n",
|
433 |
+
" continue\n",
|
434 |
+
" else:\n",
|
435 |
+
" if array[i + 1] != 0:\n",
|
436 |
+
" array[i] = float((array[i - 1] + array[i + 1]) / 2)\n",
|
437 |
+
" else:\n",
|
438 |
+
" if sum(array[i:]) == 0:\n",
|
439 |
+
" continue\n",
|
440 |
+
" else:\n",
|
441 |
+
" j = find_index(array[i + 1:])\n",
|
442 |
+
" array[i] = float(((1 + j) * array[i - 1] + 1 * array[i + 1 + j]) / (2 + j))\n",
|
443 |
+
" return array\n",
|
444 |
+
"\n",
|
445 |
+
"def interpolate_keypoints(input_file, output_file, body_identifiers):\n",
|
446 |
+
" train_data = pd.read_csv(input_file)\n",
|
447 |
+
" output_df = train_data.copy()\n",
|
448 |
+
"\n",
|
449 |
+
" for index, video in tqdm(train_data.iterrows(), total=train_data.shape[0]):\n",
|
450 |
+
" for identifier in body_identifiers:\n",
|
451 |
+
" # Interpolate the x and y keypoints\n",
|
452 |
+
" x_values = curl_skeleton(ast.literal_eval(video[identifier + \"_x\"]))\n",
|
453 |
+
" y_values = curl_skeleton(ast.literal_eval(video[identifier + \"_y\"]))\n",
|
454 |
+
"\n",
|
455 |
+
" output_df.at[index, identifier + \"_x\"] = str(x_values)\n",
|
456 |
+
" output_df.at[index, identifier + \"_y\"] = str(y_values)\n",
|
457 |
+
"\n",
|
458 |
+
" output_df.to_csv(output_file, index=False)\n",
|
459 |
+
" print(f\"Interpolated keypoints saved to {output_file}\")\n",
|
460 |
+
"\n",
|
461 |
+
"if __name__ == \"__main__\":\n",
|
462 |
+
" input_file_path = f\"vsl{num_labels}_keypoints.csv\"\n",
|
463 |
+
" output_file_path = f\"vsl{num_labels}_interpolated_keypoints.csv\"\n",
|
464 |
+
"\n",
|
465 |
+
" hand_landmarks = [\n",
|
466 |
+
" 'INDEX_FINGER_DIP', 'INDEX_FINGER_MCP', 'INDEX_FINGER_PIP', 'INDEX_FINGER_TIP', \n",
|
467 |
+
" 'MIDDLE_FINGER_DIP', 'MIDDLE_FINGER_MCP', 'MIDDLE_FINGER_PIP', 'MIDDLE_FINGER_TIP', \n",
|
468 |
+
" 'PINKY_DIP', 'PINKY_MCP', 'PINKY_PIP', 'PINKY_TIP', \n",
|
469 |
+
" 'RING_FINGER_DIP', 'RING_FINGER_MCP', 'RING_FINGER_PIP', 'RING_FINGER_TIP', \n",
|
470 |
+
" 'THUMB_CMC', 'THUMB_IP', 'THUMB_MCP', 'THUMB_TIP', 'WRIST'\n",
|
471 |
+
" ]\n",
|
472 |
+
" HAND_IDENTIFIERS = [id + \"_right\" for id in hand_landmarks] + [id + \"_left\" for id in hand_landmarks]\n",
|
473 |
+
" POSE_IDENTIFIERS = [\"RIGHT_SHOULDER\", \"LEFT_SHOULDER\", \"LEFT_ELBOW\", \"RIGHT_ELBOW\"]\n",
|
474 |
+
" body_identifiers = HAND_IDENTIFIERS + POSE_IDENTIFIERS \n",
|
475 |
+
"\n",
|
476 |
+
" interpolate_keypoints(input_file_path, output_file_path, body_identifiers)\n",
|
477 |
+
"\n",
|
478 |
+
" # Load interpolated data and store them in numpy files\n",
|
479 |
+
" train_data = pd.read_csv(output_file_path)\n",
|
480 |
+
" frames = 80\n",
|
481 |
+
"\n",
|
482 |
+
" data = []\n",
|
483 |
+
" labels = []\n",
|
484 |
+
"\n",
|
485 |
+
" for video_index, video in tqdm(train_data.iterrows(), total=train_data.shape[0]):\n",
|
486 |
+
" T = len(ast.literal_eval(video[\"INDEX_FINGER_DIP_right_x\"]))\n",
|
487 |
+
" current_row = np.empty(shape=(2, T, len(body_identifiers), 1))\n",
|
488 |
+
"\n",
|
489 |
+
" for index, identifier in enumerate(body_identifiers):\n",
|
490 |
+
" data_keypoint_preprocess_x = ast.literal_eval(video[identifier + \"_x\"])\n",
|
491 |
+
" current_row[0, :, index, :] = np.asarray(data_keypoint_preprocess_x).reshape(T, 1)\n",
|
492 |
+
"\n",
|
493 |
+
" data_keypoint_preprocess_y = ast.literal_eval(video[identifier + \"_y\"])\n",
|
494 |
+
" current_row[1, :, index, :] = np.asarray(data_keypoint_preprocess_y).reshape(T, 1)\n",
|
495 |
+
"\n",
|
496 |
+
" if T < frames:\n",
|
497 |
+
" target = np.zeros(shape=(2, frames, len(body_identifiers), 1))\n",
|
498 |
+
" target[:, :T, :, :] = current_row\n",
|
499 |
+
" else:\n",
|
500 |
+
" target = current_row[:, :frames, :, :]\n",
|
501 |
+
"\n",
|
502 |
+
" data.append(target)\n",
|
503 |
+
" labels.append(int(video[\"label\"]))\n",
|
504 |
+
"\n",
|
505 |
+
" keypoint_data = np.stack(data, axis=0)\n",
|
506 |
+
" label_data = np.stack(labels, axis=0)\n",
|
507 |
+
" np.save(f'vsl{num_labels}_data_preprocess.npy', keypoint_data)\n",
|
508 |
+
" np.save(f'vsl{num_labels}_label_preprocess.npy', label_data)\n",
|
509 |
+
" print(\"Data processing and saving completed.\")\n"
|
510 |
+
]
|
511 |
+
},
|
512 |
+
{
|
513 |
+
"cell_type": "code",
|
514 |
+
"execution_count": 7,
|
515 |
+
"metadata": {},
|
516 |
+
"outputs": [
|
517 |
+
{
|
518 |
+
"name": "stdout",
|
519 |
+
"output_type": "stream",
|
520 |
+
"text": [
|
521 |
+
"(280, 2, 80, 46, 1)\n",
|
522 |
+
"(280,)\n"
|
523 |
+
]
|
524 |
+
}
|
525 |
+
],
|
526 |
+
"source": [
|
527 |
+
"import numpy as np\n",
|
528 |
+
"a = np.load(f'vsl{num_labels}_data_preprocess.npy')\n",
|
529 |
+
"b = np.load(f'vsl{num_labels}_label_preprocess.npy')\n",
|
530 |
+
"\n",
|
531 |
+
"print(a.shape)\n",
|
532 |
+
"print(b.shape)"
|
533 |
+
]
|
534 |
+
},
|
535 |
+
{
|
536 |
+
"cell_type": "markdown",
|
537 |
+
"metadata": {},
|
538 |
+
"source": [
|
539 |
+
"Do K-Folds and store the keypoints in numpy files"
|
540 |
+
]
|
541 |
+
},
|
542 |
+
{
|
543 |
+
"cell_type": "code",
|
544 |
+
"execution_count": null,
|
545 |
+
"metadata": {},
|
546 |
+
"outputs": [
|
547 |
+
{
|
548 |
+
"name": "stdout",
|
549 |
+
"output_type": "stream",
|
550 |
+
"text": [
|
551 |
+
"Number of actors: 28\n",
|
552 |
+
"-----------------------------------------------------\n",
|
553 |
+
"Fold 1: 30 test samples\n",
|
554 |
+
"Fold 2: 29 test samples\n",
|
555 |
+
"Fold 3: 30 test samples\n",
|
556 |
+
"Fold 4: 30 test samples\n",
|
557 |
+
"Fold 5: 30 test samples\n",
|
558 |
+
"Fold 6: 30 test samples\n",
|
559 |
+
"Fold 7: 31 test samples\n",
|
560 |
+
"Fold 8: 30 test samples\n",
|
561 |
+
"Fold 9: 20 test samples\n",
|
562 |
+
"Fold 10: 20 test samples\n",
|
563 |
+
"Processed and saved vsl10 fold 1 successfully.\n",
|
564 |
+
"Processed and saved vsl10 fold 2 successfully.\n",
|
565 |
+
"Processed and saved vsl10 fold 3 successfully.\n",
|
566 |
+
"Processed and saved vsl10 fold 4 successfully.\n",
|
567 |
+
"Processed and saved vsl10 fold 5 successfully.\n",
|
568 |
+
"Processed and saved vsl10 fold 6 successfully.\n",
|
569 |
+
"Processed and saved vsl10 fold 7 successfully.\n",
|
570 |
+
"Processed and saved vsl10 fold 8 successfully.\n",
|
571 |
+
"Processed and saved vsl10 fold 9 successfully.\n",
|
572 |
+
"Processed and saved vsl10 fold 10 successfully.\n"
|
573 |
+
]
|
574 |
+
}
|
575 |
+
],
|
576 |
+
"source": [
|
577 |
+
"from sklearn.model_selection import KFold\n",
|
578 |
+
"import os\n",
|
579 |
+
"import numpy as np\n",
|
580 |
+
"import pandas as pd\n",
|
581 |
+
"from tqdm import tqdm\n",
|
582 |
+
"\n",
|
583 |
+
"def k_fold_cross_validation(train_data, keypoint_data, label_data, num_labels, k_folds, destination_folder=\"numpy_files\"):\n",
|
584 |
+
" os.makedirs(destination_folder, exist_ok=True)\n",
|
585 |
+
"\n",
|
586 |
+
" actors = train_data['actor'].unique()\n",
|
587 |
+
" print(f\"Number of actors: {len(actors)}\")\n",
|
588 |
+
" print('-----------------------------------------------------')\n",
|
589 |
+
"\n",
|
590 |
+
" kf = KFold(n_splits=k_folds, shuffle=True, random_state=42)\n",
|
591 |
+
"\n",
|
592 |
+
" actor_to_indices = {actor: train_data.index[train_data['actor'] == actor].tolist() for actor in actors}\n",
|
593 |
+
" folds = [[] for _ in range(k_folds)]\n",
|
594 |
+
"\n",
|
595 |
+
" for fold, (train_actors, test_actors) in enumerate(kf.split(actors)):\n",
|
596 |
+
" train_actors = actors[train_actors]\n",
|
597 |
+
" test_actors = actors[test_actors]\n",
|
598 |
+
" \n",
|
599 |
+
" for actor in test_actors:\n",
|
600 |
+
" folds[fold].extend(actor_to_indices[actor])\n",
|
601 |
+
"\n",
|
602 |
+
" tqdm.write(f\"Fold {fold+1}: {len(folds[fold])} test samples\")\n",
|
603 |
+
"\n",
|
604 |
+
" # Iterate over each fold to create train-test splits\n",
|
605 |
+
" for fold in range(k_folds):\n",
|
606 |
+
" test_indices = folds[fold]\n",
|
607 |
+
" train_indices = [idx for f in range(k_folds) if f != fold for idx in folds[f]]\n",
|
608 |
+
"\n",
|
609 |
+
" X_train, X_test = keypoint_data[train_indices], keypoint_data[test_indices]\n",
|
610 |
+
" y_train = np.array(label_data[train_indices], dtype=np.int64)\n",
|
611 |
+
" y_test = np.array(label_data[test_indices], dtype=np.int64)\n",
|
612 |
+
"\n",
|
613 |
+
" np.save(os.path.join(destination_folder, f'vsl{num_labels}_data_fold{fold+1}_train.npy'), X_train)\n",
|
614 |
+
" np.save(os.path.join(destination_folder, f'vsl{num_labels}_label_fold{fold+1}_train.npy'), y_train)\n",
|
615 |
+
" np.save(os.path.join(destination_folder, f'vsl{num_labels}_data_fold{fold+1}_test.npy'), X_test)\n",
|
616 |
+
" np.save(os.path.join(destination_folder, f'vsl{num_labels}_label_fold{fold+1}_test.npy'), y_test)\n",
|
617 |
+
"\n",
|
618 |
+
" tqdm.write(f\"Processed and saved vsl{num_labels} fold {fold+1} successfully.\")\n",
|
619 |
+
"\n",
|
620 |
+
"if __name__ == \"__main__\":\n",
|
621 |
+
" input_file_path = f\"vsl{num_labels}_interpolated_keypoints.csv\"\n",
|
622 |
+
" train_data = pd.read_csv(input_file_path)\n",
|
623 |
+
"\n",
|
624 |
+
" keypoint_data = np.load(f'vsl{num_labels}_data_preprocess.npy')\n",
|
625 |
+
" label_data = np.load(f'vsl{num_labels}_label_preprocess.npy')\n",
|
626 |
+
"\n",
|
627 |
+
" num_labels = len(np.unique(label_data))\n",
|
628 |
+
"\n",
|
629 |
+
" k_folds = 10\n",
|
630 |
+
" k_fold_cross_validation(train_data, keypoint_data, label_data, num_labels, k_folds)\n"
|
631 |
+
]
|
632 |
+
},
|
633 |
+
{
|
634 |
+
"cell_type": "code",
|
635 |
+
"execution_count": 9,
|
636 |
+
"metadata": {},
|
637 |
+
"outputs": [
|
638 |
+
{
|
639 |
+
"name": "stdout",
|
640 |
+
"output_type": "stream",
|
641 |
+
"text": [
|
642 |
+
"(29, 2, 80, 46, 1)\n",
|
643 |
+
"(251, 2, 80, 46, 1)\n"
|
644 |
+
]
|
645 |
+
}
|
646 |
+
],
|
647 |
+
"source": [
|
648 |
+
"import numpy as np\n",
|
649 |
+
"a = np.load(f'numpy_files/vsl{num_labels}_data_fold2_test.npy')\n",
|
650 |
+
"b = np.load(f'numpy_files/vsl{num_labels}_data_fold2_train.npy')\n",
|
651 |
+
"\n",
|
652 |
+
"print(a.shape)\n",
|
653 |
+
"print(b.shape)"
|
654 |
+
]
|
655 |
+
},
|
656 |
+
{
|
657 |
+
"cell_type": "markdown",
|
658 |
+
"metadata": {},
|
659 |
+
"source": [
|
660 |
+
"train directly with different folds"
|
661 |
+
]
|
662 |
+
},
|
663 |
+
{
|
664 |
+
"cell_type": "code",
|
665 |
+
"execution_count": null,
|
666 |
+
"metadata": {},
|
667 |
+
"outputs": [],
|
668 |
+
"source": [
|
669 |
+
"'''\n",
|
670 |
+
"import os\n",
|
671 |
+
"import numpy as np\n",
|
672 |
+
"import torch\n",
|
673 |
+
"from torch.utils.data import DataLoader\n",
|
674 |
+
"import pytorch_lightning as pl\n",
|
675 |
+
"from pytorch_lightning.callbacks import ModelCheckpoint\n",
|
676 |
+
"from feeder import FeederINCLUDE\n",
|
677 |
+
"from aagcn import Model\n",
|
678 |
+
"from augumentation import Rotate, Compose\n",
|
679 |
+
"\n",
|
680 |
+
"os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"1\"\n",
|
681 |
+
"\n",
|
682 |
+
"if __name__ == '__main__':\n",
|
683 |
+
" k_folds = 10\n",
|
684 |
+
" config = {'batch_size': 128, 'learning_rate': 0.0137296, 'weight_decay': 0.000150403}\n",
|
685 |
+
" \n",
|
686 |
+
" device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
|
687 |
+
"\n",
|
688 |
+
" best_accuracy = 0.0\n",
|
689 |
+
" best_fold = -1\n",
|
690 |
+
"\n",
|
691 |
+
" for fold in range(k_folds):\n",
|
692 |
+
" print(f\"Starting fold {fold + 1}/{k_folds}\")\n",
|
693 |
+
" train_data_path = os.path.join(\"numpy_files\", f'vsl{num_labels}_data_fold{fold+1}_train.npy')\n",
|
694 |
+
" train_label_path = os.path.join(\"numpy_files\", f'vsl{num_labels}_label_fold{fold+1}_train.npy')\n",
|
695 |
+
" val_data_path = os.path.join(\"numpy_files\", f'vsl{num_labels}_data_fold{fold+1}_test.npy')\n",
|
696 |
+
" val_label_path = os.path.join(\"numpy_files\", f'vsl{num_labels}_label_fold{fold+1}_test.npy')\n",
|
697 |
+
"\n",
|
698 |
+
" transforms = Compose([\n",
|
699 |
+
" Rotate(15, 80, 25, (0.5, 0.5))\n",
|
700 |
+
" ])\n",
|
701 |
+
"\n",
|
702 |
+
" train_dataset = FeederINCLUDE(\n",
|
703 |
+
" data_path=train_data_path,\n",
|
704 |
+
" label_path=train_label_path,\n",
|
705 |
+
" transform=transforms\n",
|
706 |
+
" )\n",
|
707 |
+
" val_dataset = FeederINCLUDE(\n",
|
708 |
+
" data_path=val_data_path,\n",
|
709 |
+
" label_path=val_label_path\n",
|
710 |
+
" )\n",
|
711 |
+
"\n",
|
712 |
+
" train_dataloader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True)\n",
|
713 |
+
" val_dataloader = DataLoader(val_dataset, batch_size=config['batch_size'], shuffle=False)\n",
|
714 |
+
"\n",
|
715 |
+
" model = Model(num_class=num_labels, num_point=46, num_person=1, in_channels=2,\n",
|
716 |
+
" graph_args={\"layout\": \"mediapipe_two_hand\", \"strategy\": \"spatial\"},\n",
|
717 |
+
" learning_rate=config['learning_rate'], weight_decay=config['weight_decay'])\n",
|
718 |
+
"\n",
|
719 |
+
" callbacks = [\n",
|
720 |
+
" ModelCheckpoint(\n",
|
721 |
+
" dirpath=\"checkpoints\",\n",
|
722 |
+
" monitor=\"valid_accuracy\",\n",
|
723 |
+
" mode=\"max\",\n",
|
724 |
+
" every_n_epochs=2,\n",
|
725 |
+
" filename=f'vsl{num_labels}-aagcn-fold={fold+1}'\n",
|
726 |
+
" ),\n",
|
727 |
+
" ]\n",
|
728 |
+
"\n",
|
729 |
+
" trainer = pl.Trainer(max_epochs=2, accelerator=\"auto\", check_val_every_n_epoch=1,\n",
|
730 |
+
" devices=1, callbacks=callbacks)\n",
|
731 |
+
"\n",
|
732 |
+
" trainer.fit(model, train_dataloader, val_dataloader)\n",
|
733 |
+
" val_accuracy = trainer.callback_metrics['valid_accuracy'].item()\n",
|
734 |
+
" print(f\"Fold {fold + 1} finished with validation accuracy: {val_accuracy:.4f}\")\n",
|
735 |
+
"\n",
|
736 |
+
" if val_accuracy > best_accuracy:\n",
|
737 |
+
" best_accuracy = val_accuracy\n",
|
738 |
+
" best_fold = fold + 1 \n",
|
739 |
+
"\n",
|
740 |
+
" print(f\"The highest validation accuracy achieved is {best_accuracy:.4f} from fold {best_fold}.\")\n",
|
741 |
+
"'''"
|
742 |
+
]
|
743 |
+
},
|
744 |
+
{
|
745 |
+
"cell_type": "code",
|
746 |
+
"execution_count": null,
|
747 |
+
"metadata": {},
|
748 |
+
"outputs": [],
|
749 |
+
"source": [
|
750 |
+
"\n",
|
751 |
+
"print(f\"The highest validation accuracy achieved of vsl{num_labels} is {best_accuracy:.4f} from fold {best_fold}.\")"
|
752 |
+
]
|
753 |
+
},
|
754 |
+
{
|
755 |
+
"cell_type": "markdown",
|
756 |
+
"metadata": {},
|
757 |
+
"source": [
|
758 |
+
"train based on AUTSL with different folds"
|
759 |
+
]
|
760 |
+
},
|
761 |
+
{
|
762 |
+
"cell_type": "code",
|
763 |
+
"execution_count": null,
|
764 |
+
"metadata": {},
|
765 |
+
"outputs": [],
|
766 |
+
"source": [
|
767 |
+
"\n",
|
768 |
+
"import os\n",
|
769 |
+
"import numpy as np\n",
|
770 |
+
"import torch\n",
|
771 |
+
"from torch.utils.data import DataLoader\n",
|
772 |
+
"import pytorch_lightning as pl\n",
|
773 |
+
"from pytorch_lightning.callbacks import ModelCheckpoint\n",
|
774 |
+
"from feeder import FeederINCLUDE\n",
|
775 |
+
"from aagcn import Model\n",
|
776 |
+
"from augumentation import Rotate, Compose\n",
|
777 |
+
"from pytorch_lightning.utilities.migration import pl_legacy_patch\n",
|
778 |
+
"\n",
|
779 |
+
"os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"1\"\n",
|
780 |
+
"\n",
|
781 |
+
"if __name__ == '__main__':\n",
|
782 |
+
" k_folds = 10 \n",
|
783 |
+
" config = {'batch_size': 128, 'learning_rate': 0.0137296, 'weight_decay': 0.000150403}\n",
|
784 |
+
" \n",
|
785 |
+
" device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
|
786 |
+
"\n",
|
787 |
+
" best_accuracy = 0.0\n",
|
788 |
+
" best_fold = -1\n",
|
789 |
+
"\n",
|
790 |
+
" for fold in range(k_folds):\n",
|
791 |
+
" print(f\"Starting fold {fold + 1}/{k_folds}\")\n",
|
792 |
+
" train_data_path = os.path.join(\"numpy_files\", f'vsl{num_labels}_data_fold{fold+1}_train.npy')\n",
|
793 |
+
" train_label_path = os.path.join(\"numpy_files\", f'vsl{num_labels}_label_fold{fold+1}_train.npy')\n",
|
794 |
+
" val_data_path = os.path.join(\"numpy_files\", f'vsl{num_labels}_data_fold{fold+1}_test.npy')\n",
|
795 |
+
" val_label_path = os.path.join(\"numpy_files\", f'vsl{num_labels}_label_fold{fold+1}_test.npy')\n",
|
796 |
+
"\n",
|
797 |
+
" transforms = Compose([\n",
|
798 |
+
" Rotate(15, 80, 25, (0.5, 0.5))\n",
|
799 |
+
" ])\n",
|
800 |
+
"\n",
|
801 |
+
" train_dataset = FeederINCLUDE(\n",
|
802 |
+
" data_path=train_data_path,\n",
|
803 |
+
" label_path=train_label_path,\n",
|
804 |
+
" transform=transforms\n",
|
805 |
+
" )\n",
|
806 |
+
" val_dataset = FeederINCLUDE(\n",
|
807 |
+
" data_path=val_data_path,\n",
|
808 |
+
" label_path=val_label_path\n",
|
809 |
+
" )\n",
|
810 |
+
"\n",
|
811 |
+
" train_dataloader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True)\n",
|
812 |
+
" val_dataloader = DataLoader(val_dataset, batch_size=config['batch_size'], shuffle=False)\n",
|
813 |
+
"\n",
|
814 |
+
" model = Model(num_class=num_labels, num_point=46, num_person=1, in_channels=2,\n",
|
815 |
+
" graph_args={\"layout\": \"mediapipe_two_hand\", \"strategy\": \"spatial\"},\n",
|
816 |
+
" learning_rate=config['learning_rate'], weight_decay=config['weight_decay'])\n",
|
817 |
+
"\n",
|
818 |
+
" # Path pre-trained checkpoint file on AUTSL\n",
|
819 |
+
" checkpoint_path = \"epoch=55-valid_loss=0.41-valid_accuracy=0.85-autsl-aagcn.ckpt\"\n",
|
820 |
+
"\n",
|
821 |
+
" with pl_legacy_patch():\n",
|
822 |
+
" checkpoint = torch.load(checkpoint_path, map_location=device)\n",
|
823 |
+
"\n",
|
824 |
+
" state_dict = checkpoint['state_dict']\n",
|
825 |
+
" filtered_state_dict = {k: v for k, v in state_dict.items() if not k.startswith('fc.')}\n",
|
826 |
+
" model.load_state_dict(filtered_state_dict, strict=False)\n",
|
827 |
+
"\n",
|
828 |
+
" callbacks = [\n",
|
829 |
+
" ModelCheckpoint(\n",
|
830 |
+
" dirpath=\"checkpoints\",\n",
|
831 |
+
" monitor=\"valid_accuracy\",\n",
|
832 |
+
" mode=\"max\",\n",
|
833 |
+
" every_n_epochs=2,\n",
|
834 |
+
" filename=f'autsl_vsl{num_labels}-aagcn-fold={fold+1}'\n",
|
835 |
+
" ),\n",
|
836 |
+
" ]\n",
|
837 |
+
"\n",
|
838 |
+
" trainer = pl.Trainer(max_epochs=100, accelerator=\"auto\", check_val_every_n_epoch=1,\n",
|
839 |
+
" devices=1, callbacks=callbacks)\n",
|
840 |
+
"\n",
|
841 |
+
" trainer.fit(model, train_dataloader, val_dataloader)\n",
|
842 |
+
" val_accuracy = trainer.callback_metrics['valid_accuracy'].item() \n",
|
843 |
+
" print(f\"Fold {fold + 1} finished with validation accuracy: {val_accuracy:.4f}\")\n",
|
844 |
+
"\n",
|
845 |
+
" if val_accuracy > best_accuracy:\n",
|
846 |
+
" best_accuracy = val_accuracy\n",
|
847 |
+
" best_fold = fold + 1 \n",
|
848 |
+
"\n",
|
849 |
+
" print(f\"The highest validation accuracy achieved is {best_accuracy:.4f} from fold {best_fold}.\")\n"
|
850 |
+
]
|
851 |
+
},
|
852 |
+
{
|
853 |
+
"cell_type": "code",
|
854 |
+
"execution_count": null,
|
855 |
+
"metadata": {},
|
856 |
+
"outputs": [],
|
857 |
+
"source": [
|
858 |
+
"print(f\"The highest validation accuracy achieved of autsl vsl{num_labels} is {best_accuracy:.4f} from fold {best_fold}.\")"
|
859 |
+
]
|
860 |
+
}
|
861 |
+
],
|
862 |
+
"metadata": {
|
863 |
+
"kernelspec": {
|
864 |
+
"display_name": "Python 3",
|
865 |
+
"language": "python",
|
866 |
+
"name": "python3"
|
867 |
+
},
|
868 |
+
"language_info": {
|
869 |
+
"codemirror_mode": {
|
870 |
+
"name": "ipython",
|
871 |
+
"version": 3
|
872 |
+
},
|
873 |
+
"file_extension": ".py",
|
874 |
+
"mimetype": "text/x-python",
|
875 |
+
"name": "python",
|
876 |
+
"nbconvert_exporter": "python",
|
877 |
+
"pygments_lexer": "ipython3",
|
878 |
+
"version": "3.10.11"
|
879 |
+
}
|
880 |
+
},
|
881 |
+
"nbformat": 4,
|
882 |
+
"nbformat_minor": 2
|
883 |
+
}
|
@@ -0,0 +1,440 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from torch.autograd import Variable
|
7 |
+
from graph import Graph
|
8 |
+
import pytorch_lightning as pl
|
9 |
+
from torchmetrics.classification import MulticlassAccuracy, BinaryAccuracy
|
10 |
+
import torch.optim as optim
|
11 |
+
|
12 |
+
def import_class(name):
|
13 |
+
components = name.split('.')
|
14 |
+
mod = __import__(components[0])
|
15 |
+
for comp in components[1:]:
|
16 |
+
mod = getattr(mod, comp)
|
17 |
+
return mod
|
18 |
+
|
19 |
+
|
20 |
+
def conv_branch_init(conv, branches):
|
21 |
+
weight = conv.weight
|
22 |
+
n = weight.size(0)
|
23 |
+
k1 = weight.size(1)
|
24 |
+
k2 = weight.size(2)
|
25 |
+
nn.init.normal_(weight, 0, math.sqrt(2. / (n * k1 * k2 * branches)))
|
26 |
+
nn.init.constant_(conv.bias, 0)
|
27 |
+
|
28 |
+
|
29 |
+
def conv_init(conv):
|
30 |
+
nn.init.kaiming_normal_(conv.weight, mode='fan_out')
|
31 |
+
nn.init.constant_(conv.bias, 0)
|
32 |
+
|
33 |
+
|
34 |
+
def bn_init(bn, scale):
|
35 |
+
nn.init.constant_(bn.weight, scale)
|
36 |
+
nn.init.constant_(bn.bias, 0)
|
37 |
+
|
38 |
+
|
39 |
+
class unit_tcn(nn.Module):
|
40 |
+
def __init__(self, in_channels, out_channels, kernel_size=9, stride=1):
|
41 |
+
super(unit_tcn, self).__init__()
|
42 |
+
pad = int((kernel_size - 1) / 2)
|
43 |
+
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=(kernel_size, 1), padding=(pad, 0),
|
44 |
+
stride=(stride, 1))
|
45 |
+
|
46 |
+
self.bn = nn.BatchNorm2d(out_channels)
|
47 |
+
self.relu = nn.ReLU(inplace=True)
|
48 |
+
conv_init(self.conv)
|
49 |
+
bn_init(self.bn, 1)
|
50 |
+
|
51 |
+
def forward(self, x):
|
52 |
+
x = self.bn(self.conv(x))
|
53 |
+
return x
|
54 |
+
|
55 |
+
|
56 |
+
class unit_gcn(nn.Module):
|
57 |
+
def __init__(self, in_channels, out_channels, A, coff_embedding=4, num_subset=3, adaptive=True, attention=True):
|
58 |
+
super(unit_gcn, self).__init__()
|
59 |
+
inter_channels = out_channels // coff_embedding
|
60 |
+
self.inter_c = inter_channels
|
61 |
+
self.out_c = out_channels
|
62 |
+
self.in_c = in_channels
|
63 |
+
self.num_subset = num_subset
|
64 |
+
num_jpts = A.shape[-1]
|
65 |
+
|
66 |
+
self.conv_d = nn.ModuleList()
|
67 |
+
for i in range(self.num_subset):
|
68 |
+
self.conv_d.append(nn.Conv2d(in_channels, out_channels, 1))
|
69 |
+
|
70 |
+
if adaptive:
|
71 |
+
self.PA = nn.Parameter(torch.from_numpy(A.astype(np.float32)))
|
72 |
+
self.alpha = nn.Parameter(torch.zeros(1))
|
73 |
+
# self.beta = nn.Parameter(torch.ones(1))
|
74 |
+
# nn.init.constant_(self.PA, 1e-6)
|
75 |
+
# self.A = Variable(torch.from_numpy(A.astype(np.float32)), requires_grad=False)
|
76 |
+
# self.A = self.PA
|
77 |
+
self.conv_a = nn.ModuleList()
|
78 |
+
self.conv_b = nn.ModuleList()
|
79 |
+
for i in range(self.num_subset):
|
80 |
+
self.conv_a.append(nn.Conv2d(in_channels, inter_channels, 1))
|
81 |
+
self.conv_b.append(nn.Conv2d(in_channels, inter_channels, 1))
|
82 |
+
else:
|
83 |
+
self.A = Variable(torch.from_numpy(A.astype(np.float32)), requires_grad=False)
|
84 |
+
self.adaptive = adaptive
|
85 |
+
|
86 |
+
if attention:
|
87 |
+
# self.beta = nn.Parameter(torch.zeros(1))
|
88 |
+
# self.gamma = nn.Parameter(torch.zeros(1))
|
89 |
+
# unified attention
|
90 |
+
# self.Attention = nn.Parameter(torch.ones(num_jpts))
|
91 |
+
|
92 |
+
# temporal attention
|
93 |
+
self.conv_ta = nn.Conv1d(out_channels, 1, 9, padding=4)
|
94 |
+
nn.init.constant_(self.conv_ta.weight, 0)
|
95 |
+
nn.init.constant_(self.conv_ta.bias, 0)
|
96 |
+
|
97 |
+
# s attention
|
98 |
+
ker_jpt = num_jpts - 1 if not num_jpts % 2 else num_jpts
|
99 |
+
pad = (ker_jpt - 1) // 2
|
100 |
+
self.conv_sa = nn.Conv1d(out_channels, 1, ker_jpt, padding=pad)
|
101 |
+
nn.init.xavier_normal_(self.conv_sa.weight)
|
102 |
+
nn.init.constant_(self.conv_sa.bias, 0)
|
103 |
+
|
104 |
+
# channel attention
|
105 |
+
rr = 2
|
106 |
+
self.fc1c = nn.Linear(out_channels, out_channels // rr)
|
107 |
+
self.fc2c = nn.Linear(out_channels // rr, out_channels)
|
108 |
+
nn.init.kaiming_normal_(self.fc1c.weight)
|
109 |
+
nn.init.constant_(self.fc1c.bias, 0)
|
110 |
+
nn.init.constant_(self.fc2c.weight, 0)
|
111 |
+
nn.init.constant_(self.fc2c.bias, 0)
|
112 |
+
|
113 |
+
# self.bn = nn.BatchNorm2d(out_channels)
|
114 |
+
# bn_init(self.bn, 1)
|
115 |
+
self.attention = attention
|
116 |
+
|
117 |
+
if in_channels != out_channels:
|
118 |
+
self.down = nn.Sequential(
|
119 |
+
nn.Conv2d(in_channels, out_channels, 1),
|
120 |
+
nn.BatchNorm2d(out_channels)
|
121 |
+
)
|
122 |
+
else:
|
123 |
+
self.down = lambda x: x
|
124 |
+
|
125 |
+
self.bn = nn.BatchNorm2d(out_channels)
|
126 |
+
self.soft = nn.Softmax(-2)
|
127 |
+
self.tan = nn.Tanh()
|
128 |
+
self.sigmoid = nn.Sigmoid()
|
129 |
+
self.relu = nn.ReLU(inplace=True)
|
130 |
+
|
131 |
+
for m in self.modules():
|
132 |
+
if isinstance(m, nn.Conv2d):
|
133 |
+
conv_init(m)
|
134 |
+
elif isinstance(m, nn.BatchNorm2d):
|
135 |
+
bn_init(m, 1)
|
136 |
+
bn_init(self.bn, 1e-6)
|
137 |
+
for i in range(self.num_subset):
|
138 |
+
conv_branch_init(self.conv_d[i], self.num_subset)
|
139 |
+
|
140 |
+
def forward(self, x):
|
141 |
+
N, C, T, V = x.size()
|
142 |
+
|
143 |
+
y = None
|
144 |
+
if self.adaptive:
|
145 |
+
A = self.PA
|
146 |
+
# A = A + self.PA
|
147 |
+
for i in range(self.num_subset):
|
148 |
+
A1 = self.conv_a[i](x).permute(0, 3, 1, 2).contiguous().view(N, V, self.inter_c * T)
|
149 |
+
A2 = self.conv_b[i](x).view(N, self.inter_c * T, V)
|
150 |
+
A1 = self.tan(torch.matmul(A1, A2) / A1.size(-1)) # N V V
|
151 |
+
A1 = A[i] + A1 * self.alpha
|
152 |
+
A2 = x.view(N, C * T, V)
|
153 |
+
z = self.conv_d[i](torch.matmul(A2, A1).view(N, C, T, V))
|
154 |
+
y = z + y if y is not None else z
|
155 |
+
else:
|
156 |
+
A = self.A.cuda(x.get_device()) * self.mask
|
157 |
+
for i in range(self.num_subset):
|
158 |
+
A1 = A[i]
|
159 |
+
A2 = x.view(N, C * T, V)
|
160 |
+
z = self.conv_d[i](torch.matmul(A2, A1).view(N, C, T, V))
|
161 |
+
y = z + y if y is not None else z
|
162 |
+
|
163 |
+
y = self.bn(y)
|
164 |
+
y += self.down(x)
|
165 |
+
y = self.relu(y)
|
166 |
+
|
167 |
+
if self.attention:
|
168 |
+
# spatial attention
|
169 |
+
se = y.mean(-2) # N C V
|
170 |
+
se1 = self.sigmoid(self.conv_sa(se))
|
171 |
+
y = y * se1.unsqueeze(-2) + y
|
172 |
+
# a1 = se1.unsqueeze(-2)
|
173 |
+
|
174 |
+
# temporal attention
|
175 |
+
se = y.mean(-1)
|
176 |
+
se1 = self.sigmoid(self.conv_ta(se))
|
177 |
+
y = y * se1.unsqueeze(-1) + y
|
178 |
+
# a2 = se1.unsqueeze(-1)
|
179 |
+
|
180 |
+
# channel attention
|
181 |
+
se = y.mean(-1).mean(-1)
|
182 |
+
se1 = self.relu(self.fc1c(se))
|
183 |
+
se2 = self.sigmoid(self.fc2c(se1))
|
184 |
+
y = y * se2.unsqueeze(-1).unsqueeze(-1) + y
|
185 |
+
# a3 = se2.unsqueeze(-1).unsqueeze(-1)
|
186 |
+
|
187 |
+
# unified attention
|
188 |
+
# y = y * self.Attention + y
|
189 |
+
# y = y + y * ((a2 + a3) / 2)
|
190 |
+
# y = self.bn(y)
|
191 |
+
return y
|
192 |
+
|
193 |
+
|
194 |
+
class TCN_GCN_unit(nn.Module):
|
195 |
+
def __init__(self, in_channels, out_channels, A, stride=1, residual=True, adaptive=True, attention=True):
|
196 |
+
super(TCN_GCN_unit, self).__init__()
|
197 |
+
self.gcn1 = unit_gcn(in_channels, out_channels, A, adaptive=adaptive, attention=attention)
|
198 |
+
self.tcn1 = unit_tcn(out_channels, out_channels, stride=stride)
|
199 |
+
self.relu = nn.ReLU(inplace=True)
|
200 |
+
# if attention:
|
201 |
+
# self.alpha = nn.Parameter(torch.zeros(1))
|
202 |
+
# self.beta = nn.Parameter(torch.ones(1))
|
203 |
+
# temporal attention
|
204 |
+
# self.conv_ta1 = nn.Conv1d(out_channels, out_channels//rt, 9, padding=4)
|
205 |
+
# self.bn = nn.BatchNorm2d(out_channels)
|
206 |
+
# bn_init(self.bn, 1)
|
207 |
+
# self.conv_ta2 = nn.Conv1d(out_channels, 1, 9, padding=4)
|
208 |
+
# nn.init.kaiming_normal_(self.conv_ta1.weight)
|
209 |
+
# nn.init.constant_(self.conv_ta1.bias, 0)
|
210 |
+
# nn.init.constant_(self.conv_ta2.weight, 0)
|
211 |
+
# nn.init.constant_(self.conv_ta2.bias, 0)
|
212 |
+
|
213 |
+
# rt = 4
|
214 |
+
# self.inter_c = out_channels // rt
|
215 |
+
# self.conv_ta1 = nn.Conv2d(out_channels, out_channels // rt, 1)
|
216 |
+
# self.conv_ta2 = nn.Conv2d(out_channels, out_channels // rt, 1)
|
217 |
+
# nn.init.constant_(self.conv_ta1.weight, 0)
|
218 |
+
# nn.init.constant_(self.conv_ta1.bias, 0)
|
219 |
+
# nn.init.constant_(self.conv_ta2.weight, 0)
|
220 |
+
# nn.init.constant_(self.conv_ta2.bias, 0)
|
221 |
+
# s attention
|
222 |
+
# num_jpts = A.shape[-1]
|
223 |
+
# ker_jpt = num_jpts - 1 if not num_jpts % 2 else num_jpts
|
224 |
+
# pad = (ker_jpt - 1) // 2
|
225 |
+
# self.conv_sa = nn.Conv1d(out_channels, 1, ker_jpt, padding=pad)
|
226 |
+
# nn.init.constant_(self.conv_sa.weight, 0)
|
227 |
+
# nn.init.constant_(self.conv_sa.bias, 0)
|
228 |
+
|
229 |
+
# channel attention
|
230 |
+
# rr = 16
|
231 |
+
# self.fc1c = nn.Linear(out_channels, out_channels // rr)
|
232 |
+
# self.fc2c = nn.Linear(out_channels // rr, out_channels)
|
233 |
+
# nn.init.kaiming_normal_(self.fc1c.weight)
|
234 |
+
# nn.init.constant_(self.fc1c.bias, 0)
|
235 |
+
# nn.init.constant_(self.fc2c.weight, 0)
|
236 |
+
# nn.init.constant_(self.fc2c.bias, 0)
|
237 |
+
#
|
238 |
+
# self.softmax = nn.Softmax(-2)
|
239 |
+
# self.sigmoid = nn.Sigmoid()
|
240 |
+
self.attention = attention
|
241 |
+
|
242 |
+
if not residual:
|
243 |
+
self.residual = lambda x: 0
|
244 |
+
|
245 |
+
elif (in_channels == out_channels) and (stride == 1):
|
246 |
+
self.residual = lambda x: x
|
247 |
+
|
248 |
+
else:
|
249 |
+
self.residual = unit_tcn(in_channels, out_channels, kernel_size=1, stride=stride)
|
250 |
+
|
251 |
+
def forward(self, x):
|
252 |
+
if self.attention:
|
253 |
+
y = self.relu(self.tcn1(self.gcn1(x)) + self.residual(x))
|
254 |
+
|
255 |
+
# spatial attention
|
256 |
+
# se = y.mean(-2) # N C V
|
257 |
+
# se1 = self.sigmoid(self.conv_sa(se))
|
258 |
+
# y = y * se1.unsqueeze(-2) + y
|
259 |
+
# a1 = se1.unsqueeze(-2)
|
260 |
+
|
261 |
+
# temporal attention
|
262 |
+
# se = y.mean(-1) # N C T
|
263 |
+
# # se1 = self.relu(self.bn(self.conv_ta1(se)))
|
264 |
+
# se2 = self.sigmoid(self.conv_ta2(se))
|
265 |
+
# # y = y * se1.unsqueeze(-1) + y
|
266 |
+
# a2 = se2.unsqueeze(-1)
|
267 |
+
|
268 |
+
# se = y # NCTV
|
269 |
+
# N, C, T, V = y.shape
|
270 |
+
# se1 = self.conv_ta1(se).permute(0, 2, 1, 3).contiguous().view(N, T, self.inter_c * V) # NTCV
|
271 |
+
# se2 = self.conv_ta2(se).permute(0, 1, 3, 2).contiguous().view(N, self.inter_c * V, T) # NCVT
|
272 |
+
# a2 = self.softmax(torch.matmul(se1, se2) / np.sqrt(se1.size(-1))) # N T T
|
273 |
+
# y = torch.matmul(y.permute(0, 1, 3, 2).contiguous().view(N, C * V, T), a2) \
|
274 |
+
# .view(N, C, V, T).permute(0, 1, 3, 2) * self.alpha + y
|
275 |
+
|
276 |
+
# channel attention
|
277 |
+
# se = y.mean(-1).mean(-1)
|
278 |
+
# se1 = self.relu(self.fc1c(se))
|
279 |
+
# se2 = self.sigmoid(self.fc2c(se1))
|
280 |
+
# # y = y * se2.unsqueeze(-1).unsqueeze(-1) + y
|
281 |
+
# a3 = se2.unsqueeze(-1).unsqueeze(-1)
|
282 |
+
#
|
283 |
+
# y = y * ((a2 + a3) / 2) + y
|
284 |
+
# y = self.bn(y)
|
285 |
+
else:
|
286 |
+
y = self.relu(self.tcn1(self.gcn1(x)) + self.residual(x))
|
287 |
+
return y
|
288 |
+
|
289 |
+
|
290 |
+
class Model(pl.LightningModule):
|
291 |
+
def __init__(self, num_class=60, num_point=25, num_person=2, graph=None, graph_args=dict(), in_channels=3,
|
292 |
+
drop_out=0, adaptive=True, attention=True, learning_rate=1e-4, weight_decay=1e-4):
|
293 |
+
super(Model, self).__init__()
|
294 |
+
|
295 |
+
# if graph is None:
|
296 |
+
# raise ValueError()
|
297 |
+
# else:
|
298 |
+
# Graph = import_class(graph)
|
299 |
+
self.graph = Graph(**graph_args)
|
300 |
+
|
301 |
+
A = self.graph.A
|
302 |
+
self.num_class = num_class
|
303 |
+
|
304 |
+
self.data_bn = nn.BatchNorm1d(num_person * in_channels * num_point)
|
305 |
+
|
306 |
+
self.l1 = TCN_GCN_unit(in_channels, 64, A, residual=False, adaptive=adaptive, attention=attention)
|
307 |
+
self.l2 = TCN_GCN_unit(64, 64, A, adaptive=adaptive, attention=attention)
|
308 |
+
self.l3 = TCN_GCN_unit(64, 64, A, adaptive=adaptive, attention=attention)
|
309 |
+
self.l4 = TCN_GCN_unit(64, 64, A, adaptive=adaptive, attention=attention)
|
310 |
+
self.l5 = TCN_GCN_unit(64, 128, A, stride=2, adaptive=adaptive, attention=attention)
|
311 |
+
self.l6 = TCN_GCN_unit(128, 128, A, adaptive=adaptive, attention=attention)
|
312 |
+
self.l7 = TCN_GCN_unit(128, 128, A, adaptive=adaptive, attention=attention)
|
313 |
+
self.l8 = TCN_GCN_unit(128, 256, A, stride=2, adaptive=adaptive, attention=attention)
|
314 |
+
self.l9 = TCN_GCN_unit(256, 256, A, adaptive=adaptive, attention=attention)
|
315 |
+
self.l10 = TCN_GCN_unit(256, 256, A, adaptive=adaptive, attention=attention)
|
316 |
+
# self.l11 = TCN_GCN_unit(256, 512, A, stride=2, adaptive=adaptive, attention=attention)
|
317 |
+
# self.l12 = TCN_GCN_unit(512, 512, A, adaptive=adaptive, attention=attention)
|
318 |
+
# self.l13 = TCN_GCN_unit(512, 512, A, adaptive=adaptive, attention=attention)
|
319 |
+
|
320 |
+
self.fc = nn.Linear(256, num_class)
|
321 |
+
nn.init.normal_(self.fc.weight, 0, math.sqrt(2. / num_class))
|
322 |
+
bn_init(self.data_bn, 1)
|
323 |
+
if drop_out:
|
324 |
+
self.drop_out = nn.Dropout(drop_out)
|
325 |
+
else:
|
326 |
+
self.drop_out = lambda x: x
|
327 |
+
|
328 |
+
self.loss = nn.CrossEntropyLoss()
|
329 |
+
self.metric = MulticlassAccuracy(num_class)
|
330 |
+
# self.metric = BinaryAccuracy()
|
331 |
+
self.learning_rate = learning_rate
|
332 |
+
self.weight_decay = weight_decay
|
333 |
+
self.validation_step_loss_outputs = []
|
334 |
+
self.validation_step_acc_outputs = []
|
335 |
+
|
336 |
+
self.save_hyperparameters()
|
337 |
+
|
338 |
+
def forward(self, x):
|
339 |
+
N, C, T, V, M = x.size()
|
340 |
+
|
341 |
+
x = x.permute(0, 4, 3, 1, 2).contiguous().view(N, M * V * C, T)
|
342 |
+
x = self.data_bn(x.float())
|
343 |
+
x = x.view(N, M, V, C, T).permute(0, 1, 3, 4, 2).contiguous().view(N * M, C, T, V)
|
344 |
+
|
345 |
+
x = self.l1(x)
|
346 |
+
x = self.l2(x)
|
347 |
+
x = self.l3(x)
|
348 |
+
x = self.l4(x)
|
349 |
+
x = self.l5(x)
|
350 |
+
x = self.l6(x)
|
351 |
+
x = self.l7(x)
|
352 |
+
x = self.l8(x)
|
353 |
+
x = self.l9(x)
|
354 |
+
x = self.l10(x)
|
355 |
+
# x = self.l11(x)
|
356 |
+
# x = self.l12(x)
|
357 |
+
# x = self.l13(x)
|
358 |
+
|
359 |
+
# N*M,C,T,V
|
360 |
+
c_new = x.size(1)
|
361 |
+
x = x.view(N, M, c_new, -1)
|
362 |
+
x = x.mean(3).mean(1)
|
363 |
+
x = self.drop_out(x)
|
364 |
+
|
365 |
+
return self.fc(x)
|
366 |
+
|
367 |
+
def training_step(self, batch, batch_idx):
|
368 |
+
inputs, targets = batch
|
369 |
+
outputs = self(inputs)
|
370 |
+
y_pred_class = torch.argmax(torch.softmax(outputs, dim=1), dim=1)
|
371 |
+
# print("Targets : ", targets)
|
372 |
+
# print("Preds : ", y_pred_class)
|
373 |
+
train_accuracy = self.metric(y_pred_class, targets)
|
374 |
+
loss = self.loss(outputs, targets)
|
375 |
+
self.log('train_accuracy', train_accuracy, prog_bar=True, on_epoch=True)
|
376 |
+
self.log('train_loss', loss, prog_bar=True, on_epoch=True)
|
377 |
+
# return {"loss": loss, "train_accuracy" : train_accuracy}
|
378 |
+
return loss
|
379 |
+
|
380 |
+
def validation_step(self, batch, batch_idx):
|
381 |
+
inputs, targets = batch
|
382 |
+
outputs = self.forward(inputs)
|
383 |
+
y_pred_class = torch.argmax(torch.softmax(outputs, dim=1), dim=1)
|
384 |
+
valid_accuracy = self.metric(y_pred_class, targets)
|
385 |
+
loss = self.loss(outputs, targets)
|
386 |
+
self.log('valid_accuracy', valid_accuracy, prog_bar=True, on_epoch=True)
|
387 |
+
self.log('valid_loss', loss, prog_bar=True, on_epoch=True)
|
388 |
+
self.validation_step_loss_outputs.append(loss)
|
389 |
+
self.validation_step_acc_outputs.append(valid_accuracy)
|
390 |
+
return {"valid_loss" : loss, "valid_accuracy" : valid_accuracy}
|
391 |
+
|
392 |
+
def on_validation_epoch_end(self):
|
393 |
+
# avg_loss = torch.stack(
|
394 |
+
# [x["valid_loss"] for x in outputs]).mean()
|
395 |
+
# avg_acc = torch.stack(
|
396 |
+
# [x["valid_accuracy"] for x in outputs]).mean()
|
397 |
+
avg_loss = torch.stack(self.validation_step_loss_outputs).mean()
|
398 |
+
avg_acc = torch.stack(self.validation_step_acc_outputs).mean()
|
399 |
+
self.log("ptl/val_loss", avg_loss)
|
400 |
+
self.log("ptl/val_accuracy", avg_acc)
|
401 |
+
self.validation_step_loss_outputs.clear()
|
402 |
+
self.validation_step_acc_outputs.clear()
|
403 |
+
|
404 |
+
def test_step(self, batch, batch_idx):
|
405 |
+
inputs, targets = batch
|
406 |
+
outputs = self.forward(inputs)
|
407 |
+
y_pred_class = torch.argmax(torch.softmax(outputs, dim=1), dim=1)
|
408 |
+
print("Targets : ", targets)
|
409 |
+
print("Preds : ", y_pred_class)
|
410 |
+
test_accuracy = self.metric(y_pred_class, targets)
|
411 |
+
loss = self.loss(outputs, targets)
|
412 |
+
self.log('test_accuracy', test_accuracy, prog_bar=True, on_epoch=True)
|
413 |
+
self.log('test_loss', loss, prog_bar=True, on_epoch=True)
|
414 |
+
return {"test_loss" : loss, "test_accuracy" : test_accuracy}
|
415 |
+
|
416 |
+
def configure_optimizers(self):
|
417 |
+
params = self.parameters()
|
418 |
+
optimizer = optim.Adam(params=params, lr = self.learning_rate, weight_decay = self.weight_decay)
|
419 |
+
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max')
|
420 |
+
return {"optimizer": optimizer,
|
421 |
+
"lr_scheduler": {"scheduler": scheduler, "monitor": "valid_accuracy"}
|
422 |
+
}
|
423 |
+
# return optimizer
|
424 |
+
|
425 |
+
def predict_step(self, batch, batch_idx):
|
426 |
+
return self(batch)
|
427 |
+
|
428 |
+
if __name__ == "__main__":
|
429 |
+
import os
|
430 |
+
from torchinfo import summary
|
431 |
+
print(os.getcwd())
|
432 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
433 |
+
model = Model(num_class=20, num_point=18, num_person=1,
|
434 |
+
graph_args={}, in_channels=2).to(device)
|
435 |
+
# print(model.device)
|
436 |
+
# N, C, T, V, M
|
437 |
+
summary(model)
|
438 |
+
x = torch.randn((1, 2, 80, 18, 1)).to(device)
|
439 |
+
y = model(x)
|
440 |
+
print(y.shape)
|
@@ -0,0 +1,275 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from torch.autograd import Variable
|
7 |
+
from graph import Graph
|
8 |
+
import pytorch_lightning as pl
|
9 |
+
from torchmetrics.classification import MulticlassAccuracy
|
10 |
+
import torch.optim as optim
|
11 |
+
|
12 |
+
def import_class(name):
|
13 |
+
components = name.split('.')
|
14 |
+
mod = __import__(components[0])
|
15 |
+
for comp in components[1:]:
|
16 |
+
mod = getattr(mod, comp)
|
17 |
+
return mod
|
18 |
+
|
19 |
+
|
20 |
+
def conv_branch_init(conv, branches):
|
21 |
+
weight = conv.weight
|
22 |
+
n = weight.size(0)
|
23 |
+
k1 = weight.size(1)
|
24 |
+
k2 = weight.size(2)
|
25 |
+
nn.init.normal_(weight, 0, math.sqrt(2. / (n * k1 * k2 * branches)))
|
26 |
+
nn.init.constant_(conv.bias, 0)
|
27 |
+
|
28 |
+
|
29 |
+
def conv_init(conv):
|
30 |
+
nn.init.kaiming_normal_(conv.weight, mode='fan_out')
|
31 |
+
nn.init.constant_(conv.bias, 0)
|
32 |
+
|
33 |
+
|
34 |
+
def bn_init(bn, scale):
|
35 |
+
nn.init.constant_(bn.weight, scale)
|
36 |
+
nn.init.constant_(bn.bias, 0)
|
37 |
+
|
38 |
+
|
39 |
+
class unit_tcn(nn.Module):
|
40 |
+
def __init__(self, in_channels, out_channels, kernel_size=9, stride=1):
|
41 |
+
super(unit_tcn, self).__init__()
|
42 |
+
pad = int((kernel_size - 1) / 2)
|
43 |
+
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=(kernel_size, 1), padding=(pad, 0),
|
44 |
+
stride=(stride, 1))
|
45 |
+
|
46 |
+
self.bn = nn.BatchNorm2d(out_channels)
|
47 |
+
self.relu = nn.ReLU()
|
48 |
+
conv_init(self.conv)
|
49 |
+
bn_init(self.bn, 1)
|
50 |
+
|
51 |
+
def forward(self, x):
|
52 |
+
x = self.bn(self.conv(x))
|
53 |
+
return x
|
54 |
+
|
55 |
+
|
56 |
+
class unit_gcn(nn.Module):
|
57 |
+
def __init__(self, in_channels, out_channels, A, coff_embedding=4, num_subset=3):
|
58 |
+
super(unit_gcn, self).__init__()
|
59 |
+
inter_channels = out_channels // coff_embedding
|
60 |
+
self.inter_c = inter_channels
|
61 |
+
self.PA = nn.Parameter(torch.from_numpy(A.astype(np.float32)))
|
62 |
+
nn.init.constant_(self.PA, 1e-6)
|
63 |
+
self.A = Variable(torch.from_numpy(A.astype(np.float32)), requires_grad=False)
|
64 |
+
self.num_subset = num_subset
|
65 |
+
|
66 |
+
self.conv_a = nn.ModuleList()
|
67 |
+
self.conv_b = nn.ModuleList()
|
68 |
+
self.conv_d = nn.ModuleList()
|
69 |
+
for i in range(self.num_subset):
|
70 |
+
self.conv_a.append(nn.Conv2d(in_channels, inter_channels, 1))
|
71 |
+
self.conv_b.append(nn.Conv2d(in_channels, inter_channels, 1))
|
72 |
+
self.conv_d.append(nn.Conv2d(in_channels, out_channels, 1))
|
73 |
+
|
74 |
+
if in_channels != out_channels:
|
75 |
+
self.down = nn.Sequential(
|
76 |
+
nn.Conv2d(in_channels, out_channels, 1),
|
77 |
+
nn.BatchNorm2d(out_channels)
|
78 |
+
)
|
79 |
+
else:
|
80 |
+
self.down = lambda x: x
|
81 |
+
|
82 |
+
self.bn = nn.BatchNorm2d(out_channels)
|
83 |
+
self.soft = nn.Softmax(-2)
|
84 |
+
self.relu = nn.ReLU()
|
85 |
+
|
86 |
+
for m in self.modules():
|
87 |
+
if isinstance(m, nn.Conv2d):
|
88 |
+
conv_init(m)
|
89 |
+
elif isinstance(m, nn.BatchNorm2d):
|
90 |
+
bn_init(m, 1)
|
91 |
+
bn_init(self.bn, 1e-6)
|
92 |
+
for i in range(self.num_subset):
|
93 |
+
conv_branch_init(self.conv_d[i], self.num_subset)
|
94 |
+
|
95 |
+
def forward(self, x):
|
96 |
+
N, C, T, V = x.size()
|
97 |
+
A = self.A.cuda(x.get_device())
|
98 |
+
A = A + self.PA
|
99 |
+
|
100 |
+
y = None
|
101 |
+
for i in range(self.num_subset):
|
102 |
+
A1 = self.conv_a[i](x).permute(0, 3, 1, 2).contiguous().view(N, V, self.inter_c * T)
|
103 |
+
A2 = self.conv_b[i](x).view(N, self.inter_c * T, V)
|
104 |
+
A1 = self.soft(torch.matmul(A1, A2) / A1.size(-1)) # N V V
|
105 |
+
A1 = A1 + A[i]
|
106 |
+
A2 = x.view(N, C * T, V)
|
107 |
+
z = self.conv_d[i](torch.matmul(A2, A1).view(N, C, T, V))
|
108 |
+
y = z + y if y is not None else z
|
109 |
+
|
110 |
+
y = self.bn(y)
|
111 |
+
y += self.down(x)
|
112 |
+
return self.relu(y)
|
113 |
+
|
114 |
+
|
115 |
+
class TCN_GCN_unit(nn.Module):
|
116 |
+
def __init__(self, in_channels, out_channels, A, stride=1, residual=True):
|
117 |
+
super(TCN_GCN_unit, self).__init__()
|
118 |
+
self.gcn1 = unit_gcn(in_channels, out_channels, A)
|
119 |
+
self.tcn1 = unit_tcn(out_channels, out_channels, stride=stride)
|
120 |
+
self.relu = nn.ReLU()
|
121 |
+
if not residual:
|
122 |
+
self.residual = lambda x: 0
|
123 |
+
|
124 |
+
elif (in_channels == out_channels) and (stride == 1):
|
125 |
+
self.residual = lambda x: x
|
126 |
+
|
127 |
+
else:
|
128 |
+
self.residual = unit_tcn(in_channels, out_channels, kernel_size=1, stride=stride)
|
129 |
+
|
130 |
+
def forward(self, x):
|
131 |
+
x = self.tcn1(self.gcn1(x)) + self.residual(x)
|
132 |
+
return self.relu(x)
|
133 |
+
|
134 |
+
|
135 |
+
class Model(pl.LightningModule):
|
136 |
+
def __init__(self, num_class=60, num_point=25, num_person=2, graph=None, graph_args=dict(), in_channels=3,
|
137 |
+
learning_rate = 1e-4, weight_decay = 1e-4):
|
138 |
+
super(Model, self).__init__()
|
139 |
+
|
140 |
+
# if graph is None:
|
141 |
+
# raise ValueError()
|
142 |
+
# else:
|
143 |
+
# Graph = import_class(graph)
|
144 |
+
self.graph = Graph(**graph_args)
|
145 |
+
|
146 |
+
A = self.graph.A
|
147 |
+
# print(num_person * in_channels * num_point)
|
148 |
+
self.data_bn = nn.BatchNorm1d(num_person * in_channels * num_point)
|
149 |
+
|
150 |
+
self.l1 = TCN_GCN_unit(in_channels, 64, A, residual=False)
|
151 |
+
self.l2 = TCN_GCN_unit(64, 64, A)
|
152 |
+
self.l3 = TCN_GCN_unit(64, 64, A)
|
153 |
+
self.l4 = TCN_GCN_unit(64, 64, A)
|
154 |
+
self.l5 = TCN_GCN_unit(64, 128, A, stride=2)
|
155 |
+
self.l6 = TCN_GCN_unit(128, 128, A)
|
156 |
+
self.l7 = TCN_GCN_unit(128, 128, A)
|
157 |
+
self.l8 = TCN_GCN_unit(128, 256, A, stride=2)
|
158 |
+
self.l9 = TCN_GCN_unit(256, 256, A)
|
159 |
+
self.l10 = TCN_GCN_unit(256, 256, A)
|
160 |
+
|
161 |
+
self.fc = nn.Linear(256, num_class)
|
162 |
+
nn.init.normal_(self.fc.weight, 0, math.sqrt(2. / num_class))
|
163 |
+
bn_init(self.data_bn, 1)
|
164 |
+
|
165 |
+
self.loss = nn.CrossEntropyLoss()
|
166 |
+
self.metric = MulticlassAccuracy(num_class)
|
167 |
+
self.learning_rate = learning_rate
|
168 |
+
self.weight_decay = weight_decay
|
169 |
+
self.validation_step_loss_outputs = []
|
170 |
+
self.validation_step_acc_outputs = []
|
171 |
+
|
172 |
+
self.save_hyperparameters()
|
173 |
+
|
174 |
+
def forward(self, x):
|
175 |
+
# 0, 1, 2, 3, 4
|
176 |
+
N, C, T, V, M = x.size()
|
177 |
+
# print(f"N {N}, C {C}, T {T}, V {V}, M {M}")
|
178 |
+
# N, M, V, C, T
|
179 |
+
x = x.permute(0, 4, 3, 1, 2).contiguous().view(N, M * V * C, T)
|
180 |
+
# print(M*V*C)
|
181 |
+
x = self.data_bn(x)
|
182 |
+
x = x.view(N, M, V, C, T).permute(0, 1, 3, 4, 2).contiguous().view(N * M, C, T, V)
|
183 |
+
|
184 |
+
x = self.l1(x)
|
185 |
+
x = self.l2(x)
|
186 |
+
x = self.l3(x)
|
187 |
+
x = self.l4(x)
|
188 |
+
x = self.l5(x)
|
189 |
+
x = self.l6(x)
|
190 |
+
x = self.l7(x)
|
191 |
+
x = self.l8(x)
|
192 |
+
x = self.l9(x)
|
193 |
+
x = self.l10(x)
|
194 |
+
|
195 |
+
# N*M,C,T,V
|
196 |
+
c_new = x.size(1)
|
197 |
+
x = x.view(N, M, c_new, -1)
|
198 |
+
x = x.mean(3).mean(1)
|
199 |
+
|
200 |
+
return self.fc(x)
|
201 |
+
|
202 |
+
def training_step(self, batch, batch_idx):
|
203 |
+
inputs, targets = batch
|
204 |
+
outputs = self(inputs)
|
205 |
+
y_pred_class = torch.argmax(torch.softmax(outputs, dim=1), dim=1)
|
206 |
+
# print("Targets : ", targets)
|
207 |
+
# print("Preds : ", y_pred_class)
|
208 |
+
train_accuracy = self.metric(y_pred_class, targets)
|
209 |
+
loss = self.loss(outputs, targets)
|
210 |
+
self.log('train_accuracy', train_accuracy, prog_bar=True, on_epoch=True)
|
211 |
+
self.log('train_loss', loss, prog_bar=True, on_epoch=True)
|
212 |
+
# return {"loss": loss, "train_accuracy" : train_accuracy}
|
213 |
+
return loss
|
214 |
+
|
215 |
+
def validation_step(self, batch, batch_idx):
|
216 |
+
inputs, targets = batch
|
217 |
+
outputs = self.forward(inputs)
|
218 |
+
y_pred_class = torch.argmax(torch.softmax(outputs, dim=1), dim=1)
|
219 |
+
valid_accuracy = self.metric(y_pred_class, targets)
|
220 |
+
loss = self.loss(outputs, targets)
|
221 |
+
self.log('valid_accuracy', valid_accuracy, prog_bar=True, on_epoch=True)
|
222 |
+
self.log('valid_loss', loss, prog_bar=True, on_epoch=True)
|
223 |
+
self.validation_step_loss_outputs.append(loss)
|
224 |
+
self.validation_step_acc_outputs.append(valid_accuracy)
|
225 |
+
return {"valid_loss" : loss, "valid_accuracy" : valid_accuracy}
|
226 |
+
|
227 |
+
def on_validation_epoch_end(self):
|
228 |
+
# avg_loss = torch.stack(
|
229 |
+
# [x["valid_loss"] for x in outputs]).mean()
|
230 |
+
# avg_acc = torch.stack(
|
231 |
+
# [x["valid_accuracy"] for x in outputs]).mean()
|
232 |
+
avg_loss = torch.stack(self.validation_step_loss_outputs).mean()
|
233 |
+
avg_acc = torch.stack(self.validation_step_acc_outputs).mean()
|
234 |
+
self.log("ptl/val_loss", avg_loss)
|
235 |
+
self.log("ptl/val_accuracy", avg_acc)
|
236 |
+
self.validation_step_loss_outputs.clear()
|
237 |
+
self.validation_step_acc_outputs.clear()
|
238 |
+
|
239 |
+
def test_step(self, batch, batch_idx):
|
240 |
+
inputs, targets = batch
|
241 |
+
outputs = self.forward(inputs)
|
242 |
+
y_pred_class = torch.argmax(torch.softmax(outputs, dim=1), dim=1)
|
243 |
+
print("Targets : ", targets)
|
244 |
+
print("Preds : ", y_pred_class)
|
245 |
+
test_accuracy = self.metric(y_pred_class, targets)
|
246 |
+
loss = self.loss(outputs, targets)
|
247 |
+
self.log('test_accuracy', test_accuracy, prog_bar=True, on_epoch=True)
|
248 |
+
self.log('test_loss', loss, prog_bar=True, on_epoch=True)
|
249 |
+
return {"test_loss" : loss, "test_accuracy" : test_accuracy}
|
250 |
+
|
251 |
+
def configure_optimizers(self):
|
252 |
+
params = self.parameters()
|
253 |
+
optimizer = optim.Adam(params=params, lr = self.learning_rate, weight_decay = self.weight_decay)
|
254 |
+
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max')
|
255 |
+
return {"optimizer": optimizer,
|
256 |
+
"lr_scheduler": {"scheduler": scheduler, "monitor": "valid_accuracy"}
|
257 |
+
}
|
258 |
+
# return optimizer
|
259 |
+
|
260 |
+
def predict_step(self, batch, batch_idx):
|
261 |
+
return self(batch)
|
262 |
+
|
263 |
+
if __name__ == "__main__":
|
264 |
+
import os
|
265 |
+
from torchinfo import summary
|
266 |
+
print(os.getcwd())
|
267 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
268 |
+
model = Model(num_class=20, num_point=25, num_person=1,
|
269 |
+
graph_args={"layout":"mediapipe", "strategy":"spatial"}, in_channels=2).to(device)
|
270 |
+
# print(model.device)
|
271 |
+
# N, C, T, V, M
|
272 |
+
summary(model)
|
273 |
+
x = torch.randn((1, 2, 80, 25, 1)).to(device)
|
274 |
+
y = model(x)
|
275 |
+
print(y.shape)
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import numpy as np
|
3 |
+
from numpy import random
|
4 |
+
from einops import einsum
|
5 |
+
import torch
|
6 |
+
class Rotate(object):
|
7 |
+
def __init__(self, range_angle, num_frames, num_nodes, point):
|
8 |
+
self.range_angle = range_angle
|
9 |
+
self.num_frames = num_frames
|
10 |
+
self.num_nodes = num_nodes
|
11 |
+
self.point = point
|
12 |
+
def __call__(self, data, label):
|
13 |
+
# data, label = sample
|
14 |
+
data = data.double()
|
15 |
+
angle = math.radians(random.uniform((-1)*self.range_angle, self.range_angle))
|
16 |
+
rotation_matrix = torch.Tensor([[math.cos(angle), (-1)*math.sin(angle)],
|
17 |
+
[math.sin(angle), math.cos(angle)]])
|
18 |
+
# print(type(rotation_matrix))
|
19 |
+
ox, oy = self.point
|
20 |
+
data[0, :, :] -= ox
|
21 |
+
data[1, :, :] -= oy
|
22 |
+
|
23 |
+
result = einsum(rotation_matrix.double(), data, "a b, b c d e -> a c d e") + 0.5
|
24 |
+
|
25 |
+
return result, label
|
26 |
+
|
27 |
+
|
28 |
+
class Left(object):
|
29 |
+
def __init__(self, width):
|
30 |
+
self.width = width
|
31 |
+
def __call__(self, data, label):
|
32 |
+
idx = find_frames(data)
|
33 |
+
p = random.random()
|
34 |
+
if p > 0.5:
|
35 |
+
data[0, :idx, :] -= self.width
|
36 |
+
return data, label
|
37 |
+
|
38 |
+
class Right(object):
|
39 |
+
def __init__(self, width):
|
40 |
+
self.width = width
|
41 |
+
def __call__(self, data, label):
|
42 |
+
idx = find_frames(data)
|
43 |
+
p = random.random()
|
44 |
+
if p > 0.5:
|
45 |
+
data[0, :idx, :] += self.width
|
46 |
+
return data, label
|
47 |
+
|
48 |
+
class GaussianNoise(object):
|
49 |
+
def __init__(self, mean, var):
|
50 |
+
self.mean = mean
|
51 |
+
self.var = var
|
52 |
+
def __call__(self, data, label):
|
53 |
+
# C, T, V, 1
|
54 |
+
print(data.size())
|
55 |
+
noise = torch.randn(size = data.size())
|
56 |
+
data = data + noise
|
57 |
+
return data, label
|
58 |
+
|
59 |
+
class Compose(object):
|
60 |
+
def __init__(self, transforms):
|
61 |
+
self.transforms = transforms
|
62 |
+
|
63 |
+
def __call__(self, data, label):
|
64 |
+
for t in self.transforms:
|
65 |
+
data, label = t(data, label)
|
66 |
+
return data, label
|
67 |
+
|
68 |
+
def find_frames(data):
|
69 |
+
for i in range(data.shape[1]):
|
70 |
+
if(data[:, i, :][0][0] == 0):
|
71 |
+
# print(i)
|
72 |
+
return i
|
73 |
+
return data.shape[1]
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0ede54d8c48503897bdad5d8f680cfe1d24d1960242d11e1c6a704fb7bbb1dbb
|
3 |
+
size 47248523
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.utils.data import Dataset
|
3 |
+
from pathlib import Path
|
4 |
+
import numpy as np
|
5 |
+
import random
|
6 |
+
from einops import rearrange
|
7 |
+
from augumentation import Rotate
|
8 |
+
from torch.utils.data import random_split
|
9 |
+
|
10 |
+
# Class read npy and pickle file to make data and label in couple
|
11 |
+
class FeederINCLUDE(Dataset):
|
12 |
+
""" Feeder for skeleton-based action recognition
|
13 |
+
Arguments:
|
14 |
+
data_path: the path to '.npy' data, the shape of data should be (N, C, T, V, M)
|
15 |
+
label_path: the path to label
|
16 |
+
window_size: The length of the output sequence
|
17 |
+
"""
|
18 |
+
def __init__(self, data_path: Path, label_path: Path, transform = None):
|
19 |
+
super(FeederINCLUDE, self).__init__
|
20 |
+
self.data_path = data_path
|
21 |
+
self.label_path = label_path
|
22 |
+
self.transform = transform
|
23 |
+
self.load_data()
|
24 |
+
|
25 |
+
def load_data(self):
|
26 |
+
# data: N C V T M
|
27 |
+
# Load label with numpy
|
28 |
+
self.label = np.load(self.label_path)
|
29 |
+
# load data
|
30 |
+
self.data = np.load(self.data_path)
|
31 |
+
self.N, self.C, self.T, self.V, self.M = self.data.shape
|
32 |
+
|
33 |
+
def __getitem__(self, index):
|
34 |
+
"""
|
35 |
+
Input shape (N, C, V, T, M)
|
36 |
+
N : batch size
|
37 |
+
C : numbers of features
|
38 |
+
V : numbers of joints (as nodes)
|
39 |
+
T : numbers of frames
|
40 |
+
M : numbers of people (should delete)
|
41 |
+
|
42 |
+
Output shape (C, V, T, M)
|
43 |
+
C : numbers of features
|
44 |
+
V : numbers of joints (as nodes)
|
45 |
+
T : numbers of frames
|
46 |
+
label : label of videos
|
47 |
+
"""
|
48 |
+
data_numpy = torch.tensor(self.data[index]).float()
|
49 |
+
# Delete one dimension
|
50 |
+
# data_numpy = data_numpy[:, :, :2]
|
51 |
+
# data_numpy = rearrange(data_numpy, ' t v c 1 -> c t v 1')
|
52 |
+
label = self.label[index]
|
53 |
+
p = random.random()
|
54 |
+
if self.transform and p > 0.5:
|
55 |
+
data_numpy, label = self.transform(data_numpy, label)
|
56 |
+
return data_numpy, label
|
57 |
+
|
58 |
+
def __len__(self):
|
59 |
+
return len(self.label)
|
60 |
+
|
61 |
+
if __name__ == '__main__':
|
62 |
+
file, label = np.load("wsl100_train_data_preprocess.npy"), np.load("wsl100_train_label_preprocess.npy")
|
63 |
+
print(file.shape, label.shape)
|
64 |
+
data = FeederINCLUDE(data_path=f"wsl100_train_data_preprocess.npy", label_path=f"wsl100_train_data_preprocess.npy",
|
65 |
+
transform=None)
|
66 |
+
# test_dataset = FeederINCLUDE(data_path=f"data/vsl100_test_data_preprocess.npy", label_path=f"data/vsl100_test_label_preprocess.npy")
|
67 |
+
# valid_dataset = FeederINCLUDE(data_path=f"data/vsl100_valid_data_preprocess.npy", label_path=f"data/vsl100_valid_label_preprocess.npy")
|
68 |
+
# data = FeederINCLUDE(data_path=f"data/vsl100_test_data_preprocess.npy", label_path=f"data/vsl100_test_label_preprocess.npy",
|
69 |
+
# transform=None)
|
70 |
+
print(data.N, data.C, data.T, data.V, data.M)
|
71 |
+
print(data.data.shape)
|
72 |
+
print(data.__len__())
|
@@ -0,0 +1,287 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from enum import Enum
|
3 |
+
|
4 |
+
class BodyIdentifier(Enum):
|
5 |
+
INDEX_FINGER_DIP_right = 0
|
6 |
+
INDEX_FINGER_MCP_right = 1
|
7 |
+
INDEX_FINGER_PIP_right = 2
|
8 |
+
INDEX_FINGER_TIP_right = 3
|
9 |
+
MIDDLE_FINGER_DIP_right = 4
|
10 |
+
MIDDLE_FINGER_MCP_right = 5
|
11 |
+
MIDDLE_FINGER_PIP_right = 6
|
12 |
+
MIDDLE_FINGER_TIP_right = 7
|
13 |
+
PINKY_DIP_right = 8
|
14 |
+
PINKY_MCP_right = 9
|
15 |
+
PINKY_PIP_right = 10
|
16 |
+
PINKY_TIP_right = 11
|
17 |
+
RING_FINGER_DIP_right = 12
|
18 |
+
RING_FINGER_MCP_right = 13
|
19 |
+
RING_FINGER_PIP_right = 14
|
20 |
+
RING_FINGER_TIP_right = 15
|
21 |
+
THUMB_CMC_right = 16
|
22 |
+
THUMB_IP_right = 17
|
23 |
+
THUMB_MCP_right = 18
|
24 |
+
THUMB_TIP_right = 19
|
25 |
+
WRIST_right = 20
|
26 |
+
INDEX_FINGER_DIP_left = 21
|
27 |
+
INDEX_FINGER_MCP_left = 22
|
28 |
+
INDEX_FINGER_PIP_left = 23
|
29 |
+
INDEX_FINGER_TIP_left = 24
|
30 |
+
MIDDLE_FINGER_DIP_left = 25
|
31 |
+
MIDDLE_FINGER_MCP_left = 26
|
32 |
+
MIDDLE_FINGER_PIP_left = 27
|
33 |
+
MIDDLE_FINGER_TIP_left = 28
|
34 |
+
PINKY_DIP_left = 29
|
35 |
+
PINKY_MCP_left = 30
|
36 |
+
PINKY_PIP_left = 31
|
37 |
+
PINKY_TIP_left = 32
|
38 |
+
RING_FINGER_DIP_left = 33
|
39 |
+
RING_FINGER_MCP_left = 34
|
40 |
+
RING_FINGER_PIP_left = 35
|
41 |
+
RING_FINGER_TIP_left = 36
|
42 |
+
THUMB_CMC_left = 37
|
43 |
+
THUMB_IP_left = 38
|
44 |
+
THUMB_MCP_left = 39
|
45 |
+
THUMB_TIP_left = 40
|
46 |
+
WRIST_left = 41
|
47 |
+
RIGHT_SHOULDER = 42
|
48 |
+
LEFT_SHOULDER = 43
|
49 |
+
LEFT_ELBOW = 44
|
50 |
+
RIGHT_ELBOW = 45
|
51 |
+
|
52 |
+
class Graph():
|
53 |
+
""" The Graph to model the skeletons extracted by the openpose
|
54 |
+
|
55 |
+
Args:
|
56 |
+
strategy (string): must be one of the follow candidates
|
57 |
+
- uniform: Uniform Labeling
|
58 |
+
- distance: Distance Partitioning
|
59 |
+
- spatial: Spatial Configuration
|
60 |
+
For more information, please refer to the section 'Partition Strategies'
|
61 |
+
in our paper (https://arxiv.org/abs/1801.07455).
|
62 |
+
|
63 |
+
layout (string): must be one of the follow candidates
|
64 |
+
- openpose: Is consists of 18 joints. For more information, please
|
65 |
+
refer to https://github.com/CMU-Perceptual-Computing-Lab/openpose#output
|
66 |
+
- ntu-rgb+d: Is consists of 25 joints. For more information, please
|
67 |
+
refer to https://github.com/shahroudy/NTURGB-D
|
68 |
+
|
69 |
+
max_hop (int): the maximal distance between two connected nodes
|
70 |
+
dilation (int): controls the spacing between the kernel points
|
71 |
+
|
72 |
+
"""
|
73 |
+
|
74 |
+
def __init__(self,
|
75 |
+
layout='openpose',
|
76 |
+
strategy='uniform',
|
77 |
+
max_hop=1,
|
78 |
+
dilation=1):
|
79 |
+
self.max_hop = max_hop
|
80 |
+
self.dilation = dilation
|
81 |
+
|
82 |
+
self.get_edge(layout)
|
83 |
+
self.hop_dis = get_hop_distance(
|
84 |
+
self.num_node, self.edge, max_hop=max_hop)
|
85 |
+
self.get_adjacency(strategy)
|
86 |
+
|
87 |
+
def __str__(self):
|
88 |
+
return self.A
|
89 |
+
|
90 |
+
def get_edge(self, layout):
|
91 |
+
if layout == 'openpose':
|
92 |
+
self.num_node = 18
|
93 |
+
self_link = [(i, i) for i in range(self.num_node)]
|
94 |
+
neighbor_link = [(4, 3), (3, 2), (7, 6), (6, 5), (13, 12), (12,
|
95 |
+
11),
|
96 |
+
(10, 9), (9, 8), (11, 5), (8, 2), (5, 1), (2, 1),
|
97 |
+
(0, 1), (15, 0), (14, 0), (17, 15), (16, 14)]
|
98 |
+
self.edge = self_link + neighbor_link
|
99 |
+
self.center = 1
|
100 |
+
elif layout == 'ntu-rgb+d':
|
101 |
+
self.num_node = 25
|
102 |
+
self_link = [(i, i) for i in range(self.num_node)]
|
103 |
+
neighbor_1base = [(1, 2), (2, 21), (3, 21), (4, 3), (5, 21),
|
104 |
+
(6, 5), (7, 6), (8, 7), (9, 21), (10, 9),
|
105 |
+
(11, 10), (12, 11), (13, 1), (14, 13), (15, 14),
|
106 |
+
(16, 15), (17, 1), (18, 17), (19, 18), (20, 19),
|
107 |
+
(22, 23), (23, 8), (24, 25), (25, 12)]
|
108 |
+
neighbor_link = [(i - 1, j - 1) for (i, j) in neighbor_1base]
|
109 |
+
self.edge = self_link + neighbor_link
|
110 |
+
self.center = 21 - 1
|
111 |
+
elif layout == 'ntu_edge':
|
112 |
+
self.num_node = 24
|
113 |
+
self_link = [(i, i) for i in range(self.num_node)]
|
114 |
+
neighbor_1base = [(1, 2), (3, 2), (4, 3), (5, 2), (6, 5), (7, 6),
|
115 |
+
(8, 7), (9, 2), (10, 9), (11, 10), (12, 11),
|
116 |
+
(13, 1), (14, 13), (15, 14), (16, 15), (17, 1),
|
117 |
+
(18, 17), (19, 18), (20, 19), (21, 22), (22, 8),
|
118 |
+
(23, 24), (24, 12)]
|
119 |
+
neighbor_link = [(i - 1, j - 1) for (i, j) in neighbor_1base]
|
120 |
+
self.edge = self_link + neighbor_link
|
121 |
+
self.center = 2
|
122 |
+
elif layout == 'mediapipe':
|
123 |
+
self.num_node = 25
|
124 |
+
self_link = [(i, i) for i in range(self.num_node)]
|
125 |
+
neighbor_1base = [(20, 18), (18, 16), (20, 16), (16, 22), (16, 14), (14, 12),
|
126 |
+
(19, 17), (17, 15), (19, 15), (15, 21), (15, 13), (13, 11),
|
127 |
+
(12, 11), (12, 24), (24, 23), (23, 11),
|
128 |
+
(10, 9),
|
129 |
+
(0, 4), (4, 5), (5, 6), (6, 8),
|
130 |
+
(0, 1), (1, 2), (2, 3), (3, 7)]
|
131 |
+
neighbor_link = [(i - 1, j - 1) for (i, j) in neighbor_1base]
|
132 |
+
self.edge = self_link + neighbor_link
|
133 |
+
self.center = 10
|
134 |
+
|
135 |
+
elif layout == "mediapipe_two_hand":
|
136 |
+
self.num_node = 46
|
137 |
+
self_link = [(i, i) for i in range(self.num_node)]
|
138 |
+
neighbor_1base = [(BodyIdentifier.WRIST_left.value, BodyIdentifier.THUMB_CMC_left.value),
|
139 |
+
(BodyIdentifier.THUMB_CMC_left.value, BodyIdentifier.THUMB_MCP_left.value),
|
140 |
+
(BodyIdentifier.THUMB_MCP_left.value, BodyIdentifier.THUMB_IP_left.value),
|
141 |
+
(BodyIdentifier.THUMB_IP_left.value, BodyIdentifier.THUMB_TIP_left.value),
|
142 |
+
|
143 |
+
(BodyIdentifier.WRIST_left.value, BodyIdentifier.INDEX_FINGER_MCP_left.value),
|
144 |
+
(BodyIdentifier.INDEX_FINGER_MCP_left.value, BodyIdentifier.INDEX_FINGER_PIP_left.value),
|
145 |
+
(BodyIdentifier.INDEX_FINGER_PIP_left.value, BodyIdentifier.INDEX_FINGER_DIP_left.value),
|
146 |
+
(BodyIdentifier.INDEX_FINGER_DIP_left.value, BodyIdentifier.INDEX_FINGER_TIP_left.value),
|
147 |
+
|
148 |
+
(BodyIdentifier.INDEX_FINGER_MCP_left.value, BodyIdentifier.MIDDLE_FINGER_MCP_left.value),
|
149 |
+
(BodyIdentifier.MIDDLE_FINGER_MCP_left.value, BodyIdentifier.MIDDLE_FINGER_PIP_left.value),
|
150 |
+
(BodyIdentifier.MIDDLE_FINGER_PIP_left.value, BodyIdentifier.MIDDLE_FINGER_DIP_left.value),
|
151 |
+
(BodyIdentifier.MIDDLE_FINGER_DIP_left.value, BodyIdentifier.MIDDLE_FINGER_TIP_left.value),
|
152 |
+
|
153 |
+
(BodyIdentifier.MIDDLE_FINGER_MCP_left.value, BodyIdentifier.RING_FINGER_MCP_left.value),
|
154 |
+
(BodyIdentifier.RING_FINGER_MCP_left.value, BodyIdentifier.RING_FINGER_PIP_left.value),
|
155 |
+
(BodyIdentifier.RING_FINGER_PIP_left.value, BodyIdentifier.RING_FINGER_DIP_left.value),
|
156 |
+
(BodyIdentifier.RING_FINGER_DIP_left.value, BodyIdentifier.RING_FINGER_TIP_left.value),
|
157 |
+
|
158 |
+
(BodyIdentifier.WRIST_left.value, BodyIdentifier.PINKY_MCP_left.value),
|
159 |
+
(BodyIdentifier.PINKY_MCP_left.value, BodyIdentifier.PINKY_PIP_left.value),
|
160 |
+
(BodyIdentifier.PINKY_PIP_left.value, BodyIdentifier.PINKY_DIP_left.value),
|
161 |
+
(BodyIdentifier.PINKY_DIP_left.value, BodyIdentifier.PINKY_TIP_left.value),
|
162 |
+
|
163 |
+
# RIGHT HAND
|
164 |
+
(BodyIdentifier.WRIST_right.value, BodyIdentifier.THUMB_CMC_right.value),
|
165 |
+
(BodyIdentifier.THUMB_CMC_right.value, BodyIdentifier.THUMB_MCP_right.value),
|
166 |
+
(BodyIdentifier.THUMB_MCP_right.value, BodyIdentifier.THUMB_IP_right.value),
|
167 |
+
(BodyIdentifier.THUMB_IP_right.value, BodyIdentifier.THUMB_TIP_right.value),
|
168 |
+
|
169 |
+
(BodyIdentifier.WRIST_right.value, BodyIdentifier.INDEX_FINGER_MCP_right.value),
|
170 |
+
(BodyIdentifier.INDEX_FINGER_MCP_right.value, BodyIdentifier.INDEX_FINGER_PIP_right.value),
|
171 |
+
(BodyIdentifier.INDEX_FINGER_PIP_right.value, BodyIdentifier.INDEX_FINGER_DIP_right.value),
|
172 |
+
(BodyIdentifier.INDEX_FINGER_DIP_right.value, BodyIdentifier.INDEX_FINGER_TIP_right.value),
|
173 |
+
|
174 |
+
(BodyIdentifier.INDEX_FINGER_MCP_right.value, BodyIdentifier.MIDDLE_FINGER_MCP_right.value),
|
175 |
+
(BodyIdentifier.MIDDLE_FINGER_MCP_right.value, BodyIdentifier.MIDDLE_FINGER_PIP_right.value),
|
176 |
+
(BodyIdentifier.MIDDLE_FINGER_PIP_right.value, BodyIdentifier.MIDDLE_FINGER_DIP_right.value),
|
177 |
+
(BodyIdentifier.MIDDLE_FINGER_DIP_right.value, BodyIdentifier.MIDDLE_FINGER_TIP_right.value),
|
178 |
+
|
179 |
+
(BodyIdentifier.MIDDLE_FINGER_MCP_right.value, BodyIdentifier.RING_FINGER_MCP_right.value),
|
180 |
+
(BodyIdentifier.RING_FINGER_MCP_right.value, BodyIdentifier.RING_FINGER_PIP_right.value),
|
181 |
+
(BodyIdentifier.RING_FINGER_PIP_right.value, BodyIdentifier.RING_FINGER_DIP_right.value),
|
182 |
+
(BodyIdentifier.RING_FINGER_DIP_right.value, BodyIdentifier.RING_FINGER_TIP_right.value),
|
183 |
+
|
184 |
+
(BodyIdentifier.WRIST_right.value, BodyIdentifier.PINKY_MCP_right.value),
|
185 |
+
(BodyIdentifier.PINKY_MCP_right.value, BodyIdentifier.PINKY_PIP_right.value),
|
186 |
+
(BodyIdentifier.PINKY_PIP_right.value, BodyIdentifier.PINKY_DIP_right.value),
|
187 |
+
(BodyIdentifier.PINKY_DIP_right.value, BodyIdentifier.PINKY_TIP_right.value),
|
188 |
+
|
189 |
+
# 2 HAND + SHOULDER + ELBOW
|
190 |
+
(BodyIdentifier.RIGHT_SHOULDER.value, BodyIdentifier.RIGHT_ELBOW.value),
|
191 |
+
(BodyIdentifier.RIGHT_ELBOW.value, BodyIdentifier.WRIST_right.value),
|
192 |
+
|
193 |
+
(BodyIdentifier.RIGHT_SHOULDER.value, BodyIdentifier.LEFT_SHOULDER.value),
|
194 |
+
|
195 |
+
(BodyIdentifier.LEFT_SHOULDER.value, BodyIdentifier.LEFT_ELBOW.value),
|
196 |
+
(BodyIdentifier.LEFT_ELBOW.value, BodyIdentifier.WRIST_left.value)]
|
197 |
+
|
198 |
+
neighbor_link = [(i, j) for (i, j) in neighbor_1base]
|
199 |
+
self.edge = self_link + neighbor_link
|
200 |
+
self.center = BodyIdentifier.RIGHT_SHOULDER.value
|
201 |
+
# elif layout=='customer settings'
|
202 |
+
# pass
|
203 |
+
else:
|
204 |
+
raise ValueError("Do Not Exist This Layout.")
|
205 |
+
|
206 |
+
def get_adjacency(self, strategy):
|
207 |
+
valid_hop = range(0, self.max_hop + 1, self.dilation)
|
208 |
+
adjacency = np.zeros((self.num_node, self.num_node))
|
209 |
+
for hop in valid_hop:
|
210 |
+
adjacency[self.hop_dis == hop] = 1
|
211 |
+
normalize_adjacency = normalize_digraph(adjacency)
|
212 |
+
|
213 |
+
if strategy == 'uniform':
|
214 |
+
A = np.zeros((1, self.num_node, self.num_node))
|
215 |
+
A[0] = normalize_adjacency
|
216 |
+
self.A = A
|
217 |
+
elif strategy == 'distance':
|
218 |
+
A = np.zeros((len(valid_hop), self.num_node, self.num_node))
|
219 |
+
for i, hop in enumerate(valid_hop):
|
220 |
+
A[i][self.hop_dis == hop] = normalize_adjacency[self.hop_dis ==
|
221 |
+
hop]
|
222 |
+
self.A = A
|
223 |
+
elif strategy == 'spatial':
|
224 |
+
A = []
|
225 |
+
for hop in valid_hop:
|
226 |
+
a_root = np.zeros((self.num_node, self.num_node))
|
227 |
+
a_close = np.zeros((self.num_node, self.num_node))
|
228 |
+
a_further = np.zeros((self.num_node, self.num_node))
|
229 |
+
for i in range(self.num_node):
|
230 |
+
for j in range(self.num_node):
|
231 |
+
if self.hop_dis[j, i] == hop:
|
232 |
+
if self.hop_dis[j, self.center] == self.hop_dis[
|
233 |
+
i, self.center]:
|
234 |
+
a_root[j, i] = normalize_adjacency[j, i]
|
235 |
+
elif self.hop_dis[j, self.
|
236 |
+
center] > self.hop_dis[i, self.
|
237 |
+
center]:
|
238 |
+
a_close[j, i] = normalize_adjacency[j, i]
|
239 |
+
else:
|
240 |
+
a_further[j, i] = normalize_adjacency[j, i]
|
241 |
+
if hop == 0:
|
242 |
+
A.append(a_root)
|
243 |
+
else:
|
244 |
+
A.append(a_root + a_close)
|
245 |
+
A.append(a_further)
|
246 |
+
A = np.stack(A)
|
247 |
+
self.A = A
|
248 |
+
else:
|
249 |
+
raise ValueError("Do Not Exist This Strategy")
|
250 |
+
|
251 |
+
|
252 |
+
def get_hop_distance(num_node, edge, max_hop=1):
|
253 |
+
A = np.zeros((num_node, num_node))
|
254 |
+
print(edge)
|
255 |
+
for i, j in edge:
|
256 |
+
A[j, i] = 1
|
257 |
+
A[i, j] = 1
|
258 |
+
|
259 |
+
# compute hop steps
|
260 |
+
hop_dis = np.zeros((num_node, num_node)) + np.inf
|
261 |
+
transfer_mat = [np.linalg.matrix_power(A, d) for d in range(max_hop + 1)]
|
262 |
+
arrive_mat = (np.stack(transfer_mat) > 0)
|
263 |
+
for d in range(max_hop, -1, -1):
|
264 |
+
hop_dis[arrive_mat[d]] = d
|
265 |
+
return hop_dis
|
266 |
+
|
267 |
+
|
268 |
+
def normalize_digraph(A):
|
269 |
+
Dl = np.sum(A, 0)
|
270 |
+
num_node = A.shape[0]
|
271 |
+
Dn = np.zeros((num_node, num_node))
|
272 |
+
for i in range(num_node):
|
273 |
+
if Dl[i] > 0:
|
274 |
+
Dn[i, i] = Dl[i]**(-1)
|
275 |
+
AD = np.dot(A, Dn)
|
276 |
+
return AD
|
277 |
+
|
278 |
+
|
279 |
+
def normalize_undigraph(A):
|
280 |
+
Dl = np.sum(A, 0)
|
281 |
+
num_node = A.shape[0]
|
282 |
+
Dn = np.zeros((num_node, num_node))
|
283 |
+
for i in range(num_node):
|
284 |
+
if Dl[i] > 0:
|
285 |
+
Dn[i, i] = Dl[i]**(-0.5)
|
286 |
+
DAD = np.dot(np.dot(Dn, A), Dn)
|
287 |
+
return DAD
|