KoFace-AI / app.py
JuyeopDang's picture
Update app.py
c4c4167 verified
raw
history blame
1.63 kB
import gradio as gr
import numpy as np
import random
import torch
from helper.cond_encoder import CLIPEncoder
from auto_encoder.models.variational_auto_encoder import VariationalAutoEncoder
from clip.models.ko_clip import KoCLIPWrapper
from diffusion_model.sampler.ddim import DDIM
from diffusion_model.models.latent_diffusion_model import LatentDiffusionModel
from diffusion_model.network.unet import Unet
from diffusion_model.network.unet_wrapper import UnetWrapper
# import spaces #[uncomment to use ZeroGPU]
device = "cuda" if torch.cuda.is_available() else "cpu"
if torch.cuda.is_available():
torch_dtype = torch.float16
else:
torch_dtype = torch.float32
if __name__ == "__main__":
from huggingface_hub import hf_hub_download
CONFIG_PATH = 'configs/composite_config.yaml'
repo_id = "JuyeopDang/KoFace-Diffusion"
filename = "composite_epoch2472.pth" # 예: "pytorch_model.pt" λ˜λŠ” "model.pt"
vae = VariationalAutoEncoder(CONFIG_PATH)
try:
# 파일 λ‹€μš΄λ‘œλ“œ
# cache_dir을 μ§€μ •ν•˜λ©΄ λ‹€μš΄λ‘œλ“œλœ 파일이 μ €μž₯될 경둜λ₯Ό μ œμ–΄ν•  수 μžˆμŠ΅λ‹ˆλ‹€.
# κΈ°λ³Έμ μœΌλ‘œλŠ” ~/.cache/huggingface/hub 에 μ €μž₯λ©λ‹ˆλ‹€.
model_path = hf_hub_download(repo_id=repo_id, filename=filename)
print(f"λͺ¨λΈ κ°€μ€‘μΉ˜ 파일이 λ‹€μŒ κ²½λ‘œμ— λ‹€μš΄λ‘œλ“œλ˜μ—ˆμŠ΅λ‹ˆλ‹€: {model_path}")
except Exception as e:
print(f"파일 λ‹€μš΄λ‘œλ“œ λ˜λŠ” λͺ¨λΈ λ‘œλ“œ 쀑 였λ₯˜ λ°œμƒ: {e}")
state_dict = torch.load(model_path, map_location='cuda')
vae.load_state_dict(state_dict['model_state_dict'])
print(vae)