fossbk commited on
Commit
a6976f4
·
verified ·
1 Parent(s): 0889918

Upload 8 files

Browse files

# AAGCN-VSL-Detection
# Introction
In this project, we modify a model to better classify Vietnamese sign language (VSL). We use Mediapipe to extract keypoints. Then apply a bilinear preprocessing technique in order to reconstruct the missing keypoints. The model is built based on the Adaptive Graph Convolutional Networks by combining with Attention Mechanisms, including spatial, temporal, and channel attention.

The model is test on our self-collected VSL dataset, which was collected from the school for hearing-impaired children in Hanoi, Vietnam. The data consists of 5,572 videos of 28 actors, with 199 classes each, representing the most frequently used spoken Vietnamese. Before training on this dataset, we pre-train the model with the Ankara University Turkish Sign Language Dataset (AUTSL) first to obtain the weights, learning rate for further training.

# Setup
Download all the files and run the jupyter file Total (VSL). This file will do almost the entire training process, ranging from extracting the keypoints to performing interpolation, training the model. All the other files are the supplement for the model part.

Change the path of the input dataset. The code will read all videos in the folder.

In this coding file (Total (VSL)), there are 2 options:
1. Train new.
2. Train continuously from the model that was trained on the AUTSL dataset.

Uncomment the part of the code that you want to use.

# Evaluation
We use k-fold in this project with k is defined as 10. The 10 trained models will be stored in the "checkpoints" folder. The final output will print out the best accuracy.

# Acknoledgement
This project is built by iBME lab at Hanoi University of Science and Technology, Vietnam. It is funded by Hanoi University of Science and Technology (HUST) under project number T2023-PC-028.

1_1000_label.csv ADDED
The diff for this file is too large to render. See raw diff
 
