from fastapi import FastAPI from pydantic import BaseModel from langdetect import detect import torch import torch.nn as nn from transformers import DistilBertModel, AutoModel, AutoTokenizer, DistilBertTokenizer from huggingface_hub import snapshot_download import os # App and device app = FastAPI() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Create safe local cache directory hf_cache_dir = "./hf_cache" os.makedirs(hf_cache_dir, exist_ok=True) os.environ["TRANSFORMERS_CACHE"] = hf_cache_dir # Download model repositories to local path english_path = snapshot_download("koyu008/English_Toxic_Classifier", cache_dir=hf_cache_dir) hinglish_path = snapshot_download("koyu008/Hinglish_comment_classifier", cache_dir=hf_cache_dir) # ---------------------------- # Model classes # ---------------------------- class ToxicBERT(nn.Module): def __init__(self): super().__init__() self.bert = DistilBertModel.from_pretrained(english_path) self.dropout = nn.Dropout(0.3) self.classifier = nn.Linear(self.bert.config.hidden_size, 6) def forward(self, input_ids, attention_mask): output = self.bert(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state[:, 0] return self.classifier(self.dropout(output)) class HinglishToxicClassifier(nn.Module): def __init__(self): super().__init__() self.bert = AutoModel.from_pretrained(hinglish_path) hidden_size = self.bert.config.hidden_size self.pool = lambda hidden: torch.cat([ hidden.mean(dim=1), hidden.max(dim=1).values ], dim=1) self.bottleneck = nn.Sequential( nn.Linear(2 * hidden_size, hidden_size), nn.ReLU(), nn.Dropout(0.2) ) self.classifier = nn.Linear(hidden_size, 2) def forward(self, input_ids, attention_mask): hidden = self.bert(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state pooled = self.pool(hidden) x = self.bottleneck(pooled) return self.classifier(x) # ---------------------------- # Load Models & Tokenizers # ---------------------------- english_model = ToxicBERT().to(device) english_model.load_state_dict(torch.load("bert_toxic_classifier.pt", map_location=device)) english_model.eval() english_tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased") hinglish_model = HinglishToxicClassifier().to(device) hinglish_model.load_state_dict(torch.load("best_hinglish_model.pt", map_location=device)) hinglish_model.eval() hinglish_tokenizer = AutoTokenizer.from_pretrained("xlm-roberta-base") # ---------------------------- # API # ---------------------------- class InputText(BaseModel): text: str @app.post("/predict") async def predict(input: InputText): text = input.text lang = detect(text) if lang == "en": inputs = english_tokenizer(text, return_tensors="pt", truncation=True, padding=True).to(device) with torch.no_grad(): logits = english_model(**inputs) probs = torch.softmax(logits, dim=1).cpu().numpy().tolist()[0] return { "language": "english", "classes": ["toxic", "severe_toxic", "obscene", "threat", "insult", "identity_hate"], "probabilities": probs } else: inputs = hinglish_tokenizer(text, return_tensors="pt", truncation=True, padding=True).to(device) with torch.no_grad(): logits = hinglish_model(**inputs) probs = torch.softmax(logits, dim=1).cpu().numpy().tolist()[0] return { "language": "hinglish", "classes": ["toxic", "non-toxic"], "probabilities": probs }