File size: 4,358 Bytes
6c6eb37 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
import os
from PIL import Image
from iopaint.helper import encode_pil_to_base64, gen_frontend_mask
from iopaint.plugins.anime_seg import AnimeSeg
from iopaint.schema import Device, RunPluginRequest, RemoveBGModel, InteractiveSegModel
from iopaint.tests.utils import check_device, current_dir, save_dir
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
import cv2
import pytest
from iopaint.plugins import (
RemoveBG,
RealESRGANUpscaler,
GFPGANPlugin,
RestoreFormerPlugin,
InteractiveSeg,
)
img_p = current_dir / "bunny.jpeg"
img_bytes = open(img_p, "rb").read()
bgr_img = cv2.imread(str(img_p))
rgb_img = cv2.cvtColor(bgr_img, cv2.COLOR_BGR2RGB)
rgb_img_base64 = encode_pil_to_base64(Image.fromarray(rgb_img), 100, {})
bgr_img_base64 = encode_pil_to_base64(Image.fromarray(bgr_img), 100, {})
person_p = current_dir / "image.png"
person_bgr_img = cv2.imread(str(person_p))
person_rgb_img = cv2.cvtColor(person_bgr_img, cv2.COLOR_BGR2RGB)
person_rgb_img = cv2.resize(person_rgb_img, (512, 512))
def _save(img, name):
name = name.replace("/", "_")
cv2.imwrite(str(save_dir / name), img)
@pytest.mark.parametrize("model_name", RemoveBGModel.values())
@pytest.mark.parametrize("device", Device.values())
def test_remove_bg(model_name, device):
check_device(device)
print(f"Testing {model_name} on {device}")
model = RemoveBG(model_name, device)
rgba_np_img = model.gen_image(
rgb_img, RunPluginRequest(name=RemoveBG.name, image=rgb_img_base64)
)
res = cv2.cvtColor(rgba_np_img, cv2.COLOR_RGBA2BGRA)
_save(res, f"test_remove_bg_{model_name}_{device}.png")
bgr_np_img = model.gen_mask(
rgb_img, RunPluginRequest(name=RemoveBG.name, image=rgb_img_base64)
)
res_mask = gen_frontend_mask(bgr_np_img)
_save(res_mask, f"test_remove_bg_frontend_mask_{model_name}_{device}.png")
assert len(bgr_np_img.shape) == 2
_save(bgr_np_img, f"test_remove_bg_mask_{model_name}_{device}.jpeg")
def test_anime_seg():
model = AnimeSeg()
img = cv2.imread(str(current_dir / "anime_test.png"))
img_base64 = encode_pil_to_base64(Image.fromarray(img), 100, {})
res = model.gen_image(img, RunPluginRequest(name=AnimeSeg.name, image=img_base64))
assert len(res.shape) == 3
assert res.shape[-1] == 4
_save(res, "test_anime_seg.png")
res = model.gen_mask(img, RunPluginRequest(name=AnimeSeg.name, image=img_base64))
assert len(res.shape) == 2
_save(res, "test_anime_seg_mask.png")
@pytest.mark.parametrize("device", ["cuda", "cpu", "mps"])
def test_upscale(device):
check_device(device)
model = RealESRGANUpscaler("realesr-general-x4v3", device)
res = model.gen_image(
rgb_img,
RunPluginRequest(name=RealESRGANUpscaler.name, image=rgb_img_base64, scale=2),
)
_save(res, f"test_upscale_x2_{device}.png")
res = model.gen_image(
rgb_img,
RunPluginRequest(name=RealESRGANUpscaler.name, image=rgb_img_base64, scale=4),
)
_save(res, f"test_upscale_x4_{device}.png")
@pytest.mark.parametrize("device", ["cuda", "cpu", "mps"])
def test_gfpgan(device):
check_device(device)
model = GFPGANPlugin(device)
res = model.gen_image(
person_rgb_img, RunPluginRequest(name=GFPGANPlugin.name, image=rgb_img_base64)
)
_save(res, f"test_gfpgan_{device}.png")
@pytest.mark.parametrize("device", ["cuda", "cpu", "mps"])
def test_restoreformer(device):
check_device(device)
model = RestoreFormerPlugin(device)
res = model.gen_image(
person_rgb_img,
RunPluginRequest(name=RestoreFormerPlugin.name, image=rgb_img_base64),
)
_save(res, f"test_restoreformer_{device}.png")
@pytest.mark.parametrize("name", InteractiveSegModel.values())
@pytest.mark.parametrize("device", ["cuda", "cpu", "mps"])
def test_segment_anything(name, device):
check_device(device)
model = InteractiveSeg(name, device)
new_mask = model.gen_mask(
rgb_img,
RunPluginRequest(
name=InteractiveSeg.name,
image=rgb_img_base64,
clicks=([[448 // 2, 394 // 2, 1]]),
),
)
save_name = f"test_segment_anything_{name}_{device}.png"
_save(new_mask, save_name)
|