Spaces:
Sleeping
Sleeping
File size: 1,498 Bytes
33d4721 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 |
import numpy as np
import torch
class ImageRegressionDataset:
"""
A dataset class for image regression tasks.
Args:
data (list): A list of data points where each data point is a dictionary containing image and target information.
transforms (callable): A function/transform that takes in an image and returns a transformed version.
config (object): A configuration object that contains the column names for images and targets.
Attributes:
data (list): The input data.
transforms (callable): The transformation function.
config (object): The configuration object.
Methods:
__len__(): Returns the number of data points in the dataset.
__getitem__(item): Returns a dictionary containing the transformed image and the target value for the given index.
"""
def __init__(self, data, transforms, config):
self.data = data
self.transforms = transforms
self.config = config
def __len__(self):
return len(self.data)
def __getitem__(self, item):
image = self.data[item][self.config.image_column]
target = self.data[item][self.config.target_column]
image = self.transforms(image=np.array(image.convert("RGB")))["image"]
image = np.transpose(image, (2, 0, 1)).astype(np.float32)
return {
"pixel_values": torch.tensor(image, dtype=torch.float),
"labels": torch.tensor(target, dtype=torch.float),
}
|