aiqtech commited on
Commit
c1d518d
·
verified ·
1 Parent(s): fe7b6d8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +138 -103
app.py CHANGED
@@ -17,66 +17,70 @@ from diffusers import EulerDiscreteScheduler
17
  from PIL import Image
18
  from insightface.app import FaceAnalysis
19
 
20
- # Login with HF token
 
 
21
  HF_TOKEN = os.getenv("HF_TOKEN")
 
 
 
22
  if HF_TOKEN:
23
  login(token=HF_TOKEN)
24
  print("Successfully logged in to Hugging Face Hub")
25
 
26
- # Download models
27
  print("Downloading models...")
28
  ckpt_dir = snapshot_download(repo_id="Kwai-Kolors/Kolors", token=HF_TOKEN)
29
  ckpt_dir_faceid = snapshot_download(repo_id="Kwai-Kolors/Kolors-IP-Adapter-FaceID-Plus", token=HF_TOKEN)
30
 
31
  print("Loading models on CPU first...")
32
 
33
- # Fix for ChatGLMTokenizer - monkey patch the _pad method
 
 
34
  original_chatglm_pad = ChatGLMTokenizer._pad if hasattr(ChatGLMTokenizer, '_pad') else None
35
-
36
  def fixed_pad(self, *args, **kwargs):
37
- # Remove the unexpected 'padding_side' argument if present
38
  kwargs.pop('padding_side', None)
39
  if original_chatglm_pad:
40
  return original_chatglm_pad(self, *args, **kwargs)
41
  else:
42
  return super(ChatGLMTokenizer, self)._pad(*args, **kwargs)
43
-
44
  ChatGLMTokenizer._pad = fixed_pad
45
 
46
- # Load models
 
 
 
47
  text_encoder = ChatGLMModel.from_pretrained(
48
- f'{ckpt_dir}/text_encoder',
49
- torch_dtype=torch.float16,
50
  trust_remote_code=True
51
  )
52
-
53
  tokenizer = ChatGLMTokenizer.from_pretrained(
54
- f'{ckpt_dir}/text_encoder',
55
  trust_remote_code=True
56
  )
57
-
58
  vae = AutoencoderKL.from_pretrained(
59
  f"{ckpt_dir}/vae",
60
- torch_dtype=torch.float16
61
  )
62
-
63
  scheduler = EulerDiscreteScheduler.from_pretrained(f"{ckpt_dir}/scheduler")
64
-
65
  unet = UNet2DConditionModel.from_pretrained(
66
  f"{ckpt_dir}/unet",
67
- torch_dtype=torch.float16
68
  )
69
 
70
- # Load CLIP
71
  clip_image_encoder = CLIPVisionModelWithProjection.from_pretrained(
72
- 'openai/clip-vit-large-patch14-336',
73
- torch_dtype=torch.float16,
74
  use_safetensors=True
75
  )
 
 
 
 
76
 
