staswrs commited on
Commit
0dcf605
·
1 Parent(s): 95553e7

add octree depth controls fix 2

Browse files
Files changed (2) hide show
  1. app.py +13 -13
  2. inference_triposg.py +26 -30
app.py CHANGED
@@ -67,10 +67,10 @@ rmbg_net = BriaRMBG.from_pretrained(rmbg_path).to(device)
67
  rmbg_net.eval()
68
 
69
 
70
- def generate(image_path, face_number=50000, guidance_scale=5.0, num_steps=25):
71
- # def generate(image_path, face_number=50000, guidance_scale=5.0, num_steps=25, octree_depth=9):
72
- # print(f"[INPUT] face_number={face_number}, guidance_scale={guidance_scale}, num_steps={num_steps}, octree_depth={octree_depth}")# 👈 добавлено
73
- print(f"[INPUT] face_number={face_number}, guidance_scale={guidance_scale}, num_steps={num_steps}")# 👈 добавлено
74
  print("[API CALL] image_path received:", image_path)
75
  print("[API CALL] File exists:", os.path.exists(image_path))
76
 
@@ -87,7 +87,7 @@ def generate(image_path, face_number=50000, guidance_scale=5.0, num_steps=25):
87
  num_inference_steps=int(num_steps),
88
  guidance_scale=float(guidance_scale),
89
  faces=int(face_number),
90
- # octree_depth=int(octree_depth), # 👈 добавлено
91
  )
92
 
93
  if mesh is None or mesh.vertices.shape[0] == 0 or mesh.faces.shape[0] == 0:
@@ -129,14 +129,14 @@ def generate(image_path, face_number=50000, guidance_scale=5.0, num_steps=25):
129
  # Интерфейс Gradio
130
  demo = gr.Interface(
131
  fn=generate,
132
- inputs=gr.Image(type="filepath", label="Upload image"),
133
- # inputs=[
134
- # gr.Image(type="filepath", label="Upload image"),
135
- # gr.Slider(10000, 150000, step=10000, value=50000, label="Face count"),
136
- # gr.Slider(1.0, 10.0, step=0.5, value=5.0, label="Guidance Scale"),
137
- # gr.Slider(10, 100, step=5, value=25, label="Steps"),
138
- # gr.Slider(6, 10, step=1, value=9, label="Octree Depth"),
139
- # ], # 👈 добавлено
140
  outputs=gr.File(label="Download .glb"),
141
  title="TripoSG Image to 3D",
142
  description="Upload an image to generate a 3D model (.glb)",
 
67
  rmbg_net.eval()
68
 
69
 
70
+ # def generate(image_path, face_number=50000, guidance_scale=5.0, num_steps=25):
71
+ def generate(image_path, face_number=50000, guidance_scale=5.0, num_steps=25, octree_depth=9):
72
+ print(f"[INPUT] face_number={face_number}, guidance_scale={guidance_scale}, num_steps={num_steps}, octree_depth={octree_depth}")# 👈 добавлено_et
73
+ # print(f"[INPUT] face_number={face_number}, guidance_scale={guidance_scale}, num_steps={num_steps}")# 👈 добавлено_et
74
  print("[API CALL] image_path received:", image_path)
75
  print("[API CALL] File exists:", os.path.exists(image_path))
76
 
 
87
  num_inference_steps=int(num_steps),
88
  guidance_scale=float(guidance_scale),
89
  faces=int(face_number),
90
+ octree_depth=int(octree_depth), # 👈 добавлено_et
91
  )
92
 
93
  if mesh is None or mesh.vertices.shape[0] == 0 or mesh.faces.shape[0] == 0:
 
129
  # Интерфейс Gradio
