aiqtech commited on
Commit
a0ed24e
·
verified ·
1 Parent(s): a70f7b1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +198 -75
app.py CHANGED
@@ -6,8 +6,8 @@ import insightface
6
  import gradio as gr
7
  import numpy as np
8
  import os
9
- from huggingface_hub import snapshot_download
10
- from transformers import CLIPVisionModelWithProjection,CLIPImageProcessor
11
  from kolors.pipelines.pipeline_stable_diffusion_xl_chatglm_256_ipadapter_FaceID import StableDiffusionXLPipeline
12
  from kolors.models.modeling_chatglm import ChatGLMModel
13
  from kolors.models.tokenization_chatglm import ChatGLMTokenizer
@@ -18,48 +18,132 @@ from PIL import Image
18
  from insightface.app import FaceAnalysis
19
  from insightface.data import get_image as ins_get_image
20
 
 
 
 
 
 
 
 
21
 
22
- device = "cuda"
23
- ckpt_dir = snapshot_download(repo_id="Kwai-Kolors/Kolors")
24
- ckpt_dir_faceid = snapshot_download(repo_id="Kwai-Kolors/Kolors-IP-Adapter-FaceID-Plus")
25
 
26
- text_encoder = ChatGLMModel.from_pretrained(f'{ckpt_dir}/text_encoder', torch_dtype=torch.float16).half().to(device)
27
- tokenizer = ChatGLMTokenizer.from_pretrained(f'{ckpt_dir}/text_encoder')
28
- vae = AutoencoderKL.from_pretrained(f"{ckpt_dir}/vae", revision=None).half().to(device)
29
- scheduler = EulerDiscreteScheduler.from_pretrained(f"{ckpt_dir}/scheduler")
30
- unet = UNet2DConditionModel.from_pretrained(f"{ckpt_dir}/unet", revision=None).half().to(device)
31
- clip_image_encoder = CLIPVisionModelWithProjection.from_pretrained(f'{ckpt_dir_faceid}/clip-vit-large-patch14-336', ignore_mismatched_sizes=True)
32
- clip_image_encoder.to(device)
33
- clip_image_processor = CLIPImageProcessor(size = 336, crop_size = 336)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
 
35
  pipe = StableDiffusionXLPipeline(
36
- vae = vae,
37
- text_encoder = text_encoder,
38
- tokenizer = tokenizer,
39
- unet = unet,
40
- scheduler = scheduler,
41
- face_clip_encoder = clip_image_encoder,
42
- face_clip_processor = clip_image_processor,
43
- force_zeros_for_empty_prompt = False,
44
  )
45
 
46
  class FaceInfoGenerator():
