mac9087 commited on
Commit
74d32e4
·
verified ·
1 Parent(s): 2057821

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +192 -65
app.py CHANGED
@@ -11,11 +11,12 @@ import io
11
  import zipfile
12
  import uuid
13
  import traceback
14
- from diffusers import ShapEImg2ImgPipeline
15
- from diffusers.utils import export_to_obj
16
  from huggingface_hub import snapshot_download
17
  from flask_cors import CORS
18
  import functools
 
 
 
19
 
20
  app = Flask(__name__)
21
  CORS(app) # Enable CORS for all routes
@@ -43,14 +44,14 @@ app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024 # 16MB max
43
  processing_jobs = {}
44
 
45
  # Global model variable
46
- pipe = None
47
  model_loaded = False
48
  model_loading = False
49
 
50
  # Configuration for processing
51
- TIMEOUT_SECONDS = 300 # 5 minutes max for processing
52
  MAX_DIMENSION = 512 # Max image dimension to process
53
- MAX_INFERENCE_STEPS = 64 # Maximum allowed inference steps to prevent the index error
54
 
55
  # TimeoutError for handling timeouts
56
  class TimeoutError(Exception):
@@ -104,65 +105,167 @@ def preprocess_image(image_path):
104
  new_width = int(img.width * (MAX_DIMENSION / img.height))
105
  img = img.resize((new_width, new_height), Image.LANCZOS)
106
 
107
- # Convert to RGB and return
108
- return img
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
 
110
  def load_model():
111
- global pipe, model_loaded, model_loading
112
 
113
  if model_loaded:
114
- return pipe
115
 
116
  if model_loading:
117
  # Wait for model to load if it's already in progress
118
  while model_loading and not model_loaded:
119
  time.sleep(0.5)
120
- return pipe
121
 
122
  try:
123
  model_loading = True
124
  print("Starting model loading...")
125
 
126
- model_name = "openai/shap-e-img2img"
127
-
128
- # Download model with retry mechanism
129
- max_retries = 3
130
- retry_delay = 5
131
-
132
- for attempt in range(max_retries):
133
- try:
134
- snapshot_download(
135
- repo_id=model_name,
136
- cache_dir=CACHE_DIR,
137
- resume_download=True,
138
- )
139
- break
140
- except Exception as e:
141
- if attempt < max_retries - 1:
142
- print(f"Download attempt {attempt+1} failed: {str(e)}. Retrying in {retry_delay} seconds...")
143
- time.sleep(retry_delay)
144
- retry_delay *= 2
145
- else:
146
- raise
147
-
148
- # Initialize pipeline with lower precision to save memory
149
  device = "cuda" if torch.cuda.is_available() else "cpu"
150
- dtype = torch.float16 if device == "cuda" else torch.float32
151
-
152
- pipe = ShapEImg2ImgPipeline.from_pretrained(
153
- model_name,
154
- torch_dtype=dtype,
155
- cache_dir=CACHE_DIR,
156
- )
157
- pipe = pipe.to(device)
158
-
159
- # Optimize for inference
160
- if device == "cuda":
161
- pipe.enable_model_cpu_offload()
162
 
163
  model_loaded = True
164
  print(f"Model loaded successfully on {device}")
165
- return pipe
166
 
167
  except Exception as e:
168
  print(f"Error loading model: {str(e)}")
@@ -175,7 +278,7 @@ def load_model():
175
  def health_check():
176
  return jsonify({
177
  "status": "healthy",
178
- "model": "Shap-E Image to 3D",
179
  "device": "cuda" if torch.cuda.is_available() else "cpu"
180
  }), 200
181
 
@@ -218,6 +321,32 @@ def progress(job_id):
218
 
219
  return Response(stream_with_context(generate()), mimetype='text/event-stream')
220
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
221
  @app.route('/convert', methods=['POST'])
222
  def convert_image_to_3d():
223
  # Check if image is in the request
@@ -234,7 +363,7 @@ def convert_image_to_3d():
234
  # Get optional parameters with defaults
235
  try:
236
  guidance_scale = float(request.form.get('guidance_scale', 3.0))
