File size: 2,725 Bytes
e75a247
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import os
from src.dataset.dataset import SimpleIterDataset, EventDataset
from src.utils.utils import to_filelist
from src.utils.paths import get_path


# To be used for simple analysis scripts, not for the full training!
def get_iter(path, full_dataloader=False, model_clusters_file=None, model_output_file=None,
             include_model_jets_unfiltered=False):
    if full_dataloader:
        datasets = os.listdir(path)
        datasets = [os.path.join(path, x) for x in datasets]
        class Args:
            def __init__(self):
                self.data_train = datasets
                self.data_val = datasets
                #self.data_train = files_train
                self.data_config = get_path('config_files/config_jets.yaml', "code")
                self.extra_selection = None
                self.train_val_split = 1
                self.data_fraction = 1
                self.file_fraction = 1
                self.fetch_by_files = False
                self.fetch_step = 0.1
                self.steps_per_epoch = None
                self.in_memory = False
                self.local_rank = None
                self.copy_inputs = False
                self.no_remake_weights = False
                self.batch_size = 10
                self.num_workers = 0
                self.demo = False
                self.laplace = False
                self.diffs = False
                self.class_edges = False

        args = Args()
        train_range = (0, args.train_val_split)
        train_file_dict, train_files = to_filelist(args, 'train')
        train_data = SimpleIterDataset(train_file_dict, args.data_config, for_training=True,
                                       extra_selection=args.extra_selection,
                                       remake_weights=True,
                                       load_range_and_fraction=(train_range, args.data_fraction),
                                       file_fraction=args.file_fraction,
                                       fetch_by_files=args.fetch_by_files,
                                       fetch_step=args.fetch_step,
                                       infinity_mode=False,
                                       in_memory=args.in_memory,
                                       async_load=False,
                                       name='train', jets=True)

        iterator = iter(train_data)
    else:
        iterator = iter(EventDataset.from_directory(path, model_clusters_file=model_clusters_file,
                                                    model_output_file=model_output_file,
                                                    include_model_jets_unfiltered=include_model_jets_unfiltered))
    return iterator