sdafd commited on
Commit
e1427cb
·
verified ·
1 Parent(s): ae0b2a7

Update Prediction.py

Browse files
Files changed (1) hide show
  1. Prediction.py +15 -24
Prediction.py CHANGED
@@ -1,24 +1,15 @@
1
- import cv2
2
- import numpy as np
3
- from tensorflow.keras.models import load_model
4
- import os
5
- from FeatureExtraction import FeatureExtractor
6
-
7
- model = load_model('orignal_model_b32.h5')
8
- model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
9
-
10
-
11
- # Initialize the feature extractor
12
- feature_extractor = FeatureExtractor(img_shape=(224, 224), channels=3)
13
-
14
- def predict_fight(frames_buffer):
15
- # Extract feature
16
- features_sequence = feature_extractor.extract_feature(frames_buffer)
17
-
18
- # Transpose the feature sequence to match the shape
19
- features_sequence = np.transpose(features_sequence, (1, 0)) # From (2048, 40) to (40, 2048)
20
- features_sequence = np.expand_dims(features_sequence, axis=0) # Add batch dimension (1, 40, 2048)
21
-
22
- # Predict
23
- prediction = model.predict(features_sequence)
24
- return prediction > 0.8 # Returning a boolean for fight detection
 
1
+ import numpy as np
2
+ from tensorflow.keras.models import load_model
3
+ from FeatureExtraction import FeatureExtractor
4
+
5
+ model = load_model('orignal_model_b32.h5')
6
+ model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
7
+
8
+ def predict_fight(frames_buffer, threshold, feature_extractor):
9
+ features_sequence = feature_extractor.extract_feature(frames_buffer)
10
+ features_sequence = np.transpose(features_sequence, (1, 0))
11
+ features_sequence = np.expand_dims(features_sequence, axis=0)
12
+ prediction = model.predict(features_sequence)
13
+ fight_prob = prediction[0][0]
14
+ fight_detected = fight_prob > threshold
15
+ return fight_detected, fight_prob