Spaces:
Runtime error
Runtime error
File size: 1,020 Bytes
f1dd031 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 |
import numpy as np
from torch.utils.data import DistributedSampler as _DistributedSampler
class DistributedVideoSampler(_DistributedSampler):
def __init__(self, dataset, num_replicas=None, rank=None, shuffle=False):
super().__init__(dataset, num_replicas=num_replicas, rank=rank)
self.shuffle = shuffle
assert not self.shuffle, "Specific for video sequential testing."
self.num_samples = len(dataset)
first_frame_indices = []
for i, img_info in enumerate(self.dataset.data_infos):
if img_info["frame_id"] == 0:
first_frame_indices.append(i)
chunks = np.array_split(first_frame_indices, num_replicas)
split_flags = [c[0] for c in chunks]
split_flags.append(self.num_samples)
self.indices = [
list(range(split_flags[i], split_flags[i + 1]))
for i in range(self.num_replicas)
]
def __iter__(self):
indices = self.indices[self.rank]
return iter(indices)
|