DHEIVER commited on
Commit
e985675
·
verified ·
1 Parent(s): 6c4853a

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +75 -0
app.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ import cv2
4
+ import numpy as np
5
+ import torch
6
+ import requests
7
+ from segment_anything import sam_model_registry, SamPredictor
8
+
9
+ def download_sam_model():
10
+ model_url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
11
+ checkpoint_path = "sam_vit_h_4b8939.pth"
12
+
13
+ if not os.path.exists(checkpoint_path):
14
+ print("Downloading SAM model...")
15
+ response = requests.get(model_url, stream=True)
16
+ with open(checkpoint_path, "wb") as f:
17
+ for chunk in response.iter_content(chunk_size=8192):
18
+ if chunk:
19
+ f.write(chunk)
20
+ print("Download complete!")
21
+ return checkpoint_path
22
+
23
+ def process_video_sam(video_path):
24
+ # Download model if needed
25
+ checkpoint_path = download_sam_model()
26
+
27
+ # Initialize SAM
28
+ DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
29
+ MODEL_TYPE = "vit_h"
30
+
31
+ sam = sam_model_registry[MODEL_TYPE](checkpoint=checkpoint_path)
32
+ sam.to(device=DEVICE)
33
+ predictor = SamPredictor(sam)
34
+
35
+ # Process video
36
+ cap = cv2.VideoCapture(video_path)
37
+ fps = cap.get(cv2.CAP_PROP_FPS)
38
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
39
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
40
+
41
+ output_path = "output_video.mp4"
42
+ out = cv2.VideoWriter(output_path,
43
+ cv2.VideoWriter_fourcc(*'mp4v'),
44
+ fps,
45
+ (width, height))
46
+
47
+ while cap.isOpened():
48
+ ret, frame = cap.read()
49
+ if not ret:
50
+ break
51
+
52
+ predictor.set_image(frame)
53
+ masks = predictor.generate()
54
+
55
+ annotated_frame = frame.copy()
56
+ for mask in masks[0]:
57
+ annotated_frame[mask.mask] = annotated_frame[mask.mask] * 0.5 + np.array([0, 255, 0]) * 0.5
58
+
59
+ out.write(annotated_frame)
60
+
61
+ cap.release()
62
+ out.release()
63
+
64
+ return output_path
65
+
66
+ iface = gr.Interface(
67
+ fn=process_video_sam,
68
+ inputs=gr.Video(label="Upload Video"),
69
+ outputs=gr.Video(label="Segmented Video"),
70
+ title="Video Segmentation with SAM",
71
+ description="Upload a video to segment objects using Segment Anything Model"
72
+ )
73
+
74
+ if __name__ == "__main__":
75
+ iface.launch()