danielsapit commited on
Commit
bff44e4
·
1 Parent(s): 8753088

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +74 -63
app.py CHANGED
@@ -17,7 +17,7 @@ for model_path in ['fbcnn_gray.pth','fbcnn_color.pth']:
17
  r = requests.get(url, allow_redirects=True)
18
  open(model_path, 'wb').write(r.content)
19
 
20
- def inference(input_img, is_gray, input_quality, zoom, x_shift, y_shift):
21
 
22
  if is_gray:
23
  n_channels = 1 # set 1 for grayscale image, set 3 for color image
@@ -46,57 +46,59 @@ def inference(input_img, is_gray, input_quality, zoom, x_shift, y_shift):
46
  # ----------------------------------------
47
  # load model
48
  # ----------------------------------------
49
- model = net(in_nc=n_channels, out_nc=n_channels, nc=nc, nb=nb, act_mode='R')
50
- model.load_state_dict(torch.load(model_path), strict=True)
51
- model.eval()
52
- for k, v in model.named_parameters():
53
- v.requires_grad = False
54
- model = model.to(device)
55
-
56
- test_results = OrderedDict()
57
- test_results['psnr'] = []
58
- test_results['ssim'] = []
59
- test_results['psnrb'] = []
60
-
61
- # ------------------------------------
62
- # (1) img_L
63
- # ------------------------------------
64
-
65
- if n_channels == 1:
66
- open_cv_image = Image.fromarray(input_img)
67
- open_cv_image = ImageOps.grayscale(open_cv_image)
68
- open_cv_image = np.array(open_cv_image) # PIL to open cv image
69
- img = np.expand_dims(open_cv_image, axis=2) # HxWx1
70
- elif n_channels == 3:
71
- open_cv_image = np.array(input_img) # PIL to open cv image
72
- if open_cv_image.ndim == 2:
73
- open_cv_image = cv2.cvtColor(open_cv_image, cv2.COLOR_GRAY2RGB) # GGG
74
- else:
75
- open_cv_image = cv2.cvtColor(open_cv_image, cv2.COLOR_BGR2RGB) # RGB
76
-
77
- img_L = util.uint2tensor4(open_cv_image)
78
- img_L = img_L.to(device)
79
-
80
- # ------------------------------------
81
- # (2) img_E
82
- # ------------------------------------
83
-
84
- img_E,QF = model(img_L)
85
- QF = 1- QF
86
- img_E = util.tensor2single(img_E)
87
- img_E = util.single2uint(img_E)
88
-
89
- qf_input = torch.tensor([[1-input_quality/100]]).cuda() if device == torch.device('cuda') else torch.tensor([[1-input_quality/100]])
90
- img_E,QF = model(img_L, qf_input)
91
- QF = 1- QF
92
- img_E = util.tensor2single(img_E)
93
- img_E = util.single2uint(img_E)
94
-
95
- if img_E.ndim == 3:
96
- img_E = img_E[:, :, [2, 1, 0]]
97
-
98
- print("--inference finished")
99
-
 
 
100
  out_img = Image.fromarray(img_E)
101
  out_img_w, out_img_h = out_img.size # output image size
102
  zoom = zoom/100
@@ -105,37 +107,46 @@ def inference(input_img, is_gray, input_quality, zoom, x_shift, y_shift):
105
  zoom_w, zoom_h = out_img_w*zoom, out_img_h*zoom
106
  zoom_left, zoom_right = int((out_img_w - zoom_w)*x_shift), int(zoom_w + (out_img_w - zoom_w)*x_shift)
107
  zoom_top, zoom_bottom = int((out_img_h - zoom_h)*y_shift), int(zoom_h + (out_img_h - zoom_h)*y_shift)
108
- in_img = Image.fromarray(input_img)
 
 
 
 
109
  in_img = in_img.crop((zoom_left, zoom_top, zoom_right, zoom_bottom))
110
  in_img = in_img.resize((int(zoom_w/zoom), int(zoom_h/zoom)), Image.NEAREST)
111
  out_img = out_img.crop((zoom_left, zoom_top, zoom_right, zoom_bottom))