130
  demo = gr.Interface(
131
  fn=generate,
132
+ # inputs=gr.Image(type="filepath", label="Upload image"),
133
+ inputs=[
134
+ gr.Image(type="filepath", label="Upload image"),
135
+ gr.Slider(10000, 150000, step=10000, value=50000, label="Face count"),
136
+ gr.Slider(1.0, 10.0, step=0.5, value=5.0, label="Guidance Scale"),
137
+ gr.Slider(10, 100, step=5, value=25, label="Steps"),
138
+ gr.Slider(6, 9, step=1, value=9, label="Octree Depth"),
139
+ ], # 👈 добавлено
140
  outputs=gr.File(label="Download .glb"),
141
  title="TripoSG Image to 3D",
142
  description="Upload an image to generate a 3D model (.glb)",
inference_triposg.py CHANGED
@@ -19,32 +19,6 @@ from briarmbg import BriaRMBG
19
  import pymeshlab
20
 
21
 
22
- # @torch.no_grad()
23
- # def run_triposg(
24
- # pipe: Any,
25
- # image_input: Union[str, Image.Image],
26
- # rmbg_net: Any,
27
- # seed: int,
28
- # num_inference_steps: int = 50,
29
- # guidance_scale: float = 7.0,
30
- # faces: int = -1,
31
- # ) -> trimesh.Scene:
32
-
33
- # img_pil = prepare_image(image_input, bg_color=np.array([1.0, 1.0, 1.0]), rmbg_net=rmbg_net)
34
-
35
- # outputs = pipe(
36
- # image=img_pil,
37
- # generator=torch.Generator(device=pipe.device).manual_seed(seed),
38
- # num_inference_steps=num_inference_steps,
39
- # guidance_scale=guidance_scale,
40
- # ).samples[0]
41
- # mesh = trimesh.Trimesh(outputs[0].astype(np.float32), np.ascontiguousarray(outputs[1]))
42
-
43
- # if faces > 0:
44
- # mesh = simplify_mesh(mesh, faces)
45
-
46
- # return mesh
47
-
48
  @torch.no_grad()
49
  def run_triposg(
50
  pipe: Any,
@@ -54,7 +28,7 @@ def run_triposg(
54
  num_inference_steps: int = 50,
55
  guidance_scale: float = 7.0,
56
  faces: int = -1,
57
- # octree_depth: int = 9, # 👈 добавлено
58
  ) -> trimesh.Scene:
59
  print("[DEBUG] Preparing image...")
60
  img_pil = prepare_image(image_input, bg_color=np.array([1.0, 1.0, 1.0]), rmbg_net=rmbg_net)
@@ -65,7 +39,7 @@ def run_triposg(
65
  generator=torch.Generator(device=pipe.device).manual_seed(seed),
66
  num_inference_steps=num_inference_steps,
67
  guidance_scale=guidance_scale,
68
- # flash_octree_depth=octree_depth, # 👈 добавлено
69
  ).samples[0]
70
 
71
  print("[DEBUG] TripoSG output keys:", type(outputs), outputs[0].shape, outputs[1].shape)
@@ -75,7 +49,7 @@ def run_triposg(
75
 
76
  if faces > 0:
77
  print(f"[DEBUG] Simplifying mesh to {faces} faces")
78
- # mesh = simplify_mesh(mesh, faces)
79
 
80
  return mesh
81
 
@@ -91,6 +65,7 @@ def pymesh_to_trimesh(mesh):
91
  faces = mesh.face_matrix()#.tolist()
92
  return trimesh.Trimesh(vertices=verts, faces=faces) #, vID, fID
93
 
 
94
  # def simplify_mesh(mesh: trimesh.Trimesh, n_faces):
95
  # if mesh.faces.shape[0] > n_faces:
96
  # ms = mesh_to_pymesh(mesh.vertices, mesh.faces)
@@ -100,17 +75,38 @@ def pymesh_to_trimesh(mesh):
100
  # else:
101
  # return mesh
102
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  def simplify_mesh(mesh: trimesh.Trimesh, n_faces):
104
- if mesh.faces.shape[0] > n_faces:
 
 
105
  ms = mesh_to_pymesh(mesh.vertices, mesh.faces)
106
  ms.meshing_merge_close_vertices()
107
  ms.meshing_decimation_quadric_edge_collapse(targetfacenum=n_faces)
108
  simplified = ms.current_mesh()
109
  if simplified is None or simplified.face_number() == 0:
110
  return None
 
 
 
 
111
  return pymesh_to_trimesh(simplified)
 
112
  return mesh
113
 
 
114
  if __name__ == "__main__":
115
  device = "cuda"
116
  dtype = torch.float16
 
19
  import pymeshlab
20
 
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  @torch.no_grad()
23
  def run_triposg(
24
  pipe: Any,
 
28
  num_inference_steps: int = 50,
29
  guidance_scale: float = 7.0,
30
  faces: int = -1,
31
+ octree_depth: int = 9, # 👈 добавлено_et
32
  ) -> trimesh.Scene:
33
  print("[DEBUG] Preparing image...")
34
  img_pil = prepare_image(image_input, bg_color=np.array([1.0, 1.0, 1.0]), rmbg_net=rmbg_net)
 
39
  generator=torch.Generator(device=pipe.device).manual_seed(seed),
40
  num_inference_steps=num_inference_steps,
41
  guidance_scale=guidance_scale,
42
+ flash_octree_depth=octree_depth, # 👈 добавлено_et
43
  ).samples[0]
44
 
45
  print("[DEBUG] TripoSG output keys:", type(outputs), outputs[0].shape, outputs[1].shape)
 
49
 
50
  if faces > 0:
51
  print(f"[DEBUG] Simplifying mesh to {faces} faces")
52
+ mesh = simplify_mesh(mesh, faces) # 👈 добавлено_et
53
 
54
  return mesh
55
 
 
65
  faces = mesh.face_matrix()#.tolist()
66
  return trimesh.Trimesh(vertices=verts, faces=faces) #, vID, fID
67
 
68
+ # old version
69
  # def simplify_mesh(mesh: trimesh.Trimesh, n_faces):
70
  # if mesh.faces.shape[0] > n_faces:
71
  # ms = mesh_to_pymesh(mesh.vertices, mesh.faces)
 
75
  # else:
76
  # return mesh
77
 
78
+ # new version
79
+ # def simplify_mesh(mesh: trimesh.Trimesh, n_faces):
80
+ # if mesh.faces.shape[0] > n_faces:
81
+ # ms = mesh_to_pymesh(mesh.vertices, mesh.faces)
82
+ # ms.meshing_merge_close_vertices()
83
+ # ms.meshing_decimation_quadric_edge_collapse(targetfacenum=n_faces)
84
+ # simplified = ms.current_mesh()
85
+ # if simplified is None or simplified.face_number() == 0:
86
+ # return None
87
+ # return pymesh_to_trimesh(simplified)
88
+ # return mesh
89
+
90
+ # new version demo
91
  def simplify_mesh(mesh: trimesh.Trimesh, n_faces):
92
+ original_faces = mesh.faces.shape[0] # 👈 сохраняем исходное количество
93
+
94
+ if original_faces > n_faces:
95
  ms = mesh_to_pymesh(mesh.vertices, mesh.faces)
96
  ms.meshing_merge_close_vertices()
97
  ms.meshing_decimation_quadric_edge_collapse(targetfacenum=n_faces)
98
  simplified = ms.current_mesh()
99
  if simplified is None or simplified.face_number() == 0:
100
  return None
101
+
102
+ simplified_faces = simplified.face_number()
103
+ print(f"[DEBUG] Simplified mesh: {original_faces} → {simplified_faces} faces") # 👈 лог здесь
104
+
105
  return pymesh_to_trimesh(simplified)
106
+
107
  return mesh
108
 
109
+
110
  if __name__ == "__main__":
111
  device = "cuda"
112
  dtype = torch.float16