asuni commited on
Commit
4bfd07a
·
verified ·
1 Parent(s): a56c1ff

Upload langid.py

Browse files
Files changed (1) hide show
  1. langid/langid.py +101 -0
langid/langid.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import fasttext
2
+ import re
3
+ from collections import Counter
4
+ languages = {
5
+ "0": "sma",
6
+ "1": "sme",
7
+ "2": "smj",
8
+ "3": "fin",
9
+ "4": "est",
10
+ "5": "eng"
11
+ }
12
+ class WordLid:
13
+ def __init__(self, model_path):
14
+ self.model = fasttext.load_model(model_path)
15
+ self.threshold = 0.5
16
+
17
+ def set_threshold(self, threshold):
18
+ self.threshold = threshold
19
+
20
+ def _clean_word(self, word):
21
+ word = word.lower()
22
+ word = re.sub(r'[^\w\s]', '', word)
23
+ word = re.sub(r'\s+', ' ', word)
24
+ word = re.sub(r'\d', '', word)
25
+ return word.strip()
26
+
27
+ def _predict_all_languages(self, word):
28
+ cleaned_word = self._clean_word(word)
29
+ labels, probabilities = self.model.predict(cleaned_word, k=-1)
30
+
31
+ print(word)
32
+ for l, p in zip(labels, probabilities):
33
+ print(f'{languages[l.replace("__label__", "")]} {p:.4f}')
34
+ return {label.replace('__label__', ''): prob for label, prob in zip(labels, probabilities)}
35
+
36
+ def _get_main_language(self, text):
37
+ words = [self._clean_word(word) for word in text.split() if word]
38
+ language_counts = Counter(
39
+ max(self._predict_all_languages(word), key=self._predict_all_languages(word).get)
40
+ for word in words if self._predict_all_languages(word)
41
+ )
42
+ return language_counts.most_common(1)[0][0] if language_counts else None
43
+
44
+
45
+ def get_lang_array(self, text):
46
+ main_language = self._get_main_language(text)
47
+ if main_language is None:
48
+ return ['unk'] * len(text)
49
+
50
+ lang_array = [main_language] * len(text)
51
+ word_start_index = 0
52
+
53
+ for word in text.split():
54
+ word_start_index = text.find(word, word_start_index)
55
+ cleaned_word = self._clean_word(word)
56
+ if not cleaned_word:
57
+ word_start_index += len(word)
58
+ continue
59
+
60
+ predictions = self._predict_all_languages(cleaned_word)
61
+ if not predictions:
62
+ word_start_index += len(word)
63
+ continue
64
+
65
+ best_word_lang = max(predictions, key=predictions.get)
66
+ main_lang_prob_for_word = predictions.get(main_language, 0.0) # Get main lang prob *for this word*
67
+ best_lang_prob = predictions[best_word_lang]
68
+
69
+ # Key change: Check if the best language probability is 0.5 greater than the main language probability *for this word*
70
+ if best_lang_prob >= main_lang_prob_for_word + 0.5:
71
+ for i in range(len(word)):
72
+ lang_array[word_start_index + i] = best_word_lang
73
+
74
+ word_start_index += len(word)
75
+ return [int(x) for x in lang_array]
76
+ #return lang_array
77
+
78
+ if __name__ == '__main__':
79
+ model_path = 'lang_id_model_q.bin'
80
+ identifier = WordLid(model_path)
81
+
82
+ test_texts = [
83
+ "Mumenvákki ođđasamos badji ii gávdno vuos sámegillii. Áigumuššan lea goit dubbet dan maiddái sámegielaide, lohká Yle Draama hoavda Jarmo Lampela."
84
+
85
+ ]
86
+
87
+ for text in test_texts:
88
+ lang_array = identifier.get_lang_array(text)
89
+ print(f"\nText: '{text}'")
90
+ print(f"Language Array: {lang_array}")
91
+
92
+ #for i in range(0,len(text)):
93
+ # print(text[i], lang_array[i])
94
+ assert len(lang_array) == len(text), "Length mismatch!"
95
+
96
+
97
+ # Example of changing the threshold:
98
+ identifier.set_threshold(0.8)
99
+ lang_array = identifier.get_lang_array("Bonjour le monde!")
100
+ print(f"\nText: 'Bonjour le monde!' (with threshold 0.8)")
101
+ print(f"Language Array: {lang_array}")