mac9087 commited on
Commit
5a23d7c
·
verified ·
1 Parent(s): 44dd3d7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +65 -593
app.py CHANGED
@@ -11,11 +11,11 @@ import io
11
  import zipfile
12
  import uuid
13
  import traceback
 
 
14
  from huggingface_hub import snapshot_download
15
  from flask_cors import CORS
16
  import functools
17
- import numpy as np
18
- import trimesh
19
 
20
  app = Flask(__name__)
21
  CORS(app) # Enable CORS for all routes
@@ -43,14 +43,14 @@ app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024 # 16MB max
43
  processing_jobs = {}
44
 
45
  # Global model variable
46
- neus_model = None
47
  model_loaded = False
48
  model_loading = False
49
 
50
  # Configuration for processing
51
- TIMEOUT_SECONDS = 180 # 3 minutes max for processing
52
  MAX_DIMENSION = 512 # Max image dimension to process
53
- MAX_INFERENCE_STEPS = 32 # Maximum allowed inference steps
54
 
55
  # TimeoutError for handling timeouts
56
  class TimeoutError(Exception):
@@ -104,310 +104,65 @@ def preprocess_image(image_path):
104
  new_width = int(img.width * (MAX_DIMENSION / img.height))
105
  img = img.resize((new_width, new_height), Image.LANCZOS)
106
 
