File size: 4,137 Bytes
976b5ca 2cad0ae 976b5ca 2cad0ae 976b5ca 2cad0ae 976b5ca 2cad0ae 976b5ca 2cad0ae 976b5ca 2cad0ae 976b5ca 2cad0ae 976b5ca f59bc45 97cb15e 2cad0ae 976b5ca f59bc45 976b5ca 2cad0ae 976b5ca 2cad0ae f59bc45 2cad0ae 976b5ca 2cad0ae 976b5ca 644d22d 976b5ca 2cad0ae 644d22d 2cad0ae 976b5ca 2cad0ae 976b5ca 2cad0ae 976b5ca 2cad0ae 976b5ca 644d22d 2cad0ae 976b5ca 2cad0ae 976b5ca 2cad0ae 976b5ca 2cad0ae dcb9e75 2cad0ae 976b5ca 2cad0ae a16970a 2bc1ec9 bf9eb9c 2cad0ae 976b5ca 2cad0ae 976b5ca 70c1aa0 2cad0ae 0c332e2 f59bc45 2cad0ae f59bc45 2cad0ae |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
import os
import warnings
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torchvision import transforms
from PIL import Image
import gradio as gr
# Suppress warnings
warnings.filterwarnings("ignore")
# Clone DIS repo if not exists
if not os.path.exists("DIS"):
os.system("git clone https://github.com/xuebinqin/DIS")
# Move model files
if not os.path.exists("models.py"):
os.system("mv DIS/IS-Net/* .")
# Project imports
from data_loader_cache import normalize, im_reader, im_preprocess
from models import *
# Setup device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# Prepare saved models folder
if not os.path.exists("saved_models"):
os.mkdir("saved_models")
# NOTE: make sure isnet.pth is available locally
os.system("mv isnet.pth saved_models/")
# --- Helpers ---
class GOSNormalize(object):
def __init__(self, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]):
self.mean = mean
self.std = std
def __call__(self, image):
return normalize(image, self.mean, self.std)
transform = transforms.Compose([
GOSNormalize([0.5, 0.5, 0.5], [1.0, 1.0, 1.0])
])
def load_image(im_path, hypar):
im = im_reader(im_path)
im, im_shp = im_preprocess(im, hypar["cache_size"])
im = torch.divide(im, 255.0)
shape = torch.from_numpy(np.array(im_shp))
return transform(im).unsqueeze(0), shape.unsqueeze(0)
def build_model(hypar, device):
net = hypar["model"]
if hypar["model_digit"] == "half":
net.half()
for layer in net.modules():
if isinstance(layer, nn.BatchNorm2d):
layer.float()
net.to(device)
if hypar["restore_model"] != "":
net.load_state_dict(torch.load(
os.path.join(hypar["model_path"], hypar["restore_model"]),
map_location=device
))
net.eval()
return net
def predict(net, inputs_val, shapes_val, hypar, device):
net.eval()
if hypar["model_digit"] == "full":
inputs_val = inputs_val.type(torch.FloatTensor)
else:
inputs_val = inputs_val.type(torch.HalfTensor)
inputs_val_v = Variable(inputs_val, requires_grad=False).to(device)
ds_val = net(inputs_val_v)[0]
pred_val = ds_val[0][0, :, :, :]
pred_val = torch.squeeze(F.interpolate(
torch.unsqueeze(pred_val, 0),
(shapes_val[0][0], shapes_val[0][1]),
mode='bilinear'
))
ma = torch.max(pred_val)
mi = torch.min(pred_val)
pred_val = (pred_val - mi) / (ma - mi)
if device == 'cuda':
torch.cuda.empty_cache()
return (pred_val.detach().cpu().numpy() * 255).astype(np.uint8)
# --- Prepare model ---
hypar = {
"model_path": "./saved_models",
"restore_model": "isnet.pth",
"interm_sup": False,
"model_digit": "full",
"seed": 0,
"cache_size": [1024, 1024],
"input_size": [1024, 1024],
"crop_size": [1024, 1024],
"model": ISNetDIS()
}
net = build_model(hypar, device)
# --- Inference ---
def inference(image):
image_path = image
image_tensor, orig_size = load_image(image_path, hypar)
mask = predict(net, image_tensor, orig_size, hypar, device)
pil_mask = Image.fromarray(mask).convert('L')
im_rgb = Image.open(image_path).convert("RGB")
im_rgba = im_rgb.copy()
im_rgba.putalpha(pil_mask)
return [im_rgba, pil_mask]
# --- Custom CSS to hide footer ---
css_hide_footer = """
footer {display: none !important;}
#component-12 {display: none !important;}
#huggingface-space-header {display: none !important;}
button[data-testid="ShareButton"] {display: none !important;}
"""
# --- Gradio Interface ---
interface = gr.Interface(
fn=inference,
inputs=gr.Image(type='filepath', height=300, width=300),
outputs=[
gr.Image(type='filepath', format="png"),
gr.Image(type='filepath', format="png", visible=False)
],
flagging_mode="never",
cache_mode="lazy",
css=css_hide_footer # ✅ CSS here inside Interface, not launch
)
interface.launch(
show_error=False,
show_api=False,
share=False
)
|