Spaces:
mashroo
/
Running on Zero

File size: 3,956 Bytes
d72a5f9
d63ad23
d72a5f9
 
 
 
 
e5c94c9
3d787d3
 
9d0b3b4
 
e5c94c9
 
 
 
 
 
 
9d0b3b4
d72a5f9
 
 
 
9d0b3b4
d72a5f9
 
 
 
 
 
c6a8a22
d72a5f9
 
 
 
 
 
 
 
 
 
db2fd1d
d72a5f9
 
 
 
 
 
 
 
 
d65fe1c
d72a5f9
 
 
 
 
 
 
 
 
 
 
411e8a2
263eddc
 
 
 
 
8a29dd3
 
 
 
bd713a9
 
d72a5f9
411e8a2
b6dcd98
1932e80
d72a5f9
 
 
 
68935d6
 
411e8a2
68935d6
 
 
 
 
411e8a2
 
2625296
411e8a2
e5c94c9
 
 
 
 
 
 
a34ef67
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import numpy as np
import torch
import time
from util.utils import get_tri
import tempfile
from util.renderer import Renderer
import os
from PIL import Image
import trimesh
from scipy.spatial import cKDTree


def generate3d(model, rgb, ccm, device):
    model.renderer = Renderer(
        tet_grid_size=model.tet_grid_size,
        camera_angle_num=model.camera_angle_num,
        scale=model.input.scale,
        geo_type=model.geo_type
    )

    color_tri = torch.from_numpy(rgb) / 255
    xyz_tri = torch.from_numpy(ccm[:, :, (2, 1, 0)]) / 255
    color = color_tri.permute(2, 0, 1)
    xyz = xyz_tri.permute(2, 0, 1)

    def get_imgs(color):
        color_list = []
        color_list.append(color[:, :, 256 * 5:256 * (1 + 5)])
        for i in range(0, 5):
            color_list.append(color[:, :, 256 * i:256 * (1 + i)])
        return torch.stack(color_list, dim=0)

    triplane_color = get_imgs(color).permute(0, 2, 3, 1).unsqueeze(0).to(device)

    color = get_imgs(color)
    xyz = get_imgs(xyz)

    color = get_tri(color, dim=0, blender=True, scale=1).unsqueeze(0)
    xyz = get_tri(xyz, dim=0, blender=True, scale=1, fix=True).unsqueeze(0)

    triplane = torch.cat([color, xyz], dim=1).to(device)
    model.eval()

    if model.denoising:
        tnew = 20
        tnew = torch.randint(tnew, tnew + 1, [triplane.shape[0]], dtype=torch.long, device=triplane.device)
        noise_new = torch.randn_like(triplane) * 0.5 + 0.5
        triplane = model.scheduler.add_noise(triplane, noise_new, tnew)
        with torch.no_grad():
            triplane_feature2 = model.unet2(triplane, tnew)
    else:
        triplane_feature2 = model.unet2(triplane)

    with torch.no_grad():
        data_config = {
            'resolution': [1024, 1024],
            "triview_color": triplane_color.to(device),
        }

        verts, faces = model.decode(data_config, triplane_feature2)
        data_config['verts'] = verts[0]
        data_config['faces'] = faces

    from kiui.mesh_utils import clean_mesh
    orig_verts = data_config['verts'].squeeze().cpu().numpy()
    # Extract per-vertex color for the original mesh
    orig_verts_tensor = data_config['verts'].unsqueeze(0)  # shape [1, N, 3]
    with torch.no_grad():
        dec_verts = model.decoder(triplane_feature2, orig_verts_tensor)
        orig_colors = model.rgbMlp(dec_verts).squeeze().detach().cpu().numpy()
        print('orig_colors min/max BEFORE scaling:', orig_colors.min(), orig_colors.max())
        # Comment out the scaling below if orig_colors is already in [0, 1]
        # orig_colors = (orig_colors * 0.5 + 0.5).clip(0, 1)  # scale to [0, 1]
        print('orig_colors min/max AFTER scaling:', orig_colors.min(), orig_colors.max())
        orig_colors = np.clip(orig_colors, 0, 1)
        orig_colors = np.power(orig_colors, 1/2.2)
    verts, faces = clean_mesh(
        orig_verts.astype(np.float32),
        data_config['faces'].squeeze().cpu().numpy().astype(np.int32),
        repair=True, remesh=False, remesh_size=0.005, remesh_iters=1
    )
    data_config['verts'] = torch.from_numpy(verts).to(device).contiguous()
    data_config['faces'] = torch.from_numpy(faces).to(device).contiguous()

    # # Build KDTree from original verts
    # tree = cKDTree(orig_verts)

    # # For each new vertex, find the nearest old vertex and copy its color
    # k = 3
    # dists, idxs = tree.query(verts, k=k)
    # # Use only the nearest neighbor for color assignment
    # new_colors = orig_colors[idxs[:, 0]]

    # Create the new mesh with colors
    mesh = trimesh.Trimesh(vertices=verts, faces=faces, vertex_colors=orig_colors)

    # === Export OBJ/MTL/PNG ===
    obj_path = tempfile.NamedTemporaryFile(suffix=".obj", delete=False).name
    base_path = obj_path[:-4]  # remove .obj

    texture_path = base_path + ".png"
    mtl_path = base_path + ".mtl"

    model.export_mesh(data_config, base_path, tri_fea_2=triplane_feature2)  # writes .obj

    return obj_path