fight-detection-live-demo / Prediction.py
Aquibjaved's picture
Upload 3 files
f126c7c verified
raw
history blame contribute delete
915 Bytes
import cv2
import numpy as np
from tensorflow.keras.models import load_model
import os
from FeatureExtraction import FeatureExtractor
model = load_model('orignal_model_b32.h5')
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
# Initialize the feature extractor
feature_extractor = FeatureExtractor(img_shape=(224, 224), channels=3)
def predict_fight(frames_buffer):
# Extract feature
features_sequence = feature_extractor.extract_feature(frames_buffer)
# Transpose the feature sequence to match the shape
features_sequence = np.transpose(features_sequence, (1, 0)) # From (2048, 40) to (40, 2048)
features_sequence = np.expand_dims(features_sequence, axis=0) # Add batch dimension (1, 40, 2048)
# Predict
prediction = model.predict(features_sequence)
return prediction > 0.8 # Returning a boolean for fight detection