Spaces:
Runtime error
Runtime error
| from mmcv.runner import OptimizerHook, HOOKS | |
| try: | |
| import apex | |
| except: | |
| print('apex is not installed') | |
| class DistOptimizerHook(OptimizerHook): | |
| """Optimizer hook for distributed training.""" | |
| def __init__(self, update_interval=1, grad_clip=None, coalesce=True, bucket_size_mb=-1, use_fp16=False): | |
| self.grad_clip = grad_clip | |
| self.coalesce = coalesce | |
| self.bucket_size_mb = bucket_size_mb | |
| self.update_interval = update_interval | |
| self.use_fp16 = use_fp16 | |
| def before_run(self, runner): | |
| runner.optimizer.zero_grad() | |
| def after_train_iter(self, runner): | |
| runner.outputs['loss'] /= self.update_interval | |
| if self.use_fp16: | |
| with apex.amp.scale_loss(runner.outputs['loss'], runner.optimizer) as scaled_loss: | |
| scaled_loss.backward() | |
| else: | |
| runner.outputs['loss'].backward() | |
| if self.every_n_iters(runner, self.update_interval): | |
| if self.grad_clip is not None: | |
| self.clip_grads(runner.model.parameters()) | |
| runner.optimizer.step() | |
| runner.optimizer.zero_grad() | |