mac9087 commited on
Commit
13b098e
·
verified ·
1 Parent(s): 64188d6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -34
app.py CHANGED
@@ -14,11 +14,12 @@ from huggingface_hub import snapshot_download
14
  from flask_cors import CORS
15
  import numpy as np
16
  import trimesh
17
- from diffusers import DiffusionPipeline
18
- import cv2
19
 
20
  # Force CPU usage
21
  os.environ["CUDA_VISIBLE_DEVICES"] = ""
 
 
22
  torch.set_default_device("cpu")
23
  torch.cuda.is_available = lambda: False
24
  torch.cuda.device_count = lambda: 0
@@ -49,13 +50,13 @@ app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024
49
  processing_jobs = {}
50
 
51
  # Global model
52
- tripo_pipeline = None
53
  model_loaded = False
54
  model_loading = False
55
 
56
  # Configuration
57
- TIMEOUT_SECONDS = 300 # 5 minutes for TripoSG
58
- MAX_DIMENSION = 256 # TripoSG works with smaller images
59
 
60
  class TimeoutError(Exception):
61
  pass
@@ -100,33 +101,26 @@ def preprocess_image(image_path):
100
  img = img.convert('RGB')
101
  # Resize to 256x256
102
  img = img.resize((256, 256), Image.LANCZOS)
103
-
104
- # Basic cv2 cleanup
105
- img_array = np.array(img)
106
- gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
107
- _, mask = cv2.threshold(gray, 10, 255, cv2.THRESH_BINARY)
108
- img_array = cv2.bitwise_and(img_array, img_array, mask=mask)
109
-
110
- return Image.fromarray(img_array)
111
  except Exception as e:
112
  raise Exception(f"Error preprocessing image: {str(e)}")
113
 
114
  def load_model():
115
- global tripo_pipeline, model_loaded, model_loading
116
 
117
  if model_loaded:
118
- return tripo_pipeline
119
 
120
  if model_loading:
121
  while model_loading and not model_loaded:
122
  time.sleep(0.5)
123
- return tripo_pipeline
124
 
125
  try:
126
  model_loading = True
127
- print("Loading TripoSG model...")
128
 
129
- model_name = "tripo3d/tripo-sg-3d"
130
 
131
  # Download model
132
  max_retries = 3
@@ -147,17 +141,17 @@ def load_model():
147
  else:
148
  raise
149
 
150
- # Load TripoSG pipeline
151
- tripo_pipeline = DiffusionPipeline.from_pretrained(
152
  model_name,
153
  cache_dir=CACHE_DIR,
154
  torch_dtype=torch.float32,
155
  )
156
- tripo_pipeline.to("cpu")
157
 
158
  model_loaded = True
159
- print("TripoSG loaded successfully on CPU")
160
- return tripo_pipeline
161
 
162
  except Exception as e:
163
  print(f"Error loading model: {str(e)}")
@@ -169,20 +163,17 @@ def load_model():
169
  def generate_3d_model(image, detail_level):
170
  try:
171
  # Parameters
172
- num_steps = {'low': 20, 'medium': 30, 'high': 40}
173
  steps = num_steps[detail_level]
174
 
175
  # Generate 3D model
176
  with torch.no_grad():
177
- result = tripo_pipeline(image, num_inference_steps=steps)
178
 
179
  # Extract mesh
180
- mesh = result.meshes[0]
181
-
182
- # Convert to trimesh
183
- vertices = np.array(mesh.vertices)
184
- faces = np.array(mesh.faces)
185
- vertex_colors = np.array(mesh.vertex_colors) if mesh.vertex_colors is not None else None
186
 
