Amodal3R / app.py
Sm0kyWu's picture
Upload app.py
3a1aea9 verified
raw
history blame
15.5 kB
import gradio as gr
import spaces
from gradio_litmodel3d import LitModel3D
import os
import shutil
os.environ['SPCONV_ALGO'] = 'native'
from typing import *
import torch
import numpy as np
import imageio
from easydict import EasyDict as edict
from PIL import Image
from Amodal3R.pipelines import Amodal3RImageTo3DPipeline
from trellis.pipelines import TrellisImageTo3DPipeline
from Amodal3R.representations import Gaussian, MeshExtractResult
from Amodal3R.utils import render_utils, postprocessing_utils
MAX_SEED = np.iinfo(np.int32).max
TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
os.makedirs(TMP_DIR, exist_ok=True)
def start_session(req: gr.Request):
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
os.makedirs(user_dir, exist_ok=True)
def end_session(req: gr.Request):
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
shutil.rmtree(user_dir)
# def preprocess_image(image: Image.Image) -> Image.Image:
# """
# Preprocess the input image.
# Args:
# image (Image.Image): The input image.
# Returns:
# Image.Image: The preprocessed image.
# """
# processed_image = pipeline.preprocess_image(image)
# return processed_image
def preprocess_images(images: List[Tuple[Image.Image, str]]) -> List[Image.Image]:
"""
Preprocess a list of input images.
Args:
images (List[Tuple[Image.Image, str]]): The input images.
Returns:
List[Image.Image]: The preprocessed images.
"""
images = [image[0] for image in images]
processed_images = [pipeline.preprocess_image(image) for image in images]
return processed_images
def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
return {
'gaussian': {
**gs.init_params,
'_xyz': gs._xyz.cpu().numpy(),
'_features_dc': gs._features_dc.cpu().numpy(),
'_scaling': gs._scaling.cpu().numpy(),
'_rotation': gs._rotation.cpu().numpy(),
'_opacity': gs._opacity.cpu().numpy(),
},
'mesh': {
'vertices': mesh.vertices.cpu().numpy(),
'faces': mesh.faces.cpu().numpy(),
},
}
def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
gs = Gaussian(
aabb=state['gaussian']['aabb'],
sh_degree=state['gaussian']['sh_degree'],
mininum_kernel_size=state['gaussian']['mininum_kernel_size'],
scaling_bias=state['gaussian']['scaling_bias'],
opacity_bias=state['gaussian']['opacity_bias'],
scaling_activation=state['gaussian']['scaling_activation'],
)
gs._xyz = torch.tensor(state['gaussian']['_xyz'], device='cuda')
gs._features_dc = torch.tensor(state['gaussian']['_features_dc'], device='cuda')
gs._scaling = torch.tensor(state['gaussian']['_scaling'], device='cuda')
gs._rotation = torch.tensor(state['gaussian']['_rotation'], device='cuda')
gs._opacity = torch.tensor(state['gaussian']['_opacity'], device='cuda')
mesh = edict(
vertices=torch.tensor(state['mesh']['vertices'], device='cuda'),
faces=torch.tensor(state['mesh']['faces'], device='cuda'),
)
return gs, mesh
def get_seed(randomize_seed: bool, seed: int) -> int:
"""
Get the random seed.
"""
return np.random.randint(0, MAX_SEED) if randomize_seed else seed
@spaces.GPU
def image_to_3d(
image: Image.Image,
mask: Image.Image,
seed: int,
ss_guidance_strength: float,
ss_sampling_steps: int,
slat_guidance_strength: float,
slat_sampling_steps: int,
multiimage_algo: Literal["multidiffusion", "stochastic"],
req: gr.Request,
) -> Tuple[dict, str]:
"""
Convert an image to a 3D model.
Args:
image (Image.Image): The input image.
multiimages (List[Tuple[Image.Image, str]]): The input images in multi-image mode.
is_multiimage (bool): Whether is in multi-image mode.
seed (int): The random seed.
ss_guidance_strength (float): The guidance strength for sparse structure generation.
ss_sampling_steps (int): The number of sampling steps for sparse structure generation.
slat_guidance_strength (float): The guidance strength for structured latent generation.
slat_sampling_steps (int): The number of sampling steps for structured latent generation.
multiimage_algo (Literal["multidiffusion", "stochastic"]): The algorithm for multi-image generation.
Returns:
dict: The information of the generated 3D model.
str: The path to the video of the 3D model.
"""
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
outputs = pipeline.run_multi_image(
[image],
[mask],
seed=seed,
formats=["gaussian", "mesh"],
preprocess_image=False,
sparse_structure_sampler_params={
"steps": ss_sampling_steps,
"cfg_strength": ss_guidance_strength,
},
slat_sampler_params={
"steps": slat_sampling_steps,
"cfg_strength": slat_guidance_strength,
},
mode=multiimage_algo,
)
video = render_utils.render_video(outputs['gaussian'][0], num_frames=120, bg_color=(1,1,1))['color']
video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120, bg_color=(1,1,1))['normal']
video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
video_path = os.path.join(user_dir, 'sample.mp4')
imageio.mimsave(video_path, video, fps=15)
state = pack_state(outputs['gaussian'][0], outputs['mesh'][0])
torch.cuda.empty_cache()
return state, video_path
@spaces.GPU(duration=90)
def extract_glb(
state: dict,
mesh_simplify: float,
texture_size: int,
req: gr.Request,
) -> Tuple[str, str]:
"""
Extract a GLB file from the 3D model.
Args:
state (dict): The state of the generated 3D model.
mesh_simplify (float): The mesh simplification factor.
texture_size (int): The texture resolution.
Returns:
str: The path to the extracted GLB file.
"""
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
gs, mesh = unpack_state(state)
glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False)
glb_path = os.path.join(user_dir, 'sample.glb')
glb.export(glb_path)
torch.cuda.empty_cache()
return glb_path, glb_path
@spaces.GPU
def extract_gaussian(state: dict, req: gr.Request) -> Tuple[str, str]:
"""
Extract a Gaussian file from the 3D model.
Args:
state (dict): The state of the generated 3D model.
Returns:
str: The path to the extracted Gaussian file.
"""
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
gs, _ = unpack_state(state)
gaussian_path = os.path.join(user_dir, 'sample.ply')
gs.save_ply(gaussian_path)
torch.cuda.empty_cache()
return gaussian_path, gaussian_path
def prepare_multi_example() -> List[Image.Image]:
multi_case = list(set([i.split('_')[0] for i in os.listdir("assets/example_multi_image")]))
images = []
for case in multi_case:
_images = []
for i in range(1, 4):
img = Image.open(f'assets/example_multi_image/{case}_{i}.png')
W, H = img.size
img = img.resize((int(W / H * 512), 512))
_images.append(np.array(img))
images.append(Image.fromarray(np.concatenate(_images, axis=1)))
return images
def split_image(image: Image.Image) -> List[Image.Image]:
"""
Split an image into multiple views.
"""
image = np.array(image)
alpha = image[..., 3]
alpha = np.any(alpha>0, axis=0)
start_pos = np.where(~alpha[:-1] & alpha[1:])[0].tolist()
end_pos = np.where(alpha[:-1] & ~alpha[1:])[0].tolist()
images = []
for s, e in zip(start_pos, end_pos):
images.append(Image.fromarray(image[:, s:e+1]))
return [preprocess_image(image) for image in images]
with gr.Blocks(delete_cache=(600, 600)) as demo:
gr.Markdown("""
## 3D Amodal Reconstruction with [Amodal3R](https://sm0kywu.github.io/Amodal3R/)
""")
with gr.Row():
gr.Markdown("""
### Step 1 - Generate Visibility Mask and Occlusion Mask.
* Please wait for a few seconds after uploading the image. The 2D segmenter is getting ready.
* Add the point prompts to indicate the target object and occluders separately.
* "Render Point", see the position of the point to be added.
* "Add Point", the point will be added to the list.
* "Generate mask", see the segmented area corresponding to current point list.
* "Add mask", current mask will be added for 3D amodal completion.
""")
with gr.Row():
with gr.Column():
input_image = gr.Image(type="numpy", label='Input Occlusion Image', sources="upload", height=300)
with gr.Row():
message = gr.Markdown("Please wait a few seconds after uploading the image.", label="Message") # 用于显示提示信息
with gr.Row():
x_input = gr.Number(label="X Coordinate", value=0)
y_input = gr.Number(label="Y Coordinate", value=0)
with gr.Row():
see_button = gr.Button("Render Point")
add_button = gr.Button("Add Point")
with gr.Row():
clear_button = gr.Button("Clear Points")
see_visible_button = gr.Button("Render Added Points")
with gr.Row():
# 新增文本框实时显示点列表
points_text = gr.Textbox(label="Points List", interactive=False)
with gr.Row():
# 新增下拉菜单,用户可选择需要删除的点
visible_points_dropdown = gr.Dropdown(label="Select Point to Delete", choices=[], value=None, interactive=True)
delete_visible_button = gr.Button("Delete Selected Visible")
with gr.Column():
# 用于显示 SAM 分割结果
visible_mask = gr.Image(label='Visible Mask', interactive=False, height=300)
with gr.Row():
gen_vis_mask = gr.Button("Generate Mask")
add_vis_mask = gr.Button("Add Mask")
with gr.Row():
render_vis_mask = gr.Button("Render Mask")
undo_vis_mask = gr.Button("Undo Last Mask")
vis_input = gr.Image(label='Visible Input', interactive=False, height=300)
with gr.Row():
zoom_scale = gr.Slider(0.3, 1.0, label="Target Object Scale", value=0.6, step=0.1)
check_visible_input = gr.Button("Generate Occluded Input")
with gr.Row():
with gr.Column():
with gr.Tabs() as input_tabs:
image_prompt = gr.Image(label="Image Prompt", format="png", image_mode="RGBA", type="pil", height=300)
mask_prompt = gr.Image(label="Mask Prompt", format="png", image_mode="L", type="pil", height=300)
with gr.Accordion(label="Generation Settings", open=False):
seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
gr.Markdown("Stage 1: Sparse Structure Generation")
with gr.Row():
ss_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=7.5, step=0.1)
ss_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
gr.Markdown("Stage 2: Structured Latent Generation")
with gr.Row():
slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=3.0, step=0.1)
slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
multiimage_algo = gr.Radio(["stochastic", "multidiffusion"], label="Multi-image Algorithm", value="stochastic")
generate_btn = gr.Button("Generate")
with gr.Accordion(label="GLB Extraction Settings", open=False):
mesh_simplify = gr.Slider(0.9, 0.98, label="Simplify", value=0.95, step=0.01)
texture_size = gr.Slider(512, 2048, label="Texture Size", value=1024, step=512)
with gr.Row():
extract_glb_btn = gr.Button("Extract GLB", interactive=False)
extract_gs_btn = gr.Button("Extract Gaussian", interactive=False)
gr.Markdown("""
*NOTE: Gaussian file can be very large (~50MB), it will take a while to display and download.*
""")
with gr.Column():
video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300)
model_output = LitModel3D(label="Extracted GLB/Gaussian", exposure=10.0, height=300)
with gr.Row():
download_glb = gr.DownloadButton(label="Download GLB", interactive=False)
download_gs = gr.DownloadButton(label="Download Gaussian", interactive=False)
is_multiimage = gr.State(False)
output_buf = gr.State()
# Handlers
demo.load(start_session)
demo.unload(end_session)
# image_prompt.upload(
# lambda x:x,
# inputs=[image_prompt],
# outputs=[image_prompt],
# )
generate_btn.click(
get_seed,
inputs=[randomize_seed, seed],
outputs=[seed],
).then(
image_to_3d,
inputs=[image_prompt, mask_prompt, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps, multiimage_algo],
outputs=[output_buf, video_output],
).then(
lambda: tuple([gr.Button(interactive=True), gr.Button(interactive=True)]),
outputs=[extract_glb_btn, extract_gs_btn],
)
video_output.clear(
lambda: tuple([gr.Button(interactive=False), gr.Button(interactive=False)]),
outputs=[extract_glb_btn, extract_gs_btn],
)
extract_glb_btn.click(
extract_glb,
inputs=[output_buf, mesh_simplify, texture_size],
outputs=[model_output, download_glb],
).then(
lambda: gr.Button(interactive=True),
outputs=[download_glb],
)
extract_gs_btn.click(
extract_gaussian,
inputs=[output_buf],
outputs=[model_output, download_gs],
).then(
lambda: gr.Button(interactive=True),
outputs=[download_gs],
)
model_output.clear(
lambda: gr.Button(interactive=False),
outputs=[download_glb],
)
# Launch the Gradio app
if __name__ == "__main__":
pipeline = pipeline = Amodal3RImageTo3DPipeline.from_pretrained("Sm0kyWu/Amodal3R")
pipeline.cuda()
try:
pipeline.preprocess_image(Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8))) # Preload rembg
except:
pass
demo.launch()