Spaces:
Sleeping
Sleeping
import os | |
import ast | |
import glob | |
import functools | |
import math | |
import torch | |
from torch.utils.data import DataLoader | |
from src.logger.logger import _logger, _configLogger | |
from src.dataset.dataset import EventDatasetCollection, EventDataset | |
from src.utils.import_tools import import_module | |
from src.dataset.functions_graph import graph_batch_func | |
from src.dataset.functions_data import concat_events | |
from src.utils.paths import get_path | |
from src.layers.object_cond import calc_eta_phi | |
from src.layers.object_cond import object_condensation_loss | |
def to_filelist(args, mode="train"): | |
if mode == "train": | |
flist = args.data_train | |
elif mode == "val": | |
flist = args.data_val | |
elif mode == "test": | |
flist = args.data_test | |
else: | |
raise NotImplementedError("Invalid mode %s" % mode) | |
print(mode, "filelist:", flist) | |
flist = [get_path(p, "preprocessed_data") for p in flist] | |
return flist | |
class TensorCollection: | |
def __init__(self, **kwargs): | |
self.__dict__.update(kwargs) | |
def to(self, device): | |
# Move all tensors to device | |
for k, v in self.__dict__.items(): | |
if torch.is_tensor(v): | |
setattr(self, k, v.to(device)) | |
return self | |
def dict_rep(self): | |
d = {} | |
for k, v in self.__dict__.items(): | |
if torch.is_tensor(v): | |
d[k] = v | |
return d | |
#def __getitem__(self, i): | |
# return TensorCollection(**{k: v[i] for k, v in self.__dict__.items()}) | |
def train_load(args, aug_soft=False, aug_collinear=False): | |
train_files = to_filelist(args, "train") | |
val_files = to_filelist(args, "val") | |
train_data = EventDatasetCollection(train_files, args, aug_soft=aug_soft, aug_collinear=aug_collinear) | |
if args.train_dataset_size is not None: | |
train_data = torch.utils.data.Subset(train_data, list(range(args.train_dataset_size))) | |
train_loader = DataLoader( | |
train_data, | |
batch_size=args.batch_size, | |
drop_last=True, | |
pin_memory=True, | |
num_workers=args.num_workers, | |
collate_fn=concat_events, | |
persistent_workers=args.num_workers > 0, | |
shuffle=False | |
) | |
'''val_loaders = {} | |
for filename in val_files: | |
val_data = EventDataset.from_directory(filename, mmap=True) | |
val_loaders[filename] = DataLoader( | |
val_data, | |
batch_size=args.batch_size, | |
drop_last=True, | |
pin_memory=True, | |
collate_fn=concat_events, | |
num_workers=args.num_workers, | |
persistent_workers=args.num_workers > 0, | |
)''' | |
val_data = EventDatasetCollection(val_files, args) | |
if args.val_dataset_size is not None: | |
val_data = torch.utils.data.Subset(val_data, list(range(args.val_dataset_size))) | |
val_loader = DataLoader( | |
val_data, | |
batch_size=args.batch_size, | |
drop_last=True, | |
pin_memory=True, | |
num_workers=args.num_workers, | |
collate_fn=concat_events, | |
persistent_workers=args.num_workers > 0, | |
shuffle=False | |
) | |
return train_loader, val_loader, val_data | |
def test_load(args): | |
test_files = to_filelist(args, "test") | |
test_loaders = {} | |
for filename in test_files: | |
test_data = EventDataset.from_directory(filename, mmap=True, aug_soft=args.augment_soft_particles, seed=1000000) | |
if args.test_dataset_max_size is not None: | |
print("Limiting test dataset size to", args.test_dataset_max_size) | |
test_data = torch.utils.data.Subset(test_data, list(range(args.test_dataset_max_size))) | |
test_loaders[filename] = DataLoader( | |
test_data, | |
batch_size=args.batch_size, | |
drop_last=True, | |
pin_memory=True, | |
collate_fn=concat_events, | |
num_workers=args.num_workers, | |
persistent_workers=args.num_workers > 0, | |
) | |
return test_loaders | |
def get_optimizer_and_scheduler(args, model, device, load_model_weights="load_model_weights"): | |
""" | |
Optimizer and scheduler. | |
:param args: | |
:param model: | |
:return: | |
""" | |
optimizer_options = {k: ast.literal_eval(v) for k, v in args.optimizer_option} | |
_logger.info("Optimizer options: %s" % str(optimizer_options)) | |
names_lr_mult = [] | |
if "weight_decay" in optimizer_options or "lr_mult" in optimizer_options: | |
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/optim/optim_factory.py#L31 | |
import re | |
decay, no_decay = {}, {} | |
names_no_decay = [] | |
for name, param in model.named_parameters(): | |
if not param.requires_grad: | |
continue # frozen weights | |
if ( | |
len(param.shape) == 1 | |
or name.endswith(".bias") | |
or ( | |
hasattr(model, "no_weight_decay") | |
and name in model.no_weight_decay() | |
) | |
): | |
no_decay[name] = param | |
names_no_decay.append(name) | |
else: | |
decay[name] = param | |
decay_1x, no_decay_1x = [], [] | |
decay_mult, no_decay_mult = [], [] | |
mult_factor = 1 | |
if "lr_mult" in optimizer_options: | |
pattern, mult_factor = optimizer_options.pop("lr_mult") | |
for name, param in decay.items(): | |
if re.match(pattern, name): | |
decay_mult.append(param) | |
names_lr_mult.append(name) | |
else: | |
decay_1x.append(param) | |
for name, param in no_decay.items(): | |
if re.match(pattern, name): | |
no_decay_mult.append(param) | |
names_lr_mult.append(name) | |
else: | |
no_decay_1x.append(param) | |
assert len(decay_1x) + len(decay_mult) == len(decay) | |
assert len(no_decay_1x) + len(no_decay_mult) == len(no_decay) | |
else: | |
decay_1x, no_decay_1x = list(decay.values()), list(no_decay.values()) | |
wd = optimizer_options.pop("weight_decay", 0.0) | |
parameters = [ | |
{"params": no_decay_1x, "weight_decay": 0.0}, | |
{"params": decay_1x, "weight_decay": wd}, | |
{ | |
"params": no_decay_mult, | |
"weight_decay": 0.0, | |
"lr": args.start_lr * mult_factor, | |
}, | |
{ | |
"params": decay_mult, | |
"weight_decay": wd, | |
"lr": args.start_lr * mult_factor, | |
}, | |
] | |
_logger.info( | |
"Parameters excluded from weight decay:\n - %s", | |
"\n - ".join(names_no_decay), | |
) | |
if len(names_lr_mult): | |
_logger.info( | |
"Parameters with lr multiplied by %s:\n - %s", | |
mult_factor, | |
"\n - ".join(names_lr_mult), | |
) | |
else: | |
parameters = model.parameters() | |
if args.optimizer == "ranger": | |
from src.utils.nn.optimizer.ranger import Ranger | |
opt = Ranger(parameters, lr=args.start_lr, **optimizer_options) | |
elif args.optimizer == "adam": | |
opt = torch.optim.Adam(parameters, lr=args.start_lr, **optimizer_options) | |
elif args.optimizer == "adamW": | |
opt = torch.optim.AdamW(parameters, lr=args.start_lr, **optimizer_options) | |
elif args.optimizer == "radam": | |
opt = torch.optim.RAdam(parameters, lr=args.start_lr, **optimizer_options) | |
if args.__dict__[load_model_weights] is not None: | |
_logger.info("Resume training from file %s" % args.__dict__[load_model_weights]) | |
model_state = torch.load( | |
args.__dict__[load_model_weights], | |
map_location=device, | |
) | |
if isinstance(model, torch.nn.parallel.DistributedDataParallel): | |
model.module.load_state_dict(model_state["model"]) | |
else: | |
model.load_state_dict(model_state["model"]) | |
opt_state = model_state["optimizer"] | |
opt.load_state_dict(opt_state) | |
scheduler = None | |
if args.lr_scheduler == "steps": | |
lr_step = round(args.num_epochs / 3) | |
scheduler = torch.optim.lr_scheduler.MultiStepLR( | |
opt, | |
milestones=[10], | |
gamma=0.20, | |
last_epoch=-1 | |
) | |
elif args.lr_scheduler == "flat+decay": | |
num_decay_epochs = max(1, int(args.num_epochs * 0.3)) | |
milestones = list( | |
range(args.num_epochs - num_decay_epochs, args.num_epochs) | |
) | |
gamma = 0.01 ** (1.0 / num_decay_epochs) | |
if len(names_lr_mult): | |
def get_lr(epoch): | |
return gamma ** max(0, epoch - milestones[0] + 1) # noqa | |
scheduler = torch.optim.lr_scheduler.LambdaLR( | |
opt, | |
(lambda _: 1, lambda _: 1, get_lr, get_lr), | |
last_epoch=-1, | |
verbose=True, | |
) | |
else: | |
scheduler = torch.optim.lr_scheduler.MultiStepLR( | |
opt, | |
milestones=milestones, | |
gamma=gamma, | |
last_epoch=-1 | |
) | |
elif args.lr_scheduler == "flat+linear" or args.lr_scheduler == "flat+cos": | |
total_steps = args.num_epochs * args.steps_per_epoch | |
warmup_steps = args.warmup_steps | |
flat_steps = total_steps * 0.7 - 1 | |
min_factor = 0.001 | |
def lr_fn(step_num): | |
if step_num > total_steps: | |
raise ValueError( | |
"Tried to step {} times. The specified number of total steps is {}".format( | |
step_num + 1, total_steps | |
) | |
) | |
if step_num < warmup_steps: | |
return 1.0 * step_num / warmup_steps | |
if step_num <= flat_steps: | |
return 1.0 | |
pct = (step_num - flat_steps) / (total_steps - flat_steps) | |
if args.lr_scheduler == "flat+linear": | |
return max(min_factor, 1 - pct) | |
else: | |
return max(min_factor, 0.5 * (math.cos(math.pi * pct) + 1)) | |
scheduler = torch.optim.lr_scheduler.LambdaLR( | |
opt, | |
lr_fn, | |
last_epoch=-1 | |
if args.load_epoch is None | |
else args.load_epoch * args.steps_per_epoch, | |
) | |
scheduler._update_per_step = ( | |
True # mark it to update the lr every step, instead of every epoch | |
) | |
elif args.lr_scheduler == "one-cycle": | |
scheduler = torch.optim.lr_scheduler.OneCycleLR( | |
opt, | |
max_lr=args.start_lr, | |
epochs=args.num_epochs, | |
steps_per_epoch=args.steps_per_epoch, | |
pct_start=0.3, | |
anneal_strategy="cos", | |
div_factor=25.0, | |
last_epoch=-1 if args.load_epoch is None else args.load_epoch, | |
) | |
scheduler._update_per_step = ( | |
True # mark it to update the lr every step, instead of every epoch | |
) | |
elif args.lr_scheduler == "reduceplateau": | |
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( | |
opt, patience=2, threshold=0.01 | |
) | |
# scheduler._update_per_step = ( | |
# True # mark it to update the lr every step, instead of every epoch | |
# ) | |
scheduler._update_per_step = ( | |
False # mark it to update the lr every step, instead of every epoch | |
) | |
if args.__dict__[load_model_weights]: | |
if scheduler is not None: | |
scheduler.load_state_dict(model_state["scheduler"]) | |
return opt, scheduler | |
def get_target_obj_score(clusters_eta, clusters_phi, clusters_pt, event_idx_clusters, dq_eta, dq_phi, dq_event_idx, gt_mode="all_in_radius"): | |
# return the target scores for each cluster (reteurns list of 1's and 0's) | |
# dq_coords: list of [eta, phi] for each dark quark | |
# dq_event_idx: list of event_idx for each dark quarks | |
target = [] | |
if gt_mode == "all_in_radius": | |
for event in event_idx_clusters.unique(): | |
filt = event_idx_clusters == event | |
clusters = torch.stack([clusters_eta[filt], clusters_phi[filt], clusters_pt[filt]], dim=1) | |
dq_coords_event = torch.stack([dq_eta[dq_event_idx == event], dq_phi[dq_event_idx == event]], dim=1) | |
dist_matrix = torch.cdist( | |
dq_coords_event, | |
clusters[:, :2].to(dq_coords_event.device), | |
p=2 | |
).T | |
if len(dist_matrix) == 0: | |
target.append(torch.zeros(len(clusters)).int().to(dist_matrix.device)) | |
continue | |
closest_quark_dist, closest_quark_idx = dist_matrix.min(dim=1) | |
closest_quark_idx[closest_quark_dist > 0.8] = -1 | |
target.append((closest_quark_idx != -1).float()) | |
else: | |
# GT is set by only considering the closest jet to each dark quark (if it's within radius) | |
for event in event_idx_clusters.unique(): | |
filt = event_idx_clusters == event | |
clusters = torch.stack([clusters_eta[filt], clusters_phi[filt], clusters_pt[filt]], dim=1) | |
dq_coords_event = torch.stack([dq_eta[dq_event_idx == event], dq_phi[dq_event_idx == event]], dim=1) | |
dist_matrix = torch.cdist( | |
dq_coords_event, | |
clusters[:, :2].to(dq_coords_event.device), | |
p=2 | |
).T | |
if len(dist_matrix) == 0: | |
target.append(torch.zeros(len(clusters)).int().to(dist_matrix.device)) | |
continue | |
closest_cluster_dist, closest_cluster_idx = dist_matrix.min(dim=0) | |
closest_cluster_idx[closest_cluster_dist > 0.8] = -1 | |
matched_clusters = closest_cluster_idx[closest_cluster_idx != -1] | |
t = torch.zeros_like(clusters_eta[filt]) | |
print(matched_clusters) | |
print(matched_clusters.int()) | |
t[matched_clusters.long()] = 1 | |
target.append(t) | |
return torch.cat(target).flatten() | |
def plot_obj_score_debug(dq_eta, dq_phi, dq_batch_idx, clusters_eta, clusters_phi, clusters_pt, clusters_batch_idx, clusters_labels, input_pxyz, input_event_idx, input_clusters, pred_obj_score_clusters): | |
# For debugging the Objectness Score head. | |
import matplotlib.pyplot as plt | |
n_events = dq_batch_idx.max().int().item() + 1 | |
pfcands_pt = torch.sqrt(input_pxyz[:, 0] ** 2 + input_pxyz[:, 1] ** 2) | |
pfcands_eta, pfcands_phi = calc_eta_phi(input_pxyz, return_stacked=0) | |
fig, ax = plt.subplots(1, n_events, figsize=(n_events * 3, 3)) | |
colors = {0: "grey", 1: "green"} | |
for i in range(n_events): | |
# Plot the clusters as dots that are green for label 1 and gray for label 0 | |
filt = clusters_batch_idx == i | |
ax[i].scatter(clusters_eta[filt].cpu(), clusters_phi[filt].cpu(), c=[colors[x] for x in clusters_labels[filt].tolist()], cmap="coolwarm", s=clusters_pt[filt].cpu(), alpha=0.5) | |
# with a light gray text, also plot the target objectness score for each cluster | |
for j in range(len(clusters_eta[filt])): | |
ax[i].text(clusters_eta[filt][j].cpu()-0.5, clusters_phi[filt][j].cpu()-0.5, str(round(pred_obj_score_clusters[filt][j].item(), 2)), fontsize=6, color="gray", alpha=0.7) | |
# Plot the dark quarks as red dots | |
filt = dq_batch_idx == i | |
ax[i].scatter(dq_eta[filt].cpu(), dq_phi[filt].cpu(), c="red", alpha=0.5) | |
ax[i].scatter(pfcands_eta[input_event_idx == i].cpu(), pfcands_phi[input_event_idx == i].cpu(), c=input_clusters[input_event_idx == i].cpu(), cmap="coolwarm", s=pfcands_pt[input_event_idx == i].cpu(), alpha=0.5) | |
# put pt of the clusters in gray text on top of them | |
filt = clusters_batch_idx == i | |
for j in range(len(clusters_eta[filt])): | |
ax[i].text(clusters_eta[filt][j].cpu(), clusters_phi[filt][j].cpu(), str(round(clusters_pt[filt][j].item(), 2)), fontsize=8, color="black") | |
fig.tight_layout() | |
return fig | |
def get_loss_func(args): | |
# Loss function takes in the output of a model and the output of GT (the GT labels) and returns the loss. | |
def loss(model_input, model_output, gt_labels): | |
batch_numbers = model_input.batch_idx | |
if not (args.loss == "quark_distance" or args.train_objectness_score): | |
labels = gt_labels+1 | |
else: | |
labels = gt_labels | |
return object_condensation_loss(model_input, model_output, labels, batch_numbers, | |
attr_weight=args.attr_loss_weight, | |
repul_weight=args.repul_loss_weight, | |
coord_weight=args.coord_loss_weight, | |
beta_type=args.beta_type, | |
lorentz_norm=args.lorentz_norm, | |
spatial_part_only=args.spatial_part_only, | |
loss_quark_distance=args.loss=="quark_distance", | |
oc_scalars=args.scalars_oc, | |
loss_obj_score=args.train_objectness_score) | |
return loss | |
def renumber_clusters(tensor): | |
unique = tensor.unique() | |
mapping = torch.zeros(unique.max() + 1) | |
for i, u in enumerate(unique): | |
mapping[u] = i | |
return mapping[tensor] | |
def get_gt_func(args): | |
# Gets the GT function: the function accepts an Event batch | |
# and returns the ground truth labels (GT idx of a dark quark it belongs to, or -1 for noise) | |
# By default, it returns the dark quark that is closest to the event, IF it's closer than R. | |
R = args.gt_radius | |
def get_idx_for_event(obj, i): | |
return obj.batch_number[i], obj.batch_number[i + 1] | |
def get_labels(b, pfcands, special=False, get_coordinates=False, get_dq_coords=False): | |
# b: Batch of events | |
# if get_coordinates is true, it returns the coordinates of the labels rather than the clustering labels themselves. | |
labels = torch.zeros(len(pfcands)).long() | |
if get_coordinates: | |
labels_coordinates = torch.zeros(len(b.matrix_element_gen_particles.pt), 4).float() | |
labels_no_renumber = torch.ones_like(labels)*-1 | |
offset = 0 | |
if get_dq_coords: | |
dq_coords = [b.matrix_element_gen_particles.eta, b.matrix_element_gen_particles.phi] | |
#dq_coords_batch_idx = b.matrix_element_gen_particles.batch_number | |
dq_coords_batch_idx = torch.zeros(b.matrix_element_gen_particles.pt.shape) | |
for i in range(len(b.matrix_element_gen_particles.batch_number) - 1): | |
dq_coords_batch_idx[b.matrix_element_gen_particles.batch_number[i]:b.matrix_element_gen_particles.batch_number[i + 1]] = i | |
for i in range(len(b)): | |
s_dq, e_dq = get_idx_for_event(b.matrix_element_gen_particles, i) | |
dq_eta = b.matrix_element_gen_particles.eta[s_dq:e_dq] | |
dq_phi = b.matrix_element_gen_particles.phi[s_dq:e_dq] | |
# dq_pt = b.matrix_element_gen_particles.pt[s:e] # Maybe we can somehow weigh the loss by pt? | |
s, e = get_idx_for_event(pfcands, i) | |
pfcands_eta = pfcands.eta[s:e] | |
pfcands_phi = pfcands.phi[s:e] | |
# calculate the distance matrix between each dark quark and pfcands | |
dist_matrix = torch.cdist( | |
torch.stack([dq_eta, dq_phi], dim=1), | |
torch.stack([pfcands_eta, pfcands_phi], dim=1), | |
p=2 | |
) | |
dist_matrix = dist_matrix.T | |
closest_quark_dist, closest_quark_idx = dist_matrix.min(dim=1) | |
closest_quark_idx[closest_quark_dist > R] = -1 | |
if len(closest_quark_idx): | |
#if special: print("Closest quark idx", closest_quark_idx, "; renumbered ", | |
# renumber_clusters(closest_quark_idx + 1) - 1) | |
if not get_coordinates: | |
closest_quark_idx = renumber_clusters(closest_quark_idx + 1) - 1 | |
else: | |
labels_no_renumber[s:e] = closest_quark_idx | |
closest_quark_idx[closest_quark_idx != -1] += offset | |
labels[s:e] = closest_quark_idx | |
if get_coordinates: | |
E_dq = b.matrix_element_gen_particles.E[s_dq:e_dq] | |
pxyz_dq = b.matrix_element_gen_particles.pxyz[s_dq:e_dq] # the -1 doesn't matter as it will be ignored anyway | |
labels_coordinates[s_dq:e_dq] = torch.cat([E_dq.unsqueeze(-1), pxyz_dq], dim=1) | |
offset += len(E_dq) | |
if get_coordinates: | |
return TensorCollection(labels=labels, labels_coordinates=labels_coordinates, labels_no_renumber=labels_no_renumber) | |
if get_dq_coords: | |
return TensorCollection(labels=labels, dq_coords=dq_coords, dq_coords_batch_idx=dq_coords_batch_idx) | |
return labels | |
def gt(events): | |
#special_labels = get_labels(events, events.special_pfcands, special=True) | |
#print("Special pfcands labels", special_labels) | |
#return torch.cat([get_labels(events, events.pfcands), special_labels]) | |
pfcands = events.pfcands | |
if args.parton_level: | |
pfcands = events.final_parton_level_particles | |
if args.gen_level: | |
pfcands = events.final_gen_particles | |
return get_labels(events, pfcands, get_coordinates=args.loss=="quark_distance", | |
get_dq_coords=args.train_objectness_score) | |
return gt | |
def count_parameters(model): | |
return sum(p.numel() for p in model.parameters() if p.requires_grad) | |
def get_model(args, dev): | |
network_options = {} # TODO: implement network options | |
network_module = import_module(args.network_config, name="_network_module") | |
model = network_module.get_model(obj_score=False, args=args, **network_options) | |
if args.load_model_weights: | |
print("Loading model state dict from %s" % args.load_model_weights) | |
model_state = torch.load(args.load_model_weights, map_location=dev)["model"] | |
missing_keys, unexpected_keys = model.load_state_dict(model_state, strict=False) | |
_logger.info( | |
"Model initialized with weights from %s\n ... Missing: %s\n ... Unexpected: %s" | |
% (args.load_model_weights, missing_keys, unexpected_keys) | |
) | |
assert len(missing_keys) == 0 | |
assert len(unexpected_keys) == 0 | |
return model | |
def get_model_obj_score(args, dev): | |
network_options = {} # TODO: implement network options | |
network_module = import_module(args.obj_score_module, name="_network_module") | |
model = network_module.get_model(obj_score=True, args=args, **network_options) | |
if args.load_objectness_score_weights: | |
assert args.train_objectness_score | |
print("Loading objectness score model state dict from %s" % args.load_objectness_score_weights) | |
model_state = torch.load(args.load_objectness_score_weights, map_location=dev)["model"] | |
missing_keys, unexpected_keys = model.load_state_dict(model_state, strict=False) | |
_logger.info( | |
"Objectness score model initialized with weights from %s\n ... Missing: %s\n ... Unexpected: %s" | |
% (args.load_objectness_score_weights, missing_keys, unexpected_keys) | |
) | |
assert len(missing_keys) == 0 | |
assert len(unexpected_keys) == 0 | |
return model | |