anvilinteractiv commited on
Commit
84a40b9
·
verified ·
1 Parent(s): 65dfc03

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -2
app.py CHANGED
@@ -291,12 +291,46 @@ def get_random_seed(randomize_seed, seed):
291
  logger.error(f"Error in get_random_seed: {str(e)}")
292
  raise
293
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
294
  @spaces.GPU()
295
  @torch.no_grad()
296
- def run_segmentation(image: str):
297
  try:
298
  logger.info("Running segmentation")
299
- image = prepare_image(image, bg_color=np.array([1.0, 1.0, 1.0]), rmbg_net=rmbg_net)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
300
  logger.info("Segmentation complete")
301
  return image
302
  except Exception as e:
 
291
  logger.error(f"Error in get_random_seed: {str(e)}")
292
  raise
293
 
294
+
295
+ def download_image(url: str, save_path: str) -> str:
296
+ """Download an image from a URL and save it locally."""
297
+ try:
298
+ logger.info(f"Downloading image from {url}")
299
+ response = requests.get(url, stream=True)
300
+ response.raise_for_status()
301
+ with open(save_path, "wb") as f:
302
+ for chunk in response.iter_content(chunk_size=8192):
303
+ f.write(chunk)
304
+ logger.info(f"Saved image to {save_path}")
305
+ return save_path
306
+ except Exception as e:
307
+ logger.error(f"Failed to download image from {url}: {str(e)}")
308
+ raise
309
+
310
  @spaces.GPU()
311
  @torch.no_grad()
312
+ def run_segmentation(image):
313
  try:
314
  logger.info("Running segmentation")
315
+ # Handle FileData dict or URL
316
+ if isinstance(image, dict):
317
+ image_path = image.get("path") or image.get("url")
318
+ if not image_path:
319
+ logger.error("Invalid image input: no path or URL provided")
320
+ raise ValueError("Invalid image input: no path or URL provided")
321
+ if image_path.startswith("http"):
322
+ temp_image_path = os.path.join(TMP_DIR, f"input_{get_random_hex()}.png")
323
+ image_path = download_image(image_path, temp_image_path)
324
+ elif isinstance(image, str) and image.startswith("http"):
325
+ temp_image_path = os.path.join(TMP_DIR, f"input_{get_random_hex()}.png")
326
+ image_path = download_image(image, temp_image_path)
327
+ else:
328
+ image_path = image
329
+ if not isinstance(image, (str, bytes)) or (isinstance(image, str) and not os.path.exists(image)):
330
+ logger.error(f"Invalid image type or path: {type(image)}")
331
+ raise ValueError(f"Expected str (path/URL), bytes, or FileData dict, got {type(image)}")
332
+
333
+ image = prepare_image(image_path, bg_color=np.array([1.0, 1.0, 1.0]), rmbg_net=rmbg_net)
334
  logger.info("Segmentation complete")
335
  return image
336
  except Exception as e: