vumichien commited on
Commit
4c7e2b8
Β·
1 Parent(s): f08373d

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +107 -1
main.py CHANGED
@@ -1,5 +1,6 @@
1
  from fastapi import FastAPI
2
  import datetime
 
3
  import torch
4
  import os
5
  from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC, Wav2Vec2ProcessorWithLM, AutoConfig
@@ -39,4 +40,109 @@ app = FastAPI()
39
 
40
  @app.get("/")
41
  def read_root():
42
- return {"Message": "Application startup complete"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from fastapi import FastAPI
2
  import datetime
3
+ import time
4
  import torch
5
  import os
6
  from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC, Wav2Vec2ProcessorWithLM, AutoConfig
 
40
 
41
  @app.get("/")
42
  def read_root():
43
+ return {"Message": "Application startup complete"}
44
+
45
+ @app.post("/naomi_api_score/")
46
+ async def predict(
47
+ file: bytes = File(...),
48
+ word: str = Form(...),
49
+ pitch: str = Form("None"),
50
+ temperature: int = Form(...),
51
+ ):
52
+ """ Transform input audio, get text and pitch from Huggingface api and calculate score by Levenshtein Distance Score
53
+ Parameters:
54
+ ----------
55
+ file : bytes
56
+ input audio file
57
+ ​
58
+ word : strings
59
+ true hiragana word to calculate word score
60
+ ​
61
+ pitch : strings
62
+ true pitch to calculate pitch score
63
+ ​
64
+ temperature: integer
65
+ the difficulty of AI model
66
+ ​
67
+ Returns:
68
+ -------
69
+ timestamp: strings
70
+ current time Year-Month-Day-Hours:Minutes:Second
71
+ ​
72
+ running_time : strings
73
+ running time second
74
+ ​
75
+ error message : strings
76
+ error message from api
77
+ ​
78
+ audio duration: integer
79
+ durations of source audio
80
+ ​
81
+ target : integer
82
+ durations of target audio
83
+ ​
84
+ method : string
85
+ method applied to transform source audio
86
+ ​
87
+ word predict : strings
88
+ text from api
89
+ ​
90
+ pitch predict : strings
91
+ pitch from api
92
+ ​
93
+ wrong word index: strings (ex: 100)
94
+ wrong word compare to target word
95
+ ​
96
+ wrong pitch index: strings (ex: 100)
97
+ wrong word compare to target word
98
+ ​
99
+ score: integer
100
+ Levenshtein Distance Score from pitch and word
101
+ ​
102
+ """
103
+ upload_audio = ffmpeg_read(file, sampling_rate=16000)
104
+ audio_duration = len(upload_audio) / 16000
105
+ current_time = datetime.datetime.now().strftime("%Y-%h-%d-%H:%M:%S")
106
+ start_time = time.time()
107
+ error_message, score = None, None
108
+
109
+ if len(word) != len(pitch):
110
+ error_message = "Length of word and pitch input is not equal"
111
+ word_preds = query_raw(upload_audio, word, processor, processor_with_lm, quantized_model, temperature=temperature)
112
+ if pitch != "None":
113
+ pitch_preds = query_dummy(upload_audio, processor_pitch, quantized_pitch_model)
114
+
115
+ # find best word
116
+ word_score_list = []
117
+ for word_predict in word_preds:
118
+ word_score_list.append(fuzz.ratio(word, word_predict[0]))
119
+ word_score = max(word_score_list)
120
+ best_word_predict = word_preds[word_score_list.index(word_score)][0]
121
+ wrong_word = find_different(word, best_word_predict) # get wrong word
122
+
123
+ # find best pitch
124
+ if pitch != "None":
125
+ if pitch_preds is not None:
126
+ best_pitch_predict = pitch_preds.replace(" ", "")
127
+ if len(best_pitch_predict) < len(best_word_predict):
128
+ best_pitch_predict = best_pitch_predict + "1" * (len(best_word_predict) - len(best_pitch_predict))
129
+ else:
130
+ best_pitch_predict = best_pitch_predict[:len(best_word_predict)] # truncate to max len
131
+ pitch_score = fuzz.ratio(pitch, best_pitch_predict)
132
+ score = int((word_score * 2 + pitch_score) / 3)
133
+ wrong_pitch = find_different(pitch, best_pitch_predict) # get wrong pitch
134
+ else:
135
+ score = int(word_score)
136
+ best_pitch_predict = None
137
+ wrong_pitch = None
138
+ ​
139
+ return {"timestamp": current_time,
140
+ "running_time": f"{round(time.time() - start_time, 4)} s",
141
+ "error message": error_message,
142
+ "audio duration": audio_duration,
143
+ "word predict": best_word_predict,
144
+ "pitch predict": best_pitch_predict,
145
+ "wrong word index": wrong_word,
146
+ "wrong pitch index": wrong_pitch,
147
+ "score": score
148
+ }