alexnasa commited on
Commit
2a9686f
·
verified ·
1 Parent(s): f96485a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -12
app.py CHANGED
@@ -8,6 +8,7 @@ from typing import Any, Union
8
 
9
  import numpy as np
10
  import torch
 
11
 
12
  print(f'torch version:{torch.__version__}')
13
 
@@ -54,8 +55,8 @@ print("installing cuda toolkit")
54
  install_cuda_toolkit()
55
  print("finished")
56
 
57
- header_path = "/usr/local/cuda/include/cuda_runtime.h"
58
- print(f"{header_path} exists:", os.path.exists(header_path))
59
 
60
  def sh(cmd_list, extra_env=None):
61
  env = os.environ.copy()
@@ -107,14 +108,19 @@ def run_triposg(image_path: str,
107
  num_inference_steps: int = 50,
108
  guidance_scale: float = 7.0,
109
  use_flash_decoder: bool = False,
110
- rmbg: bool = True):
 
 
111
 
112
- max_num_expanded_coords = 1e9
113
-
114
  """
115
  Generate 3D part meshes from an input image.
116
  """
117
 
 
 
 
 
 
118
  if rmbg:
119
  img_pil = prepare_image(image_path, bg_color=np.array([1.0, 1.0, 1.0]), rmbg_net=rmbg_net)
120
  else:
@@ -143,9 +149,7 @@ def run_triposg(image_path: str,
143
  # Merge and color
144
  merged = get_colored_mesh_composition(outputs)
145
 
146
- # Export meshes and return results
147
- timestamp = time.strftime("%Y%m%d_%H%M%S")
148
- export_dir = os.path.join("results", timestamp)
149
  os.makedirs(export_dir, exist_ok=True)
150
  for idx, mesh in enumerate(outputs):
151
  mesh.export(os.path.join(export_dir, f"part_{idx:02}.glb"))
@@ -157,7 +161,17 @@ def run_triposg(image_path: str,
157
 
158
  return mesh_file, export_dir
159
 
160
- # Gradio Interface
 
 
 
 
 
 
 
 
 
 
161
  def build_demo():
162
  css = """
163
  #col-container {
@@ -168,7 +182,9 @@ def build_demo():
168
  theme = gr.themes.Ocean()
169
 
170
  with gr.Blocks(css=css, theme=theme) as demo:
171
-
 
 
172
  with gr.Column(elem_id="col-container"):
173
 
174
  gr.Markdown(
@@ -195,7 +211,7 @@ def build_demo():
195
  gr.HTML(
196
  """
197
  <p style="opacity: 0.6; font-style: italic;">
198
- This might take a few seconds to load the 3D model
199
  </p>
200
  """
201
  )
@@ -225,10 +241,12 @@ def build_demo():
225
 
226
  run_button.click(fn=run_triposg,
227
  inputs=[input_image, num_parts, seed, num_tokens, num_steps,
228
- guidance, flash_decoder, remove_bg],
229
  outputs=[output_model, output_dir])
230
  return demo
231
 
232
  if __name__ == "__main__":
233
  demo = build_demo()
 
 
234
  demo.launch()
 
8
 
9
  import numpy as np
10
  import torch
11
+ import uuid
12
 
13
  print(f'torch version:{torch.__version__}')
14
 
 
55
  install_cuda_toolkit()
56
  print("finished")
57
 
58
+ os.environ["PARTCRAFTER_PROCESSED"] = f"{os.getcwd()}/proprocess_results"
59
+
60
 
61
  def sh(cmd_list, extra_env=None):
62
  env = os.environ.copy()
 
108
  num_inference_steps: int = 50,
109
  guidance_scale: float = 7.0,
110
  use_flash_decoder: bool = False,
111
+ rmbg: bool = True,
112
+ session_id = None,
113
+ progress=gr.Progress(track_tqdm=True),):
114
 
 
 
115
  """
116
  Generate 3D part meshes from an input image.
117
  """
118
 
119
+ max_num_expanded_coords = 1e9
120
+
121
+ if session_id is None:
122
+ session_id = uuid.uuid4().hex
123
+
124
  if rmbg:
125
  img_pil = prepare_image(image_path, bg_color=np.array([1.0, 1.0, 1.0]), rmbg_net=rmbg_net)
126
  else:
 
149
  # Merge and color
150
  merged = get_colored_mesh_composition(outputs)
151
 
152
+ export_dir = os.path.join(os.environ["PARTCRAFTER_PROCESSED"], session_id)
 
 
153
  os.makedirs(export_dir, exist_ok=True)
154
  for idx, mesh in enumerate(outputs):
155
  mesh.export(os.path.join(export_dir, f"part_{idx:02}.glb"))
 
161
 
162
  return mesh_file, export_dir
163
 
164
+ def cleanup(request: gr.Request):
165
+
166
+ sid = request.session_hash
167
+ if sid:
168
+ d1 = os.path.join(os.environ["PARTCRAFTER_PROCESSED"], sid)
169
+ shutil.rmtree(d1, ignore_errors=True)
170
+
171
+ def start_session(request: gr.Request):
172
+
173
+ return request.session_hash
174
+
175
  def build_demo():
176
  css = """
177
  #col-container {
 
182
  theme = gr.themes.Ocean()
183
 
184
  with gr.Blocks(css=css, theme=theme) as demo:
185
+ session_state = gr.State()
186
+ demo.load(start_session, outputs=[session_state])
187
+
188
  with gr.Column(elem_id="col-container"):
189
 
190
  gr.Markdown(
 
211
  gr.HTML(
212
  """
213
  <p style="opacity: 0.6; font-style: italic;">
214
+ The 3D Preview might take a few seconds to load the 3D model
215
  </p>
216
  """
217
  )
 
241
 
242
  run_button.click(fn=run_triposg,
243
  inputs=[input_image, num_parts, seed, num_tokens, num_steps,
244
+ guidance, flash_decoder, remove_bg, session_state],
245
  outputs=[output_model, output_dir])
246
  return demo
247
 
248
  if __name__ == "__main__":
249
  demo = build_demo()
250
+ demo.unload(cleanup)
251
+ demo.queue()
252
  demo.launch()