gokaygokay commited on
Commit
9ad295d
Β·
1 Parent(s): c63008c
Files changed (3) hide show
  1. app.py +2 -0
  2. histogram_processor.py +129 -0
  3. requirements.txt +2 -1
app.py CHANGED
@@ -4,6 +4,7 @@ from lut_processor import create_lut_tab
4
  from sharpen_processor import create_sharpen_tab
5
  from color_match_processor import create_color_match_tab
6
  from simple_effects_processor import create_effects_tab
 
7
 
8
  with gr.Blocks(title="Image Processing Suite") as demo:
9
  gr.Markdown("# Image Processing Suite")
@@ -13,6 +14,7 @@ with gr.Blocks(title="Image Processing Suite") as demo:
13
  create_sharpen_tab()
14
  create_color_match_tab()
15
  create_effects_tab()
 
16
 
17
  if __name__ == "__main__":
18
  demo.launch(share=True)
 
4
  from sharpen_processor import create_sharpen_tab
5
  from color_match_processor import create_color_match_tab
6
  from simple_effects_processor import create_effects_tab
7
+ from histogram_processor import create_histogram_tab
8
 
9
  with gr.Blocks(title="Image Processing Suite") as demo:
10
  gr.Markdown("# Image Processing Suite")
 
14
  create_sharpen_tab()
15
  create_color_match_tab()
16
  create_effects_tab()
17
+ create_histogram_tab()
18
 
19
  if __name__ == "__main__":
20
  demo.launch(share=True)
histogram_processor.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import numpy as np
6
+ from skimage.exposure import match_histograms
7
+
8
+ class HistogramMatcher(nn.Module):
9
+ def __init__(self, differentiable=False):
10
+ super(HistogramMatcher, self).__init__()
11
+ self.differentiable = differentiable
12
+
13
+ def forward(self, dst, ref):
14
+ B, C, H, W = dst.size()
15
+ hist_dst = self.cal_hist(dst)
16
+ hist_ref = self.cal_hist(ref)
17
+ tables = self.cal_trans_batch(hist_dst, hist_ref)
18
+
19
+ rst = dst.clone()
20
+ for b in range(B):
21
+ for c in range(C):
22
+ rst[b,c] = tables[b*c, (dst[b,c] * 255).long()]
23
+
24
+ return rst / 255.
25
+
26
+ def cal_hist(self, img):
27
+ B, C, H, W = img.size()
28
+ if self.differentiable:
29
+ hists = self.soft_histc_batch(img * 255, bins=256, min=0, max=256, sigma=75)
30
+ else:
31
+ hists = torch.stack([torch.histc(img[b,c] * 255, bins=256, min=0, max=255)
32
+ for b in range(B) for c in range(C)])
33
+
34
+ hists = hists.float()
35
+ hists = F.normalize(hists, p=1)
36
+ bc, n = hists.size()
37
+ triu = torch.ones(bc, n, n, device=hists.device).triu()
38
+ hists = torch.bmm(hists[:,None,:], triu)[:,0,:]
39
+ return hists
40
+
41
+ def soft_histc_batch(self, x, bins=256, min=0, max=256, sigma=75):
42
+ B, C, H, W = x.size()
43
+ x = x.view(B*C, -1)
44
+ delta = float(max - min) / float(bins)
45
+ centers = float(min) + delta * (torch.arange(bins, device=x.device) + 0.5)
46
+
47
+ x = torch.unsqueeze(x, 1)
48
+ centers = centers[None,:,None]
49
+ x = x - centers
50
+ x = torch.sigmoid(sigma * (x + delta/2)) - torch.sigmoid(sigma * (x - delta/2))
51
+ x = x.sum(dim=2)
52
+ return x
53
+
54
+ def cal_trans_batch(self, hist_dst, hist_ref):
55
+ hist_dst = hist_dst[:,None,:].repeat(1,256,1)
56
+ hist_ref = hist_ref[:,:,None].repeat(1,1,256)
57
+ table = hist_dst - hist_ref
58
+ table = torch.where(table>=0, 1., 0.)
59
+ table = torch.sum(table, dim=1) - 1
60
+ table = torch.clamp(table, min=0, max=255)
61
+ return table
62
+
63
+ def apply_histogram_matching(image, reference, method, factor):
64
+ if image is None or reference is None:
65
+ return None
66
+
67
+ # Convert to torch tensors and normalize
68
+ image = torch.from_numpy(image).float() / 255.0
69
+ reference = torch.from_numpy(reference).float() / 255.0
70
+
71
+ # Add batch dimension and rearrange to BCHW
72
+ image = image.unsqueeze(0).permute(0, 3, 1, 2)
73
+ reference = reference.unsqueeze(0).permute(0, 3, 1, 2)
74
+
75
+ if method == "pytorch":
76
+ # Apply PyTorch-based histogram matching
77
+ matcher = HistogramMatcher(differentiable=True)
78
+ matched = matcher(image, reference)
79
+ else: # skimage
80
+ # Convert back to numpy for skimage
81
+ matched = match_histograms(
82
+ image.permute(0, 2, 3, 1).numpy(),
83
+ reference.permute(0, 2, 3, 1).numpy(),
84
+ channel_axis=3
85
+ )
86
+ matched = torch.from_numpy(matched).permute(0, 3, 1, 2)
87
+
88
+ # Apply factor blending
89
+ result = factor * matched + (1 - factor) * image
90
+
91
+ # Convert back to HWC format and to uint8
92
+ output = result.squeeze(0).permute(1, 2, 0)
93
+ output = (output.clamp(0, 1).numpy() * 255).astype(np.uint8)
94
+
95
+ return output
96
+
97
+ def create_histogram_tab():
98
+ with gr.Tab("Histogram Matching"):
99
+ gr.Markdown("Match histograms between images using PyTorch or scikit-image methods")
100
+
101
+ with gr.Row():
102
+ with gr.Column():
103
+ input_image = gr.Image(label="Input Image")
104
+ reference_image = gr.Image(label="Reference Image")
105
+
106
+ method = gr.Radio(
107
+ choices=["pytorch", "skimage"],
108
+ value="pytorch",
109
+ label="Matching Method"
110
+ )
111
+
112
+ factor = gr.Slider(
113
+ minimum=0.0,
114
+ maximum=1.0,
115
+ value=1.0,
116
+ step=0.05,
117
+ label="Blend Factor"
118
+ )
119
+
120
+ match_btn = gr.Button("Apply Histogram Matching")
121
+
122
+ with gr.Column():
123
+ output_image = gr.Image(label="Matched Image")
124
+
125
+ match_btn.click(
126
+ fn=apply_histogram_matching,
127
+ inputs=[input_image, reference_image, method, factor],
128
+ outputs=output_image
129
+ )
requirements.txt CHANGED
@@ -3,4 +3,5 @@ torch
3
  torchvision
4
  numpy==1.26.4
5
  colour-science
6
- kornia
 
 
3
  torchvision
4
  numpy==1.26.4
5
  colour-science
6
+ kornia
7
+ scikit-image