FontDiffuser-Gradio / simple_inference.py
chulanpro5's picture
update: add simple inference example
79378ee
import random
from typing import List, Union, Optional, Tuple
import torch
from PIL import Image
from sample import (arg_parse,
sampling,
load_fontdiffuer_pipeline)
def run_fontdiffuer(source_image,
character,
reference_image,
sampling_step,
guidance_scale,
batch_size=1):
args.character_input = False if source_image is not None else True
args.content_character = character
args.sampling_step = sampling_step
args.guidance_scale = guidance_scale
args.batch_size = batch_size
args.seed = random.randint(0, 10000)
out_image = sampling(
args=args,
pipe=pipe,
content_image=source_image,
style_image=reference_image)
if out_image is not None:
out_image.format = 'PNG'
return out_image
def run_inference(
source_image_path: Union[str, None],
character: Union[str, None],
reference_image_path: str,
sampling_step: int=50,
guidance_scale: float=7.5,
):
if source_image_path is not None:
source_image = Image.open(source_image_path).convert('RGB')
else:
source_image = None
if reference_image_path is not None:
reference_image = Image.open(reference_image_path).convert('RGB')
else:
reference_image = None
image = run_fontdiffuer(
source_image=source_image,
character=character,
reference_image=reference_image,
sampling_step=sampling_step,
guidance_scale=guidance_scale
)
return image
if __name__ == '__main__':
args = arg_parse()
args.demo = True
args.ckpt_dir = 'ckpt'
args.ttf_path = 'ttf/KaiXinSongA.ttf'
args.device = 'cuda'
# load fontdiffuer pipeline
pipe = load_fontdiffuer_pipeline(args=args)
image = run_inference(
character=None,
source_image_path="figures/ref_imgs/ref_壤.jpg",
reference_image_path="figures/ref_imgs/ref_欟.jpg"
)
print(image)