Spaces:
Sleeping
Sleeping
#!/usr/bin/env python | |
import shutil | |
import glob | |
import argparse | |
import functools | |
import numpy as np | |
import math | |
import torch | |
import sys | |
import os | |
import wandb | |
import time | |
from pathlib import Path | |
torch.autograd.set_detect_anomaly(True) | |
from src.utils.train_utils import count_parameters, get_gt_func, get_loss_func | |
from src.utils.utils import clear_empty_paths | |
from src.utils.wandb_utils import get_run_by_name, update_args | |
from src.logger.logger import _logger, _configLogger | |
from src.dataset.dataset import SimpleIterDataset | |
from src.utils.import_tools import import_module | |
from src.utils.train_utils import ( | |
to_filelist, | |
train_load, | |
test_load, | |
get_model, | |
get_optimizer_and_scheduler, | |
get_model_obj_score | |
) | |
from src.evaluation.clustering_metrics import compute_f1_score_from_result | |
from src.dataset.functions_graph import graph_batch_func | |
from src.utils.parser_args import parser | |
from src.utils.paths import get_path | |
import warnings | |
import pickle | |
import os | |
def find_free_port(): | |
"""https://stackoverflow.com/questions/1365265/on-localhost-how-do-i-pick-a-free-port-number""" | |
import socket | |
from contextlib import closing | |
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: | |
s.bind(("", 0)) | |
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) | |
return str(s.getsockname()[1]) | |
# Create directories and initialize wandb run | |
args = parser.parse_args() | |
if args.load_from_run: | |
print("Loading args from run", args.load_from_run) | |
run = get_run_by_name(args.load_from_run) | |
args = update_args(args, run) | |
timestamp = time.strftime("%Y_%m_%d_%H_%M_%S") | |
random_number = str(np.random.randint(0, 1000)) # to avoid overwriting in case two jobs are started at the same time | |
args.run_name = f"{args.run_name}_{timestamp}_{random_number}" | |
if "transformer" in args.network_config.lower() or args.network_config == "src/models/GATr/Gatr.py": | |
args.spatial_part_only = False | |
if args.load_model_weights: | |
print("Changing args.load_model_weights") | |
args.load_model_weights = get_path(args.load_model_weights, "results", fallback=True) | |
if args.load_objectness_score_weights: | |
args.load_objectness_score_weights = get_path(args.load_objectness_score_weights, "results", fallback=True) | |
run_path = os.path.join(args.prefix, "train", args.run_name) | |
clear_empty_paths(get_path(os.path.join(args.prefix, "train"), "results")) # Clear paths of failed runs that don't have any files or folders in them | |
run_path = get_path(run_path, "results") | |
#Path(run_path).mkdir(parents=True, exist_ok=False) | |
os.makedirs(run_path, exist_ok=False) | |
assert os.path.exists(run_path) | |
print("Created directory", run_path) | |
args.run_path = run_path | |
wandb.init(project=args.wandb_projectname, entity=os.environ["SVJ_WANDB_ENTITY"]) | |
wandb.run.name = args.run_name | |
print("Setting the run name to", args.run_name) | |
#wandb.config.run_path = run_path | |
wandb.config.update(args.__dict__) | |
wandb.config.env_vars = {key: os.environ[key] for key in os.environ if key.startswith("SVJ_") or key.startswith("CUDA_") or key.startswith("SLURM_")} | |
if args.tag: | |
wandb.run.tags = [args.tag.strip()] | |
args.local_rank = ( | |
None if args.backend is None else int(os.environ.get("LOCAL_RANK", "0")) | |
) | |
if args.backend is not None: | |
port = find_free_port() | |
args.port = port | |
world_size = torch.cuda.device_count() | |
stdout = sys.stdout | |
if args.local_rank is not None: | |
args.log += ".%03d" % args.local_rank | |
if args.local_rank != 0: | |
stdout = None | |
_configLogger("weaver", stdout=stdout, filename=args.log) | |
warnings.filterwarnings("ignore") | |
from src.utils.nn.tools_condensation import train_epoch | |
from src.utils.nn.tools_condensation import evaluate as evaluate | |
training_mode = bool(args.data_train) | |
if training_mode: | |
# val_loaders and test_loaders are a dictionary file -> Dataloader with only one dataset | |
# train_loader is a single dataloader of all the files | |
train_loader, val_loaders, val_dataset = train_load(args) | |
if args.irc_safety_loss: | |
train_loader_aug, val_loaders_aug, val_dataset_aug = train_load(args, aug_soft=False, aug_collinear=True) | |
else: | |
train_loader_aug = None | |
else: | |
test_loaders = test_load(args) | |
if args.gpus: | |
if args.backend is not None: | |
# distributed training | |
local_rank = args.local_rank | |
print("localrank", local_rank) | |
torch.cuda.set_device(local_rank) | |
gpus = [local_rank] | |
dev = torch.device(local_rank) | |
print("initializing group process", dev) | |
torch.distributed.init_process_group(backend=args.backend) | |
_logger.info(f"Using distributed PyTorch with {args.backend} backend") | |
print("ended initializing group process") | |
else: | |
gpus = [int(i) for i in args.gpus.split(",")] | |
#if os.environ.get("CUDA_VISIBLE_DEVICES", None) is not None: | |
# gpus = [int(i) for i in os.environ["CUDA_VISIBLE_DEVICES"].split(",")] | |
dev = torch.device(gpus[0]) | |
local_rank = 0 | |
else: | |
gpus = None | |
local_rank = 0 | |
dev = torch.device("cpu") | |
model = get_model(args, dev) | |
if args.train_objectness_score: | |
model_obj_score = get_model_obj_score(args, dev) | |
model_obj_score = model_obj_score.to(dev) | |
else: | |
model_obj_score = None | |
num_parameters_counted = count_parameters(model) | |
print("Number of parameters:", num_parameters_counted) | |
wandb.config.num_parameters = num_parameters_counted | |
orig_model = model | |
loss = get_loss_func(args) | |
gt = get_gt_func(args) | |
batch_config = {"use_p_xyz": True, "use_four_momenta": False} | |
if "lgatr" in args.network_config.lower(): | |
batch_config = {"use_four_momenta": True} | |
batch_config["quark_dist_loss"] = args.loss == "quark_distance" | |
batch_config["parton_level"] = args.parton_level | |
batch_config["gen_level"] = args.gen_level | |
batch_config["obj_score"] = args.train_objectness_score | |
if args.no_pid: | |
print("Not using PID in the features") | |
batch_config["no_pid"] = True | |
print("batch_config:", batch_config) | |
if training_mode: | |
model = orig_model.to(dev) | |
if args.backend is not None: | |
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) | |
print("device_ids = gpus", gpus) | |
model = torch.nn.parallel.DistributedDataParallel( | |
model, | |
device_ids=gpus, | |
output_device=local_rank, | |
find_unused_parameters=True, | |
) | |
opt, scheduler = get_optimizer_and_scheduler(args, model, dev) | |
if args.train_objectness_score: | |
opt_os, scheduler_os = get_optimizer_and_scheduler(args, model_obj_score, dev, load_model_weights="load_objectness_score_weights") | |
else: | |
opt_os, scheduler_os = None, None | |
# DataParallel | |
if args.backend is None: | |
if gpus is not None and len(gpus) > 1: | |
# model becomes `torch.nn.DataParallel` w/ model.module being the original `torch.nn.Module` | |
model = torch.nn.DataParallel(model, device_ids=gpus) | |
if local_rank == 0: | |
wandb.watch(model, log="all", log_freq=10) | |
# Training loop | |
best_valid_metric = np.inf | |
grad_scaler = torch.cuda.amp.GradScaler() if args.use_amp else None | |
steps = 0 | |
evaluate( | |
model, | |
val_loaders, | |
dev, | |
0, | |
steps, | |
loss_func=loss, | |
gt_func=gt, | |
local_rank=local_rank, | |
args=args, | |
batch_config=batch_config, | |
predict=False, | |
model_obj_score=model_obj_score | |
) | |
res = evaluate( | |
model, | |
val_loaders, | |
dev, | |
0, | |
steps, | |
loss_func=loss, | |
gt_func=gt, | |
local_rank=local_rank, | |
args=args, | |
batch_config=batch_config, | |
predict=True, | |
model_obj_score=model_obj_score | |
) | |
# It was the quickest to do it like this | |
if model_obj_score is not None: | |
res, res_obj_score_pred, res_obj_score_target = res | |
f1 = compute_f1_score_from_result(res, val_dataset) | |
wandb.log({"val_f1_score": f1}, step=steps) | |
epochs = args.num_epochs | |
if args.num_steps != -1: | |
epochs = 999999999 | |
for epoch in range(1, epochs + 1): | |
_logger.info("-" * 50) | |
_logger.info("Epoch #%d training" % epoch) | |
steps = train_epoch( | |
args, | |
model, | |
loss_func=loss, | |
gt_func=gt, | |
opt=opt, | |
scheduler=scheduler, | |
train_loader=train_loader, | |
dev=dev, | |
epoch=epoch, | |
grad_scaler=grad_scaler, | |
local_rank=local_rank, | |
current_step=steps, | |
val_loader=val_loaders, | |
batch_config=batch_config, | |
val_dataset=val_dataset, | |
obj_score_model=model_obj_score, | |
opt_obj_score=opt_os, | |
sched_obj_score=scheduler_os, | |
train_loader_aug=train_loader_aug | |
) | |
if steps == "quit_training": | |
break | |
if args.data_test: | |
if args.backend is not None and local_rank != 0: | |
sys.exit(0) | |
if training_mode: | |
del train_loader, val_loaders | |
test_loaders = test_load(args) | |
model = orig_model.to(dev) | |
if gpus is not None and len(gpus) > 1: | |
model = torch.nn.DataParallel(model, device_ids=gpus) | |
model = model.to(dev) | |
i = 0 | |
for filename, test_loader in test_loaders.items(): | |
result = evaluate( | |
model, | |
test_loader, | |
dev, | |
0, | |
0, | |
loss_func=loss, | |
gt_func=gt, | |
local_rank=local_rank, | |
args=args, | |
batch_config=batch_config, | |
predict=True, | |
model_obj_score=model_obj_score | |
) | |
if model_obj_score is not None: | |
result, result_obj_score, result_obj_score_target = result | |
result["obj_score_pred"] = result_obj_score | |
result["obj_score_target"] = result_obj_score_target | |
_logger.info(f"Finished evaluating {filename}") | |
result["filename"] = filename | |
os.makedirs(run_path, exist_ok=True) | |
output_filename = os.path.join(run_path, f"eval_{i}.pkl") | |
pickle.dump(result, open(output_filename, "wb")) | |
i += 1 | |