yaron123 commited on
Commit
fe55e41
·
1 Parent(s): ffe6ef9
Files changed (1) hide show
  1. app.py +9 -4
app.py CHANGED
@@ -124,17 +124,24 @@ function custom(){
124
 
125
  # torch pipes
126
 
127
- image_pipe = FluxPipeline.from_pretrained(base, torch_dtype=torch.bfloat16).to(device)
 
 
 
 
 
128
 
 
129
  image_pipe.enable_model_cpu_offload()
 
130
 
131
  video_pipe = CogVideoXImageToVideoPipeline.from_pretrained(
132
  "THUDM/CogVideoX-5b-I2V",
133
  torch_dtype=torch.bfloat16
134
  ).to(device)
135
-
136
  video_pipe.vae.enable_tiling()
137
  video_pipe.vae.enable_slicing()
 
138
 
139
  # functionality
140
 
@@ -155,7 +162,6 @@ def pipe_generate(img,p1,p2,time,title):
155
  guidance_scale=img_accu,
156
  num_images_per_prompt=1,
157
  num_inference_steps=image_steps,
158
- safety_checker=None,
159
  max_sequence_length=seq,
160
  generator=torch.Generator(device).manual_seed(int(str(random.random()).split(".")[1]))
161
  ).images[0]
@@ -178,7 +184,6 @@ def pipe_generate(img,p1,p2,time,title):
178
 
179
  return video_pipe(
180
  prompt=p1,
181
- safety_checker=None,
182
  negative_prompt=p2.replace("textual content, ",""),
183
  image=img,
184
  num_inference_steps=video_steps,
 
124
 
125
  # torch pipes
126
 
127
+ def disabled_safety_checker(images, clip_input):
128
+ if len(images.shape)==4:
129
+ num_images = images.shape[0]
130
+ return images, [False]*num_images
131
+ else:
132
+ return images, False
133
 
134
+ image_pipe = FluxPipeline.from_pretrained(base, torch_dtype=torch.bfloat16).to(device)
135
  image_pipe.enable_model_cpu_offload()
136
+ image_pipe.safety_checker = disabled_safety_checker
137
 
138
  video_pipe = CogVideoXImageToVideoPipeline.from_pretrained(
139
  "THUDM/CogVideoX-5b-I2V",
140
  torch_dtype=torch.bfloat16
141
  ).to(device)
 
142
  video_pipe.vae.enable_tiling()
143
  video_pipe.vae.enable_slicing()
144
+ video_pipe.safety_checker = disabled_safety_checker
145
 
146
  # functionality
147
 
 
162
  guidance_scale=img_accu,
163
  num_images_per_prompt=1,
164
  num_inference_steps=image_steps,
 
165
  max_sequence_length=seq,
166
  generator=torch.Generator(device).manual_seed(int(str(random.random()).split(".")[1]))
167
  ).images[0]
 
184
 
185
  return video_pipe(
186
  prompt=p1,
 
187
  negative_prompt=p2.replace("textual content, ",""),
188
  image=img,
189
  num_inference_steps=video_steps,