import random from typing import List, Union, Optional, Tuple import torch from PIL import Image from sample import (arg_parse, sampling, load_fontdiffuer_pipeline) def run_fontdiffuer(source_image, character, reference_image, sampling_step, guidance_scale, batch_size=1): args.character_input = False if source_image is not None else True args.content_character = character args.sampling_step = sampling_step args.guidance_scale = guidance_scale args.batch_size = batch_size args.seed = random.randint(0, 10000) out_image = sampling( args=args, pipe=pipe, content_image=source_image, style_image=reference_image) if out_image is not None: out_image.format = 'PNG' return out_image def run_inference( source_image_path: Union[str, None], character: Union[str, None], reference_image_path: str, sampling_step: int=50, guidance_scale: float=7.5, ): if source_image_path is not None: source_image = Image.open(source_image_path).convert('RGB') else: source_image = None if reference_image_path is not None: reference_image = Image.open(reference_image_path).convert('RGB') else: reference_image = None image = run_fontdiffuer( source_image=source_image, character=character, reference_image=reference_image, sampling_step=sampling_step, guidance_scale=guidance_scale ) return image if __name__ == '__main__': args = arg_parse() args.demo = True args.ckpt_dir = 'ckpt' args.ttf_path = 'ttf/KaiXinSongA.ttf' args.device = 'cuda' # load fontdiffuer pipeline pipe = load_fontdiffuer_pipeline(args=args) image = run_inference( character=None, source_image_path="figures/ref_imgs/ref_壤.jpg", reference_image_path="figures/ref_imgs/ref_欟.jpg" ) print(image)