Spaces:
Running
on
Zero
Running
on
Zero
import spaces | |
import random | |
import torch | |
import cv2 | |
import insightface | |
import gradio as gr | |
import numpy as np | |
import os | |
import shutil | |
from huggingface_hub import snapshot_download, login | |
from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor | |
from kolors.pipelines.pipeline_stable_diffusion_xl_chatglm_256_ipadapter_FaceID import StableDiffusionXLPipeline | |
from kolors.models.modeling_chatglm import ChatGLMModel | |
from kolors.models.tokenization_chatglm import ChatGLMTokenizer | |
from diffusers import AutoencoderKL | |
from kolors.models.unet_2d_condition import UNet2DConditionModel | |
from diffusers import EulerDiscreteScheduler | |
from PIL import Image | |
from insightface.app import FaceAnalysis | |
from insightface.data import get_image as ins_get_image | |
# 캐시 클리어 (선택적) | |
def clear_cache(): | |
cache_dir = "/home/user/.cache/huggingface/hub" | |
if os.path.exists(cache_dir): | |
try: | |
# CLIP 모델 캐시만 삭제 | |
clip_cache = os.path.join(cache_dir, "models--openai--clip-vit-large-patch14-336") | |
if os.path.exists(clip_cache): | |
shutil.rmtree(clip_cache) | |
print("Cleared CLIP cache") | |
except Exception as e: | |
print(f"Could not clear cache: {e}") | |
# 캐시 클리어 (필요시) | |
# clear_cache() | |
# Hugging Face 토큰으로 로그인 | |
HF_TOKEN = os.getenv("HF_TOKEN") | |
if HF_TOKEN: | |
login(token=HF_TOKEN) | |
print("Successfully logged in to Hugging Face Hub") | |
else: | |
print("Warning: HF_TOKEN not found. Using public access only.") | |
# GPU 사용 가능 여부 확인 | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
dtype = torch.float16 if device == "cuda" else torch.float32 | |
print(f"Using device: {device}") | |
print(f"Using dtype: {dtype}") | |
# 모델 다운로드 (토큰 사용) | |
try: | |
print("Downloading Kolors models...") | |
ckpt_dir = snapshot_download( | |
repo_id="Kwai-Kolors/Kolors", | |
token=HF_TOKEN, | |
local_dir_use_symlinks=False, | |
resume_download=True | |
) | |
print("Downloading FaceID models...") | |
ckpt_dir_faceid = snapshot_download( | |
repo_id="Kwai-Kolors/Kolors-IP-Adapter-FaceID-Plus", | |
token=HF_TOKEN, | |
local_dir_use_symlinks=False, | |
resume_download=True | |
) | |
except Exception as e: | |
print(f"Error downloading models: {e}") | |
raise | |
# 모델 로딩 | |
print("Loading text encoder...") | |
text_encoder = ChatGLMModel.from_pretrained( | |
f'{ckpt_dir}/text_encoder', | |
torch_dtype=dtype, | |
token=HF_TOKEN, | |
trust_remote_code=True | |
) | |
if device == "cuda": | |
text_encoder = text_encoder.half().to(device) | |
print("Loading tokenizer...") | |
tokenizer = ChatGLMTokenizer.from_pretrained( | |
f'{ckpt_dir}/text_encoder', | |
token=HF_TOKEN, | |
trust_remote_code=True | |
) | |
print("Loading VAE...") | |
vae = AutoencoderKL.from_pretrained( | |
f"{ckpt_dir}/vae", | |
revision=None, | |
torch_dtype=dtype, | |
token=HF_TOKEN | |
) | |
if device == "cuda": | |
vae = vae.half().to(device) | |
print("Loading scheduler...") | |
scheduler = EulerDiscreteScheduler.from_pretrained( | |
f"{ckpt_dir}/scheduler", | |
token=HF_TOKEN | |
) | |
print("Loading UNet...") | |
unet = UNet2DConditionModel.from_pretrained( | |
f"{ckpt_dir}/unet", | |
revision=None, | |
torch_dtype=dtype, | |
token=HF_TOKEN | |
) | |
if device == "cuda": | |
unet = unet.half().to(device) | |
# CLIP 모델 로딩 - safetensors 우선 사용 | |
print("Loading CLIP model...") | |
try: | |
# 먼저 로컬 FaceID 디렉토리에서 시도 | |
local_clip_path = f'{ckpt_dir_faceid}/clip-vit-large-patch14-336' | |
if os.path.exists(local_clip_path): | |
print(f"Trying to load CLIP from local: {local_clip_path}") | |
clip_image_encoder = CLIPVisionModelWithProjection.from_pretrained( | |
local_clip_path, | |
torch_dtype=dtype, | |
ignore_mismatched_sizes=True, | |
token=HF_TOKEN, | |
use_safetensors=True, # safetensors 우선 사용 | |
local_files_only=True | |
) | |
else: | |
raise FileNotFoundError("Local CLIP not found") | |
except Exception as e: | |
print(f"Local loading failed: {e}") | |
try: | |
# OpenAI에서 직접 다운로드 (safetensors 버전) | |
print("Downloading CLIP from OpenAI...") | |
clip_image_encoder = CLIPVisionModelWithProjection.from_pretrained( | |
'openai/clip-vit-large-patch14-336', | |
torch_dtype=dtype, | |
ignore_mismatched_sizes=True, | |
token=HF_TOKEN, | |
use_safetensors=True, # safetensors 우선 사용 | |
revision="main" | |
) | |
except Exception as e2: | |
print(f"SafeTensors loading failed: {e2}") | |
# 최후의 수단: pytorch_model.bin 사용 | |
print("Trying with pytorch format...") | |
clip_image_encoder = CLIPVisionModelWithProjection.from_pretrained( | |
'openai/clip-vit-large-patch14-336', | |
torch_dtype=dtype, | |
ignore_mismatched_sizes=True, | |
token=HF_TOKEN, | |
use_safetensors=False | |
) | |
clip_image_encoder.to(device) | |
clip_image_processor = CLIPImageProcessor(size=336, crop_size=336) | |
print("Creating pipeline...") | |
pipe = StableDiffusionXLPipeline( | |
vae=vae, | |
text_encoder=text_encoder, | |
tokenizer=tokenizer, | |
unet=unet, | |
scheduler=scheduler, | |
face_clip_encoder=clip_image_encoder, | |
face_clip_processor=clip_image_processor, | |
force_zeros_for_empty_prompt=False, | |
) | |
print("Models loaded successfully!") | |
class FaceInfoGenerator(): | |
def __init__(self, root_dir="./.insightface/"): | |
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] if device == "cuda" else ['CPUExecutionProvider'] | |
self.app = FaceAnalysis(name='antelopev2', root=root_dir, providers=providers) | |
self.app.prepare(ctx_id=0, det_size=(640, 640)) | |
def get_faceinfo_one_img(self, face_image): | |
if face_image is None: | |
return None | |
face_info = self.app.get(cv2.cvtColor(np.array(face_image), cv2.COLOR_RGB2BGR)) | |
if len(face_info) == 0: | |
return None | |
else: | |
# only use the maximum face | |
face_info = sorted(face_info, key=lambda x:(x['bbox'][2]-x['bbox'][0])*(x['bbox'][3]-x['bbox'][1]))[-1] | |
return face_info | |
def face_bbox_to_square(bbox): | |
## l, t, r, b to square l, t, r, b | |
l, t, r, b = bbox | |
cent_x = (l + r) / 2 | |
cent_y = (t + b) / 2 | |
w, h = r - l, b - t | |
r = max(w, h) / 2 | |
l0 = cent_x - r | |
r0 = cent_x + r | |
t0 = cent_y - r | |
b0 = cent_y + r | |
return [l0, t0, r0, b0] | |
MAX_SEED = np.iinfo(np.int32).max | |
MAX_IMAGE_SIZE = 1024 | |
face_info_generator = FaceInfoGenerator() | |
def infer(prompt, | |
image=None, | |
negative_prompt="low quality, blurry, distorted", | |
seed=66, | |
randomize_seed=False, | |
guidance_scale=5.0, | |
num_inference_steps=50 | |
): | |
if image is None: | |
gr.Warning("Please upload an image with a face.") | |
return None, 0 | |
if randomize_seed: | |
seed = random.randint(0, MAX_SEED) | |
generator = torch.Generator(device=device).manual_seed(seed) | |
global pipe | |
pipe = pipe.to(device) | |
# IP Adapter 로딩 | |
try: | |
pipe.load_ip_adapter_faceid_plus(f'{ckpt_dir_faceid}/ipa-faceid-plus.bin', device=device) | |
scale = 0.8 | |
pipe.set_face_fidelity_scale(scale) | |
except Exception as e: | |
print(f"Error loading IP adapter: {e}") | |
raise gr.Error(f"Failed to load face adapter: {str(e)}") | |
# Face 정보 추출 | |
face_info = face_info_generator.get_faceinfo_one_img(image) | |
if face_info is None: | |
raise gr.Error("No face detected in the image. Please provide an image with a clear face.") | |
try: | |
face_bbox_square = face_bbox_to_square(face_info["bbox"]) | |
crop_image = image.crop(face_bbox_square) | |
crop_image = crop_image.resize((336, 336)) | |
crop_image = [crop_image] | |
face_embeds = torch.from_numpy(np.array([face_info["embedding"]])) | |
face_embeds = face_embeds.to(device, dtype=dtype) | |
except Exception as e: | |
print(f"Error processing face: {e}") | |
raise gr.Error(f"Failed to process face: {str(e)}") | |
# 이미지 생성 | |
try: | |
with torch.no_grad(): | |
image = pipe( | |
prompt=prompt, | |
negative_prompt=negative_prompt, | |
height=1024, | |
width=1024, | |
num_inference_steps=num_inference_steps, | |
guidance_scale=guidance_scale, | |
num_images_per_prompt=1, | |
generator=generator, | |
face_crop_image=crop_image, | |
face_insightface_embeds=face_embeds | |
).images[0] | |
except Exception as e: | |
print(f"Error during inference: {e}") | |
raise gr.Error(f"Failed to generate image: {str(e)}") | |
return image, seed | |
css = """ | |
footer { | |
visibility: hidden; | |
} | |
.container { | |
max-width: 1200px; | |
margin: 0 auto; | |
padding: 20px; | |
} | |
""" | |
# Gradio Interface | |
with gr.Blocks(theme="soft", css=css) as Kolors: | |
gr.HTML( | |
""" | |
<div class='container' style='display:flex; justify-content:center; gap:12px;'> | |
<a href="https://huggingface.co/spaces/openfree/Best-AI" target="_blank"> | |
<img src="https://img.shields.io/static/v1?label=OpenFree&message=BEST%20AI%20Services&color=%230000ff&labelColor=%23000080&logo=huggingface&logoColor=%23ffa500&style=for-the-badge" alt="OpenFree badge"> | |
</a> | |
<a href="https://discord.gg/openfreeai" target="_blank"> | |
<img src="https://img.shields.io/static/v1?label=Discord&message=Openfree%20AI&color=%230000ff&labelColor=%23800080&logo=discord&logoColor=white&style=for-the-badge" alt="Discord badge"> | |
</a> | |
</div> | |
<h1 style="text-align: center;">Kolors Face ID - AI Portrait Generator</h1> | |
<p style="text-align: center;">Upload a face photo and create stunning AI portraits with text prompts!</p> | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(elem_id="col-left"): | |
with gr.Row(): | |
prompt = gr.Textbox( | |
label="Prompt", | |
placeholder="e.g., A professional portrait in business attire, studio lighting", | |
lines=3, | |
value="A professional portrait photo, high quality, detailed face" | |
) | |
with gr.Row(): | |
image = gr.Image( | |
label="Upload Face Image", | |
type="pil", | |
height=400 | |
) | |
with gr.Accordion("Advanced Settings", open=False): | |
negative_prompt = gr.Textbox( | |
label="Negative prompt", | |
placeholder="Things to avoid in the image", | |
value="low quality, blurry, distorted, disfigured", | |
visible=True, | |
) | |
seed = gr.Slider( | |
label="Seed", | |
minimum=0, | |
maximum=MAX_SEED, | |
step=1, | |
value=66, | |
) | |
randomize_seed = gr.Checkbox(label="Randomize seed", value=True) | |
with gr.Row(): | |
guidance_scale = gr.Slider( | |
label="Guidance scale", | |
minimum=0.0, | |
maximum=10.0, | |
step=0.1, | |
value=5.0, | |
) | |
num_inference_steps = gr.Slider( | |
label="Number of inference steps", | |
minimum=10, | |
maximum=50, | |
step=1, | |
value=25, | |
) | |
with gr.Row(): | |
button = gr.Button("🎨 Generate Portrait", elem_id="button", variant="primary", scale=1) | |
with gr.Column(elem_id="col-right"): | |
result = gr.Image(label="Generated Portrait", show_label=True) | |
seed_used = gr.Number(label="Seed Used", precision=0) | |
button.click( | |
fn=infer, | |
inputs=[prompt, image, negative_prompt, seed, randomize_seed, guidance_scale, num_inference_steps], | |
outputs=[result, seed_used] | |
) | |
if __name__ == "__main__": | |
Kolors.queue(max_size=10).launch(debug=True, share=False) |