SamanthaStorm commited on
Commit
2f6ac5d
·
verified ·
1 Parent(s): 0e53c22

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -28
app.py CHANGED
@@ -1,17 +1,30 @@
1
  import gradio as gr
2
  import torch
3
  import numpy as np
4
- from transformers import AutoModelForSequenceClassification, AutoTokenizer
5
  from transformers import RobertaForSequenceClassification, RobertaTokenizer
6
  from motif_tagging import detect_motifs
7
  import re
8
 
9
- # custom fine-tuned sentiment model
10
- sentiment_model = AutoModelForSequenceClassification.from_pretrained("SamanthaStorm/tether-sentiment")
11
- sentiment_tokenizer = AutoTokenizer.from_pretrained("SamanthaStorm/tether-sentiment")
12
 
13
- # Load abuse pattern model
14
- model_name ="SamanthaStorm/autotrain-jlpi4-mllvp"
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  model = RobertaForSequenceClassification.from_pretrained(model_name, trust_remote_code=True)
16
  tokenizer = RobertaTokenizer.from_pretrained(model_name, trust_remote_code=True)
17
 
@@ -35,7 +48,13 @@ THRESHOLDS = {
35
  "threat": 0.25
36
  }
37
 
38
- PATTERN_LABELS = LABELS
 
 
 
 
 
 
39
 
