|
import os |
|
from models.module import DiffGeolocalizer |
|
import hydra |
|
from os.path import join |
|
|
|
import torch |
|
|
|
from omegaconf import OmegaConf |
|
from omegaconf import open_dict |
|
from hydra.utils import instantiate |
|
|
|
from models.eval_best_model import EvalModule |
|
|
|
torch.set_float32_matmul_precision("high") |
|
|
|
|
|
|
|
|
|
OmegaConf.register_new_resolver("eval", eval) |
|
|
|
|
|
def load_model(cfg, dict_config, wandb_id): |
|
logger = instantiate(cfg.logger, id=wandb_id, resume="allow") |
|
log_dict = {"model": dict_config["model"], "dataset": dict_config["dataset"]} |
|
logger._wandb_init.update({"config": log_dict}) |
|
model = EvalModule(cfg.model) |
|
trainer = instantiate( |
|
cfg.trainer, strategy=cfg.trainer.strategy |
|
) |
|
return trainer, model |
|
|
|
|
|
def hydra_boilerplate(cfg): |
|
dict_config = OmegaConf.to_container(cfg, resolve=True) |
|
trainer, model = load_model(cfg, dict_config, cfg.wandb_id) |
|
return trainer, model |
|
|
|
|
|
import copy |
|
|
|
|
|
def init_datamodule(cfg): |
|
datamodule = instantiate(cfg.datamodule) |
|
return datamodule |
|
|
|
|
|
if __name__ == "__main__": |
|
import sys |
|
|
|
sys.argv = ( |
|
[sys.argv[0]] |
|
+ ["+pt_model_path=${hydra:runtime.config_sources}"] |
|
+ sys.argv[1:] |
|
) |
|
|
|
@hydra.main(config_path="configs", config_name="config", version_base=None) |
|
def main(cfg): |
|
|
|
with open_dict(cfg): |
|
path = cfg.pt_model_path[1]["path"] |
|
cfg.wandb_id = join(path, "wandb_id.txt") |
|
cfg.checkpoint = join(path, "last.ckpt") |
|
cfg.computer.devices = 1 |
|
|
|
( |
|
trainer, |
|
model, |
|
) = hydra_boilerplate(cfg) |
|
datamodule = init_datamodule(cfg) |
|
trainer.test(model, datamodule=datamodule) |
|
|
|
main() |
|
|