Spaces:
Runtime error
Runtime error
| import random | |
| import pickle | |
| import logging | |
| import torch | |
| import cv2 | |
| import os | |
| from torch.utils.data.dataset import Dataset | |
| import numpy as np | |
| from skimage.feature import canny | |
| from .util.STTN_mask import create_random_shape_with_random_motion | |
| from cvbase import read_flow, flow2rgb | |
| from .util.flow_utils import region_fill as rf | |
| import imageio | |
| logger = logging.getLogger('base') | |
| class VideoBasedDataset(Dataset): | |
| def __init__(self, opt, dataInfo): | |
| self.opt = opt | |
| self.mode = opt['mode'] | |
| self.dataInfo = dataInfo | |
| self.flow_height, self.flow_width = dataInfo['flow']['flow_height'], dataInfo['flow']['flow_width'] | |
| self.data_path = dataInfo['flow_path'] | |
| self.frame_path = dataInfo['frame_path'] | |
| self.train_list = os.listdir(self.data_path) | |
| self.name2length = self.dataInfo['name2len'] | |
| self.require_edge = opt['use_edges'] | |
| self.sigma = dataInfo['edge']['sigma'] | |
| self.low_threshold = dataInfo['edge']['low_threshold'] | |
| self.high_threshold = dataInfo['edge']['high_threshold'] | |
| with open(self.name2length, 'rb') as f: | |
| self.name2len = pickle.load(f) | |
| self.norm = opt['norm'] | |
| self.ternary_loss = opt.get('ternary', 0) | |
| def __len__(self): | |
| return len(self.train_list) | |
| def __getitem__(self, idx): | |
| try: | |
| item = self.load_item(idx) | |
| except: | |
| print('Loading error: ' + self.train_list[idx]) | |
| item = self.load_item(0) | |
| return item | |
| def frameSample(self, flowLen): | |
| pivot = random.randint(0, flowLen - 1) | |
| return pivot | |
| def load_item(self, idx): | |
| info = {} | |
| video = self.train_list[idx] | |
| info['name'] = video | |
| if np.random.uniform(0, 1) > 0.5: | |
| direction = 'forward_flo' | |
| else: | |
| direction = 'backward_flo' | |
| flow_dir = os.path.join(self.data_path, video, direction) | |
| frame_dir = os.path.join(self.frame_path, video) | |
| flowLen = self.name2len[video] - 1 | |
| pivot = self.frameSample(flowLen) | |
| # generate random masks | |
| candidateMasks = create_random_shape_with_random_motion(1, 0.9, 1.1, 1, | |
| 10) | |
| # read the flows and masks | |
| flow = read_flow(os.path.join(flow_dir, '{:05d}.flo'.format(pivot))) | |
| mask = self.read_mask(candidateMasks[0], self.flow_height, self.flow_width) | |
| flow = self.flow_tf(flow, self.flow_height, self.flow_width) | |
| diffused_flow = self.diffusion_fill(flow, mask) | |
| current_frame, shift_frame = self.read_frames(frame_dir, pivot, direction, self.flow_width, | |
| self.flow_height) | |
| edge = self.load_edge(flow) | |
| inputs = {'flows': flow, 'diffused_flows': diffused_flow, 'current_frame': current_frame, | |
| 'shift_frame': shift_frame, 'edges': edge, 'masks': mask} | |
| return self.to_tensor(inputs) | |
| def read_frames(self, frame_dir, index, direction, width, height): | |
| if direction == 'forward_flo': | |
| current_frame = os.path.join(frame_dir, '{:05d}.jpg'.format(index)) | |
| shift_frame = os.path.join(frame_dir, '{:05d}.jpg'.format(index + 1)) | |
| else: | |
| current_frame = os.path.join(frame_dir, '{:05d}.jpg'.format(index + 1)) | |
| shift_frame = os.path.join(frame_dir, '{:05d}.jpg'.format(index)) | |
| current_frame = imageio.imread(current_frame) | |
| shift_frame = imageio.imread(shift_frame) | |
| current_frame = cv2.resize(current_frame, (width, height), cv2.INTER_LINEAR) | |
| shift_frame = cv2.resize(shift_frame, (width, height), cv2.INTER_LINEAR) | |
| current_frame = current_frame / 255. | |
| shift_frame = shift_frame / 255. | |
| return current_frame, shift_frame | |
| def diffusion_fill(self, flow, mask): | |
| flow_filled = np.zeros(flow.shape) | |
| flow_filled[:, :, 0] = rf.regionfill(flow[:, :, 0] * (1 - mask), mask) | |
| flow_filled[:, :, 1] = rf.regionfill(flow[:, :, 1] * (1 - mask), mask) | |
| return flow_filled | |
| def flow_tf(self, flow, height, width): | |
| flow_shape = flow.shape | |
| flow_resized = cv2.resize(flow, (width, height), cv2.INTER_LINEAR) | |
| flow_resized[:, :, 0] *= (float(width) / float(flow_shape[1])) | |
| flow_resized[:, :, 1] *= (float(height) / float(flow_shape[0])) | |
| return flow_resized | |
| def read_mask(self, mask, height, width): | |
| mask = np.array(mask) | |
| mask = mask / 255. | |
| raw_mask = (mask > 0.5).astype(np.uint8) | |
| raw_mask = cv2.resize(raw_mask, dsize=(width, height), interpolation=cv2.INTER_NEAREST) | |
| return raw_mask | |
| def load_edge(self, flow): | |
| gray_flow = (flow[:, :, 0] ** 2 + flow[:, :, 1] ** 2) ** 0.5 | |
| factor = gray_flow.max() | |
| gray_flow = gray_flow / factor | |
| flow_rgb = flow2rgb(flow) | |
| flow_gray = cv2.cvtColor(flow_rgb, cv2.COLOR_RGB2GRAY) | |
| return canny(flow_gray, sigma=self.sigma, mask=None, low_threshold=self.low_threshold, | |
| high_threshold=self.high_threshold).astype(np.float) | |
| def to_tensor(self, data_list): | |
| """ | |
| Args: | |
| data_list: a numpy.array list | |
| Returns: a torch.tensor list with the None entries removed | |
| """ | |
| keys = list(data_list.keys()) | |
| for key in keys: | |
| if data_list[key] is None or data_list[key] == []: | |
| data_list.pop(key) | |
| else: | |
| item = data_list[key] | |
| if not isinstance(item, list): | |
| if len(item.shape) == 2: | |
| item = item[:, :, np.newaxis] | |
| item = torch.from_numpy(np.transpose(item, (2, 0, 1))).float() | |
| else: | |
| item = np.stack(item, axis=0) | |
| if len(item.shape) == 3: | |
| item = item[:, :, :, np.newaxis] | |
| item = torch.from_numpy(np.transpose(item, (3, 0, 1, 2))).float() | |
| data_list[key] = item | |
| return data_list | |