File size: 4,137 Bytes
976b5ca
2cad0ae
976b5ca
 
2cad0ae
 
976b5ca
 
2cad0ae
 
 
 
976b5ca
 
2cad0ae
 
 
976b5ca
2cad0ae
 
 
 
 
 
976b5ca
 
2cad0ae
976b5ca
 
2cad0ae
976b5ca
 
f59bc45
97cb15e
2cad0ae
 
976b5ca
f59bc45
976b5ca
 
 
2cad0ae
 
976b5ca
2cad0ae
f59bc45
2cad0ae
976b5ca
 
 
 
2cad0ae
976b5ca
644d22d
976b5ca
2cad0ae
644d22d
2cad0ae
976b5ca
 
 
 
 
2cad0ae
 
 
 
 
 
976b5ca
 
2cad0ae
976b5ca
2cad0ae
976b5ca
 
 
 
644d22d
2cad0ae
 
 
 
 
 
 
976b5ca
 
2cad0ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
976b5ca
2cad0ae
976b5ca
2cad0ae
dcb9e75
2cad0ae
 
 
 
 
 
 
 
 
 
976b5ca
2cad0ae
 
 
a16970a
2bc1ec9
bf9eb9c
2cad0ae
976b5ca
2cad0ae
976b5ca
 
70c1aa0
2cad0ae
 
 
 
0c332e2
f59bc45
 
2cad0ae
 
 
 
 
f59bc45
2cad0ae
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import os
import warnings
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torchvision import transforms
from PIL import Image
import gradio as gr

# Suppress warnings
warnings.filterwarnings("ignore")

# Clone DIS repo if not exists
if not os.path.exists("DIS"):
    os.system("git clone https://github.com/xuebinqin/DIS")

# Move model files
if not os.path.exists("models.py"):
    os.system("mv DIS/IS-Net/* .")

# Project imports
from data_loader_cache import normalize, im_reader, im_preprocess
from models import *

# Setup device
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Prepare saved models folder
if not os.path.exists("saved_models"):
    os.mkdir("saved_models")
    # NOTE: make sure isnet.pth is available locally
    os.system("mv isnet.pth saved_models/")

# --- Helpers ---
class GOSNormalize(object):
    def __init__(self, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]):
        self.mean = mean
        self.std = std

    def __call__(self, image):
        return normalize(image, self.mean, self.std)

transform = transforms.Compose([
    GOSNormalize([0.5, 0.5, 0.5], [1.0, 1.0, 1.0])
])

def load_image(im_path, hypar):
    im = im_reader(im_path)
    im, im_shp = im_preprocess(im, hypar["cache_size"])
    im = torch.divide(im, 255.0)
    shape = torch.from_numpy(np.array(im_shp))
    return transform(im).unsqueeze(0), shape.unsqueeze(0)

def build_model(hypar, device):
    net = hypar["model"]
    if hypar["model_digit"] == "half":
        net.half()
        for layer in net.modules():
            if isinstance(layer, nn.BatchNorm2d):
                layer.float()
    net.to(device)
    if hypar["restore_model"] != "":
        net.load_state_dict(torch.load(
            os.path.join(hypar["model_path"], hypar["restore_model"]),
            map_location=device
        ))
    net.eval()
    return net

def predict(net, inputs_val, shapes_val, hypar, device):
    net.eval()
    if hypar["model_digit"] == "full":
        inputs_val = inputs_val.type(torch.FloatTensor)
    else:
        inputs_val = inputs_val.type(torch.HalfTensor)

    inputs_val_v = Variable(inputs_val, requires_grad=False).to(device)
    ds_val = net(inputs_val_v)[0]
    pred_val = ds_val[0][0, :, :, :]
    pred_val = torch.squeeze(F.interpolate(
        torch.unsqueeze(pred_val, 0),
        (shapes_val[0][0], shapes_val[0][1]),
        mode='bilinear'
    ))
    ma = torch.max(pred_val)
    mi = torch.min(pred_val)
    pred_val = (pred_val - mi) / (ma - mi)
    if device == 'cuda':
        torch.cuda.empty_cache()
    return (pred_val.detach().cpu().numpy() * 255).astype(np.uint8)

# --- Prepare model ---
hypar = {
    "model_path": "./saved_models",
    "restore_model": "isnet.pth",
    "interm_sup": False,
    "model_digit": "full",
    "seed": 0,
    "cache_size": [1024, 1024],
    "input_size": [1024, 1024],
    "crop_size": [1024, 1024],
    "model": ISNetDIS()
}

net = build_model(hypar, device)

# --- Inference ---
def inference(image):
    image_path = image
    image_tensor, orig_size = load_image(image_path, hypar)
    mask = predict(net, image_tensor, orig_size, hypar, device)
    pil_mask = Image.fromarray(mask).convert('L')

    im_rgb = Image.open(image_path).convert("RGB")
    im_rgba = im_rgb.copy()
    im_rgba.putalpha(pil_mask)

    return [im_rgba, pil_mask]

# --- Custom CSS to hide footer ---
css_hide_footer = """
footer {display: none !important;}
#component-12 {display: none !important;}
#huggingface-space-header {display: none !important;}
button[data-testid="ShareButton"] {display: none !important;}
"""

# --- Gradio Interface ---
interface = gr.Interface(
    fn=inference,
    inputs=gr.Image(type='filepath', height=300, width=300),
    outputs=[
        gr.Image(type='filepath', format="png"),
        gr.Image(type='filepath', format="png", visible=False)
    ],
    flagging_mode="never",
    cache_mode="lazy",
    css=css_hide_footer   # ✅ CSS here inside Interface, not launch
)

interface.launch(
    show_error=False,
    show_api=False,
    share=False
)