jetclustering / src /utils /train_utils.py
gregorkrzmanc's picture
.
e75a247
raw
history blame
23.5 kB
import os
import ast
import glob
import functools
import math
import torch
from torch.utils.data import DataLoader
from src.logger.logger import _logger, _configLogger
from src.dataset.dataset import EventDatasetCollection, EventDataset
from src.utils.import_tools import import_module
from src.dataset.functions_graph import graph_batch_func
from src.dataset.functions_data import concat_events
from src.utils.paths import get_path
from src.layers.object_cond import calc_eta_phi
from src.layers.object_cond import object_condensation_loss
def to_filelist(args, mode="train"):
if mode == "train":
flist = args.data_train
elif mode == "val":
flist = args.data_val
elif mode == "test":
flist = args.data_test
else:
raise NotImplementedError("Invalid mode %s" % mode)
print(mode, "filelist:", flist)
flist = [get_path(p, "preprocessed_data") for p in flist]
return flist
class TensorCollection:
def __init__(self, **kwargs):
self.__dict__.update(kwargs)
def to(self, device):
# Move all tensors to device
for k, v in self.__dict__.items():
if torch.is_tensor(v):
setattr(self, k, v.to(device))
return self
def dict_rep(self):
d = {}
for k, v in self.__dict__.items():
if torch.is_tensor(v):
d[k] = v
return d
#def __getitem__(self, i):
# return TensorCollection(**{k: v[i] for k, v in self.__dict__.items()})
def train_load(args, aug_soft=False, aug_collinear=False):
train_files = to_filelist(args, "train")
val_files = to_filelist(args, "val")
train_data = EventDatasetCollection(train_files, args, aug_soft=aug_soft, aug_collinear=aug_collinear)
if args.train_dataset_size is not None:
train_data = torch.utils.data.Subset(train_data, list(range(args.train_dataset_size)))
train_loader = DataLoader(
train_data,
batch_size=args.batch_size,
drop_last=True,
pin_memory=True,
num_workers=args.num_workers,
collate_fn=concat_events,
persistent_workers=args.num_workers > 0,
shuffle=False
)
'''val_loaders = {}
for filename in val_files:
val_data = EventDataset.from_directory(filename, mmap=True)
val_loaders[filename] = DataLoader(
val_data,
batch_size=args.batch_size,
drop_last=True,
pin_memory=True,
collate_fn=concat_events,
num_workers=args.num_workers,
persistent_workers=args.num_workers > 0,
)'''
val_data = EventDatasetCollection(val_files, args)
if args.val_dataset_size is not None:
val_data = torch.utils.data.Subset(val_data, list(range(args.val_dataset_size)))
val_loader = DataLoader(
val_data,
batch_size=args.batch_size,
drop_last=True,
pin_memory=True,
num_workers=args.num_workers,
collate_fn=concat_events,
persistent_workers=args.num_workers > 0,
shuffle=False
)
return train_loader, val_loader, val_data
def test_load(args):
test_files = to_filelist(args, "test")
test_loaders = {}
for filename in test_files:
test_data = EventDataset.from_directory(filename, mmap=True, aug_soft=args.augment_soft_particles, seed=1000000)
if args.test_dataset_max_size is not None:
print("Limiting test dataset size to", args.test_dataset_max_size)
test_data = torch.utils.data.Subset(test_data, list(range(args.test_dataset_max_size)))
test_loaders[filename] = DataLoader(
test_data,
batch_size=args.batch_size,
drop_last=True,
pin_memory=True,
collate_fn=concat_events,
num_workers=args.num_workers,
persistent_workers=args.num_workers > 0,
)
return test_loaders
def get_optimizer_and_scheduler(args, model, device, load_model_weights="load_model_weights"):
"""
Optimizer and scheduler.
:param args:
:param model:
:return:
"""
optimizer_options = {k: ast.literal_eval(v) for k, v in args.optimizer_option}
_logger.info("Optimizer options: %s" % str(optimizer_options))
names_lr_mult = []
if "weight_decay" in optimizer_options or "lr_mult" in optimizer_options:
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/optim/optim_factory.py#L31
import re
decay, no_decay = {}, {}
names_no_decay = []
for name, param in model.named_parameters():
if not param.requires_grad:
continue # frozen weights
if (
len(param.shape) == 1
or name.endswith(".bias")
or (
hasattr(model, "no_weight_decay")
and name in model.no_weight_decay()
)
):
no_decay[name] = param
names_no_decay.append(name)
else:
decay[name] = param
decay_1x, no_decay_1x = [], []
decay_mult, no_decay_mult = [], []
mult_factor = 1
if "lr_mult" in optimizer_options:
pattern, mult_factor = optimizer_options.pop("lr_mult")
for name, param in decay.items():
if re.match(pattern, name):
decay_mult.append(param)
names_lr_mult.append(name)
else:
decay_1x.append(param)
for name, param in no_decay.items():
if re.match(pattern, name):
no_decay_mult.append(param)
names_lr_mult.append(name)
else:
no_decay_1x.append(param)
assert len(decay_1x) + len(decay_mult) == len(decay)
assert len(no_decay_1x) + len(no_decay_mult) == len(no_decay)
else:
decay_1x, no_decay_1x = list(decay.values()), list(no_decay.values())
wd = optimizer_options.pop("weight_decay", 0.0)
parameters = [
{"params": no_decay_1x, "weight_decay": 0.0},
{"params": decay_1x, "weight_decay": wd},
{
"params": no_decay_mult,
"weight_decay": 0.0,
"lr": args.start_lr * mult_factor,
},
{
"params": decay_mult,
"weight_decay": wd,
"lr": args.start_lr * mult_factor,
},
]
_logger.info(
"Parameters excluded from weight decay:\n - %s",
"\n - ".join(names_no_decay),
)
if len(names_lr_mult):
_logger.info(
"Parameters with lr multiplied by %s:\n - %s",
mult_factor,
"\n - ".join(names_lr_mult),
)
else:
parameters = model.parameters()
if args.optimizer == "ranger":
from src.utils.nn.optimizer.ranger import Ranger
opt = Ranger(parameters, lr=args.start_lr, **optimizer_options)
elif args.optimizer == "adam":
opt = torch.optim.Adam(parameters, lr=args.start_lr, **optimizer_options)
elif args.optimizer == "adamW":
opt = torch.optim.AdamW(parameters, lr=args.start_lr, **optimizer_options)
elif args.optimizer == "radam":
opt = torch.optim.RAdam(parameters, lr=args.start_lr, **optimizer_options)
if args.__dict__[load_model_weights] is not None:
_logger.info("Resume training from file %s" % args.__dict__[load_model_weights])
model_state = torch.load(
args.__dict__[load_model_weights],
map_location=device,
)
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
model.module.load_state_dict(model_state["model"])
else:
model.load_state_dict(model_state["model"])
opt_state = model_state["optimizer"]
opt.load_state_dict(opt_state)
scheduler = None
if args.lr_scheduler == "steps":
lr_step = round(args.num_epochs / 3)
scheduler = torch.optim.lr_scheduler.MultiStepLR(
opt,
milestones=[10],
gamma=0.20,
last_epoch=-1
)
elif args.lr_scheduler == "flat+decay":
num_decay_epochs = max(1, int(args.num_epochs * 0.3))
milestones = list(
range(args.num_epochs - num_decay_epochs, args.num_epochs)
)
gamma = 0.01 ** (1.0 / num_decay_epochs)
if len(names_lr_mult):
def get_lr(epoch):
return gamma ** max(0, epoch - milestones[0] + 1) # noqa
scheduler = torch.optim.lr_scheduler.LambdaLR(
opt,
(lambda _: 1, lambda _: 1, get_lr, get_lr),
last_epoch=-1,
verbose=True,
)
else:
scheduler = torch.optim.lr_scheduler.MultiStepLR(
opt,
milestones=milestones,
gamma=gamma,
last_epoch=-1
)
elif args.lr_scheduler == "flat+linear" or args.lr_scheduler == "flat+cos":
total_steps = args.num_epochs * args.steps_per_epoch
warmup_steps = args.warmup_steps
flat_steps = total_steps * 0.7 - 1
min_factor = 0.001
def lr_fn(step_num):
if step_num > total_steps:
raise ValueError(
"Tried to step {} times. The specified number of total steps is {}".format(
step_num + 1, total_steps
)
)
if step_num < warmup_steps:
return 1.0 * step_num / warmup_steps
if step_num <= flat_steps:
return 1.0
pct = (step_num - flat_steps) / (total_steps - flat_steps)
if args.lr_scheduler == "flat+linear":
return max(min_factor, 1 - pct)
else:
return max(min_factor, 0.5 * (math.cos(math.pi * pct) + 1))
scheduler = torch.optim.lr_scheduler.LambdaLR(
opt,
lr_fn,
last_epoch=-1
if args.load_epoch is None
else args.load_epoch * args.steps_per_epoch,
)
scheduler._update_per_step = (
True # mark it to update the lr every step, instead of every epoch
)
elif args.lr_scheduler == "one-cycle":
scheduler = torch.optim.lr_scheduler.OneCycleLR(
opt,
max_lr=args.start_lr,
epochs=args.num_epochs,
steps_per_epoch=args.steps_per_epoch,
pct_start=0.3,
anneal_strategy="cos",
div_factor=25.0,
last_epoch=-1 if args.load_epoch is None else args.load_epoch,
)
scheduler._update_per_step = (
True # mark it to update the lr every step, instead of every epoch
)
elif args.lr_scheduler == "reduceplateau":
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
opt, patience=2, threshold=0.01
)
# scheduler._update_per_step = (
# True # mark it to update the lr every step, instead of every epoch
# )
scheduler._update_per_step = (
False # mark it to update the lr every step, instead of every epoch
)
if args.__dict__[load_model_weights]:
if scheduler is not None:
scheduler.load_state_dict(model_state["scheduler"])
return opt, scheduler
def get_target_obj_score(clusters_eta, clusters_phi, clusters_pt, event_idx_clusters, dq_eta, dq_phi, dq_event_idx, gt_mode="all_in_radius"):
# return the target scores for each cluster (reteurns list of 1's and 0's)
# dq_coords: list of [eta, phi] for each dark quark
# dq_event_idx: list of event_idx for each dark quarks
target = []
if gt_mode == "all_in_radius":
for event in event_idx_clusters.unique():
filt = event_idx_clusters == event
clusters = torch.stack([clusters_eta[filt], clusters_phi[filt], clusters_pt[filt]], dim=1)
dq_coords_event = torch.stack([dq_eta[dq_event_idx == event], dq_phi[dq_event_idx == event]], dim=1)
dist_matrix = torch.cdist(
dq_coords_event,
clusters[:, :2].to(dq_coords_event.device),
p=2
).T
if len(dist_matrix) == 0:
target.append(torch.zeros(len(clusters)).int().to(dist_matrix.device))
continue
closest_quark_dist, closest_quark_idx = dist_matrix.min(dim=1)
closest_quark_idx[closest_quark_dist > 0.8] = -1
target.append((closest_quark_idx != -1).float())
else:
# GT is set by only considering the closest jet to each dark quark (if it's within radius)
for event in event_idx_clusters.unique():
filt = event_idx_clusters == event
clusters = torch.stack([clusters_eta[filt], clusters_phi[filt], clusters_pt[filt]], dim=1)
dq_coords_event = torch.stack([dq_eta[dq_event_idx == event], dq_phi[dq_event_idx == event]], dim=1)
dist_matrix = torch.cdist(
dq_coords_event,
clusters[:, :2].to(dq_coords_event.device),
p=2
).T
if len(dist_matrix) == 0:
target.append(torch.zeros(len(clusters)).int().to(dist_matrix.device))
continue
closest_cluster_dist, closest_cluster_idx = dist_matrix.min(dim=0)
closest_cluster_idx[closest_cluster_dist > 0.8] = -1
matched_clusters = closest_cluster_idx[closest_cluster_idx != -1]
t = torch.zeros_like(clusters_eta[filt])
print(matched_clusters)
print(matched_clusters.int())
t[matched_clusters.long()] = 1
target.append(t)
return torch.cat(target).flatten()
def plot_obj_score_debug(dq_eta, dq_phi, dq_batch_idx, clusters_eta, clusters_phi, clusters_pt, clusters_batch_idx, clusters_labels, input_pxyz, input_event_idx, input_clusters, pred_obj_score_clusters):
# For debugging the Objectness Score head.
import matplotlib.pyplot as plt
n_events = dq_batch_idx.max().int().item() + 1
pfcands_pt = torch.sqrt(input_pxyz[:, 0] ** 2 + input_pxyz[:, 1] ** 2)
pfcands_eta, pfcands_phi = calc_eta_phi(input_pxyz, return_stacked=0)
fig, ax = plt.subplots(1, n_events, figsize=(n_events * 3, 3))
colors = {0: "grey", 1: "green"}
for i in range(n_events):
# Plot the clusters as dots that are green for label 1 and gray for label 0
filt = clusters_batch_idx == i
ax[i].scatter(clusters_eta[filt].cpu(), clusters_phi[filt].cpu(), c=[colors[x] for x in clusters_labels[filt].tolist()], cmap="coolwarm", s=clusters_pt[filt].cpu(), alpha=0.5)
# with a light gray text, also plot the target objectness score for each cluster
for j in range(len(clusters_eta[filt])):
ax[i].text(clusters_eta[filt][j].cpu()-0.5, clusters_phi[filt][j].cpu()-0.5, str(round(pred_obj_score_clusters[filt][j].item(), 2)), fontsize=6, color="gray", alpha=0.7)
# Plot the dark quarks as red dots
filt = dq_batch_idx == i
ax[i].scatter(dq_eta[filt].cpu(), dq_phi[filt].cpu(), c="red", alpha=0.5)
ax[i].scatter(pfcands_eta[input_event_idx == i].cpu(), pfcands_phi[input_event_idx == i].cpu(), c=input_clusters[input_event_idx == i].cpu(), cmap="coolwarm", s=pfcands_pt[input_event_idx == i].cpu(), alpha=0.5)
# put pt of the clusters in gray text on top of them
filt = clusters_batch_idx == i
for j in range(len(clusters_eta[filt])):
ax[i].text(clusters_eta[filt][j].cpu(), clusters_phi[filt][j].cpu(), str(round(clusters_pt[filt][j].item(), 2)), fontsize=8, color="black")
fig.tight_layout()
return fig
def get_loss_func(args):
# Loss function takes in the output of a model and the output of GT (the GT labels) and returns the loss.
def loss(model_input, model_output, gt_labels):
batch_numbers = model_input.batch_idx
if not (args.loss == "quark_distance" or args.train_objectness_score):
labels = gt_labels+1
else:
labels = gt_labels
return object_condensation_loss(model_input, model_output, labels, batch_numbers,
attr_weight=args.attr_loss_weight,
repul_weight=args.repul_loss_weight,
coord_weight=args.coord_loss_weight,
beta_type=args.beta_type,
lorentz_norm=args.lorentz_norm,
spatial_part_only=args.spatial_part_only,
loss_quark_distance=args.loss=="quark_distance",
oc_scalars=args.scalars_oc,
loss_obj_score=args.train_objectness_score)
return loss
def renumber_clusters(tensor):
unique = tensor.unique()
mapping = torch.zeros(unique.max() + 1)
for i, u in enumerate(unique):
mapping[u] = i
return mapping[tensor]
def get_gt_func(args):
# Gets the GT function: the function accepts an Event batch
# and returns the ground truth labels (GT idx of a dark quark it belongs to, or -1 for noise)
# By default, it returns the dark quark that is closest to the event, IF it's closer than R.
R = args.gt_radius
def get_idx_for_event(obj, i):
return obj.batch_number[i], obj.batch_number[i + 1]
def get_labels(b, pfcands, special=False, get_coordinates=False, get_dq_coords=False):
# b: Batch of events
# if get_coordinates is true, it returns the coordinates of the labels rather than the clustering labels themselves.
labels = torch.zeros(len(pfcands)).long()
if get_coordinates:
labels_coordinates = torch.zeros(len(b.matrix_element_gen_particles.pt), 4).float()
labels_no_renumber = torch.ones_like(labels)*-1
offset = 0
if get_dq_coords:
dq_coords = [b.matrix_element_gen_particles.eta, b.matrix_element_gen_particles.phi]
#dq_coords_batch_idx = b.matrix_element_gen_particles.batch_number
dq_coords_batch_idx = torch.zeros(b.matrix_element_gen_particles.pt.shape)
for i in range(len(b.matrix_element_gen_particles.batch_number) - 1):
dq_coords_batch_idx[b.matrix_element_gen_particles.batch_number[i]:b.matrix_element_gen_particles.batch_number[i + 1]] = i
for i in range(len(b)):
s_dq, e_dq = get_idx_for_event(b.matrix_element_gen_particles, i)
dq_eta = b.matrix_element_gen_particles.eta[s_dq:e_dq]
dq_phi = b.matrix_element_gen_particles.phi[s_dq:e_dq]
# dq_pt = b.matrix_element_gen_particles.pt[s:e] # Maybe we can somehow weigh the loss by pt?
s, e = get_idx_for_event(pfcands, i)
pfcands_eta = pfcands.eta[s:e]
pfcands_phi = pfcands.phi[s:e]
# calculate the distance matrix between each dark quark and pfcands
dist_matrix = torch.cdist(
torch.stack([dq_eta, dq_phi], dim=1),
torch.stack([pfcands_eta, pfcands_phi], dim=1),
p=2
)
dist_matrix = dist_matrix.T
closest_quark_dist, closest_quark_idx = dist_matrix.min(dim=1)
closest_quark_idx[closest_quark_dist > R] = -1
if len(closest_quark_idx):
#if special: print("Closest quark idx", closest_quark_idx, "; renumbered ",
# renumber_clusters(closest_quark_idx + 1) - 1)
if not get_coordinates:
closest_quark_idx = renumber_clusters(closest_quark_idx + 1) - 1
else:
labels_no_renumber[s:e] = closest_quark_idx
closest_quark_idx[closest_quark_idx != -1] += offset
labels[s:e] = closest_quark_idx
if get_coordinates:
E_dq = b.matrix_element_gen_particles.E[s_dq:e_dq]
pxyz_dq = b.matrix_element_gen_particles.pxyz[s_dq:e_dq] # the -1 doesn't matter as it will be ignored anyway
labels_coordinates[s_dq:e_dq] = torch.cat([E_dq.unsqueeze(-1), pxyz_dq], dim=1)
offset += len(E_dq)
if get_coordinates:
return TensorCollection(labels=labels, labels_coordinates=labels_coordinates, labels_no_renumber=labels_no_renumber)
if get_dq_coords:
return TensorCollection(labels=labels, dq_coords=dq_coords, dq_coords_batch_idx=dq_coords_batch_idx)
return labels
def gt(events):
#special_labels = get_labels(events, events.special_pfcands, special=True)
#print("Special pfcands labels", special_labels)
#return torch.cat([get_labels(events, events.pfcands), special_labels])
pfcands = events.pfcands
if args.parton_level:
pfcands = events.final_parton_level_particles
if args.gen_level:
pfcands = events.final_gen_particles
return get_labels(events, pfcands, get_coordinates=args.loss=="quark_distance",
get_dq_coords=args.train_objectness_score)
return gt
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
def get_model(args, dev):
network_options = {} # TODO: implement network options
network_module = import_module(args.network_config, name="_network_module")
model = network_module.get_model(obj_score=False, args=args, **network_options)
if args.load_model_weights:
print("Loading model state dict from %s" % args.load_model_weights)
model_state = torch.load(args.load_model_weights, map_location=dev)["model"]
missing_keys, unexpected_keys = model.load_state_dict(model_state, strict=False)
_logger.info(
"Model initialized with weights from %s\n ... Missing: %s\n ... Unexpected: %s"
% (args.load_model_weights, missing_keys, unexpected_keys)
)
assert len(missing_keys) == 0
assert len(unexpected_keys) == 0
return model
def get_model_obj_score(args, dev):
network_options = {} # TODO: implement network options
network_module = import_module(args.obj_score_module, name="_network_module")
model = network_module.get_model(obj_score=True, args=args, **network_options)
if args.load_objectness_score_weights:
assert args.train_objectness_score
print("Loading objectness score model state dict from %s" % args.load_objectness_score_weights)
model_state = torch.load(args.load_objectness_score_weights, map_location=dev)["model"]
missing_keys, unexpected_keys = model.load_state_dict(model_state, strict=False)
_logger.info(
"Objectness score model initialized with weights from %s\n ... Missing: %s\n ... Unexpected: %s"
% (args.load_objectness_score_weights, missing_keys, unexpected_keys)
)
assert len(missing_keys) == 0
assert len(unexpected_keys) == 0
return model