File size: 5,213 Bytes
8aa156a
 
b2a27a7
 
5572b0e
 
 
 
 
 
 
 
 
 
 
 
b2a27a7
 
 
 
 
 
d23c9df
530f16a
972e6a2
 
f3bc318
f2228f5
b2a27a7
 
 
 
972e6a2
f3bc318
 
 
b2a27a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a9a84f7
0dcf605
 
 
 
b2a27a7
 
 
 
 
8aa156a
b2a27a7
 
 
 
 
 
 
b5ef988
 
 
0dcf605
b2a27a7
 
d22bdf6
 
 
a9a84f7
 
91756de
a9a84f7
 
31cbe0e
a9a84f7
 
 
4763a4b
a9a84f7
 
b2a27a7
6eac623
 
 
 
 
a9a84f7
 
8aa156a
 
 
f2228f5
b2a27a7
d23c9df
f3bc318
b2a27a7
 
f3bc318
b2a27a7
8aa156a
a9a84f7
b2a27a7
 
 
0dcf605
 
 
 
 
 
 
 
b2a27a7
 
 
 
 
ecfd160
 
b2a27a7
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149


import os
import subprocess

# Убираем pyenv
os.environ.pop("PYENV_VERSION", None)

# Установка зависимостей
subprocess.run(["pip", "install", "torch", "wheel"], check=True)
subprocess.run([
    "pip", "install", "--no-build-isolation", 
    "diso@git+https://github.com/SarahWeiii/diso.git"
], check=True)

# Импорты (перенесены после установки зависимостей)
import gradio as gr
import uuid
import torch
import zipfile
import requests
import traceback
import trimesh
import numpy as np


from trimesh.exchange.gltf import export_glb

from inference_triposg import run_triposg
from triposg.pipelines.pipeline_triposg import TripoSGPipeline
from briarmbg import BriaRMBG

GLTF_PACK = "/tmp/gltfpack"

print("Trimesh version:", trimesh.__version__)

# Настройки устройства
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if device == "cuda" else torch.float32

# Загрузка весов
weights_dir = "pretrained_weights"
triposg_path = os.path.join(weights_dir, "TripoSG")
rmbg_path = os.path.join(weights_dir, "RMBG-1.4")

if not (os.path.exists(triposg_path) and os.path.exists(rmbg_path)):
    print("📦 Downloading pretrained weights...")
    url = "https://huggingface.co/datasets/endlesstools/pretrained-assets/resolve/main/pretrained_models.zip"
    zip_path = "pretrained_models.zip"

    with requests.get(url, stream=True) as r:
        r.raise_for_status()
        with open(zip_path, "wb") as f:
            for chunk in r.iter_content(chunk_size=8192):
                f.write(chunk)

    print("📦 Extracting weights...")
    with zipfile.ZipFile(zip_path, "r") as zip_ref:
        zip_ref.extractall(weights_dir)

    os.remove(zip_path)
    print("✅ Weights ready.")

# Загрузка моделей
pipe = TripoSGPipeline.from_pretrained(triposg_path).to(device, dtype)
rmbg_net = BriaRMBG.from_pretrained(rmbg_path).to(device)
rmbg_net.eval()


# def generate(image_path, face_number=50000, guidance_scale=5.0, num_steps=25):
def generate(image_path, face_number=50000, guidance_scale=5.0, num_steps=25, octree_depth=9):
    print(f"[INPUT] face_number={face_number}, guidance_scale={guidance_scale}, num_steps={num_steps}, octree_depth={octree_depth}")# 👈 добавлено_et
    # print(f"[INPUT] face_number={face_number}, guidance_scale={guidance_scale}, num_steps={num_steps}")# 👈 добавлено_et
    print("[API CALL] image_path received:", image_path)
    print("[API CALL] File exists:", os.path.exists(image_path))

    temp_id = str(uuid.uuid4())
    output_path = f"/tmp/{temp_id}.glb"
    print("[DEBUG] Generating mesh from:", image_path)

    try:
        mesh = run_triposg(
            pipe=pipe,
            image_input=image_path,
            rmbg_net=rmbg_net,
            seed=42,
            num_inference_steps=int(num_steps),
            guidance_scale=float(guidance_scale),
            faces=int(face_number),
            octree_depth=int(octree_depth), # 👈 добавлено_et
        )

        if mesh is None or mesh.vertices.shape[0] == 0 or mesh.faces.shape[0] == 0:
            raise ValueError("Mesh generation returned an empty mesh")

        # 🔧 Пересоздаём Trimesh и гарантируем чистоту геометрии
        mesh = trimesh.Trimesh(vertices=mesh.vertices, faces=mesh.faces, process=True)

        # ✅ Центрируем модель
        mesh.apply_translation(-mesh.center_mass)

        # ✅ Масштабируем к единичному размеру (все модели ~одинаковые)
        scale_factor = 1.0 / np.max(np.linalg.norm(mesh.vertices, axis=1))
        mesh.apply_scale(scale_factor)

        # ✅ Гарантированно пересчитываем нормали
        mesh.fix_normals()

        # print("[DEBUG] Normals present:", mesh.has_vertex_normals)
        if hasattr(mesh, "vertex_normals"):
            print("[DEBUG] Normals shape:", mesh.vertex_normals.shape)
        else:
            print("[DEBUG] Normals missing.")

        # 💾 Сохраняем GLB
        glb_data = mesh.export(file_type='glb')
        with open(output_path, "wb") as f:
            f.write(glb_data)

        print(f"[DEBUG] Mesh saved to {output_path}")
        return output_path if os.path.exists(output_path) else None

    except Exception as e:
        print("[ERROR]", e)
        traceback.print_exc()
        return f"Error: {e}"


# Интерфейс Gradio
demo = gr.Interface(
    fn=generate,
    # inputs=gr.Image(type="filepath", label="Upload image"),
    inputs=[
        gr.Image(type="filepath", label="Upload image"),
        gr.Slider(10000, 150000, step=10000, value=50000, label="Face count"),
        gr.Slider(1.0, 10.0, step=0.5, value=5.0, label="Guidance Scale"),
        gr.Slider(10, 100, step=5, value=25, label="Steps"),
        gr.Slider(6, 9, step=1, value=9, label="Octree Depth"),
    ], # 👈 добавлено
    outputs=gr.File(label="Download .glb"),
    title="TripoSG Image to 3D",
    description="Upload an image to generate a 3D model (.glb)",
)



# Запуск
demo.launch()