File size: 5,381 Bytes
a15195f
f787cd1
4c7e2b8
f787cd1
61fe542
241ce22
f787cd1
8732281
f787cd1
 
 
 
 
 
 
 
 
07d2b90
447ae70
 
f787cd1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aba3792
 
 
61fe542
aba3792
 
4c7e2b8
 
61fe542
4c7e2b8
 
61fe542
 
 
 
 
4c7e2b8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24c223a
4c7e2b8
 
 
61fe542
24c223a
4c7e2b8
24c223a
61fe542
 
4c7e2b8
24c223a
4c7e2b8
 
 
 
 
 
 
 
 
24c223a
 
 
 
 
 
 
 
 
4c7e2b8
 
 
 
29b492f
4c7e2b8
 
702e8c4
 
 
 
 
 
8732281
4c7e2b8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
from fastapi import FastAPI, File, Form
import datetime
import time
import torch
from typing import Optional

import os
import numpy as np
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC, Wav2Vec2ProcessorWithLM, AutoConfig
from huggingface_hub import hf_hub_download
from fuzzywuzzy import fuzz
from utils import ffmpeg_read, query_dummy, query_raw, find_different

## config
API_TOKEN = os.environ["API_TOKEN"]
MODEL_PATH = os.environ["MODEL_PATH"]
PITCH_PATH = os.environ["PITCH_PATH"]

QUANTIZED_MODEL_PATH = hf_hub_download(repo_id=MODEL_PATH, filename='quantized_model.pt', token=API_TOKEN)
QUANTIZED_PITCH_MODEL_PATH = hf_hub_download(repo_id=PITCH_PATH, filename='quantized_model.pt', token=API_TOKEN)


## word preprocessor
processor_with_lm = Wav2Vec2ProcessorWithLM.from_pretrained(MODEL_PATH, use_auth_token=API_TOKEN)
processor = Wav2Vec2Processor.from_pretrained(MODEL_PATH, use_auth_token=API_TOKEN)

### quantized model
config = AutoConfig.from_pretrained(MODEL_PATH, use_auth_token=API_TOKEN)
dummy_model = Wav2Vec2ForCTC(config)
quantized_model = torch.quantization.quantize_dynamic(dummy_model, {torch.nn.Linear}, dtype=torch.qint8, inplace=True)
quantized_model.load_state_dict(torch.load(QUANTIZED_MODEL_PATH))

## pitch preprocessor 
processor_pitch = Wav2Vec2Processor.from_pretrained(PITCH_PATH, use_auth_token=API_TOKEN)

### quantized pitch mode
config = AutoConfig.from_pretrained(PITCH_PATH, use_auth_token=API_TOKEN)
dummy_pitch_model = Wav2Vec2ForCTC(config)
quantized_pitch_model = torch.quantization.quantize_dynamic(dummy_pitch_model, {torch.nn.Linear}, dtype=torch.qint8, inplace=True)
quantized_pitch_model.load_state_dict(torch.load(QUANTIZED_PITCH_MODEL_PATH))

app = FastAPI()


@app.get("/")
def read_root():
    return {"Message": "Application startup complete"}


@app.post("/naomi_api_score/")
async def predict(
        file: bytes = File(...),
        word: str = Form(...),
        pitch: Optional[str] = Form(None),
        temperature: int = Form(...),
):
    """ Transform input audio, get text and pitch from Huggingface api and calculate score by Levenshtein Distance Score
        Parameters:
         ----------
        file : bytes
            input audio file
        word : strings
            true hiragana word to calculate word score
        pitch : strings
            true pitch to calculate pitch score
        temperature: integer
            the difficulty of AI model
        Returns:
        -------
        timestamp: strings
            current time Year-Month-Day-Hours:Minutes:Second
        running_time : strings
            running time second
        error message : strings
            error message from api
        audio duration: integer
            durations of source audio
        target : integer
            durations of target audio
        method : string
            method applied to transform source audio
        word predict : strings
            text from api
        pitch predict : strings
            pitch from api
        wrong word index: strings (ex: 100)
            wrong word compare to target word
        wrong pitch index: strings (ex: 100)
            wrong word compare to target word
        score: integer
            Levenshtein Distance Score from pitch and word
    """
    upload_audio = ffmpeg_read(file, sampling_rate=16000)
    audio_duration = len(upload_audio) / 16000
    current_time = datetime.datetime.now().strftime("%Y-%h-%d-%H:%M:%S")
    start_time = time.time()
    error_message, score, word_preds, pitch_preds = None, None, None, None

    word_preds = query_raw(upload_audio, word, processor, processor_with_lm, quantized_model, temperature=temperature)
    if pitch is not None:
        if len(word) != len(pitch):
            error_message = "Length of word and pitch input is not equal"
        pitch_preds = query_dummy(upload_audio, processor_pitch, quantized_pitch_model)

    # find best word
    word_score_list = []
    for word_predict in word_preds:
        word_score_list.append(fuzz.ratio(word, word_predict[0]))
    word_score = max(word_score_list)
    best_word_predict = word_preds[word_score_list.index(word_score)][0]
    wrong_word = find_different(word, best_word_predict)  # get wrong word

    # find best pitch
    if pitch_preds is not None:
        best_pitch_predict = pitch_preds.replace(" ", "")
        if len(best_pitch_predict) < len(best_word_predict):
            best_pitch_predict = best_pitch_predict + "1" * (len(best_word_predict) - len(best_pitch_predict))
        else:
            best_pitch_predict = best_pitch_predict[:len(best_word_predict)]  # truncate to max len
        pitch_score = fuzz.ratio(pitch, best_pitch_predict)
        score = int((word_score * 2 + pitch_score) / 3)
        wrong_pitch = find_different(pitch, best_pitch_predict)  # get wrong pitch
    else:
        score = int(word_score)
        best_pitch_predict = None
        wrong_pitch = None

    return {"timestamp": current_time,
            "running_time": f"{round(time.time() - start_time, 4)} s",
            "error_message": error_message,
            "audio_duration": audio_duration,
            "word_predict": best_word_predict,
            "pitch_predict": best_pitch_predict,
            "wrong_word_index": wrong_word,
            "wrong_pitch_index": wrong_pitch,
            "score": score,
            }