Spaces:
mashroo
/
Runtime error

YoussefAnso commited on
Commit
d72a5f9
·
1 Parent(s): a3626f7

Refactor generate3d function in inference.py to enhance 3D mesh generation. Introduced new image processing steps for color and XYZ data, improved triplane feature extraction, and updated mesh export logic to utilize temporary file handling. This streamlines the rendering process and maintains compatibility with GPU devices.

Browse files
Files changed (2) hide show
  1. app.py +4 -2
  2. inference.py +65 -26
app.py CHANGED
@@ -205,8 +205,10 @@ with gr.Blocks() as demo:
205
  image_output = gr.Image(interactive=False, label="Output RGB image")
206
  xyz_ouput = gr.Image(interactive=False, label="Output CCM image")
207
 
208
- output_model = gr.Model3D(label="Output GLB", clear_color=[1, 1, 1, 0])
209
-
 
 
210
  gr.Markdown("Note: Ensure that the input image is correctly pre-processed into a grey background, otherwise the results will be unpredictable.")
211
 
212
  inputs = [
 
205
  image_output = gr.Image(interactive=False, label="Output RGB image")
206
  xyz_ouput = gr.Image(interactive=False, label="Output CCM image")
207
 
208
+ output_model = gr.Model3D(
209
+ label="Output OBJ",
210
+ interactive=False,
211
+ )
212
  gr.Markdown("Note: Ensure that the input image is correctly pre-processed into a grey background, otherwise the results will be unpredictable.")
213
 
214
  inputs = [
inference.py CHANGED
@@ -1,35 +1,74 @@
1
- import os
2
  import torch
3
- from mesh import Mesh
 
 
 
 
 
4
 
5
- # Ensure the output directory exists
6
- def ensure_dir(path):
7
- os.makedirs(path, exist_ok=True)
8
 
9
- def generate3d(model, rgb_image, xyz_image, device="cpu"):
10
- output_dir = "outputs"
11
- ensure_dir(output_dir)
12
 
13
- prompt_id = "mesh_output"
14
- base_path = os.path.join(output_dir, prompt_id)
15
- obj_path = base_path + ".obj"
16
- glb_path = base_path + ".glb"
17
 
18
- # CRM export expects a data dictionary
19
- data = {"rgb": rgb_image, "xyz": xyz_image}
 
 
 
 
20
 
21
- # Get rendering context required by xatlas and nvdiffrast
22
- try:
23
- import nvdiffrast.torch as dr
24
- ctx = dr.RasterizeCudaContext(device=device)
25
- except Exception as e:
26
- raise RuntimeError("Failed to initialize nvdiffrast context: " + str(e))
 
 
 
 
27
 
28
- # Export mesh with UVs and texture image
29
- model.export_mesh_wt_uv(ctx, data, base_path, 0, device, 512)
 
 
 
 
 
 
 
30
 
31
- # Convert .obj to .glb
32
- mesh = Mesh.load(obj_path, device=torch.device("cpu"))
33
- mesh.write(glb_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
- return glb_path
 
1
+ import numpy as np
2
  import torch
3
+ import time
4
+ import nvdiffrast.torch as dr
5
+ from util.utils import get_tri
6
+ import tempfile
7
+ from util.renderer import Renderer
8
+ import os
9
 
10
+ def generate3d(model, rgb, ccm, device):
 
 
11
 
12
+ model.renderer = Renderer(tet_grid_size=model.tet_grid_size, camera_angle_num=model.camera_angle_num,
13
+ scale=model.input.scale, geo_type=model.geo_type)
 
14
 
15
+ color_tri = torch.from_numpy(rgb) / 255
16
+ xyz_tri = torch.from_numpy(ccm[:, :, (2, 1, 0)]) / 255
17
+ color = color_tri.permute(2, 0, 1)
18
+ xyz = xyz_tri.permute(2, 0, 1)
19
 
20
+ def get_imgs(color):
21
+ color_list = []
22
+ color_list.append(color[:, :, 256 * 5:256 * (1 + 5)])
23
+ for i in range(0, 5):
24
+ color_list.append(color[:, :, 256 * i:256 * (1 + i)])
25
+ return torch.stack(color_list, dim=0)
26
 
27
+ triplane_color = get_imgs(color).permute(0, 2, 3, 1).unsqueeze(0).to(device)
28
+
29
+ color = get_imgs(color)
30
+ xyz = get_imgs(xyz)
31
+
32
+ color = get_tri(color, dim=0, blender=True, scale=1).unsqueeze(0)
33
+ xyz = get_tri(xyz, dim=0, blender=True, scale=1, fix=True).unsqueeze(0)
34
+
35
+ triplane = torch.cat([color, xyz], dim=1).to(device)
36
+ model.eval()
37
 
38
+ if model.denoising:
39
+ tnew = 20
40
+ tnew = torch.randint(tnew, tnew + 1, [triplane.shape[0]], dtype=torch.long, device=triplane.device)
41
+ noise_new = torch.randn_like(triplane) * 0.5 + 0.5
42
+ triplane = model.scheduler.add_noise(triplane, noise_new, tnew)
43
+ with torch.no_grad():
44
+ triplane_feature2 = model.unet2(triplane, tnew)
45
+ else:
46
+ triplane_feature2 = model.unet2(triplane)
47
 
48
+ with torch.no_grad():
49
+ data_config = {
50
+ 'resolution': [1024, 1024],
51
+ "triview_color": triplane_color.to(device),
52
+ }
53
+
54
+ verts, faces = model.decode(data_config, triplane_feature2)
55
+ data_config['verts'] = verts[0]
56
+ data_config['faces'] = faces
57
+
58
+ from kiui.mesh_utils import clean_mesh
59
+ verts, faces = clean_mesh(
60
+ data_config['verts'].squeeze().cuda().numpy().astype(np.float32),
61
+ data_config['faces'].squeeze().cuda().numpy().astype(np.int32),
62
+ repair=False, remesh=True, remesh_size=0.005, remesh_iters=1
63
+ )
64
+ data_config['verts'] = torch.from_numpy(verts).to(device).contiguous()
65
+ data_config['faces'] = torch.from_numpy(faces).to(device).contiguous()
66
+
67
+ with torch.no_grad():
68
+ mesh_path_base = tempfile.NamedTemporaryFile(suffix="", delete=False).name
69
+
70
+ # Export mesh with UV, texture, and MTL
71
+ ctx = dr.RasterizeCudaContext(device=device)
72
+ model.export_mesh_wt_uv(ctx, data_config, mesh_path_base, ind=0, device=device, res=(1024, 1024), tri_fea_2=triplane_feature2)
73
 
74
+ return mesh_path_base + ".obj"