File size: 5,039 Bytes
e75a247
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import torch
import os.path as osp
import os
import sys
from src.dataset.dataset import SimpleIterDataset
from src.utils.utils import to_filelist
from pathlib import Path
import pickle
from src.utils.paths import get_path
import argparse
import numpy as np


parser = argparse.ArgumentParser()
parser.add_argument("--input", type=str)
parser.add_argument("--output", type=str)
parser.add_argument("--overwrite", action="store_true")
parser.add_argument("--dataset-cap", type=int, default=-1)
parser.add_argument("--v2", action="store_true") # V2 means that the dataset also stores parton-level and genParticles
parser.add_argument("--delphes", action="store_true")

args = parser.parse_args()
path = get_path(args.input, "data")

def remove_from_list(lst):
    out = []
    for item in lst:
        if item in ["hgcal", "data.txt", "test_file.root"]:
            continue
        out.append(item)
    return out

def preprocess_dataset(datasets, output_path, config_file, dataset_cap):
    #datasets = os.listdir(path)
    #datasets = [os.path.join(path, x) for x in datasets]
    class Args:
        def __init__(self):
            self.data_train = datasets
            self.data_val = datasets
            #self.data_train = files_train
            self.data_config = config_file
            self.extra_selection = None
            self.train_val_split = 1.0
            self.data_fraction = 1
            self.file_fraction = 1
            self.fetch_by_files = False
            self.fetch_step = 1
            self.steps_per_epoch = None
            self.in_memory = False
            self.local_rank = None
            self.copy_inputs = False
            self.no_remake_weights = False
            self.batch_size = 10
            self.num_workers = 0
            self.demo = False
            self.laplace = False
            self.diffs = False
            self.class_edges = False
    args = Args()
    train_range = (0, args.train_val_split)
    train_file_dict, train_files = to_filelist(args, 'train')
    train_data = SimpleIterDataset(train_file_dict, args.data_config, for_training=True,
                                   extra_selection=args.extra_selection,
                                   remake_weights=True,
                                   load_range_and_fraction=(train_range, args.data_fraction),
                                   file_fraction=args.file_fraction,
                                   fetch_by_files=args.fetch_by_files,
                                   fetch_step=args.fetch_step,
                                   infinity_mode=False,
                                   in_memory=args.in_memory,
                                   async_load=False,
                                   name='train', jets=True)
    iterator = iter(train_data)
    from time import time
    t0 = time()
    data = []
    count = 0
    while True:
        try:
            i = next(iterator)
            data.append(i)
            count += 1
            if dataset_cap > 0 and count >= dataset_cap:
                break
        except StopIteration:
            break
    t1 = time()
    print("Took", t1-t0, "s -", datasets[0])
    from src.dataset.functions_data import concat_events
    events = concat_events(data) # TODO: This can be done in a nicer way, using less memory (?)
    result = events.serialize()
    dir_name = datasets[0].split("/")[-2]
    save_to_dir = os.path.join(output_path, dir_name)
    Path(save_to_dir).mkdir(parents=True, exist_ok=True)
    for key in result[0]:
        with open(osp.join(save_to_dir, key + ".pkl"), "wb") as f:
            #pickle.dump(result[0][key], f) #save with torch for mmap
            #torch.save(result[0][key], f)
            np.save(f, result[0][key].numpy())
    with open(osp.join(save_to_dir, "metadata.pkl"), "wb") as f:
        pickle.dump(result[1], f)
    print("Saved to", save_to_dir)
    print("Finished", dir_name)
    '''
    from src.dataset.functions_data import EventCollection, EventJets, Event
    from src.dataset.dataset import EventDataset
    t2 = time()
    data1 = []
    for event in EventDataset(result[0], result[1]):
        data1.append(event)
    t3 = time()
    print("Took", t3-t2, "s")
    print("Done")
    '''
output = get_path(args.output, "preprocessed_data")
for dir in os.listdir(path):
    if args.overwrite or not os.path.exists(os.path.join(output, dir)):
        config = get_path('config_files/config_jets.yaml', 'code')
        if args.v2:
            delphes_suffix = ""
            if args.delphes:
                delphes_suffix = "_delphes"
            config = get_path(f'config_files/config_jets_2{delphes_suffix}.yaml', 'code')
        for i, file in enumerate(sorted(os.listdir(os.path.join(path, dir)))):
            print("Preprocessing file", file)
            preprocess_dataset([os.path.join(path, dir, file)], output + "_part"+str(i), config_file=config, dataset_cap=args.dataset_cap)
    else:
        print("Skipping", dir + ", already exists")
        # flush
        sys.stdout.flush()