Stylique commited on
Commit
ff907d0
·
verified ·
1 Parent(s): 0de41d8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -61
app.py CHANGED
@@ -11,6 +11,9 @@ import subprocess
11
  from glob import glob
12
  import requests
13
  from huggingface_hub import snapshot_download
 
 
 
14
 
15
  # Download models
16
  os.makedirs("ckpts", exist_ok=True)
@@ -37,49 +40,35 @@ images_examples = [
37
  ]
38
 
39
  def remove_background(input_pil, remove_bg):
40
-
41
- # Create a temporary folder for downloaded and processed images
42
  temp_dir = tempfile.mkdtemp()
43
  unique_id = str(uuid.uuid4())
44
  image_path = os.path.join(temp_dir, f'input_image_{unique_id}.png')
45
-
46
  try:
47
- # Check if input_url is already a PIL Image
48
  if isinstance(input_pil, Image.Image):
49
  image = input_pil
50
  else:
51
- # Otherwise, assume it's a file path and open it
52
  image = Image.open(input_pil)
53
-
54
- # Flip the image horizontally
55
  image = image.transpose(Image.FLIP_LEFT_RIGHT)
56
-
57
- # Save the resized image
58
  image.save(image_path)
59
  except Exception as e:
60
  shutil.rmtree(temp_dir)
61
  raise gr.Error(f"Error downloading or saving the image: {str(e)}")
62
 
63
  if remove_bg is True:
64
- # Run background removal
65
  removed_bg_path = os.path.join(temp_dir, f'output_image_rmbg_{unique_id}.png')
66
  try:
67
  img = Image.open(image_path)
68
  result = remove(img)
69
  result.save(removed_bg_path)
70
-
71
- # Remove the input image to keep the temp directory clean
72
  os.remove(image_path)
73
  except Exception as e:
74
  shutil.rmtree(temp_dir)
75
  raise gr.Error(f"Error removing background: {str(e)}")
76
-
77
  return removed_bg_path, temp_dir
78
  else:
79
  return image_path, temp_dir
80
 
81
  def run_inference(temp_dir, removed_bg_path):
82
- # Define the inference configuration
83
  inference_config = "configs/inference-768-6view.yaml"
84
  pretrained_model = "./ckpts"
85
  crop_size = 740
@@ -88,7 +77,6 @@ def run_inference(temp_dir, removed_bg_path):
88
  save_mode = "rgb"
89
 
90
  try:
91
- # Run the inference command
92
  subprocess.run(
93
  [
94
  "python", "inference.py",
@@ -103,19 +91,12 @@ def run_inference(temp_dir, removed_bg_path):
103
  ],
104
  check=True
105
  )
106
-
107
-
108
- # Retrieve the file name without the extension
109
  removed_bg_file_name = os.path.splitext(os.path.basename(removed_bg_path))[0]
110
-
111
- # List objects in the "out" folder
112
  out_folder_path = "out"
113
  out_folder_objects = os.listdir(out_folder_path)
114
  print(f"Objects in '{out_folder_path}':")
115
  for obj in out_folder_objects:
116
  print(f" - {obj}")
117
-
118
- # List objects in the "out/{removed_bg_file_name}" folder
119
  specific_out_folder_path = os.path.join(out_folder_path, removed_bg_file_name)
120
  if os.path.exists(specific_out_folder_path) and os.path.isdir(specific_out_folder_path):
121
  specific_out_folder_objects = os.listdir(specific_out_folder_path)
@@ -124,36 +105,23 @@ def run_inference(temp_dir, removed_bg_path):
124
  print(f" - {obj}")
125
  else:
126
  print(f"\nThe folder '{specific_out_folder_path}' does not exist.")
127
-
128
  output_video = glob(os.path.join(f"out/{removed_bg_file_name}", "*.mp4"))
129
  output_objects = glob(os.path.join(f"out/{removed_bg_file_name}", "*.obj"))
130
  return output_video, output_objects
131
-
132
  except subprocess.CalledProcessError as e:
133
  return f"Error during inference: {str(e)}"
134
 
135
  def process_image(input_pil, remove_bg, progress=gr.Progress(track_tqdm=True)):
136
-
137
  torch.cuda.empty_cache()
138
-
139
- # Remove background
140
  result = remove_background(input_pil, remove_bg)
141
-
142
  if isinstance(result, str) and result.startswith("Error"):
143
- raise gr.Error(f"{result}") # Return the error message if something went wrong
144
-
145
- removed_bg_path, temp_dir = result # Unpack only if successful
146
-
147
- # Run inference
148
  output_video, output_objects = run_inference(temp_dir, removed_bg_path)
149
-
150
  if isinstance(output_video, str) and output_video.startswith("Error"):
