safinal commited on
Commit
3c06693
Β·
verified Β·
1 Parent(s): bfa1c41

Create dataset.py

Browse files
Files changed (1) hide show
  1. dataset.py +56 -0
dataset.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from PIL import Image
3
+ import pandas as pd
4
+
5
+
6
+
7
+ class RetrievalDataset(torch.utils.data.Dataset):
8
+ def __init__(self, img_dir_path: str, annotations_file_path: str, split: str, transform=None, tokenizer=None) -> None:
9
+ self.img_dir_path = img_dir_path
10
+ self.transform = transform
11
+ self.tokenizer = tokenizer
12
+ self.split = split
13
+ self.annotations = self.split_data(
14
+ self.convert_image_names_to_path(
15
+ pd.read_csv(annotations_file_path)
16
+ )
17
+ )
18
+
19
+ def __len__(self) -> int:
20
+ return len(self.annotations)
21
+
22
+ def __getitem__(self, idx: int) -> tuple:
23
+ query_img_path = self.annotations.iloc[idx]['query_image']
24
+ query_text = self.annotations.iloc[idx]['query_text']
25
+ target_img_path = self.annotations.iloc[idx]['target_image']
26
+ query_img = Image.open(query_img_path).convert('RGB')
27
+ target_img = Image.open(target_img_path).convert('RGB')
28
+ # query_img = torchvision.io.read_image(path=query_img_path, mode=torchvision.io.image.ImageReadMode.RGB)
29
+ # target_img = torchvision.io.read_image(path=target_img_path, mode=torchvision.io.image.ImageReadMode.RGB)
30
+ if self.transform:
31
+ query_img = self.transform(query_img)
32
+ target_img = self.transform(target_img)
33
+ if self.tokenizer:
34
+ query_text = self.tokenizer(query_text).squeeze(0)
35
+ return query_img, query_text, target_img, self.annotations.iloc[idx]['query_text']
36
+
37
+ def split_data(self, annotations):
38
+ shuffled_df = annotations.sample(frac=1, random_state=42).reset_index(drop=True)
39
+ if self.split == "test":
40
+ return shuffled_df # sample test set
41
+ if self.split == "train":
42
+ return shuffled_df.iloc[:int(0.9 * len(shuffled_df))] # train set
43
+ if self.split == "validation":
44
+ return shuffled_df.iloc[int(0.9 * len(shuffled_df)):] # validation set
45
+ raise Exception("split is not valid")
46
+
47
+ def load_queries(self):
48
+ return self.annotations
49
+
50
+ def load_database(self):
51
+ return pd.DataFrame({'target_image': self.annotations["target_image"].unique()})
52
+
53
+ def convert_image_names_to_path(self, df):
54
+ df["query_image"] = self.img_dir_path + "/" + df["query_image"]
55
+ df["target_image"] = self.img_dir_path + "/" + df["target_image"]
56
+ return df