File size: 3,937 Bytes
16b2ba7 2b470ab 16b2ba7 2b470ab 2d04c0e 2b470ab 16b2ba7 fe3311f 2b470ab 16b2ba7 2b470ab 16b2ba7 2b470ab 16b2ba7 2b470ab 16b2ba7 2b470ab fe3311f 16b2ba7 2b470ab 16b2ba7 2b470ab 16b2ba7 2b470ab 16b2ba7 2b470ab 16b2ba7 2b470ab 16b2ba7 f3eb85c 2b470ab 16b2ba7 a3af327 2b470ab 16b2ba7 2d04c0e 79819d8 16b2ba7 a3af327 3a3863a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 |
from fastapi import FastAPI, Request
from pydantic import BaseModel
import torch
import torch.nn as nn
from transformers import DistilBertTokenizer, DistilBertModel, AutoModel, AutoTokenizer
from langdetect import detect
from huggingface_hub import snapshot_download
import os
# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Download model repos from HF Hub
english_repo = snapshot_download("koyu008/English_Toxic_Classifier")
hinglish_repo = snapshot_download("koyu008/HInglish_comment_classifier")
# Tokenizers
english_tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
hinglish_tokenizer = AutoTokenizer.from_pretrained("xlm-roberta-base")
# English Model
class ToxicBERT(nn.Module):
def __init__(self):
super().__init__()
self.bert = DistilBertModel.from_pretrained("distilbert-base-uncased")
self.dropout = nn.Dropout(0.3)
self.classifier = nn.Linear(self.bert.config.hidden_size, 6)
def forward(self, input_ids, attention_mask):
output = self.bert(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state[:, 0]
return self.classifier(self.dropout(output))
# Hinglish Model
class HinglishToxicClassifier(nn.Module):
def __init__(self):
super().__init__()
self.bert = AutoModel.from_pretrained("xlm-roberta-base")
hidden_size = self.bert.config.hidden_size
self.pool = lambda hidden: torch.cat([
hidden.mean(dim=1),
hidden.max(dim=1).values
], dim=1)
self.bottleneck = nn.Sequential(
nn.Linear(2 * hidden_size, hidden_size),
nn.ReLU(),
nn.Dropout(0.2)
)
self.classifier = nn.Linear(hidden_size, 2)
def forward(self, input_ids, attention_mask):
hidden = self.bert(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state
pooled = self.pool(hidden)
x = self.bottleneck(pooled)
return self.classifier(x)
# Instantiate and load models
english_model = ToxicBERT().to(device)
english_model.load_state_dict(torch.load(os.path.join(english_repo, "bert_toxic_classifier.pt"), map_location=device))
english_model.eval()
hinglish_model = HinglishToxicClassifier().to(device)
hinglish_model.load_state_dict(torch.load(os.path.join(hinglish_repo, "best_hinglish_model.pt"), map_location=device))
hinglish_model.eval()
# Labels
english_labels = ['toxic', 'severe toxic', 'obscene', 'threat', 'insult', 'identity hate']
hinglish_labels = ['not toxic', 'toxic']
# FastAPI
app = FastAPI()
from fastapi.middleware.cors import CORSMiddleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Or restrict to your frontend domain
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
class TextIn(BaseModel):
text: str
@app.post("/api/predict")
@app.post("/api/predict")
def predict(data: TextIn):
text = data.text
try:
lang = detect(text)
except:
lang = "unknown"
if lang == "en":
tokenizer = english_tokenizer
model = english_model
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True).to(device)
with torch.no_grad():
outputs = model(**inputs)
probs = torch.sigmoid(outputs).squeeze().cpu().tolist()
return {"language": "English", "predictions": dict(zip(english_labels, probs))}
else:
tokenizer = hinglish_tokenizer
model = hinglish_model
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True).to(device)
with torch.no_grad():
outputs = model(**inputs)
probs = torch.softmax(outputs, dim=1).squeeze().cpu().tolist()
return {"language": "Hinglish", "predictions": dict(zip(hinglish_labels, probs))}
@app.get("/")
def root():
return {"message": "API is running"}
|