237
- num_inference_steps = min(int(request.form.get('num_inference_steps', 64)), MAX_INFERENCE_STEPS)
238
  output_format = request.form.get('output_format', 'obj').lower()
239
  except ValueError:
240
  return jsonify({"error": "Invalid parameter values"}), 400
@@ -243,7 +372,7 @@ def convert_image_to_3d():
243
  if guidance_scale < 1.0 or guidance_scale > 5.0:
244
  return jsonify({"error": "Guidance scale must be between 1.0 and 5.0"}), 400
245
 
246
- if num_inference_steps < 32 or num_inference_steps > MAX_INFERENCE_STEPS:
247
  num_inference_steps = min(num_inference_steps, MAX_INFERENCE_STEPS)
248
 
249
  # Validate output format
@@ -279,12 +408,12 @@ def convert_image_to_3d():
279
  try:
280
  # Preprocess image (resize if needed)
281
  processing_jobs[job_id]['progress'] = 5
282
- image = preprocess_image(filepath)
283
  processing_jobs[job_id]['progress'] = 10
284
 
285
  # Load model
286
  try:
287
- pipe = load_model()
288
  processing_jobs[job_id]['progress'] = 30
289
  except Exception as e:
290
  processing_jobs[job_id]['status'] = 'error'
@@ -294,12 +423,12 @@ def convert_image_to_3d():
294
  # Process image with thread-safe timeout
295
  try:
296
  def generate_mesh():
297
- return pipe(
298
- image,
299
- guidance_scale=guidance_scale,
300
- num_inference_steps=num_inference_steps,
301
- output_type="mesh",
302
- ).images
303
 
304
  images, error = process_with_timeout(generate_mesh, [], TIMEOUT_SECONDS)
305
 
@@ -324,13 +453,12 @@ def convert_image_to_3d():
324
  try:
325
  if output_format == 'obj':
326
  obj_path = os.path.join(output_dir, "model.obj")
327
- export_to_obj(images[0], obj_path)
328
 
329
  # Create a zip file with OBJ and MTL
330
  zip_path = os.path.join(output_dir, "model.zip")
331
  with zipfile.ZipFile(zip_path, 'w') as zipf:
332
  zipf.write(obj_path, arcname="model.obj")
333
- mtl_path = os.path.join(output_dir, "model.mtl")
334
  if os.path.exists(mtl_path):
335
  zipf.write(mtl_path, arcname="model.mtl")
336
 
@@ -338,13 +466,12 @@ def convert_image_to_3d():
338
  processing_jobs[job_id]['preview_url'] = f"/preview/{job_id}"
339
 
340
  elif output_format == 'glb':
341
- from trimesh import Trimesh
342
- mesh = images[0]
343
- vertices = mesh.verts
344
- faces = mesh.faces
345
 
346
  # Create a trimesh object
347
- trimesh_obj = Trimesh(vertices=vertices, faces=faces)
348
 
349
  # Export as GLB
350
  glb_path = os.path.join(output_dir, "model.glb")
@@ -468,7 +595,7 @@ def cleanup_old_jobs():
468
  @app.route('/', methods=['GET'])
469
  def index():
470
  return jsonify({
471
- "message": "Image to 3D API is running",
472
  "endpoints": ["/convert", "/progress/<job_id>", "/download/<job_id>", "/preview/<job_id>"]
473
  }), 200
474
 
 
11
  import zipfile
12
  import uuid
13
  import traceback
 
 
14
  from huggingface_hub import snapshot_download
15
  from flask_cors import CORS
16
  import functools
17
+ import numpy as np
18
+ import trimesh
19
+ from scipy.spatial import Delaunay
20
 
21
  app = Flask(__name__)
22
  CORS(app) # Enable CORS for all routes
 
44
  processing_jobs = {}
45
 
46
  # Global model variable
47
+ neus_model = None
48
  model_loaded = False
49
  model_loading = False
50
 
51
  # Configuration for processing
52
+ TIMEOUT_SECONDS = 180 # 3 minutes max for processing
53
  MAX_DIMENSION = 512 # Max image dimension to process
54
+ MAX_INFERENCE_STEPS = 32 # Maximum allowed inference steps
55
 
56
  # TimeoutError for handling timeouts
57
  class TimeoutError(Exception):
 
