Spaces:
Running
Running
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| import torch | |
| from detectron2.config import CfgNode | |
| from detectron2.solver import build_lr_scheduler as build_d2_lr_scheduler | |
| from .lr_scheduler import WarmupPolyLR | |
| def build_lr_scheduler( | |
| cfg: CfgNode, optimizer: torch.optim.Optimizer | |
| ) -> torch.optim.lr_scheduler._LRScheduler: | |
| """ | |
| Build a LR scheduler from config. | |
| """ | |
| name = cfg.SOLVER.LR_SCHEDULER_NAME | |
| if name == "WarmupPolyLR": | |
| return WarmupPolyLR( | |
| optimizer, | |
| cfg.SOLVER.MAX_ITER, | |
| warmup_factor=cfg.SOLVER.WARMUP_FACTOR, | |
| warmup_iters=cfg.SOLVER.WARMUP_ITERS, | |
| warmup_method=cfg.SOLVER.WARMUP_METHOD, | |
| power=cfg.SOLVER.POLY_LR_POWER, | |
| constant_ending=cfg.SOLVER.POLY_LR_CONSTANT_ENDING, | |
| ) | |
| else: | |
| return build_d2_lr_scheduler(cfg, optimizer) | |