File size: 5,185 Bytes
b2a27a7
 
5572b0e
 
 
 
 
 
 
 
 
 
 
 
b2a27a7
 
 
 
 
 
d23c9df
f3bc318
f2228f5
b2a27a7
 
 
 
4763a4b
 
 
 
f3bc318
 
 
b2a27a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b5ef988
b2a27a7
 
 
 
 
 
 
 
 
 
 
 
b5ef988
 
 
b2a27a7
 
d22bdf6
 
 
4763a4b
 
 
 
 
 
 
b2a27a7
4763a4b
 
 
 
 
 
 
 
 
 
 
 
 
cf54aef
4763a4b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f2228f5
b2a27a7
d23c9df
f3bc318
b2a27a7
 
f3bc318
b2a27a7
4763a4b
b2a27a7
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
import os
import subprocess

# Убираем pyenv
os.environ.pop("PYENV_VERSION", None)

# Установка зависимостей
subprocess.run(["pip", "install", "torch", "wheel"], check=True)
subprocess.run([
    "pip", "install", "--no-build-isolation", 
    "diso@git+https://github.com/SarahWeiii/diso.git"
], check=True)

# Импорты (перенесены после установки зависимостей)
import gradio as gr
import uuid
import torch
import zipfile
import requests
import traceback
import trimesh
from trimesh.exchange.gltf import export_glb

from inference_triposg import run_triposg
from triposg.pipelines.pipeline_triposg import TripoSGPipeline
from briarmbg import BriaRMBG

from pygltflib import GLTF2, Scene, Node, Mesh, Buffer, BufferView, Accessor, BufferTarget, ComponentType, AccessorType
import numpy as np
import base64


print("Trimesh version:", trimesh.__version__)

# Настройки устройства
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if device == "cuda" else torch.float32

# Загрузка весов
weights_dir = "pretrained_weights"
triposg_path = os.path.join(weights_dir, "TripoSG")
rmbg_path = os.path.join(weights_dir, "RMBG-1.4")

if not (os.path.exists(triposg_path) and os.path.exists(rmbg_path)):
    print("📦 Downloading pretrained weights...")
    url = "https://huggingface.co/datasets/endlesstools/pretrained-assets/resolve/main/pretrained_models.zip"
    zip_path = "pretrained_models.zip"

    with requests.get(url, stream=True) as r:
        r.raise_for_status()
        with open(zip_path, "wb") as f:
            for chunk in r.iter_content(chunk_size=8192):
                f.write(chunk)

    print("📦 Extracting weights...")
    with zipfile.ZipFile(zip_path, "r") as zip_ref:
        zip_ref.extractall(weights_dir)

    os.remove(zip_path)
    print("✅ Weights ready.")

# Загрузка моделей
pipe = TripoSGPipeline.from_pretrained(triposg_path).to(device, dtype)
rmbg_net = BriaRMBG.from_pretrained(rmbg_path).to(device)
rmbg_net.eval()

# Генерация .glb
def generate(image_path, face_number=50000, guidance_scale=5.0, num_steps=25):
    print("[API CALL] image_path received:", image_path)
    print("[API CALL] File exists:", os.path.exists(image_path))

    temp_id = str(uuid.uuid4())
    output_path = f"/tmp/{temp_id}.glb"

    try:
        mesh = run_triposg(
            pipe=pipe,
            image_input=image_path,
            rmbg_net=rmbg_net,
            seed=42,
            num_inference_steps=int(num_steps),
            guidance_scale=float(guidance_scale),
            faces=int(face_number),
        )

        if mesh is None or mesh.vertices.shape[0] == 0 or mesh.faces.shape[0] == 0:
            raise ValueError("Mesh generation returned an empty mesh")

        vertices = mesh.vertices.astype(np.float32)
        indices = mesh.faces.astype(np.uint32).flatten()

        # Pack binary data
        vertex_bytes = vertices.tobytes()
        index_bytes = indices.tobytes()
        total_bytes = vertex_bytes + index_bytes

        buffer = Buffer(byteLength=len(total_bytes))
        buffer_view_vert = BufferView(
            buffer=0,
            byteOffset=0,
            byteLength=len(vertex_bytes),
            target=BufferTarget.ARRAY_BUFFER.value
        )
        buffer_view_index = BufferView(
            buffer=0,
            byteOffset=len(vertex_bytes),
            byteLength=len(index_bytes),
            target=BufferTarget.ELEMENT_ARRAY_BUFFER.value
        )

        accessor_vert = Accessor(
            bufferView=0,
            byteOffset=0,
            componentType=ComponentType.FLOAT.value,
            count=len(vertices),
            type=AccessorType.VEC3.value,
            min=vertices.min(axis=0).tolist(),
            max=vertices.max(axis=0).tolist()
        )

        accessor_index = Accessor(
            bufferView=1,
            byteOffset=0,
            componentType=ComponentType.UNSIGNED_INT.value,
            count=len(indices),
            type=AccessorType.SCALAR.value
        )

        gltf = GLTF2(
            buffers=[buffer],
            bufferViews=[buffer_view_vert, buffer_view_index],
            accessors=[accessor_vert, accessor_index],
            meshes=[Mesh(primitives=[{
                "attributes": {"POSITION": 0},
                "indices": 1
            }])],
            scenes=[Scene(nodes=[0])],
            nodes=[Node(mesh=0)],
            scene=0
        )

        # Inject binary blob
        gltf.set_binary_blob(total_bytes)
        gltf.save_binary(output_path)

        print(f"[DEBUG] Mesh saved to {output_path}")
        return output_path if os.path.exists(output_path) else None

    except Exception as e:
        print("[ERROR]", e)
        traceback.print_exc()
        return f"Error: {e}"
    
# Интерфейс Gradio
demo = gr.Interface(
    fn=generate,
    inputs=gr.Image(type="filepath", label="Upload image"),
    outputs=gr.File(label="Download .glb"),
    title="TripoSG Image to 3D",
    description="Upload an image to generate a 3D model (.glb)",
)

# Запуск
demo.launch()