kofaceid / app.py
aiqtech's picture
Update app.py
2432eb0 verified
raw
history blame
12.5 kB
import spaces
import random
import torch
import cv2
import insightface
import gradio as gr
import numpy as np
import os
import shutil
from huggingface_hub import snapshot_download, login
from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor
from kolors.pipelines.pipeline_stable_diffusion_xl_chatglm_256_ipadapter_FaceID import StableDiffusionXLPipeline
from kolors.models.modeling_chatglm import ChatGLMModel
from kolors.models.tokenization_chatglm import ChatGLMTokenizer
from diffusers import AutoencoderKL
from kolors.models.unet_2d_condition import UNet2DConditionModel
from diffusers import EulerDiscreteScheduler
from PIL import Image
from insightface.app import FaceAnalysis
from insightface.data import get_image as ins_get_image
# 캐시 클리어 (선택적)
def clear_cache():
cache_dir = "/home/user/.cache/huggingface/hub"
if os.path.exists(cache_dir):
try:
# CLIP 모델 캐시만 삭제
clip_cache = os.path.join(cache_dir, "models--openai--clip-vit-large-patch14-336")
if os.path.exists(clip_cache):
shutil.rmtree(clip_cache)
print("Cleared CLIP cache")
except Exception as e:
print(f"Could not clear cache: {e}")
# 캐시 클리어 (필요시)
# clear_cache()
# Hugging Face 토큰으로 로그인
HF_TOKEN = os.getenv("HF_TOKEN")
if HF_TOKEN:
login(token=HF_TOKEN)
print("Successfully logged in to Hugging Face Hub")
else:
print("Warning: HF_TOKEN not found. Using public access only.")
# GPU 사용 가능 여부 확인
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if device == "cuda" else torch.float32
print(f"Using device: {device}")
print(f"Using dtype: {dtype}")
# 모델 다운로드 (토큰 사용)
try:
print("Downloading Kolors models...")
ckpt_dir = snapshot_download(
repo_id="Kwai-Kolors/Kolors",
token=HF_TOKEN,
local_dir_use_symlinks=False,
resume_download=True
)
print("Downloading FaceID models...")
ckpt_dir_faceid = snapshot_download(
repo_id="Kwai-Kolors/Kolors-IP-Adapter-FaceID-Plus",
token=HF_TOKEN,
local_dir_use_symlinks=False,
resume_download=True
)
except Exception as e:
print(f"Error downloading models: {e}")
raise
# 모델 로딩
print("Loading text encoder...")
text_encoder = ChatGLMModel.from_pretrained(
f'{ckpt_dir}/text_encoder',
torch_dtype=dtype,
token=HF_TOKEN,
trust_remote_code=True
)
if device == "cuda":
text_encoder = text_encoder.half().to(device)
print("Loading tokenizer...")
tokenizer = ChatGLMTokenizer.from_pretrained(
f'{ckpt_dir}/text_encoder',
token=HF_TOKEN,
trust_remote_code=True
)
print("Loading VAE...")
vae = AutoencoderKL.from_pretrained(
f"{ckpt_dir}/vae",
revision=None,
torch_dtype=dtype,
token=HF_TOKEN
)
if device == "cuda":
vae = vae.half().to(device)
print("Loading scheduler...")
scheduler = EulerDiscreteScheduler.from_pretrained(
f"{ckpt_dir}/scheduler",
token=HF_TOKEN
)
print("Loading UNet...")
unet = UNet2DConditionModel.from_pretrained(
f"{ckpt_dir}/unet",
revision=None,
torch_dtype=dtype,
token=HF_TOKEN
)
if device == "cuda":
unet = unet.half().to(device)
# CLIP 모델 로딩 - safetensors 우선 사용
print("Loading CLIP model...")
try:
# 먼저 로컬 FaceID 디렉토리에서 시도
local_clip_path = f'{ckpt_dir_faceid}/clip-vit-large-patch14-336'
if os.path.exists(local_clip_path):
print(f"Trying to load CLIP from local: {local_clip_path}")
clip_image_encoder = CLIPVisionModelWithProjection.from_pretrained(
local_clip_path,
torch_dtype=dtype,
ignore_mismatched_sizes=True,
token=HF_TOKEN,
use_safetensors=True, # safetensors 우선 사용
local_files_only=True
)
else:
raise FileNotFoundError("Local CLIP not found")
except Exception as e:
print(f"Local loading failed: {e}")
try:
# OpenAI에서 직접 다운로드 (safetensors 버전)
print("Downloading CLIP from OpenAI...")
clip_image_encoder = CLIPVisionModelWithProjection.from_pretrained(
'openai/clip-vit-large-patch14-336',
torch_dtype=dtype,
ignore_mismatched_sizes=True,
token=HF_TOKEN,
use_safetensors=True, # safetensors 우선 사용
revision="main"
)
except Exception as e2:
print(f"SafeTensors loading failed: {e2}")
# 최후의 수단: pytorch_model.bin 사용
print("Trying with pytorch format...")
clip_image_encoder = CLIPVisionModelWithProjection.from_pretrained(
'openai/clip-vit-large-patch14-336',
torch_dtype=dtype,
ignore_mismatched_sizes=True,
token=HF_TOKEN,
use_safetensors=False
)
clip_image_encoder.to(device)
clip_image_processor = CLIPImageProcessor(size=336, crop_size=336)
print("Creating pipeline...")
pipe = StableDiffusionXLPipeline(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
face_clip_encoder=clip_image_encoder,
face_clip_processor=clip_image_processor,
force_zeros_for_empty_prompt=False,
)
print("Models loaded successfully!")
class FaceInfoGenerator():
def __init__(self, root_dir="./.insightface/"):
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] if device == "cuda" else ['CPUExecutionProvider']
self.app = FaceAnalysis(name='antelopev2', root=root_dir, providers=providers)
self.app.prepare(ctx_id=0, det_size=(640, 640))
def get_faceinfo_one_img(self, face_image):
if face_image is None:
return None
face_info = self.app.get(cv2.cvtColor(np.array(face_image), cv2.COLOR_RGB2BGR))
if len(face_info) == 0:
return None
else:
# only use the maximum face
face_info = sorted(face_info, key=lambda x:(x['bbox'][2]-x['bbox'][0])*(x['bbox'][3]-x['bbox'][1]))[-1]
return face_info
def face_bbox_to_square(bbox):
## l, t, r, b to square l, t, r, b
l, t, r, b = bbox
cent_x = (l + r) / 2
cent_y = (t + b) / 2
w, h = r - l, b - t
r = max(w, h) / 2
l0 = cent_x - r
r0 = cent_x + r
t0 = cent_y - r
b0 = cent_y + r
return [l0, t0, r0, b0]
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1024
face_info_generator = FaceInfoGenerator()
@spaces.GPU(duration=60)
def infer(prompt,
image=None,
negative_prompt="low quality, blurry, distorted",
seed=66,
randomize_seed=False,
guidance_scale=5.0,
num_inference_steps=50
):
if image is None:
gr.Warning("Please upload an image with a face.")
return None, 0
if randomize_seed:
seed = random.randint(0, MAX_SEED)
generator = torch.Generator(device=device).manual_seed(seed)
global pipe
pipe = pipe.to(device)
# IP Adapter 로딩
try:
pipe.load_ip_adapter_faceid_plus(f'{ckpt_dir_faceid}/ipa-faceid-plus.bin', device=device)
scale = 0.8
pipe.set_face_fidelity_scale(scale)
except Exception as e:
print(f"Error loading IP adapter: {e}")
raise gr.Error(f"Failed to load face adapter: {str(e)}")
# Face 정보 추출
face_info = face_info_generator.get_faceinfo_one_img(image)
if face_info is None:
raise gr.Error("No face detected in the image. Please provide an image with a clear face.")
try:
face_bbox_square = face_bbox_to_square(face_info["bbox"])
crop_image = image.crop(face_bbox_square)
crop_image = crop_image.resize((336, 336))
crop_image = [crop_image]
face_embeds = torch.from_numpy(np.array([face_info["embedding"]]))
face_embeds = face_embeds.to(device, dtype=dtype)
except Exception as e:
print(f"Error processing face: {e}")
raise gr.Error(f"Failed to process face: {str(e)}")
# 이미지 생성
try:
with torch.no_grad():
image = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
height=1024,
width=1024,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
num_images_per_prompt=1,
generator=generator,
face_crop_image=crop_image,
face_insightface_embeds=face_embeds
).images[0]
except Exception as e:
print(f"Error during inference: {e}")
raise gr.Error(f"Failed to generate image: {str(e)}")
return image, seed
css = """
footer {
visibility: hidden;
}
.container {
max-width: 1200px;
margin: 0 auto;
padding: 20px;
}
"""
# Gradio Interface
with gr.Blocks(theme="soft", css=css) as Kolors:
gr.HTML(
"""
<div class='container' style='display:flex; justify-content:center; gap:12px;'>
<a href="https://huggingface.co/spaces/openfree/Best-AI" target="_blank">
<img src="https://img.shields.io/static/v1?label=OpenFree&message=BEST%20AI%20Services&color=%230000ff&labelColor=%23000080&logo=huggingface&logoColor=%23ffa500&style=for-the-badge" alt="OpenFree badge">
</a>
<a href="https://discord.gg/openfreeai" target="_blank">
<img src="https://img.shields.io/static/v1?label=Discord&message=Openfree%20AI&color=%230000ff&labelColor=%23800080&logo=discord&logoColor=white&style=for-the-badge" alt="Discord badge">
</a>
</div>
<h1 style="text-align: center;">Kolors Face ID - AI Portrait Generator</h1>
<p style="text-align: center;">Upload a face photo and create stunning AI portraits with text prompts!</p>
"""
)
with gr.Row():
with gr.Column(elem_id="col-left"):
with gr.Row():
prompt = gr.Textbox(
label="Prompt",
placeholder="e.g., A professional portrait in business attire, studio lighting",
lines=3,
value="A professional portrait photo, high quality, detailed face"
)
with gr.Row():
image = gr.Image(
label="Upload Face Image",
type="pil",
height=400
)
with gr.Accordion("Advanced Settings", open=False):
negative_prompt = gr.Textbox(
label="Negative prompt",
placeholder="Things to avoid in the image",
value="low quality, blurry, distorted, disfigured",
visible=True,
)
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=66,
)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
with gr.Row():
guidance_scale = gr.Slider(
label="Guidance scale",
minimum=0.0,
maximum=10.0,
step=0.1,
value=5.0,
)
num_inference_steps = gr.Slider(
label="Number of inference steps",
minimum=10,
maximum=50,
step=1,
value=25,
)
with gr.Row():
button = gr.Button("🎨 Generate Portrait", elem_id="button", variant="primary", scale=1)
with gr.Column(elem_id="col-right"):
result = gr.Image(label="Generated Portrait", show_label=True)
seed_used = gr.Number(label="Seed Used", precision=0)
button.click(
fn=infer,
inputs=[prompt, image, negative_prompt, seed, randomize_seed, guidance_scale, num_inference_steps],
outputs=[result, seed_used]
)
if __name__ == "__main__":
Kolors.queue(max_size=10).launch(debug=True, share=False)