Spaces:
Runtime error
Runtime error
Commit
Β·
9ad295d
1
Parent(s):
c63008c
add tab
Browse files- app.py +2 -0
- histogram_processor.py +129 -0
- requirements.txt +2 -1
app.py
CHANGED
@@ -4,6 +4,7 @@ from lut_processor import create_lut_tab
|
|
4 |
from sharpen_processor import create_sharpen_tab
|
5 |
from color_match_processor import create_color_match_tab
|
6 |
from simple_effects_processor import create_effects_tab
|
|
|
7 |
|
8 |
with gr.Blocks(title="Image Processing Suite") as demo:
|
9 |
gr.Markdown("# Image Processing Suite")
|
@@ -13,6 +14,7 @@ with gr.Blocks(title="Image Processing Suite") as demo:
|
|
13 |
create_sharpen_tab()
|
14 |
create_color_match_tab()
|
15 |
create_effects_tab()
|
|
|
16 |
|
17 |
if __name__ == "__main__":
|
18 |
demo.launch(share=True)
|
|
|
4 |
from sharpen_processor import create_sharpen_tab
|
5 |
from color_match_processor import create_color_match_tab
|
6 |
from simple_effects_processor import create_effects_tab
|
7 |
+
from histogram_processor import create_histogram_tab
|
8 |
|
9 |
with gr.Blocks(title="Image Processing Suite") as demo:
|
10 |
gr.Markdown("# Image Processing Suite")
|
|
|
14 |
create_sharpen_tab()
|
15 |
create_color_match_tab()
|
16 |
create_effects_tab()
|
17 |
+
create_histogram_tab()
|
18 |
|
19 |
if __name__ == "__main__":
|
20 |
demo.launch(share=True)
|
histogram_processor.py
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
import numpy as np
|
6 |
+
from skimage.exposure import match_histograms
|
7 |
+
|
8 |
+
class HistogramMatcher(nn.Module):
|
9 |
+
def __init__(self, differentiable=False):
|
10 |
+
super(HistogramMatcher, self).__init__()
|
11 |
+
self.differentiable = differentiable
|
12 |
+
|
13 |
+
def forward(self, dst, ref):
|
14 |
+
B, C, H, W = dst.size()
|
15 |
+
hist_dst = self.cal_hist(dst)
|
16 |
+
hist_ref = self.cal_hist(ref)
|
17 |
+
tables = self.cal_trans_batch(hist_dst, hist_ref)
|
18 |
+
|
19 |
+
rst = dst.clone()
|
20 |
+
for b in range(B):
|
21 |
+
for c in range(C):
|
22 |
+
rst[b,c] = tables[b*c, (dst[b,c] * 255).long()]
|
23 |
+
|
24 |
+
return rst / 255.
|
25 |
+
|
26 |
+
def cal_hist(self, img):
|
27 |
+
B, C, H, W = img.size()
|
28 |
+
if self.differentiable:
|
29 |
+
hists = self.soft_histc_batch(img * 255, bins=256, min=0, max=256, sigma=75)
|
30 |
+
else:
|
31 |
+
hists = torch.stack([torch.histc(img[b,c] * 255, bins=256, min=0, max=255)
|
32 |
+
for b in range(B) for c in range(C)])
|
33 |
+
|
34 |
+
hists = hists.float()
|
35 |
+
hists = F.normalize(hists, p=1)
|
36 |
+
bc, n = hists.size()
|
37 |
+
triu = torch.ones(bc, n, n, device=hists.device).triu()
|
38 |
+
hists = torch.bmm(hists[:,None,:], triu)[:,0,:]
|
39 |
+
return hists
|
40 |
+
|
41 |
+
def soft_histc_batch(self, x, bins=256, min=0, max=256, sigma=75):
|
42 |
+
B, C, H, W = x.size()
|
43 |
+
x = x.view(B*C, -1)
|
44 |
+
delta = float(max - min) / float(bins)
|
45 |
+
centers = float(min) + delta * (torch.arange(bins, device=x.device) + 0.5)
|
46 |
+
|
47 |
+
x = torch.unsqueeze(x, 1)
|
48 |
+
centers = centers[None,:,None]
|
49 |
+
x = x - centers
|
50 |
+
x = torch.sigmoid(sigma * (x + delta/2)) - torch.sigmoid(sigma * (x - delta/2))
|
51 |
+
x = x.sum(dim=2)
|
52 |
+
return x
|
53 |
+
|
54 |
+
def cal_trans_batch(self, hist_dst, hist_ref):
|
55 |
+
hist_dst = hist_dst[:,None,:].repeat(1,256,1)
|
56 |
+
hist_ref = hist_ref[:,:,None].repeat(1,1,256)
|
57 |
+
table = hist_dst - hist_ref
|
58 |
+
table = torch.where(table>=0, 1., 0.)
|
59 |
+
table = torch.sum(table, dim=1) - 1
|
60 |
+
table = torch.clamp(table, min=0, max=255)
|
61 |
+
return table
|
62 |
+
|
63 |
+
def apply_histogram_matching(image, reference, method, factor):
|
64 |
+
if image is None or reference is None:
|
65 |
+
return None
|
66 |
+
|
67 |
+
# Convert to torch tensors and normalize
|
68 |
+
image = torch.from_numpy(image).float() / 255.0
|
69 |
+
reference = torch.from_numpy(reference).float() / 255.0
|
70 |
+
|
71 |
+
# Add batch dimension and rearrange to BCHW
|
72 |
+
image = image.unsqueeze(0).permute(0, 3, 1, 2)
|
73 |
+
reference = reference.unsqueeze(0).permute(0, 3, 1, 2)
|
74 |
+
|
75 |
+
if method == "pytorch":
|
76 |
+
# Apply PyTorch-based histogram matching
|
77 |
+
matcher = HistogramMatcher(differentiable=True)
|
78 |
+
matched = matcher(image, reference)
|
79 |
+
else: # skimage
|
80 |
+
# Convert back to numpy for skimage
|
81 |
+
matched = match_histograms(
|
82 |
+
image.permute(0, 2, 3, 1).numpy(),
|
83 |
+
reference.permute(0, 2, 3, 1).numpy(),
|
84 |
+
channel_axis=3
|
85 |
+
)
|
86 |
+
matched = torch.from_numpy(matched).permute(0, 3, 1, 2)
|
87 |
+
|
88 |
+
# Apply factor blending
|
89 |
+
result = factor * matched + (1 - factor) * image
|
90 |
+
|
91 |
+
# Convert back to HWC format and to uint8
|
92 |
+
output = result.squeeze(0).permute(1, 2, 0)
|
93 |
+
output = (output.clamp(0, 1).numpy() * 255).astype(np.uint8)
|
94 |
+
|
95 |
+
return output
|
96 |
+
|
97 |
+
def create_histogram_tab():
|
98 |
+
with gr.Tab("Histogram Matching"):
|
99 |
+
gr.Markdown("Match histograms between images using PyTorch or scikit-image methods")
|
100 |
+
|
101 |
+
with gr.Row():
|
102 |
+
with gr.Column():
|
103 |
+
input_image = gr.Image(label="Input Image")
|
104 |
+
reference_image = gr.Image(label="Reference Image")
|
105 |
+
|
106 |
+
method = gr.Radio(
|
107 |
+
choices=["pytorch", "skimage"],
|
108 |
+
value="pytorch",
|
109 |
+
label="Matching Method"
|
110 |
+
)
|
111 |
+
|
112 |
+
factor = gr.Slider(
|
113 |
+
minimum=0.0,
|
114 |
+
maximum=1.0,
|
115 |
+
value=1.0,
|
116 |
+
step=0.05,
|
117 |
+
label="Blend Factor"
|
118 |
+
)
|
119 |
+
|
120 |
+
match_btn = gr.Button("Apply Histogram Matching")
|
121 |
+
|
122 |
+
with gr.Column():
|
123 |
+
output_image = gr.Image(label="Matched Image")
|
124 |
+
|
125 |
+
match_btn.click(
|
126 |
+
fn=apply_histogram_matching,
|
127 |
+
inputs=[input_image, reference_image, method, factor],
|
128 |
+
outputs=output_image
|
129 |
+
)
|
requirements.txt
CHANGED
@@ -3,4 +3,5 @@ torch
|
|
3 |
torchvision
|
4 |
numpy==1.26.4
|
5 |
colour-science
|
6 |
-
kornia
|
|
|
|
3 |
torchvision
|
4 |
numpy==1.26.4
|
5 |
colour-science
|
6 |
+
kornia
|
7 |
+
scikit-image
|