|
try: |
|
import spaces |
|
except ImportError: |
|
|
|
def spaces_gpu(func): |
|
return func |
|
spaces = type('spaces', (), {'GPU': spaces_gpu})() |
|
|
|
import gradio as gr |
|
import torch |
|
from torchvision.transforms import functional as F |
|
from PIL import Image |
|
import os |
|
import cv2 |
|
import numpy as np |
|
from super_image import EdsrModel, ImageLoader |
|
from safetensors.torch import load_file |
|
|
|
|
|
@spaces.GPU |
|
def upscale_video(video_path, scale_factor, progress=gr.Progress()): |
|
""" |
|
Upscales a video using EDSR model. |
|
This function is decorated with @spaces.GPU to run on ZeroGPU. |
|
""" |
|
|
|
if scale_factor == 2: |
|
model = EdsrModel.from_pretrained('eugenesiow/edsr-base', scale=2) |
|
elif scale_factor == 4: |
|
model = EdsrModel.from_pretrained('eugenesiow/edsr-base', scale=4) |
|
else: |
|
raise gr.Error("Invalid scale factor. Choose 2 or 4.") |
|
|
|
if not os.path.exists(video_path): |
|
raise gr.Error(f"Input file not found at {video_path}") |
|
|
|
video_capture = cv2.VideoCapture(video_path) |
|
if not video_capture.isOpened(): |
|
raise gr.Error(f"Could not open video file {video_path}") |
|
|
|
fourcc = cv2.VideoWriter_fourcc(*'mp4v') |
|
fps = video_capture.get(cv2.CAP_PROP_FPS) |
|
width = int(video_capture.get(cv2.CAP_PROP_FRAME_WIDTH)) |
|
height = int(video_capture.get(cv2.CAP_PROP_FRAME_HEIGHT)) |
|
frame_count = int(video_capture.get(cv2.CAP_PROP_FRAME_COUNT)) |
|
|
|
output_width = width * scale_factor |
|
output_height = height * scale_factor |
|
|
|
output_path = f"upscaled_{scale_factor}x_{os.path.basename(video_path)}" |
|
video_writer = cv2.VideoWriter(output_path, fourcc, fps, (output_width, output_height)) |
|
|
|
for i in progress.tqdm(range(frame_count), desc=f"Upscaling {scale_factor}x"): |
|
ret, frame = video_capture.read() |
|
if not ret: |
|
break |
|
|
|
pil_frame = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) |
|
|
|
inputs = ImageLoader.load_image(pil_frame) |
|
preds = model(inputs) |
|
output_frame = ImageLoader.to_pil(preds) |
|
video_writer.write(cv2.cvtColor(np.array(output_frame), cv2.COLOR_RGB2BGR)) |
|
|
|
video_capture.release() |
|
video_writer.release() |
|
|
|
return output_path |
|
|
|
|
|
@spaces.GPU |
|
def rife_interpolate_video(video_path, progress=gr.Progress()): |
|
""" |
|
Interpolates a video using the RIFE model. |
|
This function is decorated with @spaces.GPU to run on ZeroGPU. |
|
""" |
|
if not os.path.exists(video_path): |
|
raise gr.Error(f"Input file not found at {video_path}") |
|
|
|
|
|
model = RIFEModel() |
|
model.load_state_dict(load_file("/Users/craigellenwood/Workspace/video_upscaler_rife_interpolator/rife_model_new/rife-flownet-4.13.2.safetensors")) |
|
model.eval() |
|
model.cuda() |
|
|
|
video_capture = cv2.VideoCapture(video_path) |
|
if not video_capture.isOpened(): |
|
raise gr.Error(f"Could not open video file {video_path}") |
|
|
|
fourcc = cv2.VideoWriter_fourcc(*'mp4v') |
|
fps = video_capture.get(cv2.CAP_PROP_FPS) |
|
width = int(video_capture.get(cv2.CAP_PROP_FRAME_WIDTH)) |
|
height = int(video_capture.get(cv2.CAP_PROP_FRAME_HEIGHT)) |
|
frame_count = int(video_capture.get(cv2.CAP_PROP_FRAME_COUNT)) |
|
|
|
output_path = f"interpolated_{os.path.basename(video_path)}" |
|
video_writer = cv2.VideoWriter(output_path, fourcc, fps * 2, (width, height)) |
|
|
|
prev_frame = None |
|
for i in progress.tqdm(range(frame_count), desc="Interpolating"): |
|
ret, frame = video_capture.read() |
|
if not ret: |
|
break |
|
|
|
if prev_frame is not None: |
|
|
|
img0 = torch.from_numpy(prev_frame.transpose(2, 0, 1)).float().unsqueeze(0).cuda() / 255. |
|
img1 = torch.from_numpy(frame.transpose(2, 0, 1)).float().unsqueeze(0).cuda() / 255. |
|
|
|
|
|
with torch.no_grad(): |
|
interpolated_frame = model.inference(img0, img1)[0].cpu().numpy().transpose(1, 2, 0) * 255 |
|
|
|
video_writer.write(interpolated_frame.astype(np.uint8)) |
|
|
|
video_writer.write(frame) |
|
prev_frame = frame |
|
|
|
video_capture.release() |
|
video_writer.release() |
|
|
|
return output_path |
|
|
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("# Video Upscaler and Frame Interpolator") |
|
with gr.Tab("Upscale"): |
|
with gr.Row(): |
|
with gr.Column(): |
|
video_input_upscale = gr.Video(label="Input Video") |
|
scale_factor = gr.Radio([2, 4], label="Scale Factor", value=2) |
|
upscale_button = gr.Button("Upscale Video") |
|
with gr.Column(): |
|
video_output_upscale = gr.Video(label="Upscaled Video") |
|
with gr.Tab("Interpolate"): |
|
with gr.Row(): |
|
with gr.Column(): |
|
video_input_rife = gr.Video(label="Input Video") |
|
rife_button = gr.Button("Interpolate Frames") |
|
with gr.Column(): |
|
video_output_rife = gr.Video(label="Interpolated Video") |
|
|
|
upscale_button.click( |
|
fn=upscale_video, |
|
inputs=[video_input_upscale, scale_factor], |
|
outputs=video_output_upscale |
|
) |
|
|
|
rife_button.click( |
|
fn=rife_interpolate_video, |
|
inputs=[video_input_rife], |
|
outputs=video_output_rife |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch(share=True) |
|
|