105
  new_width = int(img.width * (MAX_DIMENSION / img.height))
106
  img = img.resize((new_width, new_height), Image.LANCZOS)
107
 
108
+ # Convert to RGB and convert to tensor
109
+ img_array = np.array(img) / 255.0 # Normalize to [0, 1]
110
+ img_tensor = torch.from_numpy(img_array).float().permute(2, 0, 1).unsqueeze(0) # [1, 3, H, W]
111
+ return img_tensor
112
+
113
+ # Simple NeuS2-inspired implementation for reconstructing 3D surfaces from images
114
+ class NeuS2Model:
115
+ def __init__(self, device="cuda" if torch.cuda.is_available() else "cpu"):
116
+ self.device = device
117
+ self.encoder = self._create_encoder().to(device)
118
+ self.volume_network = self._create_volume_network().to(device)
119
+
120
+ def _create_encoder(self):
121
+ # Simple convolutional encoder
122
+ return torch.nn.Sequential(
123
+ torch.nn.Conv2d(3, 32, 3, stride=2, padding=1),
124
+ torch.nn.ReLU(),
125
+ torch.nn.Conv2d(32, 64, 3, stride=2, padding=1),
126
+ torch.nn.ReLU(),
127
+ torch.nn.Conv2d(64, 128, 3, stride=2, padding=1),
128
+ torch.nn.ReLU(),
129
+ torch.nn.AdaptiveAvgPool2d((8, 8)),
130
+ torch.nn.Flatten(),
131
+ torch.nn.Linear(8192, 512)
132
+ )
133
+
134
+ def _create_volume_network(self):
135
+ # MLP to predict occupancy and SDF values
136
+ return torch.nn.Sequential(
137
+ torch.nn.Linear(515, 256), # 512 features + 3 coordinates
138
+ torch.nn.ReLU(),
139
+ torch.nn.Linear(256, 256),
140
+ torch.nn.ReLU(),
141
+ torch.nn.Linear(256, 1) # SDF value
142
+ )
143
+
144
+ def extract_features(self, image):
145
+ with torch.no_grad():
146
+ return self.encoder(image.to(self.device))
147
+
148
+ def query_points(self, points, features):
149
+ # points shape: [batch, num_points, 3]
150
+ # features shape: [batch, 512]
151
+ batch_size, num_points, _ = points.shape
152
+
153
+ # Expand features to match points
154
+ features = features.unsqueeze(1).expand(-1, num_points, -1) # [batch, num_points, 512]
155
+
156
+ # Concatenate points with features
157
+ points_features = torch.cat([points, features], dim=-1) # [batch, num_points, 515]
158
+ points_features = points_features.reshape(-1, 515) # [batch*num_points, 515]
159
+
160
+ # Query network
161
+ with torch.no_grad():
162
+ sdf = self.volume_network(points_features.to(self.device))
163
+
164
+ return sdf.reshape(batch_size, num_points, 1)
165
+
166
+ def generate_mesh(self, image, resolution=64, threshold=0.0, num_steps=16):
167
+ # Extract image features
168
+ features = self.extract_features(image) # [1, 512]
169
+
170
+ # Create grid points
171
+ x = torch.linspace(-1, 1, resolution)
172
+ y = torch.linspace(-1, 1, resolution)
173
+ z = torch.linspace(-1, 1, resolution)
174
+ grid_x, grid_y, grid_z = torch.meshgrid(x, y, z, indexing='ij')
175
+ points = torch.stack([grid_x, grid_y, grid_z], dim=-1).reshape(1, -1, 3) # [1, res^3, 3]
176
+
177
+ # Process in batches to avoid OOM
178
+ batch_size = 32768 # Adjust based on available memory
179
+ sdf_values = []
180
+
181
+ for i in range(0, points.shape[1], batch_size):
182
+ batch_points = points[:, i:i+batch_size]
183
+ batch_sdf = self.query_points(batch_points, features)
184
+ sdf_values.append(batch_sdf)
185
+
186
+ sdf_volume = torch.cat(sdf_values, dim=1).reshape(resolution, resolution, resolution).cpu().numpy()
187
+
188
+ # Extract mesh using marching cubes
189
+ vertices, faces = self._marching_cubes(sdf_volume, threshold)
190
+
191
+ # Create a mesh object with vertices and faces
192
+ mesh = type('Mesh', (), {'verts': vertices, 'faces': faces})
193
+
194
+ return [mesh] # Returning in list format to match ShapE's output format
195
+
196
+ def _marching_cubes(self, sdf_volume, threshold=0.0):
197
+ # Simple implementation using surface points and Delaunay triangulation
198
+ # For production, you'd want to use proper marching cubes from scikit-image
199
+
200
+ # Find points near the surface
201
+ x, y, z = np.where(np.abs(sdf_volume) < 0.1)
202
+
203
+ if len(x) < 4: # Need at least 4 points for Delaunay
204
+ # Create a simple cube if not enough points
205
+ vertices = np.array([
206
+ [-0.5, -0.5, -0.5],
207
+ [0.5, -0.5, -0.5],
208
+ [0.5, 0.5, -0.5],
209
+ [-0.5, 0.5, -0.5],
210
+ [-0.5, -0.5, 0.5],
211
+ [0.5, -0.5, 0.5],
212
+ [0.5, 0.5, 0.5],
213
+ [-0.5, 0.5, 0.5]
214
+ ])
215
+ faces = np.array([
216
+ [0, 1, 2], [0, 2, 3], # Bottom face
217
+ [4, 5, 6], [4, 6, 7], # Top face
218
+ [0, 1, 5], [0, 5, 4], # Front face
219
+ [2, 3, 7], [2, 7, 6], # Back face
220
+ [0, 3, 7], [0, 7, 4], # Left face
221
+ [1, 2, 6], [1, 6, 5] # Right face
222
+ ])
223
+ return vertices, faces
224
+
225
+ # Convert indices to 3D coordinates in [-1, 1] range
226
+ res = sdf_volume.shape[0]
227
+ points = np.stack([
228
+ 2 * x / res - 1,
229
+ 2 * y / res - 1,
230
+ 2 * z / res - 1
231
+ ], axis=1)
232
+
233
+ # Limit to a reasonable number of points for Delaunay
234
+ max_points = 1000
235
+ if len(points) > max_points:
236
+ indices = np.random.choice(len(points), max_points, replace=False)
237
+ points = points[indices]
238
+
239
+ try:
240
+ # Create triangular mesh using Delaunay
241
+ tri = Delaunay(points)
242
+ return points, tri.simplices
243
+ except Exception:
244
+ # Fallback to simple shape if Delaunay fails
245
+ return np.array([[-1, -1, -1], [1, -1, -1], [1, 1, -1], [-1, 1, -1]]), np.array([[0, 1, 2], [0, 2, 3]])
246
 
