Spaces:
Running
Running
vincent-doan
commited on
Commit
·
95110bc
1
Parent(s):
ecd8334
Re-configured SRFlow
Browse files- app.py +32 -10
- models/SRFlow/srflow.py +12 -6
app.py
CHANGED
@@ -5,50 +5,54 @@ from PIL import Image
|
|
5 |
from io import BytesIO
|
6 |
from models.HAT.hat import *
|
7 |
from models.RCAN.rcan import *
|
8 |
-
from models.SRGAN.srgan import *
|
|
|
9 |
|
10 |
# Initialize session state for enhanced images
|
11 |
if 'hat_enhanced_image' not in st.session_state:
|
12 |
st.session_state['hat_enhanced_image'] = None
|
13 |
-
|
14 |
if 'rcan_enhanced_image' not in st.session_state:
|
15 |
st.session_state['rcan_enhanced_image'] = None
|
16 |
-
|
17 |
if 'srgan_enhanced_image' not in st.session_state:
|
18 |
st.session_state['srgan_enhanced_image'] = None
|
|
|
|
|
19 |
|
|
|
20 |
if 'hat_clicked' not in st.session_state:
|
21 |
st.session_state['hat_clicked'] = False
|
22 |
if 'rcan_clicked' not in st.session_state:
|
23 |
st.session_state['rcan_clicked'] = False
|
24 |
-
|
25 |
if 'srgan_clicked' not in st.session_state:
|
26 |
st.session_state['srgan_clicked'] = False
|
|
|
|
|
27 |
|
28 |
st.markdown("<h1 style='text-align: center'>Image Super Resolution</h1>", unsafe_allow_html=True)
|
|
|
29 |
# Sidebar for navigation
|
30 |
st.sidebar.title("Options")
|
31 |
-
app_mode = st.sidebar.selectbox("Choose the input source",
|
32 |
-
|
33 |
# Depending on the choice, show the uploader widget or webcam capture
|
34 |
if app_mode == "Upload image":
|
35 |
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "png"], on_change=lambda: reset_states())
|
36 |
if uploaded_file is not None:
|
37 |
image = Image.open(uploaded_file).convert("RGB")
|
38 |
elif app_mode == "Take a photo":
|
39 |
-
# Using JS code to access user's webcam
|
40 |
camera_input = st.camera_input("Take a picture", on_change=lambda: reset_states())
|
41 |
if camera_input is not None:
|
42 |
-
# Convert the camera image to an RGB image
|
43 |
image = Image.open(camera_input).convert("RGB")
|
44 |
|
45 |
def reset_states():
|
46 |
st.session_state['hat_enhanced_image'] = None
|
47 |
st.session_state['rcan_enhanced_image'] = None
|
48 |
st.session_state['srgan_enhanced_image'] = None
|
|
|
49 |
st.session_state['hat_clicked'] = False
|
50 |
st.session_state['rcan_clicked'] = False
|
51 |
st.session_state['srgan_clicked'] = False
|
|
|
52 |
|
53 |
def get_image_download_link(img, filename):
|
54 |
"""Generates a link allowing the PIL image to be downloaded"""
|
@@ -102,7 +106,8 @@ if 'image' in locals():
|
|
102 |
col2.image(st.session_state['rcan_enhanced_image'], use_column_width=True)
|
103 |
with col2:
|
104 |
get_image_download_link(st.session_state['rcan_enhanced_image'], 'rcan_enhanced.jpg')
|
105 |
-
|
|
|
106 |
if st.button('Enhance with SRGAN'):
|
107 |
with st.spinner('Processing using SRGAN...'):
|
108 |
with st.spinner('Wait for it... the model is processing the image'):
|
@@ -120,4 +125,21 @@ if 'image' in locals():
|
|
120 |
col2.header("Enhanced")
|
121 |
col2.image(st.session_state['srgan_enhanced_image'], use_column_width=True)
|
122 |
with col2:
|
123 |
-
get_image_download_link(st.session_state['srgan_enhanced_image'], 'srgan_enhanced.jpg')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
from io import BytesIO
|
6 |
from models.HAT.hat import *
|
7 |
from models.RCAN.rcan import *
|
8 |
+
from models.SRGAN.srgan import *
|
9 |
+
from models.SRFlow.srflow import *
|
10 |
|
11 |
# Initialize session state for enhanced images
|
12 |
if 'hat_enhanced_image' not in st.session_state:
|
13 |
st.session_state['hat_enhanced_image'] = None
|
|
|
14 |
if 'rcan_enhanced_image' not in st.session_state:
|
15 |
st.session_state['rcan_enhanced_image'] = None
|
|
|
16 |
if 'srgan_enhanced_image' not in st.session_state:
|
17 |
st.session_state['srgan_enhanced_image'] = None
|
18 |
+
if 'srflow_enhanced_image' not in st.session_state:
|
19 |
+
st.session_state['srflow_enhanced_image'] = None
|
20 |
|
21 |
+
# Initialize session state for button clicks
|
22 |
if 'hat_clicked' not in st.session_state:
|
23 |
st.session_state['hat_clicked'] = False
|
24 |
if 'rcan_clicked' not in st.session_state:
|
25 |
st.session_state['rcan_clicked'] = False
|
|
|
26 |
if 'srgan_clicked' not in st.session_state:
|
27 |
st.session_state['srgan_clicked'] = False
|
28 |
+
if 'srflow_clicked' not in st.session_state:
|
29 |
+
st.session_state['srflow_clicked'] = False
|
30 |
|
31 |
st.markdown("<h1 style='text-align: center'>Image Super Resolution</h1>", unsafe_allow_html=True)
|
32 |
+
|
33 |
# Sidebar for navigation
|
34 |
st.sidebar.title("Options")
|
35 |
+
app_mode = st.sidebar.selectbox("Choose the input source", ["Upload image", "Take a photo"])
|
36 |
+
|
37 |
# Depending on the choice, show the uploader widget or webcam capture
|
38 |
if app_mode == "Upload image":
|
39 |
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "png"], on_change=lambda: reset_states())
|
40 |
if uploaded_file is not None:
|
41 |
image = Image.open(uploaded_file).convert("RGB")
|
42 |
elif app_mode == "Take a photo":
|
|
|
43 |
camera_input = st.camera_input("Take a picture", on_change=lambda: reset_states())
|
44 |
if camera_input is not None:
|
|
|
45 |
image = Image.open(camera_input).convert("RGB")
|
46 |
|
47 |
def reset_states():
|
48 |
st.session_state['hat_enhanced_image'] = None
|
49 |
st.session_state['rcan_enhanced_image'] = None
|
50 |
st.session_state['srgan_enhanced_image'] = None
|
51 |
+
st.session_state['srflow_enhanced_image'] = None
|
52 |
st.session_state['hat_clicked'] = False
|
53 |
st.session_state['rcan_clicked'] = False
|
54 |
st.session_state['srgan_clicked'] = False
|
55 |
+
st.session_state['srflow_clicked'] = False
|
56 |
|
57 |
def get_image_download_link(img, filename):
|
58 |
"""Generates a link allowing the PIL image to be downloaded"""
|
|
|
106 |
col2.image(st.session_state['rcan_enhanced_image'], use_column_width=True)
|
107 |
with col2:
|
108 |
get_image_download_link(st.session_state['rcan_enhanced_image'], 'rcan_enhanced.jpg')
|
109 |
+
|
110 |
+
# --------------------------SRGAN-------------------------- #
|
111 |
if st.button('Enhance with SRGAN'):
|
112 |
with st.spinner('Processing using SRGAN...'):
|
113 |
with st.spinner('Wait for it... the model is processing the image'):
|
|
|
125 |
col2.header("Enhanced")
|
126 |
col2.image(st.session_state['srgan_enhanced_image'], use_column_width=True)
|
127 |
with col2:
|
128 |
+
get_image_download_link(st.session_state['srgan_enhanced_image'], 'srgan_enhanced.jpg')
|
129 |
+
|
130 |
+
# ------------------------ SRFlow ------------------------ #
|
131 |
+
if st.button('Enhance with SRFlow'):
|
132 |
+
with st.spinner('Processing using SRFlow...'):
|
133 |
+
with st.spinner('Wait for it... the model is processing the image'):
|
134 |
+
enhanced_image = return_SRFlow_result(image)
|
135 |
+
st.session_state['srflow_enhanced_image'] = enhanced_image
|
136 |
+
st.session_state['srflow_clicked'] = True
|
137 |
+
st.success('Done!')
|
138 |
+
if st.session_state['srflow_enhanced_image'] is not None:
|
139 |
+
col1, col2 = st.columns(2)
|
140 |
+
col1.header("Original")
|
141 |
+
col1.image(image, use_column_width=True)
|
142 |
+
col2.header("Enhanced")
|
143 |
+
col2.image(st.session_state['srflow_enhanced_image'], use_column_width=True)
|
144 |
+
with col2:
|
145 |
+
get_image_download_link(st.session_state['srflow_enhanced_image'], 'srflow_enhanced.jpg')
|
models/SRFlow/srflow.py
CHANGED
@@ -3,22 +3,23 @@ import torch
|
|
3 |
import sys
|
4 |
sys.path.append('models')
|
5 |
from SRFlow.code import imread, impad, load_model, t, rgb
|
|
|
|
|
6 |
|
7 |
-
|
8 |
-
def return_SRFlow_result(lr_path, conf_path='models/SRFlow/code/confs/SRFlow_DF2K_4X.yml', heat=0.6):
|
9 |
"""
|
10 |
Apply Super-Resolution using SRFlow model to the input LR (low-resolution) image.
|
11 |
|
12 |
Args:
|
13 |
-
-
|
14 |
- conf_path (str): Configuration file path for the SRFlow model. Default is SRFlow_DF2K_4X.yml.
|
15 |
- heat (float): Heat parameter for the SRFlow model. Default is 0.6.
|
16 |
|
17 |
Returns:
|
18 |
-
- sr
|
19 |
"""
|
20 |
model, opt = load_model(conf_path)
|
21 |
-
lr =
|
22 |
|
23 |
scale = opt['scale']
|
24 |
pad_factor = 2
|
@@ -35,5 +36,10 @@ def return_SRFlow_result(lr_path, conf_path='models/SRFlow/code/confs/SRFlow_DF2
|
|
35 |
sr = rgb(torch.clamp(sr_t, 0, 1))
|
36 |
sr = sr[:h * scale, :w * scale]
|
37 |
|
38 |
-
sr
|
39 |
return sr
|
|
|
|
|
|
|
|
|
|
|
|
3 |
import sys
|
4 |
sys.path.append('models')
|
5 |
from SRFlow.code import imread, impad, load_model, t, rgb
|
6 |
+
from PIL import Image
|
7 |
+
from torchvision.transforms import PILToTensor
|
8 |
|
9 |
+
def return_SRFlow_result(lr, conf_path='models/SRFlow/code/confs/SRFlow_DF2K_4X.yml', heat=0.6):
|
|
|
10 |
"""
|
11 |
Apply Super-Resolution using SRFlow model to the input LR (low-resolution) image.
|
12 |
|
13 |
Args:
|
14 |
+
- lr: PIL Image
|
15 |
- conf_path (str): Configuration file path for the SRFlow model. Default is SRFlow_DF2K_4X.yml.
|
16 |
- heat (float): Heat parameter for the SRFlow model. Default is 0.6.
|
17 |
|
18 |
Returns:
|
19 |
+
- sr: PIL Image
|
20 |
"""
|
21 |
model, opt = load_model(conf_path)
|
22 |
+
lr = PILToTensor()(lr).permute(1, 2, 0).numpy()
|
23 |
|
24 |
scale = opt['scale']
|
25 |
pad_factor = 2
|
|
|
36 |
sr = rgb(torch.clamp(sr_t, 0, 1))
|
37 |
sr = sr[:h * scale, :w * scale]
|
38 |
|
39 |
+
sr = Image.fromarray((sr).astype('uint8'))
|
40 |
return sr
|
41 |
+
|
42 |
+
if __name__ == '__main__':
|
43 |
+
ip = Image.open('images/demo.png')
|
44 |
+
sr = return_SRFlow_result(ip)
|
45 |
+
print(sr.size)
|