Spaces:
mashroo
/
Running on Zero

YoussefAnso commited on
Commit
dfab55e
·
1 Parent(s): b2dc9cc

Refactor generate3d function in inference.py to improve readability and maintainability. Enhanced RGB and coordinate conversion, streamlined noise addition for denoising, and updated mesh export process to utilize trimesh for GLB format, ensuring proper handling of UV textures.

Browse files
Files changed (1) hide show
  1. inference.py +55 -63
inference.py CHANGED
@@ -7,93 +7,85 @@ import tempfile
7
  from mesh import Mesh
8
  import zipfile
9
  from util.renderer import Renderer
10
- def generate3d(model, rgb, ccm, device):
11
-
12
- model.renderer = Renderer(tet_grid_size=model.tet_grid_size, camera_angle_num=model.camera_angle_num,
13
- scale=model.input.scale, geo_type = model.geo_type)
14
-
15
- color_tri = torch.from_numpy(rgb)/255
16
- xyz_tri = torch.from_numpy(ccm[:,:,(2,1,0)])/255
17
- color = color_tri.permute(2,0,1)
18
- xyz = xyz_tri.permute(2,0,1)
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
  def get_imgs(color):
22
- # color : [C, H, W*6]
23
- color_list = []
24
- color_list.append(color[:,:,256*5:256*(1+5)])
25
- for i in range(0,5):
26
- color_list.append(color[:,:,256*i:256*(1+i)])
27
- return torch.stack(color_list, dim=0)# [6, C, H, W]
28
-
29
- triplane_color = get_imgs(color).permute(0,2,3,1).unsqueeze(0).to(device)# [1, 6, H, W, C]
30
 
 
31
  color = get_imgs(color)
32
  xyz = get_imgs(xyz)
33
 
34
- color = get_tri(color, dim=0, blender= True, scale = 1).unsqueeze(0)
35
- xyz = get_tri(xyz, dim=0, blender= True, scale = 1, fix= True).unsqueeze(0)
 
 
36
 
37
- triplane = torch.cat([color,xyz],dim=1).to(device)
38
- # 3D visualize
39
  model.eval()
40
-
41
-
42
- if model.denoising == True:
43
- tnew = 20
44
- tnew = torch.randint(tnew, tnew+1, [triplane.shape[0]], dtype=torch.long, device=triplane.device)
45
- noise_new = torch.randn_like(triplane) *0.5+0.5
46
- triplane = model.scheduler.add_noise(triplane, noise_new, tnew)
47
- start_time = time.time()
48
  with torch.no_grad():
49
- triplane_feature2 = model.unet2(triplane,tnew)
50
- end_time = time.time()
51
- elapsed_time = end_time - start_time
52
- print(f"unet takes {elapsed_time}s")
53
  else:
54
  triplane_feature2 = model.unet2(triplane)
55
-
56
 
57
  with torch.no_grad():
58
  data_config = {
59
  'resolution': [1024, 1024],
60
- "triview_color": triplane_color.to(device),
61
  }
62
 
63
  verts, faces = model.decode(data_config, triplane_feature2)
64
-
65
  data_config['verts'] = verts[0]
66
  data_config['faces'] = faces
67
-
68
 
 
69
  from kiui.mesh_utils import clean_mesh
70
- verts, faces = clean_mesh(data_config['verts'].squeeze().cpu().numpy().astype(np.float32), data_config['faces'].squeeze().cpu().numpy().astype(np.int32), repair = False, remesh=True, remesh_size=0.005, remesh_iters=1)
 
 
 
 
71
  data_config['verts'] = torch.from_numpy(verts).cuda().contiguous()
72
  data_config['faces'] = torch.from_numpy(faces).cuda().contiguous()
73
 
74
- start_time = time.time()
 
 
 
 
 
 
75
  with torch.no_grad():
76
- mesh_path_glb = tempfile.NamedTemporaryFile(suffix=f"", delete=False).name
77
- model.export_mesh(data_config, mesh_path_glb, tri_fea_2 = triplane_feature2)
78
-
79
- # glctx = dr.RasterizeGLContext()#dr.RasterizeCudaContext()
80
- # mesh_path_obj = tempfile.NamedTemporaryFile(suffix=f"", delete=False).name
81
- # model.export_mesh_wt_uv(glctx, data_config, mesh_path_obj, "", device, res=(1024,1024), tri_fea_2=triplane_feature2)
82
-
83
- # mesh = Mesh.load(mesh_path_obj+".obj", bound=0.9, front_dir="+z")
84
- # mesh_path_glb = tempfile.NamedTemporaryFile(suffix=f"", delete=False).name
85
- # mesh.write(mesh_path_glb+".glb")
86
-
87
- # # mesh_obj2 = trimesh.load(mesh_path_glb+".glb", file_type='glb')
88
- # # mesh_path_obj2 = tempfile.NamedTemporaryFile(suffix=f"", delete=False).name
89
- # # mesh_obj2.export(mesh_path_obj2+".obj")
90
-
91
- # with zipfile.ZipFile(mesh_path_obj+'.zip', 'w') as myzip:
92
- # myzip.write(mesh_path_obj+'.obj', mesh_path_obj.split("/")[-1]+'.obj')
93
- # myzip.write(mesh_path_obj+'.png', mesh_path_obj.split("/")[-1]+'.png')
94
- # myzip.write(mesh_path_obj+'.mtl', mesh_path_obj.split("/")[-1]+'.mtl')
95
-
96
- end_time = time.time()
97
- elapsed_time = end_time - start_time
98
- print(f"uv takes {elapsed_time}s")
99
- return mesh_path_glb+".glb"
 
