Spaces:
mashroo
/
Running on Zero

File size: 3,241 Bytes
d72a5f9
d63ad23
d72a5f9
 
 
 
 
e5c94c9
9d0b3b4
 
e5c94c9
 
 
 
 
 
 
9d0b3b4
d72a5f9
 
 
 
9d0b3b4
d72a5f9
 
 
 
 
 
c6a8a22
d72a5f9
 
 
 
 
 
 
 
 
 
db2fd1d
d72a5f9
 
 
 
 
 
 
 
 
d65fe1c
d72a5f9
 
 
 
 
 
 
 
 
 
 
 
b6dcd98
 
d72a5f9
 
 
 
 
e5c94c9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d72a5f9
e5c94c9
 
 
 
 
 
f4e8cf6
e5c94c9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import numpy as np
import torch
import time
from util.utils import get_tri
import tempfile
from util.renderer import Renderer
import os
from PIL import Image


def generate3d(model, rgb, ccm, device):
    model.renderer = Renderer(
        tet_grid_size=model.tet_grid_size,
        camera_angle_num=model.camera_angle_num,
        scale=model.input.scale,
        geo_type=model.geo_type
    )

    color_tri = torch.from_numpy(rgb) / 255
    xyz_tri = torch.from_numpy(ccm[:, :, (2, 1, 0)]) / 255
    color = color_tri.permute(2, 0, 1)
    xyz = xyz_tri.permute(2, 0, 1)

    def get_imgs(color):
        color_list = []
        color_list.append(color[:, :, 256 * 5:256 * (1 + 5)])
        for i in range(0, 5):
            color_list.append(color[:, :, 256 * i:256 * (1 + i)])
        return torch.stack(color_list, dim=0)

    triplane_color = get_imgs(color).permute(0, 2, 3, 1).unsqueeze(0).to(device)

    color = get_imgs(color)
    xyz = get_imgs(xyz)

    color = get_tri(color, dim=0, blender=True, scale=1).unsqueeze(0)
    xyz = get_tri(xyz, dim=0, blender=True, scale=1, fix=True).unsqueeze(0)

    triplane = torch.cat([color, xyz], dim=1).to(device)
    model.eval()

    if model.denoising:
        tnew = 20
        tnew = torch.randint(tnew, tnew + 1, [triplane.shape[0]], dtype=torch.long, device=triplane.device)
        noise_new = torch.randn_like(triplane) * 0.5 + 0.5
        triplane = model.scheduler.add_noise(triplane, noise_new, tnew)
        with torch.no_grad():
            triplane_feature2 = model.unet2(triplane, tnew)
    else:
        triplane_feature2 = model.unet2(triplane)

    with torch.no_grad():
        data_config = {
            'resolution': [1024, 1024],
            "triview_color": triplane_color.to(device),
        }

        verts, faces = model.decode(data_config, triplane_feature2)
        data_config['verts'] = verts[0]
        data_config['faces'] = faces

    from kiui.mesh_utils import clean_mesh
    verts, faces = clean_mesh(
        data_config['verts'].squeeze().cpu().numpy().astype(np.float32),
        data_config['faces'].squeeze().cpu().numpy().astype(np.int32),
        repair=False, remesh=True, remesh_size=0.005, remesh_iters=1
    )
    data_config['verts'] = torch.from_numpy(verts).to(device).contiguous()
    data_config['faces'] = torch.from_numpy(faces).to(device).contiguous()

    # === Export OBJ/MTL/PNG ===
    obj_path = tempfile.NamedTemporaryFile(suffix=".obj", delete=False).name
    base_path = obj_path[:-4]  # remove .obj

    texture_path = base_path + ".png"
    mtl_path = base_path + ".mtl"

    model.export_mesh_geometry(data_config, obj_path)  # writes .obj with UVs
    model.export_texture_image(data_config, texture_path)  # saves PNG

    # Write MTL file manually
    with open(mtl_path, "w") as f:
        f.write("newmtl material0\n")
        f.write("Kd 1.000000 1.000000 1.000000\n")
        f.write(f"map_Kd {os.path.basename(texture_path)}\n")

    # Append .mtl reference to OBJ file
    with open(obj_path, "r") as original:
        lines = original.readlines()
    with open(obj_path, "w") as modified:
        modified.write(f"mtllib {os.path.basename(mtl_path)}\n")
        modified.writelines(lines)

    return obj_path