107
- # Convert to RGB and convert to tensor
108
- img_array = np.array(img) / 255.0 # Normalize to [0, 1]
109
- img_tensor = torch.from_numpy(img_array).float().permute(2, 0, 1).unsqueeze(0) # [1, 3, H, W]
110
- return img_tensor
111
-
112
- # Simple NeuS2-inspired implementation for reconstructing 3D surfaces from images
113
- class NeuS2Model:
114
- def __init__(self, device="cuda" if torch.cuda.is_available() else "cpu"):
115
- self.device = device
116
- self.encoder = self._create_encoder().to(device)
117
- self.volume_network = self._create_volume_network().to(device)
118
-
119
- def _create_encoder(self):
120
- # Simple convolutional encoder
121
- return torch.nn.Sequential(
122
- torch.nn.Conv2d(3, 32, 3, stride=2, padding=1),
123
- torch.nn.ReLU(),
124
- torch.nn.Conv2d(32, 64, 3, stride=2, padding=1),
125
- torch.nn.ReLU(),
126
- torch.nn.Conv2d(64, 128, 3, stride=2, padding=1),
127
- torch.nn.ReLU(),
128
- torch.nn.AdaptiveAvgPool2d((8, 8)),
129
- torch.nn.Flatten(),
130
- torch.nn.Linear(8192, 512)
131
- )
132
-
133
- def _create_volume_network(self):
134
- # MLP to predict occupancy and SDF values
135
- return torch.nn.Sequential(
136
- torch.nn.Linear(515, 256), # 512 features + 3 coordinates
137
- torch.nn.ReLU(),
138
- torch.nn.Linear(256, 256),
139
- torch.nn.ReLU(),
140
- torch.nn.Linear(256, 1) # SDF value
141
- )
142
-
143
- def extract_features(self, image):
144
- with torch.no_grad():
145
- return self.encoder(image.to(self.device))
146
-
147
- def query_points(self, points, features):
148
- # points shape: [batch, num_points, 3]
149
- # features shape: [batch, 512]
150
- batch_size, num_points, _ = points.shape
151
-
152
- # Expand features to match points
153
- features = features.unsqueeze(1).expand(-1, num_points, -1) # [batch, num_points, 512]
154
-
155
- # Concatenate points with features
156
- points_features = torch.cat([points, features], dim=-1) # [batch, num_points, 515]
157
- points_features = points_features.reshape(-1, 515) # [batch*num_points, 515]
158
-
159
- # Query network
160
- with torch.no_grad():
161
- sdf = self.volume_network(points_features.to(self.device))
162
-
163
- return sdf.reshape(batch_size, num_points, 1)
164
-
165
- def generate_mesh(self, image, resolution=64, threshold=0.0, num_steps=16):
166
- # Extract image features
167
- features = self.extract_features(image) # [1, 512]
168
-
169
- # Create grid points
170
- x = torch.linspace(-1, 1, resolution)
171
- y = torch.linspace(-1, 1, resolution)
172
- z = torch.linspace(-1, 1, resolution)
173
- grid_x, grid_y, grid_z = torch.meshgrid(x, y, z, indexing='ij')
174
- points = torch.stack([grid_x, grid_y, grid_z], dim=-1).reshape(1, -1, 3) # [1, res^3, 3]
175
-
176
- # Process in batches to avoid OOM
177
- batch_size = 32768 # Adjust based on available memory
178
- sdf_values = []
179
-
180
- for i in range(0, points.shape[1], batch_size):
181
- batch_points = points[:, i:i+batch_size]
182
- batch_sdf = self.query_points(batch_points, features)
183
- sdf_values.append(batch_sdf)
184
-
185
- sdf_volume = torch.cat(sdf_values, dim=1).reshape(resolution, resolution, resolution).cpu().numpy()
186
-
187
- # Extract mesh - alternative to marching cubes since we don't have scipy
188
- vertices, faces = self._simple_mesh_extraction(sdf_volume, threshold)
189
-
190
- # Create a mesh object with vertices and faces
191
- mesh = type('Mesh', (), {'verts': vertices, 'faces': faces})
192
-
193
- return [mesh] # Returning in list format to match ShapE's output format
194
-
195
- def _simple_mesh_extraction(self, sdf_volume, threshold=0.0):
196
- """Simple mesh extraction without scipy dependency"""
197
- resolution = sdf_volume.shape[0]
198
-
199
- # Find surface points (approximate)
200
- surface_points = []
201
- surface_normals = []
202
-
203
- # Sample points on three orthogonal grids
204
- for axis in range(3):
205
- for i in range(resolution):
206
- for j in range(resolution):
207
- # Create a line along the current axis
208
- line = np.zeros((resolution, 3), dtype=int)
209
- for k in range(resolution):
210
- if axis == 0:
211
- line[k] = [k, i, j]
212
- elif axis == 1:
213
- line[k] = [i, k, j]
214
- else:
215
- line[k] = [i, j, k]
216
-
217
- # Get SDF values along this line
218
- sdf_line = np.array([sdf_volume[tuple(idx)] for idx in line])
219
-
220
- # Find zero crossings
221
- signs = np.sign(sdf_line)
222
- zero_crossings = np.where(np.diff(signs) != 0)[0]
223
-
224
- for idx in zero_crossings:
225
- # Linear interpolation to find more accurate zero crossing
226
- t = sdf_line[idx] / (sdf_line[idx] - sdf_line[idx + 1])
227
- point = line[idx] * (1 - t) + line[idx + 1] * t
228
-
229
- # Normalize to [-1, 1] range
230
- normalized_point = 2 * point / resolution - 1
231
- surface_points.append(normalized_point)
232
-
233
- # Compute normal (gradient of SDF)
234
- normal = np.zeros(3)
235
- idx_3d = tuple(np.round(point).astype(int).clip(0, resolution - 1))
236
-
237
- # Compute gradient using central differences where possible
238
- for d in range(3):
239
- if 0 < idx_3d[d] < resolution - 1:
240
- idx_plus = list(idx_3d)
241
- idx_minus = list(idx_3d)
242
- idx_plus[d] += 1
243
- idx_minus[d] -= 1
244
- normal[d] = (sdf_volume[tuple(idx_plus)] - sdf_volume[tuple(idx_minus)]) / 2
245
- else:
246
- # Forward or backward difference at boundaries
247
- idx_curr = list(idx_3d)
248
- if idx_3d[d] == 0:
249
- idx_other = list(idx_3d)
250
- idx_other[d] = 1
251
- normal[d] = sdf_volume[tuple(idx_other)] - sdf_volume[tuple(idx_curr)]
252
- else:
253
- idx_other = list(idx_3d)
254
- idx_other[d] = resolution - 2
255
- normal[d] = sdf_volume[tuple(idx_curr)] - sdf_volume[tuple(idx_other)]
256
-
257
- if np.linalg.norm(normal) > 0:
258
- normal = normal / np.linalg.norm(normal)
259
- surface_normals.append(normal)
260
-
261
- # Limit the number of points to avoid OOM
262
- max_points = 5000
263
- if len(surface_points) > max_points:
264
- indices = np.random.choice(len(surface_points), max_points, replace=False)
265
- surface_points = [surface_points[i] for i in indices]
266
- surface_normals = [surface_normals[i] for i in indices]
267
-
268
- if len(surface_points) < 4:
269
- # Not enough points found, create a simple cube
270
- vertices = np.array([
271
- [-0.5, -0.5, -0.5],
272
- [0.5, -0.5, -0.5],
273
- [0.5, 0.5, -0.5],
274
- [-0.5, 0.5, -0.5],
275
- [-0.5, -0.5, 0.5],
276
- [0.5, -0.5, 0.5],
277
- [0.5, 0.5, 0.5],
278
- [-0.5, 0.5, 0.5]
279
- ])
280
- faces = np.array([
281
- [0, 1, 2], [0, 2, 3], # Bottom face
282
- [4, 5, 6], [4, 6, 7], # Top face
283
- [0, 1, 5], [0, 5, 4], # Front face
284
- [2, 3, 7], [2, 7, 6], # Back face
285
- [0, 3, 7], [0, 7, 4], # Left face
286
- [1, 2, 6], [1, 6, 5] # Right face
287
- ])
288
- return vertices, faces
289
-
290
- # Convert points to numpy array
291
- points = np.array(surface_points)
292
-
293
- # Create a simple mesh using ball-pivoting like algorithm
294
- vertices = points
295
-
296
- # For simplicity, create faces from nearest neighbors
297
- # This is a very simple approach, not as good as Delaunay but doesn't require scipy
298
- faces = []
299
-
300
- # Use a simple approach to creating faces
301
- # We'll use a greedy algorithm to connect nearby points
302
- n_points = len(vertices)
303
-
304
- # Create a simple connectivity graph
305
- # For each point, find the N closest points
306
- n_neighbors = min(12, n_points - 1)
307
- adjacency = [[] for _ in range(n_points)]
308
-
309
- # Compute all pairwise distances (this is O(n²) but should be ok for small point clouds)
310
- for i in range(n_points):
311
- distances = []
312
- for j in range(n_points):
313
- if i != j:
314
- dist = np.linalg.norm(vertices[i] - vertices[j])
315
- distances.append((dist, j))
316
- distances.sort()
317
- for k in range(min(n_neighbors, len(distances))):
318
- adjacency[i].append(distances[k][1])
319
-
320
- # Create triangles from the adjacency list
321
- added_edges = set()
322
- for i in range(n_points):
323
- for j in adjacency[i]:
324
- if j > i: # Avoid duplicates
325
- edge_ij = (i, j)
326
- added_edges.add(edge_ij)
327
-
328
- # Find common neighbors between i and j to form triangles
329
- common_neighbors = set(adjacency[i]) & set(adjacency[j])
330
- for k in common_neighbors:
331
- if k != i and k != j:
332
- edge_ik = (i, k) if i < k else (k, i)
333
- edge_jk = (j, k) if j < k else (k, j)
334
-
335
- # Check if the other two edges exist
336
- if edge_ik in added_edges and edge_jk in added_edges:
337
- # Ensure consistent winding
338
- normal = np.cross(vertices[j] - vertices[i], vertices[k] - vertices[i])
339
- center_normal = np.mean(surface_normals, axis=0)
340
-
341
- if np.dot(normal, center_normal) > 0:
342
- faces.append([i, j, k])
343
- else:
344
- faces.append([i, k, j])
345
-
346
- # If we couldn't create enough faces, create a simple shape
347
- if len(faces) < 4:
348
- # Create a convex hull-like shape
349
- center = np.mean(vertices, axis=0)
350
- n_points = len(vertices)
351
-
352
- # Sort points by distance from center
353
- dists = np.linalg.norm(vertices - center, axis=1)
354
- sorted_indices = np.argsort(dists)
355
-
356
- # Create a star-like structure connecting to center
357
- center_idx = sorted_indices[0]
358
- faces = []
359
-
360
- for i in range(1, min(n_points, 10)):
361
- if i + 1 < n_points:
362
- faces.append([center_idx, sorted_indices[i], sorted_indices[i + 1]])
363
- else:
364
- faces.append([center_idx, sorted_indices[i], sorted_indices[1]])
365
-
366
- # If still not enough, create a simple cube
367
- if len(faces) < 4:
368
- vertices = np.array([
369
- [-0.5, -0.5, -0.5],
370
- [0.5, -0.5, -0.5],
371
- [0.5, 0.5, -0.5],
372
- [-0.5, 0.5, -0.5],
373
- [-0.5, -0.5, 0.5],
374
- [0.5, -0.5, 0.5],
375
- [0.5, 0.5, 0.5],
376
- [-0.5, 0.5, 0.5]
377
- ])
378
- faces = np.array([
379
- [0, 1, 2], [0, 2, 3], # Bottom face
380
- [4, 5, 6], [4, 6, 7], # Top face
381
- [0, 1, 5], [0, 5, 4], # Front face
382
- [2, 3, 7], [2, 7, 6], # Back face
383
- [0, 3, 7], [0, 7, 4], # Left face
384
- [1, 2, 6], [1, 6, 5] # Right face
385
- ])
386
-
387
- return np.array(vertices), np.array(faces)
388
 
