vincentqyw commited on
Commit
6d69218
·
1 Parent(s): d47224a
Files changed (7) hide show
  1. README.md +3 -2
  2. app.py +8 -0
  3. requirements.txt +7 -0
  4. src/__init__.py +5 -0
  5. src/comfy_ui.py +397 -0
  6. src/gradio_ui.py +188 -0
  7. src/omni_processor.py +225 -0
README.md CHANGED
@@ -1,5 +1,5 @@
1
  ---
2
- title: Omni Sfm
3
  emoji: 💻
4
  colorFrom: purple
5
  colorTo: purple
@@ -9,4 +9,5 @@ app_file: app.py
9
  pinned: false
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
1
  ---
2
+ title: Omni SFM
3
  emoji: 💻
4
  colorFrom: purple
5
  colorTo: purple
 
9
  pinned: false
10
  ---
11
 
12
+ Check out the configuration reference at https://github.com/Vincentqyw/omni-sfm
13
+
app.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from src.gradio_ui import OmniConverterUI
2
+
3
+ if __name__ == "__main__":
4
+ ui = OmniConverterUI()
5
+ app = ui.create_interface()
6
+ app.queue().launch(
7
+ server_name="0.0.0.0", server_port=7860, share=False
8
+ )
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ gradio
2
+ loguru
3
+ numpy
4
+ opencv-python
5
+ Pillow
6
+ py360convert
7
+ scipy
src/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .comfy_ui import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS
2
+ from .gradio_ui import OmniConverterUI
3
+ from .omni_processor import OmniVideoProcessor
4
+
5
+ __all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"]
src/comfy_ui.py ADDED
@@ -0,0 +1,397 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ from pathlib import Path
4
+
5
+ import cv2
6
+ import numpy as np
7
+ import torch
8
+ from PIL import Image
9
+
10
+ try:
11
+ from .omni_processor import OmniVideoProcessor
12
+ except ImportError:
13
+ print(
14
+ "Warning: omni_processor not found, some functionality may be limited",
15
+ file=sys.stderr,
16
+ )
17
+ try:
18
+ from .read_write_model import read_model
19
+ except ImportError:
20
+ print(
21
+ "Warning: read_write_model not found, some functionality may be limited",
22
+ file=sys.stderr,
23
+ )
24
+
25
+
26
+ class OmniParameterControls:
27
+ @classmethod
28
+ def INPUT_TYPES(cls):
29
+ return {
30
+ "required": {
31
+ "frame_interval": (
32
+ "INT",
33
+ {"default": 24, "min": 1, "max": 100},
34
+ ),
35
+ "width": ("INT", {"default": 640, "min": 100, "max": 2000}),
36
+ "height": ("INT", {"default": 640, "min": 100, "max": 2000}),
37
+ "cx": ("FLOAT", {"default": 320.0, "min": 0.0, "max": 2000.0}),
38
+ "cy": ("FLOAT", {"default": 320.0, "min": 0.0, "max": 2000.0}),
39
+ "fov_h": (
40
+ "FLOAT",
41
+ {"default": 90.0, "min": 30.0, "max": 180.0},
42
+ ),
43
+ "fov_v": (
44
+ "FLOAT",
45
+ {"default": 90.0, "min": 30.0, "max": 180.0},
46
+ ),
47
+ "base_pitch": (
48
+ "FLOAT",
49
+ {"default": 35.0, "min": -90.0, "max": 90.0},
50
+ ),
51
+ "yaw_steps": ("INT", {"default": 4, "min": 1, "max": 12}),
52
+ "yaw_offset": (
53
+ "FLOAT",
54
+ {"default": 0.0, "min": -180.0, "max": 180.0},
55
+ ),
56
+ },
57
+ "optional": {
58
+ "pano_projection": (
59
+ ["equirectangular", "cubemap"],
60
+ {"default": "equirectangular"},
61
+ ),
62
+ "pano_quality": (
63
+ ["low", "medium", "high"],
64
+ {"default": "medium"},
65
+ ),
66
+ "stabilize": ("BOOLEAN", {"default": True}),
67
+ },
68
+ }
69
+
70
+ RETURN_TYPES = ("OMNI_PARAMS",)
71
+ FUNCTION = "get_params"
72
+ CATEGORY = "Omnidirectional Video"
73
+
74
+ def get_params(
75
+ self,
76
+ frame_interval,
77
+ width,
78
+ height,
79
+ fov_h,
80
+ fov_v,
81
+ base_pitch,
82
+ yaw_steps,
83
+ yaw_offset,
84
+ **kwargs,
85
+ ):
86
+ # Generate views based on parameters
87
+ views = {}
88
+ yaw_step = 360.0 / yaw_steps
89
+
90
+ # Add positive pitch views
91
+ for i in range(yaw_steps):
92
+ yaw = (i * yaw_step + yaw_offset) % 360
93
+ if yaw > 180:
94
+ yaw -= 360
95
+ views[f"pitch_{base_pitch}_yaw_{round(yaw,1)}"] = (base_pitch, yaw)
96
+
97
+ # Add negative pitch views
98
+ for i in range(yaw_steps):
99
+ yaw = (i * yaw_step + yaw_offset) % 360
100
+ if yaw > 180:
101
+ yaw -= 360
102
+ views[f"pitch_{-base_pitch}_yaw_{round(yaw,1)}"] = (
103
+ -base_pitch,
104
+ yaw,
105
+ )
106
+
107
+ params = {
108
+ "frame_interval": frame_interval,
109
+ "width": width,
110
+ "height": height,
111
+ "fov_h": fov_h,
112
+ "fov_v": fov_v,
113
+ "views": views,
114
+ }
115
+ params.update(kwargs)
116
+ return (params,)
117
+
118
+
119
+ class OmniVideoProcessorNode:
120
+ @classmethod
121
+ def INPUT_TYPES(cls):
122
+ return {
123
+ "required": {
124
+ "omni_video": ("IMAGE",),
125
+ "omni_params": ("OMNI_PARAMS",),
126
+ }
127
+ }
128
+
129
+ RETURN_TYPES = ("OMNI_PROCESSED",)
130
+ FUNCTION = "process_video"
131
+ CATEGORY = "Omnidirectional Video"
132
+
133
+ def process_video(self, omni_video, omni_params):
134
+ import tempfile
135
+ import time
136
+ from tempfile import gettempdir
137
+
138
+ # VideoFromFile
139
+
140
+ run_timestamp = time.strftime("%Y%m%d-%H%M%S")
141
+ output_dir = Path(gettempdir()) / f"omni_output_{run_timestamp}"
142
+ output_dir.mkdir(parents=True, exist_ok=True)
143
+
144
+ processor = OmniVideoProcessor(omni_params)
145
+ panoramic_frames, pinhole_images_data = processor.process_video(omni_video, output_dir)
146
+ result = {
147
+ "output_dir": str(output_dir),
148
+ "panoramic_frames": panoramic_frames,
149
+ "pinhole_views": pinhole_images_data,
150
+ }
151
+
152
+ return (result,)
153
+
154
+
155
+ class OmniReconstructionNode:
156
+ @classmethod
157
+ def INPUT_TYPES(cls):
158
+ return {
159
+ "required": {
160
+ "omni_processed": ("OMNI_PROCESSED",),
161
+ "colmap_path": ("STRING", {"default": "colmap"}),
162
+ "quality": (
163
+ ["low", "medium", "high", "extreme"],
164
+ {"default": "medium"},
165
+ ),
166
+ }
167
+ }
168
+
169
+ RETURN_TYPES = ("RECONSTRUCTION", "model_file")
170
+ RETURN_NAMES = ("reconstruction", "model_file")
171
+ FUNCTION = "run_reconstruction"
172
+ CATEGORY = "Omnidirectional Video"
173
+
174
+ def run_reconstruction(self, omni_processed, colmap_path, quality):
175
+ output_dir = Path(omni_processed["output_dir"])
176
+ image_dir = output_dir / "pinhole_images" / "images"
177
+ db_path = output_dir / "database.db"
178
+ rig_config_path = output_dir / "pinhole_images" / "rig_config.json"
179
+ sparse_dir = output_dir / "sparse"
180
+ dense_dir = output_dir / "dense"
181
+
182
+ # Create necessary directories
183
+ sparse_dir.mkdir(exist_ok=True)
184
+ dense_dir.mkdir(exist_ok=True)
185
+
186
+ cmds = [
187
+ f'"{colmap_path}" feature_extractor --database_path "{db_path}" --image_path "{image_dir}" --ImageReader.camera_model PINHOLE --ImageReader.single_camera_per_folder 1',
188
+ f'"{colmap_path}" sequential_matcher --database_path "{db_path}" --SequentialMatching.loop_detection 1',
189
+ f'"{colmap_path}" mapper --database_path "{db_path}" --image_path "{image_dir}" --output_path "{sparse_dir}" --Mapper.ba_refine_focal_length 0 --Mapper.ba_refine_principal_point 0 --Mapper.ba_refine_extra_params 0',
190
+ ]
191
+
192
+ for cmd in cmds:
193
+ print(f"Executing: {cmd}")
194
+ ret = os.system(cmd)
195
+ if ret != 0:
196
+ raise RuntimeError(f"Command failed with exit code {ret}: {cmd}")
197
+ # generate mesh and point cloud
198
+ cameras, images, points3D = read_model(sparse_dir / "0")
199
+ sparse_ply_path = sparse_dir / "0" / "sparse.ply"
200
+ # points3d_data = []
201
+ # for pts in points3D.values():
202
+ # # pts.rgb = pts.rgb.astype(np.float32) / 255.0
203
+ # points3d_data.append(
204
+ # (
205
+ # pts.xyz[0],
206
+ # pts.xyz[1],
207
+ # pts.xyz[2],
208
+ # pts.rgb[0],
209
+ # pts.rgb[1],
210
+ # pts.rgb[2],
211
+ # )
212
+ # )
213
+
214
+ # with open(sparse_ply_path, "w") as f:
215
+ # f.write("ply\n")
216
+ # f.write("format ascii 1.0\n")
217
+ # f.write(f"element vertex {len(points3d_data)}\n")
218
+ # f.write("property float x\n")
219
+ # f.write("property float y\n")
220
+ # f.write("property float z\n")
221
+ # f.write("property uchar red\n")
222
+ # f.write("property uchar green\n")
223
+ # f.write("property uchar blue\n")
224
+ # f.write("end_header\n")
225
+ # for p in points3d_data:
226
+ # f.write(f"{p[0]} {p[1]} {p[2]} {int(p[3])} {int(p[4])} {int(p[5])}\n")
227
+ print(f"Generated sparse point cloud at: {sparse_ply_path}")
228
+ return (
229
+ str(sparse_dir / "0"),
230
+ str(sparse_ply_path),
231
+ )
232
+
233
+
234
+ class OmniPreviewNode:
235
+ @classmethod
236
+ def INPUT_TYPES(cls):
237
+ return {
238
+ "required": {
239
+ "reconstruction": ("RECONSTRUCTION",),
240
+ "model_file": ("model_file",),
241
+ },
242
+ "optional": {
243
+ "show_type": (
244
+ ["input_frame", "reconstruction", "mesh", "model_file"],
245
+ {"default": "input_frame"},
246
+ ),
247
+ "view_yaw": (
248
+ "FLOAT",
249
+ {"default": 0.0, "min": -180.0, "max": 180.0},
250
+ ),
251
+ "view_pitch": (
252
+ "FLOAT",
253
+ {"default": 0.0, "min": -90.0, "max": 90.0},
254
+ ),
255
+ },
256
+ }
257
+
258
+ RETURN_TYPES = ("IMAGE",)
259
+ FUNCTION = "generate_preview"
260
+ CATEGORY = "Omnidirectional Video"
261
+
262
+ def _create_placeholder_preview(self, text):
263
+ img = Image.new("RGB", (640, 480), (30, 30, 50))
264
+ try:
265
+ from PIL import ImageDraw, ImageFont
266
+
267
+ draw = ImageDraw.Draw(img)
268
+ try:
269
+ font = ImageFont.truetype("Arial.ttf", 40)
270
+ except:
271
+ font = ImageFont.load_default()
272
+ text_width = draw.textlength(text, font=font)
273
+ position = ((640 - text_width) // 2, 220)
274
+ draw.text(position, text, fill=(200, 200, 255), font=font)
275
+ except ImportError:
276
+ pass
277
+ return img
278
+
279
+ def generate_preview(self, show_type="input_frame", view_yaw=0.0, view_pitch=0.0, **kwargs):
280
+ blank_image = self._create_placeholder_preview("No Preview Available")
281
+
282
+ def to_tensor(img):
283
+ img = img.convert("RGB").resize((640, 480))
284
+ return torch.from_numpy(np.array(img).astype(np.float32) / 255.0)[None,]
285
+
286
+ if show_type in ["reconstruction", "mesh", "model_file"]:
287
+ file_path = kwargs.get(show_type)
288
+ if file_path and Path(file_path).exists():
289
+ text = f"{show_type.replace('_', ' ').title()} Ready"
290
+ image = self._create_placeholder_preview(text)
291
+ return (to_tensor(image),)
292
+
293
+ return (to_tensor(blank_image),)
294
+
295
+
296
+ # NEW NODE FOR ADVANCED VISUALIZATION
297
+ class OmniAdvancedPreviewNode:
298
+ @classmethod
299
+ def INPUT_TYPES(cls):
300
+ return {
301
+ "required": {
302
+ "omni_processed": ("OMNI_PROCESSED",),
303
+ "show_type": (["Pinhole Images", "Panoramic Frames"],),
304
+ "max_items_to_show": (
305
+ "INT",
306
+ {"default": 8, "min": 1, "max": 64},
307
+ ),
308
+ "start_index": ("INT", {"default": 0, "min": 0}),
309
+ "enable_annotation": ("BOOLEAN", {"default": True}),
310
+ }
311
+ }
312
+
313
+ RETURN_TYPES = ("IMAGE",)
314
+ FUNCTION = "generate_preview_batch"
315
+ CATEGORY = "Omnidirectional Video"
316
+
317
+ def generate_preview_batch(
318
+ self,
319
+ omni_processed,
320
+ show_type,
321
+ max_items_to_show,
322
+ start_index,
323
+ enable_annotation,
324
+ ):
325
+ images_to_process = []
326
+ if show_type == "Pinhole Images" and "pinhole_views" in omni_processed:
327
+ images_to_process = omni_processed["pinhole_views"]
328
+ elif show_type == "Panoramic Frames" and "panoramic_frames" in omni_processed:
329
+ images_to_process = omni_processed["panoramic_frames"]
330
+
331
+ if not images_to_process:
332
+ blank_image = Image.new("RGB", (256, 256), "black")
333
+ return (torch.from_numpy(np.array(blank_image).astype(np.float32) / 255.0)[None,],)
334
+
335
+ # 分页逻辑
336
+ end_index = start_index + max_items_to_show
337
+ subset = images_to_process[start_index:end_index]
338
+
339
+ output_images = []
340
+ for item in subset:
341
+ if isinstance(item, dict) and "image" in item:
342
+ img_data = item["image"]
343
+ if isinstance(item, dict) and "frame" in item:
344
+ img_data = item["frame"]
345
+ if isinstance(img_data, str):
346
+ img_data = cv2.imread(img_data)
347
+ img_data = cv2.cvtColor(img_data, cv2.COLOR_BGR2RGB)
348
+ if img_data is None:
349
+ print(f"Warning: Image data is None for item {item}")
350
+ continue
351
+ pil_img = Image.fromarray(img_data)
352
+
353
+ if show_type == "Pinhole Images" and enable_annotation:
354
+ from PIL import ImageDraw, ImageFont
355
+
356
+ draw = ImageDraw.Draw(pil_img)
357
+ try:
358
+ font = ImageFont.truetype("arial.ttf", 20)
359
+ except IOError:
360
+ font = ImageFont.load_default()
361
+
362
+ text = (
363
+ f"P: {item['pitch']:.1f}, Y: {item['yaw']:.1f}\n"
364
+ f"Size: {item['width']}x{item['height']}\n"
365
+ f"Pano Idx: {item['pano_index']}"
366
+ )
367
+
368
+ draw.text((10, 10), text, font=font, fill="yellow")
369
+
370
+ img_tensor = torch.from_numpy(np.array(pil_img).astype(np.float32) / 255.0)
371
+ output_images.append(img_tensor)
372
+
373
+ if not output_images:
374
+ blank_image = Image.new("RGB", (256, 256), "black")
375
+ return (torch.from_numpy(np.array(blank_image).astype(np.float32) / 255.0)[None,],)
376
+
377
+ return (torch.stack(output_images),)
378
+
379
+
380
+ # UPDATE THE NODE MAPPINGS
381
+ NODE_CLASS_MAPPINGS = {
382
+ # "OmniLoadVideoUpload": OmniLoadVideoUpload,
383
+ "OmniParameterControls": OmniParameterControls,
384
+ "OmniVideoProcessor": OmniVideoProcessorNode,
385
+ "OmniReconstruction": OmniReconstructionNode,
386
+ "OmniPreview": OmniPreviewNode, # Keeping the old one for simple previews
387
+ "OmniAdvancedPreview": OmniAdvancedPreviewNode, # Adding the new one
388
+ }
389
+
390
+ NODE_DISPLAY_NAME_MAPPINGS = {
391
+ # "OmniLoadVideoUpload": "Load Omni Video Upload",
392
+ "OmniParameterControls": "Omnidirectional Parameters",
393
+ "OmniVideoProcessor": "Process Omnidirectional Video",
394
+ "OmniReconstruction": "Run COLMAP Reconstruction",
395
+ "OmniPreview": "Omni Model Preview", # Renamed for clarity
396
+ "OmniAdvancedPreview": "Omni Advanced Preview", # New node's display name
397
+ }
src/gradio_ui.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ from pathlib import Path
3
+
4
+ import cv2
5
+ import gradio as gr
6
+ from PIL import Image
7
+
8
+ from .omni_processor import OmniVideoProcessor
9
+
10
+
11
+ class OmniConverterUI:
12
+ def __init__(self):
13
+ self.processor = OmniVideoProcessor()
14
+ self.default_params = self.processor.params.copy()
15
+ self.max_gallery_items = 20
16
+
17
+ def create_interface(self):
18
+ """Create Gradio interface"""
19
+ with gr.Blocks(title="Omnidirectional Video to Pinhole Converter") as demo:
20
+ gr.Markdown("## Omnidirectional Video to Pinhole Converter")
21
+
22
+ with gr.Row():
23
+ with gr.Column():
24
+ # Video input
25
+ video_input = gr.File(label="Upload Video", type="filepath")
26
+
27
+ # Submit button
28
+ submit_btn = gr.Button("Convert", variant="primary")
29
+
30
+ # Frame extraction settings
31
+ with gr.Accordion("Frame Extraction", open=True):
32
+ frame_interval = gr.Slider(
33
+ 1,
34
+ 100,
35
+ value=self.default_params["frame_interval"],
36
+ label="Frame Interval",
37
+ interactive=True,
38
+ )
39
+
40
+ # Pinhole camera settings
41
+ with gr.Accordion("Pinhole Parameters", open=True):
42
+ with gr.Row():
43
+ image_width = gr.Slider(
44
+ 100,
45
+ 2000,
46
+ value=self.default_params["width"],
47
+ label="Image Width",
48
+ interactive=True,
49
+ )
50
+ image_height = gr.Slider(
51
+ 100,
52
+ 2000,
53
+ value=self.default_params["height"],
54
+ label="Image Height",
55
+ interactive=True,
56
+ )
57
+ with gr.Row():
58
+ cx = gr.Slider(
59
+ 50,
60
+ 1000,
61
+ value=self.default_params["cx"],
62
+ label="Principal Point X",
63
+ interactive=True,
64
+ )
65
+ cy = gr.Slider(
66
+ 50,
67
+ 1000,
68
+ value=self.default_params["cy"],
69
+ label="Principal Point Y",
70
+ interactive=True,
71
+ )
72
+ with gr.Row():
73
+ fov_h = gr.Slider(
74
+ 30,
75
+ 180,
76
+ value=self.default_params["fov_h"],
77
+ label="Horizontal FOV (deg)",
78
+ interactive=True,
79
+ )
80
+ fov_v = gr.Slider(
81
+ 30,
82
+ 180,
83
+ value=self.default_params["fov_v"],
84
+ label="Vertical FOV (deg)",
85
+ interactive=True,
86
+ )
87
+ with gr.Row():
88
+ fx = gr.Slider(
89
+ 50,
90
+ 1000,
91
+ value=self.default_params["fx"],
92
+ label="Focal Length X",
93
+ interactive=True,
94
+ )
95
+ fy = gr.Slider(
96
+ 50,
97
+ 1000,
98
+ value=self.default_params["fy"],
99
+ label="Focal Length Y",
100
+ interactive=True,
101
+ )
102
+
103
+ # View selection
104
+ with gr.Accordion("Custom View editions", open=False):
105
+ with gr.Row():
106
+ custom_pitch = gr.Slider(-90, 90, value=0, label="Custom Pitch")
107
+ custom_yaw = gr.Slider(-180, 180, value=0, label="Custom Yaw")
108
+ add_custom = gr.Button("Add Custom View")
109
+
110
+ with gr.Column():
111
+ # Results display
112
+ output_gallery = gr.Gallery(
113
+ label="Generated Pinhole Images",
114
+ columns=len(self.default_params["views"]), # Use initial value
115
+ object_fit="contain",
116
+ height="auto",
117
+ )
118
+ view_state_display = gr.JSON(
119
+ label="Current Views",
120
+ value=self.default_params["views"].copy(),
121
+ )
122
+
123
+ # Initialize views state
124
+ views_state = gr.State(self.default_params["views"].copy())
125
+
126
+ # Event handlers
127
+ add_custom.click(
128
+ fn=self._update_views,
129
+ inputs=[custom_pitch, custom_yaw, views_state],
130
+ outputs=[views_state, view_state_display],
131
+ )
132
+
133
+ submit_btn.click(
134
+ fn=self._run_conversion,
135
+ inputs=[
136
+ video_input,
137
+ frame_interval,
138
+ fx,
139
+ fy,
140
+ cx,
141
+ cy,
142
+ image_width,
143
+ image_height,
144
+ fov_h,
145
+ fov_v,
146
+ views_state,
147
+ ],
148
+ outputs=output_gallery,
149
+ )
150
+
151
+ return demo
152
+
153
+ def _update_views(self, pitch, yaw, current_views):
154
+ """Update views state with new custom view"""
155
+ new_views = {**current_views, f"pitch_{pitch}_yaw_{yaw}": (pitch, yaw)}
156
+ return new_views, new_views
157
+
158
+ def _run_conversion(self, video_file, *params):
159
+ """Run conversion with progress tracking"""
160
+ param_names = [
161
+ "frame_interval",
162
+ "fx",
163
+ "fy",
164
+ "cx",
165
+ "cy",
166
+ "width",
167
+ "height",
168
+ "fov_h",
169
+ "fov_v",
170
+ "views",
171
+ ]
172
+ params_dict = dict(zip(param_names, params))
173
+
174
+ self.processor.set_params(params_dict)
175
+
176
+ output_dir = Path.cwd() / "outputs" / time.strftime("%Y%m%d%H%M%S")
177
+ output_dir.mkdir(parents=True, exist_ok=True)
178
+ pano_images, pinhole_images_data = self.processor.process_video(video_file.name, output_dir)
179
+ image_list_for_gallery = [
180
+ (
181
+ Image.fromarray(cv2.cvtColor(img_info["image"], cv2.COLOR_BGR2RGB)),
182
+ "Frame {}, View: {}".format(img_info["pano_index"], img_info["view_name"]),
183
+ )
184
+ for img_info in pinhole_images_data
185
+ ][: self.max_gallery_items]
186
+ if not image_list_for_gallery:
187
+ return gr.update(value=[], visible=False)
188
+ return gr.update(columns=len(params_dict["views"]), value=image_list_for_gallery)
src/omni_processor.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from pathlib import Path
3
+
4
+ import cv2
5
+ import numpy as np
6
+ import py360convert
7
+ import torch
8
+ from scipy.spatial.transform import Rotation as R
9
+ from tqdm import tqdm
10
+
11
+
12
+ def compute_focal_length(image_size, fov_deg):
13
+ return (image_size / 2) / np.tan(np.deg2rad(fov_deg) / 2)
14
+
15
+
16
+ class OmniVideoProcessor:
17
+ default_params = {
18
+ "fx": 320.0,
19
+ "fy": 320.0,
20
+ "cx": 320.0,
21
+ "cy": 320.0,
22
+ "height": 640,
23
+ "width": 640,
24
+ "fov_h": 90,
25
+ "fov_v": 90,
26
+ "frame_interval": 24,
27
+ "num_steps_yaw": 4,
28
+ "pitches_deg": [-35.0, 35.0],
29
+ "views": {
30
+ "pitch_35_yaw_0": (35, 0),
31
+ "pitch_35_yaw_90": (35, 60),
32
+ "pitch_35_yaw_-90": (35, -90),
33
+ "pitch_35_yaw_180": (35, 180),
34
+ "pitch_-35_yaw_0": (-35, 0),
35
+ "pitch_-35_yaw_90": (-35, 60),
36
+ "pitch_-35_yaw_-90": (-35, -90),
37
+ "pitch_-35_yaw_180": (-35, 180),
38
+ },
39
+ }
40
+
41
+ def __init__(self, params={}):
42
+ self.params = params if params else self.default_params.copy()
43
+ self.ref_sensor = list(self.params["views"].keys())[0]
44
+
45
+ def set_params(self, params):
46
+ self.params = params
47
+
48
+ def process_video(self, video_or_path, output_dir):
49
+ output_dir = Path(output_dir)
50
+ output_dir.mkdir(parents=True, exist_ok=True)
51
+
52
+ if isinstance(video_or_path, str):
53
+ video_file = Path(video_or_path)
54
+ video = cv2.VideoCapture(str(video_file))
55
+ if not video.isOpened():
56
+ raise IOError(f"Cannot open video file: {video_file}")
57
+ pano_images = self._extract_frames(video, output_dir)
58
+ video.release()
59
+ elif isinstance(video_or_path, torch.Tensor) or isinstance(video_or_path, np.ndarray):
60
+ pano_images = self._extract_frames_torch(video_or_path)
61
+ else:
62
+ raise ValueError("video_or_path must be a string or Path object")
63
+
64
+ pinhole_images_data = self._generate_pinhole_images(pano_images, output_dir)
65
+ return pano_images, pinhole_images_data
66
+
67
+ def _extract_frames(self, video, output_dir):
68
+ frame_count = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
69
+ pano_images = []
70
+
71
+ for frame_idx in tqdm(range(frame_count), desc="Extracting Frames"):
72
+ ret, frame = video.read()
73
+ if not ret:
74
+ break
75
+ if frame_idx % self.params["frame_interval"] == 0:
76
+ pano_images.append({"image": frame, "idx": frame_idx})
77
+ return pano_images
78
+
79
+ def _extract_frames_torch(self, video_tensor):
80
+ if not isinstance(video_tensor, torch.Tensor):
81
+ raise ValueError("video_tensor must be a torch.Tensor")
82
+
83
+ pano_images = []
84
+ num_frames = video_tensor.shape[0]
85
+
86
+ for frame_idx in tqdm(range(num_frames), desc="Extracting Frames"):
87
+ if frame_idx % self.params["frame_interval"] == 0:
88
+ # Assuming video_tensor is normalized [0, 1], rgb mode
89
+ img = video_tensor[frame_idx].numpy() * 255.0
90
+ img = img.astype(np.uint8) # Convert to uint8
91
+ pano_images.append({"image": img, "idx": frame_idx})
92
+
93
+ return pano_images
94
+
95
+ def _generate_pinhole_images(self, pano_images, output_dir):
96
+ output_pinhole_dir = output_dir / "pinhole_images" / "images"
97
+ output_pinhole_dir.mkdir(parents=True, exist_ok=True)
98
+
99
+ pinhole_data = []
100
+ camera_params_list = []
101
+ camera_rig_params = {}
102
+ pinhole_views = []
103
+
104
+ for pano_info in tqdm(pano_images, desc="Generating Pinhole Views"):
105
+ pano_idx, pano_image = pano_info["idx"], pano_info["image"]
106
+ for view_name, (pitch, yaw) in self.params["views"].items():
107
+ pinhole_image = self._convert_to_pinhole(pano_image, pitch, yaw)
108
+
109
+ save_dir = output_pinhole_dir / view_name
110
+ save_dir.mkdir(parents=True, exist_ok=True)
111
+ save_path = save_dir / f"{pano_idx:06d}.jpg"
112
+ cv2.imwrite(str(save_path), pinhole_image)
113
+
114
+ h, w = pinhole_image.shape[:2]
115
+ pinhole_views.append(
116
+ {
117
+ "image": pinhole_image,
118
+ "pano_index": pano_idx,
119
+ "view_name": view_name,
120
+ "pitch": pitch,
121
+ "yaw": yaw,
122
+ "width": w,
123
+ "height": h,
124
+ "save_path": str(save_path),
125
+ }
126
+ )
127
+
128
+ pinhole_data.append((pano_idx, view_name, pinhole_image, str(save_path)))
129
+
130
+ is_ref = view_name == self.ref_sensor
131
+ cam_params = self._create_camera_params(
132
+ save_path, pano_idx, view_name, pitch, yaw, is_ref
133
+ )
134
+ camera_params_list.append(cam_params)
135
+
136
+ if view_name not in camera_rig_params:
137
+ camera_rig_params[view_name] = {
138
+ "image_prefix": view_name,
139
+ "yaw": yaw,
140
+ "pitch": pitch,
141
+ "ref_sensor": is_ref,
142
+ }
143
+
144
+ self._save_camera_params(
145
+ camera_params_list,
146
+ output_dir / "pinhole_images" / "camera_params.json",
147
+ )
148
+ self._save_colmap_camera_rig(
149
+ camera_rig_params, output_dir / "pinhole_images" / "rig_config.json"
150
+ )
151
+
152
+ return pinhole_views
153
+
154
+ def _convert_to_pinhole(self, pano_image, pitch, yaw):
155
+ return py360convert.e2p(
156
+ e_img=pano_image,
157
+ fov_deg=(self.params["fov_h"], self.params["fov_v"]),
158
+ u_deg=yaw,
159
+ v_deg=pitch,
160
+ out_hw=(self.params["height"], self.params["width"]),
161
+ in_rot_deg=0,
162
+ mode="bilinear",
163
+ )
164
+
165
+ def _create_camera_params(
166
+ self, save_path: Path, pano_idx, view_name, pitch, yaw, ref_sensor=None
167
+ ):
168
+ fx = compute_focal_length(self.params["width"], self.params["fov_h"])
169
+ fy = compute_focal_length(self.params["height"], self.params["fov_v"])
170
+ return {
171
+ "image_name": save_path.name,
172
+ "image_prefix": view_name,
173
+ "fx": fx,
174
+ "fy": fy,
175
+ "cx": self.params["width"] / 2,
176
+ "cy": self.params["height"] / 2,
177
+ "height": self.params["height"],
178
+ "width": self.params["width"],
179
+ "fov_h": self.params["fov_h"],
180
+ "fov_v": self.params["fov_v"],
181
+ "yaw": yaw,
182
+ "pitch": pitch,
183
+ "pano_index": pano_idx,
184
+ "ref_sensor": ref_sensor,
185
+ }
186
+
187
+ def _save_camera_params(self, params, output_file):
188
+ with open(output_file, "w") as f:
189
+ json.dump(params, f, indent=4)
190
+
191
+ def _save_colmap_camera_rig(self, camera_rig_params, output_file):
192
+ if not self.params["views"]:
193
+ return
194
+
195
+ ref_view_name = list(self.params["views"].keys())[0]
196
+ ref_pitch, ref_yaw = self.params["views"][ref_view_name]
197
+
198
+ # COLMAP: X right, Y down, Z forward. Euler: yaw, pitch, roll
199
+ R_ref_world = R.from_euler("yx", [ref_yaw, ref_pitch], degrees=True)
200
+
201
+ rig_cameras = []
202
+ for image_prefix, params in camera_rig_params.items():
203
+ R_view_world = R.from_euler("yx", [params["yaw"], params["pitch"]], degrees=True)
204
+ R_view_ref = R_view_world.inv() * R_ref_world # Cam from Rig
205
+
206
+ # Scipy quat (x,y,z,w) -> COLMAP quat (w,x,y,z)
207
+ qvec_scipy = R_view_ref.as_quat()
208
+ qvec_colmap = [
209
+ qvec_scipy[3],
210
+ qvec_scipy[0],
211
+ qvec_scipy[1],
212
+ qvec_scipy[2],
213
+ ]
214
+
215
+ cam_entry = {"image_prefix": image_prefix}
216
+ if params.get("ref_sensor"):
217
+ cam_entry["ref_sensor"] = True
218
+ else:
219
+ cam_entry["cam_from_rig_rotation"] = qvec_colmap
220
+ cam_entry["cam_from_rig_translation"] = [0.0, 0.0, 0.0]
221
+ rig_cameras.append(cam_entry)
222
+
223
+ colmap_rig_config = [{"cameras": rig_cameras}]
224
+ with open(output_file, "w") as f:
225
+ json.dump(colmap_rig_config, f, indent=4)