cyberbully_research / run_predictions.py
a77an's picture
Upload run_predictions.py
00a16c9 verified
import torch
import pandas as pd
import argparse
import re
import matplotlib.pyplot as plt
import seaborn as sns
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from sklearn.metrics import classification_report, confusion_matrix
# Define model names
bert_model_name = "bert-base-uncased"
hatebert_model_name = "GroNLP/hateBERT"
class CyberbullyingDetector:
def __init__(self, model_type="bert"):
if model_type == "bert":
self.tokenizer = AutoTokenizer.from_pretrained(bert_model_name)
self.model = AutoModelForSequenceClassification.from_pretrained(bert_model_name)
elif model_type == "hatebert":
self.tokenizer = AutoTokenizer.from_pretrained(hatebert_model_name)
self.model = AutoModelForSequenceClassification.from_pretrained(hatebert_model_name)
else:
raise ValueError("Invalid model_type. Choose 'bert' or 'hatebert'.")
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model.to(self.device)
self.cyberbullying_threshold = 0.7 # Confidence threshold
self.trigger_words = [
'buang', 'pokpok', 'bogo', 'linte', 'tanga', 'diputa', 'yuta mo', 'gaga',
'lagtok ka', 'addict', 'bogok', 'gago', 'law-ay', 'demonyo ka', 'animal ka', 'animal',
'bilatibay', 'yudipota', 'pangit', 'tikalon', 'tinikal', 'hambog',
'batinggilan', 'biga-on', 'bulay-ug', 'agi', 'agitot', 'alpot', 'hangag'
]
def find_triggers(self, text):
text_lower = text.lower()
return [word for word in self.trigger_words if word in text_lower]
def predict(self, text):
triggers = self.find_triggers(text)
inputs = self.tokenizer(
text,
return_tensors="pt",
truncation=True,
max_length=128,
padding=True
).to(self.device)
with torch.no_grad():
outputs = self.model(**inputs)
logits = outputs.logits
probs = torch.nn.functional.softmax(logits, dim=1)
pred_class = torch.argmax(probs).item()
confidence = probs[0][pred_class].item()
if (pred_class == 1 and confidence >= self.cyberbullying_threshold) or (len(triggers) > 0):
label = "Cyberbullying"
is_cyberbullying = True
else:
label = "Safe"
is_cyberbullying = False
return {
"text": text,
"label": label,
"confidence": confidence,
"language": "hil",
"triggers": triggers,
"is_cyberbullying": is_cyberbullying
}
def save_confusion_matrix(y_true, y_pred, filename="confusion_matrix.png"):
labels = sorted(set(y_true + y_pred))
cm = confusion_matrix(y_true, y_pred, labels=labels)
plt.figure(figsize=(6, 4))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=labels, yticklabels=labels)
plt.title("Confusion Matrix")
plt.xlabel("Predicted")
plt.ylabel("True")
plt.tight_layout()
plt.savefig(filename)
plt.close()
print(f"πŸ“Έ Confusion matrix saved as {filename}")
def run_predictions(input_csv=None, output_csv=None, model_type="bert"):
detector = CyberbullyingDetector(model_type=model_type)
if input_csv:
df = pd.read_csv(input_csv)
results = [detector.predict(text) for text in df['tweet_text']]
output_df = df.copy()
output_df['predicted_label'] = [r['label'] for r in results]
output_df['confidence'] = [r['confidence'] for r in results]
output_df['language'] = [r['language'] for r in results]
output_df['triggers'] = [', '.join(r['triggers']) for r in results]
output_df['is_cyberbullying'] = [r['is_cyberbullying'] for r in results]
output_df['true_label'] = output_df['cyberbullying_type'].apply(
lambda x: "Cyberbullying" if pd.notnull(x) and str(x).strip().lower() != "none" else "Safe"
)
if output_csv:
output_df.to_csv(output_csv, index=False)
print(f"\nβœ… Predictions saved to {output_csv}")
print("\nπŸ“Œ Sample Predictions:")
print(output_df[['tweet_text', 'predicted_label', 'confidence', 'triggers']].head(10).to_string(index=False))
print("\nπŸ“Š Prediction Summary:")
print(output_df['predicted_label'].value_counts())
print("\nβœ… Ground Truth Summary:")
print(output_df['true_label'].value_counts())
accuracy = (output_df['predicted_label'] == output_df['true_label']).mean()
print(f"\n🎯 Accuracy: {accuracy:.2%}")
print("\n🧾 Classification Report:")
print(classification_report(output_df['true_label'], output_df['predicted_label'], digits=2, zero_division=0))
save_confusion_matrix(output_df['true_label'].tolist(), output_df['predicted_label'].tolist())
return output_df
else:
# Inference mode
print("\nπŸ” Type a sentence to analyze (or 'exit' to quit):")
while True:
text = input(">>> ")
if text.lower() in ["exit", "quit"]:
break
result = detector.predict(text)
print(result)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="Cyberbullying Detector")
parser.add_argument('--input_file', type=str, help="Path to input CSV with 'tweet_text' column")
parser.add_argument('--output_file', type=str, help="Path to save results CSV")
parser.add_argument('--model', type=str, default='bert', choices=['bert', 'hatebert'], help="Model to use")
args = parser.parse_args()
if args.input_file:
print(f"πŸ“₯ Running batch predictions from {args.input_file} using {args.model.upper()}...")
else:
print(f"πŸ§ͺ No input file. Running in interactive mode using {args.model.upper()}...")
run_predictions(args.input_file, args.output_file, model_type=args.model)