Spaces:
Runtime error
Runtime error
| # MIT License | |
| # Copyright (c) 2022 Intelligent Systems Lab Org | |
| # Permission is hereby granted, free of charge, to any person obtaining a copy | |
| # of this software and associated documentation files (the "Software"), to deal | |
| # in the Software without restriction, including without limitation the rights | |
| # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | |
| # copies of the Software, and to permit persons to whom the Software is | |
| # furnished to do so, subject to the following conditions: | |
| # The above copyright notice and this permission notice shall be included in all | |
| # copies or substantial portions of the Software. | |
| # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | |
| # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | |
| # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | |
| # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | |
| # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | |
| # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | |
| # SOFTWARE. | |
| # File author: Shariq Farooq Bhat | |
| from zoedepth.utils.misc import count_parameters, parallelize | |
| from zoedepth.utils.config import get_config | |
| from zoedepth.utils.arg_utils import parse_unknown | |
| from zoedepth.trainers.builder import get_trainer | |
| from zoedepth.models.builder import build_model | |
| from zoedepth.data.data_mono import MixedNYUKITTI | |
| import torch.utils.data.distributed | |
| import torch.multiprocessing as mp | |
| import torch | |
| import numpy as np | |
| from pprint import pprint | |
| import argparse | |
| import os | |
| os.environ["PYOPENGL_PLATFORM"] = "egl" | |
| os.environ["WANDB_START_METHOD"] = "thread" | |
| def fix_random_seed(seed: int): | |
| """ | |
| Fix random seed for reproducibility | |
| Args: | |
| seed (int): random seed | |
| """ | |
| import random | |
| import numpy | |
| import torch | |
| random.seed(seed) | |
| numpy.random.seed(seed) | |
| torch.manual_seed(seed) | |
| torch.cuda.manual_seed(seed) | |
| torch.cuda.manual_seed_all(seed) | |
| torch.backends.cudnn.deterministic = True | |
| torch.backends.cudnn.benchmark = False | |
| def load_ckpt(config, model, checkpoint_dir="./checkpoints", ckpt_type="best"): | |
| import glob | |
| import os | |
| from zoedepth.models.model_io import load_wts | |
| if hasattr(config, "checkpoint"): | |
| checkpoint = config.checkpoint | |
| elif hasattr(config, "ckpt_pattern"): | |
| pattern = config.ckpt_pattern | |
| matches = glob.glob(os.path.join( | |
| checkpoint_dir, f"*{pattern}*{ckpt_type}*")) | |
| if not (len(matches) > 0): | |
| raise ValueError(f"No matches found for the pattern {pattern}") | |
| checkpoint = matches[0] | |
| else: | |
| return model | |
| model = load_wts(model, checkpoint) | |
| print("Loaded weights from {0}".format(checkpoint)) | |
| return model | |
| def main_worker(gpu, ngpus_per_node, config): | |
| try: | |
| fix_random_seed(43) | |
| config.gpu = gpu | |
| model = build_model(config) | |
| model = load_ckpt(config, model) | |
| model = parallelize(config, model) | |
| total_params = f"{round(count_parameters(model)/1e6,2)}M" | |
| config.total_params = total_params | |
| print(f"Total parameters : {total_params}") | |
| train_loader = MixedNYUKITTI(config, "train").data | |
| test_loader = MixedNYUKITTI(config, "online_eval").data | |
| trainer = get_trainer(config)( | |
| config, model, train_loader, test_loader, device=config.gpu) | |
| trainer.train() | |
| finally: | |
| import wandb | |
| wandb.finish() | |
| if __name__ == '__main__': | |
| mp.set_start_method('forkserver') | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("-m", "--model", type=str, default="synunet") | |
| parser.add_argument("-d", "--dataset", type=str, default='mix') | |
| parser.add_argument("--trainer", type=str, default=None) | |
| args, unknown_args = parser.parse_known_args() | |
| overwrite_kwargs = parse_unknown(unknown_args) | |
| overwrite_kwargs["model"] = args.model | |
| if args.trainer is not None: | |
| overwrite_kwargs["trainer"] = args.trainer | |
| config = get_config(args.model, "train", args.dataset, **overwrite_kwargs) | |
| # git_commit() | |
| if config.use_shared_dict: | |
| shared_dict = mp.Manager().dict() | |
| else: | |
| shared_dict = None | |
| config.shared_dict = shared_dict | |
| config.batch_size = config.bs | |
| config.mode = 'train' | |
| if config.root != "." and not os.path.isdir(config.root): | |
| os.makedirs(config.root) | |
| try: | |
| node_str = os.environ['SLURM_JOB_NODELIST'].replace( | |
| '[', '').replace(']', '') | |
| nodes = node_str.split(',') | |
| config.world_size = len(nodes) | |
| config.rank = int(os.environ['SLURM_PROCID']) | |
| # config.save_dir = "/ibex/scratch/bhatsf/videodepth/checkpoints" | |
| except KeyError as e: | |
| # We are NOT using SLURM | |
| config.world_size = 1 | |
| config.rank = 0 | |
| nodes = ["127.0.0.1"] | |
| if config.distributed: | |
| print(config.rank) | |
| port = np.random.randint(15000, 15025) | |
| config.dist_url = 'tcp://{}:{}'.format(nodes[0], port) | |
| print(config.dist_url) | |
| config.dist_backend = 'nccl' | |
| config.gpu = None | |
| ngpus_per_node = torch.cuda.device_count() | |
| config.num_workers = config.workers | |
| config.ngpus_per_node = ngpus_per_node | |
| print("Config:") | |
| pprint(config) | |
| if config.distributed: | |
| config.world_size = ngpus_per_node * config.world_size | |
| mp.spawn(main_worker, nprocs=ngpus_per_node, | |
| args=(ngpus_per_node, config)) | |
| else: | |
| if ngpus_per_node == 1: | |
| config.gpu = 0 | |
| main_worker(config.gpu, ngpus_per_node, config) | |