sourav520 commited on
Commit
644d22d
·
verified ·
1 Parent(s): 7d3e050

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -50
app.py CHANGED
@@ -47,103 +47,69 @@ def load_image(im_path, hypar):
47
  im, im_shp = im_preprocess(im, hypar["cache_size"])
48
  im = torch.divide(im,255.0)
49
  shape = torch.from_numpy(np.array(im_shp))
50
- return transform(im).unsqueeze(0), shape.unsqueeze(0) # make a batch of image, shape
51
 
52
 
53
  def build_model(hypar,device):
54
- net = hypar["model"]#GOSNETINC(3,1)
55
-
56
- # convert to half precision
57
  if(hypar["model_digit"]=="half"):
58
  net.half()
59
  for layer in net.modules():
60
  if isinstance(layer, nn.BatchNorm2d):
61
  layer.float()
62
-
63
  net.to(device)
64
-
65
  if(hypar["restore_model"]!=""):
66
  net.load_state_dict(torch.load(hypar["model_path"]+"/"+hypar["restore_model"], map_location=device))
67
  net.to(device)
68
  net.eval()
69
  return net
70
 
71
-
72
  def predict(net, inputs_val, shapes_val, hypar, device):
73
- '''
74
- Given an Image, predict the mask
75
- '''
76
  net.eval()
77
-
78
  if(hypar["model_digit"]=="full"):
79
  inputs_val = inputs_val.type(torch.FloatTensor)
80
  else:
81
  inputs_val = inputs_val.type(torch.HalfTensor)
82
 
83
-
84
- inputs_val_v = Variable(inputs_val, requires_grad=False).to(device) # wrap inputs in Variable
85
-
86
- ds_val = net(inputs_val_v)[0] # list of 6 results
87
-
88
- pred_val = ds_val[0][0,:,:,:] # B x 1 x H x W # we want the first one which is the most accurate prediction
89
-
90
- ## recover the prediction spatial size to the orignal image size
91
  pred_val = torch.squeeze(F.upsample(torch.unsqueeze(pred_val,0),(shapes_val[0][0],shapes_val[0][1]),mode='bilinear'))
92
-
93
  ma = torch.max(pred_val)
94
  mi = torch.min(pred_val)
95
- pred_val = (pred_val-mi)/(ma-mi) # max = 1
96
-
97
  if device == 'cuda': torch.cuda.empty_cache()
98
- return (pred_val.detach().cpu().numpy()*255).astype(np.uint8) # it is the mask we need
99
 
100
- # Set Parameters
101
- hypar = {} # paramters for inferencing
102
-
103
-
104
- hypar["model_path"] ="./saved_models" ## load trained weights from this path
105
- hypar["restore_model"] = "isnet.pth" ## name of the to-be-loaded weights
106
- hypar["interm_sup"] = False ## indicate if activate intermediate feature supervision
107
-
108
- ## choose floating point accuracy --
109
- hypar["model_digit"] = "full" ## indicates "half" or "full" accuracy of float number
110
  hypar["seed"] = 0
111
 
112
- hypar["cache_size"] = [1024, 1024] ## cached input spatial resolution, can be configured into different size
113
-
114
- ## data augmentation parameters ---
115
- hypar["input_size"] = [1024, 1024] ## mdoel input spatial size, usually use the same value hypar["cache_size"], which means we don't further resize the images
116
- hypar["crop_size"] = [1024, 1024] ## random crop size from the input, it is usually set as smaller than hypar["cache_size"], e.g., [920,920] for data augmentation
117
-
118
  hypar["model"] = ISNetDIS()
119
-
120
- # Build Model
121
  net = build_model(hypar, device)
122
 
123
 
124
  def inference(image):
125
  image_path = image
126
-
127
  image_tensor, orig_size = load_image(image_path, hypar)
128
  mask = predict(net, image_tensor, orig_size, hypar, device)
129
-
130
  pil_mask = Image.fromarray(mask).convert('L')
131
  im_rgb = Image.open(image).convert("RGB")
132
-
133
  im_rgba = im_rgb.copy()
