import time import gradio as gr import torch from einops import rearrange, repeat from PIL import Image import numpy as np import spaces # Hugging Face Spaces 임포트 추가 import threading import sys import os # 전역 변수 정의 model_initialized = False flux_generator = None initialization_message = "모델 로딩 중... 잠시만 기다려주세요." # 간단한 인용 정보 추가 _CITE_ = """PuLID: Person-under-Language Image Diffusion Model""" # GPU 사용 가능 여부 확인 및 장치 설정 - 메인 프로세스에서는 호출하지 않음 def get_device(): if torch.cuda.is_available(): return torch.device('cuda') else: print("CUDA GPU를 찾을 수 없습니다. CPU를 사용합니다.") return torch.device('cpu') def get_models(name: str, device, offload: bool): try: # 필요한 모듈만 지연 임포트 from flux.util import load_ae, load_clip, load_flow_model, load_t5 print(f"모델을 {device}에 로드합니다.") t5 = load_t5(device, max_length=128) clip_model = load_clip(device) model = load_flow_model(name, device="cpu" if offload else device) model.eval() ae = load_ae(name, device="cpu" if offload else device) return model, ae, t5, clip_model except Exception as e: print(f"모델 로드 중 오류 발생: {e}") return None, None, None, None class FluxGenerator: def __init__(self): # GPU 초기화는 Spaces GPU 데코레이터 안에서만 수행 self.device = None # 초기화 시점에는 device를 할당하지 않음 self.offload = False self.model_name = 'flux-dev' self.initialized = False self.model = None self.ae = None self.t5 = None self.clip_model = None self.pulid_model = None def initialize(self): global initialization_message try: # 필요한 모듈 지연 임포트 from pulid.pipeline_flux import PuLIDPipeline from flux.sampling import prepare # 이 시점에서 장치 설정 (GPU 데코레이터 내에서만 호출됨) self.device = get_device() print("모델 초기화 시작...") self.model, self.ae, self.t5, self.clip_model = get_models( self.model_name, device=self.device, offload=self.offload, ) if None in [self.model, self.ae, self.t5, self.clip_model]: print("모델 초기화 실패: 하나 이상의 모델 컴포넌트를 로드할 수 없습니다.") self.initialized = False initialization_message = "모델 로드 실패: 일부 컴포넌트를 로드할 수 없습니다." return self.pulid_model = PuLIDPipeline( self.model, 'cuda' if torch.cuda.is_available() else 'cpu', weight_dtype=torch.bfloat16 if self.device.type == 'cuda' else torch.float32 ) self.pulid_model.load_pretrain() self.initialized = True print("모델 초기화 완료!") # UI 메시지 업데이트 initialization_message = "모델 로딩 완료! 이제 이미지를 생성할 수 있습니다." except Exception as e: import traceback error_msg = f"모델 초기화 중 오류 발생: {str(e)}\n{traceback.format_exc()}" print(error_msg) self.initialized = False # UI 메시지 업데이트 initialization_message = f"모델 로딩 실패: {str(e)}" # 지연 로딩을 위한 백그라운드 초기화 함수 - GPU 데코레이터로 변경 @spaces.GPU(duration=60) def initialize_models(): global flux_generator, model_initialized, initialization_message print("GPU 데코레이터 내에서 모델 초기화 시작...") try: # 지연 임포트 from flux.sampling import denoise, get_noise, get_schedule, prepare, rf_denoise, rf_inversion, unpack from flux.util import SamplingOptions from pulid.utils import resize_numpy_image_long, seed_everything # 모델 초기화 flux_generator = FluxGenerator() flux_generator.initialize() model_initialized = flux_generator.initialized except Exception as e: import traceback error_msg = f"초기화 중 오류 발생: {str(e)}\n{traceback.format_exc()}" print(error_msg) model_initialized = False initialization_message = f"모델 초기화 오류: {str(e)}" return initialization_message # 모델 상태 확인 함수 def check_model_status(): return initialization_message # Spaces GPU 데코레이터 추가 (120초 GPU 사용) @spaces.GPU(duration=120) @torch.inference_mode() def generate_image( prompt: str, id_image, num_steps: int, guidance: float, seed, id_weight: float, neg_prompt: str, true_cfg: float, gamma: float, eta: float, ): global flux_generator, model_initialized # 모델이 초기화되지 않았으면 오류 메시지 반환 if not model_initialized: return None, "모델 초기화가 완료되지 않았습니다. 모델 초기화 버튼을 눌러주세요." # ID 이미지가 없으면 실행 불가 if id_image is None: return None, "오류: ID 이미지가 필요합니다." try: # 필요한 모듈 지연 임포트 from flux.sampling import denoise, get_noise, get_schedule, prepare, rf_denoise, rf_inversion, unpack from flux.util import SamplingOptions from pulid.utils import resize_numpy_image_long, seed_everything # 고정 매개변수 width = 512 height = 512 start_step = 0 timestep_to_start_cfg = 1 max_sequence_length = 128 s = 0 tau = 5 flux_generator.t5.max_length = max_sequence_length # 시드 설정 try: seed = int(seed) except: seed = -1 if seed == -1: seed = None opts = SamplingOptions( prompt=prompt, width=width, height=height, num_steps=num_steps, guidance=guidance, seed=seed, ) if opts.seed is None: opts.seed = torch.Generator(device="cpu").seed() seed_everything(opts.seed) print(f"Generating prompt: '{opts.prompt}' (seed={opts.seed})...") t0 = time.perf_counter() use_true_cfg = abs(true_cfg - 1.0) > 1e-6 # 1) 입력 노이즈 준비 noise = get_noise( num_samples=1, height=opts.height, width=opts.width, device=flux_generator.device, dtype=torch.bfloat16 if flux_generator.device.type == 'cuda' else torch.float32, seed=opts.seed, ) bs, c, h, w = noise.shape noise = rearrange(noise, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) if noise.shape[0] == 1 and bs > 1: noise = repeat(noise, "1 ... -> bs ...", bs=bs) # ID 이미지 인코딩 encode_t0 = time.perf_counter() id_image = id_image.resize((opts.width, opts.height), resample=Image.LANCZOS) x = torch.from_numpy(np.array(id_image).astype(np.float32)) x = (x / 127.5) - 1.0 x = rearrange(x, "h w c -> 1 c h w") x = x.to(flux_generator.device) dtype = torch.bfloat16 if flux_generator.device.type == 'cuda' else torch.float32 with torch.autocast(device_type=flux_generator.device.type, dtype=dtype): x = flux_generator.ae.encode(x) x = x.to(dtype) encode_t1 = time.perf_counter() print(f"Encoded in {encode_t1 - encode_t0:.2f} seconds.") timesteps = get_schedule(opts.num_steps, x.shape[-1] * x.shape[-2] // 4, shift=False) # 2) 텍스트 임베딩 준비 inp = prepare(t5=flux_generator.t5, clip=flux_generator.clip_model, img=x, prompt=opts.prompt) inp_inversion = prepare(t5=flux_generator.t5, clip=flux_generator.clip_model, img=x, prompt="") inp_neg = None if use_true_cfg: inp_neg = prepare(t5=flux_generator.t5, clip=flux_generator.clip_model, img=x, prompt=neg_prompt) # 3) ID 임베딩 생성 id_embeddings = None uncond_id_embeddings = None if id_image is not None: id_image = np.array(id_image) id_image = resize_numpy_image_long(id_image, 1024) id_embeddings, uncond_id_embeddings = flux_generator.pulid_model.get_id_embedding(id_image, cal_uncond=use_true_cfg) y_0 = inp["img"].clone().detach() # 이미지 처리 과정 inverted = rf_inversion( flux_generator.model, **inp_inversion, timesteps=timesteps, guidance=opts.guidance, id=id_embeddings, id_weight=id_weight, start_step=start_step, uncond_id=uncond_id_embeddings, true_cfg=true_cfg, timestep_to_start_cfg=timestep_to_start_cfg, neg_txt=inp_neg["txt"] if use_true_cfg else None, neg_txt_ids=inp_neg["txt_ids"] if use_true_cfg else None, neg_vec=inp_neg["vec"] if use_true_cfg else None, aggressive_offload=False, y_1=noise, gamma=gamma ) inp["img"] = inverted inp_inversion["img"] = inverted edited = rf_denoise( flux_generator.model, **inp, timesteps=timesteps, guidance=opts.guidance, id=id_embeddings, id_weight=id_weight, start_step=start_step, uncond_id=uncond_id_embeddings, true_cfg=true_cfg, timestep_to_start_cfg=timestep_to_start_cfg, neg_txt=inp_neg["txt"] if use_true_cfg else None, neg_txt_ids=inp_neg["txt_ids"] if use_true_cfg else None, neg_vec=inp_neg["vec"] if use_true_cfg else None, aggressive_offload=False, y_0=y_0, eta=eta, s=s, tau=tau, ) # 결과 이미지 디코딩 edited = unpack(edited.float(), opts.height, opts.width) with torch.autocast(device_type=flux_generator.device.type, dtype=dtype): edited = flux_generator.ae.decode(edited) t1 = time.perf_counter() print(f"Done in {t1 - t0:.2f} seconds.") # PIL 이미지로 변환 edited = edited.clamp(-1, 1) edited = rearrange(edited[0], "c h w -> h w c") edited = Image.fromarray((127.5 * (edited + 1.0)).cpu().byte().numpy()) return edited, str(opts.seed) except Exception as e: import traceback error_msg = f"이미지 생성 중 오류 발생: {str(e)}\n{traceback.format_exc()}" print(error_msg) return None, error_msg def create_demo(): with gr.Blocks() as demo: gr.Markdown("# PuLID: 인물 이미지 변환 도구") # 모델 상태 표시 status_box = gr.Textbox(label="모델 상태", value=initialization_message) # 초기화 버튼 추가 (백그라운드 초기화 대신 명시적 초기화 버튼 사용) init_btn = gr.Button("모델 초기화") init_btn.click(fn=initialize_models, inputs=[], outputs=[status_box]) refresh_btn = gr.Button("상태 새로고침") refresh_btn.click(fn=check_model_status, inputs=[], outputs=[status_box]) with gr.Row(): with gr.Column(): prompt = gr.Textbox(label="프롬프트", value="portrait, color, cinematic") id_image = gr.Image(label="ID 이미지", type="pil") id_weight = gr.Slider(0.0, 1.0, 0.4, step=0.05, label="ID 가중치") num_steps = gr.Slider(1, 24, 16, step=1, label="단계 수") guidance = gr.Slider(1.0, 10.0, 3.5, step=0.1, label="가이던스") with gr.Accordion("고급 옵션", open=False): neg_prompt = gr.Textbox(label="네거티브 프롬프트", value="") true_cfg = gr.Slider(1.0, 10.0, 3.5, step=0.1, label="CFG 스케일") seed = gr.Textbox(value="-1", label="시드 (-1: 랜덤)") gr.Markdown("### 기타 옵션") gamma = gr.Slider(0.0, 1.0, 0.5, step=0.1, label="감마") eta = gr.Slider(0.0, 1.0, 0.8, step=0.1, label="에타") generate_btn = gr.Button("이미지 생성") with gr.Column(): output_image = gr.Image(label="생성된 이미지") seed_output = gr.Textbox(label="결과/오류 메시지") gr.Markdown(_CITE_) # 예제 추가 with gr.Row(): gr.Markdown("## 예제") example_inps = [ [ 'a portrait of a clown', 'example_inputs/unsplash/lhon-karwan-11tbHtK5STE-unsplash.jpg', 16, 3.5, "-1", 0.4, "", 3.5, 0.5, 0.8 ], [ 'a portrait of a zombie', 'example_inputs/unsplash/baruk-granda-cfLL_jHQ-Iw-unsplash.jpg', 16, 3.5, "42", 0.4, "", 3.5, 0.5, 0.8 ] ] gr.Examples( examples=example_inps, inputs=[prompt, id_image, num_steps, guidance, seed, id_weight, neg_prompt, true_cfg, gamma, eta] ) # Gradio 이벤트 연결 generate_btn.click( fn=generate_image, inputs=[ prompt, id_image, num_steps, guidance, seed, id_weight, neg_prompt, true_cfg, gamma, eta ], outputs=[output_image, seed_output], ) return demo if __name__ == "__main__": import argparse parser = argparse.ArgumentParser(description="PuLID for FLUX.1-dev") parser.add_argument('--version', type=str, default='v0.9.1') parser.add_argument("--name", type=str, default="flux-dev") parser.add_argument("--port", type=int, default=8080) args = parser.parse_args() print("Hugging Face Spaces 환경에서 실행 중입니다. GPU 할당을 요청합니다.") # 메인 프로세스에서는 CUDA 초기화하지 않음 # 백그라운드 스레드 대신 명시적 버튼으로 초기화 demo = create_demo() # 디버그 모드 활성화