File size: 7,470 Bytes
90a9dd3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import torchmetrics
import torch
from PIL import Image
import argparse
from flair.utils import data_utils
import os
import tqdm
import torch.nn.functional as F
from torchmetrics.image.kid import KernelInceptionDistance


MAX_BATCH_SIZE = None

@torch.no_grad()
def main(args):
    # Determine device
    if args.device == "cuda" and torch.cuda.is_available():
        device = torch.device("cuda")
    else:
        device = torch.device("cpu")
    print(f"Using device: {device}")

    # load images
    gt_iterator = data_utils.yield_images(os.path.abspath(args.gt), size=args.resolution)
    pred_iterator = data_utils.yield_images(os.path.abspath(args.pred), size=args.resolution)
    fid_metric = torchmetrics.image.fid.FrechetInceptionDistance(normalize=True).to(device)
    # kid_metric = KernelInceptionDistance(subset_size=args.kid_subset_size, normalize=True).to(device)
    lpips_metric = torchmetrics.image.lpip.LearnedPerceptualImagePatchSimilarity(
        net_type="alex", normalize=False, reduction="mean"
    ).to(device)
    if args.patch_size:
        patch_fid_metric = torchmetrics.image.fid.FrechetInceptionDistance(normalize=True).to(device)
        # patch_kid_metric = KernelInceptionDistance(subset_size=args.kid_subset_size, normalize=True).to(device)
    psnr_list = []
    lpips_list = []
    ssim_list = []
    # iterate over images
    for gt, pred in tqdm.tqdm(zip(gt_iterator, pred_iterator)):
        # Move tensors to the selected device
        gt = gt.to(device)
        pred = pred.to(device)

        # resize gt to pred size
        if gt.shape[-2:] != (args.resolution, args.resolution):
            gt = F.interpolate(gt, size=args.resolution, mode="area")
        if pred.shape[-2:] != (args.resolution, args.resolution):
            pred = F.interpolate(pred, size=args.resolution, mode="area")
        # to range [0,1]
        gt_norm = gt * 0.5 + 0.5
        pred_norm = pred * 0.5 + 0.5
        # compute PSNR
        psnr = torchmetrics.functional.image.peak_signal_noise_ratio(
            pred_norm, gt_norm, data_range=1.0
        )
        psnr_list.append(psnr.cpu()) # Move result to CPU
        # compute LPIPS
        lpips_score = lpips_metric(pred.clip(-1,1), gt.clip(-1,1))
        lpips_list.append(lpips_score.cpu()) # Move result to CPU
        # compute SSIM
        ssim = torchmetrics.functional.image.structural_similarity_index_measure(
            pred_norm, gt_norm, data_range=1.0
        )
        ssim_list.append(ssim.cpu()) # Move result to CPU
        print(f"PSNR: {psnr}, LPIPS: {lpips_score}, SSIM: {ssim}")
        # compute FID
        # Ensure inputs are on the correct device (already handled by moving gt/pred earlier)
        fid_metric.update(gt_norm, real=False)
        fid_metric.update(pred_norm, real=True)
        # compute KID
        # kid_metric.update(pred, real=False)
        # kid_metric.update(gt, real=True)
        # compute Patchwise FID/KID if patch_size is specified
        if args.patch_size:
            # Extract patches
            patch_size = args.patch_size
            gt_patches = F.unfold(gt_norm, kernel_size=patch_size, stride=patch_size)
            pred_patches = F.unfold(pred_norm, kernel_size=patch_size, stride=patch_size)
            # Reshape patches: (B, C*P*P, N_patches) -> (B*N_patches, C, P, P)
            B, C, H, W = gt.shape
            N_patches = gt_patches.shape[-1]
            gt_patches = gt_patches.permute(0, 2, 1).reshape(B * N_patches, C, patch_size, patch_size)
            pred_patches = pred_patches.permute(0, 2, 1).reshape(B * N_patches, C, patch_size, patch_size)
            # Update patch FID metric (inputs are already on the correct device)
            # Update patch KID metric
            # process mini batches of patches
            if MAX_BATCH_SIZE is None:
                patch_fid_metric.update(pred_patches, real=False)
                patch_fid_metric.update(gt_patches, real=True)
                # patch_kid_metric.update(pred_patches, real=False)
                # patch_kid_metric.update(gt_patches, real=True)
            else:
                for i in range(0, N_patches, MAX_BATCH_SIZE):
                    patch_fid_metric.update(pred_patches[i:i + MAX_BATCH_SIZE], real=False)
                    patch_fid_metric.update(gt_patches[i:i + MAX_BATCH_SIZE], real=True)
                    # patch_kid_metric.update(pred_patches[i:i + MAX_BATCH_SIZE], real=False)
                    # patch_kid_metric.update(gt_patches[i:i + MAX_BATCH_SIZE], real=True)

    # compute FID
    fid = fid_metric.compute()
    # compute KID
    # kid_mean, kid_std = kid_metric.compute()
    if args.patch_size:
        patch_fid = patch_fid_metric.compute()
        # patch_kid_mean, patch_kid_std = patch_kid_metric.compute()
    # compute average metrics (on CPU)
    avg_psnr = torch.mean(torch.stack(psnr_list))
    avg_lpips = torch.mean(torch.stack(lpips_list))
    avg_ssim = torch.mean(torch.stack(ssim_list))
    # compute standard deviation (on CPU)
    std_psnr = torch.std(torch.stack(psnr_list))
    std_lpips = torch.std(torch.stack(lpips_list))
    std_ssim = torch.std(torch.stack(ssim_list))
    print(f"PSNR: {avg_psnr} +/- {std_psnr}")
    print(f"LPIPS: {avg_lpips} +/- {std_lpips}")
    print(f"SSIM: {avg_ssim} +/- {std_ssim}")
    print(f"FID: {fid}") # FID is computed on the selected device, print directly
    # print(f"KID: {kid_mean} +/- {kid_std}") # KID is computed on the selected device, print directly
    if args.patch_size:
        print(f"Patch FID ({args.patch_size}x{args.patch_size}): {patch_fid}") # Patch FID is computed on the selected device, print directly
        # print(f"Patch KID ({args.patch_size}x{args.patch_size}): {patch_kid_mean} +/- {patch_kid_std}") # Patch KID is computed on the selected device, print directly
    # save to prediction folder
    out_file = os.path.join(args.pred, "fid_metrics.txt")
    with open(out_file, "w") as f:
        f.write(f"PSNR: {avg_psnr.item()} +/- {std_psnr.item()}\n") # Use .item() for scalar tensors
        f.write(f"LPIPS: {avg_lpips.item()} +/- {std_lpips.item()}\n")
        f.write(f"SSIM: {avg_ssim.item()} +/- {std_ssim.item()}\n")
        f.write(f"FID: {fid.item()}\n") # Use .item() for scalar tensors
        # f.write(f"KID: {kid_mean.item()} +/- {kid_std.item()}\n") # Use .item() for scalar tensors
        if args.patch_size:
            f.write(f"Patch FID ({args.patch_size}x{args.patch_size}): {patch_fid.item()}\n") # Use .item() for scalar tensors
            # f.write(f"Patch KID ({args.patch_size}x{args.patch_size}): {patch_kid_mean.item()} +/- {patch_kid_std.item()}\n") # Use .item() for scalar tensors


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Compute metrics")
    parser.add_argument("--gt", type=str, help="Path to ground truth image")
    parser.add_argument("--pred", type=str, help="Path to predicted image")
    parser.add_argument("--resolution", type=int, default=768, help="resolution at which to evaluate")
    parser.add_argument("--patch_size", type=int, default=None, help="Patch size for Patchwise FID/KID computation (e.g., 12). If None, skip.")
    parser.add_argument("--kid_subset_size", type=int, default=1000, help="Subset size for KID computation.")
    parser.add_argument("--device", type=str, default="cpu", choices=["cpu", "cuda"], help="Device to run computation on (cpu or cuda)")
    args = parser.parse_args()

    main(args)