import os import ast import glob import functools import math import torch from torch.utils.data import DataLoader from src.logger.logger import _logger, _configLogger from src.dataset.dataset import EventDatasetCollection, EventDataset from src.utils.import_tools import import_module from src.dataset.functions_graph import graph_batch_func from src.dataset.functions_data import concat_events from src.utils.paths import get_path from src.layers.object_cond import calc_eta_phi from src.layers.object_cond import object_condensation_loss def to_filelist(args, mode="train"): if mode == "train": flist = args.data_train elif mode == "val": flist = args.data_val elif mode == "test": flist = args.data_test else: raise NotImplementedError("Invalid mode %s" % mode) print(mode, "filelist:", flist) flist = [get_path(p, "preprocessed_data") for p in flist] return flist class TensorCollection: def __init__(self, **kwargs): self.__dict__.update(kwargs) def to(self, device): # Move all tensors to device for k, v in self.__dict__.items(): if torch.is_tensor(v): setattr(self, k, v.to(device)) return self def dict_rep(self): d = {} for k, v in self.__dict__.items(): if torch.is_tensor(v): d[k] = v return d #def __getitem__(self, i): # return TensorCollection(**{k: v[i] for k, v in self.__dict__.items()}) def train_load(args, aug_soft=False, aug_collinear=False): train_files = to_filelist(args, "train") val_files = to_filelist(args, "val") train_data = EventDatasetCollection(train_files, args, aug_soft=aug_soft, aug_collinear=aug_collinear) if args.train_dataset_size is not None: train_data = torch.utils.data.Subset(train_data, list(range(args.train_dataset_size))) train_loader = DataLoader( train_data, batch_size=args.batch_size, drop_last=True, pin_memory=True, num_workers=args.num_workers, collate_fn=concat_events, persistent_workers=args.num_workers > 0, shuffle=False ) '''val_loaders = {} for filename in val_files: val_data = EventDataset.from_directory(filename, mmap=True) val_loaders[filename] = DataLoader( val_data, batch_size=args.batch_size, drop_last=True, pin_memory=True, collate_fn=concat_events, num_workers=args.num_workers, persistent_workers=args.num_workers > 0, )''' val_data = EventDatasetCollection(val_files, args) if args.val_dataset_size is not None: val_data = torch.utils.data.Subset(val_data, list(range(args.val_dataset_size))) val_loader = DataLoader( val_data, batch_size=args.batch_size, drop_last=True, pin_memory=True, num_workers=args.num_workers, collate_fn=concat_events, persistent_workers=args.num_workers > 0, shuffle=False ) return train_loader, val_loader, val_data def test_load(args): test_files = to_filelist(args, "test") test_loaders = {} for filename in test_files: test_data = EventDataset.from_directory(filename, mmap=True, aug_soft=args.augment_soft_particles, seed=1000000) if args.test_dataset_max_size is not None: print("Limiting test dataset size to", args.test_dataset_max_size) test_data = torch.utils.data.Subset(test_data, list(range(args.test_dataset_max_size))) test_loaders[filename] = DataLoader( test_data, batch_size=args.batch_size, drop_last=True, pin_memory=True, collate_fn=concat_events, num_workers=args.num_workers, persistent_workers=args.num_workers > 0, ) return test_loaders def get_optimizer_and_scheduler(args, model, device, load_model_weights="load_model_weights"): """ Optimizer and scheduler. :param args: :param model: :return: """ optimizer_options = {k: ast.literal_eval(v) for k, v in args.optimizer_option} _logger.info("Optimizer options: %s" % str(optimizer_options)) names_lr_mult = [] if "weight_decay" in optimizer_options or "lr_mult" in optimizer_options: # https://github.com/rwightman/pytorch-image-models/blob/master/timm/optim/optim_factory.py#L31 import re decay, no_decay = {}, {} names_no_decay = [] for name, param in model.named_parameters(): if not param.requires_grad: continue # frozen weights if ( len(param.shape) == 1 or name.endswith(".bias") or ( hasattr(model, "no_weight_decay") and name in model.no_weight_decay() ) ): no_decay[name] = param names_no_decay.append(name) else: decay[name] = param decay_1x, no_decay_1x = [], [] decay_mult, no_decay_mult = [], [] mult_factor = 1 if "lr_mult" in optimizer_options: pattern, mult_factor = optimizer_options.pop("lr_mult") for name, param in decay.items(): if re.match(pattern, name): decay_mult.append(param) names_lr_mult.append(name) else: decay_1x.append(param) for name, param in no_decay.items(): if re.match(pattern, name): no_decay_mult.append(param) names_lr_mult.append(name) else: no_decay_1x.append(param) assert len(decay_1x) + len(decay_mult) == len(decay) assert len(no_decay_1x) + len(no_decay_mult) == len(no_decay) else: decay_1x, no_decay_1x = list(decay.values()), list(no_decay.values()) wd = optimizer_options.pop("weight_decay", 0.0) parameters = [ {"params": no_decay_1x, "weight_decay": 0.0}, {"params": decay_1x, "weight_decay": wd}, { "params": no_decay_mult, "weight_decay": 0.0, "lr": args.start_lr * mult_factor, }, { "params": decay_mult, "weight_decay": wd, "lr": args.start_lr * mult_factor, }, ] _logger.info( "Parameters excluded from weight decay:\n - %s", "\n - ".join(names_no_decay), ) if len(names_lr_mult): _logger.info( "Parameters with lr multiplied by %s:\n - %s", mult_factor, "\n - ".join(names_lr_mult), ) else: parameters = model.parameters() if args.optimizer == "ranger": from src.utils.nn.optimizer.ranger import Ranger opt = Ranger(parameters, lr=args.start_lr, **optimizer_options) elif args.optimizer == "adam": opt = torch.optim.Adam(parameters, lr=args.start_lr, **optimizer_options) elif args.optimizer == "adamW": opt = torch.optim.AdamW(parameters, lr=args.start_lr, **optimizer_options) elif args.optimizer == "radam": opt = torch.optim.RAdam(parameters, lr=args.start_lr, **optimizer_options) if args.__dict__[load_model_weights] is not None: _logger.info("Resume training from file %s" % args.__dict__[load_model_weights]) model_state = torch.load( args.__dict__[load_model_weights], map_location=device, ) if isinstance(model, torch.nn.parallel.DistributedDataParallel): model.module.load_state_dict(model_state["model"]) else: model.load_state_dict(model_state["model"]) opt_state = model_state["optimizer"] opt.load_state_dict(opt_state) scheduler = None if args.lr_scheduler == "steps": lr_step = round(args.num_epochs / 3) scheduler = torch.optim.lr_scheduler.MultiStepLR( opt, milestones=[10], gamma=0.20, last_epoch=-1 ) elif args.lr_scheduler == "flat+decay": num_decay_epochs = max(1, int(args.num_epochs * 0.3)) milestones = list( range(args.num_epochs - num_decay_epochs, args.num_epochs) ) gamma = 0.01 ** (1.0 / num_decay_epochs) if len(names_lr_mult): def get_lr(epoch): return gamma ** max(0, epoch - milestones[0] + 1) # noqa scheduler = torch.optim.lr_scheduler.LambdaLR( opt, (lambda _: 1, lambda _: 1, get_lr, get_lr), last_epoch=-1, verbose=True, ) else: scheduler = torch.optim.lr_scheduler.MultiStepLR( opt, milestones=milestones, gamma=gamma, last_epoch=-1 ) elif args.lr_scheduler == "flat+linear" or args.lr_scheduler == "flat+cos": total_steps = args.num_epochs * args.steps_per_epoch warmup_steps = args.warmup_steps flat_steps = total_steps * 0.7 - 1 min_factor = 0.001 def lr_fn(step_num): if step_num > total_steps: raise ValueError( "Tried to step {} times. The specified number of total steps is {}".format( step_num + 1, total_steps ) ) if step_num < warmup_steps: return 1.0 * step_num / warmup_steps if step_num <= flat_steps: return 1.0 pct = (step_num - flat_steps) / (total_steps - flat_steps) if args.lr_scheduler == "flat+linear": return max(min_factor, 1 - pct) else: return max(min_factor, 0.5 * (math.cos(math.pi * pct) + 1)) scheduler = torch.optim.lr_scheduler.LambdaLR( opt, lr_fn, last_epoch=-1 if args.load_epoch is None else args.load_epoch * args.steps_per_epoch, ) scheduler._update_per_step = ( True # mark it to update the lr every step, instead of every epoch ) elif args.lr_scheduler == "one-cycle": scheduler = torch.optim.lr_scheduler.OneCycleLR( opt, max_lr=args.start_lr, epochs=args.num_epochs, steps_per_epoch=args.steps_per_epoch, pct_start=0.3, anneal_strategy="cos", div_factor=25.0, last_epoch=-1 if args.load_epoch is None else args.load_epoch, ) scheduler._update_per_step = ( True # mark it to update the lr every step, instead of every epoch ) elif args.lr_scheduler == "reduceplateau": scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( opt, patience=2, threshold=0.01 ) # scheduler._update_per_step = ( # True # mark it to update the lr every step, instead of every epoch # ) scheduler._update_per_step = ( False # mark it to update the lr every step, instead of every epoch ) if args.__dict__[load_model_weights]: if scheduler is not None: scheduler.load_state_dict(model_state["scheduler"]) return opt, scheduler def get_target_obj_score(clusters_eta, clusters_phi, clusters_pt, event_idx_clusters, dq_eta, dq_phi, dq_event_idx, gt_mode="all_in_radius"): # return the target scores for each cluster (reteurns list of 1's and 0's) # dq_coords: list of [eta, phi] for each dark quark # dq_event_idx: list of event_idx for each dark quarks target = [] if gt_mode == "all_in_radius": for event in event_idx_clusters.unique(): filt = event_idx_clusters == event clusters = torch.stack([clusters_eta[filt], clusters_phi[filt], clusters_pt[filt]], dim=1) dq_coords_event = torch.stack([dq_eta[dq_event_idx == event], dq_phi[dq_event_idx == event]], dim=1) dist_matrix = torch.cdist( dq_coords_event, clusters[:, :2].to(dq_coords_event.device), p=2 ).T if len(dist_matrix) == 0: target.append(torch.zeros(len(clusters)).int().to(dist_matrix.device)) continue closest_quark_dist, closest_quark_idx = dist_matrix.min(dim=1) closest_quark_idx[closest_quark_dist > 0.8] = -1 target.append((closest_quark_idx != -1).float()) else: # GT is set by only considering the closest jet to each dark quark (if it's within radius) for event in event_idx_clusters.unique(): filt = event_idx_clusters == event clusters = torch.stack([clusters_eta[filt], clusters_phi[filt], clusters_pt[filt]], dim=1) dq_coords_event = torch.stack([dq_eta[dq_event_idx == event], dq_phi[dq_event_idx == event]], dim=1) dist_matrix = torch.cdist( dq_coords_event, clusters[:, :2].to(dq_coords_event.device), p=2 ).T if len(dist_matrix) == 0: target.append(torch.zeros(len(clusters)).int().to(dist_matrix.device)) continue closest_cluster_dist, closest_cluster_idx = dist_matrix.min(dim=0) closest_cluster_idx[closest_cluster_dist > 0.8] = -1 matched_clusters = closest_cluster_idx[closest_cluster_idx != -1] t = torch.zeros_like(clusters_eta[filt]) print(matched_clusters) print(matched_clusters.int()) t[matched_clusters.long()] = 1 target.append(t) return torch.cat(target).flatten() def plot_obj_score_debug(dq_eta, dq_phi, dq_batch_idx, clusters_eta, clusters_phi, clusters_pt, clusters_batch_idx, clusters_labels, input_pxyz, input_event_idx, input_clusters, pred_obj_score_clusters): # For debugging the Objectness Score head. import matplotlib.pyplot as plt n_events = dq_batch_idx.max().int().item() + 1 pfcands_pt = torch.sqrt(input_pxyz[:, 0] ** 2 + input_pxyz[:, 1] ** 2) pfcands_eta, pfcands_phi = calc_eta_phi(input_pxyz, return_stacked=0) fig, ax = plt.subplots(1, n_events, figsize=(n_events * 3, 3)) colors = {0: "grey", 1: "green"} for i in range(n_events): # Plot the clusters as dots that are green for label 1 and gray for label 0 filt = clusters_batch_idx == i ax[i].scatter(clusters_eta[filt].cpu(), clusters_phi[filt].cpu(), c=[colors[x] for x in clusters_labels[filt].tolist()], cmap="coolwarm", s=clusters_pt[filt].cpu(), alpha=0.5) # with a light gray text, also plot the target objectness score for each cluster for j in range(len(clusters_eta[filt])): ax[i].text(clusters_eta[filt][j].cpu()-0.5, clusters_phi[filt][j].cpu()-0.5, str(round(pred_obj_score_clusters[filt][j].item(), 2)), fontsize=6, color="gray", alpha=0.7) # Plot the dark quarks as red dots filt = dq_batch_idx == i ax[i].scatter(dq_eta[filt].cpu(), dq_phi[filt].cpu(), c="red", alpha=0.5) ax[i].scatter(pfcands_eta[input_event_idx == i].cpu(), pfcands_phi[input_event_idx == i].cpu(), c=input_clusters[input_event_idx == i].cpu(), cmap="coolwarm", s=pfcands_pt[input_event_idx == i].cpu(), alpha=0.5) # put pt of the clusters in gray text on top of them filt = clusters_batch_idx == i for j in range(len(clusters_eta[filt])): ax[i].text(clusters_eta[filt][j].cpu(), clusters_phi[filt][j].cpu(), str(round(clusters_pt[filt][j].item(), 2)), fontsize=8, color="black") fig.tight_layout() return fig def get_loss_func(args): # Loss function takes in the output of a model and the output of GT (the GT labels) and returns the loss. def loss(model_input, model_output, gt_labels): batch_numbers = model_input.batch_idx if not (args.loss == "quark_distance" or args.train_objectness_score): labels = gt_labels+1 else: labels = gt_labels return object_condensation_loss(model_input, model_output, labels, batch_numbers, attr_weight=args.attr_loss_weight, repul_weight=args.repul_loss_weight, coord_weight=args.coord_loss_weight, beta_type=args.beta_type, lorentz_norm=args.lorentz_norm, spatial_part_only=args.spatial_part_only, loss_quark_distance=args.loss=="quark_distance", oc_scalars=args.scalars_oc, loss_obj_score=args.train_objectness_score) return loss def renumber_clusters(tensor): unique = tensor.unique() mapping = torch.zeros(unique.max() + 1) for i, u in enumerate(unique): mapping[u] = i return mapping[tensor] def get_gt_func(args): # Gets the GT function: the function accepts an Event batch # and returns the ground truth labels (GT idx of a dark quark it belongs to, or -1 for noise) # By default, it returns the dark quark that is closest to the event, IF it's closer than R. R = args.gt_radius def get_idx_for_event(obj, i): return obj.batch_number[i], obj.batch_number[i + 1] def get_labels(b, pfcands, special=False, get_coordinates=False, get_dq_coords=False): # b: Batch of events # if get_coordinates is true, it returns the coordinates of the labels rather than the clustering labels themselves. labels = torch.zeros(len(pfcands)).long() if get_coordinates: labels_coordinates = torch.zeros(len(b.matrix_element_gen_particles.pt), 4).float() labels_no_renumber = torch.ones_like(labels)*-1 offset = 0 if get_dq_coords: dq_coords = [b.matrix_element_gen_particles.eta, b.matrix_element_gen_particles.phi] #dq_coords_batch_idx = b.matrix_element_gen_particles.batch_number dq_coords_batch_idx = torch.zeros(b.matrix_element_gen_particles.pt.shape) for i in range(len(b.matrix_element_gen_particles.batch_number) - 1): dq_coords_batch_idx[b.matrix_element_gen_particles.batch_number[i]:b.matrix_element_gen_particles.batch_number[i + 1]] = i for i in range(len(b)): s_dq, e_dq = get_idx_for_event(b.matrix_element_gen_particles, i) dq_eta = b.matrix_element_gen_particles.eta[s_dq:e_dq] dq_phi = b.matrix_element_gen_particles.phi[s_dq:e_dq] # dq_pt = b.matrix_element_gen_particles.pt[s:e] # Maybe we can somehow weigh the loss by pt? s, e = get_idx_for_event(pfcands, i) pfcands_eta = pfcands.eta[s:e] pfcands_phi = pfcands.phi[s:e] # calculate the distance matrix between each dark quark and pfcands dist_matrix = torch.cdist( torch.stack([dq_eta, dq_phi], dim=1), torch.stack([pfcands_eta, pfcands_phi], dim=1), p=2 ) dist_matrix = dist_matrix.T closest_quark_dist, closest_quark_idx = dist_matrix.min(dim=1) closest_quark_idx[closest_quark_dist > R] = -1 if len(closest_quark_idx): #if special: print("Closest quark idx", closest_quark_idx, "; renumbered ", # renumber_clusters(closest_quark_idx + 1) - 1) if not get_coordinates: closest_quark_idx = renumber_clusters(closest_quark_idx + 1) - 1 else: labels_no_renumber[s:e] = closest_quark_idx closest_quark_idx[closest_quark_idx != -1] += offset labels[s:e] = closest_quark_idx if get_coordinates: E_dq = b.matrix_element_gen_particles.E[s_dq:e_dq] pxyz_dq = b.matrix_element_gen_particles.pxyz[s_dq:e_dq] # the -1 doesn't matter as it will be ignored anyway labels_coordinates[s_dq:e_dq] = torch.cat([E_dq.unsqueeze(-1), pxyz_dq], dim=1) offset += len(E_dq) if get_coordinates: return TensorCollection(labels=labels, labels_coordinates=labels_coordinates, labels_no_renumber=labels_no_renumber) if get_dq_coords: return TensorCollection(labels=labels, dq_coords=dq_coords, dq_coords_batch_idx=dq_coords_batch_idx) return labels def gt(events): #special_labels = get_labels(events, events.special_pfcands, special=True) #print("Special pfcands labels", special_labels) #return torch.cat([get_labels(events, events.pfcands), special_labels]) pfcands = events.pfcands if args.parton_level: pfcands = events.final_parton_level_particles if args.gen_level: pfcands = events.final_gen_particles return get_labels(events, pfcands, get_coordinates=args.loss=="quark_distance", get_dq_coords=args.train_objectness_score) return gt def count_parameters(model): return sum(p.numel() for p in model.parameters() if p.requires_grad) def get_model(args, dev): network_options = {} # TODO: implement network options network_module = import_module(args.network_config, name="_network_module") model = network_module.get_model(obj_score=False, args=args, **network_options) if args.load_model_weights: print("Loading model state dict from %s" % args.load_model_weights) model_state = torch.load(args.load_model_weights, map_location=dev)["model"] missing_keys, unexpected_keys = model.load_state_dict(model_state, strict=False) _logger.info( "Model initialized with weights from %s\n ... Missing: %s\n ... Unexpected: %s" % (args.load_model_weights, missing_keys, unexpected_keys) ) assert len(missing_keys) == 0 assert len(unexpected_keys) == 0 return model def get_model_obj_score(args, dev): network_options = {} # TODO: implement network options network_module = import_module(args.obj_score_module, name="_network_module") model = network_module.get_model(obj_score=True, args=args, **network_options) if args.load_objectness_score_weights: assert args.train_objectness_score print("Loading objectness score model state dict from %s" % args.load_objectness_score_weights) model_state = torch.load(args.load_objectness_score_weights, map_location=dev)["model"] missing_keys, unexpected_keys = model.load_state_dict(model_state, strict=False) _logger.info( "Objectness score model initialized with weights from %s\n ... Missing: %s\n ... Unexpected: %s" % (args.load_objectness_score_weights, missing_keys, unexpected_keys) ) assert len(missing_keys) == 0 assert len(unexpected_keys) == 0 return model