sourav520 commited on
Commit
2cad0ae
·
verified ·
1 Parent(s): 70c1aa0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +92 -62
app.py CHANGED
@@ -1,115 +1,145 @@
1
- import cv2
2
- import gradio as gr
3
  import os
4
- from PIL import Image
5
  import numpy as np
6
  import torch
 
 
7
  from torch.autograd import Variable
8
  from torchvision import transforms
9
- import torch.nn.functional as F
10
- import gdown
11
- import matplotlib.pyplot as plt
12
- import warnings
13
  warnings.filterwarnings("ignore")
14
 
15
- os.system("git clone https://github.com/xuebinqin/DIS")
16
- os.system("mv DIS/IS-Net/* .")
 
17
 
18
- # project imports
19
- from data_loader_cache import normalize, im_reader, im_preprocess
 
 
 
 
20
  from models import *
21
 
22
- #Helpers
23
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
24
 
25
- # Download official weights
26
  if not os.path.exists("saved_models"):
27
  os.mkdir("saved_models")
 
28
  os.system("mv isnet.pth saved_models/")
29
-
 
30
  class GOSNormalize(object):
31
- '''
32
- Normalize the Image using torch.transforms
33
- '''
34
  def __init__(self, mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]):
35
  self.mean = mean
36
  self.std = std
37
 
38
- def __call__(self,image):
39
- image = normalize(image,self.mean,self.std)
40
- return image
41
-
42
 
43
- transform = transforms.Compose([GOSNormalize([0.5,0.5,0.5],[1.0,1.0,1.0])])
 
 
44
 
45
  def load_image(im_path, hypar):
46
  im = im_reader(im_path)
47
  im, im_shp = im_preprocess(im, hypar["cache_size"])
48
- im = torch.divide(im,255.0)
49
  shape = torch.from_numpy(np.array(im_shp))
50
  return transform(im).unsqueeze(0), shape.unsqueeze(0)
51
 
52
-
53
- def build_model(hypar,device):
54
  net = hypar["model"]
55
- if(hypar["model_digit"]=="half"):
56
  net.half()
57
  for layer in net.modules():
58
  if isinstance(layer, nn.BatchNorm2d):
59
  layer.float()
60
  net.to(device)
61
- if(hypar["restore_model"]!=""):
62
- net.load_state_dict(torch.load(hypar["model_path"]+"/"+hypar["restore_model"], map_location=device))
63
- net.to(device)
64
- net.eval()
 
 
65
  return net
66
 
67
- def predict(net, inputs_val, shapes_val, hypar, device):
68
  net.eval()
69
- if(hypar["model_digit"]=="full"):
70
  inputs_val = inputs_val.type(torch.FloatTensor)
71
  else:
72
  inputs_val = inputs_val.type(torch.HalfTensor)
73
 
74
  inputs_val_v = Variable(inputs_val, requires_grad=False).to(device)
75
- ds_val = net(inputs_val_v)[0]
76
- pred_val = ds_val[0][0,:,:,:]
77
- pred_val = torch.squeeze(F.upsample(torch.unsqueeze(pred_val,0),(shapes_val[0][0],shapes_val[0][1]),mode='bilinear'))
 
 
 
 
78
  ma = torch.max(pred_val)
79
  mi = torch.min(pred_val)
80
- pred_val = (pred_val-mi)/(ma-mi)
81
- if device == 'cuda': torch.cuda.empty_cache()
82
- return (pred_val.detach().cpu().numpy()*255).astype(np.uint8)
83
-
84
- hypar = {}
85
- hypar["model_path"] ="./saved_models"
86
- hypar["restore_model"] = "isnet.pth"
87
- hypar["interm_sup"] = False
88
- hypar["model_digit"] = "full"
89
- hypar["seed"] = 0
90
-
91
- hypar["cache_size"] = [1024, 1024]
92
- hypar["input_size"] = [1024, 1024]
93
- hypar["crop_size"] = [1024, 1024]
94
- hypar["model"] = ISNetDIS()
95
- net = build_model(hypar, device)
 
96
 
 
97
 
 
98
  def inference(image):
99
- image_path = image
100
- image_tensor, orig_size = load_image(image_path, hypar)
101
- mask = predict(net, image_tensor, orig_size, hypar, device)
102
- pil_mask = Image.fromarray(mask).convert('L')
103
- im_rgb = Image.open(image).convert("RGB")
104
- im_rgba = im_rgb.copy()
105
- im_rgba.putalpha(pil_mask)
106
- return [im_rgba, pil_mask]
 
 
107
 
 
 
 
 
 
108
 
 
109
  interface = gr.Interface(
110
  fn=inference,
111
  inputs=gr.Image(type='filepath', height=300, width=300),
112
- outputs=[gr.Image(type='filepath', format="png"), gr.Image(type='filepath', format="png", visible=False)],
 
 
 
113
  flagging_mode="never",
114
- cache_mode="lazy",
115
- ).queue(api_open=False).launch(show_error=False, show_api=False, share=False)
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+ import warnings
3
  import numpy as np
4
  import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
  from torch.autograd import Variable
8
  from torchvision import transforms
