test_ebc / utils /log_utils.py
piaspace's picture
[first]
bb3e610
import torch
from torch import Tensor
from tensorboardX import SummaryWriter
import logging
import os
from typing import Dict, Union, Optional, List, Tuple
from collections import OrderedDict
def get_logger(log_file: str) -> logging.Logger:
logger = logging.getLogger(log_file)
logger.setLevel(logging.DEBUG)
fh = logging.FileHandler(log_file)
fh.setLevel(logging.DEBUG)
ch = logging.StreamHandler()
ch.setLevel(logging.INFO)
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
ch.setFormatter(formatter)
fh.setFormatter(formatter)
logger.addHandler(ch)
logger.addHandler(fh)
return logger
def get_config(config: Dict, mute: bool = False) -> str:
config = config.copy()
config = "\n".join([f"{k.ljust(15)}:\t{v}" for k, v in config.items()])
if not mute:
print(config)
return config
def get_writer(ckpt_dir: str) -> SummaryWriter:
return SummaryWriter(log_dir=os.path.join(ckpt_dir, "logs"))
def print_epoch(epoch: int, total_epochs: int, mute: bool = False) -> Union[str, None]:
digits = len(str(total_epochs))
info = f"Epoch: {(epoch):0{digits}d} / {total_epochs:0{digits}d}"
if mute:
return info
print(info)
def print_train_result(loss_info: Dict[str, float], mute: bool = False) -> Union[str, None]:
loss_info = [f"{k}: {v};" for k, v in loss_info.items()]
info = "Training: " + " ".join(loss_info)
if mute:
return info
print(info)
def print_eval_result(curr_scores: Dict[str, float], best_scores: Dict[str, float], mute: bool = False) -> Union[str, None]:
scores = []
for k in curr_scores.keys():
info = f"Curr {k}: {curr_scores[k]:.4f}; \t Best {k}: "
info += " ".join([f"{best_scores[k][i]:.4f};" for i in range(len(best_scores[k]))])
scores.append(info)
info = "Evaluation:\n" + "\n".join(scores)
if mute:
return info
print(info)
def update_train_result(epoch: int, loss_info: Dict[str, float], writer: SummaryWriter) -> None:
for k, v in loss_info.items():
writer.add_scalar(f"train/{k}", v, epoch)
def update_eval_result(
epoch: int,
curr_scores: Dict[str, float],
hist_scores: Dict[str, List[float]],
best_scores: Dict[str, List[float]],
writer: SummaryWriter,
state_dict: OrderedDict[str, Tensor],
ckpt_dir: str,
) -> Tuple[Dict[str, List[float]], Dict[str, float]]:
os.makedirs(ckpt_dir, exist_ok=True)
for k, v in curr_scores.items():
hist_scores[k].append(v)
writer.add_scalar(f"val/{k}", v, epoch)
# best_scores[k][0] is the best score. Smaller is better.
# Find the location idx where the new score v should be inserted
loc = None
for i in range(len(best_scores[k])):
if v < best_scores[k][i]:
best_scores[k].insert(i, v) # Add the new best score to the location i
loc = i
break
# If the new score is better than the worst best score
if loc is not None:
# Update the best scores
best_scores[k] = best_scores[k][:len(best_scores[k]) - 1]
# Rename the best_{k}_{i}.pth to best_{k}_{i+1}.pth, best_{k}_{i+1}.pth to best_{k}_{i+2}.pth ...
for i in range(len(best_scores[k]) - 1, loc, -1):
if os.path.exists(os.path.join(ckpt_dir, f"best_{k}_{i-1}.pth")):
os.rename(os.path.join(ckpt_dir, f"best_{k}_{i-1}.pth"), os.path.join(ckpt_dir, f"best_{k}_{i}.pth"))
# Save the best checkpoint
torch.save(state_dict, os.path.join(ckpt_dir, f"best_{k}_{loc}.pth"))
return hist_scores, best_scores
def update_loss_info(hist_scores: Union[Dict[str, List[float]], None], curr_scores: Dict[str, float]) -> Dict[str, List[float]]:
assert all([isinstance(v, float) for v in curr_scores.values()]), f"Expected all values to be float, got {curr_scores}"
if hist_scores is None or len(hist_scores) == 0:
hist_scores = {k: [v] for k, v in curr_scores.items()}
else:
for k, v in curr_scores.items():
hist_scores[k].append(v)
return hist_scores
def log(
logger: logging.Logger,
epoch: int,
total_epochs: int,
loss_info: Optional[Dict[str, float]] = None,
curr_scores: Optional[Dict[str, float]] = None,
best_scores: Optional[Dict[str, float]] = None,
message: Optional[str] = None,
) -> None:
if epoch is None:
assert total_epochs is None, f"Expected total_epochs to be None when epoch is None, got {total_epochs}"
msg = ""
else:
assert total_epochs is not None, f"Expected total_epochs to be not None when epoch is not None, got {total_epochs}"
msg = print_epoch(epoch, total_epochs, mute=True)
if loss_info is not None:
msg += "\n" if len(msg) > 0 else ""
msg += print_train_result(loss_info, mute=True)
if curr_scores is not None:
assert best_scores is not None, f"Expected best_scores to be not None when curr_scores is not None, got {best_scores}"
msg += "\n" if len(msg) > 0 else ""
msg += print_eval_result(curr_scores, best_scores, mute=True)
msg += message if message is not None else ""
logger.info(msg)