Spaces:
mashroo
/
Runtime error

YoussefAnso commited on
Commit
fa6cc29
·
1 Parent(s): 83a8b53

Refactor remove_background function in app.py to enhance mask handling. Added checks for mask dimensions and data type, ensuring proper conversion to RGBA format for images. This improves the output quality when masks are applied.

Browse files
Files changed (1) hide show
  1. app.py +13 -6
app.py CHANGED
@@ -68,17 +68,24 @@ def remove_background(
68
  force: bool = False,
69
  **rembg_kwargs,
70
  ) -> PIL.Image.Image:
71
- # Ensure the ONNX model exists (download if needed)
72
  ensure_dis_onnx_model()
73
- # Save PIL image to a temporary file
74
  with tempfile.NamedTemporaryFile(suffix=".png", delete=True) as temp:
75
  image.save(temp.name)
76
  extracted_img, mask = dis_remove_background(DIS_ONNX_MODEL_PATH, temp.name)
77
- # Convert to PIL Image if needed
78
  if isinstance(extracted_img, np.ndarray):
79
- if extracted_img.dtype != np.uint8:
80
- extracted_img = (np.clip(extracted_img, 0, 1) * 255).astype(np.uint8)
81
- extracted_img = Image.fromarray(extracted_img)
 
 
 
 
 
 
 
 
 
82
  return extracted_img
83
 
84
  def do_resize_content(original_image: Image, scale_rate):
 
68
  force: bool = False,
69
  **rembg_kwargs,
70
  ) -> PIL.Image.Image:
 
71
  ensure_dis_onnx_model()
 
72
  with tempfile.NamedTemporaryFile(suffix=".png", delete=True) as temp:
73
  image.save(temp.name)
74
  extracted_img, mask = dis_remove_background(DIS_ONNX_MODEL_PATH, temp.name)
75
+ # If extracted_img is a mask (single channel), use it as alpha for the original image
76
  if isinstance(extracted_img, np.ndarray):
77
+ # If mask is float, convert to uint8
78
+ if mask.dtype != np.uint8:
79
+ mask = (np.clip(mask, 0, 1) * 255).astype(np.uint8)
80
+ # Ensure mask is 2D
81
+ if mask.ndim == 3:
82
+ mask = mask[..., 0]
83
+ # Convert original image to RGBA
84
+ image = image.convert("RGBA")
85
+ image_np = np.array(image)
86
+ image_np[..., 3] = mask
87
+ return Image.fromarray(image_np)
88
+ # If extracted_img is already a color image, just return it
89
  return extracted_img
90
 
91
  def do_resize_content(original_image: Image, scale_rate):