Update app.py
Browse files
app.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
-
from transformers import VideoMAEForVideoClassification,
|
|
|
|
| 3 |
import cv2 # OpenCV for video processing
|
| 4 |
|
| 5 |
# Model ID for video classification (UCF101 subset)
|
|
@@ -8,22 +9,25 @@ model_id = "MCG-NJU/videomae-base"
|
|
| 8 |
def analyze_video(video):
|
| 9 |
# Extract key frames from the video using OpenCV
|
| 10 |
frames = extract_key_frames(video)
|
| 11 |
-
|
| 12 |
# Load model and feature extractor manually
|
| 13 |
model = VideoMAEForVideoClassification.from_pretrained(model_id)
|
| 14 |
-
|
|
|
|
|
|
|
|
|
|
| 15 |
|
| 16 |
-
#
|
| 17 |
-
|
|
|
|
| 18 |
|
| 19 |
-
|
|
|
|
|
|
|
|
|
|
| 20 |
results = []
|
| 21 |
-
for
|
| 22 |
-
|
| 23 |
-
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
| 24 |
-
predictions = classifier([frame_rgb]) # Assuming model outputs probabilities
|
| 25 |
-
# Analyze predictions for insights related to the play
|
| 26 |
-
result = analyze_predictions_ucf101(predictions)
|
| 27 |
results.append(result)
|
| 28 |
|
| 29 |
# Aggregate results across frames and provide a final analysis
|
|
@@ -40,24 +44,29 @@ def extract_key_frames(video):
|
|
| 40 |
for i in range(frame_count):
|
| 41 |
ret, frame = cap.read()
|
| 42 |
if ret and i % (fps // 2) == 0: # Extract a frame every half second
|
| 43 |
-
frames.append(frame)
|
| 44 |
|
| 45 |
cap.release()
|
| 46 |
return frames
|
| 47 |
|
| 48 |
-
def analyze_predictions_ucf101(
|
| 49 |
-
#
|
| 50 |
-
|
| 51 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
|
| 53 |
relevant_actions = ["running", "sliding", "jumping"]
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
else:
|
| 62 |
return "inconclusive"
|
| 63 |
|
|
@@ -80,6 +89,7 @@ interface = gr.Interface(
|
|
| 80 |
outputs="text",
|
| 81 |
title="Baseball Play Analysis (UCF101 Subset Exploration)",
|
| 82 |
description="Upload a video of a baseball play (safe/out at a base). This app explores using a video classification model (UCF101 subset) for analysis. Note: The model might not be specifically trained for baseball plays.",
|
|
|
|
| 83 |
)
|
| 84 |
|
| 85 |
interface.launch()
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
+
from transformers import VideoMAEForVideoClassification, VideoMAEImageProcessor
|
| 3 |
+
import torch
|
| 4 |
import cv2 # OpenCV for video processing
|
| 5 |
|
| 6 |
# Model ID for video classification (UCF101 subset)
|
|
|
|
| 9 |
def analyze_video(video):
|
| 10 |
# Extract key frames from the video using OpenCV
|
| 11 |
frames = extract_key_frames(video)
|
| 12 |
+
|
| 13 |
# Load model and feature extractor manually
|
| 14 |
model = VideoMAEForVideoClassification.from_pretrained(model_id)
|
| 15 |
+
processor = VideoMAEImageProcessor.from_pretrained(model_id)
|
| 16 |
+
|
| 17 |
+
# Prepare frames for the model
|
| 18 |
+
inputs = processor(images=frames, return_tensors="pt")
|
| 19 |
|
| 20 |
+
# Make predictions
|
| 21 |
+
with torch.no_grad():
|
| 22 |
+
outputs = model(**inputs)
|
| 23 |
|
| 24 |
+
logits = outputs.logits
|
| 25 |
+
predictions = torch.argmax(logits, dim=-1)
|
| 26 |
+
|
| 27 |
+
# Analyze predictions for insights related to the play
|
| 28 |
results = []
|
| 29 |
+
for prediction in predictions:
|
| 30 |
+
result = analyze_predictions_ucf101(prediction.item())
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
results.append(result)
|
| 32 |
|
| 33 |
# Aggregate results across frames and provide a final analysis
|
|
|
|
| 44 |
for i in range(frame_count):
|
| 45 |
ret, frame = cap.read()
|
| 46 |
if ret and i % (fps // 2) == 0: # Extract a frame every half second
|
| 47 |
+
frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) # Convert to RGB
|
| 48 |
|
| 49 |
cap.release()
|
| 50 |
return frames
|
| 51 |
|
| 52 |
+
def analyze_predictions_ucf101(prediction):
|
| 53 |
+
# Map prediction to action labels (this mapping is hypothetical)
|
| 54 |
+
action_labels = {
|
| 55 |
+
0: "running",
|
| 56 |
+
1: "sliding",
|
| 57 |
+
2: "jumping",
|
| 58 |
+
# Add more labels as necessary
|
| 59 |
+
}
|
| 60 |
+
action = action_labels.get(prediction, "unknown")
|
| 61 |
|
| 62 |
relevant_actions = ["running", "sliding", "jumping"]
|
| 63 |
+
if action in relevant_actions:
|
| 64 |
+
if action == "sliding":
|
| 65 |
+
return "potentially safe"
|
| 66 |
+
elif action == "running":
|
| 67 |
+
return "potentially out"
|
| 68 |
+
else:
|
| 69 |
+
return "inconclusive"
|
| 70 |
else:
|
| 71 |
return "inconclusive"
|
| 72 |
|
|
|
|
| 89 |
outputs="text",
|
| 90 |
title="Baseball Play Analysis (UCF101 Subset Exploration)",
|
| 91 |
description="Upload a video of a baseball play (safe/out at a base). This app explores using a video classification model (UCF101 subset) for analysis. Note: The model might not be specifically trained for baseball plays.",
|
| 92 |
+
share=True
|
| 93 |
)
|
| 94 |
|
| 95 |
interface.launch()
|