File size: 3,511 Bytes
976b5ca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97cb15e
976b5ca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
644d22d
976b5ca
 
 
644d22d
976b5ca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
644d22d
 
 
976b5ca
 
 
644d22d
976b5ca
644d22d
976b5ca
644d22d
 
 
 
 
976b5ca
 
644d22d
 
 
976b5ca
 
 
 
dcb9e75
976b5ca
 
 
 
 
 
 
35a4532
976b5ca
 
 
 
 
7d3e050
0c332e2
 
08a8f1a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import cv2
import gradio as gr
import os
from PIL import Image
import numpy as np
import torch
from torch.autograd import Variable
from torchvision import transforms
import torch.nn.functional as F
import gdown
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings("ignore")

os.system("git clone https://github.com/xuebinqin/DIS")
os.system("mv DIS/IS-Net/* .")

# project imports
from data_loader_cache import normalize, im_reader, im_preprocess 
from models import *

#Helpers
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Download official weights
if not os.path.exists("saved_models"):
    os.mkdir("saved_models")
    os.system("mv isnet.pth saved_models/")
    
class GOSNormalize(object):
    '''
    Normalize the Image using torch.transforms
    '''
    def __init__(self, mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]):
        self.mean = mean
        self.std = std

    def __call__(self,image):
        image = normalize(image,self.mean,self.std)
        return image


transform =  transforms.Compose([GOSNormalize([0.5,0.5,0.5],[1.0,1.0,1.0])])

def load_image(im_path, hypar):
    im = im_reader(im_path)
    im, im_shp = im_preprocess(im, hypar["cache_size"])
    im = torch.divide(im,255.0)
    shape = torch.from_numpy(np.array(im_shp))
    return transform(im).unsqueeze(0), shape.unsqueeze(0)


def build_model(hypar,device):
    net = hypar["model"]
    if(hypar["model_digit"]=="half"):
        net.half()
        for layer in net.modules():
            if isinstance(layer, nn.BatchNorm2d):
                layer.float()
    net.to(device)
    if(hypar["restore_model"]!=""):
        net.load_state_dict(torch.load(hypar["model_path"]+"/"+hypar["restore_model"], map_location=device))
        net.to(device)
    net.eval()  
    return net

def predict(net,  inputs_val, shapes_val, hypar, device):
    net.eval()
    if(hypar["model_digit"]=="full"):
        inputs_val = inputs_val.type(torch.FloatTensor)
    else:
        inputs_val = inputs_val.type(torch.HalfTensor)

    inputs_val_v = Variable(inputs_val, requires_grad=False).to(device)
    ds_val = net(inputs_val_v)[0] 
    pred_val = ds_val[0][0,:,:,:]
    pred_val = torch.squeeze(F.upsample(torch.unsqueeze(pred_val,0),(shapes_val[0][0],shapes_val[0][1]),mode='bilinear'))
    ma = torch.max(pred_val)
    mi = torch.min(pred_val)
    pred_val = (pred_val-mi)/(ma-mi)
    if device == 'cuda': torch.cuda.empty_cache()
    return (pred_val.detach().cpu().numpy()*255).astype(np.uint8)
    
hypar = {}
hypar["model_path"] ="./saved_models"
hypar["restore_model"] = "isnet.pth"
hypar["interm_sup"] = False
hypar["model_digit"] = "full"
hypar["seed"] = 0

hypar["cache_size"] = [1024, 1024]
hypar["input_size"] = [1024, 1024]
hypar["crop_size"] = [1024, 1024]
hypar["model"] = ISNetDIS()
net = build_model(hypar, device)


def inference(image):
  image_path = image
  image_tensor, orig_size = load_image(image_path, hypar) 
  mask = predict(net, image_tensor, orig_size, hypar, device)
  pil_mask = Image.fromarray(mask).convert('L')
  im_rgb = Image.open(image).convert("RGB")
  im_rgba = im_rgb.copy()
  im_rgba.putalpha(pil_mask)
  return [im_rgba, pil_mask]


interface = gr.Interface(
    fn=inference,
    inputs=gr.Image(type='filepath'),
    outputs=[gr.Image(type='filepath', format="png"), gr.Image(type='filepath', format="png", visible=False)],
    flagging_mode="never",
    cache_mode="lazy",
    ).queue(api_open=False).launch(show_error=False, show_api=False, share=False)