SamanthaStorm commited on
Commit
8e4d20e
·
verified ·
1 Parent(s): 2cb59fd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -14
app.py CHANGED
@@ -4,15 +4,14 @@ import numpy as np
4
  from transformers import AutoModelForSequenceClassification, AutoTokenizer
5
  from transformers import RobertaForSequenceClassification, RobertaTokenizer
6
 
7
- # Load fine-tuned sentiment model (DistilBERT)
8
- sentiment_model_name = "SamanthaStorm/tether-sentiment"
9
- sentiment_model = AutoModelForSequenceClassification.from_pretrained(sentiment_model_name)
10
- sentiment_tokenizer = AutoTokenizer.from_pretrained(sentiment_model_name)
11
 
12
- # Load abuse pattern model (RoBERTa)
13
- abuse_model_name = "SamanthaStorm/abuse-pattern-detector-v2"
14
- abuse_model = RobertaForSequenceClassification.from_pretrained(abuse_model_name)
15
- abuse_tokenizer = RobertaTokenizer.from_pretrained(abuse_model_name)
16
 
17
  LABELS = [
18
  "gaslighting", "mockery", "dismissiveness", "control", "guilt_tripping", "apology_baiting", "blame_shifting", "projection",
@@ -80,11 +79,11 @@ def analyze_messages(input_text, risk_flags):
80
  sentiment_label = sentiment['label']
81
  sentiment_score = sentiment['score']
82
 
83
- adjusted_thresholds = {k: v * 0.8 for k, v in THRESHOLDS.items()} if sentiment_label == "NEGATIVE" else THRESHOLDS.copy()
84
 
85
- inputs = abuse_tokenizer(input_text, return_tensors="pt", truncation=True, padding=True)
86
  with torch.no_grad():
87
- outputs = abuse_model(**inputs)
88
  scores = torch.sigmoid(outputs.logits.squeeze(0)).numpy()
89
 
90
  pattern_count = sum(score > adjusted_thresholds[label] for label, score in zip(PATTERN_LABELS, scores[:15]))
@@ -101,7 +100,7 @@ def analyze_messages(input_text, risk_flags):
101
  if non_abusive_score > adjusted_thresholds['non_abusive']:
102
  return "This message is classified as non-abusive."
103
 
104
- abuse_level = calculate_abuse_level(scores, THRESHOLDS)
105
  abuse_description = interpret_abuse_level(abuse_level)
106
 
107
  if danger_flag_count >= 2:
@@ -121,8 +120,8 @@ def analyze_messages(input_text, risk_flags):
121
  f"Abuse Risk Score: {abuse_level}% – {abuse_description}\n\n"
122
  f"Most Likely Patterns:\n{top_pattern_explanations}\n\n"
123
  f"⚠️ Critical Danger Flags Detected: {danger_flag_count} of 3\n"
124
- f"Resources: {resources}\n\n"
125
- f"Sentiment: {sentiment_label} (Confidence: {sentiment_score*100:.2f}%)"
126
  )
127
 
128
  if contextual_flags:
 
4
  from transformers import AutoModelForSequenceClassification, AutoTokenizer
5
  from transformers import RobertaForSequenceClassification, RobertaTokenizer
6
 
7
+ # Load custom fine-tuned sentiment model
8
+ sentiment_model = AutoModelForSequenceClassification.from_pretrained("SamanthaStorm/tether-sentiment")
9
+ sentiment_tokenizer = AutoTokenizer.from_pretrained("SamanthaStorm/tether-sentiment")
 
10
 
11
+ # Load abuse pattern model
12
+ model_name = "SamanthaStorm/abuse-pattern-detector-v2"
13
+ model = RobertaForSequenceClassification.from_pretrained(model_name, trust_remote_code=True)
14
+ tokenizer = RobertaTokenizer.from_pretrained(model_name, trust_remote_code=True)
15
 
16
  LABELS = [
17
  "gaslighting", "mockery", "dismissiveness", "control", "guilt_tripping", "apology_baiting", "blame_shifting", "projection",
 
79
  sentiment_label = sentiment['label']
80
  sentiment_score = sentiment['score']
81
 
82
+ adjusted_thresholds = {k: v * 0.8 for k, v in THRESHOLDS.items()} if sentiment_label == "undermining" else THRESHOLDS.copy()
83
 
84
+ inputs = tokenizer(input_text, return_tensors="pt", truncation=True, padding=True)
85
  with torch.no_grad():
86
+ outputs = model(**inputs)
87
  scores = torch.sigmoid(outputs.logits.squeeze(0)).numpy()
88
 
89
  pattern_count = sum(score > adjusted_thresholds[label] for label, score in zip(PATTERN_LABELS, scores[:15]))
 
100
  if non_abusive_score > adjusted_thresholds['non_abusive']:
101
  return "This message is classified as non-abusive."
102
 
103
+ abuse_level = calculate_abuse_level(scores, adjusted_thresholds)
104
  abuse_description = interpret_abuse_level(abuse_level)
105
 
106
  if danger_flag_count >= 2:
 
120
  f"Abuse Risk Score: {abuse_level}% – {abuse_description}\n\n"
121
  f"Most Likely Patterns:\n{top_pattern_explanations}\n\n"
122
  f"⚠️ Critical Danger Flags Detected: {danger_flag_count} of 3\n"
123
+ "Resources: " + resources + "\n\n"
124
+ f"Sentiment: {sentiment_label.title()} (Confidence: {sentiment_score*100:.2f}%)"
125
  )
126
 
127
  if contextual_flags: