Spaces:
Running
Running
Upload langid.py
Browse files- langid/langid.py +101 -0
langid/langid.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import fasttext
|
2 |
+
import re
|
3 |
+
from collections import Counter
|
4 |
+
languages = {
|
5 |
+
"0": "sma",
|
6 |
+
"1": "sme",
|
7 |
+
"2": "smj",
|
8 |
+
"3": "fin",
|
9 |
+
"4": "est",
|
10 |
+
"5": "eng"
|
11 |
+
}
|
12 |
+
class WordLid:
|
13 |
+
def __init__(self, model_path):
|
14 |
+
self.model = fasttext.load_model(model_path)
|
15 |
+
self.threshold = 0.5
|
16 |
+
|
17 |
+
def set_threshold(self, threshold):
|
18 |
+
self.threshold = threshold
|
19 |
+
|
20 |
+
def _clean_word(self, word):
|
21 |
+
word = word.lower()
|
22 |
+
word = re.sub(r'[^\w\s]', '', word)
|
23 |
+
word = re.sub(r'\s+', ' ', word)
|
24 |
+
word = re.sub(r'\d', '', word)
|
25 |
+
return word.strip()
|
26 |
+
|
27 |
+
def _predict_all_languages(self, word):
|
28 |
+
cleaned_word = self._clean_word(word)
|
29 |
+
labels, probabilities = self.model.predict(cleaned_word, k=-1)
|
30 |
+
|
31 |
+
print(word)
|
32 |
+
for l, p in zip(labels, probabilities):
|
33 |
+
print(f'{languages[l.replace("__label__", "")]} {p:.4f}')
|
34 |
+
return {label.replace('__label__', ''): prob for label, prob in zip(labels, probabilities)}
|
35 |
+
|
36 |
+
def _get_main_language(self, text):
|
37 |
+
words = [self._clean_word(word) for word in text.split() if word]
|
38 |
+
language_counts = Counter(
|
39 |
+
max(self._predict_all_languages(word), key=self._predict_all_languages(word).get)
|
40 |
+
for word in words if self._predict_all_languages(word)
|
41 |
+
)
|
42 |
+
return language_counts.most_common(1)[0][0] if language_counts else None
|
43 |
+
|
44 |
+
|
45 |
+
def get_lang_array(self, text):
|
46 |
+
main_language = self._get_main_language(text)
|
47 |
+
if main_language is None:
|
48 |
+
return ['unk'] * len(text)
|
49 |
+
|
50 |
+
lang_array = [main_language] * len(text)
|
51 |
+
word_start_index = 0
|
52 |
+
|
53 |
+
for word in text.split():
|
54 |
+
word_start_index = text.find(word, word_start_index)
|
55 |
+
cleaned_word = self._clean_word(word)
|
56 |
+
if not cleaned_word:
|
57 |
+
word_start_index += len(word)
|
58 |
+
continue
|
59 |
+
|
60 |
+
predictions = self._predict_all_languages(cleaned_word)
|
61 |
+
if not predictions:
|
62 |
+
word_start_index += len(word)
|
63 |
+
continue
|
64 |
+
|
65 |
+
best_word_lang = max(predictions, key=predictions.get)
|
66 |
+
main_lang_prob_for_word = predictions.get(main_language, 0.0) # Get main lang prob *for this word*
|
67 |
+
best_lang_prob = predictions[best_word_lang]
|
68 |
+
|
69 |
+
# Key change: Check if the best language probability is 0.5 greater than the main language probability *for this word*
|
70 |
+
if best_lang_prob >= main_lang_prob_for_word + 0.5:
|
71 |
+
for i in range(len(word)):
|
72 |
+
lang_array[word_start_index + i] = best_word_lang
|
73 |
+
|
74 |
+
word_start_index += len(word)
|
75 |
+
return [int(x) for x in lang_array]
|
76 |
+
#return lang_array
|
77 |
+
|
78 |
+
if __name__ == '__main__':
|
79 |
+
model_path = 'lang_id_model_q.bin'
|
80 |
+
identifier = WordLid(model_path)
|
81 |
+
|
82 |
+
test_texts = [
|
83 |
+
"Mumenvákki ođđasamos badji ii gávdno vuos sámegillii. Áigumuššan lea goit dubbet dan maiddái sámegielaide, lohká Yle Draama hoavda Jarmo Lampela."
|
84 |
+
|
85 |
+
]
|
86 |
+
|
87 |
+
for text in test_texts:
|
88 |
+
lang_array = identifier.get_lang_array(text)
|
89 |
+
print(f"\nText: '{text}'")
|
90 |
+
print(f"Language Array: {lang_array}")
|
91 |
+
|
92 |
+
#for i in range(0,len(text)):
|
93 |
+
# print(text[i], lang_array[i])
|
94 |
+
assert len(lang_array) == len(text), "Length mismatch!"
|
95 |
+
|
96 |
+
|
97 |
+
# Example of changing the threshold:
|
98 |
+
identifier.set_threshold(0.8)
|
99 |
+
lang_array = identifier.get_lang_array("Bonjour le monde!")
|
100 |
+
print(f"\nText: 'Bonjour le monde!' (with threshold 0.8)")
|
101 |
+
print(f"Language Array: {lang_array}")
|