Spaces:
Running
on
Zero
Running
on
Zero
| from r_basicsr.utils import get_root_logger | |
| from r_basicsr.utils.registry import MODEL_REGISTRY | |
| from .video_base_model import VideoBaseModel | |
| class EDVRModel(VideoBaseModel): | |
| """EDVR Model. | |
| Paper: EDVR: Video Restoration with Enhanced Deformable Convolutional Networks. # noqa: E501 | |
| """ | |
| def __init__(self, opt): | |
| super(EDVRModel, self).__init__(opt) | |
| if self.is_train: | |
| self.train_tsa_iter = opt['train'].get('tsa_iter') | |
| def setup_optimizers(self): | |
| train_opt = self.opt['train'] | |
| dcn_lr_mul = train_opt.get('dcn_lr_mul', 1) | |
| logger = get_root_logger() | |
| logger.info(f'Multiple the learning rate for dcn with {dcn_lr_mul}.') | |
| if dcn_lr_mul == 1: | |
| optim_params = self.net_g.parameters() | |
| else: # separate dcn params and normal params for different lr | |
| normal_params = [] | |
| dcn_params = [] | |
| for name, param in self.net_g.named_parameters(): | |
| if 'dcn' in name: | |
| dcn_params.append(param) | |
| else: | |
| normal_params.append(param) | |
| optim_params = [ | |
| { # add normal params first | |
| 'params': normal_params, | |
| 'lr': train_opt['optim_g']['lr'] | |
| }, | |
| { | |
| 'params': dcn_params, | |
| 'lr': train_opt['optim_g']['lr'] * dcn_lr_mul | |
| }, | |
| ] | |
| optim_type = train_opt['optim_g'].pop('type') | |
| self.optimizer_g = self.get_optimizer(optim_type, optim_params, **train_opt['optim_g']) | |
| self.optimizers.append(self.optimizer_g) | |
| def optimize_parameters(self, current_iter): | |
| if self.train_tsa_iter: | |
| if current_iter == 1: | |
| logger = get_root_logger() | |
| logger.info(f'Only train TSA module for {self.train_tsa_iter} iters.') | |
| for name, param in self.net_g.named_parameters(): | |
| if 'fusion' not in name: | |
| param.requires_grad = False | |
| elif current_iter == self.train_tsa_iter: | |
| logger = get_root_logger() | |
| logger.warning('Train all the parameters.') | |
| for param in self.net_g.parameters(): | |
| param.requires_grad = True | |
| super(EDVRModel, self).optimize_parameters(current_iter) | |