jetclustering / src /preprocessing /preprocess_dataset.py
gregorkrzmanc's picture
.
e75a247
raw
history blame
5.04 kB
import torch
import os.path as osp
import os
import sys
from src.dataset.dataset import SimpleIterDataset
from src.utils.utils import to_filelist
from pathlib import Path
import pickle
from src.utils.paths import get_path
import argparse
import numpy as np
parser = argparse.ArgumentParser()
parser.add_argument("--input", type=str)
parser.add_argument("--output", type=str)
parser.add_argument("--overwrite", action="store_true")
parser.add_argument("--dataset-cap", type=int, default=-1)
parser.add_argument("--v2", action="store_true") # V2 means that the dataset also stores parton-level and genParticles
parser.add_argument("--delphes", action="store_true")
args = parser.parse_args()
path = get_path(args.input, "data")
def remove_from_list(lst):
out = []
for item in lst:
if item in ["hgcal", "data.txt", "test_file.root"]:
continue
out.append(item)
return out
def preprocess_dataset(datasets, output_path, config_file, dataset_cap):
#datasets = os.listdir(path)
#datasets = [os.path.join(path, x) for x in datasets]
class Args:
def __init__(self):
self.data_train = datasets
self.data_val = datasets
#self.data_train = files_train
self.data_config = config_file
self.extra_selection = None
self.train_val_split = 1.0
self.data_fraction = 1
self.file_fraction = 1
self.fetch_by_files = False
self.fetch_step = 1
self.steps_per_epoch = None
self.in_memory = False
self.local_rank = None
self.copy_inputs = False
self.no_remake_weights = False
self.batch_size = 10
self.num_workers = 0
self.demo = False
self.laplace = False
self.diffs = False
self.class_edges = False
args = Args()
train_range = (0, args.train_val_split)
train_file_dict, train_files = to_filelist(args, 'train')
train_data = SimpleIterDataset(train_file_dict, args.data_config, for_training=True,
extra_selection=args.extra_selection,
remake_weights=True,
load_range_and_fraction=(train_range, args.data_fraction),
file_fraction=args.file_fraction,
fetch_by_files=args.fetch_by_files,
fetch_step=args.fetch_step,
infinity_mode=False,
in_memory=args.in_memory,
async_load=False,
name='train', jets=True)
iterator = iter(train_data)
from time import time
t0 = time()
data = []
count = 0
while True:
try:
i = next(iterator)
data.append(i)
count += 1
if dataset_cap > 0 and count >= dataset_cap:
break
except StopIteration:
break
t1 = time()
print("Took", t1-t0, "s -", datasets[0])
from src.dataset.functions_data import concat_events
events = concat_events(data) # TODO: This can be done in a nicer way, using less memory (?)
result = events.serialize()
dir_name = datasets[0].split("/")[-2]
save_to_dir = os.path.join(output_path, dir_name)
Path(save_to_dir).mkdir(parents=True, exist_ok=True)
for key in result[0]:
with open(osp.join(save_to_dir, key + ".pkl"), "wb") as f:
#pickle.dump(result[0][key], f) #save with torch for mmap
#torch.save(result[0][key], f)
np.save(f, result[0][key].numpy())
with open(osp.join(save_to_dir, "metadata.pkl"), "wb") as f:
pickle.dump(result[1], f)
print("Saved to", save_to_dir)
print("Finished", dir_name)
'''
from src.dataset.functions_data import EventCollection, EventJets, Event
from src.dataset.dataset import EventDataset
t2 = time()
data1 = []
for event in EventDataset(result[0], result[1]):
data1.append(event)
t3 = time()
print("Took", t3-t2, "s")
print("Done")
'''
output = get_path(args.output, "preprocessed_data")
for dir in os.listdir(path):
if args.overwrite or not os.path.exists(os.path.join(output, dir)):
config = get_path('config_files/config_jets.yaml', 'code')
if args.v2:
delphes_suffix = ""
if args.delphes:
delphes_suffix = "_delphes"
config = get_path(f'config_files/config_jets_2{delphes_suffix}.yaml', 'code')
for i, file in enumerate(sorted(os.listdir(os.path.join(path, dir)))):
print("Preprocessing file", file)
preprocess_dataset([os.path.join(path, dir, file)], output + "_part"+str(i), config_file=config, dataset_cap=args.dataset_cap)
else:
print("Skipping", dir + ", already exists")
# flush
sys.stdout.flush()