Spaces:
mashroo
/
Runtime error

YoussefAnso commited on
Commit
76eeb7d
·
1 Parent(s): dfab55e
Files changed (1) hide show
  1. inference.py +52 -32
inference.py CHANGED
@@ -1,13 +1,15 @@
1
  import numpy as np
2
  import torch
3
  import time
4
- import nvdiffrast.torch as dr
5
- from util.utils import get_tri
6
  import tempfile
7
- from mesh import Mesh
 
8
  import zipfile
 
 
 
 
9
  from util.renderer import Renderer
10
- import trimesh # Needed for glb export
11
 
12
  def generate3d(model, rgb, ccm, device):
13
  model.renderer = Renderer(
@@ -17,17 +19,18 @@ def generate3d(model, rgb, ccm, device):
17
  geo_type=model.geo_type
18
  )
19
 
20
- # RGB and coordinate conversion
21
  color_tri = torch.from_numpy(rgb) / 255
22
  xyz_tri = torch.from_numpy(ccm[:, :, (2, 1, 0)]) / 255
 
23
  color = color_tri.permute(2, 0, 1)
24
  xyz = xyz_tri.permute(2, 0, 1)
25
 
26
- def get_imgs(color):
27
- color_list = [color[:, :, 256 * 5:256 * (1 + 5)]]
28
- for i in range(0, 5):
29
- color_list.append(color[:, :, 256 * i:256 * (1 + i)])
30
- return torch.stack(color_list, dim=0)
 
31
 
32
  triplane_color = get_imgs(color).permute(0, 2, 3, 1).unsqueeze(0).to(device)
33
  color = get_imgs(color)
@@ -47,19 +50,20 @@ def generate3d(model, rgb, ccm, device):
47
  with torch.no_grad():
48
  triplane_feature2 = model.unet2(triplane, tnew)
49
  else:
50
- triplane_feature2 = model.unet2(triplane)
 
51
 
52
- with torch.no_grad():
53
- data_config = {
54
- 'resolution': [1024, 1024],
55
- 'triview_color': triplane_color.to(device),
56
- }
57
 
 
58
  verts, faces = model.decode(data_config, triplane_feature2)
59
  data_config['verts'] = verts[0]
60
  data_config['faces'] = faces
61
 
62
- # Optional mesh cleanup (reduce remesh for speed)
63
  from kiui.mesh_utils import clean_mesh
