jetclustering / src /train.py
gregorkrzmanc's picture
.
e75a247
raw
history blame
10.2 kB
#!/usr/bin/env python
import shutil
import glob
import argparse
import functools
import numpy as np
import math
import torch
import sys
import os
import wandb
import time
from pathlib import Path
torch.autograd.set_detect_anomaly(True)
from src.utils.train_utils import count_parameters, get_gt_func, get_loss_func
from src.utils.utils import clear_empty_paths
from src.utils.wandb_utils import get_run_by_name, update_args
from src.logger.logger import _logger, _configLogger
from src.dataset.dataset import SimpleIterDataset
from src.utils.import_tools import import_module
from src.utils.train_utils import (
to_filelist,
train_load,
test_load,
get_model,
get_optimizer_and_scheduler,
get_model_obj_score
)
from src.evaluation.clustering_metrics import compute_f1_score_from_result
from src.dataset.functions_graph import graph_batch_func
from src.utils.parser_args import parser
from src.utils.paths import get_path
import warnings
import pickle
import os
def find_free_port():
"""https://stackoverflow.com/questions/1365265/on-localhost-how-do-i-pick-a-free-port-number"""
import socket
from contextlib import closing
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
s.bind(("", 0))
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
return str(s.getsockname()[1])
# Create directories and initialize wandb run
args = parser.parse_args()
if args.load_from_run:
print("Loading args from run", args.load_from_run)
run = get_run_by_name(args.load_from_run)
args = update_args(args, run)
timestamp = time.strftime("%Y_%m_%d_%H_%M_%S")
random_number = str(np.random.randint(0, 1000)) # to avoid overwriting in case two jobs are started at the same time
args.run_name = f"{args.run_name}_{timestamp}_{random_number}"
if "transformer" in args.network_config.lower() or args.network_config == "src/models/GATr/Gatr.py":
args.spatial_part_only = False
if args.load_model_weights:
print("Changing args.load_model_weights")
args.load_model_weights = get_path(args.load_model_weights, "results", fallback=True)
if args.load_objectness_score_weights:
args.load_objectness_score_weights = get_path(args.load_objectness_score_weights, "results", fallback=True)
run_path = os.path.join(args.prefix, "train", args.run_name)
clear_empty_paths(get_path(os.path.join(args.prefix, "train"), "results")) # Clear paths of failed runs that don't have any files or folders in them
run_path = get_path(run_path, "results")
#Path(run_path).mkdir(parents=True, exist_ok=False)
os.makedirs(run_path, exist_ok=False)
assert os.path.exists(run_path)
print("Created directory", run_path)
args.run_path = run_path
wandb.init(project=args.wandb_projectname, entity=os.environ["SVJ_WANDB_ENTITY"])
wandb.run.name = args.run_name
print("Setting the run name to", args.run_name)
#wandb.config.run_path = run_path
wandb.config.update(args.__dict__)
wandb.config.env_vars = {key: os.environ[key] for key in os.environ if key.startswith("SVJ_") or key.startswith("CUDA_") or key.startswith("SLURM_")}
if args.tag:
wandb.run.tags = [args.tag.strip()]
args.local_rank = (
None if args.backend is None else int(os.environ.get("LOCAL_RANK", "0"))
)
if args.backend is not None:
port = find_free_port()
args.port = port
world_size = torch.cuda.device_count()
stdout = sys.stdout
if args.local_rank is not None:
args.log += ".%03d" % args.local_rank
if args.local_rank != 0:
stdout = None
_configLogger("weaver", stdout=stdout, filename=args.log)
warnings.filterwarnings("ignore")
from src.utils.nn.tools_condensation import train_epoch
from src.utils.nn.tools_condensation import evaluate as evaluate
training_mode = bool(args.data_train)
if training_mode:
# val_loaders and test_loaders are a dictionary file -> Dataloader with only one dataset
# train_loader is a single dataloader of all the files
train_loader, val_loaders, val_dataset = train_load(args)
if args.irc_safety_loss:
train_loader_aug, val_loaders_aug, val_dataset_aug = train_load(args, aug_soft=False, aug_collinear=True)
else:
train_loader_aug = None
else:
test_loaders = test_load(args)
if args.gpus:
if args.backend is not None:
# distributed training
local_rank = args.local_rank
print("localrank", local_rank)
torch.cuda.set_device(local_rank)
gpus = [local_rank]
dev = torch.device(local_rank)
print("initializing group process", dev)
torch.distributed.init_process_group(backend=args.backend)
_logger.info(f"Using distributed PyTorch with {args.backend} backend")
print("ended initializing group process")
else:
gpus = [int(i) for i in args.gpus.split(",")]
#if os.environ.get("CUDA_VISIBLE_DEVICES", None) is not None:
# gpus = [int(i) for i in os.environ["CUDA_VISIBLE_DEVICES"].split(",")]
dev = torch.device(gpus[0])
local_rank = 0
else:
gpus = None
local_rank = 0
dev = torch.device("cpu")
model = get_model(args, dev)
if args.train_objectness_score:
model_obj_score = get_model_obj_score(args, dev)
model_obj_score = model_obj_score.to(dev)
else:
model_obj_score = None
num_parameters_counted = count_parameters(model)
print("Number of parameters:", num_parameters_counted)
wandb.config.num_parameters = num_parameters_counted
orig_model = model
loss = get_loss_func(args)
gt = get_gt_func(args)
batch_config = {"use_p_xyz": True, "use_four_momenta": False}
if "lgatr" in args.network_config.lower():
batch_config = {"use_four_momenta": True}
batch_config["quark_dist_loss"] = args.loss == "quark_distance"
batch_config["parton_level"] = args.parton_level
batch_config["gen_level"] = args.gen_level
batch_config["obj_score"] = args.train_objectness_score
if args.no_pid:
print("Not using PID in the features")
batch_config["no_pid"] = True
print("batch_config:", batch_config)
if training_mode:
model = orig_model.to(dev)
if args.backend is not None:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
print("device_ids = gpus", gpus)
model = torch.nn.parallel.DistributedDataParallel(
model,
device_ids=gpus,
output_device=local_rank,
find_unused_parameters=True,
)
opt, scheduler = get_optimizer_and_scheduler(args, model, dev)
if args.train_objectness_score:
opt_os, scheduler_os = get_optimizer_and_scheduler(args, model_obj_score, dev, load_model_weights="load_objectness_score_weights")
else:
opt_os, scheduler_os = None, None
# DataParallel
if args.backend is None:
if gpus is not None and len(gpus) > 1:
# model becomes `torch.nn.DataParallel` w/ model.module being the original `torch.nn.Module`
model = torch.nn.DataParallel(model, device_ids=gpus)
if local_rank == 0:
wandb.watch(model, log="all", log_freq=10)
# Training loop
best_valid_metric = np.inf
grad_scaler = torch.cuda.amp.GradScaler() if args.use_amp else None
steps = 0
evaluate(
model,
val_loaders,
dev,
0,
steps,
loss_func=loss,
gt_func=gt,
local_rank=local_rank,
args=args,
batch_config=batch_config,
predict=False,
model_obj_score=model_obj_score
)
res = evaluate(
model,
val_loaders,
dev,
0,
steps,
loss_func=loss,
gt_func=gt,
local_rank=local_rank,
args=args,
batch_config=batch_config,
predict=True,
model_obj_score=model_obj_score
)
# It was the quickest to do it like this
if model_obj_score is not None:
res, res_obj_score_pred, res_obj_score_target = res
f1 = compute_f1_score_from_result(res, val_dataset)
wandb.log({"val_f1_score": f1}, step=steps)
epochs = args.num_epochs
if args.num_steps != -1:
epochs = 999999999
for epoch in range(1, epochs + 1):
_logger.info("-" * 50)
_logger.info("Epoch #%d training" % epoch)
steps = train_epoch(
args,
model,
loss_func=loss,
gt_func=gt,
opt=opt,
scheduler=scheduler,
train_loader=train_loader,
dev=dev,
epoch=epoch,
grad_scaler=grad_scaler,
local_rank=local_rank,
current_step=steps,
val_loader=val_loaders,
batch_config=batch_config,
val_dataset=val_dataset,
obj_score_model=model_obj_score,
opt_obj_score=opt_os,
sched_obj_score=scheduler_os,
train_loader_aug=train_loader_aug
)
if steps == "quit_training":
break
if args.data_test:
if args.backend is not None and local_rank != 0:
sys.exit(0)
if training_mode:
del train_loader, val_loaders
test_loaders = test_load(args)
model = orig_model.to(dev)
if gpus is not None and len(gpus) > 1:
model = torch.nn.DataParallel(model, device_ids=gpus)
model = model.to(dev)
i = 0
for filename, test_loader in test_loaders.items():
result = evaluate(
model,
test_loader,
dev,
0,
0,
loss_func=loss,
gt_func=gt,
local_rank=local_rank,
args=args,
batch_config=batch_config,
predict=True,
model_obj_score=model_obj_score
)
if model_obj_score is not None:
result, result_obj_score, result_obj_score_target = result
result["obj_score_pred"] = result_obj_score
result["obj_score_target"] = result_obj_score_target
_logger.info(f"Finished evaluating {filename}")
result["filename"] = filename
os.makedirs(run_path, exist_ok=True)
output_filename = os.path.join(run_path, f"eval_{i}.pkl")
pickle.dump(result, open(output_filename, "wb"))
i += 1