File size: 2,854 Bytes
6ae852e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import torch
from typing import List, Union
from torch.utils import data

# FIXME Test and fix these split functions later


def random_split(dataset: data.Dataset, lengths, seed):
    return data.random_split(dataset, lengths, generator=torch.Generator().manual_seed(seed))


def cold_start(dataset: data.Dataset, frac: List[float], entities: Union[str, List[str]]):
    """Create cold-start splits for PyTorch datasets.

    Args:
        dataset (Dataset): PyTorch dataset object.
        frac (list): A list of train/valid/test fractions.
        entities (Union[str, List[str]]): Either a single "cold" entity or a list of "cold" entities
            on which the split is done.

    Returns:
        dict: A dictionary of splitted datasets, where keys are 'train', 'valid', and 'test',
            and values correspond to each dataset.
    """
    if isinstance(entities, str):
        entities = [entities]

    train_frac, val_frac, test_frac = frac

    # Collect unique instances for each entity
    entity_instances = {}
    for entity in entities:
        entity_instances[entity] = list(set([getattr(sample, entity) for sample in dataset]))

    # Sample instances belonging to the test datasets
    test_entity_instances = [
        torch.randperm(len(entity_instances[entity]))[:int(len(entity_instances[entity]) * test_frac)]
        for entity in entities
    ]

    # Select samples where all entities are in the test set
    test_indices = []
    for i, sample in enumerate(dataset):
        if all([getattr(sample, entity) in entity_instances[entity][test_entity_instances[j]] for j, entity in enumerate(entities)]):
            test_indices.append(i)

    if len(test_indices) == 0:
        raise ValueError('No test samples found. Try increasing the test frac or a less stringent splitting strategy.')

    # Proceed with validation data
    train_val_indices = list(set(range(len(dataset))) - set(test_indices))

    val_entity_instances = [
        torch.randperm(len(entity_instances[entity]))[:int(len(entity_instances[entity]) * val_frac / (1 - test_frac))]
        for entity in entities
    ]

    val_indices = []
    for i in train_val_indices:
        if all([getattr(dataset[i], entity) in entity_instances[entity][val_entity_instances[j]] for j, entity in enumerate(entities)]):
            val_indices.append(i)

    if len(val_indices) == 0:
        raise ValueError('No validation samples found. Try increasing the test frac or a less stringent splitting strategy.')

    train_indices = list(set(train_val_indices) - set(val_indices))

    train_dataset = torch.utils.data.Subset(dataset, train_indices)
    val_dataset = torch.utils.data.Subset(dataset, val_indices)
    test_dataset = torch.utils.data.Subset(dataset, test_indices)

    return {'train': train_dataset, 'valid': val_dataset, 'test': test_dataset}