import os from src.dataset.dataset import SimpleIterDataset, EventDataset from src.utils.utils import to_filelist from src.utils.paths import get_path # To be used for simple analysis scripts, not for the full training! def get_iter(path, full_dataloader=False, model_clusters_file=None, model_output_file=None, include_model_jets_unfiltered=False): if full_dataloader: datasets = os.listdir(path) datasets = [os.path.join(path, x) for x in datasets] class Args: def __init__(self): self.data_train = datasets self.data_val = datasets #self.data_train = files_train self.data_config = get_path('config_files/config_jets.yaml', "code") self.extra_selection = None self.train_val_split = 1 self.data_fraction = 1 self.file_fraction = 1 self.fetch_by_files = False self.fetch_step = 0.1 self.steps_per_epoch = None self.in_memory = False self.local_rank = None self.copy_inputs = False self.no_remake_weights = False self.batch_size = 10 self.num_workers = 0 self.demo = False self.laplace = False self.diffs = False self.class_edges = False args = Args() train_range = (0, args.train_val_split) train_file_dict, train_files = to_filelist(args, 'train') train_data = SimpleIterDataset(train_file_dict, args.data_config, for_training=True, extra_selection=args.extra_selection, remake_weights=True, load_range_and_fraction=(train_range, args.data_fraction), file_fraction=args.file_fraction, fetch_by_files=args.fetch_by_files, fetch_step=args.fetch_step, infinity_mode=False, in_memory=args.in_memory, async_load=False, name='train', jets=True) iterator = iter(train_data) else: iterator = iter(EventDataset.from_directory(path, model_clusters_file=model_clusters_file, model_output_file=model_output_file, include_model_jets_unfiltered=include_model_jets_unfiltered)) return iterator