151
  shutil.rmtree(temp_dir)
152
- raise gr.Error(f"{output_video}") # Return the error message if inference failed
153
-
154
-
155
- shutil.rmtree(temp_dir) # Cleanup temporary folder
156
- print(output_video)
157
  torch.cuda.empty_cache()
158
  return output_video[0], output_objects[0], output_objects[1]
159
 
@@ -166,6 +134,7 @@ div#video-out-elm{
166
  height: 323px;
167
  }
168
  """
 
169
  def gradio_interface():
170
  with gr.Blocks(css=css) as app:
171
  with gr.Column(elem_id="col-container"):
@@ -178,48 +147,55 @@ def gradio_interface():
178
  <a href="https://penghtyx.github.io/PSHuman/">
179
  <img src='https://img.shields.io/badge/Project-Page-green'>
180
  </a>
181
- <a href="https://arxiv.org/pdf/2409.10141">
182
  <img src='https://img.shields.io/badge/ArXiv-Paper-red'>
183
  </a>
184
  <a href="https://huggingface.co/spaces/fffiloni/PSHuman?duplicate=true">
185
- <img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/duplicate-this-space-sm.svg" alt="Duplicate this Space">
186
- </a>
187
- <a href="https://huggingface.co/fffiloni">
188
- <img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/follow-me-on-HF-sm-dark.svg" alt="Follow me on HF">
189
- </a>
190
  </div>
191
  """)
192
  with gr.Group():
193
  with gr.Row():
194
  with gr.Column(scale=2):
195
-
196
- input_image = gr.Image(
197
- label="Image input",
198
- type="pil",
199
- image_mode="RGBA",
200
- height=480
201
- )
202
-
203
  remove_bg = gr.Checkbox(label="Need to remove BG ?", value=False)
204
-
205
  submit_button = gr.Button("Process")
206
-
207
  with gr.Column(scale=4):
208
  output_video= gr.Video(label="Output Video", elem_id="video-out-elm")
209
  with gr.Row():
210
  output_object_mesh = gr.Model3D(label=".OBJ Mesh", height=240)
211
  output_object_color = gr.Model3D(label=".OBJ colored", height=240)
212
-
213
  gr.Examples(
214
  examples = examples_folder,
215
  inputs = [input_image],
216
  examples_per_page = 11
217
  )
218
-
219
  submit_button.click(process_image, inputs=[input_image, remove_bg], outputs=[output_video, output_object_mesh, output_object_color])
220
-
221
  return app
222
 
223
- # Launch the Gradio app
224
- app = gradio_interface()
225
- app.launch(show_api=False, show_error=True, ssr_mode=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  from glob import glob
12
  import requests
13
  from huggingface_hub import snapshot_download
14
+ import io
15
+ from fastapi import UploadFile, File, HTTPException
16
+ from fastapi.responses import FileResponse
17
 
18
  # Download models
19
  os.makedirs("ckpts", exist_ok=True)
 
40
  ]
41
 
42
  def remove_background(input_pil, remove_bg):
 
 
43
  temp_dir = tempfile.mkdtemp()
44
  unique_id = str(uuid.uuid4())
45
  image_path = os.path.join(temp_dir, f'input_image_{unique_id}.png')
 
46
  try:
 
47
  if isinstance(input_pil, Image.Image):
48
  image = input_pil
49
  else:
 
50
  image = Image.open(input_pil)
 
 
51
  image = image.transpose(Image.FLIP_LEFT_RIGHT)
 
 
52
  image.save(image_path)
53
  except Exception as e:
54
  shutil.rmtree(temp_dir)
55
  raise gr.Error(f"Error downloading or saving the image: {str(e)}")
56
 
57
  if remove_bg is True:
 
58
  removed_bg_path = os.path.join(temp_dir, f'output_image_rmbg_{unique_id}.png')
59
  try:
60
  img = Image.open(image_path)
61
  result = remove(img)
62
  result.save(removed_bg_path)
 
 
63
  os.remove(image_path)
64
  except Exception as e:
65
  shutil.rmtree(temp_dir)
66
  raise gr.Error(f"Error removing background: {str(e)}")
 
67
  return removed_bg_path, temp_dir
68
  else:
69
  return image_path, temp_dir
70
 
71
  def run_inference(temp_dir, removed_bg_path):
 
72
  inference_config = "configs/inference-768-6view.yaml"
73
  pretrained_model = "./ckpts"
74
  crop_size = 740
 
77
  save_mode = "rgb"
78
 
79
  try:
 
80
  subprocess.run(
81
  [
82
  "python", "inference.py",
 
91
  ],
92
  check=True
93
  )
 
 
 
94
  removed_bg_file_name = os.path.splitext(os.path.basename(removed_bg_path))[0]
 
 
