PengWeixuanSZU commited on
Commit
6e69110
·
verified ·
1 Parent(s): 7277a9d

delete and rewrite

Browse files
Files changed (1) hide show
  1. app.py +5 -15
app.py CHANGED
@@ -56,9 +56,9 @@ random_seed = 42
56
  video_length = 201
57
  W = 1024
58
  H = W
 
59
 
60
  def get_pipe_image_and_video_predictor():
61
- device="cpu"
62
  vae = AutoencoderKLWan.from_pretrained("./model/vae", torch_dtype=torch.float16)
63
  transformer = Transformer3DModel.from_pretrained("./model/transformer", torch_dtype=torch.float16)
64
  scheduler = UniPCMultistepScheduler.from_pretrained("./model/scheduler")
@@ -177,7 +177,7 @@ def preprocess_for_removal(images, masks):
177
  out_masks.append(msk_resized)
178
  arr_images = np.stack(out_images)
179
  arr_masks = np.stack(out_masks)
180
- return torch.from_numpy(arr_images).half(), torch.from_numpy(arr_masks).half()
181
 
182
  @spaces.GPU(duration=300)
183
  def inference_and_return_video(dilation_iterations, num_inference_steps, video_state=None):
@@ -189,13 +189,8 @@ def inference_and_return_video(dilation_iterations, num_inference_steps, video_s
189
  images = np.array(images)
190
  masks = np.array(masks)
191
  img_tensor, mask_tensor = preprocess_for_removal(images, masks)
192
- print("mask_tensor shape:",mask_tensor.shape)
193
- img_tensor=img_tensor.to("cuda")
194
- mask_tensor=mask_tensor.to("cuda")
195
  mask_tensor = mask_tensor[:,:,:,:1]
196
 
197
- pipe.to("cuda")
198
-
199
  if mask_tensor.shape[1] < mask_tensor.shape[2]:
200
  height = 480
201
  width = 832
@@ -211,7 +206,7 @@ def inference_and_return_video(dilation_iterations, num_inference_steps, video_s
211
  height=height,
212
  width=width,
213
  num_inference_steps=int(num_inference_steps),
214
- generator=torch.Generator(device="cuda").manual_seed(random_seed),
215
  iterations=int(dilation_iterations)
216
  ).frames[0]
217
 
@@ -223,7 +218,7 @@ def inference_and_return_video(dilation_iterations, num_inference_steps, video_s
223
  clip.write_videofile(video_file, codec='libx264', audio=False, verbose=False, logger=None)
224
  return video_file
225
 
226
- @spaces.GPU(duration=150)
227
  def track_video(n_frames, video_state):
228
 
229
  input_points = video_state["input_points"]
@@ -247,12 +242,7 @@ def track_video(n_frames, video_state):
247
  images = [cv2.resize(img, (W_, H_)) for img in images]
248
  video_state["origin_images"] = images
249
  images = np.array(images)
250
-
251
- sam2_checkpoint = "./SAM2-Video-Predictor/checkpoints/sam2_hiera_large.pt"
252
- config = "sam2_hiera_l.yaml"
253
- video_predictor = build_sam2_video_predictor(config, sam2_checkpoint, device="cuda")
254
-
255
- inference_state = video_predictor.init_state(images=images/255, device="cuda")
256
  video_state["inference_state"] = inference_state
257
 
258
  if len(torch.from_numpy(video_state["masks"][0]).shape) == 3:
 
56
  video_length = 201
57
  W = 1024
58
  H = W
59
+ device = "cuda" #if torch.cuda.is_available() else "cpu"
60
 
61
  def get_pipe_image_and_video_predictor():
 
62
  vae = AutoencoderKLWan.from_pretrained("./model/vae", torch_dtype=torch.float16)
63
  transformer = Transformer3DModel.from_pretrained("./model/transformer", torch_dtype=torch.float16)
64
  scheduler = UniPCMultistepScheduler.from_pretrained("./model/scheduler")
 
177
  out_masks.append(msk_resized)
178
  arr_images = np.stack(out_images)
179
  arr_masks = np.stack(out_masks)
180
+ return torch.from_numpy(arr_images).half().to(device), torch.from_numpy(arr_masks).half().to(device)
181
 
182
  @spaces.GPU(duration=300)
183
  def inference_and_return_video(dilation_iterations, num_inference_steps, video_state=None):
 
189
  images = np.array(images)
190
  masks = np.array(masks)
191
  img_tensor, mask_tensor = preprocess_for_removal(images, masks)
 
 
 
192
  mask_tensor = mask_tensor[:,:,:,:1]
193
 
 
 
194
  if mask_tensor.shape[1] < mask_tensor.shape[2]:
195
  height = 480
196
  width = 832
 
206
  height=height,
207
  width=width,
208
  num_inference_steps=int(num_inference_steps),
209
+ generator=torch.Generator(device=device).manual_seed(random_seed),
210
  iterations=int(dilation_iterations)
211
  ).frames[0]
212
 
 
218
  clip.write_videofile(video_file, codec='libx264', audio=False, verbose=False, logger=None)
219
  return video_file
220
 
221
+ @spaces.GPU(duration=100)
222
  def track_video(n_frames, video_state):
223
 
224
  input_points = video_state["input_points"]
 
242
  images = [cv2.resize(img, (W_, H_)) for img in images]
243
  video_state["origin_images"] = images
244
  images = np.array(images)
245
+ inference_state = video_predictor.init_state(images=images/255, device=device)
 
 
 
 
 
246
  video_state["inference_state"] = inference_state
247
 
248
  if len(torch.from_numpy(video_state["masks"][0]).shape) == 3: