Dove / app.py
Muhammad Taqi Raza
adding requirements.txt
1b37cad
raw
history blame
3.97 kB
import os
import gradio as gr
import subprocess
import uuid
import shutil
from huggingface_hub import snapshot_download
# ----------------------------------------
# Step 1: Download Model Weights
# ----------------------------------------
MODEL_REPO = "roll-ai/DOVE"
MODEL_PATH = "pretrained_models/"
if not os.path.exists(MODEL_PATH) or len(os.listdir(MODEL_PATH)) == 0:
print("🔽 Downloading model weights from Hugging Face Hub...")
snapshot_download(
repo_id=MODEL_REPO,
repo_type="dataset",
local_dir=MODEL_PATH,
local_dir_use_symlinks=False
)
print("✅ Download complete.")
# ----------------------------------------
# Step 2: Setup Directories
# ----------------------------------------
INFERENCE_SCRIPT = "inference_script.py"
OUTPUT_DIR = "results/DOVE/demo"
UPLOAD_DIR = "input_videos"
os.makedirs(UPLOAD_DIR, exist_ok=True)
os.makedirs(OUTPUT_DIR, exist_ok=True)
# ----------------------------------------
# Step 3: Inference Function
# ----------------------------------------
def run_inference(video_path, save_format):
input_name = f"{uuid.uuid4()}.mp4"
input_path = os.path.join(UPLOAD_DIR, input_name)
shutil.copy(video_path, input_path)
# --- Run inference script ---
cmd = [
"python", INFERENCE_SCRIPT,
"--input_dir", UPLOAD_DIR,
"--model_path", MODEL_PATH,
"--output_path", OUTPUT_DIR,
"--is_vae_st",
"--save_format", save_format
]
try:
inference_result = subprocess.run(
cmd,
capture_output=True,
text=True,
check=True
)
print("📄 Inference stdout:\n", inference_result.stdout)
print("⚠️ Inference stderr:\n", inference_result.stderr)
except subprocess.CalledProcessError as e:
print("❌ Inference failed.")
print("⚠️ STDOUT:\n", e.stdout)
print("⚠️ STDERR:\n", e.stderr)
return f"Inference failed:\n{e.stderr}", None
# --- Convert .mkv to .mp4 ---
mkv_path = os.path.join(OUTPUT_DIR, input_name).replace(".mp4", ".mkv")
mp4_path = os.path.join(OUTPUT_DIR, input_name)
if os.path.exists(mkv_path):
convert_cmd = [
"ffmpeg", "-y", "-i", mkv_path, "-c:v", "copy", "-c:a", "aac", mp4_path
]
try:
convert_result = subprocess.run(
convert_cmd,
capture_output=True,
text=True,
check=True
)
print("🔄 FFmpeg stdout:\n", convert_result.stdout)
print("⚠️ FFmpeg stderr:\n", convert_result.stderr)
except subprocess.CalledProcessError as e:
print("❌ FFmpeg conversion failed.")
print("⚠️ STDOUT:\n", e.stdout)
print("⚠️ STDERR:\n", e.stderr)
return f"Inference OK, but conversion failed:\n{e.stderr}", None
if os.path.exists(mp4_path):
return "Inference successful!", mp4_path
else:
return "Output video not found.", None
# ----------------------------------------
# Step 4: Gradio Interface
# ----------------------------------------
with gr.Blocks() as demo:
gr.Markdown("# 🎥 DOVE Video SR + Restoration Inference Demo")
gr.Markdown("⚙️ **Note:** Default `save_format` is `yuv444p`. If playback fails, try `yuv420p` for compatibility.")
with gr.Row():
input_video = gr.Video(label="Upload input video")
output_video = gr.Video(label="Output video")
with gr.Row():
save_format = gr.Dropdown(
choices=["yuv444p", "yuv420p"],
value="yuv444p",
label="Save format (for video playback compatibility)"
)
run_button = gr.Button("Run Inference")
status = gr.Textbox(label="Status")
run_button.click(
fn=run_inference,
inputs=[input_video, save_format],
outputs=[status, output_video],
)
demo.launch()