File size: 6,075 Bytes
f42f083 15db18d f42f083 15db18d 2160ac9 f42f083 2160ac9 f42f083 73d134f f42f083 6979aa7 f42f083 a0b8a67 f42f083 780ff68 f42f083 2160ac9 ff7cbd6 e0df45d f42f083 e0df45d f42f083 2160ac9 30a18a6 f42f083 560f256 f42f083 d3d1fbf f42f083 d3d1fbf f42f083 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 |
import os
import subprocess
from datetime import datetime
from pathlib import Path
import gradio as gr
import numpy as np
# -----------------------------
# Setup paths and env
# -----------------------------
HF_HOME = "/app/hf_cache"
os.environ["HF_HOME"] = HF_HOME
os.environ["TRANSFORMERS_CACHE"] = HF_HOME
os.makedirs(HF_HOME, exist_ok=True)
PRETRAINED_DIR = "/app/pretrained"
os.makedirs(PRETRAINED_DIR, exist_ok=True)
# -----------------------------
# Step 1: Optional Model Download
# -----------------------------
def download_models():
expected_model = os.path.join(PRETRAINED_DIR, "RAFT/raft-things.pth")
if not Path(expected_model).exists():
print("⚙️ Downloading pretrained models...")
try:
subprocess.check_call(["bash", "download/download_models.sh"])
print("✅ Models downloaded.")
except subprocess.CalledProcessError as e:
print(f"❌ Model download failed: {e}")
else:
print("✅ Pretrained models already exist.")
download_models()
# -----------------------------
# Step 2: Inference Logic
# -----------------------------
def estimate_near_far(depths, lower_percentile=5, upper_percentile=95):
flat = depths.flatten()
near = np.percentile(flat, lower_percentile)
far = np.percentile(flat, upper_percentile)
return near, far
def run_epic_inference(video_path, fps, num_frames, target_pose, mode):
temp_input_path = "/app/temp_input.mp4"
output_dir = "/app/output_anchor"
video_output_path = f"{output_dir}/masked_videos/output.mp4"
# Save uploaded video
if video_path:
os.system(f"cp '{video_path}' {temp_input_path}")
try:
theta, phi, r, x, y = target_pose.strip().split()
except ValueError:
return f"Invalid target pose format. Use: θ φ r x y", None
logs = f"Running inference with target pose: θ={theta}, φ={phi}, r={r}, x={x}, y={y}\n"
command = [
"python", "/app/inference/v2v_data/inference.py",
"--video_path", temp_input_path,
"--stride", "1",
"--out_dir", output_dir,
"--radius_scale", "1",
"--camera", "target",
"--mask",
"--target_pose", theta, phi, r, x, y,
"--video_length", str(num_frames),
"--save_name", "output",
"--mode", mode,
"--fps", str(fps)
]
try:
result = subprocess.run(command, capture_output=True, text=True, check=True)
logs += result.stdout
except subprocess.CalledProcessError as e:
logs += f"❌ Inference failed:\n{e.stderr}{e.stdout}"
return logs, None
return logs + result.stdout, str(video_output_path) if os.path.exists(video_output_path) else (logs, None)
def print_output_directory(out_dir):
result = ""
for root, dirs, files in os.walk(out_dir):
level = root.replace(out_dir, '').count(os.sep)
indent = ' ' * 4 * level
result += f"{indent}{os.path.basename(root)}/\n"
sub_indent = ' ' * 4 * (level + 1)
for f in files:
result += f"{sub_indent}{f}\n"
return result
def inference(video_path, num_frames, fps, target_pose, mode):
logs, video_masked = run_epic_inference(video_path, fps, num_frames, target_pose, mode)
# return logs, video_masked, video_masked
result_dir = print_output_directory("/app/output_anchor")
MODEL_PATH = "/app/pretrained/CogVideoX-5b-I2V"
ckpt_steps = 500
ckpt_dir = "/app/out/EPiC_pretrained"
ckpt_file = f"checkpoint-{ckpt_steps}.pt"
ckpt_path = f"{ckpt_dir}/{ckpt_file}"
video_root_dir = "/app/output_anchor"
out_dir = "/app/output"
command = [
"python", "/app/inference/cli_demo_camera_i2v_pcd.py",
"--video_root_dir", video_root_dir,
"--base_model_path", MODEL_PATH,
"--controlnet_model_path", ckpt_path,
"--output_path", out_dir,
"--start_camera_idx", "0",
"--end_camera_idx", "8",
"--controlnet_weights", "1.0",
"--controlnet_guidance_start", "0.0",
"--controlnet_guidance_end", "0.4",
"--controlnet_input_channels", "3",
"--controlnet_transformer_num_attn_heads", "4",
"--controlnet_transformer_attention_head_dim", "64",
"--controlnet_transformer_out_proj_dim_factor", "64",
"--controlnet_transformer_out_proj_dim_zero_init",
"--vae_channels", "16",
"--num_frames", str(num_frames),
"--controlnet_transformer_num_layers", "8",
"--infer_with_mask",
"--pool_style", "max",
"--seed", "43",
"--fps", str(fps)
]
result = subprocess.run(command, capture_output=True, text=True)
logs += "\n" + result.stdout
result_dir = print_output_directory(out_dir)
if result.returncode == 0:
logs += "Inference completed successfully."
else:
logs += f"Error occurred during inference: {result.stderr}"
return logs + result_dir + "Hello! it is successful", str(f"{out_dir}/00000_43_out.mp4"), video_masked
# -----------------------------
# Step 3: Create Gradio UI
# -----------------------------
demo = gr.Interface(
fn=inference,
inputs=[
gr.Video(label="Upload Video (MP4)"),
gr.Slider(minimum=1, maximum=120, value=50, step=1, label="Number of Frames"),
gr.Slider(minimum=1, maximum=90, value=10, step=1, label="FPS"),
gr.Textbox(label="Target Pose (θ φ r x y)", placeholder="e.g., 0 30 -0.6 0 0"),
gr.Dropdown(choices=["gradual", "direct", "bullet"], value="gradual", label="Camera Mode"),
],
outputs=[
gr.Textbox(label="Inference Logs"),
gr.Video(label="Generated Video`"),
gr.Video(label="Masked Video")
],
title="🎬 EPiC: Efficient Video Camera Control",
description="Upload a video, describe the scene, and apply cinematic camera motion using pretrained EPiC models.",
)
# -----------------------------
# Step 4: Launch App
# -----------------------------
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)
|