ReubenSun's picture
1
2ac1c2d
import json
import math
import os
import re
import shutil
from typing import List, Optional, Union
import cv2
import imageio
import matplotlib.pyplot as plt
import numpy as np
import torch
# import wandb
from matplotlib import cm
from matplotlib.colors import LinearSegmentedColormap
from PIL import Image, ImageDraw
from .typing import *
def tensor_to_image(
data: Union[Image.Image, torch.Tensor, np.ndarray],
batched: bool = False,
format: str = "HWC",
) -> Union[Image.Image, List[Image.Image]]:
if isinstance(data, Image.Image):
return data
if isinstance(data, torch.Tensor):
data = data.detach().cpu().numpy()
if data.dtype == np.float32 or data.dtype == np.float16:
data = (data * 255).astype(np.uint8)
elif data.dtype == np.bool_:
data = data.astype(np.uint8) * 255
assert data.dtype == np.uint8
if format == "CHW":
if batched and data.ndim == 4:
data = data.transpose((0, 2, 3, 1))
elif not batched and data.ndim == 3:
data = data.transpose((1, 2, 0))
if batched:
return [Image.fromarray(d) for d in data]
return Image.fromarray(data)
def largest_factor_near_sqrt(n: int) -> int:
"""
Finds the largest factor of n that is closest to the square root of n.
Args:
n (int): The integer for which to find the largest factor near its square root.
Returns:
int: The largest factor of n that is closest to the square root of n.
"""
sqrt_n = int(math.sqrt(n)) # Get the integer part of the square root
# First, check if the square root itself is a factor
if sqrt_n * sqrt_n == n:
return sqrt_n
# Otherwise, find the largest factor by iterating from sqrt_n downwards
for i in range(sqrt_n, 0, -1):
if n % i == 0:
return i
# If n is 1, return 1
return 1
def make_image_grid(
images: List[Image.Image],
rows: Optional[int] = None,
cols: Optional[int] = None,
resize: Optional[int] = None,
) -> Image.Image:
"""
Prepares a single grid of images. Useful for visualization purposes.
"""
if rows is None and cols is not None:
assert len(images) % cols == 0
rows = len(images) // cols
elif cols is None and rows is not None:
assert len(images) % rows == 0
cols = len(images) // rows
elif rows is None and cols is None:
rows = largest_factor_near_sqrt(len(images))
cols = len(images) // rows
assert len(images) == rows * cols
if resize is not None:
images = [img.resize((resize, resize)) for img in images]
w, h = images[0].size
grid = Image.new("RGB", size=(cols * w, rows * h))
for i, img in enumerate(images):
grid.paste(img, box=(i % cols * w, i // cols * h))
return grid
class SaverMixin:
_save_dir: Optional[str] = None
_wandb_logger: Optional[Any] = None
def set_save_dir(self, save_dir: str):
self._save_dir = save_dir
def get_save_dir(self):
if self._save_dir is None:
raise ValueError("Save dir is not set")
return self._save_dir
def convert_data(self, data):
if data is None:
return None
elif isinstance(data, np.ndarray):
return data
elif isinstance(data, torch.Tensor):
if data.dtype in [torch.float16, torch.bfloat16]:
data = data.float()
return data.detach().cpu().numpy()
elif isinstance(data, list):
return [self.convert_data(d) for d in data]
elif isinstance(data, dict):
return {k: self.convert_data(v) for k, v in data.items()}
else:
raise TypeError(
"Data must be in type numpy.ndarray, torch.Tensor, list or dict, getting",
type(data),
)
def get_save_path(self, filename):
save_path = os.path.join(self.get_save_dir(), filename)
os.makedirs(os.path.dirname(save_path), exist_ok=True)
return save_path
DEFAULT_RGB_KWARGS = {"data_format": "HWC", "data_range": (0, 1)}
DEFAULT_UV_KWARGS = {
"data_format": "HWC",
"data_range": (0, 1),
"cmap": "checkerboard",
}
DEFAULT_GRAYSCALE_KWARGS = {"data_range": None, "cmap": "jet"}
DEFAULT_GRID_KWARGS = {"align": "max"}
def get_rgb_image_(self, img, data_format, data_range, rgba=False):
img = self.convert_data(img)
assert data_format in ["CHW", "HWC"]
if data_format == "CHW":
img = img.transpose(1, 2, 0)
if img.dtype != np.uint8:
img = img.clip(min=data_range[0], max=data_range[1])
img = (
(img - data_range[0]) / (data_range[1] - data_range[0]) * 255.0
).astype(np.uint8)
nc = 4 if rgba else 3
imgs = [img[..., start : start + nc] for start in range(0, img.shape[-1], nc)]
imgs = [
(
img_
if img_.shape[-1] == nc
else np.concatenate(
[
img_,
np.zeros(
(img_.shape[0], img_.shape[1], nc - img_.shape[2]),
dtype=img_.dtype,
),
],
axis=-1,
)
)
for img_ in imgs
]
img = np.concatenate(imgs, axis=1)
if rgba:
img = cv2.cvtColor(img, cv2.COLOR_RGBA2BGRA)
else:
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
return img
def _save_rgb_image(
self,
filename,
img,
data_format,
data_range,
name: Optional[str] = None,
step: Optional[int] = None,
):
img = self.get_rgb_image_(img, data_format, data_range)
cv2.imwrite(filename, img)
if name and self._wandb_logger:
self._wandb_logger.log_image(
key=name, images=[self.get_save_path(filename)], step=step
)
def save_rgb_image(
self,
filename,
img,
data_format=DEFAULT_RGB_KWARGS["data_format"],
data_range=DEFAULT_RGB_KWARGS["data_range"],
name: Optional[str] = None,
step: Optional[int] = None,
) -> str:
save_path = self.get_save_path(filename)
self._save_rgb_image(save_path, img, data_format, data_range, name, step)
return save_path
def get_uv_image_(self, img, data_format, data_range, cmap):
img = self.convert_data(img)
assert data_format in ["CHW", "HWC"]
if data_format == "CHW":
img = img.transpose(1, 2, 0)
img = img.clip(min=data_range[0], max=data_range[1])
img = (img - data_range[0]) / (data_range[1] - data_range[0])
assert cmap in ["checkerboard", "color"]
if cmap == "checkerboard":
n_grid = 64
mask = (img * n_grid).astype(int)
mask = (mask[..., 0] + mask[..., 1]) % 2 == 0
img = np.ones((img.shape[0], img.shape[1], 3), dtype=np.uint8) * 255
img[mask] = np.array([255, 0, 255], dtype=np.uint8)
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
elif cmap == "color":
img_ = np.zeros((img.shape[0], img.shape[1], 3), dtype=np.uint8)
img_[..., 0] = (img[..., 0] * 255).astype(np.uint8)
img_[..., 1] = (img[..., 1] * 255).astype(np.uint8)
img_ = cv2.cvtColor(img_, cv2.COLOR_RGB2BGR)
img = img_
return img
def save_uv_image(
self,
filename,
img,
data_format=DEFAULT_UV_KWARGS["data_format"],
data_range=DEFAULT_UV_KWARGS["data_range"],
cmap=DEFAULT_UV_KWARGS["cmap"],
) -> str:
save_path = self.get_save_path(filename)
img = self.get_uv_image_(img, data_format, data_range, cmap)
cv2.imwrite(save_path, img)
return save_path
def get_grayscale_image_(self, img, data_range, cmap):
img = self.convert_data(img)
img = np.nan_to_num(img)
if data_range is None:
img = (img - img.min()) / (img.max() - img.min())
else:
img = img.clip(data_range[0], data_range[1])
img = (img - data_range[0]) / (data_range[1] - data_range[0])
assert cmap in [None, "jet", "magma", "spectral"]
if cmap == None:
img = (img * 255.0).astype(np.uint8)
img = np.repeat(img[..., None], 3, axis=2)
elif cmap == "jet":
img = (img * 255.0).astype(np.uint8)
img = cv2.applyColorMap(img, cv2.COLORMAP_JET)
elif cmap == "magma":
img = 1.0 - img
base = cm.get_cmap("magma")
num_bins = 256
colormap = LinearSegmentedColormap.from_list(
f"{base.name}{num_bins}", base(np.linspace(0, 1, num_bins)), num_bins
)(np.linspace(0, 1, num_bins))[:, :3]
a = np.floor(img * 255.0)
b = (a + 1).clip(max=255.0)
f = img * 255.0 - a
a = a.astype(np.uint16).clip(0, 255)
b = b.astype(np.uint16).clip(0, 255)
img = colormap[a] + (colormap[b] - colormap[a]) * f[..., None]
img = (img * 255.0).astype(np.uint8)
elif cmap == "spectral":
colormap = plt.get_cmap("Spectral")
def blend_rgba(image):
image = image[..., :3] * image[..., -1:] + (
1.0 - image[..., -1:]
) # blend A to RGB
return image
img = colormap(img)
img = blend_rgba(img)
img = (img * 255).astype(np.uint8)
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
return img
def _save_grayscale_image(
self,
filename,
img,
data_range,
cmap,
name: Optional[str] = None,
step: Optional[int] = None,
):
img = self.get_grayscale_image_(img, data_range, cmap)
cv2.imwrite(filename, img)
if name and self._wandb_logger:
self._wandb_logger.log_image(
key=name, images=[self.get_save_path(filename)], step=step
)
def save_grayscale_image(
self,
filename,
img,
data_range=DEFAULT_GRAYSCALE_KWARGS["data_range"],
cmap=DEFAULT_GRAYSCALE_KWARGS["cmap"],
name: Optional[str] = None,
step: Optional[int] = None,
) -> str:
save_path = self.get_save_path(filename)
self._save_grayscale_image(save_path, img, data_range, cmap, name, step)
return save_path
def get_image_grid_(self, imgs, align):
if isinstance(imgs[0], list):
return np.concatenate(
[self.get_image_grid_(row, align) for row in imgs], axis=0
)
cols = []
for col in imgs:
assert col["type"] in ["rgb", "uv", "grayscale"]
if col["type"] == "rgb":
rgb_kwargs = self.DEFAULT_RGB_KWARGS.copy()
rgb_kwargs.update(col["kwargs"])
cols.append(self.get_rgb_image_(col["img"], **rgb_kwargs))
elif col["type"] == "uv":
uv_kwargs = self.DEFAULT_UV_KWARGS.copy()
uv_kwargs.update(col["kwargs"])
cols.append(self.get_uv_image_(col["img"], **uv_kwargs))
elif col["type"] == "grayscale":
grayscale_kwargs = self.DEFAULT_GRAYSCALE_KWARGS.copy()
grayscale_kwargs.update(col["kwargs"])
cols.append(self.get_grayscale_image_(col["img"], **grayscale_kwargs))
if align == "max":
h = max([col.shape[0] for col in cols])
elif align == "min":
h = min([col.shape[0] for col in cols])
elif isinstance(align, int):
h = align
else:
raise ValueError(
f"Unsupported image grid align: {align}, should be min, max, or int"
)
for i in range(len(cols)):
if cols[i].shape[0] != h:
w = int(cols[i].shape[1] * h / cols[i].shape[0])
cols[i] = cv2.resize(cols[i], (w, h), interpolation=cv2.INTER_CUBIC)
return np.concatenate(cols, axis=1)
def save_image_grid(
self,
filename,
imgs,
align=DEFAULT_GRID_KWARGS["align"],
name: Optional[str] = None,
step: Optional[int] = None,
texts: Optional[List[float]] = None,
):
save_path = self.get_save_path(filename)
img = self.get_image_grid_(imgs, align=align)
if texts is not None:
img = Image.fromarray(img)
draw = ImageDraw.Draw(img)
black, white = (0, 0, 0), (255, 255, 255)
for i, text in enumerate(texts):
draw.text((2, (img.size[1] // len(texts)) * i + 1), f"{text}", white)
draw.text((0, (img.size[1] // len(texts)) * i + 1), f"{text}", white)
draw.text((2, (img.size[1] // len(texts)) * i - 1), f"{text}", white)
draw.text((0, (img.size[1] // len(texts)) * i - 1), f"{text}", white)
draw.text((1, (img.size[1] // len(texts)) * i), f"{text}", black)
img = np.asarray(img)
cv2.imwrite(save_path, img)
if name and self._wandb_logger:
self._wandb_logger.log_image(key=name, images=[save_path], step=step)
return save_path
def save_image(self, filename, img) -> str:
save_path = self.get_save_path(filename)
img = self.convert_data(img)
assert img.dtype == np.uint8 or img.dtype == np.uint16
if img.ndim == 3 and img.shape[-1] == 3:
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
elif img.ndim == 3 and img.shape[-1] == 4:
img = cv2.cvtColor(img, cv2.COLOR_RGBA2BGRA)
cv2.imwrite(save_path, img)
return save_path
def save_cubemap(self, filename, img, data_range=(0, 1), rgba=False) -> str:
save_path = self.get_save_path(filename)
img = self.convert_data(img)
assert img.ndim == 4 and img.shape[0] == 6 and img.shape[1] == img.shape[2]
imgs_full = []
for start in range(0, img.shape[-1], 3):
img_ = img[..., start : start + 3]
img_ = np.stack(
[
self.get_rgb_image_(img_[i], "HWC", data_range, rgba=rgba)
for i in range(img_.shape[0])
],
axis=0,
)
size = img_.shape[1]
placeholder = np.zeros((size, size, 3), dtype=np.float32)
img_full = np.concatenate(
[
np.concatenate(
[placeholder, img_[2], placeholder, placeholder], axis=1
),
np.concatenate([img_[1], img_[4], img_[0], img_[5]], axis=1),
np.concatenate(
[placeholder, img_[3], placeholder, placeholder], axis=1
),
],
axis=0,
)
imgs_full.append(img_full)
imgs_full = np.concatenate(imgs_full, axis=1)
cv2.imwrite(save_path, imgs_full)
return save_path
def save_data(self, filename, data) -> str:
data = self.convert_data(data)
if isinstance(data, dict):
if not filename.endswith(".npz"):
filename += ".npz"
save_path = self.get_save_path(filename)
np.savez(save_path, **data)
else:
if not filename.endswith(".npy"):
filename += ".npy"
save_path = self.get_save_path(filename)
np.save(save_path, data)
return save_path
def save_state_dict(self, filename, data) -> str:
save_path = self.get_save_path(filename)
torch.save(data, save_path)
return save_path
def save_img_sequence(
self,
filename,
img_dir,
matcher,
save_format="mp4",
fps=30,
name: Optional[str] = None,
step: Optional[int] = None,
) -> str:
assert save_format in ["gif", "mp4"]
if not filename.endswith(save_format):
filename += f".{save_format}"
save_path = self.get_save_path(filename)
matcher = re.compile(matcher)
img_dir = os.path.join(self.get_save_dir(), img_dir)
imgs = []
for f in os.listdir(img_dir):
if matcher.search(f):
imgs.append(f)
imgs = sorted(imgs, key=lambda f: int(matcher.search(f).groups()[0]))
imgs = [cv2.imread(os.path.join(img_dir, f)) for f in imgs]
if save_format == "gif":
imgs = [cv2.cvtColor(i, cv2.COLOR_BGR2RGB) for i in imgs]
imageio.mimsave(save_path, imgs, fps=fps, palettesize=256)
elif save_format == "mp4":
imgs = [cv2.cvtColor(i, cv2.COLOR_BGR2RGB) for i in imgs]
imageio.mimsave(save_path, imgs, fps=fps)
if name and self._wandb_logger:
from .core import warn
warn("Wandb logger does not support video logging yet!")
return save_path
def save_img_sequences(
self,
seq_dir,
matcher,
save_format="mp4",
fps=30,
delete=True,
name: Optional[str] = None,
step: Optional[int] = None,
):
seq_dir_ = os.path.join(self.get_save_dir(), seq_dir)
for f in os.listdir(seq_dir_):
img_dir_ = os.path.join(seq_dir_, f)
if not os.path.isdir(img_dir_):
continue
try:
self.save_img_sequence(
os.path.join(seq_dir, f),
os.path.join(seq_dir, f),
matcher,
save_format=save_format,
fps=fps,
name=f"{name}_{f}",
step=step,
)
if delete:
shutil.rmtree(img_dir_)
except:
from .core import warn
warn(f"Video saving for directory {seq_dir_} failed!")
def save_file(self, filename, src_path, delete=False) -> str:
save_path = self.get_save_path(filename)
shutil.copyfile(src_path, save_path)
if delete:
os.remove(src_path)
return save_path
def save_json(self, filename, payload) -> str:
save_path = self.get_save_path(filename)
with open(save_path, "w") as f:
f.write(json.dumps(payload))
return save_path