inoculatemedia commited on
Commit
c4bd972
·
verified ·
1 Parent(s): 4a8a7a3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -164
app.py CHANGED
@@ -1,327 +1,163 @@
1
-
2
  try:
3
- 2
4
  import spaces
5
- 3
6
  except ImportError:
7
- 4
8
  # Create a dummy decorator if spaces is not available
9
- 5
10
  def spaces_gpu(func):
11
- 6
12
  return func
13
- 7
14
  spaces = type('spaces', (), {'GPU': spaces_gpu})()
15
- 8
16
 
17
- 9
18
  import gradio as gr
19
- 10
20
  import torch
21
- 11
22
  from torchvision.transforms import functional as F
23
- 12
24
  from PIL import Image
25
- 13
26
  import os
27
- 14
28
  import cv2
29
- 15
30
  import numpy as np
31
- 16
32
  from super_image import EdsrModel, ImageLoader
33
- 17
34
 
35
- 18
36
 
37
- 19
38
 
39
- 20
40
  @spaces.GPU
41
- 21
42
  def upscale_video(video_path, scale_factor, progress=gr.Progress()):
43
- 22
44
  """
45
- 23
46
  Upscales a video using EDSR model.
47
- 24
48
  This function is decorated with @spaces.GPU to run on ZeroGPU.
49
- 25
50
  """
51
- 26
52
  # Load models inside the function for ZeroGPU compatibility
53
- 27
54
  if scale_factor == 2:
55
- 28
56
  model = EdsrModel.from_pretrained('eugenesiow/edsr-base', scale=2)
57
- 29
58
  elif scale_factor == 4:
59
- 30
60
  model = EdsrModel.from_pretrained('eugenesiow/edsr-base', scale=4)
61
- 31
62
  else:
63
- 32
64
  raise gr.Error("Invalid scale factor. Choose 2 or 4.")
65
- 33
66
 
67
- 34
68
  if not os.path.exists(video_path):
69
- 35
70
  raise gr.Error(f"Input file not found at {video_path}")
71
- 36
72
 
73
- 37
74
  video_capture = cv2.VideoCapture(video_path)
75
- 38
76
  if not video_capture.isOpened():
77
- 39
78
  raise gr.Error(f"Could not open video file {video_path}")
79
- 40
80
 
81
- 41
82
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
83
- 42
84
  fps = video_capture.get(cv2.CAP_PROP_FPS)
85
- 43
86
  width = int(video_capture.get(cv2.CAP_PROP_FRAME_WIDTH))
