SamanthaStorm commited on
Commit
23bb2d2
·
verified ·
1 Parent(s): d3b5f65

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -50
app.py CHANGED
@@ -7,7 +7,7 @@ from transformers import pipeline
7
  # Load sentiment analysis model
8
  sentiment_analyzer = pipeline("sentiment-analysis")
9
 
10
- # Load model and tokenizer with trust_remote_code in case it's needed
11
  model_name = "SamanthaStorm/abuse-pattern-detector-v2"
12
  model = RobertaForSequenceClassification.from_pretrained(model_name, trust_remote_code=True)
13
  tokenizer = RobertaTokenizer.from_pretrained(model_name, trust_remote_code=True)
@@ -21,11 +21,11 @@ LABELS = [
21
  "extreme_control"
22
  ]
23
 
24
- # Custom thresholds for each label (make sure these match your original settings)
25
  THRESHOLDS = {
26
  "gaslighting": 0.25,
27
  "mockery": 0.15,
28
- "dismissiveness": 0.30, # original value, not 0.30
29
  "control": 0.43,
30
  "guilt_tripping": 0.19,
31
  "apology_baiting": 0.45,
@@ -41,30 +41,11 @@ THRESHOLDS = {
41
  "suicidal_threat": 0.45,
42
  "physical_threat": 0.20,
43
  "extreme_control": 0.36
44
-
45
  }
46
 
47
- # Define label groups using slicing (first 15: abuse patterns, last 3: danger cues)
48
  PATTERN_LABELS = LABELS[:15]
49
  DANGER_LABELS = LABELS[15:18]
50
 
51
- def calculate_abuse_level(scores, thresholds):
52
- triggered_scores = [score for label, score in zip(LABELS, scores) if score > thresholds[label]]
53
- if not triggered_scores:
54
- return 0.0
55
- return round(np.mean(triggered_scores) * 100, 2)
56
-
57
- def interpret_abuse_level(score):
58
- if score > 80:
59
- return "Extreme / High Risk"
60
- elif score > 60:
61
- return "Severe / Harmful Pattern Present"
62
- elif score > 40:
63
- return "Likely Abuse"
64
- elif score > 20:
65
- return "Mild Concern"
66
- else:
67
- return "Very Low / Likely Safe"
68
  EXPLANATIONS = {
69
  "gaslighting": "Gaslighting involves making someone question their own reality or perceptions, often causing them to feel confused or insecure.",
70
  "blame_shifting": "Blame-shifting is when one person redirects the responsibility for an issue onto someone else, avoiding accountability.",
@@ -84,63 +65,66 @@ EXPLANATIONS = {
84
  "manipulation": "Manipulation refers to using deceptive tactics to control or influence someone’s emotions, decisions, or behavior to serve the manipulator’s own interests.",
85
  "non_abusive": "Non-abusive language is communication that is respectful, empathetic, and free of harmful behaviors or manipulation."
86
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  def analyze_messages(input_text):
88
  input_text = input_text.strip()
89
  if not input_text:
90
- return "Please enter a message for analysis.", None
91
-
92
- # Sentiment analysis
93
- sentiment = sentiment_analyzer(input_text)[0] # Sentiment result
94
  sentiment_label = sentiment['label']
95
  sentiment_score = sentiment['score']
96
-
97
- # Adjust thresholds based on sentiment
98
  adjusted_thresholds = THRESHOLDS.copy()
99
  if sentiment_label == "NEGATIVE":
100
- # Lower thresholds for negative sentiment
101
- adjusted_thresholds = {key: val * 0.8 for key, val in THRESHOLDS.items()} # Example adjustment
102
-
103
- # Tokenize input and generate model predictions
104
  inputs = tokenizer(input_text, return_tensors="pt", truncation=True, padding=True)
105
  with torch.no_grad():
106
  outputs = model(**inputs)
107
  scores = torch.sigmoid(outputs.logits.squeeze(0)).numpy()
108
 
109
- # Count the number of triggered abuse pattern and danger flags based on thresholds
110
  pattern_count = sum(score > adjusted_thresholds[label] for label, score in zip(PATTERN_LABELS, scores[:15]))
111
  danger_flag_count = sum(score > adjusted_thresholds[label] for label, score in zip(DANGER_LABELS, scores[15:18]))
112
 
113
- # Check if 'non_abusive' label is triggered
114
  non_abusive_score = scores[LABELS.index('non_abusive')]
115
  if non_abusive_score > adjusted_thresholds['non_abusive']:
116
- # If non-abusive threshold is met, return a non-abusive classification
117
  return "This message is classified as non-abusive."
118
 
119
- # Build formatted raw score display
120
- score_lines = [
121
- f"{label:25}: {score:.3f}" for label, score in zip(PATTERN_LABELS + DANGER_LABELS, scores)
122
- ]
123
- raw_score_output = "\n".join(score_lines)
124
-
125
- # Calculate overall abuse level and interpret it
126
  abuse_level = calculate_abuse_level(scores, THRESHOLDS)
127
  abuse_description = interpret_abuse_level(abuse_level)
128
 
129
- # Resource logic based on the number of danger cues
130
  if danger_flag_count >= 2:
131
  resources = "Immediate assistance recommended. Please seek professional help or contact emergency services."
132
  else:
133
  resources = "For more information on abuse patterns, consider reaching out to support groups or professional counselors."
134
 
135
- # Get top 2 highest scoring abuse patterns (excluding 'non_abusive')
136
  scored_patterns = [(label, score) for label, score in zip(PATTERN_LABELS, scores[:15])]
137
  top_patterns = sorted(scored_patterns, key=lambda x: x[1], reverse=True)[:2]
138
- top_patterns_str = "\n".join([f"• {label.replace('_', ' ').title()}" for label, _ in top_patterns])
139
 
140
-
141
- top_pattern_explanations = "\n".join([f"• {label.replace('_', ' ').title()}: {EXPLANATIONS.get(label, 'No explanation available.')}" for label, _ in top_patterns])
 
 
142
 
143
- # Format final result
144
  result = (
145
  f"Abuse Risk Score: {abuse_level}% – {abuse_description}\n\n"
146
  f"Most Likely Patterns:\n{top_pattern_explanations}\n\n"
@@ -152,10 +136,8 @@ def analyze_messages(input_text):
152
  f"Sentiment: {sentiment_label} (Confidence: {sentiment_score*100:.2f}%)"
153
  )
154
 
155
- # Return both a text summary and a JSON-like dict of scores per label
156
  return result
157
 
158
- # Updated Gradio Interface using new component syntax
159
  iface = gr.Interface(
160
  fn=analyze_messages,
161
  inputs=gr.Textbox(lines=10, placeholder="Enter message here..."),
 
7
  # Load sentiment analysis model
8
  sentiment_analyzer = pipeline("sentiment-analysis")
9
 
10
+ # Load model and tokenizer
11
  model_name = "SamanthaStorm/abuse-pattern-detector-v2"
12
  model = RobertaForSequenceClassification.from_pretrained(model_name, trust_remote_code=True)
13
  tokenizer = RobertaTokenizer.from_pretrained(model_name, trust_remote_code=True)
 
21
  "extreme_control"
22
  ]
23
 
24
+ # Custom thresholds for each label
25
  THRESHOLDS = {
26
  "gaslighting": 0.25,
27
  "mockery": 0.15,
28
+ "dismissiveness": 0.30,
29
  "control": 0.43,
30
  "guilt_tripping": 0.19,
31
  "apology_baiting": 0.45,
 
41
  "suicidal_threat": 0.45,
42
  "physical_threat": 0.20,
43
  "extreme_control": 0.36
 
44
  }
45
 
 
46
  PATTERN_LABELS = LABELS[:15]
47
  DANGER_LABELS = LABELS[15:18]
48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  EXPLANATIONS = {
50
  "gaslighting": "Gaslighting involves making someone question their own reality or perceptions, often causing them to feel confused or insecure.",
51
  "blame_shifting": "Blame-shifting is when one person redirects the responsibility for an issue onto someone else, avoiding accountability.",
 
65
  "manipulation": "Manipulation refers to using deceptive tactics to control or influence someone’s emotions, decisions, or behavior to serve the manipulator’s own interests.",
66
  "non_abusive": "Non-abusive language is communication that is respectful, empathetic, and free of harmful behaviors or manipulation."
67
  }
68
+
69
+ def calculate_abuse_level(scores, thresholds):
70
+ triggered_scores = [score for label, score in zip(LABELS, scores) if score > thresholds[label]]
71
+ if not triggered_scores:
72
+ return 0.0
73
+ return round(np.mean(triggered_scores) * 100, 2)
74
+
75
+ def interpret_abuse_level(score):
76
+ if score > 80:
77
+ return "Extreme / High Risk"
78
+ elif score > 60:
79
+ return "Severe / Harmful Pattern Present"
80
+ elif score > 40:
81
+ return "Likely Abuse"
82
+ elif score > 20:
83
+ return "Mild Concern"
84
+ else:
85
+ return "Very Low / Likely Safe"
86
+
87
  def analyze_messages(input_text):
88
  input_text = input_text.strip()
89
  if not input_text:
90
+ return "Please enter a message for analysis."
91
+
92
+ sentiment = sentiment_analyzer(input_text)[0]
 
93
  sentiment_label = sentiment['label']
94
  sentiment_score = sentiment['score']
95
+
 
96
  adjusted_thresholds = THRESHOLDS.copy()
97
  if sentiment_label == "NEGATIVE":
98
+ adjusted_thresholds = {key: val * 0.8 for key, val in THRESHOLDS.items()}
99
+
 
 
100
  inputs = tokenizer(input_text, return_tensors="pt", truncation=True, padding=True)
101
  with torch.no_grad():
102
  outputs = model(**inputs)
103
  scores = torch.sigmoid(outputs.logits.squeeze(0)).numpy()
104
 
 
105
  pattern_count = sum(score > adjusted_thresholds[label] for label, score in zip(PATTERN_LABELS, scores[:15]))
106
  danger_flag_count = sum(score > adjusted_thresholds[label] for label, score in zip(DANGER_LABELS, scores[15:18]))
107
 
 
108
  non_abusive_score = scores[LABELS.index('non_abusive')]
109
  if non_abusive_score > adjusted_thresholds['non_abusive']:
 
110
  return "This message is classified as non-abusive."
111
 
 
 
 
 
 
 
 
112
  abuse_level = calculate_abuse_level(scores, THRESHOLDS)
113
  abuse_description = interpret_abuse_level(abuse_level)
114
 
 
115
  if danger_flag_count >= 2:
116
  resources = "Immediate assistance recommended. Please seek professional help or contact emergency services."
117
  else:
118
  resources = "For more information on abuse patterns, consider reaching out to support groups or professional counselors."
119
 
 
120
  scored_patterns = [(label, score) for label, score in zip(PATTERN_LABELS, scores[:15])]
121
  top_patterns = sorted(scored_patterns, key=lambda x: x[1], reverse=True)[:2]
 
122
 
123
+ top_pattern_explanations = "\n".join([
124
+ f"\u2022 {label.replace('_', ' ').title()}: {EXPLANATIONS.get(label, 'No explanation available.')}"
125
+ for label, _ in top_patterns
126
+ ])
127
 
 
128
  result = (
129
  f"Abuse Risk Score: {abuse_level}% – {abuse_description}\n\n"
130
  f"Most Likely Patterns:\n{top_pattern_explanations}\n\n"
 
136
  f"Sentiment: {sentiment_label} (Confidence: {sentiment_score*100:.2f}%)"
137
  )
138
 
 
139
  return result
140
 
 
141
  iface = gr.Interface(
142
  fn=analyze_messages,
143
  inputs=gr.Textbox(lines=10, placeholder="Enter message here..."),