File size: 3,596 Bytes
4e57ca3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from langdetect import detect
import torch
import torch.nn as nn
from transformers import (
    DistilBertTokenizer, DistilBertModel,
    AutoTokenizer, AutoModel
)

# ==== Model Classes ====

class ToxicBERT(nn.Module):
    def __init__(self):
        super().__init__()
        self.bert = DistilBertModel.from_pretrained("distilbert-base-uncased")
        self.dropout = nn.Dropout(0.3)
        self.classifier = nn.Linear(self.bert.config.hidden_size, 6)

    def forward(self, input_ids, attention_mask):
        output = self.bert(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state[:, 0]
        return self.classifier(self.dropout(output))

class HinglishToxicClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.bert = AutoModel.from_pretrained("xlm-roberta-base")
        hidden_size = self.bert.config.hidden_size

        self.pool = lambda hidden: torch.cat([
            hidden.mean(dim=1),
            hidden.max(dim=1).values
        ], dim=1)

        self.bottleneck = nn.Sequential(
            nn.Linear(2 * hidden_size, hidden_size),
            nn.ReLU(),
            nn.Dropout(0.2)
        )
        self.classifier = nn.Linear(hidden_size, 2)

    def forward(self, input_ids, attention_mask):
        hidden = self.bert(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state
        pooled = self.pool(hidden)
        x = self.bottleneck(pooled)
        return self.classifier(x)

# ==== Load Tokenizers ====
english_tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
hinglish_tokenizer = AutoTokenizer.from_pretrained("xlm-roberta-base")

# ==== Load Models from Hugging Face Hub ====
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

english_model = ToxicBERT()
eng_url = "https://huggingface.co/koyu008/English_Toxic_Classifier/resolve/main/bert_toxic_classifier.pt"
english_model.load_state_dict(torch.hub.load_state_dict_from_url(eng_url, map_location=device))
english_model.eval().to(device)

hinglish_model = HinglishToxicClassifier()
hin_url = "https://huggingface.co/koyu008/HInglish_comment_classifier/resolve/main/best_hinglish_model.pt"
hinglish_model.load_state_dict(torch.hub.load_state_dict_from_url(hin_url, map_location=device))
hinglish_model.eval().to(device)

# ==== FastAPI setup ====

app = FastAPI()

class InputText(BaseModel):
    text: str

@app.post("/predict")
def predict(input: InputText):
    text = input.text.strip()
    if not text:
        raise HTTPException(status_code=400, detail="Input text cannot be empty")

    # Language detection
    try:
        lang = detect(text)
    except:
        lang = "und"

    if lang == "en":
        model = english_model
        tokenizer = english_tokenizer
        labels = ["toxic", "severe toxic", "obscene", "threat", "insult", "identity hate"]
    else:
        model = hinglish_model
        tokenizer = hinglish_tokenizer
        labels = ["not toxic", "toxic"]

    # Tokenization
    inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True).to(device)
    with torch.no_grad():
        outputs = model(**inputs)
        probs = torch.softmax(outputs, dim=1).squeeze().tolist()

    response = {
        "language": "english" if lang == "en" else "hinglish",
        "prediction": {label: float(round(prob, 4)) for label, prob in zip(labels, probs)}
    }
    return response