koyu008's picture
Upload app.py
4e57ca3 verified
raw
history blame
3.6 kB
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from langdetect import detect
import torch
import torch.nn as nn
from transformers import (
DistilBertTokenizer, DistilBertModel,
AutoTokenizer, AutoModel
)
# ==== Model Classes ====
class ToxicBERT(nn.Module):
def __init__(self):
super().__init__()
self.bert = DistilBertModel.from_pretrained("distilbert-base-uncased")
self.dropout = nn.Dropout(0.3)
self.classifier = nn.Linear(self.bert.config.hidden_size, 6)
def forward(self, input_ids, attention_mask):
output = self.bert(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state[:, 0]
return self.classifier(self.dropout(output))
class HinglishToxicClassifier(nn.Module):
def __init__(self):
super().__init__()
self.bert = AutoModel.from_pretrained("xlm-roberta-base")
hidden_size = self.bert.config.hidden_size
self.pool = lambda hidden: torch.cat([
hidden.mean(dim=1),
hidden.max(dim=1).values
], dim=1)
self.bottleneck = nn.Sequential(
nn.Linear(2 * hidden_size, hidden_size),
nn.ReLU(),
nn.Dropout(0.2)
)
self.classifier = nn.Linear(hidden_size, 2)
def forward(self, input_ids, attention_mask):
hidden = self.bert(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state
pooled = self.pool(hidden)
x = self.bottleneck(pooled)
return self.classifier(x)
# ==== Load Tokenizers ====
english_tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
hinglish_tokenizer = AutoTokenizer.from_pretrained("xlm-roberta-base")
# ==== Load Models from Hugging Face Hub ====
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
english_model = ToxicBERT()
eng_url = "https://huggingface.co/koyu008/English_Toxic_Classifier/resolve/main/bert_toxic_classifier.pt"
english_model.load_state_dict(torch.hub.load_state_dict_from_url(eng_url, map_location=device))
english_model.eval().to(device)
hinglish_model = HinglishToxicClassifier()
hin_url = "https://huggingface.co/koyu008/HInglish_comment_classifier/resolve/main/best_hinglish_model.pt"
hinglish_model.load_state_dict(torch.hub.load_state_dict_from_url(hin_url, map_location=device))
hinglish_model.eval().to(device)
# ==== FastAPI setup ====
app = FastAPI()
class InputText(BaseModel):
text: str
@app.post("/predict")
def predict(input: InputText):
text = input.text.strip()
if not text:
raise HTTPException(status_code=400, detail="Input text cannot be empty")
# Language detection
try:
lang = detect(text)
except:
lang = "und"
if lang == "en":
model = english_model
tokenizer = english_tokenizer
labels = ["toxic", "severe toxic", "obscene", "threat", "insult", "identity hate"]
else:
model = hinglish_model
tokenizer = hinglish_tokenizer
labels = ["not toxic", "toxic"]
# Tokenization
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True).to(device)
with torch.no_grad():
outputs = model(**inputs)
probs = torch.softmax(outputs, dim=1).squeeze().tolist()
response = {
"language": "english" if lang == "en" else "hinglish",
"prediction": {label: float(round(prob, 4)) for label, prob in zip(labels, probs)}
}
return response