naomi-app-api / main.py
vumichien's picture
Update main.py
4c7e2b8
raw
history blame
5.45 kB
from fastapi import FastAPI
import datetime
import time
import torch
import os
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC, Wav2Vec2ProcessorWithLM, AutoConfig
from huggingface_hub import hf_hub_download
from fuzzywuzzy import fuzz
from utils import ffmpeg_read, query_dummy, query_raw, find_different
## config
API_TOKEN = os.environ["API_TOKEN"]
MODEL_PATH = os.environ["MODEL_PATH"]
PITCH_PATH = os.environ["PITCH_PATH"]
QUANTIZED_MODEL_PATH = hf_hub_download(repo_id=MODEL_PATH, filename='quantized_model.pt', token=API_TOKEN)
QUANTIZED_PITCH_MODEL_PATH = hf_hub_download(repo_id=PITCH_PATH, filename='quantized_model.pt', token=API_TOKEN)
## word preprocessor
processor_with_lm = Wav2Vec2ProcessorWithLM.from_pretrained(MODEL_PATH, use_auth_token=API_TOKEN)
processor = Wav2Vec2Processor.from_pretrained(MODEL_PATH, use_auth_token=API_TOKEN)
### quantized model
config = AutoConfig.from_pretrained(MODEL_PATH, use_auth_token=API_TOKEN)
dummy_model = Wav2Vec2ForCTC(config)
quantized_model = torch.quantization.quantize_dynamic(dummy_model, {torch.nn.Linear}, dtype=torch.qint8, inplace=True)
quantized_model.load_state_dict(torch.load(QUANTIZED_MODEL_PATH))
## pitch preprocessor
processor_pitch = Wav2Vec2Processor.from_pretrained(PITCH_PATH, use_auth_token=API_TOKEN)
### quantized pitch mode
config = AutoConfig.from_pretrained(PITCH_PATH, use_auth_token=API_TOKEN)
dummy_pitch_model = Wav2Vec2ForCTC(config)
quantized_pitch_model = torch.quantization.quantize_dynamic(dummy_pitch_model, {torch.nn.Linear}, dtype=torch.qint8, inplace=True)
quantized_pitch_model.load_state_dict(torch.load(QUANTIZED_PITCH_MODEL_PATH))
app = FastAPI()
@app.get("/")
def read_root():
return {"Message": "Application startup complete"}
@app.post("/naomi_api_score/")
async def predict(
file: bytes = File(...),
word: str = Form(...),
pitch: str = Form("None"),
temperature: int = Form(...),
):
""" Transform input audio, get text and pitch from Huggingface api and calculate score by Levenshtein Distance Score
Parameters:
----------
file : bytes
input audio file
​
word : strings
true hiragana word to calculate word score
​
pitch : strings
true pitch to calculate pitch score
​
temperature: integer
the difficulty of AI model
​
Returns:
-------
timestamp: strings
current time Year-Month-Day-Hours:Minutes:Second
​
running_time : strings
running time second
​
error message : strings
error message from api
​
audio duration: integer
durations of source audio
​
target : integer
durations of target audio
​
method : string
method applied to transform source audio
​
word predict : strings
text from api
​
pitch predict : strings
pitch from api
​
wrong word index: strings (ex: 100)
wrong word compare to target word
​
wrong pitch index: strings (ex: 100)
wrong word compare to target word
​
score: integer
Levenshtein Distance Score from pitch and word
​
"""
upload_audio = ffmpeg_read(file, sampling_rate=16000)
audio_duration = len(upload_audio) / 16000
current_time = datetime.datetime.now().strftime("%Y-%h-%d-%H:%M:%S")
start_time = time.time()
error_message, score = None, None
if len(word) != len(pitch):
error_message = "Length of word and pitch input is not equal"
word_preds = query_raw(upload_audio, word, processor, processor_with_lm, quantized_model, temperature=temperature)
if pitch != "None":
pitch_preds = query_dummy(upload_audio, processor_pitch, quantized_pitch_model)
# find best word
word_score_list = []
for word_predict in word_preds:
word_score_list.append(fuzz.ratio(word, word_predict[0]))
word_score = max(word_score_list)
best_word_predict = word_preds[word_score_list.index(word_score)][0]
wrong_word = find_different(word, best_word_predict) # get wrong word
# find best pitch
if pitch != "None":
if pitch_preds is not None:
best_pitch_predict = pitch_preds.replace(" ", "")
if len(best_pitch_predict) < len(best_word_predict):
best_pitch_predict = best_pitch_predict + "1" * (len(best_word_predict) - len(best_pitch_predict))
else:
best_pitch_predict = best_pitch_predict[:len(best_word_predict)] # truncate to max len
pitch_score = fuzz.ratio(pitch, best_pitch_predict)
score = int((word_score * 2 + pitch_score) / 3)
wrong_pitch = find_different(pitch, best_pitch_predict) # get wrong pitch
else:
score = int(word_score)
best_pitch_predict = None
wrong_pitch = None
​
return {"timestamp": current_time,
"running_time": f"{round(time.time() - start_time, 4)} s",
"error message": error_message,
"audio duration": audio_duration,
"word predict": best_word_predict,
"pitch predict": best_pitch_predict,
"wrong word index": wrong_word,
"wrong pitch index": wrong_pitch,
"score": score
}