Spaces:
Runtime error
Runtime error
| from PIL import Image | |
| import glob | |
| import io | |
| import argparse | |
| import inspect | |
| import os | |
| import random | |
| from typing import Dict, Optional, Tuple | |
| from omegaconf import OmegaConf | |
| import numpy as np | |
| import torch | |
| from diffusers import AutoencoderKL, DDIMScheduler | |
| from diffusers.utils import check_min_version | |
| from tqdm.auto import tqdm | |
| from transformers import CLIPTextModel, CLIPTokenizer, CLIPImageProcessor, CLIPVisionModelWithProjection | |
| from torchvision import transforms | |
| from canonicalize.models.unet_mv2d_condition import UNetMV2DConditionModel | |
| from canonicalize.models.unet_mv2d_ref import UNetMV2DRefModel | |
| from canonicalize.pipeline_canonicalize import CanonicalizationPipeline | |
| from einops import rearrange | |
| from torchvision.utils import save_image | |
| import json | |
| import cv2 | |
| import onnxruntime as rt | |
| from huggingface_hub.file_download import hf_hub_download | |
| from rm_anime_bg.cli import get_mask, SCALE | |
| check_min_version("0.24.0") | |
| weight_dtype = torch.float16 | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| class BkgRemover: | |
| def __init__(self, force_cpu: Optional[bool] = True): | |
| session_infer_path = hf_hub_download( | |
| repo_id="skytnt/anime-seg", filename="isnetis.onnx", | |
| ) | |
| providers: list[str] = ["CPUExecutionProvider"] | |
| if not force_cpu and "CUDAExecutionProvider" in rt.get_available_providers(): | |
| providers = ["CUDAExecutionProvider"] | |
| self.session_infer = rt.InferenceSession( | |
| session_infer_path, providers=providers, | |
| ) | |
| def remove_background( | |
| self, | |
| img: np.ndarray, | |
| alpha_min: float, | |
| alpha_max: float, | |
| ) -> list: | |
| img = np.array(img) | |
| mask = get_mask(self.session_infer, img) | |
| mask[mask < alpha_min] = 0.0 | |
| mask[mask > alpha_max] = 1.0 | |
| img_after = (mask * img).astype(np.uint8) | |
| mask = (mask * SCALE).astype(np.uint8) | |
| img_after = np.concatenate([img_after, mask], axis=2, dtype=np.uint8) | |
| return Image.fromarray(img_after) | |
| def set_seed(seed): | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| torch.cuda.manual_seed_all(seed) | |
| def process_image(image, totensor, width, height): | |
| assert image.mode == "RGBA" | |
| # Find non-transparent pixels | |
| non_transparent = np.nonzero(np.array(image)[..., 3]) | |
| min_x, max_x = non_transparent[1].min(), non_transparent[1].max() | |
| min_y, max_y = non_transparent[0].min(), non_transparent[0].max() | |
| image = image.crop((min_x, min_y, max_x, max_y)) | |
| # paste to center | |
| max_dim = max(image.width, image.height) | |
| max_height = int(max_dim * 1.2) | |
| max_width = int(max_dim / (height/width) * 1.2) | |
| new_image = Image.new("RGBA", (max_width, max_height)) | |
| left = (max_width - image.width) // 2 | |
| top = (max_height - image.height) // 2 | |
| new_image.paste(image, (left, top)) | |
| image = new_image.resize((width, height), resample=Image.BICUBIC) | |
| image = np.array(image) | |
| image = image.astype(np.float32) / 255. | |
| assert image.shape[-1] == 4 # RGBA | |
| alpha = image[..., 3:4] | |
| bg_color = np.array([1., 1., 1.], dtype=np.float32) | |
| image = image[..., :3] * alpha + bg_color * (1 - alpha) | |
| return totensor(image) | |
| def inference(validation_pipeline, bkg_remover, input_image, vae, feature_extractor, image_encoder, unet, ref_unet, tokenizer, | |
| text_encoder, pretrained_model_path, generator, validation, val_width, val_height, unet_condition_type, | |
| use_noise=True, noise_d=256, crop=False, seed=100, timestep=20): | |
| set_seed(seed) | |
| totensor = transforms.ToTensor() | |
| prompts = "high quality, best quality" | |
| prompt_ids = tokenizer( | |
| prompts, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, | |
| return_tensors="pt" | |
| ).input_ids[0] | |
| # (B*Nv, 3, H, W) | |
| B = 1 | |
| if input_image.mode != "RGBA": | |
| # remove background | |
| input_image = bkg_remover.remove_background(input_image, 0.1, 0.9) | |
| imgs_in = process_image(input_image, totensor, val_width, val_height) | |
| imgs_in = rearrange(imgs_in.unsqueeze(0).unsqueeze(0), "B Nv C H W -> (B Nv) C H W") | |
| with torch.autocast('cuda' if torch.cuda.is_available() else 'cpu', dtype=weight_dtype): | |
| imgs_in = imgs_in.to(device=device) | |
| # B*Nv images | |
| out = validation_pipeline(prompt=prompts, image=imgs_in.to(weight_dtype), generator=generator, | |
| num_inference_steps=timestep, prompt_ids=prompt_ids, | |
| height=val_height, width=val_width, unet_condition_type=unet_condition_type, | |
| use_noise=use_noise, **validation,) | |
| out = rearrange(out, "B C f H W -> (B f) C H W", f=1) | |
| img_buf = io.BytesIO() | |
| save_image(out[0], img_buf, format='PNG') | |
| img_buf.seek(0) | |
| img = Image.open(img_buf) | |
| torch.cuda.empty_cache() | |
| return img | |
| def main( | |
| input_dir: str, | |
| output_dir: str, | |
| pretrained_model_path: str, | |
| validation: Dict, | |
| local_crossattn: bool = True, | |
| unet_from_pretrained_kwargs=None, | |
| unet_condition_type=None, | |
| use_noise=True, | |
| noise_d=256, | |
| seed: int = 42, | |
| timestep: int = 40, | |
| width_input: int = 640, | |
| height_input: int = 1024, | |
| ): | |
| *_, config = inspect.getargvalues(inspect.currentframe()) | |
| tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer") | |
| text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder") | |
| image_encoder = CLIPVisionModelWithProjection.from_pretrained(pretrained_model_path, subfolder="image_encoder") | |
| feature_extractor = CLIPImageProcessor() | |
| vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae") | |
| unet = UNetMV2DConditionModel.from_pretrained_2d(pretrained_model_path, subfolder="unet", local_crossattn=local_crossattn, **unet_from_pretrained_kwargs) | |
| ref_unet = UNetMV2DRefModel.from_pretrained_2d(pretrained_model_path, subfolder="ref_unet", local_crossattn=local_crossattn, **unet_from_pretrained_kwargs) | |
| text_encoder.to(device, dtype=weight_dtype) | |
| image_encoder.to(device, dtype=weight_dtype) | |
| vae.to(device, dtype=weight_dtype) | |
| ref_unet.to(device, dtype=weight_dtype) | |
| unet.to(device, dtype=weight_dtype) | |
| vae.requires_grad_(False) | |
| unet.requires_grad_(False) | |
| ref_unet.requires_grad_(False) | |
| # set pipeline | |
| noise_scheduler = DDIMScheduler.from_pretrained(pretrained_model_path, subfolder="scheduler-zerosnr") | |
| validation_pipeline = CanonicalizationPipeline( | |
| vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, ref_unet=ref_unet,feature_extractor=feature_extractor,image_encoder=image_encoder, | |
| scheduler=noise_scheduler | |
| ) | |
| validation_pipeline.set_progress_bar_config(disable=True) | |
| bkg_remover = BkgRemover() | |
| def canonicalize(image, width, height, seed, timestep): | |
| generator = torch.Generator(device=device).manual_seed(seed) | |
| return inference( | |
| validation_pipeline, bkg_remover, image, vae, feature_extractor, image_encoder, unet, ref_unet, tokenizer, text_encoder, | |
| pretrained_model_path, generator, validation, width, height, unet_condition_type, | |
| use_noise=use_noise, noise_d=noise_d, crop=True, seed=seed, timestep=timestep | |
| ) | |
| img_paths = sorted(glob.glob(os.path.join(input_dir, "*.png"))) | |
| os.makedirs(output_dir, exist_ok=True) | |
| for path in tqdm(img_paths): | |
| img_input = Image.open(path) | |
| if np.array(img_input)[..., 3].min() == 255: | |
| # convert to RGB | |
| img_input = img_input.convert("RGB") | |
| img_output = canonicalize(img_input, width_input, height_input, seed, timestep) | |
| img_output.save(os.path.join(output_dir, f"{os.path.basename(path).split('.')[0]}.png")) | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--config", type=str, default="./configs/canonicalization-infer.yaml") | |
| parser.add_argument("--input_dir", type=str, default="./input_cases") | |
| parser.add_argument("--output_dir", type=str, default="./result/apose") | |
| parser.add_argument("--seed", type=int, default=42) | |
| args = parser.parse_args() | |
| main(**OmegaConf.load(args.config), seed=args.seed, input_dir=args.input_dir, output_dir=args.output_dir) |