vincent-doan commited on
Commit
95110bc
·
1 Parent(s): ecd8334

Re-configured SRFlow

Browse files
Files changed (2) hide show
  1. app.py +32 -10
  2. models/SRFlow/srflow.py +12 -6
app.py CHANGED
@@ -5,50 +5,54 @@ from PIL import Image
5
  from io import BytesIO
6
  from models.HAT.hat import *
7
  from models.RCAN.rcan import *
8
- from models.SRGAN.srgan import *
 
9
 
10
  # Initialize session state for enhanced images
11
  if 'hat_enhanced_image' not in st.session_state:
12
  st.session_state['hat_enhanced_image'] = None
13
-
14
  if 'rcan_enhanced_image' not in st.session_state:
15
  st.session_state['rcan_enhanced_image'] = None
16
-
17
  if 'srgan_enhanced_image' not in st.session_state:
18
  st.session_state['srgan_enhanced_image'] = None
 
 
19
 
 
20
  if 'hat_clicked' not in st.session_state:
21
  st.session_state['hat_clicked'] = False
22
  if 'rcan_clicked' not in st.session_state:
23
  st.session_state['rcan_clicked'] = False
24
-
25
  if 'srgan_clicked' not in st.session_state:
26
  st.session_state['srgan_clicked'] = False
 
 
27
 
28
  st.markdown("<h1 style='text-align: center'>Image Super Resolution</h1>", unsafe_allow_html=True)
 
29
  # Sidebar for navigation
30
  st.sidebar.title("Options")
31
- app_mode = st.sidebar.selectbox("Choose the input source",
32
- ["Upload image", "Take a photo"])
33
  # Depending on the choice, show the uploader widget or webcam capture
34
  if app_mode == "Upload image":
35
  uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "png"], on_change=lambda: reset_states())
36
  if uploaded_file is not None:
37
  image = Image.open(uploaded_file).convert("RGB")
38
  elif app_mode == "Take a photo":
39
- # Using JS code to access user's webcam
40
  camera_input = st.camera_input("Take a picture", on_change=lambda: reset_states())
41
  if camera_input is not None:
42
- # Convert the camera image to an RGB image
43
  image = Image.open(camera_input).convert("RGB")
44
 
45
  def reset_states():
46
  st.session_state['hat_enhanced_image'] = None
47
  st.session_state['rcan_enhanced_image'] = None
48
  st.session_state['srgan_enhanced_image'] = None
 
49
  st.session_state['hat_clicked'] = False
50
  st.session_state['rcan_clicked'] = False
51
  st.session_state['srgan_clicked'] = False
 
52
 
53
  def get_image_download_link(img, filename):
54
  """Generates a link allowing the PIL image to be downloaded"""
@@ -102,7 +106,8 @@ if 'image' in locals():
102
  col2.image(st.session_state['rcan_enhanced_image'], use_column_width=True)
103
  with col2:
104
  get_image_download_link(st.session_state['rcan_enhanced_image'], 'rcan_enhanced.jpg')
105
- #--------------------------SRGAN--------------------------#
 
106
  if st.button('Enhance with SRGAN'):
107
  with st.spinner('Processing using SRGAN...'):
108
  with st.spinner('Wait for it... the model is processing the image'):
@@ -120,4 +125,21 @@ if 'image' in locals():
120
  col2.header("Enhanced")
121
  col2.image(st.session_state['srgan_enhanced_image'], use_column_width=True)
122
  with col2:
123
- get_image_download_link(st.session_state['srgan_enhanced_image'], 'srgan_enhanced.jpg')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  from io import BytesIO
6
  from models.HAT.hat import *
7
  from models.RCAN.rcan import *
8
+ from models.SRGAN.srgan import *
9
+ from models.SRFlow.srflow import *
10
 
11
  # Initialize session state for enhanced images
12
  if 'hat_enhanced_image' not in st.session_state:
13
  st.session_state['hat_enhanced_image'] = None
 
14
  if 'rcan_enhanced_image' not in st.session_state:
15
  st.session_state['rcan_enhanced_image'] = None
 
16
  if 'srgan_enhanced_image' not in st.session_state:
17
  st.session_state['srgan_enhanced_image'] = None
18
+ if 'srflow_enhanced_image' not in st.session_state:
19
+ st.session_state['srflow_enhanced_image'] = None
20
 
21
+ # Initialize session state for button clicks
22
  if 'hat_clicked' not in st.session_state:
23
  st.session_state['hat_clicked'] = False
24
  if 'rcan_clicked' not in st.session_state:
25
  st.session_state['rcan_clicked'] = False
 
26
  if 'srgan_clicked' not in st.session_state:
27
  st.session_state['srgan_clicked'] = False
28
+ if 'srflow_clicked' not in st.session_state:
29
+ st.session_state['srflow_clicked'] = False
30
 
31
  st.markdown("<h1 style='text-align: center'>Image Super Resolution</h1>", unsafe_allow_html=True)
32
+
33
  # Sidebar for navigation
34
  st.sidebar.title("Options")
35
+ app_mode = st.sidebar.selectbox("Choose the input source", ["Upload image", "Take a photo"])
36
+
37
  # Depending on the choice, show the uploader widget or webcam capture
38
  if app_mode == "Upload image":
39
  uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "png"], on_change=lambda: reset_states())
40
  if uploaded_file is not None:
41
  image = Image.open(uploaded_file).convert("RGB")
