|
import os |
|
from pathlib import Path |
|
import sys |
|
if (_package_root := str(Path(__file__).absolute().parents[2])) not in sys.path: |
|
sys.path.insert(0, _package_root) |
|
import json |
|
import time |
|
import random |
|
from typing import * |
|
import itertools |
|
from contextlib import nullcontext |
|
from concurrent.futures import ThreadPoolExecutor |
|
import io |
|
|
|
import numpy as np |
|
import cv2 |
|
from PIL import Image |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import torch.version |
|
import accelerate |
|
from accelerate import Accelerator, DistributedDataParallelKwargs |
|
from accelerate.utils import set_seed |
|
import utils3d |
|
import click |
|
from tqdm import tqdm, trange |
|
import mlflow |
|
torch.backends.cudnn.benchmark = False |
|
|
|
from moge.train.dataloader import TrainDataLoaderPipeline |
|
from moge.train.losses import ( |
|
affine_invariant_global_loss, |
|
affine_invariant_local_loss, |
|
edge_loss, |
|
normal_loss, |
|
mask_l2_loss, |
|
mask_bce_loss, |
|
monitoring, |
|
) |
|
from moge.train.utils import build_optimizer, build_lr_scheduler |
|
from moge.utils.geometry_torch import intrinsics_to_fov |
|
from moge.utils.vis import colorize_depth, colorize_normal |
|
from moge.utils.tools import key_average, recursive_replace, CallbackOnException, flatten_nested_dict |
|
from moge.test.metrics import compute_metrics |
|
|
|
|
|
@click.command() |
|
@click.option('--config', 'config_path', type=str, default='configs/debug.json') |
|
@click.option('--workspace', type=str, default='workspace/debug', help='Path to the workspace') |
|
@click.option('--checkpoint', 'checkpoint_path', type=str, default=None, help='Path to the checkpoint to load') |
|
@click.option('--batch_size_forward', type=int, default=8, help='Batch size for each forward pass on each device') |
|
@click.option('--gradient_accumulation_steps', type=int, default=1, help='Number of steps to accumulate gradients') |
|
@click.option('--enable_gradient_checkpointing', type=bool, default=True, help='Use gradient checkpointing in backbone') |
|
@click.option('--enable_mixed_precision', type=bool, default=False, help='Use mixed precision training. Backbone is converted to FP16') |
|
@click.option('--enable_ema', type=bool, default=True, help='Maintain an exponential moving average of the model weights') |
|
@click.option('--num_iterations', type=int, default=1000000, help='Number of iterations to train the model') |
|
@click.option('--save_every', type=int, default=10000, help='Save checkpoint every n iterations') |
|
@click.option('--log_every', type=int, default=1000, help='Log metrics every n iterations') |
|
@click.option('--vis_every', type=int, default=0, help='Visualize every n iterations') |
|
@click.option('--num_vis_images', type=int, default=32, help='Number of images to visualize, must be a multiple of divided batch size') |
|
@click.option('--enable_mlflow', type=bool, default=True, help='Log metrics to MLFlow') |
|
@click.option('--seed', type=int, default=0, help='Random seed') |
|
def main( |
|
config_path: str, |
|
workspace: str, |
|
checkpoint_path: str, |
|
batch_size_forward: int, |
|
gradient_accumulation_steps: int, |
|
enable_gradient_checkpointing: bool, |
|
enable_mixed_precision: bool, |
|
enable_ema: bool, |
|
num_iterations: int, |
|
save_every: int, |
|
log_every: int, |
|
vis_every: int, |
|
num_vis_images: int, |
|
enable_mlflow: bool, |
|
seed: Optional[int], |
|
): |
|
|
|
with open(config_path, 'r') as f: |
|
config = json.load(f) |
|
|
|
accelerator = Accelerator( |
|
gradient_accumulation_steps=gradient_accumulation_steps, |
|
mixed_precision='fp16' if enable_mixed_precision else None, |
|
kwargs_handlers=[ |
|
DistributedDataParallelKwargs(find_unused_parameters=True) |
|
] |
|
) |
|
device = accelerator.device |
|
batch_size_total = batch_size_forward * gradient_accumulation_steps * accelerator.num_processes |
|
|
|
|
|
if accelerator.is_main_process: |
|
if enable_mlflow: |
|
try: |
|
mlflow.log_params({ |
|
**click.get_current_context().params, |
|
'batch_size_total': batch_size_total, |
|
}) |
|
except: |
|
print('Failed to log config to MLFlow') |
|
Path(workspace).mkdir(parents=True, exist_ok=True) |
|
with Path(workspace).joinpath('config.json').open('w') as f: |
|
json.dump(config, f, indent=4) |
|
|
|
|
|
if seed is not None: |
|
set_seed(seed, device_specific=True) |
|
|
|
|
|
print('Initialize model') |
|
with accelerator.local_main_process_first(): |
|
from moge.model import import_model_class_by_version |
|
MoGeModel = import_model_class_by_version(config['model_version']) |
|
model = MoGeModel(**config['model']) |
|
count_total_parameters = sum(p.numel() for p in model.parameters()) |
|
print(f'Total parameters: {count_total_parameters}') |
|
|
|
|
|
if enable_ema and accelerator.is_main_process: |
|
ema_avg_fn = lambda averaged_model_parameter, model_parameter, num_averaged: 0.999 * averaged_model_parameter + 0.001 * model_parameter |
|
ema_model = torch.optim.swa_utils.AveragedModel(model, device=accelerator.device, avg_fn=ema_avg_fn) |
|
|
|
|
|
if enable_gradient_checkpointing: |
|
model.enable_gradient_checkpointing() |
|
import warnings |
|
warnings.filterwarnings("ignore", category=FutureWarning, module="torch.utils.checkpoint") |
|
|
|
|
|
optimizer = build_optimizer(model, config['optimizer']) |
|
lr_scheduler = build_lr_scheduler(optimizer, config['lr_scheduler']) |
|
|
|
count_grouped_parameters = [sum(p.numel() for p in param_group['params'] if p.requires_grad) for param_group in optimizer.param_groups] |
|
for i, count in enumerate(count_grouped_parameters): |
|
print(f'- Group {i}: {count} parameters') |
|
|
|
|
|
checkpoint: Dict[str, Any] |
|
with accelerator.local_main_process_first(): |
|
if checkpoint_path.endswith('.pt'): |
|
|
|
print(f'Load checkpoint: {checkpoint_path}') |
|
checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=True) |
|
elif checkpoint_path == "latest": |
|
|
|
checkpoint_path = Path(workspace, 'checkpoint', 'latest.pt') |
|
if checkpoint_path.exists(): |
|
print(f'Load checkpoint: {checkpoint_path}') |
|
checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=True) |
|
i_step = checkpoint['step'] |
|
if 'model' not in checkpoint and (checkpoint_model_path := Path(workspace, 'checkpoint', f'{i_step:08d}.pt')).exists(): |
|
print(f'Load model checkpoint: {checkpoint_model_path}') |
|
checkpoint['model'] = torch.load(checkpoint_model_path, map_location='cpu', weights_only=True)['model'] |
|
if 'optimizer' not in checkpoint and (checkpoint_optimizer_path := Path(workspace, 'checkpoint', f'{i_step:08d}_optimizer.pt')).exists(): |
|
print(f'Load optimizer checkpoint: {checkpoint_optimizer_path}') |
|
checkpoint.update(torch.load(checkpoint_optimizer_path, map_location='cpu', weights_only=True)) |
|
if enable_ema and accelerator.is_main_process: |
|
if 'ema_model' not in checkpoint and (checkpoint_ema_model_path := Path(workspace, 'checkpoint', f'{i_step:08d}_ema.pt')).exists(): |
|
print(f'Load EMA model checkpoint: {checkpoint_ema_model_path}') |
|
checkpoint['ema_model'] = torch.load(checkpoint_ema_model_path, map_location='cpu', weights_only=True)['model'] |
|
else: |
|
checkpoint = None |
|
elif checkpoint_path is not None: |
|
|
|
i_step = int(checkpoint_path) |
|
checkpoint = {'step': i_step} |
|
if (checkpoint_model_path := Path(workspace, 'checkpoint', f'{i_step:08d}.pt')).exists(): |
|
print(f'Load model checkpoint: {checkpoint_model_path}') |
|
checkpoint['model'] = torch.load(checkpoint_model_path, map_location='cpu', weights_only=True)['model'] |
|
if (checkpoint_optimizer_path := Path(workspace, 'checkpoint', f'{i_step:08d}_optimizer.pt')).exists(): |
|
print(f'Load optimizer checkpoint: {checkpoint_optimizer_path}') |
|
checkpoint.update(torch.load(checkpoint_optimizer_path, map_location='cpu', weights_only=True)) |
|
if enable_ema and accelerator.is_main_process: |
|
if (checkpoint_ema_model_path := Path(workspace, 'checkpoint', f'{i_step:08d}_ema.pt')).exists(): |
|
print(f'Load EMA model checkpoint: {checkpoint_ema_model_path}') |
|
checkpoint['ema_model'] = torch.load(checkpoint_ema_model_path, map_location='cpu', weights_only=True)['model'] |
|
else: |
|
checkpoint = None |
|
|
|
if checkpoint is None: |
|
|
|
print('Initialize model weights') |
|
with accelerator.local_main_process_first(): |
|
model.init_weights() |
|
initial_step = 0 |
|
else: |
|
model.load_state_dict(checkpoint['model'], strict=False) |
|
if 'step' in checkpoint: |
|
initial_step = checkpoint['step'] + 1 |
|
else: |
|
initial_step = 0 |
|
if 'optimizer' in checkpoint: |
|
optimizer.load_state_dict(checkpoint['optimizer']) |
|
if enable_ema and accelerator.is_main_process and 'ema_model' in checkpoint: |
|
ema_model.module.load_state_dict(checkpoint['ema_model'], strict=False) |
|
if 'lr_scheduler' in checkpoint: |
|
lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) |
|
|
|
del checkpoint |
|
|
|
model, optimizer = accelerator.prepare(model, optimizer) |
|
if torch.version.hip and isinstance(model, torch.nn.parallel.DistributedDataParallel): |
|
|
|
from moge.model.utils import sync_ddp_hook |
|
model.register_comm_hook(None, sync_ddp_hook) |
|
|
|
|
|
with accelerator.local_main_process_first(): |
|
train_data_pipe = TrainDataLoaderPipeline(config['data'], batch_size_forward) |
|
|
|
def _write_bytes_retry_loop(save_path: Path, data: bytes): |
|
while True: |
|
try: |
|
save_path.write_bytes(data) |
|
break |
|
except Exception as e: |
|
print('Error while saving checkpoint, retrying in 1 minute: ', e) |
|
time.sleep(60) |
|
|
|
|
|
records = [] |
|
model.train() |
|
with ( |
|
train_data_pipe, |
|
tqdm(initial=initial_step, total=num_iterations, desc='Training', disable=not accelerator.is_main_process) as pbar, |
|
ThreadPoolExecutor(max_workers=1) as save_checkpoint_executor, |
|
): |
|
|
|
if accelerator.is_main_process: |
|
batches_for_vis: List[Dict[str, torch.Tensor]] = [] |
|
num_vis_images = num_vis_images // batch_size_forward * batch_size_forward |
|
for _ in range(num_vis_images // batch_size_forward): |
|
batch = train_data_pipe.get() |
|
batches_for_vis.append(batch) |
|
|
|
|
|
if vis_every > 0 and accelerator.is_main_process and initial_step == 0: |
|
save_dir = Path(workspace).joinpath('vis/gt') |
|
for i_batch, batch in enumerate(tqdm(batches_for_vis, desc='Visualize GT', leave=False)): |
|
image, gt_depth, gt_mask, gt_mask_inf, gt_intrinsics, info = batch['image'], batch['depth'], batch['depth_mask'], batch['depth_mask_inf'], batch['intrinsics'], batch['info'] |
|
gt_points = utils3d.torch.depth_to_points(gt_depth, intrinsics=gt_intrinsics) |
|
gt_normal, gt_normal_mask = utils3d.torch.points_to_normals(gt_points, gt_mask) |
|
for i_instance in range(batch['image'].shape[0]): |
|
idx = i_batch * batch_size_forward + i_instance |
|
image_i = (image[i_instance].numpy().transpose(1, 2, 0) * 255).astype(np.uint8) |
|
gt_depth_i = gt_depth[i_instance].numpy() |
|
gt_mask_i = gt_mask[i_instance].numpy() |
|
gt_mask_inf_i = gt_mask_inf[i_instance].numpy() |
|
gt_points_i = gt_points[i_instance].numpy() |
|
gt_normal_i = gt_normal[i_instance].numpy() |
|
save_dir.joinpath(f'{idx:04d}').mkdir(parents=True, exist_ok=True) |
|
cv2.imwrite(str(save_dir.joinpath(f'{idx:04d}/image.jpg')), cv2.cvtColor(image_i, cv2.COLOR_RGB2BGR)) |
|
cv2.imwrite(str(save_dir.joinpath(f'{idx:04d}/points.exr')), cv2.cvtColor(gt_points_i, cv2.COLOR_RGB2BGR), [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_FLOAT]) |
|
cv2.imwrite(str(save_dir.joinpath(f'{idx:04d}/mask.png')), gt_mask_i * 255) |
|
cv2.imwrite(str(save_dir.joinpath(f'{idx:04d}/depth_vis.png')), cv2.cvtColor(colorize_depth(gt_depth_i, gt_mask_i), cv2.COLOR_RGB2BGR)) |
|
cv2.imwrite(str(save_dir.joinpath(f'{idx:04d}/normal.png')), cv2.cvtColor(colorize_normal(gt_normal_i), cv2.COLOR_RGB2BGR)) |
|
cv2.imwrite(str(save_dir.joinpath(f'{idx:04d}/mask_inf.png')), gt_mask_inf_i * 255) |
|
with save_dir.joinpath(f'{idx:04d}/info.json').open('w') as f: |
|
json.dump(info[i_instance], f) |
|
|
|
|
|
if seed is not None: |
|
set_seed(seed + initial_step, device_specific=True) |
|
|
|
|
|
for i_step in range(initial_step, num_iterations): |
|
|
|
i_accumulate, weight_accumulate = 0, 0 |
|
while i_accumulate < gradient_accumulation_steps: |
|
|
|
batch = train_data_pipe.get() |
|
image, gt_depth, gt_mask, gt_mask_fin, gt_mask_inf, gt_intrinsics, label_type, is_metric = batch['image'], batch['depth'], batch['depth_mask'], batch['depth_mask_fin'], batch['depth_mask_inf'], batch['intrinsics'], batch['label_type'], batch['is_metric'] |
|
image, gt_depth, gt_mask, gt_mask_fin, gt_mask_inf, gt_intrinsics = image.to(device), gt_depth.to(device), gt_mask.to(device), gt_mask_fin.to(device), gt_mask_inf.to(device), gt_intrinsics.to(device) |
|
current_batch_size = image.shape[0] |
|
if all(label == 'invalid' for label in label_type): |
|
continue |
|
|
|
gt_points = utils3d.torch.depth_to_points(gt_depth, intrinsics=gt_intrinsics) |
|
gt_focal = 1 / (1 / gt_intrinsics[..., 0, 0] ** 2 + 1 / gt_intrinsics[..., 1, 1] ** 2) ** 0.5 |
|
|
|
with accelerator.accumulate(model): |
|
|
|
if i_step <= config.get('low_resolution_training_steps', 0): |
|
num_tokens = config['model']['num_tokens_range'][0] |
|
else: |
|
num_tokens = accelerate.utils.broadcast_object_list([random.randint(*config['model']['num_tokens_range'])])[0] |
|
with torch.autocast(device_type=accelerator.device.type, dtype=torch.float16, enabled=enable_mixed_precision): |
|
output = model(image, num_tokens=num_tokens) |
|
pred_points, pred_mask, pred_metric_scale = output['points'], output['mask'], output.get('metric_scale', None) |
|
|
|
|
|
loss_list, weight_list = [], [] |
|
for i in range(current_batch_size): |
|
gt_metric_scale = None |
|
loss_dict, weight_dict, misc_dict = {}, {}, {} |
|
misc_dict['monitoring'] = monitoring(pred_points[i]) |
|
for k, v in config['loss'][label_type[i]].items(): |
|
weight_dict[k] = v['weight'] |
|
if v['function'] == 'affine_invariant_global_loss': |
|
loss_dict[k], misc_dict[k], gt_metric_scale = affine_invariant_global_loss(pred_points[i], gt_points[i], gt_mask[i], **v['params']) |
|
elif v['function'] == 'affine_invariant_local_loss': |
|
loss_dict[k], misc_dict[k] = affine_invariant_local_loss(pred_points[i], gt_points[i], gt_mask[i], gt_focal[i], gt_metric_scale, **v['params']) |
|
elif v['function'] == 'normal_loss': |
|
loss_dict[k], misc_dict[k] = normal_loss(pred_points[i], gt_points[i], gt_mask[i]) |
|
elif v['function'] == 'edge_loss': |
|
loss_dict[k], misc_dict[k] = edge_loss(pred_points[i], gt_points[i], gt_mask[i]) |
|
elif v['function'] == 'mask_bce_loss': |
|
loss_dict[k], misc_dict[k] = mask_bce_loss(pred_mask[i], gt_mask_fin[i], gt_mask_inf[i]) |
|
elif v['function'] == 'mask_l2_loss': |
|
loss_dict[k], misc_dict[k] = mask_l2_loss(pred_mask[i], gt_mask_fin[i], gt_mask_inf[i]) |
|
else: |
|
raise ValueError(f'Undefined loss function: {v["function"]}') |
|
weight_dict = {'.'.join(k): v for k, v in flatten_nested_dict(weight_dict).items()} |
|
loss_dict = {'.'.join(k): v for k, v in flatten_nested_dict(loss_dict).items()} |
|
loss_ = sum([weight_dict[k] * loss_dict[k] for k in loss_dict], start=torch.tensor(0.0, device=device)) |
|
loss_list.append(loss_) |
|
|
|
if torch.isnan(loss_).item(): |
|
pbar.write(f'NaN loss in process {accelerator.process_index}') |
|
pbar.write(str(loss_dict)) |
|
|
|
misc_dict = {'.'.join(k): v for k, v in flatten_nested_dict(misc_dict).items()} |
|
records.append({ |
|
**{k: v.item() for k, v in loss_dict.items()}, |
|
**misc_dict, |
|
}) |
|
|
|
loss = sum(loss_list) / len(loss_list) |
|
|
|
|
|
accelerator.backward(loss) |
|
if accelerator.sync_gradients: |
|
if not enable_mixed_precision and any(torch.isnan(p.grad).any() for p in model.parameters() if p.grad is not None): |
|
if accelerator.is_main_process: |
|
pbar.write(f'NaN gradients, skip update') |
|
optimizer.zero_grad() |
|
continue |
|
accelerator.clip_grad_norm_(model.parameters(), 1.0) |
|
|
|
optimizer.step() |
|
optimizer.zero_grad() |
|
|
|
i_accumulate += 1 |
|
|
|
lr_scheduler.step() |
|
|
|
|
|
if enable_ema and accelerator.is_main_process and accelerator.sync_gradients: |
|
ema_model.update_parameters(model) |
|
|
|
|
|
if i_step == initial_step or i_step % log_every == 0: |
|
records = [key_average(records)] |
|
records = accelerator.gather_for_metrics(records, use_gather_object=True) |
|
if accelerator.is_main_process: |
|
records = key_average(records) |
|
if enable_mlflow: |
|
try: |
|
mlflow.log_metrics(records, step=i_step) |
|
except Exception as e: |
|
print(f'Error while logging metrics to mlflow: {e}') |
|
records = [] |
|
|
|
|
|
if accelerator.is_main_process and (i_step % save_every == 0): |
|
|
|
pbar.write(f'Save checkpoint: {i_step:08d}') |
|
Path(workspace, 'checkpoint').mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
with io.BytesIO() as f: |
|
torch.save({ |
|
'model_config': config['model'], |
|
'model': accelerator.unwrap_model(model).state_dict(), |
|
}, f) |
|
checkpoint_bytes = f.getvalue() |
|
save_checkpoint_executor.submit( |
|
_write_bytes_retry_loop, Path(workspace, 'checkpoint', f'{i_step:08d}.pt'), checkpoint_bytes |
|
) |
|
|
|
|
|
with io.BytesIO() as f: |
|
torch.save({ |
|
'model_config': config['model'], |
|
'step': i_step, |
|
'optimizer': optimizer.state_dict(), |
|
'lr_scheduler': lr_scheduler.state_dict(), |
|
}, f) |
|
checkpoint_bytes = f.getvalue() |
|
save_checkpoint_executor.submit( |
|
_write_bytes_retry_loop, Path(workspace, 'checkpoint', f'{i_step:08d}_optimizer.pt'), checkpoint_bytes |
|
) |
|
|
|
|
|
if enable_ema: |
|
with io.BytesIO() as f: |
|
torch.save({ |
|
'model_config': config['model'], |
|
'model': ema_model.module.state_dict(), |
|
}, f) |
|
checkpoint_bytes = f.getvalue() |
|
save_checkpoint_executor.submit( |
|
_write_bytes_retry_loop, Path(workspace, 'checkpoint', f'{i_step:08d}_ema.pt'), checkpoint_bytes |
|
) |
|
|
|
|
|
with io.BytesIO() as f: |
|
torch.save({ |
|
'model_config': config['model'], |
|
'step': i_step, |
|
}, f) |
|
checkpoint_bytes = f.getvalue() |
|
save_checkpoint_executor.submit( |
|
_write_bytes_retry_loop, Path(workspace, 'checkpoint', 'latest.pt'), checkpoint_bytes |
|
) |
|
|
|
|
|
if vis_every > 0 and accelerator.is_main_process and (i_step == initial_step or i_step % vis_every == 0): |
|
unwrapped_model = accelerator.unwrap_model(model) |
|
save_dir = Path(workspace).joinpath(f'vis/step_{i_step:08d}') |
|
save_dir.mkdir(parents=True, exist_ok=True) |
|
with torch.inference_mode(): |
|
for i_batch, batch in enumerate(tqdm(batches_for_vis, desc=f'Visualize: {i_step:08d}', leave=False)): |
|
image, gt_depth, gt_mask, gt_intrinsics = batch['image'], batch['depth'], batch['depth_mask'], batch['intrinsics'] |
|
image, gt_depth, gt_mask, gt_intrinsics = image.to(device), gt_depth.to(device), gt_mask.to(device), gt_intrinsics.to(device) |
|
|
|
output = unwrapped_model.infer(image) |
|
pred_points, pred_depth, pred_mask = output['points'].cpu().numpy(), output['depth'].cpu().numpy(), output['mask'].cpu().numpy() |
|
image = image.cpu().numpy() |
|
|
|
for i_instance in range(image.shape[0]): |
|
idx = i_batch * batch_size_forward + i_instance |
|
image_i = (image[i_instance].transpose(1, 2, 0) * 255).astype(np.uint8) |
|
pred_points_i = pred_points[i_instance] |
|
pred_mask_i = pred_mask[i_instance] |
|
pred_depth_i = pred_depth[i_instance] |
|
save_dir.joinpath(f'{idx:04d}').mkdir(parents=True, exist_ok=True) |
|
cv2.imwrite(str(save_dir.joinpath(f'{idx:04d}/image.jpg')), cv2.cvtColor(image_i, cv2.COLOR_RGB2BGR)) |
|
cv2.imwrite(str(save_dir.joinpath(f'{idx:04d}/points.exr')), cv2.cvtColor(pred_points_i, cv2.COLOR_RGB2BGR), [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_FLOAT]) |
|
cv2.imwrite(str(save_dir.joinpath(f'{idx:04d}/mask.png')), pred_mask_i * 255) |
|
cv2.imwrite(str(save_dir.joinpath(f'{idx:04d}/depth_vis.png')), cv2.cvtColor(colorize_depth(pred_depth_i, pred_mask_i), cv2.COLOR_RGB2BGR)) |
|
|
|
pbar.set_postfix({'loss': loss.item()}, refresh=False) |
|
pbar.update(1) |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |