Spaces:
Runtime error
Runtime error
import time | |
import gradio as gr | |
import torch | |
from einops import rearrange, repeat | |
from PIL import Image | |
import numpy as np | |
import spaces # Hugging Face Spaces ์ํฌํธ ์ถ๊ฐ | |
import threading | |
import sys | |
import os | |
# ์ ์ญ ๋ณ์ ์ ์ | |
model_initialized = False | |
flux_generator = None | |
initialization_message = "๋ชจ๋ธ ๋ก๋ฉ ์ค... ์ ์๋ง ๊ธฐ๋ค๋ ค์ฃผ์ธ์." | |
# ๊ฐ๋จํ ์ธ์ฉ ์ ๋ณด ์ถ๊ฐ | |
_CITE_ = """PuLID: Person-under-Language Image Diffusion Model""" | |
# GPU ์ฌ์ฉ ๊ฐ๋ฅ ์ฌ๋ถ ํ์ธ ๋ฐ ์ฅ์น ์ค์ - ๋ฉ์ธ ํ๋ก์ธ์ค์์๋ ํธ์ถํ์ง ์์ | |
def get_device(): | |
if torch.cuda.is_available(): | |
return torch.device('cuda') | |
else: | |
print("CUDA GPU๋ฅผ ์ฐพ์ ์ ์์ต๋๋ค. CPU๋ฅผ ์ฌ์ฉํฉ๋๋ค.") | |
return torch.device('cpu') | |
def get_models(name: str, device, offload: bool): | |
try: | |
# ํ์ํ ๋ชจ๋๋ง ์ง์ฐ ์ํฌํธ | |
from flux.util import load_ae, load_clip, load_flow_model, load_t5 | |
print(f"๋ชจ๋ธ์ {device}์ ๋ก๋ํฉ๋๋ค.") | |
t5 = load_t5(device, max_length=128) | |
clip_model = load_clip(device) | |
model = load_flow_model(name, device="cpu" if offload else device) | |
model.eval() | |
ae = load_ae(name, device="cpu" if offload else device) | |
return model, ae, t5, clip_model | |
except Exception as e: | |
print(f"๋ชจ๋ธ ๋ก๋ ์ค ์ค๋ฅ ๋ฐ์: {e}") | |
return None, None, None, None | |
class FluxGenerator: | |
def __init__(self): | |
# GPU ์ด๊ธฐํ๋ Spaces GPU ๋ฐ์ฝ๋ ์ดํฐ ์์์๋ง ์ํ | |
self.device = None # ์ด๊ธฐํ ์์ ์๋ device๋ฅผ ํ ๋นํ์ง ์์ | |
self.offload = False | |
self.model_name = 'flux-dev' | |
self.initialized = False | |
self.model = None | |
self.ae = None | |
self.t5 = None | |
self.clip_model = None | |
self.pulid_model = None | |
def initialize(self): | |
global initialization_message | |
try: | |
# ํ์ํ ๋ชจ๋ ์ง์ฐ ์ํฌํธ | |
from pulid.pipeline_flux import PuLIDPipeline | |
from flux.sampling import prepare | |
# ์ด ์์ ์์ ์ฅ์น ์ค์ (GPU ๋ฐ์ฝ๋ ์ดํฐ ๋ด์์๋ง ํธ์ถ๋จ) | |
self.device = get_device() | |
print("๋ชจ๋ธ ์ด๊ธฐํ ์์...") | |
self.model, self.ae, self.t5, self.clip_model = get_models( | |
self.model_name, | |
device=self.device, | |
offload=self.offload, | |
) | |
if None in [self.model, self.ae, self.t5, self.clip_model]: | |
print("๋ชจ๋ธ ์ด๊ธฐํ ์คํจ: ํ๋ ์ด์์ ๋ชจ๋ธ ์ปดํฌ๋ํธ๋ฅผ ๋ก๋ํ ์ ์์ต๋๋ค.") | |
self.initialized = False | |
initialization_message = "๋ชจ๋ธ ๋ก๋ ์คํจ: ์ผ๋ถ ์ปดํฌ๋ํธ๋ฅผ ๋ก๋ํ ์ ์์ต๋๋ค." | |
return | |
self.pulid_model = PuLIDPipeline( | |
self.model, | |
'cuda' if torch.cuda.is_available() else 'cpu', | |
weight_dtype=torch.bfloat16 if self.device.type == 'cuda' else torch.float32 | |
) | |
self.pulid_model.load_pretrain() | |
self.initialized = True | |
print("๋ชจ๋ธ ์ด๊ธฐํ ์๋ฃ!") | |
# UI ๋ฉ์์ง ์ ๋ฐ์ดํธ | |
initialization_message = "๋ชจ๋ธ ๋ก๋ฉ ์๋ฃ! ์ด์ ์ด๋ฏธ์ง๋ฅผ ์์ฑํ ์ ์์ต๋๋ค." | |
except Exception as e: | |
import traceback | |
error_msg = f"๋ชจ๋ธ ์ด๊ธฐํ ์ค ์ค๋ฅ ๋ฐ์: {str(e)}\n{traceback.format_exc()}" | |
print(error_msg) | |
self.initialized = False | |
# UI ๋ฉ์์ง ์ ๋ฐ์ดํธ | |
initialization_message = f"๋ชจ๋ธ ๋ก๋ฉ ์คํจ: {str(e)}" | |
# ์ง์ฐ ๋ก๋ฉ์ ์ํ ๋ฐฑ๊ทธ๋ผ์ด๋ ์ด๊ธฐํ ํจ์ - GPU ๋ฐ์ฝ๋ ์ดํฐ๋ก ๋ณ๊ฒฝ | |
def initialize_models(): | |
global flux_generator, model_initialized, initialization_message | |
print("GPU ๋ฐ์ฝ๋ ์ดํฐ ๋ด์์ ๋ชจ๋ธ ์ด๊ธฐํ ์์...") | |
try: | |
# ์ง์ฐ ์ํฌํธ | |
from flux.sampling import denoise, get_noise, get_schedule, prepare, rf_denoise, rf_inversion, unpack | |
from flux.util import SamplingOptions | |
from pulid.utils import resize_numpy_image_long, seed_everything | |
# ๋ชจ๋ธ ์ด๊ธฐํ | |
flux_generator = FluxGenerator() | |
flux_generator.initialize() | |
model_initialized = flux_generator.initialized | |
except Exception as e: | |
import traceback | |
error_msg = f"์ด๊ธฐํ ์ค ์ค๋ฅ ๋ฐ์: {str(e)}\n{traceback.format_exc()}" | |
print(error_msg) | |
model_initialized = False | |
initialization_message = f"๋ชจ๋ธ ์ด๊ธฐํ ์ค๋ฅ: {str(e)}" | |
return initialization_message | |
# ๋ชจ๋ธ ์ํ ํ์ธ ํจ์ | |
def check_model_status(): | |
return initialization_message | |
# Spaces GPU ๋ฐ์ฝ๋ ์ดํฐ ์ถ๊ฐ (120์ด GPU ์ฌ์ฉ) | |
def generate_image( | |
prompt: str, | |
id_image, | |
num_steps: int, | |
guidance: float, | |
seed, | |
id_weight: float, | |
neg_prompt: str, | |
true_cfg: float, | |
gamma: float, | |
eta: float, | |
): | |
global flux_generator, model_initialized | |
# ๋ชจ๋ธ์ด ์ด๊ธฐํ๋์ง ์์์ผ๋ฉด ์ค๋ฅ ๋ฉ์์ง ๋ฐํ | |
if not model_initialized: | |
return None, "๋ชจ๋ธ ์ด๊ธฐํ๊ฐ ์๋ฃ๋์ง ์์์ต๋๋ค. ๋ชจ๋ธ ์ด๊ธฐํ ๋ฒํผ์ ๋๋ฌ์ฃผ์ธ์." | |
# ID ์ด๋ฏธ์ง๊ฐ ์์ผ๋ฉด ์คํ ๋ถ๊ฐ | |
if id_image is None: | |
return None, "์ค๋ฅ: ID ์ด๋ฏธ์ง๊ฐ ํ์ํฉ๋๋ค." | |
try: | |
# ํ์ํ ๋ชจ๋ ์ง์ฐ ์ํฌํธ | |
from flux.sampling import denoise, get_noise, get_schedule, prepare, rf_denoise, rf_inversion, unpack | |
from flux.util import SamplingOptions | |
from pulid.utils import resize_numpy_image_long, seed_everything | |
# ๊ณ ์ ๋งค๊ฐ๋ณ์ | |
width = 512 | |
height = 512 | |
start_step = 0 | |
timestep_to_start_cfg = 1 | |
max_sequence_length = 128 | |
s = 0 | |
tau = 5 | |
flux_generator.t5.max_length = max_sequence_length | |
# ์๋ ์ค์ | |
try: | |
seed = int(seed) | |
except: | |
seed = -1 | |
if seed == -1: | |
seed = None | |
opts = SamplingOptions( | |
prompt=prompt, | |
width=width, | |
height=height, | |
num_steps=num_steps, | |
guidance=guidance, | |
seed=seed, | |
) | |
if opts.seed is None: | |
opts.seed = torch.Generator(device="cpu").seed() | |
seed_everything(opts.seed) | |
print(f"Generating prompt: '{opts.prompt}' (seed={opts.seed})...") | |
t0 = time.perf_counter() | |
use_true_cfg = abs(true_cfg - 1.0) > 1e-6 | |
# 1) ์ ๋ ฅ ๋ ธ์ด์ฆ ์ค๋น | |
noise = get_noise( | |
num_samples=1, | |
height=opts.height, | |
width=opts.width, | |
device=flux_generator.device, | |
dtype=torch.bfloat16 if flux_generator.device.type == 'cuda' else torch.float32, | |
seed=opts.seed, | |
) | |
bs, c, h, w = noise.shape | |
noise = rearrange(noise, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) | |
if noise.shape[0] == 1 and bs > 1: | |
noise = repeat(noise, "1 ... -> bs ...", bs=bs) | |
# ID ์ด๋ฏธ์ง ์ธ์ฝ๋ฉ | |
encode_t0 = time.perf_counter() | |
id_image = id_image.resize((opts.width, opts.height), resample=Image.LANCZOS) | |
x = torch.from_numpy(np.array(id_image).astype(np.float32)) | |
x = (x / 127.5) - 1.0 | |
x = rearrange(x, "h w c -> 1 c h w") | |
x = x.to(flux_generator.device) | |
dtype = torch.bfloat16 if flux_generator.device.type == 'cuda' else torch.float32 | |
with torch.autocast(device_type=flux_generator.device.type, dtype=dtype): | |
x = flux_generator.ae.encode(x) | |
x = x.to(dtype) | |
encode_t1 = time.perf_counter() | |
print(f"Encoded in {encode_t1 - encode_t0:.2f} seconds.") | |
timesteps = get_schedule(opts.num_steps, x.shape[-1] * x.shape[-2] // 4, shift=False) | |
# 2) ํ ์คํธ ์๋ฒ ๋ฉ ์ค๋น | |
inp = prepare(t5=flux_generator.t5, clip=flux_generator.clip_model, img=x, prompt=opts.prompt) | |
inp_inversion = prepare(t5=flux_generator.t5, clip=flux_generator.clip_model, img=x, prompt="") | |
inp_neg = None | |
if use_true_cfg: | |
inp_neg = prepare(t5=flux_generator.t5, clip=flux_generator.clip_model, img=x, prompt=neg_prompt) | |
# 3) ID ์๋ฒ ๋ฉ ์์ฑ | |
id_embeddings = None | |
uncond_id_embeddings = None | |
if id_image is not None: | |
id_image = np.array(id_image) | |
id_image = resize_numpy_image_long(id_image, 1024) | |
id_embeddings, uncond_id_embeddings = flux_generator.pulid_model.get_id_embedding(id_image, cal_uncond=use_true_cfg) | |
y_0 = inp["img"].clone().detach() | |
# ์ด๋ฏธ์ง ์ฒ๋ฆฌ ๊ณผ์ | |
inverted = rf_inversion( | |
flux_generator.model, | |
**inp_inversion, | |
timesteps=timesteps, | |
guidance=opts.guidance, | |
id=id_embeddings, | |
id_weight=id_weight, | |
start_step=start_step, | |
uncond_id=uncond_id_embeddings, | |
true_cfg=true_cfg, | |
timestep_to_start_cfg=timestep_to_start_cfg, | |
neg_txt=inp_neg["txt"] if use_true_cfg else None, | |
neg_txt_ids=inp_neg["txt_ids"] if use_true_cfg else None, | |
neg_vec=inp_neg["vec"] if use_true_cfg else None, | |
aggressive_offload=False, | |
y_1=noise, | |
gamma=gamma | |
) | |
inp["img"] = inverted | |
inp_inversion["img"] = inverted | |
edited = rf_denoise( | |
flux_generator.model, | |
**inp, | |
timesteps=timesteps, | |
guidance=opts.guidance, | |
id=id_embeddings, | |
id_weight=id_weight, | |
start_step=start_step, | |
uncond_id=uncond_id_embeddings, | |
true_cfg=true_cfg, | |
timestep_to_start_cfg=timestep_to_start_cfg, | |
neg_txt=inp_neg["txt"] if use_true_cfg else None, | |
neg_txt_ids=inp_neg["txt_ids"] if use_true_cfg else None, | |
neg_vec=inp_neg["vec"] if use_true_cfg else None, | |
aggressive_offload=False, | |
y_0=y_0, | |
eta=eta, | |
s=s, | |
tau=tau, | |
) | |
# ๊ฒฐ๊ณผ ์ด๋ฏธ์ง ๋์ฝ๋ฉ | |
edited = unpack(edited.float(), opts.height, opts.width) | |
with torch.autocast(device_type=flux_generator.device.type, dtype=dtype): | |
edited = flux_generator.ae.decode(edited) | |
t1 = time.perf_counter() | |
print(f"Done in {t1 - t0:.2f} seconds.") | |
# PIL ์ด๋ฏธ์ง๋ก ๋ณํ | |
edited = edited.clamp(-1, 1) | |
edited = rearrange(edited[0], "c h w -> h w c") | |
edited = Image.fromarray((127.5 * (edited + 1.0)).cpu().byte().numpy()) | |
return edited, str(opts.seed) | |
except Exception as e: | |
import traceback | |
error_msg = f"์ด๋ฏธ์ง ์์ฑ ์ค ์ค๋ฅ ๋ฐ์: {str(e)}\n{traceback.format_exc()}" | |
print(error_msg) | |
return None, error_msg | |
def create_demo(): | |
with gr.Blocks() as demo: | |
gr.Markdown("# PuLID: ์ธ๋ฌผ ์ด๋ฏธ์ง ๋ณํ ๋๊ตฌ") | |
# ๋ชจ๋ธ ์ํ ํ์ | |
status_box = gr.Textbox(label="๋ชจ๋ธ ์ํ", value=initialization_message) | |
# ์ด๊ธฐํ ๋ฒํผ ์ถ๊ฐ (๋ฐฑ๊ทธ๋ผ์ด๋ ์ด๊ธฐํ ๋์ ๋ช ์์ ์ด๊ธฐํ ๋ฒํผ ์ฌ์ฉ) | |
init_btn = gr.Button("๋ชจ๋ธ ์ด๊ธฐํ") | |
init_btn.click(fn=initialize_models, inputs=[], outputs=[status_box]) | |
refresh_btn = gr.Button("์ํ ์๋ก๊ณ ์นจ") | |
refresh_btn.click(fn=check_model_status, inputs=[], outputs=[status_box]) | |
with gr.Row(): | |
with gr.Column(): | |
prompt = gr.Textbox(label="ํ๋กฌํํธ", value="portrait, color, cinematic") | |
id_image = gr.Image(label="ID ์ด๋ฏธ์ง", type="pil") | |
id_weight = gr.Slider(0.0, 1.0, 0.4, step=0.05, label="ID ๊ฐ์ค์น") | |
num_steps = gr.Slider(1, 24, 16, step=1, label="๋จ๊ณ ์") | |
guidance = gr.Slider(1.0, 10.0, 3.5, step=0.1, label="๊ฐ์ด๋์ค") | |
with gr.Accordion("๊ณ ๊ธ ์ต์ ", open=False): | |
neg_prompt = gr.Textbox(label="๋ค๊ฑฐํฐ๋ธ ํ๋กฌํํธ", value="") | |
true_cfg = gr.Slider(1.0, 10.0, 3.5, step=0.1, label="CFG ์ค์ผ์ผ") | |
seed = gr.Textbox(value="-1", label="์๋ (-1: ๋๋ค)") | |
gr.Markdown("### ๊ธฐํ ์ต์ ") | |
gamma = gr.Slider(0.0, 1.0, 0.5, step=0.1, label="๊ฐ๋ง") | |
eta = gr.Slider(0.0, 1.0, 0.8, step=0.1, label="์ํ") | |
generate_btn = gr.Button("์ด๋ฏธ์ง ์์ฑ") | |
with gr.Column(): | |
output_image = gr.Image(label="์์ฑ๋ ์ด๋ฏธ์ง") | |
seed_output = gr.Textbox(label="๊ฒฐ๊ณผ/์ค๋ฅ ๋ฉ์์ง") | |
gr.Markdown(_CITE_) | |
# ์์ ์ถ๊ฐ | |
with gr.Row(): | |
gr.Markdown("## ์์ ") | |
example_inps = [ | |
[ | |
'a portrait of a clown', | |
'example_inputs/unsplash/lhon-karwan-11tbHtK5STE-unsplash.jpg', | |
16, 3.5, "-1", 0.4, "", 3.5, 0.5, 0.8 | |
], | |
[ | |
'a portrait of a zombie', | |
'example_inputs/unsplash/baruk-granda-cfLL_jHQ-Iw-unsplash.jpg', | |
16, 3.5, "42", 0.4, "", 3.5, 0.5, 0.8 | |
] | |
] | |
gr.Examples( | |
examples=example_inps, | |
inputs=[prompt, id_image, num_steps, guidance, seed, | |
id_weight, neg_prompt, true_cfg, gamma, eta] | |
) | |
# Gradio ์ด๋ฒคํธ ์ฐ๊ฒฐ | |
generate_btn.click( | |
fn=generate_image, | |
inputs=[ | |
prompt, id_image, num_steps, guidance, seed, | |
id_weight, neg_prompt, true_cfg, gamma, eta | |
], | |
outputs=[output_image, seed_output], | |
) | |
return demo | |
if __name__ == "__main__": | |
import argparse | |
parser = argparse.ArgumentParser(description="PuLID for FLUX.1-dev") | |
parser.add_argument('--version', type=str, default='v0.9.1') | |
parser.add_argument("--name", type=str, default="flux-dev") | |
parser.add_argument("--port", type=int, default=8080) | |
args = parser.parse_args() | |
print("Hugging Face Spaces ํ๊ฒฝ์์ ์คํ ์ค์ ๋๋ค. GPU ํ ๋น์ ์์ฒญํฉ๋๋ค.") | |
# ๋ฉ์ธ ํ๋ก์ธ์ค์์๋ CUDA ์ด๊ธฐํํ์ง ์์ | |
# ๋ฐฑ๊ทธ๋ผ์ด๋ ์ค๋ ๋ ๋์ ๋ช ์์ ๋ฒํผ์ผ๋ก ์ด๊ธฐํ | |
demo = create_demo() | |
# ์์ ๋ ๋ถ๋ถ: create_demo.launch() -> demo.launch() | |
demo.launch(server_name="0.0.0.0", server_port=args.port) |