aiqtech commited on
Commit
2432eb0
·
verified ·
1 Parent(s): 5bcee37

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +134 -93
app.py CHANGED
@@ -6,6 +6,7 @@ import insightface
6
  import gradio as gr
7
  import numpy as np
8
  import os
 
9
  from huggingface_hub import snapshot_download, login
10
  from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor
11
  from kolors.pipelines.pipeline_stable_diffusion_xl_chatglm_256_ipadapter_FaceID import StableDiffusionXLPipeline
@@ -18,6 +19,22 @@ from PIL import Image
18
  from insightface.app import FaceAnalysis
19
  from insightface.data import get_image as ins_get_image
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  # Hugging Face 토큰으로 로그인
22
  HF_TOKEN = os.getenv("HF_TOKEN")
23
  if HF_TOKEN:
@@ -30,87 +47,120 @@ else:
30
  device = "cuda" if torch.cuda.is_available() else "cpu"
31
  dtype = torch.float16 if device == "cuda" else torch.float32
32
 
 
 
 
33
  # 모델 다운로드 (토큰 사용)
34
  try:
 
35
  ckpt_dir = snapshot_download(
36
  repo_id="Kwai-Kolors/Kolors",
37
  token=HF_TOKEN,
38
- local_dir_use_symlinks=False
 
39
  )
 
 
40
  ckpt_dir_faceid = snapshot_download(
41
  repo_id="Kwai-Kolors/Kolors-IP-Adapter-FaceID-Plus",
42
  token=HF_TOKEN,
43
- local_dir_use_symlinks=False
 
44
  )
45
  except Exception as e:
46
  print(f"Error downloading models: {e}")
47
  raise
48
 
49
- # 모델 로딩 with error handling
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  try:
51
- text_encoder = ChatGLMModel.from_pretrained(
52
- f'{ckpt_dir}/text_encoder',
53
- torch_dtype=dtype,
54
- token=HF_TOKEN,
55
- trust_remote_code=True
56
- )
57
- if device == "cuda":
58
- text_encoder = text_encoder.half().to(device)
59
-
60
- tokenizer = ChatGLMTokenizer.from_pretrained(
61
- f'{ckpt_dir}/text_encoder',
62
- token=HF_TOKEN,
63
- trust_remote_code=True
64
- )
65
-
66
- vae = AutoencoderKL.from_pretrained(
67
- f"{ckpt_dir}/vae",
68
- revision=None,
69
- torch_dtype=dtype,
70
- token=HF_TOKEN
71
- )
72
- if device == "cuda":
73
- vae = vae.half().to(device)
74
-
75
- scheduler = EulerDiscreteScheduler.from_pretrained(
76
- f"{ckpt_dir}/scheduler",
77
- token=HF_TOKEN
78
- )
79
-
80
- unet = UNet2DConditionModel.from_pretrained(
81
- f"{ckpt_dir}/unet",
82
- revision=None,
83
- torch_dtype=dtype,
84
- token=HF_TOKEN
85
- )
86
- if device == "cuda":
87
- unet = unet.half().to(device)
88
-
89
- # CLIP 모델 로딩 with fallback
90
  try:
 
 
91
  clip_image_encoder = CLIPVisionModelWithProjection.from_pretrained(
92
- f'{ckpt_dir_faceid}/clip-vit-large-patch14-336',
93
  torch_dtype=dtype,
94
  ignore_mismatched_sizes=True,
95
- token=HF_TOKEN
 
 
96
  )
97
- except Exception as e:
98
- print(f"Loading CLIP from local failed: {e}, trying alternative source...")
 
 
99
  clip_image_encoder = CLIPVisionModelWithProjection.from_pretrained(
100
  'openai/clip-vit-large-patch14-336',
101
  torch_dtype=dtype,
102
  ignore_mismatched_sizes=True,
103
- token=HF_TOKEN
 
104
  )
105
-
106
- clip_image_encoder.to(device)
107
- clip_image_processor = CLIPImageProcessor(size=336, crop_size=336)
108
-
109
- except Exception as e:
110
- print(f"Error loading models: {e}")
111
- raise
112
 
113
- # Pipeline 생성
 
 
 
