Commit
·
d72a5f9
1
Parent(s):
a3626f7
Refactor generate3d function in inference.py to enhance 3D mesh generation. Introduced new image processing steps for color and XYZ data, improved triplane feature extraction, and updated mesh export logic to utilize temporary file handling. This streamlines the rendering process and maintains compatibility with GPU devices.
Browse files- app.py +4 -2
- inference.py +65 -26
app.py
CHANGED
@@ -205,8 +205,10 @@ with gr.Blocks() as demo:
|
|
205 |
image_output = gr.Image(interactive=False, label="Output RGB image")
|
206 |
xyz_ouput = gr.Image(interactive=False, label="Output CCM image")
|
207 |
|
208 |
-
output_model = gr.Model3D(
|
209 |
-
|
|
|
|
|
210 |
gr.Markdown("Note: Ensure that the input image is correctly pre-processed into a grey background, otherwise the results will be unpredictable.")
|
211 |
|
212 |
inputs = [
|
|
|
205 |
image_output = gr.Image(interactive=False, label="Output RGB image")
|
206 |
xyz_ouput = gr.Image(interactive=False, label="Output CCM image")
|
207 |
|
208 |
+
output_model = gr.Model3D(
|
209 |
+
label="Output OBJ",
|
210 |
+
interactive=False,
|
211 |
+
)
|
212 |
gr.Markdown("Note: Ensure that the input image is correctly pre-processed into a grey background, otherwise the results will be unpredictable.")
|
213 |
|
214 |
inputs = [
|
inference.py
CHANGED
@@ -1,35 +1,74 @@
|
|
1 |
-
import
|
2 |
import torch
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
4 |
|
5 |
-
|
6 |
-
def ensure_dir(path):
|
7 |
-
os.makedirs(path, exist_ok=True)
|
8 |
|
9 |
-
|
10 |
-
|
11 |
-
ensure_dir(output_dir)
|
12 |
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
|
18 |
-
|
19 |
-
|
|
|
|
|
|
|
|
|
20 |
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
|
|
|
|
|
|
|
|
27 |
|
28 |
-
|
29 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
|
31 |
-
|
32 |
-
|
33 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
|
35 |
-
return
|
|
|
1 |
+
import numpy as np
|
2 |
import torch
|
3 |
+
import time
|
4 |
+
import nvdiffrast.torch as dr
|
5 |
+
from util.utils import get_tri
|
6 |
+
import tempfile
|
7 |
+
from util.renderer import Renderer
|
8 |
+
import os
|
9 |
|
10 |
+
def generate3d(model, rgb, ccm, device):
|
|
|
|
|
11 |
|
12 |
+
model.renderer = Renderer(tet_grid_size=model.tet_grid_size, camera_angle_num=model.camera_angle_num,
|
13 |
+
scale=model.input.scale, geo_type=model.geo_type)
|
|
|
14 |
|
15 |
+
color_tri = torch.from_numpy(rgb) / 255
|
16 |
+
xyz_tri = torch.from_numpy(ccm[:, :, (2, 1, 0)]) / 255
|
17 |
+
color = color_tri.permute(2, 0, 1)
|
18 |
+
xyz = xyz_tri.permute(2, 0, 1)
|
19 |
|
20 |
+
def get_imgs(color):
|
21 |
+
color_list = []
|
22 |
+
color_list.append(color[:, :, 256 * 5:256 * (1 + 5)])
|
23 |
+
for i in range(0, 5):
|
24 |
+
color_list.append(color[:, :, 256 * i:256 * (1 + i)])
|
25 |
+
return torch.stack(color_list, dim=0)
|
26 |
|
27 |
+
triplane_color = get_imgs(color).permute(0, 2, 3, 1).unsqueeze(0).to(device)
|
28 |
+
|
29 |
+
color = get_imgs(color)
|
30 |
+
xyz = get_imgs(xyz)
|
31 |
+
|
32 |
+
color = get_tri(color, dim=0, blender=True, scale=1).unsqueeze(0)
|
33 |
+
xyz = get_tri(xyz, dim=0, blender=True, scale=1, fix=True).unsqueeze(0)
|
34 |
+
|
35 |
+
triplane = torch.cat([color, xyz], dim=1).to(device)
|
36 |
+
model.eval()
|
37 |
|
38 |
+
if model.denoising:
|
39 |
+
tnew = 20
|
40 |
+
tnew = torch.randint(tnew, tnew + 1, [triplane.shape[0]], dtype=torch.long, device=triplane.device)
|
41 |
+
noise_new = torch.randn_like(triplane) * 0.5 + 0.5
|
42 |
+
triplane = model.scheduler.add_noise(triplane, noise_new, tnew)
|
43 |
+
with torch.no_grad():
|
44 |
+
triplane_feature2 = model.unet2(triplane, tnew)
|
45 |
+
else:
|
46 |
+
triplane_feature2 = model.unet2(triplane)
|
47 |
|
48 |
+
with torch.no_grad():
|
49 |
+
data_config = {
|
50 |
+
'resolution': [1024, 1024],
|
51 |
+
"triview_color": triplane_color.to(device),
|
52 |
+
}
|
53 |
+
|
54 |
+
verts, faces = model.decode(data_config, triplane_feature2)
|
55 |
+
data_config['verts'] = verts[0]
|
56 |
+
data_config['faces'] = faces
|
57 |
+
|
58 |
+
from kiui.mesh_utils import clean_mesh
|
59 |
+
verts, faces = clean_mesh(
|
60 |
+
data_config['verts'].squeeze().cuda().numpy().astype(np.float32),
|
61 |
+
data_config['faces'].squeeze().cuda().numpy().astype(np.int32),
|
62 |
+
repair=False, remesh=True, remesh_size=0.005, remesh_iters=1
|
63 |
+
)
|
64 |
+
data_config['verts'] = torch.from_numpy(verts).to(device).contiguous()
|
65 |
+
data_config['faces'] = torch.from_numpy(faces).to(device).contiguous()
|
66 |
+
|
67 |
+
with torch.no_grad():
|
68 |
+
mesh_path_base = tempfile.NamedTemporaryFile(suffix="", delete=False).name
|
69 |
+
|
70 |
+
# Export mesh with UV, texture, and MTL
|
71 |
+
ctx = dr.RasterizeCudaContext(device=device)
|
72 |
+
model.export_mesh_wt_uv(ctx, data_config, mesh_path_base, ind=0, device=device, res=(1024, 1024), tri_fea_2=triplane_feature2)
|
73 |
|
74 |
+
return mesh_path_base + ".obj"
|