File size: 5,110 Bytes
328cff6
 
 
c003113
 
aae6bda
 
c003113
8a1772b
c003113
 
 
 
 
 
 
0c39b3b
328cff6
 
 
 
8a1772b
 
0c39b3b
328cff6
 
 
 
 
 
8a1772b
01d5b32
 
 
 
8a1772b
 
 
01d5b32
909ba8b
d26e6a6
 
 
 
 
909ba8b
 
d26e6a6
8a1772b
 
8b67aaf
8a1772b
 
 
 
 
 
 
 
0c39b3b
24ff1e4
0c39b3b
 
 
 
 
 
 
e98f015
 
 
 
 
 
 
 
 
 
a4af18d
e98f015
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cd41cbe
909ba8b
d26e6a6
8a1772b
0c39b3b
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import gradio as gr
import numpy as np
import random
import torch

from PIL import Image as im

from helper.cond_encoder import CLIPEncoder
from helper.loader import Loader

from auto_encoder.models.variational_auto_encoder import VariationalAutoEncoder
from clip.models.ko_clip import KoCLIPWrapper
from diffusion_model.sampler.ddim import DDIM
from diffusion_model.models.latent_diffusion_model import LatentDiffusionModel
from diffusion_model.network.unet import Unet
from diffusion_model.network.unet_wrapper import UnetWrapper
from huggingface_hub import hf_hub_download

# import spaces #[uncomment to use ZeroGPU]

device = "cuda" if torch.cuda.is_available() else "cpu"
loader = Loader(device)
repo_id = "JuyeopDang/KoFace-Diffusion"
CONFIG_PATH = 'configs/composite_config.yaml'

if torch.cuda.is_available():
    torch_dtype = torch.float16
else:
    torch_dtype = torch.float32

def load_model_from_HF(model, repo_id, filename, is_ema=False):
    try:
        model_path = hf_hub_download(repo_id=repo_id, filename=filename)
    except Exception as e:
        print(f"ํŒŒ์ผ ๋‹ค์šด๋กœ๋“œ ๋˜๋Š” ๋ชจ๋ธ ๋กœ๋“œ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {e}")
    model_path = model_path[:-4]
    model = loader.model_load(model_path, model, is_ema=is_ema, print_dict=False)
    return model

examples = [
    ['Guidance Scale์ด 0์ธ ๊ฒฝ์šฐ, Text Condition์„ ์ด์šฉํ•˜์ง€ ์•Š๊ณ  ์ž„์˜์˜ ์–ผ๊ตด์„ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค.', 0, im.open("./assets/Example0.webp")],
    ['๋™๊ทธ๋ž€ ์–ผ๊ตด์— ํ’์„ฑํ•˜๊ณ  ์‚ด์ง ๊ธด ์ปคํŠธ๋จธ๋ฆฌ์™€ ๊ฑฐ์˜ ๋‚˜์˜ค์ง€ ์•Š์€ ์„ฑ๋Œ€๊ฐ€ ๋‚จ์„ฑ๋ณด๋‹ค๋Š” ์—ฌ์„ฑ์ ์ธ ๋ถ„์œ„๊ธฐ๋ฅผ ๋ƒ…๋‹ˆ๋‹ค. ๋˜๋ ทํ•œ ๋ˆˆ๊ณผ ์ž…์ˆ ์ด ์ธ๋ฌผ์˜ ์„ฌ์„ธํ•จ์„ ๋”์šฑ ๋ถ€๊ฐ์‹œํ‚ค๊ณ  ์ง€์ ์œผ๋กœ ๋ณด์ด๊ฒŒ ๋งŒ๋“ญ๋‹ˆ๋‹ค.', 2, im.open("./assets/Example1.webp")],
    ['์ฝ”๊ฐ€ ํฐ ๋‚จ์ž๋‹ค.', 4, im.open("./assets/Example2.webp")],
    ['์•„์คŒ๋งˆ ํŒŒ๋งˆ๊ฐ™์ด ํฌ๊ณ  ๋ถ€ํ•œ ๋จธ๋ฆฌ์— ์ด๋ชฉ๊ตฌ๋น„๊ฐ€ ์•ฝ๊ฐ„ ์—ฌ์„ฑ์Šค๋Ÿฌ์›Œ ๋ณด์ด๋Š” ์–ผ๊ตด์ด๋ฉฐ, ์ ์ง€์•Š์€ ์ฃผ๋ฆ„๊ณผ ์ฒ˜์ง„ ๋ˆˆ์—์„œ ์—ฐ๋ฅœ์ด ๋А๊บผ์ง„๋‹ค. ํฐ ๋ˆˆ๊ณผ ๋„“์€ ํŒ”์ž ์ฃผ๋ฆ„์ด ์žˆ๋Š” ๊ณ ์ง‘์Šค๋Ÿฐ ์ž…๋งค๋Š” ์™„๊ณ ํ•œ ์›์น™์ฃผ์˜์ž ๊ฐ™์€ ๋А๋‚Œ์ด ์žˆ๋‹ค.', 4, im.open("./assets/Example3.webp")],
    ['ํ—ค์–ด ์†์งˆ์ด ๋‹ค์†Œ ๋ฏธ์ˆ™ํ•˜์—ฌ ์„ธ๋ จ๋œ ๋А๋‚Œ์ด ๋ถ€์กฑํ•˜๋‹ค. ๋ˆˆ ๋์ด ์˜ฌ๋ผ๊ฐ€ ์žˆ์–ด ๋ˆˆ๋น›์ด ๋‚ ์นด๋กญ๊ณ  ์˜ˆ๋ฏผํ•ด ๋ณด์ธ๋‹ค. ์ „์ฒด์ ์œผ๋กœ ๋งˆ๋ฅธ ์ฒด๊ฒฉ์ผ ๊ฒƒ์œผ๋กœ ๋ณด์ด๋ฉฐ, ์—…๋ฌด ์ฒ˜๋ฆฌ ๋Šฅ๋ ฅ์€ ๋›ฐ์–ด๋‚˜๊ฒ ์ง€๋งŒ, ๊ต์šฐ ๊ด€๊ณ„๋Š” ์›๋งŒํ•˜์ง€ ์•Š์„ ์ˆ˜๋„ ์žˆ๋‹ค.', 7, im.open("./assets/Example4.webp")]
]