47
- def __init__(self, root_dir = "./.insightface/"):
48
- self.app = FaceAnalysis(name = 'antelopev2', root = root_dir, providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
49
- self.app.prepare(ctx_id = 0, det_size = (640, 640))
 
50
 
51
  def get_faceinfo_one_img(self, face_image):
 
 
 
52
  face_info = self.app.get(cv2.cvtColor(np.array(face_image), cv2.COLOR_RGB2BGR))
53
 
54
  if len(face_info) == 0:
55
- face_info = None
56
  else:
57
- face_info = sorted(face_info, key=lambda x:(x['bbox'][2]-x['bbox'][0])*(x['bbox'][3]-x['bbox'][1]))[-1] # only use the maximum face
 
58
  return face_info
59
 
60
  def face_bbox_to_square(bbox):
61
  ## l, t, r, b to square l, t, r, b
62
- l,t,r,b = bbox
63
  cent_x = (l + r) / 2
64
  cent_y = (t + b) / 2
65
  w, h = r - l, b - t
@@ -78,63 +162,86 @@ face_info_generator = FaceInfoGenerator()
78
 
79
  @spaces.GPU
80
  def infer(prompt,
81
- image = None,
82
- negative_prompt = "low quality",
83
- seed = 66,
84
- randomize_seed = False,
85
- guidance_scale = 5.0,
86
- num_inference_steps = 50
87
  ):
 
 
 
88
  if randomize_seed:
89
  seed = random.randint(0, MAX_SEED)
90
- generator = torch.Generator().manual_seed(seed)
 
 
91
  global pipe
92
  pipe = pipe.to(device)
93
- pipe.load_ip_adapter_faceid_plus(f'{ckpt_dir_faceid}/ipa-faceid-plus.bin', device = device)
94
- scale = 0.8
95
- pipe.set_face_fidelity_scale(scale)
 
 
 
 
 
 
96
 
 
97
  face_info = face_info_generator.get_faceinfo_one_img(image)
 
 
 
98
  face_bbox_square = face_bbox_to_square(face_info["bbox"])
99
  crop_image = image.crop(face_bbox_square)
100
  crop_image = crop_image.resize((336, 336))
101
  crop_image = [crop_image]
 
102
  face_embeds = torch.from_numpy(np.array([face_info["embedding"]]))
103
- face_embeds = face_embeds.to(device, dtype = torch.float16)
104
-
105
- image = pipe(
106
- prompt = prompt,
107
- negative_prompt = negative_prompt,
108
- height = 1024,
109
- width = 1024,
110
- num_inference_steps= num_inference_steps,
111
- guidance_scale = guidance_scale,
112
- num_images_per_prompt = 1,
113
- generator = generator,
114
- face_crop_image = crop_image,
115
- face_insightface_embeds = face_embeds
116
- ).images[0]
117
-
118
- return image, seed
119
-
120
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
 
 
122
 
123
  css = """
124
  footer {
125
  visibility: hidden;
126
  }
 
 
 
 
 
127
  """
128
 
129
-
130
-
131
  def load_description(fp):
132
- with open(fp, 'r', encoding='utf-8') as f:
133
- content = f.read()
134
- return content
 
 
135
 
 
136
  with gr.Blocks(theme="soft", css=css) as Kolors:
137
-
138
  gr.HTML(
139
  """
140
  <div class='container' style='display:flex; justify-content:center; gap:12px;'>
@@ -146,7 +253,9 @@ with gr.Blocks(theme="soft", css=css) as Kolors:
146
  <img src="https://img.shields.io/static/v1?label=Discord&message=Openfree%20AI&color=%230000ff&labelColor=%23800080&logo=discord&logoColor=white&style=for-the-badge" alt="Discord badge">
147
  </a>
148
  </div>
149
- """
 
 
150
  )
151
 
152
  with gr.Row():
@@ -154,15 +263,21 @@ with gr.Blocks(theme="soft", css=css) as Kolors:
154
  with gr.Row():
155
  prompt = gr.Textbox(
156
  label="Prompt",
157
- placeholder="Enter your prompt",
158
- lines=2
 
159
  )
160
  with gr.Row():
161
- image = gr.Image(label="Image", type="pil")
 
 
 
 
162
  with gr.Accordion("Advanced Settings", open=False):
163
  negative_prompt = gr.Textbox(
164
  label="Negative prompt",
165
- placeholder="Enter a negative prompt",
 
166
  visible=True,
167
  )
168
  seed = gr.Slider(
@@ -170,7 +285,7 @@ with gr.Blocks(theme="soft", css=css) as Kolors:
170
  minimum=0,
171
  maximum=MAX_SEED,
172
  step=1,
173
- value=0,
174
  )
175
  randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
176
  with gr.Row():
@@ -189,19 +304,27 @@ with gr.Blocks(theme="soft", css=css) as Kolors:
189
  value=25,
190
  )
191
  with gr.Row():
192
- button = gr.Button("Run", elem_id="button")
193
 
194
  with gr.Column(elem_id="col-right"):
195
- result = gr.Image(label="Result", show_label=False)
196
- seed_used = gr.Number(label="Seed Used")
197
 
198
-
 
 
 
 
 
 
 
 
199
 
