SivaMallikarjun commited on
Commit
76aeebf
·
verified ·
1 Parent(s): a9d48b2

Create model.py

Browse files
Files changed (1) hide show
  1. model.py +52 -0
model.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import fasttext
4
+
5
+ class SimpleMultilingualClassifier(nn.Module):
6
+ def __init__(self, embedding_files, num_classes, embedding_dim=100):
7
+ super().__init__()
8
+ self.embedding_files = embedding_files
9
+ self.embedding_dim = embedding_dim
10
+ self.linear = nn.Linear(embedding_dim, num_classes)
11
+ self.language_models = {}
12
+ for lang, path in embedding_files.items():
13
+ self.language_models[lang] = fasttext.load_model(path)
14
+
15
+ def get_embedding(self, text, lang):
16
+ if lang in self.language_models:
17
+ return torch.tensor(self.language_models[lang].get_sentence_vector(text))
18
+ else:
19
+ raise ValueError(f"Language '{lang}' not supported.")
20
+
21
+ def forward(self, text, lang):
22
+ embedding = self.get_embedding(text, lang)
23
+ return self.linear(embedding)
24
+
25
+ def predict(self, text, lang, class_labels):
26
+ self.eval()
27
+ with torch.no_grad():
28
+ output = self.forward(text, lang).unsqueeze(0) # Add batch dimension
29
+ probabilities = torch.softmax(output, dim=-1)
30
+ predicted_class_index = torch.argmax(probabilities, dim=-1).item()
31
+ return class_labels[predicted_class_index]
32
+
33
+ # Example usage (you'd need to define your classes and supported languages)
34
+ if __name__ == '__main__':
35
+ embedding_files = {
36
+ 'en': 'fasttext_embeddings/cc.en.100.bin',
37
+ 'fr': 'fasttext_embeddings/cc.fr.100.bin'
38
+ }
39
+ num_classes = 3 # Example number of classes
40
+ class_labels = ["positive", "negative", "neutral"]
41
+ model = SimpleMultilingualClassifier(embedding_files, num_classes)
42
+
43
+ # Dummy prediction
44
+ text_en = "This is a great movie."
45
+ lang_en = 'en'
46
+ prediction_en = model.predict(text_en, lang_en, class_labels)
47
+ print(f"English Prediction: {prediction_en}")
48
+
49
+ text_fr = "C'est un film incroyable."
50
+ lang_fr = 'fr'
51
+ prediction_fr = model.predict(text_fr, lang_fr, class_labels)
52
+ print(f"French Prediction: {prediction_fr}")