Spaces:
Configuration error
Configuration error
Upload 167 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +16 -0
- .gitignore +52 -0
- Dockerfile +36 -0
- LICENSE +21 -0
- README.md +168 -10
- core/README.md +58 -0
- core/bicep_model/1.data.ipynb +510 -0
- core/bicep_model/2.sklearn.ipynb +413 -0
- core/bicep_model/3.deep_learning.ipynb +1312 -0
- core/bicep_model/4.evaluation.ipynb +0 -0
- core/bicep_model/5.detection.ipynb +619 -0
- core/bicep_model/README.md +36 -0
- core/bicep_model/evaluation.csv +23 -0
- core/bicep_model/model/KNN_model.pkl +3 -0
- core/bicep_model/model/RF_model.pkl +3 -0
- core/bicep_model/model/all_dp.pkl +3 -0
- core/bicep_model/model/all_sklearn.pkl +3 -0
- core/bicep_model/model/bicep_dp.pkl +3 -0
- core/bicep_model/model/input_scaler.pkl +3 -0
- core/bicep_model/test.csv +0 -0
- core/bicep_model/train.csv +3 -0
- core/lunge_model/1.stage.data.ipynb +0 -0
- core/lunge_model/2.stage.sklearn.ipynb +753 -0
- core/lunge_model/3.stage.deep_learning.ipynb +767 -0
- core/lunge_model/4.stage.detection.ipynb +717 -0
- core/lunge_model/5.err.data.ipynb +562 -0
- core/lunge_model/6.err.sklearn.ipynb +777 -0
- core/lunge_model/7.err.deep_learning.ipynb +1366 -0
- core/lunge_model/8.err.evaluation.ipynb +0 -0
- core/lunge_model/9.err.detection.ipynb +714 -0
- core/lunge_model/README.md +43 -0
- core/lunge_model/err.evaluation.csv +23 -0
- core/lunge_model/err.test.csv +0 -0
- core/lunge_model/err.train.csv +3 -0
- core/lunge_model/knee_angle.csv +0 -0
- core/lunge_model/knee_angle_2.csv +0 -0
- core/lunge_model/model/dp/all_models.pkl +3 -0
- core/lunge_model/model/dp/err_lunge_dp.pkl +3 -0
- core/lunge_model/model/dp/stage_lunge_dp.pkl +3 -0
- core/lunge_model/model/input_scaler.pkl +3 -0
- core/lunge_model/model/sklearn/err_LR_model.pkl +3 -0
- core/lunge_model/model/sklearn/err_SGDC_model.pkl +3 -0
- core/lunge_model/model/sklearn/err_all_sklearn.pkl +3 -0
- core/lunge_model/model/sklearn/stage_LR_model.pkl +3 -0
- core/lunge_model/model/sklearn/stage_Ridge_model.pkl +3 -0
- core/lunge_model/model/sklearn/stage_SVC_model.pkl +3 -0
- core/lunge_model/stage.test.csv +0 -0
- core/lunge_model/stage.train.csv +3 -0
- core/plank_model/1.data.ipynb +706 -0
- core/plank_model/2.sklearn.ipynb +762 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,19 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
core/bicep_model/train.csv filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
core/lunge_model/err.train.csv filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
core/lunge_model/stage.train.csv filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
core/plank_model/train.csv filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
demo/bc_demo.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
demo/lunge_demo.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 42 |
+
demo/plank_demo.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 43 |
+
demo/squat_demo.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 44 |
+
docs/B1809677-Ngô[[:space:]]Hồng[[:space:]]Quốc[[:space:]]Bảo-Wrong[[:space:]]Pose[[:space:]]Dectection[[:space:]]based[[:space:]]on[[:space:]]Machine[[:space:]]Learning.docx filter=lfs diff=lfs merge=lfs -text
|
| 45 |
+
docs/B1809677-Ngô[[:space:]]Hồng[[:space:]]Quốc[[:space:]]Bảo-Wrong[[:space:]]Pose[[:space:]]Dectection[[:space:]]based[[:space:]]on[[:space:]]Machine[[:space:]]Learning.pdf filter=lfs diff=lfs merge=lfs -text
|
| 46 |
+
images/bicep_curl.gif filter=lfs diff=lfs merge=lfs -text
|
| 47 |
+
images/lunge.gif filter=lfs diff=lfs merge=lfs -text
|
| 48 |
+
images/plank.gif filter=lfs diff=lfs merge=lfs -text
|
| 49 |
+
images/squat.gif filter=lfs diff=lfs merge=lfs -text
|
| 50 |
+
images/web_3.png filter=lfs diff=lfs merge=lfs -text
|
| 51 |
+
images/web_4.png filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Files
|
| 2 |
+
.DS_Store
|
| 3 |
+
|
| 4 |
+
# Machine Learning Folders
|
| 5 |
+
.ipynb_checkpoints/
|
| 6 |
+
env/
|
| 7 |
+
data/
|
| 8 |
+
|
| 9 |
+
# Editor directories and files
|
| 10 |
+
.vscode
|
| 11 |
+
!.vscode/extensions.json
|
| 12 |
+
.idea
|
| 13 |
+
*.suo
|
| 14 |
+
*.ntvs*
|
| 15 |
+
*.njsproj
|
| 16 |
+
*.sln
|
| 17 |
+
*.sw?
|
| 18 |
+
|
| 19 |
+
# Vue clients
|
| 20 |
+
package-lock.json
|
| 21 |
+
logs
|
| 22 |
+
*.log
|
| 23 |
+
npm-debug.log*
|
| 24 |
+
yarn-debug.log*
|
| 25 |
+
yarn-error.log*
|
| 26 |
+
pnpm-debug.log*
|
| 27 |
+
lerna-debug.log*
|
| 28 |
+
|
| 29 |
+
node_modules
|
| 30 |
+
dist
|
| 31 |
+
dist-ssr
|
| 32 |
+
coverage
|
| 33 |
+
*.local
|
| 34 |
+
|
| 35 |
+
/cypress/videos/
|
| 36 |
+
/cypress/screenshots/
|
| 37 |
+
|
| 38 |
+
# Keras Tuner
|
| 39 |
+
**/keras_tuner_dir
|
| 40 |
+
|
| 41 |
+
# Django server
|
| 42 |
+
__pycache__/
|
| 43 |
+
db.sqlite3
|
| 44 |
+
temp/
|
| 45 |
+
|
| 46 |
+
# References
|
| 47 |
+
github/
|
| 48 |
+
web/*/static/media/
|
| 49 |
+
web/*/static/images/
|
| 50 |
+
web/*/static/assets/
|
| 51 |
+
web/*/static/css/
|
| 52 |
+
*.env
|
Dockerfile
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.8
|
| 2 |
+
|
| 3 |
+
# Install Node.js and npm
|
| 4 |
+
RUN apt-get update && apt-get install -y curl && \
|
| 5 |
+
curl -sL https://deb.nodesource.com/setup_16.x | bash - && \
|
| 6 |
+
apt-get install -y nodejs \
|
| 7 |
+
npm
|
| 8 |
+
|
| 9 |
+
RUN apt-get update && apt-get install -y rsync
|
| 10 |
+
RUN apt-get update && apt-get install ffmpeg libsm6 libxext6 -y
|
| 11 |
+
|
| 12 |
+
RUN export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
|
| 13 |
+
ENV PYTHONDONTWRITEBYTECODE=1
|
| 14 |
+
ENV PYTHONUNBUFFERED=1
|
| 15 |
+
|
| 16 |
+
# Set the working directory to /app
|
| 17 |
+
WORKDIR /app
|
| 18 |
+
|
| 19 |
+
# Copy the requirements file into the container
|
| 20 |
+
COPY requirements.txt .
|
| 21 |
+
|
| 22 |
+
# Install any needed packages specified in requirements.txt
|
| 23 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 24 |
+
|
| 25 |
+
# copy /web folder and install client's dependencies
|
| 26 |
+
COPY ./web /app
|
| 27 |
+
WORKDIR /app
|
| 28 |
+
RUN npm run install:client
|
| 29 |
+
RUN npm run build-deploy:client
|
| 30 |
+
|
| 31 |
+
# Expose port 8000 for the Django server
|
| 32 |
+
EXPOSE 8000
|
| 33 |
+
|
| 34 |
+
# Start the server
|
| 35 |
+
CMD ["python", "server/manage.py", "runserver", "0.0.0.0:8000"]
|
| 36 |
+
|
LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2022 Ngô Hồng Quốc Bảo
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
README.md
CHANGED
|
@@ -1,10 +1,168 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<div id="top"></div>
|
| 2 |
+
|
| 3 |
+
<!-- PROJECT LOGO -->
|
| 4 |
+
<br />
|
| 5 |
+
<div align="center">
|
| 6 |
+
<a href="https://github.com/NgoQuocBao1010/Exercise-Correction">
|
| 7 |
+
<img src="./images/logo.png" alt="Logo" width="60%">
|
| 8 |
+
</a>
|
| 9 |
+
|
| 10 |
+
<h2 align="center">Exercise Pose Correction</h2>
|
| 11 |
+
|
| 12 |
+
<p align="center">
|
| 13 |
+
Make use of the power of Mediapipe’s pose detection, this project is built in order to analyze, detect and classifying the forms of fitness exercises.
|
| 14 |
+
</p>
|
| 15 |
+
</div>
|
| 16 |
+
|
| 17 |
+
<!-- ABOUT THE PROJECT -->
|
| 18 |
+
|
| 19 |
+
## About The Project
|
| 20 |
+
|
| 21 |
+
This project goal is to develop 4 machine learning models for 4 of the most home exercises **(Bicep Curl, Plank, Squat and Lunge)** which each model can detect any form of incorrect movement while a person is performing a correspond exercise. In addition, a web application that utilize the trained models, will be built in other to analyze and provide feedbacks on workout videos.
|
| 22 |
+
|
| 23 |
+
Here are some detections of the exercises:
|
| 24 |
+
|
| 25 |
+
- Bicep Curl
|
| 26 |
+
<p align="center"><img src="images/bicep_curl.gif" alt="Logo" width="70%"></p>
|
| 27 |
+
|
| 28 |
+
- Basic Plank
|
| 29 |
+
<p align="center"><img src="images/plank.gif" alt="Logo" width="70%"></p>
|
| 30 |
+
|
| 31 |
+
- Basic Squat
|
| 32 |
+
<p align="center"><img src="images/squat.gif" alt="Logo" width="70%"></p>
|
| 33 |
+
|
| 34 |
+
- Lunge
|
| 35 |
+
<p align="center"><img src="images/lunge.gif" alt="Logo" width="70%"></p>
|
| 36 |
+
|
| 37 |
+
- Models' evaluation results and website screenshots [here](#usage)
|
| 38 |
+
|
| 39 |
+
<p align="right">(<a href="#top">back to top</a>)</p>
|
| 40 |
+
|
| 41 |
+
### Built With
|
| 42 |
+
|
| 43 |
+
1. For data processing and model training
|
| 44 |
+
|
| 45 |
+
- [Numpy](https://numpy.org/)
|
| 46 |
+
- [Pandas](https://pandas.pydata.org/)
|
| 47 |
+
- [Sklearn](https://scikit-learn.org/stable/)
|
| 48 |
+
- [Keras](https://keras.io/)
|
| 49 |
+
|
| 50 |
+
1. For building website
|
| 51 |
+
|
| 52 |
+
- [Vue.js v3](https://vuejs.org/)
|
| 53 |
+
- [Django](https://www.djangoproject.com/)
|
| 54 |
+
|
| 55 |
+
<p align="right">(<a href="#top">back to top</a>)</p>
|
| 56 |
+
|
| 57 |
+
## Dataset
|
| 58 |
+
|
| 59 |
+
Due to the lack of videos or dataset online that recorded human doing exercises both in a proper or improper way, the majority of self-collected videos were either recorded by myself, my friends or my family. The majority of the collected videos were removed due to privacy purpose.
|
| 60 |
+
|
| 61 |
+
With an exercise such as Plank, as there is not much movement during the exercise, I’m able to find a dataset from an open database from [Kaggle](https://www.kaggle.com/datasets/niharika41298/yoga-poses-dataset). The found dataset is about many yoga poses but the very well-known ones are the downward dog pose, goddess pose, tree pose, plank pose and the warrior pose. The dataset contains 5 folders for 5 poses, each folder contains images of people correctly doing the correspond pose.
|
| 62 |
+
|
| 63 |
+
For the purpose of this thesis, only the folder contains the images of people properly doing plank is chosen. There are 266 image files in that folder, I handpicked all the images that represent a basic plank and discard the reset. In conclusion, there are 30 images which are arranged to the proper form class for basic plank.
|
| 64 |
+
|
| 65 |
+
## Getting Started
|
| 66 |
+
|
| 67 |
+
This is an example of how you may give instructions on setting up the project locally.
|
| 68 |
+
|
| 69 |
+
#### Setting Up Environment
|
| 70 |
+
|
| 71 |
+
```
|
| 72 |
+
Python 3.8.13
|
| 73 |
+
Node 17.8.0
|
| 74 |
+
NPM 8.5.5
|
| 75 |
+
OS: Linux or MacOS
|
| 76 |
+
```
|
| 77 |
+
|
| 78 |
+
```markdown
|
| 79 |
+
NOTES
|
| 80 |
+
⚠️ Commands/Scripts for this project are wrote for Linux-based OS. They may not work on Windows machines.
|
| 81 |
+
```
|
| 82 |
+
|
| 83 |
+
### Installation
|
| 84 |
+
|
| 85 |
+
_If you only want to try the website, look [here](./web/README.md)._
|
| 86 |
+
|
| 87 |
+
1. Clone the repo and change directory to that folder
|
| 88 |
+
|
| 89 |
+
```sh
|
| 90 |
+
git clone https://github.com/NgoQuocBao1010/Exercise-Correction.git
|
| 91 |
+
```
|
| 92 |
+
|
| 93 |
+
1. Install all project dependencies
|
| 94 |
+
|
| 95 |
+
```bash
|
| 96 |
+
pip install -r requirements.txt
|
| 97 |
+
```
|
| 98 |
+
|
| 99 |
+
1. Folder **_[core](./core/README.md)_** is the code for data processing and model training.
|
| 100 |
+
1. Folder **_[web](./web/README.md)_** is the code for website.
|
| 101 |
+
|
| 102 |
+
<p align="right">(<a href="#top">back to top</a>)</p>
|
| 103 |
+
|
| 104 |
+
<!-- USAGE EXAMPLES -->
|
| 105 |
+
<div id="Usage"></div>
|
| 106 |
+
<br/>
|
| 107 |
+
|
| 108 |
+
## Usage
|
| 109 |
+
|
| 110 |
+
As the introduction indicated, there are 2 purposes for this project.
|
| 111 |
+
|
| 112 |
+
1. Model training **(describe in depth [here](core/README.md))**. Below are the evaluation results for each models.
|
| 113 |
+
|
| 114 |
+
- [Bicep Curl](core/bicep_model/README.md) - _lean back error_: Confusion Matrix - ROC curve
|
| 115 |
+
| <img align="center" alt="Bicep Curl evaluation" src="images/bicep_curl_eval.png" /> | <img align="center" alt="NgoQuocBao's Top Languages" src="images/bicep_curl_eval_2.png" /> |
|
| 116 |
+
| ------------- | ------------- |
|
| 117 |
+
- [Plank](core/plank_model/README.md) - _all errors_: Confusion Matrix - ROC curve
|
| 118 |
+
| <img align="center" alt="Plank evaluation" src="images/plank_eval.png" /> | <img align="center" alt="NgoQuocBao's Top Languages" src="images/plank_eval_2.png" /> |
|
| 119 |
+
| ------------- | ------------- |
|
| 120 |
+
- [Basic Squat](core/squat_model/README.md) - _stage_: Confusion Matrix - ROC curve
|
| 121 |
+
| <img align="center" alt="Squat evaluation" src="images/squat_eval.png" /> | <img align="center" alt="NgoQuocBao's Top Languages" src="images/squat_eval_2.png" /> |
|
| 122 |
+
| ------------- | ------------- |
|
| 123 |
+
- [Lunge](core/lunge_model/README.md) - _knee over toe error_: Confusion Matrix - ROC curve
|
| 124 |
+
| <img align="center" alt="Lunge evaluation" src="images/lunge_eval.png" /> | <img align="center" alt="NgoQuocBao's Top Languages" src="images/lunge_eval_2.png" /> |
|
| 125 |
+
| ------------- | ------------- |
|
| 126 |
+
|
| 127 |
+
1. Website for exercise detection. This web is for demonstration purpose of all the trained models, therefore, at the moment there are only 1 main features: Analyzing and giving feedbacks on user's exercise video.
|
| 128 |
+
<p align="center"><img src="images/web_1.png" alt="Logo" width="70%"></p>
|
| 129 |
+
<p align="center"><img src="images/web_2.png" alt="Logo" width="70%"></p>
|
| 130 |
+
<p align="center"><img src="images/web_3.png" alt="Logo" width="70%"></p>
|
| 131 |
+
<p align="center"><img src="images/web_4.png" alt="Logo" width="70%"></p>
|
| 132 |
+
|
| 133 |
+
<p align="right">(<a href="#top">back to top</a>)</p>
|
| 134 |
+
|
| 135 |
+
<!-- CONTRIBUTING -->
|
| 136 |
+
|
| 137 |
+
## Contributing
|
| 138 |
+
|
| 139 |
+
Contributions are what make the open source community such an amazing place to learn, inspire, and create. Any contributions you make are **greatly appreciated**.
|
| 140 |
+
|
| 141 |
+
If you have a suggestion that would make this better, please fork the repo and create a pull request. You can also simply open an issue with the tag "enhancement".
|
| 142 |
+
Don't forget to give the project a star! Thanks again!
|
| 143 |
+
|
| 144 |
+
1. Fork the Project
|
| 145 |
+
2. Create your Feature Branch (`git checkout -b feature/AmazingFeature`)
|
| 146 |
+
3. Commit your Changes (`git commit -m 'Add some AmazingFeature'`)
|
| 147 |
+
4. Push to the Branch (`git push origin feature/AmazingFeature`)
|
| 148 |
+
5. Open a Pull Request
|
| 149 |
+
|
| 150 |
+
<p align="right">(<a href="#top">back to top</a>)</p>
|
| 151 |
+
|
| 152 |
+
<!-- LICENSE -->
|
| 153 |
+
|
| 154 |
+
## License
|
| 155 |
+
|
| 156 |
+
Distributed under the MIT License.
|
| 157 |
+
|
| 158 |
+
<p align="right">(<a href="#top">back to top</a>)</p>
|
| 159 |
+
|
| 160 |
+
<!-- ACKNOWLEDGMENTS -->
|
| 161 |
+
|
| 162 |
+
## Acknowledgments
|
| 163 |
+
|
| 164 |
+
- Here are some other projects which I get inspired from: [Pose Trainer](https://github.com/stevenzchen/pose-trainer), [Deep Learning Fitness Exercise Correction Keras](https://github.com/Vollkorn01/Deep-Learning-Fitness-Exercise-Correction-Keras) and [Posture](https://github.com/twixupmysleeve/Posture).
|
| 165 |
+
- [Logo marker](https://www4.flamingtext.com/) for this project.
|
| 166 |
+
- This awesome README template is from [Best README Template](https://github.com/othneildrew/Best-README-Template). ♥
|
| 167 |
+
|
| 168 |
+
<p align="right">(<a href="#top">back to top</a>)</p>
|
core/README.md
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<h2 align="center">Build Machine Learning Model</h2>
|
| 2 |
+
|
| 3 |
+
Brief overview about the methodology of building models for exercise pose detection.
|
| 4 |
+
To go in depth on each exercise, click the link below:
|
| 5 |
+
|
| 6 |
+
- [Bicep Curl](./bicep_model/README.md)
|
| 7 |
+
- [Plank](./plank_model/README.md)
|
| 8 |
+
- [Basic Squat](./squat_model/README.md)
|
| 9 |
+
- [Lunge](./lunge_model/README.md)
|
| 10 |
+
|
| 11 |
+
### 1. Simple error detection
|
| 12 |
+
|
| 13 |
+
For some simple errors (for example, the feet placement error in squat), the detection method is either measuring the distance/angle between different joints during the exercise with the coordinate outputs from MediaPipe Pose.
|
| 14 |
+
|
| 15 |
+
- **_Distance Calculation_**
|
| 16 |
+
Assume there are 2 points with the following coordinates: Point 1 (x1,y1) and Point 2 (x2,y2), below is the formula to calculate the distance between 2 points.
|
| 17 |
+
|
| 18 |
+
```
|
| 19 |
+
distance= √((x1-x2)^2 +(y1-y2) ^2 )
|
| 20 |
+
```
|
| 21 |
+
|
| 22 |
+
- **_Angle Calculation_**
|
| 23 |
+
Assume there are 3 points with the following coordinates: Point 1 (x1,y1), Point 2 (x2,y2) and Point 3 (x3,y3), below is the formula to calculate the angle created by 3 points.
|
| 24 |
+
```
|
| 25 |
+
angle_in_radian =arctan2(y3-y2,x3-x2) -arctan2(y1-y2,x1-x2)
|
| 26 |
+
angle_in_degree=(angle_in_rad \* 180)/Π
|
| 27 |
+
```
|
| 28 |
+
|
| 29 |
+
### 2. Model Training for Error Detection
|
| 30 |
+
|
| 31 |
+
#### 1. Pick important landmarks
|
| 32 |
+
|
| 33 |
+
For each exercise, there will be different poses/body’s position, therefore it is essential to identify which parts (shoulder, hip, …) of a body are contribute to the exercise. The important landmarks identified for each exercise are utilized to extract body part’s position while exercising using MediaPipe.
|
| 34 |
+
|
| 35 |
+
#### 2. Data Processing
|
| 36 |
+
|
| 37 |
+
<p align="center"><img src="../images/data_processing.png" alt="Logo" width="70%" style="background-color:#f5f5f5"></p>
|
| 38 |
+
|
| 39 |
+
#### 3. Model training
|
| 40 |
+
|
| 41 |
+
There are 2 methods used in this thesis for model training. For each exercise, the models trained for each method will be compared and the best model will be chosen.
|
| 42 |
+
|
| 43 |
+
- Classification with Scikit-learn. (Decision Tree/Random Forest (RF), K-Nearest Neighbors (KNN), C-Support Vector (SVC), Logistic Regression classifier (LR) and Stochastic Gradient Descent classifier (SGDC)).
|
| 44 |
+
- Building a Neural Network for classification with Keras.
|
| 45 |
+
|
| 46 |
+
### 3. Evaluation results of all models
|
| 47 |
+
|
| 48 |
+
1. Bicep Curl - _lean back error_
|
| 49 |
+
<p align="center"><img src="../images/bicep_curl_eval_3.png" alt="Logo" width="70%"></p>
|
| 50 |
+
|
| 51 |
+
2. Plank - _all errors_
|
| 52 |
+
<p align="center"><img src="../images/plank_eval_3.png" alt="Logo" width="70%"></p>
|
| 53 |
+
|
| 54 |
+
3. Basic Squat - _stage_
|
| 55 |
+
<p align="center"><img src="../images/squat_eval_3.png" alt="Logo" width="70%"></p>
|
| 56 |
+
|
| 57 |
+
4. Lunge - _knee over toe error_
|
| 58 |
+
<p align="center"><img src="../images/lunge_eval_3.png" alt="Logo" width="70%"></p>
|
core/bicep_model/1.data.ipynb
ADDED
|
@@ -0,0 +1,510 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": 1,
|
| 6 |
+
"metadata": {},
|
| 7 |
+
"outputs": [
|
| 8 |
+
{
|
| 9 |
+
"name": "stderr",
|
| 10 |
+
"output_type": "stream",
|
| 11 |
+
"text": [
|
| 12 |
+
"objc[67754]: Class CaptureDelegate is implemented in both /Users/fuixlabsdev1/Programming/PP/graduation-thesis/env/lib/python3.8/site-packages/mediapipe/.dylibs/libopencv_videoio.3.4.16.dylib (0x10a8c8860) and /Users/fuixlabsdev1/Programming/PP/graduation-thesis/env/lib/python3.8/site-packages/cv2/cv2.abi3.so (0x161476480). One of the two will be used. Which one is undefined.\n",
|
| 13 |
+
"objc[67754]: Class CVWindow is implemented in both /Users/fuixlabsdev1/Programming/PP/graduation-thesis/env/lib/python3.8/site-packages/mediapipe/.dylibs/libopencv_highgui.3.4.16.dylib (0x10567ca68) and /Users/fuixlabsdev1/Programming/PP/graduation-thesis/env/lib/python3.8/site-packages/cv2/cv2.abi3.so (0x1614764d0). One of the two will be used. Which one is undefined.\n",
|
| 14 |
+
"objc[67754]: Class CVView is implemented in both /Users/fuixlabsdev1/Programming/PP/graduation-thesis/env/lib/python3.8/site-packages/mediapipe/.dylibs/libopencv_highgui.3.4.16.dylib (0x10567ca90) and /Users/fuixlabsdev1/Programming/PP/graduation-thesis/env/lib/python3.8/site-packages/cv2/cv2.abi3.so (0x1614764f8). One of the two will be used. Which one is undefined.\n",
|
| 15 |
+
"objc[67754]: Class CVSlider is implemented in both /Users/fuixlabsdev1/Programming/PP/graduation-thesis/env/lib/python3.8/site-packages/mediapipe/.dylibs/libopencv_highgui.3.4.16.dylib (0x10567cab8) and /Users/fuixlabsdev1/Programming/PP/graduation-thesis/env/lib/python3.8/site-packages/cv2/cv2.abi3.so (0x161476520). One of the two will be used. Which one is undefined.\n"
|
| 16 |
+
]
|
| 17 |
+
}
|
| 18 |
+
],
|
| 19 |
+
"source": [
|
| 20 |
+
"import mediapipe as mp\n",
|
| 21 |
+
"import cv2\n",
|
| 22 |
+
"import numpy as np\n",
|
| 23 |
+
"import pandas as pd\n",
|
| 24 |
+
"import os, csv\n",
|
| 25 |
+
"import seaborn as sns\n",
|
| 26 |
+
"\n",
|
| 27 |
+
"import warnings\n",
|
| 28 |
+
"warnings.filterwarnings('ignore')\n",
|
| 29 |
+
"\n",
|
| 30 |
+
"# Drawing helpers\n",
|
| 31 |
+
"mp_drawing = mp.solutions.drawing_utils\n",
|
| 32 |
+
"mp_pose = mp.solutions.pose"
|
| 33 |
+
]
|
| 34 |
+
},
|
| 35 |
+
{
|
| 36 |
+
"cell_type": "markdown",
|
| 37 |
+
"metadata": {},
|
| 38 |
+
"source": [
|
| 39 |
+
"### 1. Describe the data gathering process and build dataset from Video\n",
|
| 40 |
+
"\n",
|
| 41 |
+
"The purpose is to gather data to determine the correct standing posture for Bicep Curl exercise\n",
|
| 42 |
+
"There are 2 stages:\n",
|
| 43 |
+
"- Correct: \"C\"\n",
|
| 44 |
+
"- Lean-back-error: \"L\""
|
| 45 |
+
]
|
| 46 |
+
},
|
| 47 |
+
{
|
| 48 |
+
"cell_type": "code",
|
| 49 |
+
"execution_count": 2,
|
| 50 |
+
"metadata": {},
|
| 51 |
+
"outputs": [],
|
| 52 |
+
"source": [
|
| 53 |
+
"# Determine important landmarks for plank\n",
|
| 54 |
+
"IMPORTANT_LMS = [\n",
|
| 55 |
+
" \"NOSE\",\n",
|
| 56 |
+
" \"LEFT_SHOULDER\",\n",
|
| 57 |
+
" \"RIGHT_SHOULDER\",\n",
|
| 58 |
+
" \"RIGHT_ELBOW\",\n",
|
| 59 |
+
" \"LEFT_ELBOW\",\n",
|
| 60 |
+
" \"RIGHT_WRIST\",\n",
|
| 61 |
+
" \"LEFT_WRIST\",\n",
|
| 62 |
+
" \"LEFT_HIP\",\n",
|
| 63 |
+
" \"RIGHT_HIP\",\n",
|
| 64 |
+
"]\n",
|
| 65 |
+
"\n",
|
| 66 |
+
"# Generate all columns of the data frame\n",
|
| 67 |
+
"\n",
|
| 68 |
+
"HEADERS = [\"label\"] # Label column\n",
|
| 69 |
+
"\n",
|
| 70 |
+
"for lm in IMPORTANT_LMS:\n",
|
| 71 |
+
" HEADERS += [f\"{lm.lower()}_x\", f\"{lm.lower()}_y\", f\"{lm.lower()}_z\", f\"{lm.lower()}_v\"]"
|
| 72 |
+
]
|
| 73 |
+
},
|
| 74 |
+
{
|
| 75 |
+
"cell_type": "markdown",
|
| 76 |
+
"metadata": {},
|
| 77 |
+
"source": [
|
| 78 |
+
"#### 1.2. Set up important functions"
|
| 79 |
+
]
|
| 80 |
+
},
|
| 81 |
+
{
|
| 82 |
+
"cell_type": "code",
|
| 83 |
+
"execution_count": 3,
|
| 84 |
+
"metadata": {},
|
| 85 |
+
"outputs": [],
|
| 86 |
+
"source": [
|
| 87 |
+
"def rescale_frame(frame, percent=50):\n",
|
| 88 |
+
" '''\n",
|
| 89 |
+
" Rescale a frame to a certain percentage compare to its original frame\n",
|
| 90 |
+
" '''\n",
|
| 91 |
+
" width = int(frame.shape[1] * percent/ 100)\n",
|
| 92 |
+
" height = int(frame.shape[0] * percent/ 100)\n",
|
| 93 |
+
" dim = (width, height)\n",
|
| 94 |
+
" return cv2.resize(frame, dim, interpolation = cv2.INTER_AREA)\n",
|
| 95 |
+
" \n",
|
| 96 |
+
"\n",
|
| 97 |
+
"def init_csv(dataset_path: str):\n",
|
| 98 |
+
" '''\n",
|
| 99 |
+
" Create a blank csv file with just columns\n",
|
| 100 |
+
" '''\n",
|
| 101 |
+
"\n",
|
| 102 |
+
" # Ignore if file is already exist\n",
|
| 103 |
+
" if os.path.exists(dataset_path):\n",
|
| 104 |
+
" return\n",
|
| 105 |
+
"\n",
|
| 106 |
+
" # Write all the columns to a empty file\n",
|
| 107 |
+
" with open(dataset_path, mode=\"w\", newline=\"\") as f:\n",
|
| 108 |
+
" csv_writer = csv.writer(f, delimiter=\",\", quotechar='\"', quoting=csv.QUOTE_MINIMAL)\n",
|
| 109 |
+
" csv_writer.writerow(HEADERS)\n",
|
| 110 |
+
"\n",
|
| 111 |
+
"\n",
|
| 112 |
+
"def export_landmark_to_csv(dataset_path: str, results, action: str) -> None:\n",
|
| 113 |
+
" '''\n",
|
| 114 |
+
" Export Labeled Data from detected landmark to csv\n",
|
| 115 |
+
" '''\n",
|
| 116 |
+
" landmarks = results.pose_landmarks.landmark\n",
|
| 117 |
+
" keypoints = []\n",
|
| 118 |
+
"\n",
|
| 119 |
+
" try:\n",
|
| 120 |
+
" # Extract coordinate of important landmarks\n",
|
| 121 |
+
" for lm in IMPORTANT_LMS:\n",
|
| 122 |
+
" keypoint = landmarks[mp_pose.PoseLandmark[lm].value]\n",
|
| 123 |
+
" keypoints.append([keypoint.x, keypoint.y, keypoint.z, keypoint.visibility])\n",
|
| 124 |
+
" \n",
|
| 125 |
+
" keypoints = list(np.array(keypoints).flatten())\n",
|
| 126 |
+
"\n",
|
| 127 |
+
" # Insert action as the label (first column)\n",
|
| 128 |
+
" keypoints.insert(0, action)\n",
|
| 129 |
+
"\n",
|
| 130 |
+
" # Append new row to .csv file\n",
|
| 131 |
+
" with open(dataset_path, mode=\"a\", newline=\"\") as f:\n",
|
| 132 |
+
" csv_writer = csv.writer(f, delimiter=\",\", quotechar='\"', quoting=csv.QUOTE_MINIMAL)\n",
|
| 133 |
+
" csv_writer.writerow(keypoints)\n",
|
| 134 |
+
" \n",
|
| 135 |
+
"\n",
|
| 136 |
+
" except Exception as e:\n",
|
| 137 |
+
" print(e)\n",
|
| 138 |
+
" pass\n",
|
| 139 |
+
"\n",
|
| 140 |
+
"\n",
|
| 141 |
+
"def describe_dataset(dataset_path: str):\n",
|
| 142 |
+
" '''\n",
|
| 143 |
+
" Describe dataset\n",
|
| 144 |
+
" '''\n",
|
| 145 |
+
"\n",
|
| 146 |
+
" data = pd.read_csv(dataset_path)\n",
|
| 147 |
+
" print(f\"Headers: {list(data.columns.values)}\")\n",
|
| 148 |
+
" print(f'Number of rows: {data.shape[0]} \\nNumber of columns: {data.shape[1]}\\n')\n",
|
| 149 |
+
" print(f\"Labels: \\n{data['label'].value_counts()}\\n\")\n",
|
| 150 |
+
" print(f\"Missing values: {data.isnull().values.any()}\\n\")\n",
|
| 151 |
+
" \n",
|
| 152 |
+
" duplicate = data[data.duplicated()]\n",
|
| 153 |
+
" print(f\"Duplicate Rows : {len(duplicate.sum(axis=1))}\")\n",
|
| 154 |
+
"\n",
|
| 155 |
+
" return data\n",
|
| 156 |
+
"\n",
|
| 157 |
+
"\n",
|
| 158 |
+
"def remove_duplicate_rows(dataset_path: str):\n",
|
| 159 |
+
" '''\n",
|
| 160 |
+
" Remove duplicated data from the dataset then save it to another files\n",
|
| 161 |
+
" '''\n",
|
| 162 |
+
" \n",
|
| 163 |
+
" df = pd.read_csv(dataset_path)\n",
|
| 164 |
+
" df.drop_duplicates(keep=\"first\", inplace=True)\n",
|
| 165 |
+
" df.to_csv(f\"cleaned_train.csv\", sep=',', encoding='utf-8', index=False)\n",
|
| 166 |
+
" \n",
|
| 167 |
+
"\n",
|
| 168 |
+
"def concat_csv_files_with_same_headers(file_paths: list, saved_path: str):\n",
|
| 169 |
+
" '''\n",
|
| 170 |
+
" Concat different csv files\n",
|
| 171 |
+
" '''\n",
|
| 172 |
+
" all_df = []\n",
|
| 173 |
+
" for path in file_paths:\n",
|
| 174 |
+
" df = pd.read_csv(path, index_col=None, header=0)\n",
|
| 175 |
+
" all_df.append(df)\n",
|
| 176 |
+
" \n",
|
| 177 |
+
" results = pd.concat(all_df, axis=0, ignore_index=True)\n",
|
| 178 |
+
" results.to_csv(saved_path, sep=',', encoding='utf-8', index=False)"
|
| 179 |
+
]
|
| 180 |
+
},
|
| 181 |
+
{
|
| 182 |
+
"cell_type": "markdown",
|
| 183 |
+
"metadata": {},
|
| 184 |
+
"source": [
|
| 185 |
+
"### 2. Extract data from video"
|
| 186 |
+
]
|
| 187 |
+
},
|
| 188 |
+
{
|
| 189 |
+
"cell_type": "code",
|
| 190 |
+
"execution_count": null,
|
| 191 |
+
"metadata": {},
|
| 192 |
+
"outputs": [
|
| 193 |
+
{
|
| 194 |
+
"ename": "",
|
| 195 |
+
"evalue": "",
|
| 196 |
+
"output_type": "error",
|
| 197 |
+
"traceback": [
|
| 198 |
+
"\u001b[1;31mThe Kernel crashed while executing code in the the current cell or a previous cell. Please review the code in the cell(s) to identify a possible cause of the failure. Click <a href='https://aka.ms/vscodeJupyterKernelCrash'>here</a> for more info. View Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
|
| 199 |
+
]
|
| 200 |
+
}
|
| 201 |
+
],
|
| 202 |
+
"source": [
|
| 203 |
+
"DATASET_PATH = \"train.csv\"\n",
|
| 204 |
+
"\n",
|
| 205 |
+
"cap = cv2.VideoCapture(\"../data/db_curl/stand_posture_11.mp4\")\n",
|
| 206 |
+
"save_counts = 0\n",
|
| 207 |
+
"\n",
|
| 208 |
+
"# init_csv(DATASET_PATH)\n",
|
| 209 |
+
"\n",
|
| 210 |
+
"with mp_pose.Pose(min_detection_confidence=0.5, min_tracking_confidence=0.5) as pose:\n",
|
| 211 |
+
" while cap.isOpened():\n",
|
| 212 |
+
" ret, image = cap.read()\n",
|
| 213 |
+
"\n",
|
| 214 |
+
" if not ret:\n",
|
| 215 |
+
" break\n",
|
| 216 |
+
"\n",
|
| 217 |
+
" # Reduce size of a frame\n",
|
| 218 |
+
" image = rescale_frame(image, 60)\n",
|
| 219 |
+
" image = cv2.flip(image, 1)\n",
|
| 220 |
+
"\n",
|
| 221 |
+
" # Recolor image from BGR to RGB for mediapipe\n",
|
| 222 |
+
" image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)\n",
|
| 223 |
+
" image.flags.writeable = False\n",
|
| 224 |
+
"\n",
|
| 225 |
+
" results = pose.process(image)\n",
|
| 226 |
+
"\n",
|
| 227 |
+
" if not results.pose_landmarks:\n",
|
| 228 |
+
" print(\"Cannot detect pose - No human found\")\n",
|
| 229 |
+
" continue\n",
|
| 230 |
+
"\n",
|
| 231 |
+
" # Recolor image from BGR to RGB for mediapipe\n",
|
| 232 |
+
" image.flags.writeable = True\n",
|
| 233 |
+
" image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)\n",
|
| 234 |
+
"\n",
|
| 235 |
+
" # Draw landmarks and connections\n",
|
| 236 |
+
" mp_drawing.draw_landmarks(image, results.pose_landmarks, mp_pose.POSE_CONNECTIONS, mp_drawing.DrawingSpec(color=(244, 117, 66), thickness=2, circle_radius=4), mp_drawing.DrawingSpec(color=(245, 66, 230), thickness=2, circle_radius=2))\n",
|
| 237 |
+
"\n",
|
| 238 |
+
" # Display the saved count\n",
|
| 239 |
+
" cv2.putText(image, f\"Saved: {save_counts}\", (50, 50), cv2.FONT_HERSHEY_COMPLEX, 2, (0, 0, 0), 2, cv2.LINE_AA)\n",
|
| 240 |
+
"\n",
|
| 241 |
+
" cv2.imshow(\"CV2\", image)\n",
|
| 242 |
+
"\n",
|
| 243 |
+
" # Pressed key for action\n",
|
| 244 |
+
" k = cv2.waitKey(1) & 0xFF\n",
|
| 245 |
+
"\n",
|
| 246 |
+
" # Press C to save as correct form\n",
|
| 247 |
+
" if k == ord('c'): \n",
|
| 248 |
+
" export_landmark_to_csv(DATASET_PATH, results, \"C\")\n",
|
| 249 |
+
" save_counts += 1\n",
|
| 250 |
+
" # Press L to save as low back\n",
|
| 251 |
+
" elif k == ord(\"l\"):\n",
|
| 252 |
+
" export_landmark_to_csv(DATASET_PATH, results, \"L\")\n",
|
| 253 |
+
" save_counts += 1\n",
|
| 254 |
+
"\n",
|
| 255 |
+
" # Press q to stop\n",
|
| 256 |
+
" elif k == ord(\"q\"):\n",
|
| 257 |
+
" break\n",
|
| 258 |
+
" else: continue\n",
|
| 259 |
+
"\n",
|
| 260 |
+
" cap.release()\n",
|
| 261 |
+
" cv2.destroyAllWindows()\n",
|
| 262 |
+
"\n",
|
| 263 |
+
" # (Optional)Fix bugs cannot close windows in MacOS (https://stackoverflow.com/questions/6116564/destroywindow-does-not-close-window-on-mac-using-python-and-opencv)\n",
|
| 264 |
+
" for i in range (1, 5):\n",
|
| 265 |
+
" cv2.waitKey(1)\n",
|
| 266 |
+
" "
|
| 267 |
+
]
|
| 268 |
+
},
|
| 269 |
+
{
|
| 270 |
+
"cell_type": "code",
|
| 271 |
+
"execution_count": null,
|
| 272 |
+
"metadata": {},
|
| 273 |
+
"outputs": [],
|
| 274 |
+
"source": [
|
| 275 |
+
"# csv_files = [os.path.join(\"./\", f) for f in os.listdir(\"./\") if \"csv\" in f]\n",
|
| 276 |
+
"\n",
|
| 277 |
+
"# concat_csv_files_with_same_headers(csv_files, \"train.csv\")\n",
|
| 278 |
+
"\n",
|
| 279 |
+
"df = describe_dataset(\"./train.csv\")"
|
| 280 |
+
]
|
| 281 |
+
},
|
| 282 |
+
{
|
| 283 |
+
"cell_type": "markdown",
|
| 284 |
+
"metadata": {},
|
| 285 |
+
"source": [
|
| 286 |
+
"### 3. Clean Data and Visualize data"
|
| 287 |
+
]
|
| 288 |
+
},
|
| 289 |
+
{
|
| 290 |
+
"cell_type": "code",
|
| 291 |
+
"execution_count": null,
|
| 292 |
+
"metadata": {},
|
| 293 |
+
"outputs": [],
|
| 294 |
+
"source": [
|
| 295 |
+
"remove_duplicate_rows(\"./train.csv\")"
|
| 296 |
+
]
|
| 297 |
+
},
|
| 298 |
+
{
|
| 299 |
+
"cell_type": "code",
|
| 300 |
+
"execution_count": 5,
|
| 301 |
+
"metadata": {},
|
| 302 |
+
"outputs": [
|
| 303 |
+
{
|
| 304 |
+
"name": "stdout",
|
| 305 |
+
"output_type": "stream",
|
| 306 |
+
"text": [
|
| 307 |
+
"Headers: ['label', 'nose_x', 'nose_y', 'nose_z', 'nose_v', 'left_shoulder_x', 'left_shoulder_y', 'left_shoulder_z', 'left_shoulder_v', 'right_shoulder_x', 'right_shoulder_y', 'right_shoulder_z', 'right_shoulder_v', 'right_elbow_x', 'right_elbow_y', 'right_elbow_z', 'right_elbow_v', 'left_elbow_x', 'left_elbow_y', 'left_elbow_z', 'left_elbow_v', 'right_wrist_x', 'right_wrist_y', 'right_wrist_z', 'right_wrist_v', 'left_wrist_x', 'left_wrist_y', 'left_wrist_z', 'left_wrist_v', 'left_hip_x', 'left_hip_y', 'left_hip_z', 'left_hip_v', 'right_hip_x', 'right_hip_y', 'right_hip_z', 'right_hip_v']\n",
|
| 308 |
+
"Number of rows: 15372 \n",
|
| 309 |
+
"Number of columns: 37\n",
|
| 310 |
+
"\n",
|
| 311 |
+
"Labels: \n",
|
| 312 |
+
"C 8238\n",
|
| 313 |
+
"L 7134\n",
|
| 314 |
+
"Name: label, dtype: int64\n",
|
| 315 |
+
"\n",
|
| 316 |
+
"Missing values: False\n",
|
| 317 |
+
"\n",
|
| 318 |
+
"Duplicate Rows : 0\n"
|
| 319 |
+
]
|
| 320 |
+
},
|
| 321 |
+
{
|
| 322 |
+
"data": {
|
| 323 |
+
"text/plain": [
|
| 324 |
+
"<AxesSubplot:xlabel='label', ylabel='count'>"
|
| 325 |
+
]
|
| 326 |
+
},
|
| 327 |
+
"execution_count": 5,
|
| 328 |
+
"metadata": {},
|
| 329 |
+
"output_type": "execute_result"
|
| 330 |
+
},
|
| 331 |
+
{
|
| 332 |
+
"data": {
|
| 333 |
+
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAkQAAAGwCAYAAABIC3rIAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/NK7nSAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAvuElEQVR4nO3df3RU9Z3/8ddIyBgwuRJIZph11LBmESSoGzQEq2TldxtTDz1CG3cWDwjYIGkKFMryteKvBNECq9lS5FBD+XHw7I9Y99SOBLdmixASoqmCEXXNCqwZgnUyIZgmGOb7h+WuQyjSSDITPs/HOfcc5nPf85n3h3Mgr/OZe28c4XA4LAAAAINdFu0GAAAAoo1ABAAAjEcgAgAAxiMQAQAA4xGIAACA8QhEAADAeAQiAABgvLhoN9BXnD59Wh9//LESExPlcDii3Q4AALgA4XBYJ06ckMfj0WWX/fl9IALRBfr444/l9Xqj3QYAAOiGI0eO6Kqrrvqz5wlEFygxMVHSF3+hSUlJUe4GAABciJaWFnm9Xvvn+J9DILpAZ74mS0pKIhABANDHfNXlLlxUDQAAjEcgAgAAxiMQAQAA4xGIAACA8QhEAADAeAQiAABgPAIRAAAwHoEIAAAYj0AEAACMRyACAADGIxABAADjEYgAAIDxCEQAAMB4BCIAAGA8AhEAADBeXLQbQKT9Y26NdgtAzBmzvzraLQC4xLFDBAAAjEcgAgAAxiMQAQAA4xGIAACA8QhEAADAeFENRJ9//rn+3//7f0pLS1NCQoKGDRumRx99VKdPn7ZrwuGwVq5cKY/Ho4SEBOXk5OjgwYMR87S3t2vhwoUaMmSIBg4cqLy8PB09ejSiJhgMyufzybIsWZYln8+n5ubm3lgmAACIcVENRE8++aR+/vOfq7S0VPX19Vq9erWeeuopPfvss3bN6tWrtWbNGpWWlqqmpkZut1uTJk3SiRMn7JqioiKVl5drx44d2r17t1pbW5Wbm6vOzk67Jj8/X3V1dfL7/fL7/aqrq5PP5+vV9QIAgNjkCIfD4Wh9eG5urlwulzZt2mSPfec739GAAQO0ZcsWhcNheTweFRUVadmyZZK+2A1yuVx68sknNX/+fIVCIaWkpGjLli2aOXOmJOnjjz+W1+vVyy+/rClTpqi+vl4jR45UVVWVsrKyJElVVVXKzs7Wu+++q+HDh39lry0tLbIsS6FQSElJST3wt/EFnkMEdMVziAB014X+/I7qDtE3vvENvfrqq3rvvfckSb///e+1e/duffOb35QkNTQ0KBAIaPLkyfZ7nE6nxo8frz179kiSamtrderUqYgaj8ejUaNG2TV79+6VZVl2GJKksWPHyrIsu+Zs7e3tamlpiTgAAMClKapPql62bJlCoZCuv/569evXT52dnXriiSf0ve99T5IUCAQkSS6XK+J9LpdLH330kV0THx+vQYMGdak58/5AIKDU1NQun5+ammrXnK2kpESPPPLI11sgAADoE6K6Q/TCCy9o69at2r59u9544w1t3rxZTz/9tDZv3hxR53A4Il6Hw+EuY2c7u+Zc9eebZ/ny5QqFQvZx5MiRC10WAADoY6K6Q/SjH/1IP/7xj/Xd735XkpSRkaGPPvpIJSUlmjVrltxut6QvdniGDh1qv6+pqcneNXK73ero6FAwGIzYJWpqatK4cePsmmPHjnX5/OPHj3fZfTrD6XTK6XRenIUCAICYFtUdos8++0yXXRbZQr9+/ezb7tPS0uR2u1VRUWGf7+joUGVlpR12MjMz1b9//4iaxsZGHThwwK7Jzs5WKBRSdfX/XZi5b98+hUIhuwYAAJgrqjtEd911l5544gldffXVuuGGG/Tmm29qzZo1mj17tqQvvuYqKipScXGx0tPTlZ6eruLiYg0YMED5+fmSJMuyNGfOHC1evFiDBw9WcnKylixZooyMDE2cOFGSNGLECE2dOlVz587Vhg0bJEnz5s1Tbm7uBd1hBgAALm1RDUTPPvusHnroIRUUFKipqUkej0fz58/XT37yE7tm6dKlamtrU0FBgYLBoLKysrRz504lJibaNWvXrlVcXJxmzJihtrY2TZgwQWVlZerXr59ds23bNhUWFtp3o+Xl5am0tLT3FgsAAGJWVJ9D1JfwHCIgengOEYDu6hPPIQIAAIgFBCIAAGA8AhEAADAegQgAABiPQAQAAIxHIAIAAMYjEAEAAOMRiAAAgPEIRAAAwHgEIgAAYDwCEQAAMB6BCAAAGI9ABAAAjEcgAgAAxiMQAQAA4xGIAACA8QhEAADAeAQiAABgPAIRAAAwHoEIAAAYj0AEAACMRyACAADGIxABAADjEYgAAIDxCEQAAMB4BCIAAGA8AhEAADAegQgAABiPQAQAAIxHIAIAAMaLi3YDAGCKqQ+9EO0WgJjjf2xmtFuQxA4RAABAdAPRtddeK4fD0eVYsGCBJCkcDmvlypXyeDxKSEhQTk6ODh48GDFHe3u7Fi5cqCFDhmjgwIHKy8vT0aNHI2qCwaB8Pp8sy5JlWfL5fGpubu6tZQIAgBgX1UBUU1OjxsZG+6ioqJAk3XPPPZKk1atXa82aNSotLVVNTY3cbrcmTZqkEydO2HMUFRWpvLxcO3bs0O7du9Xa2qrc3Fx1dnbaNfn5+aqrq5Pf75ff71ddXZ18Pl/vLhYAAMSsqF5DlJKSEvF61apV+uu//muNHz9e4XBY69at04oVKzR9+nRJ0ubNm+VyubR9+3bNnz9foVBImzZt0pYtWzRx4kRJ0tatW+X1erVr1y5NmTJF9fX18vv9qqqqUlZWliRp48aNys7O1qFDhzR8+PDeXTQAAIg5MXMNUUdHh7Zu3arZs2fL4XCooaFBgUBAkydPtmucTqfGjx+vPXv2SJJqa2t16tSpiBqPx6NRo0bZNXv37pVlWXYYkqSxY8fKsiy75lza29vV0tIScQAAgEtTzASiF198Uc3NzbrvvvskSYFAQJLkcrki6lwul30uEAgoPj5egwYNOm9Nampql89LTU21a86lpKTEvubIsix5vd5urw0AAMS2mAlEmzZt0rRp0+TxeCLGHQ5HxOtwONxl7Gxn15yr/qvmWb58uUKhkH0cOXLkQpYBAAD6oJgIRB999JF27dql+++/3x5zu92S1GUXp6mpyd41crvd6ujoUDAYPG/NsWPHunzm8ePHu+w+fZnT6VRSUlLEAQAALk0xEYief/55paam6lvf+pY9lpaWJrfbbd95Jn1xnVFlZaXGjRsnScrMzFT//v0jahobG3XgwAG7Jjs7W6FQSNXV1XbNvn37FAqF7BoAAGC2qD+p+vTp03r++ec1a9YsxcX9XzsOh0NFRUUqLi5Wenq60tPTVVxcrAEDBig/P1+SZFmW5syZo8WLF2vw4MFKTk7WkiVLlJGRYd91NmLECE2dOlVz587Vhg0bJEnz5s1Tbm4ud5gBAABJMRCIdu3apcOHD2v27Nldzi1dulRtbW0qKChQMBhUVlaWdu7cqcTERLtm7dq1iouL04wZM9TW1qYJEyaorKxM/fr1s2u2bdumwsJC+260vLw8lZaW9vziAABAn+AIh8PhaDfRF7S0tMiyLIVCoR69nmj/mFt7bG6grxqzv/qri/oAfpcZ0FVP/y6zC/35HRPXEAEAAEQTgQgAABiPQAQAAIxHIAIAAMYjEAEAAOMRiAAAgPEIRAAAwHgEIgAAYDwCEQAAMB6BCAAAGI9ABAAAjEcgAgAAxiMQAQAA4xGIAACA8QhEAADAeAQiAABgPAIRAAAwHoEIAAAYj0AEAACMRyACAADGIxABAADjEYgAAIDxCEQAAMB4BCIAAGA8AhEAADAegQgAABiPQAQAAIxHIAIAAMYjEAEAAOMRiAAAgPEIRAAAwHgEIgAAYLyoB6L//d//1d///d9r8ODBGjBggG666SbV1tba58PhsFauXCmPx6OEhATl5OTo4MGDEXO0t7dr4cKFGjJkiAYOHKi8vDwdPXo0oiYYDMrn88myLFmWJZ/Pp+bm5t5YIgAAiHFRDUTBYFC33Xab+vfvr9/85jd655139NOf/lRXXnmlXbN69WqtWbNGpaWlqqmpkdvt1qRJk3TixAm7pqioSOXl5dqxY4d2796t1tZW5ebmqrOz067Jz89XXV2d/H6//H6/6urq5PP5enO5AAAgRsVF88OffPJJeb1ePf/88/bYtddea/85HA5r3bp1WrFihaZPny5J2rx5s1wul7Zv36758+crFApp06ZN2rJliyZOnChJ2rp1q7xer3bt2qUpU6aovr5efr9fVVVVysrKkiRt3LhR2dnZOnTokIYPH96lt/b2drW3t9uvW1paeuKvAAAAxICo7hC99NJLGjNmjO655x6lpqbq5ptv1saNG+3zDQ0NCgQCmjx5sj3mdDo1fvx47dmzR5JUW1urU6dORdR4PB6NGjXKrtm7d68sy7LDkCSNHTtWlmXZNWcrKSmxv16zLEter/eirh0AAMSOqAaiDz/8UOvXr1d6erpeeeUVPfDAAyosLNQvf/lLSVIgEJAkuVyuiPe5XC77XCAQUHx8vAYNGnTemtTU1C6fn5qaatecbfny5QqFQvZx5MiRr7dYAAAQs6L6ldnp06c1ZswYFRcXS5JuvvlmHTx4UOvXr9c//MM/2HUOhyPifeFwuMvY2c6uOVf9+eZxOp1yOp0XvBYAANB3RXWHaOjQoRo5cmTE2IgRI3T48GFJktvtlqQuuzhNTU32rpHb7VZHR4eCweB5a44dO9bl848fP95l9wkAAJgnqoHotttu06FDhyLG3nvvPV1zzTWSpLS0NLndblVUVNjnOzo6VFlZqXHjxkmSMjMz1b9//4iaxsZGHThwwK7Jzs5WKBRSdXW1XbNv3z6FQiG7BgAAmCuqX5n98Ic/1Lhx41RcXKwZM2aourpazz33nJ577jlJX3zNVVRUpOLiYqWnpys9PV3FxcUaMGCA8vPzJUmWZWnOnDlavHixBg8erOTkZC1ZskQZGRn2XWcjRozQ1KlTNXfuXG3YsEGSNG/ePOXm5p7zDjMAAGCWqAaiW265ReXl5Vq+fLkeffRRpaWlad26dbr33nvtmqVLl6qtrU0FBQUKBoPKysrSzp07lZiYaNesXbtWcXFxmjFjhtra2jRhwgSVlZWpX79+ds22bdtUWFho342Wl5en0tLS3lssAACIWY5wOByOdhN9QUtLiyzLUigUUlJSUo99zv4xt/bY3EBfNWZ/9VcX9QFTH3oh2i0AMcf/2Mwenf9Cf35H/Vd3AAAARBuBCAAAGI9ABAAAjEcgAgAAxiMQAQAA4xGIAACA8QhEAADAeAQiAABgPAIRAAAwHoEIAAAYj0AEAACMRyACAADGIxABAADjEYgAAIDxCEQAAMB4BCIAAGA8AhEAADAegQgAABiPQAQAAIxHIAIAAMYjEAEAAOMRiAAAgPEIRAAAwHgEIgAAYDwCEQAAMB6BCAAAGI9ABAAAjEcgAgAAxiMQAQAA4xGIAACA8QhEAADAeFENRCtXrpTD4Yg43G63fT4cDmvlypXyeDxKSEhQTk6ODh48GDFHe3u7Fi5cqCFDhmjgwIHKy8vT0aNHI2qCwaB8Pp8sy5JlWfL5fGpubu6NJQIAgD4g6jtEN9xwgxobG+3j7bffts+tXr1aa9asUWlpqWpqauR2uzVp0iSdOHHCrikqKlJ5ebl27Nih3bt3q7W1Vbm5uers7LRr8vPzVVdXJ7/fL7/fr7q6Ovl8vl5dJwAAiF1xUW8gLi5iV+iMcDisdevWacWKFZo+fbokafPmzXK5XNq+fbvmz5+vUCikTZs2acuWLZo4caIkaevWrfJ6vdq1a5emTJmi+vp6+f1+VVVVKSsrS5K0ceNGZWdn69ChQxo+fHjvLRYAAMSkqO8Qvf/++/J4PEpLS9N3v/tdffjhh5KkhoYGBQIBTZ482a51Op0aP3689uzZI0mqra3VqVOnImo8Ho9GjRpl1+zdu1eWZdlhSJLGjh0ry7LsmnNpb29XS0tLxAEAAC5NUQ1EWVlZ+uUvf6lXXnlFGzduVCAQ0Lhx4/SHP/xBgUBAkuRyuSLe43K57HOBQEDx8fEaNGjQeWtSU1O7fHZqaqpdcy4lJSX2NUeWZcnr9X6ttQIAgNgV1UA0bdo0fec731FGRoYmTpyoX//615K++GrsDIfDEfGecDjcZexsZ9ecq/6r5lm+fLlCoZB9HDly5ILWBAAA+p6of2X2ZQMHDlRGRobef/99+7qis3dxmpqa7F0jt9utjo4OBYPB89YcO3asy2cdP368y+7TlzmdTiUlJUUcAADg0hRTgai9vV319fUaOnSo0tLS5Ha7VVFRYZ/v6OhQZWWlxo0bJ0nKzMxU//79I2oaGxt14MABuyY7O1uhUEjV1dV2zb59+xQKhewaAABgtqjeZbZkyRLddddduvrqq9XU1KTHH39cLS0tmjVrlhwOh4qKilRcXKz09HSlp6eruLhYAwYMUH5+viTJsizNmTNHixcv1uDBg5WcnKwlS5bYX8FJ0ogRIzR16lTNnTtXGzZskCTNmzdPubm53GEGAAAkRTkQHT16VN/73vf0ySefKCUlRWPHjlVVVZWuueYaSdLSpUvV1tamgoICBYNBZWVlaefOnUpMTLTnWLt2reLi4jRjxgy1tbVpwoQJKisrU79+/eyabdu2qbCw0L4bLS8vT6Wlpb27WAAAELMc4XA4HO0m+oKWlhZZlqVQKNSj1xPtH3Nrj80N9FVj9ld/dVEfMPWhF6LdAhBz/I/N7NH5L/Tnd0xdQwQAABANBCIAAGA8AhEAADAegQgAABiPQAQAAIzXrUB05513qrm5uct4S0uL7rzzzq/bEwAAQK/qViB67bXX1NHR0WX8j3/8o373u9997aYAAAB601/0YMa33nrL/vM777wT8XvGOjs75ff79Vd/9VcXrzsAAIBe8BcFoptuukkOh0MOh+OcX40lJCTo2WefvWjNAQAA9Ia/KBA1NDQoHA5r2LBhqq6uVkpKin0uPj5eqampEb8yAwAAoC/4iwLRmd8xdvr06R5pBgAAIBq6/ctd33vvPb322mtqamrqEpB+8pOffO3GAAAAeku3AtHGjRv1/e9/X0OGDJHb7ZbD4bDPORwOAhEAAOhTuhWIHn/8cT3xxBNatmzZxe4HAACg13XrOUTBYFD33HPPxe4FAAAgKroViO655x7t3LnzYvcCAAAQFd36yuy6667TQw89pKqqKmVkZKh///4R5wsLCy9KcwAAAL2hW4Houeee0xVXXKHKykpVVlZGnHM4HAQiAADQp3QrEDU0NFzsPgAAAKKmW9cQAQAAXEq6tUM0e/bs857/xS9+0a1mAAAAoqFbgSgYDEa8PnXqlA4cOKDm5uZz/tJXAACAWNatQFReXt5l7PTp0yooKNCwYcO+dlMAAAC96aJdQ3TZZZfphz/8odauXXuxpgQAAOgVF/Wi6v/+7//W559/fjGnBAAA6HHd+sps0aJFEa/D4bAaGxv161//WrNmzboojQEAAPSWbgWiN998M+L1ZZddppSUFP30pz/9yjvQAAAAYk23AtFvf/vbi90HAABA1HQrEJ1x/PhxHTp0SA6HQ3/zN3+jlJSUi9UXAABAr+nWRdUnT57U7NmzNXToUN1xxx26/fbb5fF4NGfOHH322WcXu0cAAIAe1a1AtGjRIlVWVuo//uM/1NzcrObmZv3qV79SZWWlFi9efLF7BAAA6FHd+srs3/7t3/Sv//qvysnJsce++c1vKiEhQTNmzND69esvVn8AAAA9rls7RJ999plcLleX8dTU1G5/ZVZSUiKHw6GioiJ7LBwOa+XKlfJ4PEpISFBOTo4OHjwY8b729nYtXLhQQ4YM0cCBA5WXl6ejR49G1ASDQfl8PlmWJcuy5PP51Nzc3K0+AQDApadbgSg7O1sPP/yw/vjHP9pjbW1teuSRR5Sdnf0Xz1dTU6PnnntOo0ePjhhfvXq11qxZo9LSUtXU1MjtdmvSpEk6ceKEXVNUVKTy8nLt2LFDu3fvVmtrq3Jzc9XZ2WnX5Ofnq66uTn6/X36/X3V1dfL5fN1YOQAAuBR16yuzdevWadq0abrqqqt04403yuFwqK6uTk6nUzt37vyL5mptbdW9996rjRs36vHHH7fHw+Gw1q1bpxUrVmj69OmSpM2bN8vlcmn79u2aP3++QqGQNm3apC1btmjixImSpK1bt8rr9WrXrl2aMmWK6uvr5ff7VVVVpaysLEnSxo0blZ2drUOHDmn48OHd+SsAAACXkG7tEGVkZOj9999XSUmJbrrpJo0ePVqrVq3SBx98oBtuuOEvmmvBggX61re+ZQeaMxoaGhQIBDR58mR7zOl0avz48dqzZ48kqba2VqdOnYqo8Xg8GjVqlF2zd+9eWZZlhyFJGjt2rCzLsmvOpb29XS0tLREHAAC4NHVrh6ikpEQul0tz586NGP/FL36h48ePa9myZRc0z44dO/TGG2+opqamy7lAICBJXa5Vcrlc+uijj+ya+Ph4DRo0qEvNmfcHAgGlpqZ2mT81NdWuOZeSkhI98sgjF7QOAADQt3Vrh2jDhg26/vrru4zfcMMN+vnPf35Bcxw5ckQ/+MEPtHXrVl1++eV/ts7hcES8DofDXcbOdnbNueq/ap7ly5crFArZx5EjR877mQAAoO/qViAKBAIaOnRol/GUlBQ1NjZe0By1tbVqampSZmam4uLiFBcXp8rKSj3zzDOKi4uzd4bO3sVpamqyz7ndbnV0dCgYDJ635tixY10+//jx4+e8U+4Mp9OppKSkiAMAAFyauhWIvF6vXn/99S7jr7/+ujwezwXNMWHCBL399tuqq6uzjzFjxujee+9VXV2dhg0bJrfbrYqKCvs9HR0dqqys1Lhx4yRJmZmZ6t+/f0RNY2OjDhw4YNdkZ2crFAqpurrartm3b59CoZBdAwAAzNata4juv/9+FRUV6dSpU7rzzjslSa+++qqWLl16wU+qTkxM1KhRoyLGBg4cqMGDB9vjRUVFKi4uVnp6utLT01VcXKwBAwYoPz9fkmRZlubMmaPFixdr8ODBSk5O1pIlS5SRkWFfpD1ixAhNnTpVc+fO1YYNGyRJ8+bNU25uLneYAQAASd0MREuXLtWnn36qgoICdXR0SJIuv/xyLVu2TMuXL79ozS1dulRtbW0qKChQMBhUVlaWdu7cqcTERLtm7dq1iouL04wZM9TW1qYJEyaorKxM/fr1s2u2bdumwsJC+260vLw8lZaWXrQ+AQBA3+YIh8Ph7r65tbVV9fX1SkhIUHp6upxO58XsLaa0tLTIsiyFQqEevZ5o/5hbe2xuoK8as7/6q4v6gKkPvRDtFoCY439sZo/Of6E/v7u1Q3TGFVdcoVtuueXrTAEAABB13bqoGgAA4FJCIAIAAMYjEAEAAOMRiAAAgPEIRAAAwHgEIgAAYDwCEQAAMB6BCAAAGI9ABAAAjEcgAgAAxiMQAQAA4xGIAACA8QhEAADAeAQiAABgPAIRAAAwHoEIAAAYj0AEAACMRyACAADGIxABAADjEYgAAIDxCEQAAMB4BCIAAGA8AhEAADAegQgAABiPQAQAAIxHIAIAAMYjEAEAAOMRiAAAgPEIRAAAwHgEIgAAYDwCEQAAMF5UA9H69es1evRoJSUlKSkpSdnZ2frNb35jnw+Hw1q5cqU8Ho8SEhKUk5OjgwcPRszR3t6uhQsXasiQIRo4cKDy8vJ09OjRiJpgMCifzyfLsmRZlnw+n5qbm3tjiQAAoA+IaiC66qqrtGrVKu3fv1/79+/XnXfeqW9/+9t26Fm9erXWrFmj0tJS1dTUyO12a9KkSTpx4oQ9R1FRkcrLy7Vjxw7t3r1bra2tys3NVWdnp12Tn5+vuro6+f1++f1+1dXVyefz9fp6AQBAbHKEw+FwtJv4suTkZD311FOaPXu2PB6PioqKtGzZMklf7Aa5XC49+eSTmj9/vkKhkFJSUrRlyxbNnDlTkvTxxx/L6/Xq5Zdf1pQpU1RfX6+RI0eqqqpKWVlZkqSqqiplZ2fr3Xff1fDhwy+or5aWFlmWpVAopKSkpJ5ZvKT9Y27tsbmBvmrM/upot3BRTH3ohWi3AMQc/2Mze3T+C/35HTPXEHV2dmrHjh06efKksrOz1dDQoEAgoMmTJ9s1TqdT48eP1549eyRJtbW1OnXqVESNx+PRqFGj7Jq9e/fKsiw7DEnS2LFjZVmWXXMu7e3tamlpiTgAAMClKeqB6O2339YVV1whp9OpBx54QOXl5Ro5cqQCgYAkyeVyRdS7XC77XCAQUHx8vAYNGnTemtTU1C6fm5qaatecS0lJiX3NkWVZ8nq9X2udAAAgdkU9EA0fPlx1dXWqqqrS97//fc2aNUvvvPOOfd7hcETUh8PhLmNnO7vmXPVfNc/y5csVCoXs48iRIxe6JAAA0MdEPRDFx8fruuuu05gxY1RSUqIbb7xR//RP/yS32y1JXXZxmpqa7F0jt9utjo4OBYPB89YcO3asy+ceP368y+7TlzmdTvvutzMHAAC4NEU9EJ0tHA6rvb1daWlpcrvdqqiosM91dHSosrJS48aNkyRlZmaqf//+ETWNjY06cOCAXZOdna1QKKTq6v+7KHPfvn0KhUJ2DQAAMFtcND/8H//xHzVt2jR5vV6dOHFCO3bs0GuvvSa/3y+Hw6GioiIVFxcrPT1d6enpKi4u1oABA5Sfny9JsixLc+bM0eLFizV48GAlJydryZIlysjI0MSJEyVJI0aM0NSpUzV37lxt2LBBkjRv3jzl5uZe8B1mAADg0hbVQHTs2DH5fD41NjbKsiyNHj1afr9fkyZNkiQtXbpUbW1tKigoUDAYVFZWlnbu3KnExER7jrVr1youLk4zZsxQW1ubJkyYoLKyMvXr18+u2bZtmwoLC+270fLy8lRaWtq7iwUAADEr5p5DFKt4DhEQPTyHCLh08RwiAACAGEEgAgAAxiMQAQAA4xGIAACA8QhEAADAeAQiAABgPAIRAAAwHoEIAAAYj0AEAACMRyACAADGIxABAADjEYgAAIDxCEQAAMB4BCIAAGA8AhEAADAegQgAABiPQAQAAIxHIAIAAMYjEAEAAOMRiAAAgPEIRAAAwHgEIgAAYDwCEQAAMB6BCAAAGI9ABAAAjEcgAgAAxiMQAQAA4xGIAACA8QhEAADAeAQiAABgPAIRAAAwHoEIAAAYL6qBqKSkRLfccosSExOVmpqqu+++W4cOHYqoCYfDWrlypTwejxISEpSTk6ODBw9G1LS3t2vhwoUaMmSIBg4cqLy8PB09ejSiJhgMyufzybIsWZYln8+n5ubmnl4iAADoA6IaiCorK7VgwQJVVVWpoqJCn3/+uSZPnqyTJ0/aNatXr9aaNWtUWlqqmpoaud1uTZo0SSdOnLBrioqKVF5erh07dmj37t1qbW1Vbm6uOjs77Zr8/HzV1dXJ7/fL7/errq5OPp+vV9cLAABikyMcDoej3cQZx48fV2pqqiorK3XHHXcoHA7L4/GoqKhIy5Ytk/TFbpDL5dKTTz6p+fPnKxQKKSUlRVu2bNHMmTMlSR9//LG8Xq9efvllTZkyRfX19Ro5cqSqqqqUlZUlSaqqqlJ2drbeffddDR8+vEsv7e3tam9vt1+3tLTI6/UqFAopKSmpx/4O9o+5tcfmBvqqMfuro93CRTH1oRei3QIQc/yPzezR+VtaWmRZ1lf+/I6pa4hCoZAkKTk5WZLU0NCgQCCgyZMn2zVOp1Pjx4/Xnj17JEm1tbU6depURI3H49GoUaPsmr1798qyLDsMSdLYsWNlWZZdc7aSkhL76zXLsuT1ei/uYgEAQMyImUAUDoe1aNEifeMb39CoUaMkSYFAQJLkcrkial0ul30uEAgoPj5egwYNOm9Nampql89MTU21a862fPlyhUIh+zhy5MjXWyAAAIhZcdFu4IwHH3xQb731lnbv3t3lnMPhiHgdDoe7jJ3t7Jpz1Z9vHqfTKafTeSGtAwCAPi4mdogWLlyol156Sb/97W911VVX2eNut1uSuuziNDU12btGbrdbHR0dCgaD5605duxYl889fvx4l90nAABgnqgGonA4rAcffFD//u//rv/8z/9UWlpaxPm0tDS53W5VVFTYYx0dHaqsrNS4ceMkSZmZmerfv39ETWNjow4cOGDXZGdnKxQKqbr6/y7M3Ldvn0KhkF0DAADMFdWvzBYsWKDt27frV7/6lRITE+2dIMuylJCQIIfDoaKiIhUXFys9PV3p6ekqLi7WgAEDlJ+fb9fOmTNHixcv1uDBg5WcnKwlS5YoIyNDEydOlCSNGDFCU6dO1dy5c7VhwwZJ0rx585Sbm3vOO8wAAIBZohqI1q9fL0nKycmJGH/++ed13333SZKWLl2qtrY2FRQUKBgMKisrSzt37lRiYqJdv3btWsXFxWnGjBlqa2vThAkTVFZWpn79+tk127ZtU2FhoX03Wl5enkpLS3t2gQAAoE+IqecQxbILfY7B18VziICueA4RcOniOUQAAAAxgkAEAACMRyACAADGIxABAADjEYgAAIDxCEQAAMB4BCIAAGA8AhEAADAegQgAABiPQAQAAIxHIAIAAMYjEAEAAOMRiAAAgPEIRAAAwHgEIgAAYDwCEQAAMB6BCAAAGI9ABAAAjEcgAgAAxiMQAQAA4xGIAACA8QhEAADAeAQiAABgPAIRAAAwHoEIAAAYj0AEAACMRyACAADGIxABAADjEYgAAIDxCEQAAMB4BCIAAGC8qAai//qv/9Jdd90lj8cjh8OhF198MeJ8OBzWypUr5fF4lJCQoJycHB08eDCipr29XQsXLtSQIUM0cOBA5eXl6ejRoxE1wWBQPp9PlmXJsiz5fD41Nzf38OoAAEBfEdVAdPLkSd14440qLS095/nVq1drzZo1Ki0tVU1NjdxutyZNmqQTJ07YNUVFRSovL9eOHTu0e/dutba2Kjc3V52dnXZNfn6+6urq5Pf75ff7VVdXJ5/P1+PrAwAAfUNcND982rRpmjZt2jnPhcNhrVu3TitWrND06dMlSZs3b5bL5dL27ds1f/58hUIhbdq0SVu2bNHEiRMlSVu3bpXX69WuXbs0ZcoU1dfXy+/3q6qqSllZWZKkjRs3Kjs7W4cOHdLw4cN7Z7EAACBmxew1RA0NDQoEApo8ebI95nQ6NX78eO3Zs0eSVFtbq1OnTkXUeDwejRo1yq7Zu3evLMuyw5AkjR07VpZl2TXn0t7erpaWlogDAABcmmI2EAUCAUmSy+WKGHe5XPa5QCCg+Ph4DRo06Lw1qampXeZPTU21a86lpKTEvubIsix5vd6vtR4AABC7YjYQneFwOCJeh8PhLmNnO7vmXPVfNc/y5csVCoXs48iRI39h5wAAoK+I2UDkdrslqcsuTlNTk71r5Ha71dHRoWAweN6aY8eOdZn/+PHjXXafvszpdCopKSniAAAAl6aYDURpaWlyu92qqKiwxzo6OlRZWalx48ZJkjIzM9W/f/+ImsbGRh04cMCuyc7OVigUUnV1tV2zb98+hUIhuwYAAJgtqneZtba26oMPPrBfNzQ0qK6uTsnJybr66qtVVFSk4uJipaenKz09XcXFxRowYIDy8/MlSZZlac6cOVq8eLEGDx6s5ORkLVmyRBkZGfZdZyNGjNDUqVM1d+5cbdiwQZI0b9485ebmcocZAACQFOVAtH//fv3d3/2d/XrRokWSpFmzZqmsrExLly5VW1ubCgoKFAwGlZWVpZ07dyoxMdF+z9q1axUXF6cZM2aora1NEyZMUFlZmfr162fXbNu2TYWFhfbdaHl5eX/22UcAAMA8jnA4HI52E31BS0uLLMtSKBTq0euJ9o+5tcfmBvqqMfurv7qoD5j60AvRbgGIOf7HZvbo/Bf68ztmryECAADoLQQiAABgPAIRAAAwHoEIAAAYj0AEAACMRyACAADGIxABAADjEYgAAIDxCEQAAMB4BCIAAGA8AhEAADAegQgAABiPQAQAAIxHIAIAAMYjEAEAAOMRiAAAgPEIRAAAwHgEIgAAYDwCEQAAMB6BCAAAGI9ABAAAjEcgAgAAxiMQAQAA4xGIAACA8QhEAADAeAQiAABgPAIRAAAwHoEIAAAYj0AEAACMRyACAADGIxABAADjEYgAAIDxjApEP/vZz5SWlqbLL79cmZmZ+t3vfhftlgAAQAwwJhC98MILKioq0ooVK/Tmm2/q9ttv17Rp03T48OFotwYAAKLMmEC0Zs0azZkzR/fff79GjBihdevWyev1av369dFuDQAARFlctBvoDR0dHaqtrdWPf/zjiPHJkydrz54953xPe3u72tvb7dehUEiS1NLS0nONSmrt7OzR+YG+qKf/3fWWz9s/i3YLQMzp6X/fZ+YPh8PnrTMiEH3yySfq7OyUy+WKGHe5XAoEAud8T0lJiR555JEu416vt0d6BHAelhXtDgD0EOup2b3yOSdOnJB1nv9LjAhEZzgcjojX4XC4y9gZy5cv16JFi+zXp0+f1qeffqrBgwf/2ffg0tHS0iKv16sjR44oKSkp2u0AuIj4922WcDisEydOyOPxnLfOiEA0ZMgQ9evXr8tuUFNTU5ddozOcTqecTmfE2JVXXtlTLSJGJSUl8R8mcIni37c5zrczdIYRF1XHx8crMzNTFRUVEeMVFRUaN25clLoCAACxwogdIklatGiRfD6fxowZo+zsbD333HM6fPiwHnjggWi3BgAAosyYQDRz5kz94Q9/0KOPPqrGxkaNGjVKL7/8sq655ppot4YY5HQ69fDDD3f52hRA38e/b5yLI/xV96EBAABc4oy4hggAAOB8CEQAAMB4BCIAAGA8AhEAADAegQg4SyAQ0MKFCzVs2DA5nU55vV7dddddevXVV6PdGoCv4b777tPdd98d7TYQo4y57R64EP/zP/+j2267TVdeeaVWr16t0aNH69SpU3rllVe0YMECvfvuu9FuEQDQAwhEwJcUFBTI4XCourpaAwcOtMdvuOEGzZ7dO7+AEADQ+/jKDPiTTz/9VH6/XwsWLIgIQ2fwu+wA4NJFIAL+5IMPPlA4HNb1118f7VYAAL2MQAT8yZmHtjscjih3AgDobQQi4E/S09PlcDhUX18f7VYAAL2MQAT8SXJysqZMmaJ//ud/1smTJ7ucb25u7v2mAAC9grvMgC/52c9+pnHjxunWW2/Vo48+qtGjR+vzzz9XRUWF1q9fz+4R0MeFQiHV1dVFjCUnJ+vqq6+OTkOIGQQi4EvS0tL0xhtv6IknntDixYvV2NiolJQUZWZmav369dFuD8DX9Nprr+nmm2+OGJs1a5bKysqi0xBihiN85kpSAAAAQ3ENEQAAMB6BCAAAGI9ABAAAjEcgAgAAxiMQAQAA4xGIAACA8QhEAADAeAQiAABgPAIRgEtCTk6OioqKLqj2tddek8Ph+Nq/n+7aa6/VunXrvtYcAGIDgQgAABiPQAQAAIxHIAJwydm6davGjBmjxMREud1u5efnq6mpqUvd66+/rhtvvFGXX365srKy9Pbbb0ec37Nnj+644w4lJCTI6/WqsLBQJ0+e7K1lAOhFBCIAl5yOjg499thj+v3vf68XX3xRDQ0Nuu+++7rU/ehHP9LTTz+tmpoapaamKi8vT6dOnZIkvf3225oyZYqmT5+ut956Sy+88IJ2796tBx98sJdXA6A3xEW7AQC42GbPnm3/ediwYXrmmWd06623qrW1VVdccYV97uGHH9akSZMkSZs3b9ZVV12l8vJyzZgxQ0899ZTy8/PtC7XT09P1zDPPaPz48Vq/fr0uv/zyXl0TgJ7FDhGAS86bb76pb3/727rmmmuUmJionJwcSdLhw4cj6rKzs+0/Jycna/jw4aqvr5ck1dbWqqysTFdccYV9TJkyRadPn1ZDQ0OvrQVA72CHCMAl5eTJk5o8ebImT56srVu3KiUlRYcPH9aUKVPU0dHxle93OBySpNOnT2v+/PkqLCzsUnP11Vdf9L4BRBeBCMAl5d1339Unn3yiVatWyev1SpL2799/ztqqqio73ASDQb333nu6/vrrJUl/+7d/q4MHD+q6667rncYBRBVfmQG4pFx99dWKj4/Xs88+qw8//FAvvfSSHnvssXPWPvroo3r11Vd14MAB3XfffRoyZIjuvvtuSdKyZcu0d+9eLViwQHV1dXr//ff10ksvaeHChb24GgC9hUAE4JKSkpKisrIy/cu//ItGjhypVatW6emnnz5n7apVq/SDH/xAmZmZamxs1EsvvaT4+HhJ0ujRo1VZWan3339ft99+u26++WY99NBDGjp0aG8uB0AvcYTD4XC0mwAAAIgmdogAAIDxCEQAAMB4BCIAAGA8AhEAADAegQgAABiPQAQAAIxHIAIAAMYjEAEAAOMRiAAAgPEIRAAAwHgEIgAAYLz/D8rYlYkPwpdYAAAAAElFTkSuQmCC",
|
| 334 |
+
"text/plain": [
|
| 335 |
+
"<Figure size 640x480 with 1 Axes>"
|
| 336 |
+
]
|
| 337 |
+
},
|
| 338 |
+
"metadata": {},
|
| 339 |
+
"output_type": "display_data"
|
| 340 |
+
}
|
| 341 |
+
],
|
| 342 |
+
"source": [
|
| 343 |
+
"df = describe_dataset(\"./train.csv\")\n",
|
| 344 |
+
"sns.countplot(x='label', data=df, palette=\"Set1\")"
|
| 345 |
+
]
|
| 346 |
+
},
|
| 347 |
+
{
|
| 348 |
+
"cell_type": "markdown",
|
| 349 |
+
"metadata": {},
|
| 350 |
+
"source": [
|
| 351 |
+
"### 4. Gather Test Dataset"
|
| 352 |
+
]
|
| 353 |
+
},
|
| 354 |
+
{
|
| 355 |
+
"cell_type": "code",
|
| 356 |
+
"execution_count": 18,
|
| 357 |
+
"metadata": {},
|
| 358 |
+
"outputs": [],
|
| 359 |
+
"source": [
|
| 360 |
+
"TEST_DATASET_PATH = \"test.csv\"\n",
|
| 361 |
+
"\n",
|
| 362 |
+
"cap = cv2.VideoCapture(\"../data/db_curl/bc_test_2.mp4\")\n",
|
| 363 |
+
"save_counts = 0\n",
|
| 364 |
+
"\n",
|
| 365 |
+
"init_csv(TEST_DATASET_PATH)\n",
|
| 366 |
+
"\n",
|
| 367 |
+
"with mp_pose.Pose(min_detection_confidence=0.5, min_tracking_confidence=0.5) as pose:\n",
|
| 368 |
+
" while cap.isOpened():\n",
|
| 369 |
+
" ret, image = cap.read()\n",
|
| 370 |
+
"\n",
|
| 371 |
+
" if not ret:\n",
|
| 372 |
+
" break\n",
|
| 373 |
+
"\n",
|
| 374 |
+
" # Reduce size of a frame\n",
|
| 375 |
+
" image = rescale_frame(image, 60)\n",
|
| 376 |
+
" image = cv2.flip(image, 1)\n",
|
| 377 |
+
"\n",
|
| 378 |
+
" # Recolor image from BGR to RGB for mediapipe\n",
|
| 379 |
+
" image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)\n",
|
| 380 |
+
" image.flags.writeable = False\n",
|
| 381 |
+
"\n",
|
| 382 |
+
" results = pose.process(image)\n",
|
| 383 |
+
"\n",
|
| 384 |
+
" if not results.pose_landmarks:\n",
|
| 385 |
+
" print(\"Cannot detect pose - No human found\")\n",
|
| 386 |
+
" continue\n",
|
| 387 |
+
"\n",
|
| 388 |
+
" # Recolor image from BGR to RGB for mediapipe\n",
|
| 389 |
+
" image.flags.writeable = True\n",
|
| 390 |
+
" image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)\n",
|
| 391 |
+
"\n",
|
| 392 |
+
" # Draw landmarks and connections\n",
|
| 393 |
+
" mp_drawing.draw_landmarks(image, results.pose_landmarks, mp_pose.POSE_CONNECTIONS, mp_drawing.DrawingSpec(color=(244, 117, 66), thickness=2, circle_radius=4), mp_drawing.DrawingSpec(color=(245, 66, 230), thickness=2, circle_radius=2))\n",
|
| 394 |
+
"\n",
|
| 395 |
+
" # Display the saved count\n",
|
| 396 |
+
" cv2.putText(image, f\"Saved: {save_counts}\", (50, 50), cv2.FONT_HERSHEY_COMPLEX, 2, (0, 0, 0), 2, cv2.LINE_AA)\n",
|
| 397 |
+
"\n",
|
| 398 |
+
" cv2.imshow(\"CV2\", image)\n",
|
| 399 |
+
"\n",
|
| 400 |
+
" # Pressed key for action\n",
|
| 401 |
+
" k = cv2.waitKey(10) & 0xFF\n",
|
| 402 |
+
"\n",
|
| 403 |
+
" # Press C to save as correct form\n",
|
| 404 |
+
" if k == ord('c'): \n",
|
| 405 |
+
" export_landmark_to_csv(TEST_DATASET_PATH, results, \"C\")\n",
|
| 406 |
+
" save_counts += 1\n",
|
| 407 |
+
" # Press L to save as low back\n",
|
| 408 |
+
" elif k == ord(\"l\"):\n",
|
| 409 |
+
" export_landmark_to_csv(TEST_DATASET_PATH, results, \"L\")\n",
|
| 410 |
+
" save_counts += 1\n",
|
| 411 |
+
"\n",
|
| 412 |
+
" # Press q to stop\n",
|
| 413 |
+
" elif k == ord(\"q\"):\n",
|
| 414 |
+
" break\n",
|
| 415 |
+
" else: continue\n",
|
| 416 |
+
"\n",
|
| 417 |
+
" cap.release()\n",
|
| 418 |
+
" cv2.destroyAllWindows()\n",
|
| 419 |
+
"\n",
|
| 420 |
+
" # (Optional)Fix bugs cannot close windows in MacOS (https://stackoverflow.com/questions/6116564/destroywindow-does-not-close-window-on-mac-using-python-and-opencv)\n",
|
| 421 |
+
" for i in range (1, 5):\n",
|
| 422 |
+
" cv2.waitKey(1)\n",
|
| 423 |
+
" "
|
| 424 |
+
]
|
| 425 |
+
},
|
| 426 |
+
{
|
| 427 |
+
"cell_type": "code",
|
| 428 |
+
"execution_count": 24,
|
| 429 |
+
"metadata": {},
|
| 430 |
+
"outputs": [
|
| 431 |
+
{
|
| 432 |
+
"name": "stdout",
|
| 433 |
+
"output_type": "stream",
|
| 434 |
+
"text": [
|
| 435 |
+
"Headers: ['label', 'nose_x', 'nose_y', 'nose_z', 'nose_v', 'left_shoulder_x', 'left_shoulder_y', 'left_shoulder_z', 'left_shoulder_v', 'right_shoulder_x', 'right_shoulder_y', 'right_shoulder_z', 'right_shoulder_v', 'right_elbow_x', 'right_elbow_y', 'right_elbow_z', 'right_elbow_v', 'left_elbow_x', 'left_elbow_y', 'left_elbow_z', 'left_elbow_v', 'right_wrist_x', 'right_wrist_y', 'right_wrist_z', 'right_wrist_v', 'left_wrist_x', 'left_wrist_y', 'left_wrist_z', 'left_wrist_v', 'left_hip_x', 'left_hip_y', 'left_hip_z', 'left_hip_v', 'right_hip_x', 'right_hip_y', 'right_hip_z', 'right_hip_v']\n",
|
| 436 |
+
"Number of rows: 604 \n",
|
| 437 |
+
"Number of columns: 37\n",
|
| 438 |
+
"\n",
|
| 439 |
+
"Labels: \n",
|
| 440 |
+
"C 339\n",
|
| 441 |
+
"L 265\n",
|
| 442 |
+
"Name: label, dtype: int64\n",
|
| 443 |
+
"\n",
|
| 444 |
+
"Missing values: False\n",
|
| 445 |
+
"\n",
|
| 446 |
+
"Duplicate Rows : 0\n"
|
| 447 |
+
]
|
| 448 |
+
},
|
| 449 |
+
{
|
| 450 |
+
"data": {
|
| 451 |
+
"text/plain": [
|
| 452 |
+
"<AxesSubplot:xlabel='count', ylabel='label'>"
|
| 453 |
+
]
|
| 454 |
+
},
|
| 455 |
+
"execution_count": 24,
|
| 456 |
+
"metadata": {},
|
| 457 |
+
"output_type": "execute_result"
|
| 458 |
+
},
|
| 459 |
+
{
|
| 460 |
+
"data": {
|
| 461 |
+
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAjAAAAGwCAYAAAC3qV8qAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/NK7nSAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAY+UlEQVR4nO3de2zV9f348VcRKMiggnLruIx5B4E5xIm6rw4n6qbD+N3ilmXDsc0ogiGSzLH9HCa/IWgyFw3e5szULAtmKs6EiVfAeSFeAAUExhQFFMSpUAQtAp/fHwvnt0rRtpYeXvTxSJrQcymvvvJO+szpOacVRVEUAQCQSJtyDwAA0FgCBgBIR8AAAOkIGAAgHQEDAKQjYACAdAQMAJBO23IP8Hnt2rUr3nrrrejcuXNUVFSUexwAoAGKoogtW7ZEdXV1tGnT+MdT0gfMW2+9FX379i33GABAE6xduzb69OnT6PulD5jOnTtHxH8W0KVLlzJPAwA0RE1NTfTt27f0c7yx0gfM7l8bdenSRcAAQDJNffqHJ/ECAOkIGAAgHQEDAKQjYACAdAQMAJCOgAEA0hEwAEA6AgYASEfAAADpCBgAIB0BAwCkI2AAgHQEDACQjoABANIRMABAOgIGAEinbbkHaC4LT/tGfOGgg8o9BgAcME544blyj7BXHoEBANIRMABAOgIGAEhHwAAA6QgYACAdAQMApCNgAIB0BAwAkI6AAQDSETAAQDoCBgBIR8AAAOkIGAAgHQEDAKQjYACAdAQMAJCOgAEA0hEwAEA6AgYASEfAAADpCBgAIB0BAwCkI2AAgHQEDACQjoABANIRMABAOgIGAEhHwAAA6QgYACAdAQMApCNgAIB0BAwAkI6AAQDSETAAQDoCBgBIR8AAAOkIGAAgHQEDAKQjYACAdAQMAJCOgAEA0hEwAEA6AgYASEfAAADpCBgAIB0BAwCkI2AAgHQEDACQjoABANIRMABAOgIGAEhHwAAA6QgYACAdAQMApCNgAIB0BAwAkI6AAQDSETAAQDoCBgBIR8AAAOkIGAAgHQEDAKQjYACAdAQMAJCOgAEA0hEwAEA6AgYASEfAAADpCBgAIJ2yB8yGDRtiwoQJ8eUvfzkqKyujb9++cd5558Xjjz9e7tEAgP1U23L+56+//nqccsopccghh8R1110XQ4YMiY8//jgefvjhuOyyy2LFihXlHA8A2E+VNWDGjRsXFRUV8dxzz0WnTp1Klw8aNCjGjh1bxskAgP1Z2QLmvffeizlz5sTUqVPrxMtuhxxySL33q62tjdra2tLnNTU1+2pEAGA/VbbnwPzrX/+KoijimGOOadT9pk2bFlVVVaWPvn377qMJAYD9VdkCpiiKiIioqKho1P0mT54cmzdvLn2sXbt2X4wHAOzHyhYwRx55ZFRUVMTy5csbdb/Kysro0qVLnQ8AoHUpW8B069YtzjrrrLjpppti69ate1y/adOmlh8KAEihrO8Dc/PNN8fOnTvjxBNPjPvuuy9WrVoVy5cvjxtvvDFGjBhRztEAgP1YWV9GPWDAgFi4cGFMnTo1Jk2aFOvXr4/u3bvHsGHD4pZbbinnaADAfqyi2P1s2qRqamqiqqoq5n7lq/GFgw4q9zgAcMA44YXn9tnX3v3ze/PmzU16PmvZ/5QAAEBjCRgAIB0BAwCkI2AAgHQEDACQjoABANIRMABAOgIGAEhHwAAA6QgYACAdAQMApCNgAIB0BAwAkI6AAQDSETAAQDoCBgBIR8AAAOkIGAAgHQEDAKQjYACAdAQMAJCOgAEA0hEwAEA6AgYASEfAAADpCBgAIB0BAwCkI2AAgHQEDACQjoABANIRMABAOgIGAEhHwAAA6QgYACAdAQMApCNgAIB0BAwAkI6AAQDSETAAQDoCBgBIR8AAAOkIGAAgHQEDAKQjYACAdAQMAJCOgAEA0hEwAEA6AgYASEfAAADpCBgAIB0BAwCkI2AAgHQEDACQjoABANIRMABAOgIGAEhHwAAA6QgYACAdAQMApCNgAIB0BAwAkI6AAQDSETAAQDoCBgBIR8AAAOkIGAAgnbblHqC5fHX+3OjSpUu5xwAAWoBHYACAdAQMAJCOgAEA0hEwAEA6AgYASEfAAADpNPhl1DfeeGODv+jll1/epGEAABqioiiKoiE3HDBgQMO+YEVFvPbaa59rqMaoqamJqqqq2Lx5s/eBAYAkPu/P7wY/ArN69epGf3EAgH3hcz0HZvv27bFy5crYsWNHc80DAPCZmhQw27Zti5/+9Kdx8MEHx6BBg2LNmjUR8Z/nvkyfPr1ZBwQA+KQmBczkyZPjpZdeinnz5kWHDh1Kl3/zm9+Me+65p9mGAwCoT5P+mOMDDzwQ99xzT5x00klRUVFRunzgwIHx6quvNttwAAD1adIjMO+880706NFjj8u3bt1aJ2gAAPaFJgXM8OHDY/bs2aXPd0fL7bffHiNGjGieyQAA9qJJv0KaNm1anH322fHKK6/Ejh074oYbbohly5bFs88+G/Pnz2/uGQEA6mjSIzAnn3xyPP3007Ft27Y4/PDD45FHHomePXvGs88+G8OGDWvuGQEA6mjwO/Hur7wTLwDk02LvxPtJO3fujFmzZsXy5cujoqIijj322Bg9enS0bdvkLwkA0CBNqo2lS5fG6NGjY8OGDXH00UdHRMQ///nP6N69ezz44IMxePDgZh0SAOC/Nek5MD/72c9i0KBBsW7duli4cGEsXLgw1q5dG0OGDImLL764uWcEAKijSY/AvPTSS/HCCy9E165dS5d17do1pk6dGsOHD2+24QAA6tOkR2COPvroePvtt/e4fOPGjXHEEUd87qEAAD5NgwOmpqam9HHNNdfE5ZdfHvfee2+sW7cu1q1bF/fee29MnDgxrr322n05LwBAw19G3aZNmzp/JmD33XZf9t+f79y5s7nn3CsvowaAfFrsZdRz585t9BcHANgXGhwwp5122r6cAwCgwT7Xu85t27Yt1qxZE9u3b69z+ZAhQz7XUAAAn6ZJAfPOO+/ET37yk3jooYfqvb4lnwMDALQ+TXoZ9cSJE+P999+PBQsWRMeOHWPOnDlx1113xZFHHhkPPvhgc88IAFBHkx6BeeKJJ+Jvf/tbDB8+PNq0aRP9+/ePM888M7p06RLTpk2Lb3/72809JwBASZMegdm6dWv06NEjIiK6desW77zzTkREDB48OBYuXNh80wEA1KPJ78S7cuXKiIj4yle+Erfddlu8+eabceutt0bv3r2bdUAAgE9q0q+QJk6cGOvXr4+IiClTpsRZZ50Vf/7zn6N9+/Zx1113NeuAAACf1OB34v0027ZtixUrVkS/fv3isMMOa465Gsw78QJAPi32TrxXXHFFg7/o9ddf3+hBAAAaqsEBs2jRogbd7r//XlJLuuC390XbyoPL8n8D0HRz/u+F5R6BhPwtJAAgnSa9CgkAoJwEDACQjoABANIRMABAOgIGAEhHwAAA6QgYACAdAQMApCNgAIB0BAwAkI6AAQDSETAAQDoCBgBIR8AAAOkIGAAgHQEDAKQjYACAdAQMAJCOgAEA0hEwAEA6AgYASEfAAADpCBgAIB0BAwCkI2AAgHQEDACQjoABANIRMABAOgIGAEhHwAAA6QgYACAdAQMApCNgAIB0BAwAkI6AAQDSETAAQDoCBgBIR8AAAOkIGAAgHQEDAKQjYACAdAQMAJCOgAEA0hEwAEA6AgYASEfAAADpCBgAIB0BAwCkI2AAgHQEDACQjoABANIRMABAOgIGAEhHwAAA6QgYACAdAQMApCNgAIB0BAwAkI6AAQDSETAAQDoCBgBIR8AAAOkIGAAgHQEDAKQjYACAdAQMAJCOgAEA0hEwAEA6AgYASGe/CJiLLroozj///HKPAQAksV8EDABAY7Qt9wCNVVtbG7W1taXPa2pqyjgNAFAO6R6BmTZtWlRVVZU++vbtW+6RAIAWli5gJk+eHJs3by59rF27ttwjAQAtLN2vkCorK6OysrLcYwAAZZTuERgAAAEDAKSz3/wKafPmzbF48eI6l3Xr1i369etXnoEAgP3WfhMw8+bNi+OPP77OZWPGjIk777yzPAMBAPut/SJg7rzzTqECADSY58AAAOkIGAAgHQEDAKQjYACAdAQMAJCOgAEA0hEwAEA6AgYASEfAAADpCBgAIB0BAwCkI2AAgHQEDACQjoABANIRMABAOgIGAEhHwAAA6QgYACAdAQMApCNgAIB0BAwAkI6AAQDSETAAQDoCBgBIR8AAAOkIGAAgHQEDAKQjYACAdAQMAJCOgAEA0hEwAEA6AgYASEfAAADpCBgAIB0BAwCkI2AAgHQEDACQjoABANIRMABAOgIGAEhHwAAA6QgYACAdAQMApCNgAIB0BAwAkI6AAQDSETAAQDoCBgBIR8AAAOkIGAAgHQEDAKQjYACAdAQMAJCOgAEA0hEwAEA6AgYASEfAAADpCBgAIB0BAwCkI2AAgHQEDACQjoABANIRMABAOgIGAEhHwAAA6QgYACCdtuUeoLnc/3/+N7p06VLuMQCAFuARGAAgHQEDAKQjYACAdAQMAJCOgAEA0hEwAEA6AgYASEfAAADpCBgAIB0BAwCkI2AAgHQEDACQjoABANIRMABAOgIGAEhHwAAA6QgYACCdtuUe4PMqiiIiImpqaso8CQDQULt/bu/+Od5Y6QPm3XffjYiIvn37lnkSAKCxtmzZElVVVY2+X/qA6datW0RErFmzpkkLOFDV1NRE3759Y+3atdGlS5dyj7NfsJP62Uv97GVPdlI/e6nfZ+2lKIrYsmVLVFdXN+nrpw+YNm3+8zSeqqoqB6ceXbp0sZdPsJP62Uv97GVPdlI/e6nfp+3l8zzw4Em8AEA6AgYASCd9wFRWVsaUKVOisrKy3KPsV+xlT3ZSP3upn73syU7qZy/129d7qSia+volAIAySf8IDADQ+ggYACAdAQMApCNgAIB0UgfMzTffHAMGDIgOHTrEsGHD4h//+Ee5R2pRV199dVRUVNT56NWrV+n6oiji6quvjurq6ujYsWOcfvrpsWzZsjJOvG88+eSTcd5550V1dXVUVFTEAw88UOf6huyhtrY2JkyYEIcddlh06tQpvvOd78S6deta8LtoXp+1k4suumiPs3PSSSfVuc2BtpNp06bF8OHDo3PnztGjR484//zzY+XKlXVu0xrPSkP20hrPyy233BJDhgwpvQnbiBEj4qGHHipd3xrPymftpKXPSdqAueeee2LixInx61//OhYtWhRf//rX45xzzok1a9aUe7QWNWjQoFi/fn3pY8mSJaXrrrvuurj++utjxowZ8fzzz0evXr3izDPPjC1btpRx4ua3devWGDp0aMyYMaPe6xuyh4kTJ8asWbNi5syZ8dRTT8UHH3wQ5557buzcubOlvo1m9Vk7iYg4++yz65ydv//973WuP9B2Mn/+/LjssstiwYIF8eijj8aOHTti1KhRsXXr1tJtWuNZacheIlrfeenTp09Mnz49XnjhhXjhhRdi5MiRMXr06FKktMaz8lk7iWjhc1IkdeKJJxaXXHJJncuOOeaY4pe//GWZJmp5U6ZMKYYOHVrvdbt27Sp69epVTJ8+vXTZRx99VFRVVRW33nprC03Y8iKimDVrVunzhuxh06ZNRbt27YqZM2eWbvPmm28Wbdq0KebMmdNis+8rn9xJURTFmDFjitGjR+/1Pgf6ToqiKDZu3FhERDF//vyiKJyV3T65l6JwXnbr2rVr8cc//tFZ+S+7d1IULX9OUj4Cs3379njxxRdj1KhRdS4fNWpUPPPMM2WaqjxWrVoV1dXVMWDAgPj+978fr732WkRErF69OjZs2FBnR5WVlXHaaae1qh01ZA8vvvhifPzxx3VuU11dHccdd9wBvat58+ZFjx494qijjoqf//znsXHjxtJ1rWEnmzdvjoj//wdhnZX/+ORedmvN52Xnzp0xc+bM2Lp1a4wYMcJZiT13sltLnpOUf8zx3//+d+zcuTN69uxZ5/KePXvGhg0byjRVy/va174Wd999dxx11FHx9ttvx29/+9s4+eSTY9myZaU91LejN954oxzjlkVD9rBhw4Zo3759dO3adY/bHKjn6Zxzzonvfe970b9//1i9enVcddVVMXLkyHjxxRejsrLygN9JURRxxRVXxKmnnhrHHXdcRDgrEfXvJaL1npclS5bEiBEj4qOPPoovfOELMWvWrBg4cGDph21rPCt720lEy5+TlAGzW0VFRZ3Pi6LY47ID2TnnnFP69+DBg2PEiBFx+OGHx1133VV64lRr39FuTdnDgbyrCy+8sPTv4447Lk444YTo379/zJ49Oy644IK93u9A2cn48ePj5ZdfjqeeemqP61rzWdnbXlrreTn66KNj8eLFsWnTprjvvvtizJgxMX/+/NL1rfGs7G0nAwcObPFzkvJXSIcddlgcdNBBexTbxo0b9yji1qRTp04xePDgWLVqVenVSK19Rw3ZQ69evWL79u3x/vvv7/U2B7revXtH//79Y9WqVRFxYO9kwoQJ8eCDD8bcuXOjT58+pctb+1nZ217q01rOS/v27eOII46IE044IaZNmxZDhw6NG264oVWflb3tpD77+pykDJj27dvHsGHD4tFHH61z+aOPPhonn3xymaYqv9ra2li+fHn07t07BgwYEL169aqzo+3bt8f8+fNb1Y4asodhw4ZFu3bt6txm/fr1sXTp0lazq3fffTfWrl0bvXv3jogDcydFUcT48ePj/vvvjyeeeCIGDBhQ5/rWelY+ay/1aQ3npT5FUURtbW2rPSv12b2T+uzzc9Lop/3uJ2bOnFm0a9euuOOOO4pXXnmlmDhxYtGpU6fi9ddfL/doLWbSpEnFvHnzitdee61YsGBBce655xadO3cu7WD69OlFVVVVcf/99xdLliwpfvCDHxS9e/cuampqyjx589qyZUuxaNGiYtGiRUVEFNdff32xaNGi4o033iiKomF7uOSSS4o+ffoUjz32WLFw4cJi5MiRxdChQ4sdO3aU69v6XD5tJ1u2bCkmTZpUPPPMM8Xq1auLuXPnFiNGjCi++MUvHtA7ufTSS4uqqqpi3rx5xfr160sf27ZtK92mNZ6Vz9pLaz0vkydPLp588sli9erVxcsvv1z86le/Ktq0aVM88sgjRVG0zrPyaTspxzlJGzBFURQ33XRT0b9//6J9+/bFV7/61Tov+2sNLrzwwqJ3795Fu3btiurq6uKCCy4oli1bVrp+165dxZQpU4pevXoVlZWVxf/8z/8US5YsKePE+8bcuXOLiNjjY8yYMUVRNGwPH374YTF+/PiiW7duRceOHYtzzz23WLNmTRm+m+bxaTvZtm1bMWrUqKJ79+5Fu3btin79+hVjxozZ4/s90HZS3z4iovjTn/5Uuk1rPCuftZfWel7Gjh1b+vnSvXv34owzzijFS1G0zrPyaTspxzmpKIqiaPzjNgAA5ZPyOTAAQOsmYACAdAQMAJCOgAEA0hEwAEA6AgYASEfAAADpCBgAIB0BAwCkI2CAA9Lrr78eFRUVsXjx4nKPAuwDAgYASEfAAPvErl274tprr40jjjgiKisro1+/fjF16tSIiFiyZEmMHDkyOnbsGIceemhcfPHF8cEHH5Tue/rpp8fEiRPrfL3zzz8/LrrootLnX/rSl+Kaa66JsWPHRufOnaNfv37xhz/8oXT9gAEDIiLi+OOPj4qKijj99NP32fcKtDwBA+wTkydPjmuvvTauuuqqeOWVV+Ivf/lL9OzZM7Zt2xZnn312dO3aNZ5//vn461//Go899liMHz++0f/H7373uzjhhBNi0aJFMW7cuLj00ktjxYoVERHx3HPPRUTEY489FuvXr4/777+/Wb8/oLzalnsA4MCzZcuWuOGGG2LGjBkxZsyYiIg4/PDD49RTT43bb789Pvzww7j77rujU6dOERExY8aMOO+88+Laa6+Nnj17Nvj/+da3vhXjxo2LiIgrr7wyfv/738e8efPimGOOie7du0dExKGHHhq9evVq5u8QKDePwADNbvny5VFbWxtnnHFGvdcNHTq0FC8REaecckrs2rUrVq5c2aj/Z8iQIaV/V1RURK9evWLjxo1NHxxIQ8AAza5jx457va4oiqioqKj3ut2Xt2nTJoqiqHPdxx9/vMft27Vrt8f9d+3a1dhxgYQEDNDsjjzyyOjYsWM8/vjje1w3cODAWLx4cWzdurV02dNPPx1t2rSJo446KiIiunfvHuvXry9dv3Pnzli6dGmjZmjfvn3pvsCBR8AAza5Dhw5x5ZVXxi9+8Yu4++6749VXX40FCxbEHXfcET/84Q+jQ4cOMWbMmFi6dGnMnTs3JkyYED/60Y9Kz38ZOXJkzJ49O2bPnh0rVqyIcePGxaZNmxo1Q48ePaJjx44xZ86cePvtt2Pz5s374DsFykXAAPvEVVddFZMmTYrf/OY3ceyxx8aFF14YGzdujIMPPjgefvjheO+992L48OHx3e9+N84444yYMWNG6b5jx46NMWPGxI9//OM47bTTYsCAAfGNb3yjUf9/27Zt48Ybb4zbbrstqqurY/To0c39LQJlVFF88hfNAAD7OY/AAADpCBgAIB0BAwCkI2AAgHQEDACQjoABANIRMABAOgIGAEhHwAAA6QgYACAdAQMApPP/AOfQIr5Q2mUaAAAAAElFTkSuQmCC",
|
| 462 |
+
"text/plain": [
|
| 463 |
+
"<Figure size 640x480 with 1 Axes>"
|
| 464 |
+
]
|
| 465 |
+
},
|
| 466 |
+
"metadata": {},
|
| 467 |
+
"output_type": "display_data"
|
| 468 |
+
}
|
| 469 |
+
],
|
| 470 |
+
"source": [
|
| 471 |
+
"test_df = describe_dataset(TEST_DATASET_PATH)\n",
|
| 472 |
+
"sns.countplot(y='label', data=test_df, palette=\"Set1\")"
|
| 473 |
+
]
|
| 474 |
+
},
|
| 475 |
+
{
|
| 476 |
+
"cell_type": "code",
|
| 477 |
+
"execution_count": null,
|
| 478 |
+
"metadata": {},
|
| 479 |
+
"outputs": [],
|
| 480 |
+
"source": []
|
| 481 |
+
}
|
| 482 |
+
],
|
| 483 |
+
"metadata": {
|
| 484 |
+
"kernelspec": {
|
| 485 |
+
"display_name": "Python 3.8.13 (conda)",
|
| 486 |
+
"language": "python",
|
| 487 |
+
"name": "python3"
|
| 488 |
+
},
|
| 489 |
+
"language_info": {
|
| 490 |
+
"codemirror_mode": {
|
| 491 |
+
"name": "ipython",
|
| 492 |
+
"version": 3
|
| 493 |
+
},
|
| 494 |
+
"file_extension": ".py",
|
| 495 |
+
"mimetype": "text/x-python",
|
| 496 |
+
"name": "python",
|
| 497 |
+
"nbconvert_exporter": "python",
|
| 498 |
+
"pygments_lexer": "ipython3",
|
| 499 |
+
"version": "3.8.13"
|
| 500 |
+
},
|
| 501 |
+
"orig_nbformat": 4,
|
| 502 |
+
"vscode": {
|
| 503 |
+
"interpreter": {
|
| 504 |
+
"hash": "9260f401923fb5c4108c543a7d176de9733d378b3752e49535ad7c43c2271b65"
|
| 505 |
+
}
|
| 506 |
+
}
|
| 507 |
+
},
|
| 508 |
+
"nbformat": 4,
|
| 509 |
+
"nbformat_minor": 2
|
| 510 |
+
}
|
core/bicep_model/2.sklearn.ipynb
ADDED
|
@@ -0,0 +1,413 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": 1,
|
| 6 |
+
"metadata": {},
|
| 7 |
+
"outputs": [
|
| 8 |
+
{
|
| 9 |
+
"name": "stderr",
|
| 10 |
+
"output_type": "stream",
|
| 11 |
+
"text": [
|
| 12 |
+
"objc[49355]: Class CaptureDelegate is implemented in both /Users/fuixlabsdev1/Programming/PP/graduation-thesis/env/lib/python3.8/site-packages/mediapipe/.dylibs/libopencv_videoio.3.4.16.dylib (0x108688860) and /Users/fuixlabsdev1/Programming/PP/graduation-thesis/env/lib/python3.8/site-packages/cv2/cv2.abi3.so (0x160ece480). One of the two will be used. Which one is undefined.\n",
|
| 13 |
+
"objc[49355]: Class CVWindow is implemented in both /Users/fuixlabsdev1/Programming/PP/graduation-thesis/env/lib/python3.8/site-packages/mediapipe/.dylibs/libopencv_highgui.3.4.16.dylib (0x103440a68) and /Users/fuixlabsdev1/Programming/PP/graduation-thesis/env/lib/python3.8/site-packages/cv2/cv2.abi3.so (0x160ece4d0). One of the two will be used. Which one is undefined.\n",
|
| 14 |
+
"objc[49355]: Class CVView is implemented in both /Users/fuixlabsdev1/Programming/PP/graduation-thesis/env/lib/python3.8/site-packages/mediapipe/.dylibs/libopencv_highgui.3.4.16.dylib (0x103440a90) and /Users/fuixlabsdev1/Programming/PP/graduation-thesis/env/lib/python3.8/site-packages/cv2/cv2.abi3.so (0x160ece4f8). One of the two will be used. Which one is undefined.\n",
|
| 15 |
+
"objc[49355]: Class CVSlider is implemented in both /Users/fuixlabsdev1/Programming/PP/graduation-thesis/env/lib/python3.8/site-packages/mediapipe/.dylibs/libopencv_highgui.3.4.16.dylib (0x103440ab8) and /Users/fuixlabsdev1/Programming/PP/graduation-thesis/env/lib/python3.8/site-packages/cv2/cv2.abi3.so (0x160ece520). One of the two will be used. Which one is undefined.\n"
|
| 16 |
+
]
|
| 17 |
+
}
|
| 18 |
+
],
|
| 19 |
+
"source": [
|
| 20 |
+
"import mediapipe as mp\n",
|
| 21 |
+
"import cv2\n",
|
| 22 |
+
"import pandas as pd\n",
|
| 23 |
+
"import pickle\n",
|
| 24 |
+
"\n",
|
| 25 |
+
"from sklearn.model_selection import train_test_split\n",
|
| 26 |
+
"from sklearn.linear_model import LogisticRegression, SGDClassifier\n",
|
| 27 |
+
"from sklearn.svm import SVC\n",
|
| 28 |
+
"from sklearn.neighbors import KNeighborsClassifier\n",
|
| 29 |
+
"from sklearn.tree import DecisionTreeClassifier\n",
|
| 30 |
+
"from sklearn.ensemble import RandomForestClassifier\n",
|
| 31 |
+
"from sklearn.naive_bayes import GaussianNB\n",
|
| 32 |
+
"from sklearn.metrics import precision_score, accuracy_score, f1_score, recall_score, confusion_matrix\n",
|
| 33 |
+
"from sklearn.preprocessing import StandardScaler\n",
|
| 34 |
+
"from sklearn.calibration import CalibratedClassifierCV\n",
|
| 35 |
+
"\n",
|
| 36 |
+
"import warnings\n",
|
| 37 |
+
"warnings.filterwarnings('ignore')\n",
|
| 38 |
+
"\n",
|
| 39 |
+
"# Drawing helpers\n",
|
| 40 |
+
"mp_drawing = mp.solutions.drawing_utils\n",
|
| 41 |
+
"mp_pose = mp.solutions.pose"
|
| 42 |
+
]
|
| 43 |
+
},
|
| 44 |
+
{
|
| 45 |
+
"cell_type": "markdown",
|
| 46 |
+
"metadata": {},
|
| 47 |
+
"source": [
|
| 48 |
+
"### 1. Train model"
|
| 49 |
+
]
|
| 50 |
+
},
|
| 51 |
+
{
|
| 52 |
+
"cell_type": "markdown",
|
| 53 |
+
"metadata": {},
|
| 54 |
+
"source": [
|
| 55 |
+
"#### 1.1. Describe data and split dataset"
|
| 56 |
+
]
|
| 57 |
+
},
|
| 58 |
+
{
|
| 59 |
+
"cell_type": "code",
|
| 60 |
+
"execution_count": 2,
|
| 61 |
+
"metadata": {},
|
| 62 |
+
"outputs": [],
|
| 63 |
+
"source": [
|
| 64 |
+
"def rescale_frame(frame, percent=50):\n",
|
| 65 |
+
" '''\n",
|
| 66 |
+
" Rescale a frame to a certain percentage compare to its original frame\n",
|
| 67 |
+
" '''\n",
|
| 68 |
+
" width = int(frame.shape[1] * percent/ 100)\n",
|
| 69 |
+
" height = int(frame.shape[0] * percent/ 100)\n",
|
| 70 |
+
" dim = (width, height)\n",
|
| 71 |
+
" return cv2.resize(frame, dim, interpolation = cv2.INTER_AREA)\n",
|
| 72 |
+
"\n",
|
| 73 |
+
"\n",
|
| 74 |
+
"def describe_dataset(dataset_path: str):\n",
|
| 75 |
+
" '''\n",
|
| 76 |
+
" Describe dataset\n",
|
| 77 |
+
" '''\n",
|
| 78 |
+
"\n",
|
| 79 |
+
" data = pd.read_csv(dataset_path)\n",
|
| 80 |
+
" print(f\"Headers: {list(data.columns.values)}\")\n",
|
| 81 |
+
" print(f'Number of rows: {data.shape[0]} \\nNumber of columns: {data.shape[1]}\\n')\n",
|
| 82 |
+
" print(f\"Labels: \\n{data['label'].value_counts()}\\n\")\n",
|
| 83 |
+
" print(f\"Missing values: {data.isnull().values.any()}\\n\")\n",
|
| 84 |
+
" \n",
|
| 85 |
+
" duplicate = data[data.duplicated()]\n",
|
| 86 |
+
" print(f\"Duplicate Rows : {len(duplicate.sum(axis=1))}\")\n",
|
| 87 |
+
"\n",
|
| 88 |
+
" return data\n",
|
| 89 |
+
"\n",
|
| 90 |
+
"\n",
|
| 91 |
+
"def round_up_metric_results(results) -> list:\n",
|
| 92 |
+
" '''Round up metrics results such as precision score, recall score, ...'''\n",
|
| 93 |
+
" return list(map(lambda el: round(el, 3), results))"
|
| 94 |
+
]
|
| 95 |
+
},
|
| 96 |
+
{
|
| 97 |
+
"cell_type": "code",
|
| 98 |
+
"execution_count": 3,
|
| 99 |
+
"metadata": {},
|
| 100 |
+
"outputs": [
|
| 101 |
+
{
|
| 102 |
+
"name": "stdout",
|
| 103 |
+
"output_type": "stream",
|
| 104 |
+
"text": [
|
| 105 |
+
"Headers: ['label', 'nose_x', 'nose_y', 'nose_z', 'nose_v', 'left_shoulder_x', 'left_shoulder_y', 'left_shoulder_z', 'left_shoulder_v', 'right_shoulder_x', 'right_shoulder_y', 'right_shoulder_z', 'right_shoulder_v', 'right_elbow_x', 'right_elbow_y', 'right_elbow_z', 'right_elbow_v', 'left_elbow_x', 'left_elbow_y', 'left_elbow_z', 'left_elbow_v', 'right_wrist_x', 'right_wrist_y', 'right_wrist_z', 'right_wrist_v', 'left_wrist_x', 'left_wrist_y', 'left_wrist_z', 'left_wrist_v', 'left_hip_x', 'left_hip_y', 'left_hip_z', 'left_hip_v', 'right_hip_x', 'right_hip_y', 'right_hip_z', 'right_hip_v']\n",
|
| 106 |
+
"Number of rows: 15372 \n",
|
| 107 |
+
"Number of columns: 37\n",
|
| 108 |
+
"\n",
|
| 109 |
+
"Labels: \n",
|
| 110 |
+
"C 8238\n",
|
| 111 |
+
"L 7134\n",
|
| 112 |
+
"Name: label, dtype: int64\n",
|
| 113 |
+
"\n",
|
| 114 |
+
"Missing values: False\n",
|
| 115 |
+
"\n",
|
| 116 |
+
"Duplicate Rows : 0\n"
|
| 117 |
+
]
|
| 118 |
+
}
|
| 119 |
+
],
|
| 120 |
+
"source": [
|
| 121 |
+
"# load dataset\n",
|
| 122 |
+
"df = describe_dataset(\"./train.csv\")\n",
|
| 123 |
+
"\n",
|
| 124 |
+
"# Categorizing label\n",
|
| 125 |
+
"df.loc[df[\"label\"] == \"C\", \"label\"] = 0\n",
|
| 126 |
+
"df.loc[df[\"label\"] == \"L\", \"label\"] = 1"
|
| 127 |
+
]
|
| 128 |
+
},
|
| 129 |
+
{
|
| 130 |
+
"cell_type": "code",
|
| 131 |
+
"execution_count": 6,
|
| 132 |
+
"metadata": {},
|
| 133 |
+
"outputs": [],
|
| 134 |
+
"source": [
|
| 135 |
+
"sc = StandardScaler()"
|
| 136 |
+
]
|
| 137 |
+
},
|
| 138 |
+
{
|
| 139 |
+
"cell_type": "code",
|
| 140 |
+
"execution_count": 4,
|
| 141 |
+
"metadata": {},
|
| 142 |
+
"outputs": [],
|
| 143 |
+
"source": [
|
| 144 |
+
"with open(\"./model/input_scaler.pkl\", \"rb\") as f:\n",
|
| 145 |
+
" sc = pickle.load(f)"
|
| 146 |
+
]
|
| 147 |
+
},
|
| 148 |
+
{
|
| 149 |
+
"cell_type": "code",
|
| 150 |
+
"execution_count": 13,
|
| 151 |
+
"metadata": {},
|
| 152 |
+
"outputs": [],
|
| 153 |
+
"source": [
|
| 154 |
+
"# Standard Scaling of features\n",
|
| 155 |
+
"x = df.drop(\"label\", axis = 1)\n",
|
| 156 |
+
"x = pd.DataFrame(sc.transform(x))\n",
|
| 157 |
+
"\n",
|
| 158 |
+
"y = df[\"label\"].astype('int')"
|
| 159 |
+
]
|
| 160 |
+
},
|
| 161 |
+
{
|
| 162 |
+
"cell_type": "code",
|
| 163 |
+
"execution_count": 14,
|
| 164 |
+
"metadata": {},
|
| 165 |
+
"outputs": [
|
| 166 |
+
{
|
| 167 |
+
"data": {
|
| 168 |
+
"text/plain": [
|
| 169 |
+
"9465 1\n",
|
| 170 |
+
"8833 0\n",
|
| 171 |
+
"6190 0\n",
|
| 172 |
+
"7645 0\n",
|
| 173 |
+
"13890 1\n",
|
| 174 |
+
" ..\n",
|
| 175 |
+
"11468 1\n",
|
| 176 |
+
"7221 1\n",
|
| 177 |
+
"1318 1\n",
|
| 178 |
+
"8915 1\n",
|
| 179 |
+
"11055 1\n",
|
| 180 |
+
"Name: label, Length: 12297, dtype: int64"
|
| 181 |
+
]
|
| 182 |
+
},
|
| 183 |
+
"execution_count": 14,
|
| 184 |
+
"metadata": {},
|
| 185 |
+
"output_type": "execute_result"
|
| 186 |
+
}
|
| 187 |
+
],
|
| 188 |
+
"source": [
|
| 189 |
+
"X_train, X_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=1234)\n",
|
| 190 |
+
"y_train"
|
| 191 |
+
]
|
| 192 |
+
},
|
| 193 |
+
{
|
| 194 |
+
"cell_type": "markdown",
|
| 195 |
+
"metadata": {},
|
| 196 |
+
"source": [
|
| 197 |
+
"#### 1.2. Train model using Scikit-learn"
|
| 198 |
+
]
|
| 199 |
+
},
|
| 200 |
+
{
|
| 201 |
+
"cell_type": "code",
|
| 202 |
+
"execution_count": 15,
|
| 203 |
+
"metadata": {},
|
| 204 |
+
"outputs": [
|
| 205 |
+
{
|
| 206 |
+
"data": {
|
| 207 |
+
"text/html": [
|
| 208 |
+
"<div>\n",
|
| 209 |
+
"<style scoped>\n",
|
| 210 |
+
" .dataframe tbody tr th:only-of-type {\n",
|
| 211 |
+
" vertical-align: middle;\n",
|
| 212 |
+
" }\n",
|
| 213 |
+
"\n",
|
| 214 |
+
" .dataframe tbody tr th {\n",
|
| 215 |
+
" vertical-align: top;\n",
|
| 216 |
+
" }\n",
|
| 217 |
+
"\n",
|
| 218 |
+
" .dataframe thead th {\n",
|
| 219 |
+
" text-align: right;\n",
|
| 220 |
+
" }\n",
|
| 221 |
+
"</style>\n",
|
| 222 |
+
"<table border=\"1\" class=\"dataframe\">\n",
|
| 223 |
+
" <thead>\n",
|
| 224 |
+
" <tr style=\"text-align: right;\">\n",
|
| 225 |
+
" <th></th>\n",
|
| 226 |
+
" <th>Model</th>\n",
|
| 227 |
+
" <th>Precision Score</th>\n",
|
| 228 |
+
" <th>Accuracy score</th>\n",
|
| 229 |
+
" <th>Recall Score</th>\n",
|
| 230 |
+
" <th>F1 score</th>\n",
|
| 231 |
+
" <th>Confusion Matrix</th>\n",
|
| 232 |
+
" </tr>\n",
|
| 233 |
+
" </thead>\n",
|
| 234 |
+
" <tbody>\n",
|
| 235 |
+
" <tr>\n",
|
| 236 |
+
" <th>0</th>\n",
|
| 237 |
+
" <td>RF</td>\n",
|
| 238 |
+
" <td>[0.999, 0.999]</td>\n",
|
| 239 |
+
" <td>0.999024</td>\n",
|
| 240 |
+
" <td>[0.999, 0.999]</td>\n",
|
| 241 |
+
" <td>[0.999, 0.999]</td>\n",
|
| 242 |
+
" <td>[[1677, 2], [1, 1395]]</td>\n",
|
| 243 |
+
" </tr>\n",
|
| 244 |
+
" <tr>\n",
|
| 245 |
+
" <th>1</th>\n",
|
| 246 |
+
" <td>KNN</td>\n",
|
| 247 |
+
" <td>[0.997, 0.999]</td>\n",
|
| 248 |
+
" <td>0.998049</td>\n",
|
| 249 |
+
" <td>[0.999, 0.996]</td>\n",
|
| 250 |
+
" <td>[0.998, 0.998]</td>\n",
|
| 251 |
+
" <td>[[1678, 1], [5, 1391]]</td>\n",
|
| 252 |
+
" </tr>\n",
|
| 253 |
+
" <tr>\n",
|
| 254 |
+
" <th>2</th>\n",
|
| 255 |
+
" <td>SVC</td>\n",
|
| 256 |
+
" <td>[0.997, 0.995]</td>\n",
|
| 257 |
+
" <td>0.996098</td>\n",
|
| 258 |
+
" <td>[0.996, 0.996]</td>\n",
|
| 259 |
+
" <td>[0.996, 0.996]</td>\n",
|
| 260 |
+
" <td>[[1672, 7], [5, 1391]]</td>\n",
|
| 261 |
+
" </tr>\n",
|
| 262 |
+
" <tr>\n",
|
| 263 |
+
" <th>3</th>\n",
|
| 264 |
+
" <td>DTC</td>\n",
|
| 265 |
+
" <td>[0.997, 0.991]</td>\n",
|
| 266 |
+
" <td>0.994146</td>\n",
|
| 267 |
+
" <td>[0.992, 0.996]</td>\n",
|
| 268 |
+
" <td>[0.995, 0.994]</td>\n",
|
| 269 |
+
" <td>[[1666, 13], [5, 1391]]</td>\n",
|
| 270 |
+
" </tr>\n",
|
| 271 |
+
" <tr>\n",
|
| 272 |
+
" <th>4</th>\n",
|
| 273 |
+
" <td>SGDC</td>\n",
|
| 274 |
+
" <td>[0.987, 0.974]</td>\n",
|
| 275 |
+
" <td>0.981463</td>\n",
|
| 276 |
+
" <td>[0.979, 0.985]</td>\n",
|
| 277 |
+
" <td>[0.983, 0.98]</td>\n",
|
| 278 |
+
" <td>[[1643, 36], [21, 1375]]</td>\n",
|
| 279 |
+
" </tr>\n",
|
| 280 |
+
" <tr>\n",
|
| 281 |
+
" <th>5</th>\n",
|
| 282 |
+
" <td>LR</td>\n",
|
| 283 |
+
" <td>[0.986, 0.975]</td>\n",
|
| 284 |
+
" <td>0.980813</td>\n",
|
| 285 |
+
" <td>[0.979, 0.983]</td>\n",
|
| 286 |
+
" <td>[0.982, 0.979]</td>\n",
|
| 287 |
+
" <td>[[1644, 35], [24, 1372]]</td>\n",
|
| 288 |
+
" </tr>\n",
|
| 289 |
+
" <tr>\n",
|
| 290 |
+
" <th>6</th>\n",
|
| 291 |
+
" <td>NB</td>\n",
|
| 292 |
+
" <td>[0.927, 0.842]</td>\n",
|
| 293 |
+
" <td>0.884878</td>\n",
|
| 294 |
+
" <td>[0.857, 0.918]</td>\n",
|
| 295 |
+
" <td>[0.89, 0.879]</td>\n",
|
| 296 |
+
" <td>[[1439, 240], [114, 1282]]</td>\n",
|
| 297 |
+
" </tr>\n",
|
| 298 |
+
" </tbody>\n",
|
| 299 |
+
"</table>\n",
|
| 300 |
+
"</div>"
|
| 301 |
+
],
|
| 302 |
+
"text/plain": [
|
| 303 |
+
" Model Precision Score Accuracy score Recall Score F1 score \\\n",
|
| 304 |
+
"0 RF [0.999, 0.999] 0.999024 [0.999, 0.999] [0.999, 0.999] \n",
|
| 305 |
+
"1 KNN [0.997, 0.999] 0.998049 [0.999, 0.996] [0.998, 0.998] \n",
|
| 306 |
+
"2 SVC [0.997, 0.995] 0.996098 [0.996, 0.996] [0.996, 0.996] \n",
|
| 307 |
+
"3 DTC [0.997, 0.991] 0.994146 [0.992, 0.996] [0.995, 0.994] \n",
|
| 308 |
+
"4 SGDC [0.987, 0.974] 0.981463 [0.979, 0.985] [0.983, 0.98] \n",
|
| 309 |
+
"5 LR [0.986, 0.975] 0.980813 [0.979, 0.983] [0.982, 0.979] \n",
|
| 310 |
+
"6 NB [0.927, 0.842] 0.884878 [0.857, 0.918] [0.89, 0.879] \n",
|
| 311 |
+
"\n",
|
| 312 |
+
" Confusion Matrix \n",
|
| 313 |
+
"0 [[1677, 2], [1, 1395]] \n",
|
| 314 |
+
"1 [[1678, 1], [5, 1391]] \n",
|
| 315 |
+
"2 [[1672, 7], [5, 1391]] \n",
|
| 316 |
+
"3 [[1666, 13], [5, 1391]] \n",
|
| 317 |
+
"4 [[1643, 36], [21, 1375]] \n",
|
| 318 |
+
"5 [[1644, 35], [24, 1372]] \n",
|
| 319 |
+
"6 [[1439, 240], [114, 1282]] "
|
| 320 |
+
]
|
| 321 |
+
},
|
| 322 |
+
"execution_count": 15,
|
| 323 |
+
"metadata": {},
|
| 324 |
+
"output_type": "execute_result"
|
| 325 |
+
}
|
| 326 |
+
],
|
| 327 |
+
"source": [
|
| 328 |
+
"algorithms =[(\"LR\", LogisticRegression()),\n",
|
| 329 |
+
" (\"SVC\", SVC(probability=True)),\n",
|
| 330 |
+
" ('KNN',KNeighborsClassifier()),\n",
|
| 331 |
+
" (\"DTC\", DecisionTreeClassifier()),\n",
|
| 332 |
+
" (\"SGDC\", CalibratedClassifierCV(SGDClassifier())),\n",
|
| 333 |
+
" (\"NB\", GaussianNB()),\n",
|
| 334 |
+
" ('RF', RandomForestClassifier()),]\n",
|
| 335 |
+
"\n",
|
| 336 |
+
"models = {}\n",
|
| 337 |
+
"final_results = []\n",
|
| 338 |
+
"\n",
|
| 339 |
+
"for name, model in algorithms:\n",
|
| 340 |
+
" trained_model = model.fit(X_train, y_train)\n",
|
| 341 |
+
" models[name] = trained_model\n",
|
| 342 |
+
"\n",
|
| 343 |
+
" # Evaluate model\n",
|
| 344 |
+
" model_results = model.predict(X_test)\n",
|
| 345 |
+
"\n",
|
| 346 |
+
" p_score = precision_score(y_test, model_results, average=None, labels=[0, 1])\n",
|
| 347 |
+
" a_score = accuracy_score(y_test, model_results)\n",
|
| 348 |
+
" r_score = recall_score(y_test, model_results, average=None, labels=[0, 1])\n",
|
| 349 |
+
" f1_score_result = f1_score(y_test, model_results, average=None, labels=[0, 1])\n",
|
| 350 |
+
" cm = confusion_matrix(y_test, model_results, labels=[0, 1])\n",
|
| 351 |
+
" final_results.append(( name, round_up_metric_results(p_score), a_score, round_up_metric_results(r_score), round_up_metric_results(f1_score_result), cm))\n",
|
| 352 |
+
"\n",
|
| 353 |
+
"# Sort results by F1 score\n",
|
| 354 |
+
"final_results.sort(key=lambda k: sum(k[4]), reverse=True)\n",
|
| 355 |
+
"pd.DataFrame(final_results, columns=[\"Model\", \"Precision Score\", \"Accuracy score\", \"Recall Score\", \"F1 score\", \"Confusion Matrix\"])"
|
| 356 |
+
]
|
| 357 |
+
},
|
| 358 |
+
{
|
| 359 |
+
"cell_type": "markdown",
|
| 360 |
+
"metadata": {},
|
| 361 |
+
"source": [
|
| 362 |
+
"#### 1.3. Dump models pickle"
|
| 363 |
+
]
|
| 364 |
+
},
|
| 365 |
+
{
|
| 366 |
+
"cell_type": "code",
|
| 367 |
+
"execution_count": 16,
|
| 368 |
+
"metadata": {},
|
| 369 |
+
"outputs": [],
|
| 370 |
+
"source": [
|
| 371 |
+
"with open(\"./model/all_sklearn.pkl\", \"wb\") as f:\n",
|
| 372 |
+
" pickle.dump(models, f)"
|
| 373 |
+
]
|
| 374 |
+
},
|
| 375 |
+
{
|
| 376 |
+
"cell_type": "code",
|
| 377 |
+
"execution_count": 43,
|
| 378 |
+
"metadata": {},
|
| 379 |
+
"outputs": [],
|
| 380 |
+
"source": [
|
| 381 |
+
"with open(\"./model/input_scaler.pkl\", \"wb\") as f:\n",
|
| 382 |
+
" pickle.dump(sc, f)"
|
| 383 |
+
]
|
| 384 |
+
}
|
| 385 |
+
],
|
| 386 |
+
"metadata": {
|
| 387 |
+
"kernelspec": {
|
| 388 |
+
"display_name": "Python 3.8.13 (conda)",
|
| 389 |
+
"language": "python",
|
| 390 |
+
"name": "python3"
|
| 391 |
+
},
|
| 392 |
+
"language_info": {
|
| 393 |
+
"codemirror_mode": {
|
| 394 |
+
"name": "ipython",
|
| 395 |
+
"version": 3
|
| 396 |
+
},
|
| 397 |
+
"file_extension": ".py",
|
| 398 |
+
"mimetype": "text/x-python",
|
| 399 |
+
"name": "python",
|
| 400 |
+
"nbconvert_exporter": "python",
|
| 401 |
+
"pygments_lexer": "ipython3",
|
| 402 |
+
"version": "3.8.13"
|
| 403 |
+
},
|
| 404 |
+
"orig_nbformat": 4,
|
| 405 |
+
"vscode": {
|
| 406 |
+
"interpreter": {
|
| 407 |
+
"hash": "9260f401923fb5c4108c543a7d176de9733d378b3752e49535ad7c43c2271b65"
|
| 408 |
+
}
|
| 409 |
+
}
|
| 410 |
+
},
|
| 411 |
+
"nbformat": 4,
|
| 412 |
+
"nbformat_minor": 2
|
| 413 |
+
}
|
core/bicep_model/3.deep_learning.ipynb
ADDED
|
@@ -0,0 +1,1312 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": 1,
|
| 6 |
+
"metadata": {},
|
| 7 |
+
"outputs": [],
|
| 8 |
+
"source": [
|
| 9 |
+
"# Data visualization\n",
|
| 10 |
+
"import numpy as np\n",
|
| 11 |
+
"import pandas as pd \n",
|
| 12 |
+
"\n",
|
| 13 |
+
"# Keras\n",
|
| 14 |
+
"from keras.models import Sequential\n",
|
| 15 |
+
"from keras.layers import Dense\n",
|
| 16 |
+
"from keras.layers import Dropout\n",
|
| 17 |
+
"from keras.optimizers import Adam\n",
|
| 18 |
+
"from keras.utils.np_utils import to_categorical\n",
|
| 19 |
+
"from keras.callbacks import EarlyStopping, TensorBoard\n",
|
| 20 |
+
"import keras_tuner as kt\n",
|
| 21 |
+
"\n",
|
| 22 |
+
"# Train-Test\n",
|
| 23 |
+
"from sklearn.model_selection import train_test_split\n",
|
| 24 |
+
"# Classification Report\n",
|
| 25 |
+
"from sklearn.metrics import confusion_matrix, precision_recall_fscore_support\n",
|
| 26 |
+
"\n",
|
| 27 |
+
"import pickle\n",
|
| 28 |
+
"\n",
|
| 29 |
+
"import warnings\n",
|
| 30 |
+
"warnings.filterwarnings('ignore')"
|
| 31 |
+
]
|
| 32 |
+
},
|
| 33 |
+
{
|
| 34 |
+
"cell_type": "markdown",
|
| 35 |
+
"metadata": {},
|
| 36 |
+
"source": [
|
| 37 |
+
"## 1. Important Landmarks and Important functions"
|
| 38 |
+
]
|
| 39 |
+
},
|
| 40 |
+
{
|
| 41 |
+
"cell_type": "code",
|
| 42 |
+
"execution_count": 2,
|
| 43 |
+
"metadata": {},
|
| 44 |
+
"outputs": [],
|
| 45 |
+
"source": [
|
| 46 |
+
"# Determine important landmarks for plank\n",
|
| 47 |
+
"IMPORTANT_LMS = [\n",
|
| 48 |
+
" \"NOSE\",\n",
|
| 49 |
+
" \"LEFT_SHOULDER\",\n",
|
| 50 |
+
" \"RIGHT_SHOULDER\",\n",
|
| 51 |
+
" \"RIGHT_ELBOW\",\n",
|
| 52 |
+
" \"LEFT_ELBOW\",\n",
|
| 53 |
+
" \"RIGHT_WRIST\",\n",
|
| 54 |
+
" \"LEFT_WRIST\",\n",
|
| 55 |
+
" \"LEFT_HIP\",\n",
|
| 56 |
+
" \"RIGHT_HIP\",\n",
|
| 57 |
+
"]\n",
|
| 58 |
+
"\n",
|
| 59 |
+
"# Generate all columns of the data frame\n",
|
| 60 |
+
"\n",
|
| 61 |
+
"HEADERS = [\"label\"] # Label column\n",
|
| 62 |
+
"\n",
|
| 63 |
+
"for lm in IMPORTANT_LMS:\n",
|
| 64 |
+
" HEADERS += [f\"{lm.lower()}_x\", f\"{lm.lower()}_y\", f\"{lm.lower()}_z\", f\"{lm.lower()}_v\"]"
|
| 65 |
+
]
|
| 66 |
+
},
|
| 67 |
+
{
|
| 68 |
+
"cell_type": "code",
|
| 69 |
+
"execution_count": 12,
|
| 70 |
+
"metadata": {},
|
| 71 |
+
"outputs": [],
|
| 72 |
+
"source": [
|
| 73 |
+
"def describe_dataset(dataset_path: str):\n",
|
| 74 |
+
" '''\n",
|
| 75 |
+
" Describe dataset\n",
|
| 76 |
+
" '''\n",
|
| 77 |
+
"\n",
|
| 78 |
+
" data = pd.read_csv(dataset_path)\n",
|
| 79 |
+
" print(f\"Headers: {list(data.columns.values)}\")\n",
|
| 80 |
+
" print(f'Number of rows: {data.shape[0]} \\nNumber of columns: {data.shape[1]}\\n')\n",
|
| 81 |
+
" print(f\"Labels: \\n{data['label'].value_counts()}\\n\")\n",
|
| 82 |
+
" print(f\"Missing values: {data.isnull().values.any()}\\n\")\n",
|
| 83 |
+
" \n",
|
| 84 |
+
" duplicate = data[data.duplicated()]\n",
|
| 85 |
+
" print(f\"Duplicate Rows : {len(duplicate.sum(axis=1))}\")\n",
|
| 86 |
+
"\n",
|
| 87 |
+
" return data\n",
|
| 88 |
+
"\n",
|
| 89 |
+
"\n",
|
| 90 |
+
"# Remove duplicate rows (optional)\n",
|
| 91 |
+
"def remove_duplicate_rows(dataset_path: str):\n",
|
| 92 |
+
" '''\n",
|
| 93 |
+
" Remove duplicated data from the dataset then save it to another files\n",
|
| 94 |
+
" '''\n",
|
| 95 |
+
" \n",
|
| 96 |
+
" df = pd.read_csv(dataset_path)\n",
|
| 97 |
+
" df.drop_duplicates(keep=\"first\", inplace=True)\n",
|
| 98 |
+
" df.to_csv(f\"cleaned_train.csv\", sep=',', encoding='utf-8', index=False)\n",
|
| 99 |
+
"\n",
|
| 100 |
+
"\n",
|
| 101 |
+
"def round_up_metric_results(results) -> list:\n",
|
| 102 |
+
" '''Round up metrics results such as precision score, recall score, ...'''\n",
|
| 103 |
+
" return list(map(lambda el: round(el, 3), results))"
|
| 104 |
+
]
|
| 105 |
+
},
|
| 106 |
+
{
|
| 107 |
+
"cell_type": "markdown",
|
| 108 |
+
"metadata": {},
|
| 109 |
+
"source": [
|
| 110 |
+
"## 2. Describe Dataset & Split Data"
|
| 111 |
+
]
|
| 112 |
+
},
|
| 113 |
+
{
|
| 114 |
+
"cell_type": "code",
|
| 115 |
+
"execution_count": 4,
|
| 116 |
+
"metadata": {},
|
| 117 |
+
"outputs": [
|
| 118 |
+
{
|
| 119 |
+
"name": "stdout",
|
| 120 |
+
"output_type": "stream",
|
| 121 |
+
"text": [
|
| 122 |
+
"Headers: ['label', 'nose_x', 'nose_y', 'nose_z', 'nose_v', 'left_shoulder_x', 'left_shoulder_y', 'left_shoulder_z', 'left_shoulder_v', 'right_shoulder_x', 'right_shoulder_y', 'right_shoulder_z', 'right_shoulder_v', 'right_elbow_x', 'right_elbow_y', 'right_elbow_z', 'right_elbow_v', 'left_elbow_x', 'left_elbow_y', 'left_elbow_z', 'left_elbow_v', 'right_wrist_x', 'right_wrist_y', 'right_wrist_z', 'right_wrist_v', 'left_wrist_x', 'left_wrist_y', 'left_wrist_z', 'left_wrist_v', 'left_hip_x', 'left_hip_y', 'left_hip_z', 'left_hip_v', 'right_hip_x', 'right_hip_y', 'right_hip_z', 'right_hip_v']\n",
|
| 123 |
+
"Number of rows: 15372 \n",
|
| 124 |
+
"Number of columns: 37\n",
|
| 125 |
+
"\n",
|
| 126 |
+
"Labels: \n",
|
| 127 |
+
"C 8238\n",
|
| 128 |
+
"L 7134\n",
|
| 129 |
+
"Name: label, dtype: int64\n",
|
| 130 |
+
"\n",
|
| 131 |
+
"Missing values: False\n",
|
| 132 |
+
"\n",
|
| 133 |
+
"Duplicate Rows : 0\n"
|
| 134 |
+
]
|
| 135 |
+
}
|
| 136 |
+
],
|
| 137 |
+
"source": [
|
| 138 |
+
"# load dataset\n",
|
| 139 |
+
"df = describe_dataset(\"./train.csv\")\n",
|
| 140 |
+
"\n",
|
| 141 |
+
"# Categorizing label\n",
|
| 142 |
+
"df.loc[df[\"label\"] == \"C\", \"label\"] = 0\n",
|
| 143 |
+
"df.loc[df[\"label\"] == \"L\", \"label\"] = 1"
|
| 144 |
+
]
|
| 145 |
+
},
|
| 146 |
+
{
|
| 147 |
+
"cell_type": "code",
|
| 148 |
+
"execution_count": 5,
|
| 149 |
+
"metadata": {},
|
| 150 |
+
"outputs": [],
|
| 151 |
+
"source": [
|
| 152 |
+
"with open(\"./model/input_scaler.pkl\", \"rb\") as f:\n",
|
| 153 |
+
" sc = pickle.load(f)"
|
| 154 |
+
]
|
| 155 |
+
},
|
| 156 |
+
{
|
| 157 |
+
"cell_type": "code",
|
| 158 |
+
"execution_count": 6,
|
| 159 |
+
"metadata": {},
|
| 160 |
+
"outputs": [],
|
| 161 |
+
"source": [
|
| 162 |
+
"# Standard Scaling of features\n",
|
| 163 |
+
"x = df.drop(\"label\", axis = 1)\n",
|
| 164 |
+
"x = pd.DataFrame(sc.transform(x))\n",
|
| 165 |
+
"\n",
|
| 166 |
+
"y = df[\"label\"]\n",
|
| 167 |
+
"\n",
|
| 168 |
+
"# # Converting prediction to categorical\n",
|
| 169 |
+
"y_cat = to_categorical(y)"
|
| 170 |
+
]
|
| 171 |
+
},
|
| 172 |
+
{
|
| 173 |
+
"cell_type": "code",
|
| 174 |
+
"execution_count": 7,
|
| 175 |
+
"metadata": {},
|
| 176 |
+
"outputs": [],
|
| 177 |
+
"source": [
|
| 178 |
+
"x_train, x_test, y_train, y_test = train_test_split(x.values, y_cat, test_size=0.2, random_state=1234)"
|
| 179 |
+
]
|
| 180 |
+
},
|
| 181 |
+
{
|
| 182 |
+
"cell_type": "markdown",
|
| 183 |
+
"metadata": {},
|
| 184 |
+
"source": [
|
| 185 |
+
"## 3. Build Model"
|
| 186 |
+
]
|
| 187 |
+
},
|
| 188 |
+
{
|
| 189 |
+
"cell_type": "markdown",
|
| 190 |
+
"metadata": {},
|
| 191 |
+
"source": [
|
| 192 |
+
"### 3.1. Set up"
|
| 193 |
+
]
|
| 194 |
+
},
|
| 195 |
+
{
|
| 196 |
+
"cell_type": "code",
|
| 197 |
+
"execution_count": 9,
|
| 198 |
+
"metadata": {},
|
| 199 |
+
"outputs": [],
|
| 200 |
+
"source": [
|
| 201 |
+
"stop_early = EarlyStopping(monitor='val_loss', patience=3)\n",
|
| 202 |
+
"\n",
|
| 203 |
+
"# Final Results\n",
|
| 204 |
+
"final_models = {}"
|
| 205 |
+
]
|
| 206 |
+
},
|
| 207 |
+
{
|
| 208 |
+
"cell_type": "code",
|
| 209 |
+
"execution_count": 8,
|
| 210 |
+
"metadata": {},
|
| 211 |
+
"outputs": [],
|
| 212 |
+
"source": [
|
| 213 |
+
"def describe_model(model):\n",
|
| 214 |
+
" '''\n",
|
| 215 |
+
" Describe Model architecture\n",
|
| 216 |
+
" '''\n",
|
| 217 |
+
" print(f\"Describe models architecture\")\n",
|
| 218 |
+
" for i, layer in enumerate(model.layers):\n",
|
| 219 |
+
" number_of_units = layer.units if hasattr(layer, 'units') else 0\n",
|
| 220 |
+
"\n",
|
| 221 |
+
" if hasattr(layer, \"activation\"):\n",
|
| 222 |
+
" print(f\"Layer-{i + 1}: {number_of_units} units, func: \", layer.activation)\n",
|
| 223 |
+
" else:\n",
|
| 224 |
+
" print(f\"Layer-{i + 1}: {number_of_units} units, func: None\")\n",
|
| 225 |
+
"\n",
|
| 226 |
+
"\n",
|
| 227 |
+
"def get_best_model(tuner):\n",
|
| 228 |
+
" '''\n",
|
| 229 |
+
" Describe and return the best model found from keras tuner\n",
|
| 230 |
+
" '''\n",
|
| 231 |
+
" best_hps = tuner.get_best_hyperparameters(num_trials=1)[0]\n",
|
| 232 |
+
" best_model = tuner.hypermodel.build(best_hps)\n",
|
| 233 |
+
"\n",
|
| 234 |
+
" describe_model(best_model)\n",
|
| 235 |
+
"\n",
|
| 236 |
+
" print(\"\\nOther params:\")\n",
|
| 237 |
+
" ignore_params = [\"tuner\", \"activation\", \"layer\", \"epoch\"]\n",
|
| 238 |
+
" for param, value in best_hps.values.items():\n",
|
| 239 |
+
" if not any(word in param for word in ignore_params):\n",
|
| 240 |
+
" print(f\"{param}: {value}\")\n",
|
| 241 |
+
"\n",
|
| 242 |
+
" return best_model"
|
| 243 |
+
]
|
| 244 |
+
},
|
| 245 |
+
{
|
| 246 |
+
"cell_type": "markdown",
|
| 247 |
+
"metadata": {},
|
| 248 |
+
"source": [
|
| 249 |
+
"### 3.2. Model with 3 layers"
|
| 250 |
+
]
|
| 251 |
+
},
|
| 252 |
+
{
|
| 253 |
+
"cell_type": "code",
|
| 254 |
+
"execution_count": 12,
|
| 255 |
+
"metadata": {},
|
| 256 |
+
"outputs": [],
|
| 257 |
+
"source": [
|
| 258 |
+
"def model_3l_builder(hp):\n",
|
| 259 |
+
" model = Sequential()\n",
|
| 260 |
+
" model.add(Dense(36, input_dim = 36, activation = \"relu\"))\n",
|
| 261 |
+
"\n",
|
| 262 |
+
" hp_activation = hp.Choice('activation', values=['relu', 'tanh'])\n",
|
| 263 |
+
" hp_layer_1 = hp.Int('layer_1', min_value=32, max_value=512, step=32)\n",
|
| 264 |
+
" hp_learning_rate = hp.Choice('learning_rate', values=[1e-2, 1e-3, 1e-4])\n",
|
| 265 |
+
"\n",
|
| 266 |
+
" model.add(Dense(units=hp_layer_1, activation=hp_activation))\n",
|
| 267 |
+
" model.add(Dense(2, activation = \"softmax\"))\n",
|
| 268 |
+
"\n",
|
| 269 |
+
" model.compile(optimizer=Adam(learning_rate=hp_learning_rate), loss=\"categorical_crossentropy\", metrics = [\"accuracy\"])\n",
|
| 270 |
+
" \n",
|
| 271 |
+
" return model"
|
| 272 |
+
]
|
| 273 |
+
},
|
| 274 |
+
{
|
| 275 |
+
"cell_type": "code",
|
| 276 |
+
"execution_count": 17,
|
| 277 |
+
"metadata": {},
|
| 278 |
+
"outputs": [
|
| 279 |
+
{
|
| 280 |
+
"name": "stdout",
|
| 281 |
+
"output_type": "stream",
|
| 282 |
+
"text": [
|
| 283 |
+
"Trial 30 Complete [00h 00m 46s]\n",
|
| 284 |
+
"val_accuracy: 0.9980487823486328\n",
|
| 285 |
+
"\n",
|
| 286 |
+
"Best val_accuracy So Far: 0.9980487823486328\n",
|
| 287 |
+
"Total elapsed time: 00h 08m 25s\n",
|
| 288 |
+
"INFO:tensorflow:Oracle triggered exit\n"
|
| 289 |
+
]
|
| 290 |
+
}
|
| 291 |
+
],
|
| 292 |
+
"source": [
|
| 293 |
+
"tuner_3l = kt.Hyperband(\n",
|
| 294 |
+
" model_3l_builder,\n",
|
| 295 |
+
" objective='val_accuracy',\n",
|
| 296 |
+
" max_epochs=10,\n",
|
| 297 |
+
" directory='keras_tuner_dir',\n",
|
| 298 |
+
" project_name='keras_tuner_demo',\n",
|
| 299 |
+
")\n",
|
| 300 |
+
"tuner_3l.search(x_train, y_train, validation_data=(x_test, y_test), epochs=10, callbacks=[stop_early])"
|
| 301 |
+
]
|
| 302 |
+
},
|
| 303 |
+
{
|
| 304 |
+
"cell_type": "code",
|
| 305 |
+
"execution_count": 19,
|
| 306 |
+
"metadata": {},
|
| 307 |
+
"outputs": [
|
| 308 |
+
{
|
| 309 |
+
"name": "stdout",
|
| 310 |
+
"output_type": "stream",
|
| 311 |
+
"text": [
|
| 312 |
+
"Describe models architecture\n",
|
| 313 |
+
"Layer-1: 36 units, func: <function relu at 0x16bf85b80>\n",
|
| 314 |
+
"Layer-2: 448 units, func: <function tanh at 0x16bf85ee0>\n",
|
| 315 |
+
"Layer-3: 2 units, func: <function softmax at 0x16bf85160>\n",
|
| 316 |
+
"\n",
|
| 317 |
+
"Other params:\n",
|
| 318 |
+
"learning_rate: 0.001\n",
|
| 319 |
+
"Epoch 1/100\n",
|
| 320 |
+
" 5/1230 [..............................] - ETA: 19s - loss: 0.6247 - accuracy: 0.6600 "
|
| 321 |
+
]
|
| 322 |
+
},
|
| 323 |
+
{
|
| 324 |
+
"name": "stderr",
|
| 325 |
+
"output_type": "stream",
|
| 326 |
+
"text": [
|
| 327 |
+
"2022-11-23 09:54:28.588878: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:113] Plugin optimizer for device_type GPU is enabled.\n"
|
| 328 |
+
]
|
| 329 |
+
},
|
| 330 |
+
{
|
| 331 |
+
"name": "stdout",
|
| 332 |
+
"output_type": "stream",
|
| 333 |
+
"text": [
|
| 334 |
+
"1230/1230 [==============================] - ETA: 0s - loss: 0.0504 - accuracy: 0.9848"
|
| 335 |
+
]
|
| 336 |
+
},
|
| 337 |
+
{
|
| 338 |
+
"name": "stderr",
|
| 339 |
+
"output_type": "stream",
|
| 340 |
+
"text": [
|
| 341 |
+
"2022-11-23 09:54:39.929268: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:113] Plugin optimizer for device_type GPU is enabled.\n"
|
| 342 |
+
]
|
| 343 |
+
},
|
| 344 |
+
{
|
| 345 |
+
"name": "stdout",
|
| 346 |
+
"output_type": "stream",
|
| 347 |
+
"text": [
|
| 348 |
+
"1230/1230 [==============================] - 14s 11ms/step - loss: 0.0504 - accuracy: 0.9848 - val_loss: 0.0889 - val_accuracy: 0.9717\n",
|
| 349 |
+
"Epoch 2/100\n",
|
| 350 |
+
"1230/1230 [==============================] - 13s 11ms/step - loss: 0.0241 - accuracy: 0.9940 - val_loss: 0.0188 - val_accuracy: 0.9948\n",
|
| 351 |
+
"Epoch 3/100\n",
|
| 352 |
+
"1230/1230 [==============================] - 13s 10ms/step - loss: 0.0187 - accuracy: 0.9946 - val_loss: 0.0127 - val_accuracy: 0.9964\n",
|
| 353 |
+
"Epoch 4/100\n",
|
| 354 |
+
"1230/1230 [==============================] - 13s 10ms/step - loss: 0.0179 - accuracy: 0.9950 - val_loss: 0.0140 - val_accuracy: 0.9958\n",
|
| 355 |
+
"Epoch 5/100\n",
|
| 356 |
+
"1230/1230 [==============================] - 13s 11ms/step - loss: 0.0145 - accuracy: 0.9958 - val_loss: 0.0211 - val_accuracy: 0.9958\n",
|
| 357 |
+
"Epoch 6/100\n",
|
| 358 |
+
"1230/1230 [==============================] - 13s 11ms/step - loss: 0.0143 - accuracy: 0.9958 - val_loss: 0.0093 - val_accuracy: 0.9984\n",
|
| 359 |
+
"Epoch 7/100\n",
|
| 360 |
+
"1230/1230 [==============================] - 13s 11ms/step - loss: 0.0118 - accuracy: 0.9966 - val_loss: 0.0077 - val_accuracy: 0.9984\n",
|
| 361 |
+
"Epoch 8/100\n",
|
| 362 |
+
"1230/1230 [==============================] - 13s 11ms/step - loss: 0.0120 - accuracy: 0.9961 - val_loss: 0.0112 - val_accuracy: 0.9977\n",
|
| 363 |
+
"Epoch 9/100\n",
|
| 364 |
+
"1230/1230 [==============================] - 13s 11ms/step - loss: 0.0121 - accuracy: 0.9959 - val_loss: 0.0073 - val_accuracy: 0.9984\n",
|
| 365 |
+
"Epoch 10/100\n",
|
| 366 |
+
"1230/1230 [==============================] - 13s 11ms/step - loss: 0.0110 - accuracy: 0.9963 - val_loss: 0.0108 - val_accuracy: 0.9971\n",
|
| 367 |
+
"Epoch 11/100\n",
|
| 368 |
+
"1230/1230 [==============================] - 13s 11ms/step - loss: 0.0097 - accuracy: 0.9970 - val_loss: 0.0110 - val_accuracy: 0.9971\n",
|
| 369 |
+
"Epoch 12/100\n",
|
| 370 |
+
"1230/1230 [==============================] - 13s 11ms/step - loss: 0.0098 - accuracy: 0.9972 - val_loss: 0.0107 - val_accuracy: 0.9967\n"
|
| 371 |
+
]
|
| 372 |
+
},
|
| 373 |
+
{
|
| 374 |
+
"data": {
|
| 375 |
+
"text/plain": [
|
| 376 |
+
"<keras.callbacks.History at 0x296950eb0>"
|
| 377 |
+
]
|
| 378 |
+
},
|
| 379 |
+
"execution_count": 19,
|
| 380 |
+
"metadata": {},
|
| 381 |
+
"output_type": "execute_result"
|
| 382 |
+
}
|
| 383 |
+
],
|
| 384 |
+
"source": [
|
| 385 |
+
"model_3l = get_best_model(tuner_3l)\n",
|
| 386 |
+
"model_3l.fit(x_train, y_train, epochs=100, batch_size=10, validation_data=(x_test, y_test), callbacks=[stop_early])"
|
| 387 |
+
]
|
| 388 |
+
},
|
| 389 |
+
{
|
| 390 |
+
"cell_type": "code",
|
| 391 |
+
"execution_count": 20,
|
| 392 |
+
"metadata": {},
|
| 393 |
+
"outputs": [],
|
| 394 |
+
"source": [
|
| 395 |
+
"final_models[\"3_layers\"] = model_3l"
|
| 396 |
+
]
|
| 397 |
+
},
|
| 398 |
+
{
|
| 399 |
+
"cell_type": "markdown",
|
| 400 |
+
"metadata": {},
|
| 401 |
+
"source": [
|
| 402 |
+
"### 3.3. Model with 5 layers"
|
| 403 |
+
]
|
| 404 |
+
},
|
| 405 |
+
{
|
| 406 |
+
"cell_type": "code",
|
| 407 |
+
"execution_count": 21,
|
| 408 |
+
"metadata": {},
|
| 409 |
+
"outputs": [],
|
| 410 |
+
"source": [
|
| 411 |
+
"def model_5l_builder(hp):\n",
|
| 412 |
+
" model = Sequential()\n",
|
| 413 |
+
" model.add(Dense(36, input_dim = 36, activation = \"relu\"))\n",
|
| 414 |
+
"\n",
|
| 415 |
+
" hp_activation = hp.Choice('activation', values=['relu', 'tanh'])\n",
|
| 416 |
+
" hp_layer_1 = hp.Int('layer_1', min_value=32, max_value=512, step=32)\n",
|
| 417 |
+
" hp_layer_2 = hp.Int('layer_2', min_value=32, max_value=512, step=32)\n",
|
| 418 |
+
" hp_layer_3 = hp.Int('layer_3', min_value=32, max_value=512, step=32)\n",
|
| 419 |
+
" hp_learning_rate = hp.Choice('learning_rate', values=[1e-2, 1e-3, 1e-4])\n",
|
| 420 |
+
"\n",
|
| 421 |
+
" model.add(Dense(units=hp_layer_1, activation=hp_activation))\n",
|
| 422 |
+
" model.add(Dense(units=hp_layer_2, activation=hp_activation))\n",
|
| 423 |
+
" model.add(Dense(units=hp_layer_3, activation=hp_activation))\n",
|
| 424 |
+
" model.add(Dense(2, activation = \"softmax\"))\n",
|
| 425 |
+
"\n",
|
| 426 |
+
" model.compile(optimizer=Adam(learning_rate=hp_learning_rate), loss=\"categorical_crossentropy\", metrics = [\"accuracy\"])\n",
|
| 427 |
+
" \n",
|
| 428 |
+
" return model"
|
| 429 |
+
]
|
| 430 |
+
},
|
| 431 |
+
{
|
| 432 |
+
"cell_type": "code",
|
| 433 |
+
"execution_count": 24,
|
| 434 |
+
"metadata": {},
|
| 435 |
+
"outputs": [
|
| 436 |
+
{
|
| 437 |
+
"name": "stdout",
|
| 438 |
+
"output_type": "stream",
|
| 439 |
+
"text": [
|
| 440 |
+
"Trial 30 Complete [00h 00m 54s]\n",
|
| 441 |
+
"val_accuracy: 0.9973983764648438\n",
|
| 442 |
+
"\n",
|
| 443 |
+
"Best val_accuracy So Far: 0.9986991882324219\n",
|
| 444 |
+
"Total elapsed time: 00h 11m 12s\n",
|
| 445 |
+
"INFO:tensorflow:Oracle triggered exit\n"
|
| 446 |
+
]
|
| 447 |
+
}
|
| 448 |
+
],
|
| 449 |
+
"source": [
|
| 450 |
+
"tuner_5l = kt.Hyperband(\n",
|
| 451 |
+
" model_5l_builder,\n",
|
| 452 |
+
" objective='val_accuracy',\n",
|
| 453 |
+
" max_epochs=10,\n",
|
| 454 |
+
" directory='keras_tuner_dir',\n",
|
| 455 |
+
" project_name='keras_tuner_demo_2'\n",
|
| 456 |
+
")\n",
|
| 457 |
+
"tuner_5l.search(x_train, y_train, validation_data=(x_test, y_test), epochs=10, callbacks=[stop_early, TensorBoard(\"./keras_tuner_dir/logs\")])"
|
| 458 |
+
]
|
| 459 |
+
},
|
| 460 |
+
{
|
| 461 |
+
"cell_type": "code",
|
| 462 |
+
"execution_count": 25,
|
| 463 |
+
"metadata": {},
|
| 464 |
+
"outputs": [
|
| 465 |
+
{
|
| 466 |
+
"name": "stdout",
|
| 467 |
+
"output_type": "stream",
|
| 468 |
+
"text": [
|
| 469 |
+
"Describe models architecture\n",
|
| 470 |
+
"Layer-1: 36 units, func: <function relu at 0x16bf85b80>\n",
|
| 471 |
+
"Layer-2: 160 units, func: <function relu at 0x16bf85b80>\n",
|
| 472 |
+
"Layer-3: 352 units, func: <function relu at 0x16bf85b80>\n",
|
| 473 |
+
"Layer-4: 64 units, func: <function relu at 0x16bf85b80>\n",
|
| 474 |
+
"Layer-5: 2 units, func: <function softmax at 0x16bf85160>\n",
|
| 475 |
+
"\n",
|
| 476 |
+
"Other params:\n",
|
| 477 |
+
"learning_rate: 0.001\n",
|
| 478 |
+
"Epoch 1/100\n"
|
| 479 |
+
]
|
| 480 |
+
},
|
| 481 |
+
{
|
| 482 |
+
"name": "stderr",
|
| 483 |
+
"output_type": "stream",
|
| 484 |
+
"text": [
|
| 485 |
+
"2022-11-23 10:15:07.538823: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:113] Plugin optimizer for device_type GPU is enabled.\n"
|
| 486 |
+
]
|
| 487 |
+
},
|
| 488 |
+
{
|
| 489 |
+
"name": "stdout",
|
| 490 |
+
"output_type": "stream",
|
| 491 |
+
"text": [
|
| 492 |
+
"1230/1230 [==============================] - ETA: 0s - loss: 0.0494 - accuracy: 0.9848"
|
| 493 |
+
]
|
| 494 |
+
},
|
| 495 |
+
{
|
| 496 |
+
"name": "stderr",
|
| 497 |
+
"output_type": "stream",
|
| 498 |
+
"text": [
|
| 499 |
+
"2022-11-23 10:15:21.397335: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:113] Plugin optimizer for device_type GPU is enabled.\n"
|
| 500 |
+
]
|
| 501 |
+
},
|
| 502 |
+
{
|
| 503 |
+
"name": "stdout",
|
| 504 |
+
"output_type": "stream",
|
| 505 |
+
"text": [
|
| 506 |
+
"1230/1230 [==============================] - 16s 13ms/step - loss: 0.0494 - accuracy: 0.9848 - val_loss: 0.0152 - val_accuracy: 0.9958\n",
|
| 507 |
+
"Epoch 2/100\n",
|
| 508 |
+
"1230/1230 [==============================] - 16s 13ms/step - loss: 0.0238 - accuracy: 0.9932 - val_loss: 0.0145 - val_accuracy: 0.9954\n",
|
| 509 |
+
"Epoch 3/100\n",
|
| 510 |
+
"1230/1230 [==============================] - 16s 13ms/step - loss: 0.0193 - accuracy: 0.9947 - val_loss: 0.0146 - val_accuracy: 0.9971\n",
|
| 511 |
+
"Epoch 4/100\n",
|
| 512 |
+
"1230/1230 [==============================] - 16s 13ms/step - loss: 0.0169 - accuracy: 0.9950 - val_loss: 0.0140 - val_accuracy: 0.9964\n",
|
| 513 |
+
"Epoch 5/100\n",
|
| 514 |
+
"1230/1230 [==============================] - 15s 13ms/step - loss: 0.0160 - accuracy: 0.9960 - val_loss: 0.0154 - val_accuracy: 0.9964\n",
|
| 515 |
+
"Epoch 6/100\n",
|
| 516 |
+
"1230/1230 [==============================] - 16s 13ms/step - loss: 0.0135 - accuracy: 0.9963 - val_loss: 0.0126 - val_accuracy: 0.9961\n",
|
| 517 |
+
"Epoch 7/100\n",
|
| 518 |
+
"1230/1230 [==============================] - 16s 13ms/step - loss: 0.0126 - accuracy: 0.9960 - val_loss: 0.0098 - val_accuracy: 0.9971\n",
|
| 519 |
+
"Epoch 8/100\n",
|
| 520 |
+
"1230/1230 [==============================] - 16s 13ms/step - loss: 0.0106 - accuracy: 0.9966 - val_loss: 0.0090 - val_accuracy: 0.9971\n",
|
| 521 |
+
"Epoch 9/100\n",
|
| 522 |
+
"1230/1230 [==============================] - 16s 13ms/step - loss: 0.0110 - accuracy: 0.9966 - val_loss: 0.0146 - val_accuracy: 0.9974\n",
|
| 523 |
+
"Epoch 10/100\n",
|
| 524 |
+
"1230/1230 [==============================] - 16s 13ms/step - loss: 0.0098 - accuracy: 0.9963 - val_loss: 0.0257 - val_accuracy: 0.9922\n",
|
| 525 |
+
"Epoch 11/100\n",
|
| 526 |
+
"1230/1230 [==============================] - 16s 13ms/step - loss: 0.0106 - accuracy: 0.9968 - val_loss: 0.0138 - val_accuracy: 0.9961\n"
|
| 527 |
+
]
|
| 528 |
+
},
|
| 529 |
+
{
|
| 530 |
+
"data": {
|
| 531 |
+
"text/plain": [
|
| 532 |
+
"<keras.callbacks.History at 0x28a166fa0>"
|
| 533 |
+
]
|
| 534 |
+
},
|
| 535 |
+
"execution_count": 25,
|
| 536 |
+
"metadata": {},
|
| 537 |
+
"output_type": "execute_result"
|
| 538 |
+
}
|
| 539 |
+
],
|
| 540 |
+
"source": [
|
| 541 |
+
"model_5l = get_best_model(tuner_5l)\n",
|
| 542 |
+
"model_5l.fit(x_train, y_train, epochs=100, batch_size=10, validation_data=(x_test, y_test), callbacks=[stop_early])"
|
| 543 |
+
]
|
| 544 |
+
},
|
| 545 |
+
{
|
| 546 |
+
"cell_type": "code",
|
| 547 |
+
"execution_count": 26,
|
| 548 |
+
"metadata": {},
|
| 549 |
+
"outputs": [],
|
| 550 |
+
"source": [
|
| 551 |
+
"final_models[\"5_layers\"] = model_5l"
|
| 552 |
+
]
|
| 553 |
+
},
|
| 554 |
+
{
|
| 555 |
+
"cell_type": "markdown",
|
| 556 |
+
"metadata": {},
|
| 557 |
+
"source": [
|
| 558 |
+
"### 3.4. Model with 7 layers include Dropout"
|
| 559 |
+
]
|
| 560 |
+
},
|
| 561 |
+
{
|
| 562 |
+
"cell_type": "code",
|
| 563 |
+
"execution_count": 27,
|
| 564 |
+
"metadata": {},
|
| 565 |
+
"outputs": [],
|
| 566 |
+
"source": [
|
| 567 |
+
"def model_7lD_builder(hp):\n",
|
| 568 |
+
" model = Sequential()\n",
|
| 569 |
+
" model.add(Dense(36, input_dim = 36, activation = \"relu\"))\n",
|
| 570 |
+
"\n",
|
| 571 |
+
" hp_activation = hp.Choice('activation', values=['relu', 'tanh'])\n",
|
| 572 |
+
" hp_layer_1 = hp.Int('layer_1', min_value=32, max_value=512, step=32)\n",
|
| 573 |
+
" hp_layer_2 = hp.Int('layer_2', min_value=32, max_value=512, step=32)\n",
|
| 574 |
+
" hp_layer_3 = hp.Int('layer_3', min_value=32, max_value=512, step=32)\n",
|
| 575 |
+
" hp_dropout_1 = hp.Float('dropout_1', min_value=0.1, max_value=0.5, step=0.1)\n",
|
| 576 |
+
" hp_dropout_2 = hp.Float('dropout_2', min_value=0.1, max_value=0.5, step=0.1)\n",
|
| 577 |
+
" hp_learning_rate = hp.Choice('learning_rate', values=[1e-2, 1e-3, 1e-4])\n",
|
| 578 |
+
"\n",
|
| 579 |
+
" model.add(Dense(units=hp_layer_1, activation=hp_activation))\n",
|
| 580 |
+
" model.add(Dropout(rate=hp_dropout_1))\n",
|
| 581 |
+
" model.add(Dense(units=hp_layer_2, activation=hp_activation))\n",
|
| 582 |
+
" model.add(Dropout(rate=hp_dropout_2))\n",
|
| 583 |
+
" model.add(Dense(units=hp_layer_3, activation=hp_activation))\n",
|
| 584 |
+
" model.add(Dense(2, activation = \"softmax\"))\n",
|
| 585 |
+
"\n",
|
| 586 |
+
" model.compile(optimizer=Adam(learning_rate=hp_learning_rate), loss=\"categorical_crossentropy\", metrics = [\"accuracy\"])\n",
|
| 587 |
+
" \n",
|
| 588 |
+
" return model"
|
| 589 |
+
]
|
| 590 |
+
},
|
| 591 |
+
{
|
| 592 |
+
"cell_type": "code",
|
| 593 |
+
"execution_count": 28,
|
| 594 |
+
"metadata": {},
|
| 595 |
+
"outputs": [
|
| 596 |
+
{
|
| 597 |
+
"name": "stdout",
|
| 598 |
+
"output_type": "stream",
|
| 599 |
+
"text": [
|
| 600 |
+
"Trial 30 Complete [00h 01m 04s]\n",
|
| 601 |
+
"accuracy: 0.9945515394210815\n",
|
| 602 |
+
"\n",
|
| 603 |
+
"Best accuracy So Far: 0.9969098567962646\n",
|
| 604 |
+
"Total elapsed time: 00h 12m 19s\n",
|
| 605 |
+
"INFO:tensorflow:Oracle triggered exit\n"
|
| 606 |
+
]
|
| 607 |
+
}
|
| 608 |
+
],
|
| 609 |
+
"source": [
|
| 610 |
+
"tuner_7lD = kt.Hyperband(\n",
|
| 611 |
+
" model_7lD_builder,\n",
|
| 612 |
+
" objective='accuracy',\n",
|
| 613 |
+
" max_epochs=10,\n",
|
| 614 |
+
" directory='keras_tuner_dir',\n",
|
| 615 |
+
" project_name='keras_tuner_demo_3'\n",
|
| 616 |
+
")\n",
|
| 617 |
+
"tuner_7lD.search(x_train, y_train, validation_data=(x_test, y_test), epochs=10, callbacks=[stop_early, TensorBoard(\"./keras_tuner_dir/logs\")])"
|
| 618 |
+
]
|
| 619 |
+
},
|
| 620 |
+
{
|
| 621 |
+
"cell_type": "code",
|
| 622 |
+
"execution_count": 29,
|
| 623 |
+
"metadata": {},
|
| 624 |
+
"outputs": [
|
| 625 |
+
{
|
| 626 |
+
"name": "stdout",
|
| 627 |
+
"output_type": "stream",
|
| 628 |
+
"text": [
|
| 629 |
+
"Describe models architecture\n",
|
| 630 |
+
"Layer-1: 36 units, func: <function relu at 0x16bf85b80>\n",
|
| 631 |
+
"Layer-2: 320 units, func: <function relu at 0x16bf85b80>\n",
|
| 632 |
+
"Layer-3: 0 units, func: None\n",
|
| 633 |
+
"Layer-4: 96 units, func: <function relu at 0x16bf85b80>\n",
|
| 634 |
+
"Layer-5: 0 units, func: None\n",
|
| 635 |
+
"Layer-6: 448 units, func: <function relu at 0x16bf85b80>\n",
|
| 636 |
+
"Layer-7: 2 units, func: <function softmax at 0x16bf85160>\n",
|
| 637 |
+
"\n",
|
| 638 |
+
"Other params:\n",
|
| 639 |
+
"dropout_1: 0.30000000000000004\n",
|
| 640 |
+
"dropout_2: 0.30000000000000004\n",
|
| 641 |
+
"learning_rate: 0.001\n",
|
| 642 |
+
"Epoch 1/100\n"
|
| 643 |
+
]
|
| 644 |
+
},
|
| 645 |
+
{
|
| 646 |
+
"name": "stderr",
|
| 647 |
+
"output_type": "stream",
|
| 648 |
+
"text": [
|
| 649 |
+
"2022-11-23 10:37:14.947724: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:113] Plugin optimizer for device_type GPU is enabled.\n"
|
| 650 |
+
]
|
| 651 |
+
},
|
| 652 |
+
{
|
| 653 |
+
"name": "stdout",
|
| 654 |
+
"output_type": "stream",
|
| 655 |
+
"text": [
|
| 656 |
+
"1230/1230 [==============================] - ETA: 0s - loss: 0.0592 - accuracy: 0.9811"
|
| 657 |
+
]
|
| 658 |
+
},
|
| 659 |
+
{
|
| 660 |
+
"name": "stderr",
|
| 661 |
+
"output_type": "stream",
|
| 662 |
+
"text": [
|
| 663 |
+
"2022-11-23 10:37:31.869492: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:113] Plugin optimizer for device_type GPU is enabled.\n"
|
| 664 |
+
]
|
| 665 |
+
},
|
| 666 |
+
{
|
| 667 |
+
"name": "stdout",
|
| 668 |
+
"output_type": "stream",
|
| 669 |
+
"text": [
|
| 670 |
+
"1230/1230 [==============================] - 20s 16ms/step - loss: 0.0592 - accuracy: 0.9811 - val_loss: 0.0177 - val_accuracy: 0.9961\n",
|
| 671 |
+
"Epoch 2/100\n",
|
| 672 |
+
"1230/1230 [==============================] - 19s 16ms/step - loss: 0.0235 - accuracy: 0.9934 - val_loss: 0.0164 - val_accuracy: 0.9951\n",
|
| 673 |
+
"Epoch 3/100\n",
|
| 674 |
+
"1230/1230 [==============================] - 19s 15ms/step - loss: 0.0206 - accuracy: 0.9945 - val_loss: 0.0150 - val_accuracy: 0.9945\n",
|
| 675 |
+
"Epoch 4/100\n",
|
| 676 |
+
"1230/1230 [==============================] - 19s 15ms/step - loss: 0.0175 - accuracy: 0.9951 - val_loss: 0.0160 - val_accuracy: 0.9961\n",
|
| 677 |
+
"Epoch 5/100\n",
|
| 678 |
+
"1230/1230 [==============================] - 19s 15ms/step - loss: 0.0176 - accuracy: 0.9950 - val_loss: 0.0131 - val_accuracy: 0.9964\n",
|
| 679 |
+
"Epoch 6/100\n",
|
| 680 |
+
"1230/1230 [==============================] - 19s 15ms/step - loss: 0.0137 - accuracy: 0.9960 - val_loss: 0.0091 - val_accuracy: 0.9984\n",
|
| 681 |
+
"Epoch 7/100\n",
|
| 682 |
+
"1230/1230 [==============================] - 19s 15ms/step - loss: 0.0141 - accuracy: 0.9959 - val_loss: 0.0121 - val_accuracy: 0.9958\n",
|
| 683 |
+
"Epoch 8/100\n",
|
| 684 |
+
"1230/1230 [==============================] - 19s 15ms/step - loss: 0.0118 - accuracy: 0.9964 - val_loss: 0.0089 - val_accuracy: 0.9967\n",
|
| 685 |
+
"Epoch 9/100\n",
|
| 686 |
+
"1230/1230 [==============================] - 19s 15ms/step - loss: 0.0097 - accuracy: 0.9969 - val_loss: 0.0155 - val_accuracy: 0.9974\n",
|
| 687 |
+
"Epoch 10/100\n",
|
| 688 |
+
"1230/1230 [==============================] - 19s 15ms/step - loss: 0.0154 - accuracy: 0.9964 - val_loss: 0.0093 - val_accuracy: 0.9974\n",
|
| 689 |
+
"Epoch 11/100\n",
|
| 690 |
+
"1230/1230 [==============================] - 19s 15ms/step - loss: 0.0118 - accuracy: 0.9970 - val_loss: 0.0073 - val_accuracy: 0.9987\n",
|
| 691 |
+
"Epoch 12/100\n",
|
| 692 |
+
"1230/1230 [==============================] - 19s 15ms/step - loss: 0.0103 - accuracy: 0.9971 - val_loss: 0.0185 - val_accuracy: 0.9980\n",
|
| 693 |
+
"Epoch 13/100\n",
|
| 694 |
+
"1230/1230 [==============================] - 19s 15ms/step - loss: 0.0112 - accuracy: 0.9971 - val_loss: 0.0105 - val_accuracy: 0.9977\n",
|
| 695 |
+
"Epoch 14/100\n",
|
| 696 |
+
"1230/1230 [==============================] - 19s 15ms/step - loss: 0.0132 - accuracy: 0.9965 - val_loss: 0.0183 - val_accuracy: 0.9964\n"
|
| 697 |
+
]
|
| 698 |
+
},
|
| 699 |
+
{
|
| 700 |
+
"data": {
|
| 701 |
+
"text/plain": [
|
| 702 |
+
"<keras.callbacks.History at 0x28fcad400>"
|
| 703 |
+
]
|
| 704 |
+
},
|
| 705 |
+
"execution_count": 29,
|
| 706 |
+
"metadata": {},
|
| 707 |
+
"output_type": "execute_result"
|
| 708 |
+
}
|
| 709 |
+
],
|
| 710 |
+
"source": [
|
| 711 |
+
"model_7lD = get_best_model(tuner_7lD)\n",
|
| 712 |
+
"model_7lD.fit(x_train, y_train, epochs=100, batch_size=10, validation_data=(x_test, y_test), callbacks=[stop_early])"
|
| 713 |
+
]
|
| 714 |
+
},
|
| 715 |
+
{
|
| 716 |
+
"cell_type": "code",
|
| 717 |
+
"execution_count": 30,
|
| 718 |
+
"metadata": {},
|
| 719 |
+
"outputs": [],
|
| 720 |
+
"source": [
|
| 721 |
+
"final_models[\"7_layers_with_dropout\"] = model_7lD"
|
| 722 |
+
]
|
| 723 |
+
},
|
| 724 |
+
{
|
| 725 |
+
"cell_type": "markdown",
|
| 726 |
+
"metadata": {},
|
| 727 |
+
"source": [
|
| 728 |
+
"### 3.5. Model with 7 layers"
|
| 729 |
+
]
|
| 730 |
+
},
|
| 731 |
+
{
|
| 732 |
+
"cell_type": "code",
|
| 733 |
+
"execution_count": 11,
|
| 734 |
+
"metadata": {},
|
| 735 |
+
"outputs": [],
|
| 736 |
+
"source": [
|
| 737 |
+
"def model_7l_builder(hp):\n",
|
| 738 |
+
" model = Sequential()\n",
|
| 739 |
+
" model.add(Dense(36, input_dim = 36, activation = \"relu\"))\n",
|
| 740 |
+
"\n",
|
| 741 |
+
" hp_activation = hp.Choice('activation', values=['relu', 'tanh'])\n",
|
| 742 |
+
" hp_layer_1 = hp.Int('layer_1', min_value=32, max_value=512, step=32)\n",
|
| 743 |
+
" hp_layer_2 = hp.Int('layer_2', min_value=32, max_value=512, step=32)\n",
|
| 744 |
+
" hp_layer_3 = hp.Int('layer_3', min_value=32, max_value=512, step=32)\n",
|
| 745 |
+
" hp_layer_4 = hp.Int('layer_4', min_value=32, max_value=512, step=32)\n",
|
| 746 |
+
" hp_layer_5 = hp.Int('layer_5', min_value=32, max_value=512, step=32)\n",
|
| 747 |
+
" hp_learning_rate = hp.Choice('learning_rate', values=[1e-2, 1e-3, 1e-4])\n",
|
| 748 |
+
"\n",
|
| 749 |
+
" model.add(Dense(units=hp_layer_1, activation=hp_activation))\n",
|
| 750 |
+
" model.add(Dense(units=hp_layer_2, activation=hp_activation))\n",
|
| 751 |
+
" model.add(Dense(units=hp_layer_3, activation=hp_activation))\n",
|
| 752 |
+
" model.add(Dense(units=hp_layer_4, activation=hp_activation))\n",
|
| 753 |
+
" model.add(Dense(units=hp_layer_5, activation=hp_activation))\n",
|
| 754 |
+
" model.add(Dense(2, activation = \"softmax\"))\n",
|
| 755 |
+
"\n",
|
| 756 |
+
" model.compile(optimizer=Adam(learning_rate=hp_learning_rate), loss=\"categorical_crossentropy\", metrics = [\"accuracy\"])\n",
|
| 757 |
+
" \n",
|
| 758 |
+
" return model"
|
| 759 |
+
]
|
| 760 |
+
},
|
| 761 |
+
{
|
| 762 |
+
"cell_type": "code",
|
| 763 |
+
"execution_count": 12,
|
| 764 |
+
"metadata": {},
|
| 765 |
+
"outputs": [
|
| 766 |
+
{
|
| 767 |
+
"name": "stdout",
|
| 768 |
+
"output_type": "stream",
|
| 769 |
+
"text": [
|
| 770 |
+
"Trial 30 Complete [00h 00m 51s]\n",
|
| 771 |
+
"val_accuracy: 0.9973983764648438\n",
|
| 772 |
+
"\n",
|
| 773 |
+
"Best val_accuracy So Far: 0.9977235794067383\n",
|
| 774 |
+
"Total elapsed time: 00h 02m 22s\n",
|
| 775 |
+
"INFO:tensorflow:Oracle triggered exit\n"
|
| 776 |
+
]
|
| 777 |
+
}
|
| 778 |
+
],
|
| 779 |
+
"source": [
|
| 780 |
+
"tuner_7l = kt.Hyperband(\n",
|
| 781 |
+
" model_7l_builder,\n",
|
| 782 |
+
" objective='val_accuracy',\n",
|
| 783 |
+
" max_epochs=10,\n",
|
| 784 |
+
" directory='keras_tuner_dir',\n",
|
| 785 |
+
" project_name='keras_tuner_demo_6'\n",
|
| 786 |
+
")\n",
|
| 787 |
+
"tuner_7l.search(x_train, y_train, validation_data=(x_test, y_test), epochs=10, callbacks=[stop_early, TensorBoard(\"./keras_tuner_dir/logs\")])"
|
| 788 |
+
]
|
| 789 |
+
},
|
| 790 |
+
{
|
| 791 |
+
"cell_type": "code",
|
| 792 |
+
"execution_count": 13,
|
| 793 |
+
"metadata": {},
|
| 794 |
+
"outputs": [
|
| 795 |
+
{
|
| 796 |
+
"name": "stdout",
|
| 797 |
+
"output_type": "stream",
|
| 798 |
+
"text": [
|
| 799 |
+
"Describe models architecture\n",
|
| 800 |
+
"Layer-1: 36 units, func: <function relu at 0x13a058b80>\n",
|
| 801 |
+
"Layer-2: 192 units, func: <function tanh at 0x13a058ee0>\n",
|
| 802 |
+
"Layer-3: 320 units, func: <function tanh at 0x13a058ee0>\n",
|
| 803 |
+
"Layer-4: 448 units, func: <function tanh at 0x13a058ee0>\n",
|
| 804 |
+
"Layer-5: 224 units, func: <function tanh at 0x13a058ee0>\n",
|
| 805 |
+
"Layer-6: 448 units, func: <function tanh at 0x13a058ee0>\n",
|
| 806 |
+
"Layer-7: 2 units, func: <function softmax at 0x13a058160>\n",
|
| 807 |
+
"\n",
|
| 808 |
+
"Other params:\n",
|
| 809 |
+
"learning_rate: 0.0001\n",
|
| 810 |
+
"Epoch 1/100\n",
|
| 811 |
+
" 1/1230 [..............................] - ETA: 8:04 - loss: 0.6561 - accuracy: 0.6000"
|
| 812 |
+
]
|
| 813 |
+
},
|
| 814 |
+
{
|
| 815 |
+
"name": "stderr",
|
| 816 |
+
"output_type": "stream",
|
| 817 |
+
"text": [
|
| 818 |
+
"2022-11-23 14:29:10.795739: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:113] Plugin optimizer for device_type GPU is enabled.\n"
|
| 819 |
+
]
|
| 820 |
+
},
|
| 821 |
+
{
|
| 822 |
+
"name": "stdout",
|
| 823 |
+
"output_type": "stream",
|
| 824 |
+
"text": [
|
| 825 |
+
"1230/1230 [==============================] - ETA: 0s - loss: 0.0720 - accuracy: 0.9748"
|
| 826 |
+
]
|
| 827 |
+
},
|
| 828 |
+
{
|
| 829 |
+
"name": "stderr",
|
| 830 |
+
"output_type": "stream",
|
| 831 |
+
"text": [
|
| 832 |
+
"2022-11-23 14:29:24.355056: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:113] Plugin optimizer for device_type GPU is enabled.\n"
|
| 833 |
+
]
|
| 834 |
+
},
|
| 835 |
+
{
|
| 836 |
+
"name": "stdout",
|
| 837 |
+
"output_type": "stream",
|
| 838 |
+
"text": [
|
| 839 |
+
"1230/1230 [==============================] - 16s 13ms/step - loss: 0.0720 - accuracy: 0.9748 - val_loss: 0.0219 - val_accuracy: 0.9941\n",
|
| 840 |
+
"Epoch 2/100\n",
|
| 841 |
+
"1230/1230 [==============================] - 15s 12ms/step - loss: 0.0319 - accuracy: 0.9910 - val_loss: 0.0353 - val_accuracy: 0.9893\n",
|
| 842 |
+
"Epoch 3/100\n",
|
| 843 |
+
"1230/1230 [==============================] - 15s 12ms/step - loss: 0.0262 - accuracy: 0.9927 - val_loss: 0.0149 - val_accuracy: 0.9958\n",
|
| 844 |
+
"Epoch 4/100\n",
|
| 845 |
+
"1230/1230 [==============================] - 15s 12ms/step - loss: 0.0221 - accuracy: 0.9936 - val_loss: 0.0125 - val_accuracy: 0.9964\n",
|
| 846 |
+
"Epoch 5/100\n",
|
| 847 |
+
"1230/1230 [==============================] - 15s 12ms/step - loss: 0.0195 - accuracy: 0.9943 - val_loss: 0.0171 - val_accuracy: 0.9951\n",
|
| 848 |
+
"Epoch 6/100\n",
|
| 849 |
+
"1230/1230 [==============================] - 16s 13ms/step - loss: 0.0186 - accuracy: 0.9942 - val_loss: 0.0128 - val_accuracy: 0.9971\n",
|
| 850 |
+
"Epoch 7/100\n",
|
| 851 |
+
"1230/1230 [==============================] - 16s 13ms/step - loss: 0.0170 - accuracy: 0.9951 - val_loss: 0.0132 - val_accuracy: 0.9974\n"
|
| 852 |
+
]
|
| 853 |
+
},
|
| 854 |
+
{
|
| 855 |
+
"data": {
|
| 856 |
+
"text/plain": [
|
| 857 |
+
"<keras.callbacks.History at 0x15a0f17c0>"
|
| 858 |
+
]
|
| 859 |
+
},
|
| 860 |
+
"execution_count": 13,
|
| 861 |
+
"metadata": {},
|
| 862 |
+
"output_type": "execute_result"
|
| 863 |
+
}
|
| 864 |
+
],
|
| 865 |
+
"source": [
|
| 866 |
+
"model_7l = get_best_model(tuner_7l)\n",
|
| 867 |
+
"model_7l.fit(x_train, y_train, epochs=100, batch_size=10, validation_data=(x_test, y_test), callbacks=[stop_early])"
|
| 868 |
+
]
|
| 869 |
+
},
|
| 870 |
+
{
|
| 871 |
+
"cell_type": "code",
|
| 872 |
+
"execution_count": 16,
|
| 873 |
+
"metadata": {},
|
| 874 |
+
"outputs": [],
|
| 875 |
+
"source": [
|
| 876 |
+
"final_models[\"7_layers\"] = model_7l"
|
| 877 |
+
]
|
| 878 |
+
},
|
| 879 |
+
{
|
| 880 |
+
"cell_type": "markdown",
|
| 881 |
+
"metadata": {},
|
| 882 |
+
"source": [
|
| 883 |
+
"### 3.6. Describe final models"
|
| 884 |
+
]
|
| 885 |
+
},
|
| 886 |
+
{
|
| 887 |
+
"cell_type": "code",
|
| 888 |
+
"execution_count": 10,
|
| 889 |
+
"metadata": {},
|
| 890 |
+
"outputs": [
|
| 891 |
+
{
|
| 892 |
+
"name": "stdout",
|
| 893 |
+
"output_type": "stream",
|
| 894 |
+
"text": [
|
| 895 |
+
"3_layers: Describe models architecture\n",
|
| 896 |
+
"Layer-1: 36 units, func: <function relu at 0x17fe29b80>\n",
|
| 897 |
+
"Layer-2: 448 units, func: <function tanh at 0x17fe29ee0>\n",
|
| 898 |
+
"Layer-3: 2 units, func: <function softmax at 0x17fe29160>\n",
|
| 899 |
+
"\n",
|
| 900 |
+
"5_layers: Describe models architecture\n",
|
| 901 |
+
"Layer-1: 36 units, func: <function relu at 0x17fe29b80>\n",
|
| 902 |
+
"Layer-2: 160 units, func: <function relu at 0x17fe29b80>\n",
|
| 903 |
+
"Layer-3: 352 units, func: <function relu at 0x17fe29b80>\n",
|
| 904 |
+
"Layer-4: 64 units, func: <function relu at 0x17fe29b80>\n",
|
| 905 |
+
"Layer-5: 2 units, func: <function softmax at 0x17fe29160>\n",
|
| 906 |
+
"\n",
|
| 907 |
+
"7_layers_with_dropout: Describe models architecture\n",
|
| 908 |
+
"Layer-1: 36 units, func: <function relu at 0x17fe29b80>\n",
|
| 909 |
+
"Layer-2: 320 units, func: <function relu at 0x17fe29b80>\n",
|
| 910 |
+
"Layer-3: 0 units, func: None\n",
|
| 911 |
+
"Layer-4: 96 units, func: <function relu at 0x17fe29b80>\n",
|
| 912 |
+
"Layer-5: 0 units, func: None\n",
|
| 913 |
+
"Layer-6: 448 units, func: <function relu at 0x17fe29b80>\n",
|
| 914 |
+
"Layer-7: 2 units, func: <function softmax at 0x17fe29160>\n",
|
| 915 |
+
"\n",
|
| 916 |
+
"7_layers: Describe models architecture\n",
|
| 917 |
+
"Layer-1: 36 units, func: <function relu at 0x17fe29b80>\n",
|
| 918 |
+
"Layer-2: 192 units, func: <function tanh at 0x17fe29ee0>\n",
|
| 919 |
+
"Layer-3: 320 units, func: <function tanh at 0x17fe29ee0>\n",
|
| 920 |
+
"Layer-4: 448 units, func: <function tanh at 0x17fe29ee0>\n",
|
| 921 |
+
"Layer-5: 224 units, func: <function tanh at 0x17fe29ee0>\n",
|
| 922 |
+
"Layer-6: 448 units, func: <function tanh at 0x17fe29ee0>\n",
|
| 923 |
+
"Layer-7: 2 units, func: <function softmax at 0x17fe29160>\n",
|
| 924 |
+
"\n"
|
| 925 |
+
]
|
| 926 |
+
}
|
| 927 |
+
],
|
| 928 |
+
"source": [
|
| 929 |
+
"for name, model in final_models.items():\n",
|
| 930 |
+
" print(f\"{name}: \", end=\"\")\n",
|
| 931 |
+
" describe_model(model)\n",
|
| 932 |
+
" print()"
|
| 933 |
+
]
|
| 934 |
+
},
|
| 935 |
+
{
|
| 936 |
+
"cell_type": "markdown",
|
| 937 |
+
"metadata": {},
|
| 938 |
+
"source": [
|
| 939 |
+
"## 4. Model Evaluation"
|
| 940 |
+
]
|
| 941 |
+
},
|
| 942 |
+
{
|
| 943 |
+
"cell_type": "markdown",
|
| 944 |
+
"metadata": {},
|
| 945 |
+
"source": [
|
| 946 |
+
"### 4.1. Train set evaluation"
|
| 947 |
+
]
|
| 948 |
+
},
|
| 949 |
+
{
|
| 950 |
+
"cell_type": "code",
|
| 951 |
+
"execution_count": 13,
|
| 952 |
+
"metadata": {},
|
| 953 |
+
"outputs": [
|
| 954 |
+
{
|
| 955 |
+
"data": {
|
| 956 |
+
"text/html": [
|
| 957 |
+
"<div>\n",
|
| 958 |
+
"<style scoped>\n",
|
| 959 |
+
" .dataframe tbody tr th:only-of-type {\n",
|
| 960 |
+
" vertical-align: middle;\n",
|
| 961 |
+
" }\n",
|
| 962 |
+
"\n",
|
| 963 |
+
" .dataframe tbody tr th {\n",
|
| 964 |
+
" vertical-align: top;\n",
|
| 965 |
+
" }\n",
|
| 966 |
+
"\n",
|
| 967 |
+
" .dataframe thead th {\n",
|
| 968 |
+
" text-align: right;\n",
|
| 969 |
+
" }\n",
|
| 970 |
+
"</style>\n",
|
| 971 |
+
"<table border=\"1\" class=\"dataframe\">\n",
|
| 972 |
+
" <thead>\n",
|
| 973 |
+
" <tr style=\"text-align: right;\">\n",
|
| 974 |
+
" <th></th>\n",
|
| 975 |
+
" <th>Model</th>\n",
|
| 976 |
+
" <th>Precision Score</th>\n",
|
| 977 |
+
" <th>Recall Score</th>\n",
|
| 978 |
+
" <th>F1 score</th>\n",
|
| 979 |
+
" <th>Confusion Matrix</th>\n",
|
| 980 |
+
" </tr>\n",
|
| 981 |
+
" </thead>\n",
|
| 982 |
+
" <tbody>\n",
|
| 983 |
+
" <tr>\n",
|
| 984 |
+
" <th>0</th>\n",
|
| 985 |
+
" <td>7_layers</td>\n",
|
| 986 |
+
" <td>[0.998, 0.997]</td>\n",
|
| 987 |
+
" <td>[0.998, 0.997]</td>\n",
|
| 988 |
+
" <td>[0.998, 0.997]</td>\n",
|
| 989 |
+
" <td>[[1675, 4], [4, 1392]]</td>\n",
|
| 990 |
+
" </tr>\n",
|
| 991 |
+
" <tr>\n",
|
| 992 |
+
" <th>1</th>\n",
|
| 993 |
+
" <td>3_layers</td>\n",
|
| 994 |
+
" <td>[0.997, 0.996]</td>\n",
|
| 995 |
+
" <td>[0.997, 0.996]</td>\n",
|
| 996 |
+
" <td>[0.997, 0.996]</td>\n",
|
| 997 |
+
" <td>[[1674, 5], [5, 1391]]</td>\n",
|
| 998 |
+
" </tr>\n",
|
| 999 |
+
" <tr>\n",
|
| 1000 |
+
" <th>2</th>\n",
|
| 1001 |
+
" <td>7_layers_with_dropout</td>\n",
|
| 1002 |
+
" <td>[0.998, 0.995]</td>\n",
|
| 1003 |
+
" <td>[0.996, 0.997]</td>\n",
|
| 1004 |
+
" <td>[0.997, 0.996]</td>\n",
|
| 1005 |
+
" <td>[[1672, 7], [4, 1392]]</td>\n",
|
| 1006 |
+
" </tr>\n",
|
| 1007 |
+
" <tr>\n",
|
| 1008 |
+
" <th>3</th>\n",
|
| 1009 |
+
" <td>5_layers</td>\n",
|
| 1010 |
+
" <td>[0.996, 0.996]</td>\n",
|
| 1011 |
+
" <td>[0.996, 0.996]</td>\n",
|
| 1012 |
+
" <td>[0.996, 0.996]</td>\n",
|
| 1013 |
+
" <td>[[1673, 6], [6, 1390]]</td>\n",
|
| 1014 |
+
" </tr>\n",
|
| 1015 |
+
" </tbody>\n",
|
| 1016 |
+
"</table>\n",
|
| 1017 |
+
"</div>"
|
| 1018 |
+
],
|
| 1019 |
+
"text/plain": [
|
| 1020 |
+
" Model Precision Score Recall Score F1 score \\\n",
|
| 1021 |
+
"0 7_layers [0.998, 0.997] [0.998, 0.997] [0.998, 0.997] \n",
|
| 1022 |
+
"1 3_layers [0.997, 0.996] [0.997, 0.996] [0.997, 0.996] \n",
|
| 1023 |
+
"2 7_layers_with_dropout [0.998, 0.995] [0.996, 0.997] [0.997, 0.996] \n",
|
| 1024 |
+
"3 5_layers [0.996, 0.996] [0.996, 0.996] [0.996, 0.996] \n",
|
| 1025 |
+
"\n",
|
| 1026 |
+
" Confusion Matrix \n",
|
| 1027 |
+
"0 [[1675, 4], [4, 1392]] \n",
|
| 1028 |
+
"1 [[1674, 5], [5, 1391]] \n",
|
| 1029 |
+
"2 [[1672, 7], [4, 1392]] \n",
|
| 1030 |
+
"3 [[1673, 6], [6, 1390]] "
|
| 1031 |
+
]
|
| 1032 |
+
},
|
| 1033 |
+
"execution_count": 13,
|
| 1034 |
+
"metadata": {},
|
| 1035 |
+
"output_type": "execute_result"
|
| 1036 |
+
}
|
| 1037 |
+
],
|
| 1038 |
+
"source": [
|
| 1039 |
+
"train_set_results = []\n",
|
| 1040 |
+
"\n",
|
| 1041 |
+
"for name, model in final_models.items():\n",
|
| 1042 |
+
" # Evaluate model\n",
|
| 1043 |
+
" predict_x = model.predict(x_test, verbose=False) \n",
|
| 1044 |
+
" y_pred_class = np.argmax(predict_x, axis=1)\n",
|
| 1045 |
+
" y_test_class = np.argmax(y_test, axis=1)\n",
|
| 1046 |
+
"\n",
|
| 1047 |
+
" cm = confusion_matrix(y_test_class, y_pred_class, labels=[0, 1])\n",
|
| 1048 |
+
" (p_score, r_score, f_score, _) = precision_recall_fscore_support(y_test_class, y_pred_class, labels=[0, 1])\n",
|
| 1049 |
+
" \n",
|
| 1050 |
+
" train_set_results.append(( name, round_up_metric_results(p_score), round_up_metric_results(r_score), round_up_metric_results(f_score), cm ))\n",
|
| 1051 |
+
"\n",
|
| 1052 |
+
"train_set_results.sort(key=lambda k: sum(k[3]), reverse=True)\n",
|
| 1053 |
+
"pd.DataFrame(train_set_results, columns=[\"Model\", \"Precision Score\", \"Recall Score\", \"F1 score\", \"Confusion Matrix\"])"
|
| 1054 |
+
]
|
| 1055 |
+
},
|
| 1056 |
+
{
|
| 1057 |
+
"cell_type": "markdown",
|
| 1058 |
+
"metadata": {},
|
| 1059 |
+
"source": [
|
| 1060 |
+
"### 4.2. Test set evaluation"
|
| 1061 |
+
]
|
| 1062 |
+
},
|
| 1063 |
+
{
|
| 1064 |
+
"cell_type": "code",
|
| 1065 |
+
"execution_count": 14,
|
| 1066 |
+
"metadata": {},
|
| 1067 |
+
"outputs": [
|
| 1068 |
+
{
|
| 1069 |
+
"name": "stdout",
|
| 1070 |
+
"output_type": "stream",
|
| 1071 |
+
"text": [
|
| 1072 |
+
"Headers: ['label', 'nose_x', 'nose_y', 'nose_z', 'nose_v', 'left_shoulder_x', 'left_shoulder_y', 'left_shoulder_z', 'left_shoulder_v', 'right_shoulder_x', 'right_shoulder_y', 'right_shoulder_z', 'right_shoulder_v', 'right_elbow_x', 'right_elbow_y', 'right_elbow_z', 'right_elbow_v', 'left_elbow_x', 'left_elbow_y', 'left_elbow_z', 'left_elbow_v', 'right_wrist_x', 'right_wrist_y', 'right_wrist_z', 'right_wrist_v', 'left_wrist_x', 'left_wrist_y', 'left_wrist_z', 'left_wrist_v', 'left_hip_x', 'left_hip_y', 'left_hip_z', 'left_hip_v', 'right_hip_x', 'right_hip_y', 'right_hip_z', 'right_hip_v']\n",
|
| 1073 |
+
"Number of rows: 604 \n",
|
| 1074 |
+
"Number of columns: 37\n",
|
| 1075 |
+
"\n",
|
| 1076 |
+
"Labels: \n",
|
| 1077 |
+
"C 339\n",
|
| 1078 |
+
"L 265\n",
|
| 1079 |
+
"Name: label, dtype: int64\n",
|
| 1080 |
+
"\n",
|
| 1081 |
+
"Missing values: False\n",
|
| 1082 |
+
"\n",
|
| 1083 |
+
"Duplicate Rows : 0\n"
|
| 1084 |
+
]
|
| 1085 |
+
}
|
| 1086 |
+
],
|
| 1087 |
+
"source": [
|
| 1088 |
+
"# load dataset\n",
|
| 1089 |
+
"test_df = describe_dataset(\"./test.csv\")\n",
|
| 1090 |
+
"\n",
|
| 1091 |
+
"# Categorizing label\n",
|
| 1092 |
+
"test_df.loc[test_df[\"label\"] == \"C\", \"label\"] = 0\n",
|
| 1093 |
+
"test_df.loc[test_df[\"label\"] == \"L\", \"label\"] = 1"
|
| 1094 |
+
]
|
| 1095 |
+
},
|
| 1096 |
+
{
|
| 1097 |
+
"cell_type": "code",
|
| 1098 |
+
"execution_count": 15,
|
| 1099 |
+
"metadata": {},
|
| 1100 |
+
"outputs": [],
|
| 1101 |
+
"source": [
|
| 1102 |
+
"# Standard Scaling of features\n",
|
| 1103 |
+
"test_x = test_df.drop(\"label\", axis = 1)\n",
|
| 1104 |
+
"test_x = pd.DataFrame(sc.transform(test_x))\n",
|
| 1105 |
+
"\n",
|
| 1106 |
+
"test_y = test_df[\"label\"]\n",
|
| 1107 |
+
"\n",
|
| 1108 |
+
"# # Converting prediction to categorical\n",
|
| 1109 |
+
"test_y_cat = to_categorical(test_y)"
|
| 1110 |
+
]
|
| 1111 |
+
},
|
| 1112 |
+
{
|
| 1113 |
+
"cell_type": "code",
|
| 1114 |
+
"execution_count": 16,
|
| 1115 |
+
"metadata": {},
|
| 1116 |
+
"outputs": [
|
| 1117 |
+
{
|
| 1118 |
+
"name": "stderr",
|
| 1119 |
+
"output_type": "stream",
|
| 1120 |
+
"text": [
|
| 1121 |
+
"2022-11-25 15:34:28.174069: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:113] Plugin optimizer for device_type GPU is enabled.\n",
|
| 1122 |
+
"2022-11-25 15:34:28.287586: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:113] Plugin optimizer for device_type GPU is enabled.\n",
|
| 1123 |
+
"2022-11-25 15:34:28.423087: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:113] Plugin optimizer for device_type GPU is enabled.\n",
|
| 1124 |
+
"2022-11-25 15:34:28.546100: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:113] Plugin optimizer for device_type GPU is enabled.\n"
|
| 1125 |
+
]
|
| 1126 |
+
},
|
| 1127 |
+
{
|
| 1128 |
+
"data": {
|
| 1129 |
+
"text/html": [
|
| 1130 |
+
"<div>\n",
|
| 1131 |
+
"<style scoped>\n",
|
| 1132 |
+
" .dataframe tbody tr th:only-of-type {\n",
|
| 1133 |
+
" vertical-align: middle;\n",
|
| 1134 |
+
" }\n",
|
| 1135 |
+
"\n",
|
| 1136 |
+
" .dataframe tbody tr th {\n",
|
| 1137 |
+
" vertical-align: top;\n",
|
| 1138 |
+
" }\n",
|
| 1139 |
+
"\n",
|
| 1140 |
+
" .dataframe thead th {\n",
|
| 1141 |
+
" text-align: right;\n",
|
| 1142 |
+
" }\n",
|
| 1143 |
+
"</style>\n",
|
| 1144 |
+
"<table border=\"1\" class=\"dataframe\">\n",
|
| 1145 |
+
" <thead>\n",
|
| 1146 |
+
" <tr style=\"text-align: right;\">\n",
|
| 1147 |
+
" <th></th>\n",
|
| 1148 |
+
" <th>Model</th>\n",
|
| 1149 |
+
" <th>Precision Score</th>\n",
|
| 1150 |
+
" <th>Recall Score</th>\n",
|
| 1151 |
+
" <th>F1 score</th>\n",
|
| 1152 |
+
" <th>Confusion Matrix</th>\n",
|
| 1153 |
+
" </tr>\n",
|
| 1154 |
+
" </thead>\n",
|
| 1155 |
+
" <tbody>\n",
|
| 1156 |
+
" <tr>\n",
|
| 1157 |
+
" <th>0</th>\n",
|
| 1158 |
+
" <td>7_layers</td>\n",
|
| 1159 |
+
" <td>[0.944, 1.0]</td>\n",
|
| 1160 |
+
" <td>[1.0, 0.925]</td>\n",
|
| 1161 |
+
" <td>[0.971, 0.961]</td>\n",
|
| 1162 |
+
" <td>[[339, 0], [20, 245]]</td>\n",
|
| 1163 |
+
" </tr>\n",
|
| 1164 |
+
" <tr>\n",
|
| 1165 |
+
" <th>1</th>\n",
|
| 1166 |
+
" <td>5_layers</td>\n",
|
| 1167 |
+
" <td>[0.926, 1.0]</td>\n",
|
| 1168 |
+
" <td>[1.0, 0.898]</td>\n",
|
| 1169 |
+
" <td>[0.962, 0.946]</td>\n",
|
| 1170 |
+
" <td>[[339, 0], [27, 238]]</td>\n",
|
| 1171 |
+
" </tr>\n",
|
| 1172 |
+
" <tr>\n",
|
| 1173 |
+
" <th>2</th>\n",
|
| 1174 |
+
" <td>7_layers_with_dropout</td>\n",
|
| 1175 |
+
" <td>[0.909, 0.963]</td>\n",
|
| 1176 |
+
" <td>[0.973, 0.875]</td>\n",
|
| 1177 |
+
" <td>[0.94, 0.917]</td>\n",
|
| 1178 |
+
" <td>[[330, 9], [33, 232]]</td>\n",
|
| 1179 |
+
" </tr>\n",
|
| 1180 |
+
" <tr>\n",
|
| 1181 |
+
" <th>3</th>\n",
|
| 1182 |
+
" <td>3_layers</td>\n",
|
| 1183 |
+
" <td>[0.896, 0.983]</td>\n",
|
| 1184 |
+
" <td>[0.988, 0.853]</td>\n",
|
| 1185 |
+
" <td>[0.94, 0.913]</td>\n",
|
| 1186 |
+
" <td>[[335, 4], [39, 226]]</td>\n",
|
| 1187 |
+
" </tr>\n",
|
| 1188 |
+
" </tbody>\n",
|
| 1189 |
+
"</table>\n",
|
| 1190 |
+
"</div>"
|
| 1191 |
+
],
|
| 1192 |
+
"text/plain": [
|
| 1193 |
+
" Model Precision Score Recall Score F1 score \\\n",
|
| 1194 |
+
"0 7_layers [0.944, 1.0] [1.0, 0.925] [0.971, 0.961] \n",
|
| 1195 |
+
"1 5_layers [0.926, 1.0] [1.0, 0.898] [0.962, 0.946] \n",
|
| 1196 |
+
"2 7_layers_with_dropout [0.909, 0.963] [0.973, 0.875] [0.94, 0.917] \n",
|
| 1197 |
+
"3 3_layers [0.896, 0.983] [0.988, 0.853] [0.94, 0.913] \n",
|
| 1198 |
+
"\n",
|
| 1199 |
+
" Confusion Matrix \n",
|
| 1200 |
+
"0 [[339, 0], [20, 245]] \n",
|
| 1201 |
+
"1 [[339, 0], [27, 238]] \n",
|
| 1202 |
+
"2 [[330, 9], [33, 232]] \n",
|
| 1203 |
+
"3 [[335, 4], [39, 226]] "
|
| 1204 |
+
]
|
| 1205 |
+
},
|
| 1206 |
+
"execution_count": 16,
|
| 1207 |
+
"metadata": {},
|
| 1208 |
+
"output_type": "execute_result"
|
| 1209 |
+
}
|
| 1210 |
+
],
|
| 1211 |
+
"source": [
|
| 1212 |
+
"test_set_results = []\n",
|
| 1213 |
+
"\n",
|
| 1214 |
+
"for name, model in final_models.items():\n",
|
| 1215 |
+
" # Evaluate model\n",
|
| 1216 |
+
" predict_x = model.predict(test_x, verbose=False) \n",
|
| 1217 |
+
" y_pred_class = np.argmax(predict_x, axis=1)\n",
|
| 1218 |
+
" y_test_class = np.argmax(test_y_cat, axis=1)\n",
|
| 1219 |
+
"\n",
|
| 1220 |
+
" cm = confusion_matrix(y_test_class, y_pred_class, labels=[0, 1])\n",
|
| 1221 |
+
" (p_score, r_score, f_score, _) = precision_recall_fscore_support(y_test_class, y_pred_class, labels=[0, 1])\n",
|
| 1222 |
+
" \n",
|
| 1223 |
+
" test_set_results.append(( name, round_up_metric_results(p_score), round_up_metric_results(r_score), round_up_metric_results(f_score), cm ))\n",
|
| 1224 |
+
"\n",
|
| 1225 |
+
"test_set_results.sort(key=lambda k: k[1] + k[2] + k[3], reverse=True)\n",
|
| 1226 |
+
"pd.DataFrame(test_set_results, columns=[\"Model\", \"Precision Score\", \"Recall Score\", \"F1 score\", \"Confusion Matrix\"])"
|
| 1227 |
+
]
|
| 1228 |
+
},
|
| 1229 |
+
{
|
| 1230 |
+
"cell_type": "markdown",
|
| 1231 |
+
"metadata": {},
|
| 1232 |
+
"source": [
|
| 1233 |
+
"## 5. Dumped Model"
|
| 1234 |
+
]
|
| 1235 |
+
},
|
| 1236 |
+
{
|
| 1237 |
+
"cell_type": "code",
|
| 1238 |
+
"execution_count": 22,
|
| 1239 |
+
"metadata": {},
|
| 1240 |
+
"outputs": [
|
| 1241 |
+
{
|
| 1242 |
+
"name": "stdout",
|
| 1243 |
+
"output_type": "stream",
|
| 1244 |
+
"text": [
|
| 1245 |
+
"INFO:tensorflow:Assets written to: ram://4145d713-d810-484c-b518-b4ae694e4919/assets\n"
|
| 1246 |
+
]
|
| 1247 |
+
}
|
| 1248 |
+
],
|
| 1249 |
+
"source": [
|
| 1250 |
+
"# Dump the best model to a pickle file\n",
|
| 1251 |
+
"with open(\"./model/bicep_dp.pkl\", \"wb\") as f:\n",
|
| 1252 |
+
" pickle.dump(final_models[\"7_layers\"], f)"
|
| 1253 |
+
]
|
| 1254 |
+
},
|
| 1255 |
+
{
|
| 1256 |
+
"cell_type": "code",
|
| 1257 |
+
"execution_count": 23,
|
| 1258 |
+
"metadata": {},
|
| 1259 |
+
"outputs": [
|
| 1260 |
+
{
|
| 1261 |
+
"name": "stdout",
|
| 1262 |
+
"output_type": "stream",
|
| 1263 |
+
"text": [
|
| 1264 |
+
"INFO:tensorflow:Assets written to: ram://5ccb0e7f-b3f8-4602-9b3c-2ae89d1a2d69/assets\n",
|
| 1265 |
+
"INFO:tensorflow:Assets written to: ram://5d2d95b4-ff82-487d-bd98-ba859e8eced0/assets\n",
|
| 1266 |
+
"INFO:tensorflow:Assets written to: ram://557449b4-6368-4822-a75f-79675a055ab9/assets\n",
|
| 1267 |
+
"INFO:tensorflow:Assets written to: ram://4857368e-b747-43dd-9b2d-6e986317f4b8/assets\n"
|
| 1268 |
+
]
|
| 1269 |
+
}
|
| 1270 |
+
],
|
| 1271 |
+
"source": [
|
| 1272 |
+
"# Dump final results\n",
|
| 1273 |
+
"with open(\"./model/all_models.pkl\", \"wb\") as f:\n",
|
| 1274 |
+
" pickle.dump(final_models, f)"
|
| 1275 |
+
]
|
| 1276 |
+
},
|
| 1277 |
+
{
|
| 1278 |
+
"cell_type": "code",
|
| 1279 |
+
"execution_count": null,
|
| 1280 |
+
"metadata": {},
|
| 1281 |
+
"outputs": [],
|
| 1282 |
+
"source": []
|
| 1283 |
+
}
|
| 1284 |
+
],
|
| 1285 |
+
"metadata": {
|
| 1286 |
+
"kernelspec": {
|
| 1287 |
+
"display_name": "Python 3.8.13 (conda)",
|
| 1288 |
+
"language": "python",
|
| 1289 |
+
"name": "python3"
|
| 1290 |
+
},
|
| 1291 |
+
"language_info": {
|
| 1292 |
+
"codemirror_mode": {
|
| 1293 |
+
"name": "ipython",
|
| 1294 |
+
"version": 3
|
| 1295 |
+
},
|
| 1296 |
+
"file_extension": ".py",
|
| 1297 |
+
"mimetype": "text/x-python",
|
| 1298 |
+
"name": "python",
|
| 1299 |
+
"nbconvert_exporter": "python",
|
| 1300 |
+
"pygments_lexer": "ipython3",
|
| 1301 |
+
"version": "3.8.13"
|
| 1302 |
+
},
|
| 1303 |
+
"orig_nbformat": 4,
|
| 1304 |
+
"vscode": {
|
| 1305 |
+
"interpreter": {
|
| 1306 |
+
"hash": "9260f401923fb5c4108c543a7d176de9733d378b3752e49535ad7c43c2271b65"
|
| 1307 |
+
}
|
| 1308 |
+
}
|
| 1309 |
+
},
|
| 1310 |
+
"nbformat": 4,
|
| 1311 |
+
"nbformat_minor": 2
|
| 1312 |
+
}
|
core/bicep_model/4.evaluation.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
core/bicep_model/5.detection.ipynb
ADDED
|
@@ -0,0 +1,619 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": null,
|
| 6 |
+
"metadata": {},
|
| 7 |
+
"outputs": [],
|
| 8 |
+
"source": [
|
| 9 |
+
"import mediapipe as mp\n",
|
| 10 |
+
"import cv2\n",
|
| 11 |
+
"import numpy as np\n",
|
| 12 |
+
"import pandas as pd\n",
|
| 13 |
+
"import datetime\n",
|
| 14 |
+
"\n",
|
| 15 |
+
"import pickle\n",
|
| 16 |
+
"\n",
|
| 17 |
+
"import warnings\n",
|
| 18 |
+
"warnings.filterwarnings('ignore')\n",
|
| 19 |
+
"\n",
|
| 20 |
+
"# Drawing helpers\n",
|
| 21 |
+
"mp_drawing = mp.solutions.drawing_utils\n",
|
| 22 |
+
"mp_pose = mp.solutions.pose"
|
| 23 |
+
]
|
| 24 |
+
},
|
| 25 |
+
{
|
| 26 |
+
"cell_type": "markdown",
|
| 27 |
+
"metadata": {},
|
| 28 |
+
"source": [
|
| 29 |
+
"### 1. Set up important functions"
|
| 30 |
+
]
|
| 31 |
+
},
|
| 32 |
+
{
|
| 33 |
+
"cell_type": "code",
|
| 34 |
+
"execution_count": null,
|
| 35 |
+
"metadata": {},
|
| 36 |
+
"outputs": [],
|
| 37 |
+
"source": [
|
| 38 |
+
"# Determine important landmarks for plank\n",
|
| 39 |
+
"IMPORTANT_LMS = [\n",
|
| 40 |
+
" \"NOSE\",\n",
|
| 41 |
+
" \"LEFT_SHOULDER\",\n",
|
| 42 |
+
" \"RIGHT_SHOULDER\",\n",
|
| 43 |
+
" \"RIGHT_ELBOW\",\n",
|
| 44 |
+
" \"LEFT_ELBOW\",\n",
|
| 45 |
+
" \"RIGHT_WRIST\",\n",
|
| 46 |
+
" \"LEFT_WRIST\",\n",
|
| 47 |
+
" \"LEFT_HIP\",\n",
|
| 48 |
+
" \"RIGHT_HIP\",\n",
|
| 49 |
+
"]\n",
|
| 50 |
+
"\n",
|
| 51 |
+
"# Generate all columns of the data frame\n",
|
| 52 |
+
"\n",
|
| 53 |
+
"HEADERS = [\"label\"] # Label column\n",
|
| 54 |
+
"\n",
|
| 55 |
+
"for lm in IMPORTANT_LMS:\n",
|
| 56 |
+
" HEADERS += [f\"{lm.lower()}_x\", f\"{lm.lower()}_y\", f\"{lm.lower()}_z\", f\"{lm.lower()}_v\"]"
|
| 57 |
+
]
|
| 58 |
+
},
|
| 59 |
+
{
|
| 60 |
+
"cell_type": "code",
|
| 61 |
+
"execution_count": null,
|
| 62 |
+
"metadata": {},
|
| 63 |
+
"outputs": [],
|
| 64 |
+
"source": [
|
| 65 |
+
"def rescale_frame(frame, percent=50):\n",
|
| 66 |
+
" '''\n",
|
| 67 |
+
" Rescale a frame from OpenCV to a certain percentage compare to its original frame\n",
|
| 68 |
+
" '''\n",
|
| 69 |
+
" width = int(frame.shape[1] * percent/ 100)\n",
|
| 70 |
+
" height = int(frame.shape[0] * percent/ 100)\n",
|
| 71 |
+
" dim = (width, height)\n",
|
| 72 |
+
" return cv2.resize(frame, dim, interpolation =cv2.INTER_AREA)\n",
|
| 73 |
+
"\n",
|
| 74 |
+
"\n",
|
| 75 |
+
"def save_frame_as_image(frame, message: str = None):\n",
|
| 76 |
+
" '''\n",
|
| 77 |
+
" Save a frame as image to display the error\n",
|
| 78 |
+
" '''\n",
|
| 79 |
+
" now = datetime.datetime.now()\n",
|
| 80 |
+
"\n",
|
| 81 |
+
" if message:\n",
|
| 82 |
+
" cv2.putText(frame, message, (50, 150), cv2.FONT_HERSHEY_COMPLEX, 0.4, (0, 0, 0), 1, cv2.LINE_AA)\n",
|
| 83 |
+
" \n",
|
| 84 |
+
" print(\"Saving ...\")\n",
|
| 85 |
+
" cv2.imwrite(f\"../data/logs/bicep_{now}.jpg\", frame)\n",
|
| 86 |
+
"\n",
|
| 87 |
+
"\n",
|
| 88 |
+
"def calculate_angle(point1: list, point2: list, point3: list) -> float:\n",
|
| 89 |
+
" '''\n",
|
| 90 |
+
" Calculate the angle between 3 points\n",
|
| 91 |
+
" Unit of the angle will be in Degree\n",
|
| 92 |
+
" '''\n",
|
| 93 |
+
" point1 = np.array(point1)\n",
|
| 94 |
+
" point2 = np.array(point2)\n",
|
| 95 |
+
" point3 = np.array(point3)\n",
|
| 96 |
+
"\n",
|
| 97 |
+
" # Calculate algo\n",
|
| 98 |
+
" angleInRad = np.arctan2(point3[1] - point2[1], point3[0] - point2[0]) - np.arctan2(point1[1] - point2[1], point1[0] - point2[0])\n",
|
| 99 |
+
" angleInDeg = np.abs(angleInRad * 180.0 / np.pi)\n",
|
| 100 |
+
"\n",
|
| 101 |
+
" angleInDeg = angleInDeg if angleInDeg <= 180 else 360 - angleInDeg\n",
|
| 102 |
+
" return angleInDeg\n",
|
| 103 |
+
"\n",
|
| 104 |
+
"\n",
|
| 105 |
+
"def extract_important_keypoints(results, important_landmarks: list) -> list:\n",
|
| 106 |
+
" '''\n",
|
| 107 |
+
" Extract important keypoints from mediapipe pose detection\n",
|
| 108 |
+
" '''\n",
|
| 109 |
+
" landmarks = results.pose_landmarks.landmark\n",
|
| 110 |
+
"\n",
|
| 111 |
+
" data = []\n",
|
| 112 |
+
" for lm in important_landmarks:\n",
|
| 113 |
+
" keypoint = landmarks[mp_pose.PoseLandmark[lm].value]\n",
|
| 114 |
+
" data.append([keypoint.x, keypoint.y, keypoint.z, keypoint.visibility])\n",
|
| 115 |
+
" \n",
|
| 116 |
+
" return np.array(data).flatten().tolist()\n"
|
| 117 |
+
]
|
| 118 |
+
},
|
| 119 |
+
{
|
| 120 |
+
"cell_type": "markdown",
|
| 121 |
+
"metadata": {},
|
| 122 |
+
"source": [
|
| 123 |
+
"### 2. OOP method for Analyze Pose of each arm\n",
|
| 124 |
+
"\n",
|
| 125 |
+
"To be easier to detect both arm at the same time, I choose to do the calculation for bicep counter and error detection with OOP.\n",
|
| 126 |
+
"\n",
|
| 127 |
+
"*Note: Every If any joints from an arm is appeared to be in poor visibility according to mediapipe, that arm will be skip*"
|
| 128 |
+
]
|
| 129 |
+
},
|
| 130 |
+
{
|
| 131 |
+
"cell_type": "code",
|
| 132 |
+
"execution_count": null,
|
| 133 |
+
"metadata": {},
|
| 134 |
+
"outputs": [],
|
| 135 |
+
"source": [
|
| 136 |
+
"class BicepPoseAnalysis:\n",
|
| 137 |
+
" def __init__(self, side: str, stage_down_threshold: float, stage_up_threshold: float, peak_contraction_threshold: float, loose_upper_arm_angle_threshold: float, visibility_threshold: float):\n",
|
| 138 |
+
" # Initialize thresholds\n",
|
| 139 |
+
" self.stage_down_threshold = stage_down_threshold\n",
|
| 140 |
+
" self.stage_up_threshold = stage_up_threshold\n",
|
| 141 |
+
" self.peak_contraction_threshold = peak_contraction_threshold\n",
|
| 142 |
+
" self.loose_upper_arm_angle_threshold = loose_upper_arm_angle_threshold\n",
|
| 143 |
+
" self.visibility_threshold = visibility_threshold\n",
|
| 144 |
+
"\n",
|
| 145 |
+
" self.side = side\n",
|
| 146 |
+
" self.counter = 0\n",
|
| 147 |
+
" self.stage = \"down\"\n",
|
| 148 |
+
" self.is_visible = True\n",
|
| 149 |
+
" self.detected_errors = {\n",
|
| 150 |
+
" \"LOOSE_UPPER_ARM\": 0,\n",
|
| 151 |
+
" \"PEAK_CONTRACTION\": 0,\n",
|
| 152 |
+
" }\n",
|
| 153 |
+
"\n",
|
| 154 |
+
" # Params for loose upper arm error detection\n",
|
| 155 |
+
" self.loose_upper_arm = False\n",
|
| 156 |
+
"\n",
|
| 157 |
+
" # Params for peak contraction error detection\n",
|
| 158 |
+
" self.peak_contraction_angle = 1000\n",
|
| 159 |
+
" self.peak_contraction_frame = None\n",
|
| 160 |
+
" \n",
|
| 161 |
+
" def get_joints(self, landmarks) -> bool:\n",
|
| 162 |
+
" '''\n",
|
| 163 |
+
" Check for joints' visibility then get joints coordinate\n",
|
| 164 |
+
" '''\n",
|
| 165 |
+
" side = self.side.upper()\n",
|
| 166 |
+
"\n",
|
| 167 |
+
" # Check visibility\n",
|
| 168 |
+
" joints_visibility = [ landmarks[mp_pose.PoseLandmark[f\"{side}_SHOULDER\"].value].visibility, landmarks[mp_pose.PoseLandmark[f\"{side}_ELBOW\"].value].visibility, landmarks[mp_pose.PoseLandmark[f\"{side}_WRIST\"].value].visibility ]\n",
|
| 169 |
+
"\n",
|
| 170 |
+
" is_visible = all([ vis > self.visibility_threshold for vis in joints_visibility ])\n",
|
| 171 |
+
" self.is_visible = is_visible\n",
|
| 172 |
+
"\n",
|
| 173 |
+
" if not is_visible:\n",
|
| 174 |
+
" return self.is_visible\n",
|
| 175 |
+
" \n",
|
| 176 |
+
" # Get joints' coordinates\n",
|
| 177 |
+
" self.shoulder = [ landmarks[mp_pose.PoseLandmark[f\"{side}_SHOULDER\"].value].x, landmarks[mp_pose.PoseLandmark[f\"{side}_SHOULDER\"].value].y ]\n",
|
| 178 |
+
" self.elbow = [ landmarks[mp_pose.PoseLandmark[f\"{side}_ELBOW\"].value].x, landmarks[mp_pose.PoseLandmark[f\"{side}_ELBOW\"].value].y ]\n",
|
| 179 |
+
" self.wrist = [ landmarks[mp_pose.PoseLandmark[f\"{side}_WRIST\"].value].x, landmarks[mp_pose.PoseLandmark[f\"{side}_WRIST\"].value].y ]\n",
|
| 180 |
+
"\n",
|
| 181 |
+
" return self.is_visible\n",
|
| 182 |
+
" \n",
|
| 183 |
+
" def analyze_pose(self, landmarks, frame):\n",
|
| 184 |
+
" '''\n",
|
| 185 |
+
" - Bicep Counter\n",
|
| 186 |
+
" - Errors Detection\n",
|
| 187 |
+
" '''\n",
|
| 188 |
+
" self.get_joints(landmarks)\n",
|
| 189 |
+
"\n",
|
| 190 |
+
" # Cancel calculation if visibility is poor\n",
|
| 191 |
+
" if not self.is_visible:\n",
|
| 192 |
+
" return (None, None)\n",
|
| 193 |
+
"\n",
|
| 194 |
+
" # * Calculate curl angle for counter\n",
|
| 195 |
+
" bicep_curl_angle = int(calculate_angle(self.shoulder, self.elbow, self.wrist))\n",
|
| 196 |
+
" if bicep_curl_angle > self.stage_down_threshold:\n",
|
| 197 |
+
" self.stage = \"down\"\n",
|
| 198 |
+
" elif bicep_curl_angle < self.stage_up_threshold and self.stage == \"down\":\n",
|
| 199 |
+
" self.stage = \"up\"\n",
|
| 200 |
+
" self.counter += 1\n",
|
| 201 |
+
" \n",
|
| 202 |
+
" # * Calculate the angle between the upper arm (shoulder & joint) and the Y axis\n",
|
| 203 |
+
" shoulder_projection = [ self.shoulder[0], 1 ] # Represent the projection of the shoulder to the X axis\n",
|
| 204 |
+
" ground_upper_arm_angle = int(calculate_angle(self.elbow, self.shoulder, shoulder_projection))\n",
|
| 205 |
+
"\n",
|
| 206 |
+
" # * Evaluation for LOOSE UPPER ARM error\n",
|
| 207 |
+
" if ground_upper_arm_angle > self.loose_upper_arm_angle_threshold:\n",
|
| 208 |
+
" # Limit the saved frame\n",
|
| 209 |
+
" if not self.loose_upper_arm:\n",
|
| 210 |
+
" self.loose_upper_arm = True\n",
|
| 211 |
+
" # save_frame_as_image(frame, f\"Loose upper arm: {ground_upper_arm_angle}\")\n",
|
| 212 |
+
" self.detected_errors[\"LOOSE_UPPER_ARM\"] += 1\n",
|
| 213 |
+
" else:\n",
|
| 214 |
+
" self.loose_upper_arm = False\n",
|
| 215 |
+
" \n",
|
| 216 |
+
" # * Evaluate PEAK CONTRACTION error\n",
|
| 217 |
+
" if self.stage == \"up\" and bicep_curl_angle < self.peak_contraction_angle:\n",
|
| 218 |
+
" # Save peaked contraction every rep\n",
|
| 219 |
+
" self.peak_contraction_angle = bicep_curl_angle\n",
|
| 220 |
+
" self.peak_contraction_frame = frame\n",
|
| 221 |
+
" \n",
|
| 222 |
+
" elif self.stage == \"down\":\n",
|
| 223 |
+
" # * Evaluate if the peak is higher than the threshold if True, marked as an error then saved that frame\n",
|
| 224 |
+
" if self.peak_contraction_angle != 1000 and self.peak_contraction_angle >= self.peak_contraction_threshold:\n",
|
| 225 |
+
" # save_frame_as_image(self.peak_contraction_frame, f\"{self.side} - Peak Contraction: {self.peak_contraction_angle}\")\n",
|
| 226 |
+
" self.detected_errors[\"PEAK_CONTRACTION\"] += 1\n",
|
| 227 |
+
" \n",
|
| 228 |
+
" # Reset params\n",
|
| 229 |
+
" self.peak_contraction_angle = 1000\n",
|
| 230 |
+
" self.peak_contraction_frame = None\n",
|
| 231 |
+
" \n",
|
| 232 |
+
" return (bicep_curl_angle, ground_upper_arm_angle)"
|
| 233 |
+
]
|
| 234 |
+
},
|
| 235 |
+
{
|
| 236 |
+
"cell_type": "markdown",
|
| 237 |
+
"metadata": {},
|
| 238 |
+
"source": [
|
| 239 |
+
"### 3. Bicep Detection"
|
| 240 |
+
]
|
| 241 |
+
},
|
| 242 |
+
{
|
| 243 |
+
"cell_type": "code",
|
| 244 |
+
"execution_count": null,
|
| 245 |
+
"metadata": {},
|
| 246 |
+
"outputs": [],
|
| 247 |
+
"source": [
|
| 248 |
+
"VIDEO_PATH1 = \"../data/db_curl/bc_test_1.mp4\"\n",
|
| 249 |
+
"VIDEO_PATH2 = \"../data/db_curl/bc_test_2.mp4\"\n",
|
| 250 |
+
"VIDEO_PATH3 = \"../data/db_curl/bc_test_3.mp4\""
|
| 251 |
+
]
|
| 252 |
+
},
|
| 253 |
+
{
|
| 254 |
+
"cell_type": "code",
|
| 255 |
+
"execution_count": null,
|
| 256 |
+
"metadata": {},
|
| 257 |
+
"outputs": [],
|
| 258 |
+
"source": [
|
| 259 |
+
"# Load input scaler\n",
|
| 260 |
+
"with open(\"./model/input_scaler.pkl\", \"rb\") as f:\n",
|
| 261 |
+
" input_scaler = pickle.load(f)"
|
| 262 |
+
]
|
| 263 |
+
},
|
| 264 |
+
{
|
| 265 |
+
"cell_type": "markdown",
|
| 266 |
+
"metadata": {},
|
| 267 |
+
"source": [
|
| 268 |
+
"#### 3.1. Detection with SKLearn model"
|
| 269 |
+
]
|
| 270 |
+
},
|
| 271 |
+
{
|
| 272 |
+
"cell_type": "code",
|
| 273 |
+
"execution_count": null,
|
| 274 |
+
"metadata": {},
|
| 275 |
+
"outputs": [],
|
| 276 |
+
"source": [
|
| 277 |
+
"# Load model\n",
|
| 278 |
+
"with open(\"./model/KNN_model.pkl\", \"rb\") as f:\n",
|
| 279 |
+
" sklearn_model = pickle.load(f)"
|
| 280 |
+
]
|
| 281 |
+
},
|
| 282 |
+
{
|
| 283 |
+
"cell_type": "code",
|
| 284 |
+
"execution_count": null,
|
| 285 |
+
"metadata": {},
|
| 286 |
+
"outputs": [],
|
| 287 |
+
"source": [
|
| 288 |
+
"cap = cv2.VideoCapture(VIDEO_PATH3)\n",
|
| 289 |
+
"\n",
|
| 290 |
+
"VISIBILITY_THRESHOLD = 0.65\n",
|
| 291 |
+
"\n",
|
| 292 |
+
"# Params for counter\n",
|
| 293 |
+
"STAGE_UP_THRESHOLD = 90\n",
|
| 294 |
+
"STAGE_DOWN_THRESHOLD = 120\n",
|
| 295 |
+
"\n",
|
| 296 |
+
"# Params to catch FULL RANGE OF MOTION error\n",
|
| 297 |
+
"PEAK_CONTRACTION_THRESHOLD = 60\n",
|
| 298 |
+
"\n",
|
| 299 |
+
"# LOOSE UPPER ARM error detection\n",
|
| 300 |
+
"LOOSE_UPPER_ARM = False\n",
|
| 301 |
+
"LOOSE_UPPER_ARM_ANGLE_THRESHOLD = 40\n",
|
| 302 |
+
"\n",
|
| 303 |
+
"# STANDING POSTURE error detection\n",
|
| 304 |
+
"POSTURE_ERROR_THRESHOLD = 0.7\n",
|
| 305 |
+
"posture = \"C\"\n",
|
| 306 |
+
"\n",
|
| 307 |
+
"# Init analysis class\n",
|
| 308 |
+
"left_arm_analysis = BicepPoseAnalysis(side=\"left\", stage_down_threshold=STAGE_DOWN_THRESHOLD, stage_up_threshold=STAGE_UP_THRESHOLD, peak_contraction_threshold=PEAK_CONTRACTION_THRESHOLD, loose_upper_arm_angle_threshold=LOOSE_UPPER_ARM_ANGLE_THRESHOLD, visibility_threshold=VISIBILITY_THRESHOLD)\n",
|
| 309 |
+
"\n",
|
| 310 |
+
"right_arm_analysis = BicepPoseAnalysis(side=\"right\", stage_down_threshold=STAGE_DOWN_THRESHOLD, stage_up_threshold=STAGE_UP_THRESHOLD, peak_contraction_threshold=PEAK_CONTRACTION_THRESHOLD, loose_upper_arm_angle_threshold=LOOSE_UPPER_ARM_ANGLE_THRESHOLD, visibility_threshold=VISIBILITY_THRESHOLD)\n",
|
| 311 |
+
"\n",
|
| 312 |
+
"with mp_pose.Pose(min_detection_confidence=0.8, min_tracking_confidence=0.8) as pose:\n",
|
| 313 |
+
" while cap.isOpened():\n",
|
| 314 |
+
" ret, image = cap.read()\n",
|
| 315 |
+
"\n",
|
| 316 |
+
" if not ret:\n",
|
| 317 |
+
" break\n",
|
| 318 |
+
"\n",
|
| 319 |
+
" # Reduce size of a frame\n",
|
| 320 |
+
" image = rescale_frame(image, 50)\n",
|
| 321 |
+
" # image = cv2.flip(image, 1)\n",
|
| 322 |
+
" \n",
|
| 323 |
+
" video_dimensions = [image.shape[1], image.shape[0]]\n",
|
| 324 |
+
"\n",
|
| 325 |
+
" # Recolor image from BGR to RGB for mediapipe\n",
|
| 326 |
+
" image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)\n",
|
| 327 |
+
" image.flags.writeable = False\n",
|
| 328 |
+
"\n",
|
| 329 |
+
" results = pose.process(image)\n",
|
| 330 |
+
"\n",
|
| 331 |
+
" if not results.pose_landmarks:\n",
|
| 332 |
+
" print(\"No human found\")\n",
|
| 333 |
+
" continue\n",
|
| 334 |
+
"\n",
|
| 335 |
+
" # Recolor image from BGR to RGB for mediapipe\n",
|
| 336 |
+
" image.flags.writeable = True\n",
|
| 337 |
+
" image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)\n",
|
| 338 |
+
"\n",
|
| 339 |
+
" # Draw landmarks and connections\n",
|
| 340 |
+
" mp_drawing.draw_landmarks(image, results.pose_landmarks, mp_pose.POSE_CONNECTIONS, mp_drawing.DrawingSpec(color=(244, 117, 66), thickness=2, circle_radius=2), mp_drawing.DrawingSpec(color=(245, 66, 230), thickness=2, circle_radius=1))\n",
|
| 341 |
+
"\n",
|
| 342 |
+
" # Make detection\n",
|
| 343 |
+
" try:\n",
|
| 344 |
+
" landmarks = results.pose_landmarks.landmark\n",
|
| 345 |
+
" \n",
|
| 346 |
+
" (left_bicep_curl_angle, left_ground_upper_arm_angle) = left_arm_analysis.analyze_pose(landmarks=landmarks, frame=image)\n",
|
| 347 |
+
" (right_bicep_curl_angle, right_ground_upper_arm_angle) = right_arm_analysis.analyze_pose(landmarks=landmarks, frame=image)\n",
|
| 348 |
+
"\n",
|
| 349 |
+
" # Extract keypoints from frame for the input\n",
|
| 350 |
+
" row = extract_important_keypoints(results, IMPORTANT_LMS)\n",
|
| 351 |
+
" X = pd.DataFrame([row], columns=HEADERS[1:])\n",
|
| 352 |
+
" X = pd.DataFrame(input_scaler.transform(X))\n",
|
| 353 |
+
"\n",
|
| 354 |
+
"\n",
|
| 355 |
+
" # Make prediction and its probability\n",
|
| 356 |
+
" predicted_class = sklearn_model.predict(X)[0]\n",
|
| 357 |
+
" prediction_probabilities = sklearn_model.predict_proba(X)[0]\n",
|
| 358 |
+
" class_prediction_probability = round(prediction_probabilities[np.argmax(prediction_probabilities)], 2)\n",
|
| 359 |
+
"\n",
|
| 360 |
+
" if class_prediction_probability >= POSTURE_ERROR_THRESHOLD:\n",
|
| 361 |
+
" posture = predicted_class\n",
|
| 362 |
+
"\n",
|
| 363 |
+
" # Visualization\n",
|
| 364 |
+
" # Status box\n",
|
| 365 |
+
" cv2.rectangle(image, (0, 0), (500, 40), (245, 117, 16), -1)\n",
|
| 366 |
+
"\n",
|
| 367 |
+
" # Display probability\n",
|
| 368 |
+
" cv2.putText(image, \"RIGHT\", (15, 12), cv2.FONT_HERSHEY_COMPLEX, 0.5, (0, 0, 0), 1, cv2.LINE_AA)\n",
|
| 369 |
+
" cv2.putText(image, str(right_arm_analysis.counter) if right_arm_analysis.is_visible else \"UNK\", (10, 30), cv2.FONT_HERSHEY_COMPLEX, 0.5, (255, 255, 255), 2, cv2.LINE_AA)\n",
|
| 370 |
+
"\n",
|
| 371 |
+
" # Display Left Counter\n",
|
| 372 |
+
" cv2.putText(image, \"LEFT\", (95, 12), cv2.FONT_HERSHEY_COMPLEX, 0.5, (0, 0, 0), 1, cv2.LINE_AA)\n",
|
| 373 |
+
" cv2.putText(image, str(left_arm_analysis.counter) if left_arm_analysis.is_visible else \"UNK\", (100, 30), cv2.FONT_HERSHEY_COMPLEX, 0.5, (255, 255, 255), 2, cv2.LINE_AA)\n",
|
| 374 |
+
"\n",
|
| 375 |
+
" # * Display error\n",
|
| 376 |
+
" # Right arm error\n",
|
| 377 |
+
" cv2.putText(image, \"R_PC\", (165, 12), cv2.FONT_HERSHEY_COMPLEX, 0.5, (0, 0, 0), 1, cv2.LINE_AA)\n",
|
| 378 |
+
" cv2.putText(image, str(right_arm_analysis.detected_errors[\"PEAK_CONTRACTION\"]), (160, 30), cv2.FONT_HERSHEY_COMPLEX, 0.5, (255, 255, 255), 2, cv2.LINE_AA)\n",
|
| 379 |
+
" cv2.putText(image, \"R_LUA\", (225, 12), cv2.FONT_HERSHEY_COMPLEX, 0.5, (0, 0, 0), 1, cv2.LINE_AA)\n",
|
| 380 |
+
" cv2.putText(image, str(right_arm_analysis.detected_errors[\"LOOSE_UPPER_ARM\"]), (220, 30), cv2.FONT_HERSHEY_COMPLEX, 0.5, (255, 255, 255), 2, cv2.LINE_AA)\n",
|
| 381 |
+
"\n",
|
| 382 |
+
" # Left arm error\n",
|
| 383 |
+
" cv2.putText(image, \"L_PC\", (300, 12), cv2.FONT_HERSHEY_COMPLEX, 0.5, (0, 0, 0), 1, cv2.LINE_AA)\n",
|
| 384 |
+
" cv2.putText(image, str(left_arm_analysis.detected_errors[\"PEAK_CONTRACTION\"]), (295, 30), cv2.FONT_HERSHEY_COMPLEX, 0.5, (255, 255, 255), 2, cv2.LINE_AA)\n",
|
| 385 |
+
" cv2.putText(image, \"L_LUA\", (380, 12), cv2.FONT_HERSHEY_COMPLEX, 0.5, (0, 0, 0), 1, cv2.LINE_AA)\n",
|
| 386 |
+
" cv2.putText(image, str(left_arm_analysis.detected_errors[\"LOOSE_UPPER_ARM\"]), (375, 30), cv2.FONT_HERSHEY_COMPLEX, 0.5, (255, 255, 255), 2, cv2.LINE_AA)\n",
|
| 387 |
+
"\n",
|
| 388 |
+
" # Lean back error\n",
|
| 389 |
+
" cv2.putText(image, \"LB\", (460, 12), cv2.FONT_HERSHEY_COMPLEX, 0.5, (0, 0, 0), 1, cv2.LINE_AA)\n",
|
| 390 |
+
" cv2.putText(image, str(f\"{posture}, {predicted_class}, {class_prediction_probability}\"), (440, 30), cv2.FONT_HERSHEY_COMPLEX, 0.3, (255, 255, 255), 1, cv2.LINE_AA)\n",
|
| 391 |
+
"\n",
|
| 392 |
+
"\n",
|
| 393 |
+
" # * Visualize angles\n",
|
| 394 |
+
" # Visualize LEFT arm calculated angles\n",
|
| 395 |
+
" if left_arm_analysis.is_visible:\n",
|
| 396 |
+
" cv2.putText(image, str(left_bicep_curl_angle), tuple(np.multiply(left_arm_analysis.elbow, video_dimensions).astype(int)), cv2.FONT_HERSHEY_COMPLEX, 0.5, (255, 255, 255), 1, cv2.LINE_AA)\n",
|
| 397 |
+
" cv2.putText(image, str(left_ground_upper_arm_angle), tuple(np.multiply(left_arm_analysis.shoulder, video_dimensions).astype(int)), cv2.FONT_HERSHEY_COMPLEX, 0.5, (255, 255, 255), 1, cv2.LINE_AA)\n",
|
| 398 |
+
"\n",
|
| 399 |
+
"\n",
|
| 400 |
+
" # Visualize RIGHT arm calculated angles\n",
|
| 401 |
+
" if right_arm_analysis.is_visible:\n",
|
| 402 |
+
" cv2.putText(image, str(right_bicep_curl_angle), tuple(np.multiply(right_arm_analysis.elbow, video_dimensions).astype(int)), cv2.FONT_HERSHEY_COMPLEX, 0.5, (255, 255, 0), 1, cv2.LINE_AA)\n",
|
| 403 |
+
" cv2.putText(image, str(right_ground_upper_arm_angle), tuple(np.multiply(right_arm_analysis.shoulder, video_dimensions).astype(int)), cv2.FONT_HERSHEY_COMPLEX, 0.5, (255, 255, 0), 1, cv2.LINE_AA)\n",
|
| 404 |
+
" \n",
|
| 405 |
+
" except Exception as e:\n",
|
| 406 |
+
" print(f\"Error: {e}\")\n",
|
| 407 |
+
" \n",
|
| 408 |
+
" cv2.imshow(\"CV2\", image)\n",
|
| 409 |
+
"\n",
|
| 410 |
+
" # if left_arm_analysis.loose_upper_arm:\n",
|
| 411 |
+
" # save_frame_as_image(image, \"\")\n",
|
| 412 |
+
" \n",
|
| 413 |
+
" # Press Q to close cv2 window\n",
|
| 414 |
+
" if cv2.waitKey(1) & 0xFF == ord('q'):\n",
|
| 415 |
+
" break\n",
|
| 416 |
+
"\n",
|
| 417 |
+
" cap.release()\n",
|
| 418 |
+
" cv2.destroyAllWindows()\n",
|
| 419 |
+
"\n",
|
| 420 |
+
" # (Optional)Fix bugs cannot close windows in MacOS (https://stackoverflow.com/questions/6116564/destroywindow-does-not-close-window-on-mac-using-python-and-opencv)\n",
|
| 421 |
+
" for i in range (1, 5):\n",
|
| 422 |
+
" cv2.waitKey(1)\n",
|
| 423 |
+
" \n"
|
| 424 |
+
]
|
| 425 |
+
},
|
| 426 |
+
{
|
| 427 |
+
"cell_type": "markdown",
|
| 428 |
+
"metadata": {},
|
| 429 |
+
"source": [
|
| 430 |
+
"#### 3.2. Detection with Deep Learning model"
|
| 431 |
+
]
|
| 432 |
+
},
|
| 433 |
+
{
|
| 434 |
+
"cell_type": "code",
|
| 435 |
+
"execution_count": null,
|
| 436 |
+
"metadata": {},
|
| 437 |
+
"outputs": [],
|
| 438 |
+
"source": [
|
| 439 |
+
"# Load model\n",
|
| 440 |
+
"with open(\"./model/bicep_model_deep_learning.pkl\", \"rb\") as f:\n",
|
| 441 |
+
" DL_model = pickle.load(f)"
|
| 442 |
+
]
|
| 443 |
+
},
|
| 444 |
+
{
|
| 445 |
+
"cell_type": "code",
|
| 446 |
+
"execution_count": null,
|
| 447 |
+
"metadata": {},
|
| 448 |
+
"outputs": [],
|
| 449 |
+
"source": [
|
| 450 |
+
"cap = cv2.VideoCapture(VIDEO_PATH3)\n",
|
| 451 |
+
"\n",
|
| 452 |
+
"VISIBILITY_THRESHOLD = 0.65\n",
|
| 453 |
+
"\n",
|
| 454 |
+
"# Params for counter\n",
|
| 455 |
+
"STAGE_UP_THRESHOLD = 90\n",
|
| 456 |
+
"STAGE_DOWN_THRESHOLD = 120\n",
|
| 457 |
+
"\n",
|
| 458 |
+
"# Params to catch FULL RANGE OF MOTION error\n",
|
| 459 |
+
"PEAK_CONTRACTION_THRESHOLD = 60\n",
|
| 460 |
+
"\n",
|
| 461 |
+
"# LOOSE UPPER ARM error detection\n",
|
| 462 |
+
"LOOSE_UPPER_ARM = False\n",
|
| 463 |
+
"LOOSE_UPPER_ARM_ANGLE_THRESHOLD = 40\n",
|
| 464 |
+
"\n",
|
| 465 |
+
"# STANDING POSTURE error detection\n",
|
| 466 |
+
"POSTURE_ERROR_THRESHOLD = 0.95\n",
|
| 467 |
+
"posture = 0\n",
|
| 468 |
+
"\n",
|
| 469 |
+
"# Init analysis class\n",
|
| 470 |
+
"left_arm_analysis = BicepPoseAnalysis(side=\"left\", stage_down_threshold=STAGE_DOWN_THRESHOLD, stage_up_threshold=STAGE_UP_THRESHOLD, peak_contraction_threshold=PEAK_CONTRACTION_THRESHOLD, loose_upper_arm_angle_threshold=LOOSE_UPPER_ARM_ANGLE_THRESHOLD, visibility_threshold=VISIBILITY_THRESHOLD)\n",
|
| 471 |
+
"\n",
|
| 472 |
+
"right_arm_analysis = BicepPoseAnalysis(side=\"right\", stage_down_threshold=STAGE_DOWN_THRESHOLD, stage_up_threshold=STAGE_UP_THRESHOLD, peak_contraction_threshold=PEAK_CONTRACTION_THRESHOLD, loose_upper_arm_angle_threshold=LOOSE_UPPER_ARM_ANGLE_THRESHOLD, visibility_threshold=VISIBILITY_THRESHOLD)\n",
|
| 473 |
+
"\n",
|
| 474 |
+
"with mp_pose.Pose(min_detection_confidence=0.5, min_tracking_confidence=0.5) as pose:\n",
|
| 475 |
+
" while cap.isOpened():\n",
|
| 476 |
+
" ret, image = cap.read()\n",
|
| 477 |
+
"\n",
|
| 478 |
+
" if not ret:\n",
|
| 479 |
+
" break\n",
|
| 480 |
+
"\n",
|
| 481 |
+
" # Reduce size of a frame\n",
|
| 482 |
+
" image = rescale_frame(image, 50)\n",
|
| 483 |
+
" # image = cv2.flip(image, 1)\n",
|
| 484 |
+
" \n",
|
| 485 |
+
" video_dimensions = [image.shape[1], image.shape[0]]\n",
|
| 486 |
+
"\n",
|
| 487 |
+
" # Recolor image from BGR to RGB for mediapipe\n",
|
| 488 |
+
" image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)\n",
|
| 489 |
+
" image.flags.writeable = False\n",
|
| 490 |
+
"\n",
|
| 491 |
+
" results = pose.process(image)\n",
|
| 492 |
+
"\n",
|
| 493 |
+
" if not results.pose_landmarks:\n",
|
| 494 |
+
" print(\"No human found\")\n",
|
| 495 |
+
" continue\n",
|
| 496 |
+
"\n",
|
| 497 |
+
" # Recolor image from BGR to RGB for mediapipe\n",
|
| 498 |
+
" image.flags.writeable = True\n",
|
| 499 |
+
" image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)\n",
|
| 500 |
+
"\n",
|
| 501 |
+
" # Draw landmarks and connections\n",
|
| 502 |
+
" mp_drawing.draw_landmarks(image, results.pose_landmarks, mp_pose.POSE_CONNECTIONS, mp_drawing.DrawingSpec(color=(244, 117, 66), thickness=2, circle_radius=2), mp_drawing.DrawingSpec(color=(245, 66, 230), thickness=2, circle_radius=1))\n",
|
| 503 |
+
"\n",
|
| 504 |
+
" # Make detection\n",
|
| 505 |
+
" try:\n",
|
| 506 |
+
" landmarks = results.pose_landmarks.landmark\n",
|
| 507 |
+
" \n",
|
| 508 |
+
" (left_bicep_curl_angle, left_ground_upper_arm_angle) = left_arm_analysis.analyze_pose(landmarks=landmarks, frame=image)\n",
|
| 509 |
+
" (right_bicep_curl_angle, right_ground_upper_arm_angle) = right_arm_analysis.analyze_pose(landmarks=landmarks, frame=image)\n",
|
| 510 |
+
"\n",
|
| 511 |
+
" # Extract keypoints from frame for the input\n",
|
| 512 |
+
" row = extract_important_keypoints(results, IMPORTANT_LMS)\n",
|
| 513 |
+
" X = pd.DataFrame([row, ], columns=HEADERS[1:])\n",
|
| 514 |
+
" X = pd.DataFrame(input_scaler.transform(X))\n",
|
| 515 |
+
"\n",
|
| 516 |
+
" # Make prediction and its probability\n",
|
| 517 |
+
" prediction = DL_model.predict(X)\n",
|
| 518 |
+
" predicted_class = np.argmax(prediction, axis=1)[0]\n",
|
| 519 |
+
" prediction_probability = round(max(prediction.tolist()[0]), 2)\n",
|
| 520 |
+
"\n",
|
| 521 |
+
" if prediction_probability >= POSTURE_ERROR_THRESHOLD:\n",
|
| 522 |
+
" posture = predicted_class\n",
|
| 523 |
+
"\n",
|
| 524 |
+
" # Visualization\n",
|
| 525 |
+
" # Status box\n",
|
| 526 |
+
" cv2.rectangle(image, (0, 0), (500, 40), (245, 117, 16), -1)\n",
|
| 527 |
+
"\n",
|
| 528 |
+
" # Display probability\n",
|
| 529 |
+
" cv2.putText(image, \"RIGHT\", (15, 12), cv2.FONT_HERSHEY_COMPLEX, 0.5, (0, 0, 0), 1, cv2.LINE_AA)\n",
|
| 530 |
+
" cv2.putText(image, str(right_arm_analysis.counter) if right_arm_analysis.is_visible else \"UNK\", (10, 30), cv2.FONT_HERSHEY_COMPLEX, 0.5, (255, 255, 255), 2, cv2.LINE_AA)\n",
|
| 531 |
+
"\n",
|
| 532 |
+
" # Display Left Counter\n",
|
| 533 |
+
" cv2.putText(image, \"LEFT\", (95, 12), cv2.FONT_HERSHEY_COMPLEX, 0.5, (0, 0, 0), 1, cv2.LINE_AA)\n",
|
| 534 |
+
" cv2.putText(image, str(left_arm_analysis.counter) if left_arm_analysis.is_visible else \"UNK\", (100, 30), cv2.FONT_HERSHEY_COMPLEX, 0.5, (255, 255, 255), 2, cv2.LINE_AA)\n",
|
| 535 |
+
"\n",
|
| 536 |
+
" # * Display error\n",
|
| 537 |
+
" # Right arm error\n",
|
| 538 |
+
" cv2.putText(image, \"R_PC\", (165, 12), cv2.FONT_HERSHEY_COMPLEX, 0.5, (0, 0, 0), 1, cv2.LINE_AA)\n",
|
| 539 |
+
" cv2.putText(image, str(right_arm_analysis.detected_errors[\"PEAK_CONTRACTION\"]), (160, 30), cv2.FONT_HERSHEY_COMPLEX, 0.5, (255, 255, 255), 2, cv2.LINE_AA)\n",
|
| 540 |
+
" cv2.putText(image, \"R_LUA\", (225, 12), cv2.FONT_HERSHEY_COMPLEX, 0.5, (0, 0, 0), 1, cv2.LINE_AA)\n",
|
| 541 |
+
" cv2.putText(image, str(right_arm_analysis.detected_errors[\"LOOSE_UPPER_ARM\"]), (220, 30), cv2.FONT_HERSHEY_COMPLEX, 0.5, (255, 255, 255), 2, cv2.LINE_AA)\n",
|
| 542 |
+
"\n",
|
| 543 |
+
" # Left arm error\n",
|
| 544 |
+
" cv2.putText(image, \"L_PC\", (300, 12), cv2.FONT_HERSHEY_COMPLEX, 0.5, (0, 0, 0), 1, cv2.LINE_AA)\n",
|
| 545 |
+
" cv2.putText(image, str(left_arm_analysis.detected_errors[\"PEAK_CONTRACTION\"]), (295, 30), cv2.FONT_HERSHEY_COMPLEX, 0.5, (255, 255, 255), 2, cv2.LINE_AA)\n",
|
| 546 |
+
" cv2.putText(image, \"L_LUA\", (380, 12), cv2.FONT_HERSHEY_COMPLEX, 0.5, (0, 0, 0), 1, cv2.LINE_AA)\n",
|
| 547 |
+
" cv2.putText(image, str(left_arm_analysis.detected_errors[\"LOOSE_UPPER_ARM\"]), (375, 30), cv2.FONT_HERSHEY_COMPLEX, 0.5, (255, 255, 255), 2, cv2.LINE_AA)\n",
|
| 548 |
+
"\n",
|
| 549 |
+
" # Lean back error\n",
|
| 550 |
+
" cv2.putText(image, \"LB\", (460, 12), cv2.FONT_HERSHEY_COMPLEX, 0.5, (0, 0, 0), 1, cv2.LINE_AA)\n",
|
| 551 |
+
" cv2.putText(image, str(\"C\" if posture == 0 else \"L\") + f\" ,{predicted_class}, {prediction_probability}\", (440, 30), cv2.FONT_HERSHEY_COMPLEX, 0.3, (255, 255, 255), 1, cv2.LINE_AA)\n",
|
| 552 |
+
"\n",
|
| 553 |
+
"\n",
|
| 554 |
+
" # * Visualize angles\n",
|
| 555 |
+
" # Visualize LEFT arm calculated angles\n",
|
| 556 |
+
" if left_arm_analysis.is_visible:\n",
|
| 557 |
+
" cv2.putText(image, str(left_bicep_curl_angle), tuple(np.multiply(left_arm_analysis.elbow, video_dimensions).astype(int)), cv2.FONT_HERSHEY_COMPLEX, 0.5, (255, 255, 255), 1, cv2.LINE_AA)\n",
|
| 558 |
+
" cv2.putText(image, str(left_ground_upper_arm_angle), tuple(np.multiply(left_arm_analysis.shoulder, video_dimensions).astype(int)), cv2.FONT_HERSHEY_COMPLEX, 0.5, (255, 255, 255), 1, cv2.LINE_AA)\n",
|
| 559 |
+
"\n",
|
| 560 |
+
"\n",
|
| 561 |
+
" # Visualize RIGHT arm calculated angles\n",
|
| 562 |
+
" if right_arm_analysis.is_visible:\n",
|
| 563 |
+
" cv2.putText(image, str(right_bicep_curl_angle), tuple(np.multiply(right_arm_analysis.elbow, video_dimensions).astype(int)), cv2.FONT_HERSHEY_COMPLEX, 0.5, (255, 255, 0), 1, cv2.LINE_AA)\n",
|
| 564 |
+
" cv2.putText(image, str(right_ground_upper_arm_angle), tuple(np.multiply(right_arm_analysis.shoulder, video_dimensions).astype(int)), cv2.FONT_HERSHEY_COMPLEX, 0.5, (255, 255, 0), 1, cv2.LINE_AA)\n",
|
| 565 |
+
" \n",
|
| 566 |
+
" except Exception as e:\n",
|
| 567 |
+
" print(f\"Error: {e}\")\n",
|
| 568 |
+
" \n",
|
| 569 |
+
" cv2.imshow(\"CV2\", image)\n",
|
| 570 |
+
" \n",
|
| 571 |
+
" # Press Q to close cv2 window\n",
|
| 572 |
+
" if cv2.waitKey(1) & 0xFF == ord('q'):\n",
|
| 573 |
+
" break\n",
|
| 574 |
+
"\n",
|
| 575 |
+
" cap.release()\n",
|
| 576 |
+
" cv2.destroyAllWindows()\n",
|
| 577 |
+
"\n",
|
| 578 |
+
" # (Optional)Fix bugs cannot close windows in MacOS (https://stackoverflow.com/questions/6116564/destroywindow-does-not-close-window-on-mac-using-python-and-opencv)\n",
|
| 579 |
+
" for i in range (1, 5):\n",
|
| 580 |
+
" cv2.waitKey(1)\n",
|
| 581 |
+
" \n"
|
| 582 |
+
]
|
| 583 |
+
},
|
| 584 |
+
{
|
| 585 |
+
"cell_type": "code",
|
| 586 |
+
"execution_count": null,
|
| 587 |
+
"metadata": {},
|
| 588 |
+
"outputs": [],
|
| 589 |
+
"source": []
|
| 590 |
+
}
|
| 591 |
+
],
|
| 592 |
+
"metadata": {
|
| 593 |
+
"kernelspec": {
|
| 594 |
+
"display_name": "Python 3.8.13 (conda)",
|
| 595 |
+
"language": "python",
|
| 596 |
+
"name": "python3"
|
| 597 |
+
},
|
| 598 |
+
"language_info": {
|
| 599 |
+
"codemirror_mode": {
|
| 600 |
+
"name": "ipython",
|
| 601 |
+
"version": 3
|
| 602 |
+
},
|
| 603 |
+
"file_extension": ".py",
|
| 604 |
+
"mimetype": "text/x-python",
|
| 605 |
+
"name": "python",
|
| 606 |
+
"nbconvert_exporter": "python",
|
| 607 |
+
"pygments_lexer": "ipython3",
|
| 608 |
+
"version": "3.8.13"
|
| 609 |
+
},
|
| 610 |
+
"orig_nbformat": 4,
|
| 611 |
+
"vscode": {
|
| 612 |
+
"interpreter": {
|
| 613 |
+
"hash": "9260f401923fb5c4108c543a7d176de9733d378b3752e49535ad7c43c2271b65"
|
| 614 |
+
}
|
| 615 |
+
}
|
| 616 |
+
},
|
| 617 |
+
"nbformat": 4,
|
| 618 |
+
"nbformat_minor": 2
|
| 619 |
+
}
|
core/bicep_model/README.md
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<h2 align="center">BICEP CURL MODEL TRAINING PROCESS</h2>
|
| 2 |
+
|
| 3 |
+
### 1. Folder structure
|
| 4 |
+
|
| 5 |
+
```
|
| 6 |
+
bicep_model
|
| 7 |
+
│ 1.data.ipynb - process collected videos
|
| 8 |
+
| 2.sklearn.ipynb - train models using Sklearn ML algo
|
| 9 |
+
│ 3.deep_leaning.ipynb - train models using Deep Learning
|
| 10 |
+
│ 4.evaluation.ipynb - evaluate trained models
|
| 11 |
+
│ 5.detection.ipynb - detection on test videos
|
| 12 |
+
| train.csv - train dataset after converted from videos
|
| 13 |
+
| test.csv - test dataset after converted from videos
|
| 14 |
+
| evaluation.csv - models' evaluation results
|
| 15 |
+
│
|
| 16 |
+
└───model/ - folder contains best trained models and input scaler
|
| 17 |
+
│ │
|
| 18 |
+
```
|
| 19 |
+
|
| 20 |
+
### 2. Important landmarks
|
| 21 |
+
|
| 22 |
+
There are 3 popular errors of bicep curl that will be targeted in this thesis:
|
| 23 |
+
|
| 24 |
+
- Loose upper arm: when an arm moves upward during the exercise, the upper arm is moving instead of staying still.
|
| 25 |
+
- Weak peak contraction: when an arm moves upward, it does not go high enough therefore not put enough contraction to the bicep.
|
| 26 |
+
- Lean too far back: the performer’s torso leans back and fore during the exercise for momentum.
|
| 27 |
+
|
| 28 |
+
In my research and exploration, **_the important MediaPipe Pose landmarks_** for this exercise are: nose, left shoulder, right shoulder, right elbow, left elbow, right wrist, left wrist, right hip and left hip.
|
| 29 |
+
|
| 30 |
+
### 3. Error detection method
|
| 31 |
+
|
| 32 |
+
1. **Loose upper arm**: Can be detected by calculating the angle between the elbow, shoulder and the shoulder’s projection on the ground. Through my research, if the angle is over 40 degrees, the movement will be classified as a “loose upper arm” error
|
| 33 |
+
|
| 34 |
+
1. **Weak peak contraction**: Can be detected by calculating the angle between the wrist, elbow and shoulder when the performer’s arm is coming up. Through my research, if the angle is more than 60 degrees before the arm comes down, the movement will be classified as a “weak peak contraction” error.
|
| 35 |
+
|
| 36 |
+
1. **Lean too far back**: Due to its complexity, machine learning will be used for this detection. See this [notebook](./4.evaluation.ipynb) for a evaluation process for this model.
|
core/bicep_model/evaluation.csv
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Model,Precision Score,Recall Score,Accuracy Score,F1 Score,Confusion Matrix
|
| 2 |
+
KNN,0.9754011299435028,0.9683363945010297,0.9718543046357616,0.9712296333655557,"[[338 1]
|
| 3 |
+
[ 16 249]]"
|
| 4 |
+
7_layers,0.9721448467966574,0.9622641509433962,0.9668874172185431,0.9660655092982752,"[[339 0]
|
| 5 |
+
[ 20 245]]"
|
| 6 |
+
5_layers,0.9631147540983607,0.9490566037735849,0.9552980132450332,0.9540120976270039,"[[339 0]
|
| 7 |
+
[ 27 238]]"
|
| 8 |
+
SVC,0.9299664562828265,0.9337618968108199,0.9321192052980133,0.9314197094947314,"[[312 27]
|
| 9 |
+
[ 14 251]]"
|
| 10 |
+
RF,0.9472295514511873,0.9245283018867925,0.9337748344370861,0.9313285202660452,"[[339 0]
|
| 11 |
+
[ 40 225]]"
|
| 12 |
+
7_layers_with_dropout,0.9358732553753301,0.9244615127734179,0.9304635761589404,0.9285834938008851,"[[330 9]
|
| 13 |
+
[ 33 232]]"
|
| 14 |
+
3_layers,0.9391653103929318,0.9205153893248734,0.9288079470198676,0.9264113788657968,"[[335 4]
|
| 15 |
+
[ 39 226]]"
|
| 16 |
+
LR,0.7926937886241248,0.737774809372739,0.7615894039735099,0.7405498281786941,"[[316 23]
|
| 17 |
+
[121 144]]"
|
| 18 |
+
SGDC,0.7124566903151295,0.7150108532309234,0.7152317880794702,0.7129372754904669,"[[243 96]
|
| 19 |
+
[ 76 189]]"
|
| 20 |
+
DTC,0.6842549139631369,0.6507819892024267,0.6754966887417219,0.6475785613069935,"[[289 50]
|
| 21 |
+
[146 119]]"
|
| 22 |
+
NB,0.7973684210526315,0.5641509433962264,0.6175496688741722,0.48664966831131273,"[[339 0]
|
| 23 |
+
[231 34]]"
|
core/bicep_model/model/KNN_model.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6db6bcacdc92f1af9d455e86936ac0437f4b4c3c352c169e790e3d0bb66454e4
|
| 3 |
+
size 3640626
|
core/bicep_model/model/RF_model.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d4c5486bda6f57467de5c94810e30b0d4b0db109ac8eb80655be71c20d88d09a
|
| 3 |
+
size 1745827
|
core/bicep_model/model/all_dp.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9a5b76624aa1b3be26205bbda66aa548cf426521a8897c383f9be4d330673cee
|
| 3 |
+
size 6349244
|
core/bicep_model/model/all_sklearn.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:27a82547b1d960a4e19da93e25d2607f674d76c9be0a569e451cc97a18a67432
|
| 3 |
+
size 5533459
|
core/bicep_model/model/bicep_dp.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d6eef302ce47e9595f503c94c2ea5d7b7b2c9c7366732236c0ccf5ed125f1281
|
| 3 |
+
size 5202150
|
core/bicep_model/model/input_scaler.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5c07942166c3fa393e9fee1fed9228a963fd8cb81111d25b2d6d77aa90d4a4e2
|
| 3 |
+
size 1949
|
core/bicep_model/test.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
core/bicep_model/train.csv
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:12d56e9215045a56103f80a83e1952d70320a25c4cf0a33a898b38aade151e5b
|
| 3 |
+
size 10821608
|
core/lunge_model/1.stage.data.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
core/lunge_model/2.stage.sklearn.ipynb
ADDED
|
@@ -0,0 +1,753 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": 1,
|
| 6 |
+
"metadata": {},
|
| 7 |
+
"outputs": [
|
| 8 |
+
{
|
| 9 |
+
"name": "stderr",
|
| 10 |
+
"output_type": "stream",
|
| 11 |
+
"text": [
|
| 12 |
+
"objc[21024]: Class CaptureDelegate is implemented in both /Users/fuixlabsdev1/Programming/PP/graduation-thesis/env/lib/python3.8/site-packages/mediapipe/.dylibs/libopencv_videoio.3.4.16.dylib (0x105f48860) and /Users/fuixlabsdev1/Programming/PP/graduation-thesis/env/lib/python3.8/site-packages/cv2/cv2.abi3.so (0x15aece480). One of the two will be used. Which one is undefined.\n",
|
| 13 |
+
"objc[21024]: Class CVWindow is implemented in both /Users/fuixlabsdev1/Programming/PP/graduation-thesis/env/lib/python3.8/site-packages/mediapipe/.dylibs/libopencv_highgui.3.4.16.dylib (0x104310a68) and /Users/fuixlabsdev1/Programming/PP/graduation-thesis/env/lib/python3.8/site-packages/cv2/cv2.abi3.so (0x15aece4d0). One of the two will be used. Which one is undefined.\n",
|
| 14 |
+
"objc[21024]: Class CVView is implemented in both /Users/fuixlabsdev1/Programming/PP/graduation-thesis/env/lib/python3.8/site-packages/mediapipe/.dylibs/libopencv_highgui.3.4.16.dylib (0x104310a90) and /Users/fuixlabsdev1/Programming/PP/graduation-thesis/env/lib/python3.8/site-packages/cv2/cv2.abi3.so (0x15aece4f8). One of the two will be used. Which one is undefined.\n",
|
| 15 |
+
"objc[21024]: Class CVSlider is implemented in both /Users/fuixlabsdev1/Programming/PP/graduation-thesis/env/lib/python3.8/site-packages/mediapipe/.dylibs/libopencv_highgui.3.4.16.dylib (0x104310ab8) and /Users/fuixlabsdev1/Programming/PP/graduation-thesis/env/lib/python3.8/site-packages/cv2/cv2.abi3.so (0x15aece520). One of the two will be used. Which one is undefined.\n"
|
| 16 |
+
]
|
| 17 |
+
}
|
| 18 |
+
],
|
| 19 |
+
"source": [
|
| 20 |
+
"import mediapipe as mp\n",
|
| 21 |
+
"import cv2\n",
|
| 22 |
+
"import pandas as pd\n",
|
| 23 |
+
"import pickle\n",
|
| 24 |
+
"\n",
|
| 25 |
+
"from sklearn.model_selection import train_test_split\n",
|
| 26 |
+
"from sklearn.preprocessing import StandardScaler\n",
|
| 27 |
+
"\n",
|
| 28 |
+
"import warnings\n",
|
| 29 |
+
"warnings.filterwarnings('ignore')\n",
|
| 30 |
+
"\n",
|
| 31 |
+
"# Drawing helpers\n",
|
| 32 |
+
"mp_drawing = mp.solutions.drawing_utils\n",
|
| 33 |
+
"mp_pose = mp.solutions.pose"
|
| 34 |
+
]
|
| 35 |
+
},
|
| 36 |
+
{
|
| 37 |
+
"cell_type": "markdown",
|
| 38 |
+
"metadata": {},
|
| 39 |
+
"source": [
|
| 40 |
+
"## 1. Train Model"
|
| 41 |
+
]
|
| 42 |
+
},
|
| 43 |
+
{
|
| 44 |
+
"cell_type": "markdown",
|
| 45 |
+
"metadata": {},
|
| 46 |
+
"source": [
|
| 47 |
+
"### 1.1. Describe data"
|
| 48 |
+
]
|
| 49 |
+
},
|
| 50 |
+
{
|
| 51 |
+
"cell_type": "code",
|
| 52 |
+
"execution_count": 2,
|
| 53 |
+
"metadata": {},
|
| 54 |
+
"outputs": [],
|
| 55 |
+
"source": [
|
| 56 |
+
"def rescale_frame(frame, percent=50):\n",
|
| 57 |
+
" '''\n",
|
| 58 |
+
" Rescale a frame to a certain percentage compare to its original frame\n",
|
| 59 |
+
" '''\n",
|
| 60 |
+
" width = int(frame.shape[1] * percent/ 100)\n",
|
| 61 |
+
" height = int(frame.shape[0] * percent/ 100)\n",
|
| 62 |
+
" dim = (width, height)\n",
|
| 63 |
+
" return cv2.resize(frame, dim, interpolation = cv2.INTER_AREA)\n",
|
| 64 |
+
"\n",
|
| 65 |
+
"\n",
|
| 66 |
+
"def describe_dataset(dataset_path: str):\n",
|
| 67 |
+
" '''\n",
|
| 68 |
+
" Describe dataset\n",
|
| 69 |
+
" '''\n",
|
| 70 |
+
"\n",
|
| 71 |
+
" data = pd.read_csv(dataset_path)\n",
|
| 72 |
+
" print(f\"Headers: {list(data.columns.values)}\")\n",
|
| 73 |
+
" print(f'Number of rows: {data.shape[0]} \\nNumber of columns: {data.shape[1]}\\n')\n",
|
| 74 |
+
" print(f\"Labels: \\n{data['label'].value_counts()}\\n\")\n",
|
| 75 |
+
" print(f\"Missing values: {data.isnull().values.any()}\\n\")\n",
|
| 76 |
+
" \n",
|
| 77 |
+
" duplicate = data[data.duplicated()]\n",
|
| 78 |
+
" print(f\"Duplicate Rows : {len(duplicate.sum(axis=1))}\")\n",
|
| 79 |
+
"\n",
|
| 80 |
+
" return data"
|
| 81 |
+
]
|
| 82 |
+
},
|
| 83 |
+
{
|
| 84 |
+
"cell_type": "code",
|
| 85 |
+
"execution_count": 4,
|
| 86 |
+
"metadata": {},
|
| 87 |
+
"outputs": [
|
| 88 |
+
{
|
| 89 |
+
"name": "stdout",
|
| 90 |
+
"output_type": "stream",
|
| 91 |
+
"text": [
|
| 92 |
+
"Headers: ['label', 'nose_x', 'nose_y', 'nose_z', 'nose_v', 'left_shoulder_x', 'left_shoulder_y', 'left_shoulder_z', 'left_shoulder_v', 'right_shoulder_x', 'right_shoulder_y', 'right_shoulder_z', 'right_shoulder_v', 'left_hip_x', 'left_hip_y', 'left_hip_z', 'left_hip_v', 'right_hip_x', 'right_hip_y', 'right_hip_z', 'right_hip_v', 'left_knee_x', 'left_knee_y', 'left_knee_z', 'left_knee_v', 'right_knee_x', 'right_knee_y', 'right_knee_z', 'right_knee_v', 'left_ankle_x', 'left_ankle_y', 'left_ankle_z', 'left_ankle_v', 'right_ankle_x', 'right_ankle_y', 'right_ankle_z', 'right_ankle_v', 'left_heel_x', 'left_heel_y', 'left_heel_z', 'left_heel_v', 'right_heel_x', 'right_heel_y', 'right_heel_z', 'right_heel_v', 'left_foot_index_x', 'left_foot_index_y', 'left_foot_index_z', 'left_foot_index_v', 'right_foot_index_x', 'right_foot_index_y', 'right_foot_index_z', 'right_foot_index_v']\n",
|
| 93 |
+
"Number of rows: 24244 \n",
|
| 94 |
+
"Number of columns: 53\n",
|
| 95 |
+
"\n",
|
| 96 |
+
"Labels: \n",
|
| 97 |
+
"D 8232\n",
|
| 98 |
+
"M 8148\n",
|
| 99 |
+
"I 7864\n",
|
| 100 |
+
"Name: label, dtype: int64\n",
|
| 101 |
+
"\n",
|
| 102 |
+
"Missing values: False\n",
|
| 103 |
+
"\n",
|
| 104 |
+
"Duplicate Rows : 151\n"
|
| 105 |
+
]
|
| 106 |
+
},
|
| 107 |
+
{
|
| 108 |
+
"data": {
|
| 109 |
+
"text/html": [
|
| 110 |
+
"<div>\n",
|
| 111 |
+
"<style scoped>\n",
|
| 112 |
+
" .dataframe tbody tr th:only-of-type {\n",
|
| 113 |
+
" vertical-align: middle;\n",
|
| 114 |
+
" }\n",
|
| 115 |
+
"\n",
|
| 116 |
+
" .dataframe tbody tr th {\n",
|
| 117 |
+
" vertical-align: top;\n",
|
| 118 |
+
" }\n",
|
| 119 |
+
"\n",
|
| 120 |
+
" .dataframe thead th {\n",
|
| 121 |
+
" text-align: right;\n",
|
| 122 |
+
" }\n",
|
| 123 |
+
"</style>\n",
|
| 124 |
+
"<table border=\"1\" class=\"dataframe\">\n",
|
| 125 |
+
" <thead>\n",
|
| 126 |
+
" <tr style=\"text-align: right;\">\n",
|
| 127 |
+
" <th></th>\n",
|
| 128 |
+
" <th>label</th>\n",
|
| 129 |
+
" <th>nose_x</th>\n",
|
| 130 |
+
" <th>nose_y</th>\n",
|
| 131 |
+
" <th>nose_z</th>\n",
|
| 132 |
+
" <th>nose_v</th>\n",
|
| 133 |
+
" <th>left_shoulder_x</th>\n",
|
| 134 |
+
" <th>left_shoulder_y</th>\n",
|
| 135 |
+
" <th>left_shoulder_z</th>\n",
|
| 136 |
+
" <th>left_shoulder_v</th>\n",
|
| 137 |
+
" <th>right_shoulder_x</th>\n",
|
| 138 |
+
" <th>...</th>\n",
|
| 139 |
+
" <th>right_heel_z</th>\n",
|
| 140 |
+
" <th>right_heel_v</th>\n",
|
| 141 |
+
" <th>left_foot_index_x</th>\n",
|
| 142 |
+
" <th>left_foot_index_y</th>\n",
|
| 143 |
+
" <th>left_foot_index_z</th>\n",
|
| 144 |
+
" <th>left_foot_index_v</th>\n",
|
| 145 |
+
" <th>right_foot_index_x</th>\n",
|
| 146 |
+
" <th>right_foot_index_y</th>\n",
|
| 147 |
+
" <th>right_foot_index_z</th>\n",
|
| 148 |
+
" <th>right_foot_index_v</th>\n",
|
| 149 |
+
" </tr>\n",
|
| 150 |
+
" </thead>\n",
|
| 151 |
+
" <tbody>\n",
|
| 152 |
+
" <tr>\n",
|
| 153 |
+
" <th>0</th>\n",
|
| 154 |
+
" <td>M</td>\n",
|
| 155 |
+
" <td>0.496085</td>\n",
|
| 156 |
+
" <td>0.286904</td>\n",
|
| 157 |
+
" <td>-0.219098</td>\n",
|
| 158 |
+
" <td>0.999996</td>\n",
|
| 159 |
+
" <td>0.500287</td>\n",
|
| 160 |
+
" <td>0.360987</td>\n",
|
| 161 |
+
" <td>0.019479</td>\n",
|
| 162 |
+
" <td>0.999978</td>\n",
|
| 163 |
+
" <td>0.436462</td>\n",
|
| 164 |
+
" <td>...</td>\n",
|
| 165 |
+
" <td>-0.268695</td>\n",
|
| 166 |
+
" <td>0.996758</td>\n",
|
| 167 |
+
" <td>0.370391</td>\n",
|
| 168 |
+
" <td>0.893386</td>\n",
|
| 169 |
+
" <td>0.505172</td>\n",
|
| 170 |
+
" <td>0.931761</td>\n",
|
| 171 |
+
" <td>0.566927</td>\n",
|
| 172 |
+
" <td>1.005949</td>\n",
|
| 173 |
+
" <td>-0.382462</td>\n",
|
| 174 |
+
" <td>0.998906</td>\n",
|
| 175 |
+
" </tr>\n",
|
| 176 |
+
" <tr>\n",
|
| 177 |
+
" <th>1</th>\n",
|
| 178 |
+
" <td>M</td>\n",
|
| 179 |
+
" <td>0.496126</td>\n",
|
| 180 |
+
" <td>0.286918</td>\n",
|
| 181 |
+
" <td>-0.217849</td>\n",
|
| 182 |
+
" <td>0.999996</td>\n",
|
| 183 |
+
" <td>0.500281</td>\n",
|
| 184 |
+
" <td>0.360954</td>\n",
|
| 185 |
+
" <td>0.019995</td>\n",
|
| 186 |
+
" <td>0.999977</td>\n",
|
| 187 |
+
" <td>0.436466</td>\n",
|
| 188 |
+
" <td>...</td>\n",
|
| 189 |
+
" <td>-0.271191</td>\n",
|
| 190 |
+
" <td>0.996724</td>\n",
|
| 191 |
+
" <td>0.370344</td>\n",
|
| 192 |
+
" <td>0.893290</td>\n",
|
| 193 |
+
" <td>0.505325</td>\n",
|
| 194 |
+
" <td>0.931969</td>\n",
|
| 195 |
+
" <td>0.567040</td>\n",
|
| 196 |
+
" <td>1.005795</td>\n",
|
| 197 |
+
" <td>-0.384848</td>\n",
|
| 198 |
+
" <td>0.998902</td>\n",
|
| 199 |
+
" </tr>\n",
|
| 200 |
+
" <tr>\n",
|
| 201 |
+
" <th>2</th>\n",
|
| 202 |
+
" <td>M</td>\n",
|
| 203 |
+
" <td>0.496144</td>\n",
|
| 204 |
+
" <td>0.286921</td>\n",
|
| 205 |
+
" <td>-0.217039</td>\n",
|
| 206 |
+
" <td>0.999996</td>\n",
|
| 207 |
+
" <td>0.500279</td>\n",
|
| 208 |
+
" <td>0.360923</td>\n",
|
| 209 |
+
" <td>0.020068</td>\n",
|
| 210 |
+
" <td>0.999977</td>\n",
|
| 211 |
+
" <td>0.436469</td>\n",
|
| 212 |
+
" <td>...</td>\n",
|
| 213 |
+
" <td>-0.271365</td>\n",
|
| 214 |
+
" <td>0.996699</td>\n",
|
| 215 |
+
" <td>0.370316</td>\n",
|
| 216 |
+
" <td>0.893275</td>\n",
|
| 217 |
+
" <td>0.504931</td>\n",
|
| 218 |
+
" <td>0.931633</td>\n",
|
| 219 |
+
" <td>0.567040</td>\n",
|
| 220 |
+
" <td>1.005774</td>\n",
|
| 221 |
+
" <td>-0.384872</td>\n",
|
| 222 |
+
" <td>0.998894</td>\n",
|
| 223 |
+
" </tr>\n",
|
| 224 |
+
" </tbody>\n",
|
| 225 |
+
"</table>\n",
|
| 226 |
+
"<p>3 rows × 53 columns</p>\n",
|
| 227 |
+
"</div>"
|
| 228 |
+
],
|
| 229 |
+
"text/plain": [
|
| 230 |
+
" label nose_x nose_y nose_z nose_v left_shoulder_x \\\n",
|
| 231 |
+
"0 M 0.496085 0.286904 -0.219098 0.999996 0.500287 \n",
|
| 232 |
+
"1 M 0.496126 0.286918 -0.217849 0.999996 0.500281 \n",
|
| 233 |
+
"2 M 0.496144 0.286921 -0.217039 0.999996 0.500279 \n",
|
| 234 |
+
"\n",
|
| 235 |
+
" left_shoulder_y left_shoulder_z left_shoulder_v right_shoulder_x ... \\\n",
|
| 236 |
+
"0 0.360987 0.019479 0.999978 0.436462 ... \n",
|
| 237 |
+
"1 0.360954 0.019995 0.999977 0.436466 ... \n",
|
| 238 |
+
"2 0.360923 0.020068 0.999977 0.436469 ... \n",
|
| 239 |
+
"\n",
|
| 240 |
+
" right_heel_z right_heel_v left_foot_index_x left_foot_index_y \\\n",
|
| 241 |
+
"0 -0.268695 0.996758 0.370391 0.893386 \n",
|
| 242 |
+
"1 -0.271191 0.996724 0.370344 0.893290 \n",
|
| 243 |
+
"2 -0.271365 0.996699 0.370316 0.893275 \n",
|
| 244 |
+
"\n",
|
| 245 |
+
" left_foot_index_z left_foot_index_v right_foot_index_x \\\n",
|
| 246 |
+
"0 0.505172 0.931761 0.566927 \n",
|
| 247 |
+
"1 0.505325 0.931969 0.567040 \n",
|
| 248 |
+
"2 0.504931 0.931633 0.567040 \n",
|
| 249 |
+
"\n",
|
| 250 |
+
" right_foot_index_y right_foot_index_z right_foot_index_v \n",
|
| 251 |
+
"0 1.005949 -0.382462 0.998906 \n",
|
| 252 |
+
"1 1.005795 -0.384848 0.998902 \n",
|
| 253 |
+
"2 1.005774 -0.384872 0.998894 \n",
|
| 254 |
+
"\n",
|
| 255 |
+
"[3 rows x 53 columns]"
|
| 256 |
+
]
|
| 257 |
+
},
|
| 258 |
+
"execution_count": 4,
|
| 259 |
+
"metadata": {},
|
| 260 |
+
"output_type": "execute_result"
|
| 261 |
+
}
|
| 262 |
+
],
|
| 263 |
+
"source": [
|
| 264 |
+
"df = describe_dataset(\"./stage.train.csv\")\n",
|
| 265 |
+
"df.head(3)"
|
| 266 |
+
]
|
| 267 |
+
},
|
| 268 |
+
{
|
| 269 |
+
"cell_type": "markdown",
|
| 270 |
+
"metadata": {},
|
| 271 |
+
"source": [
|
| 272 |
+
"### 1.2. Train and evaluate model with train set"
|
| 273 |
+
]
|
| 274 |
+
},
|
| 275 |
+
{
|
| 276 |
+
"cell_type": "code",
|
| 277 |
+
"execution_count": 5,
|
| 278 |
+
"metadata": {},
|
| 279 |
+
"outputs": [],
|
| 280 |
+
"source": [
|
| 281 |
+
"from sklearn.linear_model import LogisticRegression, RidgeClassifier, SGDClassifier\n",
|
| 282 |
+
"from sklearn.svm import SVC\n",
|
| 283 |
+
"from sklearn.neighbors import KNeighborsClassifier\n",
|
| 284 |
+
"from sklearn.tree import DecisionTreeClassifier\n",
|
| 285 |
+
"from sklearn.ensemble import RandomForestClassifier\n",
|
| 286 |
+
"\n",
|
| 287 |
+
"from sklearn.metrics import precision_score, accuracy_score, f1_score, recall_score"
|
| 288 |
+
]
|
| 289 |
+
},
|
| 290 |
+
{
|
| 291 |
+
"cell_type": "code",
|
| 292 |
+
"execution_count": 6,
|
| 293 |
+
"metadata": {},
|
| 294 |
+
"outputs": [],
|
| 295 |
+
"source": [
|
| 296 |
+
"# Extract features and class\n",
|
| 297 |
+
"X = df.drop(\"label\", axis=1) # features\n",
|
| 298 |
+
"y = df[\"label\"]\n",
|
| 299 |
+
"\n",
|
| 300 |
+
"# Standard Scaler\n",
|
| 301 |
+
"sc = StandardScaler()\n",
|
| 302 |
+
"X = pd.DataFrame(sc.fit_transform(X))"
|
| 303 |
+
]
|
| 304 |
+
},
|
| 305 |
+
{
|
| 306 |
+
"cell_type": "code",
|
| 307 |
+
"execution_count": 8,
|
| 308 |
+
"metadata": {},
|
| 309 |
+
"outputs": [
|
| 310 |
+
{
|
| 311 |
+
"data": {
|
| 312 |
+
"text/plain": [
|
| 313 |
+
"8474 M\n",
|
| 314 |
+
"4197 I\n",
|
| 315 |
+
"9705 I\n",
|
| 316 |
+
"Name: label, dtype: object"
|
| 317 |
+
]
|
| 318 |
+
},
|
| 319 |
+
"execution_count": 8,
|
| 320 |
+
"metadata": {},
|
| 321 |
+
"output_type": "execute_result"
|
| 322 |
+
}
|
| 323 |
+
],
|
| 324 |
+
"source": [
|
| 325 |
+
"X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=1234)\n",
|
| 326 |
+
"\n",
|
| 327 |
+
"y_train.head(3)"
|
| 328 |
+
]
|
| 329 |
+
},
|
| 330 |
+
{
|
| 331 |
+
"cell_type": "code",
|
| 332 |
+
"execution_count": 15,
|
| 333 |
+
"metadata": {},
|
| 334 |
+
"outputs": [],
|
| 335 |
+
"source": [
|
| 336 |
+
"algorithms =[(\"LR\", LogisticRegression()),\n",
|
| 337 |
+
" (\"SVC\", SVC(probability=True)),\n",
|
| 338 |
+
" ('KNN',KNeighborsClassifier()),\n",
|
| 339 |
+
" (\"DTC\", DecisionTreeClassifier()),\n",
|
| 340 |
+
" (\"SGDC\", SGDClassifier()),\n",
|
| 341 |
+
" (\"Ridge\", RidgeClassifier()),\n",
|
| 342 |
+
" ('RF', RandomForestClassifier()),]\n",
|
| 343 |
+
"\n",
|
| 344 |
+
"models = {}\n",
|
| 345 |
+
"final_results = []\n",
|
| 346 |
+
"\n",
|
| 347 |
+
"for name, model in algorithms:\n",
|
| 348 |
+
" trained_model = model.fit(X_train, y_train)\n",
|
| 349 |
+
" models[name] = trained_model\n",
|
| 350 |
+
"\n",
|
| 351 |
+
" # Evaluate model\n",
|
| 352 |
+
" model_results = model.predict(X_test)\n",
|
| 353 |
+
"\n",
|
| 354 |
+
" p_score = precision_score(y_test, model_results, average=\"macro\")\n",
|
| 355 |
+
" a_score = accuracy_score(y_test, model_results)\n",
|
| 356 |
+
" r_score = recall_score(y_test, model_results, average=\"macro\")\n",
|
| 357 |
+
" f1_score_result = f1_score(y_test, model_results, average=None, labels=[\"I\", \"M\", \"D\"])\n",
|
| 358 |
+
" final_results.append(( name, p_score, a_score, r_score, f1_score_result ))\n",
|
| 359 |
+
"\n",
|
| 360 |
+
"\n",
|
| 361 |
+
"final_results.sort(key=lambda k: k[4][0] + k[4][1], reverse=True)"
|
| 362 |
+
]
|
| 363 |
+
},
|
| 364 |
+
{
|
| 365 |
+
"cell_type": "code",
|
| 366 |
+
"execution_count": 16,
|
| 367 |
+
"metadata": {},
|
| 368 |
+
"outputs": [
|
| 369 |
+
{
|
| 370 |
+
"data": {
|
| 371 |
+
"text/html": [
|
| 372 |
+
"<div>\n",
|
| 373 |
+
"<style scoped>\n",
|
| 374 |
+
" .dataframe tbody tr th:only-of-type {\n",
|
| 375 |
+
" vertical-align: middle;\n",
|
| 376 |
+
" }\n",
|
| 377 |
+
"\n",
|
| 378 |
+
" .dataframe tbody tr th {\n",
|
| 379 |
+
" vertical-align: top;\n",
|
| 380 |
+
" }\n",
|
| 381 |
+
"\n",
|
| 382 |
+
" .dataframe thead th {\n",
|
| 383 |
+
" text-align: right;\n",
|
| 384 |
+
" }\n",
|
| 385 |
+
"</style>\n",
|
| 386 |
+
"<table border=\"1\" class=\"dataframe\">\n",
|
| 387 |
+
" <thead>\n",
|
| 388 |
+
" <tr style=\"text-align: right;\">\n",
|
| 389 |
+
" <th></th>\n",
|
| 390 |
+
" <th>Model</th>\n",
|
| 391 |
+
" <th>Precision Score</th>\n",
|
| 392 |
+
" <th>Accuracy score</th>\n",
|
| 393 |
+
" <th>Recall Score</th>\n",
|
| 394 |
+
" <th>F1 score</th>\n",
|
| 395 |
+
" </tr>\n",
|
| 396 |
+
" </thead>\n",
|
| 397 |
+
" <tbody>\n",
|
| 398 |
+
" <tr>\n",
|
| 399 |
+
" <th>0</th>\n",
|
| 400 |
+
" <td>KNN</td>\n",
|
| 401 |
+
" <td>0.995486</td>\n",
|
| 402 |
+
" <td>0.995463</td>\n",
|
| 403 |
+
" <td>0.995497</td>\n",
|
| 404 |
+
" <td>[0.998108448928121, 0.9936189608021876, 0.9947...</td>\n",
|
| 405 |
+
" </tr>\n",
|
| 406 |
+
" <tr>\n",
|
| 407 |
+
" <th>1</th>\n",
|
| 408 |
+
" <td>SVC</td>\n",
|
| 409 |
+
" <td>0.992812</td>\n",
|
| 410 |
+
" <td>0.992782</td>\n",
|
| 411 |
+
" <td>0.992862</td>\n",
|
| 412 |
+
" <td>[0.9977952755905511, 0.9893390191897654, 0.991...</td>\n",
|
| 413 |
+
" </tr>\n",
|
| 414 |
+
" <tr>\n",
|
| 415 |
+
" <th>2</th>\n",
|
| 416 |
+
" <td>RF</td>\n",
|
| 417 |
+
" <td>0.993596</td>\n",
|
| 418 |
+
" <td>0.993607</td>\n",
|
| 419 |
+
" <td>0.993679</td>\n",
|
| 420 |
+
" <td>[0.9949717159019483, 0.9905516610789394, 0.995...</td>\n",
|
| 421 |
+
" </tr>\n",
|
| 422 |
+
" <tr>\n",
|
| 423 |
+
" <th>3</th>\n",
|
| 424 |
+
" <td>LR</td>\n",
|
| 425 |
+
" <td>0.989931</td>\n",
|
| 426 |
+
" <td>0.989895</td>\n",
|
| 427 |
+
" <td>0.990009</td>\n",
|
| 428 |
+
" <td>[0.9959080893925086, 0.9850381679389313, 0.988...</td>\n",
|
| 429 |
+
" </tr>\n",
|
| 430 |
+
" <tr>\n",
|
| 431 |
+
" <th>4</th>\n",
|
| 432 |
+
" <td>DTC</td>\n",
|
| 433 |
+
" <td>0.990106</td>\n",
|
| 434 |
+
" <td>0.990101</td>\n",
|
| 435 |
+
" <td>0.990208</td>\n",
|
| 436 |
+
" <td>[0.9943431803896919, 0.9856576136710405, 0.990...</td>\n",
|
| 437 |
+
" </tr>\n",
|
| 438 |
+
" <tr>\n",
|
| 439 |
+
" <th>5</th>\n",
|
| 440 |
+
" <td>SGDC</td>\n",
|
| 441 |
+
" <td>0.986400</td>\n",
|
| 442 |
+
" <td>0.986389</td>\n",
|
| 443 |
+
" <td>0.986541</td>\n",
|
| 444 |
+
" <td>[0.990894819466248, 0.9797794117647058, 0.9886...</td>\n",
|
| 445 |
+
" </tr>\n",
|
| 446 |
+
" <tr>\n",
|
| 447 |
+
" <th>6</th>\n",
|
| 448 |
+
" <td>Ridge</td>\n",
|
| 449 |
+
" <td>0.970926</td>\n",
|
| 450 |
+
" <td>0.970097</td>\n",
|
| 451 |
+
" <td>0.969980</td>\n",
|
| 452 |
+
" <td>[0.9709677419354839, 0.9567827130852341, 0.982...</td>\n",
|
| 453 |
+
" </tr>\n",
|
| 454 |
+
" </tbody>\n",
|
| 455 |
+
"</table>\n",
|
| 456 |
+
"</div>"
|
| 457 |
+
],
|
| 458 |
+
"text/plain": [
|
| 459 |
+
" Model Precision Score Accuracy score Recall Score \\\n",
|
| 460 |
+
"0 KNN 0.995486 0.995463 0.995497 \n",
|
| 461 |
+
"1 SVC 0.992812 0.992782 0.992862 \n",
|
| 462 |
+
"2 RF 0.993596 0.993607 0.993679 \n",
|
| 463 |
+
"3 LR 0.989931 0.989895 0.990009 \n",
|
| 464 |
+
"4 DTC 0.990106 0.990101 0.990208 \n",
|
| 465 |
+
"5 SGDC 0.986400 0.986389 0.986541 \n",
|
| 466 |
+
"6 Ridge 0.970926 0.970097 0.969980 \n",
|
| 467 |
+
"\n",
|
| 468 |
+
" F1 score \n",
|
| 469 |
+
"0 [0.998108448928121, 0.9936189608021876, 0.9947... \n",
|
| 470 |
+
"1 [0.9977952755905511, 0.9893390191897654, 0.991... \n",
|
| 471 |
+
"2 [0.9949717159019483, 0.9905516610789394, 0.995... \n",
|
| 472 |
+
"3 [0.9959080893925086, 0.9850381679389313, 0.988... \n",
|
| 473 |
+
"4 [0.9943431803896919, 0.9856576136710405, 0.990... \n",
|
| 474 |
+
"5 [0.990894819466248, 0.9797794117647058, 0.9886... \n",
|
| 475 |
+
"6 [0.9709677419354839, 0.9567827130852341, 0.982... "
|
| 476 |
+
]
|
| 477 |
+
},
|
| 478 |
+
"execution_count": 16,
|
| 479 |
+
"metadata": {},
|
| 480 |
+
"output_type": "execute_result"
|
| 481 |
+
}
|
| 482 |
+
],
|
| 483 |
+
"source": [
|
| 484 |
+
"pd.DataFrame(final_results, columns=[\"Model\", \"Precision Score\", \"Accuracy score\", \"Recall Score\", \"F1 score\"])"
|
| 485 |
+
]
|
| 486 |
+
},
|
| 487 |
+
{
|
| 488 |
+
"cell_type": "markdown",
|
| 489 |
+
"metadata": {},
|
| 490 |
+
"source": [
|
| 491 |
+
"### 1.3. Evaluate models with test set"
|
| 492 |
+
]
|
| 493 |
+
},
|
| 494 |
+
{
|
| 495 |
+
"cell_type": "code",
|
| 496 |
+
"execution_count": 17,
|
| 497 |
+
"metadata": {},
|
| 498 |
+
"outputs": [
|
| 499 |
+
{
|
| 500 |
+
"name": "stdout",
|
| 501 |
+
"output_type": "stream",
|
| 502 |
+
"text": [
|
| 503 |
+
"Headers: ['label', 'nose_x', 'nose_y', 'nose_z', 'nose_v', 'left_shoulder_x', 'left_shoulder_y', 'left_shoulder_z', 'left_shoulder_v', 'right_shoulder_x', 'right_shoulder_y', 'right_shoulder_z', 'right_shoulder_v', 'left_hip_x', 'left_hip_y', 'left_hip_z', 'left_hip_v', 'right_hip_x', 'right_hip_y', 'right_hip_z', 'right_hip_v', 'left_knee_x', 'left_knee_y', 'left_knee_z', 'left_knee_v', 'right_knee_x', 'right_knee_y', 'right_knee_z', 'right_knee_v', 'left_ankle_x', 'left_ankle_y', 'left_ankle_z', 'left_ankle_v', 'right_ankle_x', 'right_ankle_y', 'right_ankle_z', 'right_ankle_v', 'left_heel_x', 'left_heel_y', 'left_heel_z', 'left_heel_v', 'right_heel_x', 'right_heel_y', 'right_heel_z', 'right_heel_v', 'left_foot_index_x', 'left_foot_index_y', 'left_foot_index_z', 'left_foot_index_v', 'right_foot_index_x', 'right_foot_index_y', 'right_foot_index_z', 'right_foot_index_v']\n",
|
| 504 |
+
"Number of rows: 1205 \n",
|
| 505 |
+
"Number of columns: 53\n",
|
| 506 |
+
"\n",
|
| 507 |
+
"Labels: \n",
|
| 508 |
+
"D 416\n",
|
| 509 |
+
"I 402\n",
|
| 510 |
+
"M 387\n",
|
| 511 |
+
"Name: label, dtype: int64\n",
|
| 512 |
+
"\n",
|
| 513 |
+
"Missing values: False\n",
|
| 514 |
+
"\n",
|
| 515 |
+
"Duplicate Rows : 20\n"
|
| 516 |
+
]
|
| 517 |
+
}
|
| 518 |
+
],
|
| 519 |
+
"source": [
|
| 520 |
+
"test_df = describe_dataset(\"./stage.test.csv\")\n",
|
| 521 |
+
"test_df = test_df.sample(frac=1).reset_index(drop=True)\n",
|
| 522 |
+
"\n",
|
| 523 |
+
"test_x = test_df.drop(\"label\", axis=1)\n",
|
| 524 |
+
"test_y = test_df[\"label\"]\n",
|
| 525 |
+
"\n",
|
| 526 |
+
"test_x = pd.DataFrame(sc.transform(test_x))"
|
| 527 |
+
]
|
| 528 |
+
},
|
| 529 |
+
{
|
| 530 |
+
"cell_type": "code",
|
| 531 |
+
"execution_count": 18,
|
| 532 |
+
"metadata": {},
|
| 533 |
+
"outputs": [
|
| 534 |
+
{
|
| 535 |
+
"data": {
|
| 536 |
+
"text/html": [
|
| 537 |
+
"<div>\n",
|
| 538 |
+
"<style scoped>\n",
|
| 539 |
+
" .dataframe tbody tr th:only-of-type {\n",
|
| 540 |
+
" vertical-align: middle;\n",
|
| 541 |
+
" }\n",
|
| 542 |
+
"\n",
|
| 543 |
+
" .dataframe tbody tr th {\n",
|
| 544 |
+
" vertical-align: top;\n",
|
| 545 |
+
" }\n",
|
| 546 |
+
"\n",
|
| 547 |
+
" .dataframe thead th {\n",
|
| 548 |
+
" text-align: right;\n",
|
| 549 |
+
" }\n",
|
| 550 |
+
"</style>\n",
|
| 551 |
+
"<table border=\"1\" class=\"dataframe\">\n",
|
| 552 |
+
" <thead>\n",
|
| 553 |
+
" <tr style=\"text-align: right;\">\n",
|
| 554 |
+
" <th></th>\n",
|
| 555 |
+
" <th>Model</th>\n",
|
| 556 |
+
" <th>Precision Score</th>\n",
|
| 557 |
+
" <th>Accuracy score</th>\n",
|
| 558 |
+
" <th>Recall Score</th>\n",
|
| 559 |
+
" <th>F1 score</th>\n",
|
| 560 |
+
" </tr>\n",
|
| 561 |
+
" </thead>\n",
|
| 562 |
+
" <tbody>\n",
|
| 563 |
+
" <tr>\n",
|
| 564 |
+
" <th>0</th>\n",
|
| 565 |
+
" <td>Ridge</td>\n",
|
| 566 |
+
" <td>0.953408</td>\n",
|
| 567 |
+
" <td>0.951037</td>\n",
|
| 568 |
+
" <td>0.949563</td>\n",
|
| 569 |
+
" <td>[0.9763387297633873, 0.9199457259158751, 0.954...</td>\n",
|
| 570 |
+
" </tr>\n",
|
| 571 |
+
" <tr>\n",
|
| 572 |
+
" <th>1</th>\n",
|
| 573 |
+
" <td>SVC</td>\n",
|
| 574 |
+
" <td>0.955640</td>\n",
|
| 575 |
+
" <td>0.951867</td>\n",
|
| 576 |
+
" <td>0.950163</td>\n",
|
| 577 |
+
" <td>[0.9492325855962219, 0.9194444444444445, 0.982...</td>\n",
|
| 578 |
+
" </tr>\n",
|
| 579 |
+
" <tr>\n",
|
| 580 |
+
" <th>2</th>\n",
|
| 581 |
+
" <td>LR</td>\n",
|
| 582 |
+
" <td>0.952856</td>\n",
|
| 583 |
+
" <td>0.948548</td>\n",
|
| 584 |
+
" <td>0.946658</td>\n",
|
| 585 |
+
" <td>[0.950354609929078, 0.9131652661064426, 0.9764...</td>\n",
|
| 586 |
+
" </tr>\n",
|
| 587 |
+
" <tr>\n",
|
| 588 |
+
" <th>3</th>\n",
|
| 589 |
+
" <td>KNN</td>\n",
|
| 590 |
+
" <td>0.919799</td>\n",
|
| 591 |
+
" <td>0.915353</td>\n",
|
| 592 |
+
" <td>0.916588</td>\n",
|
| 593 |
+
" <td>[0.9745454545454546, 0.875609756097561, 0.8941...</td>\n",
|
| 594 |
+
" </tr>\n",
|
| 595 |
+
" <tr>\n",
|
| 596 |
+
" <th>4</th>\n",
|
| 597 |
+
" <td>SGDC</td>\n",
|
| 598 |
+
" <td>0.928256</td>\n",
|
| 599 |
+
" <td>0.912863</td>\n",
|
| 600 |
+
" <td>0.909741</td>\n",
|
| 601 |
+
" <td>[0.8903654485049834, 0.8444444444444444, 0.992...</td>\n",
|
| 602 |
+
" </tr>\n",
|
| 603 |
+
" <tr>\n",
|
| 604 |
+
" <th>5</th>\n",
|
| 605 |
+
" <td>RF</td>\n",
|
| 606 |
+
" <td>0.870546</td>\n",
|
| 607 |
+
" <td>0.862241</td>\n",
|
| 608 |
+
" <td>0.861943</td>\n",
|
| 609 |
+
" <td>[0.90744920993228, 0.7849740932642487, 0.88829...</td>\n",
|
| 610 |
+
" </tr>\n",
|
| 611 |
+
" <tr>\n",
|
| 612 |
+
" <th>6</th>\n",
|
| 613 |
+
" <td>DTC</td>\n",
|
| 614 |
+
" <td>0.859608</td>\n",
|
| 615 |
+
" <td>0.857261</td>\n",
|
| 616 |
+
" <td>0.855094</td>\n",
|
| 617 |
+
" <td>[0.8983240223463688, 0.7577464788732395, 0.899...</td>\n",
|
| 618 |
+
" </tr>\n",
|
| 619 |
+
" </tbody>\n",
|
| 620 |
+
"</table>\n",
|
| 621 |
+
"</div>"
|
| 622 |
+
],
|
| 623 |
+
"text/plain": [
|
| 624 |
+
" Model Precision Score Accuracy score Recall Score \\\n",
|
| 625 |
+
"0 Ridge 0.953408 0.951037 0.949563 \n",
|
| 626 |
+
"1 SVC 0.955640 0.951867 0.950163 \n",
|
| 627 |
+
"2 LR 0.952856 0.948548 0.946658 \n",
|
| 628 |
+
"3 KNN 0.919799 0.915353 0.916588 \n",
|
| 629 |
+
"4 SGDC 0.928256 0.912863 0.909741 \n",
|
| 630 |
+
"5 RF 0.870546 0.862241 0.861943 \n",
|
| 631 |
+
"6 DTC 0.859608 0.857261 0.855094 \n",
|
| 632 |
+
"\n",
|
| 633 |
+
" F1 score \n",
|
| 634 |
+
"0 [0.9763387297633873, 0.9199457259158751, 0.954... \n",
|
| 635 |
+
"1 [0.9492325855962219, 0.9194444444444445, 0.982... \n",
|
| 636 |
+
"2 [0.950354609929078, 0.9131652661064426, 0.9764... \n",
|
| 637 |
+
"3 [0.9745454545454546, 0.875609756097561, 0.8941... \n",
|
| 638 |
+
"4 [0.8903654485049834, 0.8444444444444444, 0.992... \n",
|
| 639 |
+
"5 [0.90744920993228, 0.7849740932642487, 0.88829... \n",
|
| 640 |
+
"6 [0.8983240223463688, 0.7577464788732395, 0.899... "
|
| 641 |
+
]
|
| 642 |
+
},
|
| 643 |
+
"execution_count": 18,
|
| 644 |
+
"metadata": {},
|
| 645 |
+
"output_type": "execute_result"
|
| 646 |
+
}
|
| 647 |
+
],
|
| 648 |
+
"source": [
|
| 649 |
+
"testset_final_results = []\n",
|
| 650 |
+
"\n",
|
| 651 |
+
"for name, model in models.items():\n",
|
| 652 |
+
" # Evaluate model\n",
|
| 653 |
+
" model_results = model.predict(test_x)\n",
|
| 654 |
+
"\n",
|
| 655 |
+
" p_score = precision_score(test_y, model_results, average=\"macro\")\n",
|
| 656 |
+
" a_score = accuracy_score(test_y, model_results)\n",
|
| 657 |
+
" r_score = recall_score(test_y, model_results, average=\"macro\")\n",
|
| 658 |
+
" f1_score_result = f1_score(test_y, model_results, average=None, labels=[\"I\", \"M\", \"D\"])\n",
|
| 659 |
+
" testset_final_results.append(( name, p_score, a_score, r_score, f1_score_result ))\n",
|
| 660 |
+
"\n",
|
| 661 |
+
"\n",
|
| 662 |
+
"testset_final_results.sort(key=lambda k: k[4][0] + k[4][1], reverse=True)\n",
|
| 663 |
+
"pd.DataFrame(testset_final_results, columns=[\"Model\", \"Precision Score\", \"Accuracy score\", \"Recall Score\", \"F1 score\"])"
|
| 664 |
+
]
|
| 665 |
+
},
|
| 666 |
+
{
|
| 667 |
+
"cell_type": "markdown",
|
| 668 |
+
"metadata": {},
|
| 669 |
+
"source": [
|
| 670 |
+
"## 2. Dumped Model\n",
|
| 671 |
+
"\n",
|
| 672 |
+
"The best models are in order:\n",
|
| 673 |
+
"- Ridge\n",
|
| 674 |
+
"- SVC\n",
|
| 675 |
+
"- LR"
|
| 676 |
+
]
|
| 677 |
+
},
|
| 678 |
+
{
|
| 679 |
+
"cell_type": "code",
|
| 680 |
+
"execution_count": 13,
|
| 681 |
+
"metadata": {},
|
| 682 |
+
"outputs": [],
|
| 683 |
+
"source": [
|
| 684 |
+
"with open(\"./model/sklearn/stage_LR_model.pkl\", \"wb\") as f:\n",
|
| 685 |
+
" pickle.dump(models[\"LR\"], f)"
|
| 686 |
+
]
|
| 687 |
+
},
|
| 688 |
+
{
|
| 689 |
+
"cell_type": "code",
|
| 690 |
+
"execution_count": null,
|
| 691 |
+
"metadata": {},
|
| 692 |
+
"outputs": [],
|
| 693 |
+
"source": [
|
| 694 |
+
"with open(\"./model/sklearn/stage_SVC_model.pkl\", \"wb\") as f:\n",
|
| 695 |
+
" pickle.dump(models[\"SVC\"], f)"
|
| 696 |
+
]
|
| 697 |
+
},
|
| 698 |
+
{
|
| 699 |
+
"cell_type": "code",
|
| 700 |
+
"execution_count": 14,
|
| 701 |
+
"metadata": {},
|
| 702 |
+
"outputs": [],
|
| 703 |
+
"source": [
|
| 704 |
+
"with open(\"./model/sklearn/stage_Ridge_model.pkl\", \"wb\") as f:\n",
|
| 705 |
+
" pickle.dump(models[\"Ridge\"], f)"
|
| 706 |
+
]
|
| 707 |
+
},
|
| 708 |
+
{
|
| 709 |
+
"cell_type": "code",
|
| 710 |
+
"execution_count": 27,
|
| 711 |
+
"metadata": {},
|
| 712 |
+
"outputs": [],
|
| 713 |
+
"source": [
|
| 714 |
+
"with open(\"./model/input_scaler.pkl\", \"wb\") as f:\n",
|
| 715 |
+
" pickle.dump(sc, f)"
|
| 716 |
+
]
|
| 717 |
+
},
|
| 718 |
+
{
|
| 719 |
+
"cell_type": "code",
|
| 720 |
+
"execution_count": null,
|
| 721 |
+
"metadata": {},
|
| 722 |
+
"outputs": [],
|
| 723 |
+
"source": []
|
| 724 |
+
}
|
| 725 |
+
],
|
| 726 |
+
"metadata": {
|
| 727 |
+
"kernelspec": {
|
| 728 |
+
"display_name": "Python 3.8.13 (conda)",
|
| 729 |
+
"language": "python",
|
| 730 |
+
"name": "python3"
|
| 731 |
+
},
|
| 732 |
+
"language_info": {
|
| 733 |
+
"codemirror_mode": {
|
| 734 |
+
"name": "ipython",
|
| 735 |
+
"version": 3
|
| 736 |
+
},
|
| 737 |
+
"file_extension": ".py",
|
| 738 |
+
"mimetype": "text/x-python",
|
| 739 |
+
"name": "python",
|
| 740 |
+
"nbconvert_exporter": "python",
|
| 741 |
+
"pygments_lexer": "ipython3",
|
| 742 |
+
"version": "3.8.13"
|
| 743 |
+
},
|
| 744 |
+
"orig_nbformat": 4,
|
| 745 |
+
"vscode": {
|
| 746 |
+
"interpreter": {
|
| 747 |
+
"hash": "9260f401923fb5c4108c543a7d176de9733d378b3752e49535ad7c43c2271b65"
|
| 748 |
+
}
|
| 749 |
+
}
|
| 750 |
+
},
|
| 751 |
+
"nbformat": 4,
|
| 752 |
+
"nbformat_minor": 2
|
| 753 |
+
}
|
core/lunge_model/3.stage.deep_learning.ipynb
ADDED
|
@@ -0,0 +1,767 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": 15,
|
| 6 |
+
"metadata": {},
|
| 7 |
+
"outputs": [],
|
| 8 |
+
"source": [
|
| 9 |
+
"# Data visualization\n",
|
| 10 |
+
"import numpy as np\n",
|
| 11 |
+
"import pandas as pd \n",
|
| 12 |
+
"# Keras\n",
|
| 13 |
+
"from keras.models import Sequential\n",
|
| 14 |
+
"from keras.layers import Dense\n",
|
| 15 |
+
"from keras.layers import Dropout\n",
|
| 16 |
+
"from keras.optimizers import Adam\n",
|
| 17 |
+
"from keras.utils.np_utils import to_categorical\n",
|
| 18 |
+
"\n",
|
| 19 |
+
"# Train-Test\n",
|
| 20 |
+
"from sklearn.model_selection import train_test_split\n",
|
| 21 |
+
"# Classification Report\n",
|
| 22 |
+
"from sklearn.metrics import classification_report, confusion_matrix\n",
|
| 23 |
+
"\n",
|
| 24 |
+
"import pickle\n",
|
| 25 |
+
"\n",
|
| 26 |
+
"import warnings\n",
|
| 27 |
+
"warnings.filterwarnings('ignore')"
|
| 28 |
+
]
|
| 29 |
+
},
|
| 30 |
+
{
|
| 31 |
+
"cell_type": "markdown",
|
| 32 |
+
"metadata": {},
|
| 33 |
+
"source": [
|
| 34 |
+
"## 1. Describe dataset and Train Model"
|
| 35 |
+
]
|
| 36 |
+
},
|
| 37 |
+
{
|
| 38 |
+
"cell_type": "markdown",
|
| 39 |
+
"metadata": {},
|
| 40 |
+
"source": [
|
| 41 |
+
"### 1.1. Describe dataset"
|
| 42 |
+
]
|
| 43 |
+
},
|
| 44 |
+
{
|
| 45 |
+
"cell_type": "code",
|
| 46 |
+
"execution_count": 3,
|
| 47 |
+
"metadata": {},
|
| 48 |
+
"outputs": [],
|
| 49 |
+
"source": [
|
| 50 |
+
"# Determine important landmarks for lunge\n",
|
| 51 |
+
"IMPORTANT_LMS = [\n",
|
| 52 |
+
" \"NOSE\",\n",
|
| 53 |
+
" \"LEFT_SHOULDER\",\n",
|
| 54 |
+
" \"RIGHT_SHOULDER\",\n",
|
| 55 |
+
" \"LEFT_HIP\",\n",
|
| 56 |
+
" \"RIGHT_HIP\",\n",
|
| 57 |
+
" \"LEFT_KNEE\",\n",
|
| 58 |
+
" \"RIGHT_KNEE\",\n",
|
| 59 |
+
" \"LEFT_ANKLE\",\n",
|
| 60 |
+
" \"RIGHT_ANKLE\",\n",
|
| 61 |
+
" \"LEFT_HEEL\",\n",
|
| 62 |
+
" \"RIGHT_HEEL\",\n",
|
| 63 |
+
" \"LEFT_FOOT_INDEX\",\n",
|
| 64 |
+
" \"RIGHT_FOOT_INDEX\",\n",
|
| 65 |
+
"]\n",
|
| 66 |
+
"\n",
|
| 67 |
+
"# Generate all columns of the data frame\n",
|
| 68 |
+
"\n",
|
| 69 |
+
"HEADERS = [\"label\"] # Label column\n",
|
| 70 |
+
"\n",
|
| 71 |
+
"for lm in IMPORTANT_LMS:\n",
|
| 72 |
+
" HEADERS += [f\"{lm.lower()}_x\", f\"{lm.lower()}_y\", f\"{lm.lower()}_z\", f\"{lm.lower()}_v\"]"
|
| 73 |
+
]
|
| 74 |
+
},
|
| 75 |
+
{
|
| 76 |
+
"cell_type": "code",
|
| 77 |
+
"execution_count": 4,
|
| 78 |
+
"metadata": {},
|
| 79 |
+
"outputs": [
|
| 80 |
+
{
|
| 81 |
+
"name": "stdout",
|
| 82 |
+
"output_type": "stream",
|
| 83 |
+
"text": [
|
| 84 |
+
"Headers: ['label', 'nose_x', 'nose_y', 'nose_z', 'nose_v', 'left_shoulder_x', 'left_shoulder_y', 'left_shoulder_z', 'left_shoulder_v', 'right_shoulder_x', 'right_shoulder_y', 'right_shoulder_z', 'right_shoulder_v', 'left_hip_x', 'left_hip_y', 'left_hip_z', 'left_hip_v', 'right_hip_x', 'right_hip_y', 'right_hip_z', 'right_hip_v', 'left_knee_x', 'left_knee_y', 'left_knee_z', 'left_knee_v', 'right_knee_x', 'right_knee_y', 'right_knee_z', 'right_knee_v', 'left_ankle_x', 'left_ankle_y', 'left_ankle_z', 'left_ankle_v', 'right_ankle_x', 'right_ankle_y', 'right_ankle_z', 'right_ankle_v', 'left_heel_x', 'left_heel_y', 'left_heel_z', 'left_heel_v', 'right_heel_x', 'right_heel_y', 'right_heel_z', 'right_heel_v', 'left_foot_index_x', 'left_foot_index_y', 'left_foot_index_z', 'left_foot_index_v', 'right_foot_index_x', 'right_foot_index_y', 'right_foot_index_z', 'right_foot_index_v']\n",
|
| 85 |
+
"Number of rows: 17040 \n",
|
| 86 |
+
"Number of columns: 53\n",
|
| 87 |
+
"\n",
|
| 88 |
+
"Labels: \n",
|
| 89 |
+
"M 6171\n",
|
| 90 |
+
"D 5735\n",
|
| 91 |
+
"I 5134\n",
|
| 92 |
+
"Name: label, dtype: int64\n",
|
| 93 |
+
"\n",
|
| 94 |
+
"Missing values: False\n",
|
| 95 |
+
"\n",
|
| 96 |
+
"Duplicate Rows : 0\n"
|
| 97 |
+
]
|
| 98 |
+
}
|
| 99 |
+
],
|
| 100 |
+
"source": [
|
| 101 |
+
"def describe_dataset(dataset_path: str):\n",
|
| 102 |
+
" '''\n",
|
| 103 |
+
" Describe dataset\n",
|
| 104 |
+
" '''\n",
|
| 105 |
+
"\n",
|
| 106 |
+
" data = pd.read_csv(dataset_path)\n",
|
| 107 |
+
" print(f\"Headers: {list(data.columns.values)}\")\n",
|
| 108 |
+
" print(f'Number of rows: {data.shape[0]} \\nNumber of columns: {data.shape[1]}\\n')\n",
|
| 109 |
+
" print(f\"Labels: \\n{data['label'].value_counts()}\\n\")\n",
|
| 110 |
+
" print(f\"Missing values: {data.isnull().values.any()}\\n\")\n",
|
| 111 |
+
" \n",
|
| 112 |
+
" duplicate = data[data.duplicated()]\n",
|
| 113 |
+
" print(f\"Duplicate Rows : {len(duplicate.sum(axis=1))}\")\n",
|
| 114 |
+
"\n",
|
| 115 |
+
" return data\n",
|
| 116 |
+
"\n",
|
| 117 |
+
"\n",
|
| 118 |
+
"# Remove duplicate rows (optional)\n",
|
| 119 |
+
"def remove_duplicate_rows(dataset_path: str):\n",
|
| 120 |
+
" '''\n",
|
| 121 |
+
" Remove duplicated data from the dataset then save it to another files\n",
|
| 122 |
+
" '''\n",
|
| 123 |
+
" \n",
|
| 124 |
+
" df = pd.read_csv(dataset_path)\n",
|
| 125 |
+
" df.drop_duplicates(keep=\"first\", inplace=True)\n",
|
| 126 |
+
" df.to_csv(f\"cleaned_dataset.csv\", sep=',', encoding='utf-8', index=False)\n",
|
| 127 |
+
"\n",
|
| 128 |
+
"\n",
|
| 129 |
+
"df = describe_dataset(\"./dataset.csv\")"
|
| 130 |
+
]
|
| 131 |
+
},
|
| 132 |
+
{
|
| 133 |
+
"cell_type": "markdown",
|
| 134 |
+
"metadata": {},
|
| 135 |
+
"source": [
|
| 136 |
+
"### 1.2. Preprocess data"
|
| 137 |
+
]
|
| 138 |
+
},
|
| 139 |
+
{
|
| 140 |
+
"cell_type": "code",
|
| 141 |
+
"execution_count": 7,
|
| 142 |
+
"metadata": {},
|
| 143 |
+
"outputs": [
|
| 144 |
+
{
|
| 145 |
+
"name": "stdout",
|
| 146 |
+
"output_type": "stream",
|
| 147 |
+
"text": [
|
| 148 |
+
"Number of rows: 17040 \n",
|
| 149 |
+
"Number of columns: 53\n",
|
| 150 |
+
"\n",
|
| 151 |
+
"Labels: \n",
|
| 152 |
+
"1 6171\n",
|
| 153 |
+
"2 5735\n",
|
| 154 |
+
"0 5134\n",
|
| 155 |
+
"Name: label, dtype: int64\n",
|
| 156 |
+
"\n"
|
| 157 |
+
]
|
| 158 |
+
}
|
| 159 |
+
],
|
| 160 |
+
"source": [
|
| 161 |
+
"# load dataset\n",
|
| 162 |
+
"df = pd.read_csv(\"./dataset.csv\")\n",
|
| 163 |
+
"\n",
|
| 164 |
+
"# Categorizing label\n",
|
| 165 |
+
"df.loc[df[\"label\"] == \"I\", \"label\"] = 0\n",
|
| 166 |
+
"df.loc[df[\"label\"] == \"M\", \"label\"] = 1\n",
|
| 167 |
+
"df.loc[df[\"label\"] == \"D\", \"label\"] = 2\n",
|
| 168 |
+
"\n",
|
| 169 |
+
"print(f'Number of rows: {df.shape[0]} \\nNumber of columns: {df.shape[1]}\\n')\n",
|
| 170 |
+
"print(f\"Labels: \\n{df['label'].value_counts()}\\n\")"
|
| 171 |
+
]
|
| 172 |
+
},
|
| 173 |
+
{
|
| 174 |
+
"cell_type": "code",
|
| 175 |
+
"execution_count": 8,
|
| 176 |
+
"metadata": {},
|
| 177 |
+
"outputs": [],
|
| 178 |
+
"source": [
|
| 179 |
+
"# Standard Scaling of features\n",
|
| 180 |
+
"with open(\"./model/input_scaler.pkl\", \"rb\") as f2:\n",
|
| 181 |
+
" input_scaler = pickle.load(f2)\n",
|
| 182 |
+
"\n",
|
| 183 |
+
"x = df.drop(\"label\", axis = 1)\n",
|
| 184 |
+
"x = pd.DataFrame(input_scaler.transform(x))\n",
|
| 185 |
+
"\n",
|
| 186 |
+
"y = df[\"label\"]\n",
|
| 187 |
+
"\n",
|
| 188 |
+
"# # Converting prediction to categorical\n",
|
| 189 |
+
"y_cat = to_categorical(y)"
|
| 190 |
+
]
|
| 191 |
+
},
|
| 192 |
+
{
|
| 193 |
+
"cell_type": "code",
|
| 194 |
+
"execution_count": 6,
|
| 195 |
+
"metadata": {},
|
| 196 |
+
"outputs": [],
|
| 197 |
+
"source": [
|
| 198 |
+
"x_train, x_test, y_train, y_test = train_test_split(x.values, y_cat, test_size=0.2)"
|
| 199 |
+
]
|
| 200 |
+
},
|
| 201 |
+
{
|
| 202 |
+
"cell_type": "markdown",
|
| 203 |
+
"metadata": {},
|
| 204 |
+
"source": [
|
| 205 |
+
"### 1.3. Train model"
|
| 206 |
+
]
|
| 207 |
+
},
|
| 208 |
+
{
|
| 209 |
+
"cell_type": "code",
|
| 210 |
+
"execution_count": 7,
|
| 211 |
+
"metadata": {},
|
| 212 |
+
"outputs": [
|
| 213 |
+
{
|
| 214 |
+
"name": "stdout",
|
| 215 |
+
"output_type": "stream",
|
| 216 |
+
"text": [
|
| 217 |
+
"Metal device set to: Apple M1\n",
|
| 218 |
+
"\n",
|
| 219 |
+
"systemMemory: 16.00 GB\n",
|
| 220 |
+
"maxCacheSize: 5.33 GB\n",
|
| 221 |
+
"\n",
|
| 222 |
+
"Model: \"sequential\"\n",
|
| 223 |
+
"_________________________________________________________________\n",
|
| 224 |
+
" Layer (type) Output Shape Param # \n",
|
| 225 |
+
"=================================================================\n",
|
| 226 |
+
" dense (Dense) (None, 52) 2756 \n",
|
| 227 |
+
" \n",
|
| 228 |
+
" dropout (Dropout) (None, 52) 0 \n",
|
| 229 |
+
" \n",
|
| 230 |
+
" dense_1 (Dense) (None, 52) 2756 \n",
|
| 231 |
+
" \n",
|
| 232 |
+
" dropout_1 (Dropout) (None, 52) 0 \n",
|
| 233 |
+
" \n",
|
| 234 |
+
" dense_2 (Dense) (None, 14) 742 \n",
|
| 235 |
+
" \n",
|
| 236 |
+
" dense_3 (Dense) (None, 3) 45 \n",
|
| 237 |
+
" \n",
|
| 238 |
+
"=================================================================\n",
|
| 239 |
+
"Total params: 6,299\n",
|
| 240 |
+
"Trainable params: 6,299\n",
|
| 241 |
+
"Non-trainable params: 0\n",
|
| 242 |
+
"_________________________________________________________________\n"
|
| 243 |
+
]
|
| 244 |
+
},
|
| 245 |
+
{
|
| 246 |
+
"name": "stderr",
|
| 247 |
+
"output_type": "stream",
|
| 248 |
+
"text": [
|
| 249 |
+
"2022-11-14 10:38:36.786157: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:305] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.\n",
|
| 250 |
+
"2022-11-14 10:38:36.786920: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:271] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: <undefined>)\n"
|
| 251 |
+
]
|
| 252 |
+
}
|
| 253 |
+
],
|
| 254 |
+
"source": [
|
| 255 |
+
"model = Sequential()\n",
|
| 256 |
+
"model.add(Dense(52, input_dim = 52, activation = \"relu\"))\n",
|
| 257 |
+
"model.add(Dropout(0.5))\n",
|
| 258 |
+
"model.add(Dense(52, activation = \"relu\"))\n",
|
| 259 |
+
"model.add(Dropout(0.5))\n",
|
| 260 |
+
"model.add(Dense(14, activation = \"relu\"))\n",
|
| 261 |
+
"model.add(Dense(3, activation = \"softmax\"))\n",
|
| 262 |
+
"model.compile(Adam(lr = 0.01), \"categorical_crossentropy\", metrics = [\"accuracy\"])\n",
|
| 263 |
+
"model.summary()"
|
| 264 |
+
]
|
| 265 |
+
},
|
| 266 |
+
{
|
| 267 |
+
"cell_type": "code",
|
| 268 |
+
"execution_count": 8,
|
| 269 |
+
"metadata": {},
|
| 270 |
+
"outputs": [
|
| 271 |
+
{
|
| 272 |
+
"name": "stdout",
|
| 273 |
+
"output_type": "stream",
|
| 274 |
+
"text": [
|
| 275 |
+
"Epoch 1/100\n"
|
| 276 |
+
]
|
| 277 |
+
},
|
| 278 |
+
{
|
| 279 |
+
"name": "stderr",
|
| 280 |
+
"output_type": "stream",
|
| 281 |
+
"text": [
|
| 282 |
+
"2022-11-14 10:38:48.571461: W tensorflow/core/platform/profile_utils/cpu_utils.cc:128] Failed to get CPU frequency: 0 Hz\n",
|
| 283 |
+
"2022-11-14 10:38:48.827690: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:113] Plugin optimizer for device_type GPU is enabled.\n"
|
| 284 |
+
]
|
| 285 |
+
},
|
| 286 |
+
{
|
| 287 |
+
"name": "stdout",
|
| 288 |
+
"output_type": "stream",
|
| 289 |
+
"text": [
|
| 290 |
+
"1364/1364 [==============================] - ETA: 0s - loss: 0.1082 - accuracy: 0.9696"
|
| 291 |
+
]
|
| 292 |
+
},
|
| 293 |
+
{
|
| 294 |
+
"name": "stderr",
|
| 295 |
+
"output_type": "stream",
|
| 296 |
+
"text": [
|
| 297 |
+
"2022-11-14 10:39:02.176190: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:113] Plugin optimizer for device_type GPU is enabled.\n"
|
| 298 |
+
]
|
| 299 |
+
},
|
| 300 |
+
{
|
| 301 |
+
"name": "stdout",
|
| 302 |
+
"output_type": "stream",
|
| 303 |
+
"text": [
|
| 304 |
+
"1364/1364 [==============================] - 15s 11ms/step - loss: 0.1082 - accuracy: 0.9696 - val_loss: 0.0339 - val_accuracy: 0.9947\n",
|
| 305 |
+
"Epoch 2/100\n",
|
| 306 |
+
"1364/1364 [==============================] - 13s 10ms/step - loss: 0.0599 - accuracy: 0.9872 - val_loss: 0.0595 - val_accuracy: 0.9815\n",
|
| 307 |
+
"Epoch 3/100\n",
|
| 308 |
+
"1364/1364 [==============================] - 13s 10ms/step - loss: 0.0527 - accuracy: 0.9885 - val_loss: 0.0530 - val_accuracy: 0.9944\n",
|
| 309 |
+
"Epoch 4/100\n",
|
| 310 |
+
"1364/1364 [==============================] - 13s 10ms/step - loss: 0.0450 - accuracy: 0.9905 - val_loss: 0.0356 - val_accuracy: 0.9938\n",
|
| 311 |
+
"Epoch 5/100\n",
|
| 312 |
+
"1364/1364 [==============================] - 13s 10ms/step - loss: 0.1253 - accuracy: 0.9747 - val_loss: 0.0501 - val_accuracy: 0.9921\n",
|
| 313 |
+
"Epoch 6/100\n",
|
| 314 |
+
"1364/1364 [==============================] - 14s 10ms/step - loss: 0.0682 - accuracy: 0.9884 - val_loss: 0.0445 - val_accuracy: 0.9865\n",
|
| 315 |
+
"Epoch 7/100\n",
|
| 316 |
+
"1364/1364 [==============================] - 14s 10ms/step - loss: 0.0534 - accuracy: 0.9900 - val_loss: 0.0400 - val_accuracy: 0.9933\n",
|
| 317 |
+
"Epoch 8/100\n",
|
| 318 |
+
"1364/1364 [==============================] - 13s 10ms/step - loss: 0.0627 - accuracy: 0.9913 - val_loss: 0.0396 - val_accuracy: 0.9950\n",
|
| 319 |
+
"Epoch 9/100\n",
|
| 320 |
+
"1364/1364 [==============================] - 14s 10ms/step - loss: 0.0556 - accuracy: 0.9905 - val_loss: 0.0397 - val_accuracy: 0.9953\n",
|
| 321 |
+
"Epoch 10/100\n",
|
| 322 |
+
"1364/1364 [==============================] - 13s 10ms/step - loss: 0.0656 - accuracy: 0.9900 - val_loss: 0.0657 - val_accuracy: 0.9795\n",
|
| 323 |
+
"Epoch 11/100\n",
|
| 324 |
+
"1364/1364 [==============================] - 13s 10ms/step - loss: 0.0661 - accuracy: 0.9862 - val_loss: 0.0454 - val_accuracy: 0.9933\n",
|
| 325 |
+
"Epoch 12/100\n",
|
| 326 |
+
"1364/1364 [==============================] - 13s 10ms/step - loss: 0.0638 - accuracy: 0.9894 - val_loss: 0.0404 - val_accuracy: 0.9930\n",
|
| 327 |
+
"Epoch 13/100\n",
|
| 328 |
+
"1364/1364 [==============================] - 14s 10ms/step - loss: 0.0643 - accuracy: 0.9911 - val_loss: 0.1060 - val_accuracy: 0.9668\n",
|
| 329 |
+
"Epoch 14/100\n",
|
| 330 |
+
"1364/1364 [==============================] - 14s 10ms/step - loss: 0.1065 - accuracy: 0.9883 - val_loss: 0.0645 - val_accuracy: 0.9844\n",
|
| 331 |
+
"Epoch 15/100\n",
|
| 332 |
+
"1364/1364 [==============================] - 13s 10ms/step - loss: 0.0914 - accuracy: 0.9858 - val_loss: 0.0930 - val_accuracy: 0.9683\n",
|
| 333 |
+
"Epoch 16/100\n",
|
| 334 |
+
"1364/1364 [==============================] - 14s 10ms/step - loss: 0.0463 - accuracy: 0.9916 - val_loss: 0.0557 - val_accuracy: 0.9886\n",
|
| 335 |
+
"Epoch 17/100\n",
|
| 336 |
+
"1364/1364 [==============================] - 14s 11ms/step - loss: 0.1247 - accuracy: 0.9880 - val_loss: 0.0620 - val_accuracy: 0.9886\n",
|
| 337 |
+
"Epoch 18/100\n",
|
| 338 |
+
"1364/1364 [==============================] - 14s 10ms/step - loss: 0.1441 - accuracy: 0.9807 - val_loss: 0.0483 - val_accuracy: 0.9935\n",
|
| 339 |
+
"Epoch 19/100\n",
|
| 340 |
+
"1364/1364 [==============================] - 13s 10ms/step - loss: 0.1044 - accuracy: 0.9886 - val_loss: 0.0375 - val_accuracy: 0.9944\n",
|
| 341 |
+
"Epoch 20/100\n",
|
| 342 |
+
"1364/1364 [==============================] - 13s 10ms/step - loss: 0.0604 - accuracy: 0.9900 - val_loss: 0.0382 - val_accuracy: 0.9933\n",
|
| 343 |
+
"Epoch 21/100\n",
|
| 344 |
+
"1364/1364 [==============================] - 13s 10ms/step - loss: 0.0758 - accuracy: 0.9916 - val_loss: 0.0461 - val_accuracy: 0.9933\n",
|
| 345 |
+
"Epoch 22/100\n",
|
| 346 |
+
"1364/1364 [==============================] - 14s 10ms/step - loss: 0.0413 - accuracy: 0.9916 - val_loss: 0.1001 - val_accuracy: 0.9704\n",
|
| 347 |
+
"Epoch 23/100\n",
|
| 348 |
+
"1364/1364 [==============================] - 13s 10ms/step - loss: 0.0492 - accuracy: 0.9934 - val_loss: 0.0482 - val_accuracy: 0.9959\n",
|
| 349 |
+
"Epoch 24/100\n",
|
| 350 |
+
"1364/1364 [==============================] - 14s 10ms/step - loss: 0.0689 - accuracy: 0.9897 - val_loss: 0.0374 - val_accuracy: 0.9950\n",
|
| 351 |
+
"Epoch 25/100\n",
|
| 352 |
+
"1364/1364 [==============================] - 14s 10ms/step - loss: 0.0663 - accuracy: 0.9911 - val_loss: 0.0341 - val_accuracy: 0.9944\n",
|
| 353 |
+
"Epoch 26/100\n",
|
| 354 |
+
"1364/1364 [==============================] - 13s 10ms/step - loss: 0.0983 - accuracy: 0.9904 - val_loss: 0.0470 - val_accuracy: 0.9933\n",
|
| 355 |
+
"Epoch 27/100\n",
|
| 356 |
+
"1364/1364 [==============================] - 13s 10ms/step - loss: 0.0699 - accuracy: 0.9906 - val_loss: 0.0455 - val_accuracy: 0.9921\n",
|
| 357 |
+
"Epoch 28/100\n",
|
| 358 |
+
"1364/1364 [==============================] - 14s 10ms/step - loss: 0.0521 - accuracy: 0.9908 - val_loss: 0.0438 - val_accuracy: 0.9915\n",
|
| 359 |
+
"Epoch 29/100\n",
|
| 360 |
+
"1364/1364 [==============================] - 13s 10ms/step - loss: 0.0594 - accuracy: 0.9904 - val_loss: 0.0490 - val_accuracy: 0.9930\n",
|
| 361 |
+
"Epoch 30/100\n",
|
| 362 |
+
"1364/1364 [==============================] - 13s 10ms/step - loss: 0.1044 - accuracy: 0.9903 - val_loss: 0.0439 - val_accuracy: 0.9933\n",
|
| 363 |
+
"Epoch 31/100\n",
|
| 364 |
+
"1364/1364 [==============================] - 13s 10ms/step - loss: 0.0414 - accuracy: 0.9930 - val_loss: 0.0664 - val_accuracy: 0.9806\n",
|
| 365 |
+
"Epoch 32/100\n",
|
| 366 |
+
"1364/1364 [==============================] - 13s 10ms/step - loss: 0.0507 - accuracy: 0.9927 - val_loss: 0.1468 - val_accuracy: 0.9563\n",
|
| 367 |
+
"Epoch 33/100\n",
|
| 368 |
+
"1364/1364 [==============================] - 13s 10ms/step - loss: 0.0745 - accuracy: 0.9907 - val_loss: 0.0730 - val_accuracy: 0.9742\n",
|
| 369 |
+
"Epoch 34/100\n",
|
| 370 |
+
"1364/1364 [==============================] - 14s 11ms/step - loss: 0.0595 - accuracy: 0.9896 - val_loss: 0.0525 - val_accuracy: 0.9933\n",
|
| 371 |
+
"Epoch 35/100\n",
|
| 372 |
+
"1364/1364 [==============================] - 13s 10ms/step - loss: 0.0849 - accuracy: 0.9884 - val_loss: 0.0396 - val_accuracy: 0.9941\n",
|
| 373 |
+
"Epoch 36/100\n",
|
| 374 |
+
"1364/1364 [==============================] - 14s 10ms/step - loss: 0.0738 - accuracy: 0.9882 - val_loss: 0.0384 - val_accuracy: 0.9950\n",
|
| 375 |
+
"Epoch 37/100\n",
|
| 376 |
+
"1364/1364 [==============================] - 14s 10ms/step - loss: 0.1039 - accuracy: 0.9892 - val_loss: 0.0417 - val_accuracy: 0.9941\n",
|
| 377 |
+
"Epoch 38/100\n",
|
| 378 |
+
"1364/1364 [==============================] - 14s 10ms/step - loss: 0.1345 - accuracy: 0.9883 - val_loss: 0.0380 - val_accuracy: 0.9921\n",
|
| 379 |
+
"Epoch 39/100\n",
|
| 380 |
+
"1364/1364 [==============================] - 13s 10ms/step - loss: 0.0483 - accuracy: 0.9930 - val_loss: 0.0325 - val_accuracy: 0.9959\n",
|
| 381 |
+
"Epoch 40/100\n",
|
| 382 |
+
"1364/1364 [==============================] - 13s 10ms/step - loss: 0.0427 - accuracy: 0.9931 - val_loss: 0.0399 - val_accuracy: 0.9947\n",
|
| 383 |
+
"Epoch 41/100\n",
|
| 384 |
+
"1364/1364 [==============================] - 14s 11ms/step - loss: 0.0430 - accuracy: 0.9924 - val_loss: 0.0555 - val_accuracy: 0.9950\n",
|
| 385 |
+
"Epoch 42/100\n",
|
| 386 |
+
"1364/1364 [==============================] - 14s 10ms/step - loss: 0.0536 - accuracy: 0.9917 - val_loss: 0.0533 - val_accuracy: 0.9933\n",
|
| 387 |
+
"Epoch 43/100\n",
|
| 388 |
+
"1364/1364 [==============================] - 13s 10ms/step - loss: 0.1049 - accuracy: 0.9912 - val_loss: 0.0374 - val_accuracy: 0.9935\n",
|
| 389 |
+
"Epoch 44/100\n",
|
| 390 |
+
"1364/1364 [==============================] - 13s 10ms/step - loss: 0.0831 - accuracy: 0.9930 - val_loss: 0.0339 - val_accuracy: 0.9950\n",
|
| 391 |
+
"Epoch 45/100\n",
|
| 392 |
+
"1364/1364 [==============================] - 13s 10ms/step - loss: 0.0529 - accuracy: 0.9920 - val_loss: 0.0365 - val_accuracy: 0.9953\n",
|
| 393 |
+
"Epoch 46/100\n",
|
| 394 |
+
"1364/1364 [==============================] - 14s 10ms/step - loss: 0.0634 - accuracy: 0.9909 - val_loss: 0.0412 - val_accuracy: 0.9947\n",
|
| 395 |
+
"Epoch 47/100\n",
|
| 396 |
+
"1364/1364 [==============================] - 14s 10ms/step - loss: 0.0626 - accuracy: 0.9917 - val_loss: 0.0390 - val_accuracy: 0.9956\n",
|
| 397 |
+
"Epoch 48/100\n",
|
| 398 |
+
"1364/1364 [==============================] - 14s 10ms/step - loss: 0.1139 - accuracy: 0.9907 - val_loss: 0.0426 - val_accuracy: 0.9944\n",
|
| 399 |
+
"Epoch 49/100\n",
|
| 400 |
+
"1364/1364 [==============================] - 14s 10ms/step - loss: 0.0622 - accuracy: 0.9902 - val_loss: 0.0400 - val_accuracy: 0.9938\n",
|
| 401 |
+
"Epoch 50/100\n",
|
| 402 |
+
"1364/1364 [==============================] - 13s 10ms/step - loss: 0.0857 - accuracy: 0.9900 - val_loss: 0.0980 - val_accuracy: 0.9862\n",
|
| 403 |
+
"Epoch 51/100\n",
|
| 404 |
+
"1364/1364 [==============================] - 14s 10ms/step - loss: 0.1315 - accuracy: 0.9900 - val_loss: 0.0488 - val_accuracy: 0.9944\n",
|
| 405 |
+
"Epoch 52/100\n",
|
| 406 |
+
"1364/1364 [==============================] - 14s 10ms/step - loss: 0.0508 - accuracy: 0.9925 - val_loss: 0.0433 - val_accuracy: 0.9944\n",
|
| 407 |
+
"Epoch 53/100\n",
|
| 408 |
+
"1364/1364 [==============================] - 14s 10ms/step - loss: 0.0558 - accuracy: 0.9919 - val_loss: 0.0418 - val_accuracy: 0.9938\n",
|
| 409 |
+
"Epoch 54/100\n",
|
| 410 |
+
"1364/1364 [==============================] - 13s 10ms/step - loss: 0.0391 - accuracy: 0.9937 - val_loss: 0.0398 - val_accuracy: 0.9947\n",
|
| 411 |
+
"Epoch 55/100\n",
|
| 412 |
+
"1364/1364 [==============================] - 14s 10ms/step - loss: 0.0447 - accuracy: 0.9938 - val_loss: 0.0511 - val_accuracy: 0.9953\n",
|
| 413 |
+
"Epoch 56/100\n",
|
| 414 |
+
"1364/1364 [==============================] - 14s 10ms/step - loss: 0.0377 - accuracy: 0.9937 - val_loss: 0.0454 - val_accuracy: 0.9933\n",
|
| 415 |
+
"Epoch 57/100\n",
|
| 416 |
+
"1364/1364 [==============================] - 15s 11ms/step - loss: 0.0467 - accuracy: 0.9927 - val_loss: 0.0542 - val_accuracy: 0.9912\n",
|
| 417 |
+
"Epoch 58/100\n",
|
| 418 |
+
"1364/1364 [==============================] - 14s 10ms/step - loss: 0.0450 - accuracy: 0.9925 - val_loss: 0.0441 - val_accuracy: 0.9938\n",
|
| 419 |
+
"Epoch 59/100\n",
|
| 420 |
+
"1364/1364 [==============================] - 14s 10ms/step - loss: 0.0541 - accuracy: 0.9938 - val_loss: 0.0586 - val_accuracy: 0.9839\n",
|
| 421 |
+
"Epoch 60/100\n",
|
| 422 |
+
"1364/1364 [==============================] - 13s 10ms/step - loss: 0.0363 - accuracy: 0.9936 - val_loss: 0.0418 - val_accuracy: 0.9944\n",
|
| 423 |
+
"Epoch 61/100\n",
|
| 424 |
+
"1364/1364 [==============================] - 13s 10ms/step - loss: 0.0612 - accuracy: 0.9919 - val_loss: 0.0478 - val_accuracy: 0.9918\n",
|
| 425 |
+
"Epoch 62/100\n",
|
| 426 |
+
"1364/1364 [==============================] - 13s 10ms/step - loss: 0.0545 - accuracy: 0.9922 - val_loss: 0.0776 - val_accuracy: 0.9850\n",
|
| 427 |
+
"Epoch 63/100\n",
|
| 428 |
+
"1364/1364 [==============================] - 14s 10ms/step - loss: 0.0663 - accuracy: 0.9910 - val_loss: 0.0660 - val_accuracy: 0.9933\n",
|
| 429 |
+
"Epoch 64/100\n",
|
| 430 |
+
"1364/1364 [==============================] - 16s 11ms/step - loss: 0.0597 - accuracy: 0.9920 - val_loss: 0.0792 - val_accuracy: 0.9850\n",
|
| 431 |
+
"Epoch 65/100\n",
|
| 432 |
+
"1364/1364 [==============================] - 14s 10ms/step - loss: 0.0395 - accuracy: 0.9927 - val_loss: 0.1932 - val_accuracy: 0.9935\n",
|
| 433 |
+
"Epoch 66/100\n",
|
| 434 |
+
"1364/1364 [==============================] - 14s 10ms/step - loss: 0.1077 - accuracy: 0.9911 - val_loss: 0.0754 - val_accuracy: 0.9806\n",
|
| 435 |
+
"Epoch 67/100\n",
|
| 436 |
+
"1364/1364 [==============================] - 13s 10ms/step - loss: 0.0505 - accuracy: 0.9917 - val_loss: 0.1537 - val_accuracy: 0.9824\n",
|
| 437 |
+
"Epoch 68/100\n",
|
| 438 |
+
"1364/1364 [==============================] - 14s 10ms/step - loss: 0.1343 - accuracy: 0.9897 - val_loss: 0.0956 - val_accuracy: 0.9921\n",
|
| 439 |
+
"Epoch 69/100\n",
|
| 440 |
+
"1364/1364 [==============================] - 14s 10ms/step - loss: 0.0766 - accuracy: 0.9871 - val_loss: 0.0841 - val_accuracy: 0.9880\n",
|
| 441 |
+
"Epoch 70/100\n",
|
| 442 |
+
"1364/1364 [==============================] - 14s 10ms/step - loss: 0.1537 - accuracy: 0.9734 - val_loss: 0.0734 - val_accuracy: 0.9933\n",
|
| 443 |
+
"Epoch 71/100\n",
|
| 444 |
+
"1364/1364 [==============================] - 14s 10ms/step - loss: 0.0720 - accuracy: 0.9884 - val_loss: 0.0573 - val_accuracy: 0.9935\n",
|
| 445 |
+
"Epoch 72/100\n",
|
| 446 |
+
"1364/1364 [==============================] - 15s 11ms/step - loss: 0.0545 - accuracy: 0.9928 - val_loss: 0.0505 - val_accuracy: 0.9927\n",
|
| 447 |
+
"Epoch 73/100\n",
|
| 448 |
+
"1364/1364 [==============================] - 15s 11ms/step - loss: 0.0748 - accuracy: 0.9923 - val_loss: 0.0453 - val_accuracy: 0.9950\n",
|
| 449 |
+
"Epoch 74/100\n",
|
| 450 |
+
"1364/1364 [==============================] - 15s 11ms/step - loss: 0.0661 - accuracy: 0.9891 - val_loss: 0.0652 - val_accuracy: 0.9941\n",
|
| 451 |
+
"Epoch 75/100\n",
|
| 452 |
+
"1364/1364 [==============================] - 15s 11ms/step - loss: 0.0676 - accuracy: 0.9893 - val_loss: 0.0467 - val_accuracy: 0.9935\n",
|
| 453 |
+
"Epoch 76/100\n",
|
| 454 |
+
"1364/1364 [==============================] - 14s 10ms/step - loss: 0.0601 - accuracy: 0.9891 - val_loss: 0.1189 - val_accuracy: 0.9944\n",
|
| 455 |
+
"Epoch 77/100\n",
|
| 456 |
+
"1364/1364 [==============================] - 14s 11ms/step - loss: 0.0404 - accuracy: 0.9940 - val_loss: 0.0557 - val_accuracy: 0.9935\n",
|
| 457 |
+
"Epoch 78/100\n",
|
| 458 |
+
"1364/1364 [==============================] - 14s 10ms/step - loss: 0.0686 - accuracy: 0.9932 - val_loss: 0.0495 - val_accuracy: 0.9941\n",
|
| 459 |
+
"Epoch 79/100\n",
|
| 460 |
+
"1364/1364 [==============================] - 14s 10ms/step - loss: 0.0409 - accuracy: 0.9943 - val_loss: 0.0618 - val_accuracy: 0.9950\n",
|
| 461 |
+
"Epoch 80/100\n",
|
| 462 |
+
"1364/1364 [==============================] - 14s 10ms/step - loss: 0.0613 - accuracy: 0.9928 - val_loss: 0.0438 - val_accuracy: 0.9944\n",
|
| 463 |
+
"Epoch 81/100\n",
|
| 464 |
+
"1364/1364 [==============================] - 15s 11ms/step - loss: 0.0818 - accuracy: 0.9893 - val_loss: 0.0639 - val_accuracy: 0.9935\n",
|
| 465 |
+
"Epoch 82/100\n",
|
| 466 |
+
"1364/1364 [==============================] - 14s 10ms/step - loss: 0.0558 - accuracy: 0.9898 - val_loss: 0.0750 - val_accuracy: 0.9933\n",
|
| 467 |
+
"Epoch 83/100\n",
|
| 468 |
+
"1364/1364 [==============================] - 14s 10ms/step - loss: 0.0568 - accuracy: 0.9919 - val_loss: 0.0522 - val_accuracy: 0.9950\n",
|
| 469 |
+
"Epoch 84/100\n",
|
| 470 |
+
"1364/1364 [==============================] - 14s 10ms/step - loss: 0.0707 - accuracy: 0.9927 - val_loss: 0.0336 - val_accuracy: 0.9950\n",
|
| 471 |
+
"Epoch 85/100\n",
|
| 472 |
+
"1364/1364 [==============================] - 15s 11ms/step - loss: 0.0615 - accuracy: 0.9929 - val_loss: 0.0433 - val_accuracy: 0.9953\n",
|
| 473 |
+
"Epoch 86/100\n",
|
| 474 |
+
"1364/1364 [==============================] - 13s 10ms/step - loss: 0.0375 - accuracy: 0.9943 - val_loss: 0.1049 - val_accuracy: 0.9947\n",
|
| 475 |
+
"Epoch 87/100\n",
|
| 476 |
+
"1364/1364 [==============================] - 14s 10ms/step - loss: 0.0372 - accuracy: 0.9942 - val_loss: 0.0892 - val_accuracy: 0.9950\n",
|
| 477 |
+
"Epoch 88/100\n",
|
| 478 |
+
"1364/1364 [==============================] - 14s 10ms/step - loss: 0.0641 - accuracy: 0.9927 - val_loss: 0.0985 - val_accuracy: 0.9935\n",
|
| 479 |
+
"Epoch 89/100\n",
|
| 480 |
+
"1364/1364 [==============================] - 14s 10ms/step - loss: 0.0512 - accuracy: 0.9913 - val_loss: 0.0580 - val_accuracy: 0.9950\n",
|
| 481 |
+
"Epoch 90/100\n",
|
| 482 |
+
"1364/1364 [==============================] - 14s 10ms/step - loss: 0.0864 - accuracy: 0.9921 - val_loss: 0.1047 - val_accuracy: 0.9950\n",
|
| 483 |
+
"Epoch 91/100\n",
|
| 484 |
+
"1364/1364 [==============================] - 13s 10ms/step - loss: 0.0407 - accuracy: 0.9934 - val_loss: 0.0739 - val_accuracy: 0.9944\n",
|
| 485 |
+
"Epoch 92/100\n",
|
| 486 |
+
"1364/1364 [==============================] - 14s 10ms/step - loss: 0.0419 - accuracy: 0.9930 - val_loss: 0.0651 - val_accuracy: 0.9853\n",
|
| 487 |
+
"Epoch 93/100\n",
|
| 488 |
+
"1364/1364 [==============================] - 14s 10ms/step - loss: 0.0989 - accuracy: 0.9897 - val_loss: 0.0819 - val_accuracy: 0.9938\n",
|
| 489 |
+
"Epoch 94/100\n",
|
| 490 |
+
"1364/1364 [==============================] - 14s 10ms/step - loss: 0.0427 - accuracy: 0.9927 - val_loss: 0.0784 - val_accuracy: 0.9947\n",
|
| 491 |
+
"Epoch 95/100\n",
|
| 492 |
+
"1364/1364 [==============================] - 14s 10ms/step - loss: 0.0463 - accuracy: 0.9938 - val_loss: 0.0942 - val_accuracy: 0.9933\n",
|
| 493 |
+
"Epoch 96/100\n",
|
| 494 |
+
"1364/1364 [==============================] - 13s 10ms/step - loss: 0.0598 - accuracy: 0.9908 - val_loss: 0.1309 - val_accuracy: 0.9909\n",
|
| 495 |
+
"Epoch 97/100\n",
|
| 496 |
+
"1364/1364 [==============================] - 14s 10ms/step - loss: 0.0717 - accuracy: 0.9916 - val_loss: 0.0960 - val_accuracy: 0.9944\n",
|
| 497 |
+
"Epoch 98/100\n",
|
| 498 |
+
"1364/1364 [==============================] - 14s 10ms/step - loss: 0.0370 - accuracy: 0.9938 - val_loss: 0.1445 - val_accuracy: 0.9947\n",
|
| 499 |
+
"Epoch 99/100\n",
|
| 500 |
+
"1364/1364 [==============================] - 14s 10ms/step - loss: 0.0679 - accuracy: 0.9929 - val_loss: 0.1177 - val_accuracy: 0.9938\n",
|
| 501 |
+
"Epoch 100/100\n",
|
| 502 |
+
"1364/1364 [==============================] - 14s 10ms/step - loss: 0.0516 - accuracy: 0.9915 - val_loss: 0.1030 - val_accuracy: 0.9950\n"
|
| 503 |
+
]
|
| 504 |
+
},
|
| 505 |
+
{
|
| 506 |
+
"data": {
|
| 507 |
+
"text/plain": [
|
| 508 |
+
"<keras.callbacks.History at 0x15b5fc310>"
|
| 509 |
+
]
|
| 510 |
+
},
|
| 511 |
+
"execution_count": 8,
|
| 512 |
+
"metadata": {},
|
| 513 |
+
"output_type": "execute_result"
|
| 514 |
+
}
|
| 515 |
+
],
|
| 516 |
+
"source": [
|
| 517 |
+
"model.fit(x_train, y_train, epochs=100, batch_size=10, validation_data=(x_test, y_test))"
|
| 518 |
+
]
|
| 519 |
+
},
|
| 520 |
+
{
|
| 521 |
+
"cell_type": "markdown",
|
| 522 |
+
"metadata": {},
|
| 523 |
+
"source": [
|
| 524 |
+
"## 2. Model Evaluation"
|
| 525 |
+
]
|
| 526 |
+
},
|
| 527 |
+
{
|
| 528 |
+
"cell_type": "markdown",
|
| 529 |
+
"metadata": {},
|
| 530 |
+
"source": [
|
| 531 |
+
"### 2.1. Train set evaluation"
|
| 532 |
+
]
|
| 533 |
+
},
|
| 534 |
+
{
|
| 535 |
+
"cell_type": "code",
|
| 536 |
+
"execution_count": 9,
|
| 537 |
+
"metadata": {},
|
| 538 |
+
"outputs": [
|
| 539 |
+
{
|
| 540 |
+
"name": "stdout",
|
| 541 |
+
"output_type": "stream",
|
| 542 |
+
"text": [
|
| 543 |
+
" 1/107 [..............................] - ETA: 30s"
|
| 544 |
+
]
|
| 545 |
+
},
|
| 546 |
+
{
|
| 547 |
+
"name": "stderr",
|
| 548 |
+
"output_type": "stream",
|
| 549 |
+
"text": [
|
| 550 |
+
"2022-11-14 11:04:43.450300: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:113] Plugin optimizer for device_type GPU is enabled.\n"
|
| 551 |
+
]
|
| 552 |
+
},
|
| 553 |
+
{
|
| 554 |
+
"name": "stdout",
|
| 555 |
+
"output_type": "stream",
|
| 556 |
+
"text": [
|
| 557 |
+
"107/107 [==============================] - 1s 3ms/step\n",
|
| 558 |
+
"107/107 [==============================] - 0s 2ms/step\n"
|
| 559 |
+
]
|
| 560 |
+
},
|
| 561 |
+
{
|
| 562 |
+
"data": {
|
| 563 |
+
"text/plain": [
|
| 564 |
+
"array([[1037, 7, 1],\n",
|
| 565 |
+
" [ 0, 1208, 3],\n",
|
| 566 |
+
" [ 0, 6, 1146]])"
|
| 567 |
+
]
|
| 568 |
+
},
|
| 569 |
+
"execution_count": 9,
|
| 570 |
+
"metadata": {},
|
| 571 |
+
"output_type": "execute_result"
|
| 572 |
+
}
|
| 573 |
+
],
|
| 574 |
+
"source": [
|
| 575 |
+
"predict_x = model.predict(x_test) \n",
|
| 576 |
+
"y_pred_class = np.argmax(predict_x, axis=1)\n",
|
| 577 |
+
"\n",
|
| 578 |
+
"y_pred = model.predict(x_test)\n",
|
| 579 |
+
"y_test_class = np.argmax(y_test, axis=1)\n",
|
| 580 |
+
"\n",
|
| 581 |
+
"confusion_matrix(y_test_class, y_pred_class)"
|
| 582 |
+
]
|
| 583 |
+
},
|
| 584 |
+
{
|
| 585 |
+
"cell_type": "code",
|
| 586 |
+
"execution_count": 10,
|
| 587 |
+
"metadata": {},
|
| 588 |
+
"outputs": [
|
| 589 |
+
{
|
| 590 |
+
"name": "stdout",
|
| 591 |
+
"output_type": "stream",
|
| 592 |
+
"text": [
|
| 593 |
+
" precision recall f1-score support\n",
|
| 594 |
+
"\n",
|
| 595 |
+
" 0 1.00 0.99 1.00 1045\n",
|
| 596 |
+
" 1 0.99 1.00 0.99 1211\n",
|
| 597 |
+
" 2 1.00 0.99 1.00 1152\n",
|
| 598 |
+
"\n",
|
| 599 |
+
" accuracy 1.00 3408\n",
|
| 600 |
+
" macro avg 1.00 0.99 1.00 3408\n",
|
| 601 |
+
"weighted avg 1.00 1.00 1.00 3408\n",
|
| 602 |
+
"\n"
|
| 603 |
+
]
|
| 604 |
+
}
|
| 605 |
+
],
|
| 606 |
+
"source": [
|
| 607 |
+
"print(classification_report(y_test_class, y_pred_class))"
|
| 608 |
+
]
|
| 609 |
+
},
|
| 610 |
+
{
|
| 611 |
+
"cell_type": "markdown",
|
| 612 |
+
"metadata": {},
|
| 613 |
+
"source": [
|
| 614 |
+
"### 2.2. Test set evaluation"
|
| 615 |
+
]
|
| 616 |
+
},
|
| 617 |
+
{
|
| 618 |
+
"cell_type": "code",
|
| 619 |
+
"execution_count": 5,
|
| 620 |
+
"metadata": {},
|
| 621 |
+
"outputs": [],
|
| 622 |
+
"source": [
|
| 623 |
+
"test_df = pd.read_csv(\"./test.csv\")\n",
|
| 624 |
+
"\n",
|
| 625 |
+
"# Categorizing label\n",
|
| 626 |
+
"test_df.loc[test_df[\"label\"] == \"I\", \"label\"] = 0\n",
|
| 627 |
+
"test_df.loc[test_df[\"label\"] == \"M\", \"label\"] = 1\n",
|
| 628 |
+
"test_df.loc[test_df[\"label\"] == \"D\", \"label\"] = 2"
|
| 629 |
+
]
|
| 630 |
+
},
|
| 631 |
+
{
|
| 632 |
+
"cell_type": "code",
|
| 633 |
+
"execution_count": 13,
|
| 634 |
+
"metadata": {},
|
| 635 |
+
"outputs": [],
|
| 636 |
+
"source": [
|
| 637 |
+
"# Standard Scaling of features\n",
|
| 638 |
+
"test_x = test_df.drop(\"label\", axis = 1)\n",
|
| 639 |
+
"test_x = pd.DataFrame(input_scaler.transform(test_x))\n",
|
| 640 |
+
"\n",
|
| 641 |
+
"test_y = test_df[\"label\"]\n",
|
| 642 |
+
"\n",
|
| 643 |
+
"# # Converting prediction to categorical\n",
|
| 644 |
+
"test_y_cat = to_categorical(test_y)"
|
| 645 |
+
]
|
| 646 |
+
},
|
| 647 |
+
{
|
| 648 |
+
"cell_type": "code",
|
| 649 |
+
"execution_count": 16,
|
| 650 |
+
"metadata": {},
|
| 651 |
+
"outputs": [
|
| 652 |
+
{
|
| 653 |
+
"name": "stdout",
|
| 654 |
+
"output_type": "stream",
|
| 655 |
+
"text": [
|
| 656 |
+
"26/26 [==============================] - 0s 2ms/step\n"
|
| 657 |
+
]
|
| 658 |
+
},
|
| 659 |
+
{
|
| 660 |
+
"data": {
|
| 661 |
+
"text/plain": [
|
| 662 |
+
"array([[267, 0, 0],\n",
|
| 663 |
+
" [ 0, 263, 0],\n",
|
| 664 |
+
" [ 0, 0, 299]])"
|
| 665 |
+
]
|
| 666 |
+
},
|
| 667 |
+
"execution_count": 16,
|
| 668 |
+
"metadata": {},
|
| 669 |
+
"output_type": "execute_result"
|
| 670 |
+
}
|
| 671 |
+
],
|
| 672 |
+
"source": [
|
| 673 |
+
"predict_x = model.predict(test_x) \n",
|
| 674 |
+
"y_pred_class = np.argmax(predict_x, axis=1)\n",
|
| 675 |
+
"y_test_class = np.argmax(test_y_cat, axis=1)\n",
|
| 676 |
+
"\n",
|
| 677 |
+
"confusion_matrix(y_test_class, y_pred_class)"
|
| 678 |
+
]
|
| 679 |
+
},
|
| 680 |
+
{
|
| 681 |
+
"cell_type": "code",
|
| 682 |
+
"execution_count": 17,
|
| 683 |
+
"metadata": {},
|
| 684 |
+
"outputs": [
|
| 685 |
+
{
|
| 686 |
+
"name": "stdout",
|
| 687 |
+
"output_type": "stream",
|
| 688 |
+
"text": [
|
| 689 |
+
" precision recall f1-score support\n",
|
| 690 |
+
"\n",
|
| 691 |
+
" 0 1.00 1.00 1.00 267\n",
|
| 692 |
+
" 1 1.00 1.00 1.00 263\n",
|
| 693 |
+
" 2 1.00 1.00 1.00 299\n",
|
| 694 |
+
"\n",
|
| 695 |
+
" accuracy 1.00 829\n",
|
| 696 |
+
" macro avg 1.00 1.00 1.00 829\n",
|
| 697 |
+
"weighted avg 1.00 1.00 1.00 829\n",
|
| 698 |
+
"\n"
|
| 699 |
+
]
|
| 700 |
+
}
|
| 701 |
+
],
|
| 702 |
+
"source": [
|
| 703 |
+
"print(classification_report(y_test_class, y_pred_class))"
|
| 704 |
+
]
|
| 705 |
+
},
|
| 706 |
+
{
|
| 707 |
+
"cell_type": "markdown",
|
| 708 |
+
"metadata": {},
|
| 709 |
+
"source": [
|
| 710 |
+
"## 3. Dump Model"
|
| 711 |
+
]
|
| 712 |
+
},
|
| 713 |
+
{
|
| 714 |
+
"cell_type": "code",
|
| 715 |
+
"execution_count": 11,
|
| 716 |
+
"metadata": {},
|
| 717 |
+
"outputs": [
|
| 718 |
+
{
|
| 719 |
+
"name": "stdout",
|
| 720 |
+
"output_type": "stream",
|
| 721 |
+
"text": [
|
| 722 |
+
"INFO:tensorflow:Assets written to: ram://6d931dd9-8715-41d1-81d4-8f3b853c1109/assets\n"
|
| 723 |
+
]
|
| 724 |
+
}
|
| 725 |
+
],
|
| 726 |
+
"source": [
|
| 727 |
+
"# Dump the best model to a pickle file\n",
|
| 728 |
+
"with open(\"./model/lunge_model_deep_learning.pkl\", \"wb\") as f:\n",
|
| 729 |
+
" pickle.dump(model, f)"
|
| 730 |
+
]
|
| 731 |
+
},
|
| 732 |
+
{
|
| 733 |
+
"cell_type": "code",
|
| 734 |
+
"execution_count": null,
|
| 735 |
+
"metadata": {},
|
| 736 |
+
"outputs": [],
|
| 737 |
+
"source": []
|
| 738 |
+
}
|
| 739 |
+
],
|
| 740 |
+
"metadata": {
|
| 741 |
+
"kernelspec": {
|
| 742 |
+
"display_name": "Python 3.8.13 (conda)",
|
| 743 |
+
"language": "python",
|
| 744 |
+
"name": "python3"
|
| 745 |
+
},
|
| 746 |
+
"language_info": {
|
| 747 |
+
"codemirror_mode": {
|
| 748 |
+
"name": "ipython",
|
| 749 |
+
"version": 3
|
| 750 |
+
},
|
| 751 |
+
"file_extension": ".py",
|
| 752 |
+
"mimetype": "text/x-python",
|
| 753 |
+
"name": "python",
|
| 754 |
+
"nbconvert_exporter": "python",
|
| 755 |
+
"pygments_lexer": "ipython3",
|
| 756 |
+
"version": "3.8.13"
|
| 757 |
+
},
|
| 758 |
+
"orig_nbformat": 4,
|
| 759 |
+
"vscode": {
|
| 760 |
+
"interpreter": {
|
| 761 |
+
"hash": "9260f401923fb5c4108c543a7d176de9733d378b3752e49535ad7c43c2271b65"
|
| 762 |
+
}
|
| 763 |
+
}
|
| 764 |
+
},
|
| 765 |
+
"nbformat": 4,
|
| 766 |
+
"nbformat_minor": 2
|
| 767 |
+
}
|
core/lunge_model/4.stage.detection.ipynb
ADDED
|
@@ -0,0 +1,717 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": 1,
|
| 6 |
+
"metadata": {},
|
| 7 |
+
"outputs": [
|
| 8 |
+
{
|
| 9 |
+
"name": "stderr",
|
| 10 |
+
"output_type": "stream",
|
| 11 |
+
"text": [
|
| 12 |
+
"objc[95030]: Class CaptureDelegate is implemented in both /Users/fuixlabsdev1/Programming/PP/graduation-thesis/env/lib/python3.8/site-packages/mediapipe/.dylibs/libopencv_videoio.3.4.16.dylib (0x111288860) and /Users/fuixlabsdev1/Programming/PP/graduation-thesis/env/lib/python3.8/site-packages/cv2/cv2.abi3.so (0x167476480). One of the two will be used. Which one is undefined.\n",
|
| 13 |
+
"objc[95030]: Class CVWindow is implemented in both /Users/fuixlabsdev1/Programming/PP/graduation-thesis/env/lib/python3.8/site-packages/mediapipe/.dylibs/libopencv_highgui.3.4.16.dylib (0x110c50a68) and /Users/fuixlabsdev1/Programming/PP/graduation-thesis/env/lib/python3.8/site-packages/cv2/cv2.abi3.so (0x1674764d0). One of the two will be used. Which one is undefined.\n",
|
| 14 |
+
"objc[95030]: Class CVView is implemented in both /Users/fuixlabsdev1/Programming/PP/graduation-thesis/env/lib/python3.8/site-packages/mediapipe/.dylibs/libopencv_highgui.3.4.16.dylib (0x110c50a90) and /Users/fuixlabsdev1/Programming/PP/graduation-thesis/env/lib/python3.8/site-packages/cv2/cv2.abi3.so (0x1674764f8). One of the two will be used. Which one is undefined.\n",
|
| 15 |
+
"objc[95030]: Class CVSlider is implemented in both /Users/fuixlabsdev1/Programming/PP/graduation-thesis/env/lib/python3.8/site-packages/mediapipe/.dylibs/libopencv_highgui.3.4.16.dylib (0x110c50ab8) and /Users/fuixlabsdev1/Programming/PP/graduation-thesis/env/lib/python3.8/site-packages/cv2/cv2.abi3.so (0x167476520). One of the two will be used. Which one is undefined.\n"
|
| 16 |
+
]
|
| 17 |
+
}
|
| 18 |
+
],
|
| 19 |
+
"source": [
|
| 20 |
+
"import mediapipe as mp\n",
|
| 21 |
+
"import cv2\n",
|
| 22 |
+
"import numpy as np\n",
|
| 23 |
+
"import pandas as pd\n",
|
| 24 |
+
"import traceback\n",
|
| 25 |
+
"import pickle\n",
|
| 26 |
+
"\n",
|
| 27 |
+
"import warnings\n",
|
| 28 |
+
"warnings.filterwarnings('ignore')\n",
|
| 29 |
+
"\n",
|
| 30 |
+
"# Drawing helpers\n",
|
| 31 |
+
"mp_drawing = mp.solutions.drawing_utils\n",
|
| 32 |
+
"mp_pose = mp.solutions.pose"
|
| 33 |
+
]
|
| 34 |
+
},
|
| 35 |
+
{
|
| 36 |
+
"cell_type": "markdown",
|
| 37 |
+
"metadata": {},
|
| 38 |
+
"source": [
|
| 39 |
+
"### 1. Reconstruct input structure"
|
| 40 |
+
]
|
| 41 |
+
},
|
| 42 |
+
{
|
| 43 |
+
"cell_type": "code",
|
| 44 |
+
"execution_count": 2,
|
| 45 |
+
"metadata": {},
|
| 46 |
+
"outputs": [],
|
| 47 |
+
"source": [
|
| 48 |
+
"# Determine important landmarks for lunge\n",
|
| 49 |
+
"IMPORTANT_LMS = [\n",
|
| 50 |
+
" \"NOSE\",\n",
|
| 51 |
+
" \"LEFT_SHOULDER\",\n",
|
| 52 |
+
" \"RIGHT_SHOULDER\",\n",
|
| 53 |
+
" \"LEFT_HIP\",\n",
|
| 54 |
+
" \"RIGHT_HIP\",\n",
|
| 55 |
+
" \"LEFT_KNEE\",\n",
|
| 56 |
+
" \"RIGHT_KNEE\",\n",
|
| 57 |
+
" \"LEFT_ANKLE\",\n",
|
| 58 |
+
" \"RIGHT_ANKLE\",\n",
|
| 59 |
+
" \"LEFT_HEEL\",\n",
|
| 60 |
+
" \"RIGHT_HEEL\",\n",
|
| 61 |
+
" \"LEFT_FOOT_INDEX\",\n",
|
| 62 |
+
" \"RIGHT_FOOT_INDEX\",\n",
|
| 63 |
+
"]\n",
|
| 64 |
+
"\n",
|
| 65 |
+
"# Generate all columns of the data frame\n",
|
| 66 |
+
"\n",
|
| 67 |
+
"HEADERS = [\"label\"] # Label column\n",
|
| 68 |
+
"\n",
|
| 69 |
+
"for lm in IMPORTANT_LMS:\n",
|
| 70 |
+
" HEADERS += [f\"{lm.lower()}_x\", f\"{lm.lower()}_y\", f\"{lm.lower()}_z\", f\"{lm.lower()}_v\"]"
|
| 71 |
+
]
|
| 72 |
+
},
|
| 73 |
+
{
|
| 74 |
+
"cell_type": "markdown",
|
| 75 |
+
"metadata": {},
|
| 76 |
+
"source": [
|
| 77 |
+
"### 2. Set up important functions"
|
| 78 |
+
]
|
| 79 |
+
},
|
| 80 |
+
{
|
| 81 |
+
"cell_type": "code",
|
| 82 |
+
"execution_count": 3,
|
| 83 |
+
"metadata": {},
|
| 84 |
+
"outputs": [
|
| 85 |
+
{
|
| 86 |
+
"data": {
|
| 87 |
+
"text/plain": [
|
| 88 |
+
"Ellipsis"
|
| 89 |
+
]
|
| 90 |
+
},
|
| 91 |
+
"execution_count": 3,
|
| 92 |
+
"metadata": {},
|
| 93 |
+
"output_type": "execute_result"
|
| 94 |
+
}
|
| 95 |
+
],
|
| 96 |
+
"source": [
|
| 97 |
+
"def extract_important_keypoints(results) -> list:\n",
|
| 98 |
+
" '''\n",
|
| 99 |
+
" Extract important keypoints from mediapipe pose detection\n",
|
| 100 |
+
" '''\n",
|
| 101 |
+
" landmarks = results.pose_landmarks.landmark\n",
|
| 102 |
+
"\n",
|
| 103 |
+
" data = []\n",
|
| 104 |
+
" for lm in IMPORTANT_LMS:\n",
|
| 105 |
+
" keypoint = landmarks[mp_pose.PoseLandmark[lm].value]\n",
|
| 106 |
+
" data.append([keypoint.x, keypoint.y, keypoint.z, keypoint.visibility])\n",
|
| 107 |
+
" \n",
|
| 108 |
+
" return np.array(data).flatten().tolist()\n",
|
| 109 |
+
"\n",
|
| 110 |
+
"\n",
|
| 111 |
+
"def rescale_frame(frame, percent=50):\n",
|
| 112 |
+
" '''\n",
|
| 113 |
+
" Rescale a frame to a certain percentage compare to its original frame\n",
|
| 114 |
+
" '''\n",
|
| 115 |
+
" width = int(frame.shape[1] * percent/ 100)\n",
|
| 116 |
+
" height = int(frame.shape[0] * percent/ 100)\n",
|
| 117 |
+
" dim = (width, height)\n",
|
| 118 |
+
" return cv2.resize(frame, dim, interpolation =cv2.INTER_AREA)\n",
|
| 119 |
+
"\n",
|
| 120 |
+
"\n",
|
| 121 |
+
"def calculate_angle(point1: list, point2: list, point3: list) -> float:\n",
|
| 122 |
+
" '''\n",
|
| 123 |
+
" Calculate the angle between 3 points\n",
|
| 124 |
+
" Unit of the angle will be in Degree\n",
|
| 125 |
+
" '''\n",
|
| 126 |
+
" point1 = np.array(point1)\n",
|
| 127 |
+
" point2 = np.array(point2)\n",
|
| 128 |
+
" point3 = np.array(point3)\n",
|
| 129 |
+
"\n",
|
| 130 |
+
" # Calculate algo\n",
|
| 131 |
+
" angleInRad = np.arctan2(point3[1] - point2[1], point3[0] - point2[0]) - np.arctan2(point1[1] - point2[1], point1[0] - point2[0])\n",
|
| 132 |
+
" angleInDeg = np.abs(angleInRad * 180.0 / np.pi)\n",
|
| 133 |
+
"\n",
|
| 134 |
+
" angleInDeg = angleInDeg if angleInDeg <= 180 else 360 - angleInDeg\n",
|
| 135 |
+
" return angleInDeg\n",
|
| 136 |
+
" \n",
|
| 137 |
+
"\n",
|
| 138 |
+
"def analyze_knee_angle(\n",
|
| 139 |
+
" mp_results, stage: str, angle_thresholds: list, draw_to_image: tuple = None\n",
|
| 140 |
+
"):\n",
|
| 141 |
+
" \"\"\"\n",
|
| 142 |
+
" Calculate angle of each knee while performer at the DOWN position\n",
|
| 143 |
+
"\n",
|
| 144 |
+
" Return result explanation:\n",
|
| 145 |
+
" error: True if at least 1 error\n",
|
| 146 |
+
" right\n",
|
| 147 |
+
" error: True if an error is on the right knee\n",
|
| 148 |
+
" angle: Right knee angle\n",
|
| 149 |
+
" left\n",
|
| 150 |
+
" error: True if an error is on the left knee\n",
|
| 151 |
+
" angle: Left knee angle\n",
|
| 152 |
+
" \"\"\"\n",
|
| 153 |
+
" results = {\n",
|
| 154 |
+
" \"error\": None,\n",
|
| 155 |
+
" \"right\": {\"error\": None, \"angle\": None},\n",
|
| 156 |
+
" \"left\": {\"error\": None, \"angle\": None},\n",
|
| 157 |
+
" }\n",
|
| 158 |
+
"\n",
|
| 159 |
+
" landmarks = mp_results.pose_landmarks.landmark\n",
|
| 160 |
+
"\n",
|
| 161 |
+
" # Calculate right knee angle\n",
|
| 162 |
+
" right_hip = [\n",
|
| 163 |
+
" landmarks[mp_pose.PoseLandmark.RIGHT_HIP.value].x,\n",
|
| 164 |
+
" landmarks[mp_pose.PoseLandmark.RIGHT_HIP.value].y,\n",
|
| 165 |
+
" ]\n",
|
| 166 |
+
" right_knee = [\n",
|
| 167 |
+
" landmarks[mp_pose.PoseLandmark.RIGHT_KNEE.value].x,\n",
|
| 168 |
+
" landmarks[mp_pose.PoseLandmark.RIGHT_KNEE.value].y,\n",
|
| 169 |
+
" ]\n",
|
| 170 |
+
" right_ankle = [\n",
|
| 171 |
+
" landmarks[mp_pose.PoseLandmark.RIGHT_ANKLE.value].x,\n",
|
| 172 |
+
" landmarks[mp_pose.PoseLandmark.RIGHT_ANKLE.value].y,\n",
|
| 173 |
+
" ]\n",
|
| 174 |
+
" results[\"right\"][\"angle\"] = calculate_angle(right_hip, right_knee, right_ankle)\n",
|
| 175 |
+
"\n",
|
| 176 |
+
" # Calculate left knee angle\n",
|
| 177 |
+
" left_hip = [\n",
|
| 178 |
+
" landmarks[mp_pose.PoseLandmark.LEFT_HIP.value].x,\n",
|
| 179 |
+
" landmarks[mp_pose.PoseLandmark.LEFT_HIP.value].y,\n",
|
| 180 |
+
" ]\n",
|
| 181 |
+
" left_knee = [\n",
|
| 182 |
+
" landmarks[mp_pose.PoseLandmark.LEFT_KNEE.value].x,\n",
|
| 183 |
+
" landmarks[mp_pose.PoseLandmark.LEFT_KNEE.value].y,\n",
|
| 184 |
+
" ]\n",
|
| 185 |
+
" left_ankle = [\n",
|
| 186 |
+
" landmarks[mp_pose.PoseLandmark.LEFT_ANKLE.value].x,\n",
|
| 187 |
+
" landmarks[mp_pose.PoseLandmark.LEFT_ANKLE.value].y,\n",
|
| 188 |
+
" ]\n",
|
| 189 |
+
" results[\"left\"][\"angle\"] = calculate_angle(left_hip, left_knee, left_ankle)\n",
|
| 190 |
+
"\n",
|
| 191 |
+
" # Draw to image\n",
|
| 192 |
+
" if draw_to_image is not None and stage != \"down\":\n",
|
| 193 |
+
" (image, video_dimensions) = draw_to_image\n",
|
| 194 |
+
"\n",
|
| 195 |
+
" # Visualize angles\n",
|
| 196 |
+
" cv2.putText(\n",
|
| 197 |
+
" image,\n",
|
| 198 |
+
" str(int(results[\"right\"][\"angle\"])),\n",
|
| 199 |
+
" tuple(np.multiply(right_knee, video_dimensions).astype(int)),\n",
|
| 200 |
+
" cv2.FONT_HERSHEY_COMPLEX,\n",
|
| 201 |
+
" 0.5,\n",
|
| 202 |
+
" (255, 255, 255),\n",
|
| 203 |
+
" 1,\n",
|
| 204 |
+
" cv2.LINE_AA,\n",
|
| 205 |
+
" )\n",
|
| 206 |
+
" cv2.putText(\n",
|
| 207 |
+
" image,\n",
|
| 208 |
+
" str(int(results[\"left\"][\"angle\"])),\n",
|
| 209 |
+
" tuple(np.multiply(left_knee, video_dimensions).astype(int)),\n",
|
| 210 |
+
" cv2.FONT_HERSHEY_COMPLEX,\n",
|
| 211 |
+
" 0.5,\n",
|
| 212 |
+
" (255, 255, 255),\n",
|
| 213 |
+
" 1,\n",
|
| 214 |
+
" cv2.LINE_AA,\n",
|
| 215 |
+
" )\n",
|
| 216 |
+
"\n",
|
| 217 |
+
" if stage != \"down\":\n",
|
| 218 |
+
" return results\n",
|
| 219 |
+
"\n",
|
| 220 |
+
" # Evaluation\n",
|
| 221 |
+
" results[\"error\"] = False\n",
|
| 222 |
+
"\n",
|
| 223 |
+
" if angle_thresholds[0] <= results[\"right\"][\"angle\"] <= angle_thresholds[1]:\n",
|
| 224 |
+
" results[\"right\"][\"error\"] = False\n",
|
| 225 |
+
" else:\n",
|
| 226 |
+
" results[\"right\"][\"error\"] = True\n",
|
| 227 |
+
" results[\"error\"] = True\n",
|
| 228 |
+
"\n",
|
| 229 |
+
" if angle_thresholds[0] <= results[\"left\"][\"angle\"] <= angle_thresholds[1]:\n",
|
| 230 |
+
" results[\"left\"][\"error\"] = False\n",
|
| 231 |
+
" else:\n",
|
| 232 |
+
" results[\"left\"][\"error\"] = True\n",
|
| 233 |
+
" results[\"error\"] = True\n",
|
| 234 |
+
"\n",
|
| 235 |
+
" # Draw to image\n",
|
| 236 |
+
" if draw_to_image is not None:\n",
|
| 237 |
+
" (image, video_dimensions) = draw_to_image\n",
|
| 238 |
+
"\n",
|
| 239 |
+
" right_color = (255, 255, 255) if not results[\"right\"][\"error\"] else (0, 0, 255)\n",
|
| 240 |
+
" left_color = (255, 255, 255) if not results[\"left\"][\"error\"] else (0, 0, 255)\n",
|
| 241 |
+
"\n",
|
| 242 |
+
" right_font_scale = 0.5 if not results[\"right\"][\"error\"] else 1\n",
|
| 243 |
+
" left_font_scale = 0.5 if not results[\"left\"][\"error\"] else 1\n",
|
| 244 |
+
"\n",
|
| 245 |
+
" right_thickness = 1 if not results[\"right\"][\"error\"] else 2\n",
|
| 246 |
+
" left_thickness = 1 if not results[\"left\"][\"error\"] else 2\n",
|
| 247 |
+
"\n",
|
| 248 |
+
" # Visualize angles\n",
|
| 249 |
+
" cv2.putText(\n",
|
| 250 |
+
" image,\n",
|
| 251 |
+
" str(int(results[\"right\"][\"angle\"])),\n",
|
| 252 |
+
" tuple(np.multiply(right_knee, video_dimensions).astype(int)),\n",
|
| 253 |
+
" cv2.FONT_HERSHEY_COMPLEX,\n",
|
| 254 |
+
" right_font_scale,\n",
|
| 255 |
+
" right_color,\n",
|
| 256 |
+
" right_thickness,\n",
|
| 257 |
+
" cv2.LINE_AA,\n",
|
| 258 |
+
" )\n",
|
| 259 |
+
" cv2.putText(\n",
|
| 260 |
+
" image,\n",
|
| 261 |
+
" str(int(results[\"left\"][\"angle\"])),\n",
|
| 262 |
+
" tuple(np.multiply(left_knee, video_dimensions).astype(int)),\n",
|
| 263 |
+
" cv2.FONT_HERSHEY_COMPLEX,\n",
|
| 264 |
+
" left_font_scale,\n",
|
| 265 |
+
" left_color,\n",
|
| 266 |
+
" left_thickness,\n",
|
| 267 |
+
" cv2.LINE_AA,\n",
|
| 268 |
+
" )\n",
|
| 269 |
+
"\n",
|
| 270 |
+
" return results\n",
|
| 271 |
+
"\n",
|
| 272 |
+
"\n"
|
| 273 |
+
]
|
| 274 |
+
},
|
| 275 |
+
{
|
| 276 |
+
"cell_type": "code",
|
| 277 |
+
"execution_count": 4,
|
| 278 |
+
"metadata": {},
|
| 279 |
+
"outputs": [],
|
| 280 |
+
"source": [
|
| 281 |
+
"VIDEO_PATH1 = \"../data/lunge/lunge_test.mp4\""
|
| 282 |
+
]
|
| 283 |
+
},
|
| 284 |
+
{
|
| 285 |
+
"cell_type": "code",
|
| 286 |
+
"execution_count": 5,
|
| 287 |
+
"metadata": {},
|
| 288 |
+
"outputs": [],
|
| 289 |
+
"source": [
|
| 290 |
+
"with open(\"./model/input_scaler.pkl\", \"rb\") as f:\n",
|
| 291 |
+
" input_scaler = pickle.load(f)"
|
| 292 |
+
]
|
| 293 |
+
},
|
| 294 |
+
{
|
| 295 |
+
"cell_type": "markdown",
|
| 296 |
+
"metadata": {},
|
| 297 |
+
"source": [
|
| 298 |
+
"### 3. Detection with Sklearn model"
|
| 299 |
+
]
|
| 300 |
+
},
|
| 301 |
+
{
|
| 302 |
+
"cell_type": "code",
|
| 303 |
+
"execution_count": 6,
|
| 304 |
+
"metadata": {},
|
| 305 |
+
"outputs": [],
|
| 306 |
+
"source": [
|
| 307 |
+
"# Load model\n",
|
| 308 |
+
"with open(\"./model/KNN_model.pkl\", \"rb\") as f:\n",
|
| 309 |
+
" sklearn_model = pickle.load(f)"
|
| 310 |
+
]
|
| 311 |
+
},
|
| 312 |
+
{
|
| 313 |
+
"cell_type": "code",
|
| 314 |
+
"execution_count": 8,
|
| 315 |
+
"metadata": {},
|
| 316 |
+
"outputs": [],
|
| 317 |
+
"source": [
|
| 318 |
+
"cap = cv2.VideoCapture(VIDEO_PATH1)\n",
|
| 319 |
+
"current_stage = \"\"\n",
|
| 320 |
+
"counter = 0\n",
|
| 321 |
+
"prediction_probability_threshold = 0.8\n",
|
| 322 |
+
"ANGLE_THRESHOLDS = [60, 135]\n",
|
| 323 |
+
"\n",
|
| 324 |
+
"with mp_pose.Pose(min_detection_confidence=0.5, min_tracking_confidence=0.5) as pose:\n",
|
| 325 |
+
" while cap.isOpened():\n",
|
| 326 |
+
" ret, image = cap.read()\n",
|
| 327 |
+
"\n",
|
| 328 |
+
" if not ret:\n",
|
| 329 |
+
" break\n",
|
| 330 |
+
"\n",
|
| 331 |
+
" # Reduce size of a frame\n",
|
| 332 |
+
" image = rescale_frame(image, 50)\n",
|
| 333 |
+
" video_dimensions = [image.shape[1], image.shape[0]]\n",
|
| 334 |
+
"\n",
|
| 335 |
+
" # Recolor image from BGR to RGB for mediapipe\n",
|
| 336 |
+
" image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)\n",
|
| 337 |
+
" image.flags.writeable = False\n",
|
| 338 |
+
"\n",
|
| 339 |
+
" results = pose.process(image)\n",
|
| 340 |
+
"\n",
|
| 341 |
+
" if not results.pose_landmarks:\n",
|
| 342 |
+
" print(\"No human found\")\n",
|
| 343 |
+
" continue\n",
|
| 344 |
+
"\n",
|
| 345 |
+
" # Recolor image from BGR to RGB for mediapipe\n",
|
| 346 |
+
" image.flags.writeable = True\n",
|
| 347 |
+
" image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)\n",
|
| 348 |
+
"\n",
|
| 349 |
+
" # Draw landmarks and connections\n",
|
| 350 |
+
" mp_drawing.draw_landmarks(image, results.pose_landmarks, mp_pose.POSE_CONNECTIONS, mp_drawing.DrawingSpec(color=(244, 117, 66), thickness=2, circle_radius=2), mp_drawing.DrawingSpec(color=(245, 66, 230), thickness=2, circle_radius=1))\n",
|
| 351 |
+
"\n",
|
| 352 |
+
" # Make detection\n",
|
| 353 |
+
" try:\n",
|
| 354 |
+
" # Extract keypoints from frame for the input\n",
|
| 355 |
+
" row = extract_important_keypoints(results)\n",
|
| 356 |
+
" X = pd.DataFrame([row], columns=HEADERS[1:])\n",
|
| 357 |
+
" X = pd.DataFrame(input_scaler.transform(X))\n",
|
| 358 |
+
"\n",
|
| 359 |
+
" # Make prediction and its probability\n",
|
| 360 |
+
" predicted_class = sklearn_model.predict(X)[0]\n",
|
| 361 |
+
" prediction_probabilities = sklearn_model.predict_proba(X)[0]\n",
|
| 362 |
+
" prediction_probability = round(prediction_probabilities[prediction_probabilities.argmax()], 2)\n",
|
| 363 |
+
"\n",
|
| 364 |
+
" # Evaluate model prediction\n",
|
| 365 |
+
" if predicted_class == \"I\" and prediction_probability >= prediction_probability_threshold:\n",
|
| 366 |
+
" current_stage = \"init\"\n",
|
| 367 |
+
" elif predicted_class == \"M\" and prediction_probability >= prediction_probability_threshold: \n",
|
| 368 |
+
" current_stage = \"mid\"\n",
|
| 369 |
+
" elif predicted_class == \"D\" and prediction_probability >= prediction_probability_threshold:\n",
|
| 370 |
+
" if current_stage == \"mid\":\n",
|
| 371 |
+
" counter += 1\n",
|
| 372 |
+
" \n",
|
| 373 |
+
" current_stage = \"down\"\n",
|
| 374 |
+
" \n",
|
| 375 |
+
" # Error detection\n",
|
| 376 |
+
" analyze_knee_angle(mp_results=results, stage=current_stage, angle_thresholds=ANGLE_THRESHOLDS, draw_to_image=(image, video_dimensions))\n",
|
| 377 |
+
" \n",
|
| 378 |
+
" # Visualization\n",
|
| 379 |
+
" # Status box\n",
|
| 380 |
+
" cv2.rectangle(image, (0, 0), (400, 60), (245, 117, 16), -1)\n",
|
| 381 |
+
"\n",
|
| 382 |
+
" # Display probability\n",
|
| 383 |
+
" cv2.putText(image, \"PROB\", (15, 12), cv2.FONT_HERSHEY_COMPLEX, 0.5, (0, 0, 0), 1, cv2.LINE_AA)\n",
|
| 384 |
+
" cv2.putText(image, str(prediction_probability), (10, 40), cv2.FONT_HERSHEY_COMPLEX, 1, (255, 255, 255), 2, cv2.LINE_AA)\n",
|
| 385 |
+
"\n",
|
| 386 |
+
" # Display class\n",
|
| 387 |
+
" cv2.putText(image, \"CLASS\", (95, 12), cv2.FONT_HERSHEY_COMPLEX, 0.5, (0, 0, 0), 1, cv2.LINE_AA)\n",
|
| 388 |
+
" cv2.putText(image, current_stage, (90, 40), cv2.FONT_HERSHEY_COMPLEX, 1, (255, 255, 255), 2, cv2.LINE_AA)\n",
|
| 389 |
+
"\n",
|
| 390 |
+
" # Display probability\n",
|
| 391 |
+
" cv2.putText(image, \"COUNTER\", (255, 12), cv2.FONT_HERSHEY_COMPLEX, 0.5, (0, 0, 0), 1, cv2.LINE_AA)\n",
|
| 392 |
+
" cv2.putText(image, str(counter), (250, 40), cv2.FONT_HERSHEY_COMPLEX, 1, (255, 255, 255), 2, cv2.LINE_AA)\n",
|
| 393 |
+
"\n",
|
| 394 |
+
" except Exception as e:\n",
|
| 395 |
+
" print(f\"Error: {e}\")\n",
|
| 396 |
+
" traceback.print_exc()\n",
|
| 397 |
+
" \n",
|
| 398 |
+
" cv2.imshow(\"CV2\", image)\n",
|
| 399 |
+
" \n",
|
| 400 |
+
" # Press Q to close cv2 window\n",
|
| 401 |
+
" if cv2.waitKey(1) & 0xFF == ord('q'):\n",
|
| 402 |
+
" break\n",
|
| 403 |
+
"\n",
|
| 404 |
+
" cap.release()\n",
|
| 405 |
+
" cv2.destroyAllWindows()\n",
|
| 406 |
+
"\n",
|
| 407 |
+
" # (Optional)Fix bugs cannot close windows in MacOS (https://stackoverflow.com/questions/6116564/destroywindow-does-not-close-window-on-mac-using-python-and-opencv)\n",
|
| 408 |
+
" for i in range (1, 5):\n",
|
| 409 |
+
" cv2.waitKey(1)\n",
|
| 410 |
+
" "
|
| 411 |
+
]
|
| 412 |
+
},
|
| 413 |
+
{
|
| 414 |
+
"cell_type": "markdown",
|
| 415 |
+
"metadata": {},
|
| 416 |
+
"source": [
|
| 417 |
+
"### 4. Detection with Deep learning model"
|
| 418 |
+
]
|
| 419 |
+
},
|
| 420 |
+
{
|
| 421 |
+
"cell_type": "code",
|
| 422 |
+
"execution_count": 12,
|
| 423 |
+
"metadata": {},
|
| 424 |
+
"outputs": [
|
| 425 |
+
{
|
| 426 |
+
"name": "stdout",
|
| 427 |
+
"output_type": "stream",
|
| 428 |
+
"text": [
|
| 429 |
+
"Metal device set to: Apple M1\n",
|
| 430 |
+
"\n",
|
| 431 |
+
"systemMemory: 16.00 GB\n",
|
| 432 |
+
"maxCacheSize: 5.33 GB\n",
|
| 433 |
+
"\n"
|
| 434 |
+
]
|
| 435 |
+
},
|
| 436 |
+
{
|
| 437 |
+
"name": "stderr",
|
| 438 |
+
"output_type": "stream",
|
| 439 |
+
"text": [
|
| 440 |
+
"2022-11-14 11:15:32.563689: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:305] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.\n",
|
| 441 |
+
"2022-11-14 11:15:32.564066: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:271] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: <undefined>)\n"
|
| 442 |
+
]
|
| 443 |
+
}
|
| 444 |
+
],
|
| 445 |
+
"source": [
|
| 446 |
+
"# Load model\n",
|
| 447 |
+
"with open(\"./model/lunge_model_deep_learning.pkl\", \"rb\") as f:\n",
|
| 448 |
+
" deep_learning_model = pickle.load(f)"
|
| 449 |
+
]
|
| 450 |
+
},
|
| 451 |
+
{
|
| 452 |
+
"cell_type": "code",
|
| 453 |
+
"execution_count": 43,
|
| 454 |
+
"metadata": {},
|
| 455 |
+
"outputs": [
|
| 456 |
+
{
|
| 457 |
+
"name": "stdout",
|
| 458 |
+
"output_type": "stream",
|
| 459 |
+
"text": [
|
| 460 |
+
"1/1 [==============================] - 0s 10ms/step\n",
|
| 461 |
+
"1/1 [==============================] - 0s 11ms/step\n",
|
| 462 |
+
"1/1 [==============================] - 0s 11ms/step\n",
|
| 463 |
+
"1/1 [==============================] - 0s 10ms/step\n",
|
| 464 |
+
"1/1 [==============================] - 0s 10ms/step\n",
|
| 465 |
+
"1/1 [==============================] - 0s 10ms/step\n",
|
| 466 |
+
"1/1 [==============================] - 0s 11ms/step\n",
|
| 467 |
+
"1/1 [==============================] - 0s 12ms/step\n",
|
| 468 |
+
"1/1 [==============================] - 0s 11ms/step\n",
|
| 469 |
+
"1/1 [==============================] - 0s 10ms/step\n",
|
| 470 |
+
"1/1 [==============================] - 0s 11ms/step\n",
|
| 471 |
+
"1/1 [==============================] - 0s 10ms/step\n",
|
| 472 |
+
"1/1 [==============================] - 0s 12ms/step\n",
|
| 473 |
+
"1/1 [==============================] - 0s 11ms/step\n",
|
| 474 |
+
"1/1 [==============================] - 0s 12ms/step\n",
|
| 475 |
+
"1/1 [==============================] - 0s 10ms/step\n",
|
| 476 |
+
"1/1 [==============================] - 0s 10ms/step\n",
|
| 477 |
+
"1/1 [==============================] - 0s 10ms/step\n",
|
| 478 |
+
"1/1 [==============================] - 0s 10ms/step\n",
|
| 479 |
+
"1/1 [==============================] - 0s 12ms/step\n",
|
| 480 |
+
"1/1 [==============================] - 0s 11ms/step\n",
|
| 481 |
+
"1/1 [==============================] - 0s 11ms/step\n",
|
| 482 |
+
"1/1 [==============================] - 0s 10ms/step\n",
|
| 483 |
+
"1/1 [==============================] - 0s 11ms/step\n",
|
| 484 |
+
"1/1 [==============================] - 0s 10ms/step\n",
|
| 485 |
+
"1/1 [==============================] - 0s 10ms/step\n",
|
| 486 |
+
"1/1 [==============================] - 0s 10ms/step\n",
|
| 487 |
+
"1/1 [==============================] - 0s 10ms/step\n",
|
| 488 |
+
"1/1 [==============================] - 0s 10ms/step\n",
|
| 489 |
+
"1/1 [==============================] - 0s 10ms/step\n",
|
| 490 |
+
"1/1 [==============================] - 0s 10ms/step\n",
|
| 491 |
+
"1/1 [==============================] - 0s 10ms/step\n",
|
| 492 |
+
"1/1 [==============================] - 0s 10ms/step\n",
|
| 493 |
+
"1/1 [==============================] - 0s 10ms/step\n",
|
| 494 |
+
"1/1 [==============================] - 0s 10ms/step\n",
|
| 495 |
+
"1/1 [==============================] - 0s 10ms/step\n",
|
| 496 |
+
"1/1 [==============================] - 0s 10ms/step\n",
|
| 497 |
+
"1/1 [==============================] - 0s 10ms/step\n",
|
| 498 |
+
"1/1 [==============================] - 0s 10ms/step\n",
|
| 499 |
+
"1/1 [==============================] - 0s 10ms/step\n",
|
| 500 |
+
"1/1 [==============================] - 0s 10ms/step\n",
|
| 501 |
+
"1/1 [==============================] - 0s 10ms/step\n",
|
| 502 |
+
"1/1 [==============================] - 0s 10ms/step\n",
|
| 503 |
+
"1/1 [==============================] - 0s 10ms/step\n",
|
| 504 |
+
"1/1 [==============================] - 0s 10ms/step\n",
|
| 505 |
+
"1/1 [==============================] - 0s 10ms/step\n",
|
| 506 |
+
"1/1 [==============================] - 0s 11ms/step\n",
|
| 507 |
+
"1/1 [==============================] - 0s 10ms/step\n",
|
| 508 |
+
"1/1 [==============================] - 0s 10ms/step\n",
|
| 509 |
+
"1/1 [==============================] - 0s 10ms/step\n",
|
| 510 |
+
"1/1 [==============================] - 0s 10ms/step\n",
|
| 511 |
+
"1/1 [==============================] - 0s 10ms/step\n",
|
| 512 |
+
"1/1 [==============================] - 0s 10ms/step\n",
|
| 513 |
+
"1/1 [==============================] - 0s 10ms/step\n",
|
| 514 |
+
"1/1 [==============================] - 0s 10ms/step\n",
|
| 515 |
+
"1/1 [==============================] - 0s 10ms/step\n",
|
| 516 |
+
"1/1 [==============================] - 0s 10ms/step\n",
|
| 517 |
+
"1/1 [==============================] - 0s 10ms/step\n",
|
| 518 |
+
"1/1 [==============================] - 0s 10ms/step\n",
|
| 519 |
+
"1/1 [==============================] - 0s 10ms/step\n",
|
| 520 |
+
"1/1 [==============================] - 0s 10ms/step\n",
|
| 521 |
+
"1/1 [==============================] - 0s 10ms/step\n",
|
| 522 |
+
"1/1 [==============================] - 0s 10ms/step\n",
|
| 523 |
+
"1/1 [==============================] - 0s 10ms/step\n",
|
| 524 |
+
"1/1 [==============================] - 0s 9ms/step\n",
|
| 525 |
+
"1/1 [==============================] - 0s 10ms/step\n",
|
| 526 |
+
"1/1 [==============================] - 0s 10ms/step\n",
|
| 527 |
+
"1/1 [==============================] - 0s 10ms/step\n",
|
| 528 |
+
"1/1 [==============================] - 0s 10ms/step\n",
|
| 529 |
+
"1/1 [==============================] - 0s 10ms/step\n",
|
| 530 |
+
"1/1 [==============================] - 0s 10ms/step\n",
|
| 531 |
+
"1/1 [==============================] - 0s 10ms/step\n",
|
| 532 |
+
"1/1 [==============================] - 0s 10ms/step\n",
|
| 533 |
+
"1/1 [==============================] - 0s 10ms/step\n",
|
| 534 |
+
"1/1 [==============================] - 0s 10ms/step\n",
|
| 535 |
+
"1/1 [==============================] - 0s 10ms/step\n",
|
| 536 |
+
"1/1 [==============================] - 0s 10ms/step\n",
|
| 537 |
+
"1/1 [==============================] - 0s 10ms/step\n",
|
| 538 |
+
"1/1 [==============================] - 0s 10ms/step\n",
|
| 539 |
+
"1/1 [==============================] - 0s 10ms/step\n",
|
| 540 |
+
"1/1 [==============================] - 0s 10ms/step\n",
|
| 541 |
+
"1/1 [==============================] - 0s 10ms/step\n",
|
| 542 |
+
"1/1 [==============================] - 0s 10ms/step\n",
|
| 543 |
+
"1/1 [==============================] - 0s 10ms/step\n",
|
| 544 |
+
"1/1 [==============================] - 0s 19ms/step\n",
|
| 545 |
+
"1/1 [==============================] - 0s 11ms/step\n",
|
| 546 |
+
"1/1 [==============================] - 0s 10ms/step\n",
|
| 547 |
+
"1/1 [==============================] - 0s 10ms/step\n",
|
| 548 |
+
"1/1 [==============================] - 0s 10ms/step\n",
|
| 549 |
+
"1/1 [==============================] - 0s 10ms/step\n",
|
| 550 |
+
"1/1 [==============================] - 0s 10ms/step\n",
|
| 551 |
+
"1/1 [==============================] - 0s 10ms/step\n",
|
| 552 |
+
"1/1 [==============================] - 0s 10ms/step\n",
|
| 553 |
+
"1/1 [==============================] - 0s 10ms/step\n",
|
| 554 |
+
"1/1 [==============================] - 0s 10ms/step\n",
|
| 555 |
+
"1/1 [==============================] - 0s 11ms/step\n",
|
| 556 |
+
"1/1 [==============================] - 0s 10ms/step\n",
|
| 557 |
+
"1/1 [==============================] - 0s 10ms/step\n",
|
| 558 |
+
"1/1 [==============================] - 0s 10ms/step\n",
|
| 559 |
+
"1/1 [==============================] - 0s 10ms/step\n",
|
| 560 |
+
"1/1 [==============================] - 0s 10ms/step\n",
|
| 561 |
+
"1/1 [==============================] - 0s 10ms/step\n",
|
| 562 |
+
"1/1 [==============================] - 0s 10ms/step\n",
|
| 563 |
+
"1/1 [==============================] - 0s 10ms/step\n",
|
| 564 |
+
"1/1 [==============================] - 0s 10ms/step\n",
|
| 565 |
+
"1/1 [==============================] - 0s 10ms/step\n",
|
| 566 |
+
"1/1 [==============================] - 0s 10ms/step\n",
|
| 567 |
+
"1/1 [==============================] - 0s 10ms/step\n",
|
| 568 |
+
"1/1 [==============================] - 0s 10ms/step\n",
|
| 569 |
+
"1/1 [==============================] - 0s 10ms/step\n",
|
| 570 |
+
"1/1 [==============================] - 0s 10ms/step\n",
|
| 571 |
+
"1/1 [==============================] - 0s 10ms/step\n",
|
| 572 |
+
"1/1 [==============================] - 0s 10ms/step\n",
|
| 573 |
+
"1/1 [==============================] - 0s 10ms/step\n",
|
| 574 |
+
"1/1 [==============================] - 0s 10ms/step\n",
|
| 575 |
+
"1/1 [==============================] - 0s 10ms/step\n",
|
| 576 |
+
"1/1 [==============================] - 0s 10ms/step\n",
|
| 577 |
+
"1/1 [==============================] - 0s 10ms/step\n",
|
| 578 |
+
"1/1 [==============================] - 0s 10ms/step\n",
|
| 579 |
+
"1/1 [==============================] - 0s 10ms/step\n",
|
| 580 |
+
"1/1 [==============================] - 0s 10ms/step\n",
|
| 581 |
+
"1/1 [==============================] - 0s 10ms/step\n",
|
| 582 |
+
"1/1 [==============================] - 0s 10ms/step\n",
|
| 583 |
+
"1/1 [==============================] - 0s 10ms/step\n",
|
| 584 |
+
"1/1 [==============================] - 0s 10ms/step\n",
|
| 585 |
+
"1/1 [==============================] - 0s 10ms/step\n"
|
| 586 |
+
]
|
| 587 |
+
}
|
| 588 |
+
],
|
| 589 |
+
"source": [
|
| 590 |
+
"cap = cv2.VideoCapture(VIDEO_PATH1)\n",
|
| 591 |
+
"current_stage = \"\"\n",
|
| 592 |
+
"counter = 0\n",
|
| 593 |
+
"prediction_probability_threshold = 0.6\n",
|
| 594 |
+
"ANGLE_THRESHOLDS = [60, 135]\n",
|
| 595 |
+
"\n",
|
| 596 |
+
"with mp_pose.Pose(min_detection_confidence=0.5, min_tracking_confidence=0.5) as pose:\n",
|
| 597 |
+
" while cap.isOpened():\n",
|
| 598 |
+
" ret, image = cap.read()\n",
|
| 599 |
+
"\n",
|
| 600 |
+
" if not ret:\n",
|
| 601 |
+
" break\n",
|
| 602 |
+
"\n",
|
| 603 |
+
" # Reduce size of a frame\n",
|
| 604 |
+
" image = rescale_frame(image, 50)\n",
|
| 605 |
+
" image = cv2.flip(image, 1)\n",
|
| 606 |
+
" video_dimensions = [image.shape[1], image.shape[0]]\n",
|
| 607 |
+
"\n",
|
| 608 |
+
" # Recolor image from BGR to RGB for mediapipe\n",
|
| 609 |
+
" image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)\n",
|
| 610 |
+
" image.flags.writeable = False\n",
|
| 611 |
+
"\n",
|
| 612 |
+
" results = pose.process(image)\n",
|
| 613 |
+
"\n",
|
| 614 |
+
" if not results.pose_landmarks:\n",
|
| 615 |
+
" print(\"No human found\")\n",
|
| 616 |
+
" continue\n",
|
| 617 |
+
"\n",
|
| 618 |
+
" # Recolor image from BGR to RGB for mediapipe\n",
|
| 619 |
+
" image.flags.writeable = True\n",
|
| 620 |
+
" image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)\n",
|
| 621 |
+
"\n",
|
| 622 |
+
" # Draw landmarks and connections\n",
|
| 623 |
+
" mp_drawing.draw_landmarks(image, results.pose_landmarks, mp_pose.POSE_CONNECTIONS, mp_drawing.DrawingSpec(color=(244, 117, 66), thickness=2, circle_radius=2), mp_drawing.DrawingSpec(color=(245, 66, 230), thickness=2, circle_radius=1))\n",
|
| 624 |
+
"\n",
|
| 625 |
+
" # Make detection\n",
|
| 626 |
+
" try:\n",
|
| 627 |
+
" # Extract keypoints from frame for the input\n",
|
| 628 |
+
" row = extract_important_keypoints(results)\n",
|
| 629 |
+
" X = pd.DataFrame([row, ], columns=HEADERS[1:])\n",
|
| 630 |
+
" X = pd.DataFrame(input_scaler.transform(X))\n",
|
| 631 |
+
" \n",
|
| 632 |
+
"\n",
|
| 633 |
+
" # Make prediction and its probability\n",
|
| 634 |
+
" prediction = deep_learning_model.predict(X)\n",
|
| 635 |
+
" predicted_class = np.argmax(prediction, axis=1)[0]\n",
|
| 636 |
+
" prediction_probability = max(prediction.tolist()[0])\n",
|
| 637 |
+
"\n",
|
| 638 |
+
" # Evaluate model prediction\n",
|
| 639 |
+
" if predicted_class == 0 and prediction_probability >= prediction_probability_threshold:\n",
|
| 640 |
+
" current_stage = \"init\"\n",
|
| 641 |
+
" elif predicted_class == 1 and prediction_probability >= prediction_probability_threshold: \n",
|
| 642 |
+
" current_stage = \"mid\"\n",
|
| 643 |
+
" elif predicted_class == 2 and prediction_probability >= prediction_probability_threshold:\n",
|
| 644 |
+
" if current_stage == \"mid\":\n",
|
| 645 |
+
" counter += 1\n",
|
| 646 |
+
"\n",
|
| 647 |
+
" current_stage = \"down\"\n",
|
| 648 |
+
" \n",
|
| 649 |
+
" analyze_knee_angle(mp_results=results, stage=current_stage, angle_thresholds=ANGLE_THRESHOLDS, draw_to_image=(image, video_dimensions))\n",
|
| 650 |
+
"\n",
|
| 651 |
+
" # Visualization\n",
|
| 652 |
+
" # Status box\n",
|
| 653 |
+
" cv2.rectangle(image, (0, 0), (550, 60), (245, 117, 16), -1)\n",
|
| 654 |
+
"\n",
|
| 655 |
+
" # # Display class\n",
|
| 656 |
+
" cv2.putText(image, \"DETECTION\", (95, 12), cv2.FONT_HERSHEY_COMPLEX, 0.5, (0, 0, 0), 1, cv2.LINE_AA)\n",
|
| 657 |
+
" cv2.putText(image, current_stage, (90, 40), cv2.FONT_HERSHEY_COMPLEX, 1, (255, 255, 255), 2, cv2.LINE_AA)\n",
|
| 658 |
+
"\n",
|
| 659 |
+
" # # Display probability\n",
|
| 660 |
+
" cv2.putText(image, \"PROB\", (15, 12), cv2.FONT_HERSHEY_COMPLEX, 0.5, (0, 0, 0), 1, cv2.LINE_AA)\n",
|
| 661 |
+
" cv2.putText(image, str(round(prediction_probability, 2)), (10, 40), cv2.FONT_HERSHEY_COMPLEX, 1, (255, 255, 255), 2, cv2.LINE_AA)\n",
|
| 662 |
+
"\n",
|
| 663 |
+
" # # Display class\n",
|
| 664 |
+
" cv2.putText(image, \"CLASS\", (225, 12), cv2.FONT_HERSHEY_COMPLEX, 0.5, (0, 0, 0), 1, cv2.LINE_AA)\n",
|
| 665 |
+
" cv2.putText(image, str(predicted_class), (220, 40), cv2.FONT_HERSHEY_COMPLEX, 1, (255, 255, 255), 2, cv2.LINE_AA)\n",
|
| 666 |
+
"\n",
|
| 667 |
+
" # # Display class\n",
|
| 668 |
+
" cv2.putText(image, \"COUNTER\", (350, 12), cv2.FONT_HERSHEY_COMPLEX, 0.5, (0, 0, 0), 1, cv2.LINE_AA)\n",
|
| 669 |
+
" cv2.putText(image, str(counter), (345, 40), cv2.FONT_HERSHEY_COMPLEX, 1, (255, 255, 255), 2, cv2.LINE_AA)\n",
|
| 670 |
+
"\n",
|
| 671 |
+
" except Exception as e:\n",
|
| 672 |
+
" print(f\"Error: {e}\")\n",
|
| 673 |
+
" \n",
|
| 674 |
+
" cv2.imshow(\"CV2\", image)\n",
|
| 675 |
+
" \n",
|
| 676 |
+
" # Press Q to close cv2 window\n",
|
| 677 |
+
" if cv2.waitKey(1) & 0xFF == ord('q'):\n",
|
| 678 |
+
" break\n",
|
| 679 |
+
"\n",
|
| 680 |
+
" cap.release()\n",
|
| 681 |
+
" cv2.destroyAllWindows()\n",
|
| 682 |
+
"\n",
|
| 683 |
+
" # (Optional)Fix bugs cannot close windows in MacOS (https://stackoverflow.com/questions/6116564/destroywindow-does-not-close-window-on-mac-using-python-and-opencv)\n",
|
| 684 |
+
" for i in range (1, 5):\n",
|
| 685 |
+
" cv2.waitKey(1)\n",
|
| 686 |
+
" "
|
| 687 |
+
]
|
| 688 |
+
}
|
| 689 |
+
],
|
| 690 |
+
"metadata": {
|
| 691 |
+
"kernelspec": {
|
| 692 |
+
"display_name": "Python 3.8.13 (conda)",
|
| 693 |
+
"language": "python",
|
| 694 |
+
"name": "python3"
|
| 695 |
+
},
|
| 696 |
+
"language_info": {
|
| 697 |
+
"codemirror_mode": {
|
| 698 |
+
"name": "ipython",
|
| 699 |
+
"version": 3
|
| 700 |
+
},
|
| 701 |
+
"file_extension": ".py",
|
| 702 |
+
"mimetype": "text/x-python",
|
| 703 |
+
"name": "python",
|
| 704 |
+
"nbconvert_exporter": "python",
|
| 705 |
+
"pygments_lexer": "ipython3",
|
| 706 |
+
"version": "3.8.13"
|
| 707 |
+
},
|
| 708 |
+
"orig_nbformat": 4,
|
| 709 |
+
"vscode": {
|
| 710 |
+
"interpreter": {
|
| 711 |
+
"hash": "9260f401923fb5c4108c543a7d176de9733d378b3752e49535ad7c43c2271b65"
|
| 712 |
+
}
|
| 713 |
+
}
|
| 714 |
+
},
|
| 715 |
+
"nbformat": 4,
|
| 716 |
+
"nbformat_minor": 2
|
| 717 |
+
}
|
core/lunge_model/5.err.data.ipynb
ADDED
|
@@ -0,0 +1,562 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": 1,
|
| 6 |
+
"metadata": {},
|
| 7 |
+
"outputs": [
|
| 8 |
+
{
|
| 9 |
+
"name": "stderr",
|
| 10 |
+
"output_type": "stream",
|
| 11 |
+
"text": [
|
| 12 |
+
"objc[68038]: Class CaptureDelegate is implemented in both /Users/fuixlabsdev1/Programming/PP/graduation-thesis/env/lib/python3.8/site-packages/mediapipe/.dylibs/libopencv_videoio.3.4.16.dylib (0x10c548860) and /Users/fuixlabsdev1/Programming/PP/graduation-thesis/env/lib/python3.8/site-packages/cv2/cv2.abi3.so (0x17b5da480). One of the two will be used. Which one is undefined.\n",
|
| 13 |
+
"objc[68038]: Class CVWindow is implemented in both /Users/fuixlabsdev1/Programming/PP/graduation-thesis/env/lib/python3.8/site-packages/mediapipe/.dylibs/libopencv_highgui.3.4.16.dylib (0x107304a68) and /Users/fuixlabsdev1/Programming/PP/graduation-thesis/env/lib/python3.8/site-packages/cv2/cv2.abi3.so (0x17b5da4d0). One of the two will be used. Which one is undefined.\n",
|
| 14 |
+
"objc[68038]: Class CVView is implemented in both /Users/fuixlabsdev1/Programming/PP/graduation-thesis/env/lib/python3.8/site-packages/mediapipe/.dylibs/libopencv_highgui.3.4.16.dylib (0x107304a90) and /Users/fuixlabsdev1/Programming/PP/graduation-thesis/env/lib/python3.8/site-packages/cv2/cv2.abi3.so (0x17b5da4f8). One of the two will be used. Which one is undefined.\n",
|
| 15 |
+
"objc[68038]: Class CVSlider is implemented in both /Users/fuixlabsdev1/Programming/PP/graduation-thesis/env/lib/python3.8/site-packages/mediapipe/.dylibs/libopencv_highgui.3.4.16.dylib (0x107304ab8) and /Users/fuixlabsdev1/Programming/PP/graduation-thesis/env/lib/python3.8/site-packages/cv2/cv2.abi3.so (0x17b5da520). One of the two will be used. Which one is undefined.\n"
|
| 16 |
+
]
|
| 17 |
+
}
|
| 18 |
+
],
|
| 19 |
+
"source": [
|
| 20 |
+
"import mediapipe as mp\n",
|
| 21 |
+
"import cv2\n",
|
| 22 |
+
"import numpy as np\n",
|
| 23 |
+
"import pandas as pd\n",
|
| 24 |
+
"import csv\n",
|
| 25 |
+
"import os\n",
|
| 26 |
+
"import seaborn as sns\n",
|
| 27 |
+
"\n",
|
| 28 |
+
"import warnings\n",
|
| 29 |
+
"warnings.filterwarnings('ignore')\n",
|
| 30 |
+
"\n",
|
| 31 |
+
"# Drawing helpers\n",
|
| 32 |
+
"mp_drawing = mp.solutions.drawing_utils\n",
|
| 33 |
+
"mp_pose = mp.solutions.pose"
|
| 34 |
+
]
|
| 35 |
+
},
|
| 36 |
+
{
|
| 37 |
+
"cell_type": "markdown",
|
| 38 |
+
"metadata": {},
|
| 39 |
+
"source": [
|
| 40 |
+
"### 1. Build dataset from collected video"
|
| 41 |
+
]
|
| 42 |
+
},
|
| 43 |
+
{
|
| 44 |
+
"cell_type": "markdown",
|
| 45 |
+
"metadata": {},
|
| 46 |
+
"source": [
|
| 47 |
+
"### 1.1. Determine important landmarks and set up important functions\n",
|
| 48 |
+
"\n",
|
| 49 |
+
"The error that I try to tackle is **KNEE OVER TOE** error when lunge is at down stage.\n",
|
| 50 |
+
"- \"C\": Correct Form\n",
|
| 51 |
+
"- \"L\": Incorrect Form"
|
| 52 |
+
]
|
| 53 |
+
},
|
| 54 |
+
{
|
| 55 |
+
"cell_type": "code",
|
| 56 |
+
"execution_count": 2,
|
| 57 |
+
"metadata": {},
|
| 58 |
+
"outputs": [],
|
| 59 |
+
"source": [
|
| 60 |
+
"# Determine important landmarks for lunge\n",
|
| 61 |
+
"IMPORTANT_LMS = [\n",
|
| 62 |
+
" \"NOSE\",\n",
|
| 63 |
+
" \"LEFT_SHOULDER\",\n",
|
| 64 |
+
" \"RIGHT_SHOULDER\",\n",
|
| 65 |
+
" \"LEFT_HIP\",\n",
|
| 66 |
+
" \"RIGHT_HIP\",\n",
|
| 67 |
+
" \"LEFT_KNEE\",\n",
|
| 68 |
+
" \"RIGHT_KNEE\",\n",
|
| 69 |
+
" \"LEFT_ANKLE\",\n",
|
| 70 |
+
" \"RIGHT_ANKLE\",\n",
|
| 71 |
+
" \"LEFT_HEEL\",\n",
|
| 72 |
+
" \"RIGHT_HEEL\",\n",
|
| 73 |
+
" \"LEFT_FOOT_INDEX\",\n",
|
| 74 |
+
" \"RIGHT_FOOT_INDEX\",\n",
|
| 75 |
+
"]\n",
|
| 76 |
+
"\n",
|
| 77 |
+
"# Generate all columns of the data frame\n",
|
| 78 |
+
"HEADERS = [\"label\"] # Label column\n",
|
| 79 |
+
"\n",
|
| 80 |
+
"for lm in IMPORTANT_LMS:\n",
|
| 81 |
+
" HEADERS += [f\"{lm.lower()}_x\", f\"{lm.lower()}_y\", f\"{lm.lower()}_z\", f\"{lm.lower()}_v\"]"
|
| 82 |
+
]
|
| 83 |
+
},
|
| 84 |
+
{
|
| 85 |
+
"cell_type": "code",
|
| 86 |
+
"execution_count": 3,
|
| 87 |
+
"metadata": {},
|
| 88 |
+
"outputs": [],
|
| 89 |
+
"source": [
|
| 90 |
+
"def rescale_frame(frame, percent=50):\n",
|
| 91 |
+
" '''\n",
|
| 92 |
+
" Rescale a frame to a certain percentage compare to its original frame\n",
|
| 93 |
+
" '''\n",
|
| 94 |
+
" width = int(frame.shape[1] * percent/ 100)\n",
|
| 95 |
+
" height = int(frame.shape[0] * percent/ 100)\n",
|
| 96 |
+
" dim = (width, height)\n",
|
| 97 |
+
" return cv2.resize(frame, dim, interpolation = cv2.INTER_AREA)\n",
|
| 98 |
+
" \n",
|
| 99 |
+
"\n",
|
| 100 |
+
"def init_csv(dataset_path: str):\n",
|
| 101 |
+
" '''\n",
|
| 102 |
+
" Create a blank csv file with just columns\n",
|
| 103 |
+
" '''\n",
|
| 104 |
+
"\n",
|
| 105 |
+
" # Ignore if file is already exist\n",
|
| 106 |
+
" if os.path.exists(dataset_path):\n",
|
| 107 |
+
" return\n",
|
| 108 |
+
"\n",
|
| 109 |
+
" # Write all the columns to a empty file\n",
|
| 110 |
+
" with open(dataset_path, mode=\"w\", newline=\"\") as f:\n",
|
| 111 |
+
" csv_writer = csv.writer(f, delimiter=\",\", quotechar='\"', quoting=csv.QUOTE_MINIMAL)\n",
|
| 112 |
+
" csv_writer.writerow(HEADERS)\n",
|
| 113 |
+
"\n",
|
| 114 |
+
"\n",
|
| 115 |
+
"def export_landmark_to_csv(dataset_path: str, results, action: str) -> None:\n",
|
| 116 |
+
" '''\n",
|
| 117 |
+
" Export Labeled Data from detected landmark to csv\n",
|
| 118 |
+
" '''\n",
|
| 119 |
+
" landmarks = results.pose_landmarks.landmark\n",
|
| 120 |
+
" keypoints = []\n",
|
| 121 |
+
"\n",
|
| 122 |
+
" try:\n",
|
| 123 |
+
" # Extract coordinate of important landmarks\n",
|
| 124 |
+
" for lm in IMPORTANT_LMS:\n",
|
| 125 |
+
" keypoint = landmarks[mp_pose.PoseLandmark[lm].value]\n",
|
| 126 |
+
" keypoints.append([keypoint.x, keypoint.y, keypoint.z, keypoint.visibility])\n",
|
| 127 |
+
" \n",
|
| 128 |
+
" keypoints = list(np.array(keypoints).flatten())\n",
|
| 129 |
+
"\n",
|
| 130 |
+
" # Insert action as the label (first column)\n",
|
| 131 |
+
" keypoints.insert(0, action)\n",
|
| 132 |
+
"\n",
|
| 133 |
+
" # Append new row to .csv file\n",
|
| 134 |
+
" with open(dataset_path, mode=\"a\", newline=\"\") as f:\n",
|
| 135 |
+
" csv_writer = csv.writer(f, delimiter=\",\", quotechar='\"', quoting=csv.QUOTE_MINIMAL)\n",
|
| 136 |
+
" csv_writer.writerow(keypoints)\n",
|
| 137 |
+
" \n",
|
| 138 |
+
"\n",
|
| 139 |
+
" except Exception as e:\n",
|
| 140 |
+
" print(e)\n",
|
| 141 |
+
" pass\n",
|
| 142 |
+
"\n",
|
| 143 |
+
"\n",
|
| 144 |
+
"def describe_dataset(dataset_path: str):\n",
|
| 145 |
+
" '''\n",
|
| 146 |
+
" Describe dataset\n",
|
| 147 |
+
" '''\n",
|
| 148 |
+
"\n",
|
| 149 |
+
" data = pd.read_csv(dataset_path)\n",
|
| 150 |
+
" print(f\"Headers: {list(data.columns.values)}\")\n",
|
| 151 |
+
" print(f'Number of rows: {data.shape[0]} \\nNumber of columns: {data.shape[1]}\\n')\n",
|
| 152 |
+
" print(f\"Labels: \\n{data['label'].value_counts()}\\n\")\n",
|
| 153 |
+
" print(f\"Missing values: {data.isnull().values.any()}\\n\")\n",
|
| 154 |
+
" \n",
|
| 155 |
+
" duplicate = data[data.duplicated()]\n",
|
| 156 |
+
" print(f\"Duplicate Rows : {len(duplicate.sum(axis=1))}\")\n",
|
| 157 |
+
"\n",
|
| 158 |
+
" return data\n",
|
| 159 |
+
"\n",
|
| 160 |
+
"\n",
|
| 161 |
+
"def remove_duplicate_rows(dataset_path: str):\n",
|
| 162 |
+
" '''\n",
|
| 163 |
+
" Remove duplicated data from the dataset then save it to another files\n",
|
| 164 |
+
" '''\n",
|
| 165 |
+
" \n",
|
| 166 |
+
" df = pd.read_csv(dataset_path)\n",
|
| 167 |
+
" df.drop_duplicates(keep=\"first\", inplace=True)\n",
|
| 168 |
+
" df.to_csv(f\"cleaned_dataset.csv\", sep=',', encoding='utf-8', index=False)\n",
|
| 169 |
+
" \n",
|
| 170 |
+
"\n",
|
| 171 |
+
"def concat_csv_files_with_same_headers(file_paths: list, saved_path: str):\n",
|
| 172 |
+
" '''\n",
|
| 173 |
+
" Concat different csv files\n",
|
| 174 |
+
" '''\n",
|
| 175 |
+
" all_df = []\n",
|
| 176 |
+
" for path in file_paths:\n",
|
| 177 |
+
" df = pd.read_csv(path, index_col=None, header=0)\n",
|
| 178 |
+
" all_df.append(df)\n",
|
| 179 |
+
" \n",
|
| 180 |
+
" results = pd.concat(all_df, axis=0, ignore_index=True)\n",
|
| 181 |
+
" results.to_csv(saved_path, sep=',', encoding='utf-8', index=False)\n",
|
| 182 |
+
"\n",
|
| 183 |
+
"\n",
|
| 184 |
+
"def calculate_angle(point1: list, point2: list, point3: list) -> float:\n",
|
| 185 |
+
" '''\n",
|
| 186 |
+
" Calculate the angle between 3 points\n",
|
| 187 |
+
" Unit of the angle will be in Degree\n",
|
| 188 |
+
" '''\n",
|
| 189 |
+
" point1 = np.array(point1)\n",
|
| 190 |
+
" point2 = np.array(point2)\n",
|
| 191 |
+
" point3 = np.array(point3)\n",
|
| 192 |
+
"\n",
|
| 193 |
+
" # Calculate algo\n",
|
| 194 |
+
" angleInRad = np.arctan2(point3[1] - point2[1], point3[0] - point2[0]) - np.arctan2(point1[1] - point2[1], point1[0] - point2[0])\n",
|
| 195 |
+
" angleInDeg = np.abs(angleInRad * 180.0 / np.pi)\n",
|
| 196 |
+
"\n",
|
| 197 |
+
" angleInDeg = angleInDeg if angleInDeg <= 180 else 360 - angleInDeg\n",
|
| 198 |
+
" return angleInDeg"
|
| 199 |
+
]
|
| 200 |
+
},
|
| 201 |
+
{
|
| 202 |
+
"cell_type": "markdown",
|
| 203 |
+
"metadata": {},
|
| 204 |
+
"source": [
|
| 205 |
+
"### 1.2. Extract data for train set"
|
| 206 |
+
]
|
| 207 |
+
},
|
| 208 |
+
{
|
| 209 |
+
"cell_type": "code",
|
| 210 |
+
"execution_count": 22,
|
| 211 |
+
"metadata": {},
|
| 212 |
+
"outputs": [
|
| 213 |
+
{
|
| 214 |
+
"name": "stderr",
|
| 215 |
+
"output_type": "stream",
|
| 216 |
+
"text": [
|
| 217 |
+
"OpenCV: Couldn't read video stream from file \"../data/lunge/lunge_9.mp4\"\n"
|
| 218 |
+
]
|
| 219 |
+
}
|
| 220 |
+
],
|
| 221 |
+
"source": [
|
| 222 |
+
"DATASET_PATH = \"err.train.csv\"\n",
|
| 223 |
+
"\n",
|
| 224 |
+
"cap = cv2.VideoCapture(\"../data/lunge/lunge_9.mp4\")\n",
|
| 225 |
+
"save_counts = 0\n",
|
| 226 |
+
"\n",
|
| 227 |
+
"init_csv(DATASET_PATH)\n",
|
| 228 |
+
"\n",
|
| 229 |
+
"with mp_pose.Pose(min_detection_confidence=0.8, min_tracking_confidence=0.9) as pose:\n",
|
| 230 |
+
" while cap.isOpened():\n",
|
| 231 |
+
" ret, image = cap.read()\n",
|
| 232 |
+
"\n",
|
| 233 |
+
" if not ret:\n",
|
| 234 |
+
" break\n",
|
| 235 |
+
"\n",
|
| 236 |
+
" # Reduce size of a frame\n",
|
| 237 |
+
" image = rescale_frame(image, 60)\n",
|
| 238 |
+
" # image = cv2.flip(image, 1)\n",
|
| 239 |
+
"\n",
|
| 240 |
+
" video_dimensions = [image.shape[1], image.shape[0]]\n",
|
| 241 |
+
"\n",
|
| 242 |
+
" # Recolor image from BGR to RGB for mediapipe\n",
|
| 243 |
+
" image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)\n",
|
| 244 |
+
" image.flags.writeable = False\n",
|
| 245 |
+
"\n",
|
| 246 |
+
" results = pose.process(image)\n",
|
| 247 |
+
" \n",
|
| 248 |
+
" if not results.pose_landmarks: continue\n",
|
| 249 |
+
"\n",
|
| 250 |
+
" landmarks = results.pose_landmarks.landmark\n",
|
| 251 |
+
"\n",
|
| 252 |
+
" # Recolor image from BGR to RGB for mediapipe\n",
|
| 253 |
+
" image.flags.writeable = True\n",
|
| 254 |
+
" image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)\n",
|
| 255 |
+
"\n",
|
| 256 |
+
" # Draw landmarks and connections\n",
|
| 257 |
+
" mp_drawing.draw_landmarks(image, results.pose_landmarks, mp_pose.POSE_CONNECTIONS, mp_drawing.DrawingSpec(color=(244, 117, 66), thickness=2, circle_radius=4), mp_drawing.DrawingSpec(color=(245, 66, 230), thickness=2, circle_radius=2))\n",
|
| 258 |
+
"\n",
|
| 259 |
+
" # Display the saved count\n",
|
| 260 |
+
" cv2.putText(image, f\"Saved: {save_counts}\", (50, 50), cv2.FONT_HERSHEY_COMPLEX, 2, (0, 0, 0), 2, cv2.LINE_AA)\n",
|
| 261 |
+
"\n",
|
| 262 |
+
" cv2.imshow(\"CV2\", image)\n",
|
| 263 |
+
"\n",
|
| 264 |
+
" # Pressed key for action\n",
|
| 265 |
+
" k = cv2.waitKey(1) & 0xFF\n",
|
| 266 |
+
"\n",
|
| 267 |
+
" # * Press I to save as INIT stage\n",
|
| 268 |
+
" if k == ord('c'): \n",
|
| 269 |
+
" export_landmark_to_csv(DATASET_PATH, results, \"C\")\n",
|
| 270 |
+
" save_counts += 1\n",
|
| 271 |
+
" # * Press M to save as MID stage\n",
|
| 272 |
+
" elif k == ord(\"l\"):\n",
|
| 273 |
+
" export_landmark_to_csv(DATASET_PATH, results, \"L\")\n",
|
| 274 |
+
" save_counts += 1\n",
|
| 275 |
+
"\n",
|
| 276 |
+
" # Press q to stop\n",
|
| 277 |
+
" elif k == ord(\"q\"):\n",
|
| 278 |
+
" break\n",
|
| 279 |
+
" else: continue\n",
|
| 280 |
+
"\n",
|
| 281 |
+
" cap.release()\n",
|
| 282 |
+
" cv2.destroyAllWindows()\n",
|
| 283 |
+
"\n",
|
| 284 |
+
" # (Optional)Fix bugs cannot close windows in MacOS (https://stackoverflow.com/questions/6116564/destroywindow-does-not-close-window-on-mac-using-python-and-opencv)\n",
|
| 285 |
+
" for i in range (1, 5):\n",
|
| 286 |
+
" cv2.waitKey(1)\n",
|
| 287 |
+
" "
|
| 288 |
+
]
|
| 289 |
+
},
|
| 290 |
+
{
|
| 291 |
+
"cell_type": "code",
|
| 292 |
+
"execution_count": 23,
|
| 293 |
+
"metadata": {},
|
| 294 |
+
"outputs": [
|
| 295 |
+
{
|
| 296 |
+
"name": "stdout",
|
| 297 |
+
"output_type": "stream",
|
| 298 |
+
"text": [
|
| 299 |
+
"Headers: ['label', 'nose_x', 'nose_y', 'nose_z', 'nose_v', 'left_shoulder_x', 'left_shoulder_y', 'left_shoulder_z', 'left_shoulder_v', 'right_shoulder_x', 'right_shoulder_y', 'right_shoulder_z', 'right_shoulder_v', 'left_hip_x', 'left_hip_y', 'left_hip_z', 'left_hip_v', 'right_hip_x', 'right_hip_y', 'right_hip_z', 'right_hip_v', 'left_knee_x', 'left_knee_y', 'left_knee_z', 'left_knee_v', 'right_knee_x', 'right_knee_y', 'right_knee_z', 'right_knee_v', 'left_ankle_x', 'left_ankle_y', 'left_ankle_z', 'left_ankle_v', 'right_ankle_x', 'right_ankle_y', 'right_ankle_z', 'right_ankle_v', 'left_heel_x', 'left_heel_y', 'left_heel_z', 'left_heel_v', 'right_heel_x', 'right_heel_y', 'right_heel_z', 'right_heel_v', 'left_foot_index_x', 'left_foot_index_y', 'left_foot_index_z', 'left_foot_index_v', 'right_foot_index_x', 'right_foot_index_y', 'right_foot_index_z', 'right_foot_index_v']\n",
|
| 300 |
+
"Number of rows: 17907 \n",
|
| 301 |
+
"Number of columns: 53\n",
|
| 302 |
+
"\n",
|
| 303 |
+
"Labels: \n",
|
| 304 |
+
"L 9114\n",
|
| 305 |
+
"C 8793\n",
|
| 306 |
+
"Name: label, dtype: int64\n",
|
| 307 |
+
"\n",
|
| 308 |
+
"Missing values: False\n",
|
| 309 |
+
"\n",
|
| 310 |
+
"Duplicate Rows : 0\n"
|
| 311 |
+
]
|
| 312 |
+
}
|
| 313 |
+
],
|
| 314 |
+
"source": [
|
| 315 |
+
"df = describe_dataset(DATASET_PATH)"
|
| 316 |
+
]
|
| 317 |
+
},
|
| 318 |
+
{
|
| 319 |
+
"cell_type": "markdown",
|
| 320 |
+
"metadata": {},
|
| 321 |
+
"source": [
|
| 322 |
+
"### 1.3. Extract data for test set"
|
| 323 |
+
]
|
| 324 |
+
},
|
| 325 |
+
{
|
| 326 |
+
"cell_type": "code",
|
| 327 |
+
"execution_count": 13,
|
| 328 |
+
"metadata": {},
|
| 329 |
+
"outputs": [],
|
| 330 |
+
"source": [
|
| 331 |
+
"TEST_DATASET_PATH = \"err.test.csv\"\n",
|
| 332 |
+
"\n",
|
| 333 |
+
"cap = cv2.VideoCapture(\"../data/lunge/lunge_test_5.mp4\")\n",
|
| 334 |
+
"save_counts = 0\n",
|
| 335 |
+
"\n",
|
| 336 |
+
"init_csv(TEST_DATASET_PATH)\n",
|
| 337 |
+
"\n",
|
| 338 |
+
"with mp_pose.Pose(min_detection_confidence=0.5, min_tracking_confidence=0.8) as pose:\n",
|
| 339 |
+
" while cap.isOpened():\n",
|
| 340 |
+
" ret, image = cap.read()\n",
|
| 341 |
+
"\n",
|
| 342 |
+
" if not ret:\n",
|
| 343 |
+
" break\n",
|
| 344 |
+
"\n",
|
| 345 |
+
" # Reduce size of a frame\n",
|
| 346 |
+
" image = rescale_frame(image, 60)\n",
|
| 347 |
+
" image = cv2.flip(image, 1)\n",
|
| 348 |
+
"\n",
|
| 349 |
+
" video_dimensions = [image.shape[1], image.shape[0]]\n",
|
| 350 |
+
"\n",
|
| 351 |
+
" # Recolor image from BGR to RGB for mediapipe\n",
|
| 352 |
+
" image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)\n",
|
| 353 |
+
" image.flags.writeable = False\n",
|
| 354 |
+
"\n",
|
| 355 |
+
" results = pose.process(image)\n",
|
| 356 |
+
" \n",
|
| 357 |
+
" if not results.pose_landmarks: continue\n",
|
| 358 |
+
"\n",
|
| 359 |
+
" landmarks = results.pose_landmarks.landmark\n",
|
| 360 |
+
"\n",
|
| 361 |
+
" # Recolor image from BGR to RGB for mediapipe\n",
|
| 362 |
+
" image.flags.writeable = True\n",
|
| 363 |
+
" image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)\n",
|
| 364 |
+
"\n",
|
| 365 |
+
" # Draw landmarks and connections\n",
|
| 366 |
+
" mp_drawing.draw_landmarks(image, results.pose_landmarks, mp_pose.POSE_CONNECTIONS, mp_drawing.DrawingSpec(color=(244, 117, 66), thickness=2, circle_radius=4), mp_drawing.DrawingSpec(color=(245, 66, 230), thickness=2, circle_radius=2))\n",
|
| 367 |
+
"\n",
|
| 368 |
+
" # Display the saved count\n",
|
| 369 |
+
" cv2.putText(image, f\"Saved: {save_counts}\", (50, 50), cv2.FONT_HERSHEY_COMPLEX, 2, (0, 0, 0), 2, cv2.LINE_AA)\n",
|
| 370 |
+
"\n",
|
| 371 |
+
" cv2.imshow(\"CV2\", image)\n",
|
| 372 |
+
"\n",
|
| 373 |
+
" # Pressed key for action\n",
|
| 374 |
+
" k = cv2.waitKey(10) & 0xFF\n",
|
| 375 |
+
"\n",
|
| 376 |
+
" # * Press C to save as Correct stage\n",
|
| 377 |
+
" if k == ord('c'): \n",
|
| 378 |
+
" export_landmark_to_csv(TEST_DATASET_PATH, results, \"C\")\n",
|
| 379 |
+
" save_counts += 1\n",
|
| 380 |
+
" # * Press L to save as Incorrect stage\n",
|
| 381 |
+
" elif k == ord(\"l\"):\n",
|
| 382 |
+
" export_landmark_to_csv(TEST_DATASET_PATH, results, \"L\")\n",
|
| 383 |
+
" save_counts += 1\n",
|
| 384 |
+
"\n",
|
| 385 |
+
" # Press q to stop\n",
|
| 386 |
+
" elif k == ord(\"q\"):\n",
|
| 387 |
+
" break\n",
|
| 388 |
+
" else: continue\n",
|
| 389 |
+
"\n",
|
| 390 |
+
" cap.release()\n",
|
| 391 |
+
" cv2.destroyAllWindows()\n",
|
| 392 |
+
"\n",
|
| 393 |
+
" # (Optional)Fix bugs cannot close windows in MacOS (https://stackoverflow.com/questions/6116564/destroywindow-does-not-close-window-on-mac-using-python-and-opencv)\n",
|
| 394 |
+
" for i in range (1, 5):\n",
|
| 395 |
+
" cv2.waitKey(1)\n",
|
| 396 |
+
" "
|
| 397 |
+
]
|
| 398 |
+
},
|
| 399 |
+
{
|
| 400 |
+
"cell_type": "code",
|
| 401 |
+
"execution_count": 14,
|
| 402 |
+
"metadata": {},
|
| 403 |
+
"outputs": [
|
| 404 |
+
{
|
| 405 |
+
"name": "stdout",
|
| 406 |
+
"output_type": "stream",
|
| 407 |
+
"text": [
|
| 408 |
+
"Headers: ['label', 'nose_x', 'nose_y', 'nose_z', 'nose_v', 'left_shoulder_x', 'left_shoulder_y', 'left_shoulder_z', 'left_shoulder_v', 'right_shoulder_x', 'right_shoulder_y', 'right_shoulder_z', 'right_shoulder_v', 'left_hip_x', 'left_hip_y', 'left_hip_z', 'left_hip_v', 'right_hip_x', 'right_hip_y', 'right_hip_z', 'right_hip_v', 'left_knee_x', 'left_knee_y', 'left_knee_z', 'left_knee_v', 'right_knee_x', 'right_knee_y', 'right_knee_z', 'right_knee_v', 'left_ankle_x', 'left_ankle_y', 'left_ankle_z', 'left_ankle_v', 'right_ankle_x', 'right_ankle_y', 'right_ankle_z', 'right_ankle_v', 'left_heel_x', 'left_heel_y', 'left_heel_z', 'left_heel_v', 'right_heel_x', 'right_heel_y', 'right_heel_z', 'right_heel_v', 'left_foot_index_x', 'left_foot_index_y', 'left_foot_index_z', 'left_foot_index_v', 'right_foot_index_x', 'right_foot_index_y', 'right_foot_index_z', 'right_foot_index_v']\n",
|
| 409 |
+
"Number of rows: 783 \n",
|
| 410 |
+
"Number of columns: 53\n",
|
| 411 |
+
"\n",
|
| 412 |
+
"Labels: \n",
|
| 413 |
+
"L 406\n",
|
| 414 |
+
"C 377\n",
|
| 415 |
+
"Name: label, dtype: int64\n",
|
| 416 |
+
"\n",
|
| 417 |
+
"Missing values: False\n",
|
| 418 |
+
"\n",
|
| 419 |
+
"Duplicate Rows : 0\n"
|
| 420 |
+
]
|
| 421 |
+
}
|
| 422 |
+
],
|
| 423 |
+
"source": [
|
| 424 |
+
"test_df = describe_dataset(TEST_DATASET_PATH)"
|
| 425 |
+
]
|
| 426 |
+
},
|
| 427 |
+
{
|
| 428 |
+
"cell_type": "markdown",
|
| 429 |
+
"metadata": {},
|
| 430 |
+
"source": [
|
| 431 |
+
"## 3. Data Visualization"
|
| 432 |
+
]
|
| 433 |
+
},
|
| 434 |
+
{
|
| 435 |
+
"cell_type": "markdown",
|
| 436 |
+
"metadata": {},
|
| 437 |
+
"source": [
|
| 438 |
+
"### 3.1. Train set"
|
| 439 |
+
]
|
| 440 |
+
},
|
| 441 |
+
{
|
| 442 |
+
"cell_type": "code",
|
| 443 |
+
"execution_count": 6,
|
| 444 |
+
"metadata": {},
|
| 445 |
+
"outputs": [
|
| 446 |
+
{
|
| 447 |
+
"name": "stdout",
|
| 448 |
+
"output_type": "stream",
|
| 449 |
+
"text": [
|
| 450 |
+
"Headers: ['label', 'nose_x', 'nose_y', 'nose_z', 'nose_v', 'left_shoulder_x', 'left_shoulder_y', 'left_shoulder_z', 'left_shoulder_v', 'right_shoulder_x', 'right_shoulder_y', 'right_shoulder_z', 'right_shoulder_v', 'left_hip_x', 'left_hip_y', 'left_hip_z', 'left_hip_v', 'right_hip_x', 'right_hip_y', 'right_hip_z', 'right_hip_v', 'left_knee_x', 'left_knee_y', 'left_knee_z', 'left_knee_v', 'right_knee_x', 'right_knee_y', 'right_knee_z', 'right_knee_v', 'left_ankle_x', 'left_ankle_y', 'left_ankle_z', 'left_ankle_v', 'right_ankle_x', 'right_ankle_y', 'right_ankle_z', 'right_ankle_v', 'left_heel_x', 'left_heel_y', 'left_heel_z', 'left_heel_v', 'right_heel_x', 'right_heel_y', 'right_heel_z', 'right_heel_v', 'left_foot_index_x', 'left_foot_index_y', 'left_foot_index_z', 'left_foot_index_v', 'right_foot_index_x', 'right_foot_index_y', 'right_foot_index_z', 'right_foot_index_v']\n",
|
| 451 |
+
"Number of rows: 17907 \n",
|
| 452 |
+
"Number of columns: 53\n",
|
| 453 |
+
"\n",
|
| 454 |
+
"Labels: \n",
|
| 455 |
+
"L 9114\n",
|
| 456 |
+
"C 8793\n",
|
| 457 |
+
"Name: label, dtype: int64\n",
|
| 458 |
+
"\n",
|
| 459 |
+
"Missing values: False\n",
|
| 460 |
+
"\n",
|
| 461 |
+
"Duplicate Rows : 0\n"
|
| 462 |
+
]
|
| 463 |
+
},
|
| 464 |
+
{
|
| 465 |
+
"data": {
|
| 466 |
+
"text/plain": [
|
| 467 |
+
"<AxesSubplot:xlabel='label', ylabel='count'>"
|
| 468 |
+
]
|
| 469 |
+
},
|
| 470 |
+
"execution_count": 6,
|
| 471 |
+
"metadata": {},
|
| 472 |
+
"output_type": "execute_result"
|
| 473 |
+
},
|
| 474 |
+
{
|
| 475 |
+
"data": {
|
| 476 |
+
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAkQAAAGwCAYAAABIC3rIAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/NK7nSAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAiMklEQVR4nO3dfVSUdf7/8dcEMqLCJAqMrKPhiVULLcNC7EZOKloRdTwna3E5eTS1xSRS0zyuZfYV0ko9yUbqVrqpq2dv3NzdliQ3OZk3KEmpobYbm7oxYhsM3hCgzu+P1uu3I2aGMAN+no9zOMe55s3M++oc83mumQGb1+v1CgAAwGDXBHoBAACAQCOIAACA8QgiAABgPIIIAAAYjyACAADGI4gAAIDxCCIAAGC84EAv0FacO3dOX331lcLCwmSz2QK9DgAAuAxer1cnTpxQTEyMrrnm+68DEUSX6auvvpLL5Qr0GgAAoAmOHDmi7t27f+/9BNFlCgsLk/Tdf9Dw8PAAbwMAAC5HTU2NXC6X9e/49yGILtP5l8nCw8MJIgAA2pgfersLb6oGAADGI4gAAIDxCCIAAGA8gggAABiPIAIAAMYjiAAAgPEIIgAAYDyCCAAAGI8gAgAAxiOIAACA8QgiAABgPIIIAAAYjyACAADGI4gAAIDxCCIAAGC84EAvAACmGDlnfaBXAFqdghceDvQKkrhCBAAAQBABAAAQRAAAwHgEEQAAMB5BBAAAjEcQAQAA4xFEAADAeAQRAAAwHkEEAACMRxABAADjEUQAAMB4/C6zVmb3wNsCvQLQ6gzcXRzoFQBc5bhCBAAAjEcQAQAA4xFEAADAeAQRAAAwHkEEAACMRxABAADjEUQAAMB4BBEAADAeQQQAAIxHEAEAAOMRRAAAwHgEEQAAMB5BBAAAjEcQAQAA4xFEAADAeAQRAAAwHkEEAACMRxABAADjEUQAAMB4BBEAADAeQQQAAIxHEAEAAOMRRAAAwHgEEQAAMB5BBAAAjEcQAQAA4wU0iM6cOaNf/vKXio2NVWhoqHr16qV58+bp3Llz1ozX69XcuXMVExOj0NBQJScna//+/T6PU1dXpylTpqhr167q2LGj0tLSdPToUZ+ZqqoqZWRkyOFwyOFwKCMjQ9XV1f44TQAA0MoFNIgWLFig119/XXl5eSorK9PChQv10ksvaenSpdbMwoULtWjRIuXl5WnXrl1yOp0aPny4Tpw4Yc1kZ2drw4YNWrdunbZu3aqTJ08qNTVVZ8+etWbS09NVWlqqgoICFRQUqLS0VBkZGX49XwAA0DoFB/LJt2/frgceeED33XefJOm6667Tb3/7W+3evVvSd1eHlixZotmzZ2vUqFGSpFWrVik6Olpr167VpEmT5PF49MYbb+jtt9/WsGHDJEmrV6+Wy+XS+++/rxEjRqisrEwFBQXasWOHEhMTJUkrVqxQUlKSDh48qN69ewfg7AEAQGsR0CtEd9xxhzZv3qxDhw5Jkj755BNt3bpV9957rySpvLxcbrdbKSkp1vfY7XYNGTJE27ZtkySVlJSooaHBZyYmJkbx8fHWzPbt2+VwOKwYkqRBgwbJ4XBYMxeqq6tTTU2NzxcAALg6BfQK0cyZM+XxeNSnTx8FBQXp7Nmzmj9/vn72s59JktxutyQpOjra5/uio6P15ZdfWjMhISHq3Llzo5nz3+92uxUVFdXo+aOioqyZC+Xm5ur555+/shMEAABtQkCvEK1fv16rV6/W2rVr9fHHH2vVqlV6+eWXtWrVKp85m83mc9vr9TY6dqELZy42f6nHmTVrljwej/V15MiRyz0tAADQxgT0CtHTTz+tZ555Ro888ogkqV+/fvryyy+Vm5urRx99VE6nU9J3V3i6detmfV9lZaV11cjpdKq+vl5VVVU+V4kqKys1ePBga+bYsWONnv/48eONrj6dZ7fbZbfbm+dEAQBAqxbQK0SnT5/WNdf4rhAUFGR97D42NlZOp1OFhYXW/fX19SoqKrJiJyEhQe3atfOZqaio0L59+6yZpKQkeTweFRcXWzM7d+6Ux+OxZgAAgLkCeoXo/vvv1/z589WjRw/deOON2rNnjxYtWqRx48ZJ+u5lruzsbOXk5CguLk5xcXHKyclRhw4dlJ6eLklyOBwaP368pk2bpi5duigiIkLTp09Xv379rE+d9e3bVyNHjtSECRO0bNkySdLEiROVmprKJ8wAAEBgg2jp0qWaM2eOMjMzVVlZqZiYGE2aNEnPPvusNTNjxgzV1tYqMzNTVVVVSkxM1KZNmxQWFmbNLF68WMHBwRo9erRqa2s1dOhQrVy5UkFBQdbMmjVrlJWVZX0aLS0tTXl5ef47WQAA0GrZvF6vN9BLtAU1NTVyOBzyeDwKDw9vsefZPfC2FntsoK0auLv4h4fagJFz1gd6BaDVKXjh4RZ9/Mv995vfZQYAAIxHEAEAAOMRRAAAwHgEEQAAMB5BBAAAjEcQAQAA4xFEAADAeAQRAAAwHkEEAACMRxABAADjEUQAAMB4BBEAADAeQQQAAIxHEAEAAOMRRAAAwHgEEQAAMB5BBAAAjEcQAQAA4xFEAADAeAQRAAAwHkEEAACMRxABAADjEUQAAMB4BBEAADAeQQQAAIxHEAEAAOMRRAAAwHgEEQAAMB5BBAAAjEcQAQAA4xFEAADAeAQRAAAwHkEEAACMRxABAADjEUQAAMB4BBEAADAeQQQAAIxHEAEAAOMRRAAAwHgEEQAAMB5BBAAAjEcQAQAA4xFEAADAeAQRAAAwHkEEAACMRxABAADjEUQAAMB4BBEAADAeQQQAAIxHEAEAAOMRRAAAwHgEEQAAMB5BBAAAjEcQAQAA4xFEAADAeAQRAAAwHkEEAACMRxABAADjEUQAAMB4BBEAADAeQQQAAIxHEAEAAOMRRAAAwHgEEQAAMB5BBAAAjEcQAQAA4wU8iP7973/r5z//ubp06aIOHTro5ptvVklJiXW/1+vV3LlzFRMTo9DQUCUnJ2v//v0+j1FXV6cpU6aoa9eu6tixo9LS0nT06FGfmaqqKmVkZMjhcMjhcCgjI0PV1dX+OEUAANDKBTSIqqqqdPvtt6tdu3b629/+ps8++0yvvPKKrr32Wmtm4cKFWrRokfLy8rRr1y45nU4NHz5cJ06csGays7O1YcMGrVu3Tlu3btXJkyeVmpqqs2fPWjPp6ekqLS1VQUGBCgoKVFpaqoyMDH+eLgAAaKWCA/nkCxYskMvl0ltvvWUdu+6666w/e71eLVmyRLNnz9aoUaMkSatWrVJ0dLTWrl2rSZMmyePx6I033tDbb7+tYcOGSZJWr14tl8ul999/XyNGjFBZWZkKCgq0Y8cOJSYmSpJWrFihpKQkHTx4UL179/bfSQMAgFYnoFeINm7cqIEDB+qhhx5SVFSUBgwYoBUrVlj3l5eXy+12KyUlxTpmt9s1ZMgQbdu2TZJUUlKihoYGn5mYmBjFx8dbM9u3b5fD4bBiSJIGDRokh8NhzVyorq5ONTU1Pl8AAODqFNAg+uKLL5Sfn6+4uDi99957evzxx5WVlaXf/OY3kiS32y1Jio6O9vm+6Oho6z63262QkBB17tz5kjNRUVGNnj8qKsqauVBubq71fiOHwyGXy3VlJwsAAFqtgAbRuXPndMsttygnJ0cDBgzQpEmTNGHCBOXn5/vM2Ww2n9ter7fRsQtdOHOx+Us9zqxZs+TxeKyvI0eOXO5pAQCANiagQdStWzfdcMMNPsf69u2rw4cPS5KcTqckNbqKU1lZaV01cjqdqq+vV1VV1SVnjh071uj5jx8/3ujq03l2u13h4eE+XwAA4OoU0CC6/fbbdfDgQZ9jhw4dUs+ePSVJsbGxcjqdKiwstO6vr69XUVGRBg8eLElKSEhQu3btfGYqKiq0b98+ayYpKUkej0fFxcXWzM6dO+XxeKwZAABgroB+yuypp57S4MGDlZOTo9GjR6u4uFjLly/X8uXLJX33Mld2drZycnIUFxenuLg45eTkqEOHDkpPT5ckORwOjR8/XtOmTVOXLl0UERGh6dOnq1+/ftanzvr27auRI0dqwoQJWrZsmSRp4sSJSk1N5RNmAAAgsEF06623asOGDZo1a5bmzZun2NhYLVmyRGPGjLFmZsyYodraWmVmZqqqqkqJiYnatGmTwsLCrJnFixcrODhYo0ePVm1trYYOHaqVK1cqKCjImlmzZo2ysrKsT6OlpaUpLy/PfycLAABaLZvX6/UGeom2oKamRg6HQx6Pp0XfT7R74G0t9thAWzVwd/EPD7UBI+esD/QKQKtT8MLDLfr4l/vvd8B/dQcAAECgEUQAAMB4BBEAADAeQQQAAIxHEAEAAOMRRAAAwHgEEQAAMB5BBAAAjEcQAQAA4xFEAADAeAQRAAAwHkEEAACMRxABAADjEUQAAMB4BBEAADAeQQQAAIxHEAEAAOMRRAAAwHgEEQAAMB5BBAAAjEcQAQAA4xFEAADAeAQRAAAwXpOC6O6771Z1dXWj4zU1Nbr77ruvdCcAAAC/alIQbdmyRfX19Y2Of/vtt/rwww+veCkAAAB/Cv4xw59++qn1588++0xut9u6ffbsWRUUFOgnP/lJ820HAADgBz8qiG6++WbZbDbZbLaLvjQWGhqqpUuXNttyAAAA/vCjgqi8vFxer1e9evVScXGxIiMjrftCQkIUFRWloKCgZl8SAACgJf2oIOrZs6ck6dy5cy2yDAAAQCD8qCD6X4cOHdKWLVtUWVnZKJCeffbZK14MAADAX5oURCtWrNAvfvELde3aVU6nUzabzbrPZrMRRAAAoE1pUhD93//9n+bPn6+ZM2c29z4AAAB+16SfQ1RVVaWHHnqouXcBAAAIiCYF0UMPPaRNmzY19y4AAAAB0aSXzK6//nrNmTNHO3bsUL9+/dSuXTuf+7OyspplOQAAAH9oUhAtX75cnTp1UlFRkYqKinzus9lsBBEAAGhTmhRE5eXlzb0HAABAwDTpPUQAAABXkyZdIRo3btwl73/zzTebtAwAAEAgNCmIqqqqfG43NDRo3759qq6uvugvfQUAAGjNmhREGzZsaHTs3LlzyszMVK9eva54KQAAAH9qtvcQXXPNNXrqqae0ePHi5npIAAAAv2jWN1X/85//1JkzZ5rzIQEAAFpck14ymzp1qs9tr9eriooK/fWvf9Wjjz7aLIsBAAD4S5OCaM+ePT63r7nmGkVGRuqVV175wU+gAQAAtDZNCqIPPvigufcAAAAImCYF0XnHjx/XwYMHZbPZ9NOf/lSRkZHNtRcAAIDfNOlN1adOndK4cePUrVs33XXXXbrzzjsVExOj8ePH6/Tp0829IwAAQItqUhBNnTpVRUVF+vOf/6zq6mpVV1frnXfeUVFRkaZNm9bcOwIAALSoJr1k9oc//EG///3vlZycbB279957FRoaqtGjRys/P7+59gMAAGhxTbpCdPr0aUVHRzc6HhUVxUtmAACgzWlSECUlJem5557Tt99+ax2rra3V888/r6SkpGZbDgAAwB+a9JLZkiVLdM8996h79+666aabZLPZVFpaKrvdrk2bNjX3jgAAAC2qSUHUr18/ff7551q9erUOHDggr9erRx55RGPGjFFoaGhz7wgAANCimhREubm5io6O1oQJE3yOv/nmmzp+/LhmzpzZLMsBAAD4Q5PeQ7Rs2TL16dOn0fEbb7xRr7/++hUvBQAA4E9NCiK3261u3bo1Oh4ZGamKioorXgoAAMCfmhRELpdLH330UaPjH330kWJiYq54KQAAAH9q0nuIHnvsMWVnZ6uhoUF33323JGnz5s2aMWMGP6kaAAC0OU0KohkzZuibb75RZmam6uvrJUnt27fXzJkzNWvWrGZdEAAAoKU1KYhsNpsWLFigOXPmqKysTKGhoYqLi5Pdbm/u/QAAAFpck4LovE6dOunWW29trl0AAAACoklvqgYAALiaEEQAAMB4BBEAADAeQQQAAIxHEAEAAOO1miDKzc2VzWZTdna2dczr9Wru3LmKiYlRaGiokpOTtX//fp/vq6ur05QpU9S1a1d17NhRaWlpOnr0qM9MVVWVMjIy5HA45HA4lJGRoerqaj+cFQAAaAtaRRDt2rVLy5cvV//+/X2OL1y4UIsWLVJeXp527dolp9Op4cOH68SJE9ZMdna2NmzYoHXr1mnr1q06efKkUlNTdfbsWWsmPT1dpaWlKigoUEFBgUpLS5WRkeG38wMAAK1bwIPo5MmTGjNmjFasWKHOnTtbx71er5YsWaLZs2dr1KhRio+P16pVq3T69GmtXbtWkuTxePTGG2/olVde0bBhwzRgwACtXr1ae/fu1fvvvy9JKisrU0FBgX79618rKSlJSUlJWrFihf7yl7/o4MGD37tXXV2dampqfL4AAMDVKeBBNHnyZN13330aNmyYz/Hy8nK53W6lpKRYx+x2u4YMGaJt27ZJkkpKStTQ0OAzExMTo/j4eGtm+/btcjgcSkxMtGYGDRokh8NhzVxMbm6u9RKbw+GQy+VqlvMFAACtT0CDaN26dfr444+Vm5vb6D632y1Jio6O9jkeHR1t3ed2uxUSEuJzZeliM1FRUY0ePyoqypq5mFmzZsnj8VhfR44c+XEnBwAA2owr+tUdV+LIkSN68skntWnTJrVv3/5752w2m89tr9fb6NiFLpy52PwPPY7dbud3swEAYIiAXSEqKSlRZWWlEhISFBwcrODgYBUVFenVV19VcHCwdWXowqs4lZWV1n1Op1P19fWqqqq65MyxY8caPf/x48cbXX0CAABmClgQDR06VHv37lVpaan1NXDgQI0ZM0alpaXq1auXnE6nCgsLre+pr69XUVGRBg8eLElKSEhQu3btfGYqKiq0b98+ayYpKUkej0fFxcXWzM6dO+XxeKwZAABgtoC9ZBYWFqb4+HifYx07dlSXLl2s49nZ2crJyVFcXJzi4uKUk5OjDh06KD09XZLkcDg0fvx4TZs2TV26dFFERISmT5+ufv36WW/S7tu3r0aOHKkJEyZo2bJlkqSJEycqNTVVvXv39uMZAwCA1ipgQXQ5ZsyYodraWmVmZqqqqkqJiYnatGmTwsLCrJnFixcrODhYo0ePVm1trYYOHaqVK1cqKCjImlmzZo2ysrKsT6OlpaUpLy/P7+cDAABaJ5vX6/UGeom2oKamRg6HQx6PR+Hh4S32PLsH3tZijw20VQN3F//wUBswcs76QK8AtDoFLzzcoo9/uf9+B/znEAEAAAQaQQQAAIxHEAEAAOMRRAAAwHgEEQAAMB5BBAAAjEcQAQAA4xFEAADAeAQRAAAwHkEEAACMRxABAADjEUQAAMB4BBEAADAeQQQAAIxHEAEAAOMRRAAAwHgEEQAAMB5BBAAAjEcQAQAA4xFEAADAeAQRAAAwHkEEAACMRxABAADjEUQAAMB4BBEAADAeQQQAAIxHEAEAAOMRRAAAwHgEEQAAMB5BBAAAjEcQAQAA4xFEAADAeAQRAAAwHkEEAACMRxABAADjEUQAAMB4BBEAADAeQQQAAIxHEAEAAOMRRAAAwHgEEQAAMB5BBAAAjEcQAQAA4xFEAADAeAQRAAAwHkEEAACMRxABAADjEUQAAMB4BBEAADAeQQQAAIxHEAEAAOMRRAAAwHgEEQAAMB5BBAAAjEcQAQAA4xFEAADAeAQRAAAwHkEEAACMRxABAADjEUQAAMB4BBEAADAeQQQAAIxHEAEAAOMRRAAAwHgEEQAAMB5BBAAAjBfQIMrNzdWtt96qsLAwRUVF6cEHH9TBgwd9Zrxer+bOnauYmBiFhoYqOTlZ+/fv95mpq6vTlClT1LVrV3Xs2FFpaWk6evSoz0xVVZUyMjLkcDjkcDiUkZGh6urqlj5FAADQBgQ0iIqKijR58mTt2LFDhYWFOnPmjFJSUnTq1ClrZuHChVq0aJHy8vK0a9cuOZ1ODR8+XCdOnLBmsrOztWHDBq1bt05bt27VyZMnlZqaqrNnz1oz6enpKi0tVUFBgQoKClRaWqqMjAy/ni8AAGidbF6v1xvoJc47fvy4oqKiVFRUpLvuukter1cxMTHKzs7WzJkzJX13NSg6OloLFizQpEmT5PF4FBkZqbffflsPP/ywJOmrr76Sy+XSu+++qxEjRqisrEw33HCDduzYocTEREnSjh07lJSUpAMHDqh3796Ndqmrq1NdXZ11u6amRi6XSx6PR+Hh4S3232D3wNta7LGBtmrg7uJAr9AsRs5ZH+gVgFan4IWHW/Txa2pq5HA4fvDf71b1HiKPxyNJioiIkCSVl5fL7XYrJSXFmrHb7RoyZIi2bdsmSSopKVFDQ4PPTExMjOLj462Z7du3y+FwWDEkSYMGDZLD4bBmLpSbm2u9vOZwOORyuZr3ZAEAQKvRaoLI6/Vq6tSpuuOOOxQfHy9JcrvdkqTo6Gif2ejoaOs+t9utkJAQde7c+ZIzUVFRjZ4zKirKmrnQrFmz5PF4rK8jR45c2QkCAIBWKzjQC5z3xBNP6NNPP9XWrVsb3Wez2Xxue73eRscudOHMxeYv9Th2u112u/1yVgcAAG1cq7hCNGXKFG3cuFEffPCBunfvbh13Op2S1OgqTmVlpXXVyOl0qr6+XlVVVZecOXbsWKPnPX78eKOrTwAAwDwBDSKv16snnnhCf/zjH/X3v/9dsbGxPvfHxsbK6XSqsLDQOlZfX6+ioiINHjxYkpSQkKB27dr5zFRUVGjfvn3WTFJSkjwej4qL//8bM3fu3CmPx2PNAAAAcwX0JbPJkydr7dq1eueddxQWFmZdCXI4HAoNDZXNZlN2drZycnIUFxenuLg45eTkqEOHDkpPT7dmx48fr2nTpqlLly6KiIjQ9OnT1a9fPw0bNkyS1LdvX40cOVITJkzQsmXLJEkTJ05UamrqRT9hBgAAzBLQIMrPz5ckJScn+xx/6623NHbsWEnSjBkzVFtbq8zMTFVVVSkxMVGbNm1SWFiYNb948WIFBwdr9OjRqq2t1dChQ7Vy5UoFBQVZM2vWrFFWVpb1abS0tDTl5eW17AkCAIA2oVX9HKLW7HJ/jsGV4ucQAY3xc4iAqxc/hwgAAKCVIIgAAIDxCCIAAGA8gggAABiPIAIAAMYjiAAAgPEIIgAAYDyCCAAAGI8gAgAAxiOIAACA8QgiAABgPIIIAAAYjyACAADGI4gAAIDxCCIAAGA8gggAABiPIAIAAMYjiAAAgPEIIgAAYDyCCAAAGI8gAgAAxiOIAACA8QgiAABgPIIIAAAYjyACAADGI4gAAIDxCCIAAGA8gggAABiPIAIAAMYjiAAAgPEIIgAAYDyCCAAAGI8gAgAAxiOIAACA8QgiAABgPIIIAAAYjyACAADGI4gAAIDxCCIAAGA8gggAABiPIAIAAMYjiAAAgPEIIgAAYDyCCAAAGI8gAgAAxiOIAACA8QgiAABgPIIIAAAYjyACAADGI4gAAIDxCCIAAGA8gggAABiPIAIAAMYjiAAAgPEIIgAAYDyCCAAAGI8gAgAAxiOIAACA8QgiAABgPIIIAAAYjyACAADGI4gAAIDxCCIAAGA8gggAABiPIAIAAMYjiAAAgPGMCqLXXntNsbGxat++vRISEvThhx8GeiUAANAKGBNE69evV3Z2tmbPnq09e/bozjvv1D333KPDhw8HejUAABBgxgTRokWLNH78eD322GPq27evlixZIpfLpfz8/ECvBgAAAiw40Av4Q319vUpKSvTMM8/4HE9JSdG2bdsu+j11dXWqq6uzbns8HklSTU1Nyy0q6eTZsy36+EBb1NJ/7/zlTN3pQK8AtDot/ff7/ON7vd5LzhkRRF9//bXOnj2r6Ohon+PR0dFyu90X/Z7c3Fw9//zzjY67XK4W2RHAJTgcgd4AQAtxvDTOL89z4sQJOS7x/xIjgug8m83mc9vr9TY6dt6sWbM0depU6/a5c+f0zTffqEuXLt/7Pbh61NTUyOVy6ciRIwoPDw/0OgCaEX+/zeL1enXixAnFxMRccs6IIOratauCgoIaXQ2qrKxsdNXoPLvdLrvd7nPs2muvbakV0UqFh4fzP0zgKsXfb3Nc6srQeUa8qTokJEQJCQkqLCz0OV5YWKjBgwcHaCsAANBaGHGFSJKmTp2qjIwMDRw4UElJSVq+fLkOHz6sxx9/PNCrAQCAADMmiB5++GH95z//0bx581RRUaH4+Hi9++676tmzZ6BXQytkt9v13HPPNXrZFEDbx99vXIzN+0OfQwMAALjKGfEeIgAAgEshiAAAgPEIIgAAYDyCCAAAGI8gAi7gdrs1ZcoU9erVS3a7XS6XS/fff782b94c6NUAXIGxY8fqwQcfDPQaaKWM+dg9cDn+9a9/6fbbb9e1116rhQsXqn///mpoaNB7772nyZMn68CBA4FeEQDQAggi4H9kZmbKZrOpuLhYHTt2tI7feOONGjfOP7+AEADgf7xkBvzXN998o4KCAk2ePNknhs7jd9kBwNWLIAL+6x//+Ie8Xq/69OkT6FUAAH5GEAH/df6HtttstgBvAgDwN4II+K+4uDjZbDaVlZUFehUAgJ8RRMB/RUREaMSIEfrVr36lU6dONbq/urra/0sBAPyCT5kB/+O1117T4MGDddttt2nevHnq37+/zpw5o8LCQuXn53P1CGjjPB6PSktLfY5FRESoR48egVkIrQZBBPyP2NhYffzxx5o/f76mTZumiooKRUZGKiEhQfn5+YFeD8AV2rJliwYMGOBz7NFHH9XKlSsDsxBaDZv3/DtJAQAADMV7iAAAgPEIIgAAYDyCCAAAGI8gAgAAxiOIAACA8QgiAABgPIIIAAAYjyACAADGI4gAXBWSk5OVnZ19WbNbtmyRzWa74t9Pd91112nJkiVX9BgAWgeCCAAAGI8gAgAAxiOIAFx1Vq9erYEDByosLExOp1Pp6emqrKxsNPfRRx/ppptuUvv27ZWYmKi9e/f63L9t2zbdddddCg0NlcvlUlZWlk6dOuWv0wDgRwQRgKtOfX29XnjhBX3yySf605/+pPLyco0dO7bR3NNPP62XX35Zu3btUlRUlNLS0tTQ0CBJ2rt3r0aMGKFRo0bp008/1fr167V161Y98cQTfj4bAP4QHOgFAKC5jRs3zvpzr1699Oqrr+q2227TyZMn1alTJ+u+5557TsOHD5ckrVq1St27d9eGDRs0evRovfTSS0pPT7feqB0XF6dXX31VQ4YMUX5+vtq3b+/XcwLQsrhCBOCqs2fPHj3wwAPq2bOnwsLClJycLEk6fPiwz1xSUpL154iICPXu3VtlZWWSpJKSEq1cuVKdOnWyvkaMGKFz586pvLzcb+cCwD+4QgTgqnLq1CmlpKQoJSVFq1evVmRkpA4fPqwRI0aovr7+B7/fZrNJks6dO6dJkyYpKyur0UyPHj2afW8AgUUQAbiqHDhwQF9//bVefPFFuVwuSdLu3bsvOrtjxw4rbqqqqnTo0CH16dNHknTLLbdo//79uv766/2zOICA4iUzAFeVHj16KCQkREuXLtUXX3yhjRs36oUXXrjo7Lx587R582bt27dPY8eOVdeuXfXggw9KkmbOnKnt27dr8uTJKi0t1eeff66NGzdqypQpfjwbAP5CEAG4qkRGRmrlypX63e9+pxtuuEEvvviiXn755YvOvvjii3ryySeVkJCgiooKbdy4USEhIZKk/v37q6ioSJ9//rnuvPNODRgwQHPmzFG3bt38eToA/MTm9Xq9gV4CAAAgkLhCBAAAjEcQAQAA4xFEAADAeAQRAAAwHkEEAACMRxABAADjEUQAAMB4BBEAADAeQQQAAIxHEAEAAOMRRAAAwHj/Dwqq7XZsuSkjAAAAAElFTkSuQmCC",
|
| 477 |
+
"text/plain": [
|
| 478 |
+
"<Figure size 640x480 with 1 Axes>"
|
| 479 |
+
]
|
| 480 |
+
},
|
| 481 |
+
"metadata": {},
|
| 482 |
+
"output_type": "display_data"
|
| 483 |
+
}
|
| 484 |
+
],
|
| 485 |
+
"source": [
|
| 486 |
+
"df = describe_dataset(\"./err.train.csv\")\n",
|
| 487 |
+
"sns.countplot(x='label', data=df, palette=\"Set1\") "
|
| 488 |
+
]
|
| 489 |
+
},
|
| 490 |
+
{
|
| 491 |
+
"cell_type": "markdown",
|
| 492 |
+
"metadata": {},
|
| 493 |
+
"source": [
|
| 494 |
+
"### 3.2. Test set"
|
| 495 |
+
]
|
| 496 |
+
},
|
| 497 |
+
{
|
| 498 |
+
"cell_type": "code",
|
| 499 |
+
"execution_count": 8,
|
| 500 |
+
"metadata": {},
|
| 501 |
+
"outputs": [
|
| 502 |
+
{
|
| 503 |
+
"data": {
|
| 504 |
+
"text/plain": [
|
| 505 |
+
"<AxesSubplot:xlabel='count', ylabel='label'>"
|
| 506 |
+
]
|
| 507 |
+
},
|
| 508 |
+
"execution_count": 8,
|
| 509 |
+
"metadata": {},
|
| 510 |
+
"output_type": "execute_result"
|
| 511 |
+
},
|
| 512 |
+
{
|
| 513 |
+
"data": {
|
| 514 |
+
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAisAAAGwCAYAAABo5yU1AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/NK7nSAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAXDUlEQVR4nO3df2xV9f348dethYKMVlCgMotWZf7GOcAMNUPFqZlzms93iVnMhmOLUQQlmszhvs4l3y8WlsxF4s85448sC2aijoSJogJuToOzoCA/tjkVUBA3pUXQonC+fyzc7zqcK6X0vgqPR3ITe8657ev6dusz555zWyqKoggAgKSqKj0AAMBnESsAQGpiBQBITawAAKmJFQAgNbECAKQmVgCA1KorPcCe2rFjR7z99tvRv3//KJVKlR4HAOiAoihi8+bNMXTo0Kiq+uxzJz0+Vt5+++1oaGio9BgAQCesXbs2DjvssM88psfHSv/+/SPiny+2tra2wtMAAB3R2toaDQ0N5d/jn6XHx8rOt35qa2vFCgD0MB25hMMFtgBAamIFAEhNrAAAqYkVACA1sQIApCZWAIDUxAoAkJpYAQBSEysAQGpiBQBITawAAKmJFQAgNbECAKQmVgCA1MQKAJCaWAEAUquu9ABdpXnsWfG5Aw6o9BgAsM8Y9afFlR4hIpxZAQCSEysAQGpiBQBITawAAKmJFQAgNbECAKQmVgCA1MQKAJCaWAEAUhMrAEBqYgUASE2sAACpiRUAIDWxAgCkJlYAgNTECgCQmlgBAFITKwBAamIFAEhNrAAAqYkVACA1sQIApCZWAIDUxAoAkJpYAQBSEysAQGpiBQBITawAAKmJFQAgNbECAKQmVgCA1MQKAJCaWAEAUhMrAEBqYgUASE2sAACpiRUAIDWxAgCkJlYAgNTECgCQmlgBAFITKwBAamIFAEhNrAAAqYkVACA1sQIApCZWAIDUxAoAkJpYAQBSEysAQGpiBQBITawAAKmJFQAgNbECAKQmVgCA1MQKAJCaWAEAUhMrAEBqYgUASE2sAACpiRUAIDWxAgCkJlYAgNTECgCQmlgBAFITKwBAamIFAEhNrAAAqYkVACC1isfKhg0bYvLkyXHkkUdGTU1NNDQ0xIUXXhhPP/10pUcDABKoruQPf+ONN+L000+Pgw46KH7605/GiBEj4uOPP44nnngirrrqqli1alUlxwMAEqhorEycODFKpVIsXrw4+vXrV95+wgknxIQJEyo4GQCQRcVi5b333ot58+bFtGnT2oXKTgcddNCnPq+trS3a2trKX7e2tu6tEQGABCp2zcpf//rXKIoijj322N16XlNTU9TV1ZUfDQ0Ne2lCACCDisVKURQREVEqlXbreVOnTo2WlpbyY+3atXtjPAAgiYrFyvDhw6NUKsXKlSt363k1NTVRW1vb7gEA7LsqFisDBw6M8847L26//fbYsmXLLvs3bdrU/UMBAOlU9HNW7rjjjti+fXuceuqpMXv27PjLX/4SK1eujJkzZ8aYMWMqORoAkERFb11ubGyM5ubmmDZtWlx33XWxfv36GDRoUIwcOTLuvPPOSo4GACRRKnZe6dpDtba2Rl1dXSz44pficwccUOlxAGCfMepPi/fa9975+7ulpeW/Xn9a8Y/bBwD4LGIFAEhNrAAAqYkVACA1sQIApCZWAIDUxAoAkJpYAQBSEysAQGpiBQBITawAAKmJFQAgNbECAKQmVgCA1MQKAJCaWAEAUhMrAEBqYgUASE2sAACpiRUAIDWxAgCkJlYAgNTECgCQmlgBAFITKwBAamIFAEhNrAAAqYkVACA1sQIApCZWAIDUxAoAkJpYAQBSEysAQGpiBQBITawAAKmJFQAgNbECAKQmVgCA1MQKAJCaWAEAUhMrAEBqYgUASE2sAACpiRUAIDWxAgCkJlYAgNTECgCQmlgBAFITKwBAamIFAEhNrAAAqYkVACA1sQIApCZWAIDUxAoAkJpYAQBSEysAQGpiBQBITawAAKmJFQAgNbECAKQmVgCA1MQKAJCaWAEAUhMrAEBqYgUASK260gN0lS8tWhC1tbWVHgMA6GLOrAAAqYkVACA1sQIApCZWAIDUxAoAkJpYAQBS6/CtyzNnzuzwN7366qs7NQwAwL8rFUVRdOTAxsbGjn3DUin+9re/7dFQu6O1tTXq6uqipaXF56wAQA+xO7+/O3xm5fXXX9/jwQAAdtceXbOybdu2WL16dXzyySddNQ8AQDudipWtW7fG9773vTjwwAPjhBNOiDVr1kTEP69VmT59epcOCADs3zoVK1OnTo2XX345Fi5cGH369ClvP+ecc+Khhx7qsuEAADr1hwwfe+yxeOihh+LLX/5ylEql8vbjjz8+XnvttS4bDgCgU2dW3n333Rg8ePAu27ds2dIuXgAA9lSnYmX06NExd+7c8tc7A+Wee+6JMWPGdM1kAADRybeBmpqa4vzzz48VK1bEJ598Erfeemu8+uqr8fzzz8eiRYu6ekYAYD/WqTMrp512Wjz33HOxdevWOOqoo+LJJ5+MIUOGxPPPPx8jR47s6hkBgP1Yhz/BNiufYAsAPc9e+QTbf7d9+/Z49NFHY+XKlVEqleK4446Liy66KKqrO/0tAQB20amyWL58eVx00UWxYcOGOOaYYyIi4s9//nMMGjQo5syZEyeddFKXDgkA7L86dc3K97///TjhhBNi3bp10dzcHM3NzbF27doYMWJEXH755V09IwCwH+vUmZWXX345/vSnP8WAAQPK2wYMGBDTpk2L0aNHd9lwAACdOrNyzDHHxDvvvLPL9o0bN8bRRx+9x0MBAOzU4VhpbW0tP26++ea4+uqr4+GHH45169bFunXr4uGHH44pU6bEjBkz9ua8AMB+psO3LldVVbX7KP2dT9u57V+/3r59e1fP+R+5dRkAep69cuvyggUL9ngwAIDd1eFYGTt27N6cAwDgU+3RJ7ht3bo11qxZE9u2bWu3fcSIEXs0FADATp2KlXfffTe++93vxuOPP/6p+7vzmhUAYN/WqVuXp0yZEu+//3688MIL0bdv35g3b1488MADMXz48JgzZ05XzwgA7Mc6dWblmWeeid/+9rcxevToqKqqisMPPzy++tWvRm1tbTQ1NcUFF1zQ1XMCAPupTp1Z2bJlSwwePDgiIgYOHBjvvvtuREScdNJJ0dzc3HXTAQD7vU5/gu3q1asjIuKLX/xi3H333fHWW2/FXXfdFYceemiXDggA7N869TbQlClTYv369RERcdNNN8V5550Xv/rVr6J3797xwAMPdOmAAMD+rcOfYPtZtm7dGqtWrYphw4bFIYcc0hVzdZhPsAWAnmevfILttdde2+EBbrnllg4fCwDwWTocK0uWLOnQcf/694O60//839lRXXNgRX42AOyJef/nkkqPkJq/DQQApNapu4EAALqLWAEAUhMrAEBqYgUASE2sAACpiRUAIDWxAgCkJlYAgNTECgCQmlgBAFITKwBAamIFAEhNrAAAqYkVACA1sQIApCZWAIDUxAoAkJpYAQBSEysAQGpiBQBITawAAKmJFQAgNbECAKQmVgCA1MQKAJCaWAEAUhMrAEBqYgUASE2sAACpiRUAIDWxAgCkJlYAgNTECgCQmlgBAFITKwBAamIFAEhNrAAAqYkVACA1sQIApCZWAIDUxAoAkJpYAQBSEysAQGpiBQBITawAAKmJFQAgNbECAKQmVgCA1MQKAJCaWAEAUhMrAEBqYgUASE2sAACpiRUAIDWxAgCkJlYAgNTECgCQmlgBAFITKwBAamIFAEhNrAAAqYkVACA1sQIApCZWAIDUxAoAkJpYAQBSEysAQGpiBQBITawAAKmliJXLLrssLr744kqPAQAklCJWAAD+k+pKD7C72traoq2trfx1a2trBacBAPa2HndmpampKerq6sqPhoaGSo8EAOxFPS5Wpk6dGi0tLeXH2rVrKz0SALAX9bi3gWpqaqKmpqbSYwAA3aTHnVkBAPYvYgUASC3N20AtLS2xdOnSdtsGDhwYw4YNq8xAAEAKaWJl4cKFccopp7TbNn78+Lj//vsrMxAAkEKKWLn//vtFCQDwqVyzAgCkJlYAgNTECgCQmlgBAFITKwBAamIFAEhNrAAAqYkVACA1sQIApCZWAIDUxAoAkJpYAQBSEysAQGpiBQBITawAAKmJFQAgNbECAKQmVgCA1MQKAJCaWAEAUhMrAEBqYgUASE2sAACpiRUAIDWxAgCkJlYAgNTECgCQmlgBAFITKwBAamIFAEhNrAAAqYkVACA1sQIApCZWAIDUxAoAkJpYAQBSEysAQGpiBQBITawAAKmJFQAgNbECAKQmVgCA1MQKAJCaWAEAUhMrAEBqYgUASE2sAACpiRUAIDWxAgCkJlYAgNTECgCQmlgBAFITKwBAamIFAEhNrAAAqYkVACA1sQIApCZWAIDUxAoAkJpYAQBSEysAQGpiBQBITawAAKmJFQAgNbECAKQmVgCA1KorPUBXeeR//6+ora2t9BgAQBdzZgUASE2sAACpiRUAIDWxAgCkJlYAgNTECgCQmlgBAFITKwBAamIFAEhNrAAAqYkVACA1sQIApCZWAIDUxAoAkJpYAQBSEysAQGpiBQBIrbrSA+ypoigiIqK1tbXCkwAAHbXz9/bO3+OfpcfHyj/+8Y+IiGhoaKjwJADA7tq8eXPU1dV95jE9PlYGDhwYERFr1qz5ry+W7tPa2hoNDQ2xdu3aqK2trfQ4hDXJyrrkY026R1EUsXnz5hg6dOh/PbbHx0pV1T8vu6mrq/MfVUK1tbXWJRlrkpN1ycea7H0dPcngAlsAIDWxAgCk1uNjpaamJm666aaoqamp9Cj8C+uSjzXJybrkY03yKRUduWcIAKBCevyZFQBg3yZWAIDUxAoAkJpYAQBS69Gxcscdd0RjY2P06dMnRo4cGb///e8rPdI+7dlnn40LL7wwhg4dGqVSKR577LF2+4uiiJ/85CcxdOjQ6Nu3b5x55pnx6quvtjumra0tJk+eHIccckj069cvvvGNb8S6deu68VXsW5qammL06NHRv3//GDx4cFx88cWxevXqdsdYl+535513xogRI8ofKjZmzJh4/PHHy/utSeU1NTVFqVSKKVOmlLdZl8SKHmrWrFlFr169invuuadYsWJFcc011xT9+vUr3nzzzUqPts/63e9+V/zoRz8qZs+eXURE8eijj7bbP3369KJ///7F7Nmzi2XLlhWXXHJJceihhxatra3lY6644ori85//fDF//vyiubm5OOuss4qTTz65+OSTT7r51ewbzjvvvOK+++4rli9fXixdurS44IILimHDhhUffPBB+Rjr0v3mzJlTzJ07t1i9enWxevXq4oYbbih69epVLF++vCgKa1JpixcvLo444ohixIgRxTXXXFPebl3y6rGxcuqppxZXXHFFu23HHnts8cMf/rBCE+1f/j1WduzYUdTX1xfTp08vb/voo4+Kurq64q677iqKoig2bdpU9OrVq5g1a1b5mLfeequoqqoq5s2b122z78s2btxYRESxaNGioiisSyYDBgwofvnLX1qTCtu8eXMxfPjwYv78+cXYsWPLsWJdcuuRbwNt27YtXnrppTj33HPbbT/33HPjj3/8Y4Wm2r+9/vrrsWHDhnZrUlNTE2PHji2vyUsvvRQff/xxu2OGDh0aJ554onXrIi0tLRHx///Ap3WpvO3bt8esWbNiy5YtMWbMGGtSYVdddVVccMEFcc4557Tbbl1y65F/yPDvf/97bN++PYYMGdJu+5AhQ2LDhg0Vmmr/tvPf+6etyZtvvlk+pnfv3jFgwIBdjrFue64oirj22mvjjDPOiBNPPDEirEslLVu2LMaMGRMfffRRfO5zn4tHH300jj/++PIvNWvS/WbNmhXNzc3x4osv7rLP/1Zy65GxslOpVGr3dVEUu2yje3VmTaxb15g0aVK88sor8Yc//GGXfdal+x1zzDGxdOnS2LRpU8yePTvGjx8fixYtKu+3Jt1r7dq1cc0118STTz4Zffr0+Y/HWZeceuTbQIccckgccMABu5Tsxo0bd6liukd9fX1ExGeuSX19fWzbti3ef//9/3gMnTN58uSYM2dOLFiwIA477LDydutSOb17946jjz46Ro0aFU1NTXHyySfHrbfeak0q5KWXXoqNGzfGyJEjo7q6Oqqrq2PRokUxc+bMqK6uLv97tS459chY6d27d4wcOTLmz5/fbvv8+fPjtNNOq9BU+7fGxsaor69vtybbtm2LRYsWlddk5MiR0atXr3bHrF+/PpYvX27dOqkoipg0aVI88sgj8cwzz0RjY2O7/dYlj6Iooq2tzZpUyLhx42LZsmWxdOnS8mPUqFFx6aWXxtKlS+PII4+0LplV5rrePbfz1uV77723WLFiRTFlypSiX79+xRtvvFHp0fZZmzdvLpYsWVIsWbKkiIjilltuKZYsWVK+XXz69OlFXV1d8cgjjxTLli0rvvWtb33qbX+HHXZY8dRTTxXNzc3F2Wef7ba/PXDllVcWdXV1xcKFC4v169eXH1u3bi0fY12639SpU4tnn322eP3114tXXnmluOGGG4qqqqriySefLIrCmmTxr3cDFYV1yazHxkpRFMXtt99eHH744UXv3r2LL33pS+XbNdk7FixYUETELo/x48cXRfHPW/9uuummor6+vqipqSm+8pWvFMuWLWv3PT788MNi0qRJxcCBA4u+ffsWX//614s1a9ZU4NXsGz5tPSKiuO+++8rHWJfuN2HChPL/Nw0aNKgYN25cOVSKwppk8e+xYl3yKhVFUVTmnA4AwH/XI69ZAQD2H2IFAEhNrAAAqYkVACA1sQIApCZWAIDUxAoAkJpYAQBSEysAQGpiBdgnvfHGG1EqlWLp0qWVHgXYQ2IFAEhNrAB7xY4dO2LGjBlx9NFHR01NTQwbNiymTZsWERHLli2Ls88+O/r27RsHH3xwXH755fHBBx+Un3vmmWfGlClT2n2/iy++OC677LLy10cccUTcfPPNMWHChOjfv38MGzYsfvGLX5T3NzY2RkTEKaecEqVSKc4888y99lqBvUusAHvF1KlTY8aMGXHjjTfGihUr4te//nUMGTIktm7dGueff34MGDAgXnzxxfjNb34TTz31VEyaNGm3f8bPfvazGDVqVCxZsiQmTpwYV155ZaxatSoiIhYvXhwREU899VSsX78+HnnkkS59fUD3qa70AMC+Z/PmzXHrrbfGbbfdFuPHj4+IiKOOOirOOOOMuOeee+LDDz+MBx98MPr16xcREbfddltceOGFMWPGjBgyZEiHf87Xvva1mDhxYkREXH/99fHzn/88Fi5cGMcee2wMGjQoIiIOPvjgqK+v7+JXCHQnZ1aALrdy5cpoa2uLcePGfeq+k08+uRwqERGnn3567NixI1avXr1bP2fEiBHlfy6VSlFfXx8bN27s/OBASmIF6HJ9+/b9j/uKoohSqfSp+3Zur6qqiqIo2u37+OOPdzm+V69euzx/x44duzsukJxYAbrc8OHDo2/fvvH000/vsu/444+PpUuXxpYtW8rbnnvuuaiqqoovfOELERExaNCgWL9+fXn/9u3bY/ny5bs1Q+/evcvPBXo2sQJ0uT59+sT1118fP/jBD+LBBx+M1157LV544YW4995749JLL40+ffrE+PHjY/ny5bFgwYKYPHlyfPvb3y5fr3L22WfH3LlzY+7cubFq1aqYOHFibNq0abdmGDx4cPTt2zfmzZsX77zzTrS0tOyFVwp0B7EC7BU33nhjXHfddfHjH/84jjvuuLjkkkti48aNceCBB8YTTzwR7733XowePTq++c1vxrhx4+K2224rP3fChAkxfvz4+M53vhNjx46NxsbGOOuss3br51dXV8fMmTPj7rvvjqFDh8ZFF13U1S8R6Cal4t/fGAYASMSZFQAgNbECAKQmVgCA1MQKAJCaWAEAUhMrAEBqYgUASE2sAACpiRUAIDWxAgCkJlYAgNT+H4Eq19D89dKdAAAAAElFTkSuQmCC",
|
| 515 |
+
"text/plain": [
|
| 516 |
+
"<Figure size 640x480 with 1 Axes>"
|
| 517 |
+
]
|
| 518 |
+
},
|
| 519 |
+
"metadata": {},
|
| 520 |
+
"output_type": "display_data"
|
| 521 |
+
}
|
| 522 |
+
],
|
| 523 |
+
"source": [
|
| 524 |
+
"sns.countplot(y='label', data=test_df, palette=\"Set1\") "
|
| 525 |
+
]
|
| 526 |
+
},
|
| 527 |
+
{
|
| 528 |
+
"cell_type": "code",
|
| 529 |
+
"execution_count": null,
|
| 530 |
+
"metadata": {},
|
| 531 |
+
"outputs": [],
|
| 532 |
+
"source": []
|
| 533 |
+
}
|
| 534 |
+
],
|
| 535 |
+
"metadata": {
|
| 536 |
+
"kernelspec": {
|
| 537 |
+
"display_name": "Python 3.8.13 (conda)",
|
| 538 |
+
"language": "python",
|
| 539 |
+
"name": "python3"
|
| 540 |
+
},
|
| 541 |
+
"language_info": {
|
| 542 |
+
"codemirror_mode": {
|
| 543 |
+
"name": "ipython",
|
| 544 |
+
"version": 3
|
| 545 |
+
},
|
| 546 |
+
"file_extension": ".py",
|
| 547 |
+
"mimetype": "text/x-python",
|
| 548 |
+
"name": "python",
|
| 549 |
+
"nbconvert_exporter": "python",
|
| 550 |
+
"pygments_lexer": "ipython3",
|
| 551 |
+
"version": "3.8.13"
|
| 552 |
+
},
|
| 553 |
+
"orig_nbformat": 4,
|
| 554 |
+
"vscode": {
|
| 555 |
+
"interpreter": {
|
| 556 |
+
"hash": "9260f401923fb5c4108c543a7d176de9733d378b3752e49535ad7c43c2271b65"
|
| 557 |
+
}
|
| 558 |
+
}
|
| 559 |
+
},
|
| 560 |
+
"nbformat": 4,
|
| 561 |
+
"nbformat_minor": 2
|
| 562 |
+
}
|
core/lunge_model/6.err.sklearn.ipynb
ADDED
|
@@ -0,0 +1,777 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": 1,
|
| 6 |
+
"metadata": {},
|
| 7 |
+
"outputs": [
|
| 8 |
+
{
|
| 9 |
+
"name": "stderr",
|
| 10 |
+
"output_type": "stream",
|
| 11 |
+
"text": [
|
| 12 |
+
"objc[58344]: Class CaptureDelegate is implemented in both /Users/fuixlabsdev1/Programming/PP/graduation-thesis/env/lib/python3.8/site-packages/mediapipe/.dylibs/libopencv_videoio.3.4.16.dylib (0x10ae08860) and /Users/fuixlabsdev1/Programming/PP/graduation-thesis/env/lib/python3.8/site-packages/cv2/cv2.abi3.so (0x15eece480). One of the two will be used. Which one is undefined.\n",
|
| 13 |
+
"objc[58344]: Class CVWindow is implemented in both /Users/fuixlabsdev1/Programming/PP/graduation-thesis/env/lib/python3.8/site-packages/mediapipe/.dylibs/libopencv_highgui.3.4.16.dylib (0x105baca68) and /Users/fuixlabsdev1/Programming/PP/graduation-thesis/env/lib/python3.8/site-packages/cv2/cv2.abi3.so (0x15eece4d0). One of the two will be used. Which one is undefined.\n",
|
| 14 |
+
"objc[58344]: Class CVView is implemented in both /Users/fuixlabsdev1/Programming/PP/graduation-thesis/env/lib/python3.8/site-packages/mediapipe/.dylibs/libopencv_highgui.3.4.16.dylib (0x105baca90) and /Users/fuixlabsdev1/Programming/PP/graduation-thesis/env/lib/python3.8/site-packages/cv2/cv2.abi3.so (0x15eece4f8). One of the two will be used. Which one is undefined.\n",
|
| 15 |
+
"objc[58344]: Class CVSlider is implemented in both /Users/fuixlabsdev1/Programming/PP/graduation-thesis/env/lib/python3.8/site-packages/mediapipe/.dylibs/libopencv_highgui.3.4.16.dylib (0x105bacab8) and /Users/fuixlabsdev1/Programming/PP/graduation-thesis/env/lib/python3.8/site-packages/cv2/cv2.abi3.so (0x15eece520). One of the two will be used. Which one is undefined.\n"
|
| 16 |
+
]
|
| 17 |
+
}
|
| 18 |
+
],
|
| 19 |
+
"source": [
|
| 20 |
+
"import mediapipe as mp\n",
|
| 21 |
+
"import cv2\n",
|
| 22 |
+
"import pandas as pd\n",
|
| 23 |
+
"import pickle\n",
|
| 24 |
+
"\n",
|
| 25 |
+
"from sklearn.model_selection import train_test_split\n",
|
| 26 |
+
"from sklearn.linear_model import LogisticRegression, SGDClassifier\n",
|
| 27 |
+
"from sklearn.svm import SVC\n",
|
| 28 |
+
"from sklearn.neighbors import KNeighborsClassifier\n",
|
| 29 |
+
"from sklearn.tree import DecisionTreeClassifier\n",
|
| 30 |
+
"from sklearn.ensemble import RandomForestClassifier\n",
|
| 31 |
+
"from sklearn.naive_bayes import GaussianNB\n",
|
| 32 |
+
"from sklearn.metrics import precision_score, accuracy_score, f1_score, recall_score, confusion_matrix\n",
|
| 33 |
+
"from sklearn.calibration import CalibratedClassifierCV\n",
|
| 34 |
+
"\n",
|
| 35 |
+
"import warnings\n",
|
| 36 |
+
"warnings.filterwarnings('ignore')\n",
|
| 37 |
+
"\n",
|
| 38 |
+
"# Drawing helpers\n",
|
| 39 |
+
"mp_drawing = mp.solutions.drawing_utils\n",
|
| 40 |
+
"mp_pose = mp.solutions.pose"
|
| 41 |
+
]
|
| 42 |
+
},
|
| 43 |
+
{
|
| 44 |
+
"cell_type": "markdown",
|
| 45 |
+
"metadata": {},
|
| 46 |
+
"source": [
|
| 47 |
+
"## 1. Set up important functions"
|
| 48 |
+
]
|
| 49 |
+
},
|
| 50 |
+
{
|
| 51 |
+
"cell_type": "code",
|
| 52 |
+
"execution_count": 2,
|
| 53 |
+
"metadata": {},
|
| 54 |
+
"outputs": [],
|
| 55 |
+
"source": [
|
| 56 |
+
"def rescale_frame(frame, percent=50):\n",
|
| 57 |
+
" '''\n",
|
| 58 |
+
" Rescale a frame to a certain percentage compare to its original frame\n",
|
| 59 |
+
" '''\n",
|
| 60 |
+
" width = int(frame.shape[1] * percent/ 100)\n",
|
| 61 |
+
" height = int(frame.shape[0] * percent/ 100)\n",
|
| 62 |
+
" dim = (width, height)\n",
|
| 63 |
+
" return cv2.resize(frame, dim, interpolation = cv2.INTER_AREA)\n",
|
| 64 |
+
"\n",
|
| 65 |
+
"\n",
|
| 66 |
+
"def describe_dataset(dataset_path: str):\n",
|
| 67 |
+
" '''\n",
|
| 68 |
+
" Describe dataset\n",
|
| 69 |
+
" '''\n",
|
| 70 |
+
"\n",
|
| 71 |
+
" data = pd.read_csv(dataset_path)\n",
|
| 72 |
+
" print(f\"Headers: {list(data.columns.values)}\")\n",
|
| 73 |
+
" print(f'Number of rows: {data.shape[0]} \\nNumber of columns: {data.shape[1]}\\n')\n",
|
| 74 |
+
" print(f\"Labels: \\n{data['label'].value_counts()}\\n\")\n",
|
| 75 |
+
" print(f\"Missing values: {data.isnull().values.any()}\\n\")\n",
|
| 76 |
+
" \n",
|
| 77 |
+
" duplicate = data[data.duplicated()]\n",
|
| 78 |
+
" print(f\"Duplicate Rows : {len(duplicate.sum(axis=1))}\")\n",
|
| 79 |
+
"\n",
|
| 80 |
+
" return data\n",
|
| 81 |
+
"\n",
|
| 82 |
+
"\n",
|
| 83 |
+
"def round_up_metric_results(results) -> list:\n",
|
| 84 |
+
" '''Round up metrics results such as precision score, recall score, ...'''\n",
|
| 85 |
+
" return list(map(lambda el: round(el, 3), results))"
|
| 86 |
+
]
|
| 87 |
+
},
|
| 88 |
+
{
|
| 89 |
+
"cell_type": "markdown",
|
| 90 |
+
"metadata": {},
|
| 91 |
+
"source": [
|
| 92 |
+
"## 2. Describe and process data"
|
| 93 |
+
]
|
| 94 |
+
},
|
| 95 |
+
{
|
| 96 |
+
"cell_type": "code",
|
| 97 |
+
"execution_count": 3,
|
| 98 |
+
"metadata": {},
|
| 99 |
+
"outputs": [],
|
| 100 |
+
"source": [
|
| 101 |
+
"TRAIN_SET_PATH = \"./err.train.csv\"\n",
|
| 102 |
+
"TEST_SET_PATH = \"./err.test.csv\""
|
| 103 |
+
]
|
| 104 |
+
},
|
| 105 |
+
{
|
| 106 |
+
"cell_type": "code",
|
| 107 |
+
"execution_count": 6,
|
| 108 |
+
"metadata": {},
|
| 109 |
+
"outputs": [
|
| 110 |
+
{
|
| 111 |
+
"name": "stdout",
|
| 112 |
+
"output_type": "stream",
|
| 113 |
+
"text": [
|
| 114 |
+
"Headers: ['label', 'nose_x', 'nose_y', 'nose_z', 'nose_v', 'left_shoulder_x', 'left_shoulder_y', 'left_shoulder_z', 'left_shoulder_v', 'right_shoulder_x', 'right_shoulder_y', 'right_shoulder_z', 'right_shoulder_v', 'left_hip_x', 'left_hip_y', 'left_hip_z', 'left_hip_v', 'right_hip_x', 'right_hip_y', 'right_hip_z', 'right_hip_v', 'left_knee_x', 'left_knee_y', 'left_knee_z', 'left_knee_v', 'right_knee_x', 'right_knee_y', 'right_knee_z', 'right_knee_v', 'left_ankle_x', 'left_ankle_y', 'left_ankle_z', 'left_ankle_v', 'right_ankle_x', 'right_ankle_y', 'right_ankle_z', 'right_ankle_v', 'left_heel_x', 'left_heel_y', 'left_heel_z', 'left_heel_v', 'right_heel_x', 'right_heel_y', 'right_heel_z', 'right_heel_v', 'left_foot_index_x', 'left_foot_index_y', 'left_foot_index_z', 'left_foot_index_v', 'right_foot_index_x', 'right_foot_index_y', 'right_foot_index_z', 'right_foot_index_v']\n",
|
| 115 |
+
"Number of rows: 17907 \n",
|
| 116 |
+
"Number of columns: 53\n",
|
| 117 |
+
"\n",
|
| 118 |
+
"Labels: \n",
|
| 119 |
+
"L 9114\n",
|
| 120 |
+
"C 8793\n",
|
| 121 |
+
"Name: label, dtype: int64\n",
|
| 122 |
+
"\n",
|
| 123 |
+
"Missing values: False\n",
|
| 124 |
+
"\n",
|
| 125 |
+
"Duplicate Rows : 0\n"
|
| 126 |
+
]
|
| 127 |
+
},
|
| 128 |
+
{
|
| 129 |
+
"data": {
|
| 130 |
+
"text/html": [
|
| 131 |
+
"<div>\n",
|
| 132 |
+
"<style scoped>\n",
|
| 133 |
+
" .dataframe tbody tr th:only-of-type {\n",
|
| 134 |
+
" vertical-align: middle;\n",
|
| 135 |
+
" }\n",
|
| 136 |
+
"\n",
|
| 137 |
+
" .dataframe tbody tr th {\n",
|
| 138 |
+
" vertical-align: top;\n",
|
| 139 |
+
" }\n",
|
| 140 |
+
"\n",
|
| 141 |
+
" .dataframe thead th {\n",
|
| 142 |
+
" text-align: right;\n",
|
| 143 |
+
" }\n",
|
| 144 |
+
"</style>\n",
|
| 145 |
+
"<table border=\"1\" class=\"dataframe\">\n",
|
| 146 |
+
" <thead>\n",
|
| 147 |
+
" <tr style=\"text-align: right;\">\n",
|
| 148 |
+
" <th></th>\n",
|
| 149 |
+
" <th>label</th>\n",
|
| 150 |
+
" <th>nose_x</th>\n",
|
| 151 |
+
" <th>nose_y</th>\n",
|
| 152 |
+
" <th>nose_z</th>\n",
|
| 153 |
+
" <th>nose_v</th>\n",
|
| 154 |
+
" <th>left_shoulder_x</th>\n",
|
| 155 |
+
" <th>left_shoulder_y</th>\n",
|
| 156 |
+
" <th>left_shoulder_z</th>\n",
|
| 157 |
+
" <th>left_shoulder_v</th>\n",
|
| 158 |
+
" <th>right_shoulder_x</th>\n",
|
| 159 |
+
" <th>...</th>\n",
|
| 160 |
+
" <th>right_heel_z</th>\n",
|
| 161 |
+
" <th>right_heel_v</th>\n",
|
| 162 |
+
" <th>left_foot_index_x</th>\n",
|
| 163 |
+
" <th>left_foot_index_y</th>\n",
|
| 164 |
+
" <th>left_foot_index_z</th>\n",
|
| 165 |
+
" <th>left_foot_index_v</th>\n",
|
| 166 |
+
" <th>right_foot_index_x</th>\n",
|
| 167 |
+
" <th>right_foot_index_y</th>\n",
|
| 168 |
+
" <th>right_foot_index_z</th>\n",
|
| 169 |
+
" <th>right_foot_index_v</th>\n",
|
| 170 |
+
" </tr>\n",
|
| 171 |
+
" </thead>\n",
|
| 172 |
+
" <tbody>\n",
|
| 173 |
+
" <tr>\n",
|
| 174 |
+
" <th>17904</th>\n",
|
| 175 |
+
" <td>1</td>\n",
|
| 176 |
+
" <td>0.647438</td>\n",
|
| 177 |
+
" <td>0.442268</td>\n",
|
| 178 |
+
" <td>0.004114</td>\n",
|
| 179 |
+
" <td>0.999985</td>\n",
|
| 180 |
+
" <td>0.615798</td>\n",
|
| 181 |
+
" <td>0.517170</td>\n",
|
| 182 |
+
" <td>0.151706</td>\n",
|
| 183 |
+
" <td>0.999579</td>\n",
|
| 184 |
+
" <td>0.631354</td>\n",
|
| 185 |
+
" <td>...</td>\n",
|
| 186 |
+
" <td>-0.034228</td>\n",
|
| 187 |
+
" <td>0.979719</td>\n",
|
| 188 |
+
" <td>0.701826</td>\n",
|
| 189 |
+
" <td>0.880516</td>\n",
|
| 190 |
+
" <td>0.134222</td>\n",
|
| 191 |
+
" <td>0.979319</td>\n",
|
| 192 |
+
" <td>0.504880</td>\n",
|
| 193 |
+
" <td>0.881748</td>\n",
|
| 194 |
+
" <td>-0.027911</td>\n",
|
| 195 |
+
" <td>0.986165</td>\n",
|
| 196 |
+
" </tr>\n",
|
| 197 |
+
" <tr>\n",
|
| 198 |
+
" <th>17905</th>\n",
|
| 199 |
+
" <td>1</td>\n",
|
| 200 |
+
" <td>0.649652</td>\n",
|
| 201 |
+
" <td>0.419057</td>\n",
|
| 202 |
+
" <td>0.008783</td>\n",
|
| 203 |
+
" <td>0.999983</td>\n",
|
| 204 |
+
" <td>0.617577</td>\n",
|
| 205 |
+
" <td>0.503514</td>\n",
|
| 206 |
+
" <td>0.158545</td>\n",
|
| 207 |
+
" <td>0.999529</td>\n",
|
| 208 |
+
" <td>0.631972</td>\n",
|
| 209 |
+
" <td>...</td>\n",
|
| 210 |
+
" <td>-0.061176</td>\n",
|
| 211 |
+
" <td>0.980431</td>\n",
|
| 212 |
+
" <td>0.704606</td>\n",
|
| 213 |
+
" <td>0.880248</td>\n",
|
| 214 |
+
" <td>0.071476</td>\n",
|
| 215 |
+
" <td>0.979932</td>\n",
|
| 216 |
+
" <td>0.504513</td>\n",
|
| 217 |
+
" <td>0.881766</td>\n",
|
| 218 |
+
" <td>-0.088832</td>\n",
|
| 219 |
+
" <td>0.986975</td>\n",
|
| 220 |
+
" </tr>\n",
|
| 221 |
+
" <tr>\n",
|
| 222 |
+
" <th>17906</th>\n",
|
| 223 |
+
" <td>1</td>\n",
|
| 224 |
+
" <td>0.653556</td>\n",
|
| 225 |
+
" <td>0.400394</td>\n",
|
| 226 |
+
" <td>0.014852</td>\n",
|
| 227 |
+
" <td>0.999980</td>\n",
|
| 228 |
+
" <td>0.620734</td>\n",
|
| 229 |
+
" <td>0.486522</td>\n",
|
| 230 |
+
" <td>0.169807</td>\n",
|
| 231 |
+
" <td>0.999556</td>\n",
|
| 232 |
+
" <td>0.631171</td>\n",
|
| 233 |
+
" <td>...</td>\n",
|
| 234 |
+
" <td>-0.138678</td>\n",
|
| 235 |
+
" <td>0.979078</td>\n",
|
| 236 |
+
" <td>0.705475</td>\n",
|
| 237 |
+
" <td>0.878981</td>\n",
|
| 238 |
+
" <td>0.003690</td>\n",
|
| 239 |
+
" <td>0.979199</td>\n",
|
| 240 |
+
" <td>0.504067</td>\n",
|
| 241 |
+
" <td>0.882642</td>\n",
|
| 242 |
+
" <td>-0.183304</td>\n",
|
| 243 |
+
" <td>0.986824</td>\n",
|
| 244 |
+
" </tr>\n",
|
| 245 |
+
" </tbody>\n",
|
| 246 |
+
"</table>\n",
|
| 247 |
+
"<p>3 rows × 53 columns</p>\n",
|
| 248 |
+
"</div>"
|
| 249 |
+
],
|
| 250 |
+
"text/plain": [
|
| 251 |
+
" label nose_x nose_y nose_z nose_v left_shoulder_x \\\n",
|
| 252 |
+
"17904 1 0.647438 0.442268 0.004114 0.999985 0.615798 \n",
|
| 253 |
+
"17905 1 0.649652 0.419057 0.008783 0.999983 0.617577 \n",
|
| 254 |
+
"17906 1 0.653556 0.400394 0.014852 0.999980 0.620734 \n",
|
| 255 |
+
"\n",
|
| 256 |
+
" left_shoulder_y left_shoulder_z left_shoulder_v right_shoulder_x \\\n",
|
| 257 |
+
"17904 0.517170 0.151706 0.999579 0.631354 \n",
|
| 258 |
+
"17905 0.503514 0.158545 0.999529 0.631972 \n",
|
| 259 |
+
"17906 0.486522 0.169807 0.999556 0.631171 \n",
|
| 260 |
+
"\n",
|
| 261 |
+
" ... right_heel_z right_heel_v left_foot_index_x left_foot_index_y \\\n",
|
| 262 |
+
"17904 ... -0.034228 0.979719 0.701826 0.880516 \n",
|
| 263 |
+
"17905 ... -0.061176 0.980431 0.704606 0.880248 \n",
|
| 264 |
+
"17906 ... -0.138678 0.979078 0.705475 0.878981 \n",
|
| 265 |
+
"\n",
|
| 266 |
+
" left_foot_index_z left_foot_index_v right_foot_index_x \\\n",
|
| 267 |
+
"17904 0.134222 0.979319 0.504880 \n",
|
| 268 |
+
"17905 0.071476 0.979932 0.504513 \n",
|
| 269 |
+
"17906 0.003690 0.979199 0.504067 \n",
|
| 270 |
+
"\n",
|
| 271 |
+
" right_foot_index_y right_foot_index_z right_foot_index_v \n",
|
| 272 |
+
"17904 0.881748 -0.027911 0.986165 \n",
|
| 273 |
+
"17905 0.881766 -0.088832 0.986975 \n",
|
| 274 |
+
"17906 0.882642 -0.183304 0.986824 \n",
|
| 275 |
+
"\n",
|
| 276 |
+
"[3 rows x 53 columns]"
|
| 277 |
+
]
|
| 278 |
+
},
|
| 279 |
+
"execution_count": 6,
|
| 280 |
+
"metadata": {},
|
| 281 |
+
"output_type": "execute_result"
|
| 282 |
+
}
|
| 283 |
+
],
|
| 284 |
+
"source": [
|
| 285 |
+
"df = describe_dataset(TRAIN_SET_PATH)\n",
|
| 286 |
+
"# Categorizing label\n",
|
| 287 |
+
"df.loc[df[\"label\"] == \"L\", \"label\"] = 0\n",
|
| 288 |
+
"df.loc[df[\"label\"] == \"C\", \"label\"] = 1\n",
|
| 289 |
+
"\n",
|
| 290 |
+
"df.tail(3)"
|
| 291 |
+
]
|
| 292 |
+
},
|
| 293 |
+
{
|
| 294 |
+
"cell_type": "code",
|
| 295 |
+
"execution_count": 5,
|
| 296 |
+
"metadata": {},
|
| 297 |
+
"outputs": [],
|
| 298 |
+
"source": [
|
| 299 |
+
"with open(\"./model/input_scaler.pkl\", \"rb\") as f:\n",
|
| 300 |
+
" sc = pickle.load(f)"
|
| 301 |
+
]
|
| 302 |
+
},
|
| 303 |
+
{
|
| 304 |
+
"cell_type": "code",
|
| 305 |
+
"execution_count": 10,
|
| 306 |
+
"metadata": {},
|
| 307 |
+
"outputs": [],
|
| 308 |
+
"source": [
|
| 309 |
+
"# Extract features and class\n",
|
| 310 |
+
"X = df.drop(\"label\", axis=1)\n",
|
| 311 |
+
"y = df[\"label\"].astype(\"int\")\n",
|
| 312 |
+
"\n",
|
| 313 |
+
"X = pd.DataFrame(sc.transform(X))"
|
| 314 |
+
]
|
| 315 |
+
},
|
| 316 |
+
{
|
| 317 |
+
"cell_type": "code",
|
| 318 |
+
"execution_count": 11,
|
| 319 |
+
"metadata": {},
|
| 320 |
+
"outputs": [
|
| 321 |
+
{
|
| 322 |
+
"data": {
|
| 323 |
+
"text/plain": [
|
| 324 |
+
"10827 0\n",
|
| 325 |
+
"11395 0\n",
|
| 326 |
+
"3742 1\n",
|
| 327 |
+
"Name: label, dtype: int64"
|
| 328 |
+
]
|
| 329 |
+
},
|
| 330 |
+
"execution_count": 11,
|
| 331 |
+
"metadata": {},
|
| 332 |
+
"output_type": "execute_result"
|
| 333 |
+
}
|
| 334 |
+
],
|
| 335 |
+
"source": [
|
| 336 |
+
"X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=1234)\n",
|
| 337 |
+
"y_test.head(3)"
|
| 338 |
+
]
|
| 339 |
+
},
|
| 340 |
+
{
|
| 341 |
+
"cell_type": "markdown",
|
| 342 |
+
"metadata": {},
|
| 343 |
+
"source": [
|
| 344 |
+
"## 3. Train & Evaluate Model"
|
| 345 |
+
]
|
| 346 |
+
},
|
| 347 |
+
{
|
| 348 |
+
"cell_type": "markdown",
|
| 349 |
+
"metadata": {},
|
| 350 |
+
"source": [
|
| 351 |
+
"### 3.1. Train and evaluate model with train set"
|
| 352 |
+
]
|
| 353 |
+
},
|
| 354 |
+
{
|
| 355 |
+
"cell_type": "code",
|
| 356 |
+
"execution_count": 12,
|
| 357 |
+
"metadata": {},
|
| 358 |
+
"outputs": [
|
| 359 |
+
{
|
| 360 |
+
"data": {
|
| 361 |
+
"text/html": [
|
| 362 |
+
"<div>\n",
|
| 363 |
+
"<style scoped>\n",
|
| 364 |
+
" .dataframe tbody tr th:only-of-type {\n",
|
| 365 |
+
" vertical-align: middle;\n",
|
| 366 |
+
" }\n",
|
| 367 |
+
"\n",
|
| 368 |
+
" .dataframe tbody tr th {\n",
|
| 369 |
+
" vertical-align: top;\n",
|
| 370 |
+
" }\n",
|
| 371 |
+
"\n",
|
| 372 |
+
" .dataframe thead th {\n",
|
| 373 |
+
" text-align: right;\n",
|
| 374 |
+
" }\n",
|
| 375 |
+
"</style>\n",
|
| 376 |
+
"<table border=\"1\" class=\"dataframe\">\n",
|
| 377 |
+
" <thead>\n",
|
| 378 |
+
" <tr style=\"text-align: right;\">\n",
|
| 379 |
+
" <th></th>\n",
|
| 380 |
+
" <th>Model</th>\n",
|
| 381 |
+
" <th>Precision Score</th>\n",
|
| 382 |
+
" <th>Accuracy score</th>\n",
|
| 383 |
+
" <th>Recall Score</th>\n",
|
| 384 |
+
" <th>F1 score</th>\n",
|
| 385 |
+
" <th>Confusion Matrix</th>\n",
|
| 386 |
+
" </tr>\n",
|
| 387 |
+
" </thead>\n",
|
| 388 |
+
" <tbody>\n",
|
| 389 |
+
" <tr>\n",
|
| 390 |
+
" <th>0</th>\n",
|
| 391 |
+
" <td>SVC</td>\n",
|
| 392 |
+
" <td>[1.0, 0.999]</td>\n",
|
| 393 |
+
" <td>0.999721</td>\n",
|
| 394 |
+
" <td>[0.999, 1.0]</td>\n",
|
| 395 |
+
" <td>[1.0, 1.0]</td>\n",
|
| 396 |
+
" <td>[[1713, 1], [0, 1868]]</td>\n",
|
| 397 |
+
" </tr>\n",
|
| 398 |
+
" <tr>\n",
|
| 399 |
+
" <th>1</th>\n",
|
| 400 |
+
" <td>KNN</td>\n",
|
| 401 |
+
" <td>[1.0, 0.998]</td>\n",
|
| 402 |
+
" <td>0.999162</td>\n",
|
| 403 |
+
" <td>[0.998, 1.0]</td>\n",
|
| 404 |
+
" <td>[0.999, 0.999]</td>\n",
|
| 405 |
+
" <td>[[1711, 3], [0, 1868]]</td>\n",
|
| 406 |
+
" </tr>\n",
|
| 407 |
+
" <tr>\n",
|
| 408 |
+
" <th>2</th>\n",
|
| 409 |
+
" <td>RF</td>\n",
|
| 410 |
+
" <td>[0.999, 0.999]</td>\n",
|
| 411 |
+
" <td>0.999162</td>\n",
|
| 412 |
+
" <td>[0.999, 0.999]</td>\n",
|
| 413 |
+
" <td>[0.999, 0.999]</td>\n",
|
| 414 |
+
" <td>[[1712, 2], [1, 1867]]</td>\n",
|
| 415 |
+
" </tr>\n",
|
| 416 |
+
" <tr>\n",
|
| 417 |
+
" <th>3</th>\n",
|
| 418 |
+
" <td>DTC</td>\n",
|
| 419 |
+
" <td>[0.997, 0.997]</td>\n",
|
| 420 |
+
" <td>0.997208</td>\n",
|
| 421 |
+
" <td>[0.997, 0.997]</td>\n",
|
| 422 |
+
" <td>[0.997, 0.997]</td>\n",
|
| 423 |
+
" <td>[[1709, 5], [5, 1863]]</td>\n",
|
| 424 |
+
" </tr>\n",
|
| 425 |
+
" <tr>\n",
|
| 426 |
+
" <th>4</th>\n",
|
| 427 |
+
" <td>LR</td>\n",
|
| 428 |
+
" <td>[0.992, 0.987]</td>\n",
|
| 429 |
+
" <td>0.989391</td>\n",
|
| 430 |
+
" <td>[0.986, 0.993]</td>\n",
|
| 431 |
+
" <td>[0.989, 0.99]</td>\n",
|
| 432 |
+
" <td>[[1690, 24], [14, 1854]]</td>\n",
|
| 433 |
+
" </tr>\n",
|
| 434 |
+
" <tr>\n",
|
| 435 |
+
" <th>5</th>\n",
|
| 436 |
+
" <td>SGDC</td>\n",
|
| 437 |
+
" <td>[0.992, 0.988]</td>\n",
|
| 438 |
+
" <td>0.989950</td>\n",
|
| 439 |
+
" <td>[0.987, 0.993]</td>\n",
|
| 440 |
+
" <td>[0.989, 0.99]</td>\n",
|
| 441 |
+
" <td>[[1692, 22], [14, 1854]]</td>\n",
|
| 442 |
+
" </tr>\n",
|
| 443 |
+
" <tr>\n",
|
| 444 |
+
" <th>6</th>\n",
|
| 445 |
+
" <td>NB</td>\n",
|
| 446 |
+
" <td>[0.963, 0.952]</td>\n",
|
| 447 |
+
" <td>0.957286</td>\n",
|
| 448 |
+
" <td>[0.947, 0.967]</td>\n",
|
| 449 |
+
" <td>[0.955, 0.959]</td>\n",
|
| 450 |
+
" <td>[[1623, 91], [62, 1806]]</td>\n",
|
| 451 |
+
" </tr>\n",
|
| 452 |
+
" </tbody>\n",
|
| 453 |
+
"</table>\n",
|
| 454 |
+
"</div>"
|
| 455 |
+
],
|
| 456 |
+
"text/plain": [
|
| 457 |
+
" Model Precision Score Accuracy score Recall Score F1 score \\\n",
|
| 458 |
+
"0 SVC [1.0, 0.999] 0.999721 [0.999, 1.0] [1.0, 1.0] \n",
|
| 459 |
+
"1 KNN [1.0, 0.998] 0.999162 [0.998, 1.0] [0.999, 0.999] \n",
|
| 460 |
+
"2 RF [0.999, 0.999] 0.999162 [0.999, 0.999] [0.999, 0.999] \n",
|
| 461 |
+
"3 DTC [0.997, 0.997] 0.997208 [0.997, 0.997] [0.997, 0.997] \n",
|
| 462 |
+
"4 LR [0.992, 0.987] 0.989391 [0.986, 0.993] [0.989, 0.99] \n",
|
| 463 |
+
"5 SGDC [0.992, 0.988] 0.989950 [0.987, 0.993] [0.989, 0.99] \n",
|
| 464 |
+
"6 NB [0.963, 0.952] 0.957286 [0.947, 0.967] [0.955, 0.959] \n",
|
| 465 |
+
"\n",
|
| 466 |
+
" Confusion Matrix \n",
|
| 467 |
+
"0 [[1713, 1], [0, 1868]] \n",
|
| 468 |
+
"1 [[1711, 3], [0, 1868]] \n",
|
| 469 |
+
"2 [[1712, 2], [1, 1867]] \n",
|
| 470 |
+
"3 [[1709, 5], [5, 1863]] \n",
|
| 471 |
+
"4 [[1690, 24], [14, 1854]] \n",
|
| 472 |
+
"5 [[1692, 22], [14, 1854]] \n",
|
| 473 |
+
"6 [[1623, 91], [62, 1806]] "
|
| 474 |
+
]
|
| 475 |
+
},
|
| 476 |
+
"execution_count": 12,
|
| 477 |
+
"metadata": {},
|
| 478 |
+
"output_type": "execute_result"
|
| 479 |
+
}
|
| 480 |
+
],
|
| 481 |
+
"source": [
|
| 482 |
+
"algorithms =[(\"LR\", LogisticRegression()),\n",
|
| 483 |
+
" (\"SVC\", SVC(probability=True)),\n",
|
| 484 |
+
" ('KNN',KNeighborsClassifier()),\n",
|
| 485 |
+
" (\"DTC\", DecisionTreeClassifier()),\n",
|
| 486 |
+
" (\"SGDC\", CalibratedClassifierCV(SGDClassifier())),\n",
|
| 487 |
+
" (\"NB\", GaussianNB()),\n",
|
| 488 |
+
" ('RF', RandomForestClassifier()),]\n",
|
| 489 |
+
"\n",
|
| 490 |
+
"models = {}\n",
|
| 491 |
+
"final_results = []\n",
|
| 492 |
+
"\n",
|
| 493 |
+
"for name, model in algorithms:\n",
|
| 494 |
+
" trained_model = model.fit(X_train, y_train)\n",
|
| 495 |
+
" models[name] = trained_model\n",
|
| 496 |
+
"\n",
|
| 497 |
+
" # Evaluate model\n",
|
| 498 |
+
" model_results = model.predict(X_test)\n",
|
| 499 |
+
"\n",
|
| 500 |
+
" p_score = precision_score(y_test, model_results, average=None, labels=[1, 0])\n",
|
| 501 |
+
" a_score = accuracy_score(y_test, model_results)\n",
|
| 502 |
+
" r_score = recall_score(y_test, model_results, average=None, labels=[1, 0])\n",
|
| 503 |
+
" f1_score_result = f1_score(y_test, model_results, average=None, labels=[1, 0])\n",
|
| 504 |
+
" cm = confusion_matrix(y_test, model_results, labels=[1, 0])\n",
|
| 505 |
+
" final_results.append(( name, round_up_metric_results(p_score), a_score, round_up_metric_results(r_score), round_up_metric_results(f1_score_result), cm))\n",
|
| 506 |
+
"\n",
|
| 507 |
+
"# Sort results by F1 score\n",
|
| 508 |
+
"final_results.sort(key=lambda k: sum(k[4]), reverse=True)\n",
|
| 509 |
+
"pd.DataFrame(final_results, columns=[\"Model\", \"Precision Score\", \"Accuracy score\", \"Recall Score\", \"F1 score\", \"Confusion Matrix\"])"
|
| 510 |
+
]
|
| 511 |
+
},
|
| 512 |
+
{
|
| 513 |
+
"cell_type": "markdown",
|
| 514 |
+
"metadata": {},
|
| 515 |
+
"source": [
|
| 516 |
+
"### 3.2. Test set evaluation"
|
| 517 |
+
]
|
| 518 |
+
},
|
| 519 |
+
{
|
| 520 |
+
"cell_type": "code",
|
| 521 |
+
"execution_count": 15,
|
| 522 |
+
"metadata": {},
|
| 523 |
+
"outputs": [
|
| 524 |
+
{
|
| 525 |
+
"name": "stdout",
|
| 526 |
+
"output_type": "stream",
|
| 527 |
+
"text": [
|
| 528 |
+
"Headers: ['label', 'nose_x', 'nose_y', 'nose_z', 'nose_v', 'left_shoulder_x', 'left_shoulder_y', 'left_shoulder_z', 'left_shoulder_v', 'right_shoulder_x', 'right_shoulder_y', 'right_shoulder_z', 'right_shoulder_v', 'left_hip_x', 'left_hip_y', 'left_hip_z', 'left_hip_v', 'right_hip_x', 'right_hip_y', 'right_hip_z', 'right_hip_v', 'left_knee_x', 'left_knee_y', 'left_knee_z', 'left_knee_v', 'right_knee_x', 'right_knee_y', 'right_knee_z', 'right_knee_v', 'left_ankle_x', 'left_ankle_y', 'left_ankle_z', 'left_ankle_v', 'right_ankle_x', 'right_ankle_y', 'right_ankle_z', 'right_ankle_v', 'left_heel_x', 'left_heel_y', 'left_heel_z', 'left_heel_v', 'right_heel_x', 'right_heel_y', 'right_heel_z', 'right_heel_v', 'left_foot_index_x', 'left_foot_index_y', 'left_foot_index_z', 'left_foot_index_v', 'right_foot_index_x', 'right_foot_index_y', 'right_foot_index_z', 'right_foot_index_v']\n",
|
| 529 |
+
"Number of rows: 1107 \n",
|
| 530 |
+
"Number of columns: 53\n",
|
| 531 |
+
"\n",
|
| 532 |
+
"Labels: \n",
|
| 533 |
+
"L 561\n",
|
| 534 |
+
"C 546\n",
|
| 535 |
+
"Name: label, dtype: int64\n",
|
| 536 |
+
"\n",
|
| 537 |
+
"Missing values: False\n",
|
| 538 |
+
"\n",
|
| 539 |
+
"Duplicate Rows : 0\n"
|
| 540 |
+
]
|
| 541 |
+
}
|
| 542 |
+
],
|
| 543 |
+
"source": [
|
| 544 |
+
"test_df = describe_dataset(TEST_SET_PATH)\n",
|
| 545 |
+
"test_df = test_df.sample(frac=1).reset_index(drop=True)\n",
|
| 546 |
+
"\n",
|
| 547 |
+
"# Categorizing label\n",
|
| 548 |
+
"test_df.loc[test_df[\"label\"] == \"L\", \"label\"] = 0\n",
|
| 549 |
+
"test_df.loc[test_df[\"label\"] == \"C\", \"label\"] = 1\n",
|
| 550 |
+
"\n",
|
| 551 |
+
"test_x = test_df.drop(\"label\", axis=1)\n",
|
| 552 |
+
"test_y = test_df[\"label\"].astype(\"int\")\n",
|
| 553 |
+
"\n",
|
| 554 |
+
"test_x = pd.DataFrame(sc.transform(test_x))"
|
| 555 |
+
]
|
| 556 |
+
},
|
| 557 |
+
{
|
| 558 |
+
"cell_type": "code",
|
| 559 |
+
"execution_count": 17,
|
| 560 |
+
"metadata": {},
|
| 561 |
+
"outputs": [
|
| 562 |
+
{
|
| 563 |
+
"data": {
|
| 564 |
+
"text/html": [
|
| 565 |
+
"<div>\n",
|
| 566 |
+
"<style scoped>\n",
|
| 567 |
+
" .dataframe tbody tr th:only-of-type {\n",
|
| 568 |
+
" vertical-align: middle;\n",
|
| 569 |
+
" }\n",
|
| 570 |
+
"\n",
|
| 571 |
+
" .dataframe tbody tr th {\n",
|
| 572 |
+
" vertical-align: top;\n",
|
| 573 |
+
" }\n",
|
| 574 |
+
"\n",
|
| 575 |
+
" .dataframe thead th {\n",
|
| 576 |
+
" text-align: right;\n",
|
| 577 |
+
" }\n",
|
| 578 |
+
"</style>\n",
|
| 579 |
+
"<table border=\"1\" class=\"dataframe\">\n",
|
| 580 |
+
" <thead>\n",
|
| 581 |
+
" <tr style=\"text-align: right;\">\n",
|
| 582 |
+
" <th></th>\n",
|
| 583 |
+
" <th>Model</th>\n",
|
| 584 |
+
" <th>Precision Score</th>\n",
|
| 585 |
+
" <th>Accuracy score</th>\n",
|
| 586 |
+
" <th>Recall Score</th>\n",
|
| 587 |
+
" <th>F1 score</th>\n",
|
| 588 |
+
" <th>Confusion Matrix</th>\n",
|
| 589 |
+
" </tr>\n",
|
| 590 |
+
" </thead>\n",
|
| 591 |
+
" <tbody>\n",
|
| 592 |
+
" <tr>\n",
|
| 593 |
+
" <th>0</th>\n",
|
| 594 |
+
" <td>LR</td>\n",
|
| 595 |
+
" <td>[0.948, 0.998]</td>\n",
|
| 596 |
+
" <td>0.971996</td>\n",
|
| 597 |
+
" <td>[0.998, 0.947]</td>\n",
|
| 598 |
+
" <td>[0.972, 0.972]</td>\n",
|
| 599 |
+
" <td>[[545, 1], [30, 531]]</td>\n",
|
| 600 |
+
" </tr>\n",
|
| 601 |
+
" <tr>\n",
|
| 602 |
+
" <th>1</th>\n",
|
| 603 |
+
" <td>SGDC</td>\n",
|
| 604 |
+
" <td>[0.922, 0.998]</td>\n",
|
| 605 |
+
" <td>0.957543</td>\n",
|
| 606 |
+
" <td>[0.998, 0.918]</td>\n",
|
| 607 |
+
" <td>[0.959, 0.956]</td>\n",
|
| 608 |
+
" <td>[[545, 1], [46, 515]]</td>\n",
|
| 609 |
+
" </tr>\n",
|
| 610 |
+
" <tr>\n",
|
| 611 |
+
" <th>2</th>\n",
|
| 612 |
+
" <td>DTC</td>\n",
|
| 613 |
+
" <td>[0.95, 0.889]</td>\n",
|
| 614 |
+
" <td>0.916893</td>\n",
|
| 615 |
+
" <td>[0.877, 0.955]</td>\n",
|
| 616 |
+
" <td>[0.912, 0.921]</td>\n",
|
| 617 |
+
" <td>[[479, 67], [25, 536]]</td>\n",
|
| 618 |
+
" </tr>\n",
|
| 619 |
+
" <tr>\n",
|
| 620 |
+
" <th>3</th>\n",
|
| 621 |
+
" <td>RF</td>\n",
|
| 622 |
+
" <td>[0.786, 0.921]</td>\n",
|
| 623 |
+
" <td>0.841915</td>\n",
|
| 624 |
+
" <td>[0.934, 0.752]</td>\n",
|
| 625 |
+
" <td>[0.854, 0.828]</td>\n",
|
| 626 |
+
" <td>[[510, 36], [139, 422]]</td>\n",
|
| 627 |
+
" </tr>\n",
|
| 628 |
+
" <tr>\n",
|
| 629 |
+
" <th>4</th>\n",
|
| 630 |
+
" <td>NB</td>\n",
|
| 631 |
+
" <td>[0.79, 0.751]</td>\n",
|
| 632 |
+
" <td>0.768744</td>\n",
|
| 633 |
+
" <td>[0.723, 0.813]</td>\n",
|
| 634 |
+
" <td>[0.755, 0.781]</td>\n",
|
| 635 |
+
" <td>[[395, 151], [105, 456]]</td>\n",
|
| 636 |
+
" </tr>\n",
|
| 637 |
+
" <tr>\n",
|
| 638 |
+
" <th>5</th>\n",
|
| 639 |
+
" <td>KNN</td>\n",
|
| 640 |
+
" <td>[0.737, 0.799]</td>\n",
|
| 641 |
+
" <td>0.765131</td>\n",
|
| 642 |
+
" <td>[0.815, 0.717]</td>\n",
|
| 643 |
+
" <td>[0.774, 0.756]</td>\n",
|
| 644 |
+
" <td>[[445, 101], [159, 402]]</td>\n",
|
| 645 |
+
" </tr>\n",
|
| 646 |
+
" <tr>\n",
|
| 647 |
+
" <th>6</th>\n",
|
| 648 |
+
" <td>SVC</td>\n",
|
| 649 |
+
" <td>[0.659, 0.842]</td>\n",
|
| 650 |
+
" <td>0.719964</td>\n",
|
| 651 |
+
" <td>[0.894, 0.551]</td>\n",
|
| 652 |
+
" <td>[0.759, 0.666]</td>\n",
|
| 653 |
+
" <td>[[488, 58], [252, 309]]</td>\n",
|
| 654 |
+
" </tr>\n",
|
| 655 |
+
" </tbody>\n",
|
| 656 |
+
"</table>\n",
|
| 657 |
+
"</div>"
|
| 658 |
+
],
|
| 659 |
+
"text/plain": [
|
| 660 |
+
" Model Precision Score Accuracy score Recall Score F1 score \\\n",
|
| 661 |
+
"0 LR [0.948, 0.998] 0.971996 [0.998, 0.947] [0.972, 0.972] \n",
|
| 662 |
+
"1 SGDC [0.922, 0.998] 0.957543 [0.998, 0.918] [0.959, 0.956] \n",
|
| 663 |
+
"2 DTC [0.95, 0.889] 0.916893 [0.877, 0.955] [0.912, 0.921] \n",
|
| 664 |
+
"3 RF [0.786, 0.921] 0.841915 [0.934, 0.752] [0.854, 0.828] \n",
|
| 665 |
+
"4 NB [0.79, 0.751] 0.768744 [0.723, 0.813] [0.755, 0.781] \n",
|
| 666 |
+
"5 KNN [0.737, 0.799] 0.765131 [0.815, 0.717] [0.774, 0.756] \n",
|
| 667 |
+
"6 SVC [0.659, 0.842] 0.719964 [0.894, 0.551] [0.759, 0.666] \n",
|
| 668 |
+
"\n",
|
| 669 |
+
" Confusion Matrix \n",
|
| 670 |
+
"0 [[545, 1], [30, 531]] \n",
|
| 671 |
+
"1 [[545, 1], [46, 515]] \n",
|
| 672 |
+
"2 [[479, 67], [25, 536]] \n",
|
| 673 |
+
"3 [[510, 36], [139, 422]] \n",
|
| 674 |
+
"4 [[395, 151], [105, 456]] \n",
|
| 675 |
+
"5 [[445, 101], [159, 402]] \n",
|
| 676 |
+
"6 [[488, 58], [252, 309]] "
|
| 677 |
+
]
|
| 678 |
+
},
|
| 679 |
+
"execution_count": 17,
|
| 680 |
+
"metadata": {},
|
| 681 |
+
"output_type": "execute_result"
|
| 682 |
+
}
|
| 683 |
+
],
|
| 684 |
+
"source": [
|
| 685 |
+
"testset_final_results = []\n",
|
| 686 |
+
"\n",
|
| 687 |
+
"for name, model in models.items():\n",
|
| 688 |
+
" # Evaluate model\n",
|
| 689 |
+
" model_results = model.predict(test_x)\n",
|
| 690 |
+
"\n",
|
| 691 |
+
" p_score = precision_score(test_y, model_results, average=None, labels=[1, 0])\n",
|
| 692 |
+
" a_score = accuracy_score(test_y, model_results)\n",
|
| 693 |
+
" r_score = recall_score(test_y, model_results, average=None, labels=[1, 0])\n",
|
| 694 |
+
" f1_score_result = f1_score(test_y, model_results, average=None, labels=[1, 0])\n",
|
| 695 |
+
" cm = confusion_matrix(test_y, model_results, labels=[1, 0])\n",
|
| 696 |
+
" testset_final_results.append(( name, round_up_metric_results(p_score), a_score, round_up_metric_results(r_score), round_up_metric_results(f1_score_result), cm ))\n",
|
| 697 |
+
"\n",
|
| 698 |
+
"\n",
|
| 699 |
+
"testset_final_results.sort(key=lambda k: sum(k[4]), reverse=True)\n",
|
| 700 |
+
"pd.DataFrame(testset_final_results, columns=[\"Model\", \"Precision Score\", \"Accuracy score\", \"Recall Score\", \"F1 score\", \"Confusion Matrix\"])"
|
| 701 |
+
]
|
| 702 |
+
},
|
| 703 |
+
{
|
| 704 |
+
"cell_type": "markdown",
|
| 705 |
+
"metadata": {},
|
| 706 |
+
"source": [
|
| 707 |
+
"## 4. Dump Models \n",
|
| 708 |
+
"\n",
|
| 709 |
+
"According to the evaluation above, LR and KNN SGDC would be chosen for more eval."
|
| 710 |
+
]
|
| 711 |
+
},
|
| 712 |
+
{
|
| 713 |
+
"cell_type": "code",
|
| 714 |
+
"execution_count": 18,
|
| 715 |
+
"metadata": {},
|
| 716 |
+
"outputs": [],
|
| 717 |
+
"source": [
|
| 718 |
+
"with open(\"./model/sklearn/err_all_sklearn.pkl\", \"wb\") as f:\n",
|
| 719 |
+
" pickle.dump(models, f)"
|
| 720 |
+
]
|
| 721 |
+
},
|
| 722 |
+
{
|
| 723 |
+
"cell_type": "code",
|
| 724 |
+
"execution_count": 19,
|
| 725 |
+
"metadata": {},
|
| 726 |
+
"outputs": [],
|
| 727 |
+
"source": [
|
| 728 |
+
"with open(\"./model/sklearn/err_SGDC_model.pkl\", \"wb\") as f:\n",
|
| 729 |
+
" pickle.dump(models[\"SGDC\"], f)"
|
| 730 |
+
]
|
| 731 |
+
},
|
| 732 |
+
{
|
| 733 |
+
"cell_type": "code",
|
| 734 |
+
"execution_count": 20,
|
| 735 |
+
"metadata": {},
|
| 736 |
+
"outputs": [],
|
| 737 |
+
"source": [
|
| 738 |
+
"with open(\"./model/sklearn/err_LR_model.pkl\", \"wb\") as f:\n",
|
| 739 |
+
" pickle.dump(models[\"LR\"], f)"
|
| 740 |
+
]
|
| 741 |
+
},
|
| 742 |
+
{
|
| 743 |
+
"cell_type": "code",
|
| 744 |
+
"execution_count": null,
|
| 745 |
+
"metadata": {},
|
| 746 |
+
"outputs": [],
|
| 747 |
+
"source": []
|
| 748 |
+
}
|
| 749 |
+
],
|
| 750 |
+
"metadata": {
|
| 751 |
+
"kernelspec": {
|
| 752 |
+
"display_name": "Python 3.8.13 (conda)",
|
| 753 |
+
"language": "python",
|
| 754 |
+
"name": "python3"
|
| 755 |
+
},
|
| 756 |
+
"language_info": {
|
| 757 |
+
"codemirror_mode": {
|
| 758 |
+
"name": "ipython",
|
| 759 |
+
"version": 3
|
| 760 |
+
},
|
| 761 |
+
"file_extension": ".py",
|
| 762 |
+
"mimetype": "text/x-python",
|
| 763 |
+
"name": "python",
|
| 764 |
+
"nbconvert_exporter": "python",
|
| 765 |
+
"pygments_lexer": "ipython3",
|
| 766 |
+
"version": "3.8.13"
|
| 767 |
+
},
|
| 768 |
+
"orig_nbformat": 4,
|
| 769 |
+
"vscode": {
|
| 770 |
+
"interpreter": {
|
| 771 |
+
"hash": "9260f401923fb5c4108c543a7d176de9733d378b3752e49535ad7c43c2271b65"
|
| 772 |
+
}
|
| 773 |
+
}
|
| 774 |
+
},
|
| 775 |
+
"nbformat": 4,
|
| 776 |
+
"nbformat_minor": 2
|
| 777 |
+
}
|
core/lunge_model/7.err.deep_learning.ipynb
ADDED
|
@@ -0,0 +1,1366 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": 1,
|
| 6 |
+
"metadata": {},
|
| 7 |
+
"outputs": [],
|
| 8 |
+
"source": [
|
| 9 |
+
"# Data visualization\n",
|
| 10 |
+
"import numpy as np\n",
|
| 11 |
+
"import pandas as pd \n",
|
| 12 |
+
"\n",
|
| 13 |
+
"# Keras\n",
|
| 14 |
+
"from keras.models import Sequential\n",
|
| 15 |
+
"from keras.layers import Dense\n",
|
| 16 |
+
"from keras.layers import Dropout\n",
|
| 17 |
+
"from keras.optimizers import Adam\n",
|
| 18 |
+
"from keras.utils.np_utils import to_categorical\n",
|
| 19 |
+
"from keras.callbacks import EarlyStopping\n",
|
| 20 |
+
"import keras_tuner as kt\n",
|
| 21 |
+
"\n",
|
| 22 |
+
"# Train-Test\n",
|
| 23 |
+
"from sklearn.model_selection import train_test_split\n",
|
| 24 |
+
"# Classification Report\n",
|
| 25 |
+
"from sklearn.metrics import confusion_matrix, precision_recall_fscore_support\n",
|
| 26 |
+
"\n",
|
| 27 |
+
"import pickle\n",
|
| 28 |
+
"\n",
|
| 29 |
+
"import warnings\n",
|
| 30 |
+
"warnings.filterwarnings('ignore')"
|
| 31 |
+
]
|
| 32 |
+
},
|
| 33 |
+
{
|
| 34 |
+
"cell_type": "markdown",
|
| 35 |
+
"metadata": {},
|
| 36 |
+
"source": [
|
| 37 |
+
"## 1. Set up important landmarks and functions"
|
| 38 |
+
]
|
| 39 |
+
},
|
| 40 |
+
{
|
| 41 |
+
"cell_type": "code",
|
| 42 |
+
"execution_count": 2,
|
| 43 |
+
"metadata": {},
|
| 44 |
+
"outputs": [],
|
| 45 |
+
"source": [
|
| 46 |
+
"# Determine important landmarks for lunge\n",
|
| 47 |
+
"IMPORTANT_LMS = [\n",
|
| 48 |
+
" \"NOSE\",\n",
|
| 49 |
+
" \"LEFT_SHOULDER\",\n",
|
| 50 |
+
" \"RIGHT_SHOULDER\",\n",
|
| 51 |
+
" \"LEFT_HIP\",\n",
|
| 52 |
+
" \"RIGHT_HIP\",\n",
|
| 53 |
+
" \"LEFT_KNEE\",\n",
|
| 54 |
+
" \"RIGHT_KNEE\",\n",
|
| 55 |
+
" \"LEFT_ANKLE\",\n",
|
| 56 |
+
" \"RIGHT_ANKLE\",\n",
|
| 57 |
+
" \"LEFT_HEEL\",\n",
|
| 58 |
+
" \"RIGHT_HEEL\",\n",
|
| 59 |
+
" \"LEFT_FOOT_INDEX\",\n",
|
| 60 |
+
" \"RIGHT_FOOT_INDEX\",\n",
|
| 61 |
+
"]\n",
|
| 62 |
+
"\n",
|
| 63 |
+
"# Generate all columns of the data frame\n",
|
| 64 |
+
"\n",
|
| 65 |
+
"HEADERS = [\"label\"] # Label column\n",
|
| 66 |
+
"\n",
|
| 67 |
+
"for lm in IMPORTANT_LMS:\n",
|
| 68 |
+
" HEADERS += [f\"{lm.lower()}_x\", f\"{lm.lower()}_y\", f\"{lm.lower()}_z\", f\"{lm.lower()}_v\"]\n",
|
| 69 |
+
"\n",
|
| 70 |
+
"TRAIN_SET_PATH = \"./err.train.csv\"\n",
|
| 71 |
+
"TEST_SET_PATH = \"./err.test.csv\""
|
| 72 |
+
]
|
| 73 |
+
},
|
| 74 |
+
{
|
| 75 |
+
"cell_type": "code",
|
| 76 |
+
"execution_count": 3,
|
| 77 |
+
"metadata": {},
|
| 78 |
+
"outputs": [],
|
| 79 |
+
"source": [
|
| 80 |
+
"def describe_dataset(dataset_path: str):\n",
|
| 81 |
+
" '''\n",
|
| 82 |
+
" Describe dataset\n",
|
| 83 |
+
" '''\n",
|
| 84 |
+
"\n",
|
| 85 |
+
" data = pd.read_csv(dataset_path)\n",
|
| 86 |
+
" print(f\"Headers: {list(data.columns.values)}\")\n",
|
| 87 |
+
" print(f'Number of rows: {data.shape[0]} \\nNumber of columns: {data.shape[1]}\\n')\n",
|
| 88 |
+
" print(f\"Labels: \\n{data['label'].value_counts()}\\n\")\n",
|
| 89 |
+
" print(f\"Missing values: {data.isnull().values.any()}\\n\")\n",
|
| 90 |
+
" \n",
|
| 91 |
+
" duplicate = data[data.duplicated()]\n",
|
| 92 |
+
" print(f\"Duplicate Rows : {len(duplicate.sum(axis=1))}\")\n",
|
| 93 |
+
"\n",
|
| 94 |
+
" return data\n",
|
| 95 |
+
"\n",
|
| 96 |
+
"\n",
|
| 97 |
+
"# Remove duplicate rows (optional)\n",
|
| 98 |
+
"def remove_duplicate_rows(dataset_path: str):\n",
|
| 99 |
+
" '''\n",
|
| 100 |
+
" Remove duplicated data from the dataset then save it to another files\n",
|
| 101 |
+
" '''\n",
|
| 102 |
+
" \n",
|
| 103 |
+
" df = pd.read_csv(dataset_path)\n",
|
| 104 |
+
" df.drop_duplicates(keep=\"first\", inplace=True)\n",
|
| 105 |
+
" df.to_csv(f\"cleaned_dataset.csv\", sep=',', encoding='utf-8', index=False)\n",
|
| 106 |
+
"\n",
|
| 107 |
+
"\n",
|
| 108 |
+
"def round_up_metric_results(results) -> list:\n",
|
| 109 |
+
" '''Round up metrics results such as precision score, recall score, ...'''\n",
|
| 110 |
+
" return list(map(lambda el: round(el, 3), results))"
|
| 111 |
+
]
|
| 112 |
+
},
|
| 113 |
+
{
|
| 114 |
+
"cell_type": "markdown",
|
| 115 |
+
"metadata": {},
|
| 116 |
+
"source": [
|
| 117 |
+
"## 2. Describe and process data"
|
| 118 |
+
]
|
| 119 |
+
},
|
| 120 |
+
{
|
| 121 |
+
"cell_type": "code",
|
| 122 |
+
"execution_count": 4,
|
| 123 |
+
"metadata": {},
|
| 124 |
+
"outputs": [
|
| 125 |
+
{
|
| 126 |
+
"name": "stdout",
|
| 127 |
+
"output_type": "stream",
|
| 128 |
+
"text": [
|
| 129 |
+
"Headers: ['label', 'nose_x', 'nose_y', 'nose_z', 'nose_v', 'left_shoulder_x', 'left_shoulder_y', 'left_shoulder_z', 'left_shoulder_v', 'right_shoulder_x', 'right_shoulder_y', 'right_shoulder_z', 'right_shoulder_v', 'left_hip_x', 'left_hip_y', 'left_hip_z', 'left_hip_v', 'right_hip_x', 'right_hip_y', 'right_hip_z', 'right_hip_v', 'left_knee_x', 'left_knee_y', 'left_knee_z', 'left_knee_v', 'right_knee_x', 'right_knee_y', 'right_knee_z', 'right_knee_v', 'left_ankle_x', 'left_ankle_y', 'left_ankle_z', 'left_ankle_v', 'right_ankle_x', 'right_ankle_y', 'right_ankle_z', 'right_ankle_v', 'left_heel_x', 'left_heel_y', 'left_heel_z', 'left_heel_v', 'right_heel_x', 'right_heel_y', 'right_heel_z', 'right_heel_v', 'left_foot_index_x', 'left_foot_index_y', 'left_foot_index_z', 'left_foot_index_v', 'right_foot_index_x', 'right_foot_index_y', 'right_foot_index_z', 'right_foot_index_v']\n",
|
| 130 |
+
"Number of rows: 17907 \n",
|
| 131 |
+
"Number of columns: 53\n",
|
| 132 |
+
"\n",
|
| 133 |
+
"Labels: \n",
|
| 134 |
+
"L 9114\n",
|
| 135 |
+
"C 8793\n",
|
| 136 |
+
"Name: label, dtype: int64\n",
|
| 137 |
+
"\n",
|
| 138 |
+
"Missing values: False\n",
|
| 139 |
+
"\n",
|
| 140 |
+
"Duplicate Rows : 0\n"
|
| 141 |
+
]
|
| 142 |
+
}
|
| 143 |
+
],
|
| 144 |
+
"source": [
|
| 145 |
+
"# load dataset\n",
|
| 146 |
+
"df = describe_dataset(TRAIN_SET_PATH)\n",
|
| 147 |
+
"\n",
|
| 148 |
+
"# Categorizing label\n",
|
| 149 |
+
"df.loc[df[\"label\"] == \"L\", \"label\"] = 0\n",
|
| 150 |
+
"df.loc[df[\"label\"] == \"C\", \"label\"] = 1"
|
| 151 |
+
]
|
| 152 |
+
},
|
| 153 |
+
{
|
| 154 |
+
"cell_type": "code",
|
| 155 |
+
"execution_count": 5,
|
| 156 |
+
"metadata": {},
|
| 157 |
+
"outputs": [],
|
| 158 |
+
"source": [
|
| 159 |
+
"# Standard Scaling of features\n",
|
| 160 |
+
"with open(\"./model/input_scaler.pkl\", \"rb\") as f2:\n",
|
| 161 |
+
" input_scaler = pickle.load(f2)\n",
|
| 162 |
+
"\n",
|
| 163 |
+
"x = df.drop(\"label\", axis = 1)\n",
|
| 164 |
+
"x = pd.DataFrame(input_scaler.transform(x))\n",
|
| 165 |
+
"\n",
|
| 166 |
+
"y = df[\"label\"]\n",
|
| 167 |
+
"\n",
|
| 168 |
+
"# # Converting prediction to categorical\n",
|
| 169 |
+
"y_cat = to_categorical(y)"
|
| 170 |
+
]
|
| 171 |
+
},
|
| 172 |
+
{
|
| 173 |
+
"cell_type": "code",
|
| 174 |
+
"execution_count": 6,
|
| 175 |
+
"metadata": {},
|
| 176 |
+
"outputs": [],
|
| 177 |
+
"source": [
|
| 178 |
+
"x_train, x_test, y_train, y_test = train_test_split(x.values, y_cat, test_size=0.2)"
|
| 179 |
+
]
|
| 180 |
+
},
|
| 181 |
+
{
|
| 182 |
+
"cell_type": "markdown",
|
| 183 |
+
"metadata": {},
|
| 184 |
+
"source": [
|
| 185 |
+
"## 3. Train model"
|
| 186 |
+
]
|
| 187 |
+
},
|
| 188 |
+
{
|
| 189 |
+
"cell_type": "markdown",
|
| 190 |
+
"metadata": {},
|
| 191 |
+
"source": [
|
| 192 |
+
"### 3.1. Set up"
|
| 193 |
+
]
|
| 194 |
+
},
|
| 195 |
+
{
|
| 196 |
+
"cell_type": "code",
|
| 197 |
+
"execution_count": 7,
|
| 198 |
+
"metadata": {},
|
| 199 |
+
"outputs": [],
|
| 200 |
+
"source": [
|
| 201 |
+
"stop_early = EarlyStopping(monitor='loss', patience=3)\n",
|
| 202 |
+
"\n",
|
| 203 |
+
"# Final Results\n",
|
| 204 |
+
"final_models = {}"
|
| 205 |
+
]
|
| 206 |
+
},
|
| 207 |
+
{
|
| 208 |
+
"cell_type": "code",
|
| 209 |
+
"execution_count": 8,
|
| 210 |
+
"metadata": {},
|
| 211 |
+
"outputs": [],
|
| 212 |
+
"source": [
|
| 213 |
+
"def describe_model(model):\n",
|
| 214 |
+
" '''\n",
|
| 215 |
+
" Describe Model architecture\n",
|
| 216 |
+
" '''\n",
|
| 217 |
+
" print(f\"Describe models architecture\")\n",
|
| 218 |
+
" for i, layer in enumerate(model.layers):\n",
|
| 219 |
+
" number_of_units = layer.units if hasattr(layer, 'units') else 0\n",
|
| 220 |
+
"\n",
|
| 221 |
+
" if hasattr(layer, \"activation\"):\n",
|
| 222 |
+
" print(f\"Layer-{i + 1}: {number_of_units} units, func: \", layer.activation)\n",
|
| 223 |
+
" else:\n",
|
| 224 |
+
" print(f\"Layer-{i + 1}: {number_of_units} units, func: None\")\n",
|
| 225 |
+
" \n",
|
| 226 |
+
"\n",
|
| 227 |
+
"def get_best_model(tuner):\n",
|
| 228 |
+
" '''\n",
|
| 229 |
+
" Describe and return the best model found from keras tuner\n",
|
| 230 |
+
" '''\n",
|
| 231 |
+
" best_hps = tuner.get_best_hyperparameters(num_trials=1)[0]\n",
|
| 232 |
+
" best_model = tuner.hypermodel.build(best_hps)\n",
|
| 233 |
+
"\n",
|
| 234 |
+
" describe_model(best_model)\n",
|
| 235 |
+
"\n",
|
| 236 |
+
" for h_param in [\"learning_rate\"]:\n",
|
| 237 |
+
" print(f\"{h_param}: {tuner.get_best_hyperparameters()[0].get(h_param)}\")\n",
|
| 238 |
+
" \n",
|
| 239 |
+
" return best_model"
|
| 240 |
+
]
|
| 241 |
+
},
|
| 242 |
+
{
|
| 243 |
+
"cell_type": "markdown",
|
| 244 |
+
"metadata": {},
|
| 245 |
+
"source": [
|
| 246 |
+
"### 3.2. Model with 3 layers "
|
| 247 |
+
]
|
| 248 |
+
},
|
| 249 |
+
{
|
| 250 |
+
"cell_type": "code",
|
| 251 |
+
"execution_count": 65,
|
| 252 |
+
"metadata": {},
|
| 253 |
+
"outputs": [],
|
| 254 |
+
"source": [
|
| 255 |
+
"def model_builder(hp):\n",
|
| 256 |
+
" model = Sequential()\n",
|
| 257 |
+
" model.add(Dense(52, input_dim = 52, activation = \"relu\"))\n",
|
| 258 |
+
"\n",
|
| 259 |
+
" hp_activation = hp.Choice('activation', values=['relu', 'tanh'])\n",
|
| 260 |
+
" hp_layer_1 = hp.Int('layer_1', min_value=32, max_value=512, step=32)\n",
|
| 261 |
+
" hp_learning_rate = hp.Choice('learning_rate', values=[1e-2, 1e-3, 1e-4])\n",
|
| 262 |
+
"\n",
|
| 263 |
+
" model.add(Dense(units=hp_layer_1, activation=hp_activation))\n",
|
| 264 |
+
" model.add(Dense(2, activation='softmax'))\n",
|
| 265 |
+
"\n",
|
| 266 |
+
" model.compile(optimizer=Adam(learning_rate=hp_learning_rate), loss=\"categorical_crossentropy\", metrics = [\"accuracy\"])\n",
|
| 267 |
+
" \n",
|
| 268 |
+
" return model"
|
| 269 |
+
]
|
| 270 |
+
},
|
| 271 |
+
{
|
| 272 |
+
"cell_type": "code",
|
| 273 |
+
"execution_count": 66,
|
| 274 |
+
"metadata": {},
|
| 275 |
+
"outputs": [
|
| 276 |
+
{
|
| 277 |
+
"name": "stdout",
|
| 278 |
+
"output_type": "stream",
|
| 279 |
+
"text": [
|
| 280 |
+
"INFO:tensorflow:Reloading Oracle from existing project keras_tuner_dir/keras_tuner_demo/oracle.json\n"
|
| 281 |
+
]
|
| 282 |
+
}
|
| 283 |
+
],
|
| 284 |
+
"source": [
|
| 285 |
+
"tuner = kt.Hyperband(\n",
|
| 286 |
+
" model_builder,\n",
|
| 287 |
+
" objective='accuracy',\n",
|
| 288 |
+
" max_epochs=10,\n",
|
| 289 |
+
" directory='keras_tuner_dir',\n",
|
| 290 |
+
" project_name='keras_tuner_demo'\n",
|
| 291 |
+
")"
|
| 292 |
+
]
|
| 293 |
+
},
|
| 294 |
+
{
|
| 295 |
+
"cell_type": "code",
|
| 296 |
+
"execution_count": 67,
|
| 297 |
+
"metadata": {},
|
| 298 |
+
"outputs": [
|
| 299 |
+
{
|
| 300 |
+
"name": "stdout",
|
| 301 |
+
"output_type": "stream",
|
| 302 |
+
"text": [
|
| 303 |
+
"Trial 30 Complete [00h 00m 38s]\n",
|
| 304 |
+
"accuracy: 0.9995812177658081\n",
|
| 305 |
+
"\n",
|
| 306 |
+
"Best accuracy So Far: 1.0000001192092896\n",
|
| 307 |
+
"Total elapsed time: 00h 08m 17s\n",
|
| 308 |
+
"INFO:tensorflow:Oracle triggered exit\n"
|
| 309 |
+
]
|
| 310 |
+
}
|
| 311 |
+
],
|
| 312 |
+
"source": [
|
| 313 |
+
"tuner.search(x_train, y_train, epochs=10, callbacks=[stop_early])"
|
| 314 |
+
]
|
| 315 |
+
},
|
| 316 |
+
{
|
| 317 |
+
"cell_type": "code",
|
| 318 |
+
"execution_count": 102,
|
| 319 |
+
"metadata": {},
|
| 320 |
+
"outputs": [
|
| 321 |
+
{
|
| 322 |
+
"name": "stdout",
|
| 323 |
+
"output_type": "stream",
|
| 324 |
+
"text": [
|
| 325 |
+
"Describe models architecture\n",
|
| 326 |
+
"Layer-1: 52 units, func: <function relu at 0x155f86a60>\n",
|
| 327 |
+
"Layer-2: 192 units, func: <function relu at 0x155f86a60>\n",
|
| 328 |
+
"Layer-3: 2 units, func: <function softmax at 0x155f86040>\n",
|
| 329 |
+
"learning_rate: 0.001\n",
|
| 330 |
+
"Epoch 1/100\n",
|
| 331 |
+
" 12/1433 [..............................] - ETA: 13s - loss: 0.5386 - accuracy: 0.7000"
|
| 332 |
+
]
|
| 333 |
+
},
|
| 334 |
+
{
|
| 335 |
+
"name": "stderr",
|
| 336 |
+
"output_type": "stream",
|
| 337 |
+
"text": [
|
| 338 |
+
"2022-11-22 10:49:54.623575: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:113] Plugin optimizer for device_type GPU is enabled.\n"
|
| 339 |
+
]
|
| 340 |
+
},
|
| 341 |
+
{
|
| 342 |
+
"name": "stdout",
|
| 343 |
+
"output_type": "stream",
|
| 344 |
+
"text": [
|
| 345 |
+
"1429/1433 [============================>.] - ETA: 0s - loss: 0.0323 - accuracy: 0.9883"
|
| 346 |
+
]
|
| 347 |
+
},
|
| 348 |
+
{
|
| 349 |
+
"name": "stderr",
|
| 350 |
+
"output_type": "stream",
|
| 351 |
+
"text": [
|
| 352 |
+
"2022-11-22 10:50:06.095468: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:113] Plugin optimizer for device_type GPU is enabled.\n"
|
| 353 |
+
]
|
| 354 |
+
},
|
| 355 |
+
{
|
| 356 |
+
"name": "stdout",
|
| 357 |
+
"output_type": "stream",
|
| 358 |
+
"text": [
|
| 359 |
+
"1433/1433 [==============================] - 14s 9ms/step - loss: 0.0322 - accuracy: 0.9883 - val_loss: 0.0022 - val_accuracy: 0.9992\n",
|
| 360 |
+
"Epoch 2/100\n",
|
| 361 |
+
"1433/1433 [==============================] - 14s 10ms/step - loss: 0.0048 - accuracy: 0.9984 - val_loss: 0.0103 - val_accuracy: 0.9964\n",
|
| 362 |
+
"Epoch 3/100\n",
|
| 363 |
+
"1433/1433 [==============================] - 14s 10ms/step - loss: 0.0044 - accuracy: 0.9986 - val_loss: 0.0018 - val_accuracy: 0.9994\n",
|
| 364 |
+
"Epoch 4/100\n",
|
| 365 |
+
"1433/1433 [==============================] - 13s 9ms/step - loss: 0.0024 - accuracy: 0.9992 - val_loss: 8.9034e-04 - val_accuracy: 0.9997\n",
|
| 366 |
+
"Epoch 5/100\n",
|
| 367 |
+
"1433/1433 [==============================] - 13s 9ms/step - loss: 0.0026 - accuracy: 0.9991 - val_loss: 6.6072e-04 - val_accuracy: 0.9997\n",
|
| 368 |
+
"Epoch 6/100\n",
|
| 369 |
+
"1433/1433 [==============================] - 13s 9ms/step - loss: 0.0016 - accuracy: 0.9995 - val_loss: 0.0011 - val_accuracy: 0.9997\n",
|
| 370 |
+
"Epoch 7/100\n",
|
| 371 |
+
"1433/1433 [==============================] - 13s 9ms/step - loss: 0.0029 - accuracy: 0.9994 - val_loss: 0.0013 - val_accuracy: 0.9997\n",
|
| 372 |
+
"Epoch 8/100\n",
|
| 373 |
+
"1433/1433 [==============================] - 13s 9ms/step - loss: 0.0012 - accuracy: 0.9996 - val_loss: 6.2126e-04 - val_accuracy: 0.9997\n",
|
| 374 |
+
"Epoch 9/100\n",
|
| 375 |
+
"1433/1433 [==============================] - 13s 9ms/step - loss: 0.0014 - accuracy: 0.9995 - val_loss: 3.3005e-04 - val_accuracy: 0.9997\n",
|
| 376 |
+
"Epoch 10/100\n",
|
| 377 |
+
"1433/1433 [==============================] - 13s 9ms/step - loss: 0.0020 - accuracy: 0.9993 - val_loss: 3.4855e-04 - val_accuracy: 0.9997\n",
|
| 378 |
+
"Epoch 11/100\n",
|
| 379 |
+
"1433/1433 [==============================] - 13s 9ms/step - loss: 4.9058e-04 - accuracy: 0.9999 - val_loss: 7.0838e-04 - val_accuracy: 0.9997\n",
|
| 380 |
+
"Epoch 12/100\n",
|
| 381 |
+
"1433/1433 [==============================] - 14s 10ms/step - loss: 3.1028e-05 - accuracy: 1.0000 - val_loss: 9.7843e-05 - val_accuracy: 1.0000\n",
|
| 382 |
+
"Epoch 13/100\n",
|
| 383 |
+
"1433/1433 [==============================] - 14s 10ms/step - loss: 0.0023 - accuracy: 0.9994 - val_loss: 0.0037 - val_accuracy: 0.9986\n",
|
| 384 |
+
"Epoch 14/100\n",
|
| 385 |
+
"1433/1433 [==============================] - 14s 10ms/step - loss: 6.7057e-04 - accuracy: 0.9998 - val_loss: 4.5239e-04 - val_accuracy: 0.9997\n",
|
| 386 |
+
"Epoch 15/100\n",
|
| 387 |
+
"1433/1433 [==============================] - 13s 9ms/step - loss: 3.3223e-05 - accuracy: 1.0000 - val_loss: 3.2696e-04 - val_accuracy: 0.9997\n",
|
| 388 |
+
"Epoch 16/100\n",
|
| 389 |
+
"1433/1433 [==============================] - 13s 9ms/step - loss: 0.0035 - accuracy: 0.9992 - val_loss: 1.6425e-04 - val_accuracy: 1.0000\n",
|
| 390 |
+
"Epoch 17/100\n",
|
| 391 |
+
"1433/1433 [==============================] - 14s 10ms/step - loss: 0.0015 - accuracy: 0.9995 - val_loss: 0.0011 - val_accuracy: 0.9997\n"
|
| 392 |
+
]
|
| 393 |
+
},
|
| 394 |
+
{
|
| 395 |
+
"data": {
|
| 396 |
+
"text/plain": [
|
| 397 |
+
"<keras.callbacks.History at 0x159960100>"
|
| 398 |
+
]
|
| 399 |
+
},
|
| 400 |
+
"execution_count": 102,
|
| 401 |
+
"metadata": {},
|
| 402 |
+
"output_type": "execute_result"
|
| 403 |
+
}
|
| 404 |
+
],
|
| 405 |
+
"source": [
|
| 406 |
+
"model = get_best_model(tuner)\n",
|
| 407 |
+
"model.fit(x_train, y_train, epochs=100, batch_size=10, validation_data=(x_test, y_test), callbacks=[stop_early])"
|
| 408 |
+
]
|
| 409 |
+
},
|
| 410 |
+
{
|
| 411 |
+
"cell_type": "code",
|
| 412 |
+
"execution_count": 125,
|
| 413 |
+
"metadata": {},
|
| 414 |
+
"outputs": [],
|
| 415 |
+
"source": [
|
| 416 |
+
"final_models[\"3_layers\"] = model"
|
| 417 |
+
]
|
| 418 |
+
},
|
| 419 |
+
{
|
| 420 |
+
"cell_type": "markdown",
|
| 421 |
+
"metadata": {},
|
| 422 |
+
"source": [
|
| 423 |
+
"### 3.3. Model with 5 layers"
|
| 424 |
+
]
|
| 425 |
+
},
|
| 426 |
+
{
|
| 427 |
+
"cell_type": "code",
|
| 428 |
+
"execution_count": 106,
|
| 429 |
+
"metadata": {},
|
| 430 |
+
"outputs": [],
|
| 431 |
+
"source": [
|
| 432 |
+
"def model_builder_5(hp):\n",
|
| 433 |
+
" model = Sequential()\n",
|
| 434 |
+
" model.add(Dense(52, input_dim = 52, activation = \"relu\"))\n",
|
| 435 |
+
"\n",
|
| 436 |
+
" hp_activation = hp.Choice('activation', values=['relu', 'tanh'])\n",
|
| 437 |
+
" hp_layer_1 = hp.Int('layer_1', min_value=32, max_value=512, step=32)\n",
|
| 438 |
+
" hp_layer_2 = hp.Int('layer_2', min_value=32, max_value=512, step=32)\n",
|
| 439 |
+
" hp_layer_3 = hp.Int('layer_3', min_value=32, max_value=512, step=32)\n",
|
| 440 |
+
" hp_learning_rate = hp.Choice('learning_rate', values=[1e-2, 1e-3, 1e-4])\n",
|
| 441 |
+
"\n",
|
| 442 |
+
" model.add(Dense(units=hp_layer_1, activation=hp_activation))\n",
|
| 443 |
+
" model.add(Dense(units=hp_layer_2, activation=hp_activation))\n",
|
| 444 |
+
" model.add(Dense(units=hp_layer_3, activation=hp_activation))\n",
|
| 445 |
+
" model.add(Dense(2, activation='softmax'))\n",
|
| 446 |
+
"\n",
|
| 447 |
+
" model.compile(optimizer=Adam(learning_rate=hp_learning_rate), loss=\"categorical_crossentropy\", metrics = [\"accuracy\"])\n",
|
| 448 |
+
" \n",
|
| 449 |
+
" return model"
|
| 450 |
+
]
|
| 451 |
+
},
|
| 452 |
+
{
|
| 453 |
+
"cell_type": "code",
|
| 454 |
+
"execution_count": 110,
|
| 455 |
+
"metadata": {},
|
| 456 |
+
"outputs": [
|
| 457 |
+
{
|
| 458 |
+
"name": "stdout",
|
| 459 |
+
"output_type": "stream",
|
| 460 |
+
"text": [
|
| 461 |
+
"Trial 30 Complete [00h 00m 44s]\n",
|
| 462 |
+
"accuracy: 0.9998604655265808\n",
|
| 463 |
+
"\n",
|
| 464 |
+
"Best accuracy So Far: 1.0000001192092896\n",
|
| 465 |
+
"Total elapsed time: 00h 08m 57s\n",
|
| 466 |
+
"INFO:tensorflow:Oracle triggered exit\n"
|
| 467 |
+
]
|
| 468 |
+
}
|
| 469 |
+
],
|
| 470 |
+
"source": [
|
| 471 |
+
"tuner = kt.Hyperband(\n",
|
| 472 |
+
" model_builder_5,\n",
|
| 473 |
+
" objective='accuracy',\n",
|
| 474 |
+
" max_epochs=10,\n",
|
| 475 |
+
" directory='keras_tuner_dir',\n",
|
| 476 |
+
" project_name='keras_tuner_demo_2'\n",
|
| 477 |
+
")\n",
|
| 478 |
+
"tuner.search(x_train, y_train, epochs=10, callbacks=[stop_early])"
|
| 479 |
+
]
|
| 480 |
+
},
|
| 481 |
+
{
|
| 482 |
+
"cell_type": "code",
|
| 483 |
+
"execution_count": 111,
|
| 484 |
+
"metadata": {},
|
| 485 |
+
"outputs": [
|
| 486 |
+
{
|
| 487 |
+
"name": "stdout",
|
| 488 |
+
"output_type": "stream",
|
| 489 |
+
"text": [
|
| 490 |
+
"Describe models architecture\n",
|
| 491 |
+
"Layer-1: 52 units, func: <function relu at 0x155f86a60>\n",
|
| 492 |
+
"Layer-2: 480 units, func: <function tanh at 0x155f86dc0>\n",
|
| 493 |
+
"Layer-3: 480 units, func: <function tanh at 0x155f86dc0>\n",
|
| 494 |
+
"Layer-4: 192 units, func: <function tanh at 0x155f86dc0>\n",
|
| 495 |
+
"Layer-5: 2 units, func: <function softmax at 0x155f86040>\n",
|
| 496 |
+
"learning_rate: 0.0001\n",
|
| 497 |
+
"Epoch 1/100\n"
|
| 498 |
+
]
|
| 499 |
+
},
|
| 500 |
+
{
|
| 501 |
+
"name": "stderr",
|
| 502 |
+
"output_type": "stream",
|
| 503 |
+
"text": [
|
| 504 |
+
"2022-11-22 11:11:48.332514: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:113] Plugin optimizer for device_type GPU is enabled.\n"
|
| 505 |
+
]
|
| 506 |
+
},
|
| 507 |
+
{
|
| 508 |
+
"name": "stdout",
|
| 509 |
+
"output_type": "stream",
|
| 510 |
+
"text": [
|
| 511 |
+
"1433/1433 [==============================] - ETA: 0s - loss: 0.0428 - accuracy: 0.9872"
|
| 512 |
+
]
|
| 513 |
+
},
|
| 514 |
+
{
|
| 515 |
+
"name": "stderr",
|
| 516 |
+
"output_type": "stream",
|
| 517 |
+
"text": [
|
| 518 |
+
"2022-11-22 11:12:02.217934: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:113] Plugin optimizer for device_type GPU is enabled.\n"
|
| 519 |
+
]
|
| 520 |
+
},
|
| 521 |
+
{
|
| 522 |
+
"name": "stdout",
|
| 523 |
+
"output_type": "stream",
|
| 524 |
+
"text": [
|
| 525 |
+
"1433/1433 [==============================] - 16s 11ms/step - loss: 0.0428 - accuracy: 0.9872 - val_loss: 0.0248 - val_accuracy: 0.9913\n",
|
| 526 |
+
"Epoch 2/100\n",
|
| 527 |
+
"1433/1433 [==============================] - 16s 11ms/step - loss: 0.0055 - accuracy: 0.9986 - val_loss: 0.0016 - val_accuracy: 0.9992\n",
|
| 528 |
+
"Epoch 3/100\n",
|
| 529 |
+
"1433/1433 [==============================] - 16s 11ms/step - loss: 0.0052 - accuracy: 0.9988 - val_loss: 0.0036 - val_accuracy: 0.9992\n",
|
| 530 |
+
"Epoch 4/100\n",
|
| 531 |
+
"1433/1433 [==============================] - 16s 11ms/step - loss: 0.0024 - accuracy: 0.9994 - val_loss: 0.0017 - val_accuracy: 0.9997\n",
|
| 532 |
+
"Epoch 5/100\n",
|
| 533 |
+
"1433/1433 [==============================] - 16s 11ms/step - loss: 0.0048 - accuracy: 0.9982 - val_loss: 0.0019 - val_accuracy: 0.9994\n",
|
| 534 |
+
"Epoch 6/100\n",
|
| 535 |
+
"1433/1433 [==============================] - 16s 11ms/step - loss: 0.0018 - accuracy: 0.9994 - val_loss: 8.5989e-04 - val_accuracy: 0.9997\n",
|
| 536 |
+
"Epoch 7/100\n",
|
| 537 |
+
"1433/1433 [==============================] - 16s 11ms/step - loss: 0.0023 - accuracy: 0.9994 - val_loss: 4.7530e-04 - val_accuracy: 0.9997\n",
|
| 538 |
+
"Epoch 8/100\n",
|
| 539 |
+
"1433/1433 [==============================] - 16s 11ms/step - loss: 0.0022 - accuracy: 0.9991 - val_loss: 3.3401e-04 - val_accuracy: 0.9997\n",
|
| 540 |
+
"Epoch 9/100\n",
|
| 541 |
+
"1433/1433 [==============================] - 16s 11ms/step - loss: 0.0018 - accuracy: 0.9993 - val_loss: 3.5316e-04 - val_accuracy: 0.9997\n",
|
| 542 |
+
"Epoch 10/100\n",
|
| 543 |
+
"1433/1433 [==============================] - 18s 13ms/step - loss: 0.0010 - accuracy: 0.9997 - val_loss: 5.9313e-04 - val_accuracy: 0.9997\n",
|
| 544 |
+
"Epoch 11/100\n",
|
| 545 |
+
"1433/1433 [==============================] - 16s 11ms/step - loss: 0.0017 - accuracy: 0.9995 - val_loss: 5.0026e-04 - val_accuracy: 0.9997\n",
|
| 546 |
+
"Epoch 12/100\n",
|
| 547 |
+
"1433/1433 [==============================] - 17s 12ms/step - loss: 0.0016 - accuracy: 0.9994 - val_loss: 2.7773e-04 - val_accuracy: 0.9997\n",
|
| 548 |
+
"Epoch 13/100\n",
|
| 549 |
+
"1433/1433 [==============================] - 16s 11ms/step - loss: 3.9777e-04 - accuracy: 0.9999 - val_loss: 2.6341e-04 - val_accuracy: 0.9997\n",
|
| 550 |
+
"Epoch 14/100\n",
|
| 551 |
+
"1433/1433 [==============================] - 18s 12ms/step - loss: 0.0021 - accuracy: 0.9997 - val_loss: 3.1135e-04 - val_accuracy: 1.0000\n",
|
| 552 |
+
"Epoch 15/100\n",
|
| 553 |
+
"1433/1433 [==============================] - 17s 12ms/step - loss: 2.4923e-04 - accuracy: 0.9999 - val_loss: 3.7806e-04 - val_accuracy: 0.9997\n",
|
| 554 |
+
"Epoch 16/100\n",
|
| 555 |
+
"1433/1433 [==============================] - 16s 11ms/step - loss: 4.5137e-04 - accuracy: 0.9999 - val_loss: 1.0070e-04 - val_accuracy: 1.0000\n",
|
| 556 |
+
"Epoch 17/100\n",
|
| 557 |
+
"1433/1433 [==============================] - 17s 12ms/step - loss: 8.3379e-06 - accuracy: 1.0000 - val_loss: 6.9653e-04 - val_accuracy: 0.9997\n",
|
| 558 |
+
"Epoch 18/100\n",
|
| 559 |
+
"1433/1433 [==============================] - 17s 12ms/step - loss: 0.0018 - accuracy: 0.9994 - val_loss: 1.4246e-04 - val_accuracy: 1.0000\n",
|
| 560 |
+
"Epoch 19/100\n",
|
| 561 |
+
"1433/1433 [==============================] - 16s 11ms/step - loss: 2.8542e-05 - accuracy: 1.0000 - val_loss: 1.0260e-04 - val_accuracy: 1.0000\n",
|
| 562 |
+
"Epoch 20/100\n",
|
| 563 |
+
"1433/1433 [==============================] - 17s 12ms/step - loss: 0.0016 - accuracy: 0.9994 - val_loss: 2.7252e-04 - val_accuracy: 1.0000\n",
|
| 564 |
+
"Epoch 21/100\n",
|
| 565 |
+
"1433/1433 [==============================] - 16s 11ms/step - loss: 3.2176e-05 - accuracy: 1.0000 - val_loss: 1.3285e-04 - val_accuracy: 1.0000\n",
|
| 566 |
+
"Epoch 22/100\n",
|
| 567 |
+
"1433/1433 [==============================] - 18s 12ms/step - loss: 1.2290e-05 - accuracy: 1.0000 - val_loss: 1.1766e-04 - val_accuracy: 1.0000\n"
|
| 568 |
+
]
|
| 569 |
+
},
|
| 570 |
+
{
|
| 571 |
+
"data": {
|
| 572 |
+
"text/plain": [
|
| 573 |
+
"<keras.callbacks.History at 0x2f26db9d0>"
|
| 574 |
+
]
|
| 575 |
+
},
|
| 576 |
+
"execution_count": 111,
|
| 577 |
+
"metadata": {},
|
| 578 |
+
"output_type": "execute_result"
|
| 579 |
+
}
|
| 580 |
+
],
|
| 581 |
+
"source": [
|
| 582 |
+
"model_5 = get_best_model(tuner)\n",
|
| 583 |
+
"model_5.fit(x_train, y_train, epochs=100, batch_size=10, validation_data=(x_test, y_test), callbacks=[stop_early])"
|
| 584 |
+
]
|
| 585 |
+
},
|
| 586 |
+
{
|
| 587 |
+
"cell_type": "code",
|
| 588 |
+
"execution_count": 121,
|
| 589 |
+
"metadata": {},
|
| 590 |
+
"outputs": [],
|
| 591 |
+
"source": [
|
| 592 |
+
"final_models[\"5_layers\"] = model_5"
|
| 593 |
+
]
|
| 594 |
+
},
|
| 595 |
+
{
|
| 596 |
+
"cell_type": "markdown",
|
| 597 |
+
"metadata": {},
|
| 598 |
+
"source": [
|
| 599 |
+
"### 3.4. Model with 7 layers along with Dropout layers"
|
| 600 |
+
]
|
| 601 |
+
},
|
| 602 |
+
{
|
| 603 |
+
"cell_type": "code",
|
| 604 |
+
"execution_count": 113,
|
| 605 |
+
"metadata": {},
|
| 606 |
+
"outputs": [],
|
| 607 |
+
"source": [
|
| 608 |
+
"def model_builder_dropout_5(hp):\n",
|
| 609 |
+
" model = Sequential()\n",
|
| 610 |
+
" model.add(Dense(52, input_dim = 52, activation = \"relu\"))\n",
|
| 611 |
+
"\n",
|
| 612 |
+
" hp_activation = hp.Choice('activation', values=['relu', 'tanh'])\n",
|
| 613 |
+
" hp_layer_1 = hp.Int('layer_1', min_value=32, max_value=512, step=32)\n",
|
| 614 |
+
" hp_layer_2 = hp.Int('layer_2', min_value=32, max_value=512, step=32)\n",
|
| 615 |
+
" hp_layer_3 = hp.Int('layer_3', min_value=32, max_value=512, step=32)\n",
|
| 616 |
+
" hp_learning_rate = hp.Choice('learning_rate', values=[1e-2, 1e-3, 1e-4])\n",
|
| 617 |
+
"\n",
|
| 618 |
+
" model.add(Dense(units=hp_layer_1, activation=hp_activation))\n",
|
| 619 |
+
" model.add(Dropout(0.5))\n",
|
| 620 |
+
" model.add(Dense(units=hp_layer_2, activation=hp_activation))\n",
|
| 621 |
+
" model.add(Dropout(0.5))\n",
|
| 622 |
+
" model.add(Dense(units=hp_layer_3, activation=hp_activation))\n",
|
| 623 |
+
" model.add(Dense(2, activation='softmax'))\n",
|
| 624 |
+
"\n",
|
| 625 |
+
" model.compile(optimizer=Adam(learning_rate=hp_learning_rate), loss=\"categorical_crossentropy\", metrics = [\"accuracy\"])\n",
|
| 626 |
+
" \n",
|
| 627 |
+
" return model"
|
| 628 |
+
]
|
| 629 |
+
},
|
| 630 |
+
{
|
| 631 |
+
"cell_type": "code",
|
| 632 |
+
"execution_count": 114,
|
| 633 |
+
"metadata": {},
|
| 634 |
+
"outputs": [
|
| 635 |
+
{
|
| 636 |
+
"name": "stdout",
|
| 637 |
+
"output_type": "stream",
|
| 638 |
+
"text": [
|
| 639 |
+
"Trial 30 Complete [00h 00m 53s]\n",
|
| 640 |
+
"accuracy: 0.9993019700050354\n",
|
| 641 |
+
"\n",
|
| 642 |
+
"Best accuracy So Far: 0.9997208118438721\n",
|
| 643 |
+
"Total elapsed time: 00h 11m 19s\n",
|
| 644 |
+
"INFO:tensorflow:Oracle triggered exit\n"
|
| 645 |
+
]
|
| 646 |
+
}
|
| 647 |
+
],
|
| 648 |
+
"source": [
|
| 649 |
+
"tuner = kt.Hyperband(\n",
|
| 650 |
+
" model_builder_dropout_5,\n",
|
| 651 |
+
" objective='accuracy',\n",
|
| 652 |
+
" max_epochs=10,\n",
|
| 653 |
+
" directory='keras_tuner_dir',\n",
|
| 654 |
+
" project_name='keras_tuner_demo_3'\n",
|
| 655 |
+
")\n",
|
| 656 |
+
"tuner.search(x_train, y_train, epochs=10, callbacks=[stop_early])"
|
| 657 |
+
]
|
| 658 |
+
},
|
| 659 |
+
{
|
| 660 |
+
"cell_type": "code",
|
| 661 |
+
"execution_count": 117,
|
| 662 |
+
"metadata": {},
|
| 663 |
+
"outputs": [
|
| 664 |
+
{
|
| 665 |
+
"name": "stdout",
|
| 666 |
+
"output_type": "stream",
|
| 667 |
+
"text": [
|
| 668 |
+
"Describe models architecture\n",
|
| 669 |
+
"Layer-1: 52 units, func: <function relu at 0x155f86a60>\n",
|
| 670 |
+
"Layer-2: 288 units, func: <function relu at 0x155f86a60>\n",
|
| 671 |
+
"Layer-3: 0 units, func: None\n",
|
| 672 |
+
"Layer-4: 224 units, func: <function relu at 0x155f86a60>\n",
|
| 673 |
+
"Layer-5: 0 units, func: None\n",
|
| 674 |
+
"Layer-6: 320 units, func: <function relu at 0x155f86a60>\n",
|
| 675 |
+
"Layer-7: 2 units, func: <function softmax at 0x155f86040>\n",
|
| 676 |
+
"learning_rate: 0.001\n",
|
| 677 |
+
"Epoch 1/100\n"
|
| 678 |
+
]
|
| 679 |
+
},
|
| 680 |
+
{
|
| 681 |
+
"name": "stderr",
|
| 682 |
+
"output_type": "stream",
|
| 683 |
+
"text": [
|
| 684 |
+
"2022-11-22 11:34:38.687622: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:113] Plugin optimizer for device_type GPU is enabled.\n"
|
| 685 |
+
]
|
| 686 |
+
},
|
| 687 |
+
{
|
| 688 |
+
"name": "stdout",
|
| 689 |
+
"output_type": "stream",
|
| 690 |
+
"text": [
|
| 691 |
+
"1433/1433 [==============================] - ETA: 0s - loss: 0.0443 - accuracy: 0.9823"
|
| 692 |
+
]
|
| 693 |
+
},
|
| 694 |
+
{
|
| 695 |
+
"name": "stderr",
|
| 696 |
+
"output_type": "stream",
|
| 697 |
+
"text": [
|
| 698 |
+
"2022-11-22 11:34:54.587049: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:113] Plugin optimizer for device_type GPU is enabled.\n"
|
| 699 |
+
]
|
| 700 |
+
},
|
| 701 |
+
{
|
| 702 |
+
"name": "stdout",
|
| 703 |
+
"output_type": "stream",
|
| 704 |
+
"text": [
|
| 705 |
+
"1433/1433 [==============================] - 19s 13ms/step - loss: 0.0443 - accuracy: 0.9823 - val_loss: 0.0023 - val_accuracy: 0.9994\n",
|
| 706 |
+
"Epoch 2/100\n",
|
| 707 |
+
"1433/1433 [==============================] - 18s 12ms/step - loss: 0.0090 - accuracy: 0.9976 - val_loss: 0.0046 - val_accuracy: 0.9975\n",
|
| 708 |
+
"Epoch 3/100\n",
|
| 709 |
+
"1433/1433 [==============================] - 18s 13ms/step - loss: 0.0048 - accuracy: 0.9988 - val_loss: 0.0011 - val_accuracy: 0.9994\n",
|
| 710 |
+
"Epoch 4/100\n",
|
| 711 |
+
"1433/1433 [==============================] - 17s 12ms/step - loss: 0.0073 - accuracy: 0.9983 - val_loss: 0.0015 - val_accuracy: 0.9994\n",
|
| 712 |
+
"Epoch 5/100\n",
|
| 713 |
+
"1433/1433 [==============================] - 17s 12ms/step - loss: 0.0041 - accuracy: 0.9987 - val_loss: 8.3813e-04 - val_accuracy: 0.9997\n",
|
| 714 |
+
"Epoch 6/100\n",
|
| 715 |
+
"1433/1433 [==============================] - 17s 12ms/step - loss: 0.0053 - accuracy: 0.9987 - val_loss: 0.0047 - val_accuracy: 0.9986\n",
|
| 716 |
+
"Epoch 7/100\n",
|
| 717 |
+
"1433/1433 [==============================] - 17s 12ms/step - loss: 0.0044 - accuracy: 0.9984 - val_loss: 9.9630e-04 - val_accuracy: 0.9997\n",
|
| 718 |
+
"Epoch 8/100\n",
|
| 719 |
+
"1433/1433 [==============================] - 18s 13ms/step - loss: 0.0025 - accuracy: 0.9994 - val_loss: 0.0013 - val_accuracy: 0.9994\n",
|
| 720 |
+
"Epoch 9/100\n",
|
| 721 |
+
"1433/1433 [==============================] - 18s 13ms/step - loss: 0.0025 - accuracy: 0.9993 - val_loss: 0.0074 - val_accuracy: 0.9989\n",
|
| 722 |
+
"Epoch 10/100\n",
|
| 723 |
+
"1433/1433 [==============================] - 17s 12ms/step - loss: 0.0019 - accuracy: 0.9998 - val_loss: 0.0018 - val_accuracy: 0.9997\n",
|
| 724 |
+
"Epoch 11/100\n",
|
| 725 |
+
"1433/1433 [==============================] - 18s 12ms/step - loss: 0.0038 - accuracy: 0.9994 - val_loss: 3.8607e-04 - val_accuracy: 0.9997\n",
|
| 726 |
+
"Epoch 12/100\n",
|
| 727 |
+
"1433/1433 [==============================] - 17s 12ms/step - loss: 0.0031 - accuracy: 0.9992 - val_loss: 0.0011 - val_accuracy: 0.9994\n",
|
| 728 |
+
"Epoch 13/100\n",
|
| 729 |
+
"1433/1433 [==============================] - 18s 12ms/step - loss: 0.0048 - accuracy: 0.9990 - val_loss: 1.1490e-04 - val_accuracy: 1.0000\n"
|
| 730 |
+
]
|
| 731 |
+
},
|
| 732 |
+
{
|
| 733 |
+
"data": {
|
| 734 |
+
"text/plain": [
|
| 735 |
+
"<keras.callbacks.History at 0x2f45e9280>"
|
| 736 |
+
]
|
| 737 |
+
},
|
| 738 |
+
"execution_count": 117,
|
| 739 |
+
"metadata": {},
|
| 740 |
+
"output_type": "execute_result"
|
| 741 |
+
}
|
| 742 |
+
],
|
| 743 |
+
"source": [
|
| 744 |
+
"model_5_with_dropout = get_best_model(tuner)\n",
|
| 745 |
+
"model_5_with_dropout.fit(x_train, y_train, epochs=100, batch_size=10, validation_data=(x_test, y_test), callbacks=[stop_early])"
|
| 746 |
+
]
|
| 747 |
+
},
|
| 748 |
+
{
|
| 749 |
+
"cell_type": "code",
|
| 750 |
+
"execution_count": 126,
|
| 751 |
+
"metadata": {},
|
| 752 |
+
"outputs": [
|
| 753 |
+
{
|
| 754 |
+
"data": {
|
| 755 |
+
"text/plain": [
|
| 756 |
+
"dict_keys(['7_layers_with_dropout', '5_layers', '3_layers'])"
|
| 757 |
+
]
|
| 758 |
+
},
|
| 759 |
+
"execution_count": 126,
|
| 760 |
+
"metadata": {},
|
| 761 |
+
"output_type": "execute_result"
|
| 762 |
+
}
|
| 763 |
+
],
|
| 764 |
+
"source": [
|
| 765 |
+
"final_models[\"7_layers_with_dropout\"] = model_5_with_dropout"
|
| 766 |
+
]
|
| 767 |
+
},
|
| 768 |
+
{
|
| 769 |
+
"cell_type": "markdown",
|
| 770 |
+
"metadata": {},
|
| 771 |
+
"source": [
|
| 772 |
+
"### 3.5. Model with 7 layers"
|
| 773 |
+
]
|
| 774 |
+
},
|
| 775 |
+
{
|
| 776 |
+
"cell_type": "code",
|
| 777 |
+
"execution_count": 133,
|
| 778 |
+
"metadata": {},
|
| 779 |
+
"outputs": [],
|
| 780 |
+
"source": [
|
| 781 |
+
"def model_builder_7(hp):\n",
|
| 782 |
+
" model = Sequential()\n",
|
| 783 |
+
" model.add(Dense(52, input_dim = 52, activation = \"relu\"))\n",
|
| 784 |
+
"\n",
|
| 785 |
+
" hp_activation = hp.Choice('activation', values=['relu', 'tanh'])\n",
|
| 786 |
+
" hp_layer_1 = hp.Int('layer_1', min_value=32, max_value=512, step=32)\n",
|
| 787 |
+
" hp_layer_2 = hp.Int('layer_2', min_value=32, max_value=512, step=32)\n",
|
| 788 |
+
" hp_layer_3 = hp.Int('layer_3', min_value=32, max_value=512, step=32)\n",
|
| 789 |
+
" hp_layer_4 = hp.Int('layer_4', min_value=32, max_value=512, step=32)\n",
|
| 790 |
+
" hp_layer_5 = hp.Int('layer_5', min_value=32, max_value=512, step=32)\n",
|
| 791 |
+
" hp_learning_rate = hp.Choice('learning_rate', values=[1e-2, 1e-3, 1e-4])\n",
|
| 792 |
+
"\n",
|
| 793 |
+
" model.add(Dense(units=hp_layer_1, activation=hp_activation))\n",
|
| 794 |
+
" model.add(Dense(units=hp_layer_2, activation=hp_activation))\n",
|
| 795 |
+
" model.add(Dense(units=hp_layer_3, activation=hp_activation))\n",
|
| 796 |
+
" model.add(Dense(units=hp_layer_4, activation=hp_activation))\n",
|
| 797 |
+
" model.add(Dense(units=hp_layer_5, activation=hp_activation))\n",
|
| 798 |
+
" model.add(Dense(2, activation='softmax'))\n",
|
| 799 |
+
"\n",
|
| 800 |
+
" model.compile(optimizer=Adam(learning_rate=hp_learning_rate), loss=\"categorical_crossentropy\", metrics = [\"accuracy\"])\n",
|
| 801 |
+
" \n",
|
| 802 |
+
" return model"
|
| 803 |
+
]
|
| 804 |
+
},
|
| 805 |
+
{
|
| 806 |
+
"cell_type": "code",
|
| 807 |
+
"execution_count": 134,
|
| 808 |
+
"metadata": {},
|
| 809 |
+
"outputs": [
|
| 810 |
+
{
|
| 811 |
+
"name": "stdout",
|
| 812 |
+
"output_type": "stream",
|
| 813 |
+
"text": [
|
| 814 |
+
"Trial 30 Complete [00h 00m 52s]\n",
|
| 815 |
+
"accuracy: 0.9996510148048401\n",
|
| 816 |
+
"\n",
|
| 817 |
+
"Best accuracy So Far: 1.0000001192092896\n",
|
| 818 |
+
"Total elapsed time: 00h 12m 14s\n",
|
| 819 |
+
"INFO:tensorflow:Oracle triggered exit\n"
|
| 820 |
+
]
|
| 821 |
+
}
|
| 822 |
+
],
|
| 823 |
+
"source": [
|
| 824 |
+
"tuner = kt.Hyperband(\n",
|
| 825 |
+
" model_builder_7,\n",
|
| 826 |
+
" objective='accuracy',\n",
|
| 827 |
+
" max_epochs=10,\n",
|
| 828 |
+
" directory='keras_tuner_dir',\n",
|
| 829 |
+
" project_name='keras_tuner_demo_4'\n",
|
| 830 |
+
")\n",
|
| 831 |
+
"tuner.search(x_train, y_train, epochs=10, callbacks=[stop_early])"
|
| 832 |
+
]
|
| 833 |
+
},
|
| 834 |
+
{
|
| 835 |
+
"cell_type": "code",
|
| 836 |
+
"execution_count": 135,
|
| 837 |
+
"metadata": {},
|
| 838 |
+
"outputs": [
|
| 839 |
+
{
|
| 840 |
+
"name": "stdout",
|
| 841 |
+
"output_type": "stream",
|
| 842 |
+
"text": [
|
| 843 |
+
"Describe models architecture\n",
|
| 844 |
+
"Layer-1: 52 units, func: <function relu at 0x155f86a60>\n",
|
| 845 |
+
"Layer-2: 32 units, func: <function relu at 0x155f86a60>\n",
|
| 846 |
+
"Layer-3: 416 units, func: <function relu at 0x155f86a60>\n",
|
| 847 |
+
"Layer-4: 192 units, func: <function relu at 0x155f86a60>\n",
|
| 848 |
+
"Layer-5: 224 units, func: <function relu at 0x155f86a60>\n",
|
| 849 |
+
"Layer-6: 416 units, func: <function relu at 0x155f86a60>\n",
|
| 850 |
+
"Layer-7: 2 units, func: <function softmax at 0x155f86040>\n",
|
| 851 |
+
"learning_rate: 0.0001\n",
|
| 852 |
+
"Epoch 1/100\n"
|
| 853 |
+
]
|
| 854 |
+
},
|
| 855 |
+
{
|
| 856 |
+
"name": "stderr",
|
| 857 |
+
"output_type": "stream",
|
| 858 |
+
"text": [
|
| 859 |
+
"2022-11-22 13:44:16.853169: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:113] Plugin optimizer for device_type GPU is enabled.\n"
|
| 860 |
+
]
|
| 861 |
+
},
|
| 862 |
+
{
|
| 863 |
+
"name": "stdout",
|
| 864 |
+
"output_type": "stream",
|
| 865 |
+
"text": [
|
| 866 |
+
"1433/1433 [==============================] - ETA: 0s - loss: 0.0618 - accuracy: 0.9774"
|
| 867 |
+
]
|
| 868 |
+
},
|
| 869 |
+
{
|
| 870 |
+
"name": "stderr",
|
| 871 |
+
"output_type": "stream",
|
| 872 |
+
"text": [
|
| 873 |
+
"2022-11-22 13:44:35.283222: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:113] Plugin optimizer for device_type GPU is enabled.\n"
|
| 874 |
+
]
|
| 875 |
+
},
|
| 876 |
+
{
|
| 877 |
+
"name": "stdout",
|
| 878 |
+
"output_type": "stream",
|
| 879 |
+
"text": [
|
| 880 |
+
"1433/1433 [==============================] - 22s 15ms/step - loss: 0.0618 - accuracy: 0.9774 - val_loss: 0.0034 - val_accuracy: 0.9986\n",
|
| 881 |
+
"Epoch 2/100\n",
|
| 882 |
+
"1433/1433 [==============================] - 22s 15ms/step - loss: 0.0059 - accuracy: 0.9985 - val_loss: 0.0092 - val_accuracy: 0.9969\n",
|
| 883 |
+
"Epoch 3/100\n",
|
| 884 |
+
"1433/1433 [==============================] - 22s 15ms/step - loss: 0.0052 - accuracy: 0.9984 - val_loss: 0.0018 - val_accuracy: 0.9997\n",
|
| 885 |
+
"Epoch 4/100\n",
|
| 886 |
+
"1433/1433 [==============================] - 21s 15ms/step - loss: 0.0028 - accuracy: 0.9993 - val_loss: 0.0017 - val_accuracy: 0.9992\n",
|
| 887 |
+
"Epoch 5/100\n",
|
| 888 |
+
"1433/1433 [==============================] - 21s 14ms/step - loss: 0.0031 - accuracy: 0.9992 - val_loss: 3.8267e-04 - val_accuracy: 1.0000\n",
|
| 889 |
+
"Epoch 6/100\n",
|
| 890 |
+
"1433/1433 [==============================] - 21s 15ms/step - loss: 0.0030 - accuracy: 0.9994 - val_loss: 2.8054e-04 - val_accuracy: 1.0000\n",
|
| 891 |
+
"Epoch 7/100\n",
|
| 892 |
+
"1433/1433 [==============================] - 21s 15ms/step - loss: 0.0022 - accuracy: 0.9995 - val_loss: 0.0038 - val_accuracy: 0.9983\n",
|
| 893 |
+
"Epoch 8/100\n",
|
| 894 |
+
"1433/1433 [==============================] - 21s 15ms/step - loss: 0.0013 - accuracy: 0.9995 - val_loss: 1.5269e-04 - val_accuracy: 1.0000\n",
|
| 895 |
+
"Epoch 9/100\n",
|
| 896 |
+
"1433/1433 [==============================] - 21s 15ms/step - loss: 0.0014 - accuracy: 0.9998 - val_loss: 4.1306e-04 - val_accuracy: 0.9997\n",
|
| 897 |
+
"Epoch 10/100\n",
|
| 898 |
+
"1433/1433 [==============================] - 21s 14ms/step - loss: 2.9369e-04 - accuracy: 0.9999 - val_loss: 1.3631e-05 - val_accuracy: 1.0000\n",
|
| 899 |
+
"Epoch 11/100\n",
|
| 900 |
+
"1433/1433 [==============================] - 21s 15ms/step - loss: 1.4331e-05 - accuracy: 1.0000 - val_loss: 6.5710e-06 - val_accuracy: 1.0000\n",
|
| 901 |
+
"Epoch 12/100\n",
|
| 902 |
+
"1433/1433 [==============================] - 21s 15ms/step - loss: 0.0015 - accuracy: 0.9995 - val_loss: 1.8750e-04 - val_accuracy: 1.0000\n",
|
| 903 |
+
"Epoch 13/100\n",
|
| 904 |
+
"1433/1433 [==============================] - 21s 15ms/step - loss: 3.1710e-04 - accuracy: 0.9998 - val_loss: 4.7737e-05 - val_accuracy: 1.0000\n",
|
| 905 |
+
"Epoch 14/100\n",
|
| 906 |
+
"1433/1433 [==============================] - 21s 15ms/step - loss: 1.3424e-05 - accuracy: 1.0000 - val_loss: 1.4000e-05 - val_accuracy: 1.0000\n",
|
| 907 |
+
"Epoch 15/100\n",
|
| 908 |
+
"1433/1433 [==============================] - 21s 15ms/step - loss: 3.3917e-06 - accuracy: 1.0000 - val_loss: 1.4799e-05 - val_accuracy: 1.0000\n",
|
| 909 |
+
"Epoch 16/100\n",
|
| 910 |
+
"1433/1433 [==============================] - 21s 15ms/step - loss: 1.1946e-06 - accuracy: 1.0000 - val_loss: 7.0157e-06 - val_accuracy: 1.0000\n",
|
| 911 |
+
"Epoch 17/100\n",
|
| 912 |
+
"1433/1433 [==============================] - 21s 15ms/step - loss: 6.1554e-07 - accuracy: 1.0000 - val_loss: 2.0066e-05 - val_accuracy: 1.0000\n",
|
| 913 |
+
"Epoch 18/100\n",
|
| 914 |
+
"1433/1433 [==============================] - 21s 14ms/step - loss: 2.6470e-07 - accuracy: 1.0000 - val_loss: 5.7550e-05 - val_accuracy: 1.0000\n",
|
| 915 |
+
"Epoch 19/100\n",
|
| 916 |
+
"1433/1433 [==============================] - 21s 14ms/step - loss: 1.2669e-07 - accuracy: 1.0000 - val_loss: 4.0596e-06 - val_accuracy: 1.0000\n",
|
| 917 |
+
"Epoch 20/100\n",
|
| 918 |
+
"1433/1433 [==============================] - 22s 16ms/step - loss: 6.8169e-08 - accuracy: 1.0000 - val_loss: 1.9726e-05 - val_accuracy: 1.0000\n",
|
| 919 |
+
"Epoch 21/100\n",
|
| 920 |
+
"1433/1433 [==============================] - 21s 15ms/step - loss: 4.3585e-08 - accuracy: 1.0000 - val_loss: 9.4645e-05 - val_accuracy: 1.0000\n",
|
| 921 |
+
"Epoch 22/100\n",
|
| 922 |
+
"1433/1433 [==============================] - 21s 15ms/step - loss: 0.0042 - accuracy: 0.9994 - val_loss: 0.0031 - val_accuracy: 0.9992\n",
|
| 923 |
+
"Epoch 23/100\n",
|
| 924 |
+
"1433/1433 [==============================] - 21s 14ms/step - loss: 0.0024 - accuracy: 0.9994 - val_loss: 2.0752e-05 - val_accuracy: 1.0000\n",
|
| 925 |
+
"Epoch 24/100\n",
|
| 926 |
+
"1433/1433 [==============================] - 21s 15ms/step - loss: 1.0676e-04 - accuracy: 0.9999 - val_loss: 1.2278e-05 - val_accuracy: 1.0000\n"
|
| 927 |
+
]
|
| 928 |
+
},
|
| 929 |
+
{
|
| 930 |
+
"data": {
|
| 931 |
+
"text/plain": [
|
| 932 |
+
"<keras.callbacks.History at 0x326e838b0>"
|
| 933 |
+
]
|
| 934 |
+
},
|
| 935 |
+
"execution_count": 135,
|
| 936 |
+
"metadata": {},
|
| 937 |
+
"output_type": "execute_result"
|
| 938 |
+
}
|
| 939 |
+
],
|
| 940 |
+
"source": [
|
| 941 |
+
"model_7 = get_best_model(tuner)\n",
|
| 942 |
+
"model_7.fit(x_train, y_train, epochs=100, batch_size=10, validation_data=(x_test, y_test), callbacks=[stop_early])"
|
| 943 |
+
]
|
| 944 |
+
},
|
| 945 |
+
{
|
| 946 |
+
"cell_type": "code",
|
| 947 |
+
"execution_count": 138,
|
| 948 |
+
"metadata": {},
|
| 949 |
+
"outputs": [],
|
| 950 |
+
"source": [
|
| 951 |
+
"final_models[\"7_layers\"] = model_7"
|
| 952 |
+
]
|
| 953 |
+
},
|
| 954 |
+
{
|
| 955 |
+
"cell_type": "markdown",
|
| 956 |
+
"metadata": {},
|
| 957 |
+
"source": [
|
| 958 |
+
"### 3.6. Final Models Description"
|
| 959 |
+
]
|
| 960 |
+
},
|
| 961 |
+
{
|
| 962 |
+
"cell_type": "code",
|
| 963 |
+
"execution_count": 10,
|
| 964 |
+
"metadata": {},
|
| 965 |
+
"outputs": [
|
| 966 |
+
{
|
| 967 |
+
"name": "stdout",
|
| 968 |
+
"output_type": "stream",
|
| 969 |
+
"text": [
|
| 970 |
+
"7_layers_with_dropout: Describe models architecture\n",
|
| 971 |
+
"Layer-1: 52 units, func: <function relu at 0x15db71b80>\n",
|
| 972 |
+
"Layer-2: 288 units, func: <function relu at 0x15db71b80>\n",
|
| 973 |
+
"Layer-3: 0 units, func: None\n",
|
| 974 |
+
"Layer-4: 224 units, func: <function relu at 0x15db71b80>\n",
|
| 975 |
+
"Layer-5: 0 units, func: None\n",
|
| 976 |
+
"Layer-6: 320 units, func: <function relu at 0x15db71b80>\n",
|
| 977 |
+
"Layer-7: 2 units, func: <function softmax at 0x15db71160>\n",
|
| 978 |
+
"\n",
|
| 979 |
+
"5_layers: Describe models architecture\n",
|
| 980 |
+
"Layer-1: 52 units, func: <function relu at 0x15db71b80>\n",
|
| 981 |
+
"Layer-2: 480 units, func: <function tanh at 0x15db71ee0>\n",
|
| 982 |
+
"Layer-3: 480 units, func: <function tanh at 0x15db71ee0>\n",
|
| 983 |
+
"Layer-4: 192 units, func: <function tanh at 0x15db71ee0>\n",
|
| 984 |
+
"Layer-5: 2 units, func: <function softmax at 0x15db71160>\n",
|
| 985 |
+
"\n",
|
| 986 |
+
"3_layers: Describe models architecture\n",
|
| 987 |
+
"Layer-1: 52 units, func: <function relu at 0x15db71b80>\n",
|
| 988 |
+
"Layer-2: 192 units, func: <function relu at 0x15db71b80>\n",
|
| 989 |
+
"Layer-3: 2 units, func: <function softmax at 0x15db71160>\n",
|
| 990 |
+
"\n",
|
| 991 |
+
"7_layers: Describe models architecture\n",
|
| 992 |
+
"Layer-1: 52 units, func: <function relu at 0x15db71b80>\n",
|
| 993 |
+
"Layer-2: 32 units, func: <function relu at 0x15db71b80>\n",
|
| 994 |
+
"Layer-3: 416 units, func: <function relu at 0x15db71b80>\n",
|
| 995 |
+
"Layer-4: 192 units, func: <function relu at 0x15db71b80>\n",
|
| 996 |
+
"Layer-5: 224 units, func: <function relu at 0x15db71b80>\n",
|
| 997 |
+
"Layer-6: 416 units, func: <function relu at 0x15db71b80>\n",
|
| 998 |
+
"Layer-7: 2 units, func: <function softmax at 0x15db71160>\n",
|
| 999 |
+
"\n"
|
| 1000 |
+
]
|
| 1001 |
+
}
|
| 1002 |
+
],
|
| 1003 |
+
"source": [
|
| 1004 |
+
"for name, model in final_models.items():\n",
|
| 1005 |
+
" print(f\"{name}: \", end=\"\")\n",
|
| 1006 |
+
" describe_model(model)\n",
|
| 1007 |
+
" print()"
|
| 1008 |
+
]
|
| 1009 |
+
},
|
| 1010 |
+
{
|
| 1011 |
+
"cell_type": "markdown",
|
| 1012 |
+
"metadata": {},
|
| 1013 |
+
"source": [
|
| 1014 |
+
"## 4. Model Evaluation"
|
| 1015 |
+
]
|
| 1016 |
+
},
|
| 1017 |
+
{
|
| 1018 |
+
"cell_type": "markdown",
|
| 1019 |
+
"metadata": {},
|
| 1020 |
+
"source": [
|
| 1021 |
+
"### 4.1. Train set"
|
| 1022 |
+
]
|
| 1023 |
+
},
|
| 1024 |
+
{
|
| 1025 |
+
"cell_type": "code",
|
| 1026 |
+
"execution_count": 14,
|
| 1027 |
+
"metadata": {},
|
| 1028 |
+
"outputs": [
|
| 1029 |
+
{
|
| 1030 |
+
"data": {
|
| 1031 |
+
"text/html": [
|
| 1032 |
+
"<div>\n",
|
| 1033 |
+
"<style scoped>\n",
|
| 1034 |
+
" .dataframe tbody tr th:only-of-type {\n",
|
| 1035 |
+
" vertical-align: middle;\n",
|
| 1036 |
+
" }\n",
|
| 1037 |
+
"\n",
|
| 1038 |
+
" .dataframe tbody tr th {\n",
|
| 1039 |
+
" vertical-align: top;\n",
|
| 1040 |
+
" }\n",
|
| 1041 |
+
"\n",
|
| 1042 |
+
" .dataframe thead th {\n",
|
| 1043 |
+
" text-align: right;\n",
|
| 1044 |
+
" }\n",
|
| 1045 |
+
"</style>\n",
|
| 1046 |
+
"<table border=\"1\" class=\"dataframe\">\n",
|
| 1047 |
+
" <thead>\n",
|
| 1048 |
+
" <tr style=\"text-align: right;\">\n",
|
| 1049 |
+
" <th></th>\n",
|
| 1050 |
+
" <th>Model</th>\n",
|
| 1051 |
+
" <th>Precision Score</th>\n",
|
| 1052 |
+
" <th>Recall Score</th>\n",
|
| 1053 |
+
" <th>F1 score</th>\n",
|
| 1054 |
+
" <th>Confusion Matrix</th>\n",
|
| 1055 |
+
" </tr>\n",
|
| 1056 |
+
" </thead>\n",
|
| 1057 |
+
" <tbody>\n",
|
| 1058 |
+
" <tr>\n",
|
| 1059 |
+
" <th>0</th>\n",
|
| 1060 |
+
" <td>7_layers_with_dropout</td>\n",
|
| 1061 |
+
" <td>[1.0, 0.999]</td>\n",
|
| 1062 |
+
" <td>[0.999, 1.0]</td>\n",
|
| 1063 |
+
" <td>[1.0, 1.0]</td>\n",
|
| 1064 |
+
" <td>[[1805, 1], [0, 1776]]</td>\n",
|
| 1065 |
+
" </tr>\n",
|
| 1066 |
+
" <tr>\n",
|
| 1067 |
+
" <th>1</th>\n",
|
| 1068 |
+
" <td>3_layers</td>\n",
|
| 1069 |
+
" <td>[1.0, 0.999]</td>\n",
|
| 1070 |
+
" <td>[0.999, 1.0]</td>\n",
|
| 1071 |
+
" <td>[1.0, 1.0]</td>\n",
|
| 1072 |
+
" <td>[[1805, 1], [0, 1776]]</td>\n",
|
| 1073 |
+
" </tr>\n",
|
| 1074 |
+
" <tr>\n",
|
| 1075 |
+
" <th>2</th>\n",
|
| 1076 |
+
" <td>7_layers</td>\n",
|
| 1077 |
+
" <td>[1.0, 1.0]</td>\n",
|
| 1078 |
+
" <td>[1.0, 1.0]</td>\n",
|
| 1079 |
+
" <td>[1.0, 1.0]</td>\n",
|
| 1080 |
+
" <td>[[1806, 0], [0, 1776]]</td>\n",
|
| 1081 |
+
" </tr>\n",
|
| 1082 |
+
" <tr>\n",
|
| 1083 |
+
" <th>3</th>\n",
|
| 1084 |
+
" <td>5_layers</td>\n",
|
| 1085 |
+
" <td>[1.0, 0.994]</td>\n",
|
| 1086 |
+
" <td>[0.994, 1.0]</td>\n",
|
| 1087 |
+
" <td>[0.997, 0.997]</td>\n",
|
| 1088 |
+
" <td>[[1795, 11], [0, 1776]]</td>\n",
|
| 1089 |
+
" </tr>\n",
|
| 1090 |
+
" </tbody>\n",
|
| 1091 |
+
"</table>\n",
|
| 1092 |
+
"</div>"
|
| 1093 |
+
],
|
| 1094 |
+
"text/plain": [
|
| 1095 |
+
" Model Precision Score Recall Score F1 score \\\n",
|
| 1096 |
+
"0 7_layers_with_dropout [1.0, 0.999] [0.999, 1.0] [1.0, 1.0] \n",
|
| 1097 |
+
"1 3_layers [1.0, 0.999] [0.999, 1.0] [1.0, 1.0] \n",
|
| 1098 |
+
"2 7_layers [1.0, 1.0] [1.0, 1.0] [1.0, 1.0] \n",
|
| 1099 |
+
"3 5_layers [1.0, 0.994] [0.994, 1.0] [0.997, 0.997] \n",
|
| 1100 |
+
"\n",
|
| 1101 |
+
" Confusion Matrix \n",
|
| 1102 |
+
"0 [[1805, 1], [0, 1776]] \n",
|
| 1103 |
+
"1 [[1805, 1], [0, 1776]] \n",
|
| 1104 |
+
"2 [[1806, 0], [0, 1776]] \n",
|
| 1105 |
+
"3 [[1795, 11], [0, 1776]] "
|
| 1106 |
+
]
|
| 1107 |
+
},
|
| 1108 |
+
"execution_count": 14,
|
| 1109 |
+
"metadata": {},
|
| 1110 |
+
"output_type": "execute_result"
|
| 1111 |
+
}
|
| 1112 |
+
],
|
| 1113 |
+
"source": [
|
| 1114 |
+
"train_set_results = []\n",
|
| 1115 |
+
"\n",
|
| 1116 |
+
"for name, model in final_models.items():\n",
|
| 1117 |
+
" # Evaluate model\n",
|
| 1118 |
+
" predict_x = model.predict(x_test, verbose=False) \n",
|
| 1119 |
+
" y_pred_class = np.argmax(predict_x, axis=1)\n",
|
| 1120 |
+
" y_test_class = np.argmax(y_test, axis=1)\n",
|
| 1121 |
+
"\n",
|
| 1122 |
+
" cm = confusion_matrix(y_test_class, y_pred_class, labels=[0, 1])\n",
|
| 1123 |
+
" (p_score, r_score, f_score, _) = precision_recall_fscore_support(y_test_class, y_pred_class, labels=[0, 1])\n",
|
| 1124 |
+
" \n",
|
| 1125 |
+
" train_set_results.append(( name, round_up_metric_results(p_score), round_up_metric_results(r_score), round_up_metric_results(f_score), cm ))\n",
|
| 1126 |
+
"\n",
|
| 1127 |
+
"train_set_results.sort(key=lambda k: sum(k[3]), reverse=True)\n",
|
| 1128 |
+
"pd.DataFrame(train_set_results, columns=[\"Model\", \"Precision Score\", \"Recall Score\", \"F1 score\", \"Confusion Matrix\"])"
|
| 1129 |
+
]
|
| 1130 |
+
},
|
| 1131 |
+
{
|
| 1132 |
+
"cell_type": "markdown",
|
| 1133 |
+
"metadata": {},
|
| 1134 |
+
"source": [
|
| 1135 |
+
"### 4.2. Test set evaluation"
|
| 1136 |
+
]
|
| 1137 |
+
},
|
| 1138 |
+
{
|
| 1139 |
+
"cell_type": "code",
|
| 1140 |
+
"execution_count": 15,
|
| 1141 |
+
"metadata": {},
|
| 1142 |
+
"outputs": [],
|
| 1143 |
+
"source": [
|
| 1144 |
+
"test_df = pd.read_csv(TEST_SET_PATH)\n",
|
| 1145 |
+
"\n",
|
| 1146 |
+
"# Categorizing label\n",
|
| 1147 |
+
"test_df.loc[test_df[\"label\"] == \"L\", \"label\"] = 0\n",
|
| 1148 |
+
"test_df.loc[test_df[\"label\"] == \"C\", \"label\"] = 1"
|
| 1149 |
+
]
|
| 1150 |
+
},
|
| 1151 |
+
{
|
| 1152 |
+
"cell_type": "code",
|
| 1153 |
+
"execution_count": 16,
|
| 1154 |
+
"metadata": {},
|
| 1155 |
+
"outputs": [],
|
| 1156 |
+
"source": [
|
| 1157 |
+
"# Standard Scaling of features\n",
|
| 1158 |
+
"test_x = test_df.drop(\"label\", axis = 1)\n",
|
| 1159 |
+
"test_x = pd.DataFrame(input_scaler.transform(test_x))\n",
|
| 1160 |
+
"\n",
|
| 1161 |
+
"test_y = test_df[\"label\"]\n",
|
| 1162 |
+
"\n",
|
| 1163 |
+
"# # Converting prediction to categorical\n",
|
| 1164 |
+
"test_y_cat = to_categorical(test_y)"
|
| 1165 |
+
]
|
| 1166 |
+
},
|
| 1167 |
+
{
|
| 1168 |
+
"cell_type": "code",
|
| 1169 |
+
"execution_count": 17,
|
| 1170 |
+
"metadata": {},
|
| 1171 |
+
"outputs": [
|
| 1172 |
+
{
|
| 1173 |
+
"name": "stderr",
|
| 1174 |
+
"output_type": "stream",
|
| 1175 |
+
"text": [
|
| 1176 |
+
"2022-11-25 15:46:42.537983: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:113] Plugin optimizer for device_type GPU is enabled.\n",
|
| 1177 |
+
"2022-11-25 15:46:42.694947: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:113] Plugin optimizer for device_type GPU is enabled.\n",
|
| 1178 |
+
"2022-11-25 15:46:42.853052: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:113] Plugin optimizer for device_type GPU is enabled.\n",
|
| 1179 |
+
"2022-11-25 15:46:42.974234: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:113] Plugin optimizer for device_type GPU is enabled.\n"
|
| 1180 |
+
]
|
| 1181 |
+
},
|
| 1182 |
+
{
|
| 1183 |
+
"data": {
|
| 1184 |
+
"text/html": [
|
| 1185 |
+
"<div>\n",
|
| 1186 |
+
"<style scoped>\n",
|
| 1187 |
+
" .dataframe tbody tr th:only-of-type {\n",
|
| 1188 |
+
" vertical-align: middle;\n",
|
| 1189 |
+
" }\n",
|
| 1190 |
+
"\n",
|
| 1191 |
+
" .dataframe tbody tr th {\n",
|
| 1192 |
+
" vertical-align: top;\n",
|
| 1193 |
+
" }\n",
|
| 1194 |
+
"\n",
|
| 1195 |
+
" .dataframe thead th {\n",
|
| 1196 |
+
" text-align: right;\n",
|
| 1197 |
+
" }\n",
|
| 1198 |
+
"</style>\n",
|
| 1199 |
+
"<table border=\"1\" class=\"dataframe\">\n",
|
| 1200 |
+
" <thead>\n",
|
| 1201 |
+
" <tr style=\"text-align: right;\">\n",
|
| 1202 |
+
" <th></th>\n",
|
| 1203 |
+
" <th>Model</th>\n",
|
| 1204 |
+
" <th>Precision Score</th>\n",
|
| 1205 |
+
" <th>Recall Score</th>\n",
|
| 1206 |
+
" <th>F1 score</th>\n",
|
| 1207 |
+
" <th>Confusion Matrix</th>\n",
|
| 1208 |
+
" </tr>\n",
|
| 1209 |
+
" </thead>\n",
|
| 1210 |
+
" <tbody>\n",
|
| 1211 |
+
" <tr>\n",
|
| 1212 |
+
" <th>0</th>\n",
|
| 1213 |
+
" <td>3_layers</td>\n",
|
| 1214 |
+
" <td>[0.998, 0.873]</td>\n",
|
| 1215 |
+
" <td>[0.859, 0.998]</td>\n",
|
| 1216 |
+
" <td>[0.923, 0.932]</td>\n",
|
| 1217 |
+
" <td>[[482, 79], [1, 545]]</td>\n",
|
| 1218 |
+
" </tr>\n",
|
| 1219 |
+
" <tr>\n",
|
| 1220 |
+
" <th>1</th>\n",
|
| 1221 |
+
" <td>7_layers_with_dropout</td>\n",
|
| 1222 |
+
" <td>[0.995, 0.786]</td>\n",
|
| 1223 |
+
" <td>[0.736, 0.996]</td>\n",
|
| 1224 |
+
" <td>[0.846, 0.879]</td>\n",
|
| 1225 |
+
" <td>[[413, 148], [2, 544]]</td>\n",
|
| 1226 |
+
" </tr>\n",
|
| 1227 |
+
" <tr>\n",
|
| 1228 |
+
" <th>2</th>\n",
|
| 1229 |
+
" <td>5_layers</td>\n",
|
| 1230 |
+
" <td>[0.963, 0.755]</td>\n",
|
| 1231 |
+
" <td>[0.693, 0.973]</td>\n",
|
| 1232 |
+
" <td>[0.806, 0.85]</td>\n",
|
| 1233 |
+
" <td>[[389, 172], [15, 531]]</td>\n",
|
| 1234 |
+
" </tr>\n",
|
| 1235 |
+
" <tr>\n",
|
| 1236 |
+
" <th>3</th>\n",
|
| 1237 |
+
" <td>7_layers</td>\n",
|
| 1238 |
+
" <td>[0.984, 0.687]</td>\n",
|
| 1239 |
+
" <td>[0.561, 0.991]</td>\n",
|
| 1240 |
+
" <td>[0.715, 0.812]</td>\n",
|
| 1241 |
+
" <td>[[315, 246], [5, 541]]</td>\n",
|
| 1242 |
+
" </tr>\n",
|
| 1243 |
+
" </tbody>\n",
|
| 1244 |
+
"</table>\n",
|
| 1245 |
+
"</div>"
|
| 1246 |
+
],
|
| 1247 |
+
"text/plain": [
|
| 1248 |
+
" Model Precision Score Recall Score F1 score \\\n",
|
| 1249 |
+
"0 3_layers [0.998, 0.873] [0.859, 0.998] [0.923, 0.932] \n",
|
| 1250 |
+
"1 7_layers_with_dropout [0.995, 0.786] [0.736, 0.996] [0.846, 0.879] \n",
|
| 1251 |
+
"2 5_layers [0.963, 0.755] [0.693, 0.973] [0.806, 0.85] \n",
|
| 1252 |
+
"3 7_layers [0.984, 0.687] [0.561, 0.991] [0.715, 0.812] \n",
|
| 1253 |
+
"\n",
|
| 1254 |
+
" Confusion Matrix \n",
|
| 1255 |
+
"0 [[482, 79], [1, 545]] \n",
|
| 1256 |
+
"1 [[413, 148], [2, 544]] \n",
|
| 1257 |
+
"2 [[389, 172], [15, 531]] \n",
|
| 1258 |
+
"3 [[315, 246], [5, 541]] "
|
| 1259 |
+
]
|
| 1260 |
+
},
|
| 1261 |
+
"execution_count": 17,
|
| 1262 |
+
"metadata": {},
|
| 1263 |
+
"output_type": "execute_result"
|
| 1264 |
+
}
|
| 1265 |
+
],
|
| 1266 |
+
"source": [
|
| 1267 |
+
"test_set_results = []\n",
|
| 1268 |
+
"\n",
|
| 1269 |
+
"for name, model in final_models.items():\n",
|
| 1270 |
+
" # Evaluate model\n",
|
| 1271 |
+
" predict_x = model.predict(test_x, verbose=False) \n",
|
| 1272 |
+
" y_pred_class = np.argmax(predict_x, axis=1)\n",
|
| 1273 |
+
" y_test_class = np.argmax(test_y_cat, axis=1)\n",
|
| 1274 |
+
"\n",
|
| 1275 |
+
" cm = confusion_matrix(y_test_class, y_pred_class, labels=[0, 1])\n",
|
| 1276 |
+
" (p_score, r_score, f_score, _) = precision_recall_fscore_support(y_test_class, y_pred_class, labels=[0, 1])\n",
|
| 1277 |
+
" \n",
|
| 1278 |
+
" test_set_results.append(( name, round_up_metric_results(p_score), round_up_metric_results(r_score), round_up_metric_results(f_score), cm ))\n",
|
| 1279 |
+
"\n",
|
| 1280 |
+
"test_set_results.sort(key=lambda k: sum(k[3]), reverse=True)\n",
|
| 1281 |
+
"pd.DataFrame(test_set_results, columns=[\"Model\", \"Precision Score\", \"Recall Score\", \"F1 score\", \"Confusion Matrix\"])"
|
| 1282 |
+
]
|
| 1283 |
+
},
|
| 1284 |
+
{
|
| 1285 |
+
"cell_type": "markdown",
|
| 1286 |
+
"metadata": {},
|
| 1287 |
+
"source": [
|
| 1288 |
+
"## 5. Dump Model"
|
| 1289 |
+
]
|
| 1290 |
+
},
|
| 1291 |
+
{
|
| 1292 |
+
"cell_type": "code",
|
| 1293 |
+
"execution_count": 173,
|
| 1294 |
+
"metadata": {},
|
| 1295 |
+
"outputs": [
|
| 1296 |
+
{
|
| 1297 |
+
"name": "stdout",
|
| 1298 |
+
"output_type": "stream",
|
| 1299 |
+
"text": [
|
| 1300 |
+
"INFO:tensorflow:Assets written to: ram://4444bb6e-de2c-4bed-806c-db9d955522ad/assets\n"
|
| 1301 |
+
]
|
| 1302 |
+
}
|
| 1303 |
+
],
|
| 1304 |
+
"source": [
|
| 1305 |
+
"# Dump the best model to a pickle file\n",
|
| 1306 |
+
"with open(\"./model/dp/err_lunge_dp.pkl\", \"wb\") as f:\n",
|
| 1307 |
+
" pickle.dump(final_models[\"3_layers\"], f)"
|
| 1308 |
+
]
|
| 1309 |
+
},
|
| 1310 |
+
{
|
| 1311 |
+
"cell_type": "code",
|
| 1312 |
+
"execution_count": 174,
|
| 1313 |
+
"metadata": {},
|
| 1314 |
+
"outputs": [
|
| 1315 |
+
{
|
| 1316 |
+
"name": "stdout",
|
| 1317 |
+
"output_type": "stream",
|
| 1318 |
+
"text": [
|
| 1319 |
+
"INFO:tensorflow:Assets written to: ram://0f5761b9-5d62-4bf6-8de5-34355293f17a/assets\n",
|
| 1320 |
+
"INFO:tensorflow:Assets written to: ram://ac2add31-fb1f-4baa-b925-718896fea315/assets\n",
|
| 1321 |
+
"INFO:tensorflow:Assets written to: ram://20b6d7c9-7ddf-42a1-aaba-c31e12762a10/assets\n",
|
| 1322 |
+
"INFO:tensorflow:Assets written to: ram://ee8c66ec-0ee3-4a5c-bc03-f12225018c47/assets\n"
|
| 1323 |
+
]
|
| 1324 |
+
}
|
| 1325 |
+
],
|
| 1326 |
+
"source": [
|
| 1327 |
+
"with open(\"./model/dp/all_models.pkl\", \"wb\") as f:\n",
|
| 1328 |
+
" pickle.dump(final_models, f)"
|
| 1329 |
+
]
|
| 1330 |
+
},
|
| 1331 |
+
{
|
| 1332 |
+
"cell_type": "code",
|
| 1333 |
+
"execution_count": null,
|
| 1334 |
+
"metadata": {},
|
| 1335 |
+
"outputs": [],
|
| 1336 |
+
"source": []
|
| 1337 |
+
}
|
| 1338 |
+
],
|
| 1339 |
+
"metadata": {
|
| 1340 |
+
"kernelspec": {
|
| 1341 |
+
"display_name": "Python 3.8.13 (conda)",
|
| 1342 |
+
"language": "python",
|
| 1343 |
+
"name": "python3"
|
| 1344 |
+
},
|
| 1345 |
+
"language_info": {
|
| 1346 |
+
"codemirror_mode": {
|
| 1347 |
+
"name": "ipython",
|
| 1348 |
+
"version": 3
|
| 1349 |
+
},
|
| 1350 |
+
"file_extension": ".py",
|
| 1351 |
+
"mimetype": "text/x-python",
|
| 1352 |
+
"name": "python",
|
| 1353 |
+
"nbconvert_exporter": "python",
|
| 1354 |
+
"pygments_lexer": "ipython3",
|
| 1355 |
+
"version": "3.8.13"
|
| 1356 |
+
},
|
| 1357 |
+
"orig_nbformat": 4,
|
| 1358 |
+
"vscode": {
|
| 1359 |
+
"interpreter": {
|
| 1360 |
+
"hash": "9260f401923fb5c4108c543a7d176de9733d378b3752e49535ad7c43c2271b65"
|
| 1361 |
+
}
|
| 1362 |
+
}
|
| 1363 |
+
},
|
| 1364 |
+
"nbformat": 4,
|
| 1365 |
+
"nbformat_minor": 2
|
| 1366 |
+
}
|
core/lunge_model/8.err.evaluation.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
core/lunge_model/9.err.detection.ipynb
ADDED
|
@@ -0,0 +1,714 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": 1,
|
| 6 |
+
"metadata": {},
|
| 7 |
+
"outputs": [
|
| 8 |
+
{
|
| 9 |
+
"name": "stderr",
|
| 10 |
+
"output_type": "stream",
|
| 11 |
+
"text": [
|
| 12 |
+
"objc[21249]: Class CaptureDelegate is implemented in both /Users/fuixlabsdev1/Programming/PP/graduation-thesis/env/lib/python3.8/site-packages/mediapipe/.dylibs/libopencv_videoio.3.4.16.dylib (0x111188860) and /Users/fuixlabsdev1/Programming/PP/graduation-thesis/env/lib/python3.8/site-packages/cv2/cv2.abi3.so (0x15eece480). One of the two will be used. Which one is undefined.\n",
|
| 13 |
+
"objc[21249]: Class CVWindow is implemented in both /Users/fuixlabsdev1/Programming/PP/graduation-thesis/env/lib/python3.8/site-packages/mediapipe/.dylibs/libopencv_highgui.3.4.16.dylib (0x107550a68) and /Users/fuixlabsdev1/Programming/PP/graduation-thesis/env/lib/python3.8/site-packages/cv2/cv2.abi3.so (0x15eece4d0). One of the two will be used. Which one is undefined.\n",
|
| 14 |
+
"objc[21249]: Class CVView is implemented in both /Users/fuixlabsdev1/Programming/PP/graduation-thesis/env/lib/python3.8/site-packages/mediapipe/.dylibs/libopencv_highgui.3.4.16.dylib (0x107550a90) and /Users/fuixlabsdev1/Programming/PP/graduation-thesis/env/lib/python3.8/site-packages/cv2/cv2.abi3.so (0x15eece4f8). One of the two will be used. Which one is undefined.\n",
|
| 15 |
+
"objc[21249]: Class CVSlider is implemented in both /Users/fuixlabsdev1/Programming/PP/graduation-thesis/env/lib/python3.8/site-packages/mediapipe/.dylibs/libopencv_highgui.3.4.16.dylib (0x107550ab8) and /Users/fuixlabsdev1/Programming/PP/graduation-thesis/env/lib/python3.8/site-packages/cv2/cv2.abi3.so (0x15eece520). One of the two will be used. Which one is undefined.\n"
|
| 16 |
+
]
|
| 17 |
+
}
|
| 18 |
+
],
|
| 19 |
+
"source": [
|
| 20 |
+
"import mediapipe as mp\n",
|
| 21 |
+
"import cv2\n",
|
| 22 |
+
"import numpy as np\n",
|
| 23 |
+
"import pandas as pd\n",
|
| 24 |
+
"import traceback\n",
|
| 25 |
+
"import pickle\n",
|
| 26 |
+
"\n",
|
| 27 |
+
"import warnings\n",
|
| 28 |
+
"warnings.filterwarnings('ignore')\n",
|
| 29 |
+
"\n",
|
| 30 |
+
"# Drawing helpers\n",
|
| 31 |
+
"mp_drawing = mp.solutions.drawing_utils\n",
|
| 32 |
+
"mp_pose = mp.solutions.pose"
|
| 33 |
+
]
|
| 34 |
+
},
|
| 35 |
+
{
|
| 36 |
+
"cell_type": "markdown",
|
| 37 |
+
"metadata": {},
|
| 38 |
+
"source": [
|
| 39 |
+
"## 1. Setup important landmarks and functions"
|
| 40 |
+
]
|
| 41 |
+
},
|
| 42 |
+
{
|
| 43 |
+
"cell_type": "code",
|
| 44 |
+
"execution_count": 2,
|
| 45 |
+
"metadata": {},
|
| 46 |
+
"outputs": [],
|
| 47 |
+
"source": [
|
| 48 |
+
"# Determine important landmarks for lunge\n",
|
| 49 |
+
"IMPORTANT_LMS = [\n",
|
| 50 |
+
" \"NOSE\",\n",
|
| 51 |
+
" \"LEFT_SHOULDER\",\n",
|
| 52 |
+
" \"RIGHT_SHOULDER\",\n",
|
| 53 |
+
" \"LEFT_HIP\",\n",
|
| 54 |
+
" \"RIGHT_HIP\",\n",
|
| 55 |
+
" \"LEFT_KNEE\",\n",
|
| 56 |
+
" \"RIGHT_KNEE\",\n",
|
| 57 |
+
" \"LEFT_ANKLE\",\n",
|
| 58 |
+
" \"RIGHT_ANKLE\",\n",
|
| 59 |
+
" \"LEFT_HEEL\",\n",
|
| 60 |
+
" \"RIGHT_HEEL\",\n",
|
| 61 |
+
" \"LEFT_FOOT_INDEX\",\n",
|
| 62 |
+
" \"RIGHT_FOOT_INDEX\",\n",
|
| 63 |
+
"]\n",
|
| 64 |
+
"\n",
|
| 65 |
+
"# Generate all columns of the data frame\n",
|
| 66 |
+
"\n",
|
| 67 |
+
"HEADERS = [\"label\"] # Label column\n",
|
| 68 |
+
"\n",
|
| 69 |
+
"for lm in IMPORTANT_LMS:\n",
|
| 70 |
+
" HEADERS += [f\"{lm.lower()}_x\", f\"{lm.lower()}_y\", f\"{lm.lower()}_z\", f\"{lm.lower()}_v\"]"
|
| 71 |
+
]
|
| 72 |
+
},
|
| 73 |
+
{
|
| 74 |
+
"cell_type": "code",
|
| 75 |
+
"execution_count": 3,
|
| 76 |
+
"metadata": {},
|
| 77 |
+
"outputs": [],
|
| 78 |
+
"source": [
|
| 79 |
+
"def extract_important_keypoints(results) -> list:\n",
|
| 80 |
+
" '''\n",
|
| 81 |
+
" Extract important keypoints from mediapipe pose detection\n",
|
| 82 |
+
" '''\n",
|
| 83 |
+
" landmarks = results.pose_landmarks.landmark\n",
|
| 84 |
+
"\n",
|
| 85 |
+
" data = []\n",
|
| 86 |
+
" for lm in IMPORTANT_LMS:\n",
|
| 87 |
+
" keypoint = landmarks[mp_pose.PoseLandmark[lm].value]\n",
|
| 88 |
+
" data.append([keypoint.x, keypoint.y, keypoint.z, keypoint.visibility])\n",
|
| 89 |
+
" \n",
|
| 90 |
+
" return np.array(data).flatten().tolist()\n",
|
| 91 |
+
"\n",
|
| 92 |
+
"\n",
|
| 93 |
+
"def rescale_frame(frame, percent=50):\n",
|
| 94 |
+
" '''\n",
|
| 95 |
+
" Rescale a frame to a certain percentage compare to its original frame\n",
|
| 96 |
+
" '''\n",
|
| 97 |
+
" width = int(frame.shape[1] * percent/ 100)\n",
|
| 98 |
+
" height = int(frame.shape[0] * percent/ 100)\n",
|
| 99 |
+
" dim = (width, height)\n",
|
| 100 |
+
" return cv2.resize(frame, dim, interpolation =cv2.INTER_AREA)\n",
|
| 101 |
+
"\n",
|
| 102 |
+
"\n",
|
| 103 |
+
"def calculate_angle(point1: list, point2: list, point3: list) -> float:\n",
|
| 104 |
+
" '''\n",
|
| 105 |
+
" Calculate the angle between 3 points\n",
|
| 106 |
+
" Unit of the angle will be in Degree\n",
|
| 107 |
+
" '''\n",
|
| 108 |
+
" point1 = np.array(point1)\n",
|
| 109 |
+
" point2 = np.array(point2)\n",
|
| 110 |
+
" point3 = np.array(point3)\n",
|
| 111 |
+
"\n",
|
| 112 |
+
" # Calculate algo\n",
|
| 113 |
+
" angleInRad = np.arctan2(point3[1] - point2[1], point3[0] - point2[0]) - np.arctan2(point1[1] - point2[1], point1[0] - point2[0])\n",
|
| 114 |
+
" angleInDeg = np.abs(angleInRad * 180.0 / np.pi)\n",
|
| 115 |
+
"\n",
|
| 116 |
+
" angleInDeg = angleInDeg if angleInDeg <= 180 else 360 - angleInDeg\n",
|
| 117 |
+
" return angleInDeg\n",
|
| 118 |
+
" \n",
|
| 119 |
+
"\n",
|
| 120 |
+
"def analyze_knee_angle(\n",
|
| 121 |
+
" mp_results, stage: str, angle_thresholds: list, draw_to_image: tuple = None\n",
|
| 122 |
+
"):\n",
|
| 123 |
+
" \"\"\"\n",
|
| 124 |
+
" Calculate angle of each knee while performer at the DOWN position\n",
|
| 125 |
+
"\n",
|
| 126 |
+
" Return result explanation:\n",
|
| 127 |
+
" error: True if at least 1 error\n",
|
| 128 |
+
" right\n",
|
| 129 |
+
" error: True if an error is on the right knee\n",
|
| 130 |
+
" angle: Right knee angle\n",
|
| 131 |
+
" left\n",
|
| 132 |
+
" error: True if an error is on the left knee\n",
|
| 133 |
+
" angle: Left knee angle\n",
|
| 134 |
+
" \"\"\"\n",
|
| 135 |
+
" results = {\n",
|
| 136 |
+
" \"error\": None,\n",
|
| 137 |
+
" \"right\": {\"error\": None, \"angle\": None},\n",
|
| 138 |
+
" \"left\": {\"error\": None, \"angle\": None},\n",
|
| 139 |
+
" }\n",
|
| 140 |
+
"\n",
|
| 141 |
+
" landmarks = mp_results.pose_landmarks.landmark\n",
|
| 142 |
+
"\n",
|
| 143 |
+
" # Calculate right knee angle\n",
|
| 144 |
+
" right_hip = [\n",
|
| 145 |
+
" landmarks[mp_pose.PoseLandmark.RIGHT_HIP.value].x,\n",
|
| 146 |
+
" landmarks[mp_pose.PoseLandmark.RIGHT_HIP.value].y,\n",
|
| 147 |
+
" ]\n",
|
| 148 |
+
" right_knee = [\n",
|
| 149 |
+
" landmarks[mp_pose.PoseLandmark.RIGHT_KNEE.value].x,\n",
|
| 150 |
+
" landmarks[mp_pose.PoseLandmark.RIGHT_KNEE.value].y,\n",
|
| 151 |
+
" ]\n",
|
| 152 |
+
" right_ankle = [\n",
|
| 153 |
+
" landmarks[mp_pose.PoseLandmark.RIGHT_ANKLE.value].x,\n",
|
| 154 |
+
" landmarks[mp_pose.PoseLandmark.RIGHT_ANKLE.value].y,\n",
|
| 155 |
+
" ]\n",
|
| 156 |
+
" results[\"right\"][\"angle\"] = calculate_angle(right_hip, right_knee, right_ankle)\n",
|
| 157 |
+
"\n",
|
| 158 |
+
" # Calculate left knee angle\n",
|
| 159 |
+
" left_hip = [\n",
|
| 160 |
+
" landmarks[mp_pose.PoseLandmark.LEFT_HIP.value].x,\n",
|
| 161 |
+
" landmarks[mp_pose.PoseLandmark.LEFT_HIP.value].y,\n",
|
| 162 |
+
" ]\n",
|
| 163 |
+
" left_knee = [\n",
|
| 164 |
+
" landmarks[mp_pose.PoseLandmark.LEFT_KNEE.value].x,\n",
|
| 165 |
+
" landmarks[mp_pose.PoseLandmark.LEFT_KNEE.value].y,\n",
|
| 166 |
+
" ]\n",
|
| 167 |
+
" left_ankle = [\n",
|
| 168 |
+
" landmarks[mp_pose.PoseLandmark.LEFT_ANKLE.value].x,\n",
|
| 169 |
+
" landmarks[mp_pose.PoseLandmark.LEFT_ANKLE.value].y,\n",
|
| 170 |
+
" ]\n",
|
| 171 |
+
" results[\"left\"][\"angle\"] = calculate_angle(left_hip, left_knee, left_ankle)\n",
|
| 172 |
+
"\n",
|
| 173 |
+
" # Draw to image\n",
|
| 174 |
+
" if draw_to_image is not None and stage != \"down\":\n",
|
| 175 |
+
" (image, video_dimensions) = draw_to_image\n",
|
| 176 |
+
"\n",
|
| 177 |
+
" # Visualize angles\n",
|
| 178 |
+
" cv2.putText(\n",
|
| 179 |
+
" image,\n",
|
| 180 |
+
" str(int(results[\"right\"][\"angle\"])),\n",
|
| 181 |
+
" tuple(np.multiply(right_knee, video_dimensions).astype(int)),\n",
|
| 182 |
+
" cv2.FONT_HERSHEY_COMPLEX,\n",
|
| 183 |
+
" 0.5,\n",
|
| 184 |
+
" (255, 255, 255),\n",
|
| 185 |
+
" 1,\n",
|
| 186 |
+
" cv2.LINE_AA,\n",
|
| 187 |
+
" )\n",
|
| 188 |
+
" cv2.putText(\n",
|
| 189 |
+
" image,\n",
|
| 190 |
+
" str(int(results[\"left\"][\"angle\"])),\n",
|
| 191 |
+
" tuple(np.multiply(left_knee, video_dimensions).astype(int)),\n",
|
| 192 |
+
" cv2.FONT_HERSHEY_COMPLEX,\n",
|
| 193 |
+
" 0.5,\n",
|
| 194 |
+
" (255, 255, 255),\n",
|
| 195 |
+
" 1,\n",
|
| 196 |
+
" cv2.LINE_AA,\n",
|
| 197 |
+
" )\n",
|
| 198 |
+
"\n",
|
| 199 |
+
" if stage != \"down\":\n",
|
| 200 |
+
" return results\n",
|
| 201 |
+
"\n",
|
| 202 |
+
" # Evaluation\n",
|
| 203 |
+
" results[\"error\"] = False\n",
|
| 204 |
+
"\n",
|
| 205 |
+
" if angle_thresholds[0] <= results[\"right\"][\"angle\"] <= angle_thresholds[1]:\n",
|
| 206 |
+
" results[\"right\"][\"error\"] = False\n",
|
| 207 |
+
" else:\n",
|
| 208 |
+
" results[\"right\"][\"error\"] = True\n",
|
| 209 |
+
" results[\"error\"] = True\n",
|
| 210 |
+
"\n",
|
| 211 |
+
" if angle_thresholds[0] <= results[\"left\"][\"angle\"] <= angle_thresholds[1]:\n",
|
| 212 |
+
" results[\"left\"][\"error\"] = False\n",
|
| 213 |
+
" else:\n",
|
| 214 |
+
" results[\"left\"][\"error\"] = True\n",
|
| 215 |
+
" results[\"error\"] = True\n",
|
| 216 |
+
"\n",
|
| 217 |
+
" # Draw to image\n",
|
| 218 |
+
" if draw_to_image is not None:\n",
|
| 219 |
+
" (image, video_dimensions) = draw_to_image\n",
|
| 220 |
+
"\n",
|
| 221 |
+
" right_color = (255, 255, 255) if not results[\"right\"][\"error\"] else (0, 0, 255)\n",
|
| 222 |
+
" left_color = (255, 255, 255) if not results[\"left\"][\"error\"] else (0, 0, 255)\n",
|
| 223 |
+
"\n",
|
| 224 |
+
" right_font_scale = 0.5 if not results[\"right\"][\"error\"] else 1\n",
|
| 225 |
+
" left_font_scale = 0.5 if not results[\"left\"][\"error\"] else 1\n",
|
| 226 |
+
"\n",
|
| 227 |
+
" right_thickness = 1 if not results[\"right\"][\"error\"] else 2\n",
|
| 228 |
+
" left_thickness = 1 if not results[\"left\"][\"error\"] else 2\n",
|
| 229 |
+
"\n",
|
| 230 |
+
" # Visualize angles\n",
|
| 231 |
+
" cv2.putText(\n",
|
| 232 |
+
" image,\n",
|
| 233 |
+
" str(int(results[\"right\"][\"angle\"])),\n",
|
| 234 |
+
" tuple(np.multiply(right_knee, video_dimensions).astype(int)),\n",
|
| 235 |
+
" cv2.FONT_HERSHEY_COMPLEX,\n",
|
| 236 |
+
" right_font_scale,\n",
|
| 237 |
+
" right_color,\n",
|
| 238 |
+
" right_thickness,\n",
|
| 239 |
+
" cv2.LINE_AA,\n",
|
| 240 |
+
" )\n",
|
| 241 |
+
" cv2.putText(\n",
|
| 242 |
+
" image,\n",
|
| 243 |
+
" str(int(results[\"left\"][\"angle\"])),\n",
|
| 244 |
+
" tuple(np.multiply(left_knee, video_dimensions).astype(int)),\n",
|
| 245 |
+
" cv2.FONT_HERSHEY_COMPLEX,\n",
|
| 246 |
+
" left_font_scale,\n",
|
| 247 |
+
" left_color,\n",
|
| 248 |
+
" left_thickness,\n",
|
| 249 |
+
" cv2.LINE_AA,\n",
|
| 250 |
+
" )\n",
|
| 251 |
+
"\n",
|
| 252 |
+
" return results"
|
| 253 |
+
]
|
| 254 |
+
},
|
| 255 |
+
{
|
| 256 |
+
"cell_type": "markdown",
|
| 257 |
+
"metadata": {},
|
| 258 |
+
"source": [
|
| 259 |
+
"## 2. Constants"
|
| 260 |
+
]
|
| 261 |
+
},
|
| 262 |
+
{
|
| 263 |
+
"cell_type": "code",
|
| 264 |
+
"execution_count": 4,
|
| 265 |
+
"metadata": {},
|
| 266 |
+
"outputs": [],
|
| 267 |
+
"source": [
|
| 268 |
+
"VIDEO_PATH1 = \"../data/lunge/lunge_test_3.mp4\"\n",
|
| 269 |
+
"VIDEO_PATH2 = \"../data/lunge/lunge_test_5.mp4\""
|
| 270 |
+
]
|
| 271 |
+
},
|
| 272 |
+
{
|
| 273 |
+
"cell_type": "code",
|
| 274 |
+
"execution_count": 5,
|
| 275 |
+
"metadata": {},
|
| 276 |
+
"outputs": [],
|
| 277 |
+
"source": [
|
| 278 |
+
"with open(\"./model/input_scaler.pkl\", \"rb\") as f:\n",
|
| 279 |
+
" input_scaler = pickle.load(f)"
|
| 280 |
+
]
|
| 281 |
+
},
|
| 282 |
+
{
|
| 283 |
+
"cell_type": "markdown",
|
| 284 |
+
"metadata": {},
|
| 285 |
+
"source": [
|
| 286 |
+
"## 3. Detection with Sklearn Models"
|
| 287 |
+
]
|
| 288 |
+
},
|
| 289 |
+
{
|
| 290 |
+
"cell_type": "code",
|
| 291 |
+
"execution_count": 6,
|
| 292 |
+
"metadata": {},
|
| 293 |
+
"outputs": [],
|
| 294 |
+
"source": [
|
| 295 |
+
"# Load model\n",
|
| 296 |
+
"with open(\"./model/sklearn/stage_SVC_model.pkl\", \"rb\") as f:\n",
|
| 297 |
+
" stage_sklearn_model = pickle.load(f)\n",
|
| 298 |
+
"\n",
|
| 299 |
+
"with open(\"./model/sklearn/err_LR_model.pkl\", \"rb\") as f:\n",
|
| 300 |
+
" err_sklearn_model = pickle.load(f)"
|
| 301 |
+
]
|
| 302 |
+
},
|
| 303 |
+
{
|
| 304 |
+
"cell_type": "code",
|
| 305 |
+
"execution_count": 7,
|
| 306 |
+
"metadata": {},
|
| 307 |
+
"outputs": [
|
| 308 |
+
{
|
| 309 |
+
"name": "stderr",
|
| 310 |
+
"output_type": "stream",
|
| 311 |
+
"text": [
|
| 312 |
+
"INFO: Created TensorFlow Lite XNNPACK delegate for CPU.\n"
|
| 313 |
+
]
|
| 314 |
+
},
|
| 315 |
+
{
|
| 316 |
+
"name": "stdout",
|
| 317 |
+
"output_type": "stream",
|
| 318 |
+
"text": [
|
| 319 |
+
"No human found\n",
|
| 320 |
+
"No human found\n",
|
| 321 |
+
"No human found\n",
|
| 322 |
+
"No human found\n",
|
| 323 |
+
"No human found\n",
|
| 324 |
+
"No human found\n",
|
| 325 |
+
"No human found\n",
|
| 326 |
+
"No human found\n",
|
| 327 |
+
"No human found\n",
|
| 328 |
+
"No human found\n",
|
| 329 |
+
"No human found\n",
|
| 330 |
+
"No human found\n",
|
| 331 |
+
"No human found\n",
|
| 332 |
+
"No human found\n",
|
| 333 |
+
"No human found\n",
|
| 334 |
+
"No human found\n",
|
| 335 |
+
"No human found\n",
|
| 336 |
+
"No human found\n",
|
| 337 |
+
"No human found\n",
|
| 338 |
+
"No human found\n",
|
| 339 |
+
"No human found\n",
|
| 340 |
+
"No human found\n",
|
| 341 |
+
"No human found\n",
|
| 342 |
+
"No human found\n",
|
| 343 |
+
"No human found\n",
|
| 344 |
+
"No human found\n",
|
| 345 |
+
"No human found\n",
|
| 346 |
+
"No human found\n",
|
| 347 |
+
"No human found\n",
|
| 348 |
+
"No human found\n",
|
| 349 |
+
"No human found\n",
|
| 350 |
+
"No human found\n",
|
| 351 |
+
"No human found\n",
|
| 352 |
+
"No human found\n",
|
| 353 |
+
"No human found\n",
|
| 354 |
+
"No human found\n",
|
| 355 |
+
"No human found\n",
|
| 356 |
+
"No human found\n",
|
| 357 |
+
"No human found\n",
|
| 358 |
+
"No human found\n",
|
| 359 |
+
"No human found\n"
|
| 360 |
+
]
|
| 361 |
+
}
|
| 362 |
+
],
|
| 363 |
+
"source": [
|
| 364 |
+
"cap = cv2.VideoCapture(VIDEO_PATH1)\n",
|
| 365 |
+
"current_stage = \"\"\n",
|
| 366 |
+
"counter = 0\n",
|
| 367 |
+
"\n",
|
| 368 |
+
"prediction_probability_threshold = 0.8\n",
|
| 369 |
+
"ANGLE_THRESHOLDS = [60, 135]\n",
|
| 370 |
+
"\n",
|
| 371 |
+
"knee_over_toe = False\n",
|
| 372 |
+
"\n",
|
| 373 |
+
"with mp_pose.Pose(min_detection_confidence=0.5, min_tracking_confidence=0.5) as pose:\n",
|
| 374 |
+
" while cap.isOpened():\n",
|
| 375 |
+
" ret, image = cap.read()\n",
|
| 376 |
+
"\n",
|
| 377 |
+
" if not ret:\n",
|
| 378 |
+
" break\n",
|
| 379 |
+
"\n",
|
| 380 |
+
" # Reduce size of a frame\n",
|
| 381 |
+
" image = rescale_frame(image, 50)\n",
|
| 382 |
+
" video_dimensions = [image.shape[1], image.shape[0]]\n",
|
| 383 |
+
"\n",
|
| 384 |
+
" # Recolor image from BGR to RGB for mediapipe\n",
|
| 385 |
+
" image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)\n",
|
| 386 |
+
" image.flags.writeable = False\n",
|
| 387 |
+
"\n",
|
| 388 |
+
" results = pose.process(image)\n",
|
| 389 |
+
"\n",
|
| 390 |
+
" if not results.pose_landmarks:\n",
|
| 391 |
+
" print(\"No human found\")\n",
|
| 392 |
+
" continue\n",
|
| 393 |
+
"\n",
|
| 394 |
+
" # Recolor image from BGR to RGB for mediapipe\n",
|
| 395 |
+
" image.flags.writeable = True\n",
|
| 396 |
+
" image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)\n",
|
| 397 |
+
"\n",
|
| 398 |
+
" # Draw landmarks and connections\n",
|
| 399 |
+
" mp_drawing.draw_landmarks(image, results.pose_landmarks, mp_pose.POSE_CONNECTIONS, mp_drawing.DrawingSpec(color=(244, 117, 66), thickness=2, circle_radius=2), mp_drawing.DrawingSpec(color=(245, 66, 230), thickness=2, circle_radius=1))\n",
|
| 400 |
+
"\n",
|
| 401 |
+
" # Make detection\n",
|
| 402 |
+
" try:\n",
|
| 403 |
+
" # Extract keypoints from frame for the input\n",
|
| 404 |
+
" row = extract_important_keypoints(results)\n",
|
| 405 |
+
" X = pd.DataFrame([row], columns=HEADERS[1:])\n",
|
| 406 |
+
" X = pd.DataFrame(input_scaler.transform(X))\n",
|
| 407 |
+
"\n",
|
| 408 |
+
" # Make prediction and its probability\n",
|
| 409 |
+
" stage_predicted_class = stage_sklearn_model.predict(X)[0]\n",
|
| 410 |
+
" stage_prediction_probabilities = stage_sklearn_model.predict_proba(X)[0]\n",
|
| 411 |
+
" stage_prediction_probability = round(stage_prediction_probabilities[stage_prediction_probabilities.argmax()], 2)\n",
|
| 412 |
+
"\n",
|
| 413 |
+
" # Evaluate model prediction\n",
|
| 414 |
+
" if stage_predicted_class == \"I\" and stage_prediction_probability >= prediction_probability_threshold:\n",
|
| 415 |
+
" current_stage = \"init\"\n",
|
| 416 |
+
" elif stage_predicted_class == \"M\" and stage_prediction_probability >= prediction_probability_threshold: \n",
|
| 417 |
+
" current_stage = \"mid\"\n",
|
| 418 |
+
" elif stage_predicted_class == \"D\" and stage_prediction_probability >= prediction_probability_threshold:\n",
|
| 419 |
+
" if current_stage in [\"mid\", \"init\"]:\n",
|
| 420 |
+
" counter += 1\n",
|
| 421 |
+
" \n",
|
| 422 |
+
" current_stage = \"down\"\n",
|
| 423 |
+
" \n",
|
| 424 |
+
" # Error detection\n",
|
| 425 |
+
" # Knee angle\n",
|
| 426 |
+
" analyze_knee_angle(mp_results=results, stage=current_stage, angle_thresholds=ANGLE_THRESHOLDS, draw_to_image=(image, video_dimensions))\n",
|
| 427 |
+
"\n",
|
| 428 |
+
" # Knee over toe\n",
|
| 429 |
+
" err_predicted_class = err_prediction_probabilities = err_prediction_probability = None\n",
|
| 430 |
+
" if current_stage == \"down\":\n",
|
| 431 |
+
" err_predicted_class = err_sklearn_model.predict(X)[0]\n",
|
| 432 |
+
" err_prediction_probabilities = err_sklearn_model.predict_proba(X)[0]\n",
|
| 433 |
+
" err_prediction_probability = round(err_prediction_probabilities[err_prediction_probabilities.argmax()], 2)\n",
|
| 434 |
+
" \n",
|
| 435 |
+
" \n",
|
| 436 |
+
" # Visualization\n",
|
| 437 |
+
" # Status box\n",
|
| 438 |
+
" cv2.rectangle(image, (0, 0), (800, 45), (245, 117, 16), -1)\n",
|
| 439 |
+
"\n",
|
| 440 |
+
" # Display stage prediction\n",
|
| 441 |
+
" cv2.putText(image, \"STAGE\", (15, 12), cv2.FONT_HERSHEY_COMPLEX, 0.5, (0, 0, 0), 1, cv2.LINE_AA)\n",
|
| 442 |
+
" cv2.putText(image, str(stage_prediction_probability), (10, 30), cv2.FONT_HERSHEY_COMPLEX, 0.5, (255, 255, 255), 1, cv2.LINE_AA)\n",
|
| 443 |
+
" cv2.putText(image, current_stage, (50, 30), cv2.FONT_HERSHEY_COMPLEX, 0.5, (255, 255, 255), 1, cv2.LINE_AA)\n",
|
| 444 |
+
"\n",
|
| 445 |
+
" # Display error prediction\n",
|
| 446 |
+
" cv2.putText(image, \"K_O_T\", (200, 12), cv2.FONT_HERSHEY_COMPLEX, 0.5, (0, 0, 0), 1, cv2.LINE_AA)\n",
|
| 447 |
+
" cv2.putText(image, str(err_prediction_probability), (195, 30), cv2.FONT_HERSHEY_COMPLEX, 0.5, (255, 255, 255), 1, cv2.LINE_AA)\n",
|
| 448 |
+
" cv2.putText(image, str(err_predicted_class), (245, 30), cv2.FONT_HERSHEY_COMPLEX, 0.5, (255, 255, 255), 1, cv2.LINE_AA)\n",
|
| 449 |
+
"\n",
|
| 450 |
+
" # Display Counter\n",
|
| 451 |
+
" cv2.putText(image, \"COUNTER\", (110, 12), cv2.FONT_HERSHEY_COMPLEX, 0.5, (0, 0, 0), 1, cv2.LINE_AA)\n",
|
| 452 |
+
" cv2.putText(image, str(counter), (110, 30), cv2.FONT_HERSHEY_COMPLEX, 0.5, (255, 255, 255), 1, cv2.LINE_AA)\n",
|
| 453 |
+
"\n",
|
| 454 |
+
" except Exception as e:\n",
|
| 455 |
+
" print(f\"Error: {e}\")\n",
|
| 456 |
+
" traceback.print_exc()\n",
|
| 457 |
+
" break\n",
|
| 458 |
+
" \n",
|
| 459 |
+
" cv2.imshow(\"CV2\", image)\n",
|
| 460 |
+
" \n",
|
| 461 |
+
" # Press Q to close cv2 window\n",
|
| 462 |
+
" if cv2.waitKey(1) & 0xFF == ord('q'):\n",
|
| 463 |
+
" break\n",
|
| 464 |
+
"\n",
|
| 465 |
+
" cap.release()\n",
|
| 466 |
+
" cv2.destroyAllWindows()\n",
|
| 467 |
+
"\n",
|
| 468 |
+
" # (Optional)Fix bugs cannot close windows in MacOS (https://stackoverflow.com/questions/6116564/destroywindow-does-not-close-window-on-mac-using-python-and-opencv)\n",
|
| 469 |
+
" for i in range (1, 5):\n",
|
| 470 |
+
" cv2.waitKey(1)\n",
|
| 471 |
+
" "
|
| 472 |
+
]
|
| 473 |
+
},
|
| 474 |
+
{
|
| 475 |
+
"cell_type": "markdown",
|
| 476 |
+
"metadata": {},
|
| 477 |
+
"source": [
|
| 478 |
+
"## 4. Detection with deep learning model"
|
| 479 |
+
]
|
| 480 |
+
},
|
| 481 |
+
{
|
| 482 |
+
"cell_type": "code",
|
| 483 |
+
"execution_count": 30,
|
| 484 |
+
"metadata": {},
|
| 485 |
+
"outputs": [],
|
| 486 |
+
"source": [
|
| 487 |
+
"# Load model\n",
|
| 488 |
+
"with open(\"./model/dp/err_lunge_dp.pkl\", \"rb\") as f:\n",
|
| 489 |
+
" err_deep_learning_model = pickle.load(f)"
|
| 490 |
+
]
|
| 491 |
+
},
|
| 492 |
+
{
|
| 493 |
+
"cell_type": "code",
|
| 494 |
+
"execution_count": 31,
|
| 495 |
+
"metadata": {},
|
| 496 |
+
"outputs": [
|
| 497 |
+
{
|
| 498 |
+
"name": "stdout",
|
| 499 |
+
"output_type": "stream",
|
| 500 |
+
"text": [
|
| 501 |
+
"No human found\n",
|
| 502 |
+
"No human found\n",
|
| 503 |
+
"No human found\n",
|
| 504 |
+
"No human found\n",
|
| 505 |
+
"No human found\n",
|
| 506 |
+
"No human found\n",
|
| 507 |
+
"No human found\n",
|
| 508 |
+
"No human found\n",
|
| 509 |
+
"No human found\n",
|
| 510 |
+
"No human found\n",
|
| 511 |
+
"No human found\n",
|
| 512 |
+
"No human found\n",
|
| 513 |
+
"No human found\n",
|
| 514 |
+
"No human found\n",
|
| 515 |
+
"No human found\n",
|
| 516 |
+
"No human found\n",
|
| 517 |
+
"No human found\n",
|
| 518 |
+
"No human found\n",
|
| 519 |
+
"No human found\n",
|
| 520 |
+
"No human found\n",
|
| 521 |
+
"No human found\n",
|
| 522 |
+
"No human found\n",
|
| 523 |
+
"No human found\n",
|
| 524 |
+
"No human found\n",
|
| 525 |
+
"No human found\n",
|
| 526 |
+
"No human found\n",
|
| 527 |
+
"No human found\n",
|
| 528 |
+
"No human found\n",
|
| 529 |
+
"No human found\n",
|
| 530 |
+
"No human found\n",
|
| 531 |
+
"No human found\n",
|
| 532 |
+
"No human found\n",
|
| 533 |
+
"No human found\n",
|
| 534 |
+
"No human found\n",
|
| 535 |
+
"No human found\n",
|
| 536 |
+
"No human found\n",
|
| 537 |
+
"No human found\n",
|
| 538 |
+
"No human found\n",
|
| 539 |
+
"No human found\n",
|
| 540 |
+
"No human found\n",
|
| 541 |
+
"No human found\n"
|
| 542 |
+
]
|
| 543 |
+
},
|
| 544 |
+
{
|
| 545 |
+
"name": "stderr",
|
| 546 |
+
"output_type": "stream",
|
| 547 |
+
"text": [
|
| 548 |
+
"2022-11-22 15:17:20.031325: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:113] Plugin optimizer for device_type GPU is enabled.\n"
|
| 549 |
+
]
|
| 550 |
+
},
|
| 551 |
+
{
|
| 552 |
+
"ename": "",
|
| 553 |
+
"evalue": "",
|
| 554 |
+
"output_type": "error",
|
| 555 |
+
"traceback": [
|
| 556 |
+
"\u001b[1;31mThe Kernel crashed while executing code in the the current cell or a previous cell. Please review the code in the cell(s) to identify a possible cause of the failure. Click <a href='https://aka.ms/vscodeJupyterKernelCrash'>here</a> for more info. View Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
|
| 557 |
+
]
|
| 558 |
+
}
|
| 559 |
+
],
|
| 560 |
+
"source": [
|
| 561 |
+
"cap = cv2.VideoCapture(VIDEO_PATH1)\n",
|
| 562 |
+
"current_stage = \"\"\n",
|
| 563 |
+
"counter = 0\n",
|
| 564 |
+
"\n",
|
| 565 |
+
"prediction_probability_threshold = 0.8\n",
|
| 566 |
+
"ANGLE_THRESHOLDS = [60, 135]\n",
|
| 567 |
+
"\n",
|
| 568 |
+
"knee_over_toe = False\n",
|
| 569 |
+
"\n",
|
| 570 |
+
"with mp_pose.Pose(min_detection_confidence=0.5, min_tracking_confidence=0.5) as pose:\n",
|
| 571 |
+
" while cap.isOpened():\n",
|
| 572 |
+
" ret, image = cap.read()\n",
|
| 573 |
+
"\n",
|
| 574 |
+
" if not ret:\n",
|
| 575 |
+
" break\n",
|
| 576 |
+
"\n",
|
| 577 |
+
" # Reduce size of a frame\n",
|
| 578 |
+
" image = rescale_frame(image, 50)\n",
|
| 579 |
+
" video_dimensions = [image.shape[1], image.shape[0]]\n",
|
| 580 |
+
"\n",
|
| 581 |
+
" # Recolor image from BGR to RGB for mediapipe\n",
|
| 582 |
+
" image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)\n",
|
| 583 |
+
" image.flags.writeable = False\n",
|
| 584 |
+
"\n",
|
| 585 |
+
" results = pose.process(image)\n",
|
| 586 |
+
"\n",
|
| 587 |
+
" if not results.pose_landmarks:\n",
|
| 588 |
+
" print(\"No human found\")\n",
|
| 589 |
+
" continue\n",
|
| 590 |
+
"\n",
|
| 591 |
+
" # Recolor image from BGR to RGB for mediapipe\n",
|
| 592 |
+
" image.flags.writeable = True\n",
|
| 593 |
+
" image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)\n",
|
| 594 |
+
"\n",
|
| 595 |
+
" # Draw landmarks and connections\n",
|
| 596 |
+
" mp_drawing.draw_landmarks(image, results.pose_landmarks, mp_pose.POSE_CONNECTIONS, mp_drawing.DrawingSpec(color=(244, 117, 66), thickness=2, circle_radius=2), mp_drawing.DrawingSpec(color=(245, 66, 230), thickness=2, circle_radius=1))\n",
|
| 597 |
+
"\n",
|
| 598 |
+
" # Make detection\n",
|
| 599 |
+
" try:\n",
|
| 600 |
+
" # Extract keypoints from frame for the input\n",
|
| 601 |
+
" row = extract_important_keypoints(results)\n",
|
| 602 |
+
" X = pd.DataFrame([row], columns=HEADERS[1:])\n",
|
| 603 |
+
" X = pd.DataFrame(input_scaler.transform(X))\n",
|
| 604 |
+
"\n",
|
| 605 |
+
" # Make prediction and its probability\n",
|
| 606 |
+
" stage_predicted_class = stage_sklearn_model.predict(X)[0]\n",
|
| 607 |
+
" stage_prediction_probabilities = stage_sklearn_model.predict_proba(X)[0]\n",
|
| 608 |
+
" stage_prediction_probability = round(stage_prediction_probabilities[stage_prediction_probabilities.argmax()], 2)\n",
|
| 609 |
+
"\n",
|
| 610 |
+
" # Evaluate model prediction\n",
|
| 611 |
+
" if stage_predicted_class == \"I\" and stage_prediction_probability >= prediction_probability_threshold:\n",
|
| 612 |
+
" current_stage = \"init\"\n",
|
| 613 |
+
" elif stage_predicted_class == \"M\" and stage_prediction_probability >= prediction_probability_threshold: \n",
|
| 614 |
+
" current_stage = \"mid\"\n",
|
| 615 |
+
" elif stage_predicted_class == \"D\" and stage_prediction_probability >= prediction_probability_threshold:\n",
|
| 616 |
+
" if current_stage == \"mid\":\n",
|
| 617 |
+
" counter += 1\n",
|
| 618 |
+
" \n",
|
| 619 |
+
" current_stage = \"down\"\n",
|
| 620 |
+
" \n",
|
| 621 |
+
" # Error detection\n",
|
| 622 |
+
" # Knee angle\n",
|
| 623 |
+
" analyze_knee_angle(mp_results=results, stage=current_stage, angle_thresholds=ANGLE_THRESHOLDS, draw_to_image=(image, video_dimensions))\n",
|
| 624 |
+
"\n",
|
| 625 |
+
" # Knee over toe\n",
|
| 626 |
+
" err_predicted_class = err_prediction_probabilities = err_prediction_probability = None\n",
|
| 627 |
+
" if current_stage == \"down\":\n",
|
| 628 |
+
" err_prediction = err_deep_learning_model.predict(X, verbose=False)\n",
|
| 629 |
+
" err_predicted_class = np.argmax(err_prediction, axis=1)[0]\n",
|
| 630 |
+
" err_prediction_probability = round(max(err_prediction.tolist()[0]), 2)\n",
|
| 631 |
+
"\n",
|
| 632 |
+
" err_predicted_class = \"C\" if err_predicted_class == 1 else \"L\"\n",
|
| 633 |
+
" \n",
|
| 634 |
+
" \n",
|
| 635 |
+
" # Visualization\n",
|
| 636 |
+
" # Status box\n",
|
| 637 |
+
" cv2.rectangle(image, (0, 0), (800, 45), (245, 117, 16), -1)\n",
|
| 638 |
+
"\n",
|
| 639 |
+
" # Display stage prediction\n",
|
| 640 |
+
" cv2.putText(image, \"STAGE\", (15, 12), cv2.FONT_HERSHEY_COMPLEX, 0.5, (0, 0, 0), 1, cv2.LINE_AA)\n",
|
| 641 |
+
" cv2.putText(image, str(stage_prediction_probability), (10, 30), cv2.FONT_HERSHEY_COMPLEX, 0.5, (255, 255, 255), 1, cv2.LINE_AA)\n",
|
| 642 |
+
" cv2.putText(image, current_stage, (50, 30), cv2.FONT_HERSHEY_COMPLEX, 0.5, (255, 255, 255), 1, cv2.LINE_AA)\n",
|
| 643 |
+
"\n",
|
| 644 |
+
" # Display error prediction\n",
|
| 645 |
+
" cv2.putText(image, \"K_O_T\", (200, 12), cv2.FONT_HERSHEY_COMPLEX, 0.5, (0, 0, 0), 1, cv2.LINE_AA)\n",
|
| 646 |
+
" cv2.putText(image, str(err_prediction_probability), (195, 30), cv2.FONT_HERSHEY_COMPLEX, 0.5, (255, 255, 255), 1, cv2.LINE_AA)\n",
|
| 647 |
+
" cv2.putText(image, str(err_predicted_class), (245, 30), cv2.FONT_HERSHEY_COMPLEX, 0.5, (255, 255, 255), 1, cv2.LINE_AA)\n",
|
| 648 |
+
"\n",
|
| 649 |
+
" # Display Counter\n",
|
| 650 |
+
" cv2.putText(image, \"COUNTER\", (110, 12), cv2.FONT_HERSHEY_COMPLEX, 0.5, (0, 0, 0), 1, cv2.LINE_AA)\n",
|
| 651 |
+
" cv2.putText(image, str(counter), (110, 30), cv2.FONT_HERSHEY_COMPLEX, 0.5, (255, 255, 255), 1, cv2.LINE_AA)\n",
|
| 652 |
+
"\n",
|
| 653 |
+
" except Exception as e:\n",
|
| 654 |
+
" print(f\"Error: {e}\")\n",
|
| 655 |
+
" traceback.print_exc()\n",
|
| 656 |
+
" break\n",
|
| 657 |
+
" \n",
|
| 658 |
+
" cv2.imshow(\"CV2\", image)\n",
|
| 659 |
+
" \n",
|
| 660 |
+
" # Press Q to close cv2 window\n",
|
| 661 |
+
" if cv2.waitKey(1) & 0xFF == ord('q'):\n",
|
| 662 |
+
" break\n",
|
| 663 |
+
"\n",
|
| 664 |
+
" cap.release()\n",
|
| 665 |
+
" cv2.destroyAllWindows()\n",
|
| 666 |
+
"\n",
|
| 667 |
+
" # (Optional)Fix bugs cannot close windows in MacOS (https://stackoverflow.com/questions/6116564/destroywindow-does-not-close-window-on-mac-using-python-and-opencv)\n",
|
| 668 |
+
" for i in range (1, 5):\n",
|
| 669 |
+
" cv2.waitKey(1)\n",
|
| 670 |
+
" "
|
| 671 |
+
]
|
| 672 |
+
},
|
| 673 |
+
{
|
| 674 |
+
"cell_type": "markdown",
|
| 675 |
+
"metadata": {},
|
| 676 |
+
"source": [
|
| 677 |
+
"## 5. Conclusion\n",
|
| 678 |
+
"\n",
|
| 679 |
+
"- For stage detection:\n",
|
| 680 |
+
" - Best Sklearn model: KNN\n",
|
| 681 |
+
"- For error detection:\n",
|
| 682 |
+
" - Best Sklearn model: LR\n",
|
| 683 |
+
" - Both models are correct most of the time"
|
| 684 |
+
]
|
| 685 |
+
}
|
| 686 |
+
],
|
| 687 |
+
"metadata": {
|
| 688 |
+
"kernelspec": {
|
| 689 |
+
"display_name": "Python 3.8.13 (conda)",
|
| 690 |
+
"language": "python",
|
| 691 |
+
"name": "python3"
|
| 692 |
+
},
|
| 693 |
+
"language_info": {
|
| 694 |
+
"codemirror_mode": {
|
| 695 |
+
"name": "ipython",
|
| 696 |
+
"version": 3
|
| 697 |
+
},
|
| 698 |
+
"file_extension": ".py",
|
| 699 |
+
"mimetype": "text/x-python",
|
| 700 |
+
"name": "python",
|
| 701 |
+
"nbconvert_exporter": "python",
|
| 702 |
+
"pygments_lexer": "ipython3",
|
| 703 |
+
"version": "3.8.13"
|
| 704 |
+
},
|
| 705 |
+
"orig_nbformat": 4,
|
| 706 |
+
"vscode": {
|
| 707 |
+
"interpreter": {
|
| 708 |
+
"hash": "9260f401923fb5c4108c543a7d176de9733d378b3752e49535ad7c43c2271b65"
|
| 709 |
+
}
|
| 710 |
+
}
|
| 711 |
+
},
|
| 712 |
+
"nbformat": 4,
|
| 713 |
+
"nbformat_minor": 2
|
| 714 |
+
}
|
core/lunge_model/README.md
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<h2 align="center">BICEP CURL MODEL TRAINING PROCESS</h2>
|
| 2 |
+
|
| 3 |
+
### 1. Folder structure
|
| 4 |
+
|
| 5 |
+
```
|
| 6 |
+
bicep_model
|
| 7 |
+
│ 1.stage.data.ipynb - process collected videos for lunge stage
|
| 8 |
+
| 2.stage.sklearn.ipynb - train models using Sklearn ML algo for lunge stage
|
| 9 |
+
│ 3.stage.deep_leaning.ipynb - train models using Deep Learning for lunge stage
|
| 10 |
+
│ 4.stage.detection.ipynb - detection on test videos for lunge stage
|
| 11 |
+
│ 5.stage.data.ipynb - process collected videos for lunge stage
|
| 12 |
+
| 6.err.sklearn.ipynb - train models using Sklearn ML algo for lunge error
|
| 13 |
+
│ 7.err.deep_leaning.ipynb - train models using Deep Learning for lunge error
|
| 14 |
+
│ 8.err.evaluation.ipynb - evaluate trained models for lunge error
|
| 15 |
+
│ 9.err.detection.ipynb - detection on test videos for lunge error
|
| 16 |
+
| stage.train.csv - train dataset for lunge stage after converted from videos
|
| 17 |
+
| stage.test.csv - test dataset for lunge stage after converted from videos
|
| 18 |
+
| err.train.csv - train dataset for lunge error after converted from videos
|
| 19 |
+
| err.test.csv - test dataset after for lunge error converted from videos
|
| 20 |
+
| evaluation.csv - models' evaluation results
|
| 21 |
+
│
|
| 22 |
+
└───model/ - folder contains best trained models and input scaler
|
| 23 |
+
│ │
|
| 24 |
+
```
|
| 25 |
+
|
| 26 |
+
### 2. Important landmarks
|
| 27 |
+
|
| 28 |
+
There are 3 popular errors of bicep curl that will be targeted in this thesis:
|
| 29 |
+
|
| 30 |
+
- Loose upper arm: when an arm moves upward during the exercise, the upper arm is moving instead of staying still.
|
| 31 |
+
- Weak peak contraction: when an arm moves upward, it does not go high enough therefore not put enough contraction to the bicep.
|
| 32 |
+
- Lean too far back: the performer’s torso leans back and fore during the exercise for momentum.
|
| 33 |
+
|
| 34 |
+
In my research and exploration, **_the important MediaPipe Pose landmarks_** for this exercise are: nose, left shoulder, right shoulder, right elbow, left elbow, right wrist, left wrist, right hip and left hip.
|
| 35 |
+
|
| 36 |
+
### 3. Error detection method
|
| 37 |
+
|
| 38 |
+
1. **Knee angle**: Can be detected by calculating the angle of the left and right knee. To precisely choose the correct lower and upper thresholds for this error, videos of contributors perform correct form of the exercise are analyzed. In conclusion from the graph, the angle of left/right knee during the low position of a lunge should be in between 60 and 135 degrees.
|
| 39 |
+
|
| 40 |
+
- Analyzed result:
|
| 41 |
+
<p align="center"><img src="../../images/lunge_eval_4.png" alt="Logo" width="70%"></p>
|
| 42 |
+
|
| 43 |
+
1. **Knee over toe**: Due to its complexity, machine learning will be used for this detection. See this [notebook](./8.err.evaluation.ipynb) for a evaluation process for this models.
|
core/lunge_model/err.evaluation.csv
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Model,Precision Score,Accuracy Score,Recall Score,F1 Score,Confusion Matrix
|
| 2 |
+
LR,0.9733139405601312,0.971996386630533,0.971996386630533,0.9719871073494147,"[[545 1]
|
| 3 |
+
[ 30 531]]"
|
| 4 |
+
SGDC,0.9606281199356282,0.957542908762421,0.957542908762421,0.9574961020824345,"[[545 1]
|
| 5 |
+
[ 46 515]]"
|
| 6 |
+
3_layers,0.9365072351551672,0.9277326106594399,0.9277326106594399,0.9274418797827825,"[[545 1]
|
| 7 |
+
[ 79 482]]"
|
| 8 |
+
DTC,0.9192261367058114,0.916892502258356,0.916892502258356,0.9167297143761817,"[[479 67]
|
| 9 |
+
[ 25 536]]"
|
| 10 |
+
7_layers_with_dropout,0.8920702972612289,0.8644986449864499,0.8644986449864499,0.8623537930008395,"[[544 2]
|
| 11 |
+
[148 413]]"
|
| 12 |
+
RF,0.854529379025228,0.8419150858175248,0.8419150858175248,0.8407383790928666,"[[510 36]
|
| 13 |
+
[139 422]]"
|
| 14 |
+
5_layers,0.8605088616622826,0.8310749774164409,0.8310749774164409,0.8279503932280722,"[[531 15]
|
| 15 |
+
[172 389]]"
|
| 16 |
+
NB,0.7703551608827456,0.7687443541102078,0.7687443541102078,0.7682132184217674,"[[395 151]
|
| 17 |
+
[105 456]]"
|
| 18 |
+
KNN,0.7684029707859914,0.7651309846431797,0.7651309846431797,0.7646522633908915,"[[445 101]
|
| 19 |
+
[159 402]]"
|
| 20 |
+
7_layers,0.8379096785329353,0.7732610659439928,0.7732610659439928,0.7627451882949882,"[[541 5]
|
| 21 |
+
[246 315]]"
|
| 22 |
+
SVC,0.751947122240204,0.7199638663053297,0.7199638663053297,0.7118153246676994,"[[488 58]
|
| 23 |
+
[252 309]]"
|
core/lunge_model/err.test.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
core/lunge_model/err.train.csv
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c151b009fb505fc6125f95477225c66ad1456f9af92d7b28c3397ca47e0aea57
|
| 3 |
+
size 17699600
|
core/lunge_model/knee_angle.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
core/lunge_model/knee_angle_2.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
core/lunge_model/model/dp/all_models.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:555a7759e72bbdbf2aa7d57c937a92e71dd680bb464b5b4a0a253a57da96d3f6
|
| 3 |
+
size 9759164
|
core/lunge_model/model/dp/err_lunge_dp.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:bd70744c9ae0f6079733d7a10cfd0b3b675d5390f5ac0a3d3a4ffca321957062
|
| 3 |
+
size 276710
|
core/lunge_model/model/dp/stage_lunge_dp.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:fa609cba78e381e49124c7d7894e584685c24f62c1922516793313b8c3615321
|
| 3 |
+
size 245990
|
core/lunge_model/model/input_scaler.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d457c53018c4dc91702e4da9b5dcb5d9a06100e2a631666645c356b7efaf87f3
|
| 3 |
+
size 2605
|
core/lunge_model/model/sklearn/err_LR_model.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1dc70bcce708f2e8473b6a5c5162653ab0253a8634632055637c70d03dd5174e
|
| 3 |
+
size 1127
|
core/lunge_model/model/sklearn/err_SGDC_model.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:84f0449a5091092bdfa88d6df74090e9ee4bb1d9239ed25a829e2f5ac27d250f
|
| 3 |
+
size 5317
|
core/lunge_model/model/sklearn/err_all_sklearn.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:37806a8a7d8cb0c0b006839668f749ab026c631fd305424c538f7a168b709ebb
|
| 3 |
+
size 6937168
|
core/lunge_model/model/sklearn/stage_LR_model.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:13ed109d45934b90013cf4843508e2551797794089b6b5eb70d1efaf3f0e0836
|
| 3 |
+
size 1974
|
core/lunge_model/model/sklearn/stage_Ridge_model.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:2c0a230eaeb9cbfdd296b27ec3e188a9c1d94dbd0f5e941c396a118aa15ad378
|
| 3 |
+
size 2010
|
core/lunge_model/model/sklearn/stage_SVC_model.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d81a1fb1d60f1f6a2ffb855861e2864e3cb38450c5dc84a7198e68bbabcdf71e
|
| 3 |
+
size 370598
|
core/lunge_model/stage.test.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
core/lunge_model/stage.train.csv
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:fe9c1a432b8ca18f116d4d76785c95976fe783ec164c67c7d7fb14477a0c3752
|
| 3 |
+
size 24142502
|
core/plank_model/1.data.ipynb
ADDED
|
@@ -0,0 +1,706 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": 1,
|
| 6 |
+
"metadata": {},
|
| 7 |
+
"outputs": [
|
| 8 |
+
{
|
| 9 |
+
"name": "stderr",
|
| 10 |
+
"output_type": "stream",
|
| 11 |
+
"text": [
|
| 12 |
+
"objc[67861]: Class CaptureDelegate is implemented in both /Users/fuixlabsdev1/Programming/PP/graduation-thesis/env/lib/python3.8/site-packages/mediapipe/.dylibs/libopencv_videoio.3.4.16.dylib (0x104544860) and /Users/fuixlabsdev1/Programming/PP/graduation-thesis/env/lib/python3.8/site-packages/cv2/cv2.abi3.so (0x289c9e480). One of the two will be used. Which one is undefined.\n",
|
| 13 |
+
"objc[67861]: Class CVWindow is implemented in both /Users/fuixlabsdev1/Programming/PP/graduation-thesis/env/lib/python3.8/site-packages/mediapipe/.dylibs/libopencv_highgui.3.4.16.dylib (0x103324a68) and /Users/fuixlabsdev1/Programming/PP/graduation-thesis/env/lib/python3.8/site-packages/cv2/cv2.abi3.so (0x289c9e4d0). One of the two will be used. Which one is undefined.\n",
|
| 14 |
+
"objc[67861]: Class CVView is implemented in both /Users/fuixlabsdev1/Programming/PP/graduation-thesis/env/lib/python3.8/site-packages/mediapipe/.dylibs/libopencv_highgui.3.4.16.dylib (0x103324a90) and /Users/fuixlabsdev1/Programming/PP/graduation-thesis/env/lib/python3.8/site-packages/cv2/cv2.abi3.so (0x289c9e4f8). One of the two will be used. Which one is undefined.\n",
|
| 15 |
+
"objc[67861]: Class CVSlider is implemented in both /Users/fuixlabsdev1/Programming/PP/graduation-thesis/env/lib/python3.8/site-packages/mediapipe/.dylibs/libopencv_highgui.3.4.16.dylib (0x103324ab8) and /Users/fuixlabsdev1/Programming/PP/graduation-thesis/env/lib/python3.8/site-packages/cv2/cv2.abi3.so (0x289c9e520). One of the two will be used. Which one is undefined.\n"
|
| 16 |
+
]
|
| 17 |
+
}
|
| 18 |
+
],
|
| 19 |
+
"source": [
|
| 20 |
+
"import mediapipe as mp\n",
|
| 21 |
+
"import cv2\n",
|
| 22 |
+
"import numpy as np\n",
|
| 23 |
+
"import pandas as pd\n",
|
| 24 |
+
"import seaborn as sns\n",
|
| 25 |
+
"\n",
|
| 26 |
+
"import warnings\n",
|
| 27 |
+
"warnings.filterwarnings('ignore')\n",
|
| 28 |
+
"\n",
|
| 29 |
+
"# Drawing helpers\n",
|
| 30 |
+
"mp_drawing = mp.solutions.drawing_utils\n",
|
| 31 |
+
"mp_pose = mp.solutions.pose"
|
| 32 |
+
]
|
| 33 |
+
},
|
| 34 |
+
{
|
| 35 |
+
"cell_type": "markdown",
|
| 36 |
+
"metadata": {},
|
| 37 |
+
"source": [
|
| 38 |
+
"### 1. Make some pose detections (Test if Everything works correctly)"
|
| 39 |
+
]
|
| 40 |
+
},
|
| 41 |
+
{
|
| 42 |
+
"cell_type": "code",
|
| 43 |
+
"execution_count": null,
|
| 44 |
+
"metadata": {},
|
| 45 |
+
"outputs": [],
|
| 46 |
+
"source": [
|
| 47 |
+
"cap = cv2.VideoCapture(\"../data/plank/20221011_202153.mp4\")\n",
|
| 48 |
+
"\n",
|
| 49 |
+
"with mp_pose.Pose(min_detection_confidence=0.5, min_tracking_confidence=0.5) as pose:\n",
|
| 50 |
+
" while cap.isOpened():\n",
|
| 51 |
+
" ret, image = cap.read()\n",
|
| 52 |
+
" image = cv2.flip(image, 1)\n",
|
| 53 |
+
"\n",
|
| 54 |
+
" if not ret:\n",
|
| 55 |
+
" break\n",
|
| 56 |
+
"\n",
|
| 57 |
+
" # Recolor image from BGR to RGB for mediapipe\n",
|
| 58 |
+
" image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)\n",
|
| 59 |
+
" image.flags.writeable = False\n",
|
| 60 |
+
"\n",
|
| 61 |
+
" results = pose.process(image)\n",
|
| 62 |
+
"\n",
|
| 63 |
+
" # Recolor image from BGR to RGB for mediapipe\n",
|
| 64 |
+
" image.flags.writeable = True\n",
|
| 65 |
+
" image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)\n",
|
| 66 |
+
"\n",
|
| 67 |
+
" # Draw landmarks and connections\n",
|
| 68 |
+
" mp_drawing.draw_landmarks(image, results.pose_landmarks, mp_pose.POSE_CONNECTIONS, mp_drawing.DrawingSpec(color=(244, 117, 66), thickness=2, circle_radius=4), mp_drawing.DrawingSpec(color=(245, 66, 230), thickness=2, circle_radius=2))\n",
|
| 69 |
+
"\n",
|
| 70 |
+
" cv2.imshow(\"CV2\", image)\n",
|
| 71 |
+
"\n",
|
| 72 |
+
" # Press Q to close cv2 window\n",
|
| 73 |
+
" if cv2.waitKey(1) & 0xFF == ord('q'):\n",
|
| 74 |
+
" break\n",
|
| 75 |
+
"\n",
|
| 76 |
+
" cap.release()\n",
|
| 77 |
+
" cv2.destroyAllWindows()\n",
|
| 78 |
+
"\n",
|
| 79 |
+
" # (Optional)Fix bugs cannot close windows in MacOS (https://stackoverflow.com/questions/6116564/destroywindow-does-not-close-window-on-mac-using-python-and-opencv)\n",
|
| 80 |
+
" for i in range (1, 5):\n",
|
| 81 |
+
" cv2.waitKey(1)"
|
| 82 |
+
]
|
| 83 |
+
},
|
| 84 |
+
{
|
| 85 |
+
"cell_type": "markdown",
|
| 86 |
+
"metadata": {},
|
| 87 |
+
"source": [
|
| 88 |
+
"### 2. Build dataset from collected videos and picture from Kaggle to .csv file for dataset"
|
| 89 |
+
]
|
| 90 |
+
},
|
| 91 |
+
{
|
| 92 |
+
"cell_type": "code",
|
| 93 |
+
"execution_count": 2,
|
| 94 |
+
"metadata": {},
|
| 95 |
+
"outputs": [],
|
| 96 |
+
"source": [
|
| 97 |
+
"import csv\n",
|
| 98 |
+
"import os"
|
| 99 |
+
]
|
| 100 |
+
},
|
| 101 |
+
{
|
| 102 |
+
"cell_type": "markdown",
|
| 103 |
+
"metadata": {},
|
| 104 |
+
"source": [
|
| 105 |
+
"#### 2.1 Determine important keypoints and set up important functions"
|
| 106 |
+
]
|
| 107 |
+
},
|
| 108 |
+
{
|
| 109 |
+
"cell_type": "markdown",
|
| 110 |
+
"metadata": {},
|
| 111 |
+
"source": [
|
| 112 |
+
"There are 3 stages that I try to classify for Plank Exercise Correction:\n",
|
| 113 |
+
"\n",
|
| 114 |
+
"- Correct: \"C\"\n",
|
| 115 |
+
"- Back is too low: \"L\"\n",
|
| 116 |
+
"- Back is too high: \"H\""
|
| 117 |
+
]
|
| 118 |
+
},
|
| 119 |
+
{
|
| 120 |
+
"cell_type": "code",
|
| 121 |
+
"execution_count": 3,
|
| 122 |
+
"metadata": {},
|
| 123 |
+
"outputs": [],
|
| 124 |
+
"source": [
|
| 125 |
+
"# Determine important landmarks for plank\n",
|
| 126 |
+
"IMPORTANT_LMS = [\n",
|
| 127 |
+
" \"NOSE\",\n",
|
| 128 |
+
" \"LEFT_SHOULDER\",\n",
|
| 129 |
+
" \"RIGHT_SHOULDER\",\n",
|
| 130 |
+
" \"LEFT_ELBOW\",\n",
|
| 131 |
+
" \"RIGHT_ELBOW\",\n",
|
| 132 |
+
" \"LEFT_WRIST\",\n",
|
| 133 |
+
" \"RIGHT_WRIST\",\n",
|
| 134 |
+
" \"LEFT_HIP\",\n",
|
| 135 |
+
" \"RIGHT_HIP\",\n",
|
| 136 |
+
" \"LEFT_KNEE\",\n",
|
| 137 |
+
" \"RIGHT_KNEE\",\n",
|
| 138 |
+
" \"LEFT_ANKLE\",\n",
|
| 139 |
+
" \"RIGHT_ANKLE\",\n",
|
| 140 |
+
" \"LEFT_HEEL\",\n",
|
| 141 |
+
" \"RIGHT_HEEL\",\n",
|
| 142 |
+
" \"LEFT_FOOT_INDEX\",\n",
|
| 143 |
+
" \"RIGHT_FOOT_INDEX\",\n",
|
| 144 |
+
"]\n",
|
| 145 |
+
"\n",
|
| 146 |
+
"# Generate all columns of the data frame\n",
|
| 147 |
+
"\n",
|
| 148 |
+
"HEADERS = [\"label\"] # Label column\n",
|
| 149 |
+
"\n",
|
| 150 |
+
"for lm in IMPORTANT_LMS:\n",
|
| 151 |
+
" HEADERS += [f\"{lm.lower()}_x\", f\"{lm.lower()}_y\", f\"{lm.lower()}_z\", f\"{lm.lower()}_v\"]"
|
| 152 |
+
]
|
| 153 |
+
},
|
| 154 |
+
{
|
| 155 |
+
"cell_type": "markdown",
|
| 156 |
+
"metadata": {},
|
| 157 |
+
"source": [
|
| 158 |
+
"*Set up important functions*"
|
| 159 |
+
]
|
| 160 |
+
},
|
| 161 |
+
{
|
| 162 |
+
"cell_type": "code",
|
| 163 |
+
"execution_count": 4,
|
| 164 |
+
"metadata": {},
|
| 165 |
+
"outputs": [],
|
| 166 |
+
"source": [
|
| 167 |
+
"def rescale_frame(frame, percent=50):\n",
|
| 168 |
+
" '''\n",
|
| 169 |
+
" Rescale a frame to a certain percentage compare to its original frame\n",
|
| 170 |
+
" '''\n",
|
| 171 |
+
" width = int(frame.shape[1] * percent/ 100)\n",
|
| 172 |
+
" height = int(frame.shape[0] * percent/ 100)\n",
|
| 173 |
+
" dim = (width, height)\n",
|
| 174 |
+
" return cv2.resize(frame, dim, interpolation = cv2.INTER_AREA)\n",
|
| 175 |
+
" \n",
|
| 176 |
+
"\n",
|
| 177 |
+
"def init_csv(dataset_path: str):\n",
|
| 178 |
+
" '''\n",
|
| 179 |
+
" Create a blank csv file with just columns\n",
|
| 180 |
+
" '''\n",
|
| 181 |
+
"\n",
|
| 182 |
+
" # Ignore if file is already exist\n",
|
| 183 |
+
" if os.path.exists(dataset_path):\n",
|
| 184 |
+
" return\n",
|
| 185 |
+
"\n",
|
| 186 |
+
" # Write all the columns to a empaty file\n",
|
| 187 |
+
" with open(dataset_path, mode=\"w\", newline=\"\") as f:\n",
|
| 188 |
+
" csv_writer = csv.writer(f, delimiter=\",\", quotechar='\"', quoting=csv.QUOTE_MINIMAL)\n",
|
| 189 |
+
" csv_writer.writerow(HEADERS)\n",
|
| 190 |
+
"\n",
|
| 191 |
+
"\n",
|
| 192 |
+
"def export_landmark_to_csv(dataset_path: str, results, action: str) -> None:\n",
|
| 193 |
+
" '''\n",
|
| 194 |
+
" Export Labeled Data from detected landmark to csv\n",
|
| 195 |
+
" '''\n",
|
| 196 |
+
" landmarks = results.pose_landmarks.landmark\n",
|
| 197 |
+
" keypoints = []\n",
|
| 198 |
+
"\n",
|
| 199 |
+
" try:\n",
|
| 200 |
+
" # Extract coordinate of important landmarks\n",
|
| 201 |
+
" for lm in IMPORTANT_LMS:\n",
|
| 202 |
+
" keypoint = landmarks[mp_pose.PoseLandmark[lm].value]\n",
|
| 203 |
+
" keypoints.append([keypoint.x, keypoint.y, keypoint.z, keypoint.visibility])\n",
|
| 204 |
+
" \n",
|
| 205 |
+
" keypoints = list(np.array(keypoints).flatten())\n",
|
| 206 |
+
"\n",
|
| 207 |
+
" # Insert action as the label (first column)\n",
|
| 208 |
+
" keypoints.insert(0, action)\n",
|
| 209 |
+
"\n",
|
| 210 |
+
" # Append new row to .csv file\n",
|
| 211 |
+
" with open(dataset_path, mode=\"a\", newline=\"\") as f:\n",
|
| 212 |
+
" csv_writer = csv.writer(f, delimiter=\",\", quotechar='\"', quoting=csv.QUOTE_MINIMAL)\n",
|
| 213 |
+
" csv_writer.writerow(keypoints)\n",
|
| 214 |
+
" \n",
|
| 215 |
+
"\n",
|
| 216 |
+
" except Exception as e:\n",
|
| 217 |
+
" print(e)\n",
|
| 218 |
+
" pass\n",
|
| 219 |
+
"\n",
|
| 220 |
+
"\n",
|
| 221 |
+
"def describe_dataset(dataset_path: str):\n",
|
| 222 |
+
" '''\n",
|
| 223 |
+
" Describe dataset\n",
|
| 224 |
+
" '''\n",
|
| 225 |
+
"\n",
|
| 226 |
+
" data = pd.read_csv(dataset_path)\n",
|
| 227 |
+
" print(f\"Headers: {list(data.columns.values)}\")\n",
|
| 228 |
+
" print(f'Number of rows: {data.shape[0]} \\nNumber of columns: {data.shape[1]}\\n')\n",
|
| 229 |
+
" print(f\"Labels: \\n{data['label'].value_counts()}\\n\")\n",
|
| 230 |
+
" print(f\"Missing values: {data.isnull().values.any()}\\n\")\n",
|
| 231 |
+
" \n",
|
| 232 |
+
" duplicate = data[data.duplicated()]\n",
|
| 233 |
+
" print(f\"Duplicate Rows : {len(duplicate.sum(axis=1))}\")\n",
|
| 234 |
+
"\n",
|
| 235 |
+
" return data\n",
|
| 236 |
+
"\n",
|
| 237 |
+
"\n",
|
| 238 |
+
"def remove_duplicate_rows(dataset_path: str):\n",
|
| 239 |
+
" '''\n",
|
| 240 |
+
" Remove duplicated data from the dataset then save it to another files\n",
|
| 241 |
+
" '''\n",
|
| 242 |
+
" \n",
|
| 243 |
+
" df = pd.read_csv(dataset_path)\n",
|
| 244 |
+
" df.drop_duplicates(keep=\"first\", inplace=True)\n",
|
| 245 |
+
" df.to_csv(f\"cleaned_train.csv\", sep=',', encoding='utf-8', index=False)\n",
|
| 246 |
+
" \n",
|
| 247 |
+
"\n",
|
| 248 |
+
"def concat_csv_files_with_same_headers(file_paths: list, saved_path: str):\n",
|
| 249 |
+
" '''\n",
|
| 250 |
+
" Concat different csv files\n",
|
| 251 |
+
" '''\n",
|
| 252 |
+
" all_df = []\n",
|
| 253 |
+
" for path in file_paths:\n",
|
| 254 |
+
" df = pd.read_csv(path, index_col=None, header=0)\n",
|
| 255 |
+
" all_df.append(df)\n",
|
| 256 |
+
" \n",
|
| 257 |
+
" results = pd.concat(all_df, axis=0, ignore_index=True)\n",
|
| 258 |
+
" results.to_csv(saved_path, sep=',', encoding='utf-8', index=False)"
|
| 259 |
+
]
|
| 260 |
+
},
|
| 261 |
+
{
|
| 262 |
+
"cell_type": "markdown",
|
| 263 |
+
"metadata": {},
|
| 264 |
+
"source": [
|
| 265 |
+
"#### 2.2 Extract from video\n"
|
| 266 |
+
]
|
| 267 |
+
},
|
| 268 |
+
{
|
| 269 |
+
"cell_type": "code",
|
| 270 |
+
"execution_count": 6,
|
| 271 |
+
"metadata": {},
|
| 272 |
+
"outputs": [
|
| 273 |
+
{
|
| 274 |
+
"name": "stderr",
|
| 275 |
+
"output_type": "stream",
|
| 276 |
+
"text": [
|
| 277 |
+
"INFO: Created TensorFlow Lite XNNPACK delegate for CPU.\n"
|
| 278 |
+
]
|
| 279 |
+
}
|
| 280 |
+
],
|
| 281 |
+
"source": [
|
| 282 |
+
"DATASET_PATH = \"train.csv\"\n",
|
| 283 |
+
"\n",
|
| 284 |
+
"cap = cv2.VideoCapture(\"../data/plank/bad/plank_bad_high_4.mp4\")\n",
|
| 285 |
+
"save_counts = 0\n",
|
| 286 |
+
"\n",
|
| 287 |
+
"# init_csv(DATASET_PATH)\n",
|
| 288 |
+
"\n",
|
| 289 |
+
"with mp_pose.Pose(min_detection_confidence=0.5, min_tracking_confidence=0.5) as pose:\n",
|
| 290 |
+
" while cap.isOpened():\n",
|
| 291 |
+
" ret, image = cap.read()\n",
|
| 292 |
+
"\n",
|
| 293 |
+
" if not ret:\n",
|
| 294 |
+
" break\n",
|
| 295 |
+
"\n",
|
| 296 |
+
" # Reduce size of a frame\n",
|
| 297 |
+
" image = rescale_frame(image, 60)\n",
|
| 298 |
+
" image = cv2.flip(image, 1)\n",
|
| 299 |
+
"\n",
|
| 300 |
+
" # Recolor image from BGR to RGB for mediapipe\n",
|
| 301 |
+
" image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)\n",
|
| 302 |
+
" image.flags.writeable = False\n",
|
| 303 |
+
"\n",
|
| 304 |
+
" results = pose.process(image)\n",
|
| 305 |
+
"\n",
|
| 306 |
+
" if not results.pose_landmarks: continue\n",
|
| 307 |
+
"\n",
|
| 308 |
+
" # Recolor image from BGR to RGB for mediapipe\n",
|
| 309 |
+
" image.flags.writeable = True\n",
|
| 310 |
+
" image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)\n",
|
| 311 |
+
"\n",
|
| 312 |
+
" # Draw landmarks and connections\n",
|
| 313 |
+
" mp_drawing.draw_landmarks(image, results.pose_landmarks, mp_pose.POSE_CONNECTIONS, mp_drawing.DrawingSpec(color=(244, 117, 66), thickness=2, circle_radius=4), mp_drawing.DrawingSpec(color=(245, 66, 230), thickness=2, circle_radius=2))\n",
|
| 314 |
+
"\n",
|
| 315 |
+
" # Display the saved count\n",
|
| 316 |
+
" cv2.putText(image, f\"Saved: {save_counts}\", (50, 50), cv2.FONT_HERSHEY_COMPLEX, 2, (0, 0, 0), 2, cv2.LINE_AA)\n",
|
| 317 |
+
"\n",
|
| 318 |
+
" cv2.imshow(\"CV2\", image)\n",
|
| 319 |
+
"\n",
|
| 320 |
+
" # Pressed key for action\n",
|
| 321 |
+
" k = cv2.waitKey(1) & 0xFF\n",
|
| 322 |
+
"\n",
|
| 323 |
+
" # Press C to save as correct form\n",
|
| 324 |
+
" if k == ord('c'): \n",
|
| 325 |
+
" export_landmark_to_csv(DATASET_PATH, results, \"C\")\n",
|
| 326 |
+
" save_counts += 1\n",
|
| 327 |
+
" # Press L to save as low back\n",
|
| 328 |
+
" elif k == ord(\"l\"):\n",
|
| 329 |
+
" export_landmark_to_csv(DATASET_PATH, results, \"L\")\n",
|
| 330 |
+
" save_counts += 1\n",
|
| 331 |
+
" # Press L to save as high back\n",
|
| 332 |
+
" elif k == ord(\"h\"):\n",
|
| 333 |
+
" export_landmark_to_csv(DATASET_PATH, results, \"H\")\n",
|
| 334 |
+
" save_counts += 1\n",
|
| 335 |
+
"\n",
|
| 336 |
+
" # Press q to stop\n",
|
| 337 |
+
" elif k == ord(\"q\"):\n",
|
| 338 |
+
" break\n",
|
| 339 |
+
" else: continue\n",
|
| 340 |
+
"\n",
|
| 341 |
+
" cap.release()\n",
|
| 342 |
+
" cv2.destroyAllWindows()\n",
|
| 343 |
+
"\n",
|
| 344 |
+
" # (Optional)Fix bugs cannot close windows in MacOS (https://stackoverflow.com/questions/6116564/destroywindow-does-not-close-window-on-mac-using-python-and-opencv)\n",
|
| 345 |
+
" for i in range (1, 5):\n",
|
| 346 |
+
" cv2.waitKey(1)\n",
|
| 347 |
+
" "
|
| 348 |
+
]
|
| 349 |
+
},
|
| 350 |
+
{
|
| 351 |
+
"cell_type": "code",
|
| 352 |
+
"execution_count": 7,
|
| 353 |
+
"metadata": {},
|
| 354 |
+
"outputs": [
|
| 355 |
+
{
|
| 356 |
+
"name": "stdout",
|
| 357 |
+
"output_type": "stream",
|
| 358 |
+
"text": [
|
| 359 |
+
"Headers: ['label', 'nose_x', 'nose_y', 'nose_z', 'nose_v', 'left_shoulder_x', 'left_shoulder_y', 'left_shoulder_z', 'left_shoulder_v', 'right_shoulder_x', 'right_shoulder_y', 'right_shoulder_z', 'right_shoulder_v', 'left_elbow_x', 'left_elbow_y', 'left_elbow_z', 'left_elbow_v', 'right_elbow_x', 'right_elbow_y', 'right_elbow_z', 'right_elbow_v', 'left_wrist_x', 'left_wrist_y', 'left_wrist_z', 'left_wrist_v', 'right_wrist_x', 'right_wrist_y', 'right_wrist_z', 'right_wrist_v', 'left_hip_x', 'left_hip_y', 'left_hip_z', 'left_hip_v', 'right_hip_x', 'right_hip_y', 'right_hip_z', 'right_hip_v', 'left_knee_x', 'left_knee_y', 'left_knee_z', 'left_knee_v', 'right_knee_x', 'right_knee_y', 'right_knee_z', 'right_knee_v', 'left_ankle_x', 'left_ankle_y', 'left_ankle_z', 'left_ankle_v', 'right_ankle_x', 'right_ankle_y', 'right_ankle_z', 'right_ankle_v', 'left_heel_x', 'left_heel_y', 'left_heel_z', 'left_heel_v', 'right_heel_x', 'right_heel_y', 'right_heel_z', 'right_heel_v', 'left_foot_index_x', 'left_foot_index_y', 'left_foot_index_z', 'left_foot_index_v', 'right_foot_index_x', 'right_foot_index_y', 'right_foot_index_z', 'right_foot_index_v']\n",
|
| 360 |
+
"Number of rows: 28520 \n",
|
| 361 |
+
"Number of columns: 69\n",
|
| 362 |
+
"\n",
|
| 363 |
+
"Labels: \n",
|
| 364 |
+
"C 9904\n",
|
| 365 |
+
"L 9546\n",
|
| 366 |
+
"H 9070\n",
|
| 367 |
+
"Name: label, dtype: int64\n",
|
| 368 |
+
"\n",
|
| 369 |
+
"Missing values: False\n",
|
| 370 |
+
"\n",
|
| 371 |
+
"Duplicate Rows : 0\n"
|
| 372 |
+
]
|
| 373 |
+
}
|
| 374 |
+
],
|
| 375 |
+
"source": [
|
| 376 |
+
"# csv_files = [os.path.join(\"./\", f) for f in os.listdir(\"./\") if \"csv\" in f]\n",
|
| 377 |
+
"\n",
|
| 378 |
+
"# concat_csv_files_with_same_headers(csv_files, \"train.csv\")\n",
|
| 379 |
+
"\n",
|
| 380 |
+
"# remove_duplicate_rows(DATASET_PATH)\n",
|
| 381 |
+
"\n",
|
| 382 |
+
"df = describe_dataset(DATASET_PATH)\n"
|
| 383 |
+
]
|
| 384 |
+
},
|
| 385 |
+
{
|
| 386 |
+
"cell_type": "markdown",
|
| 387 |
+
"metadata": {},
|
| 388 |
+
"source": [
|
| 389 |
+
"#### 2.3. Extract from Kaggle dataset ([download here](https://www.kaggle.com/datasets/niharika41298/yoga-poses-dataset))"
|
| 390 |
+
]
|
| 391 |
+
},
|
| 392 |
+
{
|
| 393 |
+
"cell_type": "code",
|
| 394 |
+
"execution_count": null,
|
| 395 |
+
"metadata": {},
|
| 396 |
+
"outputs": [],
|
| 397 |
+
"source": [
|
| 398 |
+
"FOLDER_PATH = \"../data/kaggle/TRAIN/plank\"\n",
|
| 399 |
+
"picture_files = [os.path.join(FOLDER_PATH, f) for f in os.listdir(FOLDER_PATH) if os.path.isfile(os.path.join(FOLDER_PATH, f))]\n",
|
| 400 |
+
"print(f\"Total pictures: {len(picture_files)}\")\n",
|
| 401 |
+
"\n",
|
| 402 |
+
"DATASET_PATH = \"./kaggle.csv\"\n",
|
| 403 |
+
"saved_counts = 0\n",
|
| 404 |
+
"\n",
|
| 405 |
+
"init_csv(DATASET_PATH)\n",
|
| 406 |
+
"\n",
|
| 407 |
+
"with mp_pose.Pose(min_detection_confidence=0.7, min_tracking_confidence=0.5) as pose:\n",
|
| 408 |
+
" index = 0\n",
|
| 409 |
+
" \n",
|
| 410 |
+
" while True:\n",
|
| 411 |
+
" if index == len(picture_files):\n",
|
| 412 |
+
" break\n",
|
| 413 |
+
" \n",
|
| 414 |
+
" file_path = picture_files[index]\n",
|
| 415 |
+
"\n",
|
| 416 |
+
" image = cv2.imread(file_path)\n",
|
| 417 |
+
"\n",
|
| 418 |
+
" # Flip image horizontally for more data\n",
|
| 419 |
+
" image = cv2.flip(image, 1)\n",
|
| 420 |
+
"\n",
|
| 421 |
+
" # get dimensions of image\n",
|
| 422 |
+
" dimensions = image.shape\n",
|
| 423 |
+
" \n",
|
| 424 |
+
" # height, width, number of channels in image\n",
|
| 425 |
+
" height = image.shape[0]\n",
|
| 426 |
+
" width = image.shape[1]\n",
|
| 427 |
+
" channels = image.shape[2]\n",
|
| 428 |
+
"\n",
|
| 429 |
+
" # Reduce size of a frame\n",
|
| 430 |
+
" if width > 1000:\n",
|
| 431 |
+
" image = rescale_frame(image, 60)\n",
|
| 432 |
+
"\n",
|
| 433 |
+
" # Recolor image from BGR to RGB for mediapipe\n",
|
| 434 |
+
" image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)\n",
|
| 435 |
+
" image.flags.writeable = False\n",
|
| 436 |
+
"\n",
|
| 437 |
+
" results = pose.process(image)\n",
|
| 438 |
+
"\n",
|
| 439 |
+
" # Recolor image from BGR to RGB for mediapipe\n",
|
| 440 |
+
" image.flags.writeable = True\n",
|
| 441 |
+
" image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)\n",
|
| 442 |
+
"\n",
|
| 443 |
+
" # Draw landmarks and connections\n",
|
| 444 |
+
" mp_drawing.draw_landmarks(image, results.pose_landmarks, mp_pose.POSE_CONNECTIONS, mp_drawing.DrawingSpec(color=(244, 117, 66), thickness=2, circle_radius=4), mp_drawing.DrawingSpec(color=(245, 66, 230), thickness=2, circle_radius=2))\n",
|
| 445 |
+
"\n",
|
| 446 |
+
" # Display the saved count\n",
|
| 447 |
+
" cv2.putText(image, f\"Saved: {saved_counts}\", (20, 20), cv2.FONT_HERSHEY_COMPLEX, 2, (0, 0, 0), 2, cv2.LINE_AA)\n",
|
| 448 |
+
" \n",
|
| 449 |
+
" cv2.imshow(\"CV2\", image)\n",
|
| 450 |
+
"\n",
|
| 451 |
+
" k = cv2.waitKey(1) & 0xFF\n",
|
| 452 |
+
"\n",
|
| 453 |
+
" if k == ord('d'): \n",
|
| 454 |
+
" index += 1\n",
|
| 455 |
+
"\n",
|
| 456 |
+
" elif k == ord(\"s\"):\n",
|
| 457 |
+
" export_landmark_to_csv(DATASET_PATH, results, \"C\")\n",
|
| 458 |
+
" saved_counts += 1\n",
|
| 459 |
+
"\n",
|
| 460 |
+
" elif k == ord(\"f\"):\n",
|
| 461 |
+
" index += 1\n",
|
| 462 |
+
" os.remove(file_path)\n",
|
| 463 |
+
"\n",
|
| 464 |
+
" elif k == ord(\"q\"):\n",
|
| 465 |
+
" break\n",
|
| 466 |
+
"\n",
|
| 467 |
+
" else:\n",
|
| 468 |
+
" continue\n",
|
| 469 |
+
"\n",
|
| 470 |
+
" # # Press Q to close cv2 window\n",
|
| 471 |
+
" # if cv2.waitKey(1) & 0xFF == ord('d'):\n",
|
| 472 |
+
" # index += 1\n",
|
| 473 |
+
"\n",
|
| 474 |
+
" # # Press Q to close cv2 window\n",
|
| 475 |
+
" # if cv2.waitKey(1) & 0xFF == ord('q'):\n",
|
| 476 |
+
" # break\n",
|
| 477 |
+
"\n",
|
| 478 |
+
" # Close cv2 window\n",
|
| 479 |
+
" cv2.destroyAllWindows()\n",
|
| 480 |
+
"\n",
|
| 481 |
+
" # (Optional)Fix bugs cannot close windows in MacOS (https://stackoverflow.com/questions/6116564/destroywindow-does-not-close-window-on-mac-using-python-and-opencv)\n",
|
| 482 |
+
" for i in range (1, 5):\n",
|
| 483 |
+
" cv2.waitKey(1)"
|
| 484 |
+
]
|
| 485 |
+
},
|
| 486 |
+
{
|
| 487 |
+
"cell_type": "markdown",
|
| 488 |
+
"metadata": {},
|
| 489 |
+
"source": [
|
| 490 |
+
"### 3. Refine Data & Data Visualization"
|
| 491 |
+
]
|
| 492 |
+
},
|
| 493 |
+
{
|
| 494 |
+
"cell_type": "code",
|
| 495 |
+
"execution_count": 9,
|
| 496 |
+
"metadata": {},
|
| 497 |
+
"outputs": [
|
| 498 |
+
{
|
| 499 |
+
"data": {
|
| 500 |
+
"text/plain": [
|
| 501 |
+
"<AxesSubplot:xlabel='label', ylabel='count'>"
|
| 502 |
+
]
|
| 503 |
+
},
|
| 504 |
+
"execution_count": 9,
|
| 505 |
+
"metadata": {},
|
| 506 |
+
"output_type": "execute_result"
|
| 507 |
+
},
|
| 508 |
+
{
|
| 509 |
+
"data": {
|
| 510 |
+
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAk0AAAGwCAYAAAC0HlECAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/NK7nSAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAnDklEQVR4nO3df3DU9Z3H8ddCyBIgfCWEZMkZNIw5RIOWCxKCVahAoDWmDjOiF7uHAwV6QdIUKJThqIhHoqjAaK4InBWOH4a5tmm5trcleiVX5EcgNVUwor3mBGxC6LnZAMYkhu/90fKdLkH4GGL2G/J8zGTG/e57k/eXyZjnfHc38di2bQsAAABX1CvSCwAAAHQHRBMAAIABogkAAMAA0QQAAGCAaAIAADBANAEAABggmgAAAAxERXqB68mFCxf0xz/+UbGxsfJ4PJFeBwAAGLBtW2fPnlVSUpJ69frs60lEUyf64x//qOTk5EivAQAAOuDkyZO68cYbP/N+oqkTxcbGSvrzP/rAgQMjvA0AADDR2Nio5ORk5+f4ZyGaOtHFp+QGDhxINAEA0M1c7aU1vBAcAADAANEEAABggGgCAAAwQDQBAAAYIJoAAAAMEE0AAAAGiCYAAAADRBMAAICBiEbTf//3f+uBBx5QUlKSPB6PfvrTn4bdb9u2Vq5cqaSkJMXExGjixIk6duxY2Exzc7MWLFig+Ph49e/fXzk5OTp16lTYTDAYlN/vl2VZsixLfr9fDQ0NYTMnTpzQAw88oP79+ys+Pl75+flqaWn5Ik4bAAB0QxGNpvPnz+vOO+9UcXHxZe9fs2aN1q5dq+LiYh0+fFg+n09TpkzR2bNnnZmCggKVlpaqpKRE+/bt07lz55Sdna22tjZnJjc3V1VVVQoEAgoEAqqqqpLf73fub2tr0/3336/z589r3759Kikp0Y9//GMtWrToizt5AADQvdguIckuLS11bl+4cMH2+Xz2008/7Rz75JNPbMuy7Jdeesm2bdtuaGiw+/TpY5eUlDgzH374od2rVy87EAjYtm3b77zzji3JPnjwoDNz4MABW5L97rvv2rZt27/85S/tXr162R9++KEz8+qrr9per9cOhUKfufMnn3xih0Ih5+PkyZO2pCs+BgAAuEsoFDL6+e3a1zTV1NSorq5OWVlZzjGv16sJEyZo//79kqTKykq1traGzSQlJSktLc2ZOXDggCzLUkZGhjMzbtw4WZYVNpOWlqakpCRnZurUqWpublZlZeVn7lhUVOQ85WdZlpKTkzvn5AEAgOu4Nprq6uokSYmJiWHHExMTnfvq6uoUHR2tQYMGXXEmISGh3edPSEgIm7n06wwaNEjR0dHOzOUsW7ZMoVDI+Th58uTnPEsAANBdREV6gau59C8O27Z91b9CfOnM5eY7MnMpr9crr9d7xV0AAMD1wbVXmnw+nyS1u9JTX1/vXBXy+XxqaWlRMBi84szp06fbff4zZ86EzVz6dYLBoFpbW9tdgQIAAD2Ta680paSkyOfzqaysTKNHj5YktbS0qLy8XM8884wkKT09XX369FFZWZlmzJghSaqtrdXRo0e1Zs0aSVJmZqZCoZAqKio0duxYSdKhQ4cUCoU0fvx4Z2b16tWqra3V0KFDJUl79uyR1+tVenp6l573kTFju/Trwd3GHKmI9AoAgL+IaDSdO3dOv//9753bNTU1qqqqUlxcnIYNG6aCggIVFhYqNTVVqampKiwsVL9+/ZSbmytJsixLs2fP1qJFizR48GDFxcVp8eLFGjVqlCZPnixJGjlypKZNm6Y5c+Zo48aNkqS5c+cqOztbI0aMkCRlZWXptttuk9/v17PPPquPPvpIixcv1pw5czRw4MAu/lcBAABuFNFoOnLkiL7yla84txcuXChJmjlzprZs2aIlS5aoqalJeXl5CgaDysjI0J49exQbG+s8Zt26dYqKitKMGTPU1NSkSZMmacuWLerdu7czs2PHDuXn5zvvssvJyQn73VC9e/fWL37xC+Xl5enuu+9WTEyMcnNz9dxzz33R/wQAAKCb8Ni2bUd6ietFY2OjLMtSKBTq8BUqnp7DX+PpOQD44pn+/HbtC8EBAADchGgCAAAwQDQBAAAYIJoAAAAMEE0AAAAGiCYAAAADRBMAAIABogkAAMCAa//2HAD3mLZiV6RXgIsEnno40isAEcGVJgAAAANEEwAAgAGiCQAAwADRBAAAYIBoAgAAMEA0AQAAGCCaAAAADBBNAAAABogmAAAAA0QTAACAAaIJAADAANEEAABggGgCAAAwQDQBAAAYIJoAAAAMEE0AAAAGoiK9AAAAn1futkcivQJcZKe/pEu+DleaAAAADBBNAAAABogmAAAAA0QTAACAAaIJAADAANEEAABggGgCAAAwQDQBAAAYIJoAAAAMEE0AAAAGiCYAAAADRBMAAIABogkAAMAA0QQAAGCAaAIAADBANAEAABggmgAAAAwQTQAAAAaIJgAAAANEEwAAgAGiCQAAwADRBAAAYIBoAgAAMEA0AQAAGCCaAAAADBBNAAAABogmAAAAA0QTAACAAaIJAADAANEEAABggGgCAAAwQDQBAAAYIJoAAAAMEE0AAAAGXB1Nn376qf7pn/5JKSkpiomJ0fDhw7Vq1SpduHDBmbFtWytXrlRSUpJiYmI0ceJEHTt2LOzzNDc3a8GCBYqPj1f//v2Vk5OjU6dOhc0Eg0H5/X5ZliXLsuT3+9XQ0NAVpwkAALoBV0fTM888o5deeknFxcWqrq7WmjVr9Oyzz+rFF190ZtasWaO1a9equLhYhw8fls/n05QpU3T27FlnpqCgQKWlpSopKdG+fft07tw5ZWdnq62tzZnJzc1VVVWVAoGAAoGAqqqq5Pf7u/R8AQCAe0VFeoErOXDggL7+9a/r/vvvlyTdfPPNevXVV3XkyBFJf77KtH79ei1fvlzTp0+XJG3dulWJiYnauXOn5s2bp1AopJdfflnbtm3T5MmTJUnbt29XcnKyXnvtNU2dOlXV1dUKBAI6ePCgMjIyJEmbN29WZmamjh8/rhEjRlx2v+bmZjU3Nzu3Gxsbv7B/CwAAEFmuvtL05S9/Wa+//rree+89SdLvfvc77du3T1/72tckSTU1Naqrq1NWVpbzGK/XqwkTJmj//v2SpMrKSrW2tobNJCUlKS0tzZk5cOCALMtygkmSxo0bJ8uynJnLKSoqcp7OsyxLycnJnXfyAADAVVx9pWnp0qUKhUK69dZb1bt3b7W1tWn16tX6+7//e0lSXV2dJCkxMTHscYmJifrggw+cmejoaA0aNKjdzMXH19XVKSEhod3XT0hIcGYuZ9myZVq4cKFzu7GxkXACAOA65epo2rVrl7Zv366dO3fq9ttvV1VVlQoKCpSUlKSZM2c6cx6PJ+xxtm23O3apS2cuN3+1z+P1euX1ek1PBwAAdGOujqbvfve7+t73vqdHHnlEkjRq1Ch98MEHKioq0syZM+Xz+ST9+UrR0KFDncfV19c7V598Pp9aWloUDAbDrjbV19dr/Pjxzszp06fbff0zZ860u4oFAAB6Jle/punjjz9Wr17hK/bu3dv5lQMpKSny+XwqKytz7m9paVF5ebkTROnp6erTp0/YTG1trY4ePerMZGZmKhQKqaKiwpk5dOiQQqGQMwMAAHo2V19peuCBB7R69WoNGzZMt99+u958802tXbtWs2bNkvTnp9QKCgpUWFio1NRUpaamqrCwUP369VNubq4kybIszZ49W4sWLdLgwYMVFxenxYsXa9SoUc676UaOHKlp06Zpzpw52rhxoyRp7ty5ys7O/sx3zgEAgJ7F1dH04osvasWKFcrLy1N9fb2SkpI0b948ff/733dmlixZoqamJuXl5SkYDCojI0N79uxRbGysM7Nu3TpFRUVpxowZampq0qRJk7Rlyxb17t3bmdmxY4fy8/Odd9nl5OSouLi4604WAAC4mse2bTvSS1wvGhsbZVmWQqGQBg4c2KHPcWTM2E7eCt3ZmCMVVx/qAtNW7Ir0CnCRwFMPR3oF5W57JNIrwEV2+kuu6fGmP79d/ZomAAAAtyCaAAAADBBNAAAABogmAAAAA0QTAACAAaIJAADAANEEAABggGgCAAAwQDQBAAAYIJoAAAAMEE0AAAAGiCYAAAADRBMAAIABogkAAMAA0QQAAGCAaAIAADBANAEAABggmgAAAAwQTQAAAAaIJgAAAANEEwAAgAGiCQAAwADRBAAAYIBoAgAAMEA0AQAAGCCaAAAADBBNAAAABogmAAAAA0QTAACAAaIJAADAANEEAABggGgCAAAwQDQBAAAYIJoAAAAMEE0AAAAGiCYAAAADRBMAAIABogkAAMAA0QQAAGCAaAIAADBANAEAABggmgAAAAwQTQAAAAaIJgAAAANEEwAAgAGiCQAAwADRBAAAYIBoAgAAMEA0AQAAGCCaAAAADBBNAAAABogmAAAAA0QTAACAAaIJAADAANEEAABggGgCAAAwQDQBAAAYIJoAAAAMEE0AAAAGiCYAAAADro+mDz/8UN/4xjc0ePBg9evXT1/60pdUWVnp3G/btlauXKmkpCTFxMRo4sSJOnbsWNjnaG5u1oIFCxQfH6/+/fsrJydHp06dCpsJBoPy+/2yLEuWZcnv96uhoaErThEAAHQDro6mYDCou+++W3369NF//ud/6p133tHzzz+vG264wZlZs2aN1q5dq+LiYh0+fFg+n09TpkzR2bNnnZmCggKVlpaqpKRE+/bt07lz55Sdna22tjZnJjc3V1VVVQoEAgoEAqqqqpLf7+/K0wUAAC4WFekFruSZZ55RcnKyXnnlFefYzTff7Py3bdtav369li9frunTp0uStm7dqsTERO3cuVPz5s1TKBTSyy+/rG3btmny5MmSpO3btys5OVmvvfaapk6dqurqagUCAR08eFAZGRmSpM2bNyszM1PHjx/XiBEjLrtfc3OzmpubnduNjY2d/U8AAABcwtVXmnbv3q0xY8booYceUkJCgkaPHq3Nmzc799fU1Kiurk5ZWVnOMa/XqwkTJmj//v2SpMrKSrW2tobNJCUlKS0tzZk5cOCALMtygkmSxo0bJ8uynJnLKSoqcp7OsyxLycnJnXbuAADAXVwdTX/4wx+0YcMGpaam6le/+pW+9a1vKT8/X//2b/8mSaqrq5MkJSYmhj0uMTHRua+urk7R0dEaNGjQFWcSEhLaff2EhARn5nKWLVumUCjkfJw8ebLjJwsAAFzN1U/PXbhwQWPGjFFhYaEkafTo0Tp27Jg2bNigf/iHf3DmPB5P2ONs22537FKXzlxu/mqfx+v1yuv1Gp0LAADo3lx9pWno0KG67bbbwo6NHDlSJ06ckCT5fD5Janc1qL6+3rn65PP51NLSomAweMWZ06dPt/v6Z86caXcVCwAA9Eyujqa7775bx48fDzv23nvv6aabbpIkpaSkyOfzqayszLm/paVF5eXlGj9+vCQpPT1dffr0CZupra3V0aNHnZnMzEyFQiFVVFQ4M4cOHVIoFHJmAABAz+bqp+e+853vaPz48SosLNSMGTNUUVGhTZs2adOmTZL+/JRaQUGBCgsLlZqaqtTUVBUWFqpfv37Kzc2VJFmWpdmzZ2vRokUaPHiw4uLitHjxYo0aNcp5N93IkSM1bdo0zZkzRxs3bpQkzZ07V9nZ2Z/5zjkAANCzuDqa7rrrLpWWlmrZsmVatWqVUlJStH79ej366KPOzJIlS9TU1KS8vDwFg0FlZGRoz549io2NdWbWrVunqKgozZgxQ01NTZo0aZK2bNmi3r17OzM7duxQfn6+8y67nJwcFRcXd93JAgAAV/PYtm1HeonrRWNjoyzLUigU0sCBAzv0OY6MGdvJW6E7G3Ok4upDXWDail2RXgEuEnjq4UivoNxtj0R6BbjITn/JNT3e9Oe3q1/TBAAA4BZEEwAAgAGiCQAAwADRBAAAYKBD0XTfffepoaGh3fHGxkbdd99917oTAACA63Qomvbu3auWlpZ2xz/55BP95je/uealAAAA3OZz/Z6mt956y/nvd955J+zPl7S1tSkQCOhv/uZvOm87AAAAl/hc0fSlL31JHo9HHo/nsk/DxcTE6MUXX+y05QAAANzic0VTTU2NbNvW8OHDVVFRoSFDhjj3RUdHKyEhIey3bAMAAFwvPlc0XfxDuRcuXPhClgEAAHCrDv/tuffee0979+5VfX19u4j6/ve/f82LAQAAuEmHomnz5s36x3/8R8XHx8vn88nj8Tj3eTweogkAAFx3OhRN//zP/6zVq1dr6dKlnb0PAACAK3Xo9zQFg0E99NBDnb0LAACAa3Uomh566CHt2bOns3cBAABwrQ49PXfLLbdoxYoVOnjwoEaNGqU+ffqE3Z+fn98pywEAALhFh6Jp06ZNGjBggMrLy1VeXh52n8fjIZoAAMB1p0PRVFNT09l7AAAAuFqHXtMEAADQ03ToStOsWbOueP8Pf/jDDi0DAADgVh2KpmAwGHa7tbVVR48eVUNDw2X/kC8AAEB316FoKi0tbXfswoULysvL0/Dhw695KQAAALfptNc09erVS9/5zne0bt26zvqUAAAArtGpLwT/n//5H3366aed+SkBAABcoUNPzy1cuDDstm3bqq2t1S9+8QvNnDmzUxYDAABwkw5F05tvvhl2u1evXhoyZIief/75q76zDgAAoDvqUDT9+te/7uw9AAAAXK1D0XTRmTNndPz4cXk8Hv3t3/6thgwZ0ll7AQAAuEqHXgh+/vx5zZo1S0OHDtW9996re+65R0lJSZo9e7Y+/vjjzt4RAAAg4joUTQsXLlR5ebn+4z/+Qw0NDWpoaNDPfvYzlZeXa9GiRZ29IwAAQMR16Om5H//4x/rRj36kiRMnOse+9rWvKSYmRjNmzNCGDRs6az8AAABX6NCVpo8//liJiYntjickJPD0HAAAuC51KJoyMzP1xBNP6JNPPnGONTU16cknn1RmZmanLQcAAOAWHXp6bv369frqV7+qG2+8UXfeeac8Ho+qqqrk9Xq1Z8+ezt4RAAAg4joUTaNGjdL777+v7du3691335Vt23rkkUf06KOPKiYmprN3BAAAiLgORVNRUZESExM1Z86csOM//OEPdebMGS1durRTlgMAAHCLDr2maePGjbr11lvbHb/99tv10ksvXfNSAAAAbtOhaKqrq9PQoUPbHR8yZIhqa2uveSkAAAC36VA0JScn64033mh3/I033lBSUtI1LwUAAOA2HXpN0ze/+U0VFBSotbVV9913nyTp9ddf15IlS/iN4AAA4LrUoWhasmSJPvroI+Xl5amlpUWS1LdvXy1dulTLli3r1AUBAADcoEPR5PF49Mwzz2jFihWqrq5WTEyMUlNT5fV6O3s/AAAAV+hQNF00YMAA3XXXXZ21CwAAgGt16IXgAAAAPQ3RBAAAYIBoAgAAMEA0AQAAGCCaAAAADBBNAAAABogmAAAAA0QTAACAAaIJAADAANEEAABggGgCAAAwQDQBAAAYIJoAAAAMEE0AAAAGiCYAAAADRBMAAIABogkAAMBAt4qmoqIieTweFRQUOMds29bKlSuVlJSkmJgYTZw4UceOHQt7XHNzsxYsWKD4+Hj1799fOTk5OnXqVNhMMBiU3++XZVmyLEt+v18NDQ1dcFYAAKA76DbRdPjwYW3atEl33HFH2PE1a9Zo7dq1Ki4u1uHDh+Xz+TRlyhSdPXvWmSkoKFBpaalKSkq0b98+nTt3TtnZ2Wpra3NmcnNzVVVVpUAgoEAgoKqqKvn9/i47PwAA4G7dIprOnTunRx99VJs3b9agQYOc47Zta/369Vq+fLmmT5+utLQ0bd26VR9//LF27twpSQqFQnr55Zf1/PPPa/LkyRo9erS2b9+ut99+W6+99pokqbq6WoFAQP/6r/+qzMxMZWZmavPmzfr5z3+u48ePR+ScAQCAu3SLaJo/f77uv/9+TZ48Oex4TU2N6urqlJWV5Rzzer2aMGGC9u/fL0mqrKxUa2tr2ExSUpLS0tKcmQMHDsiyLGVkZDgz48aNk2VZzszlNDc3q7GxMewDAABcn6IivcDVlJSU6Le//a0OHz7c7r66ujpJUmJiYtjxxMREffDBB85MdHR02BWqizMXH19XV6eEhIR2nz8hIcGZuZyioiI9+eSTn++EAABAt+TqK00nT57Ut7/9bW3fvl19+/b9zDmPxxN227btdscudenM5eav9nmWLVumUCjkfJw8efKKXxMAAHRfro6myspK1dfXKz09XVFRUYqKilJ5ebleeOEFRUVFOVeYLr0aVF9f79zn8/nU0tKiYDB4xZnTp0+3+/pnzpxpdxXrr3m9Xg0cODDsAwAAXJ9cHU2TJk3S22+/raqqKudjzJgxevTRR1VVVaXhw4fL5/OprKzMeUxLS4vKy8s1fvx4SVJ6err69OkTNlNbW6ujR486M5mZmQqFQqqoqHBmDh06pFAo5MwAAICezdWvaYqNjVVaWlrYsf79+2vw4MHO8YKCAhUWFio1NVWpqakqLCxUv379lJubK0myLEuzZ8/WokWLNHjwYMXFxWnx4sUaNWqU88LykSNHatq0aZozZ442btwoSZo7d66ys7M1YsSILjxjAADgVq6OJhNLlixRU1OT8vLyFAwGlZGRoT179ig2NtaZWbdunaKiojRjxgw1NTVp0qRJ2rJli3r37u3M7NixQ/n5+c677HJyclRcXNzl5wMAANzJY9u2HeklrheNjY2yLEuhUKjDr286MmZsJ2+F7mzMkYqrD3WBaSt2RXoFuEjgqYcjvYJytz0S6RXgIjv9Jdf0eNOf365+TRMAAIBbEE0AAAAGiCYAAAADRBMAAIABogkAAMAA0QQAAGCAaAIAADBANAEAABggmgAAAAwQTQAAAAaIJgAAAANEEwAAgAGiCQAAwADRBAAAYIBoAgAAMEA0AQAAGCCaAAAADBBNAAAABogmAAAAA0QTAACAAaIJAADAANEEAABggGgCAAAwQDQBAAAYIJoAAAAMEE0AAAAGiCYAAAADRBMAAIABogkAAMAA0QQAAGCAaAIAADBANAEAABggmgAAAAwQTQAAAAaIJgAAAANEEwAAgAGiCQAAwADRBAAAYIBoAgAAMEA0AQAAGCCaAAAADBBNAAAABogmAAAAA0QTAACAAaIJAADAANEEAABggGgCAAAwQDQBAAAYIJoAAAAMEE0AAAAGiCYAAAADRBMAAIABogkAAMAA0QQAAGCAaAIAADBANAEAABggmgAAAAwQTQAAAAaIJgAAAANEEwAAgAFXR1NRUZHuuusuxcbGKiEhQQ8++KCOHz8eNmPbtlauXKmkpCTFxMRo4sSJOnbsWNhMc3OzFixYoPj4ePXv3185OTk6depU2EwwGJTf75dlWbIsS36/Xw0NDV/0KQIAgG7C1dFUXl6u+fPn6+DBgyorK9Onn36qrKwsnT9/3plZs2aN1q5dq+LiYh0+fFg+n09TpkzR2bNnnZmCggKVlpaqpKRE+/bt07lz55Sdna22tjZnJjc3V1VVVQoEAgoEAqqqqpLf7+/S8wUAAO4VFekFriQQCITdfuWVV5SQkKDKykrde++9sm1b69ev1/LlyzV9+nRJ0tatW5WYmKidO3dq3rx5CoVCevnll7Vt2zZNnjxZkrR9+3YlJyfrtdde09SpU1VdXa1AIKCDBw8qIyNDkrR582ZlZmbq+PHjGjFiRNeeOAAAcB1XX2m6VCgUkiTFxcVJkmpqalRXV6esrCxnxuv1asKECdq/f78kqbKyUq2trWEzSUlJSktLc2YOHDggy7KcYJKkcePGybIsZ+Zympub1djYGPYBAACuT90mmmzb1sKFC/XlL39ZaWlpkqS6ujpJUmJiYthsYmKic19dXZ2io6M1aNCgK84kJCS0+5oJCQnOzOUUFRU5r4GyLEvJyckdP0EAAOBq3SaaHn/8cb311lt69dVX293n8XjCbtu23e7YpS6dudz81T7PsmXLFAqFnI+TJ09e7TQAAEA31S2iacGCBdq9e7d+/etf68Ybb3SO+3w+SWp3Nai+vt65+uTz+dTS0qJgMHjFmdOnT7f7umfOnGl3Feuveb1eDRw4MOwDAABcn1wdTbZt6/HHH9dPfvIT/dd//ZdSUlLC7k9JSZHP51NZWZlzrKWlReXl5Ro/frwkKT09XX369Ambqa2t1dGjR52ZzMxMhUIhVVRUODOHDh1SKBRyZgAAQM/m6nfPzZ8/Xzt37tTPfvYzxcbGOleULMtSTEyMPB6PCgoKVFhYqNTUVKWmpqqwsFD9+vVTbm6uMzt79mwtWrRIgwcPVlxcnBYvXqxRo0Y576YbOXKkpk2bpjlz5mjjxo2SpLlz5yo7O5t3zgEAAEkuj6YNGzZIkiZOnBh2/JVXXtFjjz0mSVqyZImampqUl5enYDCojIwM7dmzR7Gxsc78unXrFBUVpRkzZqipqUmTJk3Sli1b1Lt3b2dmx44dys/Pd95ll5OTo+Li4i/2BAEAQLfhsW3bjvQS14vGxkZZlqVQKNTh1zcdGTO2k7dCdzbmSMXVh7rAtBW7Ir0CXCTw1MORXkG52x6J9ApwkZ3+kmt6vOnPb1e/pgkAAMAtiCYAAAADRBMAAIABogkAAMAA0QQAAGCAaAIAADBANAEAABggmgAAAAwQTQAAAAaIJgAAAANEEwAAgAGiCQAAwADRBAAAYIBoAgAAMEA0AQAAGCCaAAAADBBNAAAABogmAAAAA0QTAACAAaIJAADAANEEAABggGgCAAAwQDQBAAAYIJoAAAAMEE0AAAAGiCYAAAADRBMAAIABogkAAMAA0QQAAGCAaAIAADBANAEAABggmgAAAAwQTQAAAAaIJgAAAANEEwAAgAGiCQAAwADRBAAAYIBoAgAAMEA0AQAAGCCaAAAADBBNAAAABogmAAAAA0QTAACAAaIJAADAANEEAABggGgCAAAwQDQBAAAYIJoAAAAMEE0AAAAGiCYAAAADRBMAAIABogkAAMAA0QQAAGCAaAIAADBANAEAABggmgAAAAwQTQAAAAaIJgAAAANEEwAAgAGiCQAAwADRdIkf/OAHSklJUd++fZWenq7f/OY3kV4JAAC4ANH0V3bt2qWCggItX75cb775pu655x599atf1YkTJyK9GgAAiDCi6a+sXbtWs2fP1je/+U2NHDlS69evV3JysjZs2BDp1QAAQIRFRXoBt2hpaVFlZaW+973vhR3PysrS/v37L/uY5uZmNTc3O7dDoZAkqbGxscN7nGtr6/Bjcf25lu+lzvRp88eRXgEu4obvy9am1kivABe51u/Ji4+3bfuKc0TTX/zpT39SW1ubEhMTw44nJiaqrq7uso8pKirSk08+2e54cnLyF7IjeiDLivQGQDvWs7MivQIQ5kfzftIpn+fs2bOyrvD/XaLpEh6PJ+y2bdvtjl20bNkyLVy40Ll94cIFffTRRxo8ePBnPgZX19jYqOTkZJ08eVIDBw6M9DqAJL4v4T58T3Ye27Z19uxZJSUlXXGOaPqL+Ph49e7du91Vpfr6+nZXny7yer3yer1hx2644YYvasUeZ+DAgfyPAK7D9yXchu/JznGlK0wX8ULwv4iOjlZ6errKysrCjpeVlWn8+PER2goAALgFV5r+ysKFC+X3+zVmzBhlZmZq06ZNOnHihL71rW9FejUAABBhRNNfefjhh/V///d/WrVqlWpra5WWlqZf/vKXuummmyK9Wo/i9Xr1xBNPtHvqE4gkvi/hNnxPdj2PfbX31wEAAIDXNAEAAJggmgAAAAwQTQAAAAaIJgAAAANEE1ylrq5OCxYs0PDhw+X1epWcnKwHHnhAr7/+eqRXQw/12GOP6cEHH4z0GoCkz/5+3Lt3rzwejxoaGrp8p56EXzkA1/jf//1f3X333brhhhu0Zs0a3XHHHWptbdWvfvUrzZ8/X++++26kVwQA9GBEE1wjLy9PHo9HFRUV6t+/v3P89ttv16xZ/IFQAEBk8fQcXOGjjz5SIBDQ/Pnzw4LpIv6mHwAg0rjSBFf4/e9/L9u2deutt0Z6FQBwtZ///OcaMGBA2LG2trYIbdOzEE1whYu/mN7j8UR4EwBwt6985SvasGFD2LFDhw7pG9/4RoQ26jmIJrhCamqqPB6PqqureacSAFxB//79dcstt4QdO3XqVIS26Vl4TRNcIS4uTlOnTtW//Mu/6Pz58+3u5220AIBI40oTXOMHP/iBxo8fr7Fjx2rVqlW644479Omnn6qsrEwbNmxQdXV1pFdEDxUKhVRVVRV2LC4uTsOGDYvMQgAigmiCa6SkpOi3v/2tVq9erUWLFqm2tlZDhgxRenp6u+fvga60d+9ejR49OuzYzJkztWXLlsgsBCAiPPbFV+ACAADgM/GaJgAAAANEEwAAgAGiCQAAwADRBAAAYIBoAgAAMEA0AQAAGCCaAAAADBBNAAAABogmAD3GxIkTVVBQYDS7d+9eeTyea/67hzfffLPWr19/TZ8DgDsQTQAAAAaIJgAAAANEE4Aeafv27RozZoxiY2Pl8/mUm5ur+vr6dnNvvPGG7rzzTvXt21cZGRl6++23w+7fv3+/7r33XsXExCg5OVn5+fk6f/58V50GgC5ENAHokVpaWvTUU0/pd7/7nX7605+qpqZGjz32WLu57373u3ruued0+PBhJSQkKCcnR62trZKkt99+W1OnTtX06dP11ltvadeuXdq3b58ef/zxLj4bAF0hKtILAEAkzJo1y/nv4cOH64UXXtDYsWN17tw5DRgwwLnviSee0JQpUyRJW7du1Y033qjS0lLNmDFDzz77rHJzc50Xl6empuqFF17QhAkTtGHDBvXt27dLzwnAF4srTQB6pDfffFNf//rXddNNNyk2NlYTJ06UJJ04cSJsLjMz0/nvuLg4jRgxQtXV1ZKkyspKbdmyRQMGDHA+pk6dqgsXLqimpqbLzgVA1+BKE4Ae5/z588rKylJWVpa2b9+uIUOG6MSJE5o6dapaWlqu+niPxyNJunDhgubNm6f8/Px2M8OGDev0vQFEFtEEoMd599139ac//UlPP/20kpOTJUlHjhy57OzBgwedAAoGg3rvvfd06623SpL+7u/+TseOHdMtt9zSNYsDiCiengPQ4wwbNkzR0dF68cUX9Yc//EG7d+/WU089ddnZVatW6fXXX9fRo0f12GOPKT4+Xg8++KAkaenSpTpw4IDmz5+vqqoqvf/++9q9e7cWLFjQhWcDoKsQTQB6nCFDhmjLli3693//d9122216+umn9dxzz1129umnn9a3v/1tpaenq7a2Vrt371Z0dLQk6Y477lB5ebnef/993XPPPRo9erRWrFihoUOHduXpAOgiHtu27UgvAQAA4HZcaQIAADBANAEAABggmgAAAAwQTQAAAAaIJgAAAANEEwAAgAGiCQAAwADRBAAAYIBoAgAAMEA0AQAAGCCaAAAADPw/jUJkVPp5SykAAAAASUVORK5CYII=",
|
| 511 |
+
"text/plain": [
|
| 512 |
+
"<Figure size 640x480 with 1 Axes>"
|
| 513 |
+
]
|
| 514 |
+
},
|
| 515 |
+
"metadata": {},
|
| 516 |
+
"output_type": "display_data"
|
| 517 |
+
}
|
| 518 |
+
],
|
| 519 |
+
"source": [
|
| 520 |
+
"sns.countplot(x='label', data=df, palette=\"Set1\") "
|
| 521 |
+
]
|
| 522 |
+
},
|
| 523 |
+
{
|
| 524 |
+
"cell_type": "markdown",
|
| 525 |
+
"metadata": {},
|
| 526 |
+
"source": [
|
| 527 |
+
"### 5. Extract data for test set"
|
| 528 |
+
]
|
| 529 |
+
},
|
| 530 |
+
{
|
| 531 |
+
"cell_type": "code",
|
| 532 |
+
"execution_count": 5,
|
| 533 |
+
"metadata": {},
|
| 534 |
+
"outputs": [
|
| 535 |
+
{
|
| 536 |
+
"name": "stderr",
|
| 537 |
+
"output_type": "stream",
|
| 538 |
+
"text": [
|
| 539 |
+
"INFO: Created TensorFlow Lite XNNPACK delegate for CPU.\n"
|
| 540 |
+
]
|
| 541 |
+
}
|
| 542 |
+
],
|
| 543 |
+
"source": [
|
| 544 |
+
"TEST_DATASET_PATH = \"test.csv\"\n",
|
| 545 |
+
"\n",
|
| 546 |
+
"cap = cv2.VideoCapture(\"../data/plank/plank_test_4.mp4\")\n",
|
| 547 |
+
"save_counts = 0\n",
|
| 548 |
+
"\n",
|
| 549 |
+
"# init_csv(TEST_DATASET_PATH)\n",
|
| 550 |
+
"\n",
|
| 551 |
+
"with mp_pose.Pose(min_detection_confidence=0.5, min_tracking_confidence=0.5) as pose:\n",
|
| 552 |
+
" while cap.isOpened():\n",
|
| 553 |
+
" ret, image = cap.read()\n",
|
| 554 |
+
"\n",
|
| 555 |
+
" if not ret:\n",
|
| 556 |
+
" break\n",
|
| 557 |
+
"\n",
|
| 558 |
+
" # Reduce size of a frame\n",
|
| 559 |
+
" image = rescale_frame(image, 60)\n",
|
| 560 |
+
" image = cv2.flip(image, 1)\n",
|
| 561 |
+
"\n",
|
| 562 |
+
" # Recolor image from BGR to RGB for mediapipe\n",
|
| 563 |
+
" image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)\n",
|
| 564 |
+
" image.flags.writeable = False\n",
|
| 565 |
+
"\n",
|
| 566 |
+
" results = pose.process(image)\n",
|
| 567 |
+
"\n",
|
| 568 |
+
" if not results.pose_landmarks: continue\n",
|
| 569 |
+
"\n",
|
| 570 |
+
" # Recolor image from BGR to RGB for mediapipe\n",
|
| 571 |
+
" image.flags.writeable = True\n",
|
| 572 |
+
" image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)\n",
|
| 573 |
+
"\n",
|
| 574 |
+
" # Draw landmarks and connections\n",
|
| 575 |
+
" mp_drawing.draw_landmarks(image, results.pose_landmarks, mp_pose.POSE_CONNECTIONS, mp_drawing.DrawingSpec(color=(244, 117, 66), thickness=2, circle_radius=4), mp_drawing.DrawingSpec(color=(245, 66, 230), thickness=2, circle_radius=2))\n",
|
| 576 |
+
"\n",
|
| 577 |
+
" # Display the saved count\n",
|
| 578 |
+
" cv2.putText(image, f\"Saved: {save_counts}\", (50, 50), cv2.FONT_HERSHEY_COMPLEX, 2, (0, 0, 0), 2, cv2.LINE_AA)\n",
|
| 579 |
+
"\n",
|
| 580 |
+
" cv2.imshow(\"CV2\", image)\n",
|
| 581 |
+
"\n",
|
| 582 |
+
" # Pressed key for action\n",
|
| 583 |
+
" k = cv2.waitKey(1) & 0xFF\n",
|
| 584 |
+
"\n",
|
| 585 |
+
" # Press C to save as correct form\n",
|
| 586 |
+
" if k == ord('c'): \n",
|
| 587 |
+
" export_landmark_to_csv(TEST_DATASET_PATH, results, \"C\")\n",
|
| 588 |
+
" save_counts += 1\n",
|
| 589 |
+
" # Press L to save as low back\n",
|
| 590 |
+
" elif k == ord(\"l\"):\n",
|
| 591 |
+
" export_landmark_to_csv(TEST_DATASET_PATH, results, \"L\")\n",
|
| 592 |
+
" save_counts += 1\n",
|
| 593 |
+
" # Press L to save as high back\n",
|
| 594 |
+
" elif k == ord(\"h\"):\n",
|
| 595 |
+
" export_landmark_to_csv(TEST_DATASET_PATH, results, \"H\")\n",
|
| 596 |
+
" save_counts += 1\n",
|
| 597 |
+
"\n",
|
| 598 |
+
" # Press q to stop\n",
|
| 599 |
+
" elif k == ord(\"q\"):\n",
|
| 600 |
+
" break\n",
|
| 601 |
+
" else: continue\n",
|
| 602 |
+
"\n",
|
| 603 |
+
" cap.release()\n",
|
| 604 |
+
" cv2.destroyAllWindows()\n",
|
| 605 |
+
"\n",
|
| 606 |
+
" # (Optional)Fix bugs cannot close windows in MacOS (https://stackoverflow.com/questions/6116564/destroywindow-does-not-close-window-on-mac-using-python-and-opencv)\n",
|
| 607 |
+
" for i in range (1, 5):\n",
|
| 608 |
+
" cv2.waitKey(1)\n",
|
| 609 |
+
" "
|
| 610 |
+
]
|
| 611 |
+
},
|
| 612 |
+
{
|
| 613 |
+
"cell_type": "code",
|
| 614 |
+
"execution_count": 6,
|
| 615 |
+
"metadata": {},
|
| 616 |
+
"outputs": [
|
| 617 |
+
{
|
| 618 |
+
"name": "stdout",
|
| 619 |
+
"output_type": "stream",
|
| 620 |
+
"text": [
|
| 621 |
+
"Headers: ['label', 'nose_x', 'nose_y', 'nose_z', 'nose_v', 'left_shoulder_x', 'left_shoulder_y', 'left_shoulder_z', 'left_shoulder_v', 'right_shoulder_x', 'right_shoulder_y', 'right_shoulder_z', 'right_shoulder_v', 'left_elbow_x', 'left_elbow_y', 'left_elbow_z', 'left_elbow_v', 'right_elbow_x', 'right_elbow_y', 'right_elbow_z', 'right_elbow_v', 'left_wrist_x', 'left_wrist_y', 'left_wrist_z', 'left_wrist_v', 'right_wrist_x', 'right_wrist_y', 'right_wrist_z', 'right_wrist_v', 'left_hip_x', 'left_hip_y', 'left_hip_z', 'left_hip_v', 'right_hip_x', 'right_hip_y', 'right_hip_z', 'right_hip_v', 'left_knee_x', 'left_knee_y', 'left_knee_z', 'left_knee_v', 'right_knee_x', 'right_knee_y', 'right_knee_z', 'right_knee_v', 'left_ankle_x', 'left_ankle_y', 'left_ankle_z', 'left_ankle_v', 'right_ankle_x', 'right_ankle_y', 'right_ankle_z', 'right_ankle_v', 'left_heel_x', 'left_heel_y', 'left_heel_z', 'left_heel_v', 'right_heel_x', 'right_heel_y', 'right_heel_z', 'right_heel_v', 'left_foot_index_x', 'left_foot_index_y', 'left_foot_index_z', 'left_foot_index_v', 'right_foot_index_x', 'right_foot_index_y', 'right_foot_index_z', 'right_foot_index_v']\n",
|
| 622 |
+
"Number of rows: 710 \n",
|
| 623 |
+
"Number of columns: 69\n",
|
| 624 |
+
"\n",
|
| 625 |
+
"Labels: \n",
|
| 626 |
+
"H 241\n",
|
| 627 |
+
"L 235\n",
|
| 628 |
+
"C 234\n",
|
| 629 |
+
"Name: label, dtype: int64\n",
|
| 630 |
+
"\n",
|
| 631 |
+
"Missing values: False\n",
|
| 632 |
+
"\n",
|
| 633 |
+
"Duplicate Rows : 0\n"
|
| 634 |
+
]
|
| 635 |
+
}
|
| 636 |
+
],
|
| 637 |
+
"source": [
|
| 638 |
+
"test_df = describe_dataset(TEST_DATASET_PATH)"
|
| 639 |
+
]
|
| 640 |
+
},
|
| 641 |
+
{
|
| 642 |
+
"cell_type": "code",
|
| 643 |
+
"execution_count": 7,
|
| 644 |
+
"metadata": {},
|
| 645 |
+
"outputs": [
|
| 646 |
+
{
|
| 647 |
+
"data": {
|
| 648 |
+
"text/plain": [
|
| 649 |
+
"<AxesSubplot:xlabel='count', ylabel='label'>"
|
| 650 |
+
]
|
| 651 |
+
},
|
| 652 |
+
"execution_count": 7,
|
| 653 |
+
"metadata": {},
|
| 654 |
+
"output_type": "execute_result"
|
| 655 |
+
},
|
| 656 |
+
{
|
| 657 |
+
"data": {
|
| 658 |
+
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAjIAAAGwCAYAAACzXI8XAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/NK7nSAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAYd0lEQVR4nO3df5DVdf3o8ddBYEFiV1Bg2VyI/C0iGWKi9hUxUUuD8TZDjRVk5SiCw8DcjOka1vci4ow6Ovgrx5s41eA0qTlRmD+A8gdpigYKRCqBCmIqLIItAp/7R5e9rYAt6+6efa2Px8zOcD6f8+N1zpvP7HM+e3ZPqSiKIgAAEupU7gEAAJpLyAAAaQkZACAtIQMApCVkAIC0hAwAkJaQAQDS6lzuAVrCrl274vXXX4+ePXtGqVQq9zgAQBMURRFbtmyJmpqa6NSpeedWOkTIvP7661FbW1vuMQCAZli3bl0ceuihzbpthwiZnj17RsS/XojKysoyTwMANEVdXV3U1tY2fB9vjg4RMrt/nFRZWSlkACCZj/K2EG/2BQDSEjIAQFpCBgBIS8gAAGkJGQAgLSEDAKQlZACAtIQMAJCWkAEA0uoQf9l3t2dPPyM+ccAB5R4DANqdE//8VLlHaBXOyAAAaQkZACAtIQMApCVkAIC0hAwAkJaQAQDSEjIAQFpCBgBIS8gAAGkJGQAgLSEDAKQlZACAtIQMAJCWkAEA0hIyAEBaQgYASEvIAABpCRkAIC0hAwCkJWQAgLSEDACQlpABANISMgBAWkIGAEhLyAAAaQkZACAtIQMApCVkAIC0hAwAkJaQAQDSEjIAQFpCBgBIS8gAAGkJGQAgLSEDAKQlZACAtIQMAJCWkAEA0hIyAEBaQgYASEvIAABpCRkAIC0hAwCkJWQAgLSEDACQVrsImQ0bNsTkyZPj05/+dFRUVERtbW2cf/758cgjj5R7NACgHetc7gHWrFkTp556ahx00EFx7bXXxvHHHx/vv/9+PPjgg3HZZZfFypUryz0iANBOlT1kJk6cGKVSKZ566qno0aNHw/bBgwfHRRddVMbJAID2rqwh8/bbb8eCBQti5syZjSJmt4MOOmivt6uvr4/6+vqGy3V1da01IgDQjpX1PTJ/+9vfoiiKOProo/frdrNmzYqqqqqGr9ra2laaEABoz8oaMkVRREREqVTar9tNnz49Nm/e3PC1bt261hgPAGjnyhoyRxxxRJRKpVixYsV+3a6ioiIqKysbfQEAHz9lDZnevXvH2WefHTfffHNs3bp1j/2bNm1q+6EAgDTK/ndkbrnllti5c2ecdNJJ8atf/SpWr14dK1asiJtuuilGjBhR7vEAgHas7L9+PWjQoHj22Wdj5syZMW3atFi/fn306dMnhg0bFrfeemu5xwMA2rFSsfsdt4nV1dVFVVVVLPzMZ+MTBxxQ7nEAoN058c9PlXuEPez+/r158+Zmv9+17D9aAgBoLiEDAKQlZACAtIQMAJCWkAEA0hIyAEBaQgYASEvIAABpCRkAIC0hAwCkJWQAgLSEDACQlpABANISMgBAWkIGAEhLyAAAaQkZACAtIQMApCVkAIC0hAwAkJaQAQDSEjIAQFpCBgBIS8gAAGkJGQAgLSEDAKQlZACAtIQMAJCWkAEA0hIyAEBaQgYASEvIAABpCRkAIC0hAwCkJWQAgLSEDACQlpABANISMgBAWkIGAEhLyAAAaQkZACAtIQMApNW53AO0pM8uXhiVlZXlHgMAaCPOyAAAaQkZACAtIQMApCVkAIC0hAwAkJaQAQDSEjIAQFpCBgBIS8gAAGkJGQAgLSEDAKQlZACAtIQMAJCWkAEA0hIyAEBaQgYASEvIAABpCRkAIC0hAwCkJWQAgLSEDACQlpABANISMgBAWkIGAEirc7kHaEkX/O9fReeKA8s9BgC0Wwv+e1y5R2hRzsgAAGkJGQAgLSEDAKQlZACAtIQMAJCWkAEA0hIyAEBaQgYASEvIAABpCRkAIC0hAwCkJWQAgLSEDACQlpABANISMgBAWkIGAEhLyAAAaQkZACAtIQMApCVkAIC0hAwAkJaQAQDSEjIAQFpCBgBIS8gAAGkJGQAgLSEDAKQlZACAtIQMAJCWkAEA0hIyAEBaQgYASEvIAABpCRkAIC0hAwCkJWQAgLQ6N/WKN910U5Pv9PLLL2/WMAAA+6PJIXPDDTc06XqlUknIAABtoskh88orr7TmHAAA++0jvUdm+/btsWrVqtixY0dLzQMA0GTNCplt27bFt7/97TjwwANj8ODBsXbt2oj413tjrrnmmhYdEABgX5oVMtOnT4/nn38+Fi1aFN26dWvY/oUvfCHuueeeFhsOAODDNPk9Mv/u/vvvj3vuuSdOPvnkKJVKDduPPfbYeOmll1psOACAD9OsMzJvvvlm9O3bd4/tW7dubRQ2AACtqVkhM3z48Jg/f37D5d3xcscdd8SIESOaNciECRNi7NixzbotAPDx1KwfLc2aNSvOOeecePHFF2PHjh1x4403xgsvvBBPPvlkLF68uKVnBADYq2adkTnllFPi8ccfj23btsVhhx0Wv//976Nfv37x5JNPxrBhw1p6RgCAvWrWGZmIiCFDhsTcuXNbcpYmq6+vj/r6+obLdXV1ZZkDACivZofMzp0747777osVK1ZEqVSKY445JsaMGROdOzf7Lpts1qxZ8aMf/ajVHwcAaN+aVR3Lly+PMWPGxIYNG+Koo46KiIi//vWv0adPn3jggQdiyJAhLTrkB02fPj2mTp3acLmuri5qa2tb9TEBgPanWSHzne98JwYPHhx//vOfo1evXhER8c4778SECRPi4osvjieffLJFh/ygioqKqKioaNXHAADav2aFzPPPP98oYiIievXqFTNnzozhw4e32HAAAB+mWSFz1FFHxRtvvBGDBw9utH3jxo1x+OGHN3uYzZs3x3PPPddoW+/evWPAgAHNvk8AoONqcsj8+28GXX311XH55ZfHVVddFSeffHJERCxZsiR+/OMfx+zZs5s9zKJFi+KEE05otG38+PFx1113Nfs+AYCOq1QURdGUK3bq1KnRxw/svtnubf9+eefOnS0954eqq6uLqqqqOPN//p/oXHFgmz42AGSy4L/HlXuEBru/f2/evDkqKyubdR9NPiOzcOHCZj0AAEBraXLInH766a05BwDAfvtIf71u27ZtsXbt2ti+fXuj7ccff/xHGgoAoCmaFTJvvvlmfOtb34rf/e53e93f1u+RAQA+npr1oZFTpkyJd955J5YsWRLdu3ePBQsWxNy5c+OII46IBx54oKVnBADYq2adkXn00Ufj17/+dQwfPjw6deoUAwcOjLPOOisqKytj1qxZ8aUvfaml5wQA2EOzzshs3bo1+vbtGxH/+oN1b775ZkT86xOxn3322ZabDgDgQzQrZI466qhYtWpVRER85jOfidtvvz1ee+21uO2226J///4tOiAAwL4060dLU6ZMifXr10dExIwZM+Lss8+On/3sZ9G1a9eYO3duiw4IALAvzQqZCy+8sOHfJ5xwQqxZsyZWrlwZAwYMiEMOOaTFhgMA+DBNDpmpU6c2+U6vv/76Zg0DALA/mhwyS5cubdL1/v3zmAAAWpPPWgIA0mrWby0BALQHQgYASEvIAABpCRkAIC0hAwCkJWQAgLSEDACQlpABANISMgBAWkIGAEhLyAAAaQkZACAtIQMApCVkAIC0hAwAkJaQAQDSEjIAQFpCBgBIS8gAAGkJGQAgLSEDAKQlZACAtIQMAJCWkAEA0hIyAEBaQgYASEvIAABpCRkAIC0hAwCk1bncA7Ske//X/4jKyspyjwEAtBFnZACAtIQMAJCWkAEA0hIyAEBaQgYASEvIAABpCRkAIC0hAwCkJWQAgLSEDACQlpABANISMgBAWkIGAEhLyAAAaQkZACAtIQMApCVkAIC0hAwAkJaQAQDSEjIAQFpCBgBIS8gAAGkJGQAgLSEDAKTVudwDtKRvz/tWdOnepdxjAECH8YtvzCv3CB/KGRkAIC0hAwCkJWQAgLSEDACQlpABANISMgBAWkIGAEhLyAAAaQkZACAtIQMApCVkAIC0hAwAkJaQAQDSEjIAQFpCBgBIS8gAAGkJGQAgLSEDAKQlZACAtIQMAJCWkAEA0hIyAEBaQgYASEvIAABpCRkAIC0hAwCkJWQAgLSEDACQlpABANISMgBAWkIGAEhLyAAAaQkZACAtIQMApCVkAIC0hAwAkJaQAQDSEjIAQFpCBgBIS8gAAGkJGQAgLSEDAKQlZACAtIQMAJCWkAEA0ipryEyYMCHGjh27x/ZFixZFqVSKTZs2tflMAEAezsgAAGl1LvcAzVFfXx/19fUNl+vq6so4DQBQLinPyMyaNSuqqqoavmpra8s9EgBQBmU/I/Ob3/wmPvGJTzTatnPnzg+9zfTp02Pq1KkNl+vq6sQMAHwMlT1kzjjjjLj11lsbbfvTn/4UX//61/d5m4qKiqioqGjt0QCAdq7sIdOjR484/PDDG2179dVXyzQNAJBJyvfIAABECBkAIDEhAwCkVdb3yNx111173T5y5MgoiqJthwEA0nFGBgBIS8gAAGkJGQAgLSEDAKQlZACAtIQMAJCWkAEA0hIyAEBaQgYASEvIAABpCRkAIC0hAwCkJWQAgLSEDACQlpABANISMgBAWkIGAEhLyAAAaQkZACAtIQMApCVkAIC0hAwAkJaQAQDSEjIAQFpCBgBIS8gAAGkJGQAgLSEDAKQlZACAtIQMAJCWkAEA0hIyAEBaQgYASEvIAABpCRkAIC0hAwCkJWQAgLSEDACQlpABANISMgBAWkIGAEirc7kHaEl3fvWnUVlZWe4xAIA24owMAJCWkAEA0hIyAEBaQgYASEvIAABpCRkAIC0hAwCkJWQAgLSEDACQlpABANISMgBAWh3is5aKooiIiLq6ujJPAgA01e7v27u/jzdHhwiZt956KyIiamtryzwJALC/tmzZElVVVc26bYcImd69e0dExNq1a5v9QvDR1dXVRW1tbaxbt86nkJeRdWgfrEP7YB3ah32tQ1EUsWXLlqipqWn2fXeIkOnU6V9v9amqqvIftR2orKy0Du2AdWgfrEP7YB3ah72tw0c9AeHNvgBAWkIGAEirQ4RMRUVFzJgxIyoqKso9yseadWgfrEP7YB3aB+vQPrTmOpSKj/I7TwAAZdQhzsgAAB9PQgYASEvIAABpCRkAIK30IXPLLbfEoEGDolu3bjFs2LD44x//WO6ROrSrrroqSqVSo6/q6uqG/UVRxFVXXRU1NTXRvXv3GDlyZLzwwgtlnLhj+MMf/hDnn39+1NTURKlUivvvv7/R/qa87vX19TF58uQ45JBDokePHvHlL385Xn311TZ8Fvn9p3WYMGHCHsfHySef3Og61uGjmTVrVgwfPjx69uwZffv2jbFjx8aqVasaXcfx0Pqasg5tdTykDpl77rknpkyZEj/4wQ9i6dKl8fnPfz7OPffcWLt2bblH69AGDx4c69evb/hatmxZw75rr702rr/++pgzZ048/fTTUV1dHWeddVZs2bKljBPnt3Xr1hg6dGjMmTNnr/ub8rpPmTIl7rvvvpg3b1489thj8e6778Z5550XO3fubKunkd5/WoeIiHPOOafR8fHb3/620X7r8NEsXrw4LrvssliyZEk89NBDsWPHjhg9enRs3bq14TqOh9bXlHWIaKPjoUjspJNOKi655JJG244++uji+9//fpkm6vhmzJhRDB06dK/7du3aVVRXVxfXXHNNw7Z//vOfRVVVVXHbbbe10YQdX0QU9913X8PlprzumzZtKrp06VLMmzev4TqvvfZa0alTp2LBggVtNntH8sF1KIqiGD9+fDFmzJh93sY6tLyNGzcWEVEsXry4KArHQ7l8cB2Kou2Oh7RnZLZv3x7PPPNMjB49utH20aNHxxNPPFGmqT4eVq9eHTU1NTFo0KD46le/Gi+//HJERLzyyiuxYcOGRmtSUVERp59+ujVpRU153Z955pl4//33G12npqYmjjvuOGvTwhYtWhR9+/aNI488Mr773e/Gxo0bG/ZZh5a3efPmiPj/Hx7seCiPD67Dbm1xPKQNmX/84x+xc+fO6NevX6Pt/fr1iw0bNpRpqo7vc5/7XNx9993x4IMPxh133BEbNmyIU045Jd56662G192atK2mvO4bNmyIrl27Rq9evfZ5HT66c889N37+85/Ho48+Gtddd108/fTTMWrUqKivr48I69DSiqKIqVOnxmmnnRbHHXdcRDgeymFv6xDRdsdD+k+/LpVKjS4XRbHHNlrOueee2/DvIUOGxIgRI+Kwww6LuXPnNryJy5qUR3Ned2vTssaNG9fw7+OOOy5OPPHEGDhwYMyfPz8uuOCCfd7OOjTPpEmT4i9/+Us89thje+xzPLSdfa1DWx0Pac/IHHLIIXHAAQfsUW0bN27co8RpPT169IghQ4bE6tWrG357yZq0raa87tXV1bF9+/Z455139nkdWl7//v1j4MCBsXr16oiwDi1p8uTJ8cADD8TChQvj0EMPbdjueGhb+1qHvWmt4yFtyHTt2jWGDRsWDz30UKPtDz30UJxyyillmurjp76+PlasWBH9+/ePQYMGRXV1daM12b59eyxevNiatKKmvO7Dhg2LLl26NLrO+vXrY/ny5damFb311luxbt266N+/f0RYh5ZQFEVMmjQp7r333nj00Udj0KBBjfY7HtrGf1qHvWm146HJbwtuh+bNm1d06dKluPPOO4sXX3yxmDJlStGjR49izZo15R6tw5o2bVqxaNGi4uWXXy6WLFlSnHfeeUXPnj0bXvNrrrmmqKqqKu69995i2bJlxde+9rWif//+RV1dXZknz23Lli3F0qVLi6VLlxYRUVx//fXF0qVLi7///e9FUTTtdb/kkkuKQw89tHj44YeLZ599thg1alQxdOjQYseOHeV6Wul82Dps2bKlmDZtWvHEE08Ur7zySrFw4cJixIgRxSc/+Unr0IIuvfTSoqqqqli0aFGxfv36hq9t27Y1XMfx0Pr+0zq05fGQOmSKoihuvvnmYuDAgUXXrl2Lz372s41+9YuWN27cuKJ///5Fly5dipqamuKCCy4oXnjhhYb9u3btKmbMmFFUV1cXFRUVxX/9138Vy5YtK+PEHcPChQuLiNjja/z48UVRNO11f++994pJkyYVvXv3Lrp3716cd955xdq1a8vwbPL6sHXYtm1bMXr06KJPnz5Fly5digEDBhTjx4/f4zW2Dh/N3l7/iCh++tOfNlzH8dD6/tM6tOXxUPp/AwEApJP2PTIAAEIGAEhLyAAAaQkZACAtIQMApCVkAIC0hAwAkJaQAQDSEjIAQFpCBuiw1qxZE6VSKZ577rlyjwK0EiEDAKQlZIBWs2vXrpg9e3YcfvjhUVFREQMGDIiZM2dGRMSyZcti1KhR0b179zj44IPj4osvjnfffbfhtiNHjowpU6Y0ur+xY8fGhAkTGi5/6lOfiquvvjouuuii6NmzZwwYMCB+8pOfNOwfNGhQRESccMIJUSqVYuTIka32XIHyEDJAq5k+fXrMnj07rrzyynjxxRfjF7/4RfTr1y+2bdsW55xzTvTq1Suefvrp+OUvfxkPP/xwTJo0ab8f47rrrosTTzwxli5dGhMnToxLL700Vq5cGRERTz31VEREPPzww7F+/fq49957W/T5AeXXudwDAB3Tli1b4sYbb4w5c+bE+PHjIyLisMMOi9NOOy3uuOOOeO+99+Luu++OHj16RETEnDlz4vzzz4/Zs2dHv379mvw4X/ziF2PixIkREXHFFVfEDTfcEIsWLYqjjz46+vTpExERBx98cFRXV7fwMwTaA2dkgFaxYsWKqK+vjzPPPHOv+4YOHdoQMRERp556auzatStWrVq1X49z/PHHN/y7VCpFdXV1bNy4sfmDA6kIGaBVdO/efZ/7iqKIUqm01327t3fq1CmKomi07/3339/j+l26dNnj9rt27drfcYGkhAzQKo444ojo3r17PPLII3vsO/bYY+O5556LrVu3Nmx7/PHHo1OnTnHkkUdGRESfPn1i/fr1Dft37twZy5cv368Zunbt2nBboGMSMkCr6NatW1xxxRXxve99L+6+++546aWXYsmSJXHnnXfGhRdeGN26dYvx48fH8uXLY+HChTF58uT4xje+0fD+mFGjRsX8+fNj/vz5sXLlypg4cWJs2rRpv2bo27dvdO/ePRYsWBBvvPFGbN68uRWeKVBOQgZoNVdeeWVMmzYtfvjDH8YxxxwT48aNi40bN8aBBx4YDz74YLz99tsxfPjw+MpXvhJnnnlmzJkzp+G2F110UYwfPz6++c1vxumnnx6DBg2KM844Y78ev3PnznHTTTfF7bffHjU1NTFmzJiWfopAmZWKD/4QGgAgCWdkAIC0hAwAkJaQAQDSEjIAQFpCBgBIS8gAAGkJGQAgLSEDAKQlZACAtIQMAJCWkAEA0vq/E3SBUUeHU6kAAAAASUVORK5CYII=",
|
| 659 |
+
"text/plain": [
|
| 660 |
+
"<Figure size 640x480 with 1 Axes>"
|
| 661 |
+
]
|
| 662 |
+
},
|
| 663 |
+
"metadata": {},
|
| 664 |
+
"output_type": "display_data"
|
| 665 |
+
}
|
| 666 |
+
],
|
| 667 |
+
"source": [
|
| 668 |
+
"sns.countplot(y='label', data=test_df, palette=\"Set1\") "
|
| 669 |
+
]
|
| 670 |
+
},
|
| 671 |
+
{
|
| 672 |
+
"cell_type": "code",
|
| 673 |
+
"execution_count": null,
|
| 674 |
+
"metadata": {},
|
| 675 |
+
"outputs": [],
|
| 676 |
+
"source": []
|
| 677 |
+
}
|
| 678 |
+
],
|
| 679 |
+
"metadata": {
|
| 680 |
+
"kernelspec": {
|
| 681 |
+
"display_name": "Python 3.8.13 (conda)",
|
| 682 |
+
"language": "python",
|
| 683 |
+
"name": "python3"
|
| 684 |
+
},
|
| 685 |
+
"language_info": {
|
| 686 |
+
"codemirror_mode": {
|
| 687 |
+
"name": "ipython",
|
| 688 |
+
"version": 3
|
| 689 |
+
},
|
| 690 |
+
"file_extension": ".py",
|
| 691 |
+
"mimetype": "text/x-python",
|
| 692 |
+
"name": "python",
|
| 693 |
+
"nbconvert_exporter": "python",
|
| 694 |
+
"pygments_lexer": "ipython3",
|
| 695 |
+
"version": "3.8.13"
|
| 696 |
+
},
|
| 697 |
+
"orig_nbformat": 4,
|
| 698 |
+
"vscode": {
|
| 699 |
+
"interpreter": {
|
| 700 |
+
"hash": "9260f401923fb5c4108c543a7d176de9733d378b3752e49535ad7c43c2271b65"
|
| 701 |
+
}
|
| 702 |
+
}
|
| 703 |
+
},
|
| 704 |
+
"nbformat": 4,
|
| 705 |
+
"nbformat_minor": 2
|
| 706 |
+
}
|
core/plank_model/2.sklearn.ipynb
ADDED
|
@@ -0,0 +1,762 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": 53,
|
| 6 |
+
"metadata": {},
|
| 7 |
+
"outputs": [],
|
| 8 |
+
"source": [
|
| 9 |
+
"import mediapipe as mp\n",
|
| 10 |
+
"import cv2\n",
|
| 11 |
+
"import pandas as pd\n",
|
| 12 |
+
"import pickle\n",
|
| 13 |
+
"\n",
|
| 14 |
+
"from sklearn.model_selection import train_test_split\n",
|
| 15 |
+
"from sklearn.preprocessing import StandardScaler\n",
|
| 16 |
+
"from sklearn.calibration import CalibratedClassifierCV\n",
|
| 17 |
+
"from sklearn.linear_model import LogisticRegression, SGDClassifier\n",
|
| 18 |
+
"from sklearn.svm import SVC\n",
|
| 19 |
+
"from sklearn.neighbors import KNeighborsClassifier\n",
|
| 20 |
+
"from sklearn.tree import DecisionTreeClassifier\n",
|
| 21 |
+
"from sklearn.ensemble import RandomForestClassifier\n",
|
| 22 |
+
"from sklearn.naive_bayes import GaussianNB\n",
|
| 23 |
+
"\n",
|
| 24 |
+
"from sklearn.metrics import precision_score, accuracy_score, f1_score, recall_score, confusion_matrix\n",
|
| 25 |
+
"\n",
|
| 26 |
+
"import warnings\n",
|
| 27 |
+
"warnings.filterwarnings('ignore')\n",
|
| 28 |
+
"\n",
|
| 29 |
+
"# Drawing helpers\n",
|
| 30 |
+
"mp_drawing = mp.solutions.drawing_utils\n",
|
| 31 |
+
"mp_pose = mp.solutions.pose"
|
| 32 |
+
]
|
| 33 |
+
},
|
| 34 |
+
{
|
| 35 |
+
"cell_type": "markdown",
|
| 36 |
+
"metadata": {},
|
| 37 |
+
"source": [
|
| 38 |
+
"### 1. Train Model"
|
| 39 |
+
]
|
| 40 |
+
},
|
| 41 |
+
{
|
| 42 |
+
"cell_type": "markdown",
|
| 43 |
+
"metadata": {},
|
| 44 |
+
"source": [
|
| 45 |
+
"#### 1.1. Describe data and split dataset"
|
| 46 |
+
]
|
| 47 |
+
},
|
| 48 |
+
{
|
| 49 |
+
"cell_type": "code",
|
| 50 |
+
"execution_count": 54,
|
| 51 |
+
"metadata": {},
|
| 52 |
+
"outputs": [],
|
| 53 |
+
"source": [
|
| 54 |
+
"def rescale_frame(frame, percent=50):\n",
|
| 55 |
+
" '''\n",
|
| 56 |
+
" Rescale a frame to a certain percentage compare to its original frame\n",
|
| 57 |
+
" '''\n",
|
| 58 |
+
" width = int(frame.shape[1] * percent/ 100)\n",
|
| 59 |
+
" height = int(frame.shape[0] * percent/ 100)\n",
|
| 60 |
+
" dim = (width, height)\n",
|
| 61 |
+
" return cv2.resize(frame, dim, interpolation = cv2.INTER_AREA)\n",
|
| 62 |
+
"\n",
|
| 63 |
+
"\n",
|
| 64 |
+
"def describe_dataset(dataset_path: str):\n",
|
| 65 |
+
" '''\n",
|
| 66 |
+
" Describe dataset\n",
|
| 67 |
+
" '''\n",
|
| 68 |
+
"\n",
|
| 69 |
+
" data = pd.read_csv(dataset_path)\n",
|
| 70 |
+
" print(f\"Headers: {list(data.columns.values)}\")\n",
|
| 71 |
+
" print(f'Number of rows: {data.shape[0]} \\nNumber of columns: {data.shape[1]}\\n')\n",
|
| 72 |
+
" print(f\"Labels: \\n{data['label'].value_counts()}\\n\")\n",
|
| 73 |
+
" print(f\"Missing values: {data.isnull().values.any()}\\n\")\n",
|
| 74 |
+
" \n",
|
| 75 |
+
" duplicate = data[data.duplicated()]\n",
|
| 76 |
+
" print(f\"Duplicate Rows : {len(duplicate.sum(axis=1))}\")\n",
|
| 77 |
+
"\n",
|
| 78 |
+
" return data\n",
|
| 79 |
+
"\n",
|
| 80 |
+
"\n",
|
| 81 |
+
"def round_up_metric_results(results) -> list:\n",
|
| 82 |
+
" '''Round up metrics results such as precision score, recall score, ...'''\n",
|
| 83 |
+
" return list(map(lambda el: round(el, 3), results))"
|
| 84 |
+
]
|
| 85 |
+
},
|
| 86 |
+
{
|
| 87 |
+
"cell_type": "code",
|
| 88 |
+
"execution_count": 56,
|
| 89 |
+
"metadata": {},
|
| 90 |
+
"outputs": [
|
| 91 |
+
{
|
| 92 |
+
"name": "stdout",
|
| 93 |
+
"output_type": "stream",
|
| 94 |
+
"text": [
|
| 95 |
+
"Headers: ['label', 'nose_x', 'nose_y', 'nose_z', 'nose_v', 'left_shoulder_x', 'left_shoulder_y', 'left_shoulder_z', 'left_shoulder_v', 'right_shoulder_x', 'right_shoulder_y', 'right_shoulder_z', 'right_shoulder_v', 'left_elbow_x', 'left_elbow_y', 'left_elbow_z', 'left_elbow_v', 'right_elbow_x', 'right_elbow_y', 'right_elbow_z', 'right_elbow_v', 'left_wrist_x', 'left_wrist_y', 'left_wrist_z', 'left_wrist_v', 'right_wrist_x', 'right_wrist_y', 'right_wrist_z', 'right_wrist_v', 'left_hip_x', 'left_hip_y', 'left_hip_z', 'left_hip_v', 'right_hip_x', 'right_hip_y', 'right_hip_z', 'right_hip_v', 'left_knee_x', 'left_knee_y', 'left_knee_z', 'left_knee_v', 'right_knee_x', 'right_knee_y', 'right_knee_z', 'right_knee_v', 'left_ankle_x', 'left_ankle_y', 'left_ankle_z', 'left_ankle_v', 'right_ankle_x', 'right_ankle_y', 'right_ankle_z', 'right_ankle_v', 'left_heel_x', 'left_heel_y', 'left_heel_z', 'left_heel_v', 'right_heel_x', 'right_heel_y', 'right_heel_z', 'right_heel_v', 'left_foot_index_x', 'left_foot_index_y', 'left_foot_index_z', 'left_foot_index_v', 'right_foot_index_x', 'right_foot_index_y', 'right_foot_index_z', 'right_foot_index_v']\n",
|
| 96 |
+
"Number of rows: 28520 \n",
|
| 97 |
+
"Number of columns: 69\n",
|
| 98 |
+
"\n",
|
| 99 |
+
"Labels: \n",
|
| 100 |
+
"C 9904\n",
|
| 101 |
+
"L 9546\n",
|
| 102 |
+
"H 9070\n",
|
| 103 |
+
"Name: label, dtype: int64\n",
|
| 104 |
+
"\n",
|
| 105 |
+
"Missing values: False\n",
|
| 106 |
+
"\n",
|
| 107 |
+
"Duplicate Rows : 0\n"
|
| 108 |
+
]
|
| 109 |
+
},
|
| 110 |
+
{
|
| 111 |
+
"data": {
|
| 112 |
+
"text/html": [
|
| 113 |
+
"<div>\n",
|
| 114 |
+
"<style scoped>\n",
|
| 115 |
+
" .dataframe tbody tr th:only-of-type {\n",
|
| 116 |
+
" vertical-align: middle;\n",
|
| 117 |
+
" }\n",
|
| 118 |
+
"\n",
|
| 119 |
+
" .dataframe tbody tr th {\n",
|
| 120 |
+
" vertical-align: top;\n",
|
| 121 |
+
" }\n",
|
| 122 |
+
"\n",
|
| 123 |
+
" .dataframe thead th {\n",
|
| 124 |
+
" text-align: right;\n",
|
| 125 |
+
" }\n",
|
| 126 |
+
"</style>\n",
|
| 127 |
+
"<table border=\"1\" class=\"dataframe\">\n",
|
| 128 |
+
" <thead>\n",
|
| 129 |
+
" <tr style=\"text-align: right;\">\n",
|
| 130 |
+
" <th></th>\n",
|
| 131 |
+
" <th>label</th>\n",
|
| 132 |
+
" <th>nose_x</th>\n",
|
| 133 |
+
" <th>nose_y</th>\n",
|
| 134 |
+
" <th>nose_z</th>\n",
|
| 135 |
+
" <th>nose_v</th>\n",
|
| 136 |
+
" <th>left_shoulder_x</th>\n",
|
| 137 |
+
" <th>left_shoulder_y</th>\n",
|
| 138 |
+
" <th>left_shoulder_z</th>\n",
|
| 139 |
+
" <th>left_shoulder_v</th>\n",
|
| 140 |
+
" <th>right_shoulder_x</th>\n",
|
| 141 |
+
" <th>...</th>\n",
|
| 142 |
+
" <th>right_heel_z</th>\n",
|
| 143 |
+
" <th>right_heel_v</th>\n",
|
| 144 |
+
" <th>left_foot_index_x</th>\n",
|
| 145 |
+
" <th>left_foot_index_y</th>\n",
|
| 146 |
+
" <th>left_foot_index_z</th>\n",
|
| 147 |
+
" <th>left_foot_index_v</th>\n",
|
| 148 |
+
" <th>right_foot_index_x</th>\n",
|
| 149 |
+
" <th>right_foot_index_y</th>\n",
|
| 150 |
+
" <th>right_foot_index_z</th>\n",
|
| 151 |
+
" <th>right_foot_index_v</th>\n",
|
| 152 |
+
" </tr>\n",
|
| 153 |
+
" </thead>\n",
|
| 154 |
+
" <tbody>\n",
|
| 155 |
+
" <tr>\n",
|
| 156 |
+
" <th>28517</th>\n",
|
| 157 |
+
" <td>1</td>\n",
|
| 158 |
+
" <td>0.735630</td>\n",
|
| 159 |
+
" <td>0.543294</td>\n",
|
| 160 |
+
" <td>0.007467</td>\n",
|
| 161 |
+
" <td>0.999246</td>\n",
|
| 162 |
+
" <td>0.695831</td>\n",
|
| 163 |
+
" <td>0.417349</td>\n",
|
| 164 |
+
" <td>0.155194</td>\n",
|
| 165 |
+
" <td>0.995723</td>\n",
|
| 166 |
+
" <td>0.720067</td>\n",
|
| 167 |
+
" <td>...</td>\n",
|
| 168 |
+
" <td>0.086010</td>\n",
|
| 169 |
+
" <td>0.966131</td>\n",
|
| 170 |
+
" <td>0.226601</td>\n",
|
| 171 |
+
" <td>0.598075</td>\n",
|
| 172 |
+
" <td>0.219305</td>\n",
|
| 173 |
+
" <td>0.470830</td>\n",
|
| 174 |
+
" <td>0.220079</td>\n",
|
| 175 |
+
" <td>0.614120</td>\n",
|
| 176 |
+
" <td>0.026265</td>\n",
|
| 177 |
+
" <td>0.934942</td>\n",
|
| 178 |
+
" </tr>\n",
|
| 179 |
+
" <tr>\n",
|
| 180 |
+
" <th>28518</th>\n",
|
| 181 |
+
" <td>1</td>\n",
|
| 182 |
+
" <td>0.775572</td>\n",
|
| 183 |
+
" <td>0.517579</td>\n",
|
| 184 |
+
" <td>0.012821</td>\n",
|
| 185 |
+
" <td>0.999378</td>\n",
|
| 186 |
+
" <td>0.704168</td>\n",
|
| 187 |
+
" <td>0.404210</td>\n",
|
| 188 |
+
" <td>0.162908</td>\n",
|
| 189 |
+
" <td>0.995909</td>\n",
|
| 190 |
+
" <td>0.730823</td>\n",
|
| 191 |
+
" <td>...</td>\n",
|
| 192 |
+
" <td>0.070911</td>\n",
|
| 193 |
+
" <td>0.967070</td>\n",
|
| 194 |
+
" <td>0.238810</td>\n",
|
| 195 |
+
" <td>0.610591</td>\n",
|
| 196 |
+
" <td>0.198591</td>\n",
|
| 197 |
+
" <td>0.496140</td>\n",
|
| 198 |
+
" <td>0.228907</td>\n",
|
| 199 |
+
" <td>0.625559</td>\n",
|
| 200 |
+
" <td>0.018591</td>\n",
|
| 201 |
+
" <td>0.938905</td>\n",
|
| 202 |
+
" </tr>\n",
|
| 203 |
+
" <tr>\n",
|
| 204 |
+
" <th>28519</th>\n",
|
| 205 |
+
" <td>1</td>\n",
|
| 206 |
+
" <td>0.790600</td>\n",
|
| 207 |
+
" <td>0.498958</td>\n",
|
| 208 |
+
" <td>0.007789</td>\n",
|
| 209 |
+
" <td>0.999467</td>\n",
|
| 210 |
+
" <td>0.710651</td>\n",
|
| 211 |
+
" <td>0.394019</td>\n",
|
| 212 |
+
" <td>0.164441</td>\n",
|
| 213 |
+
" <td>0.996123</td>\n",
|
| 214 |
+
" <td>0.736771</td>\n",
|
| 215 |
+
" <td>...</td>\n",
|
| 216 |
+
" <td>0.085872</td>\n",
|
| 217 |
+
" <td>0.967943</td>\n",
|
| 218 |
+
" <td>0.238197</td>\n",
|
| 219 |
+
" <td>0.609329</td>\n",
|
| 220 |
+
" <td>0.233198</td>\n",
|
| 221 |
+
" <td>0.510583</td>\n",
|
| 222 |
+
" <td>0.227823</td>\n",
|
| 223 |
+
" <td>0.626068</td>\n",
|
| 224 |
+
" <td>0.036127</td>\n",
|
| 225 |
+
" <td>0.940917</td>\n",
|
| 226 |
+
" </tr>\n",
|
| 227 |
+
" </tbody>\n",
|
| 228 |
+
"</table>\n",
|
| 229 |
+
"<p>3 rows × 69 columns</p>\n",
|
| 230 |
+
"</div>"
|
| 231 |
+
],
|
| 232 |
+
"text/plain": [
|
| 233 |
+
" label nose_x nose_y nose_z nose_v left_shoulder_x \\\n",
|
| 234 |
+
"28517 1 0.735630 0.543294 0.007467 0.999246 0.695831 \n",
|
| 235 |
+
"28518 1 0.775572 0.517579 0.012821 0.999378 0.704168 \n",
|
| 236 |
+
"28519 1 0.790600 0.498958 0.007789 0.999467 0.710651 \n",
|
| 237 |
+
"\n",
|
| 238 |
+
" left_shoulder_y left_shoulder_z left_shoulder_v right_shoulder_x \\\n",
|
| 239 |
+
"28517 0.417349 0.155194 0.995723 0.720067 \n",
|
| 240 |
+
"28518 0.404210 0.162908 0.995909 0.730823 \n",
|
| 241 |
+
"28519 0.394019 0.164441 0.996123 0.736771 \n",
|
| 242 |
+
"\n",
|
| 243 |
+
" ... right_heel_z right_heel_v left_foot_index_x left_foot_index_y \\\n",
|
| 244 |
+
"28517 ... 0.086010 0.966131 0.226601 0.598075 \n",
|
| 245 |
+
"28518 ... 0.070911 0.967070 0.238810 0.610591 \n",
|
| 246 |
+
"28519 ... 0.085872 0.967943 0.238197 0.609329 \n",
|
| 247 |
+
"\n",
|
| 248 |
+
" left_foot_index_z left_foot_index_v right_foot_index_x \\\n",
|
| 249 |
+
"28517 0.219305 0.470830 0.220079 \n",
|
| 250 |
+
"28518 0.198591 0.496140 0.228907 \n",
|
| 251 |
+
"28519 0.233198 0.510583 0.227823 \n",
|
| 252 |
+
"\n",
|
| 253 |
+
" right_foot_index_y right_foot_index_z right_foot_index_v \n",
|
| 254 |
+
"28517 0.614120 0.026265 0.934942 \n",
|
| 255 |
+
"28518 0.625559 0.018591 0.938905 \n",
|
| 256 |
+
"28519 0.626068 0.036127 0.940917 \n",
|
| 257 |
+
"\n",
|
| 258 |
+
"[3 rows x 69 columns]"
|
| 259 |
+
]
|
| 260 |
+
},
|
| 261 |
+
"execution_count": 56,
|
| 262 |
+
"metadata": {},
|
| 263 |
+
"output_type": "execute_result"
|
| 264 |
+
}
|
| 265 |
+
],
|
| 266 |
+
"source": [
|
| 267 |
+
"df = describe_dataset(\"./train.csv\")\n",
|
| 268 |
+
"df.loc[df[\"label\"] == \"C\", \"label\"] = 0\n",
|
| 269 |
+
"df.loc[df[\"label\"] == \"H\", \"label\"] = 1\n",
|
| 270 |
+
"df.loc[df[\"label\"] == \"L\", \"label\"] = 2\n",
|
| 271 |
+
"df.tail(3)"
|
| 272 |
+
]
|
| 273 |
+
},
|
| 274 |
+
{
|
| 275 |
+
"cell_type": "code",
|
| 276 |
+
"execution_count": 62,
|
| 277 |
+
"metadata": {},
|
| 278 |
+
"outputs": [],
|
| 279 |
+
"source": [
|
| 280 |
+
"# Extract features and class\n",
|
| 281 |
+
"X = df.drop(\"label\", axis=1)\n",
|
| 282 |
+
"y = df[\"label\"].astype(\"int\")"
|
| 283 |
+
]
|
| 284 |
+
},
|
| 285 |
+
{
|
| 286 |
+
"cell_type": "code",
|
| 287 |
+
"execution_count": 64,
|
| 288 |
+
"metadata": {},
|
| 289 |
+
"outputs": [],
|
| 290 |
+
"source": [
|
| 291 |
+
"sc = StandardScaler()\n",
|
| 292 |
+
"X = pd.DataFrame(sc.fit_transform(X))"
|
| 293 |
+
]
|
| 294 |
+
},
|
| 295 |
+
{
|
| 296 |
+
"cell_type": "code",
|
| 297 |
+
"execution_count": 65,
|
| 298 |
+
"metadata": {},
|
| 299 |
+
"outputs": [
|
| 300 |
+
{
|
| 301 |
+
"data": {
|
| 302 |
+
"text/plain": [
|
| 303 |
+
"1469 0\n",
|
| 304 |
+
"292 0\n",
|
| 305 |
+
"1568 0\n",
|
| 306 |
+
"Name: label, dtype: int64"
|
| 307 |
+
]
|
| 308 |
+
},
|
| 309 |
+
"execution_count": 65,
|
| 310 |
+
"metadata": {},
|
| 311 |
+
"output_type": "execute_result"
|
| 312 |
+
}
|
| 313 |
+
],
|
| 314 |
+
"source": [
|
| 315 |
+
"X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=1234)\n",
|
| 316 |
+
"y_test.head(3)"
|
| 317 |
+
]
|
| 318 |
+
},
|
| 319 |
+
{
|
| 320 |
+
"cell_type": "markdown",
|
| 321 |
+
"metadata": {},
|
| 322 |
+
"source": [
|
| 323 |
+
"#### 1.2. Train model using Scikit-Learn and train set evaluation"
|
| 324 |
+
]
|
| 325 |
+
},
|
| 326 |
+
{
|
| 327 |
+
"cell_type": "code",
|
| 328 |
+
"execution_count": 66,
|
| 329 |
+
"metadata": {},
|
| 330 |
+
"outputs": [],
|
| 331 |
+
"source": [
|
| 332 |
+
"algorithms =[(\"LR\", LogisticRegression()),\n",
|
| 333 |
+
" (\"SVC\", SVC(probability=True)),\n",
|
| 334 |
+
" ('KNN',KNeighborsClassifier()),\n",
|
| 335 |
+
" (\"DTC\", DecisionTreeClassifier()),\n",
|
| 336 |
+
" (\"SGDC\", CalibratedClassifierCV(SGDClassifier())),\n",
|
| 337 |
+
" (\"NB\", GaussianNB()),\n",
|
| 338 |
+
" ('RF', RandomForestClassifier()),]\n",
|
| 339 |
+
"\n",
|
| 340 |
+
"models = {}\n",
|
| 341 |
+
"final_results = []\n",
|
| 342 |
+
"\n",
|
| 343 |
+
"for name, model in algorithms:\n",
|
| 344 |
+
" trained_model = model.fit(X_train, y_train)\n",
|
| 345 |
+
" models[name] = trained_model\n",
|
| 346 |
+
"\n",
|
| 347 |
+
" # Evaluate model\n",
|
| 348 |
+
" model_results = model.predict(X_test)\n",
|
| 349 |
+
"\n",
|
| 350 |
+
" p_score = precision_score(y_test, model_results, average=None, labels=[0, 1, 2])\n",
|
| 351 |
+
" a_score = accuracy_score(y_test, model_results)\n",
|
| 352 |
+
" r_score = recall_score(y_test, model_results, average=None, labels=[0, 1, 2])\n",
|
| 353 |
+
" f1_score_result = f1_score(y_test, model_results, average=None, labels=[0, 1, 2])\n",
|
| 354 |
+
" cm = confusion_matrix(y_test, model_results, labels=[0, 1, 2])\n",
|
| 355 |
+
" final_results.append(( name, round_up_metric_results(p_score), a_score, round_up_metric_results(r_score), round_up_metric_results(f1_score_result), cm))\n"
|
| 356 |
+
]
|
| 357 |
+
},
|
| 358 |
+
{
|
| 359 |
+
"cell_type": "code",
|
| 360 |
+
"execution_count": 67,
|
| 361 |
+
"metadata": {},
|
| 362 |
+
"outputs": [
|
| 363 |
+
{
|
| 364 |
+
"data": {
|
| 365 |
+
"text/html": [
|
| 366 |
+
"<div>\n",
|
| 367 |
+
"<style scoped>\n",
|
| 368 |
+
" .dataframe tbody tr th:only-of-type {\n",
|
| 369 |
+
" vertical-align: middle;\n",
|
| 370 |
+
" }\n",
|
| 371 |
+
"\n",
|
| 372 |
+
" .dataframe tbody tr th {\n",
|
| 373 |
+
" vertical-align: top;\n",
|
| 374 |
+
" }\n",
|
| 375 |
+
"\n",
|
| 376 |
+
" .dataframe thead th {\n",
|
| 377 |
+
" text-align: right;\n",
|
| 378 |
+
" }\n",
|
| 379 |
+
"</style>\n",
|
| 380 |
+
"<table border=\"1\" class=\"dataframe\">\n",
|
| 381 |
+
" <thead>\n",
|
| 382 |
+
" <tr style=\"text-align: right;\">\n",
|
| 383 |
+
" <th></th>\n",
|
| 384 |
+
" <th>Model</th>\n",
|
| 385 |
+
" <th>Precision Score</th>\n",
|
| 386 |
+
" <th>Accuracy score</th>\n",
|
| 387 |
+
" <th>Recall Score</th>\n",
|
| 388 |
+
" <th>F1 score</th>\n",
|
| 389 |
+
" <th>Confusion Matrix</th>\n",
|
| 390 |
+
" </tr>\n",
|
| 391 |
+
" </thead>\n",
|
| 392 |
+
" <tbody>\n",
|
| 393 |
+
" <tr>\n",
|
| 394 |
+
" <th>0</th>\n",
|
| 395 |
+
" <td>KNN</td>\n",
|
| 396 |
+
" <td>[0.999, 1.0, 1.0]</td>\n",
|
| 397 |
+
" <td>0.999825</td>\n",
|
| 398 |
+
" <td>[1.0, 1.0, 0.999]</td>\n",
|
| 399 |
+
" <td>[1.0, 1.0, 1.0]</td>\n",
|
| 400 |
+
" <td>[[1915, 0, 0], [0, 1844, 0], [1, 0, 1944]]</td>\n",
|
| 401 |
+
" </tr>\n",
|
| 402 |
+
" <tr>\n",
|
| 403 |
+
" <th>1</th>\n",
|
| 404 |
+
" <td>LR</td>\n",
|
| 405 |
+
" <td>[0.999, 1.0, 0.999]</td>\n",
|
| 406 |
+
" <td>0.999649</td>\n",
|
| 407 |
+
" <td>[0.999, 1.0, 0.999]</td>\n",
|
| 408 |
+
" <td>[0.999, 1.0, 0.999]</td>\n",
|
| 409 |
+
" <td>[[1914, 0, 1], [0, 1844, 0], [1, 0, 1944]]</td>\n",
|
| 410 |
+
" </tr>\n",
|
| 411 |
+
" <tr>\n",
|
| 412 |
+
" <th>2</th>\n",
|
| 413 |
+
" <td>SVC</td>\n",
|
| 414 |
+
" <td>[0.998, 1.0, 0.999]</td>\n",
|
| 415 |
+
" <td>0.999299</td>\n",
|
| 416 |
+
" <td>[0.999, 1.0, 0.998]</td>\n",
|
| 417 |
+
" <td>[0.999, 1.0, 0.999]</td>\n",
|
| 418 |
+
" <td>[[1914, 0, 1], [0, 1844, 0], [3, 0, 1942]]</td>\n",
|
| 419 |
+
" </tr>\n",
|
| 420 |
+
" <tr>\n",
|
| 421 |
+
" <th>3</th>\n",
|
| 422 |
+
" <td>RF</td>\n",
|
| 423 |
+
" <td>[0.998, 1.0, 1.0]</td>\n",
|
| 424 |
+
" <td>0.999474</td>\n",
|
| 425 |
+
" <td>[1.0, 0.999, 0.999]</td>\n",
|
| 426 |
+
" <td>[0.999, 1.0, 0.999]</td>\n",
|
| 427 |
+
" <td>[[1915, 0, 0], [1, 1843, 0], [2, 0, 1943]]</td>\n",
|
| 428 |
+
" </tr>\n",
|
| 429 |
+
" <tr>\n",
|
| 430 |
+
" <th>4</th>\n",
|
| 431 |
+
" <td>SGDC</td>\n",
|
| 432 |
+
" <td>[0.999, 0.998, 0.999]</td>\n",
|
| 433 |
+
" <td>0.998597</td>\n",
|
| 434 |
+
" <td>[0.997, 1.0, 0.999]</td>\n",
|
| 435 |
+
" <td>[0.998, 0.999, 0.999]</td>\n",
|
| 436 |
+
" <td>[[1909, 4, 2], [0, 1844, 0], [2, 0, 1943]]</td>\n",
|
| 437 |
+
" </tr>\n",
|
| 438 |
+
" <tr>\n",
|
| 439 |
+
" <th>5</th>\n",
|
| 440 |
+
" <td>DTC</td>\n",
|
| 441 |
+
" <td>[0.994, 1.0, 0.999]</td>\n",
|
| 442 |
+
" <td>0.997721</td>\n",
|
| 443 |
+
" <td>[0.999, 0.998, 0.995]</td>\n",
|
| 444 |
+
" <td>[0.997, 0.999, 0.997]</td>\n",
|
| 445 |
+
" <td>[[1914, 0, 1], [3, 1841, 0], [9, 0, 1936]]</td>\n",
|
| 446 |
+
" </tr>\n",
|
| 447 |
+
" <tr>\n",
|
| 448 |
+
" <th>6</th>\n",
|
| 449 |
+
" <td>NB</td>\n",
|
| 450 |
+
" <td>[0.816, 0.931, 0.941]</td>\n",
|
| 451 |
+
" <td>0.892532</td>\n",
|
| 452 |
+
" <td>[0.883, 0.951, 0.847]</td>\n",
|
| 453 |
+
" <td>[0.848, 0.941, 0.892]</td>\n",
|
| 454 |
+
" <td>[[1690, 122, 103], [91, 1753, 0], [290, 7, 1648]]</td>\n",
|
| 455 |
+
" </tr>\n",
|
| 456 |
+
" </tbody>\n",
|
| 457 |
+
"</table>\n",
|
| 458 |
+
"</div>"
|
| 459 |
+
],
|
| 460 |
+
"text/plain": [
|
| 461 |
+
" Model Precision Score Accuracy score Recall Score \\\n",
|
| 462 |
+
"0 KNN [0.999, 1.0, 1.0] 0.999825 [1.0, 1.0, 0.999] \n",
|
| 463 |
+
"1 LR [0.999, 1.0, 0.999] 0.999649 [0.999, 1.0, 0.999] \n",
|
| 464 |
+
"2 SVC [0.998, 1.0, 0.999] 0.999299 [0.999, 1.0, 0.998] \n",
|
| 465 |
+
"3 RF [0.998, 1.0, 1.0] 0.999474 [1.0, 0.999, 0.999] \n",
|
| 466 |
+
"4 SGDC [0.999, 0.998, 0.999] 0.998597 [0.997, 1.0, 0.999] \n",
|
| 467 |
+
"5 DTC [0.994, 1.0, 0.999] 0.997721 [0.999, 0.998, 0.995] \n",
|
| 468 |
+
"6 NB [0.816, 0.931, 0.941] 0.892532 [0.883, 0.951, 0.847] \n",
|
| 469 |
+
"\n",
|
| 470 |
+
" F1 score Confusion Matrix \n",
|
| 471 |
+
"0 [1.0, 1.0, 1.0] [[1915, 0, 0], [0, 1844, 0], [1, 0, 1944]] \n",
|
| 472 |
+
"1 [0.999, 1.0, 0.999] [[1914, 0, 1], [0, 1844, 0], [1, 0, 1944]] \n",
|
| 473 |
+
"2 [0.999, 1.0, 0.999] [[1914, 0, 1], [0, 1844, 0], [3, 0, 1942]] \n",
|
| 474 |
+
"3 [0.999, 1.0, 0.999] [[1915, 0, 0], [1, 1843, 0], [2, 0, 1943]] \n",
|
| 475 |
+
"4 [0.998, 0.999, 0.999] [[1909, 4, 2], [0, 1844, 0], [2, 0, 1943]] \n",
|
| 476 |
+
"5 [0.997, 0.999, 0.997] [[1914, 0, 1], [3, 1841, 0], [9, 0, 1936]] \n",
|
| 477 |
+
"6 [0.848, 0.941, 0.892] [[1690, 122, 103], [91, 1753, 0], [290, 7, 1648]] "
|
| 478 |
+
]
|
| 479 |
+
},
|
| 480 |
+
"execution_count": 67,
|
| 481 |
+
"metadata": {},
|
| 482 |
+
"output_type": "execute_result"
|
| 483 |
+
}
|
| 484 |
+
],
|
| 485 |
+
"source": [
|
| 486 |
+
"# Sort results by F1 score\n",
|
| 487 |
+
"final_results.sort(key=lambda k: sum(k[4]), reverse=True)\n",
|
| 488 |
+
"\n",
|
| 489 |
+
"pd.DataFrame(final_results, columns=[\"Model\", \"Precision Score\", \"Accuracy score\", \"Recall Score\", \"F1 score\", \"Confusion Matrix\"])"
|
| 490 |
+
]
|
| 491 |
+
},
|
| 492 |
+
{
|
| 493 |
+
"cell_type": "markdown",
|
| 494 |
+
"metadata": {},
|
| 495 |
+
"source": [
|
| 496 |
+
"#### 1.3. Test set evaluation"
|
| 497 |
+
]
|
| 498 |
+
},
|
| 499 |
+
{
|
| 500 |
+
"cell_type": "code",
|
| 501 |
+
"execution_count": 68,
|
| 502 |
+
"metadata": {},
|
| 503 |
+
"outputs": [
|
| 504 |
+
{
|
| 505 |
+
"name": "stdout",
|
| 506 |
+
"output_type": "stream",
|
| 507 |
+
"text": [
|
| 508 |
+
"Headers: ['label', 'nose_x', 'nose_y', 'nose_z', 'nose_v', 'left_shoulder_x', 'left_shoulder_y', 'left_shoulder_z', 'left_shoulder_v', 'right_shoulder_x', 'right_shoulder_y', 'right_shoulder_z', 'right_shoulder_v', 'left_elbow_x', 'left_elbow_y', 'left_elbow_z', 'left_elbow_v', 'right_elbow_x', 'right_elbow_y', 'right_elbow_z', 'right_elbow_v', 'left_wrist_x', 'left_wrist_y', 'left_wrist_z', 'left_wrist_v', 'right_wrist_x', 'right_wrist_y', 'right_wrist_z', 'right_wrist_v', 'left_hip_x', 'left_hip_y', 'left_hip_z', 'left_hip_v', 'right_hip_x', 'right_hip_y', 'right_hip_z', 'right_hip_v', 'left_knee_x', 'left_knee_y', 'left_knee_z', 'left_knee_v', 'right_knee_x', 'right_knee_y', 'right_knee_z', 'right_knee_v', 'left_ankle_x', 'left_ankle_y', 'left_ankle_z', 'left_ankle_v', 'right_ankle_x', 'right_ankle_y', 'right_ankle_z', 'right_ankle_v', 'left_heel_x', 'left_heel_y', 'left_heel_z', 'left_heel_v', 'right_heel_x', 'right_heel_y', 'right_heel_z', 'right_heel_v', 'left_foot_index_x', 'left_foot_index_y', 'left_foot_index_z', 'left_foot_index_v', 'right_foot_index_x', 'right_foot_index_y', 'right_foot_index_z', 'right_foot_index_v']\n",
|
| 509 |
+
"Number of rows: 710 \n",
|
| 510 |
+
"Number of columns: 69\n",
|
| 511 |
+
"\n",
|
| 512 |
+
"Labels: \n",
|
| 513 |
+
"H 241\n",
|
| 514 |
+
"L 235\n",
|
| 515 |
+
"C 234\n",
|
| 516 |
+
"Name: label, dtype: int64\n",
|
| 517 |
+
"\n",
|
| 518 |
+
"Missing values: False\n",
|
| 519 |
+
"\n",
|
| 520 |
+
"Duplicate Rows : 0\n"
|
| 521 |
+
]
|
| 522 |
+
}
|
| 523 |
+
],
|
| 524 |
+
"source": [
|
| 525 |
+
"test_df = describe_dataset(\"./test.csv\")\n",
|
| 526 |
+
"test_df = test_df.sample(frac=1).reset_index(drop=True)\n",
|
| 527 |
+
"\n",
|
| 528 |
+
"test_df.loc[test_df[\"label\"] == \"C\", \"label\"] = 0\n",
|
| 529 |
+
"test_df.loc[test_df[\"label\"] == \"H\", \"label\"] = 1\n",
|
| 530 |
+
"test_df.loc[test_df[\"label\"] == \"L\", \"label\"] = 2\n",
|
| 531 |
+
"\n",
|
| 532 |
+
"test_x = test_df.drop(\"label\", axis=1)\n",
|
| 533 |
+
"test_y = test_df[\"label\"].astype(\"int\")\n",
|
| 534 |
+
"\n",
|
| 535 |
+
"test_x = pd.DataFrame(sc.transform(test_x))"
|
| 536 |
+
]
|
| 537 |
+
},
|
| 538 |
+
{
|
| 539 |
+
"cell_type": "code",
|
| 540 |
+
"execution_count": 70,
|
| 541 |
+
"metadata": {},
|
| 542 |
+
"outputs": [
|
| 543 |
+
{
|
| 544 |
+
"data": {
|
| 545 |
+
"text/html": [
|
| 546 |
+
"<div>\n",
|
| 547 |
+
"<style scoped>\n",
|
| 548 |
+
" .dataframe tbody tr th:only-of-type {\n",
|
| 549 |
+
" vertical-align: middle;\n",
|
| 550 |
+
" }\n",
|
| 551 |
+
"\n",
|
| 552 |
+
" .dataframe tbody tr th {\n",
|
| 553 |
+
" vertical-align: top;\n",
|
| 554 |
+
" }\n",
|
| 555 |
+
"\n",
|
| 556 |
+
" .dataframe thead th {\n",
|
| 557 |
+
" text-align: right;\n",
|
| 558 |
+
" }\n",
|
| 559 |
+
"</style>\n",
|
| 560 |
+
"<table border=\"1\" class=\"dataframe\">\n",
|
| 561 |
+
" <thead>\n",
|
| 562 |
+
" <tr style=\"text-align: right;\">\n",
|
| 563 |
+
" <th></th>\n",
|
| 564 |
+
" <th>Model</th>\n",
|
| 565 |
+
" <th>Precision Score</th>\n",
|
| 566 |
+
" <th>Accuracy score</th>\n",
|
| 567 |
+
" <th>Recall Score</th>\n",
|
| 568 |
+
" <th>F1 score</th>\n",
|
| 569 |
+
" <th>Confusion Matrix</th>\n",
|
| 570 |
+
" </tr>\n",
|
| 571 |
+
" </thead>\n",
|
| 572 |
+
" <tbody>\n",
|
| 573 |
+
" <tr>\n",
|
| 574 |
+
" <th>0</th>\n",
|
| 575 |
+
" <td>LR</td>\n",
|
| 576 |
+
" <td>[0.987, 1.0, 1.0]</td>\n",
|
| 577 |
+
" <td>0.995775</td>\n",
|
| 578 |
+
" <td>[1.0, 0.996, 0.991]</td>\n",
|
| 579 |
+
" <td>[0.994, 0.998, 0.996]</td>\n",
|
| 580 |
+
" <td>[[234, 0, 0], [1, 240, 0], [2, 0, 233]]</td>\n",
|
| 581 |
+
" </tr>\n",
|
| 582 |
+
" <tr>\n",
|
| 583 |
+
" <th>1</th>\n",
|
| 584 |
+
" <td>SVC</td>\n",
|
| 585 |
+
" <td>[0.963, 1.0, 1.0]</td>\n",
|
| 586 |
+
" <td>0.987324</td>\n",
|
| 587 |
+
" <td>[1.0, 0.992, 0.97]</td>\n",
|
| 588 |
+
" <td>[0.981, 0.996, 0.985]</td>\n",
|
| 589 |
+
" <td>[[234, 0, 0], [2, 239, 0], [7, 0, 228]]</td>\n",
|
| 590 |
+
" </tr>\n",
|
| 591 |
+
" <tr>\n",
|
| 592 |
+
" <th>2</th>\n",
|
| 593 |
+
" <td>SGDC</td>\n",
|
| 594 |
+
" <td>[0.974, 0.975, 0.996]</td>\n",
|
| 595 |
+
" <td>0.981690</td>\n",
|
| 596 |
+
" <td>[0.974, 0.983, 0.987]</td>\n",
|
| 597 |
+
" <td>[0.974, 0.979, 0.991]</td>\n",
|
| 598 |
+
" <td>[[228, 6, 0], [3, 237, 1], [3, 0, 232]]</td>\n",
|
| 599 |
+
" </tr>\n",
|
| 600 |
+
" <tr>\n",
|
| 601 |
+
" <th>3</th>\n",
|
| 602 |
+
" <td>KNN</td>\n",
|
| 603 |
+
" <td>[0.869, 0.996, 1.0]</td>\n",
|
| 604 |
+
" <td>0.949296</td>\n",
|
| 605 |
+
" <td>[0.996, 0.992, 0.86]</td>\n",
|
| 606 |
+
" <td>[0.928, 0.994, 0.924]</td>\n",
|
| 607 |
+
" <td>[[233, 1, 0], [2, 239, 0], [33, 0, 202]]</td>\n",
|
| 608 |
+
" </tr>\n",
|
| 609 |
+
" <tr>\n",
|
| 610 |
+
" <th>4</th>\n",
|
| 611 |
+
" <td>RF</td>\n",
|
| 612 |
+
" <td>[0.765, 1.0, 1.0]</td>\n",
|
| 613 |
+
" <td>0.898592</td>\n",
|
| 614 |
+
" <td>[1.0, 1.0, 0.694]</td>\n",
|
| 615 |
+
" <td>[0.867, 1.0, 0.819]</td>\n",
|
| 616 |
+
" <td>[[234, 0, 0], [0, 241, 0], [72, 0, 163]]</td>\n",
|
| 617 |
+
" </tr>\n",
|
| 618 |
+
" <tr>\n",
|
| 619 |
+
" <th>5</th>\n",
|
| 620 |
+
" <td>NB</td>\n",
|
| 621 |
+
" <td>[0.892, 0.737, 0.945]</td>\n",
|
| 622 |
+
" <td>0.842254</td>\n",
|
| 623 |
+
" <td>[0.632, 0.942, 0.949]</td>\n",
|
| 624 |
+
" <td>[0.74, 0.827, 0.947]</td>\n",
|
| 625 |
+
" <td>[[148, 73, 13], [14, 227, 0], [4, 8, 223]]</td>\n",
|
| 626 |
+
" </tr>\n",
|
| 627 |
+
" <tr>\n",
|
| 628 |
+
" <th>6</th>\n",
|
| 629 |
+
" <td>DTC</td>\n",
|
| 630 |
+
" <td>[0.69, 1.0, 0.625]</td>\n",
|
| 631 |
+
" <td>0.767606</td>\n",
|
| 632 |
+
" <td>[0.543, 0.988, 0.766]</td>\n",
|
| 633 |
+
" <td>[0.608, 0.994, 0.688]</td>\n",
|
| 634 |
+
" <td>[[127, 0, 107], [2, 238, 1], [55, 0, 180]]</td>\n",
|
| 635 |
+
" </tr>\n",
|
| 636 |
+
" </tbody>\n",
|
| 637 |
+
"</table>\n",
|
| 638 |
+
"</div>"
|
| 639 |
+
],
|
| 640 |
+
"text/plain": [
|
| 641 |
+
" Model Precision Score Accuracy score Recall Score \\\n",
|
| 642 |
+
"0 LR [0.987, 1.0, 1.0] 0.995775 [1.0, 0.996, 0.991] \n",
|
| 643 |
+
"1 SVC [0.963, 1.0, 1.0] 0.987324 [1.0, 0.992, 0.97] \n",
|
| 644 |
+
"2 SGDC [0.974, 0.975, 0.996] 0.981690 [0.974, 0.983, 0.987] \n",
|
| 645 |
+
"3 KNN [0.869, 0.996, 1.0] 0.949296 [0.996, 0.992, 0.86] \n",
|
| 646 |
+
"4 RF [0.765, 1.0, 1.0] 0.898592 [1.0, 1.0, 0.694] \n",
|
| 647 |
+
"5 NB [0.892, 0.737, 0.945] 0.842254 [0.632, 0.942, 0.949] \n",
|
| 648 |
+
"6 DTC [0.69, 1.0, 0.625] 0.767606 [0.543, 0.988, 0.766] \n",
|
| 649 |
+
"\n",
|
| 650 |
+
" F1 score Confusion Matrix \n",
|
| 651 |
+
"0 [0.994, 0.998, 0.996] [[234, 0, 0], [1, 240, 0], [2, 0, 233]] \n",
|
| 652 |
+
"1 [0.981, 0.996, 0.985] [[234, 0, 0], [2, 239, 0], [7, 0, 228]] \n",
|
| 653 |
+
"2 [0.974, 0.979, 0.991] [[228, 6, 0], [3, 237, 1], [3, 0, 232]] \n",
|
| 654 |
+
"3 [0.928, 0.994, 0.924] [[233, 1, 0], [2, 239, 0], [33, 0, 202]] \n",
|
| 655 |
+
"4 [0.867, 1.0, 0.819] [[234, 0, 0], [0, 241, 0], [72, 0, 163]] \n",
|
| 656 |
+
"5 [0.74, 0.827, 0.947] [[148, 73, 13], [14, 227, 0], [4, 8, 223]] \n",
|
| 657 |
+
"6 [0.608, 0.994, 0.688] [[127, 0, 107], [2, 238, 1], [55, 0, 180]] "
|
| 658 |
+
]
|
| 659 |
+
},
|
| 660 |
+
"execution_count": 70,
|
| 661 |
+
"metadata": {},
|
| 662 |
+
"output_type": "execute_result"
|
| 663 |
+
}
|
| 664 |
+
],
|
| 665 |
+
"source": [
|
| 666 |
+
"testset_final_results = []\n",
|
| 667 |
+
"\n",
|
| 668 |
+
"for name, model in models.items():\n",
|
| 669 |
+
" # Evaluate model\n",
|
| 670 |
+
" model_results = model.predict(test_x)\n",
|
| 671 |
+
"\n",
|
| 672 |
+
" p_score = precision_score(test_y, model_results, average=None, labels=[0, 1, 2])\n",
|
| 673 |
+
" a_score = accuracy_score(test_y, model_results)\n",
|
| 674 |
+
" r_score = recall_score(test_y, model_results, average=None, labels=[0, 1, 2])\n",
|
| 675 |
+
" f1_score_result = f1_score(test_y, model_results, average=None, labels=[0, 1, 2])\n",
|
| 676 |
+
" cm = confusion_matrix(test_y, model_results, labels=[0, 1, 2])\n",
|
| 677 |
+
" testset_final_results.append(( name, round_up_metric_results(p_score), a_score, round_up_metric_results(r_score), round_up_metric_results(f1_score_result), cm ))\n",
|
| 678 |
+
"\n",
|
| 679 |
+
"\n",
|
| 680 |
+
"testset_final_results.sort(key=lambda k: sum(k[4]), reverse=True)\n",
|
| 681 |
+
"pd.DataFrame(testset_final_results, columns=[\"Model\", \"Precision Score\", \"Accuracy score\", \"Recall Score\", \"F1 score\", \"Confusion Matrix\"])"
|
| 682 |
+
]
|
| 683 |
+
},
|
| 684 |
+
{
|
| 685 |
+
"cell_type": "markdown",
|
| 686 |
+
"metadata": {},
|
| 687 |
+
"source": [
|
| 688 |
+
"#### 1.4. Dumped model and input scaler using pickle\n",
|
| 689 |
+
"\n",
|
| 690 |
+
"According to the evaluations, there are multiple good models at the moment, therefore, the best models are LR and Ridge."
|
| 691 |
+
]
|
| 692 |
+
},
|
| 693 |
+
{
|
| 694 |
+
"cell_type": "code",
|
| 695 |
+
"execution_count": 71,
|
| 696 |
+
"metadata": {},
|
| 697 |
+
"outputs": [],
|
| 698 |
+
"source": [
|
| 699 |
+
"with open(\"./model/all_sklearn.pkl\", \"wb\") as f:\n",
|
| 700 |
+
" pickle.dump(models, f)"
|
| 701 |
+
]
|
| 702 |
+
},
|
| 703 |
+
{
|
| 704 |
+
"cell_type": "code",
|
| 705 |
+
"execution_count": 72,
|
| 706 |
+
"metadata": {},
|
| 707 |
+
"outputs": [],
|
| 708 |
+
"source": [
|
| 709 |
+
"with open(\"./model/LR_model.pkl\", \"wb\") as f:\n",
|
| 710 |
+
" pickle.dump(models[\"LR\"], f)"
|
| 711 |
+
]
|
| 712 |
+
},
|
| 713 |
+
{
|
| 714 |
+
"cell_type": "code",
|
| 715 |
+
"execution_count": 42,
|
| 716 |
+
"metadata": {},
|
| 717 |
+
"outputs": [],
|
| 718 |
+
"source": [
|
| 719 |
+
"with open(\"./model/SVC_model.pkl\", \"wb\") as f:\n",
|
| 720 |
+
" pickle.dump(models[\"SVC\"], f)"
|
| 721 |
+
]
|
| 722 |
+
},
|
| 723 |
+
{
|
| 724 |
+
"cell_type": "code",
|
| 725 |
+
"execution_count": 21,
|
| 726 |
+
"metadata": {},
|
| 727 |
+
"outputs": [],
|
| 728 |
+
"source": [
|
| 729 |
+
"# Dump input scaler\n",
|
| 730 |
+
"with open(\"./model/input_scaler.pkl\", \"wb\") as f:\n",
|
| 731 |
+
" pickle.dump(sc, f)"
|
| 732 |
+
]
|
| 733 |
+
}
|
| 734 |
+
],
|
| 735 |
+
"metadata": {
|
| 736 |
+
"kernelspec": {
|
| 737 |
+
"display_name": "Python 3.8.13 (conda)",
|
| 738 |
+
"language": "python",
|
| 739 |
+
"name": "python3"
|
| 740 |
+
},
|
| 741 |
+
"language_info": {
|
| 742 |
+
"codemirror_mode": {
|
| 743 |
+
"name": "ipython",
|
| 744 |
+
"version": 3
|
| 745 |
+
},
|
| 746 |
+
"file_extension": ".py",
|
| 747 |
+
"mimetype": "text/x-python",
|
| 748 |
+
"name": "python",
|
| 749 |
+
"nbconvert_exporter": "python",
|
| 750 |
+
"pygments_lexer": "ipython3",
|
| 751 |
+
"version": "3.8.13"
|
| 752 |
+
},
|
| 753 |
+
"orig_nbformat": 4,
|
| 754 |
+
"vscode": {
|
| 755 |
+
"interpreter": {
|
| 756 |
+
"hash": "9260f401923fb5c4108c543a7d176de9733d378b3752e49535ad7c43c2271b65"
|
| 757 |
+
}
|
| 758 |
+
}
|
| 759 |
+
},
|
| 760 |
+
"nbformat": 4,
|
| 761 |
+
"nbformat_minor": 2
|
| 762 |
+
}
|