Commit
·
05a5508
1
Parent(s):
a34ef67
Add DIS background remover integration and model download functionality
Browse files- Introduced a new function to ensure the DIS ONNX model is downloaded if not present.
- Updated the background removal process to utilize the DIS background remover instead of the previous method.
- Implemented temporary file handling for image processing to enhance performance and maintainability.
- Improved error handling for background removal failures, ensuring user feedback on model availability.
app.py
CHANGED
@@ -16,12 +16,18 @@ import json
|
|
16 |
import os
|
17 |
import json
|
18 |
import argparse
|
|
|
|
|
19 |
|
20 |
from model import CRM
|
21 |
from inference import generate3d
|
|
|
|
|
|
|
|
|
|
|
22 |
|
23 |
pipeline = None
|
24 |
-
rembg_session = rembg.new_session()
|
25 |
|
26 |
|
27 |
def expand_to_square(image, bg_color=(0, 0, 0, 0)):
|
@@ -39,24 +45,37 @@ def check_input_image(input_image):
|
|
39 |
if input_image is None:
|
40 |
raise gr.Error("No image uploaded!")
|
41 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
|
43 |
def remove_background(
|
44 |
image: PIL.Image.Image,
|
45 |
-
rembg_session: Any = None,
|
46 |
-
force: bool = False,
|
47 |
-
**rembg_kwargs,
|
48 |
) -> PIL.Image.Image:
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
image = rembg.remove(image, session=rembg_session, **rembg_kwargs)
|
59 |
-
return image
|
60 |
|
61 |
def do_resize_content(original_image: Image, scale_rate):
|
62 |
# resize image content wile retain the original image size
|
@@ -88,11 +107,9 @@ def preprocess_image(image, background_choice, foreground_ratio, backgroud_color
|
|
88 |
background = Image.new("RGBA", image.size, (0, 0, 0, 0))
|
89 |
image = Image.alpha_composite(background, image)
|
90 |
else:
|
91 |
-
image = remove_background(image
|
92 |
-
|
93 |
-
|
94 |
-
image = add_background(image, backgroud_color)
|
95 |
-
return image.convert("RGB")
|
96 |
|
97 |
@spaces.GPU
|
98 |
def gen_image(input_image, seed, scale, step):
|
|
|
16 |
import os
|
17 |
import json
|
18 |
import argparse
|
19 |
+
import requests
|
20 |
+
import tempfile
|
21 |
|
22 |
from model import CRM
|
23 |
from inference import generate3d
|
24 |
+
from dis_bg_remover import remove_background as dis_remove_background
|
25 |
+
|
26 |
+
DIS_ONNX_MODEL_PATH = os.environ.get("DIS_ONNX_MODEL_PATH", "isnet_dis.onnx")
|
27 |
+
DIS_ONNX_MODEL_URL = "https://huggingface.co/stoned0651/isnet_dis.onnx/resolve/main/isnet_dis.onnx"
|
28 |
+
|
29 |
|
30 |
pipeline = None
|
|
|
31 |
|
32 |
|
33 |
def expand_to_square(image, bg_color=(0, 0, 0, 0)):
|
|
|
45 |
if input_image is None:
|
46 |
raise gr.Error("No image uploaded!")
|
47 |
|
48 |
+
def ensure_dis_onnx_model():
|
49 |
+
if not os.path.exists(DIS_ONNX_MODEL_PATH):
|
50 |
+
try:
|
51 |
+
print(f"Model file not found at {DIS_ONNX_MODEL_PATH}. Downloading from {DIS_ONNX_MODEL_URL}...")
|
52 |
+
response = requests.get(DIS_ONNX_MODEL_URL, stream=True)
|
53 |
+
response.raise_for_status()
|
54 |
+
with open(DIS_ONNX_MODEL_PATH, "wb") as f:
|
55 |
+
for chunk in response.iter_content(chunk_size=8192):
|
56 |
+
if chunk:
|
57 |
+
f.write(chunk)
|
58 |
+
print(f"Downloaded model to {DIS_ONNX_MODEL_PATH}")
|
59 |
+
except Exception as e:
|
60 |
+
raise gr.Error(
|
61 |
+
f"Failed to download DIS background remover model file: {e}\n"
|
62 |
+
f"Please manually download it from {DIS_ONNX_MODEL_URL} and place it in the project directory or set the DIS_ONNX_MODEL_PATH environment variable."
|
63 |
+
)
|
64 |
+
|
65 |
+
|
66 |
|
67 |
def remove_background(
|
68 |
image: PIL.Image.Image,
|
|
|
|
|
|
|
69 |
) -> PIL.Image.Image:
|
70 |
+
ensure_dis_onnx_model()
|
71 |
+
# Create a temporary file to save the image
|
72 |
+
with tempfile.NamedTemporaryFile(suffix=".png", delete=True) as temp:
|
73 |
+
# Save the PIL image to the temporary file
|
74 |
+
image.save(temp.name)
|
75 |
+
# Call the background removal function with the file path
|
76 |
+
extracted_img, _ = dis_remove_background(DIS_ONNX_MODEL_PATH, temp.name)
|
77 |
+
# The function should return the PIL image directly
|
78 |
+
return extracted_img
|
|
|
|
|
79 |
|
80 |
def do_resize_content(original_image: Image, scale_rate):
|
81 |
# resize image content wile retain the original image size
|
|
|
107 |
background = Image.new("RGBA", image.size, (0, 0, 0, 0))
|
108 |
image = Image.alpha_composite(background, image)
|
109 |
else:
|
110 |
+
image = remove_background(image)
|
111 |
+
if image is None:
|
112 |
+
raise gr.Error("Background removal failed. Please check the input image and ensure the model file exists and is valid.")
|
|
|
|
|
113 |
|
114 |
@spaces.GPU
|
115 |
def gen_image(input_image, seed, scale, step):
|