Spaces:
Build error
Build error
init
Browse files- app.py +11 -0
- requirements.txt +3 -1
app.py
CHANGED
|
@@ -27,6 +27,8 @@ from data_loader import SalObjDataset
|
|
| 27 |
from model import U2NET # full size version 173.6 MB
|
| 28 |
from model import U2NETP # small version u2net 4.7 MB
|
| 29 |
|
|
|
|
|
|
|
| 30 |
|
| 31 |
# normalize the predicted SOD probability map
|
| 32 |
def normPRED(d):
|
|
@@ -58,6 +60,12 @@ def save_output(image_name,pred,d_dir):
|
|
| 58 |
return d_dir+'/'+imidx+'.png'
|
| 59 |
|
| 60 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
# --------- 1. get image path and name ---------
|
| 62 |
model_name='u2net_portrait'#u2netp
|
| 63 |
|
|
@@ -82,6 +90,9 @@ net.eval()
|
|
| 82 |
|
| 83 |
|
| 84 |
def process(im):
|
|
|
|
|
|
|
|
|
|
| 85 |
img_name_list = glob.glob(im.name)
|
| 86 |
print("Number of images: ", len(img_name_list))
|
| 87 |
# --------- 2. dataloader ---------
|
|
|
|
| 27 |
from model import U2NET # full size version 173.6 MB
|
| 28 |
from model import U2NETP # small version u2net 4.7 MB
|
| 29 |
|
| 30 |
+
from modnet import ModNet
|
| 31 |
+
import huggingface_hub
|
| 32 |
|
| 33 |
# normalize the predicted SOD probability map
|
| 34 |
def normPRED(d):
|
|
|
|
| 60 |
return d_dir+'/'+imidx+'.png'
|
| 61 |
|
| 62 |
|
| 63 |
+
|
| 64 |
+
modnet_path = huggingface_hub.hf_hub_download('hylee/apdrawing_model',
|
| 65 |
+
'modnet.onnx',
|
| 66 |
+
force_filename='modnet.onnx')
|
| 67 |
+
modnet = ModNet(modnet_path)
|
| 68 |
+
|
| 69 |
# --------- 1. get image path and name ---------
|
| 70 |
model_name='u2net_portrait'#u2netp
|
| 71 |
|
|
|
|
| 90 |
|
| 91 |
|
| 92 |
def process(im):
|
| 93 |
+
image = modnet.segment(im.name)
|
| 94 |
+
Image.fromarray(np.uint8(image)).save(im.name)
|
| 95 |
+
|
| 96 |
img_name_list = glob.glob(im.name)
|
| 97 |
print("Number of images: ", len(img_name_list))
|
| 98 |
# --------- 2. dataloader ---------
|
requirements.txt
CHANGED
|
@@ -3,4 +3,6 @@ scikit-image
|
|
| 3 |
torch
|
| 4 |
torchvision
|
| 5 |
pillow
|
| 6 |
-
opencv-python-headless
|
|
|
|
|
|
|
|
|
| 3 |
torch
|
| 4 |
torchvision
|
| 5 |
pillow
|
| 6 |
+
opencv-python-headless
|
| 7 |
+
onnx==1.8.1
|
| 8 |
+
onnxruntime==1.6.0
|