Spaces:
Running
on
T4
Running
on
T4
""" | |
Configurations can be overwritten by adding: key=value | |
Use debug.wandb=False to disable logging to wandb. | |
""" | |
import datetime | |
from datetime import timedelta | |
import os | |
import random | |
import socket | |
import time | |
from glob import glob | |
import hydra | |
import ipdb # noqa: F401 | |
import numpy as np | |
import omegaconf | |
import torch | |
import wandb | |
from accelerate import Accelerator, DistributedDataParallelKwargs, InitProcessGroupKwargs | |
from pytorch3d.renderer import PerspectiveCameras | |
from diffusionsfm.dataset.co3d_v2 import Co3dDataset, unnormalize_image_for_vis | |
# from diffusionsfm.dataset.multiloader import get_multiloader, MultiDataset | |
from diffusionsfm.eval.eval_category import evaluate | |
from diffusionsfm.model.diffuser import RayDiffuser | |
from diffusionsfm.model.diffuser_dpt import RayDiffuserDPT | |
from diffusionsfm.model.scheduler import NoiseScheduler | |
from diffusionsfm.utils.rays import cameras_to_rays, normalize_cameras_batch, compute_ndc_coordinates | |
from diffusionsfm.utils.visualization import ( | |
create_training_visualizations, | |
view_color_coded_images_from_tensor, | |
) | |
os.umask(000) # Default to 777 permissions | |
class Trainer(object): | |
def __init__(self, cfg): | |
seed = cfg.training.seed | |
torch.manual_seed(seed) | |
np.random.seed(seed) | |
random.seed(seed) | |
self.cfg = cfg | |
self.debug = cfg.debug | |
self.resume = cfg.training.resume | |
self.pretrain_path = cfg.training.pretrain_path | |
self.batch_size = cfg.training.batch_size | |
self.max_iterations = cfg.training.max_iterations | |
self.mixed_precision = cfg.training.mixed_precision | |
self.interval_visualize = cfg.training.interval_visualize | |
self.interval_save_checkpoint = cfg.training.interval_save_checkpoint | |
self.interval_delete_checkpoint = cfg.training.interval_delete_checkpoint | |
self.interval_evaluate = cfg.training.interval_evaluate | |
self.delete_all = cfg.training.delete_all_checkpoints_after_training | |
self.freeze_encoder = cfg.training.freeze_encoder | |
self.translation_scale = cfg.training.translation_scale | |
self.regression = cfg.training.regression | |
self.prob_unconditional = cfg.training.prob_unconditional | |
self.load_extra_cameras = cfg.training.load_extra_cameras | |
self.calculate_intrinsics = cfg.training.calculate_intrinsics | |
self.distort = cfg.training.distort | |
self.diffuse_origins_and_endpoints = cfg.training.diffuse_origins_and_endpoints | |
self.diffuse_depths = cfg.training.diffuse_depths | |
self.depth_resolution = cfg.training.depth_resolution | |
self.dpt_head = cfg.training.dpt_head | |
self.full_num_patches_x = cfg.training.full_num_patches_x | |
self.full_num_patches_y = cfg.training.full_num_patches_y | |
self.dpt_encoder_features = cfg.training.dpt_encoder_features | |
self.nearest_neighbor = cfg.training.nearest_neighbor | |
self.no_bg_targets = cfg.training.no_bg_targets | |
self.unit_normalize_scene = cfg.training.unit_normalize_scene | |
self.sd_scale = cfg.training.sd_scale | |
self.bfloat = cfg.training.bfloat | |
self.first_cam_mediod = cfg.training.first_cam_mediod | |
self.normalize_first_camera = cfg.training.normalize_first_camera | |
self.gradient_clipping = cfg.training.gradient_clipping | |
self.l1_loss = cfg.training.l1_loss | |
self.reinit = cfg.training.reinit | |
if self.first_cam_mediod: | |
assert self.normalize_first_camera | |
self.pred_x0 = cfg.model.pred_x0 | |
self.num_patches_x = cfg.model.num_patches_x | |
self.num_patches_y = cfg.model.num_patches_y | |
self.depth = cfg.model.depth | |
self.num_images = cfg.model.num_images | |
self.num_visualize = min(self.batch_size, 2) | |
self.random_num_images = cfg.model.random_num_images | |
self.feature_extractor = cfg.model.feature_extractor | |
self.append_ndc = cfg.model.append_ndc | |
self.use_homogeneous = cfg.model.use_homogeneous | |
self.freeze_transformer = cfg.model.freeze_transformer | |
self.cond_depth_mask = cfg.model.cond_depth_mask | |
self.dataset_name = cfg.dataset.name | |
self.shape = cfg.dataset.shape | |
self.apply_augmentation = cfg.dataset.apply_augmentation | |
self.mask_holes = cfg.dataset.mask_holes | |
self.image_size = cfg.dataset.image_size | |
if not self.regression and (self.diffuse_origins_and_endpoints or self.diffuse_depths): | |
assert self.mask_holes or self.cond_depth_mask | |
if self.regression: | |
assert self.pred_x0 | |
self.start_time = None | |
self.iteration = 0 | |
self.epoch = 0 | |
self.wandb_id = None | |
self.hostname = socket.gethostname() | |
if self.dpt_head: | |
find_unused_parameters = True | |
else: | |
find_unused_parameters = False | |
ddp_scaler = DistributedDataParallelKwargs( | |
find_unused_parameters=find_unused_parameters | |
) | |
init_kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=5400)) | |
self.accelerator = Accelerator( | |
even_batches=False, | |
device_placement=False, | |
kwargs_handlers=[ddp_scaler, init_kwargs], | |
) | |
self.device = self.accelerator.device | |
scheduler = NoiseScheduler( | |
type=cfg.noise_scheduler.type, | |
max_timesteps=cfg.noise_scheduler.max_timesteps, | |
beta_start=cfg.noise_scheduler.beta_start, | |
beta_end=cfg.noise_scheduler.beta_end, | |
) | |
if self.dpt_head: | |
self.model = RayDiffuserDPT( | |
depth=self.depth, | |
width=self.num_patches_x, | |
P=1, | |
max_num_images=self.num_images, | |
noise_scheduler=scheduler, | |
freeze_encoder=self.freeze_encoder, | |
feature_extractor=self.feature_extractor, | |
append_ndc=self.append_ndc, | |
use_unconditional=self.prob_unconditional > 0, | |
diffuse_depths=self.diffuse_depths, | |
depth_resolution=self.depth_resolution, | |
encoder_features=self.dpt_encoder_features, | |
use_homogeneous=self.use_homogeneous, | |
freeze_transformer=self.freeze_transformer, | |
cond_depth_mask=self.cond_depth_mask, | |
).to(self.device) | |
else: | |
self.model = RayDiffuser( | |
depth=self.depth, | |
width=self.num_patches_x, | |
P=1, | |
max_num_images=self.num_images, | |
noise_scheduler=scheduler, | |
freeze_encoder=self.freeze_encoder, | |
feature_extractor=self.feature_extractor, | |
append_ndc=self.append_ndc, | |
use_unconditional=self.prob_unconditional > 0, | |
diffuse_depths=self.diffuse_depths, | |
depth_resolution=self.depth_resolution, | |
use_homogeneous=self.use_homogeneous, | |
cond_depth_mask=self.cond_depth_mask, | |
).to(self.device) | |
if self.dpt_head: | |
depth_size = self.full_num_patches_x | |
elif self.depth_resolution > 1: | |
depth_size = self.num_patches_x * self.depth_resolution | |
else: | |
depth_size = self.num_patches_x | |
self.depth_size = depth_size | |
if self.dataset_name == "multi": | |
self.dataset, self.train_dataloader, self.test_dataset = get_multiloader( | |
num_images=self.num_images, | |
apply_augmentation=self.apply_augmentation, | |
load_extra_cameras=self.load_extra_cameras, | |
distort_image=self.distort, | |
center_crop=self.diffuse_origins_and_endpoints or self.diffuse_depths, | |
crop_images=not (self.diffuse_origins_and_endpoints or self.diffuse_depths), | |
load_depths=self.diffuse_origins_and_endpoints or self.diffuse_depths, | |
depth_size=depth_size, | |
mask_holes=self.mask_holes, | |
img_size=self.image_size, | |
batch_size=self.batch_size, | |
num_workers=cfg.training.num_workers, | |
dust3r_pairs=True, | |
) | |
elif self.dataset_name == "co3d": | |
self.dataset = Co3dDataset( | |
category=self.shape, | |
split="train", | |
num_images=self.num_images, | |
apply_augmentation=self.apply_augmentation, | |
load_extra_cameras=self.load_extra_cameras, | |
distort_image=self.distort, | |
center_crop=self.diffuse_origins_and_endpoints or self.diffuse_depths, | |
crop_images=not (self.diffuse_origins_and_endpoints or self.diffuse_depths), | |
load_depths=self.diffuse_origins_and_endpoints or self.diffuse_depths, | |
depth_size=depth_size, | |
mask_holes=self.mask_holes, | |
img_size=self.image_size, | |
) | |
self.train_dataloader = torch.utils.data.DataLoader( | |
self.dataset, | |
batch_size=self.batch_size, | |
shuffle=True, | |
num_workers=cfg.training.num_workers, | |
pin_memory=True, | |
drop_last=True, | |
) | |
self.test_dataset = Co3dDataset( | |
category=self.shape, | |
split="test", | |
num_images=self.num_images, | |
apply_augmentation=False, | |
load_extra_cameras=self.load_extra_cameras, | |
distort_image=self.distort, | |
center_crop=self.diffuse_origins_and_endpoints or self.diffuse_depths, | |
crop_images=not (self.diffuse_origins_and_endpoints or self.diffuse_depths), | |
load_depths=self.diffuse_origins_and_endpoints or self.diffuse_depths, | |
depth_size=depth_size, | |
mask_holes=self.mask_holes, | |
img_size=self.image_size, | |
) | |
else: | |
raise NotImplementedError(f"Dataset '{self.dataset_name}' is not supported.") | |
self.lr = 1e-4 | |
self.output_dir = hydra.core.hydra_config.HydraConfig.get().runtime.output_dir | |
self.checkpoint_dir = os.path.join(self.output_dir, "checkpoints") | |
if self.accelerator.is_main_process: | |
name = os.path.basename(self.output_dir) | |
name += f"_{self.debug.run_name}" | |
print("Output dir:", self.output_dir) | |
with open(os.path.join(self.output_dir, name), "w"): | |
# Create empty tag with name | |
pass | |
self.name = name | |
conf_dict = omegaconf.OmegaConf.to_container( | |
cfg, resolve=True, throw_on_missing=True | |
) | |
conf_dict["output_dir"] = self.output_dir | |
conf_dict["hostname"] = self.hostname | |
if self.dpt_head: | |
self.init_optimizer_with_separate_lrs() | |
else: | |
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr) | |
self.gradscaler = torch.cuda.amp.GradScaler(growth_interval=100000, enabled=self.mixed_precision) | |
self.model, self.optimizer, self.train_dataloader = self.accelerator.prepare( | |
self.model, self.optimizer, self.train_dataloader | |
) | |
if self.resume: | |
checkpoint_files = sorted(glob(os.path.join(self.checkpoint_dir, "*.pth"))) | |
last_checkpoint = checkpoint_files[-1] | |
print("Resuming from checkpoint:", last_checkpoint) | |
self.load_model(last_checkpoint, load_metadata=True) | |
elif self.pretrain_path != "": | |
print("Loading pretrained model:", self.pretrain_path) | |
self.load_model(self.pretrain_path, load_metadata=False) | |
if self.accelerator.is_main_process: | |
mode = "online" if cfg.debug.wandb else "disabled" | |
if self.wandb_id is None: | |
self.wandb_id = wandb.util.generate_id() | |
self.wandb_run = wandb.init( | |
mode=mode, | |
name=name, | |
project=cfg.debug.project_name, | |
config=conf_dict, | |
resume=self.resume, | |
id=self.wandb_id, | |
) | |
wandb.define_metric("iteration") | |
noise_schedule = self.get_module().noise_scheduler.plot_schedule( | |
return_image=True | |
) | |
wandb.log( | |
{"Schedule": wandb.Image(noise_schedule, caption="Noise Schedule")} | |
) | |
def get_module(self): | |
if isinstance(self.model, torch.nn.parallel.DistributedDataParallel): | |
model = self.model.module | |
else: | |
model = self.model | |
return model | |
def init_optimizer_with_separate_lrs(self): | |
print("Use different LRs for the DINOv2 encoder and DiT!") | |
feature_extractor_params = [ | |
p for n, p in self.model.feature_extractor.named_parameters() | |
] | |
feature_extractor_param_names = [ | |
"feature_extractor." + n for n, _ in self.model.feature_extractor.named_parameters() | |
] | |
ray_predictor_params = [ | |
p for n, p in self.model.ray_predictor.named_parameters() | |
] | |
ray_predictor_param_names = [ | |
"ray_predictor." + n for n, p in self.model.ray_predictor.named_parameters() | |
] | |
other_params = [ | |
p for n, p in self.model.named_parameters() | |
if n not in feature_extractor_param_names + ray_predictor_param_names | |
] | |
self.optimizer = torch.optim.Adam([ | |
{'params': feature_extractor_params, 'lr': self.lr * 0.1}, # Lower LR for feature extractor | |
{'params': ray_predictor_params, 'lr': self.lr * 0.1}, # Lower LR for DIT (ray_predictor) | |
{'params': other_params, 'lr': self.lr} # Normal LR for other parts of the model | |
]) | |
def train(self): | |
while self.iteration < self.max_iterations: | |
for batch in self.train_dataloader: | |
t0 = time.time() | |
self.optimizer.zero_grad() | |
float_type = torch.bfloat16 if self.bfloat else torch.float16 | |
with torch.cuda.amp.autocast( | |
enabled=self.mixed_precision, dtype=float_type | |
): | |
images = batch["image"].to(self.device) | |
focal_lengths = batch["focal_length"].to(self.device) | |
crop_params = batch["crop_parameters"].to(self.device) | |
principal_points = batch["principal_point"].to(self.device) | |
R = batch["R"].to(self.device) | |
T = batch["T"].to(self.device) | |
if "distortion_coefficients" in batch: | |
distortion_coefficients = batch["distortion_coefficients"] | |
else: | |
distortion_coefficients = [None for _ in range(R.shape[0])] | |
depths = batch["depth"].to(self.device) | |
if self.no_bg_targets: | |
masks = batch["depth_masks"].to(self.device).bool() | |
cameras_og = [ | |
PerspectiveCameras( | |
focal_length=focal_lengths[b], | |
principal_point=principal_points[b], | |
R=R[b], | |
T=T[b], | |
device=self.device, | |
) | |
for b in range(self.batch_size) | |
] | |
cameras, _ = normalize_cameras_batch( | |
cameras=cameras_og, | |
scale=self.translation_scale, | |
normalize_first_camera=self.normalize_first_camera, | |
depths=( | |
None | |
if not (self.diffuse_origins_and_endpoints or self.diffuse_depths) | |
else depths | |
), | |
first_cam_mediod=self.first_cam_mediod, | |
crop_parameters=crop_params, | |
num_patches_x=self.depth_size, | |
num_patches_y=self.depth_size, | |
distortion_coeffs=distortion_coefficients, | |
) | |
# Now that cameras are normalized, fix shapes of camera parameters | |
if self.load_extra_cameras or self.random_num_images: | |
if self.random_num_images: | |
num_images = torch.randint(2, self.num_images + 1, (1,)) | |
else: | |
num_images = self.num_images | |
# The correct number of images is already loaded. | |
# Only need to modify these camera parameters shapes. | |
focal_lengths = focal_lengths[:, :num_images] | |
crop_params = crop_params[:, :num_images] | |
R = R[:, :num_images] | |
T = T[:, :num_images] | |
images = images[:, :num_images] | |
depths = depths[:, :num_images] | |
masks = masks[:, :num_images] | |
cameras = [ | |
PerspectiveCameras( | |
focal_length=cameras[b].focal_length[:num_images], | |
principal_point=cameras[b].principal_point[:num_images], | |
R=cameras[b].R[:num_images], | |
T=cameras[b].T[:num_images], | |
device=self.device, | |
) | |
for b in range(self.batch_size) | |
] | |
if self.regression: | |
low = self.get_module().noise_scheduler.max_timesteps - 1 | |
else: | |
low = 0 | |
t = torch.randint( | |
low=low, | |
high=self.get_module().noise_scheduler.max_timesteps, | |
size=(self.batch_size,), | |
device=self.device, | |
) | |
if self.prob_unconditional > 0: | |
unconditional_mask = ( | |
(torch.rand(self.batch_size) < self.prob_unconditional) | |
.float() | |
.to(self.device) | |
) | |
else: | |
unconditional_mask = None | |
if self.distort: | |
raise NotImplementedError() | |
else: | |
gt_rays = [] | |
rays_dirs = [] | |
rays = [] | |
for i, (camera, crop_param, depth) in enumerate( | |
zip(cameras, crop_params, depths) | |
): | |
if self.diffuse_origins_and_endpoints: | |
mode = "segment" | |
else: | |
mode = "plucker" | |
r = cameras_to_rays( | |
cameras=camera, | |
num_patches_x=self.full_num_patches_x, | |
num_patches_y=self.full_num_patches_y, | |
crop_parameters=crop_param, | |
depths=depth, | |
mode=mode, | |
depth_resolution=self.depth_resolution, | |
nearest_neighbor=self.nearest_neighbor, | |
distortion_coefficients=distortion_coefficients[i], | |
) | |
rays_dirs.append(r.get_directions()) | |
gt_rays.append(r) | |
if self.diffuse_origins_and_endpoints: | |
assert r.mode == "segment" | |
elif self.diffuse_depths: | |
assert r.mode == "plucker" | |
if self.unit_normalize_scene: | |
if self.diffuse_origins_and_endpoints: | |
assert r.mode == "segment" | |
# Let's say SD should be 0.5 | |
scale = r.get_segments().std() * self.sd_scale | |
if scale.isnan().any(): | |
assert False | |
camera.T /= scale | |
r.rays /= scale | |
depths[i] /= scale | |
else: | |
assert r.mode == "plucker" | |
scale = r.depths.std() * self.sd_scale | |
if scale.isnan().any(): | |
assert False | |
camera.T /= scale | |
r.depths /= scale | |
depths[i] /= scale | |
rays.append( | |
r.to_spatial( | |
include_ndc_coordinates=self.append_ndc, | |
include_depths=self.diffuse_depths, | |
use_homogeneous=self.use_homogeneous, | |
) | |
) | |
rays_tensor = torch.stack(rays, dim=0) | |
if self.append_ndc: | |
ndc_coordinates = rays_tensor[..., -2:, :, :] | |
rays_tensor = rays_tensor[..., :-2, :, :] | |
if self.dpt_head: | |
xy_grid = compute_ndc_coordinates( | |
crop_params, | |
num_patches_x=self.depth_size // 16, | |
num_patches_y=self.depth_size // 16, | |
distortion_coeffs=distortion_coefficients, | |
)[..., :2] | |
ndc_coordinates = xy_grid.permute(0, 1, 4, 2, 3).contiguous() | |
else: | |
ndc_coordinates = None | |
if self.cond_depth_mask: | |
condition_mask = masks | |
else: | |
condition_mask = None | |
if rays_tensor.isnan().any(): | |
import pickle | |
with open("bad.json", "wb") as f: | |
pickle.dump(batch, f) | |
ipdb.set_trace() | |
eps_pred, eps = self.model( | |
images=images, | |
rays=rays_tensor, | |
t=t, | |
ndc_coordinates=ndc_coordinates, | |
unconditional_mask=unconditional_mask, | |
depth_mask=condition_mask, | |
) | |
if self.pred_x0: | |
target = rays_tensor | |
else: | |
target = eps | |
if self.no_bg_targets: | |
C = eps_pred.shape[2] | |
loss_masks = masks.unsqueeze(2).repeat(1, 1, C, 1, 1) | |
eps_pred = loss_masks * eps_pred | |
target = loss_masks * target | |
loss = 0 | |
if self.l1_loss: | |
loss_reconstruction = torch.mean(torch.abs(eps_pred - target)) | |
else: | |
loss_reconstruction = torch.mean((eps_pred - target) ** 2) | |
loss += loss_reconstruction | |
if self.mixed_precision: | |
self.gradscaler.scale(loss).backward() | |
scaled_norm = 0 | |
for p in self.model.parameters(): | |
if p.requires_grad and p.grad is not None: | |
param_norm = p.grad.data.norm(2) | |
scaled_norm += param_norm.item() ** 2 | |
scaled_norm = scaled_norm ** 0.5 | |
if self.gradient_clipping and self.accelerator.sync_gradients: | |
self.accelerator.clip_grad_norm_( | |
self.get_module().parameters(), 1 | |
) | |
clipped_norm = 0 | |
for p in self.model.parameters(): | |
if p.requires_grad and p.grad is not None: | |
param_norm = p.grad.data.norm(2) | |
clipped_norm += param_norm.item() ** 2 | |
clipped_norm = clipped_norm ** 0.5 | |
self.gradscaler.unscale_(self.optimizer) | |
unscaled_norm = 0 | |
for p in self.model.parameters(): | |
if p.requires_grad and p.grad is not None: | |
param_norm = p.grad.data.norm(2) | |
unscaled_norm += param_norm.item() ** 2 | |
unscaled_norm = unscaled_norm ** 0.5 | |
self.gradscaler.step(self.optimizer) | |
self.gradscaler.update() | |
else: | |
self.accelerator.backward(loss) | |
if self.gradient_clipping and self.accelerator.sync_gradients: | |
self.accelerator.clip_grad_norm_( | |
self.get_module().parameters(), 10 | |
) | |
self.optimizer.step() | |
if self.accelerator.is_main_process: | |
if self.iteration % 10 == 0: | |
self.log_info( | |
loss_reconstruction, | |
t0, | |
self.lr, | |
scaled_norm, | |
unscaled_norm, | |
clipped_norm, | |
) | |
if self.iteration % self.interval_visualize == 0: | |
self.visualize( | |
images=unnormalize_image_for_vis(images.clone()), | |
cameras_gt=cameras, | |
depths=depths, | |
crop_parameters=crop_params, | |
distortion_coefficients=distortion_coefficients, | |
depth_mask=masks, | |
) | |
if self.iteration % self.interval_save_checkpoint == 0 and self.iteration != 0: | |
self.save_model() | |
if self.iteration % self.interval_delete_checkpoint == 0: | |
self.clear_old_checkpoints(self.checkpoint_dir) | |
if ( | |
self.iteration % self.interval_evaluate == 0 | |
and self.iteration > 0 | |
): | |
self.evaluate_train_acc() | |
if self.iteration >= self.max_iterations + 1: | |
if self.delete_all: | |
self.clear_old_checkpoints( | |
self.checkpoint_dir, clear_all_old=True | |
) | |
return | |
self.iteration += 1 | |
if self.reinit and self.iteration >= 50000: | |
state_dict = self.get_module().state_dict() | |
self.model = RayDiffuserDPT( | |
depth=self.depth, | |
width=self.num_patches_x, | |
P=1, | |
max_num_images=self.num_images, | |
noise_scheduler=self.get_module().noise_scheduler, | |
freeze_encoder=False, | |
feature_extractor=self.feature_extractor, | |
append_ndc=self.append_ndc, | |
use_unconditional=self.prob_unconditional > 0, | |
diffuse_depths=self.diffuse_depths, | |
depth_resolution=self.depth_resolution, | |
encoder_features=self.dpt_encoder_features, | |
use_homogeneous=self.use_homogeneous, | |
freeze_transformer=False, | |
cond_depth_mask=self.cond_depth_mask, | |
).to(self.device) | |
self.init_optimizer_with_separate_lrs() | |
self.gradscaler = torch.cuda.amp.GradScaler(growth_interval=100000, enabled=self.mixed_precision) | |
self.model, self.optimizer = self.accelerator.prepare( | |
self.model, self.optimizer | |
) | |
msg = self.get_module().load_state_dict( | |
state_dict, | |
strict=True, | |
) | |
print(msg) | |
self.reinit = False | |
self.epoch += 1 | |
def load_model(self, path, load_metadata=True): | |
save_dict = torch.load(path, map_location=self.device) | |
del save_dict["state_dict"]["ray_predictor.x_pos_enc.image_pos_table"] | |
if not self.resume: | |
if len(save_dict["state_dict"]["scratch.input_conv.weight"].shape) == 2 and self.dpt_head: | |
print("Initialize conv layer weights from the linear layer!") | |
C = save_dict["state_dict"]["scratch.input_conv.weight"].shape[1] | |
input_conv_weight = save_dict["state_dict"]["scratch.input_conv.weight"].view(384, C, 1, 1).repeat(1, 1, 16, 16) / 256. | |
input_conv_bias = save_dict["state_dict"]["scratch.input_conv.bias"] | |
self.get_module().scratch.input_conv.weight.data = input_conv_weight | |
self.get_module().scratch.input_conv.bias.data = input_conv_bias | |
del save_dict["state_dict"]["scratch.input_conv.weight"] | |
del save_dict["state_dict"]["scratch.input_conv.bias"] | |
missing, unexpected = self.get_module().load_state_dict( | |
save_dict["state_dict"], | |
strict=False, | |
) | |
print(f"Missing keys: {missing}") | |
print(f"Unexpected keys: {unexpected}") | |
if load_metadata: | |
self.iteration = save_dict["iteration"] | |
self.epoch = save_dict["epoch"] | |
time_elapsed = save_dict["elapsed"] | |
self.start_time = time.time() - time_elapsed | |
if "wandb_id" in save_dict: | |
self.wandb_id = save_dict["wandb_id"] | |
self.optimizer.load_state_dict(save_dict["optimizer"]) | |
self.gradscaler.load_state_dict(save_dict["gradscaler"]) | |
def save_model(self): | |
path = os.path.join(self.checkpoint_dir, f"ckpt_{self.iteration:08d}.pth") | |
os.makedirs(os.path.dirname(path), exist_ok=True) | |
elapsed = time.time() - self.start_time if self.start_time is not None else 0 | |
save_dict = { | |
"epoch": self.epoch, | |
"elapsed": elapsed, | |
"gradscaler": self.gradscaler.state_dict(), | |
"iteration": self.iteration, | |
"state_dict": self.get_module().state_dict(), | |
"optimizer": self.optimizer.state_dict(), | |
"wandb_id": self.wandb_id, | |
} | |
torch.save(save_dict, path) | |
def clear_old_checkpoints(self, checkpoint_dir, clear_all_old=False): | |
print("Clearing old checkpoints") | |
checkpoint_files = sorted(glob(os.path.join(checkpoint_dir, "ckpt_*.pth"))) | |
if clear_all_old: | |
for checkpoint_file in checkpoint_files[:-1]: | |
os.remove(checkpoint_file) | |
else: | |
for checkpoint_file in checkpoint_files: | |
checkpoint = os.path.basename(checkpoint_file) | |
checkpoint_iteration = int("".join(filter(str.isdigit, checkpoint))) | |
if checkpoint_iteration % self.interval_delete_checkpoint != 0: | |
os.remove(checkpoint_file) | |
def log_info( | |
self, | |
loss, | |
t0, | |
lr, | |
scaled_norm, | |
unscaled_norm, | |
clipped_norm, | |
): | |
if self.start_time is None: | |
self.start_time = time.time() | |
time_elapsed = round(time.time() - self.start_time) | |
time_remaining = round( | |
(time.time() - self.start_time) | |
/ (self.iteration + 1) | |
* (self.max_iterations - self.iteration) | |
) | |
disp = [ | |
f"Iter: {self.iteration}/{self.max_iterations}", | |
f"Epoch: {self.epoch}", | |
f"Loss: {loss.item():.4f}", | |
f"LR: {lr:.7f}", | |
f"Grad Norm: {scaled_norm:.4f}/{unscaled_norm:.4f}/{clipped_norm:.4f}", | |
f"Elap: {str(datetime.timedelta(seconds=time_elapsed))}", | |
f"Rem: {str(datetime.timedelta(seconds=time_remaining))}", | |
self.hostname, | |
self.name, | |
] | |
print(", ".join(disp), flush=True) | |
wandb_log = { | |
"loss": loss.item(), | |
"iter_time": time.time() - t0, | |
"lr": lr, | |
"iteration": self.iteration, | |
"hours_remaining": time_remaining / 3600, | |
"gradient norm": scaled_norm, | |
"unscaled norm": unscaled_norm, | |
"clipped norm": clipped_norm, | |
} | |
wandb.log(wandb_log) | |
def visualize( | |
self, | |
images, | |
cameras_gt, | |
crop_parameters=None, | |
depths=None, | |
distortion_coefficients=None, | |
depth_mask=None, | |
high_loss=False, | |
): | |
self.get_module().eval() | |
for camera in cameras_gt: | |
# AMP may not cast back to float | |
camera.R = camera.R.float() | |
camera.T = camera.T.float() | |
loss_tag = "" if not high_loss else " HIGH LOSS" | |
for i in range(self.num_visualize): | |
imgs = view_color_coded_images_from_tensor(images[i].cpu(), depth=False) | |
im = wandb.Image(imgs, caption=f"iteration {self.iteration} example {i}") | |
wandb.log({f"Vis images {i}{loss_tag}": im}) | |
if self.cond_depth_mask: | |
imgs = view_color_coded_images_from_tensor( | |
depth_mask[i].cpu(), depth=True | |
) | |
im = wandb.Image( | |
imgs, caption=f"iteration {self.iteration} example {i}" | |
) | |
wandb.log({f"Vis masks {i}{loss_tag}": im}) | |
vis_depths, _, _ = create_training_visualizations( | |
model=self.get_module(), | |
images=images[: self.num_visualize], | |
device=self.device, | |
cameras_gt=cameras_gt, | |
pred_x0=self.pred_x0, | |
num_images=images.shape[1], | |
crop_parameters=crop_parameters[: self.num_visualize], | |
visualize_pred=self.regression, | |
return_first=self.regression, | |
calculate_intrinsics=self.calculate_intrinsics, | |
mode="segment" if self.diffuse_origins_and_endpoints else "plucker", | |
depths=depths[: self.num_visualize], | |
diffuse_depths=self.diffuse_depths, | |
full_num_patches_x=self.full_num_patches_x, | |
full_num_patches_y=self.full_num_patches_y, | |
use_homogeneous=self.use_homogeneous, | |
distortion_coefficients=distortion_coefficients, | |
) | |
for i, vis_image in enumerate(vis_depths): | |
im = wandb.Image( | |
vis_image, caption=f"iteration {self.iteration} example {i}" | |
) | |
for i, vis_image in enumerate(vis_depths): | |
im = wandb.Image( | |
vis_image, caption=f"iteration {self.iteration} example {i}" | |
) | |
wandb.log({f"Vis origins and endpoints {i}{loss_tag}": im}) | |
self.get_module().train() | |
def evaluate_train_acc(self, num_evaluate=10): | |
print("Evaluating train accuracy") | |
model = self.get_module() | |
model.eval() | |
additional_timesteps = [0, 10, 20, 30, 40, 50, 60, 70, 80, 90] | |
num_images = self.num_images | |
for split in ["train", "test"]: | |
if split == "train": | |
if self.dataset_name != "co3d": | |
to_evaluate = self.dataset.datasets | |
names = self.dataset.names | |
else: | |
to_evaluate = [self.dataset] | |
names = ["co3d"] | |
elif split == "test": | |
if self.dataset_name != "co3d": | |
to_evaluate = self.test_dataset.datasets | |
names = self.test_dataset.names | |
else: | |
to_evaluate = [self.test_dataset] | |
names = ["co3d"] | |
for name, dataset in zip(names, to_evaluate): | |
results = evaluate( | |
cfg=self.cfg, | |
model=model, | |
dataset=dataset, | |
num_images=num_images, | |
device=self.device, | |
additional_timesteps=additional_timesteps, | |
num_evaluate=num_evaluate, | |
use_pbar=True, | |
mode="segment" if self.diffuse_origins_and_endpoints else "plucker", | |
metrics=False, | |
) | |
R_err = [] | |
CC_err = [] | |
for key in results.keys(): | |
R_err.append([v["R_error"] for v in results[key]]) | |
CC_err.append([v["CC_error"] for v in results[key]]) | |
R_err = np.array(R_err) | |
CC_err = np.array(CC_err) | |
R_acc_15 = np.mean(R_err < 15, (0, 2)).max() | |
CC_acc = np.mean(CC_err < 0.1, (0, 2)).max() | |
wandb.log( | |
{ | |
f"R_acc_15_{name}_{split}": R_acc_15, | |
"iteration": self.iteration, | |
} | |
) | |
wandb.log( | |
{ | |
f"CC_acc_0.1_{name}_{split}": CC_acc, | |
"iteration": self.iteration, | |
} | |
) | |
model.train() | |
def main(cfg): | |
print(cfg) | |
torch.autograd.set_detect_anomaly(cfg.debug.anomaly_detection) | |
torch.set_float32_matmul_precision(cfg.training.matmul_precision) | |
trainer = Trainer(cfg=cfg) | |
trainer.train() | |
if __name__ == "__main__": | |
main() | |