import torch from utils.flow_viz import flow_tensor_to_image from .visualization import viz_depth_tensor class Logger: def __init__(self, lr_scheduler, summary_writer, summary_freq=100, start_step=0, img_mean=None, img_std=None, ): self.lr_scheduler = lr_scheduler self.total_steps = start_step self.running_loss = {} self.summary_writer = summary_writer self.summary_freq = summary_freq self.img_mean = img_mean self.img_std = img_std def print_training_status(self, mode='train', is_depth=False): if is_depth: print('step: %06d \t loss: %.3f' % (self.total_steps, self.running_loss['total_loss'] / self.summary_freq)) else: print('step: %06d \t epe: %.3f' % (self.total_steps, self.running_loss['epe'] / self.summary_freq)) for k in self.running_loss: self.summary_writer.add_scalar(mode + '/' + k, self.running_loss[k] / self.summary_freq, self.total_steps) self.running_loss[k] = 0.0 def lr_summary(self): lr = self.lr_scheduler.get_last_lr()[0] self.summary_writer.add_scalar('lr', lr, self.total_steps) def add_image_summary(self, img1, img2, flow_preds=None, flow_gt=None, mode='train', is_depth=False, ): if self.total_steps % self.summary_freq == 0: if is_depth: img1 = self.unnormalize_image(img1.detach().cpu()) # [3, H, W], range [0, 1] img2 = self.unnormalize_image(img2.detach().cpu()) concat = torch.cat((img1, img2), dim=-1) # [3, H, W*2] self.summary_writer.add_image(mode + '/img', concat, self.total_steps) else: img_concat = torch.cat((img1[0].detach().cpu(), img2[0].detach().cpu()), dim=-1) img_concat = img_concat.type(torch.uint8) # convert to uint8 to visualize in tensorboard flow_pred = flow_tensor_to_image(flow_preds[-1][0]) forward_flow_gt = flow_tensor_to_image(flow_gt[0]) flow_concat = torch.cat((torch.from_numpy(flow_pred), torch.from_numpy(forward_flow_gt)), dim=-1) concat = torch.cat((img_concat, flow_concat), dim=-2) self.summary_writer.add_image(mode + '/img_pred_gt', concat, self.total_steps) def add_depth_summary(self, depth_pred, depth_gt, mode='train'): # assert depth_pred.dim() == 2 # [H, W] if self.total_steps % self.summary_freq == 0 or 'val' in mode: pred_viz = viz_depth_tensor(depth_pred.detach().cpu()) # [3, H, W] gt_viz = viz_depth_tensor(depth_gt.detach().cpu()) concat = torch.cat((pred_viz, gt_viz), dim=-1) # [3, H, W*2] self.summary_writer.add_image(mode + '/depth_pred_gt', concat, self.total_steps) def unnormalize_image(self, img): # img: [3, H, W], used for visualizing image mean = torch.tensor(self.img_mean).view(3, 1, 1).type_as(img) std = torch.tensor(self.img_std).view(3, 1, 1).type_as(img) out = img * std + mean return out def push(self, metrics, mode='train', is_depth=False, ): self.total_steps += 1 self.lr_summary() for key in metrics: if key not in self.running_loss: self.running_loss[key] = 0.0 self.running_loss[key] += metrics[key] if self.total_steps % self.summary_freq == 0: self.print_training_status(mode, is_depth=is_depth) self.running_loss = {} def write_dict(self, results): for key in results: tag = key.split('_')[0] tag = tag + '/' + key self.summary_writer.add_scalar(tag, results[key], self.total_steps) def close(self): self.summary_writer.close()