mac9087 commited on
Commit
59e224e
·
verified ·
1 Parent(s): f21a9a5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -34
app.py CHANGED
@@ -1,4 +1,3 @@
1
-
2
  import os
3
  import torch
4
  import time
@@ -15,8 +14,8 @@ from huggingface_hub import snapshot_download
15
  from flask_cors import CORS
16
  import numpy as np
17
  import trimesh
18
- from tsr.system import TripoSR
19
- from tsr.utils import remove_background, resize_foreground
20
 
21
  os.environ["CUDA_VISIBLE_DEVICES"] = ""
22
  torch.set_default_device("cpu")
@@ -43,12 +42,12 @@ app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
43
  app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024
44
 
45
  processing_jobs = {}
46
- triposr_model = None
47
  model_loaded = False
48
  model_loading = False
49
 
50
  TIMEOUT_SECONDS = 300
51
- MAX_DIMENSION = 512 # TripoSR uses 512x512 inputs
52
 
53
  class TimeoutError(Exception):
54
  pass
@@ -86,34 +85,35 @@ def allowed_file(filename):
86
 
87
  def preprocess_image(image_path):
88
  try:
89
- with Image.open(image_path) as img:
90
- if img.mode == 'RGBA':
91
- img = img.convert('RGB')
92
- img = img.resize((512, 512), Image.LANCZOS)
93
- img_array = np.array(img) / 255.0
94
- img_array = remove_background(img_array)
95
- img_array = resize_foreground(img_array, 0.85)
96
- img_array = np.clip(img_array, 0, 1) * 255
97
- return Image.fromarray(img_array.astype(np.uint8))
 
98
  except Exception as e:
99
  raise Exception(f"Error preprocessing image: {str(e)}")
100
 
101
  def load_model():
102
- global triposr_model, model_loaded, model_loading
103
 
104
  if model_loaded:
105
- return triposr_model
106
 
107
  if model_loading:
108
  while model_loading and not model_loaded:
109
  time.sleep(0.5)
110
- return triposr_model
111
 
112
  try:
113
  model_loading = True
114
- print("Loading TripoSR...")
115
 
116
- model_name = "tripo3d/triposr"
117
 
118
  max_retries = 3
119
  retry_delay = 5
@@ -133,15 +133,15 @@ def load_model():
133
  else:
134
  raise
135
 
136
- triposr_model = TripoSR.from_pretrained(
137
  model_name,
138
  cache_dir=CACHE_DIR,
139
  device="cpu",
140
  )
141
 
142
  model_loaded = True
143
- print("TripoSR loaded successfully on CPU")
144
- return triposr_model
145
 
146
  except Exception as e:
147
  print(f"Error loading model: {str(e)}")
@@ -152,17 +152,19 @@ def load_model():
152
 
153
  def generate_3d_model(image, detail_level):
154
  try:
155
- chunk_size = {'low': 4096, 'medium': 8192, 'high': 16384}
156
- chunk = chunk_size[detail_level]
157
 
158
  with torch.no_grad():
159
- scene_codes = triposr_model(image, device="cpu")
160
- meshes = triposr_model.mesher(scene_codes, chunk_size=chunk)
 
 
 
161
 
162
- mesh = meshes[0]
163
- vertices = np.array(mesh.vertices)
164
- faces = np.array(mesh.faces)
165
- vertex_colors = np.array(mesh.vertex_colors) if hasattr(mesh, 'vertex_colors') and mesh.vertex_colors is not None else None
166
 
