mac9087 commited on
Commit
db29a74
·
verified ·
1 Parent(s): e4c93be

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +80 -111
app.py CHANGED
@@ -11,11 +11,11 @@ import io
11
  import zipfile
12
  import uuid
13
  import traceback
14
- from transformers import AutoImageProcessor, AutoModel
15
  from huggingface_hub import snapshot_download
16
  from flask_cors import CORS
17
  import numpy as np
18
  import trimesh
 
19
 
20
  app = Flask(__name__)
21
  CORS(app) # Enable CORS for all routes
@@ -43,8 +43,7 @@ app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024 # 16MB max
43
  processing_jobs = {}
44
 
45
  # Global model variables
46
- image_processor = None
47
- model = None
48
  model_loaded = False
49
  model_loading = False
50
 
@@ -107,23 +106,23 @@ def preprocess_image(image_path):
107
  return img
108
 
109
  def load_model():
110
- global image_processor, model, model_loaded, model_loading
111
 
112
  if model_loaded:
113
- return image_processor, model
114
 
115
  if model_loading:
116
  # Wait for model to load if it's already in progress
117
  while model_loading and not model_loaded:
118
  time.sleep(0.5)
119
- return image_processor, model
120
 
121
  try:
122
  model_loading = True
123
  print("Starting model loading...")
124
 
125
- # Using a lightweight model: Pictorial 3D Scene Representation
126
- model_name = "damo-vilab/text-to-3d-texture-base" # Smaller model than ShapE-img2img
127
 
128
  # Download model with retry mechanism
129
  max_retries = 3
@@ -147,24 +146,22 @@ def load_model():
147
 
148
  # Initialize model with lower precision to save memory
149
  device = "cuda" if torch.cuda.is_available() else "cpu"
150
- dtype = torch.float16 if device == "cuda" else torch.float32
151
 
152
- image_processor = AutoImageProcessor.from_pretrained(model_name, cache_dir=CACHE_DIR)
153
- model = AutoModel.from_pretrained(
154
- model_name,
155
- torch_dtype=dtype,
156
- cache_dir=CACHE_DIR,
157
- low_cpu_mem_usage=True,
158
  )
159
- model = model.to(device)
160
 
161
- # Optimize for inference
162
  if device == "cuda":
163
- model = model.half() # Use half precision on GPU
164
 
165
  model_loaded = True
166
  print(f"Model loaded successfully on {device}")
167
- return image_processor, model
168
 
169
  except Exception as e:
170
  print(f"Error loading model: {str(e)}")
@@ -173,89 +170,62 @@ def load_model():
173
  finally:
174
  model_loading = False
175
 