200
  button.click(
201
- fn = infer,
202
- inputs = [prompt, image, negative_prompt, seed, randomize_seed, guidance_scale, num_inference_steps],
203
- outputs = [result, seed_used]
204
  )
205
 
206
-
207
- Kolors.queue().launch(debug=True)
 
6
  import gradio as gr
7
  import numpy as np
8
  import os
9
+ from huggingface_hub import snapshot_download, login
10
+ from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor
11
  from kolors.pipelines.pipeline_stable_diffusion_xl_chatglm_256_ipadapter_FaceID import StableDiffusionXLPipeline
12
  from kolors.models.modeling_chatglm import ChatGLMModel
13
  from kolors.models.tokenization_chatglm import ChatGLMTokenizer
 
18
  from insightface.app import FaceAnalysis
19
  from insightface.data import get_image as ins_get_image
20
 
21
+ # Hugging Face 토큰으로 로그인
22
+ HF_TOKEN = os.getenv("HF_TOKEN")
23
+ if HF_TOKEN:
24
+ login(token=HF_TOKEN)
25
+ print("Successfully logged in to Hugging Face Hub")
26
+ else:
27
+ print("Warning: HF_TOKEN not found. Using public access only.")
28
 
29
+ # GPU 사용 가능 여부 확인
30
+ device = "cuda" if torch.cuda.is_available() else "cpu"
31
+ dtype = torch.float16 if device == "cuda" else torch.float32
32
 
33
+ # 모델 다운로드 (토큰 사용)
34
+ try:
35
+ ckpt_dir = snapshot_download(
36
+ repo_id="Kwai-Kolors/Kolors",
37
+ token=HF_TOKEN,
38
+ local_dir_use_symlinks=False
39
+ )
40
+ ckpt_dir_faceid = snapshot_download(
41
+ repo_id="Kwai-Kolors/Kolors-IP-Adapter-FaceID-Plus",
42
+ token=HF_TOKEN,
43
+ local_dir_use_symlinks=False
44
+ )
45
+ except Exception as e:
46
+ print(f"Error downloading models: {e}")
47
+ raise
48
+
49
+ # 모델 로딩 with error handling
50
+ try:
51
+ text_encoder = ChatGLMModel.from_pretrained(
52
+ f'{ckpt_dir}/text_encoder',
53
+ torch_dtype=dtype,
54
+ token=HF_TOKEN,
55
+ trust_remote_code=True
56
+ )
57
+ if device == "cuda":
58
+ text_encoder = text_encoder.half().to(device)
59
+
60
+ tokenizer = ChatGLMTokenizer.from_pretrained(
61
+ f'{ckpt_dir}/text_encoder',
62
+ token=HF_TOKEN,
63
+ trust_remote_code=True
64
+ )
65
+
66
+ vae = AutoencoderKL.from_pretrained(
67
+ f"{ckpt_dir}/vae",
68
+ revision=None,
69
+ torch_dtype=dtype,
70
+ token=HF_TOKEN
71
+ )
72
+ if device == "cuda":
73
+ vae = vae.half().to(device)
74
+
75
+ scheduler = EulerDiscreteScheduler.from_pretrained(
76
+ f"{ckpt_dir}/scheduler",
77
+ token=HF_TOKEN
78
+ )
79
+
80
+ unet = UNet2DConditionModel.from_pretrained(
81
+ f"{ckpt_dir}/unet",
82
+ revision=None,
83
+ torch_dtype=dtype,
84
+ token=HF_TOKEN
85
+ )
86
+ if device == "cuda":
87
+ unet = unet.half().to(device)
88
+
89
+ # CLIP 모델 로딩 with fallback
90
+ try:
91
+ clip_image_encoder = CLIPVisionModelWithProjection.from_pretrained(
92
+ f'{ckpt_dir_faceid}/clip-vit-large-patch14-336',
93
+ torch_dtype=dtype,
94
+ ignore_mismatched_sizes=True,
95
+ token=HF_TOKEN
96
+ )
97
+ except Exception as e:
98
+ print(f"Loading CLIP from local failed: {e}, trying alternative source...")
99
+ clip_image_encoder = CLIPVisionModelWithProjection.from_pretrained(
100
+ 'openai/clip-vit-large-patch14-336',
101
+ torch_dtype=dtype,
102
+ ignore_mismatched_sizes=True,
103
+ token=HF_TOKEN
104
+ )
105
+
106
+ clip_image_encoder.to(device)
107
+ clip_image_processor = CLIPImageProcessor(size=336, crop_size=336)
108
+
109
+ except Exception as e:
110
+ print(f"Error loading models: {e}")
111
+ raise
112
 
