File size: 2,068 Bytes
79378ee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import random
from typing import List, Union, Optional, Tuple
import torch
from PIL import Image
from sample import (arg_parse,
                    sampling,
                    load_fontdiffuer_pipeline)

def run_fontdiffuer(source_image,
                    character,
                    reference_image,
                    sampling_step,
                    guidance_scale,
                    batch_size=1):

    args.character_input = False if source_image is not None else True
    args.content_character = character
    args.sampling_step = sampling_step
    args.guidance_scale = guidance_scale
    args.batch_size = batch_size
    args.seed = random.randint(0, 10000)
    out_image = sampling(
        args=args,
        pipe=pipe,
        content_image=source_image,
        style_image=reference_image)

    if out_image is not None:
        out_image.format = 'PNG'

    return out_image

def run_inference(
    source_image_path: Union[str, None],
    character: Union[str, None],
    reference_image_path: str,
    sampling_step: int=50,
    guidance_scale: float=7.5,
):
    if source_image_path is not None:
        source_image = Image.open(source_image_path).convert('RGB')
    else:
        source_image = None

    if reference_image_path is not None:
        reference_image = Image.open(reference_image_path).convert('RGB')
    else:
        reference_image = None
    
    image = run_fontdiffuer(
        source_image=source_image,
        character=character,
        reference_image=reference_image,
        sampling_step=sampling_step,
        guidance_scale=guidance_scale
    )
    return image

if __name__ == '__main__':
    args = arg_parse()
    args.demo = True
    args.ckpt_dir = 'ckpt'
    args.ttf_path = 'ttf/KaiXinSongA.ttf'
    args.device = 'cuda'


    # load fontdiffuer pipeline
    pipe = load_fontdiffuer_pipeline(args=args)

    image = run_inference(
        character=None,
        source_image_path="figures/ref_imgs/ref_壤.jpg",
        reference_image_path="figures/ref_imgs/ref_欟.jpg"
    )

    print(image)