File size: 3,937 Bytes
16b2ba7
2b470ab
 
 
16b2ba7
 
2b470ab
 
2d04c0e
2b470ab
16b2ba7
fe3311f
2b470ab
16b2ba7
 
 
2b470ab
16b2ba7
 
 
2b470ab
 
16b2ba7
2b470ab
 
 
16b2ba7
2b470ab
 
 
 
 
 
 
fe3311f
16b2ba7
2b470ab
 
 
16b2ba7
2b470ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16b2ba7
2b470ab
16b2ba7
2b470ab
 
 
16b2ba7
2b470ab
 
16b2ba7
 
 
 
 
 
 
f3eb85c
 
 
 
 
 
 
 
 
 
2b470ab
16b2ba7
a3af327
2b470ab
16b2ba7
2d04c0e
79819d8
16b2ba7
a3af327
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3a3863a
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
from fastapi import FastAPI, Request
from pydantic import BaseModel
import torch
import torch.nn as nn
from transformers import DistilBertTokenizer, DistilBertModel, AutoModel, AutoTokenizer
from langdetect import detect
from huggingface_hub import snapshot_download
import os


# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Download model repos from HF Hub
english_repo = snapshot_download("koyu008/English_Toxic_Classifier")
hinglish_repo = snapshot_download("koyu008/HInglish_comment_classifier")

# Tokenizers
english_tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
hinglish_tokenizer = AutoTokenizer.from_pretrained("xlm-roberta-base")


# English Model
class ToxicBERT(nn.Module):
    def __init__(self):
        super().__init__()
        self.bert = DistilBertModel.from_pretrained("distilbert-base-uncased")
        self.dropout = nn.Dropout(0.3)
        self.classifier = nn.Linear(self.bert.config.hidden_size, 6)

    def forward(self, input_ids, attention_mask):
        output = self.bert(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state[:, 0]
        return self.classifier(self.dropout(output))


# Hinglish Model
class HinglishToxicClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.bert = AutoModel.from_pretrained("xlm-roberta-base")
        hidden_size = self.bert.config.hidden_size
        self.pool = lambda hidden: torch.cat([
            hidden.mean(dim=1),
            hidden.max(dim=1).values
        ], dim=1)
        self.bottleneck = nn.Sequential(
            nn.Linear(2 * hidden_size, hidden_size),
            nn.ReLU(),
            nn.Dropout(0.2)
        )
        self.classifier = nn.Linear(hidden_size, 2)

    def forward(self, input_ids, attention_mask):
        hidden = self.bert(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state
        pooled = self.pool(hidden)
        x = self.bottleneck(pooled)
        return self.classifier(x)


# Instantiate and load models
english_model = ToxicBERT().to(device)
english_model.load_state_dict(torch.load(os.path.join(english_repo, "bert_toxic_classifier.pt"), map_location=device))
english_model.eval()

hinglish_model = HinglishToxicClassifier().to(device)
hinglish_model.load_state_dict(torch.load(os.path.join(hinglish_repo, "best_hinglish_model.pt"), map_location=device))
hinglish_model.eval()

# Labels
english_labels = ['toxic', 'severe toxic', 'obscene', 'threat', 'insult', 'identity hate']
hinglish_labels = ['not toxic', 'toxic']

# FastAPI
app = FastAPI()

from fastapi.middleware.cors import CORSMiddleware

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  # Or restrict to your frontend domain
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)


class TextIn(BaseModel):
    text: str


@app.post("/api/predict")
@app.post("/api/predict")
def predict(data: TextIn):
    text = data.text
    try:
        lang = detect(text)
    except:
        lang = "unknown"

    if lang == "en":
        tokenizer = english_tokenizer
        model = english_model
        inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True).to(device)
        with torch.no_grad():
            outputs = model(**inputs)
            probs = torch.sigmoid(outputs).squeeze().cpu().tolist()
        return {"language": "English", "predictions": dict(zip(english_labels, probs))}

    else:
        tokenizer = hinglish_tokenizer
        model = hinglish_model
        inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True).to(device)
        with torch.no_grad():
            outputs = model(**inputs)
            probs = torch.softmax(outputs, dim=1).squeeze().cpu().tolist()
        return {"language": "Hinglish", "predictions": dict(zip(hinglish_labels, probs))}

@app.get("/")
def root():
    return {"message": "API is running"}