mac9087 commited on
Commit
5264cd4
·
verified ·
1 Parent(s): f71a5f4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -20
app.py CHANGED
@@ -15,7 +15,7 @@ from huggingface_hub import snapshot_download
15
  from flask_cors import CORS
16
  import numpy as np
17
  import trimesh
18
- from diffusers import Hunyuan3DDiTPipeline
19
 
20
  os.environ["CUDA_VISIBLE_DEVICES"] = ""
21
  torch.set_default_device("cpu")
@@ -42,12 +42,12 @@ app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
42
  app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024
43
 
44
  processing_jobs = {}
45
- hunyuan_pipeline = None
46
  model_loaded = False
47
  model_loading = False
48
 
49
  TIMEOUT_SECONDS = 300
50
- MAX_DIMENSION = 512 # Hunyuan3D-1.0 uses 512x512 inputs
51
 
52
  class TimeoutError(Exception):
53
  pass
@@ -94,21 +94,21 @@ def preprocess_image(image_path):
94
  raise Exception(f"Error preprocessing image: {str(e)}")
95
 
96
  def load_model():
97
- global hunyuan_pipeline, model_loaded, model_loading
98
 
99
  if model_loaded:
100
- return hunyuan_pipeline
101
 
102
  if model_loading:
103
  while model_loading and not model_loaded:
104
  time.sleep(0.5)
105
- return hunyuan_pipeline
106
 
107
  try:
108
  model_loading = True
109
- print("Loading Hunyuan3D-1.0 Lite...")
110
 
111
- model_name = "tencent/Hunyuan3D-1"
112
 
113
  max_retries = 3
114
  retry_delay = 5
@@ -128,17 +128,16 @@ def load_model():
128
  else:
129
  raise
130
 
131
- hunyuan_pipeline = Hunyuan3DDiTPipeline.from_pretrained(
132
  model_name,
133
- subfolder="lite",
134
  cache_dir=CACHE_DIR,
135
  torch_dtype=torch.float32,
136
  )
137
- hunyuan_pipeline.to("cpu")
138
 
139
  model_loaded = True
140
- print("Hunyuan3D-1.0 Lite loaded successfully on CPU")
141
- return hunyuan_pipeline
142
 
143
  except Exception as e:
144
  print(f"Error loading model: {str(e)}")
@@ -153,12 +152,17 @@ def generate_3d_model(image, detail_level):
153
  steps = num_steps[detail_level]
154
 
155
  with torch.no_grad():
156
- result = hunyuan_pipeline(image, num_inference_steps=steps, max_faces_num=20000)
 
 
 
 
 
157
 
158
- mesh = result[0] # Hunyuan3D returns a trimesh object
159
  vertices = np.array(mesh.vertices)
160
  faces = np.array(mesh.faces)
161
- vertex_colors = np.array(mesh.vertex_colors) if mesh.vertex_colors is not None else None
162
 
163
  trimesh_mesh = trimesh.Trimesh(
164
  vertices=vertices,
@@ -176,7 +180,7 @@ def generate_3d_model(image, detail_level):
176
  def health_check():
177
  return jsonify({
178
  "status": "healthy",
179
- "model": "Hunyuan3D-1.0 Lite",
180
  "device": "cpu"
181
  }), 200
182
 
@@ -355,7 +359,7 @@ def preview_model(job_id):
355
  else:
356
  return send_file(file_path, mimetype='text/plain')
357
 
358
- return jsonify({"error": "Model file not found"}), 404
359
 
360
  def cleanup_old_jobs():
361
  current_time = time.time()
@@ -416,7 +420,7 @@ def model_info(job_id):
416
  @app.route('/', methods=['GET'])
417
  def index():
418
  return jsonify({
419
- "message": "Image to 3D API (Hunyuan3D-1.0 Lite)",
420
  "endpoints": [
421
  "/convert",
422
  "/progress/<job_id>",
@@ -428,7 +432,7 @@ def index():
428
  "output_format": "glb or obj",
429
  "detail_level": "low, medium, or high"
430
  },
431
- "description": "Creates 3D models from 2D images using Hunyuan3D-1.0 Lite."
432
  }), 200
433
 
434
  if __name__ == '__main__':
 
15
  from flask_cors import CORS
16
  import numpy as np
17
  import trimesh
18
+ from diffusers import StableFast3DPipeline
19
 
20
  os.environ["CUDA_VISIBLE_DEVICES"] = ""
21
  torch.set_default_device("cpu")
 
42
  app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024
43
 
44
  processing_jobs = {}
45
+ sf3d_pipeline = None
46
  model_loaded = False
47
  model_loading = False
48
 
49
  TIMEOUT_SECONDS = 300
50
+ MAX_DIMENSION = 512 # Stable-Fast-3D uses 512x512 inputs
51
 
52
  class TimeoutError(Exception):
53
  pass
 
94
  raise Exception(f"Error preprocessing image: {str(e)}")
95
 
96
  def load_model():
97
+ global sf3d_pipeline, model_loaded, model_loading
98
 
99
  if model_loaded:
100
+ return sf3d_pipeline
101
 
102
  if model_loading:
103
  while model_loading and not model_loaded:
104
  time.sleep(0.5)
105
+ return sf3d_pipeline
106
 
107
  try:
108
  model_loading = True
109
+ print("Loading Stable-Fast-3D...")
110
 
111
+ model_name = "stabilityai/stable-fast-3d"
112
 
113
  max_retries = 3
114
  retry_delay = 5
 
128
  else:
129
  raise
130
 
131
+ sf3d_pipeline = StableFast3DPipeline.from_pretrained(
132
  model_name,
 
133
  cache_dir=CACHE_DIR,
134
  torch_dtype=torch.float32,
135
  )
136
+ sf3d_pipeline.to("cpu")
137
 
138
  model_loaded = True
139
+ print("Stable-Fast-3D loaded successfully on CPU")
140
+ return sf3d_pipeline
141
 
142
  except Exception as e:
143
  print(f"Error loading model: {str(e)}")
 
152
  steps = num_steps[detail_level]
153
 
154
  with torch.no_grad():
155
+ result = sf3d_pipeline(
156
+ image,
157
+ num_inference_steps=steps,
158
+ normal_num_inference_steps=steps // 2,
159
+ guidance_scale=7.0,
160
+ )
161
 
162
+ mesh = result.trimesh_meshes[0]
163
  vertices = np.array(mesh.vertices)
164
  faces = np.array(mesh.faces)
165
+ vertex_colors = np.array(mesh.visual.vertex_colors) if mesh.visual.vertex_colors is not None else None
166
 
167
  trimesh_mesh = trimesh.Trimesh(
168
  vertices=vertices,
 
180
  def health_check():
181
  return jsonify({
182
  "status": "healthy",
183
+ "model": "Stable-Fast-3D",
184
  "device": "cpu"
185
  }), 200
186
 
 
359
  else:
360
  return send_file(file_path, mimetype='text/plain')
361
 
362
+ return jsonify({"error": "Model not found"}), 404
363
 
364
  def cleanup_old_jobs():
365
  current_time = time.time()
 
420
  @app.route('/', methods=['GET'])
421
  def index():
422
  return jsonify({
423
+ "message": "Image to 3D API (Stable-Fast-3D)",
424
  "endpoints": [
425
  "/convert",
426
  "/progress/<job_id>",
 
432
  "output_format": "glb or obj",
433
  "detail_level": "low, medium, or high"
434
  },
435
+ "description": "Creates 3D models from 2D images using Stable-Fast-3D."
436
  }), 200
437
 
438
  if __name__ == '__main__':