File size: 1,831 Bytes
587665f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import torch
import itertools
from torchmetrics.image import LearnedPerceptualImagePatchSimilarity as LPIPS
from utils.utils import denorm

def compute_lpips_variability(samples: torch.Tensor, 

                              net: str = 'alex', 

                              device: str = 'cuda'

                              ) -> float:
    loss_fn = LPIPS(net_type=net).to(device)
    loss_fn.eval()

    if samples.min() >= 0.0:
        samples = samples * 2 - 1  # Convertir [0, 1] → [-1, 1]

    N = samples.size(0)
    scores = []
    for i, j in itertools.combinations(range(N), 2):
        x = samples[i:i+1].to(device)
        y = samples[j:j+1].to(device)
        dist = loss_fn(denorm(x.clamp(-1, 1)), denorm(y.clamp(-1, 1)))
        scores.append(dist.item())

    return sum(scores) / len(scores)

def compute_pixelwise_correlation(samples: torch.Tensor) -> float:
    N, C, H, W = samples.shape
    samples_flat = samples.view(N, C, -1)  # (N, C, H*W)

    corrs = []
    for i, j in itertools.combinations(range(N), 2):
        x = samples_flat[i]  # (C, HW)
        y = samples_flat[j]  # (C, HW)
        mean_x = x.mean(dim=1, keepdim=True)
        mean_y = y.mean(dim=1, keepdim=True)
        x_centered = x - mean_x
        y_centered = y - mean_y
        numerator = (x_centered * y_centered).sum(dim=1)
        denominator = (x_centered.norm(dim=1) * y_centered.norm(dim=1)) + 1e-8
        corr = numerator / denominator  # (C,)
        corrs.append(corr.mean().item())
    return sum(corrs) / len(corrs)

def compute_dynamic_range(samples: torch.Tensor) -> float:
    max_vals, _ = samples.max(dim=0)  # (C, H, W)
    min_vals, _ = samples.min(dim=0)  # (C, H, W)
    
    dynamic_range = max_vals - min_vals  # (C, H, W)
    return dynamic_range.mean().item()