95
  out_folder_path = "out"
96
  out_folder_objects = os.listdir(out_folder_path)
97
  print(f"Objects in '{out_folder_path}':")
98
  for obj in out_folder_objects:
99
  print(f" - {obj}")
 
 
100
  specific_out_folder_path = os.path.join(out_folder_path, removed_bg_file_name)
101
  if os.path.exists(specific_out_folder_path) and os.path.isdir(specific_out_folder_path):
102
  specific_out_folder_objects = os.listdir(specific_out_folder_path)
 
105
  print(f" - {obj}")
106
  else:
107
  print(f"\nThe folder '{specific_out_folder_path}' does not exist.")
 
108
  output_video = glob(os.path.join(f"out/{removed_bg_file_name}", "*.mp4"))
109
  output_objects = glob(os.path.join(f"out/{removed_bg_file_name}", "*.obj"))
110
  return output_video, output_objects
 
111
  except subprocess.CalledProcessError as e:
112
  return f"Error during inference: {str(e)}"
113
 
114
  def process_image(input_pil, remove_bg, progress=gr.Progress(track_tqdm=True)):
 
115
  torch.cuda.empty_cache()
 
 
116
  result = remove_background(input_pil, remove_bg)
 
117
  if isinstance(result, str) and result.startswith("Error"):
118
+ raise gr.Error(f"{result}")
119
+ removed_bg_path, temp_dir = result
 
 
 
120
  output_video, output_objects = run_inference(temp_dir, removed_bg_path)
 
121
  if isinstance(output_video, str) and output_video.startswith("Error"):
122
  shutil.rmtree(temp_dir)
123
+ raise gr.Error(f"{output_video}")
124
+ shutil.rmtree(temp_dir)
 
 
 
125
  torch.cuda.empty_cache()
126
  return output_video[0], output_objects[0], output_objects[1]
127
 
 
134
  height: 323px;
135
  }
136
  """
137
+
138
  def gradio_interface():
139
  with gr.Blocks(css=css) as app:
140
  with gr.Column(elem_id="col-container"):
 
147
  <a href="https://penghtyx.github.io/PSHuman/">
148
  <img src='https://img.shields.io/badge/Project-Page-green'>
149
  </a>
150
+ <a href="https://arxiv.org/pdf/2409.10141">
151
  <img src='https://img.shields.io/badge/ArXiv-Paper-red'>
152
  </a>
153
  <a href="https://huggingface.co/spaces/fffiloni/PSHuman?duplicate=true">
154
+ <img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/duplicate-this-space-sm.svg" alt="Duplicate this Space">
155
+ </a>
156
+ <a href="https://huggingface.co/fffiloni">
157
+ <img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/follow-me-on-HF-sm-dark.svg" alt="Follow me on HF">
158
+ </a>
159
  </div>
160
  """)
161
  with gr.Group():
162
  with gr.Row():
163
  with gr.Column(scale=2):
164
+ input_image = gr.Image(label="Image input", type="pil", image_mode="RGBA", height=480)
 
 
 
 
 
 
 
165
  remove_bg = gr.Checkbox(label="Need to remove BG ?", value=False)
 
166
  submit_button = gr.Button("Process")
 
167
  with gr.Column(scale=4):
168
  output_video= gr.Video(label="Output Video", elem_id="video-out-elm")
169
  with gr.Row():
170
  output_object_mesh = gr.Model3D(label=".OBJ Mesh", height=240)
171
  output_object_color = gr.Model3D(label=".OBJ colored", height=240)
 
172
  gr.Examples(
173
  examples = examples_folder,
174
  inputs = [input_image],
175
  examples_per_page = 11
176
  )
 
177
  submit_button.click(process_image, inputs=[input_image, remove_bg], outputs=[output_video, output_object_mesh, output_object_color])
 
178
  return app
179
 
180
+ if __name__ == "__main__":
181
+ gradio_app = gradio_interface()
182
+ fastapi_app = gradio_app.app
183
+
184
+ @fastapi_app.post("/api/3d-reconstruct")
185
+ async def reconstruct(
186
+ image_file: UploadFile = File(...),
187
+ remove_bg: bool = False
188
+ ):
189
+ try:
190
+ contents = await image_file.read()
191
+ pil_image = Image.open(io.BytesIO(contents)).convert("RGBA")
192
+ video_path, mesh_path, colored_path = process_image(pil_image, remove_bg)
193
+ return FileResponse(
194
+ colored_path,
195
+ media_type="application/octet-stream",
196
+ filename=os.path.basename(colored_path)
197
+ )
198
+ except Exception as e:
199
+ raise HTTPException(status_code=500, detail=str(e))
200
+
201
+ gradio_app.launch(show_api=False, show_error=True, ssr_mode=False)