SceneDINO / scenedino /evaluation /base_evaluator.py
jev-aleks's picture
scenedino init
9e15541
from datetime import datetime
from pathlib import Path
import ignite.distributed as idist
import torch
from ignite.contrib.engines import common
from ignite.engine import Engine, Events
from ignite.utils import manual_seed, setup_logger
from torch.cuda.amp import autocast
from scenedino.common.logging import log_basic_info
from scenedino.common.array_operations import to
# from ignite.contrib.handlers.tensorboard_logger import *
from ignite.contrib.handlers import TensorboardLogger
from scenedino.common.metrics import DictMeanMetric, SegmentationMetric, ConcatenateMetric
from scenedino.training.handlers import VisualizationHandler
from scenedino.visualization.vis_2d import tb_visualize
from .wrapper import make_eval_fn
def base_evaluation(
local_rank,
config,
get_dataflow,
initialize,
):
rank = idist.get_rank()
if "eval_seed" in config:
manual_seed(config["eval_seed"] + rank)
else:
manual_seed(config["seed"] + rank)
device = idist.device()
model_name = config["name"]
logger = setup_logger(
name=model_name, format="%(levelname)s: %(message)s"
) ## default
output_path = config["output"]["path"]
if rank == 0:
unique_id = config["output"].get(
"unique_id", datetime.now().strftime("%Y%m%d-%H%M%S")
)
folder_name = unique_id
output_path = Path(output_path) / folder_name
if not output_path.exists():
output_path.mkdir(parents=True)
config["output"]["path"] = output_path.as_posix()
logger.info(f"Output path: {config['output']['path']}")
if "cuda" in device.type:
config["cuda device name"] = torch.cuda.get_device_name(local_rank)
tb_logger = TensorboardLogger(log_dir=output_path)
log_basic_info(logger, config)
# Setup dataflow, model, optimizer, criterion
test_loader = get_dataflow(config) ## default
if hasattr(test_loader, "dataset"):
logger.info(f"Dataset length: Test: {len(test_loader.dataset)}")
config["dataset"]["steps_per_epoch"] = len(test_loader)
# ===================================================== MODEL =====================================================
model = initialize(config)
cp_path = config.get("checkpoint", None)
if cp_path is not None:
if not cp_path.endswith(".pt"):
cp_path = Path(cp_path)
cp_path = next(cp_path.glob("training*.pt"))
checkpoint = torch.load(cp_path, map_location=device)
logger.info(f"Loading checkpoint from path: {cp_path}")
if "model" in checkpoint:
model.load_state_dict(checkpoint["model"], strict=False)
else:
model.load_state_dict(checkpoint, strict=False)
else:
logger.warning("Careful, no model is loaded")
model.to(device)
logger.info(f"Model parameters: {sum(p.numel() for p in model.parameters())}")
logger.info(f"Trainable model parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
# Let's now setup evaluator engine to perform model's validation and compute metrics
evaluator = create_evaluator(model, config=config, logger=logger, vis_logger=tb_logger)
# evaluator.add_event_handler(
# Events.ITERATION_COMPLETED(every=config["log_every"]),
# log_metrics_current(logger, metrics),
# )
try:
state = evaluator.run(test_loader, max_epochs=1)
log_metrics(logger, state.times["COMPLETED"], "Test", state.metrics)
logger.info(f"Checkpoint: {str(cp_path)}")
except Exception as e:
logger.exception("")
raise e
# def log_metrics_current(logger, metrics):
# def f(engine):
# out_str = "\n" + "\t".join(
# [
# f"{v.compute():.3f}".ljust(8)
# for v in metrics.values()
# if v._num_examples != 0
# ]
# )
# out_str += "\n" + "\t".join([f"{k}".ljust(8) for k in metrics.keys()])
# logger.info(out_str)
# return f
def log_metrics(logger, elapsed, tag, metrics):
metrics_output = "\n".join([f"\t{k}: {v}" for k, v in metrics.items()])
logger.info(
f"\nEvaluation time (seconds): {elapsed:.2f} - {tag} metrics:\n {metrics_output}"
)
# def create_evaluator(model, metrics, config, tag="val"):
def create_evaluator(model, config, logger=None, vis_logger=None, tag="val"):
with_amp = config["with_amp"]
device = idist.device()
metrics = {}
for eval_config in config["evaluations"]:
agg_type = eval_config.get("agg_type", None)
if agg_type == "unsup_seg":
metrics[eval_config["type"]] = SegmentationMetric(
eval_config["type"], make_eval_fn(model, eval_config), assign_pseudo=True
)
elif agg_type == "sup_seg":
metrics[eval_config["type"]] = SegmentationMetric(
eval_config["type"], make_eval_fn(model, eval_config), assign_pseudo=False
)
elif agg_type == "concat":
metrics[eval_config["type"]] = ConcatenateMetric(
eval_config["type"], make_eval_fn(model, eval_config)
)
else:
metrics[eval_config["type"]] = DictMeanMetric(
eval_config["type"], make_eval_fn(model, eval_config)
)
@torch.no_grad()
def evaluate_step(engine: Engine, data):
# if not engine.state_dict["iteration"] % 10 == 0: ## to prevent iterating whole testset for viz purpose
model.eval()
if "t__get_item__" in data:
timing = {"t__get_item__": torch.mean(data["t__get_item__"]).item()}
else:
timing = {}
data = to(data, device)
with autocast(enabled=with_amp):
data = model(data) ## ! This is where the occupancy prediction is made.
loss_metrics = {}
return {
"output": data,
"loss_dict": loss_metrics,
"timings_dict": timing,
"metrics_dict": {},
}
evaluator = Engine(evaluate_step)
evaluator.logger = logger ##
for name, metric in metrics.items():
metric.attach(evaluator, name)
eval_visualize = config.get("eval_visualize", [])
if eval_visualize and vis_logger is not None:
for name, vis_config in config["validation"].items():
if "visualize" in vis_config:
visualize = tb_visualize(
(model.renderer.net if hasattr(model, "renderer") else model.module.renderer.net),
None,
vis_config["visualize"],
)
def vis_wrapper(*args, **kwargs):
with autocast(enabled=with_amp):
return visualize(*args, **kwargs)
def custom_vis_filter(engine, event):
return engine.state.iteration-1 in eval_visualize
vis_logger.attach(
evaluator,
VisualizationHandler(
tag=tag,
visualizer=vis_wrapper,
),
Events.ITERATION_COMPLETED(event_filter=custom_vis_filter),
)
if idist.get_rank() == 0 and (not config.get("with_clearml", False)):
common.ProgressBar(desc=f"Evaluation ({tag})", persist=False).attach(evaluator)
return evaluator