Commit
·
fa6cc29
1
Parent(s):
83a8b53
Refactor remove_background function in app.py to enhance mask handling. Added checks for mask dimensions and data type, ensuring proper conversion to RGBA format for images. This improves the output quality when masks are applied.
Browse files
app.py
CHANGED
@@ -68,17 +68,24 @@ def remove_background(
|
|
68 |
force: bool = False,
|
69 |
**rembg_kwargs,
|
70 |
) -> PIL.Image.Image:
|
71 |
-
# Ensure the ONNX model exists (download if needed)
|
72 |
ensure_dis_onnx_model()
|
73 |
-
# Save PIL image to a temporary file
|
74 |
with tempfile.NamedTemporaryFile(suffix=".png", delete=True) as temp:
|
75 |
image.save(temp.name)
|
76 |
extracted_img, mask = dis_remove_background(DIS_ONNX_MODEL_PATH, temp.name)
|
77 |
-
#
|
78 |
if isinstance(extracted_img, np.ndarray):
|
79 |
-
|
80 |
-
|
81 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
82 |
return extracted_img
|
83 |
|
84 |
def do_resize_content(original_image: Image, scale_rate):
|
|
|
68 |
force: bool = False,
|
69 |
**rembg_kwargs,
|
70 |
) -> PIL.Image.Image:
|
|
|
71 |
ensure_dis_onnx_model()
|
|
|
72 |
with tempfile.NamedTemporaryFile(suffix=".png", delete=True) as temp:
|
73 |
image.save(temp.name)
|
74 |
extracted_img, mask = dis_remove_background(DIS_ONNX_MODEL_PATH, temp.name)
|
75 |
+
# If extracted_img is a mask (single channel), use it as alpha for the original image
|
76 |
if isinstance(extracted_img, np.ndarray):
|
77 |
+
# If mask is float, convert to uint8
|
78 |
+
if mask.dtype != np.uint8:
|
79 |
+
mask = (np.clip(mask, 0, 1) * 255).astype(np.uint8)
|
80 |
+
# Ensure mask is 2D
|
81 |
+
if mask.ndim == 3:
|
82 |
+
mask = mask[..., 0]
|
83 |
+
# Convert original image to RGBA
|
84 |
+
image = image.convert("RGBA")
|
85 |
+
image_np = np.array(image)
|
86 |
+
image_np[..., 3] = mask
|
87 |
+
return Image.fromarray(image_np)
|
88 |
+
# If extracted_img is already a color image, just return it
|
89 |
return extracted_img
|
90 |
|
91 |
def do_resize_content(original_image: Image, scale_rate):
|