openfree's picture
Update app.py
d5beeda verified
import time
import gradio as gr
import torch
from einops import rearrange, repeat
from PIL import Image
import numpy as np
import spaces # Hugging Face Spaces ์ž„ํฌํŠธ ์ถ”๊ฐ€
import threading
import sys
import os
# ์ „์—ญ ๋ณ€์ˆ˜ ์ •์˜
model_initialized = False
flux_generator = None
initialization_message = "๋ชจ๋ธ ๋กœ๋”ฉ ์ค‘... ์ž ์‹œ๋งŒ ๊ธฐ๋‹ค๋ ค์ฃผ์„ธ์š”."
# ๊ฐ„๋‹จํ•œ ์ธ์šฉ ์ •๋ณด ์ถ”๊ฐ€
_CITE_ = """PuLID: Person-under-Language Image Diffusion Model"""
# GPU ์‚ฌ์šฉ ๊ฐ€๋Šฅ ์—ฌ๋ถ€ ํ™•์ธ ๋ฐ ์žฅ์น˜ ์„ค์ • - ๋ฉ”์ธ ํ”„๋กœ์„ธ์Šค์—์„œ๋Š” ํ˜ธ์ถœํ•˜์ง€ ์•Š์Œ
def get_device():
if torch.cuda.is_available():
return torch.device('cuda')
else:
print("CUDA GPU๋ฅผ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค. CPU๋ฅผ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.")
return torch.device('cpu')
def get_models(name: str, device, offload: bool):
try:
# ํ•„์š”ํ•œ ๋ชจ๋“ˆ๋งŒ ์ง€์—ฐ ์ž„ํฌํŠธ
from flux.util import load_ae, load_clip, load_flow_model, load_t5
print(f"๋ชจ๋ธ์„ {device}์— ๋กœ๋“œํ•ฉ๋‹ˆ๋‹ค.")
t5 = load_t5(device, max_length=128)
clip_model = load_clip(device)
model = load_flow_model(name, device="cpu" if offload else device)
model.eval()
ae = load_ae(name, device="cpu" if offload else device)
return model, ae, t5, clip_model
except Exception as e:
print(f"๋ชจ๋ธ ๋กœ๋“œ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {e}")
return None, None, None, None
class FluxGenerator:
def __init__(self):
# GPU ์ดˆ๊ธฐํ™”๋Š” Spaces GPU ๋ฐ์ฝ”๋ ˆ์ดํ„ฐ ์•ˆ์—์„œ๋งŒ ์ˆ˜ํ–‰
self.device = None # ์ดˆ๊ธฐํ™” ์‹œ์ ์—๋Š” device๋ฅผ ํ• ๋‹นํ•˜์ง€ ์•Š์Œ
self.offload = False
self.model_name = 'flux-dev'
self.initialized = False
self.model = None
self.ae = None
self.t5 = None
self.clip_model = None
self.pulid_model = None
def initialize(self):
global initialization_message
try:
# ํ•„์š”ํ•œ ๋ชจ๋“ˆ ์ง€์—ฐ ์ž„ํฌํŠธ
from pulid.pipeline_flux import PuLIDPipeline
from flux.sampling import prepare
# ์ด ์‹œ์ ์—์„œ ์žฅ์น˜ ์„ค์ • (GPU ๋ฐ์ฝ”๋ ˆ์ดํ„ฐ ๋‚ด์—์„œ๋งŒ ํ˜ธ์ถœ๋จ)
self.device = get_device()
print("๋ชจ๋ธ ์ดˆ๊ธฐํ™” ์‹œ์ž‘...")
self.model, self.ae, self.t5, self.clip_model = get_models(
self.model_name,
device=self.device,
offload=self.offload,
)
if None in [self.model, self.ae, self.t5, self.clip_model]:
print("๋ชจ๋ธ ์ดˆ๊ธฐํ™” ์‹คํŒจ: ํ•˜๋‚˜ ์ด์ƒ์˜ ๋ชจ๋ธ ์ปดํฌ๋„ŒํŠธ๋ฅผ ๋กœ๋“œํ•  ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค.")
self.initialized = False
initialization_message = "๋ชจ๋ธ ๋กœ๋“œ ์‹คํŒจ: ์ผ๋ถ€ ์ปดํฌ๋„ŒํŠธ๋ฅผ ๋กœ๋“œํ•  ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค."
return
self.pulid_model = PuLIDPipeline(
self.model,
'cuda' if torch.cuda.is_available() else 'cpu',
weight_dtype=torch.bfloat16 if self.device.type == 'cuda' else torch.float32
)
self.pulid_model.load_pretrain()
self.initialized = True
print("๋ชจ๋ธ ์ดˆ๊ธฐํ™” ์™„๋ฃŒ!")
# UI ๋ฉ”์‹œ์ง€ ์—…๋ฐ์ดํŠธ
initialization_message = "๋ชจ๋ธ ๋กœ๋”ฉ ์™„๋ฃŒ! ์ด์ œ ์ด๋ฏธ์ง€๋ฅผ ์ƒ์„ฑํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค."
except Exception as e:
import traceback
error_msg = f"๋ชจ๋ธ ์ดˆ๊ธฐํ™” ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {str(e)}\n{traceback.format_exc()}"
print(error_msg)
self.initialized = False
# UI ๋ฉ”์‹œ์ง€ ์—…๋ฐ์ดํŠธ
initialization_message = f"๋ชจ๋ธ ๋กœ๋”ฉ ์‹คํŒจ: {str(e)}"
# ์ง€์—ฐ ๋กœ๋”ฉ์„ ์œ„ํ•œ ๋ฐฑ๊ทธ๋ผ์šด๋“œ ์ดˆ๊ธฐํ™” ํ•จ์ˆ˜ - GPU ๋ฐ์ฝ”๋ ˆ์ดํ„ฐ๋กœ ๋ณ€๊ฒฝ
@spaces.GPU(duration=60)
def initialize_models():
global flux_generator, model_initialized, initialization_message
print("GPU ๋ฐ์ฝ”๋ ˆ์ดํ„ฐ ๋‚ด์—์„œ ๋ชจ๋ธ ์ดˆ๊ธฐํ™” ์‹œ์ž‘...")
try:
# ์ง€์—ฐ ์ž„ํฌํŠธ
from flux.sampling import denoise, get_noise, get_schedule, prepare, rf_denoise, rf_inversion, unpack
from flux.util import SamplingOptions
from pulid.utils import resize_numpy_image_long, seed_everything
# ๋ชจ๋ธ ์ดˆ๊ธฐํ™”
flux_generator = FluxGenerator()
flux_generator.initialize()
model_initialized = flux_generator.initialized
except Exception as e:
import traceback
error_msg = f"์ดˆ๊ธฐํ™” ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {str(e)}\n{traceback.format_exc()}"
print(error_msg)
model_initialized = False
initialization_message = f"๋ชจ๋ธ ์ดˆ๊ธฐํ™” ์˜ค๋ฅ˜: {str(e)}"
return initialization_message
# ๋ชจ๋ธ ์ƒํƒœ ํ™•์ธ ํ•จ์ˆ˜
def check_model_status():
return initialization_message
# Spaces GPU ๋ฐ์ฝ”๋ ˆ์ดํ„ฐ ์ถ”๊ฐ€ (120์ดˆ GPU ์‚ฌ์šฉ)
@spaces.GPU(duration=120)
@torch.inference_mode()
def generate_image(
prompt: str,
id_image,
num_steps: int,
guidance: float,
seed,
id_weight: float,
neg_prompt: str,
true_cfg: float,
gamma: float,
eta: float,
):
global flux_generator, model_initialized
# ๋ชจ๋ธ์ด ์ดˆ๊ธฐํ™”๋˜์ง€ ์•Š์•˜์œผ๋ฉด ์˜ค๋ฅ˜ ๋ฉ”์‹œ์ง€ ๋ฐ˜ํ™˜
if not model_initialized:
return None, "๋ชจ๋ธ ์ดˆ๊ธฐํ™”๊ฐ€ ์™„๋ฃŒ๋˜์ง€ ์•Š์•˜์Šต๋‹ˆ๋‹ค. ๋ชจ๋ธ ์ดˆ๊ธฐํ™” ๋ฒ„ํŠผ์„ ๋ˆŒ๋Ÿฌ์ฃผ์„ธ์š”."
# ID ์ด๋ฏธ์ง€๊ฐ€ ์—†์œผ๋ฉด ์‹คํ–‰ ๋ถˆ๊ฐ€
if id_image is None:
return None, "์˜ค๋ฅ˜: ID ์ด๋ฏธ์ง€๊ฐ€ ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค."
try:
# ํ•„์š”ํ•œ ๋ชจ๋“ˆ ์ง€์—ฐ ์ž„ํฌํŠธ
from flux.sampling import denoise, get_noise, get_schedule, prepare, rf_denoise, rf_inversion, unpack
from flux.util import SamplingOptions
from pulid.utils import resize_numpy_image_long, seed_everything
# ๊ณ ์ • ๋งค๊ฐœ๋ณ€์ˆ˜
width = 512
height = 512
start_step = 0
timestep_to_start_cfg = 1
max_sequence_length = 128
s = 0
tau = 5
flux_generator.t5.max_length = max_sequence_length
# ์‹œ๋“œ ์„ค์ •
try:
seed = int(seed)
except:
seed = -1
if seed == -1:
seed = None
opts = SamplingOptions(
prompt=prompt,
width=width,
height=height,
num_steps=num_steps,
guidance=guidance,
seed=seed,
)
if opts.seed is None:
opts.seed = torch.Generator(device="cpu").seed()
seed_everything(opts.seed)
print(f"Generating prompt: '{opts.prompt}' (seed={opts.seed})...")
t0 = time.perf_counter()
use_true_cfg = abs(true_cfg - 1.0) > 1e-6
# 1) ์ž…๋ ฅ ๋…ธ์ด์ฆˆ ์ค€๋น„
noise = get_noise(
num_samples=1,
height=opts.height,
width=opts.width,
device=flux_generator.device,
dtype=torch.bfloat16 if flux_generator.device.type == 'cuda' else torch.float32,
seed=opts.seed,
)
bs, c, h, w = noise.shape
noise = rearrange(noise, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
if noise.shape[0] == 1 and bs > 1:
noise = repeat(noise, "1 ... -> bs ...", bs=bs)
# ID ์ด๋ฏธ์ง€ ์ธ์ฝ”๋”ฉ
encode_t0 = time.perf_counter()
id_image = id_image.resize((opts.width, opts.height), resample=Image.LANCZOS)
x = torch.from_numpy(np.array(id_image).astype(np.float32))
x = (x / 127.5) - 1.0
x = rearrange(x, "h w c -> 1 c h w")
x = x.to(flux_generator.device)
dtype = torch.bfloat16 if flux_generator.device.type == 'cuda' else torch.float32
with torch.autocast(device_type=flux_generator.device.type, dtype=dtype):
x = flux_generator.ae.encode(x)
x = x.to(dtype)
encode_t1 = time.perf_counter()
print(f"Encoded in {encode_t1 - encode_t0:.2f} seconds.")
timesteps = get_schedule(opts.num_steps, x.shape[-1] * x.shape[-2] // 4, shift=False)
# 2) ํ…์ŠคํŠธ ์ž„๋ฒ ๋”ฉ ์ค€๋น„
inp = prepare(t5=flux_generator.t5, clip=flux_generator.clip_model, img=x, prompt=opts.prompt)
inp_inversion = prepare(t5=flux_generator.t5, clip=flux_generator.clip_model, img=x, prompt="")
inp_neg = None
if use_true_cfg:
inp_neg = prepare(t5=flux_generator.t5, clip=flux_generator.clip_model, img=x, prompt=neg_prompt)
# 3) ID ์ž„๋ฒ ๋”ฉ ์ƒ์„ฑ
id_embeddings = None
uncond_id_embeddings = None
if id_image is not None:
id_image = np.array(id_image)
id_image = resize_numpy_image_long(id_image, 1024)
id_embeddings, uncond_id_embeddings = flux_generator.pulid_model.get_id_embedding(id_image, cal_uncond=use_true_cfg)
y_0 = inp["img"].clone().detach()
# ์ด๋ฏธ์ง€ ์ฒ˜๋ฆฌ ๊ณผ์ •
inverted = rf_inversion(
flux_generator.model,
**inp_inversion,
timesteps=timesteps,
guidance=opts.guidance,
id=id_embeddings,
id_weight=id_weight,
start_step=start_step,
uncond_id=uncond_id_embeddings,
true_cfg=true_cfg,
timestep_to_start_cfg=timestep_to_start_cfg,
neg_txt=inp_neg["txt"] if use_true_cfg else None,
neg_txt_ids=inp_neg["txt_ids"] if use_true_cfg else None,
neg_vec=inp_neg["vec"] if use_true_cfg else None,
aggressive_offload=False,
y_1=noise,
gamma=gamma
)
inp["img"] = inverted
inp_inversion["img"] = inverted
edited = rf_denoise(
flux_generator.model,
**inp,
timesteps=timesteps,
guidance=opts.guidance,
id=id_embeddings,
id_weight=id_weight,
start_step=start_step,
uncond_id=uncond_id_embeddings,
true_cfg=true_cfg,
timestep_to_start_cfg=timestep_to_start_cfg,
neg_txt=inp_neg["txt"] if use_true_cfg else None,
neg_txt_ids=inp_neg["txt_ids"] if use_true_cfg else None,
neg_vec=inp_neg["vec"] if use_true_cfg else None,
aggressive_offload=False,
y_0=y_0,
eta=eta,
s=s,
tau=tau,
)
# ๊ฒฐ๊ณผ ์ด๋ฏธ์ง€ ๋””์ฝ”๋”ฉ
edited = unpack(edited.float(), opts.height, opts.width)
with torch.autocast(device_type=flux_generator.device.type, dtype=dtype):
edited = flux_generator.ae.decode(edited)
t1 = time.perf_counter()
print(f"Done in {t1 - t0:.2f} seconds.")
# PIL ์ด๋ฏธ์ง€๋กœ ๋ณ€ํ™˜
edited = edited.clamp(-1, 1)
edited = rearrange(edited[0], "c h w -> h w c")
edited = Image.fromarray((127.5 * (edited + 1.0)).cpu().byte().numpy())
return edited, str(opts.seed)
except Exception as e:
import traceback
error_msg = f"์ด๋ฏธ์ง€ ์ƒ์„ฑ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {str(e)}\n{traceback.format_exc()}"
print(error_msg)
return None, error_msg
def create_demo():
with gr.Blocks() as demo:
gr.Markdown("# PuLID: ์ธ๋ฌผ ์ด๋ฏธ์ง€ ๋ณ€ํ™˜ ๋„๊ตฌ")
# ๋ชจ๋ธ ์ƒํƒœ ํ‘œ์‹œ
status_box = gr.Textbox(label="๋ชจ๋ธ ์ƒํƒœ", value=initialization_message)
# ์ดˆ๊ธฐํ™” ๋ฒ„ํŠผ ์ถ”๊ฐ€ (๋ฐฑ๊ทธ๋ผ์šด๋“œ ์ดˆ๊ธฐํ™” ๋Œ€์‹  ๋ช…์‹œ์  ์ดˆ๊ธฐํ™” ๋ฒ„ํŠผ ์‚ฌ์šฉ)
init_btn = gr.Button("๋ชจ๋ธ ์ดˆ๊ธฐํ™”")
init_btn.click(fn=initialize_models, inputs=[], outputs=[status_box])
refresh_btn = gr.Button("์ƒํƒœ ์ƒˆ๋กœ๊ณ ์นจ")
refresh_btn.click(fn=check_model_status, inputs=[], outputs=[status_box])
with gr.Row():
with gr.Column():
prompt = gr.Textbox(label="ํ”„๋กฌํ”„ํŠธ", value="portrait, color, cinematic")
id_image = gr.Image(label="ID ์ด๋ฏธ์ง€", type="pil")
id_weight = gr.Slider(0.0, 1.0, 0.4, step=0.05, label="ID ๊ฐ€์ค‘์น˜")
num_steps = gr.Slider(1, 24, 16, step=1, label="๋‹จ๊ณ„ ์ˆ˜")
guidance = gr.Slider(1.0, 10.0, 3.5, step=0.1, label="๊ฐ€์ด๋˜์Šค")
with gr.Accordion("๊ณ ๊ธ‰ ์˜ต์…˜", open=False):
neg_prompt = gr.Textbox(label="๋„ค๊ฑฐํ‹ฐ๋ธŒ ํ”„๋กฌํ”„ํŠธ", value="")
true_cfg = gr.Slider(1.0, 10.0, 3.5, step=0.1, label="CFG ์Šค์ผ€์ผ")
seed = gr.Textbox(value="-1", label="์‹œ๋“œ (-1: ๋žœ๋ค)")
gr.Markdown("### ๊ธฐํƒ€ ์˜ต์…˜")
gamma = gr.Slider(0.0, 1.0, 0.5, step=0.1, label="๊ฐ๋งˆ")
eta = gr.Slider(0.0, 1.0, 0.8, step=0.1, label="์—ํƒ€")
generate_btn = gr.Button("์ด๋ฏธ์ง€ ์ƒ์„ฑ")
with gr.Column():
output_image = gr.Image(label="์ƒ์„ฑ๋œ ์ด๋ฏธ์ง€")
seed_output = gr.Textbox(label="๊ฒฐ๊ณผ/์˜ค๋ฅ˜ ๋ฉ”์‹œ์ง€")
gr.Markdown(_CITE_)
# ์˜ˆ์ œ ์ถ”๊ฐ€
with gr.Row():
gr.Markdown("## ์˜ˆ์ œ")
example_inps = [
[
'a portrait of a clown',
'example_inputs/unsplash/lhon-karwan-11tbHtK5STE-unsplash.jpg',
16, 3.5, "-1", 0.4, "", 3.5, 0.5, 0.8
],
[
'a portrait of a zombie',
'example_inputs/unsplash/baruk-granda-cfLL_jHQ-Iw-unsplash.jpg',
16, 3.5, "42", 0.4, "", 3.5, 0.5, 0.8
]
]
gr.Examples(
examples=example_inps,
inputs=[prompt, id_image, num_steps, guidance, seed,
id_weight, neg_prompt, true_cfg, gamma, eta]
)
# Gradio ์ด๋ฒคํŠธ ์—ฐ๊ฒฐ
generate_btn.click(
fn=generate_image,
inputs=[
prompt, id_image, num_steps, guidance, seed,
id_weight, neg_prompt, true_cfg, gamma, eta
],
outputs=[output_image, seed_output],
)
return demo
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="PuLID for FLUX.1-dev")
parser.add_argument('--version', type=str, default='v0.9.1')
parser.add_argument("--name", type=str, default="flux-dev")
parser.add_argument("--port", type=int, default=8080)
args = parser.parse_args()
print("Hugging Face Spaces ํ™˜๊ฒฝ์—์„œ ์‹คํ–‰ ์ค‘์ž…๋‹ˆ๋‹ค. GPU ํ• ๋‹น์„ ์š”์ฒญํ•ฉ๋‹ˆ๋‹ค.")
# ๋ฉ”์ธ ํ”„๋กœ์„ธ์Šค์—์„œ๋Š” CUDA ์ดˆ๊ธฐํ™”ํ•˜์ง€ ์•Š์Œ
# ๋ฐฑ๊ทธ๋ผ์šด๋“œ ์Šค๋ ˆ๋“œ ๋Œ€์‹  ๋ช…์‹œ์  ๋ฒ„ํŠผ์œผ๋กœ ์ดˆ๊ธฐํ™”
demo = create_demo()
# ์ˆ˜์ •๋œ ๋ถ€๋ถ„: create_demo.launch() -> demo.launch()
demo.launch(server_name="0.0.0.0", server_port=args.port)