Spaces:
Runtime error
Runtime error
| from pathlib import Path | |
| from typing import Dict, List, Union, Tuple | |
| from omegaconf import OmegaConf | |
| import numpy as np | |
| import torch | |
| from torch import nn | |
| from PIL import Image, ImageDraw, ImageFont | |
| import models | |
| GENERATOR_PREFIX = "networks.g." | |
| WHITE = 255 | |
| EXAMPLE_CHARACTERS = ['A', 'B', 'C', 'D', 'E'] | |
| class InferenceServicer: | |
| def __init__(self, hp, checkpoint_path, content_image_dir, imsize=64, gpu_id='0') -> None: | |
| self.hp = hp | |
| self.imsize = imsize | |
| if gpu_id is None: | |
| self.device = torch.device(f'cuda:0') if torch.cuda.is_available() else 'cpu' | |
| else: | |
| self.device = torch.device(f'cuda:{gpu_id}') | |
| model_config = self.hp.models.G | |
| self.model: nn.Module = models.Generator(model_config) | |
| # Load Generator model weight | |
| model_state_dict_pl = torch.load(checkpoint_path, map_location='cpu') | |
| generator_state_dict = self.convert_generator_state_dict(model_state_dict_pl) | |
| self.model.load_state_dict(generator_state_dict) | |
| self.model.to(device=self.device) | |
| self.model.eval() | |
| # Setting Content font files | |
| self.content_character_dict = self.load_content_character_dict(Path(content_image_dir)) | |
| def convert_generator_state_dict(model_state_dict_pl): | |
| generator_prefix = GENERATOR_PREFIX | |
| generator_state_dict = {} | |
| for module_name, module_state in model_state_dict_pl['state_dict'].items(): | |
| if module_name.startswith(generator_prefix): | |
| generator_state_dict[module_name[len(generator_prefix):]] = module_state | |
| return generator_state_dict | |
| def load_content_character_dict(content_image_dir: Path) -> Dict[str, Path]: | |
| content_character_dict = {} | |
| for filepath in content_image_dir.glob("**/*.png"): | |
| content_character_dict[filepath.stem] = filepath | |
| return content_character_dict | |
| def center_align(bg_img: Image.Image, item_img: Image.Image, fit=False) -> Image.Image: | |
| bg_img = bg_img.copy() | |
| item_img = item_img.copy() | |
| item_w, item_h = item_img.size | |
| W, H = bg_img.size | |
| if fit: | |
| item_ratio = item_w / item_h | |
| bg_ratio = W / H | |
| if bg_ratio > item_ratio: | |
| # height fitting | |
| resize_ratio = H / item_h | |
| else: | |
| # width fitting | |
| resize_ratio = W / item_w | |
| item_img = item_img.resize((int(item_w * resize_ratio), int(item_h * resize_ratio))) | |
| item_w, item_h = item_img.size | |
| bg_img.paste(item_img, ((W - item_w) // 2, (H - item_h) // 2)) | |
| return bg_img | |
| def set_image(self, image: Union[Path, Image.Image]) -> Image.Image: | |
| if isinstance(image, (str, Path)): | |
| image = Image.open(image) | |
| assert isinstance(image, Image.Image) | |
| bg_img = Image.new('RGB', (self.imsize, self.imsize), color='white') | |
| blend_img = self.center_align(bg_img, image, fit=True) | |
| return blend_img | |
| def pil_image_to_array(blend_img: Image.Image) -> np.ndarray: | |
| normalized_array = np.mean(np.array(blend_img, dtype=np.float32), axis=-1) / WHITE # L-only image normalized to [0, 1] | |
| return normalized_array | |
| def get_images_from_fontfile(self, font_file_path: Path, imgmode: str = 'RGB', position: tuple = (0, 0), font_size: int = 128, padding: int = 100) -> List[Image.Image]: | |
| imagefont = ImageFont.truetype(str(font_file_path), size=font_size) | |
| example_characters = EXAMPLE_CHARACTERS | |
| font_images: List[Image.Image] = [] | |
| for character in example_characters: | |
| x, y, _, _ = imagefont.getbbox(character) | |
| img = Image.new(imgmode, (x + padding, y + padding), color='white') | |
| draw = ImageDraw.Draw(img) | |
| # bbox = draw.textbbox((0,0), character, font=imagefont) | |
| # w = bbox[2] - bbox[0] | |
| # h = bbox[3] - bbox[1] | |
| w, h = draw.textsize(character, font=imagefont) | |
| img = Image.new(imgmode, (w + padding, h + padding), color='white') | |
| draw = ImageDraw.Draw(img) | |
| draw.text(position, text=character, font=imagefont, fill='black') | |
| img = img.convert(imgmode) | |
| font_images.append(img) | |
| return font_images | |
| def get_hex_from_char(char: str) -> str: | |
| assert len(char) == 1 | |
| return f"{ord(char):04X}".upper() # 4-digit hex string | |
| def inference(self, content_char: str, style_font: Union[str, Path]) -> Tuple[Image.Image, List[Image.Image], Image.Image]: | |
| assert len(content_char) > 0 | |
| content_char = content_char[:1] # only get the first character if the length > 1 | |
| char_hex = self.get_hex_from_char(content_char) | |
| if char_hex not in self.content_character_dict: | |
| raise ValueError(f"The character {content_char} (hex: {char_hex}) is not supported in this model!") | |
| content_image = self.set_image(self.content_character_dict[char_hex]) | |
| style_images: List[Image.Image] = self.get_images_from_fontfile(Path(style_font)) | |
| style_images: List[Image.Image] = [self.set_image(image) for image in style_images] | |
| content_image_array = self.pil_image_to_array(content_image)[np.newaxis, np.newaxis, ...] # 1 x C(=1) x H x W | |
| style_images_array: np.ndarray = np.array([self.pil_image_to_array(image) for image in style_images])[np.newaxis, ...] # 1 x C(=5, # shots) x H x W, k-shots goes to batch | |
| content_input_tensor = torch.from_numpy(content_image_array).to(self.device) | |
| style_input_tensor = torch.from_numpy(style_images_array).to(self.device) | |
| generated_images: torch.Tensor = self.model((content_input_tensor, style_input_tensor)) | |
| generated_images = torch.clip(generated_images, 0, 1) | |
| assert generated_images.size(0) == 1 | |
| generated_image_numpy = (generated_images[0].cpu().numpy() * 255).astype(np.uint8)[0, ...] # H x W | |
| return content_image, style_images, Image.fromarray(generated_image_numpy, mode='L') | |
| if __name__ == '__main__': | |
| hp = OmegaConf.load("config/models/google-font.yaml") | |
| checkpoint_path = "epoch=199-step=257400.ckpt" | |
| content_image_dir = "../DATA/NotoSans" | |
| servicer = InferenceServicer(hp, checkpoint_path, content_image_dir) | |
| style_font = "example_fonts/MaShanZheng-Regular.ttf" | |
| content_image, style_images, result = servicer.inference("7", style_font) | |
| content_image.save("result_content.png") | |
| for idx, style_image in enumerate(style_images): | |
| style_image.save(f"result_style_{idx:02d}.png") | |
| result.save("result_generated.png") |