Spaces:
Running
Running
Upload 8 files
Browse files- .gitattributes +1 -0
- 20words_mean_face.npy +3 -0
- TestVisual.sh +7 -0
- app.py +204 -0
- main.py +369 -0
- mmod_human_face_detector.dat +0 -0
- requirements.txt +21 -0
- shape_predictor_68_face_landmarks.dat +3 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
35 |
preprocessing/shape_predictor_68_face_landmarks.dat filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
35 |
preprocessing/shape_predictor_68_face_landmarks.dat filter=lfs diff=lfs merge=lfs -text
|
36 |
+
shape_predictor_68_face_landmarks.dat filter=lfs diff=lfs merge=lfs -text
|
20words_mean_face.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:dbf68b2044171e1160716df7c53e8bbfaa0ee8c61fb41171d04cb6092bb81422
|
3 |
+
size 1168
|
TestVisual.sh
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
python main.py \
|
2 |
+
--config-path ./configs/lrw_resnet18_mstcn.json \
|
3 |
+
--model-path ./train_logs/tcn/2022-06-06T19:09:00/ckpt.best.pth.tar \
|
4 |
+
--data-dir ./video \
|
5 |
+
--label-path ./labels/30VietnameseSort.txt \
|
6 |
+
--save-dir ./result \
|
7 |
+
--test
|
app.py
ADDED
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
|
4 |
+
os.system('git clone https://github.com/facebookresearch/av_hubert.git')
|
5 |
+
os.chdir('/home/user/app/av_hubert')
|
6 |
+
os.system('git submodule init')
|
7 |
+
os.system('git submodule update')
|
8 |
+
os.chdir('/home/user/app/av_hubert/fairseq')
|
9 |
+
os.system('pip install ./')
|
10 |
+
os.system('pip install scipy')
|
11 |
+
os.system('pip install sentencepiece')
|
12 |
+
os.system('pip install python_speech_features')
|
13 |
+
os.system('pip install scikit-video')
|
14 |
+
os.system('pip install transformers')
|
15 |
+
os.system('pip install gradio==3.12')
|
16 |
+
os.system('pip install numpy==1.23.3')
|
17 |
+
|
18 |
+
|
19 |
+
# sys.path.append('/home/user/app/av_hubert')
|
20 |
+
sys.path.append('/home/user/app/av_hubert/avhubert')
|
21 |
+
|
22 |
+
print(sys.path)
|
23 |
+
print(os.listdir())
|
24 |
+
print(sys.argv, type(sys.argv))
|
25 |
+
sys.argv.append('dummy')
|
26 |
+
|
27 |
+
|
28 |
+
|
29 |
+
import dlib, cv2, os
|
30 |
+
import numpy as np
|
31 |
+
import skvideo
|
32 |
+
import skvideo.io
|
33 |
+
from tqdm import tqdm
|
34 |
+
from preparation.align_mouth import landmarks_interpolate, crop_patch, write_video_ffmpeg
|
35 |
+
from base64 import b64encode
|
36 |
+
import torch
|
37 |
+
import cv2
|
38 |
+
import tempfile
|
39 |
+
from argparse import Namespace
|
40 |
+
import fairseq
|
41 |
+
from fairseq import checkpoint_utils, options, tasks, utils
|
42 |
+
from fairseq.dataclass.configs import GenerationConfig
|
43 |
+
from huggingface_hub import hf_hub_download
|
44 |
+
import gradio as gr
|
45 |
+
from pytube import YouTube
|
46 |
+
|
47 |
+
# os.chdir('/home/user/app/av_hubert/avhubert')
|
48 |
+
|
49 |
+
user_dir = "/home/user/app/av_hubert/avhubert"
|
50 |
+
utils.import_user_module(Namespace(user_dir=user_dir))
|
51 |
+
data_dir = "/home/user/app/video"
|
52 |
+
|
53 |
+
# ckpt_path = hf_hub_download('vumichien/AV-HuBERT', 'model.pt')
|
54 |
+
face_detector_path = "/home/user/app/mmod_human_face_detector.dat"
|
55 |
+
face_predictor_path = "/home/user/app/shape_predictor_68_face_landmarks.dat"
|
56 |
+
mean_face_path = "/home/user/app/20words_mean_face.npy"
|
57 |
+
mouth_roi_path = "/home/user/app/roi.mp4"
|
58 |
+
output_video_path = "/home/user/app/video/vร /test"
|
59 |
+
modalities = ["video"]
|
60 |
+
gen_subset = "test"
|
61 |
+
gen_cfg = GenerationConfig(beam=20)
|
62 |
+
models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task([ckpt_path])
|
63 |
+
models = [model.eval().cuda() if torch.cuda.is_available() else model.eval() for model in models]
|
64 |
+
saved_cfg.task.modalities = modalities
|
65 |
+
saved_cfg.task.data = data_dir
|
66 |
+
saved_cfg.task.label_dir = data_dir
|
67 |
+
task = tasks.setup_task(saved_cfg.task)
|
68 |
+
generator = task.build_generator(models, gen_cfg)
|
69 |
+
|
70 |
+
def get_youtube(video_url):
|
71 |
+
yt = YouTube(video_url)
|
72 |
+
abs_video_path = yt.streams.filter(progressive=True, file_extension='mp4').order_by('resolution').desc().first().download()
|
73 |
+
print("Success download video")
|
74 |
+
print(abs_video_path)
|
75 |
+
return abs_video_path
|
76 |
+
|
77 |
+
import dlib, cv2, os
|
78 |
+
import numpy as np
|
79 |
+
import skvideo
|
80 |
+
import skvideo.io
|
81 |
+
from tqdm import tqdm
|
82 |
+
from preparation.align_mouth import landmarks_interpolate, crop_patch, write_video_ffmpeg
|
83 |
+
from IPython.display import HTML
|
84 |
+
from base64 import b64encode
|
85 |
+
import numpy as np
|
86 |
+
|
87 |
+
def convert_bgr2gray(data):
|
88 |
+
# np.stack(๋ฐฐ์ด_1, ๋ฐฐ์ด_2, axis=0): ์ง์ ํ axis๋ฅผ ์์ ํ ์๋ก์ด axis๋ก ์๊ฐ
|
89 |
+
return np.stack([cv2.cvtColor(_, cv2.COLOR_BGR2GRAY) for _ in data], axis=0)
|
90 |
+
def save2npz(filename, data=None):
|
91 |
+
"""save2npz.
|
92 |
+
:param filename: str, the fileanme where the data will be saved.
|
93 |
+
:param data: ndarray, arrays to save to the file.
|
94 |
+
"""
|
95 |
+
assert data is not None, "data is {}".format(data)
|
96 |
+
if not os.path.exists(os.path.dirname(filename)):
|
97 |
+
os.makedirs(os.path.dirname(filename))
|
98 |
+
np.savez_compressed(filename, data=data)
|
99 |
+
|
100 |
+
def detect_landmark(image, detector, predictor):
|
101 |
+
gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
|
102 |
+
face_locations = detector(gray, 1)
|
103 |
+
coords = None
|
104 |
+
for (_, face_location) in enumerate(face_locations):
|
105 |
+
if torch.cuda.is_available():
|
106 |
+
rect = face_location.rect
|
107 |
+
else:
|
108 |
+
rect = face_location
|
109 |
+
shape = predictor(gray, rect)
|
110 |
+
coords = np.zeros((68, 2), dtype=np.int32)
|
111 |
+
for i in range(0, 68):
|
112 |
+
coords[i] = (shape.part(i).x, shape.part(i).y)
|
113 |
+
return coords
|
114 |
+
|
115 |
+
def preprocess_video(input_video_path):
|
116 |
+
if torch.cuda.is_available():
|
117 |
+
detector = dlib.cnn_face_detection_model_v1(face_detector_path)
|
118 |
+
else:
|
119 |
+
detector = dlib.get_frontal_face_detector()
|
120 |
+
|
121 |
+
predictor = dlib.shape_predictor(face_predictor_path)
|
122 |
+
STD_SIZE = (256, 256)
|
123 |
+
mean_face_landmarks = np.load(mean_face_path)
|
124 |
+
stablePntsIDs = [33, 36, 39, 42, 45]
|
125 |
+
videogen = skvideo.io.vread(input_video_path)
|
126 |
+
frames = np.array([frame for frame in videogen])
|
127 |
+
landmarks = []
|
128 |
+
for frame in tqdm(frames):
|
129 |
+
landmark = detect_landmark(frame, detector, predictor)
|
130 |
+
landmarks.append(landmark)
|
131 |
+
preprocessed_landmarks = landmarks_interpolate(landmarks)
|
132 |
+
rois = crop_patch(input_video_path, preprocessed_landmarks, mean_face_landmarks, stablePntsIDs, STD_SIZE,
|
133 |
+
window_margin=12, start_idx=48, stop_idx=68, crop_height=96, crop_width=96)
|
134 |
+
rois_gray=convert_bgr2gray(rois)
|
135 |
+
save2npz(output_video_path, data=rois_gray)
|
136 |
+
write_video_ffmpeg(rois, mouth_roi_path, "/usr/bin/ffmpeg")
|
137 |
+
return mouth_roi_path
|
138 |
+
|
139 |
+
|
140 |
+
|
141 |
+
def predict(process_video):
|
142 |
+
os.chdir('/home/user/app')
|
143 |
+
return os.system('bash TestVisual.sh')
|
144 |
+
|
145 |
+
|
146 |
+
# ---- Gradio Layout -----
|
147 |
+
youtube_url_in = gr.Textbox(label="Youtube url", lines=1, interactive=True)
|
148 |
+
video_in = gr.Video(label="Input Video", mirror_webcam=False, interactive=True)
|
149 |
+
video_out = gr.Video(label="Audio Visual Video", mirror_webcam=False, interactive=True)
|
150 |
+
demo = gr.Blocks()
|
151 |
+
demo.encrypt = False
|
152 |
+
text_output = gr.Textbox()
|
153 |
+
|
154 |
+
with demo:
|
155 |
+
# gr.Markdown('''
|
156 |
+
# <div>
|
157 |
+
# <h1 style='text-align: center'>Speech Recognition from Visual Lip Movement by Audio-Visual Hidden Unit BERT Model (AV-HuBERT)</h1>
|
158 |
+
# This space uses AV-HuBERT models from <a href='https://github.com/facebookresearch' target='_blank'><b>Meta Research</b></a> to recoginze the speech from Lip Movement ๐ค
|
159 |
+
# <figure>
|
160 |
+
# <img src="https://huggingface.co/vumichien/AV-HuBERT/resolve/main/lipreading.gif" alt="Audio-Visual Speech Recognition">
|
161 |
+
# <figcaption> Speech Recognition from visual lip movement
|
162 |
+
# </figcaption>
|
163 |
+
# </figure>
|
164 |
+
# </div>
|
165 |
+
# ''')
|
166 |
+
# with gr.Row():
|
167 |
+
# gr.Markdown('''
|
168 |
+
# ### Reading Lip movement with youtube link using Avhubert
|
169 |
+
# ##### Step 1a. Download video from youtube (Note: the length of video should be less than 10 seconds if not it will be cut and the face should be stable for better result)
|
170 |
+
# ##### Step 1b. You also can upload video directly
|
171 |
+
# ##### Step 2. Generating landmarks surrounding mouth area
|
172 |
+
# ##### Step 3. Reading lip movement.
|
173 |
+
# ''')
|
174 |
+
with gr.Row():
|
175 |
+
gr.Markdown('''
|
176 |
+
### You can test by following examples:
|
177 |
+
''')
|
178 |
+
examples = gr.Examples(examples=
|
179 |
+
[ "https://www.youtube.com/watch?v=ZXVDnuepW2s",
|
180 |
+
"https://www.youtube.com/watch?v=X8_glJn1B8o",
|
181 |
+
"https://www.youtube.com/watch?v=80yqL2KzBVw"],
|
182 |
+
label="Examples", inputs=[youtube_url_in])
|
183 |
+
with gr.Column():
|
184 |
+
youtube_url_in.render()
|
185 |
+
download_youtube_btn = gr.Button("Download Youtube video")
|
186 |
+
download_youtube_btn.click(get_youtube, [youtube_url_in], [
|
187 |
+
video_in])
|
188 |
+
print(video_in)
|
189 |
+
with gr.Row():
|
190 |
+
video_in.render()
|
191 |
+
video_out.render()
|
192 |
+
with gr.Row():
|
193 |
+
detect_landmark_btn = gr.Button("Phรกt hiแปn mแปc/cแบฏt mรดi")
|
194 |
+
detect_landmark_btn.click(preprocess_video, [video_in], [
|
195 |
+
video_out])
|
196 |
+
predict_btn = gr.Button("Dแปฑ ฤoรกn")
|
197 |
+
predict_btn.click(predict, [video_out], [
|
198 |
+
text_output])
|
199 |
+
with gr.Row():
|
200 |
+
# video_lip = gr.Video(label="Audio Visual Video", mirror_webcam=False)
|
201 |
+
text_output.render()
|
202 |
+
|
203 |
+
|
204 |
+
demo.launch(debug=True)
|
main.py
ADDED
@@ -0,0 +1,369 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#! /usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
# Copyright 2020 Imperial College London (Pingchuan Ma)
|
5 |
+
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
6 |
+
|
7 |
+
""" TCN for lipreading"""
|
8 |
+
|
9 |
+
import os
|
10 |
+
import time
|
11 |
+
import random
|
12 |
+
import argparse # ๋ช
๋ นํ ์ธ์๋ฅผ ํ์ฑํด์ฃผ๋ ๋ชจ๋
|
13 |
+
import numpy as np
|
14 |
+
from tqdm import tqdm # ์์
์งํ๋ฅ ํ์ํ๋ ๋ผ์ด๋ธ๋ฌ๋ฆฌ
|
15 |
+
|
16 |
+
import torch # ํ์ดํ ์น
|
17 |
+
import torch.nn as nn # ํด๋์ค # attribute ๋ฅผ ํ์ฉํด state ๋ฅผ ์ ์ฅํ๊ณ ํ์ฉ
|
18 |
+
import torch.nn.functional as F # ํจ์ # ์ธ์คํด์คํ์ํฌ ํ์์์ด ์ฌ์ฉ ๊ฐ๋ฅ
|
19 |
+
|
20 |
+
from lipreading.utils import get_save_folder
|
21 |
+
from lipreading.utils import load_json, save2npz
|
22 |
+
from lipreading.utils import load_model, CheckpointSaver
|
23 |
+
from lipreading.utils import get_logger, update_logger_batch
|
24 |
+
from lipreading.utils import showLR, calculateNorm2, AverageMeter
|
25 |
+
from lipreading.model import Lipreading
|
26 |
+
from lipreading.mixup import mixup_data, mixup_criterion
|
27 |
+
from lipreading.optim_utils import get_optimizer, CosineScheduler
|
28 |
+
from lipreading.dataloaders import get_data_loaders, get_preprocessing_pipelines
|
29 |
+
|
30 |
+
from pathlib import Path
|
31 |
+
import wandb # ํ์ต ๊ด๋ฆฌ ํด (Loss, Acc ์๋ ์ ์ฅ)
|
32 |
+
|
33 |
+
|
34 |
+
# ์ธ์๊ฐ์ ๋ฐ์์ ์ฒ๋ฆฌํ๋ ํจ์
|
35 |
+
def load_args(default_config=None):
|
36 |
+
# ์ธ์๊ฐ์ ๋ฐ์ ์ ์๋ ์ธ์คํด์ค ์์ฑ
|
37 |
+
parser = argparse.ArgumentParser(description='Pytorch Lipreading ')
|
38 |
+
|
39 |
+
# ์
๋ ฅ๋ฐ์ ์ธ์๊ฐ ๋ชฉ๋ก
|
40 |
+
# -- dataset config
|
41 |
+
parser.add_argument('--dataset', default='lrw', help='dataset selection')
|
42 |
+
parser.add_argument('--num-classes', type=int, default=30, help='Number of classes')
|
43 |
+
parser.add_argument('--modality', default='video', choices=['video', 'raw_audio'], help='choose the modality')
|
44 |
+
# -- directory
|
45 |
+
parser.add_argument('--data-dir', default='./datasets/visual', help='Loaded data directory')
|
46 |
+
parser.add_argument('--label-path', type=str, default='./labels/30VietnameseSort.txt', help='Path to txt file with labels')
|
47 |
+
parser.add_argument('--annonation-direc', default=None, help='Loaded data directory')
|
48 |
+
# -- model config
|
49 |
+
parser.add_argument('--backbone-type', type=str, default='resnet', choices=['resnet', 'shufflenet'], help='Architecture used for backbone')
|
50 |
+
parser.add_argument('--relu-type', type=str, default='relu', choices=['relu','prelu'], help='what relu to use' )
|
51 |
+
parser.add_argument('--width-mult', type=float, default=1.0, help='Width multiplier for mobilenets and shufflenets')
|
52 |
+
# -- TCN config
|
53 |
+
parser.add_argument('--tcn-kernel-size', type=int, nargs="+", help='Kernel to be used for the TCN module')
|
54 |
+
parser.add_argument('--tcn-num-layers', type=int, default=4, help='Number of layers on the TCN module')
|
55 |
+
parser.add_argument('--tcn-dropout', type=float, default=0.2, help='Dropout value for the TCN module')
|
56 |
+
parser.add_argument('--tcn-dwpw', default=False, action='store_true', help='If True, use the depthwise seperable convolution in TCN architecture')
|
57 |
+
parser.add_argument('--tcn-width-mult', type=int, default=1, help='TCN width multiplier')
|
58 |
+
# -- train
|
59 |
+
parser.add_argument('--training-mode', default='tcn', help='tcn')
|
60 |
+
parser.add_argument('--batch-size', type=int, default=8, help='Mini-batch size') # dafault=32 ์์ default=8 (OOM ๋ฐฉ์ง) ๋ก ๋ณ๊ฒฝ
|
61 |
+
parser.add_argument('--optimizer',type=str, default='adamw', choices = ['adam','sgd','adamw'])
|
62 |
+
parser.add_argument('--lr', default=3e-4, type=float, help='initial learning rate')
|
63 |
+
parser.add_argument('--init-epoch', default=0, type=int, help='epoch to start at')
|
64 |
+
parser.add_argument('--epochs', default=100, type=int, help='number of epochs') # dafault=80 ์์ default=10 (ํ
์คํธ ์ฉ๋) ๋ก ๋ณ๊ฒฝ
|
65 |
+
parser.add_argument('--test', default=False, action='store_true', help='training mode')
|
66 |
+
parser.add_argument('--save-dir', type=Path, default=Path('/kaggle/working/result/'))
|
67 |
+
# -- mixup
|
68 |
+
parser.add_argument('--alpha', default=0.4, type=float, help='interpolation strength (uniform=1., ERM=0.)')
|
69 |
+
# -- test
|
70 |
+
parser.add_argument('--model-path', type=str, default=None, help='Pretrained model pathname')
|
71 |
+
parser.add_argument('--allow-size-mismatch', default=False, action='store_true',
|
72 |
+
help='If True, allows to init from model with mismatching weight tensors. Useful to init from model with diff. number of classes')
|
73 |
+
# -- feature extractor
|
74 |
+
parser.add_argument('--extract-feats', default=False, action='store_true', help='Feature extractor')
|
75 |
+
parser.add_argument('--mouth-patch-path', type=str, default=None, help='Path to the mouth ROIs, assuming the file is saved as numpy.array')
|
76 |
+
parser.add_argument('--mouth-embedding-out-path', type=str, default=None, help='Save mouth embeddings to a specificed path')
|
77 |
+
# -- json pathname
|
78 |
+
parser.add_argument('--config-path', type=str, default=None, help='Model configuration with json format')
|
79 |
+
# -- other vars
|
80 |
+
parser.add_argument('--interval', default=50, type=int, help='display interval')
|
81 |
+
parser.add_argument('--workers', default=2, type=int, help='number of data loading workers') # dafault=8 ์์ default=2 (GCP core 4๊ฐ์ ์ ๋ฐ) ๋ก ๋ณ๊ฒฝ
|
82 |
+
# paths
|
83 |
+
parser.add_argument('--logging-dir', type=str, default='/kaggle/working/train_logs', help = 'path to the directory in which to save the log file')
|
84 |
+
|
85 |
+
# ์
๋ ฅ๋ฐ์ ์ธ์๊ฐ์ args์ ์ ์ฅ (type: namespace)
|
86 |
+
args = parser.parse_args()
|
87 |
+
return args
|
88 |
+
|
89 |
+
|
90 |
+
args = load_args() # args ํ์ฑ ๋ฐ ๋ก๋
|
91 |
+
|
92 |
+
# ์คํ ์ฌํ์ ์ํด์ ๋์ ๊ณ ์
|
93 |
+
torch.manual_seed(1) # ๋ฉ์ธ ํ๋ ์์ํฌ์ธ pytorch ์์ random seed ๊ณ ์
|
94 |
+
np.random.seed(1) # numpy ์์ random seed ๊ณ ์
|
95 |
+
random.seed(1) # python random ๋ผ์ด๋ธ๋ฌ๋ฆฌ์์ random seed ๊ณ ์
|
96 |
+
|
97 |
+
# ์ฐธ๊ณ : ์คํ ์ฌํํ๋ ค๋ฉด torch.backends.cudnn.deterministic = True, torch.backends.cudnn.benchmark = False ์ด์ด์ผ ํจ
|
98 |
+
torch.backends.cudnn.benchmark = True # ๋ด์ฅ๋ cudnn ์๋ ํ๋๋ฅผ ํ์ฑํํ์ฌ, ํ๋์จ์ด์ ๋ง๊ฒ ์ฌ์ฉํ ์ต์์ ์๊ณ ๋ฆฌ์ฆ(ํ
์ ํฌ๊ธฐ๋ conv ์ฐ์ฐ์ ๋ง๊ฒ)์ ์ฐพ์
|
99 |
+
|
100 |
+
|
101 |
+
# feature ์ถ์ถ
|
102 |
+
def extract_feats(model):
|
103 |
+
"""
|
104 |
+
:rtype: FloatTensor
|
105 |
+
"""
|
106 |
+
model.eval() # evaluation ๊ณผ์ ์์ ์ฌ์ฉํ์ง ์์์ผ ํ๋ layer๋ค์ ์์์ off ์ํค๋๋ก ํ๋ ํจ์
|
107 |
+
preprocessing_func = get_preprocessing_pipelines()['test'] # test ์ ์ฒ๋ฆฌ
|
108 |
+
|
109 |
+
mouth_patch_path = args.mouth_patch_path.replace('.','')
|
110 |
+
dir_name = os.path.dirname(os.path.abspath(__file__))
|
111 |
+
dir_name = dir_name + mouth_patch_path
|
112 |
+
|
113 |
+
data_paths = [os.path.join(pth, f) for pth, dirs, files in os.walk(dir_name) for f in files]
|
114 |
+
|
115 |
+
npz_files = np.load(data_paths[0])['data']
|
116 |
+
|
117 |
+
data = preprocessing_func(npz_files) # data: TxHxW
|
118 |
+
# data = preprocessing_func(np.load(args.mouth_patch_path)['data']) # data: TxHxW
|
119 |
+
return data_paths[0], model(torch.FloatTensor(data)[None, None, :, :, :].cuda(), lengths=[data.shape[0]])
|
120 |
+
# return model(torch.FloatTensor(data)[None, None, :, :, :].cuda(), lengths=[data.shape[0]])
|
121 |
+
|
122 |
+
|
123 |
+
# ํ๊ฐ
|
124 |
+
def evaluate(model, dset_loader, criterion, is_print=False):
|
125 |
+
model.eval() # evaluation ๊ณผ์ ์์ ์ฌ์ฉํ์ง ์์์ผ ํ๋ layer๋ค์ ์์์ off ์ํค๋๋ก ํ๋ ํจ์
|
126 |
+
# running_loss = 0.
|
127 |
+
# running_corrects = 0.
|
128 |
+
prediction=''
|
129 |
+
# evaluation/validation ๊ณผ์ ์์ ๋ณดํต model.eval()๊ณผ torch.no_grad()๋ฅผ ํจ๊ป ์ฌ์ฉํจ
|
130 |
+
with torch.no_grad():
|
131 |
+
inferences = []
|
132 |
+
for batch_idx, (input, lengths, labels) in enumerate(tqdm(dset_loader)):
|
133 |
+
# ๋ชจ๋ธ ์์ฑ
|
134 |
+
# input ํ
์์ ์ฐจ์์ ํ๋ ๋ ๋๋ฆฌ๊ณ gpu ์ ํ ๋น
|
135 |
+
logits = model(input.unsqueeze(1).cuda(), lengths=lengths)
|
136 |
+
# _, preds = torch.max(F.softmax(logits, dim=1).data, dim=1) # softmax ์ ์ฉ ํ ๊ฐ ์์ ์ค ์ต๋๊ฐ ๊ฐ์ ธ์ค๊ธฐ
|
137 |
+
# running_corrects += preds.eq(labels.cuda().view_as(preds)).sum().item() # ์ ํ๋ ๊ณ์ฐ
|
138 |
+
|
139 |
+
# loss = criterion(logits, labels.cuda()) # loss ๊ณ์ฐ
|
140 |
+
# running_loss += loss.item() * input.size(0) # loss.item(): loss ๊ฐ ๊ฐ๊ณ ์๋ scalar ๊ฐ
|
141 |
+
# # ------------ Prediction, Confidence ์ถ๋ ฅ ------------
|
142 |
+
|
143 |
+
probs = torch.nn.functional.softmax(logits, dim=-1)
|
144 |
+
probs = probs[0].detach().cpu().numpy()
|
145 |
+
|
146 |
+
label_path = args.label_path
|
147 |
+
with Path(label_path).open() as fp:
|
148 |
+
vocab = fp.readlines()
|
149 |
+
|
150 |
+
top = np.argmax(probs)
|
151 |
+
prediction = vocab[top].strip()
|
152 |
+
# confidence = np.round(probs[top], 3)
|
153 |
+
# inferences.append({
|
154 |
+
# 'prediction': prediction,
|
155 |
+
# 'confidence': confidence
|
156 |
+
# })
|
157 |
+
|
158 |
+
if is_print:
|
159 |
+
print()
|
160 |
+
print(f'Prediction: {prediction}')
|
161 |
+
# print(f'Confidence: {confidence}')
|
162 |
+
print()
|
163 |
+
return prediction
|
164 |
+
# ------------ Prediction, Confidence ํ
์คํธ ํ์ผ ์ ์ฅ ------------
|
165 |
+
# txt_save_path = str(args.save_dir) + f'/predict.txt'
|
166 |
+
# # ํ์ผ ์์ ๊ฒฝ์ฐ
|
167 |
+
# if not os.path.exists(os.path.dirname(txt_save_path)):
|
168 |
+
# os.makedirs(os.path.dirname(txt_save_path)) # ๋๋ ํ ๋ฆฌ ์์ฑ
|
169 |
+
# with open(txt_save_path, 'w') as f:
|
170 |
+
# for inference in inferences:
|
171 |
+
# prediction = inference['prediction']
|
172 |
+
# confidence = inference['confidence']
|
173 |
+
# f.writelines(f'Prediction: {prediction}, Confidence: {confidence}\n')
|
174 |
+
|
175 |
+
# print('Test Dataset {} In Total \t CR: {}'.format( len(dset_loader.dataset), running_corrects/len(dset_loader.dataset))) # ๋ฐ์ดํฐ๊ฐ์, ์ ํ๋ ์ถ๋ ฅ
|
176 |
+
# return running_corrects/len(dset_loader.dataset), running_loss/len(dset_loader.dataset), inferences # ์ ํ๋, loss, inferences ๋ฐํ
|
177 |
+
|
178 |
+
|
179 |
+
# ๋ชจ๋ธ ํ์ต
|
180 |
+
# def train(wandb, model, dset_loader, criterion, epoch, optimizer, logger):
|
181 |
+
# data_time = AverageMeter() # ํ๊ท , ํ์ฌ๊ฐ ๏ฟฝ๏ฟฝ์ฅ
|
182 |
+
# batch_time = AverageMeter() # ํ๊ท , ํ์ฌ๊ฐ ์ ์ฅ
|
183 |
+
|
184 |
+
# lr = showLR(optimizer) # LR ๋ณํ๊ฐ
|
185 |
+
|
186 |
+
# # ๋ก๊ฑฐ INFO ์์ฑ
|
187 |
+
# logger.info('-' * 10)
|
188 |
+
# logger.info('Epoch {}/{}'.format(epoch, args.epochs - 1)) # epoch ์์ฑ
|
189 |
+
# logger.info('Current learning rate: {}'.format(lr)) # learning rate ์์ฑ
|
190 |
+
|
191 |
+
# model.train() # train mode
|
192 |
+
# running_loss = 0.
|
193 |
+
# running_corrects = 0.
|
194 |
+
# running_all = 0.
|
195 |
+
|
196 |
+
# end = time.time() # ํ์ฌ ์๊ฐ
|
197 |
+
# for batch_idx, (input, lengths, labels) in enumerate(dset_loader):
|
198 |
+
# # measure data loading time
|
199 |
+
# data_time.update(time.time() - end) # ํ๊ท , ํ์ฌ๊ฐ ์
๋ฐ์ดํธ
|
200 |
+
|
201 |
+
# # --
|
202 |
+
# # mixup augmentation ๊ณ์ฐ
|
203 |
+
# input, labels_a, labels_b, lam = mixup_data(input, labels, args.alpha)
|
204 |
+
# labels_a, labels_b = labels_a.cuda(), labels_b.cuda() # tensor ๋ฅผ gpu ์ ํ ๋น
|
205 |
+
|
206 |
+
# # Pytorch์์๋ gradients๊ฐ๋ค์ ์ถํ์ backward๋ฅผ ํด์ค๋ ๊ณ์ ๋ํด์ฃผ๊ธฐ ๋๋ฌธ
|
207 |
+
# optimizer.zero_grad() # ํญ์ backpropagation์ ํ๊ธฐ์ ์ gradients๋ฅผ zero๋ก ๋ง๋ค์ด์ฃผ๊ณ ์์์ ํด์ผ ํจ
|
208 |
+
|
209 |
+
# # ๋ชจ๋ธ ์์ฑ
|
210 |
+
# # input ํ
์์ ์ฐจ์์ ํ๋ ๋ ๋๋ฆฌ๊ณ gpu ์ ํ ๋น
|
211 |
+
# logits = model(input.unsqueeze(1).cuda(), lengths=lengths)
|
212 |
+
|
213 |
+
# loss_func = mixup_criterion(labels_a, labels_b, lam) # mixup ์ ์ฉ
|
214 |
+
# loss = loss_func(criterion, logits) # loss ๊ณ์ฐ
|
215 |
+
|
216 |
+
# loss.backward() # gradient ๊ณ์ฐ
|
217 |
+
# optimizer.step() # ์ ์ฅ๋ gradient ๊ฐ์ ์ด์ฉํ์ฌ ํ๋ผ๋ฏธํฐ๋ฅผ ์
๋ฐ์ดํธ
|
218 |
+
|
219 |
+
# # measure elapsed time # ๊ฒฝ๊ณผ ์๊ฐ ์ธก์
|
220 |
+
# batch_time.update(time.time() - end) # ํ๊ท , ํ์ฌ๊ฐ ์
๋ฐ์ดํธ
|
221 |
+
# end = time.time() # ํ์ฌ ์๊ฐ
|
222 |
+
# # -- compute running performance # ์ปดํจํ
์คํ ์ฑ๋ฅ
|
223 |
+
# _, predicted = torch.max(F.softmax(logits, dim=1).data, dim=1) # softmax ์ ์ฉ ํ ๊ฐ ์์ ์ค ์ต๋๊ฐ ๊ฐ์ ธ์ค๊ธฐ
|
224 |
+
# running_loss += loss.item()*input.size(0) # loss.item(): loss ๊ฐ ๊ฐ๊ณ ์๋ scalar ๊ฐ
|
225 |
+
# running_corrects += lam * predicted.eq(labels_a.view_as(predicted)).sum().item() + (1 - lam) * predicted.eq(labels_b.view_as(predicted)).sum().item() # ์ ํ๋ ๊ณ์ฐ
|
226 |
+
# running_all += input.size(0)
|
227 |
+
|
228 |
+
|
229 |
+
# # ------------------ wandb ๋ก๊ทธ ์
๋ ฅ ------------------
|
230 |
+
# wandb.log({'loss': running_loss, 'acc': running_corrects}, step=epoch)
|
231 |
+
|
232 |
+
|
233 |
+
# # -- log intermediate results # ์ค๊ฐ ๊ฒฐ๊ณผ ๊ธฐ๋ก
|
234 |
+
# if batch_idx % args.interval == 0 or (batch_idx == len(dset_loader)-1):
|
235 |
+
# # ๋ก๊ฑฐ INFO ์์ฑ
|
236 |
+
# update_logger_batch( args, logger, dset_loader, batch_idx, running_loss, running_corrects, running_all, batch_time, data_time )
|
237 |
+
|
238 |
+
# return model # ๋ชจ๋ธ ๋ฐํ
|
239 |
+
|
240 |
+
|
241 |
+
# model ์ค์ ์ ๋ํ json ์์ฑ
|
242 |
+
def get_model_from_json():
|
243 |
+
# json ํ์ผ์ด ์๋์ง ํ์ธ, ์์ผ๋ฉด AssertionError ๋ฉ์์ง๋ฅผ ๋์
|
244 |
+
assert args.config_path.endswith('.json') and os.path.isfile(args.config_path), \
|
245 |
+
"'.json' config path does not exist. Path input: {}".format(args.config_path) # ์ํ๋ ์กฐ๊ฑด์ ๋ณ์๊ฐ์ ๋ณด์ฆํ๊ธฐ ์ํด ์ฌ์ฉ
|
246 |
+
|
247 |
+
args_loaded = load_json( args.config_path) # json ์ฝ์ด์ค๊ธฐ
|
248 |
+
args.backbone_type = args_loaded['backbone_type'] # json ์์ backbone_type ๊ฐ์ ธ์ค๊ธฐ
|
249 |
+
args.width_mult = args_loaded['width_mult'] # json ์์ width_mult ๊ฐ์ ธ์ค๊ธฐ
|
250 |
+
args.relu_type = args_loaded['relu_type'] # json ์์ relu_type ๊ฐ์ ธ์ค๊ธฐ
|
251 |
+
|
252 |
+
# TCN ์ต์
์ค์
|
253 |
+
tcn_options = { 'num_layers': args_loaded['tcn_num_layers'],
|
254 |
+
'kernel_size': args_loaded['tcn_kernel_size'],
|
255 |
+
'dropout': args_loaded['tcn_dropout'],
|
256 |
+
'dwpw': args_loaded['tcn_dwpw'],
|
257 |
+
'width_mult': args_loaded['tcn_width_mult'],
|
258 |
+
}
|
259 |
+
|
260 |
+
# ๋ฆฝ๋ฆฌ๋ฉ ๋ชจ๋ธ ์์ฑ
|
261 |
+
model = Lipreading( modality=args.modality,
|
262 |
+
num_classes=args.num_classes,
|
263 |
+
tcn_options=tcn_options,
|
264 |
+
backbone_type=args.backbone_type,
|
265 |
+
relu_type=args.relu_type,
|
266 |
+
width_mult=args.width_mult,
|
267 |
+
extract_feats=args.extract_feats).cuda()
|
268 |
+
calculateNorm2(model) # ๋ชจ๋ธ ํ์ต์ด ์ ์งํ๋๋์ง ํ์ธ - ์ผ๋ฐ์ ์ผ๋ก parameter norm(L2)์ ํ์ต์ด ์งํ๋ ์๋ก ์ปค์ ธ์ผ ํจ
|
269 |
+
return model # ๋ชจ๋ธ ๋ฐํ
|
270 |
+
|
271 |
+
|
272 |
+
# main() ํจ์
|
273 |
+
def main():
|
274 |
+
|
275 |
+
# wandb ์ฐ๊ฒฐ
|
276 |
+
# wandb.init(project="Lipreading_using_TCN_running")
|
277 |
+
# wandb.config = {
|
278 |
+
# "learning_rate": args.lr,
|
279 |
+
# "epochs": args.epochs,
|
280 |
+
# "batch_size": args.batch_size
|
281 |
+
# }
|
282 |
+
|
283 |
+
|
284 |
+
# os.environ['CUDA_LAUNCH_BLOCKING']="1"
|
285 |
+
# os.environ["CUDA_VISIBLE_DEVICES"]="0" # GPU ์ ํ ์ฝ๋ ์ถ๊ฐ
|
286 |
+
|
287 |
+
# -- logging
|
288 |
+
save_path = get_save_folder( args) # ์ ์ฅ ๋๋ ํ ๋ฆฌ
|
289 |
+
print("Model and log being saved in: {}".format(save_path)) # ์ ์ฅ ๋๋ ํ ๋ฆฌ ๊ฒฝ๋ก ์ถ๋ ฅ
|
290 |
+
logger = get_logger(args, save_path) # ๋ก๊ฑฐ ์์ฑ ๋ฐ ์ค์
|
291 |
+
ckpt_saver = CheckpointSaver(save_path) # ์ฒดํฌํฌ์ธํธ ์ ์ฅ ์ค์
|
292 |
+
|
293 |
+
# -- get model
|
294 |
+
model = get_model_from_json()
|
295 |
+
# -- get dataset iterators
|
296 |
+
dset_loaders = get_data_loaders(args)
|
297 |
+
# -- get loss function
|
298 |
+
criterion = nn.CrossEntropyLoss()
|
299 |
+
# -- get optimizer
|
300 |
+
optimizer = get_optimizer(args, optim_policies=model.parameters())
|
301 |
+
# -- get learning rate scheduler
|
302 |
+
scheduler = CosineScheduler(args.lr, args.epochs) # ์ฝ์ฌ์ธ ์ค์ผ์ค๋ฌ ์ค์
|
303 |
+
|
304 |
+
if args.model_path:
|
305 |
+
# tar ํ์ผ์ด ์๋์ง ํ์ธ, ์์ผ๋ฉด AssertionError ๋ฉ์์ง๋ฅผ ๋์
|
306 |
+
assert args.model_path.endswith('.tar') and os.path.isfile(args.model_path), \
|
307 |
+
"'.tar' model path does not exist. Path input: {}".format(args.model_path) # ์ํ๋ ์กฐ๊ฑด์ ๋ณ์๊ฐ์ ๋ณด์ฆํ๊ธฐ ์ํด ์ฌ์ฉ
|
308 |
+
# resume from checkpoint
|
309 |
+
if args.init_epoch > 0:
|
310 |
+
model, optimizer, epoch_idx, ckpt_dict = load_model(args.model_path, model, optimizer) # ๋ชจ๋ธ ๋ถ๋ฌ์ค๊ธฐ
|
311 |
+
args.init_epoch = epoch_idx # epoch ์ค์
|
312 |
+
ckpt_saver.set_best_from_ckpt(ckpt_dict) # best ์ฒดํฌํฌ์ธํธ ์ ์ฅ
|
313 |
+
logger.info('Model and states have been successfully loaded from {}'.format( args.model_path )) # ๋ก๊ฑฐ INFO ์์ฑ
|
314 |
+
# init from trained model
|
315 |
+
else:
|
316 |
+
model = load_model(args.model_path, model, allow_size_mismatch=args.allow_size_mismatch) # ๋ชจ๋ธ ๋ถ๋ฌ์ค๊ธฐ
|
317 |
+
logger.info('Model has been successfully loaded from {}'.format( args.model_path )) # ๋ก๊ฑฐ INFO ์์ฑ
|
318 |
+
# feature extraction
|
319 |
+
if args.mouth_patch_path:
|
320 |
+
|
321 |
+
filename, embeddings = extract_feats(model)
|
322 |
+
filename = filename.split('/')[-1]
|
323 |
+
save_npz_path = os.path.join(args.mouth_embedding_out_path, filename)
|
324 |
+
|
325 |
+
# ExtractEmbedding ์ ์ฝ๋ ์์ ์ด ํ์ํจ!
|
326 |
+
save2npz(save_npz_path, data = embeddings.cpu().detach().numpy()) # npz ํ์ผ ์ ์ฅ
|
327 |
+
# save2npz( args.mouth_embedding_out_path, data = extract_feats(model).cpu().detach().numpy()) # npz ํ์ผ ์ ์ฅ
|
328 |
+
return
|
329 |
+
# if test-time, performance on test partition and exit. Otherwise, performance on validation and continue (sanity check for reload)
|
330 |
+
if args.test:
|
331 |
+
predicthi = evaluate(model, dset_loaders['test'], criterion, is_print=False) # ๋ชจ๋ธ ํ๊ฐ
|
332 |
+
|
333 |
+
# logging_sentence = 'Test-time performance on partition {}: Loss: {:.4f}\tAcc:{:.4f}'.format( 'test', loss_avg_test, acc_avg_test)
|
334 |
+
# logger.info(logging_sentence) # ๋ก๊ฑฐ INFO ์์ฑ
|
335 |
+
|
336 |
+
return predicthi
|
337 |
+
|
338 |
+
# -- fix learning rate after loading the ckeckpoint (latency)
|
339 |
+
if args.model_path and args.init_epoch > 0:
|
340 |
+
scheduler.adjust_lr(optimizer, args.init_epoch-1) # learning rate ์
๋ฐ์ดํธ
|
341 |
+
|
342 |
+
|
343 |
+
epoch = args.init_epoch # epoch ์ด๊ธฐํ
|
344 |
+
while epoch < args.epochs:
|
345 |
+
model = train(wandb, model, dset_loaders['train'], criterion, epoch, optimizer, logger) # ๋ชจ๋ธ ํ์ต
|
346 |
+
acc_avg_val, loss_avg_val, inferences = evaluate(model, dset_loaders['val'], criterion) # ๋ชจ๋ธ ํ๊ฐ
|
347 |
+
logger.info('{} Epoch:\t{:2}\tLoss val: {:.4f}\tAcc val:{:.4f}, LR: {}'.format('val', epoch, loss_avg_val, acc_avg_val, showLR(optimizer))) # ๋ก๊ฑฐ INFO ์์ฑ
|
348 |
+
# -- save checkpoint # ์ฒดํฌํฌ์ธํธ ์ํ ๊ธฐ๋ก
|
349 |
+
save_dict = {
|
350 |
+
'epoch_idx': epoch + 1,
|
351 |
+
'model_state_dict': model.state_dict(),
|
352 |
+
'optimizer_state_dict': optimizer.state_dict()
|
353 |
+
}
|
354 |
+
ckpt_saver.save(save_dict, acc_avg_val) # ์ฒดํฌํฌ์ธํธ ์ ์ฅ
|
355 |
+
scheduler.adjust_lr(optimizer, epoch) # learning rate ์
๋ฐ์ดํธ
|
356 |
+
epoch += 1
|
357 |
+
|
358 |
+
# -- evaluate best-performing epoch on test partition # test ๋ฐ์ดํฐ๋ก best ์ฑ๋ฅ์ epoch ํ๊ฐ
|
359 |
+
best_fp = os.path.join(ckpt_saver.save_dir, ckpt_saver.best_fn) # best ์ฒดํฌํฌ์ธํธ ๊ฒฝ๋ก
|
360 |
+
_ = load_model(best_fp, model) # ๋ชจ๋ธ ๋ถ๋ฌ์ค๊ธฐ
|
361 |
+
acc_avg_test, loss_avg_test, inferences = evaluate(model, dset_loaders['test'], criterion) # ๋ชจ๋ธ ํ๊ฐ
|
362 |
+
logger.info('Test time performance of best epoch: {} (loss: {})'.format(acc_avg_test, loss_avg_test)) # ๋ก๊ฑฐ INFO ์์ฑ
|
363 |
+
torch.cuda.empty_cache() # GPU ์บ์ ๋ฐ์ดํฐ ์ญ์
|
364 |
+
|
365 |
+
|
366 |
+
# ํด๋น ๋ชจ๋์ด ์ํฌํธ๋ ๊ฒฝ์ฐ๊ฐ ์๋๋ผ ์ธํฐํ๋ฆฌํฐ์์ ์ง์ ์คํ๋ ๊ฒฝ์ฐ์๋ง, if๋ฌธ ์ดํ์ ์ฝ๋๋ฅผ ๋๋ฆฌ๋ผ๋ ๋ช
๋ น
|
367 |
+
# => main.py ์คํํ ๊ฒฝ์ฐ ์ ์ผ ๋จผ์ ํธ์ถ๋๋ ๋ถ๋ถ
|
368 |
+
if __name__ == '__main__': # ํ์ฌ ์คํฌ๋ฆฝํธ ํ์ผ์ด ์คํ๋๋ ์ํ ํ์
|
369 |
+
main() # main() ํจ์ ํธ์ถ
|
mmod_human_face_detector.dat
ADDED
Binary file (730 kB). View file
|
|
requirements.txt
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch >= 1.3.0
|
2 |
+
numpy >= 1.16.4
|
3 |
+
scipy >= 1.3.0
|
4 |
+
opencv-python >= 4.1.0
|
5 |
+
matplotlib >= 3.0.3
|
6 |
+
tqdm >= 4.35.0
|
7 |
+
scikit-image >= 0.13.0
|
8 |
+
librosa >= 0.7.0
|
9 |
+
git+https://github.com/facebookresearch/fairseq.git
|
10 |
+
scipy
|
11 |
+
sentencepiece
|
12 |
+
python_speech_features
|
13 |
+
scikit-video
|
14 |
+
scikit-image
|
15 |
+
opencv-python
|
16 |
+
pytube==12.1.0
|
17 |
+
ffmpeg-python
|
18 |
+
cmake
|
19 |
+
dlib
|
20 |
+
face-alignment
|
21 |
+
torchvision==0.2.0
|
shape_predictor_68_face_landmarks.dat
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:fbdc2cb80eb9aa7a758672cbfdda32ba6300efe9b6e6c7a299ff7e736b11b92f
|
3 |
+
size 99693937
|