247
  def load_model():
248
+ global neus_model, model_loaded, model_loading
249
 
250
  if model_loaded:
251
+ return neus_model
252
 
253
  if model_loading:
254
  # Wait for model to load if it's already in progress
255
  while model_loading and not model_loaded:
256
  time.sleep(0.5)
257
+ return neus_model
258
 
259
  try:
260
  model_loading = True
261
  print("Starting model loading...")
262
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
263
  device = "cuda" if torch.cuda.is_available() else "cpu"
264
+ neus_model = NeuS2Model(device=device)
 
 
 
 
 
 
 
 
 
 
 
265
 
266
  model_loaded = True
267
  print(f"Model loaded successfully on {device}")
268
+ return neus_model
269
 
270
  except Exception as e:
271
  print(f"Error loading model: {str(e)}")
 
278
  def health_check():
279
  return jsonify({
280
  "status": "healthy",
281
+ "model": "NeuS2 Image to 3D",
282
  "device": "cuda" if torch.cuda.is_available() else "cpu"
283
  }), 200
284
 
 
321
 
322
  return Response(stream_with_context(generate()), mimetype='text/event-stream')
323
 
324
+ def export_to_obj(mesh, obj_path):
325
+ """Export mesh to OBJ file format"""
326
+ vertices = mesh.verts
327
+ faces = mesh.faces
328
+
329
+ with open(obj_path, 'w') as f:
330
+ # Write vertices
331
+ for v in vertices:
332
+ f.write(f"v {v[0]} {v[1]} {v[2]}\n")
333
+
334
+ # Write faces (OBJ uses 1-indexed vertices)
335
+ for face in faces:
336
+ f.write(f"f {face[0]+1} {face[1]+1} {face[2]+1}\n")
337
+
338
+ # Create a simple MTL file
339
+ mtl_path = obj_path.replace('.obj', '.mtl')
340
+ with open(mtl_path, 'w') as f:
341
+ f.write("newmtl material0\n")
342
+ f.write("Ka 1.0 1.0 1.0\n") # ambient color
343
+ f.write("Kd 0.8 0.8 0.8\n") # diffuse color
344
+ f.write("Ks 0.0 0.0 0.0\n") # specular color
345
+ f.write("Ns 0.0\n") # specular exponent
346
+ f.write("illum 2\n") # illumination model
347
+
348
+ return obj_path, mtl_path
349
+
350
  @app.route('/convert', methods=['POST'])
