Spaces:
Runtime error
Runtime error
File size: 3,042 Bytes
f2dbf59 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 |
from .imagefunc import *
import torch.nn as nn
from torchvision import transforms
from .BiRefNet_legacy.baseline import BiRefNet
from .BiRefNet_legacy.config import Config
class BiRefNet_img_processor:
def __init__(self, config):
self.config = config
self.data_size = (config.size, config.size)
self.transform_image = transforms.Compose([
transforms.Resize(self.data_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
def __call__(self, _image: np.array):
_image_rs = cv2.resize(_image, (self.config.size, self.config.size), interpolation=cv2.INTER_LINEAR)
_image_rs = Image.fromarray(np.uint8(_image_rs*255)).convert('RGB')
image = self.transform_image(_image_rs)
return image
class BiRefNetRemoveBackground:
def __init__(self):
self.ready = False
def load(self, weight_path, device):
# load model
self.model = BiRefNet()
state_dict = torch.load(weight_path, map_location='cpu')
unwanted_prefix = '_orig_mod.'
for k, v in list(state_dict.items()):
if k.startswith(unwanted_prefix):
state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
self.model.load_state_dict(state_dict)
self.model = self.model.to(device)
self.model.eval()
# load processor
self.processor = BiRefNet_img_processor(Config())
self.ready = True
def generate_mask(self, image:Image) -> Image:
if torch.backends.mps.is_available():
device = "mps"
elif torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
if not self.ready:
model_folder_name = 'BiRefNet'
model_name = 'BiRefNet-ep480.pth'
model_file_path = ""
try:
model_file_path = os.path.join(
os.path.normpath(folder_paths.folder_names_and_paths[model_folder_name][0][0]), model_name)
except:
pass
if not os.path.exists(model_file_path):
model_file_path = os.path.join(folder_paths.models_dir, model_folder_name, model_name)
self.load(model_file_path, device=device)
i = pil2tensor(image)
orig_image = image.convert('RGB')
np_image = i.squeeze().numpy()
img = self.processor(np_image)
inputs = img[None, ...].to(device)
with torch.no_grad():
scaled_preds = self.model(inputs)[-1].sigmoid()
_mask = nn.functional.interpolate(scaled_preds[0].unsqueeze(0),
size=np_image.shape[:2],
mode='bilinear',
align_corners=True
)[0]
brightness_image = ImageEnhance.Brightness(tensor2pil(_mask))
return brightness_image.enhance(factor=1.01)
|