File size: 5,590 Bytes
a15195f
f787cd1
4c7e2b8
f787cd1
241ce22
f787cd1
8732281
f787cd1
 
 
 
 
 
 
 
 
07d2b90
f787cd1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aba3792
 
 
 
 
4c7e2b8
 
 
 
 
 
 
 
 
 
 
 
 
 
29b492f
4c7e2b8
 
29b492f
4c7e2b8
 
29b492f
4c7e2b8
 
29b492f
4c7e2b8
 
 
 
29b492f
4c7e2b8
 
29b492f
4c7e2b8
 
29b492f
4c7e2b8
 
29b492f
4c7e2b8
 
29b492f
4c7e2b8
 
29b492f
4c7e2b8
 
29b492f
4c7e2b8
 
29b492f
4c7e2b8
 
29b492f
4c7e2b8
 
29b492f
4c7e2b8
 
29b492f
4c7e2b8
 
8f79c5b
 
8732281
4c7e2b8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29b492f
4c7e2b8
 
 
 
 
 
 
 
8732281
 
4c7e2b8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
from fastapi import FastAPI, File, Form
import datetime
import time
import torch

import os
import numpy as np
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC, Wav2Vec2ProcessorWithLM, AutoConfig
from huggingface_hub import hf_hub_download
from fuzzywuzzy import fuzz
from utils import ffmpeg_read, query_dummy, query_raw, find_different

## config
API_TOKEN = os.environ["API_TOKEN"]
MODEL_PATH = os.environ["MODEL_PATH"]
PITCH_PATH = os.environ["PITCH_PATH"]

QUANTIZED_MODEL_PATH = hf_hub_download(repo_id=MODEL_PATH, filename='quantized_model.pt', token=API_TOKEN)
QUANTIZED_PITCH_MODEL_PATH = hf_hub_download(repo_id=PITCH_PATH, filename='quantized_model.pt', token=API_TOKEN)


## word preprocessor
processor_with_lm = Wav2Vec2ProcessorWithLM.from_pretrained(MODEL_PATH, use_auth_token=API_TOKEN)
processor = Wav2Vec2Processor.from_pretrained(MODEL_PATH, use_auth_token=API_TOKEN)

### quantized model
config = AutoConfig.from_pretrained(MODEL_PATH, use_auth_token=API_TOKEN)
dummy_model = Wav2Vec2ForCTC(config)
quantized_model = torch.quantization.quantize_dynamic(dummy_model, {torch.nn.Linear}, dtype=torch.qint8, inplace=True)
quantized_model.load_state_dict(torch.load(QUANTIZED_MODEL_PATH))

## pitch preprocessor 
processor_pitch = Wav2Vec2Processor.from_pretrained(PITCH_PATH, use_auth_token=API_TOKEN)

### quantized pitch mode
config = AutoConfig.from_pretrained(PITCH_PATH, use_auth_token=API_TOKEN)
dummy_pitch_model = Wav2Vec2ForCTC(config)
quantized_pitch_model = torch.quantization.quantize_dynamic(dummy_pitch_model, {torch.nn.Linear}, dtype=torch.qint8, inplace=True)
quantized_pitch_model.load_state_dict(torch.load(QUANTIZED_PITCH_MODEL_PATH))

app = FastAPI()

@app.get("/")
def read_root():
    return {"Message": "Application startup complete"}

@app.post("/naomi_api_score/")
async def predict(
                file: bytes = File(...),
                word: str = Form(...),
                pitch: str = Form("None"),
                temperature: int = Form(...),
                 ):
    """ Transform input audio, get text and pitch from Huggingface api and calculate score by Levenshtein Distance Score
        Parameters:
         ----------
        file : bytes
            input audio file

        word : strings
            true hiragana word to calculate word score

        pitch : strings
            true pitch to calculate pitch score

        temperature: integer
            the difficulty of AI model

        Returns:
        -------
        timestamp: strings
            current time Year-Month-Day-Hours:Minutes:Second

        running_time : strings
            running time second

        error message : strings
            error message from api

        audio duration: integer
            durations of source audio

        target : integer
            durations of target audio

        method : string
            method applied to transform source audio

        word predict : strings
            text from api

        pitch predict : strings
            pitch from api

        wrong word index: strings (ex: 100)
            wrong word compare to target word

        wrong pitch index: strings (ex: 100)
            wrong word compare to target word

        score: integer
            Levenshtein Distance Score from pitch and word

    """
    upload_audio = ffmpeg_read(file, sampling_rate=16000)
    # print(upload_audio.shape)
    # print(np.sum(np.abs(upload_audio)))
    debug = np.sum(np.abs(upload_audio))               
    audio_duration = len(upload_audio) / 16000
    current_time = datetime.datetime.now().strftime("%Y-%h-%d-%H:%M:%S")
    start_time = time.time()
    error_message, score = None, None
    
    if len(word) != len(pitch):
        error_message = "Length of word and pitch input is not equal"
    word_preds = query_raw(upload_audio, word, processor, processor_with_lm, quantized_model, temperature=temperature)
    if pitch != "None":
        pitch_preds = query_dummy(upload_audio, processor_pitch, quantized_pitch_model)
        
    # find best word
    word_score_list = []
    for word_predict in word_preds:
        word_score_list.append(fuzz.ratio(word, word_predict[0]))
    word_score = max(word_score_list)
    best_word_predict = word_preds[word_score_list.index(word_score)][0]
    wrong_word = find_different(word, best_word_predict)  # get wrong word

    # find best pitch
    if pitch != "None":
        if pitch_preds is not None:
            best_pitch_predict = pitch_preds.replace(" ", "")
            if len(best_pitch_predict) < len(best_word_predict):
                best_pitch_predict = best_pitch_predict + "1" * (len(best_word_predict) - len(best_pitch_predict))
            else:
                best_pitch_predict = best_pitch_predict[:len(best_word_predict)]  # truncate to max len
            pitch_score = fuzz.ratio(pitch, best_pitch_predict)
            score = int((word_score * 2 + pitch_score) / 3)
            wrong_pitch = find_different(pitch, best_pitch_predict)  # get wrong pitch
    else:
        score = int(word_score)
        best_pitch_predict = None
        wrong_pitch = None

    return {"timestamp": current_time,
            "running_time": f"{round(time.time() - start_time, 4)} s",
            "error message": error_message,
            "audio duration": audio_duration,
            "word predict": best_word_predict,
            "pitch predict": best_pitch_predict,
            "wrong word index": wrong_word,
            "wrong pitch index": wrong_pitch,
            "score": score,
            "debug": debug,
            }