mac9087 commited on
Commit
44dd3d7
·
verified ·
1 Parent(s): 74d32e4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +430 -29
app.py CHANGED
@@ -16,7 +16,6 @@ from flask_cors import CORS
16
  import functools
17
  import numpy as np
18
  import trimesh
19
- from scipy.spatial import Delaunay
20
 
21
  app = Flask(__name__)
22
  CORS(app) # Enable CORS for all routes
@@ -185,23 +184,89 @@ class NeuS2Model:
185
 
186
  sdf_volume = torch.cat(sdf_values, dim=1).reshape(resolution, resolution, resolution).cpu().numpy()
187
 
188
- # Extract mesh using marching cubes
189
- vertices, faces = self._marching_cubes(sdf_volume, threshold)
190
 
191
  # Create a mesh object with vertices and faces
192
  mesh = type('Mesh', (), {'verts': vertices, 'faces': faces})
193
 
194
  return [mesh] # Returning in list format to match ShapE's output format
195
 
196
- def _marching_cubes(self, sdf_volume, threshold=0.0):
197
- # Simple implementation using surface points and Delaunay triangulation
198
- # For production, you'd want to use proper marching cubes from scikit-image
199
 
200
- # Find points near the surface
201
- x, y, z = np.where(np.abs(sdf_volume) < 0.1)
 
202
 
203
- if len(x) < 4: # Need at least 4 points for Delaunay
204
- # Create a simple cube if not enough points
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
  vertices = np.array([
206
  [-0.5, -0.5, -0.5],
207
  [0.5, -0.5, -0.5],
@@ -222,27 +287,104 @@ class NeuS2Model:
222
  ])
223
  return vertices, faces
224
 
225
- # Convert indices to 3D coordinates in [-1, 1] range
226
- res = sdf_volume.shape[0]
227
- points = np.stack([
228
- 2 * x / res - 1,
229
- 2 * y / res - 1,
230
- 2 * z / res - 1
231
- ], axis=1)
232
 
233
- # Limit to a reasonable number of points for Delaunay
234
- max_points = 1000
235
- if len(points) > max_points:
236
- indices = np.random.choice(len(points), max_points, replace=False)
237
- points = points[indices]
238
 
239
- try:
240
- # Create triangular mesh using Delaunay
241
- tri = Delaunay(points)
242
- return points, tri.simplices
243
- except Exception:
244
- # Fallback to simple shape if Delaunay fails
245
- return np.array([[-1, -1, -1], [1, -1, -1], [1, 1, -1], [-1, 1, -1]]), np.array([[0, 1, 2], [0, 2, 3]])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
246
 
247
  def load_model():
248
  global neus_model, model_loaded, model_loading
@@ -347,6 +489,265 @@ def export_to_obj(mesh, obj_path):
347
 
348
  return obj_path, mtl_path
349
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
350
  @app.route('/convert', methods=['POST'])
351
  def convert_image_to_3d():
352
  # Check if image is in the request
 
16
  import functools
17
  import numpy as np
18
  import trimesh
 
19
 
20
  app = Flask(__name__)
21
  CORS(app) # Enable CORS for all routes
 
184
 
185
  sdf_volume = torch.cat(sdf_values, dim=1).reshape(resolution, resolution, resolution).cpu().numpy()
186
 
187
+ # Extract mesh - alternative to marching cubes since we don't have scipy
188
+ vertices, faces = self._simple_mesh_extraction(sdf_volume, threshold)
189
 
190
  # Create a mesh object with vertices and faces
191
  mesh = type('Mesh', (), {'verts': vertices, 'faces': faces})
192
 
193
  return [mesh] # Returning in list format to match ShapE's output format
194
 
195
+ def _simple_mesh_extraction(self, sdf_volume, threshold=0.0):
196
+ """Simple mesh extraction without scipy dependency"""
197
+ resolution = sdf_volume.shape[0]
198
 
