KoFace-AI / app.py
JuyeopDang's picture
Update app.py
c003113 verified
raw
history blame
1.78 kB
import gradio as gr
import numpy as np
import random
import torch
from helper.painter import Painter
from helper.trainer import Trainer
from helper.data_generator import DataGenerator
from helper.loader import Loader
from helper.cond_encoder import CLIPEncoder
from auto_encoder.models.variational_auto_encoder import VariationalAutoEncoder
from clip.models.ko_clip import KoCLIPWrapper
from diffusion_model.sampler.ddim import DDIM
from diffusion_model.models.latent_diffusion_model import LatentDiffusionModel
from diffusion_model.network.unet import Unet
from diffusion_model.network.unet_wrapper import UnetWrapper
# import spaces #[uncomment to use ZeroGPU]
device = "cuda" if torch.cuda.is_available() else "cpu"
if torch.cuda.is_available():
torch_dtype = torch.float16
else:
torch_dtype = torch.float32
if __name__ == "__main__":
from huggingface_hub import hf_hub_download
CONFIG_PATH = 'configs/composite_config.yaml'
repo_id = "JuyeopDang/KoFace-Diffusion"
filename = "composite_epoch2472.pth" # 예: "pytorch_model.pt" λ˜λŠ” "model.pt"
vae = VariationalAutoEncoder(CONFIG_PATH)
try:
# 파일 λ‹€μš΄λ‘œλ“œ
# cache_dir을 μ§€μ •ν•˜λ©΄ λ‹€μš΄λ‘œλ“œλœ 파일이 μ €μž₯될 경둜λ₯Ό μ œμ–΄ν•  수 μžˆμŠ΅λ‹ˆλ‹€.
# κΈ°λ³Έμ μœΌλ‘œλŠ” ~/.cache/huggingface/hub 에 μ €μž₯λ©λ‹ˆλ‹€.
model_path = hf_hub_download(repo_id=repo_id, filename=filename)
print(f"λͺ¨λΈ κ°€μ€‘μΉ˜ 파일이 λ‹€μŒ κ²½λ‘œμ— λ‹€μš΄λ‘œλ“œλ˜μ—ˆμŠ΅λ‹ˆλ‹€: {model_path}")
except Exception as e:
print(f"파일 λ‹€μš΄λ‘œλ“œ λ˜λŠ” λͺ¨λΈ λ‘œλ“œ 쀑 였λ₯˜ λ°œμƒ: {e}")
state_dict = torch.load(model_path, map_location='cuda')
vae.load_state_dict(state_dict['model_state_dict'])
print(vae)