koyu008's picture
Update app.py
fe3311f verified
raw
history blame
3.76 kB
from fastapi import FastAPI
from pydantic import BaseModel
from langdetect import detect
import torch
import torch.nn as nn
from transformers import DistilBertModel, AutoModel, AutoTokenizer, DistilBertTokenizer
from huggingface_hub import snapshot_download
import os
# App and device
app = FastAPI()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Create safe local cache directory
hf_cache_dir = "./hf_cache"
os.makedirs(hf_cache_dir, exist_ok=True)
os.environ["TRANSFORMERS_CACHE"] = hf_cache_dir
# Download model repositories to local path
english_path = snapshot_download("koyu008/English_Toxic_Classifier", cache_dir=hf_cache_dir)
hinglish_path = snapshot_download("koyu008/Hinglish_comment_classifier", cache_dir=hf_cache_dir)
# ----------------------------
# Model classes
# ----------------------------
class ToxicBERT(nn.Module):
def __init__(self):
super().__init__()
self.bert = DistilBertModel.from_pretrained(english_path)
self.dropout = nn.Dropout(0.3)
self.classifier = nn.Linear(self.bert.config.hidden_size, 6)
def forward(self, input_ids, attention_mask):
output = self.bert(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state[:, 0]
return self.classifier(self.dropout(output))
class HinglishToxicClassifier(nn.Module):
def __init__(self):
super().__init__()
self.bert = AutoModel.from_pretrained(hinglish_path)
hidden_size = self.bert.config.hidden_size
self.pool = lambda hidden: torch.cat([
hidden.mean(dim=1),
hidden.max(dim=1).values
], dim=1)
self.bottleneck = nn.Sequential(
nn.Linear(2 * hidden_size, hidden_size),
nn.ReLU(),
nn.Dropout(0.2)
)
self.classifier = nn.Linear(hidden_size, 2)
def forward(self, input_ids, attention_mask):
hidden = self.bert(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state
pooled = self.pool(hidden)
x = self.bottleneck(pooled)
return self.classifier(x)
# ----------------------------
# Load Models & Tokenizers
# ----------------------------
english_model = ToxicBERT().to(device)
english_model.load_state_dict(torch.load("bert_toxic_classifier.pt", map_location=device))
english_model.eval()
english_tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
hinglish_model = HinglishToxicClassifier().to(device)
hinglish_model.load_state_dict(torch.load("best_hinglish_model.pt", map_location=device))
hinglish_model.eval()
hinglish_tokenizer = AutoTokenizer.from_pretrained("xlm-roberta-base")
# ----------------------------
# API
# ----------------------------
class InputText(BaseModel):
text: str
@app.post("/predict")
async def predict(input: InputText):
text = input.text
lang = detect(text)
if lang == "en":
inputs = english_tokenizer(text, return_tensors="pt", truncation=True, padding=True).to(device)
with torch.no_grad():
logits = english_model(**inputs)
probs = torch.softmax(logits, dim=1).cpu().numpy().tolist()[0]
return {
"language": "english",
"classes": ["toxic", "severe_toxic", "obscene", "threat", "insult", "identity_hate"],
"probabilities": probs
}
else:
inputs = hinglish_tokenizer(text, return_tensors="pt", truncation=True, padding=True).to(device)
with torch.no_grad():
logits = hinglish_model(**inputs)
probs = torch.softmax(logits, dim=1).cpu().numpy().tolist()[0]
return {
"language": "hinglish",
"classes": ["toxic", "non-toxic"],
"probabilities": probs
}