199
+ # Find surface points (approximate)
200
+ surface_points = []
201
+ surface_normals = []
202
 
203
+ # Sample points on three orthogonal grids
204
+ for axis in range(3):
205
+ for i in range(resolution):
206
+ for j in range(resolution):
207
+ # Create a line along the current axis
208
+ line = np.zeros((resolution, 3), dtype=int)
209
+ for k in range(resolution):
210
+ if axis == 0:
211
+ line[k] = [k, i, j]
212
+ elif axis == 1:
213
+ line[k] = [i, k, j]
214
+ else:
215
+ line[k] = [i, j, k]
216
+
217
+ # Get SDF values along this line
218
+ sdf_line = np.array([sdf_volume[tuple(idx)] for idx in line])
219
+
220
+ # Find zero crossings
221
+ signs = np.sign(sdf_line)
222
+ zero_crossings = np.where(np.diff(signs) != 0)[0]
223
+
224
+ for idx in zero_crossings:
225
+ # Linear interpolation to find more accurate zero crossing
226
+ t = sdf_line[idx] / (sdf_line[idx] - sdf_line[idx + 1])
227
+ point = line[idx] * (1 - t) + line[idx + 1] * t
228
+
229
+ # Normalize to [-1, 1] range
230
+ normalized_point = 2 * point / resolution - 1
231
+ surface_points.append(normalized_point)
232
+
233
+ # Compute normal (gradient of SDF)
234
+ normal = np.zeros(3)
235
+ idx_3d = tuple(np.round(point).astype(int).clip(0, resolution - 1))
236
+
237
+ # Compute gradient using central differences where possible
238
+ for d in range(3):
239
+ if 0 < idx_3d[d] < resolution - 1:
240
+ idx_plus = list(idx_3d)
241
+ idx_minus = list(idx_3d)
242
+ idx_plus[d] += 1
243
+ idx_minus[d] -= 1
244
+ normal[d] = (sdf_volume[tuple(idx_plus)] - sdf_volume[tuple(idx_minus)]) / 2
245
+ else:
246
+ # Forward or backward difference at boundaries
247
+ idx_curr = list(idx_3d)
248
+ if idx_3d[d] == 0:
249
+ idx_other = list(idx_3d)
250
+ idx_other[d] = 1
251
+ normal[d] = sdf_volume[tuple(idx_other)] - sdf_volume[tuple(idx_curr)]
252
+ else:
253
+ idx_other = list(idx_3d)
254
+ idx_other[d] = resolution - 2
255
+ normal[d] = sdf_volume[tuple(idx_curr)] - sdf_volume[tuple(idx_other)]
256
+
257
+ if np.linalg.norm(normal) > 0:
258
+ normal = normal / np.linalg.norm(normal)
259
+ surface_normals.append(normal)
260
+
261
+ # Limit the number of points to avoid OOM
262
+ max_points = 5000
263
+ if len(surface_points) > max_points:
264
+ indices = np.random.choice(len(surface_points), max_points, replace=False)
265
+ surface_points = [surface_points[i] for i in indices]
266
+ surface_normals = [surface_normals[i] for i in indices]
267
+
268
+ if len(surface_points) < 4:
269
+ # Not enough points found, create a simple cube
270
  vertices = np.array([
271
  [-0.5, -0.5, -0.5],
272
  [0.5, -0.5, -0.5],
 
287
  ])
288
  return vertices, faces
289
 
290
+ # Convert points to numpy array
291
+ points = np.array(surface_points)
 
 
 
 
 
292
 
293
+ # Create a simple mesh using ball-pivoting like algorithm
294
+ vertices = points
 
 
 
295
 
