Spaces:
Running
Running
File size: 5,110 Bytes
328cff6 c003113 aae6bda c003113 8a1772b c003113 0c39b3b 328cff6 8a1772b 0c39b3b 328cff6 8a1772b 01d5b32 8a1772b 01d5b32 909ba8b d26e6a6 909ba8b d26e6a6 8a1772b 8b67aaf 8a1772b 0c39b3b 24ff1e4 0c39b3b e98f015 a4af18d e98f015 cd41cbe 909ba8b d26e6a6 8a1772b 0c39b3b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 |
import gradio as gr
import numpy as np
import random
import torch
from PIL import Image as im
from helper.cond_encoder import CLIPEncoder
from helper.loader import Loader
from auto_encoder.models.variational_auto_encoder import VariationalAutoEncoder
from clip.models.ko_clip import KoCLIPWrapper
from diffusion_model.sampler.ddim import DDIM
from diffusion_model.models.latent_diffusion_model import LatentDiffusionModel
from diffusion_model.network.unet import Unet
from diffusion_model.network.unet_wrapper import UnetWrapper
from huggingface_hub import hf_hub_download
# import spaces #[uncomment to use ZeroGPU]
device = "cuda" if torch.cuda.is_available() else "cpu"
loader = Loader(device)
repo_id = "JuyeopDang/KoFace-Diffusion"
CONFIG_PATH = 'configs/composite_config.yaml'
if torch.cuda.is_available():
torch_dtype = torch.float16
else:
torch_dtype = torch.float32
def load_model_from_HF(model, repo_id, filename, is_ema=False):
try:
model_path = hf_hub_download(repo_id=repo_id, filename=filename)
except Exception as e:
print(f"ํ์ผ ๋ค์ด๋ก๋ ๋๋ ๋ชจ๋ธ ๋ก๋ ์ค ์ค๋ฅ ๋ฐ์: {e}")
model_path = model_path[:-4]
model = loader.model_load(model_path, model, is_ema=is_ema, print_dict=False)
return model
examples = [
['Guidance Scale์ด 0์ธ ๊ฒฝ์ฐ, Text Condition์ ์ด์ฉํ์ง ์๊ณ ์์์ ์ผ๊ตด์ ์์ฑํฉ๋๋ค.', 0, im.open("./assets/Example0.webp")],
['๋๊ทธ๋ ์ผ๊ตด์ ํ์ฑํ๊ณ ์ด์ง ๊ธด ์ปคํธ๋จธ๋ฆฌ์ ๊ฑฐ์ ๋์ค์ง ์์ ์ฑ๋๊ฐ ๋จ์ฑ๋ณด๋ค๋ ์ฌ์ฑ์ ์ธ ๋ถ์๊ธฐ๋ฅผ ๋
๋๋ค. ๋๋ ทํ ๋๊ณผ ์
์ ์ด ์ธ๋ฌผ์ ์ฌ์ธํจ์ ๋์ฑ ๋ถ๊ฐ์ํค๊ณ ์ง์ ์ผ๋ก ๋ณด์ด๊ฒ ๋ง๋ญ๋๋ค.', 2, im.open("./assets/Example1.webp")],
['์ฝ๊ฐ ํฐ ๋จ์๋ค.', 4, im.open("./assets/Example2.webp")],
['์์ค๋ง ํ๋ง๊ฐ์ด ํฌ๊ณ ๋ถํ ๋จธ๋ฆฌ์ ์ด๋ชฉ๊ตฌ๋น๊ฐ ์ฝ๊ฐ ์ฌ์ฑ์ค๋ฌ์ ๋ณด์ด๋ ์ผ๊ตด์ด๋ฉฐ, ์ ์ง์์ ์ฃผ๋ฆ๊ณผ ์ฒ์ง ๋์์ ์ฐ๋ฅ์ด ๋๊บผ์ง๋ค. ํฐ ๋๊ณผ ๋์ ํ์ ์ฃผ๋ฆ์ด ์๋ ๊ณ ์ง์ค๋ฐ ์
๋งค๋ ์๊ณ ํ ์์น์ฃผ์์ ๊ฐ์ ๋๋์ด ์๋ค.', 4, im.open("./assets/Example3.webp")],
['ํค์ด ์์ง์ด ๋ค์ ๋ฏธ์ํ์ฌ ์ธ๋ จ๋ ๋๋์ด ๋ถ์กฑํ๋ค. ๋ ๋์ด ์ฌ๋ผ๊ฐ ์์ด ๋๋น์ด ๋ ์นด๋กญ๊ณ ์๋ฏผํด ๋ณด์ธ๋ค. ์ ์ฒด์ ์ผ๋ก ๋ง๋ฅธ ์ฒด๊ฒฉ์ผ ๊ฒ์ผ๋ก ๋ณด์ด๋ฉฐ, ์
๋ฌด ์ฒ๋ฆฌ ๋ฅ๋ ฅ์ ๋ฐ์ด๋๊ฒ ์ง๋ง, ๊ต์ฐ ๊ด๊ณ๋ ์๋งํ์ง ์์ ์๋ ์๋ค.', 7, im.open("./assets/Example4.webp")]
]
if __name__ == "__main__":
vae = VariationalAutoEncoder(CONFIG_PATH)
sampler = DDIM(CONFIG_PATH)
clip = KoCLIPWrapper()
cond_encoder = CLIPEncoder(clip, CONFIG_PATH)
network = UnetWrapper(Unet, CONFIG_PATH, cond_encoder)
dm = LatentDiffusionModel(network, sampler, vae)
vae = load_model_from_HF(vae, repo_id, "composite_epoch2472.pth", False)
clip = load_model_from_HF(clip, repo_id, "asian-composite-fine-tuned-koclip.pth", True)
dm = load_model_from_HF(dm, repo_id, "asian-composite-clip-ldm.pth", True)
def generate_image(y, gamma):
images = dm.sample(2, y = y, gamma = gamma)
images = images.permute(0, 2, 3, 1)
if type(images) is torch.Tensor:
images = images.detach().cpu().numpy()
images = np.clip(images / 2 + 0.5, 0, 1)
return im.fromarray((images[0] * 255).astype(np.uint8))
with gr.Blocks() as demo:
gr.Markdown(
"""
# KoFace AI - ํ๊ตญ์ธ ๋ชฝํ์ฃผ ์์ฑ๊ธฐ
* ๋ชฝํ์ฃผ๋ฅผ ์ค๋ช
ํ๋ ํ
์คํธ๋ฅผ ์
๋ ฅํ๊ณ '์์ฑํ๊ธฐ' ๋ฒํผ์ ๋๋ฌ์ฃผ์ธ์.
* **์ฐธ๊ณ **: ํ์ฌ ๋ฌด๋ฃ ํฐ์ด๋ฅผ ์ด์ฉํ๊ณ ์์ด, **์์ฑ์ ์ฝ 10๋ถ ์ ๋์ ์๊ฐ**์ด ๊ฑธ๋ฆด ์ ์์ต๋๋ค.
* ์ด AI๋ Latent Diffusion Model์ ์ง์ ๊ตฌํํ์ฌ ๋ง๋ค์์ต๋๋ค. ์์ธํ ์ฝ๋๋ ์คํ ๊ฒฐ๊ณผ๋ GitHub ์ ์ฅ์๋ฅผ ์ฐธ๊ณ ํด ์ฃผ์ธ์.
๐ [GitHub ์ ์ฅ์ ๋ฐ๋ก๊ฐ๊ธฐ](https://github.com/Won-Seong/simple-latent-diffusion-model) | ๐ [์ฌ์ฉํ ๋ฐ์ดํฐ์ธํธ](https://www.aihub.or.kr/aihubdata/data/view.do?dataSetSn=618)
"""
)
with gr.Row():
with gr.Column():
text_input = gr.Textbox(
placeholder="๋ชฝํ์ฃผ๋ฅผ ์ค๋ช
ํ๋ ํ
์คํธ๋ฅผ ์
๋ ฅํ์ธ์.",
label="Text Condition"
)
guidance_slider = gr.Slider(
0, 10, value=2,
label="Guidance Scale",
info="์ด ์ซ์๊ฐ ํฌ๋ฉด ํด์๋ก ์
๋ ฅ ํ
์คํธ๋ฅผ ๋ ๊ฐํ๊ฒ ์ด์ฉํ์ฌ ๋ชฝํ์ฃผ๋ฅผ ์์ฑํฉ๋๋ค."
)
submit_btn = gr.Button("์์ฑํ๊ธฐ")
with gr.Column():
image_output = gr.Image(label="์์ฑ๋ ๋ชฝํ์ฃผ")
submit_btn.click(fn=generate_image, inputs=[text_input, guidance_slider], outputs=image_output)
gr.Examples(examples, inputs = [text_input, guidance_slider, image_output])
demo.launch()
|