kofaceid / app.py
aiqtech's picture
Update app.py
a0ed24e verified
raw
history blame
10.9 kB
import spaces
import random
import torch
import cv2
import insightface
import gradio as gr
import numpy as np
import os
from huggingface_hub import snapshot_download, login
from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor
from kolors.pipelines.pipeline_stable_diffusion_xl_chatglm_256_ipadapter_FaceID import StableDiffusionXLPipeline
from kolors.models.modeling_chatglm import ChatGLMModel
from kolors.models.tokenization_chatglm import ChatGLMTokenizer
from diffusers import AutoencoderKL
from kolors.models.unet_2d_condition import UNet2DConditionModel
from diffusers import EulerDiscreteScheduler
from PIL import Image
from insightface.app import FaceAnalysis
from insightface.data import get_image as ins_get_image
# Hugging Face 토큰으로 로그인
HF_TOKEN = os.getenv("HF_TOKEN")
if HF_TOKEN:
login(token=HF_TOKEN)
print("Successfully logged in to Hugging Face Hub")
else:
print("Warning: HF_TOKEN not found. Using public access only.")
# GPU 사용 가능 여부 확인
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if device == "cuda" else torch.float32
# 모델 다운로드 (토큰 사용)
try:
ckpt_dir = snapshot_download(
repo_id="Kwai-Kolors/Kolors",
token=HF_TOKEN,
local_dir_use_symlinks=False
)
ckpt_dir_faceid = snapshot_download(
repo_id="Kwai-Kolors/Kolors-IP-Adapter-FaceID-Plus",
token=HF_TOKEN,
local_dir_use_symlinks=False
)
except Exception as e:
print(f"Error downloading models: {e}")
raise
# 모델 로딩 with error handling
try:
text_encoder = ChatGLMModel.from_pretrained(
f'{ckpt_dir}/text_encoder',
torch_dtype=dtype,
token=HF_TOKEN,
trust_remote_code=True
)
if device == "cuda":
text_encoder = text_encoder.half().to(device)
tokenizer = ChatGLMTokenizer.from_pretrained(
f'{ckpt_dir}/text_encoder',
token=HF_TOKEN,
trust_remote_code=True
)
vae = AutoencoderKL.from_pretrained(
f"{ckpt_dir}/vae",
revision=None,
torch_dtype=dtype,
token=HF_TOKEN
)
if device == "cuda":
vae = vae.half().to(device)
scheduler = EulerDiscreteScheduler.from_pretrained(
f"{ckpt_dir}/scheduler",
token=HF_TOKEN
)
unet = UNet2DConditionModel.from_pretrained(
f"{ckpt_dir}/unet",
revision=None,
torch_dtype=dtype,
token=HF_TOKEN
)
if device == "cuda":
unet = unet.half().to(device)
# CLIP 모델 로딩 with fallback
try:
clip_image_encoder = CLIPVisionModelWithProjection.from_pretrained(
f'{ckpt_dir_faceid}/clip-vit-large-patch14-336',
torch_dtype=dtype,
ignore_mismatched_sizes=True,
token=HF_TOKEN
)
except Exception as e:
print(f"Loading CLIP from local failed: {e}, trying alternative source...")
clip_image_encoder = CLIPVisionModelWithProjection.from_pretrained(
'openai/clip-vit-large-patch14-336',
torch_dtype=dtype,
ignore_mismatched_sizes=True,
token=HF_TOKEN
)
clip_image_encoder.to(device)
clip_image_processor = CLIPImageProcessor(size=336, crop_size=336)
except Exception as e:
print(f"Error loading models: {e}")
raise
# Pipeline 생성
pipe = StableDiffusionXLPipeline(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
face_clip_encoder=clip_image_encoder,
face_clip_processor=clip_image_processor,
force_zeros_for_empty_prompt=False,
)
class FaceInfoGenerator():
def __init__(self, root_dir="./.insightface/"):
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] if device == "cuda" else ['CPUExecutionProvider']
self.app = FaceAnalysis(name='antelopev2', root=root_dir, providers=providers)
self.app.prepare(ctx_id=0, det_size=(640, 640))
def get_faceinfo_one_img(self, face_image):
if face_image is None:
return None
face_info = self.app.get(cv2.cvtColor(np.array(face_image), cv2.COLOR_RGB2BGR))
if len(face_info) == 0:
return None
else:
# only use the maximum face
face_info = sorted(face_info, key=lambda x:(x['bbox'][2]-x['bbox'][0])*(x['bbox'][3]-x['bbox'][1]))[-1]
return face_info
def face_bbox_to_square(bbox):
## l, t, r, b to square l, t, r, b
l, t, r, b = bbox
cent_x = (l + r) / 2
cent_y = (t + b) / 2
w, h = r - l, b - t
r = max(w, h) / 2
l0 = cent_x - r
r0 = cent_x + r
t0 = cent_y - r
b0 = cent_y + r
return [l0, t0, r0, b0]
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1024
face_info_generator = FaceInfoGenerator()
@spaces.GPU
def infer(prompt,
image=None,
negative_prompt="low quality, blurry, distorted",
seed=66,
randomize_seed=False,
guidance_scale=5.0,
num_inference_steps=50
):
if image is None:
return None, 0
if randomize_seed:
seed = random.randint(0, MAX_SEED)
generator = torch.Generator(device=device).manual_seed(seed)
global pipe
pipe = pipe.to(device)
# IP Adapter 로딩
try:
pipe.load_ip_adapter_faceid_plus(f'{ckpt_dir_faceid}/ipa-faceid-plus.bin', device=device)
scale = 0.8
pipe.set_face_fidelity_scale(scale)
except Exception as e:
print(f"Error loading IP adapter: {e}")
raise
# Face 정보 추출
face_info = face_info_generator.get_faceinfo_one_img(image)
if face_info is None:
raise gr.Error("No face detected in the image. Please provide an image with a clear face.")
face_bbox_square = face_bbox_to_square(face_info["bbox"])
crop_image = image.crop(face_bbox_square)
crop_image = crop_image.resize((336, 336))
crop_image = [crop_image]
face_embeds = torch.from_numpy(np.array([face_info["embedding"]]))
face_embeds = face_embeds.to(device, dtype=dtype)
# 이미지 생성
try:
image = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
height=1024,
width=1024,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
num_images_per_prompt=1,
generator=generator,
face_crop_image=crop_image,
face_insightface_embeds=face_embeds
).images[0]
except Exception as e:
print(f"Error during inference: {e}")
raise gr.Error(f"Failed to generate image: {str(e)}")
return image, seed
css = """
footer {
visibility: hidden;
}
.container {
max-width: 1200px;
margin: 0 auto;
padding: 20px;
}
"""
def load_description(fp):
if os.path.exists(fp):
with open(fp, 'r', encoding='utf-8') as f:
content = f.read()
return content
return ""
# Gradio Interface
with gr.Blocks(theme="soft", css=css) as Kolors:
gr.HTML(
"""
<div class='container' style='display:flex; justify-content:center; gap:12px;'>
<a href="https://huggingface.co/spaces/openfree/Best-AI" target="_blank">
<img src="https://img.shields.io/static/v1?label=OpenFree&message=BEST%20AI%20Services&color=%230000ff&labelColor=%23000080&logo=huggingface&logoColor=%23ffa500&style=for-the-badge" alt="OpenFree badge">
</a>
<a href="https://discord.gg/openfreeai" target="_blank">
<img src="https://img.shields.io/static/v1?label=Discord&message=Openfree%20AI&color=%230000ff&labelColor=%23800080&logo=discord&logoColor=white&style=for-the-badge" alt="Discord badge">
</a>
</div>
<h1 style="text-align: center;">Kolors Face ID - AI Portrait Generator</h1>
<p style="text-align: center;">Upload a face photo and create stunning AI portraits with text prompts!</p>
"""
)
with gr.Row():
with gr.Column(elem_id="col-left"):
with gr.Row():
prompt = gr.Textbox(
label="Prompt",
placeholder="e.g., A professional portrait in business attire, studio lighting",
lines=3,
value="A professional portrait photo, high quality, detailed face"
)
with gr.Row():
image = gr.Image(
label="Upload Face Image",
type="pil",
height=400
)
with gr.Accordion("Advanced Settings", open=False):
negative_prompt = gr.Textbox(
label="Negative prompt",
placeholder="Things to avoid in the image",
value="low quality, blurry, distorted, disfigured",
visible=True,
)
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=66,
)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
with gr.Row():
guidance_scale = gr.Slider(
label="Guidance scale",
minimum=0.0,
maximum=10.0,
step=0.1,
value=5.0,
)
num_inference_steps = gr.Slider(
label="Number of inference steps",
minimum=10,
maximum=50,
step=1,
value=25,
)
with gr.Row():
button = gr.Button("🎨 Generate Portrait", elem_id="button", variant="primary", scale=1)
with gr.Column(elem_id="col-right"):
result = gr.Image(label="Generated Portrait", show_label=True)
seed_used = gr.Number(label="Seed Used", precision=0)
# 예제 추가
gr.Examples(
examples=[
["A cinematic portrait, dramatic lighting, professional photography", None],
["An oil painting portrait in Renaissance style, classical art", None],
["A cyberpunk character portrait, neon lights, futuristic", None],
],
inputs=[prompt, image],
)
button.click(
fn=infer,
inputs=[prompt, image, negative_prompt, seed, randomize_seed, guidance_scale, num_inference_steps],
outputs=[result, seed_used]
)
if __name__ == "__main__":
Kolors.queue(max_size=10).launch(debug=True, share=False)