from fastapi import FastAPI, HTTPException from pydantic import BaseModel from langdetect import detect import torch import torch.nn as nn from transformers import ( DistilBertTokenizer, DistilBertModel, AutoTokenizer, AutoModel ) # ==== Model Classes ==== class ToxicBERT(nn.Module): def __init__(self): super().__init__() self.bert = DistilBertModel.from_pretrained("distilbert-base-uncased") self.dropout = nn.Dropout(0.3) self.classifier = nn.Linear(self.bert.config.hidden_size, 6) def forward(self, input_ids, attention_mask): output = self.bert(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state[:, 0] return self.classifier(self.dropout(output)) class HinglishToxicClassifier(nn.Module): def __init__(self): super().__init__() self.bert = AutoModel.from_pretrained("xlm-roberta-base") hidden_size = self.bert.config.hidden_size self.pool = lambda hidden: torch.cat([ hidden.mean(dim=1), hidden.max(dim=1).values ], dim=1) self.bottleneck = nn.Sequential( nn.Linear(2 * hidden_size, hidden_size), nn.ReLU(), nn.Dropout(0.2) ) self.classifier = nn.Linear(hidden_size, 2) def forward(self, input_ids, attention_mask): hidden = self.bert(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state pooled = self.pool(hidden) x = self.bottleneck(pooled) return self.classifier(x) # ==== Load Tokenizers ==== english_tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased") hinglish_tokenizer = AutoTokenizer.from_pretrained("xlm-roberta-base") # ==== Load Models from Hugging Face Hub ==== device = torch.device("cuda" if torch.cuda.is_available() else "cpu") english_model = ToxicBERT() eng_url = "https://huggingface.co/koyu008/English_Toxic_Classifier/resolve/main/bert_toxic_classifier.pt" english_model.load_state_dict(torch.hub.load_state_dict_from_url(eng_url, map_location=device)) english_model.eval().to(device) hinglish_model = HinglishToxicClassifier() hin_url = "https://huggingface.co/koyu008/HInglish_comment_classifier/resolve/main/best_hinglish_model.pt" hinglish_model.load_state_dict(torch.hub.load_state_dict_from_url(hin_url, map_location=device)) hinglish_model.eval().to(device) # ==== FastAPI setup ==== app = FastAPI() class InputText(BaseModel): text: str @app.post("/predict") def predict(input: InputText): text = input.text.strip() if not text: raise HTTPException(status_code=400, detail="Input text cannot be empty") # Language detection try: lang = detect(text) except: lang = "und" if lang == "en": model = english_model tokenizer = english_tokenizer labels = ["toxic", "severe toxic", "obscene", "threat", "insult", "identity hate"] else: model = hinglish_model tokenizer = hinglish_tokenizer labels = ["not toxic", "toxic"] # Tokenization inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True).to(device) with torch.no_grad(): outputs = model(**inputs) probs = torch.softmax(outputs, dim=1).squeeze().tolist() response = { "language": "english" if lang == "en" else "hinglish", "prediction": {label: float(round(prob, 4)) for label, prob in zip(labels, probs)} } return response