Spaces:
Runtime error
Runtime error
Update visualization.py
Browse files- visualization.py +5 -4
visualization.py
CHANGED
|
@@ -216,7 +216,6 @@ def plot_posture(df, posture_scores, color='blue', anomaly_threshold=3):
|
|
| 216 |
plt.close()
|
| 217 |
return fig
|
| 218 |
|
| 219 |
-
|
| 220 |
def create_heatmap(t, mse_embeddings, mse_posture, mse_voice, video_fps, total_frames, video_width):
|
| 221 |
frame_count = int(t * video_fps)
|
| 222 |
|
|
@@ -230,14 +229,16 @@ def create_heatmap(t, mse_embeddings, mse_posture, mse_voice, video_fps, total_f
|
|
| 230 |
combined_mse[1] = mse_posture_norm
|
| 231 |
combined_mse[2] = mse_voice_norm
|
| 232 |
|
| 233 |
-
fig, ax = plt.subplots(figsize=(video_width /
|
| 234 |
ax.imshow(combined_mse, aspect='auto', cmap='hot', vmin=0, vmax=1, extent=[0, total_frames, 0, 3])
|
| 235 |
ax.set_yticks([0.5, 1.5, 2.5])
|
| 236 |
-
ax.set_yticklabels(['Voice', 'Posture', 'Face'])
|
| 237 |
ax.set_xticks([])
|
| 238 |
|
| 239 |
-
ax.axvline(x=frame_count, color='black', linewidth=
|
| 240 |
|
|
|
|
|
|
|
| 241 |
canvas = FigureCanvas(fig)
|
| 242 |
canvas.draw()
|
| 243 |
heatmap_img = np.frombuffer(canvas.tostring_rgb(), dtype='uint8')
|
|
|
|
| 216 |
plt.close()
|
| 217 |
return fig
|
| 218 |
|
|
|
|
| 219 |
def create_heatmap(t, mse_embeddings, mse_posture, mse_voice, video_fps, total_frames, video_width):
|
| 220 |
frame_count = int(t * video_fps)
|
| 221 |
|
|
|
|
| 229 |
combined_mse[1] = mse_posture_norm
|
| 230 |
combined_mse[2] = mse_voice_norm
|
| 231 |
|
| 232 |
+
fig, ax = plt.subplots(figsize=(video_width / 25, 0.2)) # Much thinner height, wider width
|
| 233 |
ax.imshow(combined_mse, aspect='auto', cmap='hot', vmin=0, vmax=1, extent=[0, total_frames, 0, 3])
|
| 234 |
ax.set_yticks([0.5, 1.5, 2.5])
|
| 235 |
+
ax.set_yticklabels(['Voice', 'Posture', 'Face'], fontsize=6) # Smaller font size
|
| 236 |
ax.set_xticks([])
|
| 237 |
|
| 238 |
+
ax.axvline(x=frame_count, color='black', linewidth=2) # Thinner line for smaller heatmap
|
| 239 |
|
| 240 |
+
plt.tight_layout(pad=0.1) # Reduce padding around the plot
|
| 241 |
+
|
| 242 |
canvas = FigureCanvas(fig)
|
| 243 |
canvas.draw()
|
| 244 |
heatmap_img = np.frombuffer(canvas.tostring_rgb(), dtype='uint8')
|