Spaces:
Sleeping
Sleeping
import torch | |
import pandas as pd | |
import argparse | |
import re | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
from sklearn.metrics import classification_report, confusion_matrix | |
# Define model names | |
bert_model_name = "bert-base-uncased" | |
hatebert_model_name = "GroNLP/hateBERT" | |
class CyberbullyingDetector: | |
def __init__(self, model_type="bert"): | |
if model_type == "bert": | |
self.tokenizer = AutoTokenizer.from_pretrained(bert_model_name) | |
self.model = AutoModelForSequenceClassification.from_pretrained(bert_model_name) | |
elif model_type == "hatebert": | |
self.tokenizer = AutoTokenizer.from_pretrained(hatebert_model_name) | |
self.model = AutoModelForSequenceClassification.from_pretrained(hatebert_model_name) | |
else: | |
raise ValueError("Invalid model_type. Choose 'bert' or 'hatebert'.") | |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
self.model.to(self.device) | |
self.cyberbullying_threshold = 0.7 # Confidence threshold | |
self.trigger_words = [ | |
'buang', 'pokpok', 'bogo', 'linte', 'tanga', 'diputa', 'yuta mo', 'gaga', | |
'lagtok ka', 'addict', 'bogok', 'gago', 'law-ay', 'demonyo ka', 'animal ka', 'animal', | |
'bilatibay', 'yudipota', 'pangit', 'tikalon', 'tinikal', 'hambog', | |
'batinggilan', 'biga-on', 'bulay-ug', 'agi', 'agitot', 'alpot', 'hangag' | |
] | |
def find_triggers(self, text): | |
text_lower = text.lower() | |
return [word for word in self.trigger_words if word in text_lower] | |
def predict(self, text): | |
triggers = self.find_triggers(text) | |
inputs = self.tokenizer( | |
text, | |
return_tensors="pt", | |
truncation=True, | |
max_length=128, | |
padding=True | |
).to(self.device) | |
with torch.no_grad(): | |
outputs = self.model(**inputs) | |
logits = outputs.logits | |
probs = torch.nn.functional.softmax(logits, dim=1) | |
pred_class = torch.argmax(probs).item() | |
confidence = probs[0][pred_class].item() | |
if (pred_class == 1 and confidence >= self.cyberbullying_threshold) or (len(triggers) > 0): | |
label = "Cyberbullying" | |
is_cyberbullying = True | |
else: | |
label = "Safe" | |
is_cyberbullying = False | |
return { | |
"text": text, | |
"label": label, | |
"confidence": confidence, | |
"language": "hil", | |
"triggers": triggers, | |
"is_cyberbullying": is_cyberbullying | |
} | |
def save_confusion_matrix(y_true, y_pred, filename="confusion_matrix.png"): | |
labels = sorted(set(y_true + y_pred)) | |
cm = confusion_matrix(y_true, y_pred, labels=labels) | |
plt.figure(figsize=(6, 4)) | |
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=labels, yticklabels=labels) | |
plt.title("Confusion Matrix") | |
plt.xlabel("Predicted") | |
plt.ylabel("True") | |
plt.tight_layout() | |
plt.savefig(filename) | |
plt.close() | |
print(f"πΈ Confusion matrix saved as {filename}") | |
def run_predictions(input_csv=None, output_csv=None, model_type="bert"): | |
detector = CyberbullyingDetector(model_type=model_type) | |
if input_csv: | |
df = pd.read_csv(input_csv) | |
results = [detector.predict(text) for text in df['tweet_text']] | |
output_df = df.copy() | |
output_df['predicted_label'] = [r['label'] for r in results] | |
output_df['confidence'] = [r['confidence'] for r in results] | |
output_df['language'] = [r['language'] for r in results] | |
output_df['triggers'] = [', '.join(r['triggers']) for r in results] | |
output_df['is_cyberbullying'] = [r['is_cyberbullying'] for r in results] | |
output_df['true_label'] = output_df['cyberbullying_type'].apply( | |
lambda x: "Cyberbullying" if pd.notnull(x) and str(x).strip().lower() != "none" else "Safe" | |
) | |
if output_csv: | |
output_df.to_csv(output_csv, index=False) | |
print(f"\nβ Predictions saved to {output_csv}") | |
print("\nπ Sample Predictions:") | |
print(output_df[['tweet_text', 'predicted_label', 'confidence', 'triggers']].head(10).to_string(index=False)) | |
print("\nπ Prediction Summary:") | |
print(output_df['predicted_label'].value_counts()) | |
print("\nβ Ground Truth Summary:") | |
print(output_df['true_label'].value_counts()) | |
accuracy = (output_df['predicted_label'] == output_df['true_label']).mean() | |
print(f"\nπ― Accuracy: {accuracy:.2%}") | |
print("\nπ§Ύ Classification Report:") | |
print(classification_report(output_df['true_label'], output_df['predicted_label'], digits=2, zero_division=0)) | |
save_confusion_matrix(output_df['true_label'].tolist(), output_df['predicted_label'].tolist()) | |
return output_df | |
else: | |
# Inference mode | |
print("\nπ Type a sentence to analyze (or 'exit' to quit):") | |
while True: | |
text = input(">>> ") | |
if text.lower() in ["exit", "quit"]: | |
break | |
result = detector.predict(text) | |
print(result) | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser(description="Cyberbullying Detector") | |
parser.add_argument('--input_file', type=str, help="Path to input CSV with 'tweet_text' column") | |
parser.add_argument('--output_file', type=str, help="Path to save results CSV") | |
parser.add_argument('--model', type=str, default='bert', choices=['bert', 'hatebert'], help="Model to use") | |
args = parser.parse_args() | |
if args.input_file: | |
print(f"π₯ Running batch predictions from {args.input_file} using {args.model.upper()}...") | |
else: | |
print(f"π§ͺ No input file. Running in interactive mode using {args.model.upper()}...") | |
run_predictions(args.input_file, args.output_file, model_type=args.model) | |