a77an commited on
Commit
01d0250
·
verified ·
1 Parent(s): f4b426e

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +94 -0
app.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask, request, jsonify, render_template
2
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
3
+ import torch
4
+ import torch.nn.functional as F
5
+ import re
6
+ from flask_cors import CORS # Enable CORS
7
+
8
+ # Initialize Flask app
9
+ app = Flask(__name__)
10
+ CORS(app) # Allow requests from frontend apps
11
+
12
+ # Choose your model: 'bert-base-uncased' or 'GroNLP/hateBERT'
13
+ MODEL_NAME = 'bert-base-uncased' # Change to 'GroNLP/hateBERT' if needed
14
+
15
+ # Load tokenizer and model
16
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
17
+ model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
18
+ model.eval()
19
+
20
+ # Two-class labels only
21
+ LABELS = ['Safe', 'Cyberbullying']
22
+
23
+ # Offensive trigger words
24
+ TRIGGER_WORDS = [
25
+ "gago", "pokpok", "yawa", "linte", "ulol", "tangina", "bilat", "putang", "tarantado", "bobo",
26
+ "yudipota", "law-ay", "bilatibay", "hayop"
27
+ ]
28
+
29
+ # Detect trigger words in input text
30
+ def find_triggers(text):
31
+ found = []
32
+ for word in TRIGGER_WORDS:
33
+ if re.search(rf"\b{re.escape(word)}\b", text, re.IGNORECASE):
34
+ found.append(word)
35
+ return found
36
+
37
+ # Predict function
38
+ def predict_text(text):
39
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
40
+
41
+ # Use GPU if available
42
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
43
+ model.to(device)
44
+ inputs = {key: value.to(device) for key, value in inputs.items()}
45
+
46
+ with torch.no_grad():
47
+ outputs = model(**inputs)
48
+ logits = outputs.logits
49
+ probs = F.softmax(logits, dim=1)
50
+ confidence, predicted_class = torch.max(probs, dim=1)
51
+
52
+ # Fallback logic: if model predicts more than 2 classes, default to Safe if out-of-bounds
53
+ label_index = predicted_class.item()
54
+ if label_index >= len(LABELS):
55
+ label_index = 0 # default to "Safe"
56
+
57
+ label = LABELS[label_index]
58
+ confidence_score = round(confidence.item(), 4)
59
+ triggers = find_triggers(text)
60
+
61
+ # Override model prediction if offensive triggers found
62
+ if triggers and label == "Safe":
63
+ label = "Cyberbullying"
64
+
65
+ return {
66
+ "label": label,
67
+ "confidence": confidence_score,
68
+ "triggers": triggers
69
+ }
70
+
71
+ # Serve frontend
72
+ @app.route('/')
73
+ def index():
74
+ return render_template('index.html') # Ensure templates/index.html exists
75
+
76
+ # API endpoint
77
+ @app.route("/predict", methods=["POST"])
78
+ def predict_api():
79
+ try:
80
+ data = request.get_json()
81
+ text = data.get("text", "")
82
+
83
+ if not text.strip():
84
+ return jsonify({"error": "No text provided"}), 400
85
+
86
+ result = predict_text(text)
87
+ return jsonify(result)
88
+
89
+ except Exception as e:
90
+ return jsonify({"error": str(e)}), 500
91
+
92
+ # Run server
93
+ if __name__ == "__main__":
94
+ app.run(debug=True)