Spaces:
Runtime error
Runtime error
from fastapi import FastAPI, File, Form | |
import datetime | |
import time | |
import torch | |
from typing import Optional | |
import os | |
import numpy as np | |
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC, Wav2Vec2ProcessorWithLM, AutoConfig | |
from huggingface_hub import hf_hub_download | |
from fuzzywuzzy import fuzz | |
from utils import ffmpeg_read, query_dummy, query_raw, find_different | |
## config | |
API_TOKEN = os.environ["API_TOKEN"] | |
MODEL_PATH = os.environ["MODEL_PATH"] | |
PITCH_PATH = os.environ["PITCH_PATH"] | |
QUANTIZED_MODEL_PATH = hf_hub_download(repo_id=MODEL_PATH, filename='quantized_model.pt', token=API_TOKEN) | |
QUANTIZED_PITCH_MODEL_PATH = hf_hub_download(repo_id=PITCH_PATH, filename='quantized_model.pt', token=API_TOKEN) | |
## word preprocessor | |
processor_with_lm = Wav2Vec2ProcessorWithLM.from_pretrained(MODEL_PATH, use_auth_token=API_TOKEN) | |
processor = Wav2Vec2Processor.from_pretrained(MODEL_PATH, use_auth_token=API_TOKEN) | |
### quantized model | |
config = AutoConfig.from_pretrained(MODEL_PATH, use_auth_token=API_TOKEN) | |
dummy_model = Wav2Vec2ForCTC(config) | |
quantized_model = torch.quantization.quantize_dynamic(dummy_model, {torch.nn.Linear}, dtype=torch.qint8, inplace=True) | |
quantized_model.load_state_dict(torch.load(QUANTIZED_MODEL_PATH)) | |
## pitch preprocessor | |
processor_pitch = Wav2Vec2Processor.from_pretrained(PITCH_PATH, use_auth_token=API_TOKEN) | |
### quantized pitch mode | |
config = AutoConfig.from_pretrained(PITCH_PATH, use_auth_token=API_TOKEN) | |
dummy_pitch_model = Wav2Vec2ForCTC(config) | |
quantized_pitch_model = torch.quantization.quantize_dynamic(dummy_pitch_model, {torch.nn.Linear}, dtype=torch.qint8, inplace=True) | |
quantized_pitch_model.load_state_dict(torch.load(QUANTIZED_PITCH_MODEL_PATH)) | |
app = FastAPI() | |
def read_root(): | |
return {"Message": "Application startup complete"} | |
async def predict( | |
file: bytes = File(...), | |
word: str = Form(...), | |
pitch: Optional[str] = Form(None), | |
temperature: int = Form(...), | |
): | |
""" Transform input audio, get text and pitch from Huggingface api and calculate score by Levenshtein Distance Score | |
Parameters: | |
---------- | |
file : bytes | |
input audio file | |
word : strings | |
true hiragana word to calculate word score | |
pitch : strings | |
true pitch to calculate pitch score | |
temperature: integer | |
the difficulty of AI model | |
Returns: | |
------- | |
timestamp: strings | |
current time Year-Month-Day-Hours:Minutes:Second | |
running_time : strings | |
running time second | |
error message : strings | |
error message from api | |
audio duration: integer | |
durations of source audio | |
target : integer | |
durations of target audio | |
method : string | |
method applied to transform source audio | |
word predict : strings | |
text from api | |
pitch predict : strings | |
pitch from api | |
wrong word index: strings (ex: 100) | |
wrong word compare to target word | |
wrong pitch index: strings (ex: 100) | |
wrong word compare to target word | |
score: integer | |
Levenshtein Distance Score from pitch and word | |
""" | |
upload_audio = ffmpeg_read(file, sampling_rate=16000) | |
audio_duration = len(upload_audio) / 16000 | |
current_time = datetime.datetime.now().strftime("%Y-%h-%d-%H:%M:%S") | |
start_time = time.time() | |
error_message, score, word_preds, pitch_preds = None, None, None, None | |
word_preds = query_raw(upload_audio, word, processor, processor_with_lm, quantized_model, temperature=temperature) | |
if pitch is not None: | |
if len(word) != len(pitch): | |
error_message = "Length of word and pitch input is not equal" | |
pitch_preds = query_dummy(upload_audio, processor_pitch, quantized_pitch_model) | |
# find best word | |
word_score_list = [] | |
for word_predict in word_preds: | |
word_score_list.append(fuzz.ratio(word, word_predict[0])) | |
word_score = max(word_score_list) | |
best_word_predict = word_preds[word_score_list.index(word_score)][0] | |
wrong_word = find_different(word, best_word_predict) # get wrong word | |
# find best pitch | |
if pitch_preds is not None: | |
best_pitch_predict = pitch_preds.replace(" ", "") | |
if len(best_pitch_predict) < len(best_word_predict): | |
best_pitch_predict = best_pitch_predict + "1" * (len(best_word_predict) - len(best_pitch_predict)) | |
else: | |
best_pitch_predict = best_pitch_predict[:len(best_word_predict)] # truncate to max len | |
pitch_score = fuzz.ratio(pitch, best_pitch_predict) | |
score = int((word_score * 2 + pitch_score) / 3) | |
wrong_pitch = find_different(pitch, best_pitch_predict) # get wrong pitch | |
else: | |
score = int(word_score) | |
best_pitch_predict = None | |
wrong_pitch = None | |
return {"timestamp": current_time, | |
"running_time": f"{round(time.time() - start_time, 4)} s", | |
"error_message": error_message, | |
"audio_duration": audio_duration, | |
"word_predict": best_word_predict, | |
"pitch_predict": best_pitch_predict, | |
"wrong_word_index": wrong_word, | |
"wrong_pitch_index": wrong_pitch, | |
"score": score, | |
} |