296
+ # For simplicity, create faces from nearest neighbors
297
+ # This is a very simple approach, not as good as Delaunay but doesn't require scipy
298
+ faces = []
299
+
300
+ # Use a simple approach to creating faces
301
+ # We'll use a greedy algorithm to connect nearby points
302
+ n_points = len(vertices)
303
+
304
+ # Create a simple connectivity graph
305
+ # For each point, find the N closest points
306
+ n_neighbors = min(12, n_points - 1)
307
+ adjacency = [[] for _ in range(n_points)]
308
+
309
+ # Compute all pairwise distances (this is O(n²) but should be ok for small point clouds)
310
+ for i in range(n_points):
311
+ distances = []
312
+ for j in range(n_points):
313
+ if i != j:
314
+ dist = np.linalg.norm(vertices[i] - vertices[j])
315
+ distances.append((dist, j))
316
+ distances.sort()
317
+ for k in range(min(n_neighbors, len(distances))):
318
+ adjacency[i].append(distances[k][1])
319
+
320
+ # Create triangles from the adjacency list
321
+ added_edges = set()
322
+ for i in range(n_points):
323
+ for j in adjacency[i]:
324
+ if j > i: # Avoid duplicates
325
+ edge_ij = (i, j)
326
+ added_edges.add(edge_ij)
327
+
328
+ # Find common neighbors between i and j to form triangles
329
+ common_neighbors = set(adjacency[i]) & set(adjacency[j])
330
+ for k in common_neighbors:
331
+ if k != i and k != j:
332
+ edge_ik = (i, k) if i < k else (k, i)
333
+ edge_jk = (j, k) if j < k else (k, j)
334
+
335
+ # Check if the other two edges exist
336
+ if edge_ik in added_edges and edge_jk in added_edges:
337
+ # Ensure consistent winding
338
+ normal = np.cross(vertices[j] - vertices[i], vertices[k] - vertices[i])
339
+ center_normal = np.mean(surface_normals, axis=0)
340
+
341
+ if np.dot(normal, center_normal) > 0:
342
+ faces.append([i, j, k])
343
+ else:
344
+ faces.append([i, k, j])
345
+
346
+ # If we couldn't create enough faces, create a simple shape
347
+ if len(faces) < 4:
348
+ # Create a convex hull-like shape
349
+ center = np.mean(vertices, axis=0)
350
+ n_points = len(vertices)
351
+
352
+ # Sort points by distance from center
353
+ dists = np.linalg.norm(vertices - center, axis=1)
354
+ sorted_indices = np.argsort(dists)
355
+
356
+ # Create a star-like structure connecting to center
357
+ center_idx = sorted_indices[0]
358
+ faces = []
359
+
360
+ for i in range(1, min(n_points, 10)):
361
+ if i + 1 < n_points:
362
+ faces.append([center_idx, sorted_indices[i], sorted_indices[i + 1]])
363
+ else:
364
+ faces.append([center_idx, sorted_indices[i], sorted_indices[1]])
365
+
366
+ # If still not enough, create a simple cube
367
+ if len(faces) < 4:
368
+ vertices = np.array([
369
+ [-0.5, -0.5, -0.5],
370
+ [0.5, -0.5, -0.5],
371
+ [0.5, 0.5, -0.5],
372
+ [-0.5, 0.5, -0.5],
373
+ [-0.5, -0.5, 0.5],
374
+ [0.5, -0.5, 0.5],
375
+ [0.5, 0.5, 0.5],
376
+ [-0.5, 0.5, 0.5]
377
+ ])
378
+ faces = np.array([
379
+ [0, 1, 2], [0, 2, 3], # Bottom face
380
+ [4, 5, 6], [4, 6, 7], # Top face
381
+ [0, 1, 5], [0, 5, 4], # Front face
382
+ [2, 3, 7], [2, 7, 6], # Back face
383
+ [0, 3, 7], [0, 7, 4], # Left face
384
+ [1, 2, 6], [1, 6, 5] # Right face
385
+ ])
386
+
387
+ return np.array(vertices), np.array(faces)
388
 
389
  def load_model():
390
  global neus_model, model_loaded, model_loading
 
489
 
490
  return obj_path, mtl_path