113
+ # Pipeline 생성
114
  pipe = StableDiffusionXLPipeline(
115
+ vae=vae,
116
+ text_encoder=text_encoder,
117
+ tokenizer=tokenizer,
118
+ unet=unet,
119
+ scheduler=scheduler,
120
+ face_clip_encoder=clip_image_encoder,
121
+ face_clip_processor=clip_image_processor,
122
+ force_zeros_for_empty_prompt=False,
123
  )
124
 
125
  class FaceInfoGenerator():
126
+ def __init__(self, root_dir="./.insightface/"):
127
+ providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] if device == "cuda" else ['CPUExecutionProvider']
128
+ self.app = FaceAnalysis(name='antelopev2', root=root_dir, providers=providers)
129
+ self.app.prepare(ctx_id=0, det_size=(640, 640))
130
 
131
  def get_faceinfo_one_img(self, face_image):
132
+ if face_image is None:
133
+ return None
134
+
135
  face_info = self.app.get(cv2.cvtColor(np.array(face_image), cv2.COLOR_RGB2BGR))
136
 
137
  if len(face_info) == 0:
138
+ return None
139
  else:
140
+ # only use the maximum face
141
+ face_info = sorted(face_info, key=lambda x:(x['bbox'][2]-x['bbox'][0])*(x['bbox'][3]-x['bbox'][1]))[-1]
142
  return face_info
143
 
144
  def face_bbox_to_square(bbox):
145
  ## l, t, r, b to square l, t, r, b
146
+ l, t, r, b = bbox
147
  cent_x = (l + r) / 2
148
  cent_y = (t + b) / 2
149
  w, h = r - l, b - t
 
162
 
163
  @spaces.GPU
164
  def infer(prompt,
165
+ image=None,
166
+ negative_prompt="low quality, blurry, distorted",
167
+ seed=66,
168
+ randomize_seed=False,
169
+ guidance_scale=5.0,
170
+ num_inference_steps=50
171
  ):
172
+ if image is None:
173
+ return None, 0
174
+
175
  if randomize_seed:
176
  seed = random.randint(0, MAX_SEED)
177
+
178
+ generator = torch.Generator(device=device).manual_seed(seed)
179
+
180
  global pipe
181
  pipe = pipe.to(device)
182
+
183
+ # IP Adapter 로딩
184
+ try:
185
+ pipe.load_ip_adapter_faceid_plus(f'{ckpt_dir_faceid}/ipa-faceid-plus.bin', device=device)
186
+ scale = 0.8
187
+ pipe.set_face_fidelity_scale(scale)
188
+ except Exception as e:
189
+ print(f"Error loading IP adapter: {e}")
190
+ raise
191
 
192
+ # Face 정보 추출
193
  face_info = face_info_generator.get_faceinfo_one_img(image)
194
+ if face_info is None:
195
+ raise gr.Error("No face detected in the image. Please provide an image with a clear face.")
196
+
197
  face_bbox_square = face_bbox_to_square(face_info["bbox"])
198
  crop_image = image.crop(face_bbox_square)
199
  crop_image = crop_image.resize((336, 336))
200
  crop_image = [crop_image]
201
+
202
  face_embeds = torch.from_numpy(np.array([face_info["embedding"]]))
203
+ face_embeds = face_embeds.to(device, dtype=dtype)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
204
 
