Spaces:
Runtime error
Runtime error
Commit
Β·
fdc3f39
1
Parent(s):
27ef047
- app/app_utils.py +13 -101
app/app_utils.py
CHANGED
@@ -6,25 +6,22 @@ import cv2
|
|
6 |
from pytorch_grad_cam.utils.image import show_cam_on_image
|
7 |
import matplotlib.pyplot as plt
|
8 |
|
9 |
-
# Importing necessary components
|
10 |
-
from app.model import
|
11 |
from app.face_utils import get_box, display_info
|
12 |
from app.config import DICT_EMO, config_data
|
13 |
from app.plot import statistics_plot
|
14 |
|
|
|
15 |
mp_face_mesh = mp.solutions.face_mesh
|
16 |
|
17 |
-
|
18 |
-
|
19 |
-
return torch.device("mps")
|
20 |
-
elif torch.cuda.is_available():
|
21 |
-
return torch.device("cuda")
|
22 |
-
else:
|
23 |
-
return torch.device("cpu")
|
24 |
-
|
25 |
-
device = get_device()
|
26 |
print(f"Using device: {device}")
|
27 |
|
|
|
|
|
|
|
28 |
# Move models to the selected device
|
29 |
pth_model_static = pth_model_static.to(device)
|
30 |
pth_model_dynamic = pth_model_dynamic.to(device)
|
@@ -152,7 +149,7 @@ def preprocess_video_and_predict(video):
|
|
152 |
startX, startY, endX, endY = get_box(fl, w, h)
|
153 |
cur_face = frame_copy[startY:endY, startX: endX]
|
154 |
|
155 |
-
if count_face%config_data.FRAME_DOWNSAMPLING == 0:
|
156 |
cur_face_copy = pth_processing(Image.fromarray(cur_face)).to(device)
|
157 |
with torch.no_grad():
|
158 |
features = torch.nn.functional.relu(pth_model_static.extract_features(cur_face_copy)).detach().cpu().numpy()
|
@@ -228,8 +225,10 @@ def preprocess_video_and_predict(video):
|
|
228 |
|
229 |
return video, path_save_video_face, path_save_video_hm, stat, au_stat
|
230 |
|
231 |
-
|
232 |
-
|
|
|
|
|
233 |
|
234 |
def au_statistics_plot(frames, au_intensities_list):
|
235 |
fig, ax = plt.subplots(figsize=(12, 6))
|
@@ -244,90 +243,3 @@ def au_statistics_plot(frames, au_intensities_list):
|
|
244 |
ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
|
245 |
plt.tight_layout()
|
246 |
return fig
|
247 |
-
|
248 |
-
def preprocess_video_and_predict_sleep_quality(video):
|
249 |
-
cap = cv2.VideoCapture(video)
|
250 |
-
w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
251 |
-
h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
252 |
-
fps = np.round(cap.get(cv2.CAP_PROP_FPS))
|
253 |
-
|
254 |
-
path_save_video_original = 'result_original.mp4'
|
255 |
-
path_save_video_face = 'result_face.mp4'
|
256 |
-
path_save_video_sleep = 'result_sleep.mp4'
|
257 |
-
|
258 |
-
vid_writer_original = cv2.VideoWriter(path_save_video_original, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
|
259 |
-
vid_writer_face = cv2.VideoWriter(path_save_video_face, cv2.VideoWriter_fourcc(*'mp4v'), fps, (224, 224))
|
260 |
-
vid_writer_sleep = cv2.VideoWriter(path_save_video_sleep, cv2.VideoWriter_fourcc(*'mp4v'), fps, (224, 224))
|
261 |
-
|
262 |
-
frames = []
|
263 |
-
sleep_quality_scores = []
|
264 |
-
eye_bags_images = []
|
265 |
-
|
266 |
-
with mp_face_mesh.FaceMesh(
|
267 |
-
max_num_faces=1,
|
268 |
-
refine_landmarks=False,
|
269 |
-
min_detection_confidence=0.5,
|
270 |
-
min_tracking_confidence=0.5) as face_mesh:
|
271 |
-
|
272 |
-
while cap.isOpened():
|
273 |
-
ret, frame = cap.read()
|
274 |
-
if not ret:
|
275 |
-
break
|
276 |
-
|
277 |
-
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
278 |
-
results = face_mesh.process(frame_rgb)
|
279 |
-
|
280 |
-
if results.multi_face_landmarks:
|
281 |
-
for fl in results.multi_face_landmarks:
|
282 |
-
startX, startY, endX, endY = get_box(fl, w, h)
|
283 |
-
cur_face = frame_rgb[startY:endY, startX:endX]
|
284 |
-
|
285 |
-
sleep_quality_score, eye_bags_image = analyze_sleep_quality(cur_face)
|
286 |
-
sleep_quality_scores.append(sleep_quality_score)
|
287 |
-
eye_bags_images.append(cv2.resize(eye_bags_image, (224, 224)))
|
288 |
-
|
289 |
-
sleep_quality_viz = create_sleep_quality_visualization(cur_face, sleep_quality_score)
|
290 |
-
|
291 |
-
cur_face = cv2.resize(cur_face, (224, 224))
|
292 |
-
|
293 |
-
vid_writer_face.write(cv2.cvtColor(cur_face, cv2.COLOR_RGB2BGR))
|
294 |
-
vid_writer_sleep.write(sleep_quality_viz)
|
295 |
-
|
296 |
-
vid_writer_original.write(frame)
|
297 |
-
frames.append(len(frames) + 1)
|
298 |
-
|
299 |
-
cap.release()
|
300 |
-
vid_writer_original.release()
|
301 |
-
vid_writer_face.release()
|
302 |
-
vid_writer_sleep.release()
|
303 |
-
|
304 |
-
sleep_stat = sleep_quality_statistics_plot(frames, sleep_quality_scores)
|
305 |
-
|
306 |
-
if eye_bags_images:
|
307 |
-
average_eye_bags_image = np.mean(np.array(eye_bags_images), axis=0).astype(np.uint8)
|
308 |
-
else:
|
309 |
-
average_eye_bags_image = np.zeros((224, 224, 3), dtype=np.uint8)
|
310 |
-
|
311 |
-
return (path_save_video_original, path_save_video_face, path_save_video_sleep,
|
312 |
-
average_eye_bags_image, sleep_stat)
|
313 |
-
|
314 |
-
def analyze_sleep_quality(face_image):
|
315 |
-
# Placeholder function - implement your sleep quality analysis here
|
316 |
-
sleep_quality_score = np.random.random()
|
317 |
-
eye_bags_image = cv2.resize(face_image, (224, 224))
|
318 |
-
return sleep_quality_score, eye_bags_image
|
319 |
-
|
320 |
-
def create_sleep_quality_visualization(face_image, sleep_quality_score):
|
321 |
-
viz = face_image.copy()
|
322 |
-
cv2.putText(viz, f"Sleep Quality: {sleep_quality_score:.2f}", (10, 30),
|
323 |
-
cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
|
324 |
-
return cv2.cvtColor(viz, cv2.COLOR_RGB2BGR)
|
325 |
-
|
326 |
-
def sleep_quality_statistics_plot(frames, sleep_quality_scores):
|
327 |
-
# Placeholder function - implement your statistics plotting here
|
328 |
-
fig, ax = plt.subplots()
|
329 |
-
ax.plot(frames, sleep_quality_scores)
|
330 |
-
ax.set_xlabel('Frame')
|
331 |
-
ax.set_ylabel('Sleep Quality Score')
|
332 |
-
ax.set_title('Sleep Quality Over Time')
|
333 |
-
return fig
|
|
|
6 |
from pytorch_grad_cam.utils.image import show_cam_on_image
|
7 |
import matplotlib.pyplot as plt
|
8 |
|
9 |
+
# Importing necessary components
|
10 |
+
from app.model import load_models
|
11 |
from app.face_utils import get_box, display_info
|
12 |
from app.config import DICT_EMO, config_data
|
13 |
from app.plot import statistics_plot
|
14 |
|
15 |
+
# Initialize the face mesh detector
|
16 |
mp_face_mesh = mp.solutions.face_mesh
|
17 |
|
18 |
+
# Set the device directly
|
19 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
print(f"Using device: {device}")
|
21 |
|
22 |
+
# Load models and GradCAM
|
23 |
+
pth_model_static, pth_model_dynamic, cam = load_models()
|
24 |
+
|
25 |
# Move models to the selected device
|
26 |
pth_model_static = pth_model_static.to(device)
|
27 |
pth_model_dynamic = pth_model_dynamic.to(device)
|
|
|
149 |
startX, startY, endX, endY = get_box(fl, w, h)
|
150 |
cur_face = frame_copy[startY:endY, startX: endX]
|
151 |
|
152 |
+
if count_face % config_data.FRAME_DOWNSAMPLING == 0:
|
153 |
cur_face_copy = pth_processing(Image.fromarray(cur_face)).to(device)
|
154 |
with torch.no_grad():
|
155 |
features = torch.nn.functional.relu(pth_model_static.extract_features(cur_face_copy)).detach().cpu().numpy()
|
|
|
225 |
|
226 |
return video, path_save_video_face, path_save_video_hm, stat, au_stat
|
227 |
|
228 |
+
def features_to_au_intensities(features):
|
229 |
+
features_np = features.detach().cpu().numpy()[0]
|
230 |
+
au_intensities = (features_np - features_np.min()) / (features_np.max() - features_np.min())
|
231 |
+
return au_intensities[:24] # Assuming we want 24 AUs
|
232 |
|
233 |
def au_statistics_plot(frames, au_intensities_list):
|
234 |
fig, ax = plt.subplots(figsize=(12, 6))
|
|
|
243 |
ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
|
244 |
plt.tight_layout()
|
245 |
return fig
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|