Spaces:
Sleeping
Sleeping
import gradio as gr | |
import numpy as np | |
import random | |
import torch | |
from helper.cond_encoder import CLIPEncoder | |
from helper.loader import Loader | |
from auto_encoder.models.variational_auto_encoder import VariationalAutoEncoder | |
from clip.models.ko_clip import KoCLIPWrapper | |
from diffusion_model.sampler.ddim import DDIM | |
from diffusion_model.models.latent_diffusion_model import LatentDiffusionModel | |
from diffusion_model.network.unet import Unet | |
from diffusion_model.network.unet_wrapper import UnetWrapper | |
from huggingface_hub import hf_hub_download | |
# import spaces #[uncomment to use ZeroGPU] | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
loader = Loader(device) | |
repo_id = "JuyeopDang/KoFace-Diffusion" | |
CONFIG_PATH = 'configs/composite_config.yaml' | |
if torch.cuda.is_available(): | |
torch_dtype = torch.float16 | |
else: | |
torch_dtype = torch.float32 | |
def load_model_from_HF(model, repo_id, filename, is_ema=False): | |
try: | |
model_path = hf_hub_download(repo_id=repo_id, filename=filename) | |
except Exception as e: | |
print(f"νμΌ λ€μ΄λ‘λ λλ λͺ¨λΈ λ‘λ μ€ μ€λ₯ λ°μ: {e}") | |
model_path = model_path[:-4] | |
model = loader.model_load(model_path, model, is_ema=is_ema, print_dict=False) | |
return model | |
if __name__ == "__main__": | |
vae = VariationalAutoEncoder(CONFIG_PATH) | |
sampler = DDIM(CONFIG_PATH) | |
clip = KoCLIPWrapper() | |
cond_encoder = CLIPEncoder(clip, CONFIG_PATH) | |
network = UnetWrapper(Unet, CONFIG_PATH, cond_encoder) | |
dm = LatentDiffusionModel(network, sampler, vae) | |
vae = load_model_from_HF(vae, repo_id, "composite_epoch2472.pth", False) | |
clip = load_model_from_HF(clip, repo_id, "asian-composite-fine-tuned-koclip.pth", True) | |
dm = load_model_from_HF(dm, repo_id, "asian-composite-clip-ldm.pth", True) | |
def generate_image(y, gamma, dm): | |
images = dm.sample(2, y = y, gamma = gamma) | |
images = images.permute(0, 2, 3, 1) | |
if type(images) is torch.Tensor: | |
images = images.detach().cpu().numpy() | |
images = np.clip(images / 2 + 0.5, 0, 1) | |
return im.fromarray((images[0] * 255).astype(np.uint8)) | |
demo = gr.Interface( | |
generate_image, | |
inputs=["textbox", gr.Slider(0, 10)], | |
outputs=["image"], | |
) | |
demo.launch() | |