|
import traceback |
|
import os |
|
import sys |
|
import PIL |
|
import json |
|
import torch |
|
import numpy as np |
|
import pandas as pd |
|
import operator |
|
import joblib |
|
import reverse_geocoder |
|
|
|
from PIL import Image |
|
from itertools import cycle |
|
from tqdm.auto import tqdm, trange |
|
from os.path import join |
|
from PIL import Image |
|
|
|
from tqdm import tqdm |
|
from collections import Counter |
|
from transformers import CLIPProcessor, CLIPModel |
|
from torch.utils.data import Dataset, DataLoader |
|
from torch.nn import functional as F |
|
from utils import haversine |
|
|
|
|
|
class GeoDataset(Dataset): |
|
def __init__(self, image_folder, annotation_file, tag="image_id"): |
|
self.image_folder = image_folder |
|
gt = pd.read_csv(annotation_file, dtype={tag: str}) |
|
files = set([f.replace(".jpg", "") for f in os.listdir(image_folder)]) |
|
gt = gt[gt[tag].isin(files)] |
|
self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") |
|
self.gt = [ |
|
(g[1][tag], g[1]["latitude"], g[1]["longitude"]) for g in gt.iterrows() |
|
] |
|
self.tag = tag |
|
|
|
def fid(self, i): |
|
return self.gt[i][0] |
|
|
|
def latlon(self, i): |
|
return self.gt[i][1] |
|
|
|
def __len__(self): |
|
return len(self.gt) |
|
|
|
def __getitem__(self, idx): |
|
fp = join(self.image_folder, self.gt[idx][0] + ".jpg") |
|
pil = PIL.Image.open(fp) |
|
proc = self.processor(images=pil, return_tensors="pt") |
|
proc["image_id"] = self.gt[idx][0] |
|
return proc |
|
|
|
|
|
@torch.no_grad() |
|
def compute_features_clip(img, model): |
|
image_ids = img.data.pop("image_id") |
|
image_input = img.to(model.device) |
|
image_input["pixel_values"] = image_input["pixel_values"].squeeze(1) |
|
features = model.get_image_features(**image_input) |
|
features /= features.norm(dim=-1, keepdim=True) |
|
return image_ids, features.cpu() |
|
|
|
|
|
def get_prompts(country, region, sub_region, city): |
|
a = country if country != "" else None |
|
b, c, d = None, None, None |
|
if a is not None: |
|
b = country + ", " + region if region != "" else None |
|
if b is not None: |
|
c = ( |
|
country + ", " + region + ", " + sub_region |
|
if sub_region != "" |
|
else None |
|
) |
|
d = ( |
|
country + ", " + region + ", " + sub_region + ", " + city |
|
if city != "" |
|
else None |
|
) |
|
return a, b, c, d |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
import argparse |
|
|
|
parser = argparse.ArgumentParser() |
|
parser.add_argument( |
|
"--annotation_file", type=str, required=False, default="train.csv" |
|
) |
|
parser.add_argument( |
|
"--features_parent", type=str, default="/home/isig/gaia-v2/faiss/street-clip" |
|
) |
|
parser.add_argument( |
|
"--data_parent", type=str, default="/home/isig/gaia-v2/loic-data/" |
|
) |
|
|
|
args = parser.parse_args() |
|
test_path_csv = join(args.data_parent, "test.csv") |
|
test_image_dir = join(args.data_parent, "test") |
|
save_path = join(args.features_parent, "indexes/test.index") |
|
test_features_dir = join(args.features_parent, "indexes/features-test") |
|
|
|
processor = CLIPProcessor.from_pretrained("geolocal/StreetCLIP") |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
model = CLIPModel.from_pretrained("geolocal/StreetCLIP").to(device) |
|
|
|
@torch.no_grad() |
|
def compute_text_features_clip(text): |
|
text_pt = processor(text=text, return_tensors="pt").to(device) |
|
features = model.get_text_features(**text_pt) |
|
features /= features.norm(dim=-1, keepdim=True) |
|
return features.cpu().squeeze(0).numpy() |
|
|
|
import country_converter as coco |
|
|
|
if not os.path.isfile("text_street-clip-features.pkl"): |
|
if not os.path.isfile("rg_cities1000.csv"): |
|
os.system( |
|
"wget https://raw.githubusercontent.com/thampiman/reverse-geocoder/master/reverse_geocoder/rg_cities1000.csv" |
|
) |
|
|
|
cities = pd.read_csv("rg_cities1000.csv") |
|
cities = cities[["lat", "lon", "name", "admin1", "admin2", "cc"]] |
|
reprs = {0: {}, 1: {}, 2: {}, 3: {}} |
|
for line in tqdm( |
|
cities.iterrows(), total=len(cities), desc="Creating hierarchy" |
|
): |
|
lat, lon, city, region, sub_region, cc = line[1] |
|
try: |
|
city, region, sub_region, cc = [ |
|
("" if pd.isna(x) else x) |
|
for x in [ |
|
city, |
|
region, |
|
sub_region, |
|
coco.convert(cc, to="name_short"), |
|
] |
|
] |
|
a, b, c, d = get_prompts(cc, region, sub_region, city) |
|
if a is not None: |
|
if a not in reprs[0]: |
|
reprs[0][a] = { |
|
"gps": {(lat, lon)}, |
|
"embedding": compute_text_features_clip(a), |
|
} |
|
else: |
|
reprs[0][a]["gps"].add((lat, lon)) |
|
|
|
if b is not None: |
|
if b not in reprs[1]: |
|
reprs[1][b] = { |
|
"gps": {(lat, lon)}, |
|
"embedding": compute_text_features_clip(b), |
|
} |
|
else: |
|
reprs[1][b]["gps"].add((lat, lon)) |
|
|
|
if c is not None: |
|
if c not in reprs[2]: |
|
reprs[2][c] = { |
|
"gps": {(lat, lon)}, |
|
"embedding": compute_text_features_clip(c), |
|
} |
|
else: |
|
reprs[2][c]["gps"].add((lat, lon)) |
|
|
|
if d is not None: |
|
if d not in reprs[3]: |
|
reprs[3][d] = { |
|
"gps": {(lat, lon)}, |
|
"embedding": compute_text_features_clip( |
|
d.replace(", , ", ", ") |
|
), |
|
} |
|
else: |
|
reprs[3][d]["gps"].add((lat, lon)) |
|
except Exception as e: |
|
|
|
with open("log.txt", "a") as f: |
|
print(traceback.format_exc(), file=f) |
|
|
|
reprs[-1] = {"": {"gps": (0, 0), "embedding": compute_text_features_clip("")}} |
|
|
|
|
|
for i in range(4): |
|
for k in reprs[i].keys(): |
|
reprs[i][k]["gps"] = tuple( |
|
np.array(list(reprs[i][k]["gps"])).mean(axis=0).tolist() |
|
) |
|
|
|
joblib.dump(reprs, "text_street-clip-features.pkl") |
|
else: |
|
reprs = joblib.load("text_street-clip-features.pkl") |
|
|
|
def get_loc(x): |
|
location = reverse_geocoder.search(x[0].tolist())[0] |
|
country = coco.convert(names=location["cc"], to="name_short") |
|
region = location.get("admin1", "") |
|
sub_region = location.get("admin2", "") |
|
city = location.get("name", "") |
|
a, b, c, d = get_prompts(country, region, sub_region, city) |
|
return a, b, c, d |
|
|
|
def matches(embed, repr, control, gt, sw=None): |
|
first_max = max( |
|
( |
|
(k, embed.dot(v["embedding"])) |
|
for k, v in repr.items() |
|
if sw is None or k.startswith(sw) |
|
), |
|
key=operator.itemgetter(1), |
|
) |
|
if first_max[1] > embed.dot(control["embedding"]): |
|
return repr[first_max[0]]["gps"], gt == first_max[0] |
|
else: |
|
return control["gps"], False |
|
|
|
def get_match_values(gt, embed, N, pos): |
|
xa, xb, xc, xd = get_loc(gt) |
|
|
|
if xa is not None: |
|
N["country"] += 1 |
|
gps, flag = matches(embed, reprs[0], reprs[-1][""], xa) |
|
if flag: |
|
pos["country"] += 1 |
|
if xb is not None: |
|
N["region"] += 1 |
|
gps, flag = matches(embed, reprs[1], reprs[0][xa], xb, sw=xa) |
|
if flag: |
|
pos["region"] += 1 |
|
if xc is not None: |
|
N["sub-region"] += 1 |
|
gps, flag = matches( |
|
embed, reprs[2], reprs[1][xb], xc, sw=xb |
|
) |
|
if flag: |
|
pos["sub-region"] += 1 |
|
if xd is not None: |
|
N["city"] += 1 |
|
gps, flag = matches( |
|
embed, reprs[3], reprs[2][xc], xd, sw=xc |
|
) |
|
if flag: |
|
pos["city"] += 1 |
|
else: |
|
if xd is not None: |
|
N["city"] += 1 |
|
gps, flag = matches( |
|
embed, reprs[3], reprs[1][xb], xd, sw=xb + ", " |
|
) |
|
if flag: |
|
pos["city"] += 1 |
|
|
|
haversine(np.array(gps)[None, :], np.array(gt), N, pos) |
|
|
|
def compute_print_accuracy(N, pos): |
|
for k in N.keys(): |
|
pos[k] /= N[k] |
|
|
|
|
|
print( |
|
f'Accuracy: {pos["country"]*100.0:.2f} (country), {pos["region"]*100.0:.2f} (region), {pos["sub-region"]*100.0:.2f} (sub-region), {pos["city"]*100.0:.2f} (city)' |
|
) |
|
print( |
|
f'Haversine: {pos["haversine"]:.2f} (haversine), {pos["geoguessr"]:.2f} (geoguessr)' |
|
) |
|
|
|
import joblib |
|
|
|
data = GeoDataset(test_image_dir, test_path_csv, tag="id") |
|
test_gt = pd.read_csv(test_path_csv, dtype={"id": str})[ |
|
["id", "latitude", "longitude"] |
|
] |
|
test_gt = { |
|
g[1]["id"]: np.array([g[1]["latitude"], g[1]["longitude"]]) |
|
for g in tqdm(test_gt.iterrows(), total=len(test_gt), desc="Loading test_gt") |
|
} |
|
|
|
with open("/home/isig/gaia-v2/loic/plonk/test3_indices.txt", "r") as f: |
|
|
|
lines = f.readlines() |
|
|
|
lines = [l.strip() for l in lines] |
|
|
|
lines = set(lines) |
|
|
|
train_test = [] |
|
N, pos = Counter(), Counter() |
|
for f in tqdm(os.listdir(test_features_dir)): |
|
if f.replace(".npy", "") not in lines: |
|
continue |
|
query_vector = np.squeeze(np.load(join(test_features_dir, f))) |
|
test_gps = test_gt[f.replace(".npy", "")][None, :] |
|
get_match_values(test_gps, query_vector, N, pos) |
|
|
|
compute_print_accuracy(N, pos) |
|
|