|
import torch |
|
from torch.utils.data import Dataset |
|
from pathlib import Path |
|
import numpy as np |
|
import random |
|
from einops import rearrange |
|
from augumentation import Rotate |
|
from torch.utils.data import random_split |
|
|
|
|
|
class FeederINCLUDE(Dataset): |
|
""" Feeder for skeleton-based action recognition |
|
Arguments: |
|
data_path: the path to '.npy' data, the shape of data should be (N, C, T, V, M) |
|
label_path: the path to label |
|
window_size: The length of the output sequence |
|
""" |
|
def __init__(self, data_path: Path, label_path: Path, transform = None): |
|
super(FeederINCLUDE, self).__init__ |
|
self.data_path = data_path |
|
self.label_path = label_path |
|
self.transform = transform |
|
self.load_data() |
|
|
|
def load_data(self): |
|
|
|
|
|
self.label = np.load(self.label_path) |
|
|
|
self.data = np.load(self.data_path) |
|
self.N, self.C, self.T, self.V, self.M = self.data.shape |
|
|
|
def __getitem__(self, index): |
|
""" |
|
Input shape (N, C, V, T, M) |
|
N : batch size |
|
C : numbers of features |
|
V : numbers of joints (as nodes) |
|
T : numbers of frames |
|
M : numbers of people (should delete) |
|
|
|
Output shape (C, V, T, M) |
|
C : numbers of features |
|
V : numbers of joints (as nodes) |
|
T : numbers of frames |
|
label : label of videos |
|
""" |
|
data_numpy = torch.tensor(self.data[index]).float() |
|
|
|
|
|
|
|
label = self.label[index] |
|
p = random.random() |
|
if self.transform and p > 0.5: |
|
data_numpy, label = self.transform(data_numpy, label) |
|
return data_numpy, label |
|
|
|
def __len__(self): |
|
return len(self.label) |
|
|
|
if __name__ == '__main__': |
|
file, label = np.load("wsl100_train_data_preprocess.npy"), np.load("wsl100_train_label_preprocess.npy") |
|
print(file.shape, label.shape) |
|
data = FeederINCLUDE(data_path=f"wsl100_train_data_preprocess.npy", label_path=f"wsl100_train_data_preprocess.npy", |
|
transform=None) |
|
|
|
|
|
|
|
|
|
print(data.N, data.C, data.T, data.V, data.M) |
|
print(data.data.shape) |
|
print(data.__len__()) |
|
|