Spaces:
Runtime error
Runtime error
| import torch | |
| _output_ref = None | |
| _replicas_ref = None | |
| def data_parallel_workaround(model, *input): | |
| global _output_ref | |
| global _replicas_ref | |
| device_ids = list(range(torch.cuda.device_count())) | |
| output_device = device_ids[0] | |
| replicas = torch.nn.parallel.replicate(model, device_ids) | |
| # input.shape = (num_args, batch, ...) | |
| inputs = torch.nn.parallel.scatter(input, device_ids) | |
| # inputs.shape = (num_gpus, num_args, batch/num_gpus, ...) | |
| replicas = replicas[:len(inputs)] | |
| outputs = torch.nn.parallel.parallel_apply(replicas, inputs) | |
| y_hat = torch.nn.parallel.gather(outputs, output_device) | |
| _output_ref = outputs | |
| _replicas_ref = replicas | |
| return y_hat | |
| class ValueWindow(): | |
| def __init__(self, window_size=100): | |
| self._window_size = window_size | |
| self._values = [] | |
| def append(self, x): | |
| self._values = self._values[-(self._window_size - 1):] + [x] | |
| def sum(self): | |
| return sum(self._values) | |
| def count(self): | |
| return len(self._values) | |
| def average(self): | |
| return self.sum / max(1, self.count) | |
| def reset(self): | |
| self._values = [] | |