fight-detection-live-demo / Prediction.py
sdafd's picture
Update Prediction.py
e1427cb verified
raw
history blame contribute delete
672 Bytes
import numpy as np
from tensorflow.keras.models import load_model
from FeatureExtraction import FeatureExtractor
model = load_model('orignal_model_b32.h5')
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
def predict_fight(frames_buffer, threshold, feature_extractor):
features_sequence = feature_extractor.extract_feature(frames_buffer)
features_sequence = np.transpose(features_sequence, (1, 0))
features_sequence = np.expand_dims(features_sequence, axis=0)
prediction = model.predict(features_sequence)
fight_prob = prediction[0][0]
fight_detected = fight_prob > threshold
return fight_detected, fight_prob