112
  out_img = out_img.resize((int(zoom_w/zoom), int(zoom_h/zoom)), Image.NEAREST)
113
 
114
- return img_E, in_img, out_img
115
 
116
  gr.Interface(
117
  fn = inference,
118
  inputs = [gr.inputs.Image(label="Input Image"),
119
  gr.inputs.Checkbox(label="Grayscale (Check this if your image is grayscale)"),
120
  gr.inputs.Slider(minimum=1, maximum=100, step=1, label="Intensity (Higher = stronger JPEG artifact removal)"),
 
 
 
121
  gr.inputs.Slider(minimum=10, maximum=100, step=1, default=50, label="Zoom Image "
122
  "(Use this to see the image quality up close. "
123
  "100 = original size)"),
124
  gr.inputs.Slider(minimum=0, maximum=100, step=1, label="Zoom preview horizontal shift "
125
  "(Increase to shift to the right)"),
126
  gr.inputs.Slider(minimum=0, maximum=100, step=1, label="Zoom preview vertical shift "
127
- "(Increase to shift downwards)")
 
128
  ],
129
  outputs = [gr.outputs.Image(label="Result"),
130
  gr.outputs.Image(label="Before:"),
131
- gr.outputs.Image(label="After:")],
132
- examples = [["doraemon.jpg",False,60,42,50,50],
133
- ["tomandjerry.jpg",False,60,40,57,44],
134
- ["somepanda.jpg",True,100,30,8,24],
135
- ["cemetry.jpg",False,70,20,76,62],
136
- ["michelangelo_david.jpg",True,30,12,53,27],
137
- ["elon_musk.jpg",False,45,15,33,30],
138
- ["text.jpg",True,70,50,11,29]],
 
139
  title = "JPEG Artifacts Removal [FBCNN]",
140
  description = "Gradio Demo for JPEG Artifacts Removal. To use it, simply upload your image, "
141
  "or click one of the examples to load them. Check out the paper and the original GitHub repo at the link below. "
 
17
  r = requests.get(url, allow_redirects=True)
18
  open(model_path, 'wb').write(r.content)
19
 
20
+ def inference(input_img, is_gray, input_quality, enable_zoom, zoom, x_shift, y_shift, state):
21
 
22
  if is_gray:
23
  n_channels = 1 # set 1 for grayscale image, set 3 for color image
 
46
  # ----------------------------------------
47
  # load model
48
  # ----------------------------------------
49
+ if (not enable_zoom) or (state[1] is None):
50
+ model = net(in_nc=n_channels, out_nc=n_channels, nc=nc, nb=nb, act_mode='R')
51
+ model.load_state_dict(torch.load(model_path), strict=True)
52
+ model.eval()
53
+ for k, v in model.named_parameters():
54
+ v.requires_grad = False
55
+ model = model.to(device)
56
+
57
+ test_results = OrderedDict()
58
+ test_results['psnr'] = []
59
+ test_results['ssim'] = []
60
+ test_results['psnrb'] = []
61
+
62
+ # ------------------------------------
63
+ # (1) img_L
64
+ # ------------------------------------
65
+
66
+ if n_channels == 1:
67
+ open_cv_image = Image.fromarray(input_img)
68
+ open_cv_image = ImageOps.grayscale(open_cv_image)
69
+ open_cv_image = np.array(open_cv_image) # PIL to open cv image
70
+ img = np.expand_dims(open_cv_image, axis=2) # HxWx1
71
+ elif n_channels == 3:
72
+ open_cv_image = np.array(input_img) # PIL to open cv image
73
+ if open_cv_image.ndim == 2:
74
+ open_cv_image = cv2.cvtColor(open_cv_image, cv2.COLOR_GRAY2RGB) # GGG
75
+ else:
76
+ open_cv_image = cv2.cvtColor(open_cv_image, cv2.COLOR_BGR2RGB) # RGB
77
+
78
+ img_L = util.uint2tensor4(open_cv_image)
79
+ img_L = img_L.to(device)
80
+
81
+ # ------------------------------------
82
+ # (2) img_E
83
+ # ------------------------------------
84
+
85
+ img_E,QF = model(img_L)
86
+ QF = 1- QF
87
+ img_E = util.tensor2single(img_E)
88
+ img_E = util.single2uint(img_E)
89
+
90
+ qf_input = torch.tensor([[1-input_quality/100]]).cuda() if device == torch.device('cuda') else torch.tensor([[1-input_quality/100]])
91
+ img_E,QF = model(img_L, qf_input)
92
+ QF = 1- QF
93
+ img_E = util.tensor2single(img_E)
94
+ img_E = util.single2uint(img_E)
95
+
96
+ if img_E.ndim == 3:
97
+ img_E = img_E[:, :, [2, 1, 0]]
98
+
99
+ print("--inference finished")
100
+ if (state[1] is not None) and enable_zoom:
101
+ img_E = state[1]
102
  out_img = Image.fromarray(img_E)
103
  out_img_w, out_img_h = out_img.size # output image size
104
  zoom = zoom/100
 
107
  zoom_w, zoom_h = out_img_w*zoom, out_img_h*zoom
108
  zoom_left, zoom_right = int((out_img_w - zoom_w)*x_shift), int(zoom_w + (out_img_w - zoom_w)*x_shift)
109
  zoom_top, zoom_bottom = int((out_img_h - zoom_h)*y_shift), int(zoom_h + (out_img_h - zoom_h)*y_shift)
110
+ if (state[0] is None) or not enable_zoom:
111
+ in_img = Image.fromarray(input_img)
112
+ state[0] = input_img
113
+ else:
114
+ in_img = Image.fromarray(state[0])
115
  in_img = in_img.crop((zoom_left, zoom_top, zoom_right, zoom_bottom))
116
  in_img = in_img.resize((int(zoom_w/zoom), int(zoom_h/zoom)), Image.NEAREST)
117
  out_img = out_img.crop((zoom_left, zoom_top, zoom_right, zoom_bottom))
118
  out_img = out_img.resize((int(zoom_w/zoom), int(zoom_h/zoom)), Image.NEAREST)
119
 
120
+ return img_E, in_img, out_img, [state[0],img_E]
121
 
122
  gr.Interface(
123
  fn = inference,
124
  inputs = [gr.inputs.Image(label="Input Image"),
125
  gr.inputs.Checkbox(label="Grayscale (Check this if your image is grayscale)"),
126
  gr.inputs.Slider(minimum=1, maximum=100, step=1, label="Intensity (Higher = stronger JPEG artifact removal)"),
127
+ gr.inputs.Checkbox(default=False, label="Edit Zoom preview (This is optional. "
128
+ "After the image result is loaded, check this to edit zoom parameters "
129
+ "so that the input image will not be processed when the submit button is pressed.)"),
130
  gr.inputs.Slider(minimum=10, maximum=100, step=1, default=50, label="Zoom Image "
131
  "(Use this to see the image quality up close. "
132
  "100 = original size)"),
133
  gr.inputs.Slider(minimum=0, maximum=100, step=1, label="Zoom preview horizontal shift "
134
  "(Increase to shift to the right)"),
135
  gr.inputs.Slider(minimum=0, maximum=100, step=1, label="Zoom preview vertical shift "
136
+ "(Increase to shift downwards)"),
137
+ gr.inputs.State(default=[None,None], label="\t")
138
  ],
139
  outputs = [gr.outputs.Image(label="Result"),
140
  gr.outputs.Image(label="Before:"),
141
+ gr.outputs.Image(label="After:"),
142
+ "state"],
143
+ examples = [["doraemon.jpg",False,60,False,42,50,50],
144
+ ["tomandjerry.jpg",False,60,False,40,57,44],
145
+ ["somepanda.jpg",True,100,False,30,8,24],
146
+ ["cemetry.jpg",False,70,False,20,76,62],
147
+ ["michelangelo_david.jpg",True,30,False,12,53,27],
148
+ ["elon_musk.jpg",False,45,False,15,33,30],
149
+ ["text.jpg",True,70,False,50,11,29]],
150
  title = "JPEG Artifacts Removal [FBCNN]",
151
  description = "Gradio Demo for JPEG Artifacts Removal. To use it, simply upload your image, "
152
  "or click one of the examples to load them. Check out the paper and the original GitHub repo at the link below. "