File size: 3,968 Bytes
d2d7c02
2ae859b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d2d7c02
2ae859b
d2d7c02
2ae859b
 
 
d2d7c02
2ae859b
 
d2d7c02
 
1b37cad
2ae859b
d2d7c02
2ae859b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d2d7c02
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import os
import gradio as gr
import subprocess
import uuid
import shutil
from huggingface_hub import snapshot_download

# ----------------------------------------
# Step 1: Download Model Weights
# ----------------------------------------
MODEL_REPO = "roll-ai/DOVE"
MODEL_PATH = "pretrained_models/"

if not os.path.exists(MODEL_PATH) or len(os.listdir(MODEL_PATH)) == 0:
    print("🔽 Downloading model weights from Hugging Face Hub...")
    snapshot_download(
        repo_id=MODEL_REPO,
        repo_type="dataset",
        local_dir=MODEL_PATH,
        local_dir_use_symlinks=False
    )
    print("✅ Download complete.")

# ----------------------------------------
# Step 2: Setup Directories
# ----------------------------------------
INFERENCE_SCRIPT = "inference_script.py"
OUTPUT_DIR = "results/DOVE/demo"
UPLOAD_DIR = "input_videos"

os.makedirs(UPLOAD_DIR, exist_ok=True)
os.makedirs(OUTPUT_DIR, exist_ok=True)

# ----------------------------------------
# Step 3: Inference Function
# ----------------------------------------
def run_inference(video_path, save_format):
    input_name = f"{uuid.uuid4()}.mp4"
    input_path = os.path.join(UPLOAD_DIR, input_name)
    shutil.copy(video_path, input_path)

    # --- Run inference script ---
    cmd = [
        "python", INFERENCE_SCRIPT,
        "--input_dir", UPLOAD_DIR,
        "--model_path", MODEL_PATH,
        "--output_path", OUTPUT_DIR,
        "--is_vae_st",
        "--save_format", save_format
    ]

    try:
        inference_result = subprocess.run(
            cmd,
            capture_output=True,
            text=True,
            check=True
        )
        print("📄 Inference stdout:\n", inference_result.stdout)
        print("⚠️ Inference stderr:\n", inference_result.stderr)
    except subprocess.CalledProcessError as e:
        print("❌ Inference failed.")
        print("⚠️ STDOUT:\n", e.stdout)
        print("⚠️ STDERR:\n", e.stderr)
        return f"Inference failed:\n{e.stderr}", None

    # --- Convert .mkv to .mp4 ---
    mkv_path = os.path.join(OUTPUT_DIR, input_name).replace(".mp4", ".mkv")
    mp4_path = os.path.join(OUTPUT_DIR, input_name)

    if os.path.exists(mkv_path):
        convert_cmd = [
            "ffmpeg", "-y", "-i", mkv_path, "-c:v", "copy", "-c:a", "aac", mp4_path
        ]
        try:
            convert_result = subprocess.run(
                convert_cmd,
                capture_output=True,
                text=True,
                check=True
            )
            print("🔄 FFmpeg stdout:\n", convert_result.stdout)
            print("⚠️ FFmpeg stderr:\n", convert_result.stderr)
        except subprocess.CalledProcessError as e:
            print("❌ FFmpeg conversion failed.")
            print("⚠️ STDOUT:\n", e.stdout)
            print("⚠️ STDERR:\n", e.stderr)
            return f"Inference OK, but conversion failed:\n{e.stderr}", None

    if os.path.exists(mp4_path):
        return "Inference successful!", mp4_path
    else:
        return "Output video not found.", None

# ----------------------------------------
# Step 4: Gradio Interface
# ----------------------------------------
with gr.Blocks() as demo:
    gr.Markdown("# 🎥 DOVE Video SR + Restoration Inference Demo")
    gr.Markdown("⚙️ **Note:** Default `save_format` is `yuv444p`. If playback fails, try `yuv420p` for compatibility.")

    with gr.Row():
        input_video = gr.Video(label="Upload input video")
        output_video = gr.Video(label="Output video")

    with gr.Row():
        save_format = gr.Dropdown(
            choices=["yuv444p", "yuv420p"],
            value="yuv444p",
            label="Save format (for video playback compatibility)"
        )

    run_button = gr.Button("Run Inference")
    status = gr.Textbox(label="Status")

    run_button.click(
        fn=run_inference,
        inputs=[input_video, save_format],
        outputs=[status, output_video],
    )

demo.launch()