File size: 1,741 Bytes
33d4721
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import numpy as np
import torch


class ImageClassificationDataset:
    """
    A custom dataset class for image classification tasks.

    Args:
        data (list): A list of data samples, where each sample is a dictionary containing image and target information.
        transforms (callable): A function/transform that takes in an image and returns a transformed version.
        config (object): A configuration object containing the column names for images and targets.

    Attributes:
        data (list): The dataset containing image and target information.
        transforms (callable): The transformation function to be applied to the images.
        config (object): The configuration object with image and target column names.

    Methods:
        __len__(): Returns the number of samples in the dataset.
        __getitem__(item): Retrieves the image and target at the specified index, applies transformations, and returns them as tensors.

    Example:
        dataset = ImageClassificationDataset(data, transforms, config)
        image, target = dataset[0]
    """

    def __init__(self, data, transforms, config):
        self.data = data
        self.transforms = transforms
        self.config = config

    def __len__(self):
        return len(self.data)

    def __getitem__(self, item):
        image = self.data[item][self.config.image_column]
        target = int(self.data[item][self.config.target_column])

        image = self.transforms(image=np.array(image.convert("RGB")))["image"]
        image = np.transpose(image, (2, 0, 1)).astype(np.float32)

        return {
            "pixel_values": torch.tensor(image, dtype=torch.float),
            "labels": torch.tensor(target, dtype=torch.long),
        }