|
|
|
|
|
|
|
import os |
|
|
|
from datasets.transforms import get_pair_transforms |
|
from PIL import Image |
|
from torch.utils.data import Dataset |
|
|
|
|
|
def load_image(impath): |
|
return Image.open(impath) |
|
|
|
|
|
def load_pairs_from_cache_file(fname, root=""): |
|
assert os.path.isfile( |
|
fname |
|
), "cannot parse pairs from {:s}, file does not exist".format(fname) |
|
with open(fname, "r") as fid: |
|
lines = fid.read().strip().splitlines() |
|
pairs = [ |
|
(os.path.join(root, l.split()[0]), os.path.join(root, l.split()[1])) |
|
for l in lines |
|
] |
|
return pairs |
|
|
|
|
|
def load_pairs_from_list_file(fname, root=""): |
|
assert os.path.isfile( |
|
fname |
|
), "cannot parse pairs from {:s}, file does not exist".format(fname) |
|
with open(fname, "r") as fid: |
|
lines = fid.read().strip().splitlines() |
|
pairs = [ |
|
(os.path.join(root, l + "_1.jpg"), os.path.join(root, l + "_2.jpg")) |
|
for l in lines |
|
if not l.startswith("#") |
|
] |
|
return pairs |
|
|
|
|
|
def write_cache_file(fname, pairs, root=""): |
|
if len(root) > 0: |
|
if not root.endswith("/"): |
|
root += "/" |
|
assert os.path.isdir(root) |
|
s = "" |
|
for im1, im2 in pairs: |
|
if len(root) > 0: |
|
assert im1.startswith(root), im1 |
|
assert im2.startswith(root), im2 |
|
s += "{:s} {:s}\n".format(im1[len(root) :], im2[len(root) :]) |
|
with open(fname, "w") as fid: |
|
fid.write(s[:-1]) |
|
|
|
|
|
def parse_and_cache_all_pairs(dname, data_dir="./data/"): |
|
if dname == "habitat_release": |
|
dirname = os.path.join(data_dir, "habitat_release") |
|
assert os.path.isdir(dirname), ( |
|
"cannot find folder for habitat_release pairs: " + dirname |
|
) |
|
cache_file = os.path.join(dirname, "pairs.txt") |
|
assert not os.path.isfile(cache_file), ( |
|
"cache file already exists: " + cache_file |
|
) |
|
|
|
print("Parsing pairs for dataset: " + dname) |
|
pairs = [] |
|
for root, dirs, files in os.walk(dirname): |
|
if "val" in root: |
|
continue |
|
dirs.sort() |
|
pairs += [ |
|
( |
|
os.path.join(root, f), |
|
os.path.join(root, f[: -len("_1.jpeg")] + "_2.jpeg"), |
|
) |
|
for f in sorted(files) |
|
if f.endswith("_1.jpeg") |
|
] |
|
print("Found {:,} pairs".format(len(pairs))) |
|
print("Writing cache to: " + cache_file) |
|
write_cache_file(cache_file, pairs, root=dirname) |
|
|
|
else: |
|
raise NotImplementedError("Unknown dataset: " + dname) |
|
|
|
|
|
def dnames_to_image_pairs(dnames, data_dir="./data/"): |
|
""" |
|
dnames: list of datasets with image pairs, separated by + |
|
""" |
|
all_pairs = [] |
|
for dname in dnames.split("+"): |
|
if dname == "habitat_release": |
|
dirname = os.path.join(data_dir, "habitat_release") |
|
assert os.path.isdir(dirname), ( |
|
"cannot find folder for habitat_release pairs: " + dirname |
|
) |
|
cache_file = os.path.join(dirname, "pairs.txt") |
|
assert os.path.isfile(cache_file), ( |
|
"cannot find cache file for habitat_release pairs, please first create the cache file, see instructions. " |
|
+ cache_file |
|
) |
|
pairs = load_pairs_from_cache_file(cache_file, root=dirname) |
|
elif dname in ["ARKitScenes", "MegaDepth", "3DStreetView", "IndoorVL"]: |
|
dirname = os.path.join(data_dir, dname + "_crops") |
|
assert os.path.isdir( |
|
dirname |
|
), "cannot find folder for {:s} pairs: {:s}".format(dname, dirname) |
|
list_file = os.path.join(dirname, "listing.txt") |
|
assert os.path.isfile( |
|
list_file |
|
), "cannot find list file for {:s} pairs, see instructions. {:s}".format( |
|
dname, list_file |
|
) |
|
pairs = load_pairs_from_list_file(list_file, root=dirname) |
|
print(" {:s}: {:,} pairs".format(dname, len(pairs))) |
|
all_pairs += pairs |
|
if "+" in dnames: |
|
print(" Total: {:,} pairs".format(len(all_pairs))) |
|
return all_pairs |
|
|
|
|
|
class PairsDataset(Dataset): |
|
def __init__( |
|
self, dnames, trfs="", totensor=True, normalize=True, data_dir="./data/" |
|
): |
|
super().__init__() |
|
self.image_pairs = dnames_to_image_pairs(dnames, data_dir=data_dir) |
|
self.transforms = get_pair_transforms( |
|
transform_str=trfs, totensor=totensor, normalize=normalize |
|
) |
|
|
|
def __len__(self): |
|
return len(self.image_pairs) |
|
|
|
def __getitem__(self, index): |
|
im1path, im2path = self.image_pairs[index] |
|
im1 = load_image(im1path) |
|
im2 = load_image(im2path) |
|
if self.transforms is not None: |
|
im1, im2 = self.transforms(im1, im2) |
|
return im1, im2 |
|
|
|
|
|
if __name__ == "__main__": |
|
import argparse |
|
|
|
parser = argparse.ArgumentParser( |
|
prog="Computing and caching list of pairs for a given dataset" |
|
) |
|
parser.add_argument( |
|
"--data_dir", default="./data/", type=str, help="path where data are stored" |
|
) |
|
parser.add_argument( |
|
"--dataset", default="habitat_release", type=str, help="name of the dataset" |
|
) |
|
args = parser.parse_args() |
|
parse_and_cache_all_pairs(dname=args.dataset, data_dir=args.data_dir) |
|
|