Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -56,7 +56,7 @@ random_seed = 42
|
|
56 |
video_length = 201
|
57 |
W = 1024
|
58 |
H = W
|
59 |
-
device = "
|
60 |
|
61 |
def get_pipe_image_and_video_predictor():
|
62 |
vae = AutoencoderKLWan.from_pretrained("./model/vae", torch_dtype=torch.float16)
|
@@ -177,7 +177,7 @@ def preprocess_for_removal(images, masks):
|
|
177 |
out_masks.append(msk_resized)
|
178 |
arr_images = np.stack(out_images)
|
179 |
arr_masks = np.stack(out_masks)
|
180 |
-
return torch.from_numpy(arr_images).half(), torch.from_numpy(arr_masks).half()
|
181 |
|
182 |
@spaces.GPU(duration=300)
|
183 |
def inference_and_return_video(dilation_iterations, num_inference_steps, video_state=None):
|
@@ -189,8 +189,7 @@ def inference_and_return_video(dilation_iterations, num_inference_steps, video_s
|
|
189 |
images = np.array(images)
|
190 |
masks = np.array(masks)
|
191 |
img_tensor, mask_tensor = preprocess_for_removal(images, masks)
|
192 |
-
|
193 |
-
mask_tensor = mask_tensor[:,:,:,:1].to(device)
|
194 |
|
195 |
if mask_tensor.shape[1] < mask_tensor.shape[2]:
|
196 |
height = 480
|
@@ -207,7 +206,7 @@ def inference_and_return_video(dilation_iterations, num_inference_steps, video_s
|
|
207 |
height=height,
|
208 |
width=width,
|
209 |
num_inference_steps=int(num_inference_steps),
|
210 |
-
generator=torch.Generator(device=
|
211 |
iterations=int(dilation_iterations)
|
212 |
).frames[0]
|
213 |
|
@@ -403,4 +402,4 @@ with gr.Blocks() as demo:
|
|
403 |
clear_btn.click(clear_clicks, inputs=video_state, outputs=image_output)
|
404 |
track_btn.click(track_video, inputs=[n_frames_slider, video_state], outputs=video_output)
|
405 |
|
406 |
-
demo.launch()
|
|
|
56 |
video_length = 201
|
57 |
W = 1024
|
58 |
H = W
|
59 |
+
device = "cpu"
|
60 |
|
61 |
def get_pipe_image_and_video_predictor():
|
62 |
vae = AutoencoderKLWan.from_pretrained("./model/vae", torch_dtype=torch.float16)
|
|
|
177 |
out_masks.append(msk_resized)
|
178 |
arr_images = np.stack(out_images)
|
179 |
arr_masks = np.stack(out_masks)
|
180 |
+
return torch.from_numpy(arr_images).half().to(device), torch.from_numpy(arr_masks).half().to(device)
|
181 |
|
182 |
@spaces.GPU(duration=300)
|
183 |
def inference_and_return_video(dilation_iterations, num_inference_steps, video_state=None):
|
|
|
189 |
images = np.array(images)
|
190 |
masks = np.array(masks)
|
191 |
img_tensor, mask_tensor = preprocess_for_removal(images, masks)
|
192 |
+
mask_tensor = mask_tensor[:,:,:,:1]
|
|
|
193 |
|
194 |
if mask_tensor.shape[1] < mask_tensor.shape[2]:
|
195 |
height = 480
|
|
|
206 |
height=height,
|
207 |
width=width,
|
208 |
num_inference_steps=int(num_inference_steps),
|
209 |
+
generator=torch.Generator(device="cuda").manual_seed(random_seed),
|
210 |
iterations=int(dilation_iterations)
|
211 |
).frames[0]
|
212 |
|
|
|
402 |
clear_btn.click(clear_clicks, inputs=video_state, outputs=image_output)
|
403 |
track_btn.click(track_video, inputs=[n_frames_slider, video_state], outputs=video_output)
|
404 |
|
405 |
+
demo.launch()
|