MBase / app.py
MNGames's picture
Update app.py
d2d1207 verified
raw
history blame
3.05 kB
import gradio as gr
from transformers import VideoMAEForVideoClassification, VideoMAEImageProcessor
import torch
import cv2 # OpenCV for video processing
# Model ID for video classification (UCF101 subset)
model_id = "MCG-NJU/videomae-base"
def analyze_video(video):
# Extract key frames from the video using OpenCV
frames = extract_key_frames(video)
# Load model and feature extractor manually
model = VideoMAEForVideoClassification.from_pretrained(model_id)
processor = VideoMAEImageProcessor.from_pretrained(model_id)
# Prepare frames for the model
inputs = processor(images=frames, return_tensors="pt")
# Make predictions
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
predictions = torch.argmax(logits, dim=-1)
# Analyze predictions for insights related to the play
results = []
for prediction in predictions:
result = analyze_predictions_ucf101(prediction.item())
results.append(result)
# Aggregate results across frames and provide a final analysis
final_result = aggregate_results(results)
return final_result
def extract_key_frames(video):
cap = cv2.VideoCapture(video)
frames = []
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
fps = int(cap.get(cv2.CAP_PROP_FPS))
for i in range(frame_count):
ret, frame = cap.read()
if ret and i % (fps // 2) == 0: # Extract a frame every half second
frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) # Convert to RGB
cap.release()
return frames
def analyze_predictions_ucf101(prediction):
# Map prediction to action labels (this mapping is hypothetical)
action_labels = {
0: "running",
1: "sliding",
2: "jumping",
# Add more labels as necessary
}
action = action_labels.get(prediction, "unknown")
relevant_actions = ["running", "sliding", "jumping"]
if action in relevant_actions:
if action == "sliding":
return "potentially safe"
elif action == "running":
return "potentially out"
else:
return "inconclusive"
else:
return "inconclusive"
def aggregate_results(results):
# Combine insights from analyzing each frame (e.g., dominant action classes, confidence scores)
safe_count = results.count("potentially safe")
out_count = results.count("potentially out")
if safe_count > out_count:
return "Safe"
elif out_count > safe_count:
return "Out"
else:
return "Inconclusive"
# Gradio interface
interface = gr.Interface(
fn=analyze_video,
inputs="video",
outputs="text",
title="Baseball Play Analysis (UCF101 Subset Exploration)",
description="Upload a video of a baseball play (safe/out at a base). This app explores using a video classification model (UCF101 subset) for analysis. Note: The model might not be specifically trained for baseball plays."
)
interface.launch(share=True)