File size: 4,854 Bytes
05ca42f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
from pathlib import Path
from typing import List, Optional, Tuple

import lightning
import hydra
import torch
from lightning import Callback, LightningDataModule, LightningModule, Trainer
from lightning.pytorch.loggers import Logger
from omegaconf import DictConfig

from deepscreen.utils.hydra import checkpoint_rerun_config
from deepscreen.utils import get_logger, job_wrapper, instantiate_callbacks, instantiate_loggers, log_hyperparameters

log = get_logger(__name__)


# def fix_dict_config(cfg: DictConfig):
#     """fix all vars in the cfg config
#     this is an in-place operation"""
#     keys = list(cfg.keys())
#     for k in keys:
#         if type(cfg[k]) is DictConfig:
#             fix_dict_config(cfg[k])
#         else:
#             setattr(cfg, k, getattr(cfg, k))


@job_wrapper(extra_utils=True)
def train(cfg: DictConfig) -> Tuple[dict, dict]:
    """Trains the model. Can additionally evaluate on a testset, using best weights obtained during
    training.

    This method is wrapped in optional @job_wrapper decorator, that controls the behavior during
    failure. Useful for multiruns, saving info about the crash, etc.

    Args:
        cfg (DictConfig): Configuration composed by Hydra.

    Returns:
        Tuple[dict, dict]: Dict with metrics and dict with all instantiated objects.
    """
    # fix_dict_config(cfg)
    # set seed for random number generators in pytorch, numpy and python.random
    if cfg.get("seed"):
        lightning.seed_everything(cfg.seed, workers=True)

    if cfg.get("ckpt_path"):
        cfg = checkpoint_rerun_config(cfg)

    log.info(f"Instantiating datamodule <{cfg.data._target_}>.")
    datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data)

    log.info(f"Instantiating model <{cfg.model._target_}>.")
    model: LightningModule = hydra.utils.instantiate(cfg.model)

    log.info("Instantiating callbacks.")
    callbacks: List[Callback] = instantiate_callbacks(cfg.get("callbacks"))

    log.info("Instantiating loggers.")
    logger: List[Logger] = instantiate_loggers(cfg.get("logger"))

    log.info(f"Instantiating trainer <{cfg.trainer._target_}>.")
    trainer: Trainer = hydra.utils.instantiate(cfg.trainer, callbacks=callbacks, logger=logger)

    object_dict = {
        "cfg": cfg,
        "datamodule": datamodule,
        "model": model,
        "callbacks": callbacks,
        "logger": logger,
        "trainer": trainer,
    }

    # Temporary fix to explicitly initialize UninitializedParameters in LazyModules
    # for batch in datamodule.train_dataloader():
    #     device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    #     batch = batch.to(device)
    #     model(batch)
    #     break

    if logger:
        log.info("Logging hyperparameters...")
        log_hyperparameters(object_dict)

    if cfg.get("compile"):
        log.info("Compiling model...")
        model = torch.compile(model)

    log.info("Start training...")
    trainer.fit(model=model, datamodule=datamodule, ckpt_path=cfg.get("ckpt_path"))

    if trainer.checkpoint_callback.best_model_path:
        ckpt_path = Path(trainer.checkpoint_callback.best_model_path).resolve()
        log.info(f"Best checkpoint path: {ckpt_path}")
    else:
        ckpt_path = None
        log.warning("Best checkpoint not saved.")

    if cfg.data.train_val_test_split[2] is not None:
        log.info("Start testing...")
        if ckpt_path is None:
            log.warning("Best checkpoint not found. Using current weights for testing.")
        trainer.test(model=model, datamodule=datamodule, ckpt_path=ckpt_path)

    metric_dict = trainer.callback_metrics
    metric_dict['ckpt_path'] = ckpt_path

    return metric_dict, object_dict


@hydra.main(version_base="1.3", config_path="../configs", config_name="train.yaml")
def main(cfg: DictConfig):
    metric_dict, _ = train(cfg)
    cfg.ckpt_path = metric_dict.get('ckpt_path')

    objective_metrics = cfg.get("objective_metrics")

    if not objective_metrics:
        return None

    else:
        invalid_metrics = [metric for metric in objective_metrics if metric not in metric_dict]

        if invalid_metrics:
            raise ValueError(
                f"Unable to find {invalid_metrics} (specified in `objective_metrics`) in `metric_dict`.\n"
                "Make sure your `model.metrics` and `sweep.objective_metrics` configs are correct."
            )

        # metric_value = metric_dict[objective_metric].item()
        metric_values = tuple([metric_dict[metric].item() for metric in objective_metrics])

        for objective_metric, metric_value in zip(objective_metrics, metric_values):
            log.info(f"Retrieved objective: {objective_metric}={metric_value}")

        # return optimized metrics
        return metric_values


if __name__ == "__main__":
    main()