File size: 3,594 Bytes
4bfd07a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import fasttext
import re
from collections import Counter
languages = {
    "0": "sma",
    "1": "sme", 
    "2": "smj",
    "3": "fin",
    "4": "est",
    "5": "eng"
}
class WordLid:
    def __init__(self, model_path):
        self.model = fasttext.load_model(model_path)
        self.threshold = 0.5

    def set_threshold(self, threshold):
        self.threshold = threshold

    def _clean_word(self, word):
        word = word.lower()
        word = re.sub(r'[^\w\s]', '', word)
        word = re.sub(r'\s+', ' ', word)
        word = re.sub(r'\d', '', word)
        return word.strip()

    def _predict_all_languages(self, word):
        cleaned_word = self._clean_word(word)
        labels, probabilities = self.model.predict(cleaned_word, k=-1)
        
        print(word)
        for l, p in zip(labels, probabilities):
            print(f'{languages[l.replace("__label__", "")]} {p:.4f}')
        return {label.replace('__label__', ''): prob for label, prob in zip(labels, probabilities)}

    def _get_main_language(self, text):
        words = [self._clean_word(word) for word in text.split() if word]
        language_counts = Counter(
            max(self._predict_all_languages(word), key=self._predict_all_languages(word).get)
            for word in words if self._predict_all_languages(word)
        )
        return language_counts.most_common(1)[0][0] if language_counts else None


    def get_lang_array(self, text):
        main_language = self._get_main_language(text)
        if main_language is None:
            return ['unk'] * len(text)

        lang_array = [main_language] * len(text)
        word_start_index = 0

        for word in text.split():
            word_start_index = text.find(word, word_start_index)
            cleaned_word = self._clean_word(word)
            if not cleaned_word:
                word_start_index += len(word)
                continue

            predictions = self._predict_all_languages(cleaned_word)
            if not predictions:
                word_start_index += len(word)
                continue

            best_word_lang = max(predictions, key=predictions.get)
            main_lang_prob_for_word = predictions.get(main_language, 0.0)  # Get main lang prob *for this word*
            best_lang_prob = predictions[best_word_lang]

            # Key change: Check if the best language probability is 0.5 greater than the main language probability *for this word*
            if best_lang_prob >= main_lang_prob_for_word + 0.5:
                for i in range(len(word)):
                    lang_array[word_start_index + i] = best_word_lang

            word_start_index += len(word)
        return [int(x) for x in lang_array]
        #return lang_array

if __name__ == '__main__':
    model_path = 'lang_id_model_q.bin'
    identifier = WordLid(model_path)

    test_texts = [
        "Mumenvákki ođđasamos badji ii gávdno vuos sámegillii. Áigumuššan lea goit dubbet dan maiddái sámegielaide, lohká Yle Draama hoavda Jarmo Lampela."
     
    ]

    for text in test_texts:
        lang_array = identifier.get_lang_array(text)
        print(f"\nText: '{text}'")
        print(f"Language Array: {lang_array}")

        #for i in range(0,len(text)):
        #    print(text[i], lang_array[i])
        assert len(lang_array) == len(text), "Length mismatch!"


    # Example of changing the threshold:
    identifier.set_threshold(0.8)
    lang_array = identifier.get_lang_array("Bonjour le monde!")
    print(f"\nText: 'Bonjour le monde!' (with threshold 0.8)")
    print(f"Language Array: {lang_array}")