134
  im_rgba.putalpha(pil_mask)
135
-
136
  return [im_rgba, pil_mask]
137
 
138
 
139
- title = ""
140
- description = ""
141
- article = ""
142
-
143
  interface = gr.Interface(
144
  fn=inference,
145
  inputs=gr.Image(type='filepath'),
146
  outputs=[gr.Image(type='filepath', format="png"), gr.Image(type='filepath', format="png", visible=False)],
147
  flagging_mode="never",
148
  cache_mode="lazy",
149
- ).queue(api_open=True).launch(show_error=True, show_api=True)
 
47
  im, im_shp = im_preprocess(im, hypar["cache_size"])
48
  im = torch.divide(im,255.0)
49
  shape = torch.from_numpy(np.array(im_shp))
50
+ return transform(im).unsqueeze(0), shape.unsqueeze(0)
51
 
52
 
53
  def build_model(hypar,device):
54
+ net = hypar["model"]
 
 
55
  if(hypar["model_digit"]=="half"):
56
  net.half()
57
  for layer in net.modules():
58
  if isinstance(layer, nn.BatchNorm2d):
59
  layer.float()
 
60
  net.to(device)
 
61
  if(hypar["restore_model"]!=""):
62
  net.load_state_dict(torch.load(hypar["model_path"]+"/"+hypar["restore_model"], map_location=device))
63
  net.to(device)
64
  net.eval()
65
  return net
66
 
 
67
  def predict(net, inputs_val, shapes_val, hypar, device):
 
 
 
68
  net.eval()
 
69
  if(hypar["model_digit"]=="full"):
70
  inputs_val = inputs_val.type(torch.FloatTensor)
71
  else:
72
  inputs_val = inputs_val.type(torch.HalfTensor)
73
 
74
+ inputs_val_v = Variable(inputs_val, requires_grad=False).to(device)
75
+ ds_val = net(inputs_val_v)[0]
76
+ pred_val = ds_val[0][0,:,:,:]
 
 
 
 
 
77
  pred_val = torch.squeeze(F.upsample(torch.unsqueeze(pred_val,0),(shapes_val[0][0],shapes_val[0][1]),mode='bilinear'))
 
78
  ma = torch.max(pred_val)
79
  mi = torch.min(pred_val)
80
+ pred_val = (pred_val-mi)/(ma-mi)
 
81
  if device == 'cuda': torch.cuda.empty_cache()
82
+ return (pred_val.detach().cpu().numpy()*255).astype(np.uint8)
83
 
84
+ hypar = {}
85
+ hypar["model_path"] ="./saved_models"
86
+ hypar["restore_model"] = "isnet.pth"
87
+ hypar["interm_sup"] = False
88
+ hypar["model_digit"] = "full"
 
 
 
 
 
89
  hypar["seed"] = 0
90
 
91
+ hypar["cache_size"] = [1024, 1024]
92
+ hypar["input_size"] = [1024, 1024]
93
+ hypar["crop_size"] = [1024, 1024]
 
 
 
94
  hypar["model"] = ISNetDIS()
 
 
95
  net = build_model(hypar, device)
96
 
97
 
98
  def inference(image):
99
  image_path = image
 
100
  image_tensor, orig_size = load_image(image_path, hypar)
101
  mask = predict(net, image_tensor, orig_size, hypar, device)
 
102
  pil_mask = Image.fromarray(mask).convert('L')
103
  im_rgb = Image.open(image).convert("RGB")
 
104
  im_rgba = im_rgb.copy()
105
  im_rgba.putalpha(pil_mask)
 
106
  return [im_rgba, pil_mask]
107
 
108
 
 
 
 
 
109
  interface = gr.Interface(
110
  fn=inference,
111
  inputs=gr.Image(type='filepath'),
112
  outputs=[gr.Image(type='filepath', format="png"), gr.Image(type='filepath', format="png", visible=False)],
113
  flagging_mode="never",
114
  cache_mode="lazy",
115
+ ).queue(api_open=True).launch(show_error=True, show_api=True, share=False)