Spaces:
Sleeping
Sleeping
File size: 2,358 Bytes
3a664f3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 |
"""
Contains functionality for creating PyTorch DataLoaders for
image classification data(Food101).
"""
import os
from pathlib import Path
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
num_workers = os.cpu_count()
def create_dataloaders(transform: transforms.Compose,
batch_size: int,
num_workers: int = num_workers):
"""Creates training and testing DataLoaders.
Takes in a transform them and download food 101 dataset
and then into PyTorch DataLoaders.
Args:
transform: torchvision transforms to perform on training and testing data.
batch_size: Number of samples per batch in each of the DataLoaders.
num_workers: An integer for number of workers per DataLoader.
Returns:
A tuple of (train_dataloader, test_dataloader, class_names).
Where class_names is a list of the target classes.
Example usage:
train_dataloader, test_dataloader, class_names = \
= create_dataloaders(transform=some_transform,
batch_size=32,
num_workers=4)
"""
# making dir for data
data_path = Path("data")
data_path.mkdir(parents=True, exist_ok=True)
# Dataset
train_data = torchvision.datasets.Food101(root=data_path,
split="train",
transform=transform,
download=True)
test_data = torchvision.datasets.Food101(root=data_path,
split="test",
transform=transform,
download=True)
# DataLoaders
train_dataloader = DataLoader(dataset=train_data,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers)
test_dataloader = DataLoader(dataset=test_data,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers)
class_names = train_data.classes
return train_dataloader, test_dataloader, class_names
|