114
  pipe = StableDiffusionXLPipeline(
115
  vae=vae,
116
  text_encoder=text_encoder,
@@ -122,6 +172,8 @@ pipe = StableDiffusionXLPipeline(
122
  force_zeros_for_empty_prompt=False,
123
  )
124
 
 
 
125
  class FaceInfoGenerator():
126
  def __init__(self, root_dir="./.insightface/"):
127
  providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] if device == "cuda" else ['CPUExecutionProvider']
@@ -160,7 +212,7 @@ MAX_SEED = np.iinfo(np.int32).max
160
  MAX_IMAGE_SIZE = 1024
161
  face_info_generator = FaceInfoGenerator()
162
 
163
- @spaces.GPU
164
  def infer(prompt,
165
  image=None,
166
  negative_prompt="low quality, blurry, distorted",
@@ -170,6 +222,7 @@ def infer(prompt,
170
  num_inference_steps=50
171
  ):
172
  if image is None:
 
173
  return None, 0
174
 
175
  if randomize_seed:
@@ -187,35 +240,40 @@ def infer(prompt,
187
  pipe.set_face_fidelity_scale(scale)
188
  except Exception as e:
189
  print(f"Error loading IP adapter: {e}")
190
- raise
191
 
192
  # Face 정보 추출
193
  face_info = face_info_generator.get_faceinfo_one_img(image)
194
  if face_info is None:
195
  raise gr.Error("No face detected in the image. Please provide an image with a clear face.")
196
 
197
- face_bbox_square = face_bbox_to_square(face_info["bbox"])
198
- crop_image = image.crop(face_bbox_square)
199
- crop_image = crop_image.resize((336, 336))
200
- crop_image = [crop_image]
201
-
202
- face_embeds = torch.from_numpy(np.array([face_info["embedding"]]))
203
- face_embeds = face_embeds.to(device, dtype=dtype)
 
 
 
 
204
 
205
  # 이미지 생성
206
  try:
207
- image = pipe(
208
- prompt=prompt,
209
- negative_prompt=negative_prompt,
210
- height=1024,
211
- width=1024,
212
- num_inference_steps=num_inference_steps,
213
- guidance_scale=guidance_scale,
214
- num_images_per_prompt=1,
215
- generator=generator,
216
- face_crop_image=crop_image,
217
- face_insightface_embeds=face_embeds
218
- ).images[0]
 
219
  except Exception as e:
220
  print(f"Error during inference: {e}")
221
  raise gr.Error(f"Failed to generate image: {str(e)}")
@@ -233,13 +291,6 @@ footer {
233
  }
234
  """
235
 
236
- def load_description(fp):
237
- if os.path.exists(fp):
238
- with open(fp, 'r', encoding='utf-8') as f:
239
- content = f.read()
240
- return content
241
- return ""
242
-
243
  # Gradio Interface
244
  with gr.Blocks(theme="soft", css=css) as Kolors:
245
  gr.HTML(
@@ -309,16 +360,6 @@ with gr.Blocks(theme="soft", css=css) as Kolors:
309
  with gr.Column(elem_id="col-right"):
310
  result = gr.Image(label="Generated Portrait", show_label=True)
311
  seed_used = gr.Number(label="Seed Used", precision=0)
312
-
313
- # 예제 추가
314
- gr.Examples(
315
- examples=[
316
- ["A cinematic portrait, dramatic lighting, professional photography", None],
317
- ["An oil painting portrait in Renaissance style, classical art", None],
318
- ["A cyberpunk character portrait, neon lights, futuristic", None],
319
- ],
320
- inputs=[prompt, image],
321
- )
322
 
323
  button.click(
324
  fn=infer,
 
6
  import gradio as gr
7
  import numpy as np
8
  import os
9
+ import shutil
10
  from huggingface_hub import snapshot_download, login
11
  from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor
12
  from kolors.pipelines.pipeline_stable_diffusion_xl_chatglm_256_ipadapter_FaceID import StableDiffusionXLPipeline
 
19
  from insightface.app import FaceAnalysis
20
  from insightface.data import get_image as ins_get_image
21
 
22
+ # 캐시 클리어 (선택적)
23
+ def clear_cache():
24
+ cache_dir = "/home/user/.cache/huggingface/hub"
25
+ if os.path.exists(cache_dir):
26
+ try:
27
+ # CLIP 모델 캐시만 삭제
28
+ clip_cache = os.path.join(cache_dir, "models--openai--clip-vit-large-patch14-336")
29
+ if os.path.exists(clip_cache):
30
+ shutil.rmtree(clip_cache)
31
+ print("Cleared CLIP cache")
32
+ except Exception as e:
33
+ print(f"Could not clear cache: {e}")
34
+
35
+ # 캐시 클리어 (필요시)
36
+ # clear_cache()
37
+
38
  # Hugging Face 토큰으로 로그인
39
  HF_TOKEN = os.getenv("HF_TOKEN")
40
  if HF_TOKEN:
 
47
  device = "cuda" if torch.cuda.is_available() else "cpu"
48
  dtype = torch.float16 if device == "cuda" else torch.float32
49
 
50
+ print(f"Using device: {device}")
51
+ print(f"Using dtype: {dtype}")
52
+
53
  # 모델 다운로드 (토큰 사용)
54
  try:
55
+ print("Downloading Kolors models...")
56
  ckpt_dir = snapshot_download(
57
  repo_id="Kwai-Kolors/Kolors",
58
  token=HF_TOKEN,
59
+ local_dir_use_symlinks=False,
60
+ resume_download=True
61
  )
62
+
63
+ print("Downloading FaceID models...")
64
  ckpt_dir_faceid = snapshot_download(
65
  repo_id="Kwai-Kolors/Kolors-IP-Adapter-FaceID-Plus",
66
  token=HF_TOKEN,
67
+ local_dir_use_symlinks=False,
68
+ resume_download=True
69
  )
70
  except Exception as e:
71
  print(f"Error downloading models: {e}")
72
  raise
73
 
74
+ # 모델 로딩
75
+ print("Loading text encoder...")
76
+ text_encoder = ChatGLMModel.from_pretrained(
77
+ f'{ckpt_dir}/text_encoder',
78
+ torch_dtype=dtype,
79
+ token=HF_TOKEN,
80
+ trust_remote_code=True
81
+ )
82
+ if device == "cuda":
83
+ text_encoder = text_encoder.half().to(device)
84
+
85
+ print("Loading tokenizer...")
86
+ tokenizer = ChatGLMTokenizer.from_pretrained(
87
+ f'{ckpt_dir}/text_encoder',
88
+ token=HF_TOKEN,
89
+ trust_remote_code=True
90
+ )
91
+
92
+ print("Loading VAE...")
93
+ vae = AutoencoderKL.from_pretrained(
94
+ f"{ckpt_dir}/vae",
95
+ revision=None,
96
+ torch_dtype=dtype,
97
+ token=HF_TOKEN
98
+ )
99
+ if device == "cuda":
100
+ vae = vae.half().to(device)
101
+
102
+ print("Loading scheduler...")
103
+ scheduler = EulerDiscreteScheduler.from_pretrained(
104
+ f"{ckpt_dir}/scheduler",
105
+ token=HF_TOKEN
106
+ )
107
+
108
+ print("Loading UNet...")
109
+ unet = UNet2DConditionModel.from_pretrained(
110
+ f"{ckpt_dir}/unet",
111
+ revision=None,
112
+ torch_dtype=dtype,
113
+ token=HF_TOKEN
114
+ )
115
+ if device == "cuda":
116
+ unet = unet.half().to(device)
117
+
118
+ # CLIP 모델 로딩 - safetensors 우선 사용
119
+ print("Loading CLIP model...")
120
  try:
121
+ # 먼저 로컬 FaceID 디렉토리에서 시도
122
+ local_clip_path = f'{ckpt_dir_faceid}/clip-vit-large-patch14-336'
123
+ if os.path.exists(local_clip_path):
124
+ print(f"Trying to load CLIP from local: {local_clip_path}")
125
+ clip_image_encoder = CLIPVisionModelWithProjection.from_pretrained(
126
+ local_clip_path,
127
+ torch_dtype=dtype,
128
+ ignore_mismatched_sizes=True,
129
+ token=HF_TOKEN,
130
+ use_safetensors=True, # safetensors 우선 사용
131
+ local_files_only=True
132
+ )
133
+ else:
134
+ raise FileNotFoundError("Local CLIP not found")
135
+ except Exception as e:
136
+ print(f"Local loading failed: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
  try:
138
+ # OpenAI에서 직접 다운로드 (safetensors 버전)
139
+ print("Downloading CLIP from OpenAI...")
140
  clip_image_encoder = CLIPVisionModelWithProjection.from_pretrained(
141
+ 'openai/clip-vit-large-patch14-336',
142
  torch_dtype=dtype,
143
  ignore_mismatched_sizes=True,
144
+ token=HF_TOKEN,
145
+ use_safetensors=True, # safetensors 우선 사용
146
+ revision="main"
147
  )
148
+ except Exception as e2:
149
+ print(f"SafeTensors loading failed: {e2}")
150
+ # 최후의 수단: pytorch_model.bin 사용
151
+ print("Trying with pytorch format...")
152
  clip_image_encoder = CLIPVisionModelWithProjection.from_pretrained(
153
  'openai/clip-vit-large-patch14-336',
154
  torch_dtype=dtype,
155
  ignore_mismatched_sizes=True,
156
+ token=HF_TOKEN,
157
+ use_safetensors=False
158
  )
 
 
 
 
 
 
 
159
 
160
+ clip_image_encoder.to(device)
161
+ clip_image_processor = CLIPImageProcessor(size=336, crop_size=336)
162
+
163
+ print("Creating pipeline...")
164
  pipe = StableDiffusionXLPipeline(
165
  vae=vae,
166
  text_encoder=text_encoder,
 
172
  force_zeros_for_empty_prompt=False,
173
  )
174
 
175
+ print("Models loaded successfully!")
176
+
177
  class FaceInfoGenerator():
178
  def __init__(self, root_dir="./.insightface/"):
179
  providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] if device == "cuda" else ['CPUExecutionProvider']
 
212
  MAX_IMAGE_SIZE = 1024
213
  face_info_generator = FaceInfoGenerator()
214
 
215
+ @spaces.GPU(duration=60)
216
  def infer(prompt,
217
  image=None,
218
  negative_prompt="low quality, blurry, distorted",
 
222
  num_inference_steps=50
223
  ):
224
  if image is None:
225
+ gr.Warning("Please upload an image with a face.")
226
  return None, 0
227
 
228
  if randomize_seed:
 
240
  pipe.set_face_fidelity_scale(scale)
241
  except Exception as e:
242
  print(f"Error loading IP adapter: {e}")
243
+ raise gr.Error(f"Failed to load face adapter: {str(e)}")
244
 
245
  # Face 정보 추출
246
  face_info = face_info_generator.get_faceinfo_one_img(image)
247
  if face_info is None:
248
  raise gr.Error("No face detected in the image. Please provide an image with a clear face.")
249
 
250
+ try:
251
+ face_bbox_square = face_bbox_to_square(face_info["bbox"])
252
+ crop_image = image.crop(face_bbox_square)
253
+ crop_image = crop_image.resize((336, 336))
254
+ crop_image = [crop_image]
255
+
256
+ face_embeds = torch.from_numpy(np.array([face_info["embedding"]]))
257
+ face_embeds = face_embeds.to(device, dtype=dtype)
258
+ except Exception as e:
259
+ print(f"Error processing face: {e}")
260
+ raise gr.Error(f"Failed to process face: {str(e)}")
261
 
262
  # 이미지 생성
263
  try:
264
+ with torch.no_grad():
265
+ image = pipe(
266
+ prompt=prompt,
267
+ negative_prompt=negative_prompt,
268
+ height=1024,
269
+ width=1024,
270
+ num_inference_steps=num_inference_steps,
271
+ guidance_scale=guidance_scale,
272
+ num_images_per_prompt=1,
273
+ generator=generator,
274
+ face_crop_image=crop_image,
275
+ face_insightface_embeds=face_embeds
276
+ ).images[0]
277
  except Exception as e:
278
  print(f"Error during inference: {e}")
279
  raise gr.Error(f"Failed to generate image: {str(e)}")
 
291
  }
292
  """
293
 
 
 
 
 
 
 
 
294
  # Gradio Interface
295
  with gr.Blocks(theme="soft", css=css) as Kolors:
296
  gr.HTML(
 
360
  with gr.Column(elem_id="col-right"):
361
  result = gr.Image(label="Generated Portrait", show_label=True)
362
  seed_used = gr.Number(label="Seed Used", precision=0)
 
 
 
 
 
 
 
 
 
 
363
 
364
  button.click(
365
  fn=infer,