Update app.py
Browse files
app.py
CHANGED
|
@@ -1,24 +1,77 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
from transformers import pipeline
|
|
|
|
| 3 |
|
| 4 |
-
#
|
| 5 |
-
model_id = "sayakpaul/
|
| 6 |
|
| 7 |
-
def
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
# Gradio interface
|
| 16 |
interface = gr.Interface(
|
| 17 |
-
fn=
|
| 18 |
-
inputs="
|
| 19 |
outputs="text",
|
| 20 |
-
title="
|
| 21 |
-
description="Upload
|
| 22 |
)
|
| 23 |
|
| 24 |
interface.launch()
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
from transformers import pipeline
|
| 3 |
+
import cv2 # OpenCV for video processing
|
| 4 |
|
| 5 |
+
# Model ID for video classification (UCF101 subset)
|
| 6 |
+
model_id = "sayakpaul/videomae-base-finetuned-ucf101-subset"
|
| 7 |
|
| 8 |
+
def analyze_video(video):
|
| 9 |
+
# Extract key frames from the video using OpenCV
|
| 10 |
+
frames = extract_key_frames(video)
|
| 11 |
+
|
| 12 |
+
# Analyze key frames using video classification model
|
| 13 |
+
results = []
|
| 14 |
+
classifier = pipeline("video-classification", model=model_id)
|
| 15 |
+
for frame in frames:
|
| 16 |
+
predictions = classifier(images=frame) # Assuming model outputs probabilities
|
| 17 |
+
# Analyze predictions for insights related to the play
|
| 18 |
+
result = analyze_predictions_ucf101(predictions)
|
| 19 |
+
results.append(result)
|
| 20 |
+
|
| 21 |
+
# Aggregate results across frames and provide a final analysis
|
| 22 |
+
final_result = aggregate_results(results)
|
| 23 |
+
|
| 24 |
+
return final_result
|
| 25 |
+
|
| 26 |
+
def extract_key_frames(video):
|
| 27 |
+
cap = cv2.VideoCapture(video)
|
| 28 |
+
frames = []
|
| 29 |
+
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
| 30 |
+
fps = int(cap.get(cv2.CAP_PROP_FPS))
|
| 31 |
+
|
| 32 |
+
for i in range(frame_count):
|
| 33 |
+
ret, frame = cap.read()
|
| 34 |
+
if ret and i % (fps // 2) == 0: # Extract a frame every half second
|
| 35 |
+
frames.append(frame)
|
| 36 |
+
|
| 37 |
+
cap.release()
|
| 38 |
+
return frames
|
| 39 |
+
|
| 40 |
+
def analyze_predictions_ucf101(predictions):
|
| 41 |
+
# Analyze the model's predictions (probabilities) for insights relevant to baseball plays
|
| 42 |
+
# For simplicity, we'll assume predictions return the top-1 class
|
| 43 |
+
actions = [pred['label'] for pred in predictions]
|
| 44 |
+
|
| 45 |
+
relevant_actions = ["running", "sliding", "jumping"]
|
| 46 |
+
runner_actions = [action for action in actions if action in relevant_actions]
|
| 47 |
+
|
| 48 |
+
# Check for 'running', 'sliding' actions as key indicators for safe/out decision
|
| 49 |
+
if "sliding" in runner_actions:
|
| 50 |
+
return "potentially safe"
|
| 51 |
+
elif "running" in runner_actions:
|
| 52 |
+
return "potentially out"
|
| 53 |
+
else:
|
| 54 |
+
return "inconclusive"
|
| 55 |
+
|
| 56 |
+
def aggregate_results(results):
|
| 57 |
+
# Combine insights from analyzing each frame (e.g., dominant action classes, confidence scores)
|
| 58 |
+
safe_count = results.count("potentially safe")
|
| 59 |
+
out_count = results.count("potentially out")
|
| 60 |
+
|
| 61 |
+
if safe_count > out_count:
|
| 62 |
+
return "Safe"
|
| 63 |
+
elif out_count > safe_count:
|
| 64 |
+
return "Out"
|
| 65 |
+
else:
|
| 66 |
+
return "Inconclusive"
|
| 67 |
|
| 68 |
# Gradio interface
|
| 69 |
interface = gr.Interface(
|
| 70 |
+
fn=analyze_video,
|
| 71 |
+
inputs="video",
|
| 72 |
outputs="text",
|
| 73 |
+
title="Baseball Play Analysis (UCF101 Subset Exploration)",
|
| 74 |
+
description="Upload a video of a baseball play (safe/out at a base). This app explores using a video classification model (UCF101 subset) for analysis. Note: The model might not be specifically trained for baseball plays.",
|
| 75 |
)
|
| 76 |
|
| 77 |
interface.launch()
|