Spaces:
Running
Running
alessandro trinca tornidor
feat: port whisper and faster-whisper support from https://github.com/Thiagohgl/ai-pronunciation-trainer
85b7206
import unittest | |
import torch | |
import torch.package | |
from transformers import MarianTokenizer | |
import models as mo | |
from AIModels import NeuralASR | |
from constants import sample_rate_resample, language_not_implemented | |
class TestModels(unittest.TestCase): | |
def setUp(self): | |
self.language_de = "de" | |
self.language_en = "en" | |
self.tmp_dir = torch.hub.get_dir() | |
self.device = torch.device("cpu") | |
def test_getASRModel_de_silero(self): | |
asr = mo.getASRModel(self.language_de, model_name="silero") | |
self.assertIsInstance(asr, NeuralASR) | |
def test_getASRModel_en_silero(self): | |
asr = mo.getASRModel(self.language_en, model_name="silero") | |
self.assertIsInstance(asr, NeuralASR) | |
def test_getASRModel_language_not_implemented(self): | |
with self.assertRaises(ValueError): | |
lang = "wrong_language" | |
try: | |
mo.getASRModel(lang, model_name="silero") | |
except ValueError as ve: | |
msg = language_not_implemented.format(lang) | |
self.assertEqual(str(ve), msg) | |
raise ve | |
def test_getASRModel_model_not_implemented(self): | |
with self.assertRaises(ValueError): | |
model = "wrong_model" | |
try: | |
mo.getASRModel(self.language_de, model_name=model) | |
except ValueError as ve: | |
msg = f"Model '{model}' not implemented. Supported models: whisper, faster_whisper, silero." | |
self.assertEqual(str(ve), msg) | |
raise ve | |
def test_getTranslationModel_de(self): | |
model, tokenizer = mo.getTranslationModel(self.language_de) | |
self.assertIsInstance(model, torch.nn.Module) | |
self.assertIsInstance(tokenizer, MarianTokenizer) | |
def test_getTranslationModel_not_implemented(self): | |
with self.assertRaises(ValueError): | |
lang = "wrong_language" | |
try: | |
mo.getTranslationModel(lang) | |
except ValueError as ve: | |
self.assertEqual(str(ve), language_not_implemented.format(lang)) | |
raise ve | |
def test_whisper_wrapper_parse_word(self): | |
from whisper_wrapper import parse_word_info | |
inputs_list = [ | |
{'word': ' Hallo,', 'start': 0.0, 'end': 0.32, 'probability': 0.8968557715415955}, | |
{'word': ' wie', 'start': 0.54, 'end': 0.64, 'probability': 0.9032214879989624}, | |
{'word': ' geht', 'start': 0.64, 'end': 0.82, 'probability': 0.9981840252876282}, | |
{'word': ' es', 'start': 0.82, 'end': 1.04, 'probability': 0.9160193800926208}, | |
{'word': ' dir?', 'start': 1.04, 'end': 1.26, 'probability': 0.9904420971870422} | |
] | |
expected_outputs_list = [ | |
{'word': ' Hallo,', 'start_ts': 0.0, 'end_ts': 5120.0}, | |
{'word': ' wie', 'start_ts': 8640.0, 'end_ts': 10240.0}, | |
{'word': ' geht', 'start_ts': 10240.0, 'end_ts': 13120.0}, | |
{'word': ' es', 'start_ts': 13120.0, 'end_ts': 16640.0}, | |
{'word': ' dir?', 'start_ts': 16640.0, 'end_ts': 20160.0} | |
] | |
for word_info, expected_output in zip(inputs_list, expected_outputs_list): | |
output = parse_word_info( | |
word_info=word_info, | |
sample_rate=sample_rate_resample | |
) | |
print(f"output: {output} .") | |
self.assertEqual(output, expected_output) | |
if __name__ == '__main__': | |
unittest.main() |