Spaces:
Sleeping
Sleeping
| import torch | |
| from typing import List, Union | |
| from torch.utils import data | |
| # FIXME Test and fix these split functions later | |
| def random_split(dataset: data.Dataset, lengths, seed): | |
| return data.random_split(dataset, lengths, generator=torch.Generator().manual_seed(seed)) | |
| def cold_start(dataset: data.Dataset, frac: List[float], entities: Union[str, List[str]]): | |
| """Create cold-start splits for PyTorch datasets. | |
| Args: | |
| dataset (Dataset): PyTorch dataset object. | |
| frac (list): A list of train/valid/test fractions. | |
| entities (Union[str, List[str]]): Either a single "cold" entity or a list of "cold" entities | |
| on which the split is done. | |
| Returns: | |
| dict: A dictionary of splitted datasets, where keys are 'train', 'valid', and 'test', | |
| and values correspond to each dataset. | |
| """ | |
| if isinstance(entities, str): | |
| entities = [entities] | |
| train_frac, val_frac, test_frac = frac | |
| # Collect unique instances for each entity | |
| entity_instances = {} | |
| for entity in entities: | |
| entity_instances[entity] = list(set([getattr(sample, entity) for sample in dataset])) | |
| # Sample instances belonging to the test datasets | |
| test_entity_instances = [ | |
| torch.randperm(len(entity_instances[entity]))[:int(len(entity_instances[entity]) * test_frac)] | |
| for entity in entities | |
| ] | |
| # Select samples where all entities are in the test set | |
| test_indices = [] | |
| for i, sample in enumerate(dataset): | |
| if all([getattr(sample, entity) in entity_instances[entity][test_entity_instances[j]] for j, entity in enumerate(entities)]): | |
| test_indices.append(i) | |
| if len(test_indices) == 0: | |
| raise ValueError('No test samples found. Try increasing the test frac or a less stringent splitting strategy.') | |
| # Proceed with validation data | |
| train_val_indices = list(set(range(len(dataset))) - set(test_indices)) | |
| val_entity_instances = [ | |
| torch.randperm(len(entity_instances[entity]))[:int(len(entity_instances[entity]) * val_frac / (1 - test_frac))] | |
| for entity in entities | |
| ] | |
| val_indices = [] | |
| for i in train_val_indices: | |
| if all([getattr(dataset[i], entity) in entity_instances[entity][val_entity_instances[j]] for j, entity in enumerate(entities)]): | |
| val_indices.append(i) | |
| if len(val_indices) == 0: | |
| raise ValueError('No validation samples found. Try increasing the test frac or a less stringent splitting strategy.') | |
| train_indices = list(set(train_val_indices) - set(val_indices)) | |
| train_dataset = torch.utils.data.Subset(dataset, train_indices) | |
| val_dataset = torch.utils.data.Subset(dataset, val_indices) | |
| test_dataset = torch.utils.data.Subset(dataset, test_indices) | |
| return {'train': train_dataset, 'valid': val_dataset, 'test': test_dataset} | |