File size: 22,151 Bytes
2568013
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
# # Copyright 2024 Yuehao Wang (https://github.com/yuehaowang). This part of code is borrowed form ["Bilateral Guided Radiance Field Processing"](https://bilarfpro.github.io/).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
This is a standalone PyTorch implementation of 3D bilateral grid and CP-decomposed 4D bilateral grid.
To use this module, you can download the "lib_bilagrid.py" file and simply put it in your project directory.

For the details, please check our research project: ["Bilateral Guided Radiance Field Processing"](https://bilarfpro.github.io/).

#### Dependencies

In addition to PyTorch and Numpy, please install [tensorly](https://github.com/tensorly/tensorly).
We have tested this module on Python 3.9.18, PyTorch 2.0.1 (CUDA 11), tensorly 0.8.1, and Numpy 1.25.2.

#### Overview

- For bilateral guided training, you need to construct a `BilateralGrid` instance, which can hold multiple bilateral grids
  for input views. Then, use `slice` function to obtain transformed RGB output and the corresponding affine transformations.

- For bilateral guided finishing, you need to instantiate a `BilateralGridCP4D` object and use `slice4d`.

#### Examples

- Bilateral grid for approximating ISP:
    <a target="_blank" href="https://colab.research.google.com/drive/1tx2qKtsHH9deDDnParMWrChcsa9i7Prr?usp=sharing">
    <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

- Low-rank 4D bilateral grid for MR enhancement:
    <a target="_blank" href="https://colab.research.google.com/drive/17YOjQqgWFT3QI1vysOIH494rMYtt_mHL?usp=sharing">
    <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>


Below is the API reference.

"""

import tensorly as tl
import torch
import torch.nn.functional as F
from torch import nn

tl.set_backend("pytorch")


def color_correct(
    img: torch.Tensor, ref: torch.Tensor, num_iters: int = 5, eps: float = 0.5 / 255
) -> torch.Tensor:
    """
    Warp `img` to match the colors in `ref_img` using iterative color matching.

    This function performs color correction by warping the colors of the input image
    to match those of a reference image. It uses a least squares method to find a
    transformation that maps the input image's colors to the reference image's colors.

    The algorithm iteratively solves a system of linear equations, updating the set of
    unsaturated pixels in each iteration. This approach helps handle non-linear color
    transformations and reduces the impact of clipping.

    Args:
        img (torch.Tensor): Input image to be color corrected. Shape: [..., num_channels]
        ref (torch.Tensor): Reference image to match colors. Shape: [..., num_channels]
        num_iters (int, optional): Number of iterations for the color matching process.
                                   Default is 5.
        eps (float, optional): Small value to determine the range of unclipped pixels.
                               Default is 0.5 / 255.

    Returns:
        torch.Tensor: Color corrected image with the same shape as the input image.

    Note:
        - Both input and reference images should be in the range [0, 1].
        - The function works with any number of channels, but typically used with 3 (RGB).
    """
    if img.shape[-1] != ref.shape[-1]:
        raise ValueError(
            f"img's {img.shape[-1]} and ref's {ref.shape[-1]} channels must match"
        )
    num_channels = img.shape[-1]
    img_mat = img.reshape([-1, num_channels])
    ref_mat = ref.reshape([-1, num_channels])

    def is_unclipped(z):
        return (z >= eps) & (z <= 1 - eps)  # z \in [eps, 1-eps].

    mask0 = is_unclipped(img_mat)
    # Because the set of saturated pixels may change after solving for a
    # transformation, we repeatedly solve a system `num_iters` times and update
    # our estimate of which pixels are saturated.
    for _ in range(num_iters):
        # Construct the left hand side of a linear system that contains a quadratic
        # expansion of each pixel of `img`.
        a_mat = []
        for c in range(num_channels):
            a_mat.append(img_mat[:, c : (c + 1)] * img_mat[:, c:])  # Quadratic term.
        a_mat.append(img_mat)  # Linear term.
        a_mat.append(torch.ones_like(img_mat[:, :1]))  # Bias term.
        a_mat = torch.cat(a_mat, dim=-1)
        warp = []
        for c in range(num_channels):
            # Construct the right hand side of a linear system containing each color
            # of `ref`.
            b = ref_mat[:, c]
            # Ignore rows of the linear system that were saturated in the input or are
            # saturated in the current corrected color estimate.
            mask = mask0[:, c] & is_unclipped(img_mat[:, c]) & is_unclipped(b)
            ma_mat = torch.where(mask[:, None], a_mat, torch.zeros_like(a_mat))
            mb = torch.where(mask, b, torch.zeros_like(b))
            w = torch.linalg.lstsq(ma_mat, mb, rcond=-1)[0]
            assert torch.all(torch.isfinite(w))
            warp.append(w)
        warp = torch.stack(warp, dim=-1)
        # Apply the warp to update img_mat.
        img_mat = torch.clip(torch.matmul(a_mat, warp), 0, 1)
    corrected_img = torch.reshape(img_mat, img.shape)
    return corrected_img


def bilateral_grid_tv_loss(model, config):
    """Computes total variations of bilateral grids."""
    total_loss = 0.0

    for bil_grids in model.bil_grids:
        total_loss += config.bilgrid_tv_loss_mult * total_variation_loss(
            bil_grids.grids
        )

    return total_loss


def color_affine_transform(affine_mats, rgb):
    """Applies color affine transformations.

    Args:
        affine_mats (torch.Tensor): Affine transformation matrices. Supported shape: $(..., 3, 4)$.
        rgb  (torch.Tensor): Input RGB values. Supported shape: $(..., 3)$.

    Returns:
        Output transformed colors of shape $(..., 3)$.
    """
    return (
        torch.matmul(affine_mats[..., :3], rgb.unsqueeze(-1)).squeeze(-1)
        + affine_mats[..., 3]
    )


def _num_tensor_elems(t):
    return max(torch.prod(torch.tensor(t.size()[1:]).float()).item(), 1.0)


def total_variation_loss(x):  # noqa: F811
    """Returns total variation on multi-dimensional tensors.

    Args:
        x (torch.Tensor): The input tensor with shape $(B, C, ...)$, where $B$ is the batch size and $C$ is the channel size.
    """
    batch_size = x.shape[0]
    tv = 0
    for i in range(2, len(x.shape)):
        n_res = x.shape[i]
        idx1 = torch.arange(1, n_res, device=x.device)
        idx2 = torch.arange(0, n_res - 1, device=x.device)
        x1 = x.index_select(i, idx1)
        x2 = x.index_select(i, idx2)
        count = _num_tensor_elems(x1)
        tv += torch.pow((x1 - x2), 2).sum() / count
    return tv / batch_size


def slice(bil_grids, xy, rgb, grid_idx):
    """Slices a batch of 3D bilateral grids by pixel coordinates `xy` and gray-scale guidances of pixel colors `rgb`.

    Supports 2-D, 3-D, and 4-D input shapes. The first dimension of the input is the batch size
    and the last dimension is 2 for `xy`, 3 for `rgb`, and 1 for `grid_idx`.

    The return value is a dictionary containing the affine transformations `affine_mats` sliced from bilateral grids and
    the output color `rgb_out` after applying the afffine transformations.

    In the 2-D input case, `xy` is a $(N, 2)$ tensor, `rgb` is  a $(N, 3)$ tensor, and `grid_idx` is a $(N, 1)$ tensor.
    Then `affine_mats[i]` can be obtained via slicing the bilateral grid indexed at `grid_idx[i]` by `xy[i, :]` and `rgb2gray(rgb[i, :])`.
    For 3-D and 4-D input cases, the behavior of indexing bilateral grids and coordinates is the same with the 2-D case.

    .. note::
        This function can be regarded as a wrapper of `color_affine_transform` and `BilateralGrid` with a slight performance improvement.
        When `grid_idx` contains a unique index, only a single bilateral grid will used during the slicing. In this case, this function will not
        perform tensor indexing to avoid data copy and extra memory
        (see [this](https://discuss.pytorch.org/t/does-indexing-a-tensor-return-a-copy-of-it/164905)).

    Args:
        bil_grids (`BilateralGrid`): An instance of $N$ bilateral grids.
        xy (torch.Tensor): The x-y coordinates of shape $(..., 2)$ in the range of $[0,1]$.
        rgb (torch.Tensor): The RGB values of shape $(..., 3)$ for computing the guidance coordinates, ranging in $[0,1]$.
        grid_idx (torch.Tensor): The indices of bilateral grids for each slicing. Shape: $(..., 1)$.

    Returns:
        A dictionary with keys and values as follows:
        ```
        {
            "rgb": Transformed RGB colors. Shape: (..., 3),
            "rgb_affine_mats": The sliced affine transformation matrices from bilateral grids. Shape: (..., 3, 4)
        }
        ```
    """

    sh_ = rgb.shape

    grid_idx_unique = torch.unique(grid_idx)
    if len(grid_idx_unique) == 1:
        # All pixels are from a single view.
        grid_idx = grid_idx_unique  # (1,)
        xy = xy.unsqueeze(0)  # (1, ..., 2)
        rgb = rgb.unsqueeze(0)  # (1, ..., 3)
    else:
        # Pixels are randomly sampled from different views.
        if len(grid_idx.shape) == 4:
            grid_idx = grid_idx[:, 0, 0, 0]  # (chunk_size,)
        elif len(grid_idx.shape) == 3:
            grid_idx = grid_idx[:, 0, 0]  # (chunk_size,)
        elif len(grid_idx.shape) == 2:
            grid_idx = grid_idx[:, 0]  # (chunk_size,)
        else:
            raise ValueError(
                "The input to bilateral grid slicing is not supported yet."
            )

    affine_mats = bil_grids(xy, rgb, grid_idx)
    rgb = color_affine_transform(affine_mats, rgb)

    return {
        "rgb": rgb.reshape(*sh_),
        "rgb_affine_mats": affine_mats.reshape(
            *sh_[:-1], affine_mats.shape[-2], affine_mats.shape[-1]
        ),
    }


class BilateralGrid(nn.Module):
    """Class for 3D bilateral grids.

    Holds one or more than one bilateral grids.
    """

    def __init__(self, num, grid_X=16, grid_Y=16, grid_W=8):
        """
        Args:
            num (int): The number of bilateral grids (i.e., the number of views).
            grid_X (int): Defines grid width $W$.
            grid_Y (int): Defines grid height $H$.
            grid_W (int): Defines grid guidance dimension $L$.
        """
        super(BilateralGrid, self).__init__()

        self.grid_width = grid_X
        """Grid width. Type: int."""
        self.grid_height = grid_Y
        """Grid height. Type: int."""
        self.grid_guidance = grid_W
        """Grid guidance dimension. Type: int."""

        # Initialize grids.
        grid = self._init_identity_grid()
        self.grids = nn.Parameter(grid.tile(num, 1, 1, 1, 1))  # (N, 12, L, H, W)
        """ A 5-D tensor of shape $(N, 12, L, H, W)$."""

        # Weights of BT601 RGB-to-gray.
        self.register_buffer("rgb2gray_weight", torch.Tensor([[0.299, 0.587, 0.114]]))
        self.rgb2gray = lambda rgb: (rgb @ self.rgb2gray_weight.T) * 2.0 - 1.0
        """ A function that converts RGB to gray-scale guidance in $[-1, 1]$."""

    def _init_identity_grid(self):
        grid = torch.tensor(
            [
                1.0,
                0,
                0,
                0,
                0,
                1.0,
                0,
                0,
                0,
                0,
                1.0,
                0,
            ]
        ).float()
        grid = grid.repeat(
            [self.grid_guidance * self.grid_height * self.grid_width, 1]
        )  # (L * H * W, 12)
        grid = grid.reshape(
            1, self.grid_guidance, self.grid_height, self.grid_width, -1
        )  # (1, L, H, W, 12)
        grid = grid.permute(0, 4, 1, 2, 3)  # (1, 12, L, H, W)
        return grid

    def tv_loss(self):
        """Computes and returns total variation loss on the bilateral grids."""
        return total_variation_loss(self.grids)

    def forward(self, grid_xy, rgb, idx=None):
        """Bilateral grid slicing. Supports 2-D, 3-D, 4-D, and 5-D input.
        For the 2-D, 3-D, and 4-D cases, please refer to `slice`.
        For the 5-D cases, `idx` will be unused and the first dimension of `xy` should be
        equal to the number of bilateral grids. Then this function becomes PyTorch's
        [`F.grid_sample`](https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html).

        Args:
            grid_xy (torch.Tensor): The x-y coordinates in the range of $[0,1]$.
            rgb (torch.Tensor): The RGB values in the range of $[0,1]$.
            idx (torch.Tensor): The bilateral grid indices.

        Returns:
            Sliced affine matrices of shape $(..., 3, 4)$.
        """

        grids = self.grids
        input_ndims = len(grid_xy.shape)
        assert len(rgb.shape) == input_ndims

        if input_ndims > 1 and input_ndims < 5:
            # Convert input into 5D
            for i in range(5 - input_ndims):
                grid_xy = grid_xy.unsqueeze(1)
                rgb = rgb.unsqueeze(1)
            assert idx is not None
        elif input_ndims != 5:
            raise ValueError(
                "Bilateral grid slicing only takes either 2D, 3D, 4D and 5D inputs"
            )

        grids = self.grids
        if idx is not None:
            grids = grids[idx]
        assert grids.shape[0] == grid_xy.shape[0]

        # Generate slicing coordinates.
        grid_xy = (grid_xy - 0.5) * 2  # Rescale to [-1, 1].
        grid_z = self.rgb2gray(rgb)

        # print(grid_xy.shape, grid_z.shape)
        # exit()
        grid_xyz = torch.cat([grid_xy, grid_z], dim=-1)  # (N, m, h, w, 3)

        affine_mats = F.grid_sample(
            grids, grid_xyz, mode="bilinear", align_corners=True, padding_mode="border"
        )  # (N, 12, m, h, w)
        affine_mats = affine_mats.permute(0, 2, 3, 4, 1)  # (N, m, h, w, 12)
        affine_mats = affine_mats.reshape(
            *affine_mats.shape[:-1], 3, 4
        )  # (N, m, h, w, 3, 4)

        for _ in range(5 - input_ndims):
            affine_mats = affine_mats.squeeze(1)

        return affine_mats


def slice4d(bil_grid4d, xyz, rgb):
    """Slices a 4D bilateral grid by point coordinates `xyz` and gray-scale guidances of radiance colors `rgb`.

    Args:
        bil_grid4d (`BilateralGridCP4D`): The input 4D bilateral grid.
        xyz (torch.Tensor): The xyz coordinates with shape $(..., 3)$.
        rgb (torch.Tensor): The RGB values with shape $(..., 3)$.

    Returns:
        A dictionary with keys and values as follows:
        ```
        {
            "rgb": Transformed radiance RGB colors. Shape: (..., 3),
            "rgb_affine_mats": The sliced affine transformation matrices from the 4D bilateral grid. Shape: (..., 3, 4)
        }
        ```
    """

    affine_mats = bil_grid4d(xyz, rgb)
    rgb = color_affine_transform(affine_mats, rgb)

    return {"rgb": rgb, "rgb_affine_mats": affine_mats}


class _ScaledTanh(nn.Module):
    def __init__(self, s=2.0):
        super().__init__()
        self.scaler = s

    def forward(self, x):
        return torch.tanh(self.scaler * x)


class BilateralGridCP4D(nn.Module):
    """Class for low-rank 4D bilateral grids."""

    def __init__(
        self,
        grid_X=16,
        grid_Y=16,
        grid_Z=16,
        grid_W=8,
        rank=5,
        learn_gray=True,
        gray_mlp_width=8,
        gray_mlp_depth=2,
        init_noise_scale=1e-6,
        bound=2.0,
    ):
        """
        Args:
            grid_X (int): Defines grid width.
            grid_Y (int): Defines grid height.
            grid_Z (int): Defines grid depth.
            grid_W (int): Defines grid guidance dimension.
            rank (int): Rank of the 4D bilateral grid.
            learn_gray (bool): If True, an MLP will be learned to convert RGB colors to gray-scale guidances.
            gray_mlp_width (int): The MLP width for learnable guidance.
            gray_mlp_depth (int): The number of MLP layers for learnable guidance.
            init_noise_scale (float): The noise scale of the initialized factors.
            bound (float): The bound of the xyz coordinates.
        """
        super(BilateralGridCP4D, self).__init__()

        self.grid_X = grid_X
        """Grid width. Type: int."""
        self.grid_Y = grid_Y
        """Grid height. Type: int."""
        self.grid_Z = grid_Z
        """Grid depth. Type: int."""
        self.grid_W = grid_W
        """Grid guidance dimension. Type: int."""
        self.rank = rank
        """Rank of the 4D bilateral grid. Type: int."""
        self.learn_gray = learn_gray
        """Flags of learnable guidance is used. Type: bool."""
        self.gray_mlp_width = gray_mlp_width
        """The MLP width for learnable guidance. Type: int."""
        self.gray_mlp_depth = gray_mlp_depth
        """The MLP depth for learnable guidance. Type: int."""
        self.init_noise_scale = init_noise_scale
        """The noise scale of the initialized factors. Type: float."""
        self.bound = bound
        """The bound of the xyz coordinates. Type: float."""

        self._init_cp_factors_parafac()

        self.rgb2gray = None
        """ A function that converts RGB to gray-scale guidances in $[-1, 1]$.
        If `learn_gray` is True, this will be an MLP network."""

        if self.learn_gray:

            def rgb2gray_mlp_linear(layer):
                return nn.Linear(
                    self.gray_mlp_width,
                    self.gray_mlp_width if layer < self.gray_mlp_depth - 1 else 1,
                )

            def rgb2gray_mlp_actfn(_):
                return nn.ReLU(inplace=True)

            self.rgb2gray = nn.Sequential(
                *(
                    [nn.Linear(3, self.gray_mlp_width)]
                    + [
                        nn_module(layer)
                        for layer in range(1, self.gray_mlp_depth)
                        for nn_module in [rgb2gray_mlp_actfn, rgb2gray_mlp_linear]
                    ]
                    + [_ScaledTanh(2.0)]
                )
            )
        else:
            # Weights of BT601/BT470 RGB-to-gray.
            self.register_buffer(
                "rgb2gray_weight", torch.Tensor([[0.299, 0.587, 0.114]])
            )
            self.rgb2gray = lambda rgb: (rgb @ self.rgb2gray_weight.T) * 2.0 - 1.0

    def _init_identity_grid(self):
        grid = torch.tensor(
            [
                1.0,
                0,
                0,
                0,
                0,
                1.0,
                0,
                0,
                0,
                0,
                1.0,
                0,
            ]
        ).float()
        grid = grid.repeat([self.grid_W * self.grid_Z * self.grid_Y * self.grid_X, 1])
        grid = grid.reshape(self.grid_W, self.grid_Z, self.grid_Y, self.grid_X, -1)
        grid = grid.permute(4, 0, 1, 2, 3)  # (12, grid_W, grid_Z, grid_Y, grid_X)
        return grid

    def _init_cp_factors_parafac(self):
        # Initialize identity grids.
        init_grids = self._init_identity_grid()
        # Random noises are added to avoid singularity.
        init_grids = torch.randn_like(init_grids) * self.init_noise_scale + init_grids
        from tensorly.decomposition import parafac

        # Initialize grid CP factors
        _, facs = parafac(init_grids.clone().detach(), rank=self.rank)

        self.num_facs = len(facs)

        self.fac_0 = nn.Linear(facs[0].shape[0], facs[0].shape[1], bias=False)
        self.fac_0.weight = nn.Parameter(facs[0])  # (12, rank)

        for i in range(1, self.num_facs):
            fac = facs[i].T  # (rank, grid_size)
            fac = fac.view(1, fac.shape[0], fac.shape[1], 1)  # (1, rank, grid_size, 1)
            self.register_buffer(f"fac_{i}_init", fac)

            fac_resid = torch.zeros_like(fac)
            self.register_parameter(f"fac_{i}", nn.Parameter(fac_resid))

    def tv_loss(self):
        """Computes and returns total variation loss on the factors of the low-rank 4D bilateral grids."""

        total_loss = 0
        for i in range(1, self.num_facs):
            fac = self.get_parameter(f"fac_{i}")
            total_loss += total_variation_loss(fac)

        return total_loss

    def forward(self, xyz, rgb):
        """Low-rank 4D bilateral grid slicing.

        Args:
            xyz (torch.Tensor): The xyz coordinates with shape $(..., 3)$.
            rgb (torch.Tensor): The corresponding RGB values with shape $(..., 3)$.

        Returns:
            Sliced affine matrices with shape $(..., 3, 4)$.
        """
        sh_ = xyz.shape
        xyz = xyz.reshape(-1, 3)  # flatten (N, 3)
        rgb = rgb.reshape(-1, 3)  # flatten (N, 3)

        xyz = xyz / self.bound
        assert self.rgb2gray is not None
        gray = self.rgb2gray(rgb)
        xyzw = torch.cat([xyz, gray], dim=-1)  # (N, 4)
        xyzw = xyzw.transpose(0, 1)  # (4, N)
        coords = torch.stack([torch.zeros_like(xyzw), xyzw], dim=-1)  # (4, N, 2)
        coords = coords.unsqueeze(1)  # (4, 1, N, 2)

        coef = 1.0
        for i in range(1, self.num_facs):
            fac = self.get_parameter(f"fac_{i}") + self.get_buffer(f"fac_{i}_init")
            coef = coef * F.grid_sample(
                fac, coords[[i - 1]], align_corners=True, padding_mode="border"
            )  # [1, rank, 1, N]
        coef = coef.squeeze([0, 2]).transpose(0, 1)  # (N, rank) #type: ignore
        mat = self.fac_0(coef)
        return mat.reshape(*sh_[:-1], 3, 4)