351
  def convert_image_to_3d():
352
  # Check if image is in the request
 
363
  # Get optional parameters with defaults
364
  try:
365
  guidance_scale = float(request.form.get('guidance_scale', 3.0))
366
+ num_inference_steps = min(int(request.form.get('num_inference_steps', 32)), MAX_INFERENCE_STEPS)
367
  output_format = request.form.get('output_format', 'obj').lower()
368
  except ValueError:
369
  return jsonify({"error": "Invalid parameter values"}), 400
 
372
  if guidance_scale < 1.0 or guidance_scale > 5.0:
373
  return jsonify({"error": "Guidance scale must be between 1.0 and 5.0"}), 400
374
 
375
+ if num_inference_steps < 16 or num_inference_steps > MAX_INFERENCE_STEPS:
376
  num_inference_steps = min(num_inference_steps, MAX_INFERENCE_STEPS)
377
 
378
  # Validate output format
 
408
  try:
409
  # Preprocess image (resize if needed)
410
  processing_jobs[job_id]['progress'] = 5
411
+ image_tensor = preprocess_image(filepath)
412
  processing_jobs[job_id]['progress'] = 10
413
 
414
  # Load model
415
  try:
416
+ model = load_model()
417
  processing_jobs[job_id]['progress'] = 30
418
  except Exception as e:
419
  processing_jobs[job_id]['status'] = 'error'
 
423
  # Process image with thread-safe timeout
424
  try:
425
  def generate_mesh():
426
+ return model.generate_mesh(
427
+ image_tensor,
428
+ resolution=min(32 + num_inference_steps, 64), # Adjust resolution based on steps
429
+ threshold=0.0,
430
+ num_steps=num_inference_steps
431
+ )
432
 
433
  images, error = process_with_timeout(generate_mesh, [], TIMEOUT_SECONDS)
434
 
 
453
  try:
454
  if output_format == 'obj':
455
  obj_path = os.path.join(output_dir, "model.obj")
456
+ obj_path, mtl_path = export_to_obj(images[0], obj_path)
457
 
458
  # Create a zip file with OBJ and MTL
459
  zip_path = os.path.join(output_dir, "model.zip")
460
  with zipfile.ZipFile(zip_path, 'w') as zipf:
461
  zipf.write(obj_path, arcname="model.obj")
 
462
  if os.path.exists(mtl_path):
463
  zipf.write(mtl_path, arcname="model.mtl")
464
 
 
466
  processing_jobs[job_id]['preview_url'] = f"/preview/{job_id}"
467
 
468
  elif output_format == 'glb':
469
+ # Convert to trimesh format
470
+ vertices = images[0].verts
471
+ faces = images[0].faces
 
472
 
473
  # Create a trimesh object
474
+ trimesh_obj = trimesh.Trimesh(vertices=vertices, faces=faces)
475
 
476
  # Export as GLB
477
  glb_path = os.path.join(output_dir, "model.glb")
 
595
  @app.route('/', methods=['GET'])
596
  def index():
597
  return jsonify({
598
+ "message": "Image to 3D API using NeuS2 is running",
599
  "endpoints": ["/convert", "/progress/<job_id>", "/download/<job_id>", "/preview/<job_id>"]
600
  }), 200
601