|
import numpy as np |
|
import pandas as pd |
|
import torch |
|
import random |
|
import pickle |
|
from os.path import join |
|
from os.path import isfile |
|
from PIL import Image |
|
from sklearn.model_selection import train_test_split |
|
from torch.utils.data import Dataset |
|
from torchvision.transforms import ( |
|
Compose, |
|
RandomCrop, |
|
CenterCrop, |
|
RandomHorizontalFlip, |
|
ToTensor, |
|
) |
|
import time |
|
from torchvision.transforms import GaussianBlur |
|
from torchvision import transforms |
|
from pathlib import Path |
|
import json |
|
from tqdm import tqdm |
|
import multiprocessing as mp |
|
import ctypes |
|
|
|
|
|
def normalize(lat, lon): |
|
"""Used to put all lat lon inside ±90 and ±180.""" |
|
lat = (lat + 90) % 360 - 90 |
|
if lat > 90: |
|
lat = 180 - lat |
|
lon += 180 |
|
lon = (lon + 180) % 360 - 180 |
|
return lat, lon |
|
|
|
|
|
def collate_fn(batch): |
|
"""Collate function for the dataloader. |
|
Args: |
|
batch (list): list of dictionaries with keys "img", "gps", "idx" and optionally "label" |
|
Returns: |
|
dict: dictionary with keys "img", "gps", "idx" and optionally "label" |
|
""" |
|
keys = list(batch[0].keys()) |
|
if "weight" in batch[0].keys(): |
|
keys.remove("weight") |
|
output = {} |
|
for key in [ |
|
"idx", |
|
"unique_country", |
|
"unique_region", |
|
"unique_sub-region", |
|
"unique_city", |
|
"img_idx", |
|
"text", |
|
]: |
|
if key in keys: |
|
idx = [x[key] for x in batch] |
|
output[key] = idx |
|
keys.remove(key) |
|
if "img" in keys and isinstance(batch[0]["img"], Image.Image): |
|
output["img"] = [x["img"] for x in batch] |
|
keys.remove("img") |
|
for key in keys: |
|
if not ("text" in key): |
|
output[key] = torch.stack([x[key] for x in batch]) |
|
return output |
|
|
|
|
|
def collate_fn_streetclip(batch): |
|
"""Collate function for the dataloader. |
|
Args: |
|
batch (list): list of dictionaries with keys "img", "gps", "idx" and optionally "label" |
|
Returns: |
|
dict: dictionary with keys "img", "gps", "idx" and optionally "label" |
|
""" |
|
keys = list(batch[0].keys()) |
|
if "weight" in batch[0].keys(): |
|
keys.remove("weight") |
|
output = {} |
|
for key in [ |
|
"idx", |
|
"unique_country", |
|
"unique_region", |
|
"unique_sub-region", |
|
"unique_city", |
|
"img_idx", |
|
"img", |
|
"text", |
|
]: |
|
if key in keys: |
|
idx = [x[key] for x in batch] |
|
output[key] = idx |
|
keys.remove(key) |
|
for key in keys: |
|
if not ("text" in key): |
|
output[key] = torch.stack([x[key] for x in batch]) |
|
return output |
|
|
|
|
|
def collate_fn_denstity(batch): |
|
"""Collate function for the dataloader. |
|
Args: |
|
batch (list): list of dictionaries with keys "img", "gps", "idx" and optionally "label" |
|
Returns: |
|
dict: dictionary with keys "img", "gps", "idx" and optionally "label" |
|
""" |
|
keys = list(batch[0].keys()) |
|
if "weight" in batch[0].keys(): |
|
keys.remove("weight") |
|
|
|
weights = np.array([x["weight"] for x in batch]) |
|
normalized_weights = weights / np.sum(weights) |
|
sampled_indices = np.random.choice( |
|
len(batch), size=len(batch), p=normalized_weights, replace=True |
|
) |
|
output = {} |
|
for key in [ |
|
"idx", |
|
"unique_country", |
|
"unique_region", |
|
"unique_sub-region", |
|
"unique_city", |
|
"img_idx", |
|
"text", |
|
]: |
|
if key in keys: |
|
idx = [batch[i][key] for i in sampled_indices] |
|
output[key] = idx |
|
keys.remove(key) |
|
for key in keys: |
|
if not ("text" in key): |
|
output[key] = torch.stack([batch[i][key] for i in sampled_indices]) |
|
return output |
|
|
|
|
|
def collate_fn_streetclip_denstity(batch): |
|
"""Collate function for the dataloader. |
|
Args: |
|
batch (list): list of dictionaries with keys "img", "gps", "idx" and optionally "label" |
|
Returns: |
|
dict: dictionary with keys "img", "gps", "idx" and optionally "label" |
|
""" |
|
keys = list(batch[0].keys()) |
|
if "weight" in batch[0].keys(): |
|
keys.remove("weight") |
|
|
|
weights = np.array([x["weight"] for x in batch]) |
|
normalized_weights = weights / np.sum(weights) |
|
sampled_indices = np.random.choice( |
|
len(batch), size=len(batch), p=normalized_weights, replace=True |
|
) |
|
output = {} |
|
for key in [ |
|
"idx", |
|
"unique_country", |
|
"unique_region", |
|
"unique_sub-region", |
|
"unique_city", |
|
"img_idx", |
|
"img", |
|
"text", |
|
]: |
|
if key in keys: |
|
idx = [batch[i][key] for i in sampled_indices] |
|
output[key] = idx |
|
keys.remove(key) |
|
for key in keys: |
|
if not ("text" in key): |
|
output[key] = torch.stack([batch[i][key] for i in sampled_indices]) |
|
return output |
|
|
|
|
|
def collate_fn_contrastive(batch): |
|
"""Collate function for the dataloader. |
|
Args: |
|
batch (list): list of dictionaries with keys "img", "gps", "idx" and optionally "label" |
|
Returns: |
|
dict: dictionary with keys "img", "gps", "idx" and optionally "label" |
|
""" |
|
output = collate_fn(batch) |
|
pos_img = torch.stack([x["pos_img"] for x in batch]) |
|
output["pos_img"] = pos_img |
|
return output |
|
|
|
|
|
def collate_fn_contrastive_density(batch): |
|
"""Collate function for the dataloader. |
|
Args: |
|
batch (list): list of dictionaries with keys "img", "gps", "idx" and optionally "label" |
|
Returns: |
|
dict: dictionary with keys "img", "gps", "idx" and optionally "label" |
|
""" |
|
keys = list(batch[0].keys()) |
|
if "weight" in batch[0].keys(): |
|
keys.remove("weight") |
|
|
|
weights = np.array([x["weight"] for x in batch]) |
|
normalized_weights = weights / np.sum(weights) |
|
sampled_indices = np.random.choice( |
|
len(batch), size=len(batch), p=normalized_weights, replace=True |
|
) |
|
output = {} |
|
for key in [ |
|
"idx", |
|
"unique_country", |
|
"unique_region", |
|
"unique_sub-region", |
|
"unique_city", |
|
"img_idx", |
|
]: |
|
if key in keys: |
|
idx = [batch[i][key] for i in sampled_indices] |
|
output[key] = idx |
|
keys.remove(key) |
|
for key in keys: |
|
if not ("text" in key): |
|
output[key] = torch.stack([batch[i][key] for i in sampled_indices]) |
|
return output |
|
|
|
|
|
class iNaturalist(Dataset): |
|
def __init__( |
|
self, |
|
path, |
|
transforms, |
|
split="train", |
|
output_type="image", |
|
embedding_name="dinov2", |
|
): |
|
super().__init__() |
|
self.split = split |
|
with open(Path(path) / f"{split}.json", "r") as f: |
|
self.metadata = json.load(f) |
|
self.metadata = [ |
|
datapoint |
|
for datapoint in self.metadata["images"] |
|
if "latitude" in datapoint and datapoint["latitude"] is not None |
|
] |
|
self.path = path |
|
self.transforms = transforms |
|
self.output_type = output_type |
|
self.embedding_name = embedding_name |
|
|
|
self.collate_fn = collate_fn |
|
|
|
def __getitem__(self, i): |
|
output = {} |
|
if "image" in self.output_type: |
|
image_path = Path(self.path) / "images" / self.metadata[i]["file_name"] |
|
img = self.transforms(Image.open(image_path)) |
|
output["img"] = img |
|
if "emb" in self.output_type: |
|
emb_path = ( |
|
Path(self.path) |
|
/ "embeddings" |
|
/ self.embedding_name |
|
/ self.metadata[i]["file_name"].replace(".jpg", ".npy") |
|
) |
|
output["emb"] = torch.tensor(np.load(emb_path)) |
|
lat, lon = normalize( |
|
self.metadata[i]["latitude"], self.metadata[i]["longitude"] |
|
) |
|
output["gps"] = torch.tensor( |
|
[np.radians(lat), np.radians(lon)], dtype=torch.float |
|
) |
|
output["idx"] = i |
|
output["img_idx"] = self.metadata[i]["id"] |
|
return output |
|
|
|
def __len__(self): |
|
return len(self.metadata) |
|
|
|
|
|
class OSV5M(Dataset): |
|
csv_dtype = {"category": str, "country": str, "city": str} |
|
|
|
def __init__( |
|
self, |
|
path, |
|
transforms, |
|
split="train", |
|
class_name=None, |
|
aux_data=[], |
|
is_baseline=False, |
|
areas=["country", "region", "sub-region", "city"], |
|
streetclip=False, |
|
suff="", |
|
blur=False, |
|
output_type="image", |
|
embedding_name="dinov2", |
|
): |
|
"""Initializes the dataset. |
|
Args: |
|
path (str): path to the dataset |
|
transforms (torchvision.transforms): transforms to apply to the images |
|
split (str): split to use (train, val, test) |
|
class_name (str): category to use (e.g. "city") |
|
aux_data (list of str): auxilliary datas to use |
|
areas (list of str): regions to perform accuracy |
|
streetclip (bool): if the model is streetclip, do not use transform |
|
suff (str): suffix of test csv |
|
blur (bool): blur bottom of images or not |
|
output_type (str): type of output (image or emb) |
|
""" |
|
self.suff = suff |
|
self.path = path |
|
self.aux = len(aux_data) > 0 |
|
self.aux_list = aux_data |
|
self.split = split |
|
if split == "select": |
|
self.df = self.load_split(split) |
|
split = "test" |
|
else: |
|
self.df = self.load_split(split) |
|
self.split = split |
|
if "image" in output_type: |
|
self.image_data_folder = join( |
|
path, |
|
"images", |
|
("train" if split == "val" else split), |
|
) |
|
self.image_dict_names = {} |
|
for root, _, files in os.walk(self.image_data_folder): |
|
for file in files: |
|
self.image_dict_names[file] = os.path.join(root, file) |
|
if "emb" in output_type: |
|
self.emb_data_folder = join( |
|
path, |
|
"embeddings", |
|
embedding_name, |
|
("train" if split == "val" else split), |
|
) |
|
self.emb_dict_names = {} |
|
for root, _, files in os.walk(self.emb_data_folder): |
|
for file in files: |
|
self.emb_dict_names[file] = os.path.join(root, file) |
|
|
|
self.output_type = output_type |
|
|
|
self.is_baseline = is_baseline |
|
if self.aux: |
|
self.aux_data = {} |
|
for col in self.aux_list: |
|
if col in ["land_cover", "climate", "soil"]: |
|
self.aux_data[col] = pd.get_dummies(self.df[col], dtype=float) |
|
if col == "climate": |
|
for i in range(31): |
|
if not (i in list(self.aux_data[col].columns)): |
|
self.aux_data[col][i] = 0 |
|
desired_order = [i for i in range(31)] |
|
desired_order.remove(20) |
|
self.aux_data[col] = self.aux_data[col][desired_order] |
|
else: |
|
self.aux_data[col] = self.df[col].apply(lambda x: [x]) |
|
|
|
self.areas = ["_".join(["unique", area]) for area in areas] |
|
if class_name is None: |
|
self.class_name = class_name |
|
elif "quadtree" in class_name: |
|
self.class_name = class_name |
|
else: |
|
self.class_name = "_".join(["unique", class_name]) |
|
ex = self.extract_classes(self.class_name) |
|
self.df = self.df[ |
|
["id", "latitude", "longitude", "weight"] + self.areas + ex |
|
].fillna("NaN") |
|
if self.class_name in self.areas: |
|
self.df.columns = list(self.df.columns)[:-1] + [self.class_name + "_2"] |
|
self.transforms = transforms |
|
self.collate_fn = collate_fn |
|
self.collate_fn_density = collate_fn_denstity |
|
self.blur = blur |
|
self.streetclip = streetclip |
|
if self.streetclip: |
|
self.collate_fn = collate_fn_streetclip |
|
self.collate_fn_density = collate_fn_streetclip_denstity |
|
|
|
def load_split(self, split): |
|
"""Returns a new dataset with the given split.""" |
|
start_time = time.time() |
|
if split == "test": |
|
df = pd.read_csv(join(self.path, "test.csv"), dtype=self.csv_dtype) |
|
|
|
longitude = df["longitude"].values |
|
latitude = df["latitude"].values |
|
|
|
num_bins = 100 |
|
lon_bins = np.linspace(longitude.min(), longitude.max(), num_bins) |
|
lat_bins = np.linspace(latitude.min(), latitude.max(), num_bins) |
|
|
|
hist, _, _ = np.histogram2d(longitude, latitude, bins=[lon_bins, lat_bins]) |
|
weights = 1.0 / np.power(hist[df["lon_bin"], df["lat_bin"]], 0.75) |
|
normalized_weights = weights / np.sum(weights) |
|
df["weight"] = normalized_weights |
|
return df |
|
elif split == "select": |
|
df = pd.read_csv(join(self.path, "select.csv"), dtype=self.csv_dtype) |
|
|
|
longitude = df["longitude"].values |
|
latitude = df["latitude"].values |
|
|
|
num_bins = 100 |
|
lon_bins = np.linspace(longitude.min(), longitude.max(), num_bins) |
|
lat_bins = np.linspace(latitude.min(), latitude.max(), num_bins) |
|
|
|
hist, _, _ = np.histogram2d(longitude, latitude, bins=[lon_bins, lat_bins]) |
|
weights = 1.0 / np.power(hist[df["lon_bin"], df["lat_bin"]], 0.75) |
|
normalized_weights = weights / np.sum(weights) |
|
df["weight"] = normalized_weights |
|
return df |
|
else: |
|
if len(self.suff) == 0: |
|
df = pd.read_csv(join(self.path, "train.csv"), dtype=self.csv_dtype) |
|
else: |
|
df = pd.read_csv( |
|
join(self.path, "train" + "_" + self.suff + ".csv"), |
|
dtype=self.csv_dtype, |
|
) |
|
|
|
|
|
longitude = df["longitude"].values |
|
latitude = df["latitude"].values |
|
|
|
num_bins = 100 |
|
lon_bins = np.linspace(longitude.min(), longitude.max(), num_bins) |
|
lat_bins = np.linspace(latitude.min(), latitude.max(), num_bins) |
|
|
|
hist, _, _ = np.histogram2d(longitude, latitude, bins=[lon_bins, lat_bins]) |
|
weights = 1.0 / np.power(hist[df["lon_bin"], df["lat_bin"]], 0.75) |
|
normalized_weights = weights / np.sum(weights) |
|
df["weight"] = normalized_weights |
|
|
|
test_df = df.sample( |
|
n=int(0.1 * len(df)), |
|
weights=normalized_weights, |
|
replace=False, |
|
random_state=42, |
|
) |
|
|
|
end_time = time.time() |
|
print(f"Loading {split} dataset took {(end_time - start_time):.2f} seconds") |
|
|
|
if split == "val": |
|
return test_df |
|
else: |
|
return df.drop(test_df.index) |
|
|
|
def extract_classes(self, tag=None): |
|
"""Extracts the categories from the dataset.""" |
|
if tag is None: |
|
self.has_labels = False |
|
return [] |
|
splits = ["train", "test"] if self.is_baseline else ["train"] |
|
|
|
print(f"Loading categories from {splits}") |
|
|
|
|
|
self.categories = sorted( |
|
pd.concat( |
|
[pd.read_csv(join(self.path, f"{split}.csv"))[tag] for split in splits] |
|
) |
|
.fillna("NaN") |
|
.unique() |
|
.tolist() |
|
) |
|
|
|
if "NaN" in self.categories: |
|
self.categories.remove("NaN") |
|
if self.split != "test": |
|
self.df = self.df.dropna(subset=[tag]) |
|
|
|
self.num_classes = len(self.categories) |
|
|
|
|
|
self.category_to_index = { |
|
category: i for i, category in enumerate(self.categories) |
|
} |
|
self.has_labels = True |
|
return [tag] |
|
|
|
def __getitem__(self, i): |
|
"""Returns an item from the dataset. |
|
Args: |
|
i (int): index of the item |
|
Returns: |
|
dict: dictionary with keys "img", "gps", "idx" and optionally "label" |
|
""" |
|
x = list(self.df.iloc[i]) |
|
output = {} |
|
if "image" in self.output_type: |
|
if self.streetclip: |
|
img = Image.open(self.image_dict_names[f"{int(x[0])}.jpg"]) |
|
elif self.blur: |
|
img = transforms.ToTensor()( |
|
Image.open(self.image_dict_names[f"{int(x[0])}.jpg"]) |
|
) |
|
u = GaussianBlur(kernel_size=13, sigma=2.0) |
|
bottom_part = img[:, -14:, :].unsqueeze(0) |
|
blurred_bottom = u(bottom_part) |
|
img[:, -14:, :] = blurred_bottom.squeeze() |
|
img = self.transforms(transforms.ToPILImage()(img)) |
|
else: |
|
img = self.transforms( |
|
Image.open(self.image_dict_names[f"{int(x[0])}.jpg"]) |
|
) |
|
output["img"] = img |
|
if "emb" in self.output_type: |
|
output["emb"] = torch.FloatTensor( |
|
np.load(self.emb_dict_names[f"{int(x[0])}.npy"]) |
|
) |
|
|
|
lat, lon = normalize(x[1], x[2]) |
|
gps = torch.FloatTensor([np.radians(lat), np.radians(lon)]).squeeze(0) |
|
|
|
output.update( |
|
{ |
|
"gps": gps, |
|
"idx": i, |
|
"img_idx": int(x[0]), |
|
"weight": x[3], |
|
} |
|
) |
|
|
|
for count, area in enumerate(self.areas): |
|
output[area] = x[ |
|
count + 4 |
|
] |
|
|
|
if self.has_labels: |
|
if x[-1] in self.categories: |
|
output["label"] = torch.LongTensor( |
|
[self.category_to_index[x[-1]]] |
|
).squeeze(-1) |
|
else: |
|
output["label"] = torch.LongTensor([-1]).squeeze(-1) |
|
if self.aux: |
|
for col in self.aux_list: |
|
output[col] = torch.FloatTensor(self.aux_data[col].iloc[i]) |
|
return output |
|
|
|
def __len__(self): |
|
return len(self.df) |
|
|
|
|
|
class ContrastiveOSV5M(OSV5M): |
|
def __init__( |
|
self, |
|
path, |
|
transforms, |
|
split="train", |
|
class_name=None, |
|
aux_data=[], |
|
class_name2=None, |
|
blur=False, |
|
): |
|
""" |
|
class_name2 (str): if not None, we do contrastive an other class than the one specified for classif |
|
""" |
|
super().__init__( |
|
path, |
|
transforms, |
|
split=split, |
|
class_name=class_name, |
|
aux_data=aux_data, |
|
blur=blur, |
|
) |
|
self.add_label = False |
|
if not (class_name2 is None) and split != "test" and split != "select": |
|
self.add_label = True |
|
self.class_name = class_name2 |
|
self.extract_classes_contrastive(tag=class_name2) |
|
self.df = self.df.reset_index(drop=True) |
|
self.dict_classes = { |
|
value: indices.tolist() |
|
for value, indices in self.df.groupby(self.class_name).groups.items() |
|
} |
|
self.collate_fn = collate_fn_contrastive |
|
self.random_crop = RandomCrop(224) |
|
|
|
def sample_positive(self, i): |
|
""" |
|
sample positive image from the same city, country if it is available |
|
otherwise, apply different crop to the image |
|
""" |
|
x = self.df.iloc[i] |
|
class_name = x[self.class_name] |
|
idxs = self.dict_classes[class_name] |
|
idxs.remove(i) |
|
|
|
if len(idxs) > 0: |
|
idx = random.choice(idxs) |
|
x = self.df.iloc[idx] |
|
pos_img = self.transforms( |
|
Image.open(self.dict_names[f"{int(x['id'])}.jpg"]) |
|
) |
|
else: |
|
pos_img = self.random_crop( |
|
self.transforms(Image.open(self.dict_names[f"{int(x['id'])}.jpg"])) |
|
) |
|
return pos_img |
|
|
|
def extract_classes_contrastive(self, tag=None): |
|
"""Extracts the categories from the dataset.""" |
|
if tag is None: |
|
self.has_labels = False |
|
return [] |
|
splits = ["train", "test"] if self.is_baseline else ["train"] |
|
|
|
print(f"Loading categories from {splits}") |
|
|
|
|
|
categories = sorted( |
|
pd.concat( |
|
[pd.read_csv(join(self.path, f"{split}.csv"))[tag] for split in splits] |
|
) |
|
.fillna("NaN") |
|
.unique() |
|
.tolist() |
|
) |
|
|
|
self.contrastive_category_to_index = { |
|
category: i for i, category in enumerate(categories) |
|
} |
|
|
|
def __getitem__(self, i): |
|
output = super().__getitem__(i) |
|
pos_img = self.sample_positive(i) |
|
output["pos_img"] = pos_img |
|
if self.add_label: |
|
output["label_contrastive"] = torch.LongTensor( |
|
[self.contrastive_category_to_index[self.df[self.class_name].iloc[i]]] |
|
).squeeze(-1) |
|
return output |
|
|
|
|
|
class TextContrastiveOSV5M(OSV5M): |
|
def __init__( |
|
self, |
|
path, |
|
transforms, |
|
split="train", |
|
class_name=None, |
|
aux_data=[], |
|
blur=False, |
|
): |
|
super().__init__( |
|
path, |
|
transforms, |
|
split=split, |
|
class_name=class_name, |
|
aux_data=aux_data, |
|
blur=blur, |
|
) |
|
self.df = self.df.reset_index(drop=True) |
|
|
|
def get_text(self, i): |
|
""" |
|
sample positive image from the same city, country if it is available |
|
otherwise, apply different crop to the image |
|
""" |
|
x = self.df.iloc[i] |
|
l = [ |
|
name.split("_")[-1] |
|
for name in [ |
|
x["unique_city"], |
|
x["unique_sub-region"], |
|
x["unique_region"], |
|
x["unique_country"], |
|
] |
|
] |
|
|
|
pre = False |
|
sentence = "An image of " |
|
if l[0] != "NaN": |
|
sentence += "the city of " |
|
sentence += l[0] |
|
pre = True |
|
|
|
if l[1] != "NaN": |
|
if pre: |
|
sentence += ", in " |
|
sentence += "the area of " |
|
sentence += l[1] |
|
pre = True |
|
|
|
if l[2] != "NaN": |
|
if pre: |
|
sentence += ", in " |
|
sentence += "the region of " |
|
sentence += l[2] |
|
pre = True |
|
|
|
if l[3] != "NaN": |
|
if pre: |
|
sentence += ", in " |
|
sentence += l[3] |
|
|
|
return sentence |
|
|
|
def __getitem__(self, i): |
|
output = super().__getitem__(i) |
|
output["text"] = self.get_text(i) |
|
return output |
|
|
|
|
|
import os |
|
import json |
|
|
|
|
|
class Baseline(Dataset): |
|
def __init__( |
|
self, |
|
path, |
|
which, |
|
transforms, |
|
): |
|
"""Initializes the dataset. |
|
Args: |
|
path (str): path to the dataset |
|
which (str): which baseline to use (im2gps, im2gps3k) |
|
transforms (torchvision.transforms): transforms to apply to the images |
|
""" |
|
baselines = { |
|
"im2gps": self.load_im2gps, |
|
"im2gps3k": self.load_im2gps, |
|
"yfcc4k": self.load_yfcc4k, |
|
} |
|
self.path = path |
|
self.samples = baselines[which]() |
|
self.transforms = transforms |
|
self.collate_fn = collate_fn |
|
self.class_name = which |
|
|
|
def load_im2gps( |
|
self, |
|
): |
|
json_path = join(self.path, "info.json") |
|
with open(json_path) as f: |
|
data = json.load(f) |
|
|
|
samples = [] |
|
for f in os.listdir(join(self.path, "images")): |
|
if len(data[f]): |
|
lat = float(data[f][-4].replace("latitude: ", "")) |
|
lon = float(data[f][-3].replace("longitude: ", "")) |
|
samples.append((f, lat, lon)) |
|
|
|
return samples |
|
|
|
def load_yfcc4k( |
|
self, |
|
): |
|
samples = [] |
|
with open(join(self.path, "info.txt")) as f: |
|
lines = f.readlines() |
|
for line in lines: |
|
x = line.split("\t") |
|
f, lon, lat = x[1], x[12], x[13] |
|
samples.append((f + ".jpg", float(lat), float(lon))) |
|
|
|
return samples |
|
|
|
def __getitem__(self, i): |
|
"""Returns an item from the dataset. |
|
Args: |
|
i (int): index of the item |
|
Returns: |
|
dict: dictionary with keys "img", "gps", "idx" and optionally "label" |
|
""" |
|
img_path, lat, lon = self.samples[i] |
|
img = self.transforms( |
|
Image.open(join(self.path, "images", img_path)).convert("RGB") |
|
) |
|
lat, lon = normalize(lat, lon) |
|
gps = torch.FloatTensor([np.radians(lat), np.radians(lon)]).squeeze(0) |
|
|
|
return { |
|
"img": img, |
|
"gps": gps, |
|
"idx": i, |
|
} |
|
|
|
def __len__(self): |
|
return len(self.samples) |
|
|
|
|
|
null_transform = lambda x: x |
|
|