Spaces:
Sleeping
Sleeping
from flask import Flask, request, jsonify, render_template | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
import torch | |
import torch.nn.functional as F | |
import re | |
from flask_cors import CORS # Enable CORS | |
# Initialize Flask app | |
app = Flask(__name__) | |
CORS(app) # Allow requests from frontend apps | |
# Choose your model: 'bert-base-uncased' or 'GroNLP/hateBERT' | |
MODEL_NAME = 'bert-base-uncased' # Change to 'GroNLP/hateBERT' if needed | |
# Load tokenizer and model | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME) | |
model.eval() | |
# Two-class labels only | |
LABELS = ['Safe', 'Cyberbullying'] | |
# Offensive trigger words | |
TRIGGER_WORDS = [ | |
"gago", "pokpok", "yawa", "linte", "ulol", "tangina", "bilat", "putang", "tarantado", "bobo", | |
"yudipota", "law-ay", "bilatibay", "hayop" | |
] | |
# Detect trigger words in input text | |
def find_triggers(text): | |
found = [] | |
for word in TRIGGER_WORDS: | |
if re.search(rf"\b{re.escape(word)}\b", text, re.IGNORECASE): | |
found.append(word) | |
return found | |
# Predict function | |
def predict_text(text): | |
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True) | |
# Use GPU if available | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model.to(device) | |
inputs = {key: value.to(device) for key, value in inputs.items()} | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
logits = outputs.logits | |
probs = F.softmax(logits, dim=1) | |
confidence, predicted_class = torch.max(probs, dim=1) | |
# Fallback logic: if model predicts more than 2 classes, default to Safe if out-of-bounds | |
label_index = predicted_class.item() | |
if label_index >= len(LABELS): | |
label_index = 0 # default to "Safe" | |
label = LABELS[label_index] | |
confidence_score = round(confidence.item(), 4) | |
triggers = find_triggers(text) | |
# Override model prediction if offensive triggers found | |
if triggers and label == "Safe": | |
label = "Cyberbullying" | |
return { | |
"label": label, | |
"confidence": confidence_score, | |
"triggers": triggers | |
} | |
# Serve frontend | |
def index(): | |
return render_template('index.html') # Ensure templates/index.html exists | |
# API endpoint | |
def predict_api(): | |
try: | |
data = request.get_json() | |
text = data.get("text", "") | |
if not text.strip(): | |
return jsonify({"error": "No text provided"}), 400 | |
result = predict_text(text) | |
return jsonify(result) | |
except Exception as e: | |
return jsonify({"error": str(e)}), 500 | |
# Run server | |
#if __name__ == "__main__": | |
# app.run(debug=True) | |