Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -1,17 +1,30 @@
|
|
1 |
import gradio as gr
|
2 |
import torch
|
3 |
import numpy as np
|
4 |
-
from transformers import
|
5 |
from transformers import RobertaForSequenceClassification, RobertaTokenizer
|
6 |
from motif_tagging import detect_motifs
|
7 |
import re
|
8 |
|
9 |
-
#
|
10 |
-
|
11 |
-
|
12 |
|
13 |
-
|
14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
model = RobertaForSequenceClassification.from_pretrained(model_name, trust_remote_code=True)
|
16 |
tokenizer = RobertaTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
17 |
|
@@ -35,7 +48,13 @@ THRESHOLDS = {
|
|
35 |
"threat": 0.25
|
36 |
}
|
37 |
|
38 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
|
40 |
EXPLANATIONS = {
|
41 |
"blame shifting": "Blame-shifting is when one person redirects responsibility onto someone else to avoid accountability.",
|
@@ -51,11 +70,6 @@ EXPLANATIONS = {
|
|
51 |
"threat": "Threats use fear of harm (physical, emotional, or relational) to control or intimidate someone."
|
52 |
}
|
53 |
|
54 |
-
PATTERN_WEIGHTS = {
|
55 |
-
"gaslighting": 1.3, "control": 1.2, "dismissiveness": 0.8, "blame shifting": 0.8,
|
56 |
-
"contradictory statements": 0.75
|
57 |
-
}
|
58 |
-
|
59 |
RISK_SNIPPETS = {
|
60 |
"low": (
|
61 |
"🟢 Risk Level: Low",
|
@@ -82,12 +96,13 @@ def generate_risk_snippet(abuse_score, top_label):
|
|
82 |
else:
|
83 |
risk_level = "low"
|
84 |
title, summary, advice = RISK_SNIPPETS[risk_level]
|
85 |
-
return f"\n\n{title}\n{summary} (Pattern:
|
86 |
|
87 |
-
# --- DARVO Detection
|
88 |
DARVO_PATTERNS = {
|
89 |
"blame shifting", "projection", "dismissiveness", "guilt tripping", "contradictory statements"
|
90 |
}
|
|
|
91 |
DARVO_MOTIFS = [
|
92 |
"i guess i’m the bad guy", "after everything i’ve done", "you always twist everything",
|
93 |
"so now it’s all my fault", "i’m the villain", "i’m always wrong", "you never listen",
|
@@ -125,17 +140,19 @@ def calculate_darvo_score(patterns, sentiment_before, sentiment_after, motifs_fo
|
|
125 |
)
|
126 |
return round(min(darvo_score, 1.0), 3)
|
127 |
|
|
|
128 |
def custom_sentiment(text):
|
129 |
-
|
130 |
with torch.no_grad():
|
131 |
-
outputs = sentiment_model(
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
return {"label": label_map[label_idx], "score": probs[0][label_idx].item()}
|
136 |
|
|
|
137 |
def calculate_abuse_level(scores, thresholds, motif_hits=None, flag_multiplier=1.0):
|
138 |
-
weighted_scores = [score * PATTERN_WEIGHTS.get(label, 1.0)
|
|
|
139 |
base_score = round(np.mean(weighted_scores) * 100, 2) if weighted_scores else 0.0
|
140 |
base_score *= flag_multiplier
|
141 |
return min(base_score, 100.0)
|
@@ -143,12 +160,12 @@ def calculate_abuse_level(scores, thresholds, motif_hits=None, flag_multiplier=1
|
|
143 |
def analyze_single_message(text, thresholds, motif_flags):
|
144 |
motif_hits, matched_phrases = detect_motifs(text)
|
145 |
sentiment = custom_sentiment(text)
|
146 |
-
sentiment_score =
|
147 |
-
|
148 |
-
# TEMP: print sentiment to console for debugging
|
149 |
-
print(f"Sentiment label: {sentiment['label']}, score: {sentiment['score']}")
|
150 |
|
151 |
-
adjusted_thresholds = {
|
|
|
|
|
152 |
|
153 |
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
|
154 |
with torch.no_grad():
|
@@ -160,13 +177,16 @@ def analyze_single_message(text, thresholds, motif_flags):
|
|
160 |
pattern_labels_used = list(set(threshold_labels + phrase_labels))
|
161 |
|
162 |
abuse_level = calculate_abuse_level(scores, adjusted_thresholds, motif_hits)
|
163 |
-
top_patterns = sorted([(label, score) for label, score in zip(LABELS, scores)],
|
|
|
|
|
164 |
motif_phrases = [text for _, text in matched_phrases]
|
165 |
contradiction_flag = detect_contradiction(text)
|
166 |
darvo_score = calculate_darvo_score(pattern_labels_used, 0.0, sentiment_score, motif_phrases, contradiction_flag)
|
167 |
|
168 |
return abuse_level, pattern_labels_used, top_patterns, darvo_score, sentiment
|
169 |
|
|
|
170 |
def analyze_composite(msg1, msg2, msg3, flags):
|
171 |
thresholds = THRESHOLDS
|
172 |
messages = [msg1, msg2, msg3]
|
@@ -180,15 +200,17 @@ def analyze_composite(msg1, msg2, msg3, flags):
|
|
180 |
print(f"Message: {m}")
|
181 |
print(f"Sentiment result: {result[4]}")
|
182 |
results.append(result)
|
|
|
183 |
abuse_scores = [r[0] for r in results]
|
184 |
darvo_scores = [r[3] for r in results]
|
185 |
average_darvo = round(sum(darvo_scores) / len(darvo_scores), 3)
|
186 |
base_score = sum(abuse_scores) / len(abuse_scores)
|
|
|
187 |
label_sets = [[label for label, _ in r[2]] for r in results]
|
188 |
label_counts = {label: sum(label in s for s in label_sets) for label in set().union(*label_sets)}
|
189 |
top_label = max(label_counts.items(), key=lambda x: x[1])
|
190 |
top_explanation = EXPLANATIONS.get(top_label[0], "")
|
191 |
-
|
192 |
flag_weights = {
|
193 |
"They've threatened harm": 6,
|
194 |
"They isolate me": 5,
|
@@ -196,6 +218,7 @@ def analyze_composite(msg1, msg2, msg3, flags):
|
|
196 |
"They monitor/follow me": 4,
|
197 |
"I feel unsafe when alone with them": 6
|
198 |
}
|
|
|
199 |
flag_boost = sum(flag_weights.get(f, 3) for f in flags) / len(active_messages)
|
200 |
composite_score = min(base_score + flag_boost, 100)
|
201 |
if len(active_messages) == 1:
|
@@ -203,6 +226,7 @@ def analyze_composite(msg1, msg2, msg3, flags):
|
|
203 |
elif len(active_messages) == 2:
|
204 |
composite_score *= 0.93
|
205 |
composite_score = round(min(composite_score, 100), 2)
|
|
|
206 |
result = f"These messages show a pattern of **{top_label[0]}** and are estimated to be {composite_score}% likely abusive."
|
207 |
if top_explanation:
|
208 |
result += f"\n• {top_explanation}"
|
@@ -212,6 +236,7 @@ def analyze_composite(msg1, msg2, msg3, flags):
|
|
212 |
result += generate_risk_snippet(composite_score, top_label[0])
|
213 |
return result
|
214 |
|
|
|
215 |
textbox_inputs = [
|
216 |
gr.Textbox(label="Message 1"),
|
217 |
gr.Textbox(label="Message 2"),
|
|
|
1 |
import gradio as gr
|
2 |
import torch
|
3 |
import numpy as np
|
4 |
+
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
5 |
from transformers import RobertaForSequenceClassification, RobertaTokenizer
|
6 |
from motif_tagging import detect_motifs
|
7 |
import re
|
8 |
|
9 |
+
# --- Sentiment Model: T5-based Emotion Classifier ---
|
10 |
+
sentiment_tokenizer = AutoTokenizer.from_pretrained("mrm8488/t5-base-finetuned-emotion")
|
11 |
+
sentiment_model = AutoModelForSeq2SeqLM.from_pretrained("mrm8488/t5-base-finetuned-emotion")
|
12 |
|
13 |
+
EMOTION_TO_SENTIMENT = {
|
14 |
+
"joy": "supportive",
|
15 |
+
"love": "supportive",
|
16 |
+
"surprise": "supportive",
|
17 |
+
"neutral": "supportive",
|
18 |
+
"sadness": "undermining",
|
19 |
+
"anger": "undermining",
|
20 |
+
"fear": "undermining",
|
21 |
+
"disgust": "undermining",
|
22 |
+
"shame": "undermining",
|
23 |
+
"guilt": "undermining"
|
24 |
+
}
|
25 |
+
|
26 |
+
# --- Abuse Detection Model ---
|
27 |
+
model_name = "SamanthaStorm/autotrain-jlpi4-mllvp"
|
28 |
model = RobertaForSequenceClassification.from_pretrained(model_name, trust_remote_code=True)
|
29 |
tokenizer = RobertaTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
30 |
|
|
|
48 |
"threat": 0.25
|
49 |
}
|
50 |
|
51 |
+
PATTERN_WEIGHTS = {
|
52 |
+
"gaslighting": 1.3,
|
53 |
+
"control": 1.2,
|
54 |
+
"dismissiveness": 0.8,
|
55 |
+
"blame shifting": 0.8,
|
56 |
+
"contradictory statements": 0.75
|
57 |
+
}
|
58 |
|
59 |
EXPLANATIONS = {
|
60 |
"blame shifting": "Blame-shifting is when one person redirects responsibility onto someone else to avoid accountability.",
|
|
|
70 |
"threat": "Threats use fear of harm (physical, emotional, or relational) to control or intimidate someone."
|
71 |
}
|
72 |
|
|
|
|
|
|
|
|
|
|
|
73 |
RISK_SNIPPETS = {
|
74 |
"low": (
|
75 |
"🟢 Risk Level: Low",
|
|
|
96 |
else:
|
97 |
risk_level = "low"
|
98 |
title, summary, advice = RISK_SNIPPETS[risk_level]
|
99 |
+
return f"\n\n{title}\n{summary} (Pattern: {top_label})\n💡 {advice}"
|
100 |
|
101 |
+
# --- DARVO Detection ---
|
102 |
DARVO_PATTERNS = {
|
103 |
"blame shifting", "projection", "dismissiveness", "guilt tripping", "contradictory statements"
|
104 |
}
|
105 |
+
|
106 |
DARVO_MOTIFS = [
|
107 |
"i guess i’m the bad guy", "after everything i’ve done", "you always twist everything",
|
108 |
"so now it’s all my fault", "i’m the villain", "i’m always wrong", "you never listen",
|
|
|
140 |
)
|
141 |
return round(min(darvo_score, 1.0), 3)
|
142 |
|
143 |
+
# --- Sentiment Mapping ---
|
144 |
def custom_sentiment(text):
|
145 |
+
input_ids = sentiment_tokenizer(f"emotion: {text}", return_tensors="pt").input_ids
|
146 |
with torch.no_grad():
|
147 |
+
outputs = sentiment_model.generate(input_ids)
|
148 |
+
emotion = sentiment_tokenizer.decode(outputs[0], skip_special_tokens=True).strip().lower()
|
149 |
+
sentiment = EMOTION_TO_SENTIMENT.get(emotion, "undermining")
|
150 |
+
return {"label": sentiment, "emotion": emotion}
|
|
|
151 |
|
152 |
+
# --- Abuse Analysis Core ---
|
153 |
def calculate_abuse_level(scores, thresholds, motif_hits=None, flag_multiplier=1.0):
|
154 |
+
weighted_scores = [score * PATTERN_WEIGHTS.get(label, 1.0)
|
155 |
+
for label, score in zip(LABELS, scores) if score > thresholds[label]]
|
156 |
base_score = round(np.mean(weighted_scores) * 100, 2) if weighted_scores else 0.0
|
157 |
base_score *= flag_multiplier
|
158 |
return min(base_score, 100.0)
|
|
|
160 |
def analyze_single_message(text, thresholds, motif_flags):
|
161 |
motif_hits, matched_phrases = detect_motifs(text)
|
162 |
sentiment = custom_sentiment(text)
|
163 |
+
sentiment_score = 0.5 if sentiment["label"] == "undermining" else 0.0
|
164 |
+
print(f"Detected emotion: {sentiment['emotion']} → sentiment: {sentiment['label']}")
|
|
|
|
|
165 |
|
166 |
+
adjusted_thresholds = {
|
167 |
+
k: v * 0.8 for k, v in thresholds.items()
|
168 |
+
} if sentiment["label"] == "undermining" else thresholds.copy()
|
169 |
|
170 |
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
|
171 |
with torch.no_grad():
|
|
|
177 |
pattern_labels_used = list(set(threshold_labels + phrase_labels))
|
178 |
|
179 |
abuse_level = calculate_abuse_level(scores, adjusted_thresholds, motif_hits)
|
180 |
+
top_patterns = sorted([(label, score) for label, score in zip(LABELS, scores)],
|
181 |
+
key=lambda x: x[1], reverse=True)[:2]
|
182 |
+
|
183 |
motif_phrases = [text for _, text in matched_phrases]
|
184 |
contradiction_flag = detect_contradiction(text)
|
185 |
darvo_score = calculate_darvo_score(pattern_labels_used, 0.0, sentiment_score, motif_phrases, contradiction_flag)
|
186 |
|
187 |
return abuse_level, pattern_labels_used, top_patterns, darvo_score, sentiment
|
188 |
|
189 |
+
# --- Composite Message Analysis ---
|
190 |
def analyze_composite(msg1, msg2, msg3, flags):
|
191 |
thresholds = THRESHOLDS
|
192 |
messages = [msg1, msg2, msg3]
|
|
|
200 |
print(f"Message: {m}")
|
201 |
print(f"Sentiment result: {result[4]}")
|
202 |
results.append(result)
|
203 |
+
|
204 |
abuse_scores = [r[0] for r in results]
|
205 |
darvo_scores = [r[3] for r in results]
|
206 |
average_darvo = round(sum(darvo_scores) / len(darvo_scores), 3)
|
207 |
base_score = sum(abuse_scores) / len(abuse_scores)
|
208 |
+
|
209 |
label_sets = [[label for label, _ in r[2]] for r in results]
|
210 |
label_counts = {label: sum(label in s for s in label_sets) for label in set().union(*label_sets)}
|
211 |
top_label = max(label_counts.items(), key=lambda x: x[1])
|
212 |
top_explanation = EXPLANATIONS.get(top_label[0], "")
|
213 |
+
|
214 |
flag_weights = {
|
215 |
"They've threatened harm": 6,
|
216 |
"They isolate me": 5,
|
|
|
218 |
"They monitor/follow me": 4,
|
219 |
"I feel unsafe when alone with them": 6
|
220 |
}
|
221 |
+
|
222 |
flag_boost = sum(flag_weights.get(f, 3) for f in flags) / len(active_messages)
|
223 |
composite_score = min(base_score + flag_boost, 100)
|
224 |
if len(active_messages) == 1:
|
|
|
226 |
elif len(active_messages) == 2:
|
227 |
composite_score *= 0.93
|
228 |
composite_score = round(min(composite_score, 100), 2)
|
229 |
+
|
230 |
result = f"These messages show a pattern of **{top_label[0]}** and are estimated to be {composite_score}% likely abusive."
|
231 |
if top_explanation:
|
232 |
result += f"\n• {top_explanation}"
|
|
|
236 |
result += generate_risk_snippet(composite_score, top_label[0])
|
237 |
return result
|
238 |
|
239 |
+
# --- Gradio Interface ---
|
240 |
textbox_inputs = [
|
241 |
gr.Textbox(label="Message 1"),
|
242 |
gr.Textbox(label="Message 2"),
|