VTBench / evaluations /evaluate_images.py
huaweilin's picture
update
14ce5a9
import os
import argparse
from PIL import Image
from tqdm import tqdm
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
import torch
import torch.nn.functional as F
from ocr import OCR
from character_error_rate import CharacterErrorRate
from word_error_rate import WordErrorRate
from torchmetrics.image import (
PeakSignalNoiseRatio,
StructuralSimilarityIndexMeasure,
LearnedPerceptualImagePatchSimilarity,
FrechetInceptionDistance,
)
class ImageFolderPairDataset(Dataset):
def __init__(self, dir1, dir2, transform=None):
self.dir1 = dir1
self.dir2 = dir2
self.filenames = sorted(os.listdir(dir1))
self.transform = transform
def __len__(self):
return len(self.filenames)
def __getitem__(self, idx):
name = self.filenames[idx]
img1 = Image.open(os.path.join(self.dir1, name)).convert("RGB")
img2 = Image.open(os.path.join(self.dir2, name)).convert("RGB")
if self.transform:
img1 = self.transform(img1)
img2 = self.transform(img2)
return img1, img2
def evaluate(args):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
transform = transforms.Compose(
[transforms.Resize((args.image_size, args.image_size)), transforms.ToTensor()]
)
dataset = ImageFolderPairDataset(
args.original_dir, args.reconstructed_dir, transform
)
loader = DataLoader(
dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers
)
if "cer" in args.metrics or "wer" in args.metrics:
ocr = OCR(device)
# Metrics init
metrics = {}
if "psnr" in args.metrics:
metrics["psnr"] = PeakSignalNoiseRatio().to(device)
if "ssim" in args.metrics:
metrics["ssim"] = StructuralSimilarityIndexMeasure().to(device)
if "lpips" in args.metrics:
metrics["lpips"] = LearnedPerceptualImagePatchSimilarity().to(device)
if "fid" in args.metrics:
metrics["fid"] = FrechetInceptionDistance().to(device)
if "cer" in args.metrics:
metrics["cer"] = CharacterErrorRate(ocr)
if "wer" in args.metrics:
metrics["wer"] = WordErrorRate(ocr)
for batch in tqdm(loader, desc="Evaluating"):
# img1, img1_path, img2, img2_path = [b.to(device) for b in batch]
img1, img2 = [b.to(device) for b in batch]
if "psnr" in metrics:
metrics["psnr"].update(img2, img1)
if "ssim" in metrics:
metrics["ssim"].update(img2, img1)
if "lpips" in metrics:
metrics["lpips"].update(img2, img1)
if "cer" in metrics:
metrics["cer"].update(img2, img1)
if "wer" in metrics:
metrics["wer"].update(img2, img1)
if "fid" in metrics:
img1_uint8 = (img1 * 255).clamp(0, 255).to(torch.uint8)
img2_uint8 = (img2 * 255).clamp(0, 255).to(torch.uint8)
metrics["fid"].update(img1_uint8, real=True)
metrics["fid"].update(img2_uint8, real=False)
print("\nResults:")
for name, metric in metrics.items():
print(f"{name.upper()}", end="\t")
print()
for name, metric in metrics.items():
result = metric.compute().item()
print(f"{result:.4f}", end="\t")
print()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--original_dir", type=str, required=True, help="Path to original images"
)
parser.add_argument(
"--reconstructed_dir",
type=str,
required=True,
help="Path to reconstructed images",
)
parser.add_argument(
"--metrics",
nargs="+",
default=["psnr", "ssim", "lpips", "fid"],
help="Metrics to compute: psnr, ssim, lpips, fid",
)
parser.add_argument(
"--batch_size", type=int, default=8, help="Batch size for processing"
)
parser.add_argument("--image_size", type=int, default=256, help="Image resize size")
parser.add_argument(
"--num_workers", type=int, default=4, help="Number of workers for DataLoader"
)
args = parser.parse_args()
evaluate(args)