|
import os |
|
import argparse |
|
from PIL import Image |
|
from tqdm import tqdm |
|
from torchvision import transforms |
|
from torch.utils.data import Dataset, DataLoader |
|
import torch |
|
import torch.nn.functional as F |
|
from ocr import OCR |
|
from character_error_rate import CharacterErrorRate |
|
from word_error_rate import WordErrorRate |
|
from torchmetrics.image import ( |
|
PeakSignalNoiseRatio, |
|
StructuralSimilarityIndexMeasure, |
|
LearnedPerceptualImagePatchSimilarity, |
|
FrechetInceptionDistance, |
|
) |
|
|
|
|
|
class ImageFolderPairDataset(Dataset): |
|
def __init__(self, dir1, dir2, transform=None): |
|
self.dir1 = dir1 |
|
self.dir2 = dir2 |
|
self.filenames = sorted(os.listdir(dir1)) |
|
self.transform = transform |
|
|
|
def __len__(self): |
|
return len(self.filenames) |
|
|
|
def __getitem__(self, idx): |
|
name = self.filenames[idx] |
|
img1 = Image.open(os.path.join(self.dir1, name)).convert("RGB") |
|
img2 = Image.open(os.path.join(self.dir2, name)).convert("RGB") |
|
if self.transform: |
|
img1 = self.transform(img1) |
|
img2 = self.transform(img2) |
|
return img1, img2 |
|
|
|
|
|
def evaluate(args): |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
print(f"Using device: {device}") |
|
|
|
transform = transforms.Compose( |
|
[transforms.Resize((args.image_size, args.image_size)), transforms.ToTensor()] |
|
) |
|
|
|
dataset = ImageFolderPairDataset( |
|
args.original_dir, args.reconstructed_dir, transform |
|
) |
|
loader = DataLoader( |
|
dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers |
|
) |
|
|
|
if "cer" in args.metrics or "wer" in args.metrics: |
|
ocr = OCR(device) |
|
|
|
|
|
metrics = {} |
|
|
|
if "psnr" in args.metrics: |
|
metrics["psnr"] = PeakSignalNoiseRatio().to(device) |
|
if "ssim" in args.metrics: |
|
metrics["ssim"] = StructuralSimilarityIndexMeasure().to(device) |
|
if "lpips" in args.metrics: |
|
metrics["lpips"] = LearnedPerceptualImagePatchSimilarity().to(device) |
|
if "fid" in args.metrics: |
|
metrics["fid"] = FrechetInceptionDistance().to(device) |
|
if "cer" in args.metrics: |
|
metrics["cer"] = CharacterErrorRate(ocr) |
|
if "wer" in args.metrics: |
|
metrics["wer"] = WordErrorRate(ocr) |
|
|
|
for batch in tqdm(loader, desc="Evaluating"): |
|
|
|
img1, img2 = [b.to(device) for b in batch] |
|
|
|
if "psnr" in metrics: |
|
metrics["psnr"].update(img2, img1) |
|
if "ssim" in metrics: |
|
metrics["ssim"].update(img2, img1) |
|
if "lpips" in metrics: |
|
metrics["lpips"].update(img2, img1) |
|
if "cer" in metrics: |
|
metrics["cer"].update(img2, img1) |
|
if "wer" in metrics: |
|
metrics["wer"].update(img2, img1) |
|
if "fid" in metrics: |
|
img1_uint8 = (img1 * 255).clamp(0, 255).to(torch.uint8) |
|
img2_uint8 = (img2 * 255).clamp(0, 255).to(torch.uint8) |
|
metrics["fid"].update(img1_uint8, real=True) |
|
metrics["fid"].update(img2_uint8, real=False) |
|
|
|
print("\nResults:") |
|
for name, metric in metrics.items(): |
|
print(f"{name.upper()}", end="\t") |
|
print() |
|
for name, metric in metrics.items(): |
|
result = metric.compute().item() |
|
print(f"{result:.4f}", end="\t") |
|
print() |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument( |
|
"--original_dir", type=str, required=True, help="Path to original images" |
|
) |
|
parser.add_argument( |
|
"--reconstructed_dir", |
|
type=str, |
|
required=True, |
|
help="Path to reconstructed images", |
|
) |
|
parser.add_argument( |
|
"--metrics", |
|
nargs="+", |
|
default=["psnr", "ssim", "lpips", "fid"], |
|
help="Metrics to compute: psnr, ssim, lpips, fid", |
|
) |
|
parser.add_argument( |
|
"--batch_size", type=int, default=8, help="Batch size for processing" |
|
) |
|
parser.add_argument("--image_size", type=int, default=256, help="Image resize size") |
|
parser.add_argument( |
|
"--num_workers", type=int, default=4, help="Number of workers for DataLoader" |
|
) |
|
args = parser.parse_args() |
|
|
|
evaluate(args) |
|
|