205
+ # 이미지 생성
206
+ try:
207
+ image = pipe(
208
+ prompt=prompt,
209
+ negative_prompt=negative_prompt,
210
+ height=1024,
211
+ width=1024,
212
+ num_inference_steps=num_inference_steps,
213
+ guidance_scale=guidance_scale,
214
+ num_images_per_prompt=1,
215
+ generator=generator,
216
+ face_crop_image=crop_image,
217
+ face_insightface_embeds=face_embeds
218
+ ).images[0]
219
+ except Exception as e:
220
+ print(f"Error during inference: {e}")
221
+ raise gr.Error(f"Failed to generate image: {str(e)}")
222
 
223
+ return image, seed
224
 
225
  css = """
226
  footer {
227
  visibility: hidden;
228
  }
229
+ .container {
230
+ max-width: 1200px;
231
+ margin: 0 auto;
232
+ padding: 20px;
233
+ }
234
  """
235
 
 
 
236
  def load_description(fp):
237
+ if os.path.exists(fp):
238
+ with open(fp, 'r', encoding='utf-8') as f:
239
+ content = f.read()
240
+ return content
241
+ return ""
242
 
243
+ # Gradio Interface
244
  with gr.Blocks(theme="soft", css=css) as Kolors:
 
245
  gr.HTML(
246
  """
247
  <div class='container' style='display:flex; justify-content:center; gap:12px;'>
 
253
  <img src="https://img.shields.io/static/v1?label=Discord&message=Openfree%20AI&color=%230000ff&labelColor=%23800080&logo=discord&logoColor=white&style=for-the-badge" alt="Discord badge">
254
  </a>
255
  </div>
256
+ <h1 style="text-align: center;">Kolors Face ID - AI Portrait Generator</h1>
257
+ <p style="text-align: center;">Upload a face photo and create stunning AI portraits with text prompts!</p>
258
+ """
259
  )
260
 
261
  with gr.Row():
 
263
  with gr.Row():
264
  prompt = gr.Textbox(
265
  label="Prompt",
266
+ placeholder="e.g., A professional portrait in business attire, studio lighting",
267
+ lines=3,
268
+ value="A professional portrait photo, high quality, detailed face"
269
  )
270
  with gr.Row():
271
+ image = gr.Image(
272
+ label="Upload Face Image",
273
+ type="pil",
274
+ height=400
275
+ )
276
  with gr.Accordion("Advanced Settings", open=False):
277
  negative_prompt = gr.Textbox(
278
  label="Negative prompt",
279
+ placeholder="Things to avoid in the image",
280
+ value="low quality, blurry, distorted, disfigured",
281
  visible=True,
282
  )
283
  seed = gr.Slider(
 
285
  minimum=0,
286
  maximum=MAX_SEED,
287
  step=1,
288
+ value=66,
289
  )
290
  randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
291
  with gr.Row():
 
304
  value=25,
305
  )
306
  with gr.Row():
307
+ button = gr.Button("🎨 Generate Portrait", elem_id="button", variant="primary", scale=1)
308
 
309
  with gr.Column(elem_id="col-right"):
310
+ result = gr.Image(label="Generated Portrait", show_label=True)
311
+ seed_used = gr.Number(label="Seed Used", precision=0)
312
 
313
+ # 예제 추가
314
+ gr.Examples(
315
+ examples=[
316
+ ["A cinematic portrait, dramatic lighting, professional photography", None],
317
+ ["An oil painting portrait in Renaissance style, classical art", None],
318
+ ["A cyberpunk character portrait, neon lights, futuristic", None],
319
+ ],
320
+ inputs=[prompt, image],
321
+ )
322
 
323
  button.click(
324
+ fn=infer,
325
+ inputs=[prompt, image, negative_prompt, seed, randomize_seed, guidance_scale, num_inference_steps],
326
+ outputs=[result, seed_used]
327
  )
328
 
329
+ if __name__ == "__main__":
330
+ Kolors.queue(max_size=10).launch(debug=True, share=False)