Spaces:
mashroo
/
Running on Zero

File size: 3,368 Bytes
f4e8cf6
 
 
 
76eeb7d
 
 
 
 
 
8216438
f4e8cf6
dfab55e
 
 
 
 
 
 
 
 
 
76eeb7d
dfab55e
 
cb29219
76eeb7d
 
 
 
 
 
f4e8cf6
dfab55e
f4e8cf6
 
 
dfab55e
 
 
 
f4e8cf6
 
dfab55e
 
 
 
 
f4e8cf6
dfab55e
f4e8cf6
76eeb7d
 
cb29219
76eeb7d
 
 
 
f4e8cf6
76eeb7d
f4e8cf6
 
 
cb29219
f4e8cf6
dfab55e
 
 
 
 
f4e8cf6
 
 
30b6055
dfab55e
76eeb7d
30b6055
76eeb7d
 
 
 
 
 
 
dfab55e
30b6055
 
76eeb7d
dfab55e
30b6055
 
76eeb7d
30b6055
 
76eeb7d
30b6055
 
76eeb7d
30b6055
 
dfab55e
30b6055
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import numpy as np
import torch
import time
import tempfile
import os
from PIL import Image
import trimesh

from util.utils import get_tri
from mesh import Mesh
from util.renderer import Renderer

def generate3d(model, rgb, ccm, device):
    model.renderer = Renderer(
        tet_grid_size=model.tet_grid_size,
        camera_angle_num=model.camera_angle_num,
        scale=model.input.scale,
        geo_type=model.geo_type
    )

    color_tri = torch.from_numpy(rgb) / 255
    xyz_tri = torch.from_numpy(ccm[:, :, (2, 1, 0)]) / 255

    color = color_tri.permute(2, 0, 1)
    xyz = xyz_tri.permute(2, 0, 1)

    def get_imgs(img_tensor):
        images = []
        images.append(img_tensor[:, :, 256*5:256*(1+5)])
        for i in range(5):
            images.append(img_tensor[:, :, 256*i:256*(i+1)])
        return torch.stack(images, dim=0)  # [6, C, H, W]

    triplane_color = get_imgs(color).permute(0, 2, 3, 1).unsqueeze(0).to(device)
    color = get_imgs(color)
    xyz = get_imgs(xyz)

    color = get_tri(color, dim=0, blender=True, scale=1).unsqueeze(0)
    xyz = get_tri(xyz, dim=0, blender=True, scale=1, fix=True).unsqueeze(0)

    triplane = torch.cat([color, xyz], dim=1).to(device)

    model.eval()

    if model.denoising:
        tnew = torch.randint(20, 21, [triplane.shape[0]], dtype=torch.long, device=triplane.device)
        noise_new = torch.randn_like(triplane) * 0.5 + 0.5
        triplane = model.scheduler.add_noise(triplane, noise_new, tnew)
        with torch.no_grad():
            triplane_feature2 = model.unet2(triplane, tnew)
    else:
        with torch.no_grad():
            triplane_feature2 = model.unet2(triplane)

    data_config = {
        'resolution': [1024, 1024],
        "triview_color": triplane_color.to(device),
    }

    with torch.no_grad():
        verts, faces = model.decode(data_config, triplane_feature2)
        data_config['verts'] = verts[0]
        data_config['faces'] = faces

    from kiui.mesh_utils import clean_mesh
    verts, faces = clean_mesh(
        data_config['verts'].squeeze().cpu().numpy().astype(np.float32),
        data_config['faces'].squeeze().cpu().numpy().astype(np.int32),
        repair=False, remesh=True, remesh_size=0.005, remesh_iters=1
    )
    data_config['verts'] = torch.from_numpy(verts).cuda().contiguous()
    data_config['faces'] = torch.from_numpy(faces).cuda().contiguous()

    # Export to .obj with UV and texture image
    mesh_path_obj = tempfile.NamedTemporaryFile(suffix="", delete=False).name
    model.export_mesh_wt_uv(
        None,  # GL context skipped (no dr.RasterizeGLContext)
        data_config,
        mesh_path_obj,
        "",
        device,
        res=(1024, 1024),
        tri_fea_2=triplane_feature2
    )

    # Convert to GLB with embedded texture
    obj_path = mesh_path_obj + ".obj"
    texture_path = mesh_path_obj + ".png"

    if not os.path.exists(obj_path) or not os.path.exists(texture_path):
        raise RuntimeError("OBJ or texture file missing — cannot convert to GLB.")

    mesh_scene = trimesh.load(obj_path, force='scene')
    texture_image = np.array(Image.open(texture_path))

    for name, geom in mesh_scene.geometry.items():
        geom.visual = trimesh.visual.texture.TextureVisuals(image=texture_image)

    glb_path = mesh_path_obj + "_final.glb"
    mesh_scene.export(glb_path)

    return glb_path