etTripoSg-api-gradio / inference_triposg.py
staswrs
add octree depth controls fix 1
95553e7
raw
history blame
4.99 kB
import argparse
import os
import sys
from glob import glob
from typing import Any, Union
import numpy as np
import torch
import trimesh
from huggingface_hub import snapshot_download
from PIL import Image
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from triposg.pipelines.pipeline_triposg import TripoSGPipeline
from image_process import prepare_image
from briarmbg import BriaRMBG
import pymeshlab
# @torch.no_grad()
# def run_triposg(
# pipe: Any,
# image_input: Union[str, Image.Image],
# rmbg_net: Any,
# seed: int,
# num_inference_steps: int = 50,
# guidance_scale: float = 7.0,
# faces: int = -1,
# ) -> trimesh.Scene:
# img_pil = prepare_image(image_input, bg_color=np.array([1.0, 1.0, 1.0]), rmbg_net=rmbg_net)
# outputs = pipe(
# image=img_pil,
# generator=torch.Generator(device=pipe.device).manual_seed(seed),
# num_inference_steps=num_inference_steps,
# guidance_scale=guidance_scale,
# ).samples[0]
# mesh = trimesh.Trimesh(outputs[0].astype(np.float32), np.ascontiguousarray(outputs[1]))
# if faces > 0:
# mesh = simplify_mesh(mesh, faces)
# return mesh
@torch.no_grad()
def run_triposg(
pipe: Any,
image_input: Union[str, Image.Image],
rmbg_net: Any,
seed: int,
num_inference_steps: int = 50,
guidance_scale: float = 7.0,
faces: int = -1,
# octree_depth: int = 9, # 👈 добавлено
) -> trimesh.Scene:
print("[DEBUG] Preparing image...")
img_pil = prepare_image(image_input, bg_color=np.array([1.0, 1.0, 1.0]), rmbg_net=rmbg_net)
print("[DEBUG] Running TripoSG pipeline...")
outputs = pipe(
image=img_pil,
generator=torch.Generator(device=pipe.device).manual_seed(seed),
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
# flash_octree_depth=octree_depth, # 👈 добавлено
).samples[0]
print("[DEBUG] TripoSG output keys:", type(outputs), outputs[0].shape, outputs[1].shape)
mesh = trimesh.Trimesh(outputs[0].astype(np.float32), np.ascontiguousarray(outputs[1]))
print(f"[DEBUG] Mesh created: {mesh.vertices.shape[0]} verts / {mesh.faces.shape[0]} faces")
if faces > 0:
print(f"[DEBUG] Simplifying mesh to {faces} faces")
# mesh = simplify_mesh(mesh, faces)
return mesh
def mesh_to_pymesh(vertices, faces):
mesh = pymeshlab.Mesh(vertex_matrix=vertices, face_matrix=faces)
ms = pymeshlab.MeshSet()
ms.add_mesh(mesh)
return ms
def pymesh_to_trimesh(mesh):
verts = mesh.vertex_matrix()#.tolist()
faces = mesh.face_matrix()#.tolist()
return trimesh.Trimesh(vertices=verts, faces=faces) #, vID, fID
# def simplify_mesh(mesh: trimesh.Trimesh, n_faces):
# if mesh.faces.shape[0] > n_faces:
# ms = mesh_to_pymesh(mesh.vertices, mesh.faces)
# ms.meshing_merge_close_vertices()
# ms.meshing_decimation_quadric_edge_collapse(targetfacenum = n_faces)
# return pymesh_to_trimesh(ms.current_mesh())
# else:
# return mesh
def simplify_mesh(mesh: trimesh.Trimesh, n_faces):
if mesh.faces.shape[0] > n_faces:
ms = mesh_to_pymesh(mesh.vertices, mesh.faces)
ms.meshing_merge_close_vertices()
ms.meshing_decimation_quadric_edge_collapse(targetfacenum=n_faces)
simplified = ms.current_mesh()
if simplified is None or simplified.face_number() == 0:
return None
return pymesh_to_trimesh(simplified)
return mesh
if __name__ == "__main__":
device = "cuda"
dtype = torch.float16
parser = argparse.ArgumentParser()
parser.add_argument("--image-input", type=str, required=True)
parser.add_argument("--output-path", type=str, default="./output.glb")
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--num-inference-steps", type=int, default=50)
parser.add_argument("--guidance-scale", type=float, default=7.0)
parser.add_argument("--faces", type=int, default=-1)
args = parser.parse_args()
# download pretrained weights
triposg_weights_dir = "pretrained_weights/TripoSG"
rmbg_weights_dir = "pretrained_weights/RMBG-1.4"
snapshot_download(repo_id="VAST-AI/TripoSG", local_dir=triposg_weights_dir)
snapshot_download(repo_id="briaai/RMBG-1.4", local_dir=rmbg_weights_dir)
# init rmbg model for background removal
rmbg_net = BriaRMBG.from_pretrained(rmbg_weights_dir).to(device)
rmbg_net.eval()
# init tripoSG pipeline
pipe: TripoSGPipeline = TripoSGPipeline.from_pretrained(triposg_weights_dir).to(device, dtype)
# run inference
run_triposg(
pipe,
image_input=args.image_input,
rmbg_net=rmbg_net,
seed=args.seed,
num_inference_steps=args.num_inference_steps,
guidance_scale=args.guidance_scale,
faces=args.faces,
).export(args.output_path)
print(f"Mesh saved to {args.output_path}")