KoFace-AI / app.py
JuyeopDang's picture
Update app.py
d26e6a6 verified
import gradio as gr
import numpy as np
import random
import torch
from PIL import Image as im
from helper.cond_encoder import CLIPEncoder
from helper.loader import Loader
from auto_encoder.models.variational_auto_encoder import VariationalAutoEncoder
from clip.models.ko_clip import KoCLIPWrapper
from diffusion_model.sampler.ddim import DDIM
from diffusion_model.models.latent_diffusion_model import LatentDiffusionModel
from diffusion_model.network.unet import Unet
from diffusion_model.network.unet_wrapper import UnetWrapper
from huggingface_hub import hf_hub_download
# import spaces #[uncomment to use ZeroGPU]
device = "cuda" if torch.cuda.is_available() else "cpu"
loader = Loader(device)
repo_id = "JuyeopDang/KoFace-Diffusion"
CONFIG_PATH = 'configs/composite_config.yaml'
if torch.cuda.is_available():
torch_dtype = torch.float16
else:
torch_dtype = torch.float32
def load_model_from_HF(model, repo_id, filename, is_ema=False):
try:
model_path = hf_hub_download(repo_id=repo_id, filename=filename)
except Exception as e:
print(f"ํŒŒ์ผ ๋‹ค์šด๋กœ๋“œ ๋˜๋Š” ๋ชจ๋ธ ๋กœ๋“œ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {e}")
model_path = model_path[:-4]
model = loader.model_load(model_path, model, is_ema=is_ema, print_dict=False)
return model
examples = [
['Guidance Scale์ด 0์ธ ๊ฒฝ์šฐ, Text Condition์„ ์ด์šฉํ•˜์ง€ ์•Š๊ณ  ์ž„์˜์˜ ์–ผ๊ตด์„ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค.', 0, im.open("./assets/Example0.webp")],
['๋™๊ทธ๋ž€ ์–ผ๊ตด์— ํ’์„ฑํ•˜๊ณ  ์‚ด์ง ๊ธด ์ปคํŠธ๋จธ๋ฆฌ์™€ ๊ฑฐ์˜ ๋‚˜์˜ค์ง€ ์•Š์€ ์„ฑ๋Œ€๊ฐ€ ๋‚จ์„ฑ๋ณด๋‹ค๋Š” ์—ฌ์„ฑ์ ์ธ ๋ถ„์œ„๊ธฐ๋ฅผ ๋ƒ…๋‹ˆ๋‹ค. ๋˜๋ ทํ•œ ๋ˆˆ๊ณผ ์ž…์ˆ ์ด ์ธ๋ฌผ์˜ ์„ฌ์„ธํ•จ์„ ๋”์šฑ ๋ถ€๊ฐ์‹œํ‚ค๊ณ  ์ง€์ ์œผ๋กœ ๋ณด์ด๊ฒŒ ๋งŒ๋“ญ๋‹ˆ๋‹ค.', 2, im.open("./assets/Example1.webp")],
['์ฝ”๊ฐ€ ํฐ ๋‚จ์ž๋‹ค.', 4, im.open("./assets/Example2.webp")],
['์•„์คŒ๋งˆ ํŒŒ๋งˆ๊ฐ™์ด ํฌ๊ณ  ๋ถ€ํ•œ ๋จธ๋ฆฌ์— ์ด๋ชฉ๊ตฌ๋น„๊ฐ€ ์•ฝ๊ฐ„ ์—ฌ์„ฑ์Šค๋Ÿฌ์›Œ ๋ณด์ด๋Š” ์–ผ๊ตด์ด๋ฉฐ, ์ ์ง€์•Š์€ ์ฃผ๋ฆ„๊ณผ ์ฒ˜์ง„ ๋ˆˆ์—์„œ ์—ฐ๋ฅœ์ด ๋А๊บผ์ง„๋‹ค. ํฐ ๋ˆˆ๊ณผ ๋„“์€ ํŒ”์ž ์ฃผ๋ฆ„์ด ์žˆ๋Š” ๊ณ ์ง‘์Šค๋Ÿฐ ์ž…๋งค๋Š” ์™„๊ณ ํ•œ ์›์น™์ฃผ์˜์ž ๊ฐ™์€ ๋А๋‚Œ์ด ์žˆ๋‹ค.', 4, im.open("./assets/Example3.webp")],
['ํ—ค์–ด ์†์งˆ์ด ๋‹ค์†Œ ๋ฏธ์ˆ™ํ•˜์—ฌ ์„ธ๋ จ๋œ ๋А๋‚Œ์ด ๋ถ€์กฑํ•˜๋‹ค. ๋ˆˆ ๋์ด ์˜ฌ๋ผ๊ฐ€ ์žˆ์–ด ๋ˆˆ๋น›์ด ๋‚ ์นด๋กญ๊ณ  ์˜ˆ๋ฏผํ•ด ๋ณด์ธ๋‹ค. ์ „์ฒด์ ์œผ๋กœ ๋งˆ๋ฅธ ์ฒด๊ฒฉ์ผ ๊ฒƒ์œผ๋กœ ๋ณด์ด๋ฉฐ, ์—…๋ฌด ์ฒ˜๋ฆฌ ๋Šฅ๋ ฅ์€ ๋›ฐ์–ด๋‚˜๊ฒ ์ง€๋งŒ, ๊ต์šฐ ๊ด€๊ณ„๋Š” ์›๋งŒํ•˜์ง€ ์•Š์„ ์ˆ˜๋„ ์žˆ๋‹ค.', 7, im.open("./assets/Example4.webp")]
]
if __name__ == "__main__":
vae = VariationalAutoEncoder(CONFIG_PATH)
sampler = DDIM(CONFIG_PATH)
clip = KoCLIPWrapper()
cond_encoder = CLIPEncoder(clip, CONFIG_PATH)
network = UnetWrapper(Unet, CONFIG_PATH, cond_encoder)
dm = LatentDiffusionModel(network, sampler, vae)
vae = load_model_from_HF(vae, repo_id, "composite_epoch2472.pth", False)
clip = load_model_from_HF(clip, repo_id, "asian-composite-fine-tuned-koclip.pth", True)
dm = load_model_from_HF(dm, repo_id, "asian-composite-clip-ldm.pth", True)
def generate_image(y, gamma):
images = dm.sample(2, y = y, gamma = gamma)
images = images.permute(0, 2, 3, 1)
if type(images) is torch.Tensor:
images = images.detach().cpu().numpy()
images = np.clip(images / 2 + 0.5, 0, 1)
return im.fromarray((images[0] * 255).astype(np.uint8))
with gr.Blocks() as demo:
gr.Markdown(
"""
# KoFace AI - ํ•œ๊ตญ์ธ ๋ชฝํƒ€์ฃผ ์ƒ์„ฑ๊ธฐ
* ๋ชฝํƒ€์ฃผ๋ฅผ ์„ค๋ช…ํ•˜๋Š” ํ…์ŠคํŠธ๋ฅผ ์ž…๋ ฅํ•˜๊ณ  '์ƒ์„ฑํ•˜๊ธฐ' ๋ฒ„ํŠผ์„ ๋ˆŒ๋Ÿฌ์ฃผ์„ธ์š”.
* **์ฐธ๊ณ **: ํ˜„์žฌ ๋ฌด๋ฃŒ ํ‹ฐ์–ด๋ฅผ ์ด์šฉํ•˜๊ณ  ์žˆ์–ด, **์ƒ์„ฑ์— ์•ฝ 10๋ถ„ ์ •๋„์˜ ์‹œ๊ฐ„**์ด ๊ฑธ๋ฆด ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
* ์ด AI๋Š” Latent Diffusion Model์„ ์ง์ ‘ ๊ตฌํ˜„ํ•˜์—ฌ ๋งŒ๋“ค์—ˆ์Šต๋‹ˆ๋‹ค. ์ž์„ธํ•œ ์ฝ”๋“œ๋‚˜ ์‹คํ—˜ ๊ฒฐ๊ณผ๋Š” GitHub ์ €์žฅ์†Œ๋ฅผ ์ฐธ๊ณ ํ•ด ์ฃผ์„ธ์š”.
๐Ÿ”— [GitHub ์ €์žฅ์†Œ ๋ฐ”๋กœ๊ฐ€๊ธฐ](https://github.com/Won-Seong/simple-latent-diffusion-model) | ๐Ÿ“„ [์‚ฌ์šฉํ•œ ๋ฐ์ดํ„ฐ์„ธํŠธ](https://www.aihub.or.kr/aihubdata/data/view.do?dataSetSn=618)
"""
)
with gr.Row():
with gr.Column():
text_input = gr.Textbox(
placeholder="๋ชฝํƒ€์ฃผ๋ฅผ ์„ค๋ช…ํ•˜๋Š” ํ…์ŠคํŠธ๋ฅผ ์ž…๋ ฅํ•˜์„ธ์š”.",
label="Text Condition"
)
guidance_slider = gr.Slider(
0, 10, value=2,
label="Guidance Scale",
info="์ด ์ˆซ์ž๊ฐ€ ํฌ๋ฉด ํด์ˆ˜๋ก ์ž…๋ ฅ ํ…์ŠคํŠธ๋ฅผ ๋” ๊ฐ•ํ•˜๊ฒŒ ์ด์šฉํ•˜์—ฌ ๋ชฝํƒ€์ฃผ๋ฅผ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค."
)
submit_btn = gr.Button("์ƒ์„ฑํ•˜๊ธฐ")
with gr.Column():
image_output = gr.Image(label="์ƒ์„ฑ๋œ ๋ชฝํƒ€์ฃผ")
submit_btn.click(fn=generate_image, inputs=[text_input, guidance_slider], outputs=image_output)
gr.Examples(examples, inputs = [text_input, guidance_slider, image_output])
demo.launch()