# MIT License # Copyright (c) 2022 Intelligent Systems Lab Org # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. # File author: Shariq Farooq Bhat import torch import torch.cuda.amp as amp import torch.nn as nn from zoedepth.trainers.loss import GradL1Loss, SILogLoss from zoedepth.utils.config import DATASETS_CONFIG from zoedepth.utils.misc import compute_metrics from zoedepth.data.preprocess import get_black_border from .base_trainer import BaseTrainer from torchvision import transforms from PIL import Image import numpy as np class Trainer(BaseTrainer): def __init__(self, config, model, train_loader, test_loader=None, device=None): super().__init__(config, model, train_loader, test_loader=test_loader, device=device) #self.device = device self.silog_loss = SILogLoss() self.grad_loss = GradL1Loss() self.scaler = amp.GradScaler(enabled=self.config.use_amp) def train_on_batch(self, batch, train_step): """ Expects a batch of images and depth as input batch["image"].shape : batch_size, c, h, w batch["depth"].shape : batch_size, 1, h, w """ images, depths_gt = batch['image'].to( self.device), batch['depth'].to(self.device) if "masked_depth" in batch.keys(): # FIXME fix the permutation here, i've missed this upstream somewhere masked_depth = batch["masked_depth"].to(self.device).permute(0, 3, 1, 2) depth_mask = (masked_depth != 0).float() dataset = batch['dataset'][0] max_depth = self.config.max_depth if self.config["add_depth_channel"] and "masked_depth" in batch.keys(): images = torch.cat([images, masked_depth / max_depth, depth_mask], dim=1) elif self.config["add_depth_channel"]: images = torch.cat([images, depths_gt / max_depth, depth_mask], dim=1) b, c, h, w = images.size() mask = batch["mask"].to(self.device).to(torch.bool) losses = {} with amp.autocast(enabled=self.config.use_amp): output = self.model(images) pred_depths = output['metric_depth'] l_si, pred = self.silog_loss( pred_depths, depths_gt, mask=mask, interpolate=True, return_interpolated=True) loss = self.config.w_si * l_si losses[self.silog_loss.name] = l_si if self.config.w_grad > 0: l_grad = self.grad_loss(pred, depths_gt, mask=mask) loss = loss + self.config.w_grad * l_grad losses[self.grad_loss.name] = l_grad else: l_grad = torch.Tensor([0]) if hasattr(self.config, "w_sd") and self.config.w_sd > 0: l_sd = (nn.functional.mse_loss(pred, depths_gt, reduction="none") * depth_mask).mean() loss = loss + self.config.w_sd * l_sd losses["SparseDepth"] = l_sd self.scaler.scale(loss).backward() if self.config.clip_grad > 0: self.scaler.unscale_(self.optimizer) nn.utils.clip_grad_norm_( self.model.parameters(), self.config.clip_grad) self.scaler.step(self.optimizer) if self.should_log and (self.step % int(self.config.log_images_every * self.iters_per_epoch)) == 0: # -99 is treated as invalid depth in the log_images function and is colored grey. depths_gt[torch.logical_not(mask)] = -99 rand_batch_idx = torch.randint(0, b, (1,)).item() depth_log_items = {"GT": depths_gt[rand_batch_idx], "PredictedMono": pred[rand_batch_idx]} if "masked_depth" in batch.keys(): depth_log_items["MaskedGT"] = masked_depth[rand_batch_idx] self.log_images(rgb={"Input": images[rand_batch_idx, :3, ...]}, depth=depth_log_items, prefix="Train", min_depth=DATASETS_CONFIG[dataset]['min_depth'], max_depth=DATASETS_CONFIG[dataset]['max_depth']) if self.config.get("log_rel", False): self.log_images( scalar_field={"RelPred": output["relative_depth"][rand_batch_idx]}, prefix="TrainRel") self.scaler.update() self.optimizer.zero_grad() return losses @torch.no_grad() def eval_infer(self, x): with amp.autocast(enabled=self.config.use_amp): m = self.model.module if self.config.multigpu else self.model pred_depths = m(x)['metric_depth'] return pred_depths @torch.no_grad() def crop_aware_infer(self, x): # if we are not avoiding the black border, we can just use the normal inference if not self.config.get("avoid_boundary", False): return self.eval_infer(x) # otherwise, we need to crop the image to avoid the black border # For now, this may be a bit slow due to converting to numpy and back # We assume no normalization is done on the input image # get the black border assert x.shape[0] == 1, "Only batch size 1 is supported for now" x_pil = transforms.ToPILImage()(x[0].cpu()) x_np = np.array(x_pil, dtype=np.uint8) black_border_params = get_black_border(x_np) top, bottom, left, right = black_border_params.top, black_border_params.bottom, black_border_params.left, black_border_params.right x_np_cropped = x_np[top:bottom, left:right, :] x_cropped = transforms.ToTensor()(Image.fromarray(x_np_cropped)) # run inference on the cropped image pred_depths_cropped = self.eval_infer(x_cropped.unsqueeze(0).to(self.device)) # resize the prediction to x_np_cropped's size pred_depths_cropped = nn.functional.interpolate( pred_depths_cropped, size=(x_np_cropped.shape[0], x_np_cropped.shape[1]), mode="bilinear", align_corners=False) # pad the prediction back to the original size pred_depths = torch.zeros((1, 1, x_np.shape[0], x_np.shape[1]), device=pred_depths_cropped.device, dtype=pred_depths_cropped.dtype) pred_depths[:, :, top:bottom, left:right] = pred_depths_cropped return pred_depths def validate_on_batch(self, batch, val_step): images = batch['image'].to(self.device) depths_gt = batch['depth'].to(self.device) dataset = batch['dataset'][0] mask = batch["mask"].to(self.device) if 'has_valid_depth' in batch: if not batch['has_valid_depth']: return None, None depths_gt = depths_gt.squeeze().unsqueeze(0).unsqueeze(0) mask = mask.squeeze().unsqueeze(0).unsqueeze(0) if self.config["add_depth_channel"]: images = torch.cat([images, torch.zeros_like(depths_gt), torch.zeros_like(depths_gt)], dim=1) if dataset == 'nyu': pred_depths = self.crop_aware_infer(images) else: pred_depths = self.eval_infer(images) pred_depths = pred_depths.squeeze().unsqueeze(0).unsqueeze(0) with amp.autocast(enabled=self.config.use_amp): l_depth = self.silog_loss( pred_depths, depths_gt, mask=mask.to(torch.bool), interpolate=True) metrics = compute_metrics(depths_gt, pred_depths, **self.config) losses = {f"{self.silog_loss.name}": l_depth.item()} if val_step == 1 and self.should_log: depths_gt[torch.logical_not(mask)] = -99 self.log_images(rgb={"Input": images[0, :3, ...]}, depth={"GT": depths_gt[0], "PredictedMono": pred_depths[0]}, prefix="Test", min_depth=DATASETS_CONFIG[dataset]['min_depth'], max_depth=DATASETS_CONFIG[dataset]['max_depth']) return metrics, losses