File size: 2,411 Bytes
3c06693
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import torch
from PIL import Image
import pandas as pd



class RetrievalDataset(torch.utils.data.Dataset):
    def __init__(self, img_dir_path: str, annotations_file_path: str, split: str, transform=None, tokenizer=None) -> None:
        self.img_dir_path = img_dir_path
        self.transform = transform
        self.tokenizer = tokenizer
        self.split = split
        self.annotations = self.split_data(
            self.convert_image_names_to_path(
                pd.read_csv(annotations_file_path)
            )
        )
    
    def __len__(self) -> int:
        return len(self.annotations)

    def __getitem__(self, idx: int) -> tuple:
        query_img_path = self.annotations.iloc[idx]['query_image']
        query_text = self.annotations.iloc[idx]['query_text']
        target_img_path = self.annotations.iloc[idx]['target_image']
        query_img = Image.open(query_img_path).convert('RGB')
        target_img = Image.open(target_img_path).convert('RGB')
        # query_img = torchvision.io.read_image(path=query_img_path, mode=torchvision.io.image.ImageReadMode.RGB)
        # target_img = torchvision.io.read_image(path=target_img_path, mode=torchvision.io.image.ImageReadMode.RGB)
        if self.transform:
            query_img = self.transform(query_img)
            target_img = self.transform(target_img)
        if self.tokenizer:
            query_text = self.tokenizer(query_text).squeeze(0)
        return query_img, query_text, target_img, self.annotations.iloc[idx]['query_text']
    
    def split_data(self, annotations):
        shuffled_df = annotations.sample(frac=1, random_state=42).reset_index(drop=True)
        if self.split == "test":
            return shuffled_df # sample test set
        if self.split == "train":
            return shuffled_df.iloc[:int(0.9 * len(shuffled_df))] # train set
        if self.split == "validation":
            return shuffled_df.iloc[int(0.9 * len(shuffled_df)):] # validation set
        raise Exception("split is not valid")

    def load_queries(self):
        return self.annotations
    
    def load_database(self):
        return pd.DataFrame({'target_image': self.annotations["target_image"].unique()})
    
    def convert_image_names_to_path(self, df):
        df["query_image"] = self.img_dir_path + "/" + df["query_image"]
        df["target_image"] = self.img_dir_path + "/" + df["target_image"]
        return df