Spaces:
Runtime error
Runtime error
| import os | |
| import argparse | |
| from pathlib import Path | |
| from typing import Optional, Union, Tuple, List | |
| import subprocess | |
| import gradio as gr | |
| from PIL import Image | |
| from omegaconf import OmegaConf, DictConfig | |
| from inference import InferenceServicer | |
| PATH_DOCS = os.getenv("PATH_DOCS", default="docs/ml-font-style-transfer.md") | |
| MODEL_CONFIG = os.getenv("MODEL_CONFIG", default="config/models/google-font.yaml") | |
| MODEL_CHECKPOINT_PATH = os.getenv("MODEL_CHECKPOINT_PATH", default=None) | |
| NOTO_SANS_ZIP_PATH = os.getenv("NOTO_SANS_ZIP_PATH", default=None) | |
| LOCAL_CHECKPOINT_PATH = "checkpoint/checkpoint.ckpt" | |
| LOCAL_NOTO_ZIP_PATH = "data/NotoSans.zip" | |
| if MODEL_CHECKPOINT_PATH is not None: | |
| subprocess.call(f"wget --no-check-certificate -O {LOCAL_CHECKPOINT_PATH} {MODEL_CHECKPOINT_PATH}", shell=True) | |
| if NOTO_SANS_ZIP_PATH is not None: | |
| subprocess.call(f"wget --no-check-certificate -O {LOCAL_NOTO_ZIP_PATH} {NOTO_SANS_ZIP_PATH}", shell=True) | |
| subprocess.call(f"unzip data/NotoSans.zip -d {str(Path(LOCAL_NOTO_ZIP_PATH).parent)}", shell=True) | |
| assert Path("checkpoint/checkpoint.ckpt").exists() | |
| assert Path("data/NotoSans").exists() | |
| EXAMPLE_FONTS = sorted([ | |
| "example_fonts/BalooDa2-Bold.ttf", | |
| "example_fonts/BalooDa2-Regular.ttf", | |
| "example_fonts/Lalezar-Regular.ttf", | |
| "example_fonts/MaShanZheng-Regular.ttf", | |
| ]) | |
| def parse_args(): | |
| parser = argparse.ArgumentParser(description="Augmentation simulator for NetsPresso Trainer") | |
| # -------- User arguments ---------------------------------------- | |
| parser.add_argument( | |
| '--docs', type=Path, default=PATH_DOCS, | |
| help="Docs string file") | |
| parser.add_argument( | |
| '--config', type=Path, default=MODEL_CONFIG, | |
| help="Config for model") | |
| parser.add_argument( | |
| '--local', action='store_true', | |
| help="Whether to run in local environment or not") | |
| parser.add_argument( | |
| '--port', type=int, default=50003, | |
| help="Service port (only applicable when running on local server)") | |
| args, _ = parser.parse_known_args() | |
| return args | |
| class InferenceServiceResolver(InferenceServicer): | |
| def __init__(self, hp, checkpoint_path, content_image_dir, imsize=64, gpu_id='0') -> None: | |
| super().__init__(hp, checkpoint_path, content_image_dir, imsize, gpu_id) | |
| def generate(self, content_char: str, style_font: Union[str, Path]) -> List[Image.Image]: | |
| try: | |
| content_image, style_images, result = self.inference(content_char=content_char, style_font=style_font) | |
| return [content_image, *style_images, result] | |
| except Exception as e: | |
| raise gr.Error(str(e)) | |
| def launch_gradio(docs_path: Path, hp: DictConfig, checkpoint_path: Path, content_image_dir: Path, is_local: bool, port: Optional[int] = None): | |
| servicer = InferenceServiceResolver(hp, checkpoint_path, content_image_dir, gpu_id=None) | |
| with gr.Blocks(title="Multilingual Font Style Transfer (training with Google Fonts)") as demo: | |
| gr.Markdown(docs_path.read_text()) | |
| with gr.Row(equal_height=True): | |
| character_input = gr.Textbox(max_lines=1, value="7", info="Only single character is acceptable (e.g. '간', '7', or 'ជ')") | |
| style_font = gr.Dropdown(label="Select example font: ", choices=EXAMPLE_FONTS, value=EXAMPLE_FONTS[0]) | |
| run_button = gr.Button(value="Generate", variant='primary') | |
| with gr.Row(equal_height=True): | |
| with gr.Column(scale=1): | |
| with gr.Group(): | |
| gr.Markdown(f"<center><h3>Content character</h3></center>") | |
| content_char = gr.Image(label="Content character", show_label=False) | |
| with gr.Column(scale=5): | |
| with gr.Group(): | |
| gr.Markdown(f"<center><h3>Style font images</h3></center>") | |
| with gr.Row(equal_height=True): | |
| style_char_1 = gr.Image(label="Style #1", show_label=False) | |
| style_char_2 = gr.Image(label="Style #2", show_label=False) | |
| style_char_3 = gr.Image(label="Style #3", show_label=False) | |
| style_char_4 = gr.Image(label="Style #4", show_label=False) | |
| style_char_5 = gr.Image(label="Style #5", show_label=False) | |
| with gr.Column(scale=1): | |
| with gr.Group(): | |
| gr.Markdown(f"<center><h3>Generated font image</h3></center>") | |
| generated_font = gr.Image(label="Generated font image", show_label=False) | |
| outputs = [content_char, style_char_1, style_char_2, style_char_3, style_char_4, style_char_5, generated_font] | |
| run_inputs = [character_input, style_font] | |
| run_button.click(servicer.generate, inputs=run_inputs, outputs=outputs) | |
| if is_local: | |
| demo.launch(server_name="0.0.0.0", server_port=port) | |
| else: | |
| demo.launch() | |
| if __name__ == "__main__": | |
| args = parse_args() | |
| hp = OmegaConf.load(args.config) | |
| checkpoint_path = Path(LOCAL_CHECKPOINT_PATH) | |
| content_image_dir = Path(LOCAL_NOTO_ZIP_PATH).with_suffix("") | |
| launch_gradio(args.docs, hp, checkpoint_path, content_image_dir, args.local, args.port) |