Spaces:
mashroo
/
Running on Zero

YoussefAnso commited on
Commit
2cf3ed2
·
1 Parent(s): 30b6055

Refactor generate3d function in inference.py to streamline noise addition and improve mesh processing. Enhanced color generation for vertices and updated GLB export logic, ensuring efficient handling of mesh data and improved readability.

Browse files
Files changed (1) hide show
  1. inference.py +19 -47
inference.py CHANGED
@@ -2,12 +2,9 @@ import numpy as np
2
  import torch
3
  import time
4
  import tempfile
5
- import os
6
- from PIL import Image
7
  import trimesh
8
 
9
  from util.utils import get_tri
10
- from mesh import Mesh
11
  from util.renderer import Renderer
12
 
13
  def generate3d(model, rgb, ccm, device):
@@ -42,61 +39,36 @@ def generate3d(model, rgb, ccm, device):
42
 
43
  model.eval()
44
 
45
- if model.denoising:
46
- tnew = torch.randint(20, 21, [triplane.shape[0]], dtype=torch.long, device=triplane.device)
47
- noise_new = torch.randn_like(triplane) * 0.5 + 0.5
48
- triplane = model.scheduler.add_noise(triplane, noise_new, tnew)
49
- with torch.no_grad():
50
  triplane_feature2 = model.unet2(triplane, tnew)
51
- else:
52
- with torch.no_grad():
53
  triplane_feature2 = model.unet2(triplane)
54
 
55
- data_config = {
56
- 'resolution': [1024, 1024],
57
- "triview_color": triplane_color.to(device),
58
- }
59
-
60
  with torch.no_grad():
 
 
 
 
61
  verts, faces = model.decode(data_config, triplane_feature2)
62
- data_config['verts'] = verts[0]
63
- data_config['faces'] = faces
64
 
65
  from kiui.mesh_utils import clean_mesh
66
- verts, faces = clean_mesh(
67
- data_config['verts'].squeeze().cpu().numpy().astype(np.float32),
68
- data_config['faces'].squeeze().cpu().numpy().astype(np.int32),
69
  repair=False, remesh=True, remesh_size=0.005, remesh_iters=1
70
  )
71
- data_config['verts'] = torch.from_numpy(verts).cuda().contiguous()
72
- data_config['faces'] = torch.from_numpy(faces).cuda().contiguous()
73
-
74
- # Export to .obj with UV and texture image
75
- mesh_path_obj = tempfile.NamedTemporaryFile(suffix="", delete=False).name
76
- model.export_mesh_wt_uv(
77
- None, # GL context skipped (no dr.RasterizeGLContext)
78
- data_config,
79
- mesh_path_obj,
80
- "",
81
- device,
82
- res=(1024, 1024),
83
- tri_fea_2=triplane_feature2
84
- )
85
-
86
- # Convert to GLB with embedded texture
87
- obj_path = mesh_path_obj + ".obj"
88
- texture_path = mesh_path_obj + ".png"
89
-
90
- if not os.path.exists(obj_path) or not os.path.exists(texture_path):
91
- raise RuntimeError("OBJ or texture file missing — cannot convert to GLB.")
92
 
93
- mesh_scene = trimesh.load(obj_path, force='scene')
94
- texture_image = np.array(Image.open(texture_path))
 
95
 
96
- for name, geom in mesh_scene.geometry.items():
97
- geom.visual = trimesh.visual.texture.TextureVisuals(image=texture_image)
98
 
99
- glb_path = mesh_path_obj + "_final.glb"
100
- mesh_scene.export(glb_path)
101
 
102
  return glb_path
 
2
  import torch
3
  import time
4
  import tempfile
 
 
5
  import trimesh
6
 
7
  from util.utils import get_tri
 
8
  from util.renderer import Renderer
9
 
10
  def generate3d(model, rgb, ccm, device):
 
39
 
40
  model.eval()
41
 
42
+ with torch.no_grad():
43
+ if model.denoising:
44
+ tnew = torch.randint(20, 21, [triplane.shape[0]], dtype=torch.long, device=triplane.device)
45
+ noise_new = torch.randn_like(triplane) * 0.5 + 0.5
46
+ triplane = model.scheduler.add_noise(triplane, noise_new, tnew)
47
  triplane_feature2 = model.unet2(triplane, tnew)
48
+ else:
 
49
  triplane_feature2 = model.unet2(triplane)
50
 
 
 
 
 
 
51
  with torch.no_grad():
52
+ data_config = {
53
+ 'resolution': [1024, 1024],
54
+ "triview_color": triplane_color.to(device),
55
+ }
56
  verts, faces = model.decode(data_config, triplane_feature2)
 
 
57
 
58
  from kiui.mesh_utils import clean_mesh
59
+ verts_np, faces_np = clean_mesh(
60
+ verts[0].squeeze().cpu().numpy().astype(np.float32),
61
+ faces.squeeze().cpu().numpy().astype(np.int32),
62
  repair=False, remesh=True, remesh_size=0.005, remesh_iters=1
63
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
+ # === Generate per-vertex color (approximate)
66
+ colors = np.tile(np.mean(rgb.reshape(-1, 3), axis=0, keepdims=True), (verts_np.shape[0], 1)) / 255.0
67
+ # Optionally, use more sophisticated color mapping logic if you have UVs
68
 
69
+ mesh = trimesh.Trimesh(vertices=verts_np, faces=faces_np, vertex_colors=colors, process=False)
 
70
 
71
+ glb_path = tempfile.NamedTemporaryFile(suffix=".glb", delete=False).name
72
+ mesh.export(glb_path)
73
 
74
  return glb_path