7
  from mesh import Mesh
8
  import zipfile
9
  from util.renderer import Renderer
10
+ import trimesh # Needed for glb export
 
 
 
 
 
 
 
 
11
 
12
+ def generate3d(model, rgb, ccm, device):
13
+ model.renderer = Renderer(
14
+ tet_grid_size=model.tet_grid_size,
15
+ camera_angle_num=model.camera_angle_num,
16
+ scale=model.input.scale,
17
+ geo_type=model.geo_type
18
+ )
19
+
20
+ # RGB and coordinate conversion
21
+ color_tri = torch.from_numpy(rgb) / 255
22
+ xyz_tri = torch.from_numpy(ccm[:, :, (2, 1, 0)]) / 255
23
+ color = color_tri.permute(2, 0, 1)
24
+ xyz = xyz_tri.permute(2, 0, 1)
25
 
26
  def get_imgs(color):
27
+ color_list = [color[:, :, 256 * 5:256 * (1 + 5)]]
28
+ for i in range(0, 5):
29
+ color_list.append(color[:, :, 256 * i:256 * (1 + i)])
30
+ return torch.stack(color_list, dim=0)
 
 
 
 
31
 
32
+ triplane_color = get_imgs(color).permute(0, 2, 3, 1).unsqueeze(0).to(device)
33
  color = get_imgs(color)
34
  xyz = get_imgs(xyz)
35
 
36
+ color = get_tri(color, dim=0, blender=True, scale=1).unsqueeze(0)
37
+ xyz = get_tri(xyz, dim=0, blender=True, scale=1, fix=True).unsqueeze(0)
38
+
39
+ triplane = torch.cat([color, xyz], dim=1).to(device)
40
 
 
 
41
  model.eval()
42
+
43
+ if model.denoising:
44
+ tnew = torch.randint(20, 21, [triplane.shape[0]], dtype=torch.long, device=triplane.device)
45
+ noise_new = torch.randn_like(triplane) * 0.5 + 0.5
46
+ triplane = model.scheduler.add_noise(triplane, noise_new, tnew)
 
 
 
47
  with torch.no_grad():
48
+ triplane_feature2 = model.unet2(triplane, tnew)
 
 
 
49
  else:
50
  triplane_feature2 = model.unet2(triplane)
 
51
 
52
  with torch.no_grad():
53
  data_config = {
54
  'resolution': [1024, 1024],
55
+ 'triview_color': triplane_color.to(device),
56
  }
57
 
58
  verts, faces = model.decode(data_config, triplane_feature2)
 
59
  data_config['verts'] = verts[0]
60
  data_config['faces'] = faces
 
61
 
62
+ # Optional mesh cleanup (reduce remesh for speed)
63
  from kiui.mesh_utils import clean_mesh
64
+ verts, faces = clean_mesh(
65
+ data_config['verts'].squeeze().cpu().numpy().astype(np.float32),
66
+ data_config['faces'].squeeze().cpu().numpy().astype(np.int32),
67
+ repair=False, remesh=True, remesh_size=0.005, remesh_iters=1
68
+ )
69
  data_config['verts'] = torch.from_numpy(verts).cuda().contiguous()
70
  data_config['faces'] = torch.from_numpy(faces).cuda().contiguous()
71
 
72
+ # Rasterization context
73
+ glctx = dr.RasterizeGLContext()
74
+
75
+ # Temporary output path
76
+ mesh_path_obj = tempfile.NamedTemporaryFile(suffix="", delete=False).name
77
+
78
+ # Export OBJ with UV and PNG
79
  with torch.no_grad():
80
+ model.export_mesh_wt_uv(
81
+ glctx, data_config, mesh_path_obj, "", device,
82
+ res=(512, 512), tri_fea_2=triplane_feature2
83
+ )
84
+
85
+ # Convert to .glb using trimesh
86
+ mesh = trimesh.load(mesh_path_obj + ".obj", force='mesh')
87
+ mesh_path_glb = mesh_path_obj + ".glb"
88
+ mesh.export(mesh_path_glb, file_type='glb')
89
+
90
+ print(f"✅ Exported GLB with UV texture: {mesh_path_glb}")
91
+ return mesh_path_glb