87
- 44
88
  height = int(video_capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
89
- 45
90
  frame_count = int(video_capture.get(cv2.CAP_PROP_FRAME_COUNT))
91
- 46
92
 
93
- 47
94
  output_width = width * scale_factor
95
- 48
96
  output_height = height * scale_factor
97
- 49
98
 
99
- 50
100
  output_path = f"upscaled_{scale_factor}x_{os.path.basename(video_path)}"
101
- 51
102
  video_writer = cv2.VideoWriter(output_path, fourcc, fps, (output_width, output_height))
103
- 52
104
 
105
- 53
106
  for i in progress.tqdm(range(frame_count), desc=f"Upscaling {scale_factor}x"):
107
- 54
108
  ret, frame = video_capture.read()
109
- 55
110
  if not ret:
111
- 56
112
  break
113
- 57
114
 
115
- 58
116
  pil_frame = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
117
- 59
118
 
119
- 60
120
  inputs = ImageLoader.load_image(pil_frame)
121
- 61
122
  preds = model(inputs)
123
- 62
124
  output_frame = ImageLoader.save_image(preds, mode='RGB').convert("RGB")
125
- 63
126
 
127
- 64
128
  video_writer.write(cv2.cvtColor(np.array(output_frame), cv2.COLOR_RGB2BGR))
129
- 65
130
 
131
- 66
132
  video_capture.release()
133
- 67
134
  video_writer.release()
135
- 68
136
 
137
- 69
138
  return output_path
139
- 70
140
 
141
- 71
142
  from RIFE import Model as RIFEModel
143
- 72
144
  from safetensors.torch import load_file
145
- 73
146
 
147
- 74
148
  # ... (existing code)
149
- 75
150
 
151
- 76
152
  @spaces.GPU
153
- 77
154
  def rife_interpolate_video(video_path, progress=gr.Progress()):
155
- 78
156
  """
157
- 79
158
  Interpolates a video using the RIFE model.
159
- 80
160
  This function is decorated with @spaces.GPU to run on ZeroGPU.
161
- 81
162
  """
163
- 82
164
  if not os.path.exists(video_path):
165
- 83
166
  raise gr.Error(f"Input file not found at {video_path}")
167
- 84
168
 
169
- 85
170
  # Load the RIFE model
171
- 86
172
  model = RIFEModel()
173
- 87
174
  model.load_state_dict(load_file("/Users/craigellenwood/Workspace/video_upscaler_rife_interpolator/rife_model_new/rife-flownet-4.13.2.safetensors"))
175
- 88
176
  model.eval()
177
- 89
178
  model.cuda()
179
- 90
180
 
181
- 91
182
  video_capture = cv2.VideoCapture(video_path)
183
- 92
184
  if not video_capture.isOpened():
185
- 93
186
  raise gr.Error(f"Could not open video file {video_path}")
187
- 94
188
 
189
- 95
190
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
191
- 96
192
  fps = video_capture.get(cv2.CAP_PROP_FPS)
193
- 97
194
  width = int(video_capture.get(cv2.CAP_PROP_FRAME_WIDTH))
195
- 98
196
  height = int(video_capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
197
- 99
198
  frame_count = int(video_capture.get(cv2.CAP_PROP_FRAME_COUNT))
199
- 100
200
 
201
- 101
202
  output_path = f"interpolated_{os.path.basename(video_path)}"
203
- 102
204
  video_writer = cv2.VideoWriter(output_path, fourcc, fps * 2, (width, height))
205
- 103
206
 
207
- 104
208
  prev_frame = None
209
- 105
210
  for i in progress.tqdm(range(frame_count), desc="Interpolating"):
211
- 106
212
  ret, frame = video_capture.read()
213
- 107
214
  if not ret:
215
- 108
216
  break
217
- 109
218
 
219
- 110
220
  if prev_frame is not None:
221
- 111
222
  # Preprocess frames
223
- 112
224
  img0 = torch.from_numpy(prev_frame.transpose(2, 0, 1)).float().unsqueeze(0).cuda() / 255.
225
- 113
226
  img1 = torch.from_numpy(frame.transpose(2, 0, 1)).float().unsqueeze(0).cuda() / 255.
227
- 114
228
 
229
- 115
230
  # Run inference
231
- 116
232
  with torch.no_grad():
233
- 117
234
  interpolated_frame = model.inference(img0, img1)[0].cpu().numpy().transpose(1, 2, 0) * 255
235
- 118
236
 
237
- 119
238
  video_writer.write(interpolated_frame.astype(np.uint8))
239
- 120
240
 
241
- 121
242
  video_writer.write(frame)
243
- 122
244
  prev_frame = frame
245
- 123
246
 
247
- 124
248
  video_capture.release()
249
- 125
250
  video_writer.release()
251
- 126
252
 
253
- 127
254
  return output_path
255
- 128
256
 
257
- 129
258
 
259
- 130
260
 
261
- 131
262
 
263
- 132
264
  with gr.Blocks() as demo:
265
- 133
266
  gr.Markdown("# Video Upscaler and Frame Interpolator")
267
- 134
268
  with gr.Tab("Upscale"):
269
- 135
270
  with gr.Row():
271
- 136
272
  with gr.Column():
273
- 137
274
  video_input_upscale = gr.Video(label="Input Video")
275
- 138
276
  scale_factor = gr.Radio([2, 4], label="Scale Factor", value=2)
277
- 139
278
  upscale_button = gr.Button("Upscale Video")
279
- 140
280
  with gr.Column():
281
- 141
282
  video_output_upscale = gr.Video(label="Upscaled Video")
283
- 142
284
  with gr.Tab("Interpolate"):
285
- 143
286
  with gr.Row():
287
- 144
288
  with gr.Column():
289
- 145
290
  video_input_rife = gr.Video(label="Input Video")
291
- 146
292
  rife_button = gr.Button("Interpolate Frames")
293
- 147
294
  with gr.Column():
295
- 148
296
  video_output_rife = gr.Video(label="Interpolated Video")
297
- 149
298
 
299
- 150
300
  upscale_button.click(
301
- 151
302
  fn=upscale_video,
303
- 152
304
  inputs=[video_input_upscale, scale_factor],
305
- 153
306
  outputs=video_output_upscale
307
- 154
308
  )
309
- 155
310
 
311
- 156
312
  rife_button.click(
313
- 157
314
  fn=rife_interpolate_video,
315
- 158
316
  inputs=[video_input_rife],
317
- 159
318
  outputs=video_output_rife
319
- 160
320
  )
321
- 161
322
 
323
- 162
324
  if __name__ == "__main__":
325
- 163
326
  demo.launch(share=True)
327
- 164
 
 
1
  try:
 
2
  import spaces
 
3
  except ImportError:
 
4
  # Create a dummy decorator if spaces is not available
 
5
  def spaces_gpu(func):
 
6
  return func
 
7
  spaces = type('spaces', (), {'GPU': spaces_gpu})()
 
8
 
 
9
  import gradio as gr
 
10
  import torch
 
11
  from torchvision.transforms import functional as F
 
12
  from PIL import Image
 
13
  import os
 
14
  import cv2
 
15
  import numpy as np
 
16
  from super_image import EdsrModel, ImageLoader
 
17
 
 
18
 
 
19
 
 
20
  @spaces.GPU
 
21
  def upscale_video(video_path, scale_factor, progress=gr.Progress()):
 
22
  """
 
23
  Upscales a video using EDSR model.
 
24
  This function is decorated with @spaces.GPU to run on ZeroGPU.
 
25
  """
 
26
  # Load models inside the function for ZeroGPU compatibility
 
27
  if scale_factor == 2:
 
28
  model = EdsrModel.from_pretrained('eugenesiow/edsr-base', scale=2)
 
29
  elif scale_factor == 4:
 
30
  model = EdsrModel.from_pretrained('eugenesiow/edsr-base', scale=4)
 
31
  else:
 
32
  raise gr.Error("Invalid scale factor. Choose 2 or 4.")
 
33
 
 
34
  if not os.path.exists(video_path):
 
35
  raise gr.Error(f"Input file not found at {video_path}")
 
36
 
 
37
  video_capture = cv2.VideoCapture(video_path)
 
38
  if not video_capture.isOpened():
 
39
  raise gr.Error(f"Could not open video file {video_path}")
 
40
 
 
41
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
 
42
  fps = video_capture.get(cv2.CAP_PROP_FPS)
 
43
  width = int(video_capture.get(cv2.CAP_PROP_FRAME_WIDTH))
 
44
  height = int(video_capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
 
45
  frame_count = int(video_capture.get(cv2.CAP_PROP_FRAME_COUNT))
 
46
 
 
47
  output_width = width * scale_factor
 
48
  output_height = height * scale_factor
 
49
 
 
50
  output_path = f"upscaled_{scale_factor}x_{os.path.basename(video_path)}"
 
51
  video_writer = cv2.VideoWriter(output_path, fourcc, fps, (output_width, output_height))
 
52
 
 
53
  for i in progress.tqdm(range(frame_count), desc=f"Upscaling {scale_factor}x"):
 
54
  ret, frame = video_capture.read()
 
55
  if not ret:
 
56
  break
 
57
 
 
58
  pil_frame = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
 
59
 
 
60
  inputs = ImageLoader.load_image(pil_frame)
 
61
  preds = model(inputs)
 
62
  output_frame = ImageLoader.save_image(preds, mode='RGB').convert("RGB")
 
63
 
 
64
  video_writer.write(cv2.cvtColor(np.array(output_frame), cv2.COLOR_RGB2BGR))
 
65
 
 
66
  video_capture.release()
 
67
  video_writer.release()
 
68
 
 
69
  return output_path
 
70
 
 
71
  from RIFE import Model as RIFEModel
 
72
  from safetensors.torch import load_file
 
73
 
 
74
  # ... (existing code)
 
75
 
 
76
  @spaces.GPU
 
77
  def rife_interpolate_video(video_path, progress=gr.Progress()):
 
78
  """
 
79
  Interpolates a video using the RIFE model.
 
80
  This function is decorated with @spaces.GPU to run on ZeroGPU.
 
81
  """
 
82
  if not os.path.exists(video_path):
 
83
  raise gr.Error(f"Input file not found at {video_path}")
 
84
 
 
85
  # Load the RIFE model
 
86
  model = RIFEModel()
 
87
  model.load_state_dict(load_file("/Users/craigellenwood/Workspace/video_upscaler_rife_interpolator/rife_model_new/rife-flownet-4.13.2.safetensors"))
 
88
  model.eval()
 
89
  model.cuda()
 
90
 
 
91
  video_capture = cv2.VideoCapture(video_path)
 
92
  if not video_capture.isOpened():
 
93
  raise gr.Error(f"Could not open video file {video_path}")
 
94
 
 
95
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
 
96
  fps = video_capture.get(cv2.CAP_PROP_FPS)
 
97
  width = int(video_capture.get(cv2.CAP_PROP_FRAME_WIDTH))
 
98
  height = int(video_capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
 
99
  frame_count = int(video_capture.get(cv2.CAP_PROP_FRAME_COUNT))
 
100
 
 
101
  output_path = f"interpolated_{os.path.basename(video_path)}"
 
102
  video_writer = cv2.VideoWriter(output_path, fourcc, fps * 2, (width, height))
 
103
 
 
104
  prev_frame = None
 
105
  for i in progress.tqdm(range(frame_count), desc="Interpolating"):
 
106
  ret, frame = video_capture.read()
 
107
  if not ret:
 
108
  break
 
109
 
 
110
  if prev_frame is not None:
 
111
  # Preprocess frames
 
112
  img0 = torch.from_numpy(prev_frame.transpose(2, 0, 1)).float().unsqueeze(0).cuda() / 255.
 
113
  img1 = torch.from_numpy(frame.transpose(2, 0, 1)).float().unsqueeze(0).cuda() / 255.
 
114
 
 
115
  # Run inference
 
116
  with torch.no_grad():
 
117
  interpolated_frame = model.inference(img0, img1)[0].cpu().numpy().transpose(1, 2, 0) * 255
 
118
 
 
119
  video_writer.write(interpolated_frame.astype(np.uint8))
 
120
 
 
121
  video_writer.write(frame)
 
122
  prev_frame = frame
 
123
 
 
124
  video_capture.release()
 
125
  video_writer.release()
 
126
 
 
127
  return output_path
 
128
 
 
129
 
 
130
 
 
131
 
 
132
  with gr.Blocks() as demo:
 
133
  gr.Markdown("# Video Upscaler and Frame Interpolator")
 
134
  with gr.Tab("Upscale"):
 
135
  with gr.Row():
 
136
  with gr.Column():
 
137
  video_input_upscale = gr.Video(label="Input Video")
 
138
  scale_factor = gr.Radio([2, 4], label="Scale Factor", value=2)
 
139
  upscale_button = gr.Button("Upscale Video")
 
140
  with gr.Column():
 
141
  video_output_upscale = gr.Video(label="Upscaled Video")
 
142
  with gr.Tab("Interpolate"):
 
143
  with gr.Row():
 
144
  with gr.Column():
 
145
  video_input_rife = gr.Video(label="Input Video")
 
146
  rife_button = gr.Button("Interpolate Frames")
 
147
  with gr.Column():
 
148
  video_output_rife = gr.Video(label="Interpolated Video")
 
149
 
 
150
  upscale_button.click(
 
151
  fn=upscale_video,
 
152
  inputs=[video_input_upscale, scale_factor],
 
153
  outputs=video_output_upscale
 
154
  )
 
155
 
 
156
  rife_button.click(
 
157
  fn=rife_interpolate_video,
 
158
  inputs=[video_input_rife],
 
159
  outputs=video_output_rife
 
160
  )
 
161
 
 
162
  if __name__ == "__main__":
 
163
  demo.launch(share=True)