Update app.py
Browse files
app.py
CHANGED
@@ -1,327 +1,163 @@
|
|
1 |
-
|
2 |
try:
|
3 |
-
2
|
4 |
import spaces
|
5 |
-
3
|
6 |
except ImportError:
|
7 |
-
4
|
8 |
# Create a dummy decorator if spaces is not available
|
9 |
-
5
|
10 |
def spaces_gpu(func):
|
11 |
-
6
|
12 |
return func
|
13 |
-
7
|
14 |
spaces = type('spaces', (), {'GPU': spaces_gpu})()
|
15 |
-
8
|
16 |
|
17 |
-
9
|
18 |
import gradio as gr
|
19 |
-
10
|
20 |
import torch
|
21 |
-
11
|
22 |
from torchvision.transforms import functional as F
|
23 |
-
12
|
24 |
from PIL import Image
|
25 |
-
13
|
26 |
import os
|
27 |
-
14
|
28 |
import cv2
|
29 |
-
15
|
30 |
import numpy as np
|
31 |
-
16
|
32 |
from super_image import EdsrModel, ImageLoader
|
33 |
-
17
|
34 |
|
35 |
-
18
|
36 |
|
37 |
-
19
|
38 |
|
39 |
-
20
|
40 |
@spaces.GPU
|
41 |
-
21
|
42 |
def upscale_video(video_path, scale_factor, progress=gr.Progress()):
|
43 |
-
22
|
44 |
"""
|
45 |
-
23
|
46 |
Upscales a video using EDSR model.
|
47 |
-
24
|
48 |
This function is decorated with @spaces.GPU to run on ZeroGPU.
|
49 |
-
25
|
50 |
"""
|
51 |
-
26
|
52 |
# Load models inside the function for ZeroGPU compatibility
|
53 |
-
27
|
54 |
if scale_factor == 2:
|
55 |
-
28
|
56 |
model = EdsrModel.from_pretrained('eugenesiow/edsr-base', scale=2)
|
57 |
-
29
|
58 |
elif scale_factor == 4:
|
59 |
-
30
|
60 |
model = EdsrModel.from_pretrained('eugenesiow/edsr-base', scale=4)
|
61 |
-
31
|
62 |
else:
|
63 |
-
32
|
64 |
raise gr.Error("Invalid scale factor. Choose 2 or 4.")
|
65 |
-
33
|
66 |
|
67 |
-
34
|
68 |
if not os.path.exists(video_path):
|
69 |
-
35
|
70 |
raise gr.Error(f"Input file not found at {video_path}")
|
71 |
-
36
|
72 |
|
73 |
-
37
|
74 |
video_capture = cv2.VideoCapture(video_path)
|
75 |
-
38
|
76 |
if not video_capture.isOpened():
|
77 |
-
39
|
78 |
raise gr.Error(f"Could not open video file {video_path}")
|
79 |
-
40
|
80 |
|
81 |
-
41
|
82 |
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
83 |
-
42
|
84 |
fps = video_capture.get(cv2.CAP_PROP_FPS)
|
85 |
-
43
|
86 |
width = int(video_capture.get(cv2.CAP_PROP_FRAME_WIDTH))
|
87 |
-
44
|
88 |
height = int(video_capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
89 |
-
45
|
90 |
frame_count = int(video_capture.get(cv2.CAP_PROP_FRAME_COUNT))
|
91 |
-
46
|
92 |
|
93 |
-
47
|
94 |
output_width = width * scale_factor
|
95 |
-
48
|
96 |
output_height = height * scale_factor
|
97 |
-
49
|
98 |
|
99 |
-
50
|
100 |
output_path = f"upscaled_{scale_factor}x_{os.path.basename(video_path)}"
|
101 |
-
51
|
102 |
video_writer = cv2.VideoWriter(output_path, fourcc, fps, (output_width, output_height))
|
103 |
-
52
|
104 |
|
105 |
-
53
|
106 |
for i in progress.tqdm(range(frame_count), desc=f"Upscaling {scale_factor}x"):
|
107 |
-
54
|
108 |
ret, frame = video_capture.read()
|
109 |
-
55
|
110 |
if not ret:
|
111 |
-
56
|
112 |
break
|
113 |
-
57
|
114 |
|
115 |
-
58
|
116 |
pil_frame = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
|
117 |
-
59
|
118 |
|
119 |
-
60
|
120 |
inputs = ImageLoader.load_image(pil_frame)
|
121 |
-
61
|
122 |
preds = model(inputs)
|
123 |
-
62
|
124 |
output_frame = ImageLoader.save_image(preds, mode='RGB').convert("RGB")
|
125 |
-
63
|
126 |
|
127 |
-
64
|
128 |
video_writer.write(cv2.cvtColor(np.array(output_frame), cv2.COLOR_RGB2BGR))
|
129 |
-
65
|
130 |
|
131 |
-
66
|
132 |
video_capture.release()
|
133 |
-
67
|
134 |
video_writer.release()
|
135 |
-
68
|
136 |
|
137 |
-
69
|
138 |
return output_path
|
139 |
-
70
|
140 |
|
141 |
-
71
|
142 |
from RIFE import Model as RIFEModel
|
143 |
-
72
|
144 |
from safetensors.torch import load_file
|
145 |
-
73
|
146 |
|
147 |
-
74
|
148 |
# ... (existing code)
|
149 |
-
75
|
150 |
|
151 |
-
76
|
152 |
@spaces.GPU
|
153 |
-
77
|
154 |
def rife_interpolate_video(video_path, progress=gr.Progress()):
|
155 |
-
78
|
156 |
"""
|
157 |
-
79
|
158 |
Interpolates a video using the RIFE model.
|
159 |
-
80
|
160 |
This function is decorated with @spaces.GPU to run on ZeroGPU.
|
161 |
-
81
|
162 |
"""
|
163 |
-
82
|
164 |
if not os.path.exists(video_path):
|
165 |
-
83
|
166 |
raise gr.Error(f"Input file not found at {video_path}")
|
167 |
-
84
|
168 |
|
169 |
-
85
|
170 |
# Load the RIFE model
|
171 |
-
86
|
172 |
model = RIFEModel()
|
173 |
-
87
|
174 |
model.load_state_dict(load_file("/Users/craigellenwood/Workspace/video_upscaler_rife_interpolator/rife_model_new/rife-flownet-4.13.2.safetensors"))
|
175 |
-
88
|
176 |
model.eval()
|
177 |
-
89
|
178 |
model.cuda()
|
179 |
-
90
|
180 |
|
181 |
-
91
|
182 |
video_capture = cv2.VideoCapture(video_path)
|
183 |
-
92
|
184 |
if not video_capture.isOpened():
|
185 |
-
93
|
186 |
raise gr.Error(f"Could not open video file {video_path}")
|
187 |
-
94
|
188 |
|
189 |
-
95
|
190 |
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
191 |
-
96
|
192 |
fps = video_capture.get(cv2.CAP_PROP_FPS)
|
193 |
-
97
|
194 |
width = int(video_capture.get(cv2.CAP_PROP_FRAME_WIDTH))
|
195 |
-
98
|
196 |
height = int(video_capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
197 |
-
99
|
198 |
frame_count = int(video_capture.get(cv2.CAP_PROP_FRAME_COUNT))
|
199 |
-
100
|
200 |
|
201 |
-
101
|
202 |
output_path = f"interpolated_{os.path.basename(video_path)}"
|
203 |
-
102
|
204 |
video_writer = cv2.VideoWriter(output_path, fourcc, fps * 2, (width, height))
|
205 |
-
103
|
206 |
|
207 |
-
104
|
208 |
prev_frame = None
|
209 |
-
105
|
210 |
for i in progress.tqdm(range(frame_count), desc="Interpolating"):
|
211 |
-
106
|
212 |
ret, frame = video_capture.read()
|
213 |
-
107
|
214 |
if not ret:
|
215 |
-
108
|
216 |
break
|
217 |
-
109
|
218 |
|
219 |
-
110
|
220 |
if prev_frame is not None:
|
221 |
-
111
|
222 |
# Preprocess frames
|
223 |
-
112
|
224 |
img0 = torch.from_numpy(prev_frame.transpose(2, 0, 1)).float().unsqueeze(0).cuda() / 255.
|
225 |
-
113
|
226 |
img1 = torch.from_numpy(frame.transpose(2, 0, 1)).float().unsqueeze(0).cuda() / 255.
|
227 |
-
114
|
228 |
|
229 |
-
115
|
230 |
# Run inference
|
231 |
-
116
|
232 |
with torch.no_grad():
|
233 |
-
117
|
234 |
interpolated_frame = model.inference(img0, img1)[0].cpu().numpy().transpose(1, 2, 0) * 255
|
235 |
-
118
|
236 |
|
237 |
-
119
|
238 |
video_writer.write(interpolated_frame.astype(np.uint8))
|
239 |
-
120
|
240 |
|
241 |
-
121
|
242 |
video_writer.write(frame)
|
243 |
-
122
|
244 |
prev_frame = frame
|
245 |
-
123
|
246 |
|
247 |
-
124
|
248 |
video_capture.release()
|
249 |
-
125
|
250 |
video_writer.release()
|
251 |
-
126
|
252 |
|
253 |
-
127
|
254 |
return output_path
|
255 |
-
128
|
256 |
|
257 |
-
129
|
258 |
|
259 |
-
130
|
260 |
|
261 |
-
131
|
262 |
|
263 |
-
132
|
264 |
with gr.Blocks() as demo:
|
265 |
-
133
|
266 |
gr.Markdown("# Video Upscaler and Frame Interpolator")
|
267 |
-
134
|
268 |
with gr.Tab("Upscale"):
|
269 |
-
135
|
270 |
with gr.Row():
|
271 |
-
136
|
272 |
with gr.Column():
|
273 |
-
137
|
274 |
video_input_upscale = gr.Video(label="Input Video")
|
275 |
-
138
|
276 |
scale_factor = gr.Radio([2, 4], label="Scale Factor", value=2)
|
277 |
-
139
|
278 |
upscale_button = gr.Button("Upscale Video")
|
279 |
-
140
|
280 |
with gr.Column():
|
281 |
-
141
|
282 |
video_output_upscale = gr.Video(label="Upscaled Video")
|
283 |
-
142
|
284 |
with gr.Tab("Interpolate"):
|
285 |
-
143
|
286 |
with gr.Row():
|
287 |
-
144
|
288 |
with gr.Column():
|
289 |
-
145
|
290 |
video_input_rife = gr.Video(label="Input Video")
|
291 |
-
146
|
292 |
rife_button = gr.Button("Interpolate Frames")
|
293 |
-
147
|
294 |
with gr.Column():
|
295 |
-
148
|
296 |
video_output_rife = gr.Video(label="Interpolated Video")
|
297 |
-
149
|
298 |
|
299 |
-
150
|
300 |
upscale_button.click(
|
301 |
-
151
|
302 |
fn=upscale_video,
|
303 |
-
152
|
304 |
inputs=[video_input_upscale, scale_factor],
|
305 |
-
153
|
306 |
outputs=video_output_upscale
|
307 |
-
154
|
308 |
)
|
309 |
-
155
|
310 |
|
311 |
-
156
|
312 |
rife_button.click(
|
313 |
-
157
|
314 |
fn=rife_interpolate_video,
|
315 |
-
158
|
316 |
inputs=[video_input_rife],
|
317 |
-
159
|
318 |
outputs=video_output_rife
|
319 |
-
160
|
320 |
)
|
321 |
-
161
|
322 |
|
323 |
-
162
|
324 |
if __name__ == "__main__":
|
325 |
-
163
|
326 |
demo.launch(share=True)
|
327 |
-
164
|
|
|
|
|
1 |
try:
|
|
|
2 |
import spaces
|
|
|
3 |
except ImportError:
|
|
|
4 |
# Create a dummy decorator if spaces is not available
|
|
|
5 |
def spaces_gpu(func):
|
|
|
6 |
return func
|
|
|
7 |
spaces = type('spaces', (), {'GPU': spaces_gpu})()
|
|
|
8 |
|
|
|
9 |
import gradio as gr
|
|
|
10 |
import torch
|
|
|
11 |
from torchvision.transforms import functional as F
|
|
|
12 |
from PIL import Image
|
|
|
13 |
import os
|
|
|
14 |
import cv2
|
|
|
15 |
import numpy as np
|
|
|
16 |
from super_image import EdsrModel, ImageLoader
|
|
|
17 |
|
|
|
18 |
|
|
|
19 |
|
|
|
20 |
@spaces.GPU
|
|
|
21 |
def upscale_video(video_path, scale_factor, progress=gr.Progress()):
|
|
|
22 |
"""
|
|
|
23 |
Upscales a video using EDSR model.
|
|
|
24 |
This function is decorated with @spaces.GPU to run on ZeroGPU.
|
|
|
25 |
"""
|
|
|
26 |
# Load models inside the function for ZeroGPU compatibility
|
|
|
27 |
if scale_factor == 2:
|
|
|
28 |
model = EdsrModel.from_pretrained('eugenesiow/edsr-base', scale=2)
|
|
|
29 |
elif scale_factor == 4:
|
|
|
30 |
model = EdsrModel.from_pretrained('eugenesiow/edsr-base', scale=4)
|
|
|
31 |
else:
|
|
|
32 |
raise gr.Error("Invalid scale factor. Choose 2 or 4.")
|
|
|
33 |
|
|
|
34 |
if not os.path.exists(video_path):
|
|
|
35 |
raise gr.Error(f"Input file not found at {video_path}")
|
|
|
36 |
|
|
|
37 |
video_capture = cv2.VideoCapture(video_path)
|
|
|
38 |
if not video_capture.isOpened():
|
|
|
39 |
raise gr.Error(f"Could not open video file {video_path}")
|
|
|
40 |
|
|
|
41 |
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
|
|
42 |
fps = video_capture.get(cv2.CAP_PROP_FPS)
|
|
|
43 |
width = int(video_capture.get(cv2.CAP_PROP_FRAME_WIDTH))
|
|
|
44 |
height = int(video_capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
|
|
45 |
frame_count = int(video_capture.get(cv2.CAP_PROP_FRAME_COUNT))
|
|
|
46 |
|
|
|
47 |
output_width = width * scale_factor
|
|
|
48 |
output_height = height * scale_factor
|
|
|
49 |
|
|
|
50 |
output_path = f"upscaled_{scale_factor}x_{os.path.basename(video_path)}"
|
|
|
51 |
video_writer = cv2.VideoWriter(output_path, fourcc, fps, (output_width, output_height))
|
|
|
52 |
|
|
|
53 |
for i in progress.tqdm(range(frame_count), desc=f"Upscaling {scale_factor}x"):
|
|
|
54 |
ret, frame = video_capture.read()
|
|
|
55 |
if not ret:
|
|
|
56 |
break
|
|
|
57 |
|
|
|
58 |
pil_frame = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
|
|
|
59 |
|
|
|
60 |
inputs = ImageLoader.load_image(pil_frame)
|
|
|
61 |
preds = model(inputs)
|
|
|
62 |
output_frame = ImageLoader.save_image(preds, mode='RGB').convert("RGB")
|
|
|
63 |
|
|
|
64 |
video_writer.write(cv2.cvtColor(np.array(output_frame), cv2.COLOR_RGB2BGR))
|
|
|
65 |
|
|
|
66 |
video_capture.release()
|
|
|
67 |
video_writer.release()
|
|
|
68 |
|
|
|
69 |
return output_path
|
|
|
70 |
|
|
|
71 |
from RIFE import Model as RIFEModel
|
|
|
72 |
from safetensors.torch import load_file
|
|
|
73 |
|
|
|
74 |
# ... (existing code)
|
|
|
75 |
|
|
|
76 |
@spaces.GPU
|
|
|
77 |
def rife_interpolate_video(video_path, progress=gr.Progress()):
|
|
|
78 |
"""
|
|
|
79 |
Interpolates a video using the RIFE model.
|
|
|
80 |
This function is decorated with @spaces.GPU to run on ZeroGPU.
|
|
|
81 |
"""
|
|
|
82 |
if not os.path.exists(video_path):
|
|
|
83 |
raise gr.Error(f"Input file not found at {video_path}")
|
|
|
84 |
|
|
|
85 |
# Load the RIFE model
|
|
|
86 |
model = RIFEModel()
|
|
|
87 |
model.load_state_dict(load_file("/Users/craigellenwood/Workspace/video_upscaler_rife_interpolator/rife_model_new/rife-flownet-4.13.2.safetensors"))
|
|
|
88 |
model.eval()
|
|
|
89 |
model.cuda()
|
|
|
90 |
|
|
|
91 |
video_capture = cv2.VideoCapture(video_path)
|
|
|
92 |
if not video_capture.isOpened():
|
|
|
93 |
raise gr.Error(f"Could not open video file {video_path}")
|
|
|
94 |
|
|
|
95 |
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
|
|
96 |
fps = video_capture.get(cv2.CAP_PROP_FPS)
|
|
|
97 |
width = int(video_capture.get(cv2.CAP_PROP_FRAME_WIDTH))
|
|
|
98 |
height = int(video_capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
|
|
99 |
frame_count = int(video_capture.get(cv2.CAP_PROP_FRAME_COUNT))
|
|
|
100 |
|
|
|
101 |
output_path = f"interpolated_{os.path.basename(video_path)}"
|
|
|
102 |
video_writer = cv2.VideoWriter(output_path, fourcc, fps * 2, (width, height))
|
|
|
103 |
|
|
|
104 |
prev_frame = None
|
|
|
105 |
for i in progress.tqdm(range(frame_count), desc="Interpolating"):
|
|
|
106 |
ret, frame = video_capture.read()
|
|
|
107 |
if not ret:
|
|
|
108 |
break
|
|
|
109 |
|
|
|
110 |
if prev_frame is not None:
|
|
|
111 |
# Preprocess frames
|
|
|
112 |
img0 = torch.from_numpy(prev_frame.transpose(2, 0, 1)).float().unsqueeze(0).cuda() / 255.
|
|
|
113 |
img1 = torch.from_numpy(frame.transpose(2, 0, 1)).float().unsqueeze(0).cuda() / 255.
|
|
|
114 |
|
|
|
115 |
# Run inference
|
|
|
116 |
with torch.no_grad():
|
|
|
117 |
interpolated_frame = model.inference(img0, img1)[0].cpu().numpy().transpose(1, 2, 0) * 255
|
|
|
118 |
|
|
|
119 |
video_writer.write(interpolated_frame.astype(np.uint8))
|
|
|
120 |
|
|
|
121 |
video_writer.write(frame)
|
|
|
122 |
prev_frame = frame
|
|
|
123 |
|
|
|
124 |
video_capture.release()
|
|
|
125 |
video_writer.release()
|
|
|
126 |
|
|
|
127 |
return output_path
|
|
|
128 |
|
|
|
129 |
|
|
|
130 |
|
|
|
131 |
|
|
|
132 |
with gr.Blocks() as demo:
|
|
|
133 |
gr.Markdown("# Video Upscaler and Frame Interpolator")
|
|
|
134 |
with gr.Tab("Upscale"):
|
|
|
135 |
with gr.Row():
|
|
|
136 |
with gr.Column():
|
|
|
137 |
video_input_upscale = gr.Video(label="Input Video")
|
|
|
138 |
scale_factor = gr.Radio([2, 4], label="Scale Factor", value=2)
|
|
|
139 |
upscale_button = gr.Button("Upscale Video")
|
|
|
140 |
with gr.Column():
|
|
|
141 |
video_output_upscale = gr.Video(label="Upscaled Video")
|
|
|
142 |
with gr.Tab("Interpolate"):
|
|
|
143 |
with gr.Row():
|
|
|
144 |
with gr.Column():
|
|
|
145 |
video_input_rife = gr.Video(label="Input Video")
|
|
|
146 |
rife_button = gr.Button("Interpolate Frames")
|
|
|
147 |
with gr.Column():
|
|
|
148 |
video_output_rife = gr.Video(label="Interpolated Video")
|
|
|
149 |
|
|
|
150 |
upscale_button.click(
|
|
|
151 |
fn=upscale_video,
|
|
|
152 |
inputs=[video_input_upscale, scale_factor],
|
|
|
153 |
outputs=video_output_upscale
|
|
|
154 |
)
|
|
|
155 |
|
|
|
156 |
rife_button.click(
|
|
|
157 |
fn=rife_interpolate_video,
|
|
|
158 |
inputs=[video_input_rife],
|
|
|
159 |
outputs=video_output_rife
|
|
|
160 |
)
|
|
|
161 |
|
|
|
162 |
if __name__ == "__main__":
|
|
|
163 |
demo.launch(share=True)
|
|