Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -14,11 +14,12 @@ from huggingface_hub import snapshot_download
|
|
14 |
from flask_cors import CORS
|
15 |
import numpy as np
|
16 |
import trimesh
|
17 |
-
from
|
18 |
-
import cv2
|
19 |
|
20 |
# Force CPU usage
|
21 |
os.environ["CUDA_VISIBLE_DEVICES"] = ""
|
|
|
|
|
22 |
torch.set_default_device("cpu")
|
23 |
torch.cuda.is_available = lambda: False
|
24 |
torch.cuda.device_count = lambda: 0
|
@@ -49,13 +50,13 @@ app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024
|
|
49 |
processing_jobs = {}
|
50 |
|
51 |
# Global model
|
52 |
-
|
53 |
model_loaded = False
|
54 |
model_loading = False
|
55 |
|
56 |
# Configuration
|
57 |
-
TIMEOUT_SECONDS =
|
58 |
-
MAX_DIMENSION = 256 #
|
59 |
|
60 |
class TimeoutError(Exception):
|
61 |
pass
|
@@ -100,33 +101,26 @@ def preprocess_image(image_path):
|
|
100 |
img = img.convert('RGB')
|
101 |
# Resize to 256x256
|
102 |
img = img.resize((256, 256), Image.LANCZOS)
|
103 |
-
|
104 |
-
# Basic cv2 cleanup
|
105 |
-
img_array = np.array(img)
|
106 |
-
gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
|
107 |
-
_, mask = cv2.threshold(gray, 10, 255, cv2.THRESH_BINARY)
|
108 |
-
img_array = cv2.bitwise_and(img_array, img_array, mask=mask)
|
109 |
-
|
110 |
-
return Image.fromarray(img_array)
|
111 |
except Exception as e:
|
112 |
raise Exception(f"Error preprocessing image: {str(e)}")
|
113 |
|
114 |
def load_model():
|
115 |
-
global
|
116 |
|
117 |
if model_loaded:
|
118 |
-
return
|
119 |
|
120 |
if model_loading:
|
121 |
while model_loading and not model_loaded:
|
122 |
time.sleep(0.5)
|
123 |
-
return
|
124 |
|
125 |
try:
|
126 |
model_loading = True
|
127 |
-
print("Loading
|
128 |
|
129 |
-
model_name = "
|
130 |
|
131 |
# Download model
|
132 |
max_retries = 3
|
@@ -147,17 +141,17 @@ def load_model():
|
|
147 |
else:
|
148 |
raise
|
149 |
|
150 |
-
# Load
|
151 |
-
|
152 |
model_name,
|
153 |
cache_dir=CACHE_DIR,
|
154 |
torch_dtype=torch.float32,
|
155 |
)
|
156 |
-
|
157 |
|
158 |
model_loaded = True
|
159 |
-
print("
|
160 |
-
return
|
161 |
|
162 |
except Exception as e:
|
163 |
print(f"Error loading model: {str(e)}")
|
@@ -169,20 +163,17 @@ def load_model():
|
|
169 |
def generate_3d_model(image, detail_level):
|
170 |
try:
|
171 |
# Parameters
|
172 |
-
num_steps = {'low':
|
173 |
steps = num_steps[detail_level]
|
174 |
|
175 |
# Generate 3D model
|
176 |
with torch.no_grad():
|
177 |
-
result =
|
178 |
|
179 |
# Extract mesh
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
vertices = np.array(mesh.vertices)
|
184 |
-
faces = np.array(mesh.faces)
|
185 |
-
vertex_colors = np.array(mesh.vertex_colors) if mesh.vertex_colors is not None else None
|
186 |
|
187 |
trimesh_mesh = trimesh.Trimesh(
|
188 |
vertices=vertices,
|
@@ -201,7 +192,7 @@ def generate_3d_model(image, detail_level):
|
|
201 |
def health_check():
|
202 |
return jsonify({
|
203 |
"status": "healthy",
|
204 |
-
"model": "
|
205 |
"device": "cpu"
|
206 |
}), 200
|
207 |
|
@@ -442,7 +433,7 @@ def model_info(job_id):
|
|
442 |
@app.route('/', methods=['GET'])
|
443 |
def index():
|
444 |
return jsonify({
|
445 |
-
"message": "Image to 3D API (
|
446 |
"endpoints": [
|
447 |
"/convert",
|
448 |
"/progress/<job_id>",
|
@@ -454,7 +445,7 @@ def index():
|
|
454 |
"output_format": "glb or obj",
|
455 |
"detail_level": "low, medium, or high - controls inference steps"
|
456 |
},
|
457 |
-
"description": "Creates 3D models from 2D images using
|
458 |
}), 200
|
459 |
|
460 |
if __name__ == '__main__':
|
|
|
14 |
from flask_cors import CORS
|
15 |
import numpy as np
|
16 |
import trimesh
|
17 |
+
from trellis.pipelines import TrellisImageTo3DPipeline
|
|
|
18 |
|
19 |
# Force CPU usage
|
20 |
os.environ["CUDA_VISIBLE_DEVICES"] = ""
|
21 |
+
os.environ["ATTN_BACKEND"] = "native" # Disable xformers/flash-attn
|
22 |
+
os.environ["SPCONV_ALGO"] = "native" # Optimize for CPU
|
23 |
torch.set_default_device("cpu")
|
24 |
torch.cuda.is_available = lambda: False
|
25 |
torch.cuda.device_count = lambda: 0
|
|
|
50 |
processing_jobs = {}
|
51 |
|
52 |
# Global model
|
53 |
+
trellis_pipeline = None
|
54 |
model_loaded = False
|
55 |
model_loading = False
|
56 |
|
57 |
# Configuration
|
58 |
+
TIMEOUT_SECONDS = 360 # 6 minutes for TRELLIS
|
59 |
+
MAX_DIMENSION = 256 # TRELLIS works with smaller images
|
60 |
|
61 |
class TimeoutError(Exception):
|
62 |
pass
|
|
|
101 |
img = img.convert('RGB')
|
102 |
# Resize to 256x256
|
103 |
img = img.resize((256, 256), Image.LANCZOS)
|
104 |
+
return img
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
105 |
except Exception as e:
|
106 |
raise Exception(f"Error preprocessing image: {str(e)}")
|
107 |
|
108 |
def load_model():
|
109 |
+
global trellis_pipeline, model_loaded, model_loading
|
110 |
|
111 |
if model_loaded:
|
112 |
+
return trellis_pipeline
|
113 |
|
114 |
if model_loading:
|
115 |
while model_loading and not model_loaded:
|
116 |
time.sleep(0.5)
|
117 |
+
return trellis_pipeline
|
118 |
|
119 |
try:
|
120 |
model_loading = True
|
121 |
+
print("Loading TRELLIS-image-large...")
|
122 |
|
123 |
+
model_name = "JeffreyXiang/TRELLIS-image-large"
|
124 |
|
125 |
# Download model
|
126 |
max_retries = 3
|
|
|
141 |
else:
|
142 |
raise
|
143 |
|
144 |
+
# Load TRELLIS pipeline
|
145 |
+
trellis_pipeline = TrellisImageTo3DPipeline.from_pretrained(
|
146 |
model_name,
|
147 |
cache_dir=CACHE_DIR,
|
148 |
torch_dtype=torch.float32,
|
149 |
)
|
150 |
+
trellis_pipeline.to("cpu")
|
151 |
|
152 |
model_loaded = True
|
153 |
+
print("TRELLIS loaded successfully on CPU")
|
154 |
+
return trellis_pipeline
|
155 |
|
156 |
except Exception as e:
|
157 |
print(f"Error loading model: {str(e)}")
|
|
|
163 |
def generate_3d_model(image, detail_level):
|
164 |
try:
|
165 |
# Parameters
|
166 |
+
num_steps = {'low': 50, 'medium': 75, 'high': 100}
|
167 |
steps = num_steps[detail_level]
|
168 |
|
169 |
# Generate 3D model
|
170 |
with torch.no_grad():
|
171 |
+
result = trellis_pipeline(image, num_inference_steps=steps, output_type="mesh")
|
172 |
|
173 |
# Extract mesh
|
174 |
+
vertices = np.array(result.vertices)
|
175 |
+
faces = np.array(result.faces)
|
176 |
+
vertex_colors = np.array(result.vertex_colors) if result.vertex_colors is not None else None
|
|
|
|
|
|
|
177 |
|
178 |
trimesh_mesh = trimesh.Trimesh(
|
179 |
vertices=vertices,
|
|
|
192 |
def health_check():
|
193 |
return jsonify({
|
194 |
"status": "healthy",
|
195 |
+
"model": "TRELLIS-image-large",
|
196 |
"device": "cpu"
|
197 |
}), 200
|
198 |
|
|
|
433 |
@app.route('/', methods=['GET'])
|
434 |
def index():
|
435 |
return jsonify({
|
436 |
+
"message": "Image to 3D API (TRELLIS-image-large)",
|
437 |
"endpoints": [
|
438 |
"/convert",
|
439 |
"/progress/<job_id>",
|
|
|
445 |
"output_format": "glb or obj",
|
446 |
"detail_level": "low, medium, or high - controls inference steps"
|
447 |
},
|
448 |
+
"description": "Creates 3D models from 2D images using TRELLIS-image-large. Use transparent PNGs for best results."
|
449 |
}), 200
|
450 |
|
451 |
if __name__ == '__main__':
|