167
  trimesh_mesh = trimesh.Trimesh(
168
  vertices=vertices,
@@ -180,7 +182,7 @@ def generate_3d_model(image, detail_level):
180
  def health_check():
181
  return jsonify({
182
  "status": "healthy",
183
- "model": "TripoSR",
184
  "device": "cpu"
185
  }), 200
186
 
@@ -294,7 +296,7 @@ def convert_image_to_3d():
294
  file_path = os.path.join(output_dir, f"model.{output_format}")
295
  mesh.export(file_path, file_type=output_format)
296
 
297
- processing_jobs[job_id]['result_url'] = f"/download/{job_id.ConcurrentHashMap}"
298
  processing_jobs[job_id]['preview_url'] = f"/preview/{job_id}"
299
 
300
  processing_jobs[job_id]['status'] = 'completed'
@@ -420,7 +422,7 @@ def model_info(job_id):
420
  @app.route('/', methods=['GET'])
421
  def index():
422
  return jsonify({
423
- "message": "Image to 3D API (TripoSR)",
424
  "endpoints": [
425
  "/convert",
426
  "/progress/<job_id>",
@@ -432,7 +434,7 @@ def index():
432
  "output_format": "glb or obj",
433
  "detail_level": "low, medium, or high"
434
  },
435
- "description": "Creates 3D models from 2D images using TripoSR."
436
  }), 200
437
 
438
  if __name__ == '__main__':
 
 
1
  import os
2
  import torch
3
  import time
 
14
  from flask_cors import CORS
15
  import numpy as np
16
  import trimesh
17
+ import cv2
18
+ from lgm.models import LGM
19
 
20
  os.environ["CUDA_VISIBLE_DEVICES"] = ""
21
  torch.set_default_device("cpu")
 
42
  app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024
43
 
44
  processing_jobs = {}
45
+ lgm_model = None
46
  model_loaded = False
47
  model_loading = False
48
 
49
  TIMEOUT_SECONDS = 300
50
+ MAX_DIMENSION = 512 # LGM uses 512x512 inputs
51
 
52
  class TimeoutError(Exception):
53
  pass
 
85
 
86
  def preprocess_image(image_path):
87
  try:
88
+ img = cv2.imread(image_path, cv2.IMREAD_UNCHANGED)
89
+ if img.shape[2] == 4: # RGBA
90
+ alpha = img[:, :, 3]
91
+ rgb = img[:, :, :3]
92
+ white_bg = np.ones_like(rgb) * 255
93
+ mask = alpha[:, :, np.newaxis] / 255.0
94
+ img = rgb * mask + white_bg * (1 - mask)
95
+ img = img.astype(np.uint8)
96
+ img = cv2.resize(img, (512, 512), interpolation=cv2.INTER_LANCZOS4)
97
+ return Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
98
  except Exception as e:
99
  raise Exception(f"Error preprocessing image: {str(e)}")
100
 
101
  def load_model():
102
+ global lgm_model, model_loaded, model_loading
103
 
104
  if model_loaded:
105
+ return lgm_model
106
 
107
  if model_loading:
108
  while model_loading and not model_loaded:
109
  time.sleep(0.5)
110
+ return lgm_model
111
 
112
  try:
113
  model_loading = True
114
+ print("Loading LGM...")
115
 
116
+ model_name = "large-gaussian-model/lgm"
117
 
118
  max_retries = 3
119
  retry_delay = 5
 
133
  else:
134
  raise
135
 
136
+ lgm_model = LGM.from_pretrained(
137
  model_name,
138
  cache_dir=CACHE_DIR,
139
  device="cpu",
140
  )
141
 
142
  model_loaded = True
143
+ print("LGM loaded successfully on CPU")
144
+ return lgm_model
145
 
146
  except Exception as e:
147
  print(f"Error loading model: {str(e)}")
 
152
 
153
  def generate_3d_model(image, detail_level):
154
  try:
155
+ resolution = {'low': 256, 'medium': 512, 'high': 1024}
156
+ res = resolution[detail_level]
157
 
158
  with torch.no_grad():
159
+ mesh_data = lgm_model.generate_mesh(
160
+ image,
161
+ resolution=res,
162
+ device="cpu"
163
+ )
164
 
165
+ vertices = np.array(mesh_data['vertices'])
166
+ faces = np.array(mesh_data['faces'])
167
+ vertex_colors = np.array(mesh_data['vertex_colors']) if 'vertex_colors' in mesh_data else None
 
168
 
169
  trimesh_mesh = trimesh.Trimesh(
170
  vertices=vertices,
 
182
  def health_check():
183
  return jsonify({
184
  "status": "healthy",
185
+ "model": "LGM",
186
  "device": "cpu"
187
  }), 200
188
 
 
296
  file_path = os.path.join(output_dir, f"model.{output_format}")
297
  mesh.export(file_path, file_type=output_format)
298
 
299
+ processing_jobs[job_id]['result_url'] = f"/download/{job_id}"
300
  processing_jobs[job_id]['preview_url'] = f"/preview/{job_id}"
301
 
302
  processing_jobs[job_id]['status'] = 'completed'
 
422
  @app.route('/', methods=['GET'])
423
  def index():
424
  return jsonify({
425
+ "message": "Image to 3D API (LGM)",
426
  "endpoints": [
427
  "/convert",
428
  "/progress/<job_id>",
 
434
  "output_format": "glb or obj",
435
  "detail_level": "low, medium, or high"
436
  },
437
+ "description": "Creates 3D models from 2D images using LGM."
438
  }), 200
439
 
440
  if __name__ == '__main__':