42
  elif app_mode == "Take a photo":
 
43
  camera_input = st.camera_input("Take a picture", on_change=lambda: reset_states())
44
  if camera_input is not None:
 
45
  image = Image.open(camera_input).convert("RGB")
46
 
47
  def reset_states():
48
  st.session_state['hat_enhanced_image'] = None
49
  st.session_state['rcan_enhanced_image'] = None
50
  st.session_state['srgan_enhanced_image'] = None
51
+ st.session_state['srflow_enhanced_image'] = None
52
  st.session_state['hat_clicked'] = False
53
  st.session_state['rcan_clicked'] = False
54
  st.session_state['srgan_clicked'] = False
55
+ st.session_state['srflow_clicked'] = False
56
 
57
  def get_image_download_link(img, filename):
58
  """Generates a link allowing the PIL image to be downloaded"""
 
106
  col2.image(st.session_state['rcan_enhanced_image'], use_column_width=True)
107
  with col2:
108
  get_image_download_link(st.session_state['rcan_enhanced_image'], 'rcan_enhanced.jpg')
109
+
110
+ # --------------------------SRGAN-------------------------- #
111
  if st.button('Enhance with SRGAN'):
112
  with st.spinner('Processing using SRGAN...'):
113
  with st.spinner('Wait for it... the model is processing the image'):
 
125
  col2.header("Enhanced")
126
  col2.image(st.session_state['srgan_enhanced_image'], use_column_width=True)
127
  with col2:
128
+ get_image_download_link(st.session_state['srgan_enhanced_image'], 'srgan_enhanced.jpg')
129
+
130
+ # ------------------------ SRFlow ------------------------ #
131
+ if st.button('Enhance with SRFlow'):
132
+ with st.spinner('Processing using SRFlow...'):
133
+ with st.spinner('Wait for it... the model is processing the image'):
134
+ enhanced_image = return_SRFlow_result(image)
135
+ st.session_state['srflow_enhanced_image'] = enhanced_image
136
+ st.session_state['srflow_clicked'] = True
137
+ st.success('Done!')
138
+ if st.session_state['srflow_enhanced_image'] is not None:
139
+ col1, col2 = st.columns(2)
140
+ col1.header("Original")
141
+ col1.image(image, use_column_width=True)
142
+ col2.header("Enhanced")
143
+ col2.image(st.session_state['srflow_enhanced_image'], use_column_width=True)
144
+ with col2:
145
+ get_image_download_link(st.session_state['srflow_enhanced_image'], 'srflow_enhanced.jpg')
models/SRFlow/srflow.py CHANGED
@@ -3,22 +3,23 @@ import torch
3
  import sys
4
  sys.path.append('models')
5
  from SRFlow.code import imread, impad, load_model, t, rgb
 
 
6
 
7
-
8
- def return_SRFlow_result(lr_path, conf_path='models/SRFlow/code/confs/SRFlow_DF2K_4X.yml', heat=0.6):
9
  """
10
  Apply Super-Resolution using SRFlow model to the input LR (low-resolution) image.
11
 
12
  Args:
13
- - lr_path (str): File path of the LR image. (Refer to OpenCV documentation for suitable image formats with imread)
14
  - conf_path (str): Configuration file path for the SRFlow model. Default is SRFlow_DF2K_4X.yml.
15
  - heat (float): Heat parameter for the SRFlow model. Default is 0.6.
16
 
17
  Returns:
18
- - sr (numpy.ndarray): Super-resolved image in numpy array format.
19
  """
20
  model, opt = load_model(conf_path)
21
- lr = imread(lr_path)
22
 
23
  scale = opt['scale']
24
  pad_factor = 2
@@ -35,5 +36,10 @@ def return_SRFlow_result(lr_path, conf_path='models/SRFlow/code/confs/SRFlow_DF2
35
  sr = rgb(torch.clamp(sr_t, 0, 1))
36
  sr = sr[:h * scale, :w * scale]
37
 
38
- sr
39
  return sr
 
 
 
 
 
 
3
  import sys
4
  sys.path.append('models')
5
  from SRFlow.code import imread, impad, load_model, t, rgb
6
+ from PIL import Image
7
+ from torchvision.transforms import PILToTensor
8
 
9
+ def return_SRFlow_result(lr, conf_path='models/SRFlow/code/confs/SRFlow_DF2K_4X.yml', heat=0.6):
 
10
  """
11
  Apply Super-Resolution using SRFlow model to the input LR (low-resolution) image.
12
 
13
  Args:
14
+ - lr: PIL Image
15
  - conf_path (str): Configuration file path for the SRFlow model. Default is SRFlow_DF2K_4X.yml.
16
  - heat (float): Heat parameter for the SRFlow model. Default is 0.6.
17
 
18
  Returns:
19
+ - sr: PIL Image
20
  """
21
  model, opt = load_model(conf_path)
22
+ lr = PILToTensor()(lr).permute(1, 2, 0).numpy()
23
 
24
  scale = opt['scale']
25
  pad_factor = 2
 
36
  sr = rgb(torch.clamp(sr_t, 0, 1))
37
  sr = sr[:h * scale, :w * scale]
38
 
39
+ sr = Image.fromarray((sr).astype('uint8'))
40
  return sr
41
+
42
+ if __name__ == '__main__':
43
+ ip = Image.open('images/demo.png')
44
+ sr = return_SRFlow_result(ip)
45
+ print(sr.size)