|
import webdataset as wds |
|
from pathlib import Path |
|
import pandas as pd |
|
import numpy as np |
|
from PIL import Image |
|
import torch |
|
import torchvision.transforms as transforms |
|
from torch.utils.data import Dataset, DataLoader |
|
from utils.image_processing import CenterCrop |
|
from tqdm import tqdm |
|
import os |
|
|
|
tqdm.pandas() |
|
|
|
print("Loading dinov2") |
|
augmentation_dinov2 = transforms.Compose( |
|
[ |
|
CenterCrop(ratio="1:1"), |
|
transforms.Resize(336, interpolation=transforms.InterpolationMode.BICUBIC), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), |
|
] |
|
) |
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
model = torch.hub.load("facebookresearch/dinov2", "dinov2_vitl14_reg") |
|
model.eval() |
|
model.to(device) |
|
print(f"Model loaded on {device}") |
|
|
|
|
|
class YFCCDataset(Dataset): |
|
def __init__(self, csv_path, images_root): |
|
self.df = pd.read_csv(csv_path, sep="\t") |
|
self.df = self.df[self.df["latitude"].notna() & self.df["longitude"].notna()] |
|
self.images_root = Path(images_root) |
|
|
|
|
|
print("Checking image existence...") |
|
self.df["image_path"] = self.df["hash"].progress_apply( |
|
lambda x: self.images_root / x[:3] / x[3:6] / f"{x}.jpg" |
|
) |
|
|
|
def __len__(self): |
|
return len(self.df) |
|
|
|
def __getitem__(self, idx): |
|
row = self.df.iloc[idx] |
|
image_path = row["image_path"] |
|
|
|
if not image_path.exists(): |
|
print(f"Image {image_path} does not exist") |
|
return None |
|
|
|
|
|
with open(image_path, "rb") as f: |
|
jpg_data = f.read() |
|
|
|
image = Image.open(image_path).convert("RGB") |
|
image = augmentation_dinov2(image) |
|
|
|
|
|
metadata = row.to_dict() |
|
del metadata["image_path"] |
|
|
|
return { |
|
"image": image, |
|
"jpg_data": jpg_data, |
|
"photo_id": str(row["photo_id"]), |
|
"metadata": metadata, |
|
} |
|
|
|
|
|
def custom_collate(batch): |
|
""" |
|
Custom collate function to handle dictionary items from the dataset |
|
""" |
|
return { |
|
"image": torch.stack([item["image"] for item in batch if item is not None]), |
|
"jpg_data": [item["jpg_data"] for item in batch if item is not None], |
|
"photo_id": [item["photo_id"] for item in batch if item is not None], |
|
"metadata": [item["metadata"] for item in batch if item is not None], |
|
} |
|
|
|
|
|
def process_batch(batch, model, device): |
|
images = batch["image"].to(device) |
|
with torch.no_grad(): |
|
embeddings = model(images).cpu().numpy() |
|
|
|
samples = [] |
|
for i in range(len(batch["photo_id"])): |
|
sample = { |
|
"__key__": batch["photo_id"][i], |
|
"jpg": batch["jpg_data"][i], |
|
"dinov2_vitl14_registers.npy": embeddings[i], |
|
"json": batch["metadata"][i], |
|
} |
|
samples.append(sample) |
|
return samples |
|
|
|
|
|
def main( |
|
src_csv, |
|
src_images, |
|
dest_folder, |
|
num_samples_per_tar=10000, |
|
job_offset=0, |
|
batch_size=32, |
|
): |
|
print(f"Loading dataset") |
|
dataset = YFCCDataset(src_csv, src_images) |
|
dataloader = DataLoader( |
|
dataset, |
|
batch_size=batch_size, |
|
shuffle=False, |
|
num_workers=8, |
|
pin_memory=True, |
|
collate_fn=custom_collate, |
|
) |
|
|
|
print(f"Processing job {job_offset} with {len(dataset)} samples") |
|
with wds.ShardWriter( |
|
str(Path(dest_folder) / "%04d.tar"), |
|
maxcount=num_samples_per_tar, |
|
start_shard=10 * job_offset, |
|
) as sink: |
|
for batch in tqdm(dataloader): |
|
samples = process_batch(batch, model, device) |
|
for sample in samples: |
|
sink.write(sample) |
|
|
|
|
|
if __name__ == "__main__": |
|
import argparse |
|
|
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--src_csv_dir", help="pixel_input_folder") |
|
parser.add_argument("--src_images_dir", help="path to source images") |
|
parser.add_argument("--dest", help="path to destination web") |
|
parser.add_argument( |
|
"--num_samples_per_tar", |
|
help="number of samples per tar", |
|
type=int, |
|
default=10000, |
|
) |
|
parser.add_argument("--job_offset", help="job offset", type=int, default=0) |
|
parser.add_argument("--batch_size", help="batch size", type=int, default=256) |
|
args = parser.parse_args() |
|
|
|
dest = Path(args.dest) |
|
dest.mkdir(exist_ok=True, parents=True) |
|
|
|
main( |
|
Path(args.src_csv_dir) / f"{str(args.job_offset).zfill(3)}.csv", |
|
args.src_images_dir, |
|
args.dest, |
|
args.num_samples_per_tar, |
|
args.job_offset, |
|
args.batch_size, |
|
) |
|
|