77
- clip_image_processor = CLIPImageProcessor(size=336, crop_size=336)
78
-
79
- # Create pipeline
80
  pipe = StableDiffusionXLPipeline(
81
  vae=vae,
82
  text_encoder=text_encoder,
@@ -90,25 +94,39 @@ pipe = StableDiffusionXLPipeline(
90
 
91
  print("Models loaded successfully!")
92
 
93
- class FaceInfoGenerator():
94
- def __init__(self, root_dir="./.insightface/"):
 
 
 
 
 
 
 
 
 
 
 
 
95
  self.app = FaceAnalysis(
96
- name='antelopev2',
97
  root=root_dir,
98
- providers=['CPUExecutionProvider']
99
  )
100
  self.app.prepare(ctx_id=0, det_size=(640, 640))
101
 
102
- def get_faceinfo_one_img(self, face_image):
103
  if face_image is None:
104
  return None
105
-
106
  face_info = self.app.get(cv2.cvtColor(np.array(face_image), cv2.COLOR_RGB2BGR))
107
-
108
  if len(face_info) == 0:
109
  return None
110
- else:
111
- face_info = sorted(face_info, key=lambda x:(x['bbox'][2]-x['bbox'][0])*(x['bbox'][3]-x['bbox'][1]))[-1]
 
 
 
112
  return face_info
113
 
114
  def face_bbox_to_square(bbox):
@@ -116,102 +134,116 @@ def face_bbox_to_square(bbox):
116
  cent_x = (l + r) / 2
117
  cent_y = (t + b) / 2
118
  w, h = r - l, b - t
119
- r = max(w, h) / 2
120
-
121
- l0 = cent_x - r
122
- r0 = cent_x + r
123
- t0 = cent_y - r
124
- b0 = cent_y + r
125
-
126
- return [l0, t0, r0, b0]
127
 
128
  MAX_SEED = np.iinfo(np.int32).max
129
  face_info_generator = FaceInfoGenerator()
130
 
 
 
 
 
 
131
  @spaces.GPU(duration=120)
132
- def infer(prompt,
133
- image=None,
134
- negative_prompt="low quality, blurry, distorted",
135
- seed=66,
136
- randomize_seed=False,
137
- guidance_scale=5.0,
138
- num_inference_steps=50
139
- ):
 
140
  if image is None:
141
  gr.Warning("Please upload an image with a face.")
142
  return None, 0
143
-
144
- # Face detection on CPU
145
  face_info = face_info_generator.get_faceinfo_one_img(image)
146
  if face_info is None:
147
  raise gr.Error("No face detected. Please upload an image with a clear face.")
148
-
 
149
  face_bbox_square = face_bbox_to_square(face_info["bbox"])
150
- crop_image = image.crop(face_bbox_square)
151
- crop_image = crop_image.resize((336, 336))
152
- crop_image = [crop_image]
153
  face_embeds = torch.from_numpy(np.array([face_info["embedding"]]))
154
-
155
- # Move to GPU
156
- device = torch.device("cuda")
157
  global pipe
158
-
159
- # Move models to GPU
160
- pipe.vae = pipe.vae.to(device)
161
- pipe.text_encoder = pipe.text_encoder.to(device)
162
- pipe.unet = pipe.unet.to(device)
163
- pipe.face_clip_encoder = pipe.face_clip_encoder.to(device)
164
-
165
- face_embeds = face_embeds.to(device, dtype=torch.float16)
166
-
167
- # Load IP adapter
168
- pipe.load_ip_adapter_faceid_plus(f'{ckpt_dir_faceid}/ipa-faceid-plus.bin', device=device)
169
  pipe.set_face_fidelity_scale(0.8)
170
-
171
  if randomize_seed:
172
  seed = random.randint(0, MAX_SEED)
173
-
174
  generator = torch.Generator(device=device).manual_seed(seed)
175
-
176
- # Generate image
177
  with torch.no_grad():
178
- with torch.autocast(device_type="cuda", dtype=torch.float16):
179
- result = pipe(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180
  prompt=prompt,
181
  negative_prompt=negative_prompt,
182
  height=1024,
183
  width=1024,
184
- num_inference_steps=num_inference_steps,
185
- guidance_scale=guidance_scale,
186
  num_images_per_prompt=1,
187
  generator=generator,
188
  face_crop_image=crop_image,
189
  face_insightface_embeds=face_embeds
190
- ).images[0]
191
-
192
- # Move models back to CPU to free GPU memory
193
- pipe.vae = pipe.vae.to("cpu")
194
- pipe.text_encoder = pipe.text_encoder.to("cpu")
195
- pipe.unet = pipe.unet.to("cpu")
196
- pipe.face_clip_encoder = pipe.face_clip_encoder.to("cpu")
197
- torch.cuda.empty_cache()
198
-
 
 
 
 
 
 
199
  return result, seed
200
 
 
 
 
201
  css = """
202
- footer {
203
- visibility: hidden;
204
- }
205
- #col-left, #col-right {
206
- max-width: 640px;
207
- margin: 0 auto;
208
- }
209
- .gr-button {
210
- max-width: 100%;
211
- }
212
  """
213
 
214
- # Gradio interface
215
  with gr.Blocks(theme="soft", css=css) as Kolors:
216
  gr.HTML(
217
  """
@@ -226,10 +258,13 @@ with gr.Blocks(theme="soft", css=css) as Kolors:
226
  <img src="https://img.shields.io/badge/Discord-OpenFree%20AI-purple?style=for-the-badge&logo=discord" alt="Discord">
227
  </a>
228
  </div>
 
 
 
229
  </div>
230
- """
231
  )
232
-
233
  with gr.Row():
234
  with gr.Column(elem_id="col-left"):
235
  prompt = gr.Textbox(
@@ -239,7 +274,7 @@ with gr.Blocks(theme="soft", css=css) as Kolors:
239
  value="A professional portrait photo, high quality"
240
  )
241
  image = gr.Image(label="Upload Face Image", type="pil", height=300)
242
-
243
  with gr.Accordion("Advanced Settings", open=False):
244
  negative_prompt = gr.Textbox(
245
  label="Negative prompt",
@@ -247,15 +282,15 @@ with gr.Blocks(theme="soft", css=css) as Kolors:
247
  )
248
  seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=66)
249
  randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
250
- guidance_scale = gr.Slider(label="Guidance", minimum=1, maximum=10, step=0.5, value=5)
251
  num_inference_steps = gr.Slider(label="Steps", minimum=10, maximum=50, step=5, value=25)
252
-
253
  button = gr.Button("🎨 Generate Portrait", variant="primary")
254
-
255
  with gr.Column(elem_id="col-right"):
256
  result = gr.Image(label="Generated Portrait")
257
  seed_used = gr.Number(label="Seed Used", precision=0)
258
-
259
  button.click(
260
  fn=infer,
261
  inputs=[prompt, image, negative_prompt, seed, randomize_seed, guidance_scale, num_inference_steps],
@@ -263,4 +298,4 @@ with gr.Blocks(theme="soft", css=css) as Kolors:
263
  )
264
 
265
  if __name__ == "__main__":
266
- Kolors.queue(max_size=20).launch(debug=True)
 
17
  from PIL import Image
18
  from insightface.app import FaceAnalysis
19
 
20
+ # ---------------------------
21
+ # Runtime / device settings
22
+ # ---------------------------
23
  HF_TOKEN = os.getenv("HF_TOKEN")
24
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
25
+ DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32
26
+
27
  if HF_TOKEN:
28
  login(token=HF_TOKEN)
29
  print("Successfully logged in to Hugging Face Hub")
30
 
 
31
  print("Downloading models...")
32
  ckpt_dir = snapshot_download(repo_id="Kwai-Kolors/Kolors", token=HF_TOKEN)
33
  ckpt_dir_faceid = snapshot_download(repo_id="Kwai-Kolors/Kolors-IP-Adapter-FaceID-Plus", token=HF_TOKEN)
34
 
35
  print("Loading models on CPU first...")
36
 
37
+ # ---------------------------
38
+ # ChatGLM tokenizer pad fix
39
+ # ---------------------------
40
  original_chatglm_pad = ChatGLMTokenizer._pad if hasattr(ChatGLMTokenizer, '_pad') else None
 
41
  def fixed_pad(self, *args, **kwargs):
 
42
  kwargs.pop('padding_side', None)
43
  if original_chatglm_pad:
44
  return original_chatglm_pad(self, *args, **kwargs)
45
  else:
46
  return super(ChatGLMTokenizer, self)._pad(*args, **kwargs)
 
47
  ChatGLMTokenizer._pad = fixed_pad
48
 
49
+ # ---------------------------
50
+ # Load Kolors components
51
+ # NOTE: dtype is fp16 on CUDA, fp32 on CPU to avoid NaNs on CPU
52
+ # ---------------------------
53
  text_encoder = ChatGLMModel.from_pretrained(
54
+ f"{ckpt_dir}/text_encoder",
55
+ torch_dtype=DTYPE,
56
  trust_remote_code=True
57
  )
 
58
  tokenizer = ChatGLMTokenizer.from_pretrained(
59
+ f"{ckpt_dir}/text_encoder",
60
  trust_remote_code=True
61
  )
 
62
  vae = AutoencoderKL.from_pretrained(
63
  f"{ckpt_dir}/vae",
64
+ torch_dtype=DTYPE
65
  )
 
66
  scheduler = EulerDiscreteScheduler.from_pretrained(f"{ckpt_dir}/scheduler")
 
67
  unet = UNet2DConditionModel.from_pretrained(
68
  f"{ckpt_dir}/unet",
69
+ torch_dtype=DTYPE
70
  )
71
 
72
+ # CLIP image encoder + processor
73
  clip_image_encoder = CLIPVisionModelWithProjection.from_pretrained(
74
+ "openai/clip-vit-large-patch14-336",
75
+ torch_dtype=DTYPE,
76
  use_safetensors=True
77
  )
78
+ # Prefer from_pretrained for config parity
79
+ clip_image_processor = CLIPImageProcessor.from_pretrained(
80
+ "openai/clip-vit-large-patch14-336"
81
+ )
82
 
83
+ # Create pipeline (initially on CPU to be safe with memory)
 
 
84
  pipe = StableDiffusionXLPipeline(
85
  vae=vae,
86
  text_encoder=text_encoder,
 
94
 
95
  print("Models loaded successfully!")
96
 
97
+ # ---------------------------
98
+ # InsightFace helper (CPU by default; GPU if available)
99
+ # ---------------------------
100
+ class FaceInfoGenerator:
101
+ def __init__(self, root_dir: str = "./.insightface/"):
102
+ providers = ["CPUExecutionProvider"]
103
+ # Try to prefer CUDA provider if available in runtime
104
+ try:
105
+ import onnxruntime as ort
106
+ if "CUDAExecutionProvider" in ort.get_available_providers():
107
+ providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
108
+ except Exception:
109
+ pass
110
+
111
  self.app = FaceAnalysis(
112
+ name="antelopev2",
113
  root=root_dir,
114
+ providers=providers
115
  )
116
  self.app.prepare(ctx_id=0, det_size=(640, 640))
117
 
118
+ def get_faceinfo_one_img(self, face_image: Image.Image):
119
  if face_image is None:
120
  return None
121
+ # PIL RGB -> OpenCV BGR
122
  face_info = self.app.get(cv2.cvtColor(np.array(face_image), cv2.COLOR_RGB2BGR))
 
123
  if len(face_info) == 0:
124
  return None
125
+ # Largest face
126
+ face_info = sorted(
127
+ face_info,
128
+ key=lambda x: (x["bbox"][2] - x["bbox"][0]) * (x["bbox"][3] - x["bbox"][1])
129
+ )[-1]
130
  return face_info
131
 
132
  def face_bbox_to_square(bbox):
 
134
  cent_x = (l + r) / 2
135
  cent_y = (t + b) / 2
136
  w, h = r - l, b - t
137
+ rad = max(w, h) / 2
138
+ return [cent_x - rad, cent_y - rad, cent_x + rad, cent_y + rad]
 
 
 
 
 
 
139
 
140
  MAX_SEED = np.iinfo(np.int32).max
141
  face_info_generator = FaceInfoGenerator()
142
 
143
+ # ---------------------------
144
+ # Inference function
145
+ # - Uses fp16 autocast only on CUDA
146
+ # - Ensures dtype/device consistency to avoid NaNs
147
+ # ---------------------------
148
  @spaces.GPU(duration=120)
149
+ def infer(
150
+ prompt,
151
+ image=None,
152
+ negative_prompt="low quality, blurry, distorted",
153
+ seed=66,
154
+ randomize_seed=False,
155
+ guidance_scale=5.0,
156
+ num_inference_steps=25
157
+ ):
158
  if image is None:
159
  gr.Warning("Please upload an image with a face.")
160
  return None, 0
161
+
162
+ # Detect face (InsightFace)
163
  face_info = face_info_generator.get_faceinfo_one_img(image)
164
  if face_info is None:
165
  raise gr.Error("No face detected. Please upload an image with a clear face.")
166
+
167
+ # Prepare crop for IP-Adapter FaceID
168
  face_bbox_square = face_bbox_to_square(face_info["bbox"])
169
+ crop_image = image.crop(face_bbox_square).resize((336, 336))
170
+ crop_image = [crop_image] # pipeline expects list
 
171
  face_embeds = torch.from_numpy(np.array([face_info["embedding"]]))
172
+
173
+ # Device move
174
+ device = torch.device(DEVICE)
175
  global pipe
176
+
177
+ # Move modules to device with proper dtype
178
+ pipe.vae = pipe.vae.to(device, dtype=DTYPE)
179
+ pipe.text_encoder = pipe.text_encoder.to(device, dtype=DTYPE)
180
+ pipe.unet = pipe.unet.to(device, dtype=DTYPE)
181
+ pipe.face_clip_encoder = pipe.face_clip_encoder.to(device, dtype=DTYPE)
182
+
183
+ face_embeds = face_embeds.to(device, dtype=DTYPE)
184
+
185
+ # Load IP-Adapter weights (FaceID Plus)
186
+ pipe.load_ip_adapter_faceid_plus(f"{ckpt_dir_faceid}/ipa-faceid-plus.bin", device=device)
187
  pipe.set_face_fidelity_scale(0.8)
188
+
189
  if randomize_seed:
190
  seed = random.randint(0, MAX_SEED)
 
191
  generator = torch.Generator(device=device).manual_seed(seed)
192
+
193
+ # Inference: autocast only on CUDA
194
  with torch.no_grad():
195
+ if DEVICE == "cuda":
196
+ with torch.autocast(device_type="cuda", dtype=torch.float16):
197
+ images = pipe(
198
+ prompt=prompt,
199
+ negative_prompt=negative_prompt,
200
+ height=1024,
201
+ width=1024,
202
+ num_inference_steps=int(num_inference_steps),
203
+ guidance_scale=float(guidance_scale),
204
+ num_images_per_prompt=1,
205
+ generator=generator,
206
+ face_crop_image=crop_image,
207
+ face_insightface_embeds=face_embeds
208
+ ).images
209
+ else:
210
+ images = pipe(
211
  prompt=prompt,
212
  negative_prompt=negative_prompt,
213
  height=1024,
214
  width=1024,
215
+ num_inference_steps=int(num_inference_steps),
216
+ guidance_scale=float(guidance_scale),
217
  num_images_per_prompt=1,
218
  generator=generator,
219
  face_crop_image=crop_image,
220
  face_insightface_embeds=face_embeds
221
+ ).images
222
+
223
+ result = images[0]
224
+
225
+ # Offload back to CPU to free GPU memory
226
+ try:
227
+ pipe.vae = pipe.vae.to("cpu")
228
+ pipe.text_encoder = pipe.text_encoder.to("cpu")
229
+ pipe.unet = pipe.unet.to("cpu")
230
+ pipe.face_clip_encoder = pipe.face_clip_encoder.to("cpu")
231
+ if DEVICE == "cuda":
232
+ torch.cuda.empty_cache()
233
+ except Exception:
234
+ pass
235
+
236
  return result, seed
237
 
238
+ # ---------------------------
239
+ # Gradio UI
240
+ # ---------------------------
241
  css = """
242
+ footer { visibility: hidden; }
243
+ #col-left, #col-right { max-width: 640px; margin: 0 auto; }
244
+ .gr-button { max-width: 100%; }
 
 
 
 
 
 
 
245
  """
246
 
 
247
  with gr.Blocks(theme="soft", css=css) as Kolors:
248
  gr.HTML(
249
  """
 
258
  <img src="https://img.shields.io/badge/Discord-OpenFree%20AI-purple?style=for-the-badge&logo=discord" alt="Discord">
259
  </a>
260
  </div>
261
+ <div style='margin-top:8px;font-size:12px;opacity:.7;'>
262
+ Device: {device}, DType: {dtype}
263
+ </div>
264
  </div>
265
+ """.format(device=DEVICE.upper(), dtype=str(DTYPE).replace("torch.", ""))
266
  )
267
+
268
  with gr.Row():
269
  with gr.Column(elem_id="col-left"):
270
  prompt = gr.Textbox(
 
274
  value="A professional portrait photo, high quality"
275
  )
276
  image = gr.Image(label="Upload Face Image", type="pil", height=300)
277
+
278
  with gr.Accordion("Advanced Settings", open=False):
279
  negative_prompt = gr.Textbox(
280
  label="Negative prompt",
 
282
  )
283
  seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=66)
284
  randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
285
+ guidance_scale = gr.Slider(label="Guidance", minimum=1, maximum=10, step=0.5, value=5.0)
286
  num_inference_steps = gr.Slider(label="Steps", minimum=10, maximum=50, step=5, value=25)
287
+
288
  button = gr.Button("🎨 Generate Portrait", variant="primary")
289
+
290
  with gr.Column(elem_id="col-right"):
291
  result = gr.Image(label="Generated Portrait")
292
  seed_used = gr.Number(label="Seed Used", precision=0)
293
+
294
  button.click(
295
  fn=infer,
296
  inputs=[prompt, image, negative_prompt, seed, randomize_seed, guidance_scale, num_inference_steps],
 
298
  )
299
 
300
  if __name__ == "__main__":
301
+ Kolors.queue(max_size=20).launch(debug=True)