Spaces:
Running
on
Zero
Running
on
Zero
from datetime import datetime | |
from pathlib import Path | |
import ignite.distributed as idist | |
import torch | |
from ignite.contrib.engines import common | |
from ignite.engine import Engine, Events | |
from ignite.utils import manual_seed, setup_logger | |
from torch.cuda.amp import autocast | |
from scenedino.common.logging import log_basic_info | |
from scenedino.common.array_operations import to | |
# from ignite.contrib.handlers.tensorboard_logger import * | |
from ignite.contrib.handlers import TensorboardLogger | |
from scenedino.common.metrics import DictMeanMetric, SegmentationMetric, ConcatenateMetric | |
from scenedino.training.handlers import VisualizationHandler | |
from scenedino.visualization.vis_2d import tb_visualize | |
from .wrapper import make_eval_fn | |
def base_evaluation( | |
local_rank, | |
config, | |
get_dataflow, | |
initialize, | |
): | |
rank = idist.get_rank() | |
if "eval_seed" in config: | |
manual_seed(config["eval_seed"] + rank) | |
else: | |
manual_seed(config["seed"] + rank) | |
device = idist.device() | |
model_name = config["name"] | |
logger = setup_logger( | |
name=model_name, format="%(levelname)s: %(message)s" | |
) ## default | |
output_path = config["output"]["path"] | |
if rank == 0: | |
unique_id = config["output"].get( | |
"unique_id", datetime.now().strftime("%Y%m%d-%H%M%S") | |
) | |
folder_name = unique_id | |
output_path = Path(output_path) / folder_name | |
if not output_path.exists(): | |
output_path.mkdir(parents=True) | |
config["output"]["path"] = output_path.as_posix() | |
logger.info(f"Output path: {config['output']['path']}") | |
if "cuda" in device.type: | |
config["cuda device name"] = torch.cuda.get_device_name(local_rank) | |
tb_logger = TensorboardLogger(log_dir=output_path) | |
log_basic_info(logger, config) | |
# Setup dataflow, model, optimizer, criterion | |
test_loader = get_dataflow(config) ## default | |
if hasattr(test_loader, "dataset"): | |
logger.info(f"Dataset length: Test: {len(test_loader.dataset)}") | |
config["dataset"]["steps_per_epoch"] = len(test_loader) | |
# ===================================================== MODEL ===================================================== | |
model = initialize(config) | |
cp_path = config.get("checkpoint", None) | |
if cp_path is not None: | |
if not cp_path.endswith(".pt"): | |
cp_path = Path(cp_path) | |
cp_path = next(cp_path.glob("training*.pt")) | |
checkpoint = torch.load(cp_path, map_location=device) | |
logger.info(f"Loading checkpoint from path: {cp_path}") | |
if "model" in checkpoint: | |
model.load_state_dict(checkpoint["model"], strict=False) | |
else: | |
model.load_state_dict(checkpoint, strict=False) | |
else: | |
logger.warning("Careful, no model is loaded") | |
model.to(device) | |
logger.info(f"Model parameters: {sum(p.numel() for p in model.parameters())}") | |
logger.info(f"Trainable model parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}") | |
# Let's now setup evaluator engine to perform model's validation and compute metrics | |
evaluator = create_evaluator(model, config=config, logger=logger, vis_logger=tb_logger) | |
# evaluator.add_event_handler( | |
# Events.ITERATION_COMPLETED(every=config["log_every"]), | |
# log_metrics_current(logger, metrics), | |
# ) | |
try: | |
state = evaluator.run(test_loader, max_epochs=1) | |
log_metrics(logger, state.times["COMPLETED"], "Test", state.metrics) | |
logger.info(f"Checkpoint: {str(cp_path)}") | |
except Exception as e: | |
logger.exception("") | |
raise e | |
# def log_metrics_current(logger, metrics): | |
# def f(engine): | |
# out_str = "\n" + "\t".join( | |
# [ | |
# f"{v.compute():.3f}".ljust(8) | |
# for v in metrics.values() | |
# if v._num_examples != 0 | |
# ] | |
# ) | |
# out_str += "\n" + "\t".join([f"{k}".ljust(8) for k in metrics.keys()]) | |
# logger.info(out_str) | |
# return f | |
def log_metrics(logger, elapsed, tag, metrics): | |
metrics_output = "\n".join([f"\t{k}: {v}" for k, v in metrics.items()]) | |
logger.info( | |
f"\nEvaluation time (seconds): {elapsed:.2f} - {tag} metrics:\n {metrics_output}" | |
) | |
# def create_evaluator(model, metrics, config, tag="val"): | |
def create_evaluator(model, config, logger=None, vis_logger=None, tag="val"): | |
with_amp = config["with_amp"] | |
device = idist.device() | |
metrics = {} | |
for eval_config in config["evaluations"]: | |
agg_type = eval_config.get("agg_type", None) | |
if agg_type == "unsup_seg": | |
metrics[eval_config["type"]] = SegmentationMetric( | |
eval_config["type"], make_eval_fn(model, eval_config), assign_pseudo=True | |
) | |
elif agg_type == "sup_seg": | |
metrics[eval_config["type"]] = SegmentationMetric( | |
eval_config["type"], make_eval_fn(model, eval_config), assign_pseudo=False | |
) | |
elif agg_type == "concat": | |
metrics[eval_config["type"]] = ConcatenateMetric( | |
eval_config["type"], make_eval_fn(model, eval_config) | |
) | |
else: | |
metrics[eval_config["type"]] = DictMeanMetric( | |
eval_config["type"], make_eval_fn(model, eval_config) | |
) | |
def evaluate_step(engine: Engine, data): | |
# if not engine.state_dict["iteration"] % 10 == 0: ## to prevent iterating whole testset for viz purpose | |
model.eval() | |
if "t__get_item__" in data: | |
timing = {"t__get_item__": torch.mean(data["t__get_item__"]).item()} | |
else: | |
timing = {} | |
data = to(data, device) | |
with autocast(enabled=with_amp): | |
data = model(data) ## ! This is where the occupancy prediction is made. | |
loss_metrics = {} | |
return { | |
"output": data, | |
"loss_dict": loss_metrics, | |
"timings_dict": timing, | |
"metrics_dict": {}, | |
} | |
evaluator = Engine(evaluate_step) | |
evaluator.logger = logger ## | |
for name, metric in metrics.items(): | |
metric.attach(evaluator, name) | |
eval_visualize = config.get("eval_visualize", []) | |
if eval_visualize and vis_logger is not None: | |
for name, vis_config in config["validation"].items(): | |
if "visualize" in vis_config: | |
visualize = tb_visualize( | |
(model.renderer.net if hasattr(model, "renderer") else model.module.renderer.net), | |
None, | |
vis_config["visualize"], | |
) | |
def vis_wrapper(*args, **kwargs): | |
with autocast(enabled=with_amp): | |
return visualize(*args, **kwargs) | |
def custom_vis_filter(engine, event): | |
return engine.state.iteration-1 in eval_visualize | |
vis_logger.attach( | |
evaluator, | |
VisualizationHandler( | |
tag=tag, | |
visualizer=vis_wrapper, | |
), | |
Events.ITERATION_COMPLETED(event_filter=custom_vis_filter), | |
) | |
if idist.get_rank() == 0 and (not config.get("with_clearml", False)): | |
common.ProgressBar(desc=f"Evaluation ({tag})", persist=False).attach(evaluator) | |
return evaluator | |