Spaces:
mashroo
/
Running on Zero

YoussefAnso commited on
Commit
a14c9ce
·
1 Parent(s): 2cf3ed2

Refactor generate3d function in inference.py to enhance mesh processing and export logic. Streamlined image tensor handling, improved denoising integration, and updated mesh export to include UV mapping. Ensured proper handling of temporary files for OBJ and GLB formats, enhancing overall readability and maintainability.

Browse files
Files changed (1) hide show
  1. inference.py +40 -43
inference.py CHANGED
@@ -1,32 +1,21 @@
1
- import numpy as np
2
- import torch
3
- import time
4
- import tempfile
5
  import trimesh
6
-
7
- from util.utils import get_tri
8
- from util.renderer import Renderer
 
9
 
10
  def generate3d(model, rgb, ccm, device):
11
- model.renderer = Renderer(
12
- tet_grid_size=model.tet_grid_size,
13
- camera_angle_num=model.camera_angle_num,
14
- scale=model.input.scale,
15
- geo_type=model.geo_type
16
- )
17
 
18
  color_tri = torch.from_numpy(rgb) / 255
19
  xyz_tri = torch.from_numpy(ccm[:, :, (2, 1, 0)]) / 255
20
-
21
  color = color_tri.permute(2, 0, 1)
22
  xyz = xyz_tri.permute(2, 0, 1)
23
 
24
- def get_imgs(img_tensor):
25
- images = []
26
- images.append(img_tensor[:, :, 256*5:256*(1+5)])
27
- for i in range(5):
28
- images.append(img_tensor[:, :, 256*i:256*(i+1)])
29
- return torch.stack(images, dim=0) # [6, C, H, W]
30
 
31
  triplane_color = get_imgs(color).permute(0, 2, 3, 1).unsqueeze(0).to(device)
32
  color = get_imgs(color)
@@ -34,41 +23,49 @@ def generate3d(model, rgb, ccm, device):
34
 
35
  color = get_tri(color, dim=0, blender=True, scale=1).unsqueeze(0)
36
  xyz = get_tri(xyz, dim=0, blender=True, scale=1, fix=True).unsqueeze(0)
37
-
38
  triplane = torch.cat([color, xyz], dim=1).to(device)
39
 
40
  model.eval()
41
-
42
- with torch.no_grad():
43
- if model.denoising:
44
- tnew = torch.randint(20, 21, [triplane.shape[0]], dtype=torch.long, device=triplane.device)
45
- noise_new = torch.randn_like(triplane) * 0.5 + 0.5
46
- triplane = model.scheduler.add_noise(triplane, noise_new, tnew)
47
  triplane_feature2 = model.unet2(triplane, tnew)
48
- else:
 
49
  triplane_feature2 = model.unet2(triplane)
50
 
 
 
 
 
 
51
  with torch.no_grad():
52
- data_config = {
53
- 'resolution': [1024, 1024],
54
- "triview_color": triplane_color.to(device),
55
- }
56
  verts, faces = model.decode(data_config, triplane_feature2)
 
 
57
 
58
- from kiui.mesh_utils import clean_mesh
59
- verts_np, faces_np = clean_mesh(
60
- verts[0].squeeze().cpu().numpy().astype(np.float32),
61
- faces.squeeze().cpu().numpy().astype(np.int32),
62
  repair=False, remesh=True, remesh_size=0.005, remesh_iters=1
63
  )
 
 
64
 
65
- # === Generate per-vertex color (approximate)
66
- colors = np.tile(np.mean(rgb.reshape(-1, 3), axis=0, keepdims=True), (verts_np.shape[0], 1)) / 255.0
67
- # Optionally, use more sophisticated color mapping logic if you have UVs
68
 
69
- mesh = trimesh.Trimesh(vertices=verts_np, faces=faces_np, vertex_colors=colors, process=False)
 
 
 
 
70
 
71
- glb_path = tempfile.NamedTemporaryFile(suffix=".glb", delete=False).name
72
- mesh.export(glb_path)
 
73
 
74
- return glb_path
 
1
+ from kiui.mesh_utils import clean_mesh
 
 
 
2
  import trimesh
3
+ import zipfile
4
+ import tempfile
5
+ import os
6
+ import nvdiffrast.torch as dr
7
 
8
  def generate3d(model, rgb, ccm, device):
9
+ model.renderer = Renderer(tet_grid_size=model.tet_grid_size, camera_angle_num=model.camera_angle_num,
10
+ scale=model.input.scale, geo_type=model.geo_type)
 
 
 
 
11
 
12
  color_tri = torch.from_numpy(rgb) / 255
13
  xyz_tri = torch.from_numpy(ccm[:, :, (2, 1, 0)]) / 255
 
14
  color = color_tri.permute(2, 0, 1)
15
  xyz = xyz_tri.permute(2, 0, 1)
16
 
17
+ def get_imgs(color):
18
+ return torch.stack([color[:, :, 256 * i:256 * (i + 1)] for i in [5, 0, 1, 2, 3, 4]], dim=0)
 
 
 
 
19
 
20
  triplane_color = get_imgs(color).permute(0, 2, 3, 1).unsqueeze(0).to(device)
21
  color = get_imgs(color)
 
23
 
24
  color = get_tri(color, dim=0, blender=True, scale=1).unsqueeze(0)
25
  xyz = get_tri(xyz, dim=0, blender=True, scale=1, fix=True).unsqueeze(0)
 
26
  triplane = torch.cat([color, xyz], dim=1).to(device)
27
 
28
  model.eval()
29
+ if model.denoising:
30
+ tnew = torch.randint(20, 21, [triplane.shape[0]], dtype=torch.long, device=triplane.device)
31
+ noise_new = torch.randn_like(triplane) * 0.5 + 0.5
32
+ triplane = model.scheduler.add_noise(triplane, noise_new, tnew)
33
+ with torch.no_grad():
 
34
  triplane_feature2 = model.unet2(triplane, tnew)
35
+ else:
36
+ with torch.no_grad():
37
  triplane_feature2 = model.unet2(triplane)
38
 
39
+ data_config = {
40
+ 'resolution': [1024, 1024],
41
+ "triview_color": triplane_color.to(device),
42
+ }
43
+
44
  with torch.no_grad():
 
 
 
 
45
  verts, faces = model.decode(data_config, triplane_feature2)
46
+ data_config['verts'] = verts[0]
47
+ data_config['faces'] = faces
48
 
49
+ verts, faces = clean_mesh(
50
+ data_config['verts'].squeeze().cpu().numpy().astype(np.float32),
51
+ data_config['faces'].squeeze().cpu().numpy().astype(np.int32),
 
52
  repair=False, remesh=True, remesh_size=0.005, remesh_iters=1
53
  )
54
+ data_config['verts'] = torch.from_numpy(verts).cuda().contiguous()
55
+ data_config['faces'] = torch.from_numpy(faces).cuda().contiguous()
56
 
57
+ # Create base filename
58
+ temp_path = tempfile.NamedTemporaryFile(suffix="", delete=False).name
59
+ obj_base = temp_path # no extension
60
 
61
+ # Export mesh with UV and PNG
62
+ glctx = dr.RasterizeCudaContext()
63
+ model.export_mesh_wt_uv(
64
+ glctx, data_config, obj_base, "", device, res=(1024, 1024), tri_fea_2=triplane_feature2
65
+ )
66
 
67
+ # Load .obj with texture and export .glb
68
+ mesh = trimesh.load(obj_base + ".obj", process=False)
69
+ mesh.export(obj_base + ".glb")
70
 
71
+ return obj_base + ".glb"