Spaces:
Running
Running
| import cv2 | |
| import numpy as np | |
| import tensorflow as tf | |
| from PIL import Image | |
| if tf.__version__ >= '2.0': | |
| tf = tf.compat.v1 | |
| class ImageUniversalMatting: | |
| def __init__(self, weight_path): | |
| super().__init__() | |
| config = tf.ConfigProto(allow_soft_placement=True, device_count={'GPU': 0}) | |
| config.gpu_options.allow_growth = True | |
| self._session = tf.Session(config=config) | |
| with self._session.as_default(): | |
| print(f'loading model from {weight_path}') | |
| with tf.gfile.FastGFile(weight_path, 'rb') as f: | |
| graph_def = tf.GraphDef() | |
| graph_def.ParseFromString(f.read()) | |
| tf.import_graph_def(graph_def, name='') | |
| self.output = self._session.graph.get_tensor_by_name( | |
| 'output_png:0') | |
| self.input_name = 'input_image:0' | |
| print('load model done') | |
| self._session.graph.finalize() | |
| def __call__(self, image): | |
| output = self.preprocess(image) | |
| output = self.forward(output) | |
| output = self.postprocess(output) | |
| return output | |
| def resize_image(self, img, limit_side_len): | |
| """ | |
| resize image to a size multiple of 32 which is required by the network | |
| args: | |
| img(array): array with shape [h, w, c] | |
| return(tuple): | |
| img, (ratio_h, ratio_w) | |
| """ | |
| h, w, _ = img.shape | |
| # limit the max side | |
| if max(h, w) > limit_side_len: | |
| if h > w: | |
| ratio = float(limit_side_len) / h | |
| else: | |
| ratio = float(limit_side_len) / w | |
| else: | |
| ratio = 1. | |
| resize_h = int(h * ratio) | |
| resize_w = int(w * ratio) | |
| resize_h = int(round(resize_h / 32) * 32) | |
| resize_w = int(round(resize_w / 32) * 32) | |
| img = cv2.resize(img, (int(resize_w), int(resize_h))) | |
| return img | |
| def convert_to_ndarray(img): | |
| if isinstance(img, Image.Image): | |
| img = np.array(img.convert('RGB')) | |
| elif isinstance(img, np.ndarray): | |
| if len(img.shape) == 2: | |
| img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) | |
| img = img[:, :, ::-1] # convert to rgb | |
| else: | |
| raise TypeError(f'input should be either PIL.Image,' | |
| f' np.array, but got {type(img)}') | |
| return img | |
| def preprocess(self, input, limit_side_len=800): | |
| img = self.convert_to_ndarray(input) # rgb input | |
| img = img.astype(float) | |
| orig_h, orig_w, _ = img.shape | |
| img = self.resize_image(img, limit_side_len) | |
| result = {'img': img, 'orig_h': orig_h, 'orig_w': orig_w} | |
| return result | |
| def forward(self, input): | |
| orig_h, orig_w = input['orig_h'], input['orig_w'] | |
| with self._session.as_default(): | |
| feed_dict = {self.input_name: input['img']} | |
| output_img = self._session.run(self.output, feed_dict=feed_dict) # RGBA | |
| # output_img = cv2.cvtColor(output_img, cv2.COLOR_RGBA2BGRA) | |
| output_img = cv2.resize(output_img, (int(orig_w), int(orig_h))) | |
| return {"output_img": output_img} | |
| def postprocess(self, inputs): | |
| return inputs["output_img"] | |