|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
This is a standalone PyTorch implementation of 3D bilateral grid and CP-decomposed 4D bilateral grid. |
|
To use this module, you can download the "lib_bilagrid.py" file and simply put it in your project directory. |
|
|
|
For the details, please check our research project: ["Bilateral Guided Radiance Field Processing"](https://bilarfpro.github.io/). |
|
|
|
#### Dependencies |
|
|
|
In addition to PyTorch and Numpy, please install [tensorly](https://github.com/tensorly/tensorly). |
|
We have tested this module on Python 3.9.18, PyTorch 2.0.1 (CUDA 11), tensorly 0.8.1, and Numpy 1.25.2. |
|
|
|
#### Overview |
|
|
|
- For bilateral guided training, you need to construct a `BilateralGrid` instance, which can hold multiple bilateral grids |
|
for input views. Then, use `slice` function to obtain transformed RGB output and the corresponding affine transformations. |
|
|
|
- For bilateral guided finishing, you need to instantiate a `BilateralGridCP4D` object and use `slice4d`. |
|
|
|
#### Examples |
|
|
|
- Bilateral grid for approximating ISP: |
|
<a target="_blank" href="https://colab.research.google.com/drive/1tx2qKtsHH9deDDnParMWrChcsa9i7Prr?usp=sharing"> |
|
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a> |
|
|
|
- Low-rank 4D bilateral grid for MR enhancement: |
|
<a target="_blank" href="https://colab.research.google.com/drive/17YOjQqgWFT3QI1vysOIH494rMYtt_mHL?usp=sharing"> |
|
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a> |
|
|
|
|
|
Below is the API reference. |
|
|
|
""" |
|
|
|
import tensorly as tl |
|
import torch |
|
import torch.nn.functional as F |
|
from torch import nn |
|
|
|
tl.set_backend("pytorch") |
|
|
|
|
|
def color_correct( |
|
img: torch.Tensor, ref: torch.Tensor, num_iters: int = 5, eps: float = 0.5 / 255 |
|
) -> torch.Tensor: |
|
""" |
|
Warp `img` to match the colors in `ref_img` using iterative color matching. |
|
|
|
This function performs color correction by warping the colors of the input image |
|
to match those of a reference image. It uses a least squares method to find a |
|
transformation that maps the input image's colors to the reference image's colors. |
|
|
|
The algorithm iteratively solves a system of linear equations, updating the set of |
|
unsaturated pixels in each iteration. This approach helps handle non-linear color |
|
transformations and reduces the impact of clipping. |
|
|
|
Args: |
|
img (torch.Tensor): Input image to be color corrected. Shape: [..., num_channels] |
|
ref (torch.Tensor): Reference image to match colors. Shape: [..., num_channels] |
|
num_iters (int, optional): Number of iterations for the color matching process. |
|
Default is 5. |
|
eps (float, optional): Small value to determine the range of unclipped pixels. |
|
Default is 0.5 / 255. |
|
|
|
Returns: |
|
torch.Tensor: Color corrected image with the same shape as the input image. |
|
|
|
Note: |
|
- Both input and reference images should be in the range [0, 1]. |
|
- The function works with any number of channels, but typically used with 3 (RGB). |
|
""" |
|
if img.shape[-1] != ref.shape[-1]: |
|
raise ValueError( |
|
f"img's {img.shape[-1]} and ref's {ref.shape[-1]} channels must match" |
|
) |
|
num_channels = img.shape[-1] |
|
img_mat = img.reshape([-1, num_channels]) |
|
ref_mat = ref.reshape([-1, num_channels]) |
|
|
|
def is_unclipped(z): |
|
return (z >= eps) & (z <= 1 - eps) |
|
|
|
mask0 = is_unclipped(img_mat) |
|
|
|
|
|
|
|
for _ in range(num_iters): |
|
|
|
|
|
a_mat = [] |
|
for c in range(num_channels): |
|
a_mat.append(img_mat[:, c : (c + 1)] * img_mat[:, c:]) |
|
a_mat.append(img_mat) |
|
a_mat.append(torch.ones_like(img_mat[:, :1])) |
|
a_mat = torch.cat(a_mat, dim=-1) |
|
warp = [] |
|
for c in range(num_channels): |
|
|
|
|
|
b = ref_mat[:, c] |
|
|
|
|
|
mask = mask0[:, c] & is_unclipped(img_mat[:, c]) & is_unclipped(b) |
|
ma_mat = torch.where(mask[:, None], a_mat, torch.zeros_like(a_mat)) |
|
mb = torch.where(mask, b, torch.zeros_like(b)) |
|
w = torch.linalg.lstsq(ma_mat, mb, rcond=-1)[0] |
|
assert torch.all(torch.isfinite(w)) |
|
warp.append(w) |
|
warp = torch.stack(warp, dim=-1) |
|
|
|
img_mat = torch.clip(torch.matmul(a_mat, warp), 0, 1) |
|
corrected_img = torch.reshape(img_mat, img.shape) |
|
return corrected_img |
|
|
|
|
|
def bilateral_grid_tv_loss(model, config): |
|
"""Computes total variations of bilateral grids.""" |
|
total_loss = 0.0 |
|
|
|
for bil_grids in model.bil_grids: |
|
total_loss += config.bilgrid_tv_loss_mult * total_variation_loss( |
|
bil_grids.grids |
|
) |
|
|
|
return total_loss |
|
|
|
|
|
def color_affine_transform(affine_mats, rgb): |
|
"""Applies color affine transformations. |
|
|
|
Args: |
|
affine_mats (torch.Tensor): Affine transformation matrices. Supported shape: $(..., 3, 4)$. |
|
rgb (torch.Tensor): Input RGB values. Supported shape: $(..., 3)$. |
|
|
|
Returns: |
|
Output transformed colors of shape $(..., 3)$. |
|
""" |
|
return ( |
|
torch.matmul(affine_mats[..., :3], rgb.unsqueeze(-1)).squeeze(-1) |
|
+ affine_mats[..., 3] |
|
) |
|
|
|
|
|
def _num_tensor_elems(t): |
|
return max(torch.prod(torch.tensor(t.size()[1:]).float()).item(), 1.0) |
|
|
|
|
|
def total_variation_loss(x): |
|
"""Returns total variation on multi-dimensional tensors. |
|
|
|
Args: |
|
x (torch.Tensor): The input tensor with shape $(B, C, ...)$, where $B$ is the batch size and $C$ is the channel size. |
|
""" |
|
batch_size = x.shape[0] |
|
tv = 0 |
|
for i in range(2, len(x.shape)): |
|
n_res = x.shape[i] |
|
idx1 = torch.arange(1, n_res, device=x.device) |
|
idx2 = torch.arange(0, n_res - 1, device=x.device) |
|
x1 = x.index_select(i, idx1) |
|
x2 = x.index_select(i, idx2) |
|
count = _num_tensor_elems(x1) |
|
tv += torch.pow((x1 - x2), 2).sum() / count |
|
return tv / batch_size |
|
|
|
|
|
def slice(bil_grids, xy, rgb, grid_idx): |
|
"""Slices a batch of 3D bilateral grids by pixel coordinates `xy` and gray-scale guidances of pixel colors `rgb`. |
|
|
|
Supports 2-D, 3-D, and 4-D input shapes. The first dimension of the input is the batch size |
|
and the last dimension is 2 for `xy`, 3 for `rgb`, and 1 for `grid_idx`. |
|
|
|
The return value is a dictionary containing the affine transformations `affine_mats` sliced from bilateral grids and |
|
the output color `rgb_out` after applying the afffine transformations. |
|
|
|
In the 2-D input case, `xy` is a $(N, 2)$ tensor, `rgb` is a $(N, 3)$ tensor, and `grid_idx` is a $(N, 1)$ tensor. |
|
Then `affine_mats[i]` can be obtained via slicing the bilateral grid indexed at `grid_idx[i]` by `xy[i, :]` and `rgb2gray(rgb[i, :])`. |
|
For 3-D and 4-D input cases, the behavior of indexing bilateral grids and coordinates is the same with the 2-D case. |
|
|
|
.. note:: |
|
This function can be regarded as a wrapper of `color_affine_transform` and `BilateralGrid` with a slight performance improvement. |
|
When `grid_idx` contains a unique index, only a single bilateral grid will used during the slicing. In this case, this function will not |
|
perform tensor indexing to avoid data copy and extra memory |
|
(see [this](https://discuss.pytorch.org/t/does-indexing-a-tensor-return-a-copy-of-it/164905)). |
|
|
|
Args: |
|
bil_grids (`BilateralGrid`): An instance of $N$ bilateral grids. |
|
xy (torch.Tensor): The x-y coordinates of shape $(..., 2)$ in the range of $[0,1]$. |
|
rgb (torch.Tensor): The RGB values of shape $(..., 3)$ for computing the guidance coordinates, ranging in $[0,1]$. |
|
grid_idx (torch.Tensor): The indices of bilateral grids for each slicing. Shape: $(..., 1)$. |
|
|
|
Returns: |
|
A dictionary with keys and values as follows: |
|
``` |
|
{ |
|
"rgb": Transformed RGB colors. Shape: (..., 3), |
|
"rgb_affine_mats": The sliced affine transformation matrices from bilateral grids. Shape: (..., 3, 4) |
|
} |
|
``` |
|
""" |
|
|
|
sh_ = rgb.shape |
|
|
|
grid_idx_unique = torch.unique(grid_idx) |
|
if len(grid_idx_unique) == 1: |
|
|
|
grid_idx = grid_idx_unique |
|
xy = xy.unsqueeze(0) |
|
rgb = rgb.unsqueeze(0) |
|
else: |
|
|
|
if len(grid_idx.shape) == 4: |
|
grid_idx = grid_idx[:, 0, 0, 0] |
|
elif len(grid_idx.shape) == 3: |
|
grid_idx = grid_idx[:, 0, 0] |
|
elif len(grid_idx.shape) == 2: |
|
grid_idx = grid_idx[:, 0] |
|
else: |
|
raise ValueError( |
|
"The input to bilateral grid slicing is not supported yet." |
|
) |
|
|
|
affine_mats = bil_grids(xy, rgb, grid_idx) |
|
rgb = color_affine_transform(affine_mats, rgb) |
|
|
|
return { |
|
"rgb": rgb.reshape(*sh_), |
|
"rgb_affine_mats": affine_mats.reshape( |
|
*sh_[:-1], affine_mats.shape[-2], affine_mats.shape[-1] |
|
), |
|
} |
|
|
|
|
|
class BilateralGrid(nn.Module): |
|
"""Class for 3D bilateral grids. |
|
|
|
Holds one or more than one bilateral grids. |
|
""" |
|
|
|
def __init__(self, num, grid_X=16, grid_Y=16, grid_W=8): |
|
""" |
|
Args: |
|
num (int): The number of bilateral grids (i.e., the number of views). |
|
grid_X (int): Defines grid width $W$. |
|
grid_Y (int): Defines grid height $H$. |
|
grid_W (int): Defines grid guidance dimension $L$. |
|
""" |
|
super(BilateralGrid, self).__init__() |
|
|
|
self.grid_width = grid_X |
|
"""Grid width. Type: int.""" |
|
self.grid_height = grid_Y |
|
"""Grid height. Type: int.""" |
|
self.grid_guidance = grid_W |
|
"""Grid guidance dimension. Type: int.""" |
|
|
|
|
|
grid = self._init_identity_grid() |
|
self.grids = nn.Parameter(grid.tile(num, 1, 1, 1, 1)) |
|
""" A 5-D tensor of shape $(N, 12, L, H, W)$.""" |
|
|
|
|
|
self.register_buffer("rgb2gray_weight", torch.Tensor([[0.299, 0.587, 0.114]])) |
|
self.rgb2gray = lambda rgb: (rgb @ self.rgb2gray_weight.T) * 2.0 - 1.0 |
|
""" A function that converts RGB to gray-scale guidance in $[-1, 1]$.""" |
|
|
|
def _init_identity_grid(self): |
|
grid = torch.tensor( |
|
[ |
|
1.0, |
|
0, |
|
0, |
|
0, |
|
0, |
|
1.0, |
|
0, |
|
0, |
|
0, |
|
0, |
|
1.0, |
|
0, |
|
] |
|
).float() |
|
grid = grid.repeat( |
|
[self.grid_guidance * self.grid_height * self.grid_width, 1] |
|
) |
|
grid = grid.reshape( |
|
1, self.grid_guidance, self.grid_height, self.grid_width, -1 |
|
) |
|
grid = grid.permute(0, 4, 1, 2, 3) |
|
return grid |
|
|
|
def tv_loss(self): |
|
"""Computes and returns total variation loss on the bilateral grids.""" |
|
return total_variation_loss(self.grids) |
|
|
|
def forward(self, grid_xy, rgb, idx=None): |
|
"""Bilateral grid slicing. Supports 2-D, 3-D, 4-D, and 5-D input. |
|
For the 2-D, 3-D, and 4-D cases, please refer to `slice`. |
|
For the 5-D cases, `idx` will be unused and the first dimension of `xy` should be |
|
equal to the number of bilateral grids. Then this function becomes PyTorch's |
|
[`F.grid_sample`](https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html). |
|
|
|
Args: |
|
grid_xy (torch.Tensor): The x-y coordinates in the range of $[0,1]$. |
|
rgb (torch.Tensor): The RGB values in the range of $[0,1]$. |
|
idx (torch.Tensor): The bilateral grid indices. |
|
|
|
Returns: |
|
Sliced affine matrices of shape $(..., 3, 4)$. |
|
""" |
|
|
|
grids = self.grids |
|
input_ndims = len(grid_xy.shape) |
|
assert len(rgb.shape) == input_ndims |
|
|
|
if input_ndims > 1 and input_ndims < 5: |
|
|
|
for i in range(5 - input_ndims): |
|
grid_xy = grid_xy.unsqueeze(1) |
|
rgb = rgb.unsqueeze(1) |
|
assert idx is not None |
|
elif input_ndims != 5: |
|
raise ValueError( |
|
"Bilateral grid slicing only takes either 2D, 3D, 4D and 5D inputs" |
|
) |
|
|
|
grids = self.grids |
|
if idx is not None: |
|
grids = grids[idx] |
|
assert grids.shape[0] == grid_xy.shape[0] |
|
|
|
|
|
grid_xy = (grid_xy - 0.5) * 2 |
|
grid_z = self.rgb2gray(rgb) |
|
|
|
|
|
|
|
grid_xyz = torch.cat([grid_xy, grid_z], dim=-1) |
|
|
|
affine_mats = F.grid_sample( |
|
grids, grid_xyz, mode="bilinear", align_corners=True, padding_mode="border" |
|
) |
|
affine_mats = affine_mats.permute(0, 2, 3, 4, 1) |
|
affine_mats = affine_mats.reshape( |
|
*affine_mats.shape[:-1], 3, 4 |
|
) |
|
|
|
for _ in range(5 - input_ndims): |
|
affine_mats = affine_mats.squeeze(1) |
|
|
|
return affine_mats |
|
|
|
|
|
def slice4d(bil_grid4d, xyz, rgb): |
|
"""Slices a 4D bilateral grid by point coordinates `xyz` and gray-scale guidances of radiance colors `rgb`. |
|
|
|
Args: |
|
bil_grid4d (`BilateralGridCP4D`): The input 4D bilateral grid. |
|
xyz (torch.Tensor): The xyz coordinates with shape $(..., 3)$. |
|
rgb (torch.Tensor): The RGB values with shape $(..., 3)$. |
|
|
|
Returns: |
|
A dictionary with keys and values as follows: |
|
``` |
|
{ |
|
"rgb": Transformed radiance RGB colors. Shape: (..., 3), |
|
"rgb_affine_mats": The sliced affine transformation matrices from the 4D bilateral grid. Shape: (..., 3, 4) |
|
} |
|
``` |
|
""" |
|
|
|
affine_mats = bil_grid4d(xyz, rgb) |
|
rgb = color_affine_transform(affine_mats, rgb) |
|
|
|
return {"rgb": rgb, "rgb_affine_mats": affine_mats} |
|
|
|
|
|
class _ScaledTanh(nn.Module): |
|
def __init__(self, s=2.0): |
|
super().__init__() |
|
self.scaler = s |
|
|
|
def forward(self, x): |
|
return torch.tanh(self.scaler * x) |
|
|
|
|
|
class BilateralGridCP4D(nn.Module): |
|
"""Class for low-rank 4D bilateral grids.""" |
|
|
|
def __init__( |
|
self, |
|
grid_X=16, |
|
grid_Y=16, |
|
grid_Z=16, |
|
grid_W=8, |
|
rank=5, |
|
learn_gray=True, |
|
gray_mlp_width=8, |
|
gray_mlp_depth=2, |
|
init_noise_scale=1e-6, |
|
bound=2.0, |
|
): |
|
""" |
|
Args: |
|
grid_X (int): Defines grid width. |
|
grid_Y (int): Defines grid height. |
|
grid_Z (int): Defines grid depth. |
|
grid_W (int): Defines grid guidance dimension. |
|
rank (int): Rank of the 4D bilateral grid. |
|
learn_gray (bool): If True, an MLP will be learned to convert RGB colors to gray-scale guidances. |
|
gray_mlp_width (int): The MLP width for learnable guidance. |
|
gray_mlp_depth (int): The number of MLP layers for learnable guidance. |
|
init_noise_scale (float): The noise scale of the initialized factors. |
|
bound (float): The bound of the xyz coordinates. |
|
""" |
|
super(BilateralGridCP4D, self).__init__() |
|
|
|
self.grid_X = grid_X |
|
"""Grid width. Type: int.""" |
|
self.grid_Y = grid_Y |
|
"""Grid height. Type: int.""" |
|
self.grid_Z = grid_Z |
|
"""Grid depth. Type: int.""" |
|
self.grid_W = grid_W |
|
"""Grid guidance dimension. Type: int.""" |
|
self.rank = rank |
|
"""Rank of the 4D bilateral grid. Type: int.""" |
|
self.learn_gray = learn_gray |
|
"""Flags of learnable guidance is used. Type: bool.""" |
|
self.gray_mlp_width = gray_mlp_width |
|
"""The MLP width for learnable guidance. Type: int.""" |
|
self.gray_mlp_depth = gray_mlp_depth |
|
"""The MLP depth for learnable guidance. Type: int.""" |
|
self.init_noise_scale = init_noise_scale |
|
"""The noise scale of the initialized factors. Type: float.""" |
|
self.bound = bound |
|
"""The bound of the xyz coordinates. Type: float.""" |
|
|
|
self._init_cp_factors_parafac() |
|
|
|
self.rgb2gray = None |
|
""" A function that converts RGB to gray-scale guidances in $[-1, 1]$. |
|
If `learn_gray` is True, this will be an MLP network.""" |
|
|
|
if self.learn_gray: |
|
|
|
def rgb2gray_mlp_linear(layer): |
|
return nn.Linear( |
|
self.gray_mlp_width, |
|
self.gray_mlp_width if layer < self.gray_mlp_depth - 1 else 1, |
|
) |
|
|
|
def rgb2gray_mlp_actfn(_): |
|
return nn.ReLU(inplace=True) |
|
|
|
self.rgb2gray = nn.Sequential( |
|
*( |
|
[nn.Linear(3, self.gray_mlp_width)] |
|
+ [ |
|
nn_module(layer) |
|
for layer in range(1, self.gray_mlp_depth) |
|
for nn_module in [rgb2gray_mlp_actfn, rgb2gray_mlp_linear] |
|
] |
|
+ [_ScaledTanh(2.0)] |
|
) |
|
) |
|
else: |
|
|
|
self.register_buffer( |
|
"rgb2gray_weight", torch.Tensor([[0.299, 0.587, 0.114]]) |
|
) |
|
self.rgb2gray = lambda rgb: (rgb @ self.rgb2gray_weight.T) * 2.0 - 1.0 |
|
|
|
def _init_identity_grid(self): |
|
grid = torch.tensor( |
|
[ |
|
1.0, |
|
0, |
|
0, |
|
0, |
|
0, |
|
1.0, |
|
0, |
|
0, |
|
0, |
|
0, |
|
1.0, |
|
0, |
|
] |
|
).float() |
|
grid = grid.repeat([self.grid_W * self.grid_Z * self.grid_Y * self.grid_X, 1]) |
|
grid = grid.reshape(self.grid_W, self.grid_Z, self.grid_Y, self.grid_X, -1) |
|
grid = grid.permute(4, 0, 1, 2, 3) |
|
return grid |
|
|
|
def _init_cp_factors_parafac(self): |
|
|
|
init_grids = self._init_identity_grid() |
|
|
|
init_grids = torch.randn_like(init_grids) * self.init_noise_scale + init_grids |
|
from tensorly.decomposition import parafac |
|
|
|
|
|
_, facs = parafac(init_grids.clone().detach(), rank=self.rank) |
|
|
|
self.num_facs = len(facs) |
|
|
|
self.fac_0 = nn.Linear(facs[0].shape[0], facs[0].shape[1], bias=False) |
|
self.fac_0.weight = nn.Parameter(facs[0]) |
|
|
|
for i in range(1, self.num_facs): |
|
fac = facs[i].T |
|
fac = fac.view(1, fac.shape[0], fac.shape[1], 1) |
|
self.register_buffer(f"fac_{i}_init", fac) |
|
|
|
fac_resid = torch.zeros_like(fac) |
|
self.register_parameter(f"fac_{i}", nn.Parameter(fac_resid)) |
|
|
|
def tv_loss(self): |
|
"""Computes and returns total variation loss on the factors of the low-rank 4D bilateral grids.""" |
|
|
|
total_loss = 0 |
|
for i in range(1, self.num_facs): |
|
fac = self.get_parameter(f"fac_{i}") |
|
total_loss += total_variation_loss(fac) |
|
|
|
return total_loss |
|
|
|
def forward(self, xyz, rgb): |
|
"""Low-rank 4D bilateral grid slicing. |
|
|
|
Args: |
|
xyz (torch.Tensor): The xyz coordinates with shape $(..., 3)$. |
|
rgb (torch.Tensor): The corresponding RGB values with shape $(..., 3)$. |
|
|
|
Returns: |
|
Sliced affine matrices with shape $(..., 3, 4)$. |
|
""" |
|
sh_ = xyz.shape |
|
xyz = xyz.reshape(-1, 3) |
|
rgb = rgb.reshape(-1, 3) |
|
|
|
xyz = xyz / self.bound |
|
assert self.rgb2gray is not None |
|
gray = self.rgb2gray(rgb) |
|
xyzw = torch.cat([xyz, gray], dim=-1) |
|
xyzw = xyzw.transpose(0, 1) |
|
coords = torch.stack([torch.zeros_like(xyzw), xyzw], dim=-1) |
|
coords = coords.unsqueeze(1) |
|
|
|
coef = 1.0 |
|
for i in range(1, self.num_facs): |
|
fac = self.get_parameter(f"fac_{i}") + self.get_buffer(f"fac_{i}_init") |
|
coef = coef * F.grid_sample( |
|
fac, coords[[i - 1]], align_corners=True, padding_mode="border" |
|
) |
|
coef = coef.squeeze([0, 2]).transpose(0, 1) |
|
mat = self.fac_0(coef) |
|
return mat.reshape(*sh_[:-1], 3, 4) |
|
|