Spaces:
Configuration error
Configuration error
import os | |
from .inference import init_segmentor, inference_segmentor, show_result_pyplot | |
import warnings | |
import cv2 | |
import numpy as np | |
from PIL import Image | |
from custom_controlnet_aux.util import HWC3, common_input_validate, resize_image_with_pad, custom_hf_download, HF_MODEL_NAME | |
import torch | |
from custom_mmpkg.custom_mmseg.core.evaluation import get_palette | |
config_file = os.path.join(os.path.dirname(os.path.realpath(__file__)), "upernet_global_small.py") | |
class UniformerSegmentor: | |
def __init__(self, netNetwork): | |
self.model = netNetwork | |
def from_pretrained(cls, pretrained_model_or_path=HF_MODEL_NAME, filename="upernet_global_small.pth"): | |
model_path = custom_hf_download(pretrained_model_or_path, filename) | |
netNetwork = init_segmentor(config_file, model_path, device="cpu") | |
netNetwork.load_state_dict({k.replace('module.', ''): v for k, v in torch.load(model_path)['state_dict'].items()}) | |
netNetwork.eval() | |
return cls(netNetwork) | |
def to(self, device): | |
self.model.to(device) | |
return self | |
def _inference(self, img): | |
if next(self.model.parameters()).device.type == 'mps': | |
# adaptive_avg_pool2d can fail on MPS, workaround with CPU | |
import torch.nn.functional | |
orig_adaptive_avg_pool2d = torch.nn.functional.adaptive_avg_pool2d | |
def cpu_if_exception(input, *args, **kwargs): | |
try: | |
return orig_adaptive_avg_pool2d(input, *args, **kwargs) | |
except: | |
return orig_adaptive_avg_pool2d(input.cpu(), *args, **kwargs).to(input.device) | |
try: | |
torch.nn.functional.adaptive_avg_pool2d = cpu_if_exception | |
result = inference_segmentor(self.model, img) | |
finally: | |
torch.nn.functional.adaptive_avg_pool2d = orig_adaptive_avg_pool2d | |
else: | |
result = inference_segmentor(self.model, img) | |
res_img = show_result_pyplot(self.model, img, result, get_palette('ade'), opacity=1) | |
return res_img | |
def __call__(self, input_image=None, detect_resolution=512, output_type=None, upscale_method="INTER_CUBIC", **kwargs): | |
input_image, output_type = common_input_validate(input_image, output_type, **kwargs) | |
input_image, remove_pad = resize_image_with_pad(input_image, detect_resolution, upscale_method) | |
detected_map = self._inference(input_image) | |
detected_map = remove_pad(HWC3(detected_map)) | |
if output_type == "pil": | |
detected_map = Image.fromarray(detected_map) | |
return detected_map | |