Spaces:
Sleeping
Sleeping
import os | |
from src.dataset.dataset import SimpleIterDataset, EventDataset | |
from src.utils.utils import to_filelist | |
from src.utils.paths import get_path | |
# To be used for simple analysis scripts, not for the full training! | |
def get_iter(path, full_dataloader=False, model_clusters_file=None, model_output_file=None, | |
include_model_jets_unfiltered=False): | |
if full_dataloader: | |
datasets = os.listdir(path) | |
datasets = [os.path.join(path, x) for x in datasets] | |
class Args: | |
def __init__(self): | |
self.data_train = datasets | |
self.data_val = datasets | |
#self.data_train = files_train | |
self.data_config = get_path('config_files/config_jets.yaml', "code") | |
self.extra_selection = None | |
self.train_val_split = 1 | |
self.data_fraction = 1 | |
self.file_fraction = 1 | |
self.fetch_by_files = False | |
self.fetch_step = 0.1 | |
self.steps_per_epoch = None | |
self.in_memory = False | |
self.local_rank = None | |
self.copy_inputs = False | |
self.no_remake_weights = False | |
self.batch_size = 10 | |
self.num_workers = 0 | |
self.demo = False | |
self.laplace = False | |
self.diffs = False | |
self.class_edges = False | |
args = Args() | |
train_range = (0, args.train_val_split) | |
train_file_dict, train_files = to_filelist(args, 'train') | |
train_data = SimpleIterDataset(train_file_dict, args.data_config, for_training=True, | |
extra_selection=args.extra_selection, | |
remake_weights=True, | |
load_range_and_fraction=(train_range, args.data_fraction), | |
file_fraction=args.file_fraction, | |
fetch_by_files=args.fetch_by_files, | |
fetch_step=args.fetch_step, | |
infinity_mode=False, | |
in_memory=args.in_memory, | |
async_load=False, | |
name='train', jets=True) | |
iterator = iter(train_data) | |
else: | |
iterator = iter(EventDataset.from_directory(path, model_clusters_file=model_clusters_file, | |
model_output_file=model_output_file, | |
include_model_jets_unfiltered=include_model_jets_unfiltered)) | |
return iterator | |