a77an commited on
Commit
00a16c9
·
verified ·
1 Parent(s): 1d71477

Upload run_predictions.py

Browse files
Files changed (1) hide show
  1. run_predictions.py +152 -0
run_predictions.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import pandas as pd
3
+ import argparse
4
+ import re
5
+ import matplotlib.pyplot as plt
6
+ import seaborn as sns
7
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
8
+ from sklearn.metrics import classification_report, confusion_matrix
9
+
10
+ # Define model names
11
+ bert_model_name = "bert-base-uncased"
12
+ hatebert_model_name = "GroNLP/hateBERT"
13
+
14
+ class CyberbullyingDetector:
15
+ def __init__(self, model_type="bert"):
16
+ if model_type == "bert":
17
+ self.tokenizer = AutoTokenizer.from_pretrained(bert_model_name)
18
+ self.model = AutoModelForSequenceClassification.from_pretrained(bert_model_name)
19
+ elif model_type == "hatebert":
20
+ self.tokenizer = AutoTokenizer.from_pretrained(hatebert_model_name)
21
+ self.model = AutoModelForSequenceClassification.from_pretrained(hatebert_model_name)
22
+ else:
23
+ raise ValueError("Invalid model_type. Choose 'bert' or 'hatebert'.")
24
+
25
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
26
+ self.model.to(self.device)
27
+
28
+ self.cyberbullying_threshold = 0.7 # Confidence threshold
29
+
30
+ self.trigger_words = [
31
+ 'buang', 'pokpok', 'bogo', 'linte', 'tanga', 'diputa', 'yuta mo', 'gaga',
32
+ 'lagtok ka', 'addict', 'bogok', 'gago', 'law-ay', 'demonyo ka', 'animal ka', 'animal',
33
+ 'bilatibay', 'yudipota', 'pangit', 'tikalon', 'tinikal', 'hambog',
34
+ 'batinggilan', 'biga-on', 'bulay-ug', 'agi', 'agitot', 'alpot', 'hangag'
35
+ ]
36
+
37
+ def find_triggers(self, text):
38
+ text_lower = text.lower()
39
+ return [word for word in self.trigger_words if word in text_lower]
40
+
41
+ def predict(self, text):
42
+ triggers = self.find_triggers(text)
43
+
44
+ inputs = self.tokenizer(
45
+ text,
46
+ return_tensors="pt",
47
+ truncation=True,
48
+ max_length=128,
49
+ padding=True
50
+ ).to(self.device)
51
+
52
+ with torch.no_grad():
53
+ outputs = self.model(**inputs)
54
+
55
+ logits = outputs.logits
56
+ probs = torch.nn.functional.softmax(logits, dim=1)
57
+ pred_class = torch.argmax(probs).item()
58
+ confidence = probs[0][pred_class].item()
59
+
60
+ if (pred_class == 1 and confidence >= self.cyberbullying_threshold) or (len(triggers) > 0):
61
+ label = "Cyberbullying"
62
+ is_cyberbullying = True
63
+ else:
64
+ label = "Safe"
65
+ is_cyberbullying = False
66
+
67
+ return {
68
+ "text": text,
69
+ "label": label,
70
+ "confidence": confidence,
71
+ "language": "hil",
72
+ "triggers": triggers,
73
+ "is_cyberbullying": is_cyberbullying
74
+ }
75
+
76
+ def save_confusion_matrix(y_true, y_pred, filename="confusion_matrix.png"):
77
+ labels = sorted(set(y_true + y_pred))
78
+ cm = confusion_matrix(y_true, y_pred, labels=labels)
79
+ plt.figure(figsize=(6, 4))
80
+ sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=labels, yticklabels=labels)
81
+ plt.title("Confusion Matrix")
82
+ plt.xlabel("Predicted")
83
+ plt.ylabel("True")
84
+ plt.tight_layout()
85
+ plt.savefig(filename)
86
+ plt.close()
87
+ print(f"📸 Confusion matrix saved as {filename}")
88
+
89
+ def run_predictions(input_csv=None, output_csv=None, model_type="bert"):
90
+ detector = CyberbullyingDetector(model_type=model_type)
91
+
92
+ if input_csv:
93
+ df = pd.read_csv(input_csv)
94
+ results = [detector.predict(text) for text in df['tweet_text']]
95
+
96
+ output_df = df.copy()
97
+ output_df['predicted_label'] = [r['label'] for r in results]
98
+ output_df['confidence'] = [r['confidence'] for r in results]
99
+ output_df['language'] = [r['language'] for r in results]
100
+ output_df['triggers'] = [', '.join(r['triggers']) for r in results]
101
+ output_df['is_cyberbullying'] = [r['is_cyberbullying'] for r in results]
102
+
103
+ output_df['true_label'] = output_df['cyberbullying_type'].apply(
104
+ lambda x: "Cyberbullying" if pd.notnull(x) and str(x).strip().lower() != "none" else "Safe"
105
+ )
106
+
107
+ if output_csv:
108
+ output_df.to_csv(output_csv, index=False)
109
+ print(f"\n✅ Predictions saved to {output_csv}")
110
+
111
+ print("\n📌 Sample Predictions:")
112
+ print(output_df[['tweet_text', 'predicted_label', 'confidence', 'triggers']].head(10).to_string(index=False))
113
+
114
+ print("\n📊 Prediction Summary:")
115
+ print(output_df['predicted_label'].value_counts())
116
+
117
+ print("\n✅ Ground Truth Summary:")
118
+ print(output_df['true_label'].value_counts())
119
+
120
+ accuracy = (output_df['predicted_label'] == output_df['true_label']).mean()
121
+ print(f"\n🎯 Accuracy: {accuracy:.2%}")
122
+
123
+ print("\n🧾 Classification Report:")
124
+ print(classification_report(output_df['true_label'], output_df['predicted_label'], digits=2, zero_division=0))
125
+
126
+ save_confusion_matrix(output_df['true_label'].tolist(), output_df['predicted_label'].tolist())
127
+
128
+ return output_df
129
+ else:
130
+ # Inference mode
131
+ print("\n🔍 Type a sentence to analyze (or 'exit' to quit):")
132
+ while True:
133
+ text = input(">>> ")
134
+ if text.lower() in ["exit", "quit"]:
135
+ break
136
+ result = detector.predict(text)
137
+ print(result)
138
+
139
+ if __name__ == '__main__':
140
+ parser = argparse.ArgumentParser(description="Cyberbullying Detector")
141
+ parser.add_argument('--input_file', type=str, help="Path to input CSV with 'tweet_text' column")
142
+ parser.add_argument('--output_file', type=str, help="Path to save results CSV")
143
+ parser.add_argument('--model', type=str, default='bert', choices=['bert', 'hatebert'], help="Model to use")
144
+
145
+ args = parser.parse_args()
146
+
147
+ if args.input_file:
148
+ print(f"📥 Running batch predictions from {args.input_file} using {args.model.upper()}...")
149
+ else:
150
+ print(f"🧪 No input file. Running in interactive mode using {args.model.upper()}...")
151
+
152
+ run_predictions(args.input_file, args.output_file, model_type=args.model)