Spaces:
Sleeping
Sleeping
File size: 1,741 Bytes
33d4721 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 |
import numpy as np
import torch
class ImageClassificationDataset:
"""
A custom dataset class for image classification tasks.
Args:
data (list): A list of data samples, where each sample is a dictionary containing image and target information.
transforms (callable): A function/transform that takes in an image and returns a transformed version.
config (object): A configuration object containing the column names for images and targets.
Attributes:
data (list): The dataset containing image and target information.
transforms (callable): The transformation function to be applied to the images.
config (object): The configuration object with image and target column names.
Methods:
__len__(): Returns the number of samples in the dataset.
__getitem__(item): Retrieves the image and target at the specified index, applies transformations, and returns them as tensors.
Example:
dataset = ImageClassificationDataset(data, transforms, config)
image, target = dataset[0]
"""
def __init__(self, data, transforms, config):
self.data = data
self.transforms = transforms
self.config = config
def __len__(self):
return len(self.data)
def __getitem__(self, item):
image = self.data[item][self.config.image_column]
target = int(self.data[item][self.config.target_column])
image = self.transforms(image=np.array(image.convert("RGB")))["image"]
image = np.transpose(image, (2, 0, 1)).astype(np.float32)
return {
"pixel_values": torch.tensor(image, dtype=torch.float),
"labels": torch.tensor(target, dtype=torch.long),
}
|