Spaces:
mashroo
/
Runtime error

File size: 4,267 Bytes
d63ad23
 
 
a14c9ce
d454202
 
9ee53e8
854cd53
f6ac519
d454202
d63ad23
 
5777f44
d454202
5777f44
f4e8cf6
dfab55e
a14c9ce
 
dfab55e
 
 
 
 
cb29219
a14c9ce
 
f4e8cf6
dfab55e
f4e8cf6
 
 
dfab55e
 
 
f4e8cf6
 
a14c9ce
 
 
 
 
dfab55e
a14c9ce
 
76eeb7d
cb29219
a14c9ce
 
 
 
 
76eeb7d
f4e8cf6
a14c9ce
 
cb29219
a14c9ce
 
 
dfab55e
 
9ee53e8
 
76eeb7d
f6ac519
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dcb102c
 
f6ac519
 
 
 
 
 
 
 
 
dfab55e
9ee53e8
 
 
 
 
 
f6ac519
9ee53e8
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import numpy as np
import torch
import time
import tempfile
import zipfile
import nvdiffrast.torch as dr
import xatlas
import cv2
import trimesh

from util.utils import get_tri
from mesh import Mesh
from util.renderer import Renderer
from kiui.mesh_utils import clean_mesh  


def generate3d(model, rgb, ccm, device):
    model.renderer = Renderer(tet_grid_size=model.tet_grid_size, camera_angle_num=model.camera_angle_num,
                              scale=model.input.scale, geo_type=model.geo_type)

    color_tri = torch.from_numpy(rgb) / 255
    xyz_tri = torch.from_numpy(ccm[:, :, (2, 1, 0)]) / 255
    color = color_tri.permute(2, 0, 1)
    xyz = xyz_tri.permute(2, 0, 1)

    def get_imgs(color):
        return torch.stack([color[:, :, 256 * i:256 * (i + 1)] for i in [5, 0, 1, 2, 3, 4]], dim=0)

    triplane_color = get_imgs(color).permute(0, 2, 3, 1).unsqueeze(0).to(device)
    color = get_imgs(color)
    xyz = get_imgs(xyz)

    color = get_tri(color, dim=0, blender=True, scale=1).unsqueeze(0)
    xyz = get_tri(xyz, dim=0, blender=True, scale=1, fix=True).unsqueeze(0)
    triplane = torch.cat([color, xyz], dim=1).to(device)

    model.eval()
    if model.denoising:
        tnew = torch.randint(20, 21, [triplane.shape[0]], dtype=torch.long, device=triplane.device)
        noise_new = torch.randn_like(triplane) * 0.5 + 0.5
        triplane = model.scheduler.add_noise(triplane, noise_new, tnew)
        with torch.no_grad():
            triplane_feature2 = model.unet2(triplane, tnew)
    else:
        with torch.no_grad():
            triplane_feature2 = model.unet2(triplane)

    data_config = {
        'resolution': [1024, 1024],
        "triview_color": triplane_color.to(device),
    }

    with torch.no_grad():
        verts, faces = model.decode(data_config, triplane_feature2)
        data_config['verts'] = verts[0]
        data_config['faces'] = faces

    verts, faces = clean_mesh(
        data_config['verts'].squeeze().cpu().numpy().astype(np.float32),
        data_config['faces'].squeeze().cpu().numpy().astype(np.int32),
        repair=False, remesh=True, remesh_size=0.005, remesh_iters=1
    )
    data_config['verts'] = torch.from_numpy(verts).contiguous()
    data_config['faces'] = torch.from_numpy(faces).contiguous()

    # CPU-only UV unwrapping with xatlas
    mesh_v = data_config['verts'].cpu().numpy()
    mesh_f = data_config['faces'].cpu().numpy()
    vmapping, ft, vt = xatlas.parametrize(mesh_v, mesh_f)

    # Bake texture using model's decoder and rgbMlp (CPU-only)
    tex_res = (1024, 1024)
    # Generate a grid of UV coordinates
    uv_grid = np.stack(np.meshgrid(
        np.linspace(0, 1, tex_res[0]),
        np.linspace(0, 1, tex_res[1])
    ), -1).reshape(-1, 2)  # (H*W, 2)
    # Map UVs to 3D positions using barycentric interpolation of mesh faces
    # For simplicity, we'll sample random points on the mesh surface and use their UVs
    # (A more advanced approach would rasterize each face, but this is a CPU-friendly approximation)
    mesh_trimesh = trimesh.Trimesh(vertices=mesh_v, faces=mesh_f, process=False)
    points, face_indices = trimesh.sample.sample_surface(mesh_trimesh, tex_res[0]*tex_res[1])
    # Use the model's decoder and rgbMlp to get color for each sampled point
    points_tensor = torch.from_numpy(points).float().unsqueeze(0).cpu()  # (1, N, 3) on CPU
    triplane_feature2 = triplane_feature2.cpu()  # Ensure on CPU
    with torch.no_grad():
        dec_verts = model.decoder(triplane_feature2, points_tensor)
        colors = model.rgbMlp(dec_verts).squeeze().cpu().numpy()  # (N, 3)
        colors = (colors * 0.5 + 0.5).clip(0, 1)
    # Fill the texture image
    texture = np.zeros((tex_res[1]*tex_res[0], 3), dtype=np.float32)
    texture[:colors.shape[0]] = colors
    texture = texture.reshape(tex_res[1], tex_res[0], 3)
    texture = np.clip(texture, 0, 1)

    # Create Mesh and export .glb
    mesh = Mesh(
        v=torch.from_numpy(mesh_v).float(),
        f=torch.from_numpy(mesh_f).int(),
        vt=torch.from_numpy(vt).float(),
        ft=torch.from_numpy(ft).int(),
        albedo=torch.from_numpy(texture).float()
    )
    temp_path = tempfile.NamedTemporaryFile(suffix=".glb", delete=False).name
    mesh.write(temp_path)
    return temp_path