Spaces:
Running
on
L4
Running
on
L4
import os | |
import torch | |
import numpy as np | |
import matplotlib.pyplot as plt | |
try: | |
from kornia.morphology import opening | |
except ImportError: | |
from kornia.morphology import open as opening | |
from torchvision import transforms | |
from torchvision.utils import make_grid, save_image | |
from typing import Any | |
def exist(val: Any) -> bool: | |
return val is not None | |
def morph_open(x: torch.Tensor, k: int) -> torch.Tensor: | |
if k==0: | |
return x | |
else: | |
with torch.no_grad(): | |
return opening(x, torch.ones(k,k,device=x.device)) | |
def make_grid_images(images: list[torch.Tensor], **kwargs) -> torch.Tensor: | |
concatenated_images = torch.cat(images, dim=3) | |
grid_concatenated = make_grid(concatenated_images, **kwargs) | |
return grid_concatenated | |
def save_images(images: tuple[torch.Tensor, torch.Tensor], path: str, **kwargs) -> None: | |
gen, real = images | |
concatenated_images = torch.cat((gen, real), dim=3) | |
grid_concatenated = make_grid(concatenated_images, **kwargs) | |
ndarr_concatenated = grid_concatenated.permute(1, 2, 0).to("cpu").numpy() | |
ndarr_concatenated = (ndarr_concatenated * 255).astype(np.uint8) | |
save_image(torch.from_numpy(ndarr_concatenated).permute(2, 0, 1) / 255, path) | |
def save_triplet(images: tuple[torch.Tensor, ...], path: str, **kwargs) -> None: | |
concatenated_images = torch.cat(images, dim=3) | |
grid_concatenated = make_grid(concatenated_images, **kwargs) | |
ndarr_concatenated = grid_concatenated.permute(1, 2, 0).to("cpu").numpy() | |
ndarr_concatenated = (ndarr_concatenated * 255).astype(np.uint8) | |
save_image(torch.from_numpy(ndarr_concatenated).permute(2, 0, 1) / 255, path) | |
def plot_images(images: torch.Tensor) -> None: | |
plt.figure(figsize=(32, 32)) | |
plt.imshow(torch.cat([ | |
torch.cat([i for i in images.cpu()], dim=-1), | |
], dim=-2).permute(1, 2, 0).cpu()) | |
plt.show() | |
def make_graphic(metric_name: str, metrics: list[torch.Tensor], path: str) -> None: | |
plt.figure(figsize=(32, 32)) | |
metrics = [m.cpu().numpy() for m in metrics] | |
plt.plot(metrics) | |
plt.title(metric_name) | |
plt.xlabel("Epoch") | |
plt.ylabel(metric_name) | |
path = os.path.join(path, f"{metric_name}.png") | |
plt.savefig(path) | |
plt.close() | |
def norm( | |
img: torch.Tensor, | |
mean: list[float] = [0.5, 0.5, 0.5], | |
std: list[float] = [0.5, 0.5, 0.5] | |
) -> torch.Tensor: | |
normalize = transforms.Normalize(mean, std) | |
return normalize(img) | |
def denorm( | |
img: torch.Tensor, | |
mean: list[float] = [0.5, 0.5, 0.5], | |
std: list[float] = [0.5, 0.5, 0.5] | |
) -> torch.Tensor: | |
mean = torch.tensor(mean, device=img.device) | |
std = torch.tensor(std, device=img.device) | |
return img*std[None][...,None,None] + mean[None][...,None,None] |