import torch from PIL import Image import pandas as pd class RetrievalDataset(torch.utils.data.Dataset): def __init__(self, img_dir_path: str, annotations_file_path: str, split: str, transform=None, tokenizer=None) -> None: self.img_dir_path = img_dir_path self.transform = transform self.tokenizer = tokenizer self.split = split self.annotations = self.split_data( self.convert_image_names_to_path( pd.read_csv(annotations_file_path) ) ) def __len__(self) -> int: return len(self.annotations) def __getitem__(self, idx: int) -> tuple: query_img_path = self.annotations.iloc[idx]['query_image'] query_text = self.annotations.iloc[idx]['query_text'] target_img_path = self.annotations.iloc[idx]['target_image'] query_img = Image.open(query_img_path).convert('RGB') target_img = Image.open(target_img_path).convert('RGB') # query_img = torchvision.io.read_image(path=query_img_path, mode=torchvision.io.image.ImageReadMode.RGB) # target_img = torchvision.io.read_image(path=target_img_path, mode=torchvision.io.image.ImageReadMode.RGB) if self.transform: query_img = self.transform(query_img) target_img = self.transform(target_img) if self.tokenizer: query_text = self.tokenizer(query_text).squeeze(0) return query_img, query_text, target_img, self.annotations.iloc[idx]['query_text'] def split_data(self, annotations): shuffled_df = annotations.sample(frac=1, random_state=42).reset_index(drop=True) if self.split == "test": return shuffled_df # sample test set if self.split == "train": return shuffled_df.iloc[:int(0.9 * len(shuffled_df))] # train set if self.split == "validation": return shuffled_df.iloc[int(0.9 * len(shuffled_df)):] # validation set raise Exception("split is not valid") def load_queries(self): return self.annotations def load_database(self): return pd.DataFrame({'target_image': self.annotations["target_image"].unique()}) def convert_image_names_to_path(self, df): df["query_image"] = self.img_dir_path + "/" + df["query_image"] df["target_image"] = self.img_dir_path + "/" + df["target_image"] return df