Spaces:
Running
Running
File size: 3,476 Bytes
85b7206 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 |
import unittest
import torch
import torch.package
from transformers import MarianTokenizer
import models as mo
from AIModels import NeuralASR
from constants import sample_rate_resample, language_not_implemented
class TestModels(unittest.TestCase):
def setUp(self):
self.language_de = "de"
self.language_en = "en"
self.tmp_dir = torch.hub.get_dir()
self.device = torch.device("cpu")
def test_getASRModel_de_silero(self):
asr = mo.getASRModel(self.language_de, model_name="silero")
self.assertIsInstance(asr, NeuralASR)
def test_getASRModel_en_silero(self):
asr = mo.getASRModel(self.language_en, model_name="silero")
self.assertIsInstance(asr, NeuralASR)
def test_getASRModel_language_not_implemented(self):
with self.assertRaises(ValueError):
lang = "wrong_language"
try:
mo.getASRModel(lang, model_name="silero")
except ValueError as ve:
msg = language_not_implemented.format(lang)
self.assertEqual(str(ve), msg)
raise ve
def test_getASRModel_model_not_implemented(self):
with self.assertRaises(ValueError):
model = "wrong_model"
try:
mo.getASRModel(self.language_de, model_name=model)
except ValueError as ve:
msg = f"Model '{model}' not implemented. Supported models: whisper, faster_whisper, silero."
self.assertEqual(str(ve), msg)
raise ve
def test_getTranslationModel_de(self):
model, tokenizer = mo.getTranslationModel(self.language_de)
self.assertIsInstance(model, torch.nn.Module)
self.assertIsInstance(tokenizer, MarianTokenizer)
def test_getTranslationModel_not_implemented(self):
with self.assertRaises(ValueError):
lang = "wrong_language"
try:
mo.getTranslationModel(lang)
except ValueError as ve:
self.assertEqual(str(ve), language_not_implemented.format(lang))
raise ve
def test_whisper_wrapper_parse_word(self):
from whisper_wrapper import parse_word_info
inputs_list = [
{'word': ' Hallo,', 'start': 0.0, 'end': 0.32, 'probability': 0.8968557715415955},
{'word': ' wie', 'start': 0.54, 'end': 0.64, 'probability': 0.9032214879989624},
{'word': ' geht', 'start': 0.64, 'end': 0.82, 'probability': 0.9981840252876282},
{'word': ' es', 'start': 0.82, 'end': 1.04, 'probability': 0.9160193800926208},
{'word': ' dir?', 'start': 1.04, 'end': 1.26, 'probability': 0.9904420971870422}
]
expected_outputs_list = [
{'word': ' Hallo,', 'start_ts': 0.0, 'end_ts': 5120.0},
{'word': ' wie', 'start_ts': 8640.0, 'end_ts': 10240.0},
{'word': ' geht', 'start_ts': 10240.0, 'end_ts': 13120.0},
{'word': ' es', 'start_ts': 13120.0, 'end_ts': 16640.0},
{'word': ' dir?', 'start_ts': 16640.0, 'end_ts': 20160.0}
]
for word_info, expected_output in zip(inputs_list, expected_outputs_list):
output = parse_word_info(
word_info=word_info,
sample_rate=sample_rate_resample
)
print(f"output: {output} .")
self.assertEqual(output, expected_output)
if __name__ == '__main__':
unittest.main() |