Total (VSL).ipynb ADDED
@@ -0,0 +1,883 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "pre-train with AUTSL dataset"
8
+ ]
9
+ },
10
+ {
11
+ "cell_type": "markdown",
12
+ "metadata": {},
13
+ "source": [
14
+ "This is the code to train the model with the AUTSL dataset"
15
+ ]
16
+ },
17
+ {
18
+ "cell_type": "code",
19
+ "execution_count": 1,
20
+ "metadata": {},
21
+ "outputs": [
22
+ {
23
+ "data": {
24
+ "text/plain": [
25
+ "'\\nfrom torch.utils.data import DataLoader\\nimport torch\\nfrom torchinfo import summary\\nfrom feeder import FeederINCLUDE\\nfrom aagcn import Model\\nimport pytorch_lightning as pl\\nfrom pytorch_lightning.loggers import WandbLogger # Importing here\\nfrom pytorch_lightning.callbacks import ModelCheckpoint\\nimport wandb\\nfrom augumentation import Rotate, Compose\\nfrom torch.utils.data import random_split\\n\\n\\nif __name__ == \\'__main__\\':\\n\\n # Hyper parameter tuning : batch_size, learning_rate, weight_decay\\n config = {\\'batch_size\\': 150, \\'learning_rate\\': 0.0137296, \\'weight_decay\\': 0.000150403}\\n \\n # Load device\\n device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\\n\\n # Initialize wandb\\n wandb.finish()\\n wandb.init(project=\"GCN_VSL\", config=config) \\n wandb_config = wandb.config # Access the config parameters\\n\\n try:\\n # Your training or evaluation code here\\n print(\"WandB initialized successfully.\")\\n\\n finally:\\n # Fnish the WandB run\\n wandb.finish()\\n\\n # Load model\\n model = Model(num_class=226, num_point=46, num_person=1, in_channels=2,\\n graph_args={\"layout\": \"mediapipe_two_hand\", \"strategy\": \"spatial\"},\\n learning_rate=wandb_config.learning_rate, weight_decay=wandb_config.weight_decay)\\n\\n # Callback PL\\n callbacks = [\\n ModelCheckpoint(\\n dirpath=\"checkpoints\",\\n monitor=\"valid_loss\",\\n mode=\"min\",\\n every_n_epochs=2,\\n filename=\\'{epoch}-{valid_loss:.2f}-{valid_accuracy:.2f}-autsl-aagcn\\'\\n ),\\n ]\\n\\n # Augmentation\\n transforms = Compose([\\n Rotate(15, 80, 25, (0.5, 0.5))\\n ])\\n\\n %cd /home/ibmelab/Documents/GG/VSLRecognition/AUTSL/AAGCN\\n # Dataset class\\n train_dataset = FeederINCLUDE(data_path=f\"autsl_train_data_preprocess.npy\", label_path=f\"train_label_preprocess.npy\",\\n transform=transforms)\\n test_dataset = FeederINCLUDE(data_path=f\"autsl_test_data_preprocess.npy\", label_path=f\"test_label_preprocess.npy\")\\n valid_dataset = FeederINCLUDE(data_path=f\"autsl_valid_data_preprocess.npy\", label_path=f\"valid_label_preprocess.npy\")\\n\\n # DataLoader\\n train_dataloader = DataLoader(train_dataset, batch_size=wandb_config.batch_size, shuffle=True)\\n test_dataloader = DataLoader(test_dataset, batch_size=wandb_config.batch_size, shuffle=False)\\n val_dataloader = DataLoader(valid_dataset, batch_size=wandb_config.batch_size, shuffle=False)\\n\\n # Wandb Logger\\n wandb_logger = WandbLogger(log_model=\\'all\\')\\n\\n %cd /media/ibmelab/ibme21/Test\\n # Trainer PL\\n trainer = pl.Trainer(max_epochs=120, accelerator=\"auto\", check_val_every_n_epoch=1,\\n devices=1, callbacks=callbacks, logger=wandb_logger) # Added logger here\\n\\n trainer.fit(model, train_dataloader, val_dataloader)\\n\\n # Optional: Uncomment this when you want to test\\n # trainer.test(model, test_dataloader, ckpt_path=\"checkpoints/your_checkpoint.ckpt\", verbose=True)\\n'"
26
+ ]
27
+ },
28
+ "execution_count": 1,
29
+ "metadata": {},
30
+ "output_type": "execute_result"
31
+ }
32
+ ],
33
+ "source": [
34
+ "'''\n",
35
+ "from torch.utils.data import DataLoader\n",
36
+ "import torch\n",
37
+ "from torchinfo import summary\n",
38
+ "from feeder import FeederINCLUDE\n",
39
+ "from aagcn import Model\n",
40
+ "import pytorch_lightning as pl\n",
41
+ "from pytorch_lightning.loggers import WandbLogger # Importing here\n",
42
+ "from pytorch_lightning.callbacks import ModelCheckpoint\n",
43
+ "import wandb\n",
44
+ "from augumentation import Rotate, Compose\n",
45
+ "from torch.utils.data import random_split\n",
46
+ "\n",
47
+ "\n",
48
+ "if __name__ == '__main__':\n",
49
+ "\n",
50
+ " # Hyper parameter tuning : batch_size, learning_rate, weight_decay\n",
51
+ " config = {'batch_size': 150, 'learning_rate': 0.0137296, 'weight_decay': 0.000150403}\n",
52
+ " \n",
53
+ " # Load device\n",
54
+ " device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
55
+ "\n",
56
+ " # Initialize wandb\n",
57
+ " wandb.finish()\n",
58
+ " wandb.init(project=\"GCN_VSL\", config=config) \n",
59
+ " wandb_config = wandb.config # Access the config parameters\n",
60
+ "\n",
61
+ " try:\n",
62
+ " # Your training or evaluation code here\n",
63
+ " print(\"WandB initialized successfully.\")\n",
64
+ "\n",
65
+ " finally:\n",
66
+ " # Fnish the WandB run\n",
67
+ " wandb.finish()\n",
68
+ "\n",
69
+ " # Load model\n",
70
+ " model = Model(num_class=226, num_point=46, num_person=1, in_channels=2,\n",
71
+ " graph_args={\"layout\": \"mediapipe_two_hand\", \"strategy\": \"spatial\"},\n",
72
+ " learning_rate=wandb_config.learning_rate, weight_decay=wandb_config.weight_decay)\n",
73
+ "\n",
74
+ " # Callback PL\n",
75
+ " callbacks = [\n",
76
+ " ModelCheckpoint(\n",
77
+ " dirpath=\"checkpoints\",\n",
78
+ " monitor=\"valid_loss\",\n",
79
+ " mode=\"min\",\n",
80
+ " every_n_epochs=2,\n",
81
+ " filename='{epoch}-{valid_loss:.2f}-{valid_accuracy:.2f}-autsl-aagcn'\n",
82
+ " ),\n",
83
+ " ]\n",
84
+ "\n",
85
+ " # Augmentation\n",
86
+ " transforms = Compose([\n",
87
+ " Rotate(15, 80, 25, (0.5, 0.5))\n",
88
+ " ])\n",
89
+ "\n",
90
+ " %cd /home/ibmelab/Documents/GG/VSLRecognition/AUTSL/AAGCN\n",
91
+ " # Dataset class\n",
92
+ " train_dataset = FeederINCLUDE(data_path=f\"autsl_train_data_preprocess.npy\", label_path=f\"train_label_preprocess.npy\",\n",
93
+ " transform=transforms)\n",
94
+ " test_dataset = FeederINCLUDE(data_path=f\"autsl_test_data_preprocess.npy\", label_path=f\"test_label_preprocess.npy\")\n",
95
+ " valid_dataset = FeederINCLUDE(data_path=f\"autsl_valid_data_preprocess.npy\", label_path=f\"valid_label_preprocess.npy\")\n",
96
+ "\n",
97
+ " # DataLoader\n",
98
+ " train_dataloader = DataLoader(train_dataset, batch_size=wandb_config.batch_size, shuffle=True)\n",
99
+ " test_dataloader = DataLoader(test_dataset, batch_size=wandb_config.batch_size, shuffle=False)\n",
100
+ " val_dataloader = DataLoader(valid_dataset, batch_size=wandb_config.batch_size, shuffle=False)\n",
101
+ "\n",
102
+ " # Wandb Logger\n",
103
+ " wandb_logger = WandbLogger(log_model='all')\n",
104
+ "\n",
105
+ " %cd /media/ibmelab/ibme21/Test\n",
106
+ " # Trainer PL\n",
107
+ " trainer = pl.Trainer(max_epochs=120, accelerator=\"auto\", check_val_every_n_epoch=1,\n",
108
+ " devices=1, callbacks=callbacks, logger=wandb_logger) # Added logger here\n",
109
+ "\n",
110
+ " trainer.fit(model, train_dataloader, val_dataloader)\n",
111
+ "\n",
112
+ " # Optional: Uncomment this when you want to test\n",
113
+ " # trainer.test(model, test_dataloader, ckpt_path=\"checkpoints/your_checkpoint.ckpt\", verbose=True)\n",
114
+ "'''\n",
115
+ " \n"
116
+ ]
117
+ },
118
+ {
119
+ "cell_type": "code",
120
+ "execution_count": 2,
121
+ "metadata": {},
122
+ "outputs": [],
123
+ "source": [
124
+ "import pandas as pd\n",
125
+ "import mediapipe as mp\n",
126
+ "import cv2\n",
127
+ "from collections import defaultdict\n",
128
+ "from joblib import Parallel, delayed\n",
129
+ "from tqdm import tqdm\n",
130
+ "import ast\n",
131
+ "import os\n",
132
+ "import csv\n",
133
+ "import re\n",
134
+ "from sklearn.model_selection import KFold\n",
135
+ "import numpy as np"
136
+ ]
137
+ },
138
+ {
139
+ "cell_type": "markdown",
140
+ "metadata": {},
141
+ "source": [
142
+ "Load the videos to videos_list.csv (columns: file (path), label, gloss, video name, actor)"
143
+ ]
144
+ },
145
+ {
146
+ "cell_type": "code",
147
+ "execution_count": null,
148
+ "metadata": {},
149
+ "outputs": [
150
+ {
151
+ "name": "stdout",
152
+ "output_type": "stream",
153
+ "text": [
154
+ "Video names have been written to videos_list.csv\n",
155
+ "Minimum label: 20\n",
156
+ "Labels have been updated and saved.\n"
157
+ ]
158
+ }
159
+ ],
160
+ "source": [
161
+ "folder_path = r'path_to_dataset_folder'\n",
162
+ "csv_file_path = 'videos_list.csv'\n",
163
+ "labels_file_path = '1_1000_label.csv'\n",
164
+ "final_file_path = 'temp_videos_list.csv'\n",
165
+ "\n",
166
+ "label_to_gloss = {}\n",
167
+ "with open(labels_file_path, mode='r', encoding='utf-8') as labels_file:\n",
168
+ " csv_reader = csv.DictReader(labels_file)\n",
169
+ " for row in csv_reader:\n",
170
+ " label = int(row['id_label_in_documents'])\n",
171
+ " gloss = row['name']\n",
172
+ " label_to_gloss[label] = gloss\n",
173
+ "\n",
174
+ "with open(csv_file_path, mode='w', newline='', encoding='utf-8') as csv_file:\n",
175
+ " csv_writer = csv.writer(csv_file)\n",
176
+ " csv_writer.writerow(['file', 'label', 'gloss', 'video_name', 'actor'])\n",
177
+ "\n",
178
+ " for filename in os.listdir(folder_path):\n",
179
+ " if filename.lower().endswith(('.mp4', '.mkv', '.avi', '.mov', '.flv', '.wmv')):\n",
180
+ " actor = filename.split('_')[0]\n",
181
+ " \n",
182
+ " match = re.search(r'_(\\d+)\\.', filename)\n",
183
+ " if match:\n",
184
+ " label = int(match.group(1))\n",
185
+ " gloss = label_to_gloss.get(label, 'Unknown')\n",
186
+ " else:\n",
187
+ " label = 'N/A'\n",
188
+ " gloss = 'Unknown'\n",
189
+ "\n",
190
+ " if label != 200:\n",
191
+ " full_filename = os.path.join(folder_path, filename)\n",
192
+ " csv_writer.writerow([full_filename, label, gloss, filename, actor]) \n",
193
+ "\n",
194
+ "print(f'Video names have been written to {csv_file_path}')\n",
195
+ "\n",
196
+ "# Find min label\n",
197
+ "with open(csv_file_path, mode='r', newline='', encoding='utf-8') as csv_file:\n",
198
+ " csv_reader = csv.DictReader(csv_file)\n",
199
+ " labels = [int(row[\"label\"]) for row in csv_reader if row[\"label\"].isdigit()] \n",
200
+ " min_label = min(labels) if labels else None\n",
201
+ "\n",
202
+ "print(\"Minimum label:\", min_label)\n",
203
+ "\n",
204
+ "# Normalize labels\n",
205
+ "with open(csv_file_path, mode='r', newline='', encoding='utf-8') as csv_file, \\\n",
206
+ " open(final_file_path, mode='w', newline='', encoding='utf-8') as final_file:\n",
207
+ " \n",
208
+ " csv_reader = csv.DictReader(csv_file)\n",
209
+ " fieldnames = csv_reader.fieldnames\n",
210
+ " \n",
211
+ " csv_writer = csv.DictWriter(final_file, fieldnames=fieldnames)\n",
212
+ " csv_writer.writeheader()\n",
213
+ " \n",
214
+ " for row in csv_reader:\n",
215
+ " if row['label'].isdigit(): # Check if label is a digit before converting\n",
216
+ " row['label'] = str(int(row['label']) - min_label) \n",
217
+ " csv_writer.writerow(row)\n",
218
+ "\n",
219
+ "# Replace the original file with the updated file\n",
220
+ "os.replace(final_file_path, csv_file_path)\n",
221
+ "\n",
222
+ "print(\"Labels have been updated and saved.\")\n"
223
+ ]
224
+ },
225
+ {
226
+ "cell_type": "markdown",
227
+ "metadata": {},
228
+ "source": [
229
+ "Number of labels in the dataset"
230
+ ]
231
+ },
232
+ {
233
+ "cell_type": "code",
234
+ "execution_count": 4,
235
+ "metadata": {},
236
+ "outputs": [
237
+ {
238
+ "name": "stdout",
239
+ "output_type": "stream",
240
+ "text": [
241
+ "10\n",
242
+ "{20, 21, 22, 23, 24, 25, 26, 27, 28, 29}\n"
243
+ ]
244
+ }
245
+ ],
246
+ "source": [
247
+ "num_labels = len(set(labels))\n",
248
+ "print(num_labels)\n",
249
+ "print(set(labels))"
250
+ ]
251
+ },
252
+ {
253
+ "cell_type": "markdown",
254
+ "metadata": {},
255
+ "source": [
256
+ "Extract keypoints"
257
+ ]
258
+ },
259
+ {
260
+ "cell_type": "code",
261
+ "execution_count": null,
262
+ "metadata": {},
263
+ "outputs": [
264
+ {
265
+ "name": "stderr",
266
+ "output_type": "stream",
267
+ "text": [
268
+ " \r"
269
+ ]
270
+ }
271
+ ],
272
+ "source": [
273
+ "import pandas as pd\n",
274
+ "import mediapipe as mp\n",
275
+ "import cv2\n",
276
+ "import os\n",
277
+ "from collections import defaultdict\n",
278
+ "from joblib import Parallel, delayed\n",
279
+ "from tqdm import tqdm\n",
280
+ "\n",
281
+ "mp_holistic = mp.solutions.holistic\n",
282
+ "mp_drawing = mp.solutions.drawing_utils\n",
283
+ "\n",
284
+ "hand_landmarks = ['INDEX_FINGER_DIP', 'INDEX_FINGER_MCP', 'INDEX_FINGER_PIP', 'INDEX_FINGER_TIP', \n",
285
+ " 'MIDDLE_FINGER_DIP', 'MIDDLE_FINGER_MCP', 'MIDDLE_FINGER_PIP', 'MIDDLE_FINGER_TIP', \n",
286
+ " 'PINKY_DIP', 'PINKY_MCP', 'PINKY_PIP', 'PINKY_TIP', 'RING_FINGER_DIP', 'RING_FINGER_MCP', \n",
287
+ " 'RING_FINGER_PIP', 'RING_FINGER_TIP', 'THUMB_CMC', 'THUMB_IP', 'THUMB_MCP', 'THUMB_TIP', 'WRIST']\n",
288
+ "pose_landmarks = ['LEFT_ANKLE', 'LEFT_EAR', 'LEFT_ELBOW', 'LEFT_EYE', 'LEFT_EYE_INNER', 'LEFT_EYE_OUTER', \n",
289
+ " 'LEFT_FOOT_INDEX', 'LEFT_HEEL', 'LEFT_HIP', 'LEFT_INDEX', 'LEFT_KNEE', 'LEFT_PINKY', \n",
290
+ " 'LEFT_SHOULDER', 'LEFT_THUMB', 'LEFT_WRIST', 'MOUTH_LEFT', 'MOUTH_RIGHT', 'NOSE', \n",
291
+ " 'RIGHT_ANKLE', 'RIGHT_EAR', 'RIGHT_ELBOW', 'RIGHT_EYE', 'RIGHT_EYE_INNER', 'RIGHT_EYE_OUTER', \n",
292
+ " 'RIGHT_FOOT_INDEX', 'RIGHT_HEEL', 'RIGHT_HIP', 'RIGHT_INDEX', 'RIGHT_KNEE', 'RIGHT_PINKY', \n",
293
+ " 'RIGHT_SHOULDER', 'RIGHT_THUMB', 'RIGHT_WRIST']\n",
294
+ "\n",
295
+ "def extract_keypoint(video_path, label, actor):\n",
296
+ " cap = cv2.VideoCapture(video_path)\n",
297
+ " \n",
298
+ " keypoint_dict = defaultdict(list)\n",
299
+ " count = 0\n",
300
+ "\n",
301
+ " with mp_holistic.Holistic(min_detection_confidence=0.5, min_tracking_confidence=0.5) as holistic:\n",
302
+ " while True:\n",
303
+ " ret, frame = cap.read()\n",
304
+ " if not ret:\n",
305
+ " break\n",
306
+ " \n",
307
+ " count += 1\n",
308
+ " image = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)\n",
309
+ " results = holistic.process(image)\n",
310
+ "\n",
311
+ " if results.right_hand_landmarks:\n",
312
+ " for idx, landmark in enumerate(results.right_hand_landmarks.landmark): \n",
313
+ " keypoint_dict[f\"{hand_landmarks[idx]}_right_x\"].append(landmark.x)\n",
314
+ " keypoint_dict[f\"{hand_landmarks[idx]}_right_y\"].append(landmark.y)\n",
315
+ " keypoint_dict[f\"{hand_landmarks[idx]}_right_z\"].append(landmark.z)\n",
316
+ " else:\n",
317
+ " for idx in range(len(hand_landmarks)):\n",
318
+ " keypoint_dict[f\"{hand_landmarks[idx]}_right_x\"].append(0)\n",
319
+ " keypoint_dict[f\"{hand_landmarks[idx]}_right_y\"].append(0)\n",
320
+ " keypoint_dict[f\"{hand_landmarks[idx]}_right_z\"].append(0)\n",
321
+ "\n",
322
+ " if results.left_hand_landmarks:\n",
323
+ " for idx, landmark in enumerate(results.left_hand_landmarks.landmark): \n",
324
+ " keypoint_dict[f\"{hand_landmarks[idx]}_left_x\"].append(landmark.x)\n",
325
+ " keypoint_dict[f\"{hand_landmarks[idx]}_left_y\"].append(landmark.y)\n",
326
+ " keypoint_dict[f\"{hand_landmarks[idx]}_left_z\"].append(landmark.z)\n",
327
+ " else:\n",
328
+ " for idx in range(len(hand_landmarks)):\n",
329
+ " keypoint_dict[f\"{hand_landmarks[idx]}_left_x\"].append(0)\n",
330
+ " keypoint_dict[f\"{hand_landmarks[idx]}_left_y\"].append(0)\n",
331
+ " keypoint_dict[f\"{hand_landmarks[idx]}_left_z\"].append(0)\n",
332
+ "\n",
333
+ " if results.pose_landmarks:\n",
334
+ " for idx, landmark in enumerate(results.pose_landmarks.landmark): \n",
335
+ " keypoint_dict[f\"{pose_landmarks[idx]}_x\"].append(landmark.x)\n",
336
+ " keypoint_dict[f\"{pose_landmarks[idx]}_y\"].append(landmark.y)\n",
337
+ " keypoint_dict[f\"{pose_landmarks[idx]}_z\"].append(landmark.z)\n",
338
+ " else:\n",
339
+ " for idx in range(len(pose_landmarks)):\n",
340
+ " keypoint_dict[f\"{pose_landmarks[idx]}_x\"].append(0)\n",
341
+ " keypoint_dict[f\"{pose_landmarks[idx]}_y\"].append(0)\n",
342
+ " keypoint_dict[f\"{pose_landmarks[idx]}_z\"].append(0)\n",
343
+ "\n",
344
+ " keypoint_dict[\"frame\"] = count\n",
345
+ " keypoint_dict[\"video_path\"] = video_path\n",
346
+ " keypoint_dict[\"label\"] = label\n",
347
+ " keypoint_dict[\"actor\"] = actor\n",
348
+ "\n",
349
+ " return keypoint_dict\n",
350
+ "\n",
351
+ "def process_videos():\n",
352
+ " csv_file = f\"videos_list.csv\"\n",
353
+ " data = pd.read_csv(csv_file)\n",
354
+ "\n",
355
+ " keypoints_list = Parallel(n_jobs=-1)( \n",
356
+ " delayed(extract_keypoint)(row['file'], row['label'], row['actor']) for index, row in tqdm(data.iterrows(), total=len(data), desc=\"Processing videos\", leave=False)\n",
357
+ " )\n",
358
+ "\n",
359
+ " keypoints_df = pd.DataFrame(keypoints_list)\n",
360
+ " keypoints_df.to_csv(f\"vsl{num_labels}_keypoints.csv\", index=False)\n",
361
+ "\n",
362
+ "if __name__ == '__main__':\n",
363
+ " process_videos()\n"
364
+ ]
365
+ },
366
+ {
367
+ "cell_type": "markdown",
368
+ "metadata": {},
369
+ "source": [
370
+ "Interpolation"
371
+ ]
372
+ },
373
+ {
374
+ "cell_type": "code",
375
+ "execution_count": null,
376
+ "metadata": {},
377
+ "outputs": [
378
+ {
379
+ "name": "stderr",
380
+ "output_type": "stream",
381
+ "text": [
382
+ "100%|██████████| 280/280 [00:04<00:00, 67.80it/s]\n"
383
+ ]
384
+ },
385
+ {
386
+ "name": "stdout",
387
+ "output_type": "stream",
388
+ "text": [
389
+ "Interpolated keypoints saved to vsl10_interpolated_keypoints.csv\n"
390
+ ]
391
+ },
392
+ {
393
+ "name": "stderr",
394
+ "output_type": "stream",
395
+ "text": [
396
+ "100%|██████████| 280/280 [00:03<00:00, 91.60it/s] "
397
+ ]
398
+ },
399
+ {
400
+ "name": "stdout",
401
+ "output_type": "stream",
402
+ "text": [
403
+ "Data processing and saving completed.\n"
404
+ ]
405
+ },
406
+ {
407
+ "name": "stderr",
408
+ "output_type": "stream",
409
+ "text": [
410
+ "\n"
411
+ ]
412
+ }
413
+ ],
414
+ "source": [
415
+ "import pandas as pd\n",
416
+ "import numpy as np\n",
417
+ "import ast\n",
418
+ "from tqdm import tqdm\n",
419
+ "\n",
420
+ "def find_index(array):\n",
421
+ " for i, num in enumerate(array):\n",
422
+ " if num != 0:\n",
423
+ " return i\n",
424
+ "\n",
425
+ "def curl_skeleton(array):\n",
426
+ " if sum(array) == 0:\n",
427
+ " return array\n",
428
+ " for i, location in enumerate(array):\n",
429
+ " if location != 0:\n",
430
+ " continue\n",
431
+ " else:\n",
432
+ " if i == 0 or i == len(array) - 1:\n",
433
+ " continue\n",
434
+ " else:\n",
435
+ " if array[i + 1] != 0:\n",
436
+ " array[i] = float((array[i - 1] + array[i + 1]) / 2)\n",
437
+ " else:\n",
438
+ " if sum(array[i:]) == 0:\n",
439
+ " continue\n",
440
+ " else:\n",
441
+ " j = find_index(array[i + 1:])\n",
442
+ " array[i] = float(((1 + j) * array[i - 1] + 1 * array[i + 1 + j]) / (2 + j))\n",
443
+ " return array\n",
444
+ "\n",
445
+ "def interpolate_keypoints(input_file, output_file, body_identifiers):\n",
446
+ " train_data = pd.read_csv(input_file)\n",
447
+ " output_df = train_data.copy()\n",
448
+ "\n",
449
+ " for index, video in tqdm(train_data.iterrows(), total=train_data.shape[0]):\n",
450
+ " for identifier in body_identifiers:\n",
451
+ " # Interpolate the x and y keypoints\n",
452
+ " x_values = curl_skeleton(ast.literal_eval(video[identifier + \"_x\"]))\n",
453
+ " y_values = curl_skeleton(ast.literal_eval(video[identifier + \"_y\"]))\n",
454
+ "\n",
455
+ " output_df.at[index, identifier + \"_x\"] = str(x_values)\n",
456
+ " output_df.at[index, identifier + \"_y\"] = str(y_values)\n",
457
+ "\n",
458
+ " output_df.to_csv(output_file, index=False)\n",
459
+ " print(f\"Interpolated keypoints saved to {output_file}\")\n",
460
+ "\n",
461
+ "if __name__ == \"__main__\":\n",
462
+ " input_file_path = f\"vsl{num_labels}_keypoints.csv\"\n",
463
+ " output_file_path = f\"vsl{num_labels}_interpolated_keypoints.csv\"\n",
464
+ "\n",
465
+ " hand_landmarks = [\n",
466
+ " 'INDEX_FINGER_DIP', 'INDEX_FINGER_MCP', 'INDEX_FINGER_PIP', 'INDEX_FINGER_TIP', \n",
467
+ " 'MIDDLE_FINGER_DIP', 'MIDDLE_FINGER_MCP', 'MIDDLE_FINGER_PIP', 'MIDDLE_FINGER_TIP', \n",
468
+ " 'PINKY_DIP', 'PINKY_MCP', 'PINKY_PIP', 'PINKY_TIP', \n",
469
+ " 'RING_FINGER_DIP', 'RING_FINGER_MCP', 'RING_FINGER_PIP', 'RING_FINGER_TIP', \n",
470
+ " 'THUMB_CMC', 'THUMB_IP', 'THUMB_MCP', 'THUMB_TIP', 'WRIST'\n",
471
+ " ]\n",
472
+ " HAND_IDENTIFIERS = [id + \"_right\" for id in hand_landmarks] + [id + \"_left\" for id in hand_landmarks]\n",
473
+ " POSE_IDENTIFIERS = [\"RIGHT_SHOULDER\", \"LEFT_SHOULDER\", \"LEFT_ELBOW\", \"RIGHT_ELBOW\"]\n",
474
+ " body_identifiers = HAND_IDENTIFIERS + POSE_IDENTIFIERS \n",
475
+ "\n",
476
+ " interpolate_keypoints(input_file_path, output_file_path, body_identifiers)\n",
477
+ "\n",
478
+ " # Load interpolated data and store them in numpy files\n",
479
+ " train_data = pd.read_csv(output_file_path)\n",
480
+ " frames = 80\n",
481
+ "\n",
482
+ " data = []\n",
483
+ " labels = []\n",
484
+ "\n",
485
+ " for video_index, video in tqdm(train_data.iterrows(), total=train_data.shape[0]):\n",
486
+ " T = len(ast.literal_eval(video[\"INDEX_FINGER_DIP_right_x\"]))\n",
487
+ " current_row = np.empty(shape=(2, T, len(body_identifiers), 1))\n",
488
+ "\n",
489
+ " for index, identifier in enumerate(body_identifiers):\n",
490
+ " data_keypoint_preprocess_x = ast.literal_eval(video[identifier + \"_x\"])\n",
491
+ " current_row[0, :, index, :] = np.asarray(data_keypoint_preprocess_x).reshape(T, 1)\n",
492
+ "\n",
493
+ " data_keypoint_preprocess_y = ast.literal_eval(video[identifier + \"_y\"])\n",
494
+ " current_row[1, :, index, :] = np.asarray(data_keypoint_preprocess_y).reshape(T, 1)\n",
495
+ "\n",
496
+ " if T < frames:\n",
497
+ " target = np.zeros(shape=(2, frames, len(body_identifiers), 1))\n",
498
+ " target[:, :T, :, :] = current_row\n",
499
+ " else:\n",
500
+ " target = current_row[:, :frames, :, :]\n",
501
+ "\n",
502
+ " data.append(target)\n",
503
+ " labels.append(int(video[\"label\"]))\n",
504
+ "\n",
505
+ " keypoint_data = np.stack(data, axis=0)\n",
506
+ " label_data = np.stack(labels, axis=0)\n",
507
+ " np.save(f'vsl{num_labels}_data_preprocess.npy', keypoint_data)\n",
508
+ " np.save(f'vsl{num_labels}_label_preprocess.npy', label_data)\n",
509
+ " print(\"Data processing and saving completed.\")\n"
510
+ ]
511
+ },
512
+ {
513
+ "cell_type": "code",
514
+ "execution_count": 7,
515
+ "metadata": {},
516
+ "outputs": [
517
+ {
518
+ "name": "stdout",
519
+ "output_type": "stream",
520
+ "text": [
521
+ "(280, 2, 80, 46, 1)\n",
522
+ "(280,)\n"
523
+ ]
524
+ }
525
+ ],
526
+ "source": [
527
+ "import numpy as np\n",
528
+ "a = np.load(f'vsl{num_labels}_data_preprocess.npy')\n",
529
+ "b = np.load(f'vsl{num_labels}_label_preprocess.npy')\n",
530
+ "\n",
531
+ "print(a.shape)\n",
532
+ "print(b.shape)"
533
+ ]
534
+ },
535
+ {
536
+ "cell_type": "markdown",
537
+ "metadata": {},
538
+ "source": [
539
+ "Do K-Folds and store the keypoints in numpy files"
540
+ ]
541
+ },
542
+ {
543
+ "cell_type": "code",
544
+ "execution_count": null,
545
+ "metadata": {},
546
+ "outputs": [
547
+ {
548
+ "name": "stdout",
549
+ "output_type": "stream",
550
+ "text": [
551
+ "Number of actors: 28\n",
552
+ "-----------------------------------------------------\n",
553
+ "Fold 1: 30 test samples\n",
554
+ "Fold 2: 29 test samples\n",
555
+ "Fold 3: 30 test samples\n",
556
+ "Fold 4: 30 test samples\n",
557
+ "Fold 5: 30 test samples\n",
558
+ "Fold 6: 30 test samples\n",
559
+ "Fold 7: 31 test samples\n",
560
+ "Fold 8: 30 test samples\n",
561
+ "Fold 9: 20 test samples\n",
562
+ "Fold 10: 20 test samples\n",
563
+ "Processed and saved vsl10 fold 1 successfully.\n",
564
+ "Processed and saved vsl10 fold 2 successfully.\n",
565
+ "Processed and saved vsl10 fold 3 successfully.\n",
566
+ "Processed and saved vsl10 fold 4 successfully.\n",
567
+ "Processed and saved vsl10 fold 5 successfully.\n",
568
+ "Processed and saved vsl10 fold 6 successfully.\n",
569
+ "Processed and saved vsl10 fold 7 successfully.\n",
570
+ "Processed and saved vsl10 fold 8 successfully.\n",
571
+ "Processed and saved vsl10 fold 9 successfully.\n",
572
+ "Processed and saved vsl10 fold 10 successfully.\n"
573
+ ]
574
+ }
575
+ ],
576
+ "source": [
577
+ "from sklearn.model_selection import KFold\n",
578
+ "import os\n",
579
+ "import numpy as np\n",
580
+ "import pandas as pd\n",
581
+ "from tqdm import tqdm\n",
582
+ "\n",
583
+ "def k_fold_cross_validation(train_data, keypoint_data, label_data, num_labels, k_folds, destination_folder=\"numpy_files\"):\n",
584
+ " os.makedirs(destination_folder, exist_ok=True)\n",
585
+ "\n",
586
+ " actors = train_data['actor'].unique()\n",
587
+ " print(f\"Number of actors: {len(actors)}\")\n",
588
+ " print('-----------------------------------------------------')\n",
589
+ "\n",
590
+ " kf = KFold(n_splits=k_folds, shuffle=True, random_state=42)\n",
591
+ "\n",
592
+ " actor_to_indices = {actor: train_data.index[train_data['actor'] == actor].tolist() for actor in actors}\n",
593
+ " folds = [[] for _ in range(k_folds)]\n",
594
+ "\n",
595
+ " for fold, (train_actors, test_actors) in enumerate(kf.split(actors)):\n",
596
+ " train_actors = actors[train_actors]\n",
597
+ " test_actors = actors[test_actors]\n",
598
+ " \n",
599
+ " for actor in test_actors:\n",
600
+ " folds[fold].extend(actor_to_indices[actor])\n",
601
+ "\n",
602
+ " tqdm.write(f\"Fold {fold+1}: {len(folds[fold])} test samples\")\n",
603
+ "\n",
604
+ " # Iterate over each fold to create train-test splits\n",
605
+ " for fold in range(k_folds):\n",
606
+ " test_indices = folds[fold]\n",
607
+ " train_indices = [idx for f in range(k_folds) if f != fold for idx in folds[f]]\n",
608
+ "\n",
609
+ " X_train, X_test = keypoint_data[train_indices], keypoint_data[test_indices]\n",
610
+ " y_train = np.array(label_data[train_indices], dtype=np.int64)\n",
611
+ " y_test = np.array(label_data[test_indices], dtype=np.int64)\n",
612
+ "\n",
613
+ " np.save(os.path.join(destination_folder, f'vsl{num_labels}_data_fold{fold+1}_train.npy'), X_train)\n",
614
+ " np.save(os.path.join(destination_folder, f'vsl{num_labels}_label_fold{fold+1}_train.npy'), y_train)\n",
615
+ " np.save(os.path.join(destination_folder, f'vsl{num_labels}_data_fold{fold+1}_test.npy'), X_test)\n",
616
+ " np.save(os.path.join(destination_folder, f'vsl{num_labels}_label_fold{fold+1}_test.npy'), y_test)\n",
617
+ "\n",
618
+ " tqdm.write(f\"Processed and saved vsl{num_labels} fold {fold+1} successfully.\")\n",
619
+ "\n",
620
+ "if __name__ == \"__main__\":\n",
621
+ " input_file_path = f\"vsl{num_labels}_interpolated_keypoints.csv\"\n",
622
+ " train_data = pd.read_csv(input_file_path)\n",
623
+ "\n",
624
+ " keypoint_data = np.load(f'vsl{num_labels}_data_preprocess.npy')\n",
625
+ " label_data = np.load(f'vsl{num_labels}_label_preprocess.npy')\n",
626
+ "\n",
627
+ " num_labels = len(np.unique(label_data))\n",
628
+ "\n",
629
+ " k_folds = 10\n",
630
+ " k_fold_cross_validation(train_data, keypoint_data, label_data, num_labels, k_folds)\n"
631
+ ]
632
+ },
633
+ {
634
+ "cell_type": "code",
635
+ "execution_count": 9,
636
+ "metadata": {},
637
+ "outputs": [
638
+ {
639
+ "name": "stdout",
640
+ "output_type": "stream",
641
+ "text": [
642
+ "(29, 2, 80, 46, 1)\n",
643
+ "(251, 2, 80, 46, 1)\n"
644
+ ]
645
+ }
646
+ ],
647
+ "source": [
648
+ "import numpy as np\n",
649
+ "a = np.load(f'numpy_files/vsl{num_labels}_data_fold2_test.npy')\n",
650
+ "b = np.load(f'numpy_files/vsl{num_labels}_data_fold2_train.npy')\n",
651
+ "\n",
652
+ "print(a.shape)\n",
653
+ "print(b.shape)"
654
+ ]
655
+ },
656
+ {
657
+ "cell_type": "markdown",
658
+ "metadata": {},
659
+ "source": [
660
+ "train directly with different folds"
661
+ ]
662
+ },
663
+ {
664
+ "cell_type": "code",
665
+ "execution_count": null,
666
+ "metadata": {},
667
+ "outputs": [],
668
+ "source": [
669
+ "'''\n",
670
+ "import os\n",
671
+ "import numpy as np\n",
672
+ "import torch\n",
673
+ "from torch.utils.data import DataLoader\n",
674
+ "import pytorch_lightning as pl\n",
675
+ "from pytorch_lightning.callbacks import ModelCheckpoint\n",
676
+ "from feeder import FeederINCLUDE\n",
677
+ "from aagcn import Model\n",
678
+ "from augumentation import Rotate, Compose\n",
679
+ "\n",
680
+ "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"1\"\n",
681
+ "\n",
682
+ "if __name__ == '__main__':\n",
683
+ " k_folds = 10\n",
684
+ " config = {'batch_size': 128, 'learning_rate': 0.0137296, 'weight_decay': 0.000150403}\n",
685
+ " \n",
686
+ " device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
687
+ "\n",
688
+ " best_accuracy = 0.0\n",
689
+ " best_fold = -1\n",
690
+ "\n",
691
+ " for fold in range(k_folds):\n",
692
+ " print(f\"Starting fold {fold + 1}/{k_folds}\")\n",
693
+ " train_data_path = os.path.join(\"numpy_files\", f'vsl{num_labels}_data_fold{fold+1}_train.npy')\n",
694
+ " train_label_path = os.path.join(\"numpy_files\", f'vsl{num_labels}_label_fold{fold+1}_train.npy')\n",
695
+ " val_data_path = os.path.join(\"numpy_files\", f'vsl{num_labels}_data_fold{fold+1}_test.npy')\n",
696
+ " val_label_path = os.path.join(\"numpy_files\", f'vsl{num_labels}_label_fold{fold+1}_test.npy')\n",
697
+ "\n",
698
+ " transforms = Compose([\n",
699
+ " Rotate(15, 80, 25, (0.5, 0.5))\n",
700
+ " ])\n",
701
+ "\n",
702
+ " train_dataset = FeederINCLUDE(\n",
703
+ " data_path=train_data_path,\n",
704
+ " label_path=train_label_path,\n",
705
+ " transform=transforms\n",
706
+ " )\n",
707
+ " val_dataset = FeederINCLUDE(\n",
708
+ " data_path=val_data_path,\n",
709
+ " label_path=val_label_path\n",
710
+ " )\n",
711
+ "\n",
712
+ " train_dataloader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True)\n",
713
+ " val_dataloader = DataLoader(val_dataset, batch_size=config['batch_size'], shuffle=False)\n",
714
+ "\n",
715
+ " model = Model(num_class=num_labels, num_point=46, num_person=1, in_channels=2,\n",
716
+ " graph_args={\"layout\": \"mediapipe_two_hand\", \"strategy\": \"spatial\"},\n",
717
+ " learning_rate=config['learning_rate'], weight_decay=config['weight_decay'])\n",
718
+ "\n",
719
+ " callbacks = [\n",
720
+ " ModelCheckpoint(\n",
721
+ " dirpath=\"checkpoints\",\n",
722
+ " monitor=\"valid_accuracy\",\n",
723
+ " mode=\"max\",\n",
724
+ " every_n_epochs=2,\n",
725
+ " filename=f'vsl{num_labels}-aagcn-fold={fold+1}'\n",
726
+ " ),\n",
727
+ " ]\n",
728
+ "\n",
729
+ " trainer = pl.Trainer(max_epochs=2, accelerator=\"auto\", check_val_every_n_epoch=1,\n",
730
+ " devices=1, callbacks=callbacks)\n",
731
+ "\n",
732
+ " trainer.fit(model, train_dataloader, val_dataloader)\n",
733
+ " val_accuracy = trainer.callback_metrics['valid_accuracy'].item()\n",
734
+ " print(f\"Fold {fold + 1} finished with validation accuracy: {val_accuracy:.4f}\")\n",
735
+ "\n",
736
+ " if val_accuracy > best_accuracy:\n",
737
+ " best_accuracy = val_accuracy\n",
738
+ " best_fold = fold + 1 \n",
739
+ "\n",
740
+ " print(f\"The highest validation accuracy achieved is {best_accuracy:.4f} from fold {best_fold}.\")\n",
741
+ "'''"
742
+ ]
743
+ },
744
+ {
745
+ "cell_type": "code",
746
+ "execution_count": null,
747
+ "metadata": {},
748
+ "outputs": [],
749
+ "source": [
750
+ "\n",
751
+ "print(f\"The highest validation accuracy achieved of vsl{num_labels} is {best_accuracy:.4f} from fold {best_fold}.\")"
752
+ ]
753
+ },
754
+ {
755
+ "cell_type": "markdown",
756
+ "metadata": {},
757
+ "source": [
758
+ "train based on AUTSL with different folds"
759
+ ]
760
+ },
761
+ {
762
+ "cell_type": "code",
763
+ "execution_count": null,
764
+ "metadata": {},
765
+ "outputs": [],
766
+ "source": [
767
+ "\n",
768
+ "import os\n",
769
+ "import numpy as np\n",
770
+ "import torch\n",
771
+ "from torch.utils.data import DataLoader\n",
772
+ "import pytorch_lightning as pl\n",
773
+ "from pytorch_lightning.callbacks import ModelCheckpoint\n",
774
+ "from feeder import FeederINCLUDE\n",
775
+ "from aagcn import Model\n",
776
+ "from augumentation import Rotate, Compose\n",
777
+ "from pytorch_lightning.utilities.migration import pl_legacy_patch\n",
778
+ "\n",
779
+ "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"1\"\n",
780
+ "\n",
781
+ "if __name__ == '__main__':\n",
782
+ " k_folds = 10 \n",
783
+ " config = {'batch_size': 128, 'learning_rate': 0.0137296, 'weight_decay': 0.000150403}\n",
784
+ " \n",
785
+ " device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
786
+ "\n",
787
+ " best_accuracy = 0.0\n",
788
+ " best_fold = -1\n",
789
+ "\n",
790
+ " for fold in range(k_folds):\n",
791
+ " print(f\"Starting fold {fold + 1}/{k_folds}\")\n",
792
+ " train_data_path = os.path.join(\"numpy_files\", f'vsl{num_labels}_data_fold{fold+1}_train.npy')\n",
793
+ " train_label_path = os.path.join(\"numpy_files\", f'vsl{num_labels}_label_fold{fold+1}_train.npy')\n",
794
+ " val_data_path = os.path.join(\"numpy_files\", f'vsl{num_labels}_data_fold{fold+1}_test.npy')\n",
795
+ " val_label_path = os.path.join(\"numpy_files\", f'vsl{num_labels}_label_fold{fold+1}_test.npy')\n",
796
+ "\n",
797
+ " transforms = Compose([\n",
798
+ " Rotate(15, 80, 25, (0.5, 0.5))\n",
799
+ " ])\n",
800
+ "\n",
801
+ " train_dataset = FeederINCLUDE(\n",
802
+ " data_path=train_data_path,\n",
803
+ " label_path=train_label_path,\n",
804
+ " transform=transforms\n",
805
+ " )\n",
806
+ " val_dataset = FeederINCLUDE(\n",
807
+ " data_path=val_data_path,\n",
808
+ " label_path=val_label_path\n",
809
+ " )\n",
810
+ "\n",
811
+ " train_dataloader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True)\n",
812
+ " val_dataloader = DataLoader(val_dataset, batch_size=config['batch_size'], shuffle=False)\n",
813
+ "\n",
814
+ " model = Model(num_class=num_labels, num_point=46, num_person=1, in_channels=2,\n",
815
+ " graph_args={\"layout\": \"mediapipe_two_hand\", \"strategy\": \"spatial\"},\n",
816
+ " learning_rate=config['learning_rate'], weight_decay=config['weight_decay'])\n",
817
+ "\n",
818
+ " # Path pre-trained checkpoint file on AUTSL\n",
819
+ " checkpoint_path = \"epoch=55-valid_loss=0.41-valid_accuracy=0.85-autsl-aagcn.ckpt\"\n",
820
+ "\n",
821
+ " with pl_legacy_patch():\n",
822
+ " checkpoint = torch.load(checkpoint_path, map_location=device)\n",
823
+ "\n",
824
+ " state_dict = checkpoint['state_dict']\n",
825
+ " filtered_state_dict = {k: v for k, v in state_dict.items() if not k.startswith('fc.')}\n",
826
+ " model.load_state_dict(filtered_state_dict, strict=False)\n",
827
+ "\n",
828
+ " callbacks = [\n",
829
+ " ModelCheckpoint(\n",
830
+ " dirpath=\"checkpoints\",\n",
831
+ " monitor=\"valid_accuracy\",\n",
832
+ " mode=\"max\",\n",
833
+ " every_n_epochs=2,\n",
834
+ " filename=f'autsl_vsl{num_labels}-aagcn-fold={fold+1}'\n",
835
+ " ),\n",
836
+ " ]\n",
837
+ "\n",
838
+ " trainer = pl.Trainer(max_epochs=100, accelerator=\"auto\", check_val_every_n_epoch=1,\n",
839
+ " devices=1, callbacks=callbacks)\n",
840
+ "\n",
841
+ " trainer.fit(model, train_dataloader, val_dataloader)\n",
842
+ " val_accuracy = trainer.callback_metrics['valid_accuracy'].item() \n",
843
+ " print(f\"Fold {fold + 1} finished with validation accuracy: {val_accuracy:.4f}\")\n",
844
+ "\n",
845
+ " if val_accuracy > best_accuracy:\n",
846
+ " best_accuracy = val_accuracy\n",
847
+ " best_fold = fold + 1 \n",
848
+ "\n",
849
+ " print(f\"The highest validation accuracy achieved is {best_accuracy:.4f} from fold {best_fold}.\")\n"
850
+ ]
851
+ },
852
+ {
853
+ "cell_type": "code",
854
+ "execution_count": null,
855
+ "metadata": {},
856
+ "outputs": [],
857
+ "source": [
858
+ "print(f\"The highest validation accuracy achieved of autsl vsl{num_labels} is {best_accuracy:.4f} from fold {best_fold}.\")"
859
+ ]
860
+ }
861
+ ],
862
+ "metadata": {
863
+ "kernelspec": {
864
+ "display_name": "Python 3",
865
+ "language": "python",
866
+ "name": "python3"
867
+ },
868
+ "language_info": {
869
+ "codemirror_mode": {
870
+ "name": "ipython",
871
+ "version": 3
872
+ },
873
+ "file_extension": ".py",
874
+ "mimetype": "text/x-python",
875
+ "name": "python",
876
+ "nbconvert_exporter": "python",
877
+ "pygments_lexer": "ipython3",
878
+ "version": "3.10.11"
879
+ }
880
+ },
881
+ "nbformat": 4,
882
+ "nbformat_minor": 2
883
+ }
aagcn.py ADDED
@@ -0,0 +1,440 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ from torch.autograd import Variable
7
+ from graph import Graph
8
+ import pytorch_lightning as pl
9
+ from torchmetrics.classification import MulticlassAccuracy, BinaryAccuracy
10
+ import torch.optim as optim
11
+
12
+ def import_class(name):
13
+ components = name.split('.')
14
+ mod = __import__(components[0])
15
+ for comp in components[1:]:
16
+ mod = getattr(mod, comp)
17
+ return mod
18
+
19
+
20
+ def conv_branch_init(conv, branches):
21
+ weight = conv.weight
22
+ n = weight.size(0)
23
+ k1 = weight.size(1)
24
+ k2 = weight.size(2)
25
+ nn.init.normal_(weight, 0, math.sqrt(2. / (n * k1 * k2 * branches)))
26
+ nn.init.constant_(conv.bias, 0)
27
+
28
+
29
+ def conv_init(conv):
30
+ nn.init.kaiming_normal_(conv.weight, mode='fan_out')
31
+ nn.init.constant_(conv.bias, 0)
32
+
33
+
34
+ def bn_init(bn, scale):
35
+ nn.init.constant_(bn.weight, scale)
36
+ nn.init.constant_(bn.bias, 0)
37
+
38
+
39
+ class unit_tcn(nn.Module):
40
+ def __init__(self, in_channels, out_channels, kernel_size=9, stride=1):
41
+ super(unit_tcn, self).__init__()
42
+ pad = int((kernel_size - 1) / 2)
43
+ self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=(kernel_size, 1), padding=(pad, 0),
44
+ stride=(stride, 1))
45
+
46
+ self.bn = nn.BatchNorm2d(out_channels)
47
+ self.relu = nn.ReLU(inplace=True)
48
+ conv_init(self.conv)
49
+ bn_init(self.bn, 1)
50
+
51
+ def forward(self, x):
52
+ x = self.bn(self.conv(x))
53
+ return x
54
+
55
+
56
+ class unit_gcn(nn.Module):
57
+ def __init__(self, in_channels, out_channels, A, coff_embedding=4, num_subset=3, adaptive=True, attention=True):
58
+ super(unit_gcn, self).__init__()
59
+ inter_channels = out_channels // coff_embedding
60
+ self.inter_c = inter_channels
61
+ self.out_c = out_channels
62
+ self.in_c = in_channels
63
+ self.num_subset = num_subset
64
+ num_jpts = A.shape[-1]
65
+
66
+ self.conv_d = nn.ModuleList()
67
+ for i in range(self.num_subset):
68
+ self.conv_d.append(nn.Conv2d(in_channels, out_channels, 1))
69
+
70
+ if adaptive:
71
+ self.PA = nn.Parameter(torch.from_numpy(A.astype(np.float32)))
72
+ self.alpha = nn.Parameter(torch.zeros(1))
73
+ # self.beta = nn.Parameter(torch.ones(1))
74
+ # nn.init.constant_(self.PA, 1e-6)
75
+ # self.A = Variable(torch.from_numpy(A.astype(np.float32)), requires_grad=False)
76
+ # self.A = self.PA
77
+ self.conv_a = nn.ModuleList()
78
+ self.conv_b = nn.ModuleList()
79
+ for i in range(self.num_subset):
80
+ self.conv_a.append(nn.Conv2d(in_channels, inter_channels, 1))
81
+ self.conv_b.append(nn.Conv2d(in_channels, inter_channels, 1))
82
+ else:
83
+ self.A = Variable(torch.from_numpy(A.astype(np.float32)), requires_grad=False)
84
+ self.adaptive = adaptive
85
+
86
+ if attention:
87
+ # self.beta = nn.Parameter(torch.zeros(1))
88
+ # self.gamma = nn.Parameter(torch.zeros(1))
89
+ # unified attention
90
+ # self.Attention = nn.Parameter(torch.ones(num_jpts))
91
+
92
+ # temporal attention
93
+ self.conv_ta = nn.Conv1d(out_channels, 1, 9, padding=4)
94
+ nn.init.constant_(self.conv_ta.weight, 0)
95
+ nn.init.constant_(self.conv_ta.bias, 0)
96
+
97
+ # s attention
98
+ ker_jpt = num_jpts - 1 if not num_jpts % 2 else num_jpts
99
+ pad = (ker_jpt - 1) // 2
100
+ self.conv_sa = nn.Conv1d(out_channels, 1, ker_jpt, padding=pad)
101
+ nn.init.xavier_normal_(self.conv_sa.weight)
102
+ nn.init.constant_(self.conv_sa.bias, 0)
103
+
104
+ # channel attention
105
+ rr = 2
106
+ self.fc1c = nn.Linear(out_channels, out_channels // rr)
107
+ self.fc2c = nn.Linear(out_channels // rr, out_channels)
108
+ nn.init.kaiming_normal_(self.fc1c.weight)
109
+ nn.init.constant_(self.fc1c.bias, 0)
110
+ nn.init.constant_(self.fc2c.weight, 0)
111
+ nn.init.constant_(self.fc2c.bias, 0)
112
+
113
+ # self.bn = nn.BatchNorm2d(out_channels)
114
+ # bn_init(self.bn, 1)
115
+ self.attention = attention
116
+
117
+ if in_channels != out_channels:
118
+ self.down = nn.Sequential(
119
+ nn.Conv2d(in_channels, out_channels, 1),
120
+ nn.BatchNorm2d(out_channels)
121
+ )
122
+ else:
123
+ self.down = lambda x: x
124
+
125
+ self.bn = nn.BatchNorm2d(out_channels)
126
+ self.soft = nn.Softmax(-2)
127
+ self.tan = nn.Tanh()
128
+ self.sigmoid = nn.Sigmoid()
129
+ self.relu = nn.ReLU(inplace=True)
130
+
131
+ for m in self.modules():
132
+ if isinstance(m, nn.Conv2d):
133
+ conv_init(m)
134
+ elif isinstance(m, nn.BatchNorm2d):
135
+ bn_init(m, 1)
136
+ bn_init(self.bn, 1e-6)
137
+ for i in range(self.num_subset):
138
+ conv_branch_init(self.conv_d[i], self.num_subset)
139
+
140
+ def forward(self, x):
141
+ N, C, T, V = x.size()
142
+
143
+ y = None
144
+ if self.adaptive:
145
+ A = self.PA
146
+ # A = A + self.PA
147
+ for i in range(self.num_subset):
148
+ A1 = self.conv_a[i](x).permute(0, 3, 1, 2).contiguous().view(N, V, self.inter_c * T)
149
+ A2 = self.conv_b[i](x).view(N, self.inter_c * T, V)
150
+ A1 = self.tan(torch.matmul(A1, A2) / A1.size(-1)) # N V V
151
+ A1 = A[i] + A1 * self.alpha
152
+ A2 = x.view(N, C * T, V)
153
+ z = self.conv_d[i](torch.matmul(A2, A1).view(N, C, T, V))
154
+ y = z + y if y is not None else z
155
+ else:
156
+ A = self.A.cuda(x.get_device()) * self.mask
157
+ for i in range(self.num_subset):
158
+ A1 = A[i]
159
+ A2 = x.view(N, C * T, V)
160
+ z = self.conv_d[i](torch.matmul(A2, A1).view(N, C, T, V))
161
+ y = z + y if y is not None else z
162
+
163
+ y = self.bn(y)
164
+ y += self.down(x)
165
+ y = self.relu(y)
166
+
167
+ if self.attention:
168
+ # spatial attention
169
+ se = y.mean(-2) # N C V
170
+ se1 = self.sigmoid(self.conv_sa(se))
171
+ y = y * se1.unsqueeze(-2) + y
172
+ # a1 = se1.unsqueeze(-2)
173
+
174
+ # temporal attention
175
+ se = y.mean(-1)
176
+ se1 = self.sigmoid(self.conv_ta(se))
177
+ y = y * se1.unsqueeze(-1) + y
178
+ # a2 = se1.unsqueeze(-1)
179
+
180
+ # channel attention
181
+ se = y.mean(-1).mean(-1)
182
+ se1 = self.relu(self.fc1c(se))
183
+ se2 = self.sigmoid(self.fc2c(se1))
184
+ y = y * se2.unsqueeze(-1).unsqueeze(-1) + y
185
+ # a3 = se2.unsqueeze(-1).unsqueeze(-1)
186
+
187
+ # unified attention
188
+ # y = y * self.Attention + y
189
+ # y = y + y * ((a2 + a3) / 2)
190
+ # y = self.bn(y)
191
+ return y
192
+
193
+
194
+ class TCN_GCN_unit(nn.Module):
195
+ def __init__(self, in_channels, out_channels, A, stride=1, residual=True, adaptive=True, attention=True):
196
+ super(TCN_GCN_unit, self).__init__()
197
+ self.gcn1 = unit_gcn(in_channels, out_channels, A, adaptive=adaptive, attention=attention)
198
+ self.tcn1 = unit_tcn(out_channels, out_channels, stride=stride)
199
+ self.relu = nn.ReLU(inplace=True)
200
+ # if attention:
201
+ # self.alpha = nn.Parameter(torch.zeros(1))
202
+ # self.beta = nn.Parameter(torch.ones(1))
203
+ # temporal attention
204
+ # self.conv_ta1 = nn.Conv1d(out_channels, out_channels//rt, 9, padding=4)
205
+ # self.bn = nn.BatchNorm2d(out_channels)
206
+ # bn_init(self.bn, 1)
207
+ # self.conv_ta2 = nn.Conv1d(out_channels, 1, 9, padding=4)
208
+ # nn.init.kaiming_normal_(self.conv_ta1.weight)
209
+ # nn.init.constant_(self.conv_ta1.bias, 0)
210
+ # nn.init.constant_(self.conv_ta2.weight, 0)
211
+ # nn.init.constant_(self.conv_ta2.bias, 0)
212
+
213
+ # rt = 4
214
+ # self.inter_c = out_channels // rt
215
+ # self.conv_ta1 = nn.Conv2d(out_channels, out_channels // rt, 1)
216
+ # self.conv_ta2 = nn.Conv2d(out_channels, out_channels // rt, 1)
217
+ # nn.init.constant_(self.conv_ta1.weight, 0)
218
+ # nn.init.constant_(self.conv_ta1.bias, 0)
219
+ # nn.init.constant_(self.conv_ta2.weight, 0)
220
+ # nn.init.constant_(self.conv_ta2.bias, 0)
221
+ # s attention
222
+ # num_jpts = A.shape[-1]
223
+ # ker_jpt = num_jpts - 1 if not num_jpts % 2 else num_jpts
224
+ # pad = (ker_jpt - 1) // 2
225
+ # self.conv_sa = nn.Conv1d(out_channels, 1, ker_jpt, padding=pad)
226
+ # nn.init.constant_(self.conv_sa.weight, 0)
227
+ # nn.init.constant_(self.conv_sa.bias, 0)
228
+
229
+ # channel attention
230
+ # rr = 16
231
+ # self.fc1c = nn.Linear(out_channels, out_channels // rr)
232
+ # self.fc2c = nn.Linear(out_channels // rr, out_channels)
233
+ # nn.init.kaiming_normal_(self.fc1c.weight)
234
+ # nn.init.constant_(self.fc1c.bias, 0)
235
+ # nn.init.constant_(self.fc2c.weight, 0)
236
+ # nn.init.constant_(self.fc2c.bias, 0)
237
+ #
238
+ # self.softmax = nn.Softmax(-2)
239
+ # self.sigmoid = nn.Sigmoid()
240
+ self.attention = attention
241
+
242
+ if not residual:
243
+ self.residual = lambda x: 0
244
+
245
+ elif (in_channels == out_channels) and (stride == 1):
246
+ self.residual = lambda x: x
247
+
248
+ else:
249
+ self.residual = unit_tcn(in_channels, out_channels, kernel_size=1, stride=stride)
250
+
251
+ def forward(self, x):
252
+ if self.attention:
253
+ y = self.relu(self.tcn1(self.gcn1(x)) + self.residual(x))
254
+
255
+ # spatial attention
256
+ # se = y.mean(-2) # N C V
257
+ # se1 = self.sigmoid(self.conv_sa(se))
258
+ # y = y * se1.unsqueeze(-2) + y
259
+ # a1 = se1.unsqueeze(-2)
260
+
261
+ # temporal attention
262
+ # se = y.mean(-1) # N C T
263
+ # # se1 = self.relu(self.bn(self.conv_ta1(se)))
264
+ # se2 = self.sigmoid(self.conv_ta2(se))
265
+ # # y = y * se1.unsqueeze(-1) + y
266
+ # a2 = se2.unsqueeze(-1)
267
+
268
+ # se = y # NCTV
269
+ # N, C, T, V = y.shape
270
+ # se1 = self.conv_ta1(se).permute(0, 2, 1, 3).contiguous().view(N, T, self.inter_c * V) # NTCV
271
+ # se2 = self.conv_ta2(se).permute(0, 1, 3, 2).contiguous().view(N, self.inter_c * V, T) # NCVT
272
+ # a2 = self.softmax(torch.matmul(se1, se2) / np.sqrt(se1.size(-1))) # N T T
273
+ # y = torch.matmul(y.permute(0, 1, 3, 2).contiguous().view(N, C * V, T), a2) \
274
+ # .view(N, C, V, T).permute(0, 1, 3, 2) * self.alpha + y
275
+
276
+ # channel attention
277
+ # se = y.mean(-1).mean(-1)
278
+ # se1 = self.relu(self.fc1c(se))
279
+ # se2 = self.sigmoid(self.fc2c(se1))
280
+ # # y = y * se2.unsqueeze(-1).unsqueeze(-1) + y
281
+ # a3 = se2.unsqueeze(-1).unsqueeze(-1)
282
+ #
283
+ # y = y * ((a2 + a3) / 2) + y
284
+ # y = self.bn(y)
285
+ else:
286
+ y = self.relu(self.tcn1(self.gcn1(x)) + self.residual(x))
287
+ return y
288
+
289
+
290
+ class Model(pl.LightningModule):
291
+ def __init__(self, num_class=60, num_point=25, num_person=2, graph=None, graph_args=dict(), in_channels=3,
292
+ drop_out=0, adaptive=True, attention=True, learning_rate=1e-4, weight_decay=1e-4):
293
+ super(Model, self).__init__()
294
+
295
+ # if graph is None:
296
+ # raise ValueError()
297
+ # else:
298
+ # Graph = import_class(graph)
299
+ self.graph = Graph(**graph_args)
300
+
301
+ A = self.graph.A
302
+ self.num_class = num_class
303
+
304
+ self.data_bn = nn.BatchNorm1d(num_person * in_channels * num_point)
305
+
306
+ self.l1 = TCN_GCN_unit(in_channels, 64, A, residual=False, adaptive=adaptive, attention=attention)
307
+ self.l2 = TCN_GCN_unit(64, 64, A, adaptive=adaptive, attention=attention)
308
+ self.l3 = TCN_GCN_unit(64, 64, A, adaptive=adaptive, attention=attention)
309
+ self.l4 = TCN_GCN_unit(64, 64, A, adaptive=adaptive, attention=attention)
310
+ self.l5 = TCN_GCN_unit(64, 128, A, stride=2, adaptive=adaptive, attention=attention)
311
+ self.l6 = TCN_GCN_unit(128, 128, A, adaptive=adaptive, attention=attention)
312
+ self.l7 = TCN_GCN_unit(128, 128, A, adaptive=adaptive, attention=attention)
313
+ self.l8 = TCN_GCN_unit(128, 256, A, stride=2, adaptive=adaptive, attention=attention)
314
+ self.l9 = TCN_GCN_unit(256, 256, A, adaptive=adaptive, attention=attention)
315
+ self.l10 = TCN_GCN_unit(256, 256, A, adaptive=adaptive, attention=attention)
316
+ # self.l11 = TCN_GCN_unit(256, 512, A, stride=2, adaptive=adaptive, attention=attention)
317
+ # self.l12 = TCN_GCN_unit(512, 512, A, adaptive=adaptive, attention=attention)
318
+ # self.l13 = TCN_GCN_unit(512, 512, A, adaptive=adaptive, attention=attention)
319
+
320
+ self.fc = nn.Linear(256, num_class)
321
+ nn.init.normal_(self.fc.weight, 0, math.sqrt(2. / num_class))
322
+ bn_init(self.data_bn, 1)
323
+ if drop_out:
324
+ self.drop_out = nn.Dropout(drop_out)
325
+ else:
326
+ self.drop_out = lambda x: x
327
+
328
+ self.loss = nn.CrossEntropyLoss()
329
+ self.metric = MulticlassAccuracy(num_class)
330
+ # self.metric = BinaryAccuracy()
331
+ self.learning_rate = learning_rate
332
+ self.weight_decay = weight_decay
333
+ self.validation_step_loss_outputs = []
334
+ self.validation_step_acc_outputs = []
335
+
336
+ self.save_hyperparameters()
337
+
338
+ def forward(self, x):
339
+ N, C, T, V, M = x.size()
340
+
341
+ x = x.permute(0, 4, 3, 1, 2).contiguous().view(N, M * V * C, T)
342
+ x = self.data_bn(x.float())
343
+ x = x.view(N, M, V, C, T).permute(0, 1, 3, 4, 2).contiguous().view(N * M, C, T, V)
344
+
345
+ x = self.l1(x)
346
+ x = self.l2(x)
347
+ x = self.l3(x)
348
+ x = self.l4(x)
349
+ x = self.l5(x)
350
+ x = self.l6(x)
351
+ x = self.l7(x)
352
+ x = self.l8(x)
353
+ x = self.l9(x)
354
+ x = self.l10(x)
355
+ # x = self.l11(x)
356
+ # x = self.l12(x)
357
+ # x = self.l13(x)
358
+
359
+ # N*M,C,T,V
360
+ c_new = x.size(1)
361
+ x = x.view(N, M, c_new, -1)
362
+ x = x.mean(3).mean(1)
363
+ x = self.drop_out(x)
364
+
365
+ return self.fc(x)
366
+
367
+ def training_step(self, batch, batch_idx):
368
+ inputs, targets = batch
369
+ outputs = self(inputs)
370
+ y_pred_class = torch.argmax(torch.softmax(outputs, dim=1), dim=1)
371
+ # print("Targets : ", targets)
372
+ # print("Preds : ", y_pred_class)
373
+ train_accuracy = self.metric(y_pred_class, targets)
374
+ loss = self.loss(outputs, targets)
375
+ self.log('train_accuracy', train_accuracy, prog_bar=True, on_epoch=True)
376
+ self.log('train_loss', loss, prog_bar=True, on_epoch=True)
377
+ # return {"loss": loss, "train_accuracy" : train_accuracy}
378
+ return loss
379
+
380
+ def validation_step(self, batch, batch_idx):
381
+ inputs, targets = batch
382
+ outputs = self.forward(inputs)
383
+ y_pred_class = torch.argmax(torch.softmax(outputs, dim=1), dim=1)
384
+ valid_accuracy = self.metric(y_pred_class, targets)
385
+ loss = self.loss(outputs, targets)
386
+ self.log('valid_accuracy', valid_accuracy, prog_bar=True, on_epoch=True)
387
+ self.log('valid_loss', loss, prog_bar=True, on_epoch=True)
388
+ self.validation_step_loss_outputs.append(loss)
389
+ self.validation_step_acc_outputs.append(valid_accuracy)
390
+ return {"valid_loss" : loss, "valid_accuracy" : valid_accuracy}
391
+
392
+ def on_validation_epoch_end(self):
393
+ # avg_loss = torch.stack(
394
+ # [x["valid_loss"] for x in outputs]).mean()
395
+ # avg_acc = torch.stack(
396
+ # [x["valid_accuracy"] for x in outputs]).mean()
397
+ avg_loss = torch.stack(self.validation_step_loss_outputs).mean()
398
+ avg_acc = torch.stack(self.validation_step_acc_outputs).mean()
399
+ self.log("ptl/val_loss", avg_loss)
400
+ self.log("ptl/val_accuracy", avg_acc)
401
+ self.validation_step_loss_outputs.clear()
402
+ self.validation_step_acc_outputs.clear()
403
+
404
+ def test_step(self, batch, batch_idx):
405
+ inputs, targets = batch
406
+ outputs = self.forward(inputs)
407
+ y_pred_class = torch.argmax(torch.softmax(outputs, dim=1), dim=1)
408
+ print("Targets : ", targets)
409
+ print("Preds : ", y_pred_class)
410
+ test_accuracy = self.metric(y_pred_class, targets)
411
+ loss = self.loss(outputs, targets)
412
+ self.log('test_accuracy', test_accuracy, prog_bar=True, on_epoch=True)
413
+ self.log('test_loss', loss, prog_bar=True, on_epoch=True)
414
+ return {"test_loss" : loss, "test_accuracy" : test_accuracy}
415
+
416
+ def configure_optimizers(self):
417
+ params = self.parameters()
418
+ optimizer = optim.Adam(params=params, lr = self.learning_rate, weight_decay = self.weight_decay)
419
+ scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max')
420
+ return {"optimizer": optimizer,
421
+ "lr_scheduler": {"scheduler": scheduler, "monitor": "valid_accuracy"}
422
+ }
423
+ # return optimizer
424
+
425
+ def predict_step(self, batch, batch_idx):
426
+ return self(batch)
427
+
428
+ if __name__ == "__main__":
429
+ import os
430
+ from torchinfo import summary
431
+ print(os.getcwd())
432
+ device = "cuda" if torch.cuda.is_available() else "cpu"
433
+ model = Model(num_class=20, num_point=18, num_person=1,
434
+ graph_args={}, in_channels=2).to(device)
435
+ # print(model.device)
436
+ # N, C, T, V, M
437
+ summary(model)
438
+ x = torch.randn((1, 2, 80, 18, 1)).to(device)
439
+ y = model(x)
440
+ print(y.shape)
agcn.py ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ from torch.autograd import Variable
7
+ from graph import Graph
8
+ import pytorch_lightning as pl
9
+ from torchmetrics.classification import MulticlassAccuracy
10
+ import torch.optim as optim
11
+
12
+ def import_class(name):
13
+ components = name.split('.')
14
+ mod = __import__(components[0])
15
+ for comp in components[1:]:
16
+ mod = getattr(mod, comp)
17
+ return mod
18
+
19
+
20
+ def conv_branch_init(conv, branches):
21
+ weight = conv.weight
22
+ n = weight.size(0)
23
+ k1 = weight.size(1)
24
+ k2 = weight.size(2)
25
+ nn.init.normal_(weight, 0, math.sqrt(2. / (n * k1 * k2 * branches)))
26
+ nn.init.constant_(conv.bias, 0)
27
+
28
+
29
+ def conv_init(conv):
30
+ nn.init.kaiming_normal_(conv.weight, mode='fan_out')
31
+ nn.init.constant_(conv.bias, 0)
32
+
33
+
34
+ def bn_init(bn, scale):
35
+ nn.init.constant_(bn.weight, scale)
36
+ nn.init.constant_(bn.bias, 0)
37
+
38
+
39
+ class unit_tcn(nn.Module):
40
+ def __init__(self, in_channels, out_channels, kernel_size=9, stride=1):
41
+ super(unit_tcn, self).__init__()
42
+ pad = int((kernel_size - 1) / 2)
43
+ self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=(kernel_size, 1), padding=(pad, 0),
44
+ stride=(stride, 1))
45
+
46
+ self.bn = nn.BatchNorm2d(out_channels)
47
+ self.relu = nn.ReLU()
48
+ conv_init(self.conv)
49
+ bn_init(self.bn, 1)
50
+
51
+ def forward(self, x):
52
+ x = self.bn(self.conv(x))
53
+ return x
54
+
55
+
56
+ class unit_gcn(nn.Module):
57
+ def __init__(self, in_channels, out_channels, A, coff_embedding=4, num_subset=3):
58
+ super(unit_gcn, self).__init__()
59
+ inter_channels = out_channels // coff_embedding
60
+ self.inter_c = inter_channels
61
+ self.PA = nn.Parameter(torch.from_numpy(A.astype(np.float32)))
62
+ nn.init.constant_(self.PA, 1e-6)
63
+ self.A = Variable(torch.from_numpy(A.astype(np.float32)), requires_grad=False)
64
+ self.num_subset = num_subset
65
+
66
+ self.conv_a = nn.ModuleList()
67
+ self.conv_b = nn.ModuleList()
68
+ self.conv_d = nn.ModuleList()
69
+ for i in range(self.num_subset):
70
+ self.conv_a.append(nn.Conv2d(in_channels, inter_channels, 1))
71
+ self.conv_b.append(nn.Conv2d(in_channels, inter_channels, 1))
72
+ self.conv_d.append(nn.Conv2d(in_channels, out_channels, 1))
73
+
74
+ if in_channels != out_channels:
75
+ self.down = nn.Sequential(
76
+ nn.Conv2d(in_channels, out_channels, 1),
77
+ nn.BatchNorm2d(out_channels)
78
+ )
79
+ else:
80
+ self.down = lambda x: x
81
+
82
+ self.bn = nn.BatchNorm2d(out_channels)
83
+ self.soft = nn.Softmax(-2)
84
+ self.relu = nn.ReLU()
85
+
86
+ for m in self.modules():
87
+ if isinstance(m, nn.Conv2d):
88
+ conv_init(m)
89
+ elif isinstance(m, nn.BatchNorm2d):
90
+ bn_init(m, 1)
91
+ bn_init(self.bn, 1e-6)
92
+ for i in range(self.num_subset):
93
+ conv_branch_init(self.conv_d[i], self.num_subset)
94
+
95
+ def forward(self, x):
96
+ N, C, T, V = x.size()
97
+ A = self.A.cuda(x.get_device())
98
+ A = A + self.PA
99
+
100
+ y = None
101
+ for i in range(self.num_subset):
102
+ A1 = self.conv_a[i](x).permute(0, 3, 1, 2).contiguous().view(N, V, self.inter_c * T)
103
+ A2 = self.conv_b[i](x).view(N, self.inter_c * T, V)
104
+ A1 = self.soft(torch.matmul(A1, A2) / A1.size(-1)) # N V V
105
+ A1 = A1 + A[i]
106
+ A2 = x.view(N, C * T, V)
107
+ z = self.conv_d[i](torch.matmul(A2, A1).view(N, C, T, V))
108
+ y = z + y if y is not None else z
109
+
110
+ y = self.bn(y)
111
+ y += self.down(x)
112
+ return self.relu(y)
113
+
114
+
115
+ class TCN_GCN_unit(nn.Module):
116
+ def __init__(self, in_channels, out_channels, A, stride=1, residual=True):
117
+ super(TCN_GCN_unit, self).__init__()
118
+ self.gcn1 = unit_gcn(in_channels, out_channels, A)
119
+ self.tcn1 = unit_tcn(out_channels, out_channels, stride=stride)
120
+ self.relu = nn.ReLU()
121
+ if not residual:
122
+ self.residual = lambda x: 0
123
+
124
+ elif (in_channels == out_channels) and (stride == 1):
125
+ self.residual = lambda x: x
126
+
127
+ else:
128
+ self.residual = unit_tcn(in_channels, out_channels, kernel_size=1, stride=stride)
129
+
130
+ def forward(self, x):
131
+ x = self.tcn1(self.gcn1(x)) + self.residual(x)
132
+ return self.relu(x)
133
+
134
+
135
+ class Model(pl.LightningModule):
136
+ def __init__(self, num_class=60, num_point=25, num_person=2, graph=None, graph_args=dict(), in_channels=3,
137
+ learning_rate = 1e-4, weight_decay = 1e-4):
138
+ super(Model, self).__init__()
139
+
140
+ # if graph is None:
141
+ # raise ValueError()
142
+ # else:
143
+ # Graph = import_class(graph)
144
+ self.graph = Graph(**graph_args)
145
+
146
+ A = self.graph.A
147
+ # print(num_person * in_channels * num_point)
148
+ self.data_bn = nn.BatchNorm1d(num_person * in_channels * num_point)
149
+
150
+ self.l1 = TCN_GCN_unit(in_channels, 64, A, residual=False)
151
+ self.l2 = TCN_GCN_unit(64, 64, A)
152
+ self.l3 = TCN_GCN_unit(64, 64, A)
153
+ self.l4 = TCN_GCN_unit(64, 64, A)
154
+ self.l5 = TCN_GCN_unit(64, 128, A, stride=2)
155
+ self.l6 = TCN_GCN_unit(128, 128, A)
156
+ self.l7 = TCN_GCN_unit(128, 128, A)
157
+ self.l8 = TCN_GCN_unit(128, 256, A, stride=2)
158
+ self.l9 = TCN_GCN_unit(256, 256, A)
159
+ self.l10 = TCN_GCN_unit(256, 256, A)
160
+
161
+ self.fc = nn.Linear(256, num_class)
162
+ nn.init.normal_(self.fc.weight, 0, math.sqrt(2. / num_class))
163
+ bn_init(self.data_bn, 1)
164
+
165
+ self.loss = nn.CrossEntropyLoss()
166
+ self.metric = MulticlassAccuracy(num_class)
167
+ self.learning_rate = learning_rate
168
+ self.weight_decay = weight_decay
169
+ self.validation_step_loss_outputs = []
170
+ self.validation_step_acc_outputs = []
171
+
172
+ self.save_hyperparameters()
173
+
174
+ def forward(self, x):
175
+ # 0, 1, 2, 3, 4
176
+ N, C, T, V, M = x.size()
177
+ # print(f"N {N}, C {C}, T {T}, V {V}, M {M}")
178
+ # N, M, V, C, T
179
+ x = x.permute(0, 4, 3, 1, 2).contiguous().view(N, M * V * C, T)
180
+ # print(M*V*C)
181
+ x = self.data_bn(x)
182
+ x = x.view(N, M, V, C, T).permute(0, 1, 3, 4, 2).contiguous().view(N * M, C, T, V)
183
+
184
+ x = self.l1(x)
185
+ x = self.l2(x)
186
+ x = self.l3(x)
187
+ x = self.l4(x)
188
+ x = self.l5(x)
189
+ x = self.l6(x)
190
+ x = self.l7(x)
191
+ x = self.l8(x)
192
+ x = self.l9(x)
193
+ x = self.l10(x)
194
+
195
+ # N*M,C,T,V
196
+ c_new = x.size(1)
197
+ x = x.view(N, M, c_new, -1)
198
+ x = x.mean(3).mean(1)
199
+
200
+ return self.fc(x)
201
+
202
+ def training_step(self, batch, batch_idx):
203
+ inputs, targets = batch
204
+ outputs = self(inputs)
205
+ y_pred_class = torch.argmax(torch.softmax(outputs, dim=1), dim=1)
206
+ # print("Targets : ", targets)
207
+ # print("Preds : ", y_pred_class)
208
+ train_accuracy = self.metric(y_pred_class, targets)
209
+ loss = self.loss(outputs, targets)
210
+ self.log('train_accuracy', train_accuracy, prog_bar=True, on_epoch=True)
211
+ self.log('train_loss', loss, prog_bar=True, on_epoch=True)
212
+ # return {"loss": loss, "train_accuracy" : train_accuracy}
213
+ return loss
214
+
215
+ def validation_step(self, batch, batch_idx):
216
+ inputs, targets = batch
217
+ outputs = self.forward(inputs)
218
+ y_pred_class = torch.argmax(torch.softmax(outputs, dim=1), dim=1)
219
+ valid_accuracy = self.metric(y_pred_class, targets)
220
+ loss = self.loss(outputs, targets)
221
+ self.log('valid_accuracy', valid_accuracy, prog_bar=True, on_epoch=True)
222
+ self.log('valid_loss', loss, prog_bar=True, on_epoch=True)
223
+ self.validation_step_loss_outputs.append(loss)
224
+ self.validation_step_acc_outputs.append(valid_accuracy)
225
+ return {"valid_loss" : loss, "valid_accuracy" : valid_accuracy}
226
+
227
+ def on_validation_epoch_end(self):
228
+ # avg_loss = torch.stack(
229
+ # [x["valid_loss"] for x in outputs]).mean()
230
+ # avg_acc = torch.stack(
231
+ # [x["valid_accuracy"] for x in outputs]).mean()
232
+ avg_loss = torch.stack(self.validation_step_loss_outputs).mean()
233
+ avg_acc = torch.stack(self.validation_step_acc_outputs).mean()
234
+ self.log("ptl/val_loss", avg_loss)
235
+ self.log("ptl/val_accuracy", avg_acc)
236
+ self.validation_step_loss_outputs.clear()
237
+ self.validation_step_acc_outputs.clear()
238
+
239
+ def test_step(self, batch, batch_idx):
240
+ inputs, targets = batch
241
+ outputs = self.forward(inputs)
242
+ y_pred_class = torch.argmax(torch.softmax(outputs, dim=1), dim=1)
243
+ print("Targets : ", targets)
244
+ print("Preds : ", y_pred_class)
245
+ test_accuracy = self.metric(y_pred_class, targets)
246
+ loss = self.loss(outputs, targets)
247
+ self.log('test_accuracy', test_accuracy, prog_bar=True, on_epoch=True)
248
+ self.log('test_loss', loss, prog_bar=True, on_epoch=True)
249
+ return {"test_loss" : loss, "test_accuracy" : test_accuracy}
250
+
251
+ def configure_optimizers(self):
252
+ params = self.parameters()
253
+ optimizer = optim.Adam(params=params, lr = self.learning_rate, weight_decay = self.weight_decay)
254
+ scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max')
255
+ return {"optimizer": optimizer,
256
+ "lr_scheduler": {"scheduler": scheduler, "monitor": "valid_accuracy"}
257
+ }
258
+ # return optimizer
259
+
260
+ def predict_step(self, batch, batch_idx):
261
+ return self(batch)
262
+
263
+ if __name__ == "__main__":
264
+ import os
265
+ from torchinfo import summary
266
+ print(os.getcwd())
267
+ device = "cuda" if torch.cuda.is_available() else "cpu"
268
+ model = Model(num_class=20, num_point=25, num_person=1,
269
+ graph_args={"layout":"mediapipe", "strategy":"spatial"}, in_channels=2).to(device)
270
+ # print(model.device)
271
+ # N, C, T, V, M
272
+ summary(model)
273
+ x = torch.randn((1, 2, 80, 25, 1)).to(device)
274
+ y = model(x)
275
+ print(y.shape)
augumentation.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numpy as np
3
+ from numpy import random
4
+ from einops import einsum
5
+ import torch
6
+ class Rotate(object):
7
+ def __init__(self, range_angle, num_frames, num_nodes, point):
8
+ self.range_angle = range_angle
9
+ self.num_frames = num_frames
10
+ self.num_nodes = num_nodes
11
+ self.point = point
12
+ def __call__(self, data, label):
13
+ # data, label = sample
14
+ data = data.double()
15
+ angle = math.radians(random.uniform((-1)*self.range_angle, self.range_angle))
16
+ rotation_matrix = torch.Tensor([[math.cos(angle), (-1)*math.sin(angle)],
17
+ [math.sin(angle), math.cos(angle)]])
18
+ # print(type(rotation_matrix))
19
+ ox, oy = self.point
20
+ data[0, :, :] -= ox
21
+ data[1, :, :] -= oy
22
+
23
+ result = einsum(rotation_matrix.double(), data, "a b, b c d e -> a c d e") + 0.5
24
+
25
+ return result, label
26
+
27
+
28
+ class Left(object):
29
+ def __init__(self, width):
30
+ self.width = width
31
+ def __call__(self, data, label):
32
+ idx = find_frames(data)
33
+ p = random.random()
34
+ if p > 0.5:
35
+ data[0, :idx, :] -= self.width
36
+ return data, label
37
+
38
+ class Right(object):
39
+ def __init__(self, width):
40
+ self.width = width
41
+ def __call__(self, data, label):
42
+ idx = find_frames(data)
43
+ p = random.random()
44
+ if p > 0.5:
45
+ data[0, :idx, :] += self.width
46
+ return data, label
47
+
48
+ class GaussianNoise(object):
49
+ def __init__(self, mean, var):
50
+ self.mean = mean
51
+ self.var = var
52
+ def __call__(self, data, label):
53
+ # C, T, V, 1
54
+ print(data.size())
55
+ noise = torch.randn(size = data.size())
56
+ data = data + noise
57
+ return data, label
58
+
59
+ class Compose(object):
60
+ def __init__(self, transforms):
61
+ self.transforms = transforms
62
+
63
+ def __call__(self, data, label):
64
+ for t in self.transforms:
65
+ data, label = t(data, label)
66
+ return data, label
67
+
68
+ def find_frames(data):
69
+ for i in range(data.shape[1]):
70
+ if(data[:, i, :][0][0] == 0):
71
+ # print(i)
72
+ return i
73
+ return data.shape[1]
epoch=55-valid_loss=0.41-valid_accuracy=0.85-autsl-aagcn.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0ede54d8c48503897bdad5d8f680cfe1d24d1960242d11e1c6a704fb7bbb1dbb
3
+ size 47248523
feeder.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import Dataset
3
+ from pathlib import Path
4
+ import numpy as np
5
+ import random
6
+ from einops import rearrange
7
+ from augumentation import Rotate
8
+ from torch.utils.data import random_split
9
+
10
+ # Class read npy and pickle file to make data and label in couple
11
+ class FeederINCLUDE(Dataset):
12
+ """ Feeder for skeleton-based action recognition
13
+ Arguments:
14
+ data_path: the path to '.npy' data, the shape of data should be (N, C, T, V, M)
15
+ label_path: the path to label
16
+ window_size: The length of the output sequence
17
+ """
18
+ def __init__(self, data_path: Path, label_path: Path, transform = None):
19
+ super(FeederINCLUDE, self).__init__
20
+ self.data_path = data_path
21
+ self.label_path = label_path
22
+ self.transform = transform
23
+ self.load_data()
24
+
25
+ def load_data(self):
26
+ # data: N C V T M
27
+ # Load label with numpy
28
+ self.label = np.load(self.label_path)
29
+ # load data
30
+ self.data = np.load(self.data_path)
31
+ self.N, self.C, self.T, self.V, self.M = self.data.shape
32
+
33
+ def __getitem__(self, index):
34
+ """
35
+ Input shape (N, C, V, T, M)
36
+ N : batch size
37
+ C : numbers of features
38
+ V : numbers of joints (as nodes)
39
+ T : numbers of frames
40
+ M : numbers of people (should delete)
41
+
42
+ Output shape (C, V, T, M)
43
+ C : numbers of features
44
+ V : numbers of joints (as nodes)
45
+ T : numbers of frames
46
+ label : label of videos
47
+ """
48
+ data_numpy = torch.tensor(self.data[index]).float()
49
+ # Delete one dimension
50
+ # data_numpy = data_numpy[:, :, :2]
51
+ # data_numpy = rearrange(data_numpy, ' t v c 1 -> c t v 1')
52
+ label = self.label[index]
53
+ p = random.random()
54
+ if self.transform and p > 0.5:
55
+ data_numpy, label = self.transform(data_numpy, label)
56
+ return data_numpy, label
57
+
58
+ def __len__(self):
59
+ return len(self.label)
60
+
61
+ if __name__ == '__main__':
62
+ file, label = np.load("wsl100_train_data_preprocess.npy"), np.load("wsl100_train_label_preprocess.npy")
63
+ print(file.shape, label.shape)
64
+ data = FeederINCLUDE(data_path=f"wsl100_train_data_preprocess.npy", label_path=f"wsl100_train_data_preprocess.npy",
65
+ transform=None)
66
+ # test_dataset = FeederINCLUDE(data_path=f"data/vsl100_test_data_preprocess.npy", label_path=f"data/vsl100_test_label_preprocess.npy")
67
+ # valid_dataset = FeederINCLUDE(data_path=f"data/vsl100_valid_data_preprocess.npy", label_path=f"data/vsl100_valid_label_preprocess.npy")
68
+ # data = FeederINCLUDE(data_path=f"data/vsl100_test_data_preprocess.npy", label_path=f"data/vsl100_test_label_preprocess.npy",
69
+ # transform=None)
70
+ print(data.N, data.C, data.T, data.V, data.M)
71
+ print(data.data.shape)
72
+ print(data.__len__())
graph.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from enum import Enum
3
+
4
+ class BodyIdentifier(Enum):
5
+ INDEX_FINGER_DIP_right = 0
6
+ INDEX_FINGER_MCP_right = 1
7
+ INDEX_FINGER_PIP_right = 2
8
+ INDEX_FINGER_TIP_right = 3
9
+ MIDDLE_FINGER_DIP_right = 4
10
+ MIDDLE_FINGER_MCP_right = 5
11
+ MIDDLE_FINGER_PIP_right = 6
12
+ MIDDLE_FINGER_TIP_right = 7
13
+ PINKY_DIP_right = 8
14
+ PINKY_MCP_right = 9
15
+ PINKY_PIP_right = 10
16
+ PINKY_TIP_right = 11
17
+ RING_FINGER_DIP_right = 12
18
+ RING_FINGER_MCP_right = 13
19
+ RING_FINGER_PIP_right = 14
20
+ RING_FINGER_TIP_right = 15
21
+ THUMB_CMC_right = 16
22
+ THUMB_IP_right = 17
23
+ THUMB_MCP_right = 18
24
+ THUMB_TIP_right = 19
25
+ WRIST_right = 20
26
+ INDEX_FINGER_DIP_left = 21
27
+ INDEX_FINGER_MCP_left = 22
28
+ INDEX_FINGER_PIP_left = 23
29
+ INDEX_FINGER_TIP_left = 24
30
+ MIDDLE_FINGER_DIP_left = 25
31
+ MIDDLE_FINGER_MCP_left = 26
32
+ MIDDLE_FINGER_PIP_left = 27
33
+ MIDDLE_FINGER_TIP_left = 28
34
+ PINKY_DIP_left = 29
35
+ PINKY_MCP_left = 30
36
+ PINKY_PIP_left = 31
37
+ PINKY_TIP_left = 32
38
+ RING_FINGER_DIP_left = 33
39
+ RING_FINGER_MCP_left = 34
40
+ RING_FINGER_PIP_left = 35
41
+ RING_FINGER_TIP_left = 36
42
+ THUMB_CMC_left = 37
43
+ THUMB_IP_left = 38
44
+ THUMB_MCP_left = 39
45
+ THUMB_TIP_left = 40
46
+ WRIST_left = 41
47
+ RIGHT_SHOULDER = 42
48
+ LEFT_SHOULDER = 43
49
+ LEFT_ELBOW = 44
50
+ RIGHT_ELBOW = 45
51
+
52
+ class Graph():
53
+ """ The Graph to model the skeletons extracted by the openpose
54
+
55
+ Args:
56
+ strategy (string): must be one of the follow candidates
57
+ - uniform: Uniform Labeling
58
+ - distance: Distance Partitioning
59
+ - spatial: Spatial Configuration
60
+ For more information, please refer to the section 'Partition Strategies'
61
+ in our paper (https://arxiv.org/abs/1801.07455).
62
+
63
+ layout (string): must be one of the follow candidates
64
+ - openpose: Is consists of 18 joints. For more information, please
65
+ refer to https://github.com/CMU-Perceptual-Computing-Lab/openpose#output
66
+ - ntu-rgb+d: Is consists of 25 joints. For more information, please
67
+ refer to https://github.com/shahroudy/NTURGB-D
68
+
69
+ max_hop (int): the maximal distance between two connected nodes
70
+ dilation (int): controls the spacing between the kernel points
71
+
72
+ """
73
+
74
+ def __init__(self,
75
+ layout='openpose',
76
+ strategy='uniform',
77
+ max_hop=1,
78
+ dilation=1):
79
+ self.max_hop = max_hop
80
+ self.dilation = dilation
81
+
82
+ self.get_edge(layout)
83
+ self.hop_dis = get_hop_distance(
84
+ self.num_node, self.edge, max_hop=max_hop)
85
+ self.get_adjacency(strategy)
86
+
87
+ def __str__(self):
88
+ return self.A
89
+
90
+ def get_edge(self, layout):
91
+ if layout == 'openpose':
92
+ self.num_node = 18
93
+ self_link = [(i, i) for i in range(self.num_node)]
94
+ neighbor_link = [(4, 3), (3, 2), (7, 6), (6, 5), (13, 12), (12,
95
+ 11),
96
+ (10, 9), (9, 8), (11, 5), (8, 2), (5, 1), (2, 1),
97
+ (0, 1), (15, 0), (14, 0), (17, 15), (16, 14)]
98
+ self.edge = self_link + neighbor_link
99
+ self.center = 1
100
+ elif layout == 'ntu-rgb+d':
101
+ self.num_node = 25
102
+ self_link = [(i, i) for i in range(self.num_node)]
103
+ neighbor_1base = [(1, 2), (2, 21), (3, 21), (4, 3), (5, 21),
104
+ (6, 5), (7, 6), (8, 7), (9, 21), (10, 9),
105
+ (11, 10), (12, 11), (13, 1), (14, 13), (15, 14),
106
+ (16, 15), (17, 1), (18, 17), (19, 18), (20, 19),
107
+ (22, 23), (23, 8), (24, 25), (25, 12)]
108
+ neighbor_link = [(i - 1, j - 1) for (i, j) in neighbor_1base]
109
+ self.edge = self_link + neighbor_link
110
+ self.center = 21 - 1
111
+ elif layout == 'ntu_edge':
112
+ self.num_node = 24
113
+ self_link = [(i, i) for i in range(self.num_node)]
114
+ neighbor_1base = [(1, 2), (3, 2), (4, 3), (5, 2), (6, 5), (7, 6),
115
+ (8, 7), (9, 2), (10, 9), (11, 10), (12, 11),
116
+ (13, 1), (14, 13), (15, 14), (16, 15), (17, 1),
117
+ (18, 17), (19, 18), (20, 19), (21, 22), (22, 8),
118
+ (23, 24), (24, 12)]
119
+ neighbor_link = [(i - 1, j - 1) for (i, j) in neighbor_1base]
120
+ self.edge = self_link + neighbor_link
121
+ self.center = 2
122
+ elif layout == 'mediapipe':
123
+ self.num_node = 25
124
+ self_link = [(i, i) for i in range(self.num_node)]
125
+ neighbor_1base = [(20, 18), (18, 16), (20, 16), (16, 22), (16, 14), (14, 12),
126
+ (19, 17), (17, 15), (19, 15), (15, 21), (15, 13), (13, 11),
127
+ (12, 11), (12, 24), (24, 23), (23, 11),
128
+ (10, 9),
129
+ (0, 4), (4, 5), (5, 6), (6, 8),
130
+ (0, 1), (1, 2), (2, 3), (3, 7)]
131
+ neighbor_link = [(i - 1, j - 1) for (i, j) in neighbor_1base]
132
+ self.edge = self_link + neighbor_link
133
+ self.center = 10
134
+
135
+ elif layout == "mediapipe_two_hand":
136
+ self.num_node = 46
137
+ self_link = [(i, i) for i in range(self.num_node)]
138
+ neighbor_1base = [(BodyIdentifier.WRIST_left.value, BodyIdentifier.THUMB_CMC_left.value),
139
+ (BodyIdentifier.THUMB_CMC_left.value, BodyIdentifier.THUMB_MCP_left.value),
140
+ (BodyIdentifier.THUMB_MCP_left.value, BodyIdentifier.THUMB_IP_left.value),
141
+ (BodyIdentifier.THUMB_IP_left.value, BodyIdentifier.THUMB_TIP_left.value),
142
+
143
+ (BodyIdentifier.WRIST_left.value, BodyIdentifier.INDEX_FINGER_MCP_left.value),
144
+ (BodyIdentifier.INDEX_FINGER_MCP_left.value, BodyIdentifier.INDEX_FINGER_PIP_left.value),
145
+ (BodyIdentifier.INDEX_FINGER_PIP_left.value, BodyIdentifier.INDEX_FINGER_DIP_left.value),
146
+ (BodyIdentifier.INDEX_FINGER_DIP_left.value, BodyIdentifier.INDEX_FINGER_TIP_left.value),
147
+
148
+ (BodyIdentifier.INDEX_FINGER_MCP_left.value, BodyIdentifier.MIDDLE_FINGER_MCP_left.value),
149
+ (BodyIdentifier.MIDDLE_FINGER_MCP_left.value, BodyIdentifier.MIDDLE_FINGER_PIP_left.value),
150
+ (BodyIdentifier.MIDDLE_FINGER_PIP_left.value, BodyIdentifier.MIDDLE_FINGER_DIP_left.value),
151
+ (BodyIdentifier.MIDDLE_FINGER_DIP_left.value, BodyIdentifier.MIDDLE_FINGER_TIP_left.value),
152
+
153
+ (BodyIdentifier.MIDDLE_FINGER_MCP_left.value, BodyIdentifier.RING_FINGER_MCP_left.value),
154
+ (BodyIdentifier.RING_FINGER_MCP_left.value, BodyIdentifier.RING_FINGER_PIP_left.value),
155
+ (BodyIdentifier.RING_FINGER_PIP_left.value, BodyIdentifier.RING_FINGER_DIP_left.value),
156
+ (BodyIdentifier.RING_FINGER_DIP_left.value, BodyIdentifier.RING_FINGER_TIP_left.value),
157
+
158
+ (BodyIdentifier.WRIST_left.value, BodyIdentifier.PINKY_MCP_left.value),
159
+ (BodyIdentifier.PINKY_MCP_left.value, BodyIdentifier.PINKY_PIP_left.value),
160
+ (BodyIdentifier.PINKY_PIP_left.value, BodyIdentifier.PINKY_DIP_left.value),
161
+ (BodyIdentifier.PINKY_DIP_left.value, BodyIdentifier.PINKY_TIP_left.value),
162
+
163
+ # RIGHT HAND
164
+ (BodyIdentifier.WRIST_right.value, BodyIdentifier.THUMB_CMC_right.value),
165
+ (BodyIdentifier.THUMB_CMC_right.value, BodyIdentifier.THUMB_MCP_right.value),
166
+ (BodyIdentifier.THUMB_MCP_right.value, BodyIdentifier.THUMB_IP_right.value),
167
+ (BodyIdentifier.THUMB_IP_right.value, BodyIdentifier.THUMB_TIP_right.value),
168
+
169
+ (BodyIdentifier.WRIST_right.value, BodyIdentifier.INDEX_FINGER_MCP_right.value),
170
+ (BodyIdentifier.INDEX_FINGER_MCP_right.value, BodyIdentifier.INDEX_FINGER_PIP_right.value),
171
+ (BodyIdentifier.INDEX_FINGER_PIP_right.value, BodyIdentifier.INDEX_FINGER_DIP_right.value),
172
+ (BodyIdentifier.INDEX_FINGER_DIP_right.value, BodyIdentifier.INDEX_FINGER_TIP_right.value),
173
+
174
+ (BodyIdentifier.INDEX_FINGER_MCP_right.value, BodyIdentifier.MIDDLE_FINGER_MCP_right.value),
175
+ (BodyIdentifier.MIDDLE_FINGER_MCP_right.value, BodyIdentifier.MIDDLE_FINGER_PIP_right.value),
176
+ (BodyIdentifier.MIDDLE_FINGER_PIP_right.value, BodyIdentifier.MIDDLE_FINGER_DIP_right.value),
177
+ (BodyIdentifier.MIDDLE_FINGER_DIP_right.value, BodyIdentifier.MIDDLE_FINGER_TIP_right.value),
178
+
179
+ (BodyIdentifier.MIDDLE_FINGER_MCP_right.value, BodyIdentifier.RING_FINGER_MCP_right.value),
180
+ (BodyIdentifier.RING_FINGER_MCP_right.value, BodyIdentifier.RING_FINGER_PIP_right.value),
181
+ (BodyIdentifier.RING_FINGER_PIP_right.value, BodyIdentifier.RING_FINGER_DIP_right.value),
182
+ (BodyIdentifier.RING_FINGER_DIP_right.value, BodyIdentifier.RING_FINGER_TIP_right.value),
183
+
184
+ (BodyIdentifier.WRIST_right.value, BodyIdentifier.PINKY_MCP_right.value),
185
+ (BodyIdentifier.PINKY_MCP_right.value, BodyIdentifier.PINKY_PIP_right.value),
186
+ (BodyIdentifier.PINKY_PIP_right.value, BodyIdentifier.PINKY_DIP_right.value),
187
+ (BodyIdentifier.PINKY_DIP_right.value, BodyIdentifier.PINKY_TIP_right.value),
188
+
189
+ # 2 HAND + SHOULDER + ELBOW
190
+ (BodyIdentifier.RIGHT_SHOULDER.value, BodyIdentifier.RIGHT_ELBOW.value),
191
+ (BodyIdentifier.RIGHT_ELBOW.value, BodyIdentifier.WRIST_right.value),
192
+
193
+ (BodyIdentifier.RIGHT_SHOULDER.value, BodyIdentifier.LEFT_SHOULDER.value),
194
+
195
+ (BodyIdentifier.LEFT_SHOULDER.value, BodyIdentifier.LEFT_ELBOW.value),
196
+ (BodyIdentifier.LEFT_ELBOW.value, BodyIdentifier.WRIST_left.value)]
197
+
198
+ neighbor_link = [(i, j) for (i, j) in neighbor_1base]
199
+ self.edge = self_link + neighbor_link
200
+ self.center = BodyIdentifier.RIGHT_SHOULDER.value
201
+ # elif layout=='customer settings'
202
+ # pass
203
+ else:
204
+ raise ValueError("Do Not Exist This Layout.")
205
+
206
+ def get_adjacency(self, strategy):
207
+ valid_hop = range(0, self.max_hop + 1, self.dilation)
208
+ adjacency = np.zeros((self.num_node, self.num_node))
209
+ for hop in valid_hop:
210
+ adjacency[self.hop_dis == hop] = 1
211
+ normalize_adjacency = normalize_digraph(adjacency)
212
+
213
+ if strategy == 'uniform':
214
+ A = np.zeros((1, self.num_node, self.num_node))
215
+ A[0] = normalize_adjacency
216
+ self.A = A
217
+ elif strategy == 'distance':
218
+ A = np.zeros((len(valid_hop), self.num_node, self.num_node))
219
+ for i, hop in enumerate(valid_hop):
220
+ A[i][self.hop_dis == hop] = normalize_adjacency[self.hop_dis ==
221
+ hop]
222
+ self.A = A
223
+ elif strategy == 'spatial':
224
+ A = []
225
+ for hop in valid_hop:
226
+ a_root = np.zeros((self.num_node, self.num_node))
227
+ a_close = np.zeros((self.num_node, self.num_node))
228
+ a_further = np.zeros((self.num_node, self.num_node))
229
+ for i in range(self.num_node):
230
+ for j in range(self.num_node):
231
+ if self.hop_dis[j, i] == hop:
232
+ if self.hop_dis[j, self.center] == self.hop_dis[
233
+ i, self.center]:
234
+ a_root[j, i] = normalize_adjacency[j, i]
235
+ elif self.hop_dis[j, self.
236
+ center] > self.hop_dis[i, self.
237
+ center]:
238
+ a_close[j, i] = normalize_adjacency[j, i]
239
+ else:
240
+ a_further[j, i] = normalize_adjacency[j, i]
241
+ if hop == 0:
242
+ A.append(a_root)
243
+ else:
244
+ A.append(a_root + a_close)
245
+ A.append(a_further)
246
+ A = np.stack(A)
247
+ self.A = A
248
+ else:
249
+ raise ValueError("Do Not Exist This Strategy")
250
+
251
+
252
+ def get_hop_distance(num_node, edge, max_hop=1):
253
+ A = np.zeros((num_node, num_node))
254
+ print(edge)
255
+ for i, j in edge:
256
+ A[j, i] = 1
257
+ A[i, j] = 1
258
+
259
+ # compute hop steps
260
+ hop_dis = np.zeros((num_node, num_node)) + np.inf
261
+ transfer_mat = [np.linalg.matrix_power(A, d) for d in range(max_hop + 1)]
262
+ arrive_mat = (np.stack(transfer_mat) > 0)
263
+ for d in range(max_hop, -1, -1):
264
+ hop_dis[arrive_mat[d]] = d
265
+ return hop_dis
266
+
267
+
268
+ def normalize_digraph(A):
269
+ Dl = np.sum(A, 0)
270
+ num_node = A.shape[0]
271
+ Dn = np.zeros((num_node, num_node))
272
+ for i in range(num_node):
273
+ if Dl[i] > 0:
274
+ Dn[i, i] = Dl[i]**(-1)
275
+ AD = np.dot(A, Dn)
276
+ return AD
277
+
278
+
279
+ def normalize_undigraph(A):
280
+ Dl = np.sum(A, 0)
281
+ num_node = A.shape[0]
282
+ Dn = np.zeros((num_node, num_node))
283
+ for i in range(num_node):
284
+ if Dl[i] > 0:
285
+ Dn[i, i] = Dl[i]**(-0.5)
286
+ DAD = np.dot(np.dot(Dn, A), Dn)
287
+ return DAD