Spaces:
mashroo
/
Running on Zero

YoussefAnso commited on
Commit
05a5508
·
1 Parent(s): a34ef67

Add DIS background remover integration and model download functionality

Browse files

- Introduced a new function to ensure the DIS ONNX model is downloaded if not present.
- Updated the background removal process to utilize the DIS background remover instead of the previous method.
- Implemented temporary file handling for image processing to enhance performance and maintainability.
- Improved error handling for background removal failures, ensuring user feedback on model availability.

Files changed (1) hide show
  1. app.py +37 -20
app.py CHANGED
@@ -16,12 +16,18 @@ import json
16
  import os
17
  import json
18
  import argparse
 
 
19
 
20
  from model import CRM
21
  from inference import generate3d
 
 
 
 
 
22
 
23
  pipeline = None
24
- rembg_session = rembg.new_session()
25
 
26
 
27
  def expand_to_square(image, bg_color=(0, 0, 0, 0)):
@@ -39,24 +45,37 @@ def check_input_image(input_image):
39
  if input_image is None:
40
  raise gr.Error("No image uploaded!")
41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
  def remove_background(
44
  image: PIL.Image.Image,
45
- rembg_session: Any = None,
46
- force: bool = False,
47
- **rembg_kwargs,
48
  ) -> PIL.Image.Image:
49
- do_remove = True
50
- if image.mode == "RGBA" and image.getextrema()[3][0] < 255:
51
- # explain why current do not rm bg
52
- print("alhpa channl not enpty, skip remove background, using alpha channel as mask")
53
- background = Image.new("RGBA", image.size, (0, 0, 0, 0))
54
- image = Image.alpha_composite(background, image)
55
- do_remove = False
56
- do_remove = do_remove or force
57
- if do_remove:
58
- image = rembg.remove(image, session=rembg_session, **rembg_kwargs)
59
- return image
60
 
61
  def do_resize_content(original_image: Image, scale_rate):
62
  # resize image content wile retain the original image size
@@ -88,11 +107,9 @@ def preprocess_image(image, background_choice, foreground_ratio, backgroud_color
88
  background = Image.new("RGBA", image.size, (0, 0, 0, 0))
89
  image = Image.alpha_composite(background, image)
90
  else:
91
- image = remove_background(image, rembg_session, force=True)
92
- image = do_resize_content(image, foreground_ratio)
93
- image = expand_to_square(image)
94
- image = add_background(image, backgroud_color)
95
- return image.convert("RGB")
96
 
97
  @spaces.GPU
98
  def gen_image(input_image, seed, scale, step):
 
16
  import os
17
  import json
18
  import argparse
19
+ import requests
20
+ import tempfile
21
 
22
  from model import CRM
23
  from inference import generate3d
24
+ from dis_bg_remover import remove_background as dis_remove_background
25
+
26
+ DIS_ONNX_MODEL_PATH = os.environ.get("DIS_ONNX_MODEL_PATH", "isnet_dis.onnx")
27
+ DIS_ONNX_MODEL_URL = "https://huggingface.co/stoned0651/isnet_dis.onnx/resolve/main/isnet_dis.onnx"
28
+
29
 
30
  pipeline = None
 
31
 
32
 
33
  def expand_to_square(image, bg_color=(0, 0, 0, 0)):
 
45
  if input_image is None:
46
  raise gr.Error("No image uploaded!")
47
 
48
+ def ensure_dis_onnx_model():
49
+ if not os.path.exists(DIS_ONNX_MODEL_PATH):
50
+ try:
51
+ print(f"Model file not found at {DIS_ONNX_MODEL_PATH}. Downloading from {DIS_ONNX_MODEL_URL}...")
52
+ response = requests.get(DIS_ONNX_MODEL_URL, stream=True)
53
+ response.raise_for_status()
54
+ with open(DIS_ONNX_MODEL_PATH, "wb") as f:
55
+ for chunk in response.iter_content(chunk_size=8192):
56
+ if chunk:
57
+ f.write(chunk)
58
+ print(f"Downloaded model to {DIS_ONNX_MODEL_PATH}")
59
+ except Exception as e:
60
+ raise gr.Error(
61
+ f"Failed to download DIS background remover model file: {e}\n"
62
+ f"Please manually download it from {DIS_ONNX_MODEL_URL} and place it in the project directory or set the DIS_ONNX_MODEL_PATH environment variable."
63
+ )
64
+
65
+
66
 
67
  def remove_background(
68
  image: PIL.Image.Image,
 
 
 
69
  ) -> PIL.Image.Image:
70
+ ensure_dis_onnx_model()
71
+ # Create a temporary file to save the image
72
+ with tempfile.NamedTemporaryFile(suffix=".png", delete=True) as temp:
73
+ # Save the PIL image to the temporary file
74
+ image.save(temp.name)
75
+ # Call the background removal function with the file path
76
+ extracted_img, _ = dis_remove_background(DIS_ONNX_MODEL_PATH, temp.name)
77
+ # The function should return the PIL image directly
78
+ return extracted_img
 
 
79
 
80
  def do_resize_content(original_image: Image, scale_rate):
81
  # resize image content wile retain the original image size
 
107
  background = Image.new("RGBA", image.size, (0, 0, 0, 0))
108
  image = Image.alpha_composite(background, image)
109
  else:
110
+ image = remove_background(image)
111
+ if image is None:
112
+ raise gr.Error("Background removal failed. Please check the input image and ensure the model file exists and is valid.")
 
 
113
 
114
  @spaces.GPU
115
  def gen_image(input_image, seed, scale, step):