Spaces:
Sleeping
Sleeping
File size: 5,077 Bytes
b2a27a7 0dcf605 b2a27a7 0dcf605 b2a27a7 0dcf605 b2a27a7 0dcf605 d22bdf6 0dcf605 b2a27a7 0dcf605 b2a27a7 d22bdf6 0dcf605 d22bdf6 0dcf605 d22bdf6 b2a27a7 0dcf605 b2a27a7 144fa82 b2a27a7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
import argparse
import os
import sys
from glob import glob
from typing import Any, Union
import numpy as np
import torch
import trimesh
from huggingface_hub import snapshot_download
from PIL import Image
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from triposg.pipelines.pipeline_triposg import TripoSGPipeline
from image_process import prepare_image
from briarmbg import BriaRMBG
import pymeshlab
@torch.no_grad()
def run_triposg(
pipe: Any,
image_input: Union[str, Image.Image],
rmbg_net: Any,
seed: int,
num_inference_steps: int = 50,
guidance_scale: float = 7.0,
faces: int = -1,
octree_depth: int = 9, # 👈 добавлено_et
) -> trimesh.Scene:
print("[DEBUG] Preparing image...")
img_pil = prepare_image(image_input, bg_color=np.array([1.0, 1.0, 1.0]), rmbg_net=rmbg_net)
print("[DEBUG] Running TripoSG pipeline...")
outputs = pipe(
image=img_pil,
generator=torch.Generator(device=pipe.device).manual_seed(seed),
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
flash_octree_depth=octree_depth, # 👈 добавлено_et
).samples[0]
print("[DEBUG] TripoSG output keys:", type(outputs), outputs[0].shape, outputs[1].shape)
mesh = trimesh.Trimesh(outputs[0].astype(np.float32), np.ascontiguousarray(outputs[1]))
print(f"[DEBUG] Mesh created: {mesh.vertices.shape[0]} verts / {mesh.faces.shape[0]} faces")
if faces > 0:
print(f"[DEBUG] Simplifying mesh to {faces} faces")
mesh = simplify_mesh(mesh, faces) # 👈 добавлено_et
return mesh
def mesh_to_pymesh(vertices, faces):
mesh = pymeshlab.Mesh(vertex_matrix=vertices, face_matrix=faces)
ms = pymeshlab.MeshSet()
ms.add_mesh(mesh)
return ms
def pymesh_to_trimesh(mesh):
verts = mesh.vertex_matrix()#.tolist()
faces = mesh.face_matrix()#.tolist()
return trimesh.Trimesh(vertices=verts, faces=faces) #, vID, fID
# old version
# def simplify_mesh(mesh: trimesh.Trimesh, n_faces):
# if mesh.faces.shape[0] > n_faces:
# ms = mesh_to_pymesh(mesh.vertices, mesh.faces)
# ms.meshing_merge_close_vertices()
# ms.meshing_decimation_quadric_edge_collapse(targetfacenum = n_faces)
# return pymesh_to_trimesh(ms.current_mesh())
# else:
# return mesh
# new version
# def simplify_mesh(mesh: trimesh.Trimesh, n_faces):
# if mesh.faces.shape[0] > n_faces:
# ms = mesh_to_pymesh(mesh.vertices, mesh.faces)
# ms.meshing_merge_close_vertices()
# ms.meshing_decimation_quadric_edge_collapse(targetfacenum=n_faces)
# simplified = ms.current_mesh()
# if simplified is None or simplified.face_number() == 0:
# return None
# return pymesh_to_trimesh(simplified)
# return mesh
# new version demo
def simplify_mesh(mesh: trimesh.Trimesh, n_faces):
original_faces = mesh.faces.shape[0] # 👈 сохраняем исходное количество
if original_faces > n_faces:
ms = mesh_to_pymesh(mesh.vertices, mesh.faces)
ms.meshing_merge_close_vertices()
ms.meshing_decimation_quadric_edge_collapse(targetfacenum=n_faces)
simplified = ms.current_mesh()
if simplified is None or simplified.face_number() == 0:
return None
simplified_faces = simplified.face_number()
print(f"[DEBUG] Simplified mesh: {original_faces} → {simplified_faces} faces") # 👈 лог здесь
return pymesh_to_trimesh(simplified)
return mesh
if __name__ == "__main__":
device = "cuda"
dtype = torch.float16
parser = argparse.ArgumentParser()
parser.add_argument("--image-input", type=str, required=True)
parser.add_argument("--output-path", type=str, default="./output.glb")
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--num-inference-steps", type=int, default=50)
parser.add_argument("--guidance-scale", type=float, default=7.0)
parser.add_argument("--faces", type=int, default=-1)
args = parser.parse_args()
# download pretrained weights
triposg_weights_dir = "pretrained_weights/TripoSG"
rmbg_weights_dir = "pretrained_weights/RMBG-1.4"
snapshot_download(repo_id="VAST-AI/TripoSG", local_dir=triposg_weights_dir)
snapshot_download(repo_id="briaai/RMBG-1.4", local_dir=rmbg_weights_dir)
# init rmbg model for background removal
rmbg_net = BriaRMBG.from_pretrained(rmbg_weights_dir).to(device)
rmbg_net.eval()
# init tripoSG pipeline
pipe: TripoSGPipeline = TripoSGPipeline.from_pretrained(triposg_weights_dir).to(device, dtype)
# run inference
run_triposg(
pipe,
image_input=args.image_input,
rmbg_net=rmbg_net,
seed=args.seed,
num_inference_steps=args.num_inference_steps,
guidance_scale=args.guidance_scale,
faces=args.faces,
octree_depth=octree_depth,
).export(args.output_path)
print(f"Mesh saved to {args.output_path}")
|