a77an's picture
Update app.py
1e8265a verified
from flask import Flask, request, jsonify, render_template
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
import torch.nn.functional as F
import re
from flask_cors import CORS # Enable CORS
# Initialize Flask app
app = Flask(__name__)
CORS(app) # Allow requests from frontend apps
# Choose your model: 'bert-base-uncased' or 'GroNLP/hateBERT'
MODEL_NAME = 'bert-base-uncased' # Change to 'GroNLP/hateBERT' if needed
# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
model.eval()
# Two-class labels only
LABELS = ['Safe', 'Cyberbullying']
# Offensive trigger words
TRIGGER_WORDS = [
"gago", "pokpok", "yawa", "linte", "ulol", "tangina", "bilat", "putang", "tarantado", "bobo",
"yudipota", "law-ay", "bilatibay", "hayop"
]
# Detect trigger words in input text
def find_triggers(text):
found = []
for word in TRIGGER_WORDS:
if re.search(rf"\b{re.escape(word)}\b", text, re.IGNORECASE):
found.append(word)
return found
# Predict function
def predict_text(text):
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
# Use GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
inputs = {key: value.to(device) for key, value in inputs.items()}
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
probs = F.softmax(logits, dim=1)
confidence, predicted_class = torch.max(probs, dim=1)
# Fallback logic: if model predicts more than 2 classes, default to Safe if out-of-bounds
label_index = predicted_class.item()
if label_index >= len(LABELS):
label_index = 0 # default to "Safe"
label = LABELS[label_index]
confidence_score = round(confidence.item(), 4)
triggers = find_triggers(text)
# Override model prediction if offensive triggers found
if triggers and label == "Safe":
label = "Cyberbullying"
return {
"label": label,
"confidence": confidence_score,
"triggers": triggers
}
# Serve frontend
@app.route('/')
def index():
return render_template('index.html') # Ensure templates/index.html exists
# API endpoint
@app.route("/predict", methods=["POST"])
def predict_api():
try:
data = request.get_json()
text = data.get("text", "")
if not text.strip():
return jsonify({"error": "No text provided"}), 400
result = predict_text(text)
return jsonify(result)
except Exception as e:
return jsonify({"error": str(e)}), 500
# Run server
#if __name__ == "__main__":
# app.run(debug=True)