DHEIVER commited on
Commit
ded1ee1
·
verified ·
1 Parent(s): ceeee63

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -18
app.py CHANGED
@@ -20,20 +20,16 @@ def download_sam_model():
20
  print("Download complete!")
21
  return checkpoint_path
22
 
23
- def process_video_sam(video_path):
24
- # Download model if needed
25
  checkpoint_path = download_sam_model()
26
 
27
- # Initialize SAM
28
  DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
29
- MODEL_TYPE = "vit_h"
30
-
31
- sam = sam_model_registry[MODEL_TYPE](checkpoint=checkpoint_path)
32
  sam.to(device=DEVICE)
33
  predictor = SamPredictor(sam)
34
 
35
- # Process video
36
  cap = cv2.VideoCapture(video_path)
 
37
  fps = cap.get(cv2.CAP_PROP_FPS)
38
  width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
39
  height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
@@ -44,32 +40,55 @@ def process_video_sam(video_path):
44
  fps,
45
  (width, height))
46
 
 
47
  while cap.isOpened():
48
  ret, frame = cap.read()
49
  if not ret:
50
  break
51
 
52
  predictor.set_image(frame)
53
- masks = predictor.generate()
 
 
 
 
 
 
 
 
54
 
55
  annotated_frame = frame.copy()
56
- for mask in masks[0]:
57
- annotated_frame[mask.mask] = annotated_frame[mask.mask] * 0.5 + np.array([0, 255, 0]) * 0.5
58
 
59
  out.write(annotated_frame)
60
 
 
 
 
61
  cap.release()
62
  out.release()
63
 
64
  return output_path
65
 
66
- iface = gr.Interface(
67
- fn=process_video_sam,
68
- inputs=gr.Video(label="Upload Video"),
69
- outputs=gr.Video(label="Segmented Video"),
70
- title="Video Segmentation with SAM",
71
- description="Upload a video to segment objects using Segment Anything Model"
72
- )
 
 
 
 
 
 
 
 
 
 
 
73
 
74
  if __name__ == "__main__":
75
- iface.launch()
 
20
  print("Download complete!")
21
  return checkpoint_path
22
 
23
+ def process_video_sam(video_path, progress=gr.Progress()):
 
24
  checkpoint_path = download_sam_model()
25
 
 
26
  DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
27
+ sam = sam_model_registry["vit_h"](checkpoint=checkpoint_path)
 
 
28
  sam.to(device=DEVICE)
29
  predictor = SamPredictor(sam)
30
 
 
31
  cap = cv2.VideoCapture(video_path)
32
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
33
  fps = cap.get(cv2.CAP_PROP_FPS)
34
  width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
35
  height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
 
40
  fps,
41
  (width, height))
42
 
43
+ frame_count = 0
44
  while cap.isOpened():
45
  ret, frame = cap.read()
46
  if not ret:
47
  break
48
 
49
  predictor.set_image(frame)
50
+ # Gerar pontos de prompt automáticos
51
+ input_point = np.array([[width//2, height//2]])
52
+ input_label = np.array([1])
53
+
54
+ masks, scores, logits = predictor.predict(
55
+ point_coords=input_point,
56
+ point_labels=input_label,
57
+ multimask_output=True
58
+ )
59
 
60
  annotated_frame = frame.copy()
61
+ for mask in masks:
62
+ annotated_frame[mask] = annotated_frame[mask] * 0.5 + np.array([0, 255, 0]) * 0.5
63
 
64
  out.write(annotated_frame)
65
 
66
+ frame_count += 1
67
+ progress(frame_count/total_frames, desc="Processing video...")
68
+
69
  cap.release()
70
  out.release()
71
 
72
  return output_path
73
 
74
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
75
+ gr.Markdown("# Video Segmentation with SAM")
76
+ gr.Markdown("Upload a video to segment objects using Segment Anything Model")
77
+
78
+ with gr.Row():
79
+ with gr.Column():
80
+ input_video = gr.Video(label="Input Video")
81
+ process_btn = gr.Button("Process Video", variant="primary")
82
+
83
+ with gr.Column():
84
+ output_video = gr.Video(label="Segmented Video")
85
+
86
+ process_btn.click(
87
+ fn=process_video_sam,
88
+ inputs=input_video,
89
+ outputs=output_video,
90
+ api_name="segment_video"
91
+ )
92
 
93
  if __name__ == "__main__":
94
+ demo.launch()