Spaces:
Sleeping
Sleeping
import tensorflow as tf | |
from frame_slicer import extract_video_frames | |
import cv2 | |
import os | |
import numpy as np | |
import matplotlib.pyplot as plt | |
# Configuration | |
import os | |
MODEL_PATH = os.path.join(os.path.dirname(__file__),"final_model_2.h5") | |
N_FRAMES = 30 | |
IMG_SIZE = (96, 96) | |
# Define RESULT_PATH relative to the script location | |
RESULT_PATH = os.path.join(os.path.dirname(__file__), "results") | |
def fight_detec(video_path: str, debug: bool = True): | |
"""Detects fight in a video and returns the result string and raw prediction score.""" | |
class FightDetector: | |
def __init__(self): | |
self.model = self._load_model() | |
def _load_model(self): | |
# Ensure the model path exists before loading | |
if not os.path.exists(MODEL_PATH): | |
print(f"Error: Model file not found at {MODEL_PATH}") | |
return None | |
try: | |
# Load model with compile=False if optimizer state isn't needed for inference | |
model = tf.keras.models.load_model(MODEL_PATH, compile=False) | |
if debug: | |
print("\nModel loaded successfully. Input shape:", model.input_shape) | |
return model | |
except Exception as e: | |
print(f"Model loading failed: {e}") | |
return None | |
def _extract_frames(self, video_path): | |
frames = extract_video_frames(video_path, N_FRAMES, IMG_SIZE) | |
if frames is None: | |
print(f"Frame extraction returned None for {video_path}") | |
return None | |
if debug: | |
blank_frames = np.all(frames == 0, axis=(1, 2, 3)).sum() | |
if blank_frames > 0: | |
print(f"Warning: {blank_frames} blank frames detected") | |
# Save a sample frame for debugging only if debug is True | |
if frames.shape[0] > 0 and not np.all(frames[0] == 0): # Avoid saving blank frame | |
sample_frame = (frames[0] * 255).astype(np.uint8) | |
try: | |
os.makedirs(RESULT_PATH, exist_ok=True) # Ensure result path exists | |
debug_frame_path = os.path.join(RESULT_PATH, 'debug_frame.jpg') | |
cv2.imwrite(debug_frame_path, cv2.cvtColor(sample_frame, cv2.COLOR_RGB2BGR)) | |
print(f"Debug frame saved to {debug_frame_path}") | |
except Exception as e: | |
print(f"Failed to save debug frame: {e}") | |
else: | |
print("Skipping debug frame save (first frame blank or no frames).") | |
return frames | |
def predict(self, video_path): | |
if not os.path.exists(video_path): | |
print(f"Error: Video not found at {video_path}") | |
return "Error: Video not found", None | |
try: | |
frames = self._extract_frames(video_path) | |
if frames is None: | |
return "Error: Frame extraction failed", None | |
if frames.shape[0] != N_FRAMES: | |
# Pad with last frame or zeros if not enough frames were extracted | |
print(f"Warning: Expected {N_FRAMES} frames, got {frames.shape[0]}. Padding...") | |
if frames.shape[0] == 0: # No frames at all | |
frames = np.zeros((N_FRAMES, *IMG_SIZE, 3), dtype=np.float32) | |
else: # Pad with the last available frame | |
padding_needed = N_FRAMES - frames.shape[0] | |
last_frame = frames[-1][np.newaxis, ...] | |
padding = np.repeat(last_frame, padding_needed, axis=0) | |
frames = np.concatenate((frames, padding), axis=0) | |
print(f"Frames padded to shape: {frames.shape}") | |
if np.all(frames == 0): | |
# Check if all frames are actually blank (can happen with padding) | |
print("Error: All frames are blank after processing/padding.") | |
return "Error: All frames are blank", None | |
# Perform prediction | |
prediction = self.model.predict(frames[np.newaxis, ...], verbose=0)[0][0] | |
# Determine result based on threshold | |
threshold = 0.61 # Example threshold | |
is_fight = prediction >= threshold | |
result = "FIGHT" if is_fight else "NORMAL" | |
# Calculate confidence (simple distance from threshold, scaled) | |
# Adjust scaling factor (e.g., 150) and base (e.g., 50) as needed | |
# Ensure confidence reflects certainty (higher for values far from threshold) | |
if is_fight: | |
confidence = min(max((prediction - threshold) * 150 + 50, 0), 100) | |
else: | |
confidence = min(max((threshold - prediction) * 150 + 50, 0), 100) | |
result_string = f"{result} ({confidence:.1f}% confidence)" | |
if debug: | |
print(f"Raw Prediction Score: {prediction:.4f}") | |
self._debug_visualization(frames, prediction, result_string, video_path) | |
return result_string, float(prediction) # Return string and raw score | |
except Exception as e: | |
print(f"Prediction error: {str(e)}") | |
# Consider logging the full traceback here in a real application | |
# import traceback | |
# print(traceback.format_exc()) | |
return f"Prediction error: {str(e)}", None | |
def _debug_visualization(self, frames, score, result, video_path): | |
# This function will only run if debug=True is passed to fight_detec | |
print(f"\n--- Debug Visualization ---") | |
print(f"Prediction Score: {score:.4f}") | |
print(f"Decision: {result}") | |
# Avoid plotting if matplotlib is not available or causes issues in deployment | |
try: | |
import matplotlib.pyplot as plt | |
plt.figure(figsize=(15, 5)) | |
num_frames_to_show = min(10, len(frames)) | |
for i in range(num_frames_to_show): | |
plt.subplot(2, 5, i+1) | |
# Ensure frame values are valid for imshow (0-1 or 0-255) | |
img_display = frames[i] | |
if np.max(img_display) <= 1.0: # Assuming normalized float [0,1] | |
img_display = (img_display * 255).astype(np.uint8) | |
else: # Assuming it might already be uint8 [0,255] | |
img_display = img_display.astype(np.uint8) | |
plt.imshow(img_display) | |
plt.title(f"Frame {i}\nMean: {frames[i].mean():.2f}") # Use original frame for mean | |
plt.axis('off') | |
plt.suptitle(f"Video: {os.path.basename(video_path)}\nPrediction: {result} (Raw Score: {score:.4f})") | |
plt.tight_layout(rect=[0, 0.03, 1, 0.95]) # Adjust layout | |
# Save the visualization | |
os.makedirs(RESULT_PATH, exist_ok=True) # Ensure result path exists again | |
base_name = os.path.splitext(os.path.basename(video_path))[0] | |
save_path = os.path.join(RESULT_PATH, f"{base_name}_prediction_result.png") | |
plt.savefig(save_path) | |
plt.close() # Close the plot to free memory | |
print(f"Debug visualization saved to: {save_path}") | |
except ImportError: | |
print("Matplotlib not found. Skipping debug visualization plot.") | |
except Exception as e: | |
print(f"Error during debug visualization: {e}") | |
print("--- End Debug Visualization ---") | |
# --- Main function logic --- | |
detector = FightDetector() | |
if detector.model is None: | |
# Model loading failed, return error | |
return "Error: Model loading failed", None | |
# Call the predict method | |
result_str, prediction_score = detector.predict(video_path) | |
return result_str, prediction_score | |
# # Example usage (commented out for library use) | |
# if __name__ == "__main__": | |
# # Example of how to call the function | |
# test_video = input("Enter the local path to the video file: ").strip('"') | |
# if os.path.exists(test_video): | |
# print(f"[INFO] Processing video: {test_video}") | |
# result, score = fight_detec(test_video, debug=True) # Enable debug for local testing | |
# print(f"\nFinal Result: {result}") | |
# if score is not None: | |
# print(f"Raw Score: {score:.4f}") | |
# else: | |
# print(f"Error: File not found - {test_video}") | |