Muhammad Taqi Raza commited on
Commit
9e96e5e
·
1 Parent(s): 6825e25

adding files

Browse files
Files changed (4) hide show
  1. Dockerfile +1 -1
  2. config.yaml +4 -0
  3. gradio_app.py +3 -4
  4. gradio_batch.py +185 -0
Dockerfile CHANGED
@@ -40,4 +40,4 @@ RUN pip install gradio
40
  EXPOSE 7860
41
 
42
  # Start the Gradio app
43
- CMD ["conda", "run", "--no-capture-output", "-n", "epic", "python", "gradio_app.py"]
 
40
  EXPOSE 7860
41
 
42
  # Start the Gradio app
43
+ CMD ["conda", "run", "--no-capture-output", "-n", "epic", "python", "gradio_batch.py"]
config.yaml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ ZoomIn:
2
+ target_pose: [0, 0.0, 0.2, 0.0, 0.0]
3
+ Pan:
4
+ target_pose: [0, 0.0, 0.0, 0.2, 0.0]
gradio_app.py CHANGED
@@ -72,7 +72,6 @@ def get_anchor_video(video_path, fps, num_frames, target_pose, mode,
72
  video_output_path = f"{output_dir}/masked_videos/output.mp4"
73
  captions_text_file = f"{output_dir}/captions/output.txt"
74
  depth_file = f"{output_dir}/depth/output.npy"
75
- depth_video_path = visualize_depth_npy_as_video(depth_file, fps)
76
 
77
 
78
  if video_path:
@@ -133,7 +132,7 @@ def get_anchor_video(video_path, fps, num_frames, target_pose, mode,
133
  if os.path.exists(captions_text_file):
134
  with open(captions_text_file, "r") as f:
135
  caption_text = f.read()
136
-
137
  return str(video_output_path), logs, caption_text, depth_video_path
138
  # -----------------------------
139
  # Step 2: Run Inference
@@ -198,7 +197,7 @@ with demo:
198
  near_far_estimated = gr.Checkbox(label="Near Far Estimation", value=True)
199
  pose_input = gr.Textbox(label="Target Pose (θ φ r x y)", placeholder="e.g., 0 30 -0.6 0 0")
200
  fps_input = gr.Number(value=24, label="FPS")
201
- aspect_ratio_inputs=gr.Textbox(label="Target Aspect Ratio (e.g., 2,3)")
202
 
203
  init_dx = gr.Number(value=0.0, label="Start Camera Offset X")
204
  init_dy = gr.Number(value=0.0, label="Start Camera Offset Y")
@@ -214,7 +213,7 @@ with demo:
214
  depth_guidance_input = gr.Number(value=1.0, label="Depth Guidance")
215
  window_input = gr.Number(value=64, label="Window Size")
216
  overlap_input = gr.Number(value=25, label="Overlap")
217
- maxres_input = gr.Number(value=1920, label="Max Resolution")
218
  sample_size = gr.Textbox(label="Sample Size (height, width)", placeholder="e.g., 384, 672", value="384, 672")
219
  seed_input = gr.Number(value=43, label="Seed")
220
  height = gr.Number(value=480, label="Height")
 
72
  video_output_path = f"{output_dir}/masked_videos/output.mp4"
73
  captions_text_file = f"{output_dir}/captions/output.txt"
74
  depth_file = f"{output_dir}/depth/output.npy"
 
75
 
76
 
77
  if video_path:
 
132
  if os.path.exists(captions_text_file):
133
  with open(captions_text_file, "r") as f:
134
  caption_text = f.read()
135
+ depth_video_path = visualize_depth_npy_as_video(depth_file, fps)
136
  return str(video_output_path), logs, caption_text, depth_video_path
137
  # -----------------------------
138
  # Step 2: Run Inference
 
197
  near_far_estimated = gr.Checkbox(label="Near Far Estimation", value=True)
198
  pose_input = gr.Textbox(label="Target Pose (θ φ r x y)", placeholder="e.g., 0 30 -0.6 0 0")
199
  fps_input = gr.Number(value=24, label="FPS")
200
+ aspect_ratio_inputs=gr.Textbox(value= "3,4",label="Target Aspect Ratio (e.g., 2,3)")
201
 
202
  init_dx = gr.Number(value=0.0, label="Start Camera Offset X")
203
  init_dy = gr.Number(value=0.0, label="Start Camera Offset Y")
 
213
  depth_guidance_input = gr.Number(value=1.0, label="Depth Guidance")
214
  window_input = gr.Number(value=64, label="Window Size")
215
  overlap_input = gr.Number(value=25, label="Overlap")
216
+ maxres_input = gr.Number(value=720, label="Max Resolution")
217
  sample_size = gr.Textbox(label="Sample Size (height, width)", placeholder="e.g., 384, 672", value="384, 672")
218
  seed_input = gr.Number(value=43, label="Seed")
219
  height = gr.Number(value=480, label="Height")
gradio_batch.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import cv2
4
+ import yaml
5
+ import shutil
6
+ import zipfile
7
+ import subprocess
8
+ import gradio as gr
9
+ import numpy as np
10
+ from pathlib import Path
11
+ from huggingface_hub import hf_hub_download
12
+
13
+ # -----------------------------
14
+ # Environment Setup
15
+ # -----------------------------
16
+ HF_HOME = "/app/hf_cache"
17
+ os.environ["HF_HOME"] = HF_HOME
18
+ os.environ["TRANSFORMERS_CACHE"] = HF_HOME
19
+ os.makedirs(HF_HOME, exist_ok=True)
20
+
21
+ PRETRAINED_DIR = "/app/pretrained"
22
+ os.makedirs(PRETRAINED_DIR, exist_ok=True)
23
+
24
+ INPUT_VIDEOS_DIR = "Input_Videos"
25
+ CONFIG_FILE = "config.yaml"
26
+ FINAL_RESULTS_DIR = "Final_results"
27
+
28
+ # -----------------------------
29
+ # Utility Functions
30
+ # -----------------------------
31
+ def download_models():
32
+ expected_model = os.path.join(PRETRAINED_DIR, "RAFT/raft-things.pth")
33
+ if not Path(expected_model).exists():
34
+ print("\u2699\ufe0f Downloading pretrained models...")
35
+ try:
36
+ subprocess.check_call(["bash", "download/download_models.sh"])
37
+ print("\u2705 Models downloaded.")
38
+ except subprocess.CalledProcessError as e:
39
+ print(f"Model download failed: {e}")
40
+ else:
41
+ print("\u2705 Pretrained models already exist.")
42
+
43
+ def visualize_depth_npy_as_video(npy_file, fps):
44
+ depth_np = np.load(npy_file)
45
+ tensor = torch.from_numpy(depth_np)
46
+ T, _, H, W = tensor.shape
47
+
48
+ video_path = "/app/depth_video_preview.mp4"
49
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
50
+ out = cv2.VideoWriter(video_path, fourcc, fps, (W, H))
51
+
52
+ for i in range(T):
53
+ frame = tensor[i, 0].numpy()
54
+ norm = (frame - frame.min()) / (frame.max() - frame.min() + 1e-8)
55
+ frame_uint8 = (norm * 255).astype(np.uint8)
56
+ colored = cv2.applyColorMap(frame_uint8, cv2.COLORMAP_INFERNO)
57
+ out.write(colored)
58
+
59
+ out.release()
60
+ return video_path
61
+
62
+ def zip_dir(dir_path, zip_path):
63
+ with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zf:
64
+ for root, _, files in os.walk(dir_path):
65
+ for file in files:
66
+ full_path = os.path.join(root, file)
67
+ rel_path = os.path.relpath(full_path, dir_path)
68
+ zf.write(full_path, rel_path)
69
+
70
+ # -----------------------------
71
+ # Inference Functions (Use actual versions from your main file)
72
+ # -----------------------------
73
+ from main import get_anchor_video, inference # Replace with real imports
74
+
75
+ def run_batch_process(progress=gr.Progress()):
76
+ with open(CONFIG_FILE, 'r') as f:
77
+ trajectories = yaml.safe_load(f)
78
+
79
+ os.makedirs(FINAL_RESULTS_DIR, exist_ok=True)
80
+ logs = ""
81
+ videos = list(Path(INPUT_VIDEOS_DIR).glob("*.mp4"))
82
+ total = len(videos) * len(trajectories)
83
+ idx = 0
84
+
85
+ for video_path in videos:
86
+ video_name = video_path.stem
87
+
88
+ for traj_name, params in trajectories.items():
89
+ idx += 1
90
+ logs += f"\n---\nRunning {video_name}/{traj_name} ({idx}/{total})\n"
91
+
92
+ out_dir = Path(FINAL_RESULTS_DIR) / video_name / traj_name
93
+ out_dir.mkdir(parents=True, exist_ok=True)
94
+
95
+ anchor_path, logs1, caption, depth_path = get_anchor_video(
96
+ video_path=str(video_path),
97
+ fps=params["fps"],
98
+ num_frames=params["num_frames"],
99
+ target_pose=params["target_pose"],
100
+ mode=params.get("mode", "gradual"),
101
+ radius_scale=params.get("radius_scale", 1.0),
102
+ near_far_estimated=params.get("near_far_estimated", True),
103
+ sampler_name=params.get("sampler_name", "DDIM_Origin"),
104
+ diffusion_guidance_scale=params.get("diff_guidance", 6.0),
105
+ diffusion_inference_steps=params.get("diff_steps", 50),
106
+ prompt=params.get("prompt", ""),
107
+ negative_prompt=params.get("neg_prompt", ""),
108
+ refine_prompt=params.get("refine_prompt", ""),
109
+ depth_inference_steps=params.get("depth_steps", 5),
110
+ depth_guidance_scale=params.get("depth_guidance", 1.0),
111
+ window_size=params.get("window_size", 64),
112
+ overlap=params.get("overlap", 25),
113
+ max_res=params.get("max_res", 720),
114
+ sample_size=params.get("sample_size", "384, 672"),
115
+ seed_input=params.get("seed", 43),
116
+ height=params.get("height", 480),
117
+ width=params.get("width", 720),
118
+ aspect_ratio_inputs=params.get("aspect_ratio", "3,4"),
119
+ init_dx=params.get("init_dx", 0.0),
120
+ init_dy=params.get("init_dy", 0.0),
121
+ init_dz=params.get("init_dz", 0.0)
122
+ )
123
+
124
+ if not anchor_path:
125
+ logs += f"❌ Failed: {video_name}/{traj_name}\n"
126
+ continue
127
+
128
+ shutil.copy(anchor_path, out_dir / "anchor_video.mp4")
129
+ shutil.copy(depth_path, out_dir / "depth.mp4")
130
+ with open(out_dir / "captions.txt", "w") as f:
131
+ f.write(caption or "")
132
+ with open(out_dir / "step1_logs.txt", "w") as f:
133
+ f.write(logs1 or "")
134
+
135
+ final_video, logs2 = inference(
136
+ fps=params["fps"],
137
+ num_frames=params["num_frames"],
138
+ controlnet_weights=params.get("controlnet_weights", 0.5),
139
+ controlnet_guidance_start=params.get("controlnet_guidance_start", 0.0),
140
+ controlnet_guidance_end=params.get("controlnet_guidance_end", 0.5),
141
+ guidance_scale=params.get("guidance_scale", 6.0),
142
+ num_inference_steps=params.get("inference_steps", 50),
143
+ dtype=params.get("dtype", "bfloat16"),
144
+ seed=params.get("seed2", 42),
145
+ height=params.get("height", 480),
146
+ width=params.get("width", 720),
147
+ downscale_coef=params.get("downscale_coef", 8),
148
+ vae_channels=params.get("vae_channels", 16),
149
+ controlnet_input_channels=params.get("controlnet_input_channels", 6),
150
+ controlnet_transformer_num_layers=params.get("controlnet_transformer_layers", 8)
151
+ )
152
+
153
+ if final_video:
154
+ shutil.copy(final_video, out_dir / "final_video.mp4")
155
+ with open(out_dir / "step2_logs.txt", "w") as f:
156
+ f.write(logs2 or "")
157
+
158
+ progress(idx / total)
159
+
160
+ zip_path = FINAL_RESULTS_DIR + ".zip"
161
+ zip_dir(FINAL_RESULTS_DIR, zip_path)
162
+ return logs, zip_path
163
+
164
+ # -----------------------------
165
+ # Gradio Interface
166
+ # -----------------------------
167
+ demo = gr.Blocks()
168
+
169
+ with demo:
170
+ gr.Markdown("## 🚀 EPiC Batch Inference: Automate Experiments")
171
+
172
+ with gr.TabItem("📁 Run All Experiments"):
173
+ with gr.Row():
174
+ run_batch_btn = gr.Button("▶️ Run Batch Experiments")
175
+ download_btn = gr.Button("⬇️ Download Results")
176
+
177
+ batch_logs = gr.Textbox(label="Logs", lines=25)
178
+ zip_file_output = gr.File(label="Final ZIP", visible=True)
179
+
180
+ run_batch_btn.click(run_batch_process, outputs=[batch_logs, zip_file_output])
181
+ download_btn.click(lambda: FINAL_RESULTS_DIR + ".zip", outputs=zip_file_output)
182
+
183
+ if __name__ == "__main__":
184
+ download_models()
185
+ demo.launch(server_name="0.0.0.0", server_port=7860)