File size: 2,358 Bytes
3a664f3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
"""
Contains functionality for creating PyTorch DataLoaders for 
image classification data(Food101).
"""
import os

from pathlib import Path

import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader

num_workers = os.cpu_count()


def create_dataloaders(transform: transforms.Compose,
                       batch_size: int,
                       num_workers: int = num_workers):
    """Creates training and testing DataLoaders.

    Takes in a transform them and download food 101 dataset 
    and then into PyTorch DataLoaders.

    Args:
        transform: torchvision transforms to perform on training and testing data.
        batch_size: Number of samples per batch in each of the DataLoaders.
        num_workers: An integer for number of workers per DataLoader.

    Returns:
        A tuple of (train_dataloader, test_dataloader, class_names).
        Where class_names is a list of the target classes.
        Example usage:
          train_dataloader, test_dataloader, class_names = \
            = create_dataloaders(transform=some_transform,
                                 batch_size=32,
                                 num_workers=4)
    """
    # making dir for data
    data_path = Path("data")
    data_path.mkdir(parents=True, exist_ok=True)

    # Dataset
    train_data = torchvision.datasets.Food101(root=data_path,
                                              split="train",
                                              transform=transform,
                                              download=True)
    test_data = torchvision.datasets.Food101(root=data_path,
                                             split="test",
                                             transform=transform,
                                             download=True)
    # DataLoaders
    train_dataloader = DataLoader(dataset=train_data,
                                  batch_size=batch_size,
                                  shuffle=True,
                                  num_workers=num_workers)
    test_dataloader = DataLoader(dataset=test_data,
                                 batch_size=batch_size,
                                 shuffle=False,
                                 num_workers=num_workers)
    class_names = train_data.classes
    return train_dataloader, test_dataloader, class_names