Spaces:
Sleeping
Sleeping
File size: 3,596 Bytes
4e57ca3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 |
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from langdetect import detect
import torch
import torch.nn as nn
from transformers import (
DistilBertTokenizer, DistilBertModel,
AutoTokenizer, AutoModel
)
# ==== Model Classes ====
class ToxicBERT(nn.Module):
def __init__(self):
super().__init__()
self.bert = DistilBertModel.from_pretrained("distilbert-base-uncased")
self.dropout = nn.Dropout(0.3)
self.classifier = nn.Linear(self.bert.config.hidden_size, 6)
def forward(self, input_ids, attention_mask):
output = self.bert(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state[:, 0]
return self.classifier(self.dropout(output))
class HinglishToxicClassifier(nn.Module):
def __init__(self):
super().__init__()
self.bert = AutoModel.from_pretrained("xlm-roberta-base")
hidden_size = self.bert.config.hidden_size
self.pool = lambda hidden: torch.cat([
hidden.mean(dim=1),
hidden.max(dim=1).values
], dim=1)
self.bottleneck = nn.Sequential(
nn.Linear(2 * hidden_size, hidden_size),
nn.ReLU(),
nn.Dropout(0.2)
)
self.classifier = nn.Linear(hidden_size, 2)
def forward(self, input_ids, attention_mask):
hidden = self.bert(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state
pooled = self.pool(hidden)
x = self.bottleneck(pooled)
return self.classifier(x)
# ==== Load Tokenizers ====
english_tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
hinglish_tokenizer = AutoTokenizer.from_pretrained("xlm-roberta-base")
# ==== Load Models from Hugging Face Hub ====
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
english_model = ToxicBERT()
eng_url = "https://huggingface.co/koyu008/English_Toxic_Classifier/resolve/main/bert_toxic_classifier.pt"
english_model.load_state_dict(torch.hub.load_state_dict_from_url(eng_url, map_location=device))
english_model.eval().to(device)
hinglish_model = HinglishToxicClassifier()
hin_url = "https://huggingface.co/koyu008/HInglish_comment_classifier/resolve/main/best_hinglish_model.pt"
hinglish_model.load_state_dict(torch.hub.load_state_dict_from_url(hin_url, map_location=device))
hinglish_model.eval().to(device)
# ==== FastAPI setup ====
app = FastAPI()
class InputText(BaseModel):
text: str
@app.post("/predict")
def predict(input: InputText):
text = input.text.strip()
if not text:
raise HTTPException(status_code=400, detail="Input text cannot be empty")
# Language detection
try:
lang = detect(text)
except:
lang = "und"
if lang == "en":
model = english_model
tokenizer = english_tokenizer
labels = ["toxic", "severe toxic", "obscene", "threat", "insult", "identity hate"]
else:
model = hinglish_model
tokenizer = hinglish_tokenizer
labels = ["not toxic", "toxic"]
# Tokenization
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True).to(device)
with torch.no_grad():
outputs = model(**inputs)
probs = torch.softmax(outputs, dim=1).squeeze().tolist()
response = {
"language": "english" if lang == "en" else "hinglish",
"prediction": {label: float(round(prob, 4)) for label, prob in zip(labels, probs)}
}
return response
|