Update app.py
Browse files
app.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
| 1 |
import subprocess
|
| 2 |
import re
|
|
|
|
| 3 |
|
| 4 |
# Define the command to be executed
|
| 5 |
command = ["python", "setup.py", "build_ext", "--inplace"]
|
|
@@ -43,7 +44,7 @@ def get_video_fps(video_path):
|
|
| 43 |
|
| 44 |
return fps
|
| 45 |
|
| 46 |
-
def
|
| 47 |
# we clean all
|
| 48 |
return [
|
| 49 |
image, # first_frame_path
|
|
@@ -59,10 +60,10 @@ def preprocess_video_in(video_path):
|
|
| 59 |
unique_id = datetime.now().strftime('%Y%m%d%H%M%S')
|
| 60 |
|
| 61 |
# Set directory with this ID to store video frames
|
| 62 |
-
|
| 63 |
|
| 64 |
# Create the output directory
|
| 65 |
-
os.makedirs(
|
| 66 |
|
| 67 |
### Process video frames ###
|
| 68 |
# Open the video file
|
|
@@ -87,7 +88,7 @@ def preprocess_video_in(video_path):
|
|
| 87 |
break
|
| 88 |
|
| 89 |
# Format the frame filename as '00000.jpg'
|
| 90 |
-
frame_filename = os.path.join(
|
| 91 |
|
| 92 |
# Save the frame as a JPEG file
|
| 93 |
cv2.imwrite(frame_filename, frame)
|
|
@@ -103,12 +104,11 @@ def preprocess_video_in(video_path):
|
|
| 103 |
|
| 104 |
# scan all the JPEG frame names in this directory
|
| 105 |
scanned_frames = [
|
| 106 |
-
p for p in os.listdir(
|
| 107 |
if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
|
| 108 |
]
|
| 109 |
scanned_frames.sort(key=lambda p: int(os.path.splitext(p)[0]))
|
| 110 |
print(f"SCANNED_FRAMES: {scanned_frames}")
|
| 111 |
-
|
| 112 |
|
| 113 |
return [
|
| 114 |
first_frame, # first_frame_path
|
|
@@ -116,7 +116,7 @@ def preprocess_video_in(video_path):
|
|
| 116 |
gr.State([]), # trackings_input_label
|
| 117 |
first_frame, # input_first_frame_image
|
| 118 |
first_frame, # points_map
|
| 119 |
-
|
| 120 |
scanned_frames, # scanned_frames
|
| 121 |
None, # stored_inference_state
|
| 122 |
None, # stored_frame_names
|
|
@@ -195,46 +195,61 @@ def load_model(checkpoint):
|
|
| 195 |
if checkpoint == "tiny":
|
| 196 |
sam2_checkpoint = "./checkpoints/sam2_hiera_tiny.pt"
|
| 197 |
model_cfg = "sam2_hiera_t.yaml"
|
| 198 |
-
return sam2_checkpoint, model_cfg
|
| 199 |
elif checkpoint == "samll":
|
| 200 |
sam2_checkpoint = "./checkpoints/sam2_hiera_small.pt"
|
| 201 |
model_cfg = "sam2_hiera_s.yaml"
|
| 202 |
-
return sam2_checkpoint, model_cfg
|
| 203 |
elif checkpoint == "base-plus":
|
| 204 |
sam2_checkpoint = "./checkpoints/sam2_hiera_base_plus.pt"
|
| 205 |
model_cfg = "sam2_hiera_b+.yaml"
|
| 206 |
-
return sam2_checkpoint, model_cfg
|
| 207 |
elif checkpoint == "large":
|
| 208 |
sam2_checkpoint = "./checkpoints/sam2_hiera_large.pt"
|
| 209 |
model_cfg = "sam2_hiera_l.yaml"
|
| 210 |
-
return sam2_checkpoint, model_cfg
|
| 211 |
|
| 212 |
|
| 213 |
|
| 214 |
-
def
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 218 |
sam2_checkpoint, model_cfg = load_model(checkpoint)
|
|
|
|
|
|
|
|
|
|
| 219 |
predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint)
|
|
|
|
| 220 |
|
| 221 |
-
|
| 222 |
# `video_dir` a directory of JPEG frames with filenames like `<frame_index>.jpg`
|
| 223 |
print(f"STATE FRAME OUTPUT DIRECTORY: {video_frames_dir}")
|
| 224 |
video_dir = video_frames_dir
|
| 225 |
|
| 226 |
# scan all the JPEG frame names in this directory
|
| 227 |
frame_names = scanned_frames
|
| 228 |
-
|
|
|
|
| 229 |
inference_state = predictor.init_state(video_path=video_dir)
|
|
|
|
| 230 |
|
| 231 |
# segment and track one object
|
| 232 |
# predictor.reset_state(inference_state) # if any previous tracking, reset
|
| 233 |
|
|
|
|
| 234 |
new_working_frame = None
|
| 235 |
# Add new point
|
| 236 |
-
if working_frame
|
| 237 |
-
ann_frame_idx = 0 # the frame index we interact with
|
| 238 |
new_working_frame = "frames_output_images/frame_0.jpg"
|
| 239 |
else:
|
| 240 |
# Use a regular expression to find the integer
|
|
@@ -244,6 +259,7 @@ def sam_process(input_first_frame_image, checkpoint, tracking_points, trackings_
|
|
| 244 |
frame_number = int(match.group(1))
|
| 245 |
ann_frame_idx = frame_number
|
| 246 |
new_working_frame = f"frames_output_images/frame_{ann_frame_idx}.jpg"
|
|
|
|
| 247 |
|
| 248 |
ann_obj_id = 1 # give a unique id to each object we interact with (it can be any integers)
|
| 249 |
|
|
@@ -458,7 +474,7 @@ with gr.Blocks() as demo:
|
|
| 458 |
|
| 459 |
# Clear every points clicked and added to the map
|
| 460 |
clear_points_btn.click(
|
| 461 |
-
fn =
|
| 462 |
inputs = input_first_frame_image, # we get the untouched hidden image
|
| 463 |
outputs = [
|
| 464 |
first_frame_path,
|
|
@@ -480,9 +496,21 @@ with gr.Blocks() as demo:
|
|
| 480 |
"""
|
| 481 |
|
| 482 |
submit_btn.click(
|
| 483 |
-
fn =
|
| 484 |
-
inputs = [
|
| 485 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 486 |
)
|
| 487 |
|
| 488 |
propagate_btn.click(
|
|
|
|
| 1 |
import subprocess
|
| 2 |
import re
|
| 3 |
+
from typing import List, Tuple, Optional
|
| 4 |
|
| 5 |
# Define the command to be executed
|
| 6 |
command = ["python", "setup.py", "build_ext", "--inplace"]
|
|
|
|
| 44 |
|
| 45 |
return fps
|
| 46 |
|
| 47 |
+
def clear_points(image):
|
| 48 |
# we clean all
|
| 49 |
return [
|
| 50 |
image, # first_frame_path
|
|
|
|
| 60 |
unique_id = datetime.now().strftime('%Y%m%d%H%M%S')
|
| 61 |
|
| 62 |
# Set directory with this ID to store video frames
|
| 63 |
+
extracted_frames_output_dir = f'frames_{unique_id}'
|
| 64 |
|
| 65 |
# Create the output directory
|
| 66 |
+
os.makedirs(extracted_frames_output_dir, exist_ok=True)
|
| 67 |
|
| 68 |
### Process video frames ###
|
| 69 |
# Open the video file
|
|
|
|
| 88 |
break
|
| 89 |
|
| 90 |
# Format the frame filename as '00000.jpg'
|
| 91 |
+
frame_filename = os.path.join(extracted_frames_output_dir, f'{frame_number:05d}.jpg')
|
| 92 |
|
| 93 |
# Save the frame as a JPEG file
|
| 94 |
cv2.imwrite(frame_filename, frame)
|
|
|
|
| 104 |
|
| 105 |
# scan all the JPEG frame names in this directory
|
| 106 |
scanned_frames = [
|
| 107 |
+
p for p in os.listdir(extracted_frames_output_dir)
|
| 108 |
if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
|
| 109 |
]
|
| 110 |
scanned_frames.sort(key=lambda p: int(os.path.splitext(p)[0]))
|
| 111 |
print(f"SCANNED_FRAMES: {scanned_frames}")
|
|
|
|
| 112 |
|
| 113 |
return [
|
| 114 |
first_frame, # first_frame_path
|
|
|
|
| 116 |
gr.State([]), # trackings_input_label
|
| 117 |
first_frame, # input_first_frame_image
|
| 118 |
first_frame, # points_map
|
| 119 |
+
extracted_frames_output_dir, # video_frames_dir
|
| 120 |
scanned_frames, # scanned_frames
|
| 121 |
None, # stored_inference_state
|
| 122 |
None, # stored_frame_names
|
|
|
|
| 195 |
if checkpoint == "tiny":
|
| 196 |
sam2_checkpoint = "./checkpoints/sam2_hiera_tiny.pt"
|
| 197 |
model_cfg = "sam2_hiera_t.yaml"
|
| 198 |
+
return [sam2_checkpoint, model_cfg]
|
| 199 |
elif checkpoint == "samll":
|
| 200 |
sam2_checkpoint = "./checkpoints/sam2_hiera_small.pt"
|
| 201 |
model_cfg = "sam2_hiera_s.yaml"
|
| 202 |
+
return [sam2_checkpoint, model_cfg]
|
| 203 |
elif checkpoint == "base-plus":
|
| 204 |
sam2_checkpoint = "./checkpoints/sam2_hiera_base_plus.pt"
|
| 205 |
model_cfg = "sam2_hiera_b+.yaml"
|
| 206 |
+
return [sam2_checkpoint, model_cfg]
|
| 207 |
elif checkpoint == "large":
|
| 208 |
sam2_checkpoint = "./checkpoints/sam2_hiera_large.pt"
|
| 209 |
model_cfg = "sam2_hiera_l.yaml"
|
| 210 |
+
return [sam2_checkpoint, model_cfg]
|
| 211 |
|
| 212 |
|
| 213 |
|
| 214 |
+
def get_mask_sam_process(
|
| 215 |
+
input_first_frame_image,
|
| 216 |
+
checkpoint,
|
| 217 |
+
tracking_points,
|
| 218 |
+
trackings_input_label,
|
| 219 |
+
video_frames_dir, # extracted_frames_output_dir defined in 'preprocess_video_in' function
|
| 220 |
+
scanned_frames,
|
| 221 |
+
working_frame: str = None, # current frame being added points
|
| 222 |
+
progress=gr.Progress(track_tqdm=True)
|
| 223 |
+
):
|
| 224 |
+
|
| 225 |
+
# get model and model config paths
|
| 226 |
+
print(f"USER CHOSEN CHECKPOINT: {checkpoint}")
|
| 227 |
sam2_checkpoint, model_cfg = load_model(checkpoint)
|
| 228 |
+
print("MODEL LOADED")
|
| 229 |
+
|
| 230 |
+
# set predictor
|
| 231 |
predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint)
|
| 232 |
+
print("PREDICTOR READY")
|
| 233 |
|
|
|
|
| 234 |
# `video_dir` a directory of JPEG frames with filenames like `<frame_index>.jpg`
|
| 235 |
print(f"STATE FRAME OUTPUT DIRECTORY: {video_frames_dir}")
|
| 236 |
video_dir = video_frames_dir
|
| 237 |
|
| 238 |
# scan all the JPEG frame names in this directory
|
| 239 |
frame_names = scanned_frames
|
| 240 |
+
|
| 241 |
+
# Init SAM2 inference_state
|
| 242 |
inference_state = predictor.init_state(video_path=video_dir)
|
| 243 |
+
print("NEW INFERENCE_STATE INITIATED")
|
| 244 |
|
| 245 |
# segment and track one object
|
| 246 |
# predictor.reset_state(inference_state) # if any previous tracking, reset
|
| 247 |
|
| 248 |
+
### HANDLING WORKING FRAME
|
| 249 |
new_working_frame = None
|
| 250 |
# Add new point
|
| 251 |
+
if working_frame is None:
|
| 252 |
+
ann_frame_idx = 0 # the frame index we interact with, 0 if it is the first frame
|
| 253 |
new_working_frame = "frames_output_images/frame_0.jpg"
|
| 254 |
else:
|
| 255 |
# Use a regular expression to find the integer
|
|
|
|
| 259 |
frame_number = int(match.group(1))
|
| 260 |
ann_frame_idx = frame_number
|
| 261 |
new_working_frame = f"frames_output_images/frame_{ann_frame_idx}.jpg"
|
| 262 |
+
print(f"NEW_WORKING_FRAME PATH: {new_working_frame}")
|
| 263 |
|
| 264 |
ann_obj_id = 1 # give a unique id to each object we interact with (it can be any integers)
|
| 265 |
|
|
|
|
| 474 |
|
| 475 |
# Clear every points clicked and added to the map
|
| 476 |
clear_points_btn.click(
|
| 477 |
+
fn = clear_points,
|
| 478 |
inputs = input_first_frame_image, # we get the untouched hidden image
|
| 479 |
outputs = [
|
| 480 |
first_frame_path,
|
|
|
|
| 496 |
"""
|
| 497 |
|
| 498 |
submit_btn.click(
|
| 499 |
+
fn = get_mask_sam_process,
|
| 500 |
+
inputs = [
|
| 501 |
+
input_first_frame_image,
|
| 502 |
+
checkpoint,
|
| 503 |
+
tracking_points,
|
| 504 |
+
trackings_input_label,
|
| 505 |
+
video_frames_dir,
|
| 506 |
+
scanned_frames,
|
| 507 |
+
working_frame,
|
| 508 |
+
],
|
| 509 |
+
outputs = [
|
| 510 |
+
output_result,
|
| 511 |
+
stored_frame_names,
|
| 512 |
+
stored_inference_state,
|
| 513 |
+
]
|
| 514 |
)
|
| 515 |
|
| 516 |
propagate_btn.click(
|