julien.blanchon
add app
c8c12e9
"""Run wandb sweep."""
# Copyright (C) 2020 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions
# and limitations under the License.
from argparse import ArgumentParser
from pathlib import Path
from typing import Union
import pytorch_lightning as pl
from omegaconf import DictConfig, ListConfig, OmegaConf
from pytorch_lightning import seed_everything
from pytorch_lightning.loggers import WandbLogger
from utils import flatten_hpo_params
import wandb
from anomalib.config import get_configurable_parameters, update_input_size_config
from anomalib.data import get_datamodule
from anomalib.models import get_model
from anomalib.utils.sweep import flatten_sweep_params, set_in_nested_config
class WandbSweep:
"""wandb sweep.
Args:
config (DictConfig): Original model configuration.
sweep_config (DictConfig): Sweep configuration.
"""
def __init__(self, config: Union[DictConfig, ListConfig], sweep_config: Union[DictConfig, ListConfig]) -> None:
self.config = config
self.sweep_config = sweep_config
self.observation_budget = sweep_config.observation_budget
if "observation_budget" in self.sweep_config.keys():
# this instance check is to silence mypy.
if isinstance(self.sweep_config, DictConfig):
self.sweep_config.pop("observation_budget")
def run(self):
"""Run the sweep."""
flattened_hpo_params = flatten_hpo_params(self.sweep_config.parameters)
self.sweep_config.parameters = flattened_hpo_params
sweep_id = wandb.sweep(
OmegaConf.to_object(self.sweep_config),
project=f"{self.config.model.name}_{self.config.dataset.name}",
)
wandb.agent(sweep_id, function=self.sweep, count=self.observation_budget)
def sweep(self):
"""Method to load the model, update config and call fit. The metrics are logged to ```wandb``` dashboard."""
wandb_logger = WandbLogger(config=flatten_sweep_params(self.sweep_config), log_model=False)
sweep_config = wandb_logger.experiment.config
for param in sweep_config.keys():
set_in_nested_config(self.config, param.split("."), sweep_config[param])
config = update_input_size_config(self.config)
model = get_model(config)
datamodule = get_datamodule(config)
# Disable saving checkpoints as all checkpoints from the sweep will get uploaded
config.trainer.checkpoint_callback = False
trainer = pl.Trainer(**config.trainer, logger=wandb_logger)
trainer.fit(model, datamodule=datamodule)
def get_args():
"""Gets parameters from commandline."""
parser = ArgumentParser()
parser.add_argument("--model", type=str, default="padim", help="Name of the algorithm to train/test")
parser.add_argument("--model_config", type=Path, required=False, help="Path to a model config file")
parser.add_argument("--sweep_config", type=Path, required=True, help="Path to sweep configuration")
return parser.parse_args()
if __name__ == "__main__":
args = get_args()
model_config = get_configurable_parameters(model_name=args.model, config_path=args.model_config)
hpo_config = OmegaConf.load(args.sweep_config)
if model_config.project.seed != 0:
seed_everything(model_config.project.seed)
sweep = WandbSweep(model_config, hpo_config)
sweep.run()