|
import sys |
|
from pathlib import Path |
|
|
|
sys.path.append(str(Path(__file__).resolve().parent.parent.parent)) |
|
import argparse |
|
import json |
|
from collections import UserDict |
|
from pathlib import Path |
|
|
|
import numpy as np |
|
import torch |
|
import webdataset as wds |
|
from PIL import Image |
|
from torchvision import transforms |
|
from tqdm import tqdm |
|
from webdataset.autodecode import ImageHandler |
|
from utils.image_processing import CenterCrop |
|
|
|
print("Loading dinov2") |
|
augmentation_dinov2 = transforms.Compose( |
|
[ |
|
CenterCrop(ratio="1:1"), |
|
transforms.Resize(336, interpolation=transforms.InterpolationMode.BICUBIC), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), |
|
] |
|
) |
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
dinov2_model = torch.hub.load("facebookresearch/dinov2", "dinov2_vitl14_reg") |
|
dinov2_model.eval() |
|
dinov2_model.to(device) |
|
print(f"Model loaded on {device}") |
|
|
|
|
|
def dict_collate(batch): |
|
output_dict = {} |
|
if isinstance(batch[0], dict): |
|
for key in batch[0].keys(): |
|
list_key = [d[key] for d in batch] |
|
if key != "json": |
|
output_dict[key] = dict_collate(list_key) |
|
else: |
|
output_dict[key] = list_key |
|
return output_dict |
|
elif isinstance(batch[0], Image.Image): |
|
return [img for img in batch] |
|
else: |
|
return torch.utils.data.dataloader.default_collate(batch) |
|
|
|
|
|
def log_and_continue(exn): |
|
"""Call in an exception handler to ignore any exception, issue a warning, and continue.""" |
|
|
|
return True |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
def add_clip_scores_and_embeddings(src, dest, batch_size=512): |
|
dataset = wds.DataPipeline( |
|
wds.SimpleShardList(str(src)), |
|
wds.split_by_worker, |
|
wds.tarfile_to_samples(), |
|
wds.rename( |
|
__key__="__key__", |
|
dino_image="jpg", |
|
image="jpg", |
|
street_clip="street_clip.npy", |
|
json="json", |
|
), |
|
wds.decode( |
|
ImageHandler("pilrgb", ["dino_image"]) |
|
), |
|
wds.map_dict( |
|
dino_image=augmentation_dinov2, |
|
image=lambda x: x, |
|
street_clip=lambda x: x, |
|
json=lambda x: x, |
|
), |
|
wds.to_tuple( |
|
"__key__", |
|
"dino_image", |
|
"street_clip", |
|
"image", |
|
"json", |
|
), |
|
wds.batched(batch_size), |
|
) |
|
loader = wds.WebLoader(dataset, num_workers=8, batch_size=None) |
|
with wds.TarWriter(str(dest)) as sink: |
|
for batch in tqdm(loader, total=10000 // batch_size): |
|
( |
|
keys, |
|
dino_image, |
|
street_clip, |
|
image, |
|
json, |
|
) = batch |
|
dino_image = dino_image.to(device) |
|
with torch.no_grad(): |
|
dino_embedding = dinov2_model(dino_image).cpu().numpy() |
|
for i in range(len(keys)): |
|
sample = { |
|
"__key__": keys[i], |
|
"jpg": image[i], |
|
"street_clip.npy": street_clip[i], |
|
"json": json[i], |
|
"dinov2_vitl14_registers.npy": dino_embedding[i], |
|
} |
|
sink.write(sample) |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--src", help="path to source files") |
|
parser.add_argument("--dest", help="path to destination files") |
|
parser.add_argument("--shard_id", help="shard id") |
|
args = parser.parse_args() |
|
|
|
src = Path(args.src) |
|
list_of_shards = list(src.glob("*.tar")) |
|
list_of_shards.sort() |
|
shard = str(list_of_shards[int(args.shard_id)]).split("/")[-1] |
|
dest = Path(args.dest) |
|
dest.mkdir(exist_ok=True, parents=True) |
|
batch_size = 256 |
|
|
|
print(f"Loading {shard}") |
|
|
|
tar_name = shard.split(".")[0] |
|
|
|
src_shard = src / shard |
|
|
|
print(f"Processing {src_shard} to {dest / shard}") |
|
add_clip_scores_and_embeddings(src_shard, dest / shard, batch_size) |
|
|