naomi-app-api / main.py
vumichien's picture
Update main.py
702e8c4
raw
history blame
5.38 kB
from fastapi import FastAPI, File, Form
import datetime
import time
import torch
from typing import Optional
import os
import numpy as np
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC, Wav2Vec2ProcessorWithLM, AutoConfig
from huggingface_hub import hf_hub_download
from fuzzywuzzy import fuzz
from utils import ffmpeg_read, query_dummy, query_raw, find_different
## config
API_TOKEN = os.environ["API_TOKEN"]
MODEL_PATH = os.environ["MODEL_PATH"]
PITCH_PATH = os.environ["PITCH_PATH"]
QUANTIZED_MODEL_PATH = hf_hub_download(repo_id=MODEL_PATH, filename='quantized_model.pt', token=API_TOKEN)
QUANTIZED_PITCH_MODEL_PATH = hf_hub_download(repo_id=PITCH_PATH, filename='quantized_model.pt', token=API_TOKEN)
## word preprocessor
processor_with_lm = Wav2Vec2ProcessorWithLM.from_pretrained(MODEL_PATH, use_auth_token=API_TOKEN)
processor = Wav2Vec2Processor.from_pretrained(MODEL_PATH, use_auth_token=API_TOKEN)
### quantized model
config = AutoConfig.from_pretrained(MODEL_PATH, use_auth_token=API_TOKEN)
dummy_model = Wav2Vec2ForCTC(config)
quantized_model = torch.quantization.quantize_dynamic(dummy_model, {torch.nn.Linear}, dtype=torch.qint8, inplace=True)
quantized_model.load_state_dict(torch.load(QUANTIZED_MODEL_PATH))
## pitch preprocessor
processor_pitch = Wav2Vec2Processor.from_pretrained(PITCH_PATH, use_auth_token=API_TOKEN)
### quantized pitch mode
config = AutoConfig.from_pretrained(PITCH_PATH, use_auth_token=API_TOKEN)
dummy_pitch_model = Wav2Vec2ForCTC(config)
quantized_pitch_model = torch.quantization.quantize_dynamic(dummy_pitch_model, {torch.nn.Linear}, dtype=torch.qint8, inplace=True)
quantized_pitch_model.load_state_dict(torch.load(QUANTIZED_PITCH_MODEL_PATH))
app = FastAPI()
@app.get("/")
def read_root():
return {"Message": "Application startup complete"}
@app.post("/naomi_api_score/")
async def predict(
file: bytes = File(...),
word: str = Form(...),
pitch: Optional[str] = Form(None),
temperature: int = Form(...),
):
""" Transform input audio, get text and pitch from Huggingface api and calculate score by Levenshtein Distance Score
Parameters:
----------
file : bytes
input audio file
word : strings
true hiragana word to calculate word score
pitch : strings
true pitch to calculate pitch score
temperature: integer
the difficulty of AI model
Returns:
-------
timestamp: strings
current time Year-Month-Day-Hours:Minutes:Second
running_time : strings
running time second
error message : strings
error message from api
audio duration: integer
durations of source audio
target : integer
durations of target audio
method : string
method applied to transform source audio
word predict : strings
text from api
pitch predict : strings
pitch from api
wrong word index: strings (ex: 100)
wrong word compare to target word
wrong pitch index: strings (ex: 100)
wrong word compare to target word
score: integer
Levenshtein Distance Score from pitch and word
"""
upload_audio = ffmpeg_read(file, sampling_rate=16000)
audio_duration = len(upload_audio) / 16000
current_time = datetime.datetime.now().strftime("%Y-%h-%d-%H:%M:%S")
start_time = time.time()
error_message, score, word_preds, pitch_preds = None, None, None, None
word_preds = query_raw(upload_audio, word, processor, processor_with_lm, quantized_model, temperature=temperature)
if pitch is not None:
if len(word) != len(pitch):
error_message = "Length of word and pitch input is not equal"
pitch_preds = query_dummy(upload_audio, processor_pitch, quantized_pitch_model)
# find best word
word_score_list = []
for word_predict in word_preds:
word_score_list.append(fuzz.ratio(word, word_predict[0]))
word_score = max(word_score_list)
best_word_predict = word_preds[word_score_list.index(word_score)][0]
wrong_word = find_different(word, best_word_predict) # get wrong word
# find best pitch
if pitch_preds is not None:
best_pitch_predict = pitch_preds.replace(" ", "")
if len(best_pitch_predict) < len(best_word_predict):
best_pitch_predict = best_pitch_predict + "1" * (len(best_word_predict) - len(best_pitch_predict))
else:
best_pitch_predict = best_pitch_predict[:len(best_word_predict)] # truncate to max len
pitch_score = fuzz.ratio(pitch, best_pitch_predict)
score = int((word_score * 2 + pitch_score) / 3)
wrong_pitch = find_different(pitch, best_pitch_predict) # get wrong pitch
else:
score = int(word_score)
best_pitch_predict = None
wrong_pitch = None
return {"timestamp": current_time,
"running_time": f"{round(time.time() - start_time, 4)} s",
"error_message": error_message,
"audio_duration": audio_duration,
"word_predict": best_word_predict,
"pitch_predict": best_pitch_predict,
"wrong_word_index": wrong_word,
"wrong_pitch_index": wrong_pitch,
"score": score,
}