Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, Request | |
from pydantic import BaseModel | |
import torch | |
import torch.nn as nn | |
from transformers import DistilBertTokenizer, DistilBertModel, AutoModel, AutoTokenizer | |
from langdetect import detect | |
from huggingface_hub import snapshot_download | |
import os | |
# Device | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# Download model repos from HF Hub | |
english_repo = snapshot_download("koyu008/English_Toxic_Classifier") | |
hinglish_repo = snapshot_download("koyu008/HInglish_comment_classifier") | |
# Tokenizers | |
english_tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased") | |
hinglish_tokenizer = AutoTokenizer.from_pretrained("xlm-roberta-base") | |
# English Model | |
class ToxicBERT(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.bert = DistilBertModel.from_pretrained("distilbert-base-uncased") | |
self.dropout = nn.Dropout(0.3) | |
self.classifier = nn.Linear(self.bert.config.hidden_size, 6) | |
def forward(self, input_ids, attention_mask): | |
output = self.bert(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state[:, 0] | |
return self.classifier(self.dropout(output)) | |
# Hinglish Model | |
class HinglishToxicClassifier(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.bert = AutoModel.from_pretrained("xlm-roberta-base") | |
hidden_size = self.bert.config.hidden_size | |
self.pool = lambda hidden: torch.cat([ | |
hidden.mean(dim=1), | |
hidden.max(dim=1).values | |
], dim=1) | |
self.bottleneck = nn.Sequential( | |
nn.Linear(2 * hidden_size, hidden_size), | |
nn.ReLU(), | |
nn.Dropout(0.2) | |
) | |
self.classifier = nn.Linear(hidden_size, 2) | |
def forward(self, input_ids, attention_mask): | |
hidden = self.bert(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state | |
pooled = self.pool(hidden) | |
x = self.bottleneck(pooled) | |
return self.classifier(x) | |
# Instantiate and load models | |
english_model = ToxicBERT().to(device) | |
english_model.load_state_dict(torch.load(os.path.join(english_repo, "bert_toxic_classifier.pt"), map_location=device)) | |
english_model.eval() | |
hinglish_model = HinglishToxicClassifier().to(device) | |
hinglish_model.load_state_dict(torch.load(os.path.join(hinglish_repo, "best_hinglish_model.pt"), map_location=device)) | |
hinglish_model.eval() | |
# Labels | |
english_labels = ['toxic', 'severe toxic', 'obscene', 'threat', 'insult', 'identity hate'] | |
hinglish_labels = ['not toxic', 'toxic'] | |
# FastAPI | |
app = FastAPI() | |
from fastapi.middleware.cors import CORSMiddleware | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], # Or restrict to your frontend domain | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
class TextIn(BaseModel): | |
text: str | |
def predict(data: TextIn): | |
text = data.text | |
try: | |
lang = detect(text) | |
except: | |
lang = "unknown" | |
if lang == "en": | |
tokenizer = english_tokenizer | |
model = english_model | |
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True).to(device) | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
probs = torch.sigmoid(outputs).squeeze().cpu().tolist() | |
return {"language": "English", "predictions": dict(zip(english_labels, probs))} | |
else: | |
tokenizer = hinglish_tokenizer | |
model = hinglish_model | |
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True).to(device) | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
probs = torch.softmax(outputs, dim=1).squeeze().cpu().tolist() | |
return {"language": "Hinglish", "predictions": dict(zip(hinglish_labels, probs))} | |
def root(): | |
return {"message": "API is running"} | |