40
  EXPLANATIONS = {
41
  "blame shifting": "Blame-shifting is when one person redirects responsibility onto someone else to avoid accountability.",
@@ -51,11 +70,6 @@ EXPLANATIONS = {
51
  "threat": "Threats use fear of harm (physical, emotional, or relational) to control or intimidate someone."
52
  }
53
 
54
- PATTERN_WEIGHTS = {
55
- "gaslighting": 1.3, "control": 1.2, "dismissiveness": 0.8, "blame shifting": 0.8,
56
- "contradictory statements": 0.75
57
- }
58
-
59
  RISK_SNIPPETS = {
60
  "low": (
61
  "🟢 Risk Level: Low",
@@ -82,12 +96,13 @@ def generate_risk_snippet(abuse_score, top_label):
82
  else:
83
  risk_level = "low"
84
  title, summary, advice = RISK_SNIPPETS[risk_level]
85
- return f"\n\n{title}\n{summary} (Pattern: **{top_label}**)\n💡 {advice}"
86
 
87
- # --- DARVO Detection Tools ---
88
  DARVO_PATTERNS = {
89
  "blame shifting", "projection", "dismissiveness", "guilt tripping", "contradictory statements"
90
  }
 
91
  DARVO_MOTIFS = [
92
  "i guess i’m the bad guy", "after everything i’ve done", "you always twist everything",
93
  "so now it’s all my fault", "i’m the villain", "i’m always wrong", "you never listen",
@@ -125,17 +140,19 @@ def calculate_darvo_score(patterns, sentiment_before, sentiment_after, motifs_fo
125
  )
126
  return round(min(darvo_score, 1.0), 3)
127
 
 
128
  def custom_sentiment(text):
129
- inputs = sentiment_tokenizer(text, return_tensors="pt", truncation=True, padding=True)
130
  with torch.no_grad():
131
- outputs = sentiment_model(**inputs)
132
- probs = torch.nn.functional.softmax(outputs.logits, dim=1)
133
- label_idx = torch.argmax(probs).item()
134
- label_map = {0: "supportive", 1: "undermining"}
135
- return {"label": label_map[label_idx], "score": probs[0][label_idx].item()}
136
 
 
137
  def calculate_abuse_level(scores, thresholds, motif_hits=None, flag_multiplier=1.0):
138
- weighted_scores = [score * PATTERN_WEIGHTS.get(label, 1.0) for label, score in zip(LABELS, scores) if score > thresholds[label]]
 
139
  base_score = round(np.mean(weighted_scores) * 100, 2) if weighted_scores else 0.0
140
  base_score *= flag_multiplier
141
  return min(base_score, 100.0)
@@ -143,12 +160,12 @@ def calculate_abuse_level(scores, thresholds, motif_hits=None, flag_multiplier=1
143
  def analyze_single_message(text, thresholds, motif_flags):
144
  motif_hits, matched_phrases = detect_motifs(text)
145
  sentiment = custom_sentiment(text)
146
- sentiment_score = sentiment["score"] if sentiment["label"] == "undermining" else 0.0
147
-
148
- # TEMP: print sentiment to console for debugging
149
- print(f"Sentiment label: {sentiment['label']}, score: {sentiment['score']}")
150
 
151
- adjusted_thresholds = {k: v * 0.8 for k, v in thresholds.items()} if sentiment['label'] == "undermining" else thresholds.copy()
 
 
152
 
153
  inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
154
  with torch.no_grad():
@@ -160,13 +177,16 @@ def analyze_single_message(text, thresholds, motif_flags):
160
  pattern_labels_used = list(set(threshold_labels + phrase_labels))
161
 
162
  abuse_level = calculate_abuse_level(scores, adjusted_thresholds, motif_hits)
163
- top_patterns = sorted([(label, score) for label, score in zip(LABELS, scores)], key=lambda x: x[1], reverse=True)[:2]
 
 
164
  motif_phrases = [text for _, text in matched_phrases]
165
  contradiction_flag = detect_contradiction(text)
166
  darvo_score = calculate_darvo_score(pattern_labels_used, 0.0, sentiment_score, motif_phrases, contradiction_flag)
167
 
168
  return abuse_level, pattern_labels_used, top_patterns, darvo_score, sentiment
169
 
 
170
  def analyze_composite(msg1, msg2, msg3, flags):
171
  thresholds = THRESHOLDS
172
  messages = [msg1, msg2, msg3]
@@ -180,15 +200,17 @@ def analyze_composite(msg1, msg2, msg3, flags):
180
  print(f"Message: {m}")
181
  print(f"Sentiment result: {result[4]}")
182
  results.append(result)
 
183
  abuse_scores = [r[0] for r in results]
184
  darvo_scores = [r[3] for r in results]
185
  average_darvo = round(sum(darvo_scores) / len(darvo_scores), 3)
186
  base_score = sum(abuse_scores) / len(abuse_scores)
 
187
  label_sets = [[label for label, _ in r[2]] for r in results]
188
  label_counts = {label: sum(label in s for s in label_sets) for label in set().union(*label_sets)}
189
  top_label = max(label_counts.items(), key=lambda x: x[1])
190
  top_explanation = EXPLANATIONS.get(top_label[0], "")
191
- danger_weight = 5
192
  flag_weights = {
193
  "They've threatened harm": 6,
194
  "They isolate me": 5,
@@ -196,6 +218,7 @@ def analyze_composite(msg1, msg2, msg3, flags):
196
  "They monitor/follow me": 4,
197
  "I feel unsafe when alone with them": 6
198
  }
 
199
  flag_boost = sum(flag_weights.get(f, 3) for f in flags) / len(active_messages)
200
  composite_score = min(base_score + flag_boost, 100)
201
  if len(active_messages) == 1:
@@ -203,6 +226,7 @@ def analyze_composite(msg1, msg2, msg3, flags):
203
  elif len(active_messages) == 2:
204
  composite_score *= 0.93
205
  composite_score = round(min(composite_score, 100), 2)
 
206
  result = f"These messages show a pattern of **{top_label[0]}** and are estimated to be {composite_score}% likely abusive."
207
  if top_explanation:
208
  result += f"\n• {top_explanation}"
@@ -212,6 +236,7 @@ def analyze_composite(msg1, msg2, msg3, flags):
212
  result += generate_risk_snippet(composite_score, top_label[0])
213
  return result
214
 
 
215
  textbox_inputs = [
216
  gr.Textbox(label="Message 1"),
217
  gr.Textbox(label="Message 2"),
 
1
  import gradio as gr
2
  import torch
3
  import numpy as np
4
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
5
  from transformers import RobertaForSequenceClassification, RobertaTokenizer
6
  from motif_tagging import detect_motifs
7
  import re
8
 
9
+ # --- Sentiment Model: T5-based Emotion Classifier ---
10
+ sentiment_tokenizer = AutoTokenizer.from_pretrained("mrm8488/t5-base-finetuned-emotion")
11
+ sentiment_model = AutoModelForSeq2SeqLM.from_pretrained("mrm8488/t5-base-finetuned-emotion")
12
 
13
+ EMOTION_TO_SENTIMENT = {
14
+ "joy": "supportive",
15
+ "love": "supportive",
16
+ "surprise": "supportive",
17
+ "neutral": "supportive",
18
+ "sadness": "undermining",
19
+ "anger": "undermining",
20
+ "fear": "undermining",
21
+ "disgust": "undermining",
22
+ "shame": "undermining",
23
+ "guilt": "undermining"
24
+ }
25
+
26
+ # --- Abuse Detection Model ---
27
+ model_name = "SamanthaStorm/autotrain-jlpi4-mllvp"
28
  model = RobertaForSequenceClassification.from_pretrained(model_name, trust_remote_code=True)
29
  tokenizer = RobertaTokenizer.from_pretrained(model_name, trust_remote_code=True)
30
 
 
48
  "threat": 0.25
49
  }
50
 
51
+ PATTERN_WEIGHTS = {
52
+ "gaslighting": 1.3,
53
+ "control": 1.2,
54
+ "dismissiveness": 0.8,
55
+ "blame shifting": 0.8,
56
+ "contradictory statements": 0.75
57
+ }
58
 
59
  EXPLANATIONS = {
60
  "blame shifting": "Blame-shifting is when one person redirects responsibility onto someone else to avoid accountability.",
 
70
  "threat": "Threats use fear of harm (physical, emotional, or relational) to control or intimidate someone."
71
  }
72
 
 
 
 
 
 
73
  RISK_SNIPPETS = {
74
  "low": (
75
  "🟢 Risk Level: Low",
 
96
  else:
97
  risk_level = "low"
98
  title, summary, advice = RISK_SNIPPETS[risk_level]
99
+ return f"\n\n{title}\n{summary} (Pattern: {top_label})\n💡 {advice}"
100
 
101
+ # --- DARVO Detection ---
102
  DARVO_PATTERNS = {
103
  "blame shifting", "projection", "dismissiveness", "guilt tripping", "contradictory statements"
104
  }
105
+
106
  DARVO_MOTIFS = [
107
  "i guess i’m the bad guy", "after everything i’ve done", "you always twist everything",
108
  "so now it’s all my fault", "i’m the villain", "i’m always wrong", "you never listen",
 
140
  )
141
  return round(min(darvo_score, 1.0), 3)
142
 
143
+ # --- Sentiment Mapping ---
144
  def custom_sentiment(text):
145
+ input_ids = sentiment_tokenizer(f"emotion: {text}", return_tensors="pt").input_ids
146
  with torch.no_grad():
147
+ outputs = sentiment_model.generate(input_ids)
148
+ emotion = sentiment_tokenizer.decode(outputs[0], skip_special_tokens=True).strip().lower()
149
+ sentiment = EMOTION_TO_SENTIMENT.get(emotion, "undermining")
150
+ return {"label": sentiment, "emotion": emotion}
 
151
 
152
+ # --- Abuse Analysis Core ---
153
  def calculate_abuse_level(scores, thresholds, motif_hits=None, flag_multiplier=1.0):
154
+ weighted_scores = [score * PATTERN_WEIGHTS.get(label, 1.0)
155
+ for label, score in zip(LABELS, scores) if score > thresholds[label]]
156
  base_score = round(np.mean(weighted_scores) * 100, 2) if weighted_scores else 0.0
157
  base_score *= flag_multiplier
158
  return min(base_score, 100.0)
 
160
  def analyze_single_message(text, thresholds, motif_flags):
161
  motif_hits, matched_phrases = detect_motifs(text)
162
  sentiment = custom_sentiment(text)
163
+ sentiment_score = 0.5 if sentiment["label"] == "undermining" else 0.0
164
+ print(f"Detected emotion: {sentiment['emotion']} → sentiment: {sentiment['label']}")
 
 
165
 
166
+ adjusted_thresholds = {
167
+ k: v * 0.8 for k, v in thresholds.items()
168
+ } if sentiment["label"] == "undermining" else thresholds.copy()
169
 
170
  inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
171
  with torch.no_grad():
 
177
  pattern_labels_used = list(set(threshold_labels + phrase_labels))
178
 
179
  abuse_level = calculate_abuse_level(scores, adjusted_thresholds, motif_hits)
180
+ top_patterns = sorted([(label, score) for label, score in zip(LABELS, scores)],
181
+ key=lambda x: x[1], reverse=True)[:2]
182
+
183
  motif_phrases = [text for _, text in matched_phrases]
184
  contradiction_flag = detect_contradiction(text)
185
  darvo_score = calculate_darvo_score(pattern_labels_used, 0.0, sentiment_score, motif_phrases, contradiction_flag)
186
 
187
  return abuse_level, pattern_labels_used, top_patterns, darvo_score, sentiment
188
 
189
+ # --- Composite Message Analysis ---
190
  def analyze_composite(msg1, msg2, msg3, flags):
191
  thresholds = THRESHOLDS
192
  messages = [msg1, msg2, msg3]
 
200
  print(f"Message: {m}")
201
  print(f"Sentiment result: {result[4]}")
202
  results.append(result)
203
+
204
  abuse_scores = [r[0] for r in results]
205
  darvo_scores = [r[3] for r in results]
206
  average_darvo = round(sum(darvo_scores) / len(darvo_scores), 3)
207
  base_score = sum(abuse_scores) / len(abuse_scores)
208
+
209
  label_sets = [[label for label, _ in r[2]] for r in results]
210
  label_counts = {label: sum(label in s for s in label_sets) for label in set().union(*label_sets)}
211
  top_label = max(label_counts.items(), key=lambda x: x[1])
212
  top_explanation = EXPLANATIONS.get(top_label[0], "")
213
+
214
  flag_weights = {
215
  "They've threatened harm": 6,
216
  "They isolate me": 5,
 
218
  "They monitor/follow me": 4,
219
  "I feel unsafe when alone with them": 6
220
  }
221
+
222
  flag_boost = sum(flag_weights.get(f, 3) for f in flags) / len(active_messages)
223
  composite_score = min(base_score + flag_boost, 100)
224
  if len(active_messages) == 1:
 
226
  elif len(active_messages) == 2:
227
  composite_score *= 0.93
228
  composite_score = round(min(composite_score, 100), 2)
229
+
230
  result = f"These messages show a pattern of **{top_label[0]}** and are estimated to be {composite_score}% likely abusive."
231
  if top_explanation:
232
  result += f"\n• {top_explanation}"
 
236
  result += generate_risk_snippet(composite_score, top_label[0])
237
  return result
238
 
239
+ # --- Gradio Interface ---
240
  textbox_inputs = [
241
  gr.Textbox(label="Message 1"),
242
  gr.Textbox(label="Message 2"),