KoFace-AI / app.py
JuyeopDang's picture
Update app.py
0c39b3b verified
raw
history blame
2.3 kB
import gradio as gr
import numpy as np
import random
import torch
from helper.cond_encoder import CLIPEncoder
from helper.loader import Loader
from auto_encoder.models.variational_auto_encoder import VariationalAutoEncoder
from clip.models.ko_clip import KoCLIPWrapper
from diffusion_model.sampler.ddim import DDIM
from diffusion_model.models.latent_diffusion_model import LatentDiffusionModel
from diffusion_model.network.unet import Unet
from diffusion_model.network.unet_wrapper import UnetWrapper
from huggingface_hub import hf_hub_download
# import spaces #[uncomment to use ZeroGPU]
device = "cuda" if torch.cuda.is_available() else "cpu"
loader = Loader(device)
repo_id = "JuyeopDang/KoFace-Diffusion"
CONFIG_PATH = 'configs/composite_config.yaml'
if torch.cuda.is_available():
torch_dtype = torch.float16
else:
torch_dtype = torch.float32
def load_model_from_HF(model, repo_id, filename, is_ema=False):
try:
model_path = hf_hub_download(repo_id=repo_id, filename=filename)
except Exception as e:
print(f"파일 λ‹€μš΄λ‘œλ“œ λ˜λŠ” λͺ¨λΈ λ‘œλ“œ 쀑 였λ₯˜ λ°œμƒ: {e}")
model_path = model_path[:-4]
model = loader.model_load(model_path, model, is_ema=is_ema, print_dict=False)
return model
if __name__ == "__main__":
vae = VariationalAutoEncoder(CONFIG_PATH)
sampler = DDIM(CONFIG_PATH)
clip = KoCLIPWrapper()
cond_encoder = CLIPEncoder(clip, CONFIG_PATH)
network = UnetWrapper(Unet, CONFIG_PATH, cond_encoder)
dm = LatentDiffusionModel(network, sampler, vae)
vae = load_model_from_HF(vae, repo_id, "composite_epoch2472.pth", False)
clip = load_model_from_HF(clip, repo_id, "asian-composite-fine-tuned-koclip.pth", True)
dm = load_model_from_HF(dm, repo_id, "asian-composite-clip-ldm.pth", True)
def generate_image(y, gamma, dm):
images = dm.sample(2, y = y, gamma = gamma)
images = images.permute(0, 2, 3, 1)
if type(images) is torch.Tensor:
images = images.detach().cpu().numpy()
images = np.clip(images / 2 + 0.5, 0, 1)
return im.fromarray((images[0] * 255).astype(np.uint8))
demo = gr.Interface(
generate_image,
inputs=["textbox", gr.Slider(0, 10)],
outputs=["image"],
)
demo.launch()