|
import os |
|
import gradio as gr |
|
import subprocess |
|
import uuid |
|
import shutil |
|
from huggingface_hub import snapshot_download |
|
|
|
|
|
|
|
|
|
MODEL_REPO = "roll-ai/DOVE" |
|
MODEL_PATH = "pretrained_models/" |
|
|
|
if not os.path.exists(MODEL_PATH) or len(os.listdir(MODEL_PATH)) == 0: |
|
print("🔽 Downloading model weights from Hugging Face Hub...") |
|
snapshot_download( |
|
repo_id=MODEL_REPO, |
|
repo_type="dataset", |
|
local_dir=MODEL_PATH, |
|
local_dir_use_symlinks=False |
|
) |
|
print("✅ Download complete.") |
|
|
|
|
|
|
|
|
|
INFERENCE_SCRIPT = "inference_script.py" |
|
OUTPUT_DIR = "results/DOVE/demo" |
|
UPLOAD_DIR = "input_videos" |
|
|
|
os.makedirs(UPLOAD_DIR, exist_ok=True) |
|
os.makedirs(OUTPUT_DIR, exist_ok=True) |
|
|
|
|
|
|
|
|
|
def run_inference(video_path, save_format): |
|
input_name = f"{uuid.uuid4()}.mp4" |
|
input_path = os.path.join(UPLOAD_DIR, input_name) |
|
shutil.copy(video_path, input_path) |
|
|
|
|
|
cmd = [ |
|
"python", INFERENCE_SCRIPT, |
|
"--input_dir", UPLOAD_DIR, |
|
"--model_path", MODEL_PATH, |
|
"--output_path", OUTPUT_DIR, |
|
"--is_vae_st", |
|
"--save_format", save_format |
|
] |
|
|
|
try: |
|
inference_result = subprocess.run( |
|
cmd, |
|
capture_output=True, |
|
text=True, |
|
check=True |
|
) |
|
print("📄 Inference stdout:\n", inference_result.stdout) |
|
print("⚠️ Inference stderr:\n", inference_result.stderr) |
|
except subprocess.CalledProcessError as e: |
|
print("❌ Inference failed.") |
|
print("⚠️ STDOUT:\n", e.stdout) |
|
print("⚠️ STDERR:\n", e.stderr) |
|
return f"Inference failed:\n{e.stderr}", None |
|
|
|
|
|
mkv_path = os.path.join(OUTPUT_DIR, input_name).replace(".mp4", ".mkv") |
|
mp4_path = os.path.join(OUTPUT_DIR, input_name) |
|
|
|
if os.path.exists(mkv_path): |
|
convert_cmd = [ |
|
"ffmpeg", "-y", "-i", mkv_path, "-c:v", "copy", "-c:a", "aac", mp4_path |
|
] |
|
try: |
|
convert_result = subprocess.run( |
|
convert_cmd, |
|
capture_output=True, |
|
text=True, |
|
check=True |
|
) |
|
print("🔄 FFmpeg stdout:\n", convert_result.stdout) |
|
print("⚠️ FFmpeg stderr:\n", convert_result.stderr) |
|
except subprocess.CalledProcessError as e: |
|
print("❌ FFmpeg conversion failed.") |
|
print("⚠️ STDOUT:\n", e.stdout) |
|
print("⚠️ STDERR:\n", e.stderr) |
|
return f"Inference OK, but conversion failed:\n{e.stderr}", None |
|
|
|
if os.path.exists(mp4_path): |
|
return "Inference successful!", mp4_path |
|
else: |
|
return "Output video not found.", None |
|
|
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("# 🎥 DOVE Video SR + Restoration Inference Demo") |
|
gr.Markdown("⚙️ **Note:** Default `save_format` is `yuv444p`. If playback fails, try `yuv420p` for compatibility.") |
|
|
|
with gr.Row(): |
|
input_video = gr.Video(label="Upload input video") |
|
output_video = gr.Video(label="Output video") |
|
|
|
with gr.Row(): |
|
save_format = gr.Dropdown( |
|
choices=["yuv444p", "yuv420p"], |
|
value="yuv444p", |
|
label="Save format (for video playback compatibility)" |
|
) |
|
|
|
run_button = gr.Button("Run Inference") |
|
status = gr.Textbox(label="Status") |
|
|
|
run_button.click( |
|
fn=run_inference, |
|
inputs=[input_video, save_format], |
|
outputs=[status, output_video], |
|
) |
|
|
|
demo.launch() |
|
|