a77an commited on
Commit
1e8265a
·
verified ·
1 Parent(s): e219d78

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +94 -94
app.py CHANGED
@@ -1,94 +1,94 @@
1
- from flask import Flask, request, jsonify, render_template
2
- from transformers import AutoTokenizer, AutoModelForSequenceClassification
3
- import torch
4
- import torch.nn.functional as F
5
- import re
6
- from flask_cors import CORS # Enable CORS
7
-
8
- # Initialize Flask app
9
- app = Flask(__name__)
10
- CORS(app) # Allow requests from frontend apps
11
-
12
- # Choose your model: 'bert-base-uncased' or 'GroNLP/hateBERT'
13
- MODEL_NAME = 'bert-base-uncased' # Change to 'GroNLP/hateBERT' if needed
14
-
15
- # Load tokenizer and model
16
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
17
- model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
18
- model.eval()
19
-
20
- # Two-class labels only
21
- LABELS = ['Safe', 'Cyberbullying']
22
-
23
- # Offensive trigger words
24
- TRIGGER_WORDS = [
25
- "gago", "pokpok", "yawa", "linte", "ulol", "tangina", "bilat", "putang", "tarantado", "bobo",
26
- "yudipota", "law-ay", "bilatibay", "hayop"
27
- ]
28
-
29
- # Detect trigger words in input text
30
- def find_triggers(text):
31
- found = []
32
- for word in TRIGGER_WORDS:
33
- if re.search(rf"\b{re.escape(word)}\b", text, re.IGNORECASE):
34
- found.append(word)
35
- return found
36
-
37
- # Predict function
38
- def predict_text(text):
39
- inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
40
-
41
- # Use GPU if available
42
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
43
- model.to(device)
44
- inputs = {key: value.to(device) for key, value in inputs.items()}
45
-
46
- with torch.no_grad():
47
- outputs = model(**inputs)
48
- logits = outputs.logits
49
- probs = F.softmax(logits, dim=1)
50
- confidence, predicted_class = torch.max(probs, dim=1)
51
-
52
- # Fallback logic: if model predicts more than 2 classes, default to Safe if out-of-bounds
53
- label_index = predicted_class.item()
54
- if label_index >= len(LABELS):
55
- label_index = 0 # default to "Safe"
56
-
57
- label = LABELS[label_index]
58
- confidence_score = round(confidence.item(), 4)
59
- triggers = find_triggers(text)
60
-
61
- # Override model prediction if offensive triggers found
62
- if triggers and label == "Safe":
63
- label = "Cyberbullying"
64
-
65
- return {
66
- "label": label,
67
- "confidence": confidence_score,
68
- "triggers": triggers
69
- }
70
-
71
- # Serve frontend
72
- @app.route('/')
73
- def index():
74
- return render_template('index.html') # Ensure templates/index.html exists
75
-
76
- # API endpoint
77
- @app.route("/predict", methods=["POST"])
78
- def predict_api():
79
- try:
80
- data = request.get_json()
81
- text = data.get("text", "")
82
-
83
- if not text.strip():
84
- return jsonify({"error": "No text provided"}), 400
85
-
86
- result = predict_text(text)
87
- return jsonify(result)
88
-
89
- except Exception as e:
90
- return jsonify({"error": str(e)}), 500
91
-
92
- # Run server
93
- if __name__ == "__main__":
94
- app.run(debug=True)
 
1
+ from flask import Flask, request, jsonify, render_template
2
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
3
+ import torch
4
+ import torch.nn.functional as F
5
+ import re
6
+ from flask_cors import CORS # Enable CORS
7
+
8
+ # Initialize Flask app
9
+ app = Flask(__name__)
10
+ CORS(app) # Allow requests from frontend apps
11
+
12
+ # Choose your model: 'bert-base-uncased' or 'GroNLP/hateBERT'
13
+ MODEL_NAME = 'bert-base-uncased' # Change to 'GroNLP/hateBERT' if needed
14
+
15
+ # Load tokenizer and model
16
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
17
+ model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
18
+ model.eval()
19
+
20
+ # Two-class labels only
21
+ LABELS = ['Safe', 'Cyberbullying']
22
+
23
+ # Offensive trigger words
24
+ TRIGGER_WORDS = [
25
+ "gago", "pokpok", "yawa", "linte", "ulol", "tangina", "bilat", "putang", "tarantado", "bobo",
26
+ "yudipota", "law-ay", "bilatibay", "hayop"
27
+ ]
28
+
29
+ # Detect trigger words in input text
30
+ def find_triggers(text):
31
+ found = []
32
+ for word in TRIGGER_WORDS:
33
+ if re.search(rf"\b{re.escape(word)}\b", text, re.IGNORECASE):
34
+ found.append(word)
35
+ return found
36
+
37
+ # Predict function
38
+ def predict_text(text):
39
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
40
+
41
+ # Use GPU if available
42
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
43
+ model.to(device)
44
+ inputs = {key: value.to(device) for key, value in inputs.items()}
45
+
46
+ with torch.no_grad():
47
+ outputs = model(**inputs)
48
+ logits = outputs.logits
49
+ probs = F.softmax(logits, dim=1)
50
+ confidence, predicted_class = torch.max(probs, dim=1)
51
+
52
+ # Fallback logic: if model predicts more than 2 classes, default to Safe if out-of-bounds
53
+ label_index = predicted_class.item()
54
+ if label_index >= len(LABELS):
55
+ label_index = 0 # default to "Safe"
56
+
57
+ label = LABELS[label_index]
58
+ confidence_score = round(confidence.item(), 4)
59
+ triggers = find_triggers(text)
60
+
61
+ # Override model prediction if offensive triggers found
62
+ if triggers and label == "Safe":
63
+ label = "Cyberbullying"
64
+
65
+ return {
66
+ "label": label,
67
+ "confidence": confidence_score,
68
+ "triggers": triggers
69
+ }
70
+
71
+ # Serve frontend
72
+ @app.route('/')
73
+ def index():
74
+ return render_template('index.html') # Ensure templates/index.html exists
75
+
76
+ # API endpoint
77
+ @app.route("/predict", methods=["POST"])
78
+ def predict_api():
79
+ try:
80
+ data = request.get_json()
81
+ text = data.get("text", "")
82
+
83
+ if not text.strip():
84
+ return jsonify({"error": "No text provided"}), 400
85
+
86
+ result = predict_text(text)
87
+ return jsonify(result)
88
+
89
+ except Exception as e:
90
+ return jsonify({"error": str(e)}), 500
91
+
92
+ # Run server
93
+ #if __name__ == "__main__":
94
+ # app.run(debug=True)