File size: 2,892 Bytes
a6976f4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
import torch
from torch.utils.data import Dataset
from pathlib import Path
import numpy as np
import random
from einops import rearrange
from augumentation import Rotate
from torch.utils.data import random_split
# Class read npy and pickle file to make data and label in couple
class FeederINCLUDE(Dataset):
""" Feeder for skeleton-based action recognition
Arguments:
data_path: the path to '.npy' data, the shape of data should be (N, C, T, V, M)
label_path: the path to label
window_size: The length of the output sequence
"""
def __init__(self, data_path: Path, label_path: Path, transform = None):
super(FeederINCLUDE, self).__init__
self.data_path = data_path
self.label_path = label_path
self.transform = transform
self.load_data()
def load_data(self):
# data: N C V T M
# Load label with numpy
self.label = np.load(self.label_path)
# load data
self.data = np.load(self.data_path)
self.N, self.C, self.T, self.V, self.M = self.data.shape
def __getitem__(self, index):
"""
Input shape (N, C, V, T, M)
N : batch size
C : numbers of features
V : numbers of joints (as nodes)
T : numbers of frames
M : numbers of people (should delete)
Output shape (C, V, T, M)
C : numbers of features
V : numbers of joints (as nodes)
T : numbers of frames
label : label of videos
"""
data_numpy = torch.tensor(self.data[index]).float()
# Delete one dimension
# data_numpy = data_numpy[:, :, :2]
# data_numpy = rearrange(data_numpy, ' t v c 1 -> c t v 1')
label = self.label[index]
p = random.random()
if self.transform and p > 0.5:
data_numpy, label = self.transform(data_numpy, label)
return data_numpy, label
def __len__(self):
return len(self.label)
if __name__ == '__main__':
file, label = np.load("wsl100_train_data_preprocess.npy"), np.load("wsl100_train_label_preprocess.npy")
print(file.shape, label.shape)
data = FeederINCLUDE(data_path=f"wsl100_train_data_preprocess.npy", label_path=f"wsl100_train_data_preprocess.npy",
transform=None)
# test_dataset = FeederINCLUDE(data_path=f"data/vsl100_test_data_preprocess.npy", label_path=f"data/vsl100_test_label_preprocess.npy")
# valid_dataset = FeederINCLUDE(data_path=f"data/vsl100_valid_data_preprocess.npy", label_path=f"data/vsl100_valid_label_preprocess.npy")
# data = FeederINCLUDE(data_path=f"data/vsl100_test_data_preprocess.npy", label_path=f"data/vsl100_test_label_preprocess.npy",
# transform=None)
print(data.N, data.C, data.T, data.V, data.M)
print(data.data.shape)
print(data.__len__())
|