SIGMitch commited on
Commit
3bcdecc
·
1 Parent(s): e5af8e8

run_safety_checker

Browse files
Files changed (1) hide show
  1. latent_consistency_img2img.py +4 -15
latent_consistency_img2img.py CHANGED
@@ -171,21 +171,8 @@ class LatentConsistencyModelImg2ImgPipeline(DiffusionPipeline):
171
  return prompt_embeds
172
 
173
  def run_safety_checker(self, image, device, dtype):
174
- if self.safety_checker is None:
175
  has_nsfw_concept = None
176
- else:
177
- if torch.is_tensor(image):
178
- feature_extractor_input = self.image_processor.postprocess(
179
- image, output_type="pil"
180
- )
181
- else:
182
- feature_extractor_input = self.image_processor.numpy_to_pil(image)
183
- safety_checker_input = self.feature_extractor(
184
- feature_extractor_input, return_tensors="pt"
185
- ).to(device)
186
- image, has_nsfw_concept = self.safety_checker(
187
- images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
188
- )
189
  return image, has_nsfw_concept
190
 
191
  def prepare_latents(
@@ -425,7 +412,9 @@ class LatentConsistencyModelImg2ImgPipeline(DiffusionPipeline):
425
  image = self.vae.decode(
426
  denoised / self.vae.config.scaling_factor, return_dict=False
427
  )[0]
428
- image, has_nsfw_concept = None
 
 
429
  else:
430
  image = denoised
431
  has_nsfw_concept = None
 
171
  return prompt_embeds
172
 
173
  def run_safety_checker(self, image, device, dtype):
 
174
  has_nsfw_concept = None
175
+
 
 
 
 
 
 
 
 
 
 
 
 
176
  return image, has_nsfw_concept
177
 
178
  def prepare_latents(
 
412
  image = self.vae.decode(
413
  denoised / self.vae.config.scaling_factor, return_dict=False
414
  )[0]
415
+ image, has_nsfw_concept = self.run_safety_checker(
416
+ image, device, prompt_embeds.dtype
417
+ )
418
  else:
419
  image = denoised
420
  has_nsfw_concept = None