koyu008 commited on
Commit
4e57ca3
·
verified ·
1 Parent(s): ad51919

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +103 -0
app.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from pydantic import BaseModel
3
+ from langdetect import detect
4
+ import torch
5
+ import torch.nn as nn
6
+ from transformers import (
7
+ DistilBertTokenizer, DistilBertModel,
8
+ AutoTokenizer, AutoModel
9
+ )
10
+
11
+ # ==== Model Classes ====
12
+
13
+ class ToxicBERT(nn.Module):
14
+ def __init__(self):
15
+ super().__init__()
16
+ self.bert = DistilBertModel.from_pretrained("distilbert-base-uncased")
17
+ self.dropout = nn.Dropout(0.3)
18
+ self.classifier = nn.Linear(self.bert.config.hidden_size, 6)
19
+
20
+ def forward(self, input_ids, attention_mask):
21
+ output = self.bert(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state[:, 0]
22
+ return self.classifier(self.dropout(output))
23
+
24
+ class HinglishToxicClassifier(nn.Module):
25
+ def __init__(self):
26
+ super().__init__()
27
+ self.bert = AutoModel.from_pretrained("xlm-roberta-base")
28
+ hidden_size = self.bert.config.hidden_size
29
+
30
+ self.pool = lambda hidden: torch.cat([
31
+ hidden.mean(dim=1),
32
+ hidden.max(dim=1).values
33
+ ], dim=1)
34
+
35
+ self.bottleneck = nn.Sequential(
36
+ nn.Linear(2 * hidden_size, hidden_size),
37
+ nn.ReLU(),
38
+ nn.Dropout(0.2)
39
+ )
40
+ self.classifier = nn.Linear(hidden_size, 2)
41
+
42
+ def forward(self, input_ids, attention_mask):
43
+ hidden = self.bert(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state
44
+ pooled = self.pool(hidden)
45
+ x = self.bottleneck(pooled)
46
+ return self.classifier(x)
47
+
48
+ # ==== Load Tokenizers ====
49
+ english_tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
50
+ hinglish_tokenizer = AutoTokenizer.from_pretrained("xlm-roberta-base")
51
+
52
+ # ==== Load Models from Hugging Face Hub ====
53
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
54
+
55
+ english_model = ToxicBERT()
56
+ eng_url = "https://huggingface.co/koyu008/English_Toxic_Classifier/resolve/main/bert_toxic_classifier.pt"
57
+ english_model.load_state_dict(torch.hub.load_state_dict_from_url(eng_url, map_location=device))
58
+ english_model.eval().to(device)
59
+
60
+ hinglish_model = HinglishToxicClassifier()
61
+ hin_url = "https://huggingface.co/koyu008/HInglish_comment_classifier/resolve/main/best_hinglish_model.pt"
62
+ hinglish_model.load_state_dict(torch.hub.load_state_dict_from_url(hin_url, map_location=device))
63
+ hinglish_model.eval().to(device)
64
+
65
+ # ==== FastAPI setup ====
66
+
67
+ app = FastAPI()
68
+
69
+ class InputText(BaseModel):
70
+ text: str
71
+
72
+ @app.post("/predict")
73
+ def predict(input: InputText):
74
+ text = input.text.strip()
75
+ if not text:
76
+ raise HTTPException(status_code=400, detail="Input text cannot be empty")
77
+
78
+ # Language detection
79
+ try:
80
+ lang = detect(text)
81
+ except:
82
+ lang = "und"
83
+
84
+ if lang == "en":
85
+ model = english_model
86
+ tokenizer = english_tokenizer
87
+ labels = ["toxic", "severe toxic", "obscene", "threat", "insult", "identity hate"]
88
+ else:
89
+ model = hinglish_model
90
+ tokenizer = hinglish_tokenizer
91
+ labels = ["not toxic", "toxic"]
92
+
93
+ # Tokenization
94
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True).to(device)
95
+ with torch.no_grad():
96
+ outputs = model(**inputs)
97
+ probs = torch.softmax(outputs, dim=1).squeeze().tolist()
98
+
99
+ response = {
100
+ "language": "english" if lang == "en" else "hinglish",
101
+ "prediction": {label: float(round(prob, 4)) for label, prob in zip(labels, probs)}
102
+ }
103
+ return response