if __name__ == "__main__":
    vae = VariationalAutoEncoder(CONFIG_PATH)
    sampler = DDIM(CONFIG_PATH)
    clip = KoCLIPWrapper()
    cond_encoder = CLIPEncoder(clip, CONFIG_PATH)
    network = UnetWrapper(Unet, CONFIG_PATH, cond_encoder)
    dm = LatentDiffusionModel(network, sampler, vae)

    vae = load_model_from_HF(vae, repo_id, "composite_epoch2472.pth", False)
    clip = load_model_from_HF(clip, repo_id, "asian-composite-fine-tuned-koclip.pth", True)
    dm = load_model_from_HF(dm, repo_id, "asian-composite-clip-ldm.pth", True)

    def generate_image(y, gamma):
        images = dm.sample(2, y = y, gamma = gamma)
        images = images.permute(0, 2, 3, 1) 
        if type(images) is torch.Tensor:
          images = images.detach().cpu().numpy()
        images = np.clip(images / 2 + 0.5, 0, 1)
        return im.fromarray((images[0] * 255).astype(np.uint8))

    with gr.Blocks() as demo:
        gr.Markdown(
            """
            # KoFace AI - ํ•œ๊ตญ์ธ ๋ชฝํƒ€์ฃผ ์ƒ์„ฑ๊ธฐ
            * ๋ชฝํƒ€์ฃผ๋ฅผ ์„ค๋ช…ํ•˜๋Š” ํ…์ŠคํŠธ๋ฅผ ์ž…๋ ฅํ•˜๊ณ  '์ƒ์„ฑํ•˜๊ธฐ' ๋ฒ„ํŠผ์„ ๋ˆŒ๋Ÿฌ์ฃผ์„ธ์š”.
            
            * **์ฐธ๊ณ **: ํ˜„์žฌ ๋ฌด๋ฃŒ ํ‹ฐ์–ด๋ฅผ ์ด์šฉํ•˜๊ณ  ์žˆ์–ด, **์ƒ์„ฑ์— ์•ฝ 10๋ถ„ ์ •๋„์˜ ์‹œ๊ฐ„**์ด ๊ฑธ๋ฆด ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. 
    
            * ์ด AI๋Š” Latent Diffusion Model์„ ์ง์ ‘ ๊ตฌํ˜„ํ•˜์—ฌ ๋งŒ๋“ค์—ˆ์Šต๋‹ˆ๋‹ค. ์ž์„ธํ•œ ์ฝ”๋“œ๋‚˜ ์‹คํ—˜ ๊ฒฐ๊ณผ๋Š” GitHub ์ €์žฅ์†Œ๋ฅผ ์ฐธ๊ณ ํ•ด ์ฃผ์„ธ์š”.
    
            ๐Ÿ”— [GitHub ์ €์žฅ์†Œ ๋ฐ”๋กœ๊ฐ€๊ธฐ](https://github.com/Won-Seong/simple-latent-diffusion-model) | ๐Ÿ“„ [์‚ฌ์šฉํ•œ ๋ฐ์ดํ„ฐ์„ธํŠธ](https://www.aihub.or.kr/aihubdata/data/view.do?dataSetSn=618)
    
            
            """
        )
    
        with gr.Row():
            with gr.Column():
                text_input = gr.Textbox(
                    placeholder="๋ชฝํƒ€์ฃผ๋ฅผ ์„ค๋ช…ํ•˜๋Š” ํ…์ŠคํŠธ๋ฅผ ์ž…๋ ฅํ•˜์„ธ์š”.", 
                    label="Text Condition"
                )
                guidance_slider = gr.Slider(
                    0, 10, value=2, 
                    label="Guidance Scale", 
                    info="์ด ์ˆซ์ž๊ฐ€ ํฌ๋ฉด ํด์ˆ˜๋ก ์ž…๋ ฅ ํ…์ŠคํŠธ๋ฅผ ๋” ๊ฐ•ํ•˜๊ฒŒ ์ด์šฉํ•˜์—ฌ ๋ชฝํƒ€์ฃผ๋ฅผ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค."
                )
                submit_btn = gr.Button("์ƒ์„ฑํ•˜๊ธฐ")
    
            with gr.Column():
                image_output = gr.Image(label="์ƒ์„ฑ๋œ ๋ชฝํƒ€์ฃผ")
    
        submit_btn.click(fn=generate_image, inputs=[text_input, guidance_slider], outputs=image_output)

        gr.Examples(examples, inputs = [text_input, guidance_slider, image_output])
    
    demo.launch()