389
  def load_model():
390
- global neus_model, model_loaded, model_loading
391
 
392
  if model_loaded:
393
- return neus_model
394
 
395
  if model_loading:
396
  # Wait for model to load if it's already in progress
397
  while model_loading and not model_loaded:
398
  time.sleep(0.5)
399
- return neus_model
400
 
401
  try:
402
  model_loading = True
403
  print("Starting model loading...")
404
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
405
  device = "cuda" if torch.cuda.is_available() else "cpu"
406
- neus_model = NeuS2Model(device=device)
 
 
 
 
 
 
 
 
 
 
 
407
 
408
  model_loaded = True
409
  print(f"Model loaded successfully on {device}")
410
- return neus_model
411
 
412
  except Exception as e:
413
  print(f"Error loading model: {str(e)}")
@@ -420,7 +175,7 @@ def load_model():
420
  def health_check():
421
  return jsonify({
422
  "status": "healthy",
423
- "model": "NeuS2 Image to 3D",
424
  "device": "cuda" if torch.cuda.is_available() else "cpu"
425
  }), 200
426
 
@@ -463,291 +218,6 @@ def progress(job_id):
463
 
464
  return Response(stream_with_context(generate()), mimetype='text/event-stream')
465
 
466
- def export_to_obj(mesh, obj_path):
467
- """Export mesh to OBJ file format"""
468
- vertices = mesh.verts
469
- faces = mesh.faces
470
-
471
- with open(obj_path, 'w') as f:
472
- # Write vertices
473
- for v in vertices:
474
- f.write(f"v {v[0]} {v[1]} {v[2]}\n")
475
-
476
- # Write faces (OBJ uses 1-indexed vertices)
477
- for face in faces:
478
- f.write(f"f {face[0]+1} {face[1]+1} {face[2]+1}\n")
479
-
480
- # Create a simple MTL file
481
- mtl_path = obj_path.replace('.obj', '.mtl')
482
- with open(mtl_path, 'w') as f:
483
- f.write("newmtl material0\n")
484
- f.write("Ka 1.0 1.0 1.0\n") # ambient color
485
- f.write("Kd 0.8 0.8 0.8\n") # diffuse color
486
- f.write("Ks 0.0 0.0 0.0\n") # specular color
487
- f.write("Ns 0.0\n") # specular exponent
488
- f.write("illum 2\n") # illumination model
489
-
490
- return obj_path, mtl_path
491
-
492
- @app.route('/convert', methods=['POST'])
493
- def convert_image_to_3d():
494
- # Check if image is in the request
495
- if 'image' not in request.files:
496
- return jsonify({"error": "No image provided"}), 400
497
-
498
- file = request.files['image']
499
- if file.filename == '':
500
- return jsonify({"error": "No image selected"}), 400
501
-
502
- if not allowed_file(file.filename):
503
- return jsonify({"error": f"File type not allowed. Supported types: {', '.join(ALLOWED_EXTENSIONS)}"}), 400
504
-
505
- # Get optional parameters with defaults
506
- try:
507
- guidance_scale = float(request.form.get('guidance_scale', 3.0))
508
- num_inference_steps = min(int(request.form.get('num_inference_steps', 32)), MAX_INFERENCE_STEPS)
509
- output_format = request.form.get('output_format', 'obj').lower()
510
- except ValueError:
511
- return jsonify({"error": "Invalid parameter values"}), 400
512
-
513
- # Validate parameters
514
- if guidance_scale < 1.0 or guidance_scale > 5.0:
515
- return jsonify({"error": "Guidance scale must be between 1.0 and 5.0"}), 400
516
-
517
- if num_inference_steps < 16 or num_inference_steps > MAX_INFERENCE_STEPS:
518
- num_inference_steps = min(num_inference_steps, MAX_INFERENCE_STEPS)
519
-
520
- # Validate output format
521
- if output_format not in ['obj', 'glb']:
522
- return jsonify({"error": "Unsupported output format. Use 'obj' or 'glb'"}), 400
523
-
524
- # Create a job ID
525
- job_id = str(uuid.uuid4())
526
- output_dir = os.path.join(RESULTS_FOLDER, job_id)
527
- os.makedirs(output_dir, exist_ok=True)
528
-
529
- # Save the uploaded file
530
- filename = secure_filename(file.filename)
531
- filepath = os.path.join(app.config['UPLOAD_FOLDER'], f"{job_id}_{filename}")
532
- file.save(filepath)
533
-
534
- # Initialize job tracking
535
- processing_jobs[job_id] = {
536
- 'status': 'processing',
537
- 'progress': 0,
538
- 'result_url': None,
539
- 'preview_url': None,
540
- 'error': None,
541
- 'output_format': output_format,
542
- 'created_at': time.time()
543
- }
544
-
545
- # Start processing in a separate thread
546
- def process_image():
547
- thread = threading.current_thread()
548
- processing_jobs[job_id]['thread_alive'] = lambda: thread.is_alive()
549
-
550
- try:
551
- # Preprocess image (resize if needed)
552
- processing_jobs[job_id]['progress'] = 5
553
- image_tensor = preprocess_image(filepath)
554
- processing_jobs[job_id]['progress'] = 10
555
-
556
- # Load model
557
- try:
558
- model = load_model()
559
- processing_jobs[job_id]['progress'] = 30
560
- except Exception as e:
561
- processing_jobs[job_id]['status'] = 'error'
562
- processing_jobs[job_id]['error'] = f"Error loading model: {str(e)}"
563
- return
564
-
565
- # Process image with thread-safe timeout
566
- try:
567
- def generate_mesh():
568
- return model.generate_mesh(
569
- image_tensor,
570
- resolution=min(32 + num_inference_steps // 2, 64), # Adjust resolution based on steps
571
- threshold=0.0,
572
- num_steps=num_inference_steps
573
- )
574
-
575
- images, error = process_with_timeout(generate_mesh, [], TIMEOUT_SECONDS)
576
-
577
- if error:
578
- if isinstance(error, TimeoutError):
579
- processing_jobs[job_id]['status'] = 'error'
580
- processing_jobs[job_id]['error'] = f"Processing timed out after {TIMEOUT_SECONDS} seconds"
581
- return
582
- else:
583
- raise error
584
-
585
- processing_jobs[job_id]['progress'] = 80
586
- except Exception as e:
587
- error_details = traceback.format_exc()
588
- processing_jobs[job_id]['status'] = 'error'
589
- processing_jobs[job_id]['error'] = f"Error during processing: {str(e)}"
590
- print(f"Error processing job {job_id}: {str(e)}")
591
- print(error_details)
592
- return
593
-
594
- # Export based on requested format
595
- try:
596
- if output_format == 'obj':
597
- obj_path = os.path.join(output_dir, "model.obj")
598
- obj_path, mtl_path = export_to_obj(images[0], obj_path)
599
-
600
- # Create a zip file with OBJ and MTL
601
- zip_path = os.path.join(output_dir, "model.zip")
602
- with zipfile.ZipFile(zip_path, 'w') as zipf:
603
- zipf.write(obj_path, arcname="model.obj")
604
- if os.path.exists(mtl_path):
605
- zipf.write(mtl_path, arcname="model.mtl")
606
-
607
- processing_jobs[job_id]['result_url'] = f"/download/{job_id}"
608
- processing_jobs[job_id]['preview_url'] = f"/preview/{job_id}"
609
-
610
- elif output_format == 'glb':
611
- # Convert to trimesh format
612
- vertices = images[0].verts
613
- faces = images[0].faces
614
-
615
- # Create a trimesh object
616
- trimesh_obj = trimesh.Trimesh(vertices=vertices, faces=faces)
617
-
618
- # Export as GLB
619
- glb_path = os.path.join(output_dir, "model.glb")
620
- trimesh_obj.export(glb_path)
621
-
622
- processing_jobs[job_id]['result_url'] = f"/download/{job_id}"
623
- processing_jobs[job_id]['preview_url'] = f"/preview/{job_id}"
624
-
625
- # Update job status
626
- processing_jobs[job_id]['status'] = 'completed'
627
- processing_jobs[job_id]['progress'] = 100
628
- print(f"Job {job_id} completed successfully")
629
- except Exception as e:
630
- error_details = traceback.format_exc()
631
- processing_jobs[job_id]['status'] = 'error'
632
- processing_jobs[job_id]['error'] = f"Error exporting model: {str(e)}"
633
- print(f"Error exporting model for job {job_id}: {str(e)}")
634
- print(error_details)
635
-
636
- # Clean up temporary file
637
- if os.path.exists(filepath):
638
- os.remove(filepath)
639
-
640
- # Force garbage collection to free memory
641
- gc.collect()
642
- if torch.cuda.is_available():
643
- torch.cuda.empty_cache()
644
-
645
- except Exception as e:
646
- # Handle errors
647
- error_details = traceback.format_exc()
648
- processing_jobs[job_id]['status'] = 'error'
649
- processing_jobs[job_id]['error'] = f"{str(e)}\n{error_details}"
650
- print(f"Error processing job {job_id}: {str(e)}")
651
- print(error_details)
652
-
653
- # Clean up on error
654
- if os.path.exists(filepath):
655
- os.remove(filepath)
656
-
657
- # Start processing thread
658
- processing_thread = threading.Thread(target=process_image)
659
- processing_thread.daemon = True
660
- processing_thread.start()
661
-
662
- # Return job ID immediately
663
- return jsonify({"job_id": job_id}), 202 # 202 Accepted
664
-
665
- @app.route('/download/<job_id>', methods=['GET'])
666
- def download_model(job_id):
667
- if job_id not in processing_jobs or processing_jobs[job_id]['status'] != 'completed':
668
- return jsonify({"error": "Model not found or processing not complete"}), 404
669
-
670
- # Get the output directory for this job
671
- output_dir = os.path.join(RESULTS_FOLDER, job_id)
672
-
673
- # Determine file format from the job data
674
- output_format = processing_jobs[job_id].get('output_format', 'obj')
675
-
676
- if output_format == 'obj':
677
- zip_path = os.path.join(output_dir, "model.zip")
678
- if os.path.exists(zip_path):
679
- return send_file(zip_path, as_attachment=True, download_name="model.zip")
680
- else: # glb
681
- glb_path = os.path.join(output_dir, "model.glb")
682
- if os.path.exists(glb_path):
683
- return send_file(glb_path, as_attachment=True, download_name="model.glb")
684
-
685
- return jsonify({"error": "File not found"}), 404
686
- @app.route('/progress/<job_id>', methods=['GET'])
687
- def progress(job_id):
688
- def generate():
689
- if job_id not in processing_jobs:
690
- yield f"data: {json.dumps({'error': 'Job not found'})}\n\n"
691
- return
692
-
693
- job = processing_jobs[job_id]
694
-
695
- # Send initial progress
696
- yield f"data: {json.dumps({'status': 'processing', 'progress': job['progress']})}\n\n"
697
-
698
- # Wait for job to complete or update
699
- last_progress = job['progress']
700
- check_count = 0
701
- while job['status'] == 'processing':
702
- if job['progress'] != last_progress:
703
- yield f"data: {json.dumps({'status': 'processing', 'progress': job['progress']})}\n\n"
704
- last_progress = job['progress']
705
-
706
- time.sleep(0.5)
707
- check_count += 1
708
-
709
- # If client hasn't received updates for a while, check if job is still running
710
- if check_count > 60: # 30 seconds with no updates
711
- if 'thread_alive' in job and not job['thread_alive']():
712
- job['status'] = 'error'
713
- job['error'] = 'Processing thread died unexpectedly'
714
- break
715
- check_count = 0
716
-
717
- # Send final status
718
- if job['status'] == 'completed':
719
- yield f"data: {json.dumps({'status': 'completed', 'progress': 100, 'result_url': job['result_url'], 'preview_url': job['preview_url']})}\n\n"
720
- else:
721
- yield f"data: {json.dumps({'status': 'error', 'error': job['error']})}\n\n"
722
-
723
- return Response(stream_with_context(generate()), mimetype='text/event-stream')
724
-
725
- def export_to_obj(mesh, obj_path):
726
- """Export mesh to OBJ file format"""
727
- vertices = mesh.verts
728
- faces = mesh.faces
729
-
730
- with open(obj_path, 'w') as f:
731
- # Write vertices
732
- for v in vertices:
733
- f.write(f"v {v[0]} {v[1]} {v[2]}\n")
734
-
735
- # Write faces (OBJ uses 1-indexed vertices)
736
- for face in faces:
737
- f.write(f"f {face[0]+1} {face[1]+1} {face[2]+1}\n")
738
-
739
- # Create a simple MTL file
740
- mtl_path = obj_path.replace('.obj', '.mtl')
741
- with open(mtl_path, 'w') as f:
742
- f.write("newmtl material0\n")
743
- f.write("Ka 1.0 1.0 1.0\n") # ambient color
744
- f.write("Kd 0.8 0.8 0.8\n") # diffuse color
745
- f.write("Ks 0.0 0.0 0.0\n") # specular color
746
- f.write("Ns 0.0\n") # specular exponent
747
- f.write("illum 2\n") # illumination model
748
-
749
- return obj_path, mtl_path
750
-
751
  @app.route('/convert', methods=['POST'])
752
  def convert_image_to_3d():
753
  # Check if image is in the request
@@ -764,7 +234,7 @@ def convert_image_to_3d():
764
  # Get optional parameters with defaults
765
  try:
766
  guidance_scale = float(request.form.get('guidance_scale', 3.0))
767
- num_inference_steps = min(int(request.form.get('num_inference_steps', 32)), MAX_INFERENCE_STEPS)
768
  output_format = request.form.get('output_format', 'obj').lower()
769
  except ValueError:
770
  return jsonify({"error": "Invalid parameter values"}), 400
@@ -773,7 +243,7 @@ def convert_image_to_3d():
773
  if guidance_scale < 1.0 or guidance_scale > 5.0:
774
  return jsonify({"error": "Guidance scale must be between 1.0 and 5.0"}), 400
775
 
776
- if num_inference_steps < 16 or num_inference_steps > MAX_INFERENCE_STEPS:
777
  num_inference_steps = min(num_inference_steps, MAX_INFERENCE_STEPS)
778
 
779
  # Validate output format
@@ -809,12 +279,12 @@ def convert_image_to_3d():
809
  try:
810
  # Preprocess image (resize if needed)
811
  processing_jobs[job_id]['progress'] = 5
812
- image_tensor = preprocess_image(filepath)
813
  processing_jobs[job_id]['progress'] = 10
814
 
815
  # Load model
816
  try:
817
- model = load_model()
818
  processing_jobs[job_id]['progress'] = 30
819
  except Exception as e:
820
  processing_jobs[job_id]['status'] = 'error'
@@ -824,12 +294,12 @@ def convert_image_to_3d():
824
  # Process image with thread-safe timeout
825
  try:
826
  def generate_mesh():
827
- return model.generate_mesh(
828
- image_tensor,
829
- resolution=min(32 + num_inference_steps, 64), # Adjust resolution based on steps
830
- threshold=0.0,
831
- num_steps=num_inference_steps
832
- )
833
 
834
  images, error = process_with_timeout(generate_mesh, [], TIMEOUT_SECONDS)
835
 
@@ -854,12 +324,13 @@ def convert_image_to_3d():
854
  try:
855
  if output_format == 'obj':
856
  obj_path = os.path.join(output_dir, "model.obj")
857
- obj_path, mtl_path = export_to_obj(images[0], obj_path)
858
 
859
  # Create a zip file with OBJ and MTL
860
  zip_path = os.path.join(output_dir, "model.zip")
861
  with zipfile.ZipFile(zip_path, 'w') as zipf:
862
  zipf.write(obj_path, arcname="model.obj")
 
863
  if os.path.exists(mtl_path):
864
  zipf.write(mtl_path, arcname="model.mtl")
865
 
@@ -867,12 +338,13 @@ def convert_image_to_3d():
867
  processing_jobs[job_id]['preview_url'] = f"/preview/{job_id}"
868
 
869
  elif output_format == 'glb':
870
- # Convert to trimesh format
871
- vertices = images[0].verts
872
- faces = images[0].faces
 
873
 
874
  # Create a trimesh object
875
- trimesh_obj = trimesh.Trimesh(vertices=vertices, faces=faces)
876
 
877
  # Export as GLB
878
  glb_path = os.path.join(output_dir, "model.glb")
@@ -996,7 +468,7 @@ def cleanup_old_jobs():
996
  @app.route('/', methods=['GET'])
997
  def index():
998
  return jsonify({
999
- "message": "Image to 3D API using NeuS2 is running",
1000
  "endpoints": ["/convert", "/progress/<job_id>", "/download/<job_id>", "/preview/<job_id>"]
1001
  }), 200
1002
 
 
11
  import zipfile
12
  import uuid
13
  import traceback
14
+ from diffusers import ShapEImg2ImgPipeline
15
+ from diffusers.utils import export_to_obj
16
  from huggingface_hub import snapshot_download
17
  from flask_cors import CORS
18
  import functools
 
 
19
 
20
  app = Flask(__name__)
21
  CORS(app) # Enable CORS for all routes
 
43
  processing_jobs = {}
44
 
45
  # Global model variable
46
+ pipe = None
47
  model_loaded = False
48
  model_loading = False
49
 
50
  # Configuration for processing
51
+ TIMEOUT_SECONDS = 300 # 5 minutes max for processing
52
  MAX_DIMENSION = 512 # Max image dimension to process
53
+ MAX_INFERENCE_STEPS = 64 # Maximum allowed inference steps to prevent the index error
54
 
55
  # TimeoutError for handling timeouts
56
  class TimeoutError(Exception):
 
104
  new_width = int(img.width * (MAX_DIMENSION / img.height))
105
  img = img.resize((new_width, new_height), Image.LANCZOS)
106
 
107
+ # Convert to RGB and return
108
+ return img
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
 
110
  def load_model():
111
+ global pipe, model_loaded, model_loading
112
 
113
  if model_loaded:
114
+ return pipe
115
 
116
  if model_loading:
117
  # Wait for model to load if it's already in progress
118
  while model_loading and not model_loaded:
119
  time.sleep(0.5)
120
+ return pipe
121
 
122
  try:
123
  model_loading = True
124
  print("Starting model loading...")
125
 
126
+ model_name = "openai/shap-e-img2img"
127
+
128
+ # Download model with retry mechanism
129
+ max_retries = 3
130
+ retry_delay = 5
131
+
132
+ for attempt in range(max_retries):
133
+ try:
134
+ snapshot_download(
135
+ repo_id=model_name,
136
+ cache_dir=CACHE_DIR,
137
+ resume_download=True,
138
+ )
139
+ break
140
+ except Exception as e:
141
+ if attempt < max_retries - 1:
142
+ print(f"Download attempt {attempt+1} failed: {str(e)}. Retrying in {retry_delay} seconds...")
143
+ time.sleep(retry_delay)
144
+ retry_delay *= 2
145
+ else:
146
+ raise
147
+
148
+ # Initialize pipeline with lower precision to save memory
149
  device = "cuda" if torch.cuda.is_available() else "cpu"
150
+ dtype = torch.float16 if device == "cuda" else torch.float32
151
+
152
+ pipe = ShapEImg2ImgPipeline.from_pretrained(
153
+ model_name,
154
+ torch_dtype=dtype,
155
+ cache_dir=CACHE_DIR,
156
+ )
157
+ pipe = pipe.to(device)
158
+
159
+ # Optimize for inference
160
+ if device == "cuda":
161
+ pipe.enable_model_cpu_offload()
162
 
163
  model_loaded = True
164
  print(f"Model loaded successfully on {device}")
165
+ return pipe
166
 
167
  except Exception as e:
168
  print(f"Error loading model: {str(e)}")
 
175
  def health_check():
176
  return jsonify({
177
  "status": "healthy",
178
+ "model": "Shap-E Image to 3D",
179
  "device": "cuda" if torch.cuda.is_available() else "cpu"
180
  }), 200
181
 
 
218
 
219
  return Response(stream_with_context(generate()), mimetype='text/event-stream')
220
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
221
  @app.route('/convert', methods=['POST'])
222
  def convert_image_to_3d():
223
  # Check if image is in the request
 
234
  # Get optional parameters with defaults
235
  try:
236
  guidance_scale = float(request.form.get('guidance_scale', 3.0))
237
+ num_inference_steps = min(int(request.form.get('num_inference_steps', 64)), MAX_INFERENCE_STEPS)
238
  output_format = request.form.get('output_format', 'obj').lower()
239
  except ValueError:
240
  return jsonify({"error": "Invalid parameter values"}), 400
 
243
  if guidance_scale < 1.0 or guidance_scale > 5.0:
244
  return jsonify({"error": "Guidance scale must be between 1.0 and 5.0"}), 400
245
 
246
+ if num_inference_steps < 32 or num_inference_steps > MAX_INFERENCE_STEPS:
247
  num_inference_steps = min(num_inference_steps, MAX_INFERENCE_STEPS)
248
 
249
  # Validate output format
 
279
  try:
280
  # Preprocess image (resize if needed)
281
  processing_jobs[job_id]['progress'] = 5
282
+ image = preprocess_image(filepath)
283
  processing_jobs[job_id]['progress'] = 10
284
 
285
  # Load model
286
  try:
287
+ pipe = load_model()
288
  processing_jobs[job_id]['progress'] = 30
289
  except Exception as e:
290
  processing_jobs[job_id]['status'] = 'error'
 
294
  # Process image with thread-safe timeout
295
  try:
296
  def generate_mesh():
297
+ return pipe(
298
+ image,
299
+ guidance_scale=guidance_scale,
300
+ num_inference_steps=num_inference_steps,
301
+ output_type="mesh",
302
+ ).images
303
 
304
  images, error = process_with_timeout(generate_mesh, [], TIMEOUT_SECONDS)
305
 
 
324
  try:
325
  if output_format == 'obj':
326
  obj_path = os.path.join(output_dir, "model.obj")
327
+ export_to_obj(images[0], obj_path)
328
 
329
  # Create a zip file with OBJ and MTL
330
  zip_path = os.path.join(output_dir, "model.zip")
331
  with zipfile.ZipFile(zip_path, 'w') as zipf:
332
  zipf.write(obj_path, arcname="model.obj")
333
+ mtl_path = os.path.join(output_dir, "model.mtl")
334
  if os.path.exists(mtl_path):
335
  zipf.write(mtl_path, arcname="model.mtl")
336
 
 
338
  processing_jobs[job_id]['preview_url'] = f"/preview/{job_id}"
339
 
340
  elif output_format == 'glb':
341
+ from trimesh import Trimesh
342
+ mesh = images[0]
343
+ vertices = mesh.verts
344
+ faces = mesh.faces
345
 
346
  # Create a trimesh object
347
+ trimesh_obj = Trimesh(vertices=vertices, faces=faces)
348
 
349
  # Export as GLB
350
  glb_path = os.path.join(output_dir, "model.glb")
 
468
  @app.route('/', methods=['GET'])
469
  def index():
470
  return jsonify({
471
+ "message": "Image to 3D API is running",
472
  "endpoints": ["/convert", "/progress/<job_id>", "/download/<job_id>", "/preview/<job_id>"]
473
  }), 200
474