491
 
492
+ @app.route('/convert', methods=['POST'])
493
+ def convert_image_to_3d():
494
+ # Check if image is in the request
495
+ if 'image' not in request.files:
496
+ return jsonify({"error": "No image provided"}), 400
497
+
498
+ file = request.files['image']
499
+ if file.filename == '':
500
+ return jsonify({"error": "No image selected"}), 400
501
+
502
+ if not allowed_file(file.filename):
503
+ return jsonify({"error": f"File type not allowed. Supported types: {', '.join(ALLOWED_EXTENSIONS)}"}), 400
504
+
505
+ # Get optional parameters with defaults
506
+ try:
507
+ guidance_scale = float(request.form.get('guidance_scale', 3.0))
508
+ num_inference_steps = min(int(request.form.get('num_inference_steps', 32)), MAX_INFERENCE_STEPS)
509
+ output_format = request.form.get('output_format', 'obj').lower()
510
+ except ValueError:
511
+ return jsonify({"error": "Invalid parameter values"}), 400
512
+
513
+ # Validate parameters
514
+ if guidance_scale < 1.0 or guidance_scale > 5.0:
515
+ return jsonify({"error": "Guidance scale must be between 1.0 and 5.0"}), 400
516
+
517
+ if num_inference_steps < 16 or num_inference_steps > MAX_INFERENCE_STEPS:
518
+ num_inference_steps = min(num_inference_steps, MAX_INFERENCE_STEPS)
519
+
520
+ # Validate output format
521
+ if output_format not in ['obj', 'glb']:
522
+ return jsonify({"error": "Unsupported output format. Use 'obj' or 'glb'"}), 400
523
+
524
+ # Create a job ID
525
+ job_id = str(uuid.uuid4())
526
+ output_dir = os.path.join(RESULTS_FOLDER, job_id)
527
+ os.makedirs(output_dir, exist_ok=True)
528
+
529
+ # Save the uploaded file
530
+ filename = secure_filename(file.filename)
531
+ filepath = os.path.join(app.config['UPLOAD_FOLDER'], f"{job_id}_{filename}")
532
+ file.save(filepath)
533
+
534
+ # Initialize job tracking
535
+ processing_jobs[job_id] = {
536
+ 'status': 'processing',
537
+ 'progress': 0,
538
+ 'result_url': None,
539
+ 'preview_url': None,
540
+ 'error': None,
541
+ 'output_format': output_format,
542
+ 'created_at': time.time()
543
+ }
544
+
545
+ # Start processing in a separate thread
546
+ def process_image():
547
+ thread = threading.current_thread()
548
+ processing_jobs[job_id]['thread_alive'] = lambda: thread.is_alive()
549
+
550
+ try:
551
+ # Preprocess image (resize if needed)
552
+ processing_jobs[job_id]['progress'] = 5
553
+ image_tensor = preprocess_image(filepath)
554
+ processing_jobs[job_id]['progress'] = 10
555
+
556
+ # Load model
557
+ try:
558
+ model = load_model()
559
+ processing_jobs[job_id]['progress'] = 30
560
+ except Exception as e:
561
+ processing_jobs[job_id]['status'] = 'error'
562
+ processing_jobs[job_id]['error'] = f"Error loading model: {str(e)}"
563
+ return
564
+
565
+ # Process image with thread-safe timeout
566
+ try:
567
+ def generate_mesh():
568
+ return model.generate_mesh(
569
+ image_tensor,
570
+ resolution=min(32 + num_inference_steps // 2, 64), # Adjust resolution based on steps
571
+ threshold=0.0,
572
+ num_steps=num_inference_steps
573
+ )
574
+
575
+ images, error = process_with_timeout(generate_mesh, [], TIMEOUT_SECONDS)
576
+
577
+ if error:
578
+ if isinstance(error, TimeoutError):
579
+ processing_jobs[job_id]['status'] = 'error'
580
+ processing_jobs[job_id]['error'] = f"Processing timed out after {TIMEOUT_SECONDS} seconds"
581
+ return
582
+ else:
583
+ raise error
584
+
585
+ processing_jobs[job_id]['progress'] = 80
586
+ except Exception as e:
587
+ error_details = traceback.format_exc()
588
+ processing_jobs[job_id]['status'] = 'error'
589
+ processing_jobs[job_id]['error'] = f"Error during processing: {str(e)}"
590
+ print(f"Error processing job {job_id}: {str(e)}")
591
+ print(error_details)
592
+ return
593
+
594
+ # Export based on requested format
595
+ try:
596
+ if output_format == 'obj':
597
+ obj_path = os.path.join(output_dir, "model.obj")
598
+ obj_path, mtl_path = export_to_obj(images[0], obj_path)
599
+
600
+ # Create a zip file with OBJ and MTL
601
+ zip_path = os.path.join(output_dir, "model.zip")
602
+ with zipfile.ZipFile(zip_path, 'w') as zipf:
603
+ zipf.write(obj_path, arcname="model.obj")
604
+ if os.path.exists(mtl_path):
605
+ zipf.write(mtl_path, arcname="model.mtl")
606
+
607
+ processing_jobs[job_id]['result_url'] = f"/download/{job_id}"
608
+ processing_jobs[job_id]['preview_url'] = f"/preview/{job_id}"
609
+
610
+ elif output_format == 'glb':
611
+ # Convert to trimesh format
612
+ vertices = images[0].verts
613
+ faces = images[0].faces
614
+
615
+ # Create a trimesh object
616
+ trimesh_obj = trimesh.Trimesh(vertices=vertices, faces=faces)
617
+
618
+ # Export as GLB
619
+ glb_path = os.path.join(output_dir, "model.glb")
620
+ trimesh_obj.export(glb_path)
621
+
622
+ processing_jobs[job_id]['result_url'] = f"/download/{job_id}"
623
+ processing_jobs[job_id]['preview_url'] = f"/preview/{job_id}"
624
+
625
+ # Update job status
626
+ processing_jobs[job_id]['status'] = 'completed'
627
+ processing_jobs[job_id]['progress'] = 100
628
+ print(f"Job {job_id} completed successfully")
629
+ except Exception as e:
630
+ error_details = traceback.format_exc()
631
+ processing_jobs[job_id]['status'] = 'error'
632
+ processing_jobs[job_id]['error'] = f"Error exporting model: {str(e)}"
633
+ print(f"Error exporting model for job {job_id}: {str(e)}")
634
+ print(error_details)
635
+
636
+ # Clean up temporary file
637
+ if os.path.exists(filepath):
638
+ os.remove(filepath)
639
+
640
+ # Force garbage collection to free memory
641
+ gc.collect()
642
+ if torch.cuda.is_available():
643
+ torch.cuda.empty_cache()
644
+
645
+ except Exception as e:
646
+ # Handle errors
647
+ error_details = traceback.format_exc()
648
+ processing_jobs[job_id]['status'] = 'error'
649
+ processing_jobs[job_id]['error'] = f"{str(e)}\n{error_details}"
650
+ print(f"Error processing job {job_id}: {str(e)}")
651
+ print(error_details)
652
+
653
+ # Clean up on error
654
+ if os.path.exists(filepath):
655
+ os.remove(filepath)
656
+
657
+ # Start processing thread
658
+ processing_thread = threading.Thread(target=process_image)
659
+ processing_thread.daemon = True
660
+ processing_thread.start()
661
+
662
+ # Return job ID immediately
663
+ return jsonify({"job_id": job_id}), 202 # 202 Accepted
664
+
665
+ @app.route('/download/<job_id>', methods=['GET'])
666
+ def download_model(job_id):
667
+ if job_id not in processing_jobs or processing_jobs[job_id]['status'] != 'completed':
668
+ return jsonify({"error": "Model not found or processing not complete"}), 404
669
+
670
+ # Get the output directory for this job
671
+ output_dir = os.path.join(RESULTS_FOLDER, job_id)
672
+
673
+ # Determine file format from the job data
674
+ output_format = processing_jobs[job_id].get('output_format', 'obj')
675
+
676
+ if output_format == 'obj':
677
+ zip_path = os.path.join(output_dir, "model.zip")
678
+ if os.path.exists(zip_path):
679
+ return send_file(zip_path, as_attachment=True, download_name="model.zip")
680
+ else: # glb
681
+ glb_path = os.path.join(output_dir, "model.glb")
682
+ if os.path.exists(glb_path):
683
+ return send_file(glb_path, as_attachment=True, download_name="model.glb")
684
+
685
+ return jsonify({"error": "File not found"}), 404
686
+ @app.route('/progress/<job_id>', methods=['GET'])
687
+ def progress(job_id):
688
+ def generate():
689
+ if job_id not in processing_jobs:
690
+ yield f"data: {json.dumps({'error': 'Job not found'})}\n\n"
691
+ return
692
+
693
+ job = processing_jobs[job_id]
694
+
695
+ # Send initial progress
696
+ yield f"data: {json.dumps({'status': 'processing', 'progress': job['progress']})}\n\n"
697
+
698
+ # Wait for job to complete or update
699
+ last_progress = job['progress']
700
+ check_count = 0
701
+ while job['status'] == 'processing':
702
+ if job['progress'] != last_progress:
703
+ yield f"data: {json.dumps({'status': 'processing', 'progress': job['progress']})}\n\n"
704
+ last_progress = job['progress']
705
+
706
+ time.sleep(0.5)
707
+ check_count += 1
708
+
709
+ # If client hasn't received updates for a while, check if job is still running
710
+ if check_count > 60: # 30 seconds with no updates
711
+ if 'thread_alive' in job and not job['thread_alive']():
712
+ job['status'] = 'error'
713
+ job['error'] = 'Processing thread died unexpectedly'
714
+ break
715
+ check_count = 0
716
+
717
+ # Send final status
718
+ if job['status'] == 'completed':
719
+ yield f"data: {json.dumps({'status': 'completed', 'progress': 100, 'result_url': job['result_url'], 'preview_url': job['preview_url']})}\n\n"
720
+ else:
721
+ yield f"data: {json.dumps({'status': 'error', 'error': job['error']})}\n\n"
722
+
723
+ return Response(stream_with_context(generate()), mimetype='text/event-stream')
724
+
725
+ def export_to_obj(mesh, obj_path):
726
+ """Export mesh to OBJ file format"""
727
+ vertices = mesh.verts
728
+ faces = mesh.faces
729
+
730
+ with open(obj_path, 'w') as f:
731
+ # Write vertices
732
+ for v in vertices:
733
+ f.write(f"v {v[0]} {v[1]} {v[2]}\n")
734
+
735
+ # Write faces (OBJ uses 1-indexed vertices)
736
+ for face in faces:
737
+ f.write(f"f {face[0]+1} {face[1]+1} {face[2]+1}\n")
738
+
739
+ # Create a simple MTL file
740
+ mtl_path = obj_path.replace('.obj', '.mtl')
741
+ with open(mtl_path, 'w') as f:
742
+ f.write("newmtl material0\n")
743
+ f.write("Ka 1.0 1.0 1.0\n") # ambient color
744
+ f.write("Kd 0.8 0.8 0.8\n") # diffuse color
745
+ f.write("Ks 0.0 0.0 0.0\n") # specular color
746
+ f.write("Ns 0.0\n") # specular exponent
747
+ f.write("illum 2\n") # illumination model
748
+
749
+ return obj_path, mtl_path
750
+
751
  @app.route('/convert', methods=['POST'])
752
  def convert_image_to_3d():
753
  # Check if image is in the request