Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, HTTPException | |
from pydantic import BaseModel | |
from langdetect import detect | |
import torch | |
import torch.nn as nn | |
from transformers import ( | |
DistilBertTokenizer, DistilBertModel, | |
AutoTokenizer, AutoModel | |
) | |
# ==== Model Classes ==== | |
class ToxicBERT(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.bert = DistilBertModel.from_pretrained("distilbert-base-uncased") | |
self.dropout = nn.Dropout(0.3) | |
self.classifier = nn.Linear(self.bert.config.hidden_size, 6) | |
def forward(self, input_ids, attention_mask): | |
output = self.bert(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state[:, 0] | |
return self.classifier(self.dropout(output)) | |
class HinglishToxicClassifier(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.bert = AutoModel.from_pretrained("xlm-roberta-base") | |
hidden_size = self.bert.config.hidden_size | |
self.pool = lambda hidden: torch.cat([ | |
hidden.mean(dim=1), | |
hidden.max(dim=1).values | |
], dim=1) | |
self.bottleneck = nn.Sequential( | |
nn.Linear(2 * hidden_size, hidden_size), | |
nn.ReLU(), | |
nn.Dropout(0.2) | |
) | |
self.classifier = nn.Linear(hidden_size, 2) | |
def forward(self, input_ids, attention_mask): | |
hidden = self.bert(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state | |
pooled = self.pool(hidden) | |
x = self.bottleneck(pooled) | |
return self.classifier(x) | |
# ==== Load Tokenizers ==== | |
english_tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased") | |
hinglish_tokenizer = AutoTokenizer.from_pretrained("xlm-roberta-base") | |
# ==== Load Models from Hugging Face Hub ==== | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
english_model = ToxicBERT() | |
eng_url = "https://huggingface.co/koyu008/English_Toxic_Classifier/resolve/main/bert_toxic_classifier.pt" | |
english_model.load_state_dict(torch.hub.load_state_dict_from_url(eng_url, map_location=device)) | |
english_model.eval().to(device) | |
hinglish_model = HinglishToxicClassifier() | |
hin_url = "https://huggingface.co/koyu008/HInglish_comment_classifier/resolve/main/best_hinglish_model.pt" | |
hinglish_model.load_state_dict(torch.hub.load_state_dict_from_url(hin_url, map_location=device)) | |
hinglish_model.eval().to(device) | |
# ==== FastAPI setup ==== | |
app = FastAPI() | |
class InputText(BaseModel): | |
text: str | |
def predict(input: InputText): | |
text = input.text.strip() | |
if not text: | |
raise HTTPException(status_code=400, detail="Input text cannot be empty") | |
# Language detection | |
try: | |
lang = detect(text) | |
except: | |
lang = "und" | |
if lang == "en": | |
model = english_model | |
tokenizer = english_tokenizer | |
labels = ["toxic", "severe toxic", "obscene", "threat", "insult", "identity hate"] | |
else: | |
model = hinglish_model | |
tokenizer = hinglish_tokenizer | |
labels = ["not toxic", "toxic"] | |
# Tokenization | |
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True).to(device) | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
probs = torch.softmax(outputs, dim=1).squeeze().tolist() | |
response = { | |
"language": "english" if lang == "en" else "hinglish", | |
"prediction": {label: float(round(prob, 4)) for label, prob in zip(labels, probs)} | |
} | |
return response | |