176
- # Convert model output to 3D mesh
177
- def create_mesh_from_output(output, resolution=64):
178
- """Create a mesh from model output"""
179
- # Extract features from model output and create mesh
180
- # This is a simplified implementation - adapt based on your specific model
181
- features = output.last_hidden_state.detach().cpu().numpy()[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
 
183
- # Create a simple cube mesh as placeholder - replace with actual mesh generation
184
- vertices, faces = create_primitive_mesh(features, resolution)
185
 
 
186
  mesh = trimesh.Trimesh(vertices=vertices, faces=faces)
 
 
 
 
 
 
187
  return mesh
188
 
189
- def create_primitive_mesh(features, resolution=64):
190
- """Create a simple primitive mesh based on features"""
191
- # Create a mesh using features as modifiers
192
- # This is a simplified implementation - adapt based on your specific model's output
193
-
194
- # Create a cube/sphere mesh as a placeholder
195
- use_sphere = True # Change to False for cube
196
-
197
- if use_sphere:
198
- # Create a sphere
199
- u = np.linspace(0, 2 * np.pi, resolution)
200
- v = np.linspace(0, np.pi, resolution)
201
-
202
- # Base radius and modifiers
203
- base_radius = 1.0
204
-
205
- # Use some features to modify the radius (just as an example)
206
- feature_sum = np.sum(features[:10]) # Use first 10 features
207
- radius_mod = 0.5 + (feature_sum % 1.0) # Simple modifier between 0.5 and 1.5
208
-
209
- # Create vertices
210
- x = base_radius * radius_mod * np.outer(np.cos(u), np.sin(v))
211
- y = base_radius * radius_mod * np.outer(np.sin(u), np.sin(v))
212
- z = base_radius * radius_mod * np.outer(np.ones_like(u), np.cos(v))
213
-
214
- # Reshape to get list of vertices
215
- vertices = np.vstack([x.flatten(), y.flatten(), z.flatten()]).T
216
-
217
- # Create faces (triangles)
218
- faces = []
219
- for i in range(resolution-1):
220
- for j in range(resolution-1):
221
- p1 = i * resolution + j
222
- p2 = i * resolution + (j + 1)
223
- p3 = (i + 1) * resolution + j
224
- p4 = (i + 1) * resolution + (j + 1)
225
-
226
- faces.append([p1, p2, p4])
227
- faces.append([p1, p4, p3])
228
-
229
- faces = np.array(faces)
230
- else:
231
- # Create a cube
232
- vertices = np.array([
233
- [-1, -1, -1], [1, -1, -1], [1, 1, -1], [-1, 1, -1],
234
- [-1, -1, 1], [1, -1, 1], [1, 1, 1], [-1, 1, 1]
235
- ])
236
-
237
- # Apply some feature-based modifications
238
- feature_sum = np.sum(features[:10]) # Use first 10 features
239
- scale_factor = 0.5 + (feature_sum % 1.0) # Simple modifier between 0.5 and 1.5
240
- vertices *= scale_factor
241
-
242
- # Faces (triangles)
243
- faces = np.array([
244
- [0, 1, 2], [0, 2, 3], # Bottom face
245
- [4, 5, 6], [4, 6, 7], # Top face
246
- [0, 1, 5], [0, 5, 4], # Front face
247
- [2, 3, 7], [2, 7, 6], # Back face
248
- [0, 3, 7], [0, 7, 4], # Left face
249
- [1, 2, 6], [1, 6, 5] # Right face
250
- ])
251
-
252
- return vertices, faces
253
-
254
  @app.route('/health', methods=['GET'])
255
  def health_check():
256
  return jsonify({
257
  "status": "healthy",
258
- "model": "Lightweight 3D Model Generator",
259
  "device": "cuda" if torch.cuda.is_available() else "cpu"
260
  }), 200
261
 
@@ -313,15 +283,11 @@ def convert_image_to_3d():
313
 
314
  # Get optional parameters with defaults
315
  try:
316
- guidance_scale = float(request.form.get('guidance_scale', 3.0))
317
  output_format = request.form.get('output_format', 'obj').lower()
318
  except ValueError:
319
  return jsonify({"error": "Invalid parameter values"}), 400
320
 
321
- # Validate parameters
322
- if guidance_scale < 1.0 or guidance_scale > 5.0:
323
- return jsonify({"error": "Guidance scale must be between 1.0 and 5.0"}), 400
324
-
325
  # Validate output format
326
  if output_format not in ['obj', 'glb']:
327
  return jsonify({"error": "Unsupported output format. Use 'obj' or 'glb'"}), 400
@@ -360,7 +326,7 @@ def convert_image_to_3d():
360
 
361
  # Load model
362
  try:
363
- processor, model_instance = load_model()
364
  processing_jobs[job_id]['progress'] = 30
365
  except Exception as e:
366
  processing_jobs[job_id]['status'] = 'error'
@@ -369,18 +335,20 @@ def convert_image_to_3d():
369
 
370
  # Process image with thread-safe timeout
371
  try:
372
- def generate_3d():
373
- # Process the image
374
- device = model_instance.device
375
- inputs = processor(images=image, return_tensors="pt").to(device)
376
 
377
- # Forward pass through model
378
- with torch.no_grad():
379
- outputs = model_instance(**inputs)
 
 
380
 
381
- return outputs
382
 
383
- outputs, error = process_with_timeout(generate_3d, [], TIMEOUT_SECONDS)
384
 
385
  if error:
386
  if isinstance(error, TimeoutError):
@@ -390,10 +358,11 @@ def convert_image_to_3d():
390
  else:
391
  raise error
392
 
393
- processing_jobs[job_id]['progress'] = 80
394
 
395
- # Create mesh from outputs
396
- mesh = create_mesh_from_output(outputs)
 
397
 
398
  except Exception as e:
399
  error_details = traceback.format_exc()
 
11
  import zipfile
12
  import uuid
13
  import traceback
 
14
  from huggingface_hub import snapshot_download
15
  from flask_cors import CORS
16
  import numpy as np
17
  import trimesh
18
+ from transformers import pipeline
19
 
20
  app = Flask(__name__)
21
  CORS(app) # Enable CORS for all routes
 
43
  processing_jobs = {}
44
 
45
  # Global model variables
46
+ depth_estimator = None
 
47
  model_loaded = False
48
  model_loading = False
49
 
 
106
  return img
107
 
108
  def load_model():
109
+ global depth_estimator, model_loaded, model_loading
110
 
111
  if model_loaded:
112
+ return depth_estimator
113
 
114
  if model_loading:
115
  # Wait for model to load if it's already in progress
116
  while model_loading and not model_loaded:
117
  time.sleep(0.5)
118
+ return depth_estimator
119
 
120
  try:
121
  model_loading = True
122
  print("Starting model loading...")
123
 
124
+ # Using DPT-Hybrid which is smaller than other depth estimation models
125
+ model_name = "Intel/dpt-hybrid-midas"
126
 
127
  # Download model with retry mechanism
128
  max_retries = 3
 
146
 
147
  # Initialize model with lower precision to save memory
148
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
149
 
150
+ # Load depth estimator pipeline
151
+ depth_estimator = pipeline(
152
+ "depth-estimation",
153
+ model=model_name,
154
+ device=device if device == "cuda" else -1,
155
+ cache_dir=CACHE_DIR
156
  )
 
157
 
158
+ # Optimize memory usage
159
  if device == "cuda":
160
+ torch.cuda.empty_cache()
161
 
162
  model_loaded = True
163
  print(f"Model loaded successfully on {device}")
164
+ return depth_estimator
165
 
166
  except Exception as e:
167
  print(f"Error loading model: {str(e)}")
 
170
  finally:
171
  model_loading = False
172
 
173
+ # Convert depth map to 3D mesh
174
+ def depth_to_mesh(depth_map, image, resolution=100):
175
+ """Convert depth map to 3D mesh"""
176
+ # Get dimensions
177
+ h, w = depth_map.shape
178
+
179
+ # Create a grid of points
180
+ x = np.linspace(0, w-1, resolution)
181
+ y = np.linspace(0, h-1, resolution)
182
+ x_grid, y_grid = np.meshgrid(x, y)
183
+
184
+ # Sample depth at grid points
185
+ x_indices = x_grid.astype(int)
186
+ y_indices = y_grid.astype(int)
187
+ z_values = depth_map[y_indices, x_indices]
188
+
189
+ # Normalize depth values to suitable range
190
+ z_min, z_max = z_values.min(), z_values.max()
191
+ z_values = (z_values - z_min) / (z_max - z_min) * 2.0 # Map to 0-2 range
192
+
193
+ # Normalize x and y coordinates
194
+ x_grid = (x_grid / w - 0.5) * 2.0 # Map to -1 to 1
195
+ y_grid = (y_grid / h - 0.5) * 2.0 # Map to -1 to 1
196
+
197
+ # Create vertices
198
+ vertices = np.vstack([x_grid.flatten(), -y_grid.flatten(), -z_values.flatten()]).T
199
+
200
+ # Create faces (triangles)
201
+ faces = []
202
+ for i in range(resolution-1):
203
+ for j in range(resolution-1):
204
+ p1 = i * resolution + j
205
+ p2 = i * resolution + (j + 1)
206
+ p3 = (i + 1) * resolution + j
207
+ p4 = (i + 1) * resolution + (j + 1)
208
+
209
+ faces.append([p1, p2, p4])
210
+ faces.append([p1, p4, p3])
211
 
212
+ faces = np.array(faces)
 
213
 
214
+ # Create mesh
215
  mesh = trimesh.Trimesh(vertices=vertices, faces=faces)
216
+
217
+ # Optional: Apply texture from original image
218
+ if image:
219
+ # This is simplified - proper UV mapping would be needed for accurate texturing
220
+ pass
221
+
222
  return mesh
223
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
224
  @app.route('/health', methods=['GET'])
225
  def health_check():
226
  return jsonify({
227
  "status": "healthy",
228
+ "model": "Depth-Based 3D Model Generator",
229
  "device": "cuda" if torch.cuda.is_available() else "cpu"
230
  }), 200
231
 
 
283
 
284
  # Get optional parameters with defaults
285
  try:
286
+ mesh_resolution = min(int(request.form.get('mesh_resolution', 100)), 200) # Limit max resolution
287
  output_format = request.form.get('output_format', 'obj').lower()
288
  except ValueError:
289
  return jsonify({"error": "Invalid parameter values"}), 400
290
 
 
 
 
 
291
  # Validate output format
292
  if output_format not in ['obj', 'glb']:
293
  return jsonify({"error": "Unsupported output format. Use 'obj' or 'glb'"}), 400
 
326
 
327
  # Load model
328
  try:
329
+ model = load_model()
330
  processing_jobs[job_id]['progress'] = 30
331
  except Exception as e:
332
  processing_jobs[job_id]['status'] = 'error'
 
335
 
336
  # Process image with thread-safe timeout
337
  try:
338
+ def estimate_depth():
339
+ # Get depth map
340
+ result = model(image)
341
+ depth_map = result["depth"]
342
 
343
+ # Convert to numpy array if needed
344
+ if isinstance(depth_map, torch.Tensor):
345
+ depth_map = depth_map.cpu().numpy()
346
+ elif hasattr(depth_map, 'numpy'):
347
+ depth_map = depth_map.numpy()
348
 
349
+ return depth_map
350
 
351
+ depth_map, error = process_with_timeout(estimate_depth, [], TIMEOUT_SECONDS)
352
 
353
  if error:
354
  if isinstance(error, TimeoutError):
 
358
  else:
359
  raise error
360
 
361
+ processing_jobs[job_id]['progress'] = 60
362
 
363
+ # Create mesh from depth map
364
+ mesh = depth_to_mesh(depth_map, image, resolution=mesh_resolution)
365
+ processing_jobs[job_id]['progress'] = 80
366
 
367
  except Exception as e:
368
  error_details = traceback.format_exc()