Spaces:
Runtime error
Runtime error
Update main.py
Browse files
main.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
from fastapi import FastAPI
|
2 |
import datetime
|
|
|
3 |
import torch
|
4 |
import os
|
5 |
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC, Wav2Vec2ProcessorWithLM, AutoConfig
|
@@ -39,4 +40,109 @@ app = FastAPI()
|
|
39 |
|
40 |
@app.get("/")
|
41 |
def read_root():
|
42 |
-
return {"Message": "Application startup complete"}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from fastapi import FastAPI
|
2 |
import datetime
|
3 |
+
import time
|
4 |
import torch
|
5 |
import os
|
6 |
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC, Wav2Vec2ProcessorWithLM, AutoConfig
|
|
|
40 |
|
41 |
@app.get("/")
|
42 |
def read_root():
|
43 |
+
return {"Message": "Application startup complete"}
|
44 |
+
|
45 |
+
@app.post("/naomi_api_score/")
|
46 |
+
async def predict(
|
47 |
+
file: bytes = File(...),
|
48 |
+
word: str = Form(...),
|
49 |
+
pitch: str = Form("None"),
|
50 |
+
temperature: int = Form(...),
|
51 |
+
):
|
52 |
+
""" Transform input audio, get text and pitch from Huggingface api and calculate score by Levenshtein Distance Score
|
53 |
+
Parameters:
|
54 |
+
----------
|
55 |
+
file : bytes
|
56 |
+
input audio file
|
57 |
+
β
|
58 |
+
word : strings
|
59 |
+
true hiragana word to calculate word score
|
60 |
+
β
|
61 |
+
pitch : strings
|
62 |
+
true pitch to calculate pitch score
|
63 |
+
β
|
64 |
+
temperature: integer
|
65 |
+
the difficulty of AI model
|
66 |
+
β
|
67 |
+
Returns:
|
68 |
+
-------
|
69 |
+
timestamp: strings
|
70 |
+
current time Year-Month-Day-Hours:Minutes:Second
|
71 |
+
β
|
72 |
+
running_time : strings
|
73 |
+
running time second
|
74 |
+
β
|
75 |
+
error message : strings
|
76 |
+
error message from api
|
77 |
+
β
|
78 |
+
audio duration: integer
|
79 |
+
durations of source audio
|
80 |
+
β
|
81 |
+
target : integer
|
82 |
+
durations of target audio
|
83 |
+
β
|
84 |
+
method : string
|
85 |
+
method applied to transform source audio
|
86 |
+
β
|
87 |
+
word predict : strings
|
88 |
+
text from api
|
89 |
+
β
|
90 |
+
pitch predict : strings
|
91 |
+
pitch from api
|
92 |
+
β
|
93 |
+
wrong word index: strings (ex: 100)
|
94 |
+
wrong word compare to target word
|
95 |
+
β
|
96 |
+
wrong pitch index: strings (ex: 100)
|
97 |
+
wrong word compare to target word
|
98 |
+
β
|
99 |
+
score: integer
|
100 |
+
Levenshtein Distance Score from pitch and word
|
101 |
+
β
|
102 |
+
"""
|
103 |
+
upload_audio = ffmpeg_read(file, sampling_rate=16000)
|
104 |
+
audio_duration = len(upload_audio) / 16000
|
105 |
+
current_time = datetime.datetime.now().strftime("%Y-%h-%d-%H:%M:%S")
|
106 |
+
start_time = time.time()
|
107 |
+
error_message, score = None, None
|
108 |
+
|
109 |
+
if len(word) != len(pitch):
|
110 |
+
error_message = "Length of word and pitch input is not equal"
|
111 |
+
word_preds = query_raw(upload_audio, word, processor, processor_with_lm, quantized_model, temperature=temperature)
|
112 |
+
if pitch != "None":
|
113 |
+
pitch_preds = query_dummy(upload_audio, processor_pitch, quantized_pitch_model)
|
114 |
+
|
115 |
+
# find best word
|
116 |
+
word_score_list = []
|
117 |
+
for word_predict in word_preds:
|
118 |
+
word_score_list.append(fuzz.ratio(word, word_predict[0]))
|
119 |
+
word_score = max(word_score_list)
|
120 |
+
best_word_predict = word_preds[word_score_list.index(word_score)][0]
|
121 |
+
wrong_word = find_different(word, best_word_predict) # get wrong word
|
122 |
+
|
123 |
+
# find best pitch
|
124 |
+
if pitch != "None":
|
125 |
+
if pitch_preds is not None:
|
126 |
+
best_pitch_predict = pitch_preds.replace(" ", "")
|
127 |
+
if len(best_pitch_predict) < len(best_word_predict):
|
128 |
+
best_pitch_predict = best_pitch_predict + "1" * (len(best_word_predict) - len(best_pitch_predict))
|
129 |
+
else:
|
130 |
+
best_pitch_predict = best_pitch_predict[:len(best_word_predict)] # truncate to max len
|
131 |
+
pitch_score = fuzz.ratio(pitch, best_pitch_predict)
|
132 |
+
score = int((word_score * 2 + pitch_score) / 3)
|
133 |
+
wrong_pitch = find_different(pitch, best_pitch_predict) # get wrong pitch
|
134 |
+
else:
|
135 |
+
score = int(word_score)
|
136 |
+
best_pitch_predict = None
|
137 |
+
wrong_pitch = None
|
138 |
+
β
|
139 |
+
return {"timestamp": current_time,
|
140 |
+
"running_time": f"{round(time.time() - start_time, 4)} s",
|
141 |
+
"error message": error_message,
|
142 |
+
"audio duration": audio_duration,
|
143 |
+
"word predict": best_word_predict,
|
144 |
+
"pitch predict": best_pitch_predict,
|
145 |
+
"wrong word index": wrong_word,
|
146 |
+
"wrong pitch index": wrong_pitch,
|
147 |
+
"score": score
|
148 |
+
}
|