Spaces:
Runtime error
Runtime error
Update freesplatter/utils/infer_util.py
Browse files- freesplatter/utils/infer_util.py +31 -31
freesplatter/utils/infer_util.py
CHANGED
|
@@ -68,36 +68,10 @@ def get_obj_from_str(string, reload=False):
|
|
| 68 |
# return image
|
| 69 |
|
| 70 |
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
# force: bool = False,
|
| 76 |
-
# **rembg_kwargs,
|
| 77 |
-
# ) -> PIL.Image.Image:
|
| 78 |
-
# do_remove = True
|
| 79 |
-
# if image.mode == "RGBA" and image.getextrema()[3][0] < 255:
|
| 80 |
-
# do_remove = False
|
| 81 |
-
# do_remove = do_remove or force
|
| 82 |
-
# if do_remove:
|
| 83 |
-
# transform_image = transforms.Compose([
|
| 84 |
-
# transforms.Resize((1024, 1024)),
|
| 85 |
-
# transforms.ToTensor(),
|
| 86 |
-
# transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
| 87 |
-
# ])
|
| 88 |
-
# image = image.convert('RGB')
|
| 89 |
-
# input_images = transform_image(image).unsqueeze(0).to(rembg.device)
|
| 90 |
-
# with torch.no_grad():
|
| 91 |
-
# preds = rembg(input_images)[-1].sigmoid().cpu()
|
| 92 |
-
# pred = preds[0].squeeze()
|
| 93 |
-
# pred_pil = transforms.ToPILImage()(pred)
|
| 94 |
-
# mask = pred_pil.resize(image.size)
|
| 95 |
-
# image.putalpha(mask)
|
| 96 |
-
# return image
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
def remove_background(image: PIL.Image.Image,
|
| 100 |
-
rembg_session: Any = None,
|
| 101 |
force: bool = False,
|
| 102 |
**rembg_kwargs,
|
| 103 |
) -> PIL.Image.Image:
|
|
@@ -106,10 +80,36 @@ def remove_background(image: PIL.Image.Image,
|
|
| 106 |
do_remove = False
|
| 107 |
do_remove = do_remove or force
|
| 108 |
if do_remove:
|
| 109 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 110 |
return image
|
| 111 |
|
| 112 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 113 |
def resize_foreground(
|
| 114 |
image: PIL.Image.Image,
|
| 115 |
ratio: float,
|
|
|
|
| 68 |
# return image
|
| 69 |
|
| 70 |
|
| 71 |
+
@torch.inference_mode()
|
| 72 |
+
def remove_background(
|
| 73 |
+
image: PIL.Image.Image,
|
| 74 |
+
rembg: Any = None,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
force: bool = False,
|
| 76 |
**rembg_kwargs,
|
| 77 |
) -> PIL.Image.Image:
|
|
|
|
| 80 |
do_remove = False
|
| 81 |
do_remove = do_remove or force
|
| 82 |
if do_remove:
|
| 83 |
+
transform_image = transforms.Compose([
|
| 84 |
+
transforms.Resize((1024, 1024)),
|
| 85 |
+
transforms.ToTensor(),
|
| 86 |
+
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
| 87 |
+
])
|
| 88 |
+
image = image.convert('RGB')
|
| 89 |
+
input_images = transform_image(image).unsqueeze(0).to(rembg.device)
|
| 90 |
+
with torch.no_grad():
|
| 91 |
+
preds = rembg(input_images)[-1].sigmoid().cpu()
|
| 92 |
+
pred = preds[0].squeeze()
|
| 93 |
+
pred_pil = transforms.ToPILImage()(pred)
|
| 94 |
+
mask = pred_pil.resize(image.size)
|
| 95 |
+
image.putalpha(mask)
|
| 96 |
return image
|
| 97 |
|
| 98 |
|
| 99 |
+
# def remove_background(image: PIL.Image.Image,
|
| 100 |
+
# rembg_session: Any = None,
|
| 101 |
+
# force: bool = False,
|
| 102 |
+
# **rembg_kwargs,
|
| 103 |
+
# ) -> PIL.Image.Image:
|
| 104 |
+
# do_remove = True
|
| 105 |
+
# if image.mode == "RGBA" and image.getextrema()[3][0] < 255:
|
| 106 |
+
# do_remove = False
|
| 107 |
+
# do_remove = do_remove or force
|
| 108 |
+
# if do_remove:
|
| 109 |
+
# image = rembg.remove(image, session=rembg_session, **rembg_kwargs)
|
| 110 |
+
# return image
|
| 111 |
+
|
| 112 |
+
|
| 113 |
def resize_foreground(
|
| 114 |
image: PIL.Image.Image,
|
| 115 |
ratio: float,
|