187
  trimesh_mesh = trimesh.Trimesh(
188
  vertices=vertices,
@@ -201,7 +192,7 @@ def generate_3d_model(image, detail_level):
201
  def health_check():
202
  return jsonify({
203
  "status": "healthy",
204
- "model": "TripoSG",
205
  "device": "cpu"
206
  }), 200
207
 
@@ -442,7 +433,7 @@ def model_info(job_id):
442
  @app.route('/', methods=['GET'])
443
  def index():
444
  return jsonify({
445
- "message": "Image to 3D API (TripoSG)",
446
  "endpoints": [
447
  "/convert",
448
  "/progress/<job_id>",
@@ -454,7 +445,7 @@ def index():
454
  "output_format": "glb or obj",
455
  "detail_level": "low, medium, or high - controls inference steps"
456
  },
457
- "description": "Creates 3D models from 2D images using TripoSG. Use transparent PNGs for best results."
458
  }), 200
459
 
460
  if __name__ == '__main__':
 
14
  from flask_cors import CORS
15
  import numpy as np
16
  import trimesh
17
+ from trellis.pipelines import TrellisImageTo3DPipeline
 
18
 
19
  # Force CPU usage
20
  os.environ["CUDA_VISIBLE_DEVICES"] = ""
21
+ os.environ["ATTN_BACKEND"] = "native" # Disable xformers/flash-attn
22
+ os.environ["SPCONV_ALGO"] = "native" # Optimize for CPU
23
  torch.set_default_device("cpu")
24
  torch.cuda.is_available = lambda: False
25
  torch.cuda.device_count = lambda: 0
 
50
  processing_jobs = {}
51
 
52
  # Global model
53
+ trellis_pipeline = None
54
  model_loaded = False
55
  model_loading = False
56
 
57
  # Configuration
58
+ TIMEOUT_SECONDS = 360 # 6 minutes for TRELLIS
59
+ MAX_DIMENSION = 256 # TRELLIS works with smaller images
60
 
61
  class TimeoutError(Exception):
62
  pass
 
101
  img = img.convert('RGB')
102
  # Resize to 256x256
103
  img = img.resize((256, 256), Image.LANCZOS)
104
+ return img
 
 
 
 
 
 
 
105
  except Exception as e:
106
  raise Exception(f"Error preprocessing image: {str(e)}")
107
 
108
  def load_model():
109
+ global trellis_pipeline, model_loaded, model_loading
110
 
111
  if model_loaded:
112
+ return trellis_pipeline
113
 
114
  if model_loading:
115
  while model_loading and not model_loaded:
116
  time.sleep(0.5)
117
+ return trellis_pipeline
118
 
119
  try:
120
  model_loading = True
121
+ print("Loading TRELLIS-image-large...")
122
 
123
+ model_name = "JeffreyXiang/TRELLIS-image-large"
124
 
125
  # Download model
126
  max_retries = 3
 
141
  else:
142
  raise
143
 
144
+ # Load TRELLIS pipeline
145
+ trellis_pipeline = TrellisImageTo3DPipeline.from_pretrained(
146
  model_name,
147
  cache_dir=CACHE_DIR,
148
  torch_dtype=torch.float32,
149
  )
150
+ trellis_pipeline.to("cpu")
151
 
152
  model_loaded = True
153
+ print("TRELLIS loaded successfully on CPU")
154
+ return trellis_pipeline
155
 
156
  except Exception as e:
157
  print(f"Error loading model: {str(e)}")
 
163
  def generate_3d_model(image, detail_level):
164
  try:
165
  # Parameters
166
+ num_steps = {'low': 50, 'medium': 75, 'high': 100}
167
  steps = num_steps[detail_level]
168
 
169
  # Generate 3D model
170
  with torch.no_grad():
171
+ result = trellis_pipeline(image, num_inference_steps=steps, output_type="mesh")
172
 
173
  # Extract mesh
174
+ vertices = np.array(result.vertices)
175
+ faces = np.array(result.faces)
176
+ vertex_colors = np.array(result.vertex_colors) if result.vertex_colors is not None else None
 
 
 
177
 
178
  trimesh_mesh = trimesh.Trimesh(
179
  vertices=vertices,
 
192
  def health_check():
193
  return jsonify({
194
  "status": "healthy",
195
+ "model": "TRELLIS-image-large",
196
  "device": "cpu"
197
  }), 200
198
 
 
433
  @app.route('/', methods=['GET'])
434
  def index():
435
  return jsonify({
436
+ "message": "Image to 3D API (TRELLIS-image-large)",
437
  "endpoints": [
438
  "/convert",
439
  "/progress/<job_id>",
 
445
  "output_format": "glb or obj",
446
  "detail_level": "low, medium, or high - controls inference steps"
447
  },
448
+ "description": "Creates 3D models from 2D images using TRELLIS-image-large. Use transparent PNGs for best results."
449
  }), 200
450
 
451
  if __name__ == '__main__':