Spaces:
Runtime error
Runtime error
| from typing import List, Literal | |
| from pathlib import Path | |
| from functools import partial | |
| import spaces | |
| import gradio as gr | |
| import numpy as np | |
| import torch | |
| from torchvision.datasets.utils import download_and_extract_archive | |
| from einops import repeat | |
| from omegaconf import OmegaConf | |
| from modeling.pipeline import VMemPipeline | |
| from diffusers.utils import export_to_video, export_to_gif | |
| from scipy.spatial.transform import Rotation, Slerp | |
| from navigation import Navigator | |
| from PIL import Image | |
| from utils import tensor_to_pil, encode_vae_image, encode_image, get_default_intrinsics, load_img_and_K, transform_img_and_K | |
| import os | |
| import glob | |
| CONFIG_PATH = "configs/inference/inference.yaml" | |
| CONFIG = OmegaConf.load(CONFIG_PATH) | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| MODEL = VMemPipeline(CONFIG, DEVICE) | |
| NAVIGATORS = [] | |
| NAVIGATION_FPS = 3 | |
| WIDTH = 576 | |
| HEIGHT = 576 | |
| IMAGE_PATHS = ['test_samples/changi.jpg', 'test_samples/oxford.jpeg', 'test_samples/open_door.jpg', 'test_samples/jesus.jpg', 'test_samples/friends.jpg'] | |
| # for asset_dir in ASSET_DIRS: | |
| # if os.path.exists(asset_dir): | |
| # for ext in ["*.jpg", "*.jpeg", "*.png"]: | |
| # IMAGE_PATHS.extend(glob.glob(os.path.join(asset_dir, ext))) | |
| # If no images found, create placeholders | |
| if not IMAGE_PATHS: | |
| def create_placeholder_images(num_samples=5, height=HEIGHT, width=WIDTH): | |
| """Create placeholder images for the demo""" | |
| images = [] | |
| for i in range(num_samples): | |
| # Create a gradient image as placeholder | |
| img = np.zeros((height, width, 3), dtype=np.uint8) | |
| for h in range(height): | |
| for w in range(width): | |
| img[h, w, 0] = int(255 * h / height) # Red gradient | |
| img[h, w, 1] = int(255 * w / width) # Green gradient | |
| img[h, w, 2] = int(255 * (i+1) / num_samples) # Blue varies by image | |
| images.append(img) | |
| return images | |
| # Create placeholder video frames and poses | |
| def create_placeholder_video_and_poses(num_samples=5, num_frames=1, height=HEIGHT, width=WIDTH): | |
| """Create placeholder videos and poses for the demo""" | |
| videos = [] | |
| poses = [] | |
| for i in range(num_samples): | |
| # Create a simple video (just one frame initially for each sample) | |
| frames = [] | |
| for j in range(num_frames): | |
| # Create a gradient frame | |
| img = np.zeros((height, width, 3), dtype=np.uint8) | |
| for h in range(height): | |
| for w in range(width): | |
| img[h, w, 0] = int(255 * h / height) # Red gradient | |
| img[h, w, 1] = int(255 * w / width) # Green gradient | |
| img[h, w, 2] = int(255 * (i+1) / num_samples) # Blue varies by video | |
| # Convert to torch tensor [C, H, W] with normalized values | |
| frame = torch.from_numpy(img.transpose(2, 0, 1)).float() / 255.0 | |
| frames.append(frame) | |
| video = torch.stack(frames) | |
| videos.append(video) | |
| # Create placeholder poses (identity matrices flattened) | |
| # This creates a 4x4 identity matrix flattened to match expected format | |
| # pose = torch.eye(4).flatten()[:-4] # Remove last row of 4x4 matrix | |
| poses.append(torch.eye(4).unsqueeze(0).repeat(num_frames, 1, 1)) | |
| return videos, poses | |
| first_frame_list = create_placeholder_images(num_samples=5) | |
| video_list, poses_list = create_placeholder_video_and_poses(num_samples=5) | |
| # Function to load image from path | |
| def load_image_for_navigation(image_path): | |
| """Load image from path and prepare for navigation""" | |
| # Load image and get default intrinsics | |
| image, _ = load_img_and_K(image_path, None, K=None, device=DEVICE) | |
| # Transform image to the target size | |
| config = OmegaConf.load(CONFIG_PATH) | |
| image, _ = transform_img_and_K(image, (config.model.height, config.model.width), mode="crop", K=None) | |
| # Create initial video with single frame and pose | |
| video = image | |
| pose = torch.eye(4).unsqueeze(0) # [1, 4, 4] | |
| return { | |
| "image": tensor_to_pil(image), | |
| "video": video, | |
| "pose": pose | |
| } | |
| class CustomProgressBar: | |
| def __init__(self, pbar): | |
| self.pbar = pbar | |
| def set_postfix(self, **kwargs): | |
| pass | |
| def __getattr__(self, attr): | |
| return getattr(self.pbar, attr) | |
| def get_duration_navigate_video(video: torch.Tensor, | |
| poses: torch.Tensor, | |
| x_angle: float, | |
| y_angle: float, | |
| distance: float | |
| ): | |
| # Estimate processing time based on navigation complexity and number of frames | |
| base_duration = 15 # Base duration in seconds | |
| # Add time for more complex navigation operations | |
| if abs(x_angle) > 20 or abs(y_angle) > 30: | |
| base_duration += 10 # More time for sharp turns | |
| if distance > 100: | |
| base_duration += 10 # More time for longer distances | |
| # Add time proportional to existing video length (more frames = more processing) | |
| base_duration += min(10, len(video)) | |
| return base_duration | |
| def navigate_video( | |
| video: torch.Tensor, | |
| poses: torch.Tensor, | |
| x_angle: float, | |
| y_angle: float, | |
| distance: float, | |
| ): | |
| """ | |
| Generate new video frames by navigating in the 3D scene. | |
| This function uses the Navigator class from navigation.py to handle movement: | |
| - y_angle parameter controls left/right turning (turn_left/turn_right methods) | |
| - distance parameter controls forward movement (move_forward method) | |
| - x_angle parameter controls vertical angle (not directly implemented in Navigator) | |
| Each Navigator instance is stored based on the video session to maintain state. | |
| """ | |
| try: | |
| # Convert first frame to PIL Image for navigator | |
| initial_frame = tensor_to_pil(video[0]) | |
| # Initialize the navigator for this session if not already done | |
| if len(NAVIGATORS) == 0: | |
| # Create a new navigator instance | |
| NAVIGATORS.append(Navigator(MODEL, step_size=0.1, num_interpolation_frames=4)) | |
| # Get the initial pose and convert to numpy | |
| initial_pose = poses[0].cpu().numpy().reshape(4, 4) | |
| # Default camera intrinsics if not available | |
| initial_K = np.array(get_default_intrinsics()[0]) | |
| # Initialize the navigator | |
| NAVIGATORS[0].initialize(initial_frame, initial_pose, initial_K) | |
| navigator = NAVIGATORS[0] | |
| # Generate new frames based on navigation commands | |
| new_frames = [] | |
| # First handle any x-angle (vertical angle) adjustments | |
| # Note: This is approximated as Navigator doesn't directly support this | |
| if abs(x_angle) > 0: | |
| # Implementation for x-angle could be added here | |
| # For now, we'll skip this as it's not directly supported | |
| pass | |
| # Next handle y-angle (turning left/right) | |
| if abs(y_angle) > 0: | |
| # Use Navigator's turn methods | |
| if y_angle > 0: | |
| new_frames = navigator.turn_left(abs(y_angle//2)) | |
| else: | |
| new_frames = navigator.turn_right(abs(y_angle//2)) | |
| # Finally handle distance (moving forward) | |
| elif distance > 0: | |
| # Calculate number of steps based on distance | |
| steps = max(1, int(distance / 10)) | |
| new_frames = navigator.move_forward(steps) | |
| elif distance < 0: | |
| # Handle moving backward if needed | |
| steps = max(1, int(abs(distance) / 10)) | |
| new_frames = navigator.move_backward(steps) | |
| if not new_frames: | |
| # If no new frames were generated, return the current state | |
| return video, poses, tensor_to_pil(video[-1]), export_to_video([tensor_to_pil(video[i]) for i in range(len(video))], fps=NAVIGATION_FPS), [(tensor_to_pil(video[i]), f"t={i}") for i in range(len(video))] | |
| # Convert PIL images to tensors | |
| new_frame_tensors = [] | |
| for frame in new_frames: | |
| # Convert PIL Image to tensor [C, H, W] | |
| frame_np = np.array(frame) / 255.0 | |
| # Convert to [-1, 1] range to match the expected format | |
| frame_tensor = torch.from_numpy(frame_np.transpose(2, 0, 1)).float() * 2.0 - 1.0 | |
| new_frame_tensors.append(frame_tensor) | |
| new_frames_tensor = torch.stack(new_frame_tensors) | |
| # Get the updated camera poses from the navigator | |
| current_pose = navigator.current_pose | |
| new_poses = torch.from_numpy(current_pose).float().unsqueeze(0).repeat(len(new_frames), 1, 1) | |
| # Reshape the poses to match the expected format | |
| new_poses = new_poses.view(len(new_frames), 4, 4) | |
| # Concatenate new frames and poses with existing ones | |
| updated_video = torch.cat([video.cpu(), new_frames_tensor], dim=0) | |
| updated_poses = torch.cat([poses.cpu(), new_poses], dim=0) | |
| # Create output images for gallery | |
| all_images = [(tensor_to_pil(updated_video[i]), f"t={i}") for i in range(len(updated_video))] | |
| updated_video_pil = [tensor_to_pil(updated_video[i]) for i in range(len(updated_video))] | |
| return ( | |
| updated_video, | |
| updated_poses, | |
| tensor_to_pil(updated_video[-1]), # Current view | |
| export_to_video(updated_video_pil, fps=NAVIGATION_FPS), # Video | |
| all_images, # Gallery | |
| ) | |
| except Exception as e: | |
| print(f"Error in navigate_video: {e}") | |
| gr.Warning(f"Navigation error: {e}") | |
| # Return the original inputs to avoid crashes | |
| current_frame = tensor_to_pil(video[-1]) if len(video) > 0 else None | |
| all_frames = [(tensor_to_pil(video[i]), f"t={i}") for i in range(len(video))] | |
| video_frames = [tensor_to_pil(video[i]) for i in range(len(video))] | |
| video_output = export_to_video(video_frames, fps=NAVIGATION_FPS) if video_frames else None | |
| return video, poses, current_frame, video_output, all_frames | |
| def undo_navigation( | |
| video: torch.Tensor, | |
| poses: torch.Tensor, | |
| ): | |
| """ | |
| Undo the last navigation step by removing the last set of frames. | |
| Uses the Navigator's undo method which in turn uses the pipeline's undo_latest_move | |
| to properly handle surfels and state management. | |
| """ | |
| if len(NAVIGATORS) > 0: | |
| navigator = NAVIGATORS[0] | |
| # Call the Navigator's undo method to handle the operation | |
| success = navigator.undo() | |
| if success: | |
| # Since the navigator has handled the frame removal internally, | |
| # we need to update our video and poses tensors to match | |
| updated_video = video[:len(navigator.frames)] | |
| updated_poses = poses[:len(navigator.frames)] | |
| # Create gallery images | |
| all_images = [(tensor_to_pil(updated_video[i]), f"t={i}") for i in range(len(updated_video))] | |
| return ( | |
| updated_video, | |
| updated_poses, | |
| tensor_to_pil(updated_video[-1]), | |
| export_to_video([tensor_to_pil(updated_video[i]) for i in range(len(updated_video))], fps=NAVIGATION_FPS), | |
| all_images, | |
| ) | |
| else: | |
| gr.Warning("You have no moves left to undo!") | |
| else: | |
| gr.Warning("No navigation session available!") | |
| # If undo wasn't successful or no navigator exists, return original state | |
| all_images = [(tensor_to_pil(video[i]), f"t={i}") for i in range(len(video))] | |
| return ( | |
| video, | |
| poses, | |
| tensor_to_pil(video[-1]), | |
| export_to_video([tensor_to_pil(video[i]) for i in range(len(video))], fps=NAVIGATION_FPS), | |
| all_images, | |
| ) | |
| def render_demo3( | |
| s: Literal["Selection", "Generation"], | |
| idx: int, | |
| demo3_stage: gr.State, | |
| demo3_selected_index: gr.State, | |
| demo3_current_video: gr.State, | |
| demo3_current_poses: gr.State | |
| ): | |
| gr.Markdown( | |
| """ | |
| ## Single Image → Consistent Scene Navigation | |
| > #### _Select an image and navigate through the scene by controlling camera movements._ | |
| """, | |
| elem_classes=["task-title"] | |
| ) | |
| match s: | |
| case "Selection": | |
| with gr.Group(): | |
| # Add upload functionality | |
| with gr.Group(elem_classes=["gradio-box"]): | |
| gr.Markdown("### Upload Your Own Image") | |
| gr.Markdown("_Upload an image to navigate through its 3D scene_") | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| upload_image = gr.Image( | |
| label="Upload an image", | |
| type="filepath", | |
| height=300, | |
| elem_id="upload-image" | |
| ) | |
| with gr.Column(scale=1): | |
| gr.Markdown("#### Instructions:") | |
| gr.Markdown("1. Upload a clear, high-quality image") | |
| gr.Markdown("2. Images with distinct visual features work best") | |
| gr.Markdown("3. Landscape or architectural scenes are ideal") | |
| upload_btn = gr.Button("Start Navigation", variant="primary", size="lg") | |
| def process_uploaded_image(image_path): | |
| if image_path is None: | |
| gr.Warning("Please upload an image first") | |
| return "Selection", None, None, None | |
| try: | |
| # Load image and prepare for navigation | |
| result = load_image_for_navigation(image_path) | |
| # Clear any existing navigators | |
| global NAVIGATORS | |
| NAVIGATORS = [] | |
| return ( | |
| "Generation", | |
| None, # No predefined index for uploaded images | |
| result["video"], | |
| result["pose"], | |
| ) | |
| except Exception as e: | |
| print(f"Error in process_uploaded_image: {e}") | |
| gr.Warning(f"Error processing uploaded image: {e}") | |
| return "Selection", None, None, None | |
| upload_btn.click( | |
| fn=process_uploaded_image, | |
| inputs=[upload_image], | |
| outputs=[demo3_stage, demo3_selected_index, demo3_current_video, demo3_current_poses] | |
| ) | |
| gr.Markdown("### Or Choose From Our Examples") | |
| # Define image captions | |
| image_captions = { | |
| 'test_samples/changi.jpg': 'Changi Airport', | |
| 'test_samples/oxford.jpeg': 'Oxford University', | |
| 'test_samples/open_door.jpg': 'Bedroom Interior', | |
| 'test_samples/jesus.jpg': 'Jesus College', | |
| 'test_samples/friends.jpg': 'Friends Café' | |
| } | |
| # Load all images for the gallery with captions | |
| gallery_images = [] | |
| for img_path in IMAGE_PATHS: | |
| try: | |
| # Get caption or default to basename | |
| caption = image_captions.get(img_path, os.path.basename(img_path)) | |
| gallery_images.append((img_path, caption)) | |
| except Exception as e: | |
| print(f"Error loading image {img_path}: {e}") | |
| # Show image gallery for selection | |
| demo3_image_gallery = gr.Gallery( | |
| value=gallery_images, | |
| label="Select an Image to Start Navigation", | |
| columns=len(gallery_images), | |
| height=400, | |
| allow_preview=True, | |
| preview=False, | |
| elem_id="navigation-gallery" | |
| ) | |
| gr.Markdown("_Click on an image to begin navigation_") | |
| def start_navigation(evt: gr.SelectData): | |
| try: | |
| # Get the selected image path | |
| selected_path = IMAGE_PATHS[evt.index] | |
| # Load image and prepare for navigation | |
| result = load_image_for_navigation(selected_path) | |
| # Clear any existing navigators | |
| global NAVIGATORS | |
| NAVIGATORS = [] | |
| return ( | |
| "Generation", | |
| evt.index, | |
| result["video"], | |
| result["pose"], | |
| ) | |
| except Exception as e: | |
| print(f"Error in start_navigation: {e}") | |
| gr.Warning(f"Error starting navigation: {e}") | |
| return "Selection", None, None, None | |
| demo3_image_gallery.select( | |
| fn=start_navigation, | |
| inputs=None, | |
| outputs=[demo3_stage, demo3_selected_index, demo3_current_video, demo3_current_poses] | |
| ) | |
| case "Generation": | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| with gr.Row(): | |
| demo3_current_view = gr.Image( | |
| label="Current View", | |
| width=256, | |
| height=256, | |
| ) | |
| demo3_video = gr.Video( | |
| label="Generated Video", | |
| width=256, | |
| height=256, | |
| autoplay=True, | |
| loop=True, | |
| show_share_button=True, | |
| show_download_button=True, | |
| ) | |
| demo3_generated_gallery = gr.Gallery( | |
| value=[], | |
| label="Generated Frames", | |
| columns=[6], | |
| ) | |
| # Initialize the current view with the selected image if available | |
| if idx is not None: | |
| try: | |
| selected_path = IMAGE_PATHS[idx] | |
| result = load_image_for_navigation(selected_path) | |
| demo3_current_view.value = result["image"] | |
| except Exception as e: | |
| print(f"Error initializing current view: {e}") | |
| with gr.Column(): | |
| gr.Markdown("### Navigation Controls ↓") | |
| with gr.Accordion("Instructions", open=False): | |
| gr.Markdown(""" | |
| - **The model will predict the next few frames based on your camera movements. Repeat the process to continue navigating through the scene.** | |
| - **Use the navigation controls to move forward/backward and turn left/right.** | |
| - **At the end of your navigation, you can save your camera path for later use.** | |
| """) | |
| # with gr.Tab("Basic", elem_id="basic-controls-tab"): | |
| with gr.Group(): | |
| gr.Markdown("_**Select a direction to move:**_") | |
| # First row: Turn left/right | |
| with gr.Row(elem_id="basic-controls"): | |
| gr.Button( | |
| "↰20°\nVeer", | |
| size="sm", | |
| min_width=0, | |
| variant="primary", | |
| ).click( | |
| fn=partial( | |
| navigate_video, | |
| x_angle=0, | |
| y_angle=20, | |
| distance=0, | |
| ), | |
| inputs=[ | |
| demo3_current_video, | |
| demo3_current_poses, | |
| ], | |
| outputs=[ | |
| demo3_current_video, | |
| demo3_current_poses, | |
| demo3_current_view, | |
| demo3_video, | |
| demo3_generated_gallery, | |
| ], | |
| ) | |
| gr.Button( | |
| "↖10°\nTurn", | |
| size="sm", | |
| min_width=0, | |
| variant="primary", | |
| ).click( | |
| fn=partial( | |
| navigate_video, | |
| x_angle=0, | |
| y_angle=10, | |
| distance=0, | |
| ), | |
| inputs=[ | |
| demo3_current_video, | |
| demo3_current_poses, | |
| ], | |
| outputs=[ | |
| demo3_current_video, | |
| demo3_current_poses, | |
| demo3_current_view, | |
| demo3_video, | |
| demo3_generated_gallery, | |
| ], | |
| ) | |
| # gr.Button( | |
| # "↑0°\nAhead", | |
| # size="sm", | |
| # min_width=0, | |
| # variant="primary", | |
| # ).click( | |
| # fn=partial( | |
| # navigate_video, | |
| # x_angle=0, | |
| # y_angle=0, | |
| # distance=10, | |
| # ), | |
| # inputs=[ | |
| # demo3_current_video, | |
| # demo3_current_poses, | |
| # ], | |
| # outputs=[ | |
| # demo3_current_video, | |
| # demo3_current_poses, | |
| # demo3_current_view, | |
| # demo3_video, | |
| # demo3_generated_gallery, | |
| # ], | |
| # ) | |
| gr.Button( | |
| "↗10°\nTurn", | |
| size="sm", | |
| min_width=0, | |
| variant="primary", | |
| ).click( | |
| fn=partial( | |
| navigate_video, | |
| x_angle=0, | |
| y_angle=-10, | |
| distance=0, | |
| ), | |
| inputs=[ | |
| demo3_current_video, | |
| demo3_current_poses, | |
| ], | |
| outputs=[ | |
| demo3_current_video, | |
| demo3_current_poses, | |
| demo3_current_view, | |
| demo3_video, | |
| demo3_generated_gallery, | |
| ], | |
| ) | |
| gr.Button( | |
| "↱\n20° Veer", | |
| size="sm", | |
| min_width=0, | |
| variant="primary", | |
| ).click( | |
| fn=partial( | |
| navigate_video, | |
| x_angle=0, | |
| y_angle=-20, | |
| distance=0, | |
| ), | |
| inputs=[ | |
| demo3_current_video, | |
| demo3_current_poses, | |
| ], | |
| outputs=[ | |
| demo3_current_video, | |
| demo3_current_poses, | |
| demo3_current_view, | |
| demo3_video, | |
| demo3_generated_gallery, | |
| ], | |
| ) | |
| # Second row: Forward/Backward movement | |
| with gr.Row(elem_id="forward-backward-controls"): | |
| gr.Button( | |
| "↓\nBackward", | |
| size="sm", | |
| min_width=0, | |
| variant="secondary", | |
| ).click( | |
| fn=partial( | |
| navigate_video, | |
| x_angle=0, | |
| y_angle=0, | |
| distance=-10, | |
| ), | |
| inputs=[ | |
| demo3_current_video, | |
| demo3_current_poses, | |
| ], | |
| outputs=[ | |
| demo3_current_video, | |
| demo3_current_poses, | |
| demo3_current_view, | |
| demo3_video, | |
| demo3_generated_gallery, | |
| ], | |
| ) | |
| gr.Button( | |
| "↑\nForward", | |
| size="sm", | |
| min_width=0, | |
| variant="secondary", | |
| ).click( | |
| fn=partial( | |
| navigate_video, | |
| x_angle=0, | |
| y_angle=0, | |
| distance=10, | |
| ), | |
| inputs=[ | |
| demo3_current_video, | |
| demo3_current_poses, | |
| ], | |
| outputs=[ | |
| demo3_current_video, | |
| demo3_current_poses, | |
| demo3_current_view, | |
| demo3_video, | |
| demo3_generated_gallery, | |
| ], | |
| ) | |
| # with gr.Tab("Advanced", elem_id="advanced-controls-tab"): | |
| # with gr.Group(): | |
| # gr.Markdown("_**Select angles and distance:**_") | |
| # demo3_y_angle = gr.Slider( | |
| # minimum=-90, | |
| # maximum=90, | |
| # value=0, | |
| # step=10, | |
| # label="Horizontal Angle", | |
| # interactive=True, | |
| # ) | |
| # demo3_x_angle = gr.Slider( | |
| # minimum=-40, | |
| # maximum=40, | |
| # value=0, | |
| # step=10, | |
| # label="Vertical Angle", | |
| # interactive=True, | |
| # ) | |
| # demo3_distance = gr.Slider( | |
| # minimum=-200, | |
| # maximum=200, | |
| # value=100, | |
| # step=10, | |
| # label="Distance (negative = backward)", | |
| # interactive=True, | |
| # ) | |
| # gr.Button( | |
| # "Generate Next Move", variant="primary" | |
| # ).click( | |
| # fn=navigate_video, | |
| # inputs=[ | |
| # demo3_current_video, | |
| # demo3_current_poses, | |
| # demo3_x_angle, | |
| # demo3_y_angle, | |
| # demo3_distance, | |
| # ], | |
| # outputs=[ | |
| # demo3_current_video, | |
| # demo3_current_poses, | |
| # demo3_current_view, | |
| # demo3_video, | |
| # demo3_generated_gallery, | |
| # ], | |
| # ) | |
| gr.Markdown("---") | |
| with gr.Group(): | |
| gr.Markdown("_**Navigation controls:**_") | |
| with gr.Row(): | |
| gr.Button("Undo Last Move", variant="huggingface").click( | |
| fn=undo_navigation, | |
| inputs=[demo3_current_video, demo3_current_poses], | |
| outputs=[ | |
| demo3_current_video, | |
| demo3_current_poses, | |
| demo3_current_view, | |
| demo3_video, | |
| demo3_generated_gallery, | |
| ], | |
| ) | |
| # Add a function to save camera poses | |
| def save_camera_poses(video, poses): | |
| if len(NAVIGATORS) > 0: | |
| navigator = NAVIGATORS[0] | |
| # Create a directory for saved poses | |
| os.makedirs("./visualization", exist_ok=True) | |
| save_path = f"./visualization/transforms_{len(navigator.frames)}_frames.json" | |
| navigator.save_camera_poses(save_path) | |
| return gr.Info(f"Camera poses saved to {save_path}") | |
| return gr.Warning("No navigation instance found") | |
| gr.Button("Save Camera", variant="huggingface").click( | |
| fn=save_camera_poses, | |
| inputs=[demo3_current_video, demo3_current_poses], | |
| outputs=[] | |
| ) | |
| # Add a button to return to image selection | |
| def reset_navigation(): | |
| # Clear current navigator | |
| global NAVIGATORS | |
| NAVIGATORS = [] | |
| return "Selection", None, None, None | |
| gr.Button("Choose New Image", variant="secondary").click( | |
| fn=reset_navigation, | |
| inputs=[], | |
| outputs=[demo3_stage, demo3_selected_index, demo3_current_video, demo3_current_poses] | |
| ) | |
| # Create the Gradio Blocks | |
| with gr.Blocks(theme=gr.themes.Base(primary_hue="teal")) as demo: | |
| gr.HTML( | |
| """ | |
| <style> | |
| [data-tab-id="task-1"], [data-tab-id="task-2"], [data-tab-id="task-3"] { | |
| font-size: 16px !important; | |
| font-weight: bold; | |
| } | |
| #page-title h1 { | |
| color: #0D9488 !important; | |
| } | |
| .task-title h2 { | |
| color: #F59E0C !important; | |
| } | |
| .header-button-row { | |
| gap: 4px !important; | |
| } | |
| .header-button-row div { | |
| width: 131.0px !important; | |
| } | |
| .header-button-column { | |
| width: 131.0px !important; | |
| gap: 5px !important; | |
| } | |
| .header-button a { | |
| border: 1px solid #e4e4e7; | |
| } | |
| .header-button .button-icon { | |
| margin-right: 8px; | |
| } | |
| .demo-button-column .gap { | |
| gap: 5px !important; | |
| } | |
| #basic-controls { | |
| column-gap: 0px; | |
| } | |
| #basic-controls-tab { | |
| padding: 0px; | |
| } | |
| #advanced-controls-tab { | |
| padding: 0px; | |
| } | |
| #forward-backward-controls { | |
| column-gap: 0px; | |
| justify-content: center; | |
| margin-top: 8px; | |
| } | |
| #selected-demo-button { | |
| color: #F59E0C; | |
| text-decoration: underline; | |
| } | |
| .demo-button { | |
| text-align: left !important; | |
| display: block !important; | |
| } | |
| #navigation-gallery { | |
| margin-bottom: 15px; | |
| } | |
| #navigation-gallery .gallery-item { | |
| cursor: pointer; | |
| border-radius: 6px; | |
| transition: transform 0.2s, box-shadow 0.2s; | |
| } | |
| #navigation-gallery .gallery-item:hover { | |
| transform: scale(1.02); | |
| box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); | |
| } | |
| #navigation-gallery .gallery-item.selected { | |
| border: 3px solid #0D9488; | |
| } | |
| /* Upload image styling */ | |
| #upload-image { | |
| border-radius: 8px; | |
| border: 2px dashed #0D9488; | |
| padding: 10px; | |
| transition: all 0.3s ease; | |
| } | |
| #upload-image:hover { | |
| border-color: #F59E0C; | |
| box-shadow: 0 4px 12px rgba(0, 0, 0, 0.1); | |
| } | |
| /* Box styling */ | |
| .gradio-box { | |
| border-radius: 10px; | |
| margin-bottom: 20px; | |
| padding: 15px; | |
| background-color: #f8f9fa; | |
| border: 1px solid #e9ecef; | |
| } | |
| </style> | |
| """ | |
| ) | |
| demo_idx = gr.State(value=3) | |
| with gr.Sidebar(): | |
| gr.Markdown("# VMem: Consistent Scene Generation with Surfel Memory of Views", elem_id="page-title") | |
| gr.Markdown( | |
| "### Official Interactive Demo for [_VMem_](https://arxiv.org/abs/2502.06764)" | |
| ) | |
| gr.Markdown("---") | |
| gr.Markdown("#### Links ↓") | |
| with gr.Row(elem_classes=["header-button-row"]): | |
| with gr.Column(elem_classes=["header-button-column"], min_width=0): | |
| gr.Button( | |
| value="Website", | |
| link="https://v-mem.github.io/", | |
| icon="https://simpleicons.org/icons/googlechrome.svg", | |
| elem_classes=["header-button"], | |
| size="md", | |
| min_width=0, | |
| ) | |
| gr.Button( | |
| value="Paper", | |
| link="https://arxiv.org/abs/2502.06764", | |
| icon="https://simpleicons.org/icons/arxiv.svg", | |
| elem_classes=["header-button"], | |
| size="md", | |
| min_width=0, | |
| ) | |
| with gr.Column(elem_classes=["header-button-column"], min_width=0): | |
| gr.Button( | |
| value="Code", | |
| link="https://github.com/kwsong0113/diffusion-forcing-transformer", | |
| icon="https://simpleicons.org/icons/github.svg", | |
| elem_classes=["header-button"], | |
| size="md", | |
| min_width=0, | |
| ) | |
| gr.Button( | |
| value="Weights", | |
| link="https://huggingface.co/liguang0115/vmem", | |
| icon="https://simpleicons.org/icons/huggingface.svg", | |
| elem_classes=["header-button"], | |
| size="md", | |
| min_width=0, | |
| ) | |
| gr.Markdown("---") | |
| gr.Markdown("#### Choose a Demo ↓") | |
| with gr.Column(elem_classes=["demo-button-column"]): | |
| def render_demo_tabs(idx): | |
| demo_tab_button3 = gr.Button( | |
| "Navigate Image", | |
| size="md", elem_classes=["demo-button"], **{"elem_id": "selected-demo-button"} if idx == 3 else {} | |
| ).click( | |
| fn=lambda: 3, | |
| outputs=demo_idx | |
| ) | |
| gr.Markdown("---") | |
| gr.Markdown("#### Troubleshooting ↓") | |
| with gr.Group(): | |
| with gr.Accordion("Error or Unexpected Results?", open=False): | |
| gr.Markdown("Please try again after refreshing the page and ensure you do not click the same button multiple times.") | |
| with gr.Accordion("Too Slow or No GPU Allocation?", open=False): | |
| gr.Markdown( | |
| "Consider running the demo locally (click the dots in the top-right corner). Alternatively, you can subscribe to Hugging Face Pro for an increased GPU quota." | |
| ) | |
| demo3_stage = gr.State(value="Selection") | |
| demo3_selected_index = gr.State(value=None) | |
| demo3_current_video = gr.State(value=None) | |
| demo3_current_poses = gr.State(value=None) | |
| def render_demo( | |
| _demo_idx, _demo3_stage, _demo3_selected_index | |
| ): | |
| match _demo_idx: | |
| case 3: | |
| render_demo3(_demo3_stage, _demo3_selected_index, demo3_stage, demo3_selected_index, demo3_current_video, demo3_current_poses) | |
| if __name__ == "__main__": | |
| demo.launch(debug=True, | |
| share=True, | |
| max_threads=1, # Limit concurrent processing | |
| show_error=True, # Show detailed error messages | |
| ) | |