64
  verts, faces = clean_mesh(
65
  data_config['verts'].squeeze().cpu().numpy().astype(np.float32),
@@ -69,23 +73,39 @@ def generate3d(model, rgb, ccm, device):
69
  data_config['verts'] = torch.from_numpy(verts).cuda().contiguous()
70
  data_config['faces'] = torch.from_numpy(faces).cuda().contiguous()
71
 
72
- # Rasterization context
73
- glctx = dr.RasterizeGLContext()
74
-
75
- # Temporary output path
76
  mesh_path_obj = tempfile.NamedTemporaryFile(suffix="", delete=False).name
 
 
 
 
 
 
 
 
 
77
 
78
- # Export OBJ with UV and PNG
79
- with torch.no_grad():
80
- model.export_mesh_wt_uv(
81
- glctx, data_config, mesh_path_obj, "", device,
82
- res=(512, 512), tri_fea_2=triplane_feature2
83
- )
84
 
85
- # Convert to .glb using trimesh
86
- mesh = trimesh.load(mesh_path_obj + ".obj", force='mesh')
87
- mesh_path_glb = mesh_path_obj + ".glb"
88
- mesh.export(mesh_path_glb, file_type='glb')
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
- print(f"✅ Exported GLB with UV texture: {mesh_path_glb}")
91
  return mesh_path_glb
 
1
  import numpy as np
2
  import torch
3
  import time
 
 
4
  import tempfile
5
+ import os
6
+ from PIL import Image
7
  import zipfile
8
+ import trimesh
9
+
10
+ from util.utils import get_tri
11
+ from mesh import Mesh
12
  from util.renderer import Renderer
 
13
 
14
  def generate3d(model, rgb, ccm, device):
15
  model.renderer = Renderer(
 
19
  geo_type=model.geo_type
20
  )
21
 
 
22
  color_tri = torch.from_numpy(rgb) / 255
23
  xyz_tri = torch.from_numpy(ccm[:, :, (2, 1, 0)]) / 255
24
+
25
  color = color_tri.permute(2, 0, 1)
26
  xyz = xyz_tri.permute(2, 0, 1)
27
 
28
+ def get_imgs(img_tensor):
29
+ images = []
30
+ images.append(img_tensor[:, :, 256*5:256*(1+5)])
31
+ for i in range(5):
32
+ images.append(img_tensor[:, :, 256*i:256*(i+1)])
33
+ return torch.stack(images, dim=0) # [6, C, H, W]
34
 
35
  triplane_color = get_imgs(color).permute(0, 2, 3, 1).unsqueeze(0).to(device)
36
  color = get_imgs(color)
 
50
  with torch.no_grad():
51
  triplane_feature2 = model.unet2(triplane, tnew)
52
  else:
53
+ with torch.no_grad():
54
+ triplane_feature2 = model.unet2(triplane)
55
 
56
+ data_config = {
57
+ 'resolution': [1024, 1024],
58
+ "triview_color": triplane_color.to(device),
59
+ }
 
60
 
61
+ with torch.no_grad():
62
  verts, faces = model.decode(data_config, triplane_feature2)
63
  data_config['verts'] = verts[0]
64
  data_config['faces'] = faces
65
 
66
+ # Clean the mesh
67
  from kiui.mesh_utils import clean_mesh
68
  verts, faces = clean_mesh(
69
  data_config['verts'].squeeze().cpu().numpy().astype(np.float32),
 
73
  data_config['verts'] = torch.from_numpy(verts).cuda().contiguous()
74
  data_config['faces'] = torch.from_numpy(faces).cuda().contiguous()
75
 
76
+ # Export mesh with UV
 
 
 
77
  mesh_path_obj = tempfile.NamedTemporaryFile(suffix="", delete=False).name
78
+ model.export_mesh_wt_uv(
79
+ None, # GL context not needed for CPU export
80
+ data_config,
81
+ mesh_path_obj,
82
+ "",
83
+ device,
84
+ res=(1024, 1024),
85
+ tri_fea_2=triplane_feature2
86
+ )
87
 
88
+ # Check if texture exists
89
+ texture_path = mesh_path_obj + ".png"
90
+ if not os.path.exists(texture_path):
91
+ raise RuntimeError("Texture image not created, cannot export textured GLB.")
 
 
92
 
93
+ # Load the .obj file and apply the texture manually
94
+ scene_or_mesh = trimesh.load(mesh_path_obj + ".obj", force='scene')
95
+
96
+ texture_image = Image.open(texture_path)
97
+ texture_image = np.array(texture_image)
98
+
99
+ mesh_path_glb = mesh_path_obj + "_textured.glb"
100
+
101
+ if isinstance(scene_or_mesh, trimesh.Scene):
102
+ for _, geometry in scene_or_mesh.geometry.items():
103
+ material = trimesh.visual.texture.SimpleMaterial(image=texture_image)
104
+ geometry.visual = trimesh.visual.texture.TextureVisuals(image=texture_image, material=material)
105
+ scene_or_mesh.export(mesh_path_glb)
106
+ else:
107
+ material = trimesh.visual.texture.SimpleMaterial(image=texture_image)
108
+ scene_or_mesh.visual = trimesh.visual.texture.TextureVisuals(image=texture_image, material=material)
109
+ scene_or_mesh.export(mesh_path_glb)
110
 
 
111
  return mesh_path_glb