Spaces:
Sleeping
Sleeping
import tensorflow as tf | |
import os | |
import numpy as np | |
import cv2 | |
import logging | |
from datetime import datetime | |
import matplotlib.pyplot as plt | |
from frame_slicer import extract_video_frames # Make sure this module is in PYTHONPATH | |
# ----------------- Configuration ----------------- # | |
MODEL_PATH = os.path.join(os.path.dirname(__file__), "training_output", "final_model_2.h5") | |
N_FRAMES = 30 | |
IMG_SIZE = (96, 96) | |
RESULT_PATH = os.path.join(os.path.dirname(__file__), "results") | |
FIGHT_THRESHOLD = 0.61 | |
os.makedirs(RESULT_PATH, exist_ok=True) | |
# ----------------- Logging Setup ----------------- # | |
logging.basicConfig(level=logging.INFO, format='[%(levelname)s] %(message)s') | |
logger = logging.getLogger(__name__) | |
# ----------------- Main Detector Class ----------------- # | |
class FightDetector: | |
def __init__(self): | |
self.model = self._load_model() | |
def _load_model(self): | |
try: | |
model = tf.keras.models.load_model(MODEL_PATH, compile=False) | |
logger.info(f"Model loaded successfully. Input shape: {model.input_shape}") | |
return model | |
except Exception as e: | |
logger.error(f"Model loading failed: {e}") | |
return None | |
def _extract_frames(self, video_path): | |
frames = extract_video_frames(video_path, N_FRAMES, IMG_SIZE) | |
if frames is None: | |
return None | |
blank_frames = np.all(frames == 0, axis=(1, 2, 3)).sum() | |
if blank_frames > 0: | |
logger.warning(f"{blank_frames} blank frames detected.") | |
sample_frame = (frames[0] * 255).astype(np.uint8) | |
debug_frame_path = os.path.join(RESULT_PATH, 'debug_frame.jpg') | |
cv2.imwrite(debug_frame_path, cv2.cvtColor(sample_frame, cv2.COLOR_RGB2BGR)) | |
logger.info(f"Debug frame saved to: {debug_frame_path}") | |
return frames | |
def predict(self, video_path): | |
if not os.path.exists(video_path): | |
return "Error: Video not found", None | |
try: | |
frames = self._extract_frames(video_path) | |
if frames is None: | |
return "Error: Frame extraction failed", None | |
if frames.shape[0] != N_FRAMES: | |
return f"Error: Expected {N_FRAMES} frames, got {frames.shape[0]}", None | |
if np.all(frames == 0): | |
return "Error: All frames are blank", None | |
prediction = self.model.predict(frames[np.newaxis, ...], verbose=0)[0][0] | |
result = "FIGHT" if prediction >= FIGHT_THRESHOLD else "NORMAL" | |
confidence = min(max(abs(prediction - FIGHT_THRESHOLD) * 150 + 50, 0), 100) | |
self._debug_visualization(frames, prediction, result, video_path) | |
return f"{result} ({confidence:.1f}% confidence)", prediction | |
except Exception as e: | |
return f"Prediction error: {str(e)}", None | |
def _debug_visualization(self, frames, score, result, video_path): | |
logger.info(f"Prediction Score: {score:.4f}") | |
logger.info(f"Decision: {result}") | |
plt.figure(figsize=(15, 5)) | |
for i in range(min(10, len(frames))): | |
plt.subplot(2, 5, i + 1) | |
plt.imshow(frames[i]) | |
plt.title(f"Frame {i}\nMean: {frames[i].mean():.2f}") | |
plt.axis('off') | |
plt.suptitle(f"Prediction: {result} (Score: {score:.4f})") | |
plt.tight_layout() | |
base_name = os.path.splitext(os.path.basename(video_path))[0] | |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
save_path = os.path.join(RESULT_PATH, f"{base_name}_result_{timestamp}.png") | |
plt.savefig(save_path) | |
plt.close() | |
logger.info(f"Visualization saved to: {save_path}") | |
# ----------------- External Interface ----------------- # | |
def fight_detec(video_path: str, debug: bool = True): | |
detector = FightDetector() | |
if detector.model is None: | |
return "Error: Model loading failed", None | |
return detector.predict(video_path) | |
# ----------------- Optional CLI Entry ----------------- # | |
if __name__ == "__main__": | |
path0 = input("Enter the local path to the video file to detect fight: ") | |
path = path0.strip('"') # Remove extra quotes if copied from Windows | |
logger.info(f"Loading video: {path}") | |
result, score = fight_detec(path) | |
logger.info(f"Result: {result}") | |