Spaces:
mashroo
/
Runtime error

File size: 2,719 Bytes
d63ad23
 
 
a14c9ce
 
d63ad23
 
 
 
5777f44
 
f4e8cf6
dfab55e
a14c9ce
 
dfab55e
 
 
 
 
cb29219
a14c9ce
 
f4e8cf6
dfab55e
f4e8cf6
 
 
dfab55e
 
 
f4e8cf6
 
a14c9ce
 
 
 
 
dfab55e
a14c9ce
 
76eeb7d
cb29219
a14c9ce
 
 
 
 
76eeb7d
f4e8cf6
a14c9ce
 
cb29219
a14c9ce
 
 
dfab55e
 
a14c9ce
 
76eeb7d
a14c9ce
 
 
76eeb7d
a14c9ce
 
 
 
 
76eeb7d
a14c9ce
 
 
dfab55e
a14c9ce
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import numpy as np
import torch
import time
import tempfile
import os
from PIL import Image
import trimesh
from util.utils import get_tri
from mesh import Mesh
from util.renderer import Renderer


def generate3d(model, rgb, ccm, device):
    model.renderer = Renderer(tet_grid_size=model.tet_grid_size, camera_angle_num=model.camera_angle_num,
                              scale=model.input.scale, geo_type=model.geo_type)

    color_tri = torch.from_numpy(rgb) / 255
    xyz_tri = torch.from_numpy(ccm[:, :, (2, 1, 0)]) / 255
    color = color_tri.permute(2, 0, 1)
    xyz = xyz_tri.permute(2, 0, 1)

    def get_imgs(color):
        return torch.stack([color[:, :, 256 * i:256 * (i + 1)] for i in [5, 0, 1, 2, 3, 4]], dim=0)

    triplane_color = get_imgs(color).permute(0, 2, 3, 1).unsqueeze(0).to(device)
    color = get_imgs(color)
    xyz = get_imgs(xyz)

    color = get_tri(color, dim=0, blender=True, scale=1).unsqueeze(0)
    xyz = get_tri(xyz, dim=0, blender=True, scale=1, fix=True).unsqueeze(0)
    triplane = torch.cat([color, xyz], dim=1).to(device)

    model.eval()
    if model.denoising:
        tnew = torch.randint(20, 21, [triplane.shape[0]], dtype=torch.long, device=triplane.device)
        noise_new = torch.randn_like(triplane) * 0.5 + 0.5
        triplane = model.scheduler.add_noise(triplane, noise_new, tnew)
        with torch.no_grad():
            triplane_feature2 = model.unet2(triplane, tnew)
    else:
        with torch.no_grad():
            triplane_feature2 = model.unet2(triplane)

    data_config = {
        'resolution': [1024, 1024],
        "triview_color": triplane_color.to(device),
    }

    with torch.no_grad():
        verts, faces = model.decode(data_config, triplane_feature2)
        data_config['verts'] = verts[0]
        data_config['faces'] = faces

    verts, faces = clean_mesh(
        data_config['verts'].squeeze().cpu().numpy().astype(np.float32),
        data_config['faces'].squeeze().cpu().numpy().astype(np.int32),
        repair=False, remesh=True, remesh_size=0.005, remesh_iters=1
    )
    data_config['verts'] = torch.from_numpy(verts).cuda().contiguous()
    data_config['faces'] = torch.from_numpy(faces).cuda().contiguous()

    # Create base filename
    temp_path = tempfile.NamedTemporaryFile(suffix="", delete=False).name
    obj_base = temp_path  # no extension

    # Export mesh with UV and PNG
    glctx = dr.RasterizeCudaContext()
    model.export_mesh_wt_uv(
        glctx, data_config, obj_base, "", device, res=(1024, 1024), tri_fea_2=triplane_feature2
    )

    # Load .obj with texture and export .glb
    mesh = trimesh.load(obj_base + ".obj", process=False)
    mesh.export(obj_base + ".glb")

    return obj_base + ".glb"