|
import os, sys |
|
|
|
|
|
root_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")) |
|
sys.path.append(root_dir) |
|
|
|
import torch |
|
from utils.image_processing import CenterCrop |
|
from data.extract_embeddings.dataset_with_path import ImageWithPathDataset |
|
import torch |
|
from torchvision import transforms |
|
from pathlib import Path |
|
|
|
|
|
from tqdm import tqdm |
|
import numpy as np |
|
import argparse |
|
|
|
parser = argparse.ArgumentParser() |
|
parser.add_argument( |
|
"--number_of_splits", |
|
type=int, |
|
help="Number of splits to process", |
|
default=1, |
|
) |
|
parser.add_argument( |
|
"--split_index", |
|
type=int, |
|
help="Index of the split to process", |
|
default=0, |
|
) |
|
parser.add_argument( |
|
"--input_path", |
|
type=str, |
|
help="Path to the input dataset", |
|
) |
|
parser.add_argument( |
|
"--output_path", |
|
type=str, |
|
help="Path to the output dataset", |
|
) |
|
|
|
args = parser.parse_args() |
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
model = torch.hub.load("facebookresearch/dinov2", "dinov2_vitl14_reg") |
|
model = torch.compile(model, mode="max-autotune") |
|
model.eval() |
|
model.to(device) |
|
|
|
input_path = Path(args.input_path) |
|
output_path = Path(args.output_path) |
|
|
|
output_path.mkdir(exist_ok=True, parents=True) |
|
augmentation = transforms.Compose( |
|
[ |
|
CenterCrop(ratio="1:1"), |
|
transforms.Resize(336, interpolation=transforms.InterpolationMode.BICUBIC), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), |
|
] |
|
) |
|
dataset = ImageWithPathDataset(input_path, output_path, transform=augmentation) |
|
dataset = torch.utils.data.Subset( |
|
dataset, |
|
range( |
|
args.split_index * len(dataset) // args.number_of_splits, |
|
( |
|
(args.split_index + 1) * len(dataset) // args.number_of_splits |
|
if args.split_index != args.number_of_splits - 1 |
|
else len(dataset) |
|
), |
|
), |
|
) |
|
|
|
batch_size = 128 |
|
dataloader = torch.utils.data.DataLoader( |
|
dataset, batch_size=batch_size, num_workers=16, collate_fn=lambda x: zip(*x) |
|
) |
|
|
|
for images, output_emb_paths in tqdm(dataloader): |
|
images = torch.stack(images, dim=0).to(device) |
|
with torch.no_grad(): |
|
embeddings = model(images) |
|
numpy_embeddings = embeddings.cpu().numpy() |
|
for emb, output_emb_path in zip(numpy_embeddings, output_emb_paths): |
|
np.save(f"{output_emb_path}.npy", emb) |
|
|