Spaces:
Paused
Paused
Muhammad Taqi Raza
commited on
Commit
·
9e96e5e
1
Parent(s):
6825e25
adding files
Browse files- Dockerfile +1 -1
- config.yaml +4 -0
- gradio_app.py +3 -4
- gradio_batch.py +185 -0
Dockerfile
CHANGED
@@ -40,4 +40,4 @@ RUN pip install gradio
|
|
40 |
EXPOSE 7860
|
41 |
|
42 |
# Start the Gradio app
|
43 |
-
CMD ["conda", "run", "--no-capture-output", "-n", "epic", "python", "
|
|
|
40 |
EXPOSE 7860
|
41 |
|
42 |
# Start the Gradio app
|
43 |
+
CMD ["conda", "run", "--no-capture-output", "-n", "epic", "python", "gradio_batch.py"]
|
config.yaml
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
ZoomIn:
|
2 |
+
target_pose: [0, 0.0, 0.2, 0.0, 0.0]
|
3 |
+
Pan:
|
4 |
+
target_pose: [0, 0.0, 0.0, 0.2, 0.0]
|
gradio_app.py
CHANGED
@@ -72,7 +72,6 @@ def get_anchor_video(video_path, fps, num_frames, target_pose, mode,
|
|
72 |
video_output_path = f"{output_dir}/masked_videos/output.mp4"
|
73 |
captions_text_file = f"{output_dir}/captions/output.txt"
|
74 |
depth_file = f"{output_dir}/depth/output.npy"
|
75 |
-
depth_video_path = visualize_depth_npy_as_video(depth_file, fps)
|
76 |
|
77 |
|
78 |
if video_path:
|
@@ -133,7 +132,7 @@ def get_anchor_video(video_path, fps, num_frames, target_pose, mode,
|
|
133 |
if os.path.exists(captions_text_file):
|
134 |
with open(captions_text_file, "r") as f:
|
135 |
caption_text = f.read()
|
136 |
-
|
137 |
return str(video_output_path), logs, caption_text, depth_video_path
|
138 |
# -----------------------------
|
139 |
# Step 2: Run Inference
|
@@ -198,7 +197,7 @@ with demo:
|
|
198 |
near_far_estimated = gr.Checkbox(label="Near Far Estimation", value=True)
|
199 |
pose_input = gr.Textbox(label="Target Pose (θ φ r x y)", placeholder="e.g., 0 30 -0.6 0 0")
|
200 |
fps_input = gr.Number(value=24, label="FPS")
|
201 |
-
aspect_ratio_inputs=gr.Textbox(label="Target Aspect Ratio (e.g., 2,3)")
|
202 |
|
203 |
init_dx = gr.Number(value=0.0, label="Start Camera Offset X")
|
204 |
init_dy = gr.Number(value=0.0, label="Start Camera Offset Y")
|
@@ -214,7 +213,7 @@ with demo:
|
|
214 |
depth_guidance_input = gr.Number(value=1.0, label="Depth Guidance")
|
215 |
window_input = gr.Number(value=64, label="Window Size")
|
216 |
overlap_input = gr.Number(value=25, label="Overlap")
|
217 |
-
maxres_input = gr.Number(value=
|
218 |
sample_size = gr.Textbox(label="Sample Size (height, width)", placeholder="e.g., 384, 672", value="384, 672")
|
219 |
seed_input = gr.Number(value=43, label="Seed")
|
220 |
height = gr.Number(value=480, label="Height")
|
|
|
72 |
video_output_path = f"{output_dir}/masked_videos/output.mp4"
|
73 |
captions_text_file = f"{output_dir}/captions/output.txt"
|
74 |
depth_file = f"{output_dir}/depth/output.npy"
|
|
|
75 |
|
76 |
|
77 |
if video_path:
|
|
|
132 |
if os.path.exists(captions_text_file):
|
133 |
with open(captions_text_file, "r") as f:
|
134 |
caption_text = f.read()
|
135 |
+
depth_video_path = visualize_depth_npy_as_video(depth_file, fps)
|
136 |
return str(video_output_path), logs, caption_text, depth_video_path
|
137 |
# -----------------------------
|
138 |
# Step 2: Run Inference
|
|
|
197 |
near_far_estimated = gr.Checkbox(label="Near Far Estimation", value=True)
|
198 |
pose_input = gr.Textbox(label="Target Pose (θ φ r x y)", placeholder="e.g., 0 30 -0.6 0 0")
|
199 |
fps_input = gr.Number(value=24, label="FPS")
|
200 |
+
aspect_ratio_inputs=gr.Textbox(value= "3,4",label="Target Aspect Ratio (e.g., 2,3)")
|
201 |
|
202 |
init_dx = gr.Number(value=0.0, label="Start Camera Offset X")
|
203 |
init_dy = gr.Number(value=0.0, label="Start Camera Offset Y")
|
|
|
213 |
depth_guidance_input = gr.Number(value=1.0, label="Depth Guidance")
|
214 |
window_input = gr.Number(value=64, label="Window Size")
|
215 |
overlap_input = gr.Number(value=25, label="Overlap")
|
216 |
+
maxres_input = gr.Number(value=720, label="Max Resolution")
|
217 |
sample_size = gr.Textbox(label="Sample Size (height, width)", placeholder="e.g., 384, 672", value="384, 672")
|
218 |
seed_input = gr.Number(value=43, label="Seed")
|
219 |
height = gr.Number(value=480, label="Height")
|
gradio_batch.py
ADDED
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import cv2
|
4 |
+
import yaml
|
5 |
+
import shutil
|
6 |
+
import zipfile
|
7 |
+
import subprocess
|
8 |
+
import gradio as gr
|
9 |
+
import numpy as np
|
10 |
+
from pathlib import Path
|
11 |
+
from huggingface_hub import hf_hub_download
|
12 |
+
|
13 |
+
# -----------------------------
|
14 |
+
# Environment Setup
|
15 |
+
# -----------------------------
|
16 |
+
HF_HOME = "/app/hf_cache"
|
17 |
+
os.environ["HF_HOME"] = HF_HOME
|
18 |
+
os.environ["TRANSFORMERS_CACHE"] = HF_HOME
|
19 |
+
os.makedirs(HF_HOME, exist_ok=True)
|
20 |
+
|
21 |
+
PRETRAINED_DIR = "/app/pretrained"
|
22 |
+
os.makedirs(PRETRAINED_DIR, exist_ok=True)
|
23 |
+
|
24 |
+
INPUT_VIDEOS_DIR = "Input_Videos"
|
25 |
+
CONFIG_FILE = "config.yaml"
|
26 |
+
FINAL_RESULTS_DIR = "Final_results"
|
27 |
+
|
28 |
+
# -----------------------------
|
29 |
+
# Utility Functions
|
30 |
+
# -----------------------------
|
31 |
+
def download_models():
|
32 |
+
expected_model = os.path.join(PRETRAINED_DIR, "RAFT/raft-things.pth")
|
33 |
+
if not Path(expected_model).exists():
|
34 |
+
print("\u2699\ufe0f Downloading pretrained models...")
|
35 |
+
try:
|
36 |
+
subprocess.check_call(["bash", "download/download_models.sh"])
|
37 |
+
print("\u2705 Models downloaded.")
|
38 |
+
except subprocess.CalledProcessError as e:
|
39 |
+
print(f"Model download failed: {e}")
|
40 |
+
else:
|
41 |
+
print("\u2705 Pretrained models already exist.")
|
42 |
+
|
43 |
+
def visualize_depth_npy_as_video(npy_file, fps):
|
44 |
+
depth_np = np.load(npy_file)
|
45 |
+
tensor = torch.from_numpy(depth_np)
|
46 |
+
T, _, H, W = tensor.shape
|
47 |
+
|
48 |
+
video_path = "/app/depth_video_preview.mp4"
|
49 |
+
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
50 |
+
out = cv2.VideoWriter(video_path, fourcc, fps, (W, H))
|
51 |
+
|
52 |
+
for i in range(T):
|
53 |
+
frame = tensor[i, 0].numpy()
|
54 |
+
norm = (frame - frame.min()) / (frame.max() - frame.min() + 1e-8)
|
55 |
+
frame_uint8 = (norm * 255).astype(np.uint8)
|
56 |
+
colored = cv2.applyColorMap(frame_uint8, cv2.COLORMAP_INFERNO)
|
57 |
+
out.write(colored)
|
58 |
+
|
59 |
+
out.release()
|
60 |
+
return video_path
|
61 |
+
|
62 |
+
def zip_dir(dir_path, zip_path):
|
63 |
+
with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zf:
|
64 |
+
for root, _, files in os.walk(dir_path):
|
65 |
+
for file in files:
|
66 |
+
full_path = os.path.join(root, file)
|
67 |
+
rel_path = os.path.relpath(full_path, dir_path)
|
68 |
+
zf.write(full_path, rel_path)
|
69 |
+
|
70 |
+
# -----------------------------
|
71 |
+
# Inference Functions (Use actual versions from your main file)
|
72 |
+
# -----------------------------
|
73 |
+
from main import get_anchor_video, inference # Replace with real imports
|
74 |
+
|
75 |
+
def run_batch_process(progress=gr.Progress()):
|
76 |
+
with open(CONFIG_FILE, 'r') as f:
|
77 |
+
trajectories = yaml.safe_load(f)
|
78 |
+
|
79 |
+
os.makedirs(FINAL_RESULTS_DIR, exist_ok=True)
|
80 |
+
logs = ""
|
81 |
+
videos = list(Path(INPUT_VIDEOS_DIR).glob("*.mp4"))
|
82 |
+
total = len(videos) * len(trajectories)
|
83 |
+
idx = 0
|
84 |
+
|
85 |
+
for video_path in videos:
|
86 |
+
video_name = video_path.stem
|
87 |
+
|
88 |
+
for traj_name, params in trajectories.items():
|
89 |
+
idx += 1
|
90 |
+
logs += f"\n---\nRunning {video_name}/{traj_name} ({idx}/{total})\n"
|
91 |
+
|
92 |
+
out_dir = Path(FINAL_RESULTS_DIR) / video_name / traj_name
|
93 |
+
out_dir.mkdir(parents=True, exist_ok=True)
|
94 |
+
|
95 |
+
anchor_path, logs1, caption, depth_path = get_anchor_video(
|
96 |
+
video_path=str(video_path),
|
97 |
+
fps=params["fps"],
|
98 |
+
num_frames=params["num_frames"],
|
99 |
+
target_pose=params["target_pose"],
|
100 |
+
mode=params.get("mode", "gradual"),
|
101 |
+
radius_scale=params.get("radius_scale", 1.0),
|
102 |
+
near_far_estimated=params.get("near_far_estimated", True),
|
103 |
+
sampler_name=params.get("sampler_name", "DDIM_Origin"),
|
104 |
+
diffusion_guidance_scale=params.get("diff_guidance", 6.0),
|
105 |
+
diffusion_inference_steps=params.get("diff_steps", 50),
|
106 |
+
prompt=params.get("prompt", ""),
|
107 |
+
negative_prompt=params.get("neg_prompt", ""),
|
108 |
+
refine_prompt=params.get("refine_prompt", ""),
|
109 |
+
depth_inference_steps=params.get("depth_steps", 5),
|
110 |
+
depth_guidance_scale=params.get("depth_guidance", 1.0),
|
111 |
+
window_size=params.get("window_size", 64),
|
112 |
+
overlap=params.get("overlap", 25),
|
113 |
+
max_res=params.get("max_res", 720),
|
114 |
+
sample_size=params.get("sample_size", "384, 672"),
|
115 |
+
seed_input=params.get("seed", 43),
|
116 |
+
height=params.get("height", 480),
|
117 |
+
width=params.get("width", 720),
|
118 |
+
aspect_ratio_inputs=params.get("aspect_ratio", "3,4"),
|
119 |
+
init_dx=params.get("init_dx", 0.0),
|
120 |
+
init_dy=params.get("init_dy", 0.0),
|
121 |
+
init_dz=params.get("init_dz", 0.0)
|
122 |
+
)
|
123 |
+
|
124 |
+
if not anchor_path:
|
125 |
+
logs += f"❌ Failed: {video_name}/{traj_name}\n"
|
126 |
+
continue
|
127 |
+
|
128 |
+
shutil.copy(anchor_path, out_dir / "anchor_video.mp4")
|
129 |
+
shutil.copy(depth_path, out_dir / "depth.mp4")
|
130 |
+
with open(out_dir / "captions.txt", "w") as f:
|
131 |
+
f.write(caption or "")
|
132 |
+
with open(out_dir / "step1_logs.txt", "w") as f:
|
133 |
+
f.write(logs1 or "")
|
134 |
+
|
135 |
+
final_video, logs2 = inference(
|
136 |
+
fps=params["fps"],
|
137 |
+
num_frames=params["num_frames"],
|
138 |
+
controlnet_weights=params.get("controlnet_weights", 0.5),
|
139 |
+
controlnet_guidance_start=params.get("controlnet_guidance_start", 0.0),
|
140 |
+
controlnet_guidance_end=params.get("controlnet_guidance_end", 0.5),
|
141 |
+
guidance_scale=params.get("guidance_scale", 6.0),
|
142 |
+
num_inference_steps=params.get("inference_steps", 50),
|
143 |
+
dtype=params.get("dtype", "bfloat16"),
|
144 |
+
seed=params.get("seed2", 42),
|
145 |
+
height=params.get("height", 480),
|
146 |
+
width=params.get("width", 720),
|
147 |
+
downscale_coef=params.get("downscale_coef", 8),
|
148 |
+
vae_channels=params.get("vae_channels", 16),
|
149 |
+
controlnet_input_channels=params.get("controlnet_input_channels", 6),
|
150 |
+
controlnet_transformer_num_layers=params.get("controlnet_transformer_layers", 8)
|
151 |
+
)
|
152 |
+
|
153 |
+
if final_video:
|
154 |
+
shutil.copy(final_video, out_dir / "final_video.mp4")
|
155 |
+
with open(out_dir / "step2_logs.txt", "w") as f:
|
156 |
+
f.write(logs2 or "")
|
157 |
+
|
158 |
+
progress(idx / total)
|
159 |
+
|
160 |
+
zip_path = FINAL_RESULTS_DIR + ".zip"
|
161 |
+
zip_dir(FINAL_RESULTS_DIR, zip_path)
|
162 |
+
return logs, zip_path
|
163 |
+
|
164 |
+
# -----------------------------
|
165 |
+
# Gradio Interface
|
166 |
+
# -----------------------------
|
167 |
+
demo = gr.Blocks()
|
168 |
+
|
169 |
+
with demo:
|
170 |
+
gr.Markdown("## 🚀 EPiC Batch Inference: Automate Experiments")
|
171 |
+
|
172 |
+
with gr.TabItem("📁 Run All Experiments"):
|
173 |
+
with gr.Row():
|
174 |
+
run_batch_btn = gr.Button("▶️ Run Batch Experiments")
|
175 |
+
download_btn = gr.Button("⬇️ Download Results")
|
176 |
+
|
177 |
+
batch_logs = gr.Textbox(label="Logs", lines=25)
|
178 |
+
zip_file_output = gr.File(label="Final ZIP", visible=True)
|
179 |
+
|
180 |
+
run_batch_btn.click(run_batch_process, outputs=[batch_logs, zip_file_output])
|
181 |
+
download_btn.click(lambda: FINAL_RESULTS_DIR + ".zip", outputs=zip_file_output)
|
182 |
+
|
183 |
+
if __name__ == "__main__":
|
184 |
+
download_models()
|
185 |
+
demo.launch(server_name="0.0.0.0", server_port=7860)
|