qubvel-hf commited on
Commit
4d657e7
·
1 Parent(s): c509e76

Gradio app

Browse files
Files changed (3) hide show
  1. app.py +38 -0
  2. inference_gradio.py +352 -0
  3. packages.txt +1 -0
app.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import spaces
3
+ from inference_gradio import inference_one_image, model_init
4
+
5
+ MODEL_PATH = "./checkpoints/docres.pkl"
6
+
7
+ model = model_init(MODEL_PATH)
8
+ possible_tasks = [
9
+ "dewarping",
10
+ "deshadowing",
11
+ "appearance",
12
+ "deblurring",
13
+ "binarization",
14
+ ]
15
+
16
+ @spaces.GPU
17
+ def run_tasks(image, tasks):
18
+ bgr_image = image[..., ::-1].copy()
19
+ bgr_restored_image = inference_one_image(model, bgr_image, tasks)
20
+ if bgr_restored_image.ndim == 3:
21
+ rgb_image = bgr_restored_image[..., ::-1]
22
+ else:
23
+ rgb_image = bgr_restored_image
24
+ return rgb_image
25
+
26
+
27
+ with gr.Blocks() as demo:
28
+ with gr.Row():
29
+ input_image = gr.Image(type="numpy")
30
+ output_image = gr.Image(type="numpy")
31
+
32
+ task = gr.CheckboxGroup(choices=possible_tasks, label="Choose tasks:")
33
+ button = gr.Button()
34
+ button.click(
35
+ run_tasks, inputs=[input_image, task], outputs=[output_image]
36
+ )
37
+
38
+ demo.launch()
inference_gradio.py ADDED
@@ -0,0 +1,352 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import cv2
3
+ import utils
4
+ import numpy as np
5
+
6
+ import torch
7
+ from PIL import Image
8
+
9
+ from utils import convert_state_dict
10
+ from models import restormer_arch
11
+ from data.preprocess.crop_merge_image import stride_integral
12
+
13
+ sys.path.append("./data/MBD/")
14
+ from data.MBD.infer import net1_net2_infer_single_im
15
+
16
+
17
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
18
+
19
+
20
+ def dewarp_prompt(img):
21
+ mask = net1_net2_infer_single_im(img, "data/MBD/checkpoint/mbd.pkl")
22
+ base_coord = utils.getBasecoord(256, 256) / 256
23
+ img[mask == 0] = 0
24
+ mask = cv2.resize(mask, (256, 256)) / 255
25
+ return img, np.concatenate((base_coord, np.expand_dims(mask, -1)), -1)
26
+
27
+
28
+ def deshadow_prompt(img):
29
+ h, w = img.shape[:2]
30
+ # img = cv2.resize(img,(128,128))
31
+ img = cv2.resize(img, (1024, 1024))
32
+ rgb_planes = cv2.split(img)
33
+ result_planes = []
34
+ result_norm_planes = []
35
+ bg_imgs = []
36
+ for plane in rgb_planes:
37
+ dilated_img = cv2.dilate(plane, np.ones((7, 7), np.uint8))
38
+ bg_img = cv2.medianBlur(dilated_img, 21)
39
+ bg_imgs.append(bg_img)
40
+ diff_img = 255 - cv2.absdiff(plane, bg_img)
41
+ norm_img = cv2.normalize(
42
+ diff_img,
43
+ None,
44
+ alpha=0,
45
+ beta=255,
46
+ norm_type=cv2.NORM_MINMAX,
47
+ dtype=cv2.CV_8UC1,
48
+ )
49
+ result_planes.append(diff_img)
50
+ result_norm_planes.append(norm_img)
51
+ bg_imgs = cv2.merge(bg_imgs)
52
+ bg_imgs = cv2.resize(bg_imgs, (w, h))
53
+ # result = cv2.merge(result_planes)
54
+ result_norm = cv2.merge(result_norm_planes)
55
+ result_norm[result_norm == 0] = 1
56
+ shadow_map = np.clip(
57
+ img.astype(float) / result_norm.astype(float) * 255, 0, 255
58
+ ).astype(np.uint8)
59
+ shadow_map = cv2.resize(shadow_map, (w, h))
60
+ shadow_map = cv2.cvtColor(shadow_map, cv2.COLOR_BGR2GRAY)
61
+ shadow_map = cv2.cvtColor(shadow_map, cv2.COLOR_GRAY2BGR)
62
+ # return shadow_map
63
+ return bg_imgs
64
+
65
+
66
+ def deblur_prompt(img):
67
+ x = cv2.Sobel(img, cv2.CV_16S, 1, 0)
68
+ y = cv2.Sobel(img, cv2.CV_16S, 0, 1)
69
+ absX = cv2.convertScaleAbs(x) # 转回uint8
70
+ absY = cv2.convertScaleAbs(y)
71
+ high_frequency = cv2.addWeighted(absX, 0.5, absY, 0.5, 0)
72
+ high_frequency = cv2.cvtColor(high_frequency, cv2.COLOR_BGR2GRAY)
73
+ high_frequency = cv2.cvtColor(high_frequency, cv2.COLOR_GRAY2BGR)
74
+ return high_frequency
75
+
76
+
77
+ def appearance_prompt(img):
78
+ h, w = img.shape[:2]
79
+ # img = cv2.resize(img,(128,128))
80
+ img = cv2.resize(img, (1024, 1024))
81
+ rgb_planes = cv2.split(img)
82
+ result_planes = []
83
+ result_norm_planes = []
84
+ for plane in rgb_planes:
85
+ dilated_img = cv2.dilate(plane, np.ones((7, 7), np.uint8))
86
+ bg_img = cv2.medianBlur(dilated_img, 21)
87
+ diff_img = 255 - cv2.absdiff(plane, bg_img)
88
+ norm_img = cv2.normalize(
89
+ diff_img,
90
+ None,
91
+ alpha=0,
92
+ beta=255,
93
+ norm_type=cv2.NORM_MINMAX,
94
+ dtype=cv2.CV_8UC1,
95
+ )
96
+ result_planes.append(diff_img)
97
+ result_norm_planes.append(norm_img)
98
+ result_norm = cv2.merge(result_norm_planes)
99
+ result_norm = cv2.resize(result_norm, (w, h))
100
+ return result_norm
101
+
102
+
103
+ def binarization_promptv2(img):
104
+ result, thresh = utils.SauvolaModBinarization(img)
105
+ thresh = thresh.astype(np.uint8)
106
+ result[result > 155] = 255
107
+ result[result <= 155] = 0
108
+
109
+ x = cv2.Sobel(img, cv2.CV_16S, 1, 0)
110
+ y = cv2.Sobel(img, cv2.CV_16S, 0, 1)
111
+ absX = cv2.convertScaleAbs(x) # 转回uint8
112
+ absY = cv2.convertScaleAbs(y)
113
+ high_frequency = cv2.addWeighted(absX, 0.5, absY, 0.5, 0)
114
+ high_frequency = cv2.cvtColor(high_frequency, cv2.COLOR_BGR2GRAY)
115
+ return np.concatenate(
116
+ (
117
+ np.expand_dims(thresh, -1),
118
+ np.expand_dims(high_frequency, -1),
119
+ np.expand_dims(result, -1),
120
+ ),
121
+ -1,
122
+ )
123
+
124
+
125
+ def dewarping(model, im_org):
126
+ INPUT_SIZE = 256
127
+ im_masked, prompt_org = dewarp_prompt(im_org.copy())
128
+
129
+ h, w = im_masked.shape[:2]
130
+ im_masked = im_masked.copy()
131
+ im_masked = cv2.resize(im_masked, (INPUT_SIZE, INPUT_SIZE))
132
+ im_masked = im_masked / 255.0
133
+ im_masked = torch.from_numpy(im_masked.transpose(2, 0, 1)).unsqueeze(0)
134
+ im_masked = im_masked.float().to(DEVICE)
135
+
136
+ prompt = torch.from_numpy(prompt_org.transpose(2, 0, 1)).unsqueeze(0)
137
+ prompt = prompt.float().to(DEVICE)
138
+
139
+ in_im = torch.cat((im_masked, prompt), dim=1)
140
+
141
+ # inference
142
+ base_coord = utils.getBasecoord(INPUT_SIZE, INPUT_SIZE) / INPUT_SIZE
143
+ model = model.float()
144
+ with torch.no_grad():
145
+ pred = model(in_im)
146
+ pred = pred[0][:2].permute(1, 2, 0).cpu().numpy()
147
+ pred = pred + base_coord
148
+ ## smooth
149
+ for i in range(15):
150
+ pred = cv2.blur(pred, (3, 3), borderType=cv2.BORDER_REPLICATE)
151
+ pred = cv2.resize(pred, (w, h)) * (w, h)
152
+ pred = pred.astype(np.float32)
153
+ out_im = cv2.remap(im_org, pred[:, :, 0], pred[:, :, 1], cv2.INTER_LINEAR)
154
+
155
+ prompt_org = (prompt_org * 255).astype(np.uint8)
156
+ prompt_org = cv2.resize(prompt_org, im_org.shape[:2][::-1])
157
+
158
+ return prompt_org[:, :, 0], prompt_org[:, :, 1], prompt_org[:, :, 2], out_im
159
+
160
+
161
+ def appearance(model, im_org):
162
+ MAX_SIZE = 1600
163
+ # obtain im and prompt
164
+ h, w = im_org.shape[:2]
165
+ prompt = appearance_prompt(im_org)
166
+ in_im = np.concatenate((im_org, prompt), -1)
167
+
168
+ # constrain the max resolution
169
+ if max(w, h) < MAX_SIZE:
170
+ in_im, padding_h, padding_w = stride_integral(in_im, 8)
171
+ else:
172
+ in_im = cv2.resize(in_im, (MAX_SIZE, MAX_SIZE))
173
+
174
+ # normalize
175
+ in_im = in_im / 255.0
176
+ in_im = torch.from_numpy(in_im.transpose(2, 0, 1)).unsqueeze(0)
177
+
178
+ # inference
179
+ in_im = in_im.half().to(DEVICE)
180
+ model = model.half()
181
+ with torch.no_grad():
182
+ pred = model(in_im)
183
+ pred = torch.clamp(pred, 0, 1)
184
+ pred = pred[0].permute(1, 2, 0).cpu().numpy()
185
+ pred = (pred * 255).astype(np.uint8)
186
+
187
+ if max(w, h) < MAX_SIZE:
188
+ out_im = pred[padding_h:, padding_w:]
189
+ else:
190
+ pred[pred == 0] = 1
191
+ shadow_map = cv2.resize(im_org, (MAX_SIZE, MAX_SIZE)).astype(
192
+ float
193
+ ) / pred.astype(float)
194
+ shadow_map = cv2.resize(shadow_map, (w, h))
195
+ shadow_map[shadow_map == 0] = 0.00001
196
+ out_im = np.clip(im_org.astype(float) / shadow_map, 0, 255).astype(np.uint8)
197
+
198
+ return prompt[:, :, 0], prompt[:, :, 1], prompt[:, :, 2], out_im
199
+
200
+
201
+ def deshadowing(model, im_org):
202
+ MAX_SIZE = 1600
203
+ # obtain im and prompt
204
+ h, w = im_org.shape[:2]
205
+ prompt = deshadow_prompt(im_org)
206
+ in_im = np.concatenate((im_org, prompt), -1)
207
+
208
+ # constrain the max resolution
209
+ if max(w, h) < MAX_SIZE:
210
+ in_im, padding_h, padding_w = stride_integral(in_im, 8)
211
+ else:
212
+ in_im = cv2.resize(in_im, (MAX_SIZE, MAX_SIZE))
213
+
214
+ # normalize
215
+ in_im = in_im / 255.0
216
+ in_im = torch.from_numpy(in_im.transpose(2, 0, 1)).unsqueeze(0)
217
+
218
+ # inference
219
+ in_im = in_im.half().to(DEVICE)
220
+ model = model.half()
221
+ with torch.no_grad():
222
+ pred = model(in_im)
223
+ pred = torch.clamp(pred, 0, 1)
224
+ pred = pred[0].permute(1, 2, 0).cpu().numpy()
225
+ pred = (pred * 255).astype(np.uint8)
226
+
227
+ if max(w, h) < MAX_SIZE:
228
+ out_im = pred[padding_h:, padding_w:]
229
+ else:
230
+ pred[pred == 0] = 1
231
+ shadow_map = cv2.resize(im_org, (MAX_SIZE, MAX_SIZE)).astype(
232
+ float
233
+ ) / pred.astype(float)
234
+ shadow_map = cv2.resize(shadow_map, (w, h))
235
+ shadow_map[shadow_map == 0] = 0.00001
236
+ out_im = np.clip(im_org.astype(float) / shadow_map, 0, 255).astype(np.uint8)
237
+
238
+ return prompt[:, :, 0], prompt[:, :, 1], prompt[:, :, 2], out_im
239
+
240
+
241
+ def deblurring(model, im_org):
242
+ # setup image
243
+ in_im, padding_h, padding_w = stride_integral(im_org, 8)
244
+ prompt = deblur_prompt(in_im)
245
+ in_im = np.concatenate((in_im, prompt), -1)
246
+ in_im = in_im / 255.0
247
+ in_im = torch.from_numpy(in_im.transpose(2, 0, 1)).unsqueeze(0)
248
+ in_im = in_im.half().to(DEVICE)
249
+ # inference
250
+ model.to(DEVICE)
251
+ model.eval()
252
+ model = model.half()
253
+ with torch.no_grad():
254
+ pred = model(in_im)
255
+ pred = torch.clamp(pred, 0, 1)
256
+ pred = pred[0].permute(1, 2, 0).cpu().numpy()
257
+ pred = (pred * 255).astype(np.uint8)
258
+ out_im = pred[padding_h:, padding_w:]
259
+
260
+ return prompt[:, :, 0], prompt[:, :, 1], prompt[:, :, 2], out_im
261
+
262
+
263
+ def binarization(model, im_org):
264
+ im, padding_h, padding_w = stride_integral(im_org, 8)
265
+ prompt = binarization_promptv2(im)
266
+ h, w = im.shape[:2]
267
+ in_im = np.concatenate((im, prompt), -1)
268
+
269
+ in_im = in_im / 255.0
270
+ in_im = torch.from_numpy(in_im.transpose(2, 0, 1)).unsqueeze(0)
271
+ in_im = in_im.to(DEVICE)
272
+ model = model.half()
273
+ in_im = in_im.half()
274
+ with torch.no_grad():
275
+ pred = model(in_im)
276
+ pred = pred[:, :2, :, :]
277
+ pred = torch.max(torch.softmax(pred, 1), 1)[1]
278
+ pred = pred[0].cpu().numpy()
279
+ pred = (pred * 255).astype(np.uint8)
280
+ pred = cv2.resize(pred, (w, h))
281
+ out_im = pred[padding_h:, padding_w:]
282
+
283
+ return prompt[:, :, 0], prompt[:, :, 1], prompt[:, :, 2], out_im
284
+
285
+
286
+ def model_init(model_path):
287
+ # prepare model
288
+ model = restormer_arch.Restormer(
289
+ inp_channels=6,
290
+ out_channels=3,
291
+ dim=48,
292
+ num_blocks=[2, 3, 3, 4],
293
+ num_refinement_blocks=4,
294
+ heads=[1, 2, 4, 8],
295
+ ffn_expansion_factor=2.66,
296
+ bias=False,
297
+ LayerNorm_type="WithBias",
298
+ dual_pixel_task=True,
299
+ )
300
+
301
+ if DEVICE == "cpu":
302
+ state = convert_state_dict(
303
+ torch.load(model_path, map_location="cpu")["model_state"]
304
+ )
305
+ else:
306
+ state = convert_state_dict(
307
+ torch.load(model_path, map_location="cuda:0")["model_state"]
308
+ )
309
+ model.load_state_dict(state)
310
+
311
+ model.eval()
312
+ model = model.to(DEVICE)
313
+ return model
314
+
315
+
316
+ def resize(image, max_size):
317
+ h, w = image.shape[:2]
318
+ if max(h, w) > max_size:
319
+ if h > w:
320
+ h_new = max_size
321
+ w_new = int(w * h_new / h)
322
+ else:
323
+ w_new = max_size
324
+ h_new = int(h * w_new / w)
325
+ pil_image = Image.fromarray(image)
326
+ pil_image = pil_image.resize((w_new, h_new), Image.Resampling.LANCZOS)
327
+ image = np.array(pil_image)
328
+ return image
329
+
330
+
331
+ def inference_one_image(model, image, tasks):
332
+ # image should be in BGR format
333
+
334
+ if "dewarping" in tasks:
335
+ *_, image = dewarping(model, image)
336
+
337
+ # if only dewarping return here
338
+ if len(tasks) == 1 and "dewarping" in tasks:
339
+ return image
340
+
341
+ image = resize(image, 1536)
342
+
343
+ if "deshadowing" in tasks:
344
+ *_, image = deshadowing(model, image)
345
+ if "appearance" in tasks:
346
+ *_, image = appearance(model, image)
347
+ if "deblurring" in tasks:
348
+ *_, image = deblurring(model, image)
349
+ if "binarization" in tasks:
350
+ *_, image = binarization(model, image)
351
+
352
+ return image
packages.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ python3-opencv