9
+ from PIL import Image
10
+ import gradio as gr
11
+
12
+ # Suppress warnings
13
  warnings.filterwarnings("ignore")
14
 
15
+ # Clone DIS repo if not exists
16
+ if not os.path.exists("DIS"):
17
+ os.system("git clone https://github.com/xuebinqin/DIS")
18
 
19
+ # Move model files
20
+ if not os.path.exists("models.py"):
21
+ os.system("mv DIS/IS-Net/* .")
22
+
23
+ # Project imports
24
+ from data_loader_cache import normalize, im_reader, im_preprocess
25
  from models import *
26
 
27
+ # Setup device
28
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
29
 
30
+ # Prepare saved models folder
31
  if not os.path.exists("saved_models"):
32
  os.mkdir("saved_models")
33
+ # NOTE: make sure isnet.pth is available, otherwise manual download needed
34
  os.system("mv isnet.pth saved_models/")
35
+
36
+ # --- Helpers ---
37
  class GOSNormalize(object):
 
 
 
38
  def __init__(self, mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]):
39
  self.mean = mean
40
  self.std = std
41
 
42
+ def __call__(self, image):
43
+ return normalize(image, self.mean, self.std)
 
 
44
 
45
+ transform = transforms.Compose([
46
+ GOSNormalize([0.5,0.5,0.5],[1.0,1.0,1.0])
47
+ ])
48
 
49
  def load_image(im_path, hypar):
50
  im = im_reader(im_path)
51
  im, im_shp = im_preprocess(im, hypar["cache_size"])
52
+ im = torch.divide(im, 255.0)
53
  shape = torch.from_numpy(np.array(im_shp))
54
  return transform(im).unsqueeze(0), shape.unsqueeze(0)
55
 
56
+ def build_model(hypar, device):
 
57
  net = hypar["model"]
58
+ if hypar["model_digit"] == "half":
59
  net.half()
60
  for layer in net.modules():
61
  if isinstance(layer, nn.BatchNorm2d):
62
  layer.float()
63
  net.to(device)
64
+ if hypar["restore_model"] != "":
65
+ net.load_state_dict(torch.load(
66
+ os.path.join(hypar["model_path"], hypar["restore_model"]),
67
+ map_location=device
68
+ ))
69
+ net.eval()
70
  return net
71
 
72
+ def predict(net, inputs_val, shapes_val, hypar, device):
73
  net.eval()
74
+ if hypar["model_digit"] == "full":
75
  inputs_val = inputs_val.type(torch.FloatTensor)
76
  else:
77
  inputs_val = inputs_val.type(torch.HalfTensor)
78
 
79
  inputs_val_v = Variable(inputs_val, requires_grad=False).to(device)
80
+ ds_val = net(inputs_val_v)[0]
81
+ pred_val = ds_val[0][0, :, :, :]
82
+ pred_val = torch.squeeze(F.interpolate(
83
+ torch.unsqueeze(pred_val, 0),
84
+ (shapes_val[0][0], shapes_val[0][1]),
85
+ mode='bilinear'
86
+ ))
87
  ma = torch.max(pred_val)
88
  mi = torch.min(pred_val)
89
+ pred_val = (pred_val - mi) / (ma - mi)
90
+ if device == 'cuda':
91
+ torch.cuda.empty_cache()
92
+ return (pred_val.detach().cpu().numpy() * 255).astype(np.uint8)
93
+
94
+ # --- Prepare model ---
95
+ hypar = {
96
+ "model_path": "./saved_models",
97
+ "restore_model": "isnet.pth",
98
+ "interm_sup": False,
99
+ "model_digit": "full",
100
+ "seed": 0,
101
+ "cache_size": [1024, 1024],
102
+ "input_size": [1024, 1024],
103
+ "crop_size": [1024, 1024],
104
+ "model": ISNetDIS()
105
+ }
106
 
107
+ net = build_model(hypar, device)
108
 
109
+ # --- Inference ---
110
  def inference(image):
111
+ image_path = image
112
+ image_tensor, orig_size = load_image(image_path, hypar)
113
+ mask = predict(net, image_tensor, orig_size, hypar, device)
114
+ pil_mask = Image.fromarray(mask).convert('L')
115
+
116
+ im_rgb = Image.open(image_path).convert("RGB")
117
+ im_rgba = im_rgb.copy()
118
+ im_rgba.putalpha(pil_mask)
119
+
120
+ return [im_rgba, pil_mask]
121
 
122
+ # --- Custom CSS to hide footer ---
123
+ css_hide_footer = """
124
+ footer {display: none !important;}
125
+ #share-btn-container {display: none !important;}
126
+ """
127
 
128
+ # --- Gradio Interface ---
129
  interface = gr.Interface(
130
  fn=inference,
131
  inputs=gr.Image(type='filepath', height=300, width=300),
132
+ outputs=[
133
+ gr.Image(type='filepath', format="png"),
134
+ gr.Image(type='filepath', format="png", visible=False)
135
+ ],
136
  flagging_mode="never",
137
+ cache_mode="lazy"
138
+ )
139
+
140
+ interface.launch(
141
+ show_error=False,
142
+ show_api=False,
143
+ share=False,
144
+ css=css_hide_footer
145
+ )