#!/usr/bin/env python import shutil import glob import argparse import functools import numpy as np import math import torch import sys import os import wandb import time from pathlib import Path torch.autograd.set_detect_anomaly(True) from src.utils.train_utils import count_parameters, get_gt_func, get_loss_func from src.utils.utils import clear_empty_paths from src.utils.wandb_utils import get_run_by_name, update_args from src.logger.logger import _logger, _configLogger from src.dataset.dataset import SimpleIterDataset from src.utils.import_tools import import_module from src.utils.train_utils import ( to_filelist, train_load, test_load, get_model, get_optimizer_and_scheduler, get_model_obj_score ) from src.evaluation.clustering_metrics import compute_f1_score_from_result from src.dataset.functions_graph import graph_batch_func from src.utils.parser_args import parser from src.utils.paths import get_path import warnings import pickle import os def find_free_port(): """https://stackoverflow.com/questions/1365265/on-localhost-how-do-i-pick-a-free-port-number""" import socket from contextlib import closing with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: s.bind(("", 0)) s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) return str(s.getsockname()[1]) # Create directories and initialize wandb run args = parser.parse_args() if args.load_from_run: print("Loading args from run", args.load_from_run) run = get_run_by_name(args.load_from_run) args = update_args(args, run) timestamp = time.strftime("%Y_%m_%d_%H_%M_%S") random_number = str(np.random.randint(0, 1000)) # to avoid overwriting in case two jobs are started at the same time args.run_name = f"{args.run_name}_{timestamp}_{random_number}" if "transformer" in args.network_config.lower() or args.network_config == "src/models/GATr/Gatr.py": args.spatial_part_only = False if args.load_model_weights: print("Changing args.load_model_weights") args.load_model_weights = get_path(args.load_model_weights, "results", fallback=True) if args.load_objectness_score_weights: args.load_objectness_score_weights = get_path(args.load_objectness_score_weights, "results", fallback=True) run_path = os.path.join(args.prefix, "train", args.run_name) clear_empty_paths(get_path(os.path.join(args.prefix, "train"), "results")) # Clear paths of failed runs that don't have any files or folders in them run_path = get_path(run_path, "results") #Path(run_path).mkdir(parents=True, exist_ok=False) os.makedirs(run_path, exist_ok=False) assert os.path.exists(run_path) print("Created directory", run_path) args.run_path = run_path wandb.init(project=args.wandb_projectname, entity=os.environ["SVJ_WANDB_ENTITY"]) wandb.run.name = args.run_name print("Setting the run name to", args.run_name) #wandb.config.run_path = run_path wandb.config.update(args.__dict__) wandb.config.env_vars = {key: os.environ[key] for key in os.environ if key.startswith("SVJ_") or key.startswith("CUDA_") or key.startswith("SLURM_")} if args.tag: wandb.run.tags = [args.tag.strip()] args.local_rank = ( None if args.backend is None else int(os.environ.get("LOCAL_RANK", "0")) ) if args.backend is not None: port = find_free_port() args.port = port world_size = torch.cuda.device_count() stdout = sys.stdout if args.local_rank is not None: args.log += ".%03d" % args.local_rank if args.local_rank != 0: stdout = None _configLogger("weaver", stdout=stdout, filename=args.log) warnings.filterwarnings("ignore") from src.utils.nn.tools_condensation import train_epoch from src.utils.nn.tools_condensation import evaluate as evaluate training_mode = bool(args.data_train) if training_mode: # val_loaders and test_loaders are a dictionary file -> Dataloader with only one dataset # train_loader is a single dataloader of all the files train_loader, val_loaders, val_dataset = train_load(args) if args.irc_safety_loss: train_loader_aug, val_loaders_aug, val_dataset_aug = train_load(args, aug_soft=False, aug_collinear=True) else: train_loader_aug = None else: test_loaders = test_load(args) if args.gpus: if args.backend is not None: # distributed training local_rank = args.local_rank print("localrank", local_rank) torch.cuda.set_device(local_rank) gpus = [local_rank] dev = torch.device(local_rank) print("initializing group process", dev) torch.distributed.init_process_group(backend=args.backend) _logger.info(f"Using distributed PyTorch with {args.backend} backend") print("ended initializing group process") else: gpus = [int(i) for i in args.gpus.split(",")] #if os.environ.get("CUDA_VISIBLE_DEVICES", None) is not None: # gpus = [int(i) for i in os.environ["CUDA_VISIBLE_DEVICES"].split(",")] dev = torch.device(gpus[0]) local_rank = 0 else: gpus = None local_rank = 0 dev = torch.device("cpu") model = get_model(args, dev) if args.train_objectness_score: model_obj_score = get_model_obj_score(args, dev) model_obj_score = model_obj_score.to(dev) else: model_obj_score = None num_parameters_counted = count_parameters(model) print("Number of parameters:", num_parameters_counted) wandb.config.num_parameters = num_parameters_counted orig_model = model loss = get_loss_func(args) gt = get_gt_func(args) batch_config = {"use_p_xyz": True, "use_four_momenta": False} if "lgatr" in args.network_config.lower(): batch_config = {"use_four_momenta": True} batch_config["quark_dist_loss"] = args.loss == "quark_distance" batch_config["parton_level"] = args.parton_level batch_config["gen_level"] = args.gen_level batch_config["obj_score"] = args.train_objectness_score if args.no_pid: print("Not using PID in the features") batch_config["no_pid"] = True print("batch_config:", batch_config) if training_mode: model = orig_model.to(dev) if args.backend is not None: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) print("device_ids = gpus", gpus) model = torch.nn.parallel.DistributedDataParallel( model, device_ids=gpus, output_device=local_rank, find_unused_parameters=True, ) opt, scheduler = get_optimizer_and_scheduler(args, model, dev) if args.train_objectness_score: opt_os, scheduler_os = get_optimizer_and_scheduler(args, model_obj_score, dev, load_model_weights="load_objectness_score_weights") else: opt_os, scheduler_os = None, None # DataParallel if args.backend is None: if gpus is not None and len(gpus) > 1: # model becomes `torch.nn.DataParallel` w/ model.module being the original `torch.nn.Module` model = torch.nn.DataParallel(model, device_ids=gpus) if local_rank == 0: wandb.watch(model, log="all", log_freq=10) # Training loop best_valid_metric = np.inf grad_scaler = torch.cuda.amp.GradScaler() if args.use_amp else None steps = 0 evaluate( model, val_loaders, dev, 0, steps, loss_func=loss, gt_func=gt, local_rank=local_rank, args=args, batch_config=batch_config, predict=False, model_obj_score=model_obj_score ) res = evaluate( model, val_loaders, dev, 0, steps, loss_func=loss, gt_func=gt, local_rank=local_rank, args=args, batch_config=batch_config, predict=True, model_obj_score=model_obj_score ) # It was the quickest to do it like this if model_obj_score is not None: res, res_obj_score_pred, res_obj_score_target = res f1 = compute_f1_score_from_result(res, val_dataset) wandb.log({"val_f1_score": f1}, step=steps) epochs = args.num_epochs if args.num_steps != -1: epochs = 999999999 for epoch in range(1, epochs + 1): _logger.info("-" * 50) _logger.info("Epoch #%d training" % epoch) steps = train_epoch( args, model, loss_func=loss, gt_func=gt, opt=opt, scheduler=scheduler, train_loader=train_loader, dev=dev, epoch=epoch, grad_scaler=grad_scaler, local_rank=local_rank, current_step=steps, val_loader=val_loaders, batch_config=batch_config, val_dataset=val_dataset, obj_score_model=model_obj_score, opt_obj_score=opt_os, sched_obj_score=scheduler_os, train_loader_aug=train_loader_aug ) if steps == "quit_training": break if args.data_test: if args.backend is not None and local_rank != 0: sys.exit(0) if training_mode: del train_loader, val_loaders test_loaders = test_load(args) model = orig_model.to(dev) if gpus is not None and len(gpus) > 1: model = torch.nn.DataParallel(model, device_ids=gpus) model = model.to(dev) i = 0 for filename, test_loader in test_loaders.items(): result = evaluate( model, test_loader, dev, 0, 0, loss_func=loss, gt_func=gt, local_rank=local_rank, args=args, batch_config=batch_config, predict=True, model_obj_score=model_obj_score ) if model_obj_score is not None: result, result_obj_score, result_obj_score_target = result result["obj_score_pred"] = result_obj_score result["obj_score_target"] = result_obj_score_target _logger.info(f"Finished evaluating {filename}") result["filename"] = filename os.makedirs(run_path, exist_ok=True) output_filename = os.path.join(run_path, f"eval_{i}.pkl") pickle.dump(result, open(output_filename, "wb")) i += 1