"""Run wandb sweep.""" # Copyright (C) 2020 Intel Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions # and limitations under the License. from argparse import ArgumentParser from pathlib import Path from typing import Union import pytorch_lightning as pl from omegaconf import DictConfig, ListConfig, OmegaConf from pytorch_lightning import seed_everything from pytorch_lightning.loggers import WandbLogger from utils import flatten_hpo_params import wandb from anomalib.config import get_configurable_parameters, update_input_size_config from anomalib.data import get_datamodule from anomalib.models import get_model from anomalib.utils.sweep import flatten_sweep_params, set_in_nested_config class WandbSweep: """wandb sweep. Args: config (DictConfig): Original model configuration. sweep_config (DictConfig): Sweep configuration. """ def __init__(self, config: Union[DictConfig, ListConfig], sweep_config: Union[DictConfig, ListConfig]) -> None: self.config = config self.sweep_config = sweep_config self.observation_budget = sweep_config.observation_budget if "observation_budget" in self.sweep_config.keys(): # this instance check is to silence mypy. if isinstance(self.sweep_config, DictConfig): self.sweep_config.pop("observation_budget") def run(self): """Run the sweep.""" flattened_hpo_params = flatten_hpo_params(self.sweep_config.parameters) self.sweep_config.parameters = flattened_hpo_params sweep_id = wandb.sweep( OmegaConf.to_object(self.sweep_config), project=f"{self.config.model.name}_{self.config.dataset.name}", ) wandb.agent(sweep_id, function=self.sweep, count=self.observation_budget) def sweep(self): """Method to load the model, update config and call fit. The metrics are logged to ```wandb``` dashboard.""" wandb_logger = WandbLogger(config=flatten_sweep_params(self.sweep_config), log_model=False) sweep_config = wandb_logger.experiment.config for param in sweep_config.keys(): set_in_nested_config(self.config, param.split("."), sweep_config[param]) config = update_input_size_config(self.config) model = get_model(config) datamodule = get_datamodule(config) # Disable saving checkpoints as all checkpoints from the sweep will get uploaded config.trainer.checkpoint_callback = False trainer = pl.Trainer(**config.trainer, logger=wandb_logger) trainer.fit(model, datamodule=datamodule) def get_args(): """Gets parameters from commandline.""" parser = ArgumentParser() parser.add_argument("--model", type=str, default="padim", help="Name of the algorithm to train/test") parser.add_argument("--model_config", type=Path, required=False, help="Path to a model config file") parser.add_argument("--sweep_config", type=Path, required=True, help="Path to sweep configuration") return parser.parse_args() if __name__ == "__main__": args = get_args() model_config = get_configurable_parameters(model_name=args.model, config_path=args.model_config) hpo_config = OmegaConf.load(args.sweep_config) if model_config.project.seed != 0: seed_everything(model_config.project.seed) sweep = WandbSweep(model_config, hpo_config) sweep.run()