Spaces:
Sleeping
Sleeping
File size: 5,039 Bytes
e75a247 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
import torch
import os.path as osp
import os
import sys
from src.dataset.dataset import SimpleIterDataset
from src.utils.utils import to_filelist
from pathlib import Path
import pickle
from src.utils.paths import get_path
import argparse
import numpy as np
parser = argparse.ArgumentParser()
parser.add_argument("--input", type=str)
parser.add_argument("--output", type=str)
parser.add_argument("--overwrite", action="store_true")
parser.add_argument("--dataset-cap", type=int, default=-1)
parser.add_argument("--v2", action="store_true") # V2 means that the dataset also stores parton-level and genParticles
parser.add_argument("--delphes", action="store_true")
args = parser.parse_args()
path = get_path(args.input, "data")
def remove_from_list(lst):
out = []
for item in lst:
if item in ["hgcal", "data.txt", "test_file.root"]:
continue
out.append(item)
return out
def preprocess_dataset(datasets, output_path, config_file, dataset_cap):
#datasets = os.listdir(path)
#datasets = [os.path.join(path, x) for x in datasets]
class Args:
def __init__(self):
self.data_train = datasets
self.data_val = datasets
#self.data_train = files_train
self.data_config = config_file
self.extra_selection = None
self.train_val_split = 1.0
self.data_fraction = 1
self.file_fraction = 1
self.fetch_by_files = False
self.fetch_step = 1
self.steps_per_epoch = None
self.in_memory = False
self.local_rank = None
self.copy_inputs = False
self.no_remake_weights = False
self.batch_size = 10
self.num_workers = 0
self.demo = False
self.laplace = False
self.diffs = False
self.class_edges = False
args = Args()
train_range = (0, args.train_val_split)
train_file_dict, train_files = to_filelist(args, 'train')
train_data = SimpleIterDataset(train_file_dict, args.data_config, for_training=True,
extra_selection=args.extra_selection,
remake_weights=True,
load_range_and_fraction=(train_range, args.data_fraction),
file_fraction=args.file_fraction,
fetch_by_files=args.fetch_by_files,
fetch_step=args.fetch_step,
infinity_mode=False,
in_memory=args.in_memory,
async_load=False,
name='train', jets=True)
iterator = iter(train_data)
from time import time
t0 = time()
data = []
count = 0
while True:
try:
i = next(iterator)
data.append(i)
count += 1
if dataset_cap > 0 and count >= dataset_cap:
break
except StopIteration:
break
t1 = time()
print("Took", t1-t0, "s -", datasets[0])
from src.dataset.functions_data import concat_events
events = concat_events(data) # TODO: This can be done in a nicer way, using less memory (?)
result = events.serialize()
dir_name = datasets[0].split("/")[-2]
save_to_dir = os.path.join(output_path, dir_name)
Path(save_to_dir).mkdir(parents=True, exist_ok=True)
for key in result[0]:
with open(osp.join(save_to_dir, key + ".pkl"), "wb") as f:
#pickle.dump(result[0][key], f) #save with torch for mmap
#torch.save(result[0][key], f)
np.save(f, result[0][key].numpy())
with open(osp.join(save_to_dir, "metadata.pkl"), "wb") as f:
pickle.dump(result[1], f)
print("Saved to", save_to_dir)
print("Finished", dir_name)
'''
from src.dataset.functions_data import EventCollection, EventJets, Event
from src.dataset.dataset import EventDataset
t2 = time()
data1 = []
for event in EventDataset(result[0], result[1]):
data1.append(event)
t3 = time()
print("Took", t3-t2, "s")
print("Done")
'''
output = get_path(args.output, "preprocessed_data")
for dir in os.listdir(path):
if args.overwrite or not os.path.exists(os.path.join(output, dir)):
config = get_path('config_files/config_jets.yaml', 'code')
if args.v2:
delphes_suffix = ""
if args.delphes:
delphes_suffix = "_delphes"
config = get_path(f'config_files/config_jets_2{delphes_suffix}.yaml', 'code')
for i, file in enumerate(sorted(os.listdir(os.path.join(path, dir)))):
print("Preprocessing file", file)
preprocess_dataset([os.path.join(path, dir, file)], output + "_part"+str(i), config_file=config, dataset_cap=args.dataset_cap)
else:
print("Skipping", dir + ", already exists")
# flush
sys.stdout.flush()
|