Spaces:
Running
Running
import gradio as gr | |
import numpy as np | |
import random | |
import torch | |
from PIL import Image as im | |
from helper.cond_encoder import CLIPEncoder | |
from helper.loader import Loader | |
from auto_encoder.models.variational_auto_encoder import VariationalAutoEncoder | |
from clip.models.ko_clip import KoCLIPWrapper | |
from diffusion_model.sampler.ddim import DDIM | |
from diffusion_model.models.latent_diffusion_model import LatentDiffusionModel | |
from diffusion_model.network.unet import Unet | |
from diffusion_model.network.unet_wrapper import UnetWrapper | |
from huggingface_hub import hf_hub_download | |
# import spaces #[uncomment to use ZeroGPU] | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
loader = Loader(device) | |
repo_id = "JuyeopDang/KoFace-Diffusion" | |
CONFIG_PATH = 'configs/composite_config.yaml' | |
if torch.cuda.is_available(): | |
torch_dtype = torch.float16 | |
else: | |
torch_dtype = torch.float32 | |
def load_model_from_HF(model, repo_id, filename, is_ema=False): | |
try: | |
model_path = hf_hub_download(repo_id=repo_id, filename=filename) | |
except Exception as e: | |
print(f"ํ์ผ ๋ค์ด๋ก๋ ๋๋ ๋ชจ๋ธ ๋ก๋ ์ค ์ค๋ฅ ๋ฐ์: {e}") | |
model_path = model_path[:-4] | |
model = loader.model_load(model_path, model, is_ema=is_ema, print_dict=False) | |
return model | |
examples = [ | |
['Guidance Scale์ด 0์ธ ๊ฒฝ์ฐ, Text Condition์ ์ด์ฉํ์ง ์๊ณ ์์์ ์ผ๊ตด์ ์์ฑํฉ๋๋ค.', 0, im.open("./assets/Example0.webp")], | |
['๋๊ทธ๋ ์ผ๊ตด์ ํ์ฑํ๊ณ ์ด์ง ๊ธด ์ปคํธ๋จธ๋ฆฌ์ ๊ฑฐ์ ๋์ค์ง ์์ ์ฑ๋๊ฐ ๋จ์ฑ๋ณด๋ค๋ ์ฌ์ฑ์ ์ธ ๋ถ์๊ธฐ๋ฅผ ๋ ๋๋ค. ๋๋ ทํ ๋๊ณผ ์ ์ ์ด ์ธ๋ฌผ์ ์ฌ์ธํจ์ ๋์ฑ ๋ถ๊ฐ์ํค๊ณ ์ง์ ์ผ๋ก ๋ณด์ด๊ฒ ๋ง๋ญ๋๋ค.', 2, im.open("./assets/Example1.webp")], | |
['์ฝ๊ฐ ํฐ ๋จ์๋ค.', 4, im.open("./assets/Example2.webp")], | |
['์์ค๋ง ํ๋ง๊ฐ์ด ํฌ๊ณ ๋ถํ ๋จธ๋ฆฌ์ ์ด๋ชฉ๊ตฌ๋น๊ฐ ์ฝ๊ฐ ์ฌ์ฑ์ค๋ฌ์ ๋ณด์ด๋ ์ผ๊ตด์ด๋ฉฐ, ์ ์ง์์ ์ฃผ๋ฆ๊ณผ ์ฒ์ง ๋์์ ์ฐ๋ฅ์ด ๋๊บผ์ง๋ค. ํฐ ๋๊ณผ ๋์ ํ์ ์ฃผ๋ฆ์ด ์๋ ๊ณ ์ง์ค๋ฐ ์ ๋งค๋ ์๊ณ ํ ์์น์ฃผ์์ ๊ฐ์ ๋๋์ด ์๋ค.', 4, im.open("./assets/Example3.webp")], | |
['ํค์ด ์์ง์ด ๋ค์ ๋ฏธ์ํ์ฌ ์ธ๋ จ๋ ๋๋์ด ๋ถ์กฑํ๋ค. ๋ ๋์ด ์ฌ๋ผ๊ฐ ์์ด ๋๋น์ด ๋ ์นด๋กญ๊ณ ์๋ฏผํด ๋ณด์ธ๋ค. ์ ์ฒด์ ์ผ๋ก ๋ง๋ฅธ ์ฒด๊ฒฉ์ผ ๊ฒ์ผ๋ก ๋ณด์ด๋ฉฐ, ์ ๋ฌด ์ฒ๋ฆฌ ๋ฅ๋ ฅ์ ๋ฐ์ด๋๊ฒ ์ง๋ง, ๊ต์ฐ ๊ด๊ณ๋ ์๋งํ์ง ์์ ์๋ ์๋ค.', 7, im.open("./assets/Example4.webp")] | |
] | |
if __name__ == "__main__": | |
vae = VariationalAutoEncoder(CONFIG_PATH) | |
sampler = DDIM(CONFIG_PATH) | |
clip = KoCLIPWrapper() | |
cond_encoder = CLIPEncoder(clip, CONFIG_PATH) | |
network = UnetWrapper(Unet, CONFIG_PATH, cond_encoder) | |
dm = LatentDiffusionModel(network, sampler, vae) | |
vae = load_model_from_HF(vae, repo_id, "composite_epoch2472.pth", False) | |
clip = load_model_from_HF(clip, repo_id, "asian-composite-fine-tuned-koclip.pth", True) | |
dm = load_model_from_HF(dm, repo_id, "asian-composite-clip-ldm.pth", True) | |
def generate_image(y, gamma): | |
images = dm.sample(2, y = y, gamma = gamma) | |
images = images.permute(0, 2, 3, 1) | |
if type(images) is torch.Tensor: | |
images = images.detach().cpu().numpy() | |
images = np.clip(images / 2 + 0.5, 0, 1) | |
return im.fromarray((images[0] * 255).astype(np.uint8)) | |
with gr.Blocks() as demo: | |
gr.Markdown( | |
""" | |
# KoFace AI - ํ๊ตญ์ธ ๋ชฝํ์ฃผ ์์ฑ๊ธฐ | |
* ๋ชฝํ์ฃผ๋ฅผ ์ค๋ช ํ๋ ํ ์คํธ๋ฅผ ์ ๋ ฅํ๊ณ '์์ฑํ๊ธฐ' ๋ฒํผ์ ๋๋ฌ์ฃผ์ธ์. | |
* **์ฐธ๊ณ **: ํ์ฌ ๋ฌด๋ฃ ํฐ์ด๋ฅผ ์ด์ฉํ๊ณ ์์ด, **์์ฑ์ ์ฝ 10๋ถ ์ ๋์ ์๊ฐ**์ด ๊ฑธ๋ฆด ์ ์์ต๋๋ค. | |
* ์ด AI๋ Latent Diffusion Model์ ์ง์ ๊ตฌํํ์ฌ ๋ง๋ค์์ต๋๋ค. ์์ธํ ์ฝ๋๋ ์คํ ๊ฒฐ๊ณผ๋ GitHub ์ ์ฅ์๋ฅผ ์ฐธ๊ณ ํด ์ฃผ์ธ์. | |
๐ [GitHub ์ ์ฅ์ ๋ฐ๋ก๊ฐ๊ธฐ](https://github.com/Won-Seong/simple-latent-diffusion-model) | ๐ [์ฌ์ฉํ ๋ฐ์ดํฐ์ธํธ](https://www.aihub.or.kr/aihubdata/data/view.do?dataSetSn=618) | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(): | |
text_input = gr.Textbox( | |
placeholder="๋ชฝํ์ฃผ๋ฅผ ์ค๋ช ํ๋ ํ ์คํธ๋ฅผ ์ ๋ ฅํ์ธ์.", | |
label="Text Condition" | |
) | |
guidance_slider = gr.Slider( | |
0, 10, value=2, | |
label="Guidance Scale", | |
info="์ด ์ซ์๊ฐ ํฌ๋ฉด ํด์๋ก ์ ๋ ฅ ํ ์คํธ๋ฅผ ๋ ๊ฐํ๊ฒ ์ด์ฉํ์ฌ ๋ชฝํ์ฃผ๋ฅผ ์์ฑํฉ๋๋ค." | |
) | |
submit_btn = gr.Button("์์ฑํ๊ธฐ") | |
with gr.Column(): | |
image_output = gr.Image(label="์์ฑ๋ ๋ชฝํ์ฃผ") | |
submit_btn.click(fn=generate_image, inputs=[text_input, guidance_slider], outputs=image_output) | |
gr.Examples(examples, inputs = [text_input, guidance_slider, image_output]) | |
demo.launch() | |