raw2logit / processing /pipeline_torch.py
willis
reorganize
0220054
raw
history blame
11 kB
import os
from numpy.lib.function_base import interp
import torch
import torch.nn as nn
if not os.path.exists('README.md'):
os.chdir('..')
from processing.pipeline_numpy import processing as default_processing
from utils.base import np2torch, torch2np
import segmentation_models_pytorch as smp
from utils.debug import debug
K_G = torch.Tensor([[0, 1, 0],
[1, 4, 1],
[0, 1, 0]]) / 4
K_RB = torch.Tensor([[1, 2, 1],
[2, 4, 2],
[1, 2, 1]]) / 4
M_RGB_2_YUV = torch.Tensor([[0.299, 0.587, 0.114],
[-0.14714119, -0.28886916, 0.43601035],
[0.61497538, -0.51496512, -0.10001026]])
M_YUV_2_RGB = torch.Tensor([[1.0000000000e+00, -4.1827794561e-09, 1.1398830414e+00],
[1.0000000000e+00, -3.9464232326e-01, -5.8062183857e-01],
[1.0000000000e+00, 2.0320618153e+00, -1.2232658220e-09]])
K_BLUR = torch.Tensor([[6.9625e-08, 2.8089e-05, 2.0755e-04, 2.8089e-05, 6.9625e-08],
[2.8089e-05, 1.1332e-02, 8.3731e-02, 1.1332e-02, 2.8089e-05],
[2.0755e-04, 8.3731e-02, 6.1869e-01, 8.3731e-02, 2.0755e-04],
[2.8089e-05, 1.1332e-02, 8.3731e-02, 1.1332e-02, 2.8089e-05],
[6.9625e-08, 2.8089e-05, 2.0755e-04, 2.8089e-05, 6.9625e-08]])
K_SHARP = torch.Tensor([[0, -1, 0],
[-1, 5, -1],
[0, -1, 0]])
DEFAULT_CAMERA_PARAMS = (
[0., 0., 0., 0.],
[1., 1., 1.],
[1., 0., 0., 0., 1., 0., 0., 0., 1.],
)
class RawToRGB(nn.Module):
def __init__(self, reduce_size=True, out_channels=3, track_stages=False, normalize_mosaic=None):
super().__init__()
self.stages = None
self.buffer = None
self.reduce_size = reduce_size
self.out_channels = out_channels
self.track_stages = track_stages
self.normalize_mosaic = normalize_mosaic
def forward(self, raw):
self.stages = {}
self.buffer = {}
rgb = raw2rgb(raw, reduce_size=self.reduce_size, out_channels=self.out_channels)
self.stages['demosaic'] = rgb
if self.normalize_mosaic:
rgb = self.normalize_mosaic(rgb)
if self.track_stages and raw.requires_grad:
for stage in self.stages.values():
stage.retain_grad()
self.buffer['processed_rgb'] = rgb
return rgb
class NNProcessing(nn.Module):
def __init__(self, track_stages=False, normalize_mosaic=None, batch_norm_output=True):
super().__init__()
self.stages = None
self.buffer = None
self.track_stages = track_stages
self.model = smp.UnetPlusPlus(
encoder_name='resnet34',
encoder_depth=3,
decoder_channels=[256, 128, 64],
in_channels=3,
classes=3,
)
self.batch_norm = None if not batch_norm_output else nn.BatchNorm2d(3, affine=False)
self.normalize_mosaic = normalize_mosaic
def forward(self, raw):
self.stages = {}
self.buffer = {}
# self.stages['raw'] = raw
rgb = raw2rgb(raw)
if self.normalize_mosaic:
rgb = self.normalize_mosaic(rgb)
self.stages['demosaic'] = rgb
rgb = self.model(rgb)
if self.batch_norm is not None:
rgb = self.batch_norm(rgb)
self.stages['rgb'] = rgb
if self.track_stages and raw.requires_grad:
for stage in self.stages.values():
stage.retain_grad()
self.buffer['processed_rgb'] = rgb
return rgb
def add_additive_layer(processor):
processor.additive_layer = nn.Parameter(torch.zeros((1, 3, 256, 256)))
# processor.additive_layer = nn.Parameter(0.001 * torch.randn((1, 3, 256, 256)))
class ParametrizedProcessing(nn.Module):
def __init__(self, camera_parameters, track_stages=False, batch_norm_output=True):
super().__init__()
self.stages = None
self.buffer = None
self.track_stages = track_stages
black_level, white_balance, colour_matrix = camera_parameters
self.black_level = nn.Parameter(torch.as_tensor(black_level))
self.white_balance = nn.Parameter(torch.as_tensor(white_balance).reshape(1, 3))
self.colour_correction = nn.Parameter(torch.as_tensor(colour_matrix).reshape(3, 3))
self.gamma_correct = nn.Parameter(torch.Tensor([2.2]))
self.debayer = Debayer()
self.sharpening_filter = nn.Conv2d(1, 1, kernel_size=3, padding=1, bias=False)
self.sharpening_filter.weight.data[0][0] = K_SHARP.clone()
self.gaussian_blur = nn.Conv2d(1, 1, kernel_size=5, padding=2, padding_mode='reflect', bias=False)
self.gaussian_blur.weight.data[0][0] = K_BLUR.clone()
self.batch_norm = nn.BatchNorm2d(3, affine=False) if batch_norm_output else None
self.register_buffer('M_RGB_2_YUV', M_RGB_2_YUV.clone())
self.register_buffer('M_YUV_2_RGB', M_YUV_2_RGB.clone())
self.additive_layer = None # this can be added in later
def forward(self, raw):
assert raw.ndim == 3, f"needs dims (B, H, W), got {raw.shape}"
self.stages = {}
self.buffer = {}
# self.stages['raw'] = raw
rgb = raw2rgb(raw, black_level=self.black_level, reduce_size=False)
rgb = rgb.contiguous()
self.stages['demosaic'] = rgb
rgb = self.debayer(rgb)
# self.stages['debayer'] = rgb
rgb = torch.einsum('bchw,kc->bchw', rgb, self.white_balance).contiguous()
rgb = torch.einsum('bchw,kc->bkhw', rgb, self.colour_correction).contiguous()
self.stages['color_correct'] = rgb
yuv = torch.einsum('bchw,kc->bkhw', rgb, self.M_RGB_2_YUV).contiguous()
yuv[:, [0], ...] = self.sharpening_filter(yuv[:, [0], ...])
if self.track_stages: # keep stage in computational graph for grad information
rgb = torch.einsum('bchw,kc->bkhw', yuv.clone(), self.M_YUV_2_RGB).contiguous()
self.stages['sharpening'] = rgb
yuv = torch.einsum('bchw,kc->bkhw', rgb, self.M_RGB_2_YUV).contiguous()
yuv[:, [0], ...] = self.gaussian_blur(yuv[:, [0], ...])
rgb = torch.einsum('bchw,kc->bkhw', yuv, self.M_YUV_2_RGB).contiguous()
self.stages['gaussian'] = rgb
rgb = torch.clip(rgb, 1e-5, 1)
self.stages['clipped'] = rgb
rgb = torch.exp((1 / self.gamma_correct) * torch.log(rgb))
self.stages['gamma_correct'] = rgb
if self.additive_layer is not None:
rgb = rgb + self.additive_layer
self.stages['noise'] = rgb
if self.batch_norm is not None:
rgb = self.batch_norm(rgb)
if self.track_stages and raw.requires_grad:
for stage in self.stages.values():
stage.retain_grad()
self.buffer['processed_rgb'] = rgb
return rgb
class Debayer(nn.Conv2d):
def __init__(self):
super().__init__(3, 3, kernel_size=3, padding=1, padding_mode='reflect', bias=False) # default_pipeline uses 'replicate'
self.weight.data.fill_(0)
self.weight.data[0, 0] = K_RB.clone()
self.weight.data[1, 1] = K_G.clone()
self.weight.data[2, 2] = K_RB.clone()
def raw2rgb(raw, black_level=None, reduce_size=True, out_channels=3):
"""transform raw image with 1 channel to rgb with 3 channels
Args:
raw (Tensor): raw Tensor of shape (B, H, W)
black_level (iterable, optional): RGGB black level values to subtract
reduce_size (bool, optional): if False, the output image will have the same height and width
as the raw input, i.e. (B, C, H, W), empty values are filled with zeros.
if True, the output dimensions are reduced by half (B, C, H//2, W//2),
the two green channels are averaged.
out_channels (int, optional): number of output channels. One of {3, 4}.
"""
assert out_channels in [3, 4]
if black_level is None:
black_level = [0, 0, 0, 0]
Bch, H, W = raw.shape
R = raw[:, 0::2, 0::2] - black_level[0] # R
G1 = raw[:, 0::2, 1::2] - black_level[1] # G
G2 = raw[:, 1::2, 0::2] - black_level[2] # G
B = raw[:, 1::2, 1::2] - black_level[3] # B
if reduce_size:
rgb = torch.zeros((Bch, out_channels, H // 2, W // 2), device=raw.device)
if out_channels == 3:
rgb[:, 0, :, :] = R
rgb[:, 1, :, :] = (G1 + G2) / 2
rgb[:, 2, :, :] = B
elif out_channels == 4:
rgb[:, 0, :, :] = R
rgb[:, 1, :, :] = G1
rgb[:, 2, :, :] = G2
rgb[:, 3, :, :] = B
else:
rgb = torch.zeros((Bch, out_channels, H, W), device=raw.device)
if out_channels == 3:
rgb[:, 0, 0::2, 0::2] = R
rgb[:, 1, 0::2, 1::2] = G1
rgb[:, 1, 1::2, 0::2] = G2
rgb[:, 2, 1::2, 1::2] = B
elif out_channels == 4:
rgb[:, 0, 0::2, 0::2] = R
rgb[:, 1, 0::2, 1::2] = G1
rgb[:, 2, 1::2, 0::2] = G2
rgb[:, 3, 1::2, 1::2] = B
return rgb
# pipeline validation
if __name__ == "__main__":
import torch
import numpy as np
if not os.path.exists('README.md'):
os.chdir('..')
import matplotlib.pyplot as plt
from dataset import get_dataset
from utils.base import np2torch, torch2np
from utils.debug import debug
from processing.pipeline_numpy import processing as default_processing
raw_dataset = get_dataset('DS')
loader = torch.utils.data.DataLoader(raw_dataset, batch_size=1)
batch_raw, batch_mask = next(iter(loader))
# torch proc
camera_parameters = raw_dataset.camera_parameters
black_level = camera_parameters[0]
proc = ParametrizedProcessing(camera_parameters)
batch_rgb = proc(batch_raw)
rgb = batch_rgb[0]
# numpy proc
raw_img = batch_raw[0]
numpy_raw = torch2np(raw_img)
default_rgb = default_processing(numpy_raw, *camera_parameters,
sharpening='sharpening_filter', denoising='gaussian_denoising')
rgb_valid = np2torch(default_rgb)
print("pipeline norm difference:", (rgb - rgb_valid).norm().item())
rgb_mosaic = raw2rgb(batch_raw, reduce_size=False).squeeze()
rgb_reduced = raw2rgb(batch_raw, reduce_size=True).squeeze()
plt.figure(figsize=(16, 8))
plt.subplot(151)
plt.title('Raw')
plt.imshow(torch2np(raw_img))
plt.subplot(152)
plt.title('RGB Mosaic')
plt.imshow(torch2np(rgb_mosaic))
plt.subplot(153)
plt.title('RGB Reduced')
plt.imshow(torch2np(rgb_reduced))
plt.subplot(154)
plt.title('Torch Pipeline')
plt.imshow(torch2np(rgb))
plt.subplot(155)
plt.title('Default Pipeline')
plt.imshow(torch2np(rgb_valid))
plt.show()
# assert rgb.allclose(rgb_valid)