import os from numpy.lib.function_base import interp import torch import torch.nn as nn if not os.path.exists('README.md'): os.chdir('..') from processing.pipeline_numpy import processing as default_processing from utils.base import np2torch, torch2np import segmentation_models_pytorch as smp from utils.debug import debug K_G = torch.Tensor([[0, 1, 0], [1, 4, 1], [0, 1, 0]]) / 4 K_RB = torch.Tensor([[1, 2, 1], [2, 4, 2], [1, 2, 1]]) / 4 M_RGB_2_YUV = torch.Tensor([[0.299, 0.587, 0.114], [-0.14714119, -0.28886916, 0.43601035], [0.61497538, -0.51496512, -0.10001026]]) M_YUV_2_RGB = torch.Tensor([[1.0000000000e+00, -4.1827794561e-09, 1.1398830414e+00], [1.0000000000e+00, -3.9464232326e-01, -5.8062183857e-01], [1.0000000000e+00, 2.0320618153e+00, -1.2232658220e-09]]) K_BLUR = torch.Tensor([[6.9625e-08, 2.8089e-05, 2.0755e-04, 2.8089e-05, 6.9625e-08], [2.8089e-05, 1.1332e-02, 8.3731e-02, 1.1332e-02, 2.8089e-05], [2.0755e-04, 8.3731e-02, 6.1869e-01, 8.3731e-02, 2.0755e-04], [2.8089e-05, 1.1332e-02, 8.3731e-02, 1.1332e-02, 2.8089e-05], [6.9625e-08, 2.8089e-05, 2.0755e-04, 2.8089e-05, 6.9625e-08]]) K_SHARP = torch.Tensor([[0, -1, 0], [-1, 5, -1], [0, -1, 0]]) DEFAULT_CAMERA_PARAMS = ( [0., 0., 0., 0.], [1., 1., 1.], [1., 0., 0., 0., 1., 0., 0., 0., 1.], ) class RawToRGB(nn.Module): def __init__(self, reduce_size=True, out_channels=3, track_stages=False, normalize_mosaic=None): super().__init__() self.stages = None self.buffer = None self.reduce_size = reduce_size self.out_channels = out_channels self.track_stages = track_stages self.normalize_mosaic = normalize_mosaic def forward(self, raw): self.stages = {} self.buffer = {} rgb = raw2rgb(raw, reduce_size=self.reduce_size, out_channels=self.out_channels) self.stages['demosaic'] = rgb if self.normalize_mosaic: rgb = self.normalize_mosaic(rgb) if self.track_stages and raw.requires_grad: for stage in self.stages.values(): stage.retain_grad() self.buffer['processed_rgb'] = rgb return rgb class NNProcessing(nn.Module): def __init__(self, track_stages=False, normalize_mosaic=None, batch_norm_output=True): super().__init__() self.stages = None self.buffer = None self.track_stages = track_stages self.model = smp.UnetPlusPlus( encoder_name='resnet34', encoder_depth=3, decoder_channels=[256, 128, 64], in_channels=3, classes=3, ) self.batch_norm = None if not batch_norm_output else nn.BatchNorm2d(3, affine=False) self.normalize_mosaic = normalize_mosaic def forward(self, raw): self.stages = {} self.buffer = {} # self.stages['raw'] = raw rgb = raw2rgb(raw) if self.normalize_mosaic: rgb = self.normalize_mosaic(rgb) self.stages['demosaic'] = rgb rgb = self.model(rgb) if self.batch_norm is not None: rgb = self.batch_norm(rgb) self.stages['rgb'] = rgb if self.track_stages and raw.requires_grad: for stage in self.stages.values(): stage.retain_grad() self.buffer['processed_rgb'] = rgb return rgb def add_additive_layer(processor): processor.additive_layer = nn.Parameter(torch.zeros((1, 3, 256, 256))) # processor.additive_layer = nn.Parameter(0.001 * torch.randn((1, 3, 256, 256))) class ParametrizedProcessing(nn.Module): def __init__(self, camera_parameters, track_stages=False, batch_norm_output=True): super().__init__() self.stages = None self.buffer = None self.track_stages = track_stages black_level, white_balance, colour_matrix = camera_parameters self.black_level = nn.Parameter(torch.as_tensor(black_level)) self.white_balance = nn.Parameter(torch.as_tensor(white_balance).reshape(1, 3)) self.colour_correction = nn.Parameter(torch.as_tensor(colour_matrix).reshape(3, 3)) self.gamma_correct = nn.Parameter(torch.Tensor([2.2])) self.debayer = Debayer() self.sharpening_filter = nn.Conv2d(1, 1, kernel_size=3, padding=1, bias=False) self.sharpening_filter.weight.data[0][0] = K_SHARP.clone() self.gaussian_blur = nn.Conv2d(1, 1, kernel_size=5, padding=2, padding_mode='reflect', bias=False) self.gaussian_blur.weight.data[0][0] = K_BLUR.clone() self.batch_norm = nn.BatchNorm2d(3, affine=False) if batch_norm_output else None self.register_buffer('M_RGB_2_YUV', M_RGB_2_YUV.clone()) self.register_buffer('M_YUV_2_RGB', M_YUV_2_RGB.clone()) self.additive_layer = None # this can be added in later def forward(self, raw): assert raw.ndim == 3, f"needs dims (B, H, W), got {raw.shape}" self.stages = {} self.buffer = {} # self.stages['raw'] = raw rgb = raw2rgb(raw, black_level=self.black_level, reduce_size=False) rgb = rgb.contiguous() self.stages['demosaic'] = rgb rgb = self.debayer(rgb) # self.stages['debayer'] = rgb rgb = torch.einsum('bchw,kc->bchw', rgb, self.white_balance).contiguous() rgb = torch.einsum('bchw,kc->bkhw', rgb, self.colour_correction).contiguous() self.stages['color_correct'] = rgb yuv = torch.einsum('bchw,kc->bkhw', rgb, self.M_RGB_2_YUV).contiguous() yuv[:, [0], ...] = self.sharpening_filter(yuv[:, [0], ...]) if self.track_stages: # keep stage in computational graph for grad information rgb = torch.einsum('bchw,kc->bkhw', yuv.clone(), self.M_YUV_2_RGB).contiguous() self.stages['sharpening'] = rgb yuv = torch.einsum('bchw,kc->bkhw', rgb, self.M_RGB_2_YUV).contiguous() yuv[:, [0], ...] = self.gaussian_blur(yuv[:, [0], ...]) rgb = torch.einsum('bchw,kc->bkhw', yuv, self.M_YUV_2_RGB).contiguous() self.stages['gaussian'] = rgb rgb = torch.clip(rgb, 1e-5, 1) self.stages['clipped'] = rgb rgb = torch.exp((1 / self.gamma_correct) * torch.log(rgb)) self.stages['gamma_correct'] = rgb if self.additive_layer is not None: rgb = rgb + self.additive_layer self.stages['noise'] = rgb if self.batch_norm is not None: rgb = self.batch_norm(rgb) if self.track_stages and raw.requires_grad: for stage in self.stages.values(): stage.retain_grad() self.buffer['processed_rgb'] = rgb return rgb class Debayer(nn.Conv2d): def __init__(self): super().__init__(3, 3, kernel_size=3, padding=1, padding_mode='reflect', bias=False) # default_pipeline uses 'replicate' self.weight.data.fill_(0) self.weight.data[0, 0] = K_RB.clone() self.weight.data[1, 1] = K_G.clone() self.weight.data[2, 2] = K_RB.clone() def raw2rgb(raw, black_level=None, reduce_size=True, out_channels=3): """transform raw image with 1 channel to rgb with 3 channels Args: raw (Tensor): raw Tensor of shape (B, H, W) black_level (iterable, optional): RGGB black level values to subtract reduce_size (bool, optional): if False, the output image will have the same height and width as the raw input, i.e. (B, C, H, W), empty values are filled with zeros. if True, the output dimensions are reduced by half (B, C, H//2, W//2), the two green channels are averaged. out_channels (int, optional): number of output channels. One of {3, 4}. """ assert out_channels in [3, 4] if black_level is None: black_level = [0, 0, 0, 0] Bch, H, W = raw.shape R = raw[:, 0::2, 0::2] - black_level[0] # R G1 = raw[:, 0::2, 1::2] - black_level[1] # G G2 = raw[:, 1::2, 0::2] - black_level[2] # G B = raw[:, 1::2, 1::2] - black_level[3] # B if reduce_size: rgb = torch.zeros((Bch, out_channels, H // 2, W // 2), device=raw.device) if out_channels == 3: rgb[:, 0, :, :] = R rgb[:, 1, :, :] = (G1 + G2) / 2 rgb[:, 2, :, :] = B elif out_channels == 4: rgb[:, 0, :, :] = R rgb[:, 1, :, :] = G1 rgb[:, 2, :, :] = G2 rgb[:, 3, :, :] = B else: rgb = torch.zeros((Bch, out_channels, H, W), device=raw.device) if out_channels == 3: rgb[:, 0, 0::2, 0::2] = R rgb[:, 1, 0::2, 1::2] = G1 rgb[:, 1, 1::2, 0::2] = G2 rgb[:, 2, 1::2, 1::2] = B elif out_channels == 4: rgb[:, 0, 0::2, 0::2] = R rgb[:, 1, 0::2, 1::2] = G1 rgb[:, 2, 1::2, 0::2] = G2 rgb[:, 3, 1::2, 1::2] = B return rgb # pipeline validation if __name__ == "__main__": import torch import numpy as np if not os.path.exists('README.md'): os.chdir('..') import matplotlib.pyplot as plt from dataset import get_dataset from utils.base import np2torch, torch2np from utils.debug import debug from processing.pipeline_numpy import processing as default_processing raw_dataset = get_dataset('DS') loader = torch.utils.data.DataLoader(raw_dataset, batch_size=1) batch_raw, batch_mask = next(iter(loader)) # torch proc camera_parameters = raw_dataset.camera_parameters black_level = camera_parameters[0] proc = ParametrizedProcessing(camera_parameters) batch_rgb = proc(batch_raw) rgb = batch_rgb[0] # numpy proc raw_img = batch_raw[0] numpy_raw = torch2np(raw_img) default_rgb = default_processing(numpy_raw, *camera_parameters, sharpening='sharpening_filter', denoising='gaussian_denoising') rgb_valid = np2torch(default_rgb) print("pipeline norm difference:", (rgb - rgb_valid).norm().item()) rgb_mosaic = raw2rgb(batch_raw, reduce_size=False).squeeze() rgb_reduced = raw2rgb(batch_raw, reduce_size=True).squeeze() plt.figure(figsize=(16, 8)) plt.subplot(151) plt.title('Raw') plt.imshow(torch2np(raw_img)) plt.subplot(152) plt.title('RGB Mosaic') plt.imshow(torch2np(rgb_mosaic)) plt.subplot(153) plt.title('RGB Reduced') plt.imshow(torch2np(rgb_reduced)) plt.subplot(154) plt.title('Torch Pipeline') plt.imshow(torch2np(rgb)) plt.subplot(155) plt.title('Default Pipeline') plt.imshow(torch2np(rgb_valid)) plt.show() # assert rgb.allclose(rgb_valid)