Spaces:
Sleeping
Sleeping
File size: 2,725 Bytes
e75a247 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 |
import os
from src.dataset.dataset import SimpleIterDataset, EventDataset
from src.utils.utils import to_filelist
from src.utils.paths import get_path
# To be used for simple analysis scripts, not for the full training!
def get_iter(path, full_dataloader=False, model_clusters_file=None, model_output_file=None,
include_model_jets_unfiltered=False):
if full_dataloader:
datasets = os.listdir(path)
datasets = [os.path.join(path, x) for x in datasets]
class Args:
def __init__(self):
self.data_train = datasets
self.data_val = datasets
#self.data_train = files_train
self.data_config = get_path('config_files/config_jets.yaml', "code")
self.extra_selection = None
self.train_val_split = 1
self.data_fraction = 1
self.file_fraction = 1
self.fetch_by_files = False
self.fetch_step = 0.1
self.steps_per_epoch = None
self.in_memory = False
self.local_rank = None
self.copy_inputs = False
self.no_remake_weights = False
self.batch_size = 10
self.num_workers = 0
self.demo = False
self.laplace = False
self.diffs = False
self.class_edges = False
args = Args()
train_range = (0, args.train_val_split)
train_file_dict, train_files = to_filelist(args, 'train')
train_data = SimpleIterDataset(train_file_dict, args.data_config, for_training=True,
extra_selection=args.extra_selection,
remake_weights=True,
load_range_and_fraction=(train_range, args.data_fraction),
file_fraction=args.file_fraction,
fetch_by_files=args.fetch_by_files,
fetch_step=args.fetch_step,
infinity_mode=False,
in_memory=args.in_memory,
async_load=False,
name='train', jets=True)
iterator = iter(train_data)
else:
iterator = iter(EventDataset.from_directory(path, model_clusters_file=model_clusters_file,
model_output_file=model_output_file,
include_model_jets_unfiltered=include_model_jets_unfiltered))
return iterator
|