File size: 3,476 Bytes
85b7206
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import unittest
import torch
import torch.package
from transformers import MarianTokenizer

import models as mo
from AIModels import NeuralASR
from constants import sample_rate_resample, language_not_implemented


class TestModels(unittest.TestCase):

    def setUp(self):
        self.language_de = "de"
        self.language_en = "en"
        self.tmp_dir = torch.hub.get_dir()
        self.device = torch.device("cpu")

    def test_getASRModel_de_silero(self):
        asr = mo.getASRModel(self.language_de, model_name="silero")
        self.assertIsInstance(asr, NeuralASR)

    def test_getASRModel_en_silero(self):
        asr = mo.getASRModel(self.language_en, model_name="silero")
        self.assertIsInstance(asr, NeuralASR)

    def test_getASRModel_language_not_implemented(self):
        with self.assertRaises(ValueError):
            lang = "wrong_language"
            try:
                mo.getASRModel(lang, model_name="silero")
            except ValueError as ve:
                msg = language_not_implemented.format(lang)
                self.assertEqual(str(ve), msg)
                raise ve

    def test_getASRModel_model_not_implemented(self):
        with self.assertRaises(ValueError):
            model = "wrong_model"
            try:
                mo.getASRModel(self.language_de, model_name=model)
            except ValueError as ve:
                msg = f"Model '{model}' not implemented. Supported models: whisper, faster_whisper, silero."
                self.assertEqual(str(ve), msg)
                raise ve

    def test_getTranslationModel_de(self):
        model, tokenizer = mo.getTranslationModel(self.language_de)
        self.assertIsInstance(model, torch.nn.Module)
        self.assertIsInstance(tokenizer, MarianTokenizer)

    def test_getTranslationModel_not_implemented(self):
        with self.assertRaises(ValueError):
            lang = "wrong_language"
            try:
                mo.getTranslationModel(lang)
            except ValueError as ve:
                self.assertEqual(str(ve), language_not_implemented.format(lang))
                raise ve

    def test_whisper_wrapper_parse_word(self):
        from whisper_wrapper import parse_word_info

        inputs_list = [
            {'word': ' Hallo,', 'start': 0.0, 'end': 0.32, 'probability': 0.8968557715415955},
            {'word': ' wie', 'start': 0.54, 'end': 0.64, 'probability': 0.9032214879989624},
            {'word': ' geht', 'start': 0.64, 'end': 0.82, 'probability': 0.9981840252876282},
            {'word': ' es', 'start': 0.82, 'end': 1.04, 'probability': 0.9160193800926208},
            {'word': ' dir?', 'start': 1.04, 'end': 1.26, 'probability': 0.9904420971870422}
        ]
        expected_outputs_list = [
            {'word': ' Hallo,', 'start_ts': 0.0, 'end_ts': 5120.0},
            {'word': ' wie', 'start_ts': 8640.0, 'end_ts': 10240.0},
            {'word': ' geht', 'start_ts': 10240.0, 'end_ts': 13120.0},
            {'word': ' es', 'start_ts': 13120.0, 'end_ts': 16640.0},
            {'word': ' dir?', 'start_ts': 16640.0, 'end_ts': 20160.0}
        ]
        for word_info, expected_output in zip(inputs_list, expected_outputs_list):
            output = parse_word_info(
                word_info=word_info,
                sample_rate=sample_rate_resample
            )
            print(f"output: {output} .")
            self.assertEqual(output, expected_output)


if __name__ == '__main__':
    unittest.main()