Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -1,55 +1,23 @@
|
|
1 |
import gradio as gr
|
2 |
import torch
|
3 |
import numpy as np
|
4 |
-
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
5 |
-
from transformers import RobertaForSequenceClassification, RobertaTokenizer
|
6 |
from motif_tagging import detect_motifs
|
7 |
-
from abuse_type_mapping import determine_abuse_type
|
8 |
|
9 |
-
#
|
10 |
sentiment_model = AutoModelForSequenceClassification.from_pretrained("SamanthaStorm/tether-sentiment")
|
11 |
sentiment_tokenizer = AutoTokenizer.from_pretrained("SamanthaStorm/tether-sentiment")
|
12 |
-
|
13 |
-
# Load abuse pattern model
|
14 |
model_name = "SamanthaStorm/autotrain-c1un8-p8vzo"
|
15 |
model = RobertaForSequenceClassification.from_pretrained(model_name, trust_remote_code=True)
|
16 |
tokenizer = RobertaTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
17 |
|
18 |
-
LABELS = [
|
19 |
-
|
20 |
-
"contradictory_statements", "manipulation", "deflection", "insults", "obscure_formal", "recovery_phase", "non_abusive",
|
21 |
-
"suicidal_threat", "physical_threat", "extreme_control"
|
22 |
-
]
|
23 |
-
|
24 |
-
THRESHOLDS = {
|
25 |
-
"gaslighting": 0.25, "mockery": 0.15, "dismissiveness": 0.45, "control": 0.43, "guilt_tripping": 0.15,
|
26 |
-
"apology_baiting": 0.2, "blame_shifting": 0.23, "projection": 0.50, "contradictory_statements": 0.25,
|
27 |
-
"manipulation": 0.25, "deflection": 0.30, "insults": 0.34, "obscure_formal": 0.25, "recovery_phase": 0.25,
|
28 |
-
"non_abusive": 2.0, "suicidal_threat": 0.45, "physical_threat": 0.02, "extreme_control": 0.30
|
29 |
-
}
|
30 |
-
|
31 |
PATTERN_LABELS = LABELS[:15]
|
32 |
DANGER_LABELS = LABELS[15:18]
|
|
|
|
|
33 |
|
34 |
-
EXPLANATIONS = {
|
35 |
-
"gaslighting": "Gaslighting involves making someone question their own reality or perceptions...",
|
36 |
-
"blame_shifting": "Blame-shifting is when one person redirects the responsibility...",
|
37 |
-
"projection": "Projection involves accusing the victim of behaviors the abuser exhibits.",
|
38 |
-
"dismissiveness": "Dismissiveness is belittling or disregarding another person’s feelings.",
|
39 |
-
"mockery": "Mockery ridicules someone in a hurtful, humiliating way.",
|
40 |
-
"recovery_phase": "Recovery phase dismisses someone's emotional healing process.",
|
41 |
-
"insults": "Insults are derogatory remarks aimed at degrading someone.",
|
42 |
-
"apology_baiting": "Apology-baiting manipulates victims into apologizing for abuser's behavior.",
|
43 |
-
"deflection": "Deflection avoids accountability by redirecting blame.",
|
44 |
-
"control": "Control restricts autonomy through manipulation or coercion.",
|
45 |
-
"extreme_control": "Extreme control dominates decisions and behaviors entirely.",
|
46 |
-
"physical_threat": "Physical threats signal risk of bodily harm.",
|
47 |
-
"suicidal_threat": "Suicidal threats manipulate others using self-harm threats.",
|
48 |
-
"guilt_tripping": "Guilt-tripping uses guilt to manipulate someone’s actions.",
|
49 |
-
"manipulation": "Manipulation deceives to influence or control outcomes.",
|
50 |
-
"non_abusive": "Non-abusive language is respectful and free of coercion.",
|
51 |
-
"obscure_formal": "Obscure/formal language manipulates through confusion or superiority."
|
52 |
-
}
|
53 |
|
54 |
def custom_sentiment(text):
|
55 |
inputs = sentiment_tokenizer(text, return_tensors="pt", truncation=True, padding=True)
|
@@ -57,21 +25,11 @@ def custom_sentiment(text):
|
|
57 |
outputs = sentiment_model(**inputs)
|
58 |
probs = torch.nn.functional.softmax(outputs.logits, dim=1)
|
59 |
label_idx = torch.argmax(probs).item()
|
60 |
-
|
61 |
label_map = {0: "supportive", 1: "undermining"}
|
62 |
label = label_map[label_idx]
|
63 |
score = probs[0][label_idx].item()
|
64 |
return {"label": label, "score": score}
|
65 |
|
66 |
-
PATTERN_WEIGHTS = {
|
67 |
-
"physical_threat": 1.5,
|
68 |
-
"suicidal_threat": 1.4,
|
69 |
-
"extreme_control": 1.5,
|
70 |
-
"gaslighting": 1.3,
|
71 |
-
"control": 1.2,
|
72 |
-
"dismissiveness": 0.8,
|
73 |
-
"non_abusive": 0.0 # shouldn't contribute to abuse score
|
74 |
-
}
|
75 |
|
76 |
def calculate_abuse_level(scores, thresholds, motif_hits=None):
|
77 |
weighted_scores = []
|
@@ -80,13 +38,12 @@ def calculate_abuse_level(scores, thresholds, motif_hits=None):
|
|
80 |
weight = PATTERN_WEIGHTS.get(label, 1.0)
|
81 |
weighted_scores.append(score * weight)
|
82 |
base_score = round(np.mean(weighted_scores) * 100, 2) if weighted_scores else 0.0
|
83 |
-
|
84 |
motif_hits = motif_hits or []
|
85 |
-
if any(label in motif_hits for label in
|
86 |
base_score = max(base_score, 75.0)
|
87 |
-
|
88 |
return base_score
|
89 |
|
|
|
90 |
def interpret_abuse_level(score):
|
91 |
if score > 80:
|
92 |
return "Extreme / High Risk"
|
@@ -98,124 +55,50 @@ def interpret_abuse_level(score):
|
|
98 |
return "Mild Concern"
|
99 |
return "Very Low / Likely Safe"
|
100 |
|
101 |
-
def analyze_messages(input_text, risk_flags):
|
102 |
-
input_text = input_text.strip()
|
103 |
-
if not input_text:
|
104 |
-
return "Please enter a message for analysis."
|
105 |
-
|
106 |
-
# Normalize the text (example: lower case)
|
107 |
-
normalized_text = input_text.strip().lower()
|
108 |
-
|
109 |
-
motif_flags, matched_phrases = detect_motifs(input_text)
|
110 |
-
risk_flags = list(set(risk_flags + motif_flags)) if risk_flags else motif_flags
|
111 |
-
|
112 |
-
sentiment = custom_sentiment(input_text)
|
113 |
-
sentiment_label = sentiment['label']
|
114 |
-
sentiment_score = sentiment['score']
|
115 |
|
116 |
-
|
117 |
-
|
118 |
-
|
|
|
|
|
|
|
|
|
119 |
with torch.no_grad():
|
120 |
outputs = model(**inputs)
|
121 |
scores = torch.sigmoid(outputs.logits.squeeze(0)).numpy()
|
|
|
|
|
|
|
122 |
|
123 |
-
threshold_labels = [label for label, score in zip(PATTERN_LABELS, scores[:15]) if score > adjusted_thresholds[label]]
|
124 |
-
phrase_labels = [label for label, _ in matched_phrases]
|
125 |
-
pattern_labels_used = list(set(threshold_labels + phrase_labels))
|
126 |
-
|
127 |
-
contextual_flags = risk_flags if risk_flags else []
|
128 |
-
# Note: If there are two or more contextual flags, you might wish to adjust a danger counter
|
129 |
-
# danger_flag_count += 1 <-- Ensure that danger_flag_count is defined before incrementing.
|
130 |
-
abuse_level = calculate_abuse_level(scores, adjusted_thresholds, motif_hits=[label for label, _ in matched_phrases])
|
131 |
-
abuse_description = interpret_abuse_level(abuse_level)
|
132 |
-
|
133 |
-
# Escalate risk if user checks a critical context box
|
134 |
-
if contextual_flags and abuse_level < 15:
|
135 |
-
abuse_level = 15 # bump to at least Mild Concern
|
136 |
-
|
137 |
-
abuse_type, abuser_profile, advice = determine_abuse_type(pattern_labels_used)
|
138 |
-
|
139 |
-
danger_flag_count = sum(score > adjusted_thresholds[label] for label, score in zip(DANGER_LABELS, scores[15:18]))
|
140 |
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
if non_abusive_confident and danger_flag_count == 0 and not matched_phrases:
|
148 |
-
return "This message is classified as non-abusive."
|
149 |
-
# Supportive override logic
|
150 |
-
if (
|
151 |
-
sentiment_label == "supportive"
|
152 |
-
and sentiment_score > 0.95
|
153 |
-
and non_abusive_confident
|
154 |
-
and danger_flag_count == 0
|
155 |
-
and not matched_phrases
|
156 |
-
):
|
157 |
-
return "This message is classified as non-abusive. It appears emotionally supportive and safe."
|
158 |
-
|
159 |
-
scored_patterns = [
|
160 |
-
(label, score) for label, score in zip(PATTERN_LABELS, scores[:15]) if label != "non_abusive"
|
161 |
-
]
|
162 |
-
|
163 |
-
override_labels = {"physical_threat", "suicidal_threat", "extreme_control"}
|
164 |
-
override_matches = [label for label, _ in matched_phrases if label in override_labels]
|
165 |
-
|
166 |
-
if override_matches:
|
167 |
-
top_patterns = [(label, 1.0) for label in override_matches]
|
168 |
-
else:
|
169 |
-
top_patterns = sorted(scored_patterns, key=lambda x: x[1], reverse=True)[:2]
|
170 |
-
|
171 |
-
top_pattern_explanations = "\n".join([
|
172 |
-
f"• {label.replace('_', ' ').title()}: {EXPLANATIONS.get(label, 'No explanation available.')}"
|
173 |
-
for label, _ in top_patterns
|
174 |
])
|
|
|
|
|
|
|
175 |
|
176 |
-
resources = "Immediate assistance recommended. Please seek professional help or contact emergency services." if danger_flag_count >= 2 else "For more information on abuse patterns, consider reaching out to support groups or professional counselors."
|
177 |
-
|
178 |
-
result = f"Abuse Risk Score: {abuse_level}% – {abuse_description}\n\n"
|
179 |
-
if abuse_level >= 15:
|
180 |
-
result += f"Most Likely Patterns:\n{top_pattern_explanations}\n\n"
|
181 |
-
result += f"⚠️ Critical Danger Flags Detected: {danger_flag_count} of 3\n"
|
182 |
-
result += f"Resources: {resources}\n"
|
183 |
-
result += f"🧠 Sentiment: {sentiment_label.title()} (Confidence: {sentiment_score*100:.2f}%)\n"
|
184 |
|
185 |
-
if contextual_flags:
|
186 |
-
result += "\n\n⚠️ You indicated the following:\n" + "\n".join([f"• {flag.replace('_', ' ').title()}" for flag in contextual_flags])
|
187 |
-
|
188 |
-
if high_risk_context:
|
189 |
-
result += "\n\n🚨 These responses suggest a high-risk situation. Consider seeking immediate help or safety planning resources."
|
190 |
-
|
191 |
-
if matched_phrases:
|
192 |
-
result += "\n\n🚨 Detected High-Risk Phrases:\n"
|
193 |
-
for label, phrase in matched_phrases:
|
194 |
-
phrase_clean = phrase.replace('"', "'").strip()
|
195 |
-
result += f"• {label.replace('_', ' ').title()}: “{phrase_clean}”\n"
|
196 |
-
|
197 |
-
if abuse_type:
|
198 |
-
result += f"\n\n🧠 Likely Abuse Type: {abuse_type}"
|
199 |
-
result += f"\n🧠 Abuser Profile: {abuser_profile}"
|
200 |
-
result += f"\n📘 Safety Tip: {advice}"
|
201 |
-
|
202 |
-
return result
|
203 |
-
|
204 |
-
# Updated Interface: Added flagging functionality to allow users to flag mispredictions.
|
205 |
iface = gr.Interface(
|
206 |
-
fn=
|
207 |
inputs=[
|
208 |
-
gr.Textbox(
|
209 |
-
gr.
|
210 |
-
|
211 |
-
|
212 |
-
|
|
|
|
|
|
|
|
|
213 |
],
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
allow_flagging="manual" # This enables the manual flagging button for user feedback.
|
218 |
)
|
219 |
|
220 |
if __name__ == "__main__":
|
221 |
-
iface.
|
|
|
1 |
import gradio as gr
|
2 |
import torch
|
3 |
import numpy as np
|
4 |
+
from transformers import AutoModelForSequenceClassification, AutoTokenizer, RobertaForSequenceClassification, RobertaTokenizer
|
|
|
5 |
from motif_tagging import detect_motifs
|
|
|
6 |
|
7 |
+
# Load models
|
8 |
sentiment_model = AutoModelForSequenceClassification.from_pretrained("SamanthaStorm/tether-sentiment")
|
9 |
sentiment_tokenizer = AutoTokenizer.from_pretrained("SamanthaStorm/tether-sentiment")
|
|
|
|
|
10 |
model_name = "SamanthaStorm/autotrain-c1un8-p8vzo"
|
11 |
model = RobertaForSequenceClassification.from_pretrained(model_name, trust_remote_code=True)
|
12 |
tokenizer = RobertaTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
13 |
|
14 |
+
LABELS = [...]
|
15 |
+
THRESHOLDS = {...}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
PATTERN_LABELS = LABELS[:15]
|
17 |
DANGER_LABELS = LABELS[15:18]
|
18 |
+
EXPLANATIONS = {...}
|
19 |
+
PATTERN_WEIGHTS = {...}
|
20 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
|
22 |
def custom_sentiment(text):
|
23 |
inputs = sentiment_tokenizer(text, return_tensors="pt", truncation=True, padding=True)
|
|
|
25 |
outputs = sentiment_model(**inputs)
|
26 |
probs = torch.nn.functional.softmax(outputs.logits, dim=1)
|
27 |
label_idx = torch.argmax(probs).item()
|
|
|
28 |
label_map = {0: "supportive", 1: "undermining"}
|
29 |
label = label_map[label_idx]
|
30 |
score = probs[0][label_idx].item()
|
31 |
return {"label": label, "score": score}
|
32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
|
34 |
def calculate_abuse_level(scores, thresholds, motif_hits=None):
|
35 |
weighted_scores = []
|
|
|
38 |
weight = PATTERN_WEIGHTS.get(label, 1.0)
|
39 |
weighted_scores.append(score * weight)
|
40 |
base_score = round(np.mean(weighted_scores) * 100, 2) if weighted_scores else 0.0
|
|
|
41 |
motif_hits = motif_hits or []
|
42 |
+
if any(label in motif_hits for label in DANGER_LABELS):
|
43 |
base_score = max(base_score, 75.0)
|
|
|
44 |
return base_score
|
45 |
|
46 |
+
|
47 |
def interpret_abuse_level(score):
|
48 |
if score > 80:
|
49 |
return "Extreme / High Risk"
|
|
|
55 |
return "Mild Concern"
|
56 |
return "Very Low / Likely Safe"
|
57 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
|
59 |
+
def analyze_single_message(text):
|
60 |
+
if not text.strip():
|
61 |
+
return "No input provided."
|
62 |
+
sentiment = custom_sentiment(text)
|
63 |
+
thresholds = {k: v * 0.8 for k, v in THRESHOLDS.items()} if sentiment['label'] == "undermining" else THRESHOLDS.copy()
|
64 |
+
motif_flags, matched_phrases = detect_motifs(text)
|
65 |
+
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
|
66 |
with torch.no_grad():
|
67 |
outputs = model(**inputs)
|
68 |
scores = torch.sigmoid(outputs.logits.squeeze(0)).numpy()
|
69 |
+
abuse_score = calculate_abuse_level(scores, thresholds, [label for label, _ in matched_phrases])
|
70 |
+
summary = interpret_abuse_level(abuse_score)
|
71 |
+
return f"Abuse Risk Score: {abuse_score}% — {summary}\nSentiment: {sentiment['label']} ({sentiment['score']*100:.2f}%)"
|
72 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
|
74 |
+
def analyze_composite(msg1, msg2, msg3):
|
75 |
+
results = [analyze_single_message(t) for t in [msg1, msg2, msg3]]
|
76 |
+
composite_score = np.mean([
|
77 |
+
float(line.split('%')[0].split()[-1]) if 'Abuse Risk Score:' in line else 0
|
78 |
+
for line in results
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
])
|
80 |
+
final_summary = interpret_abuse_level(composite_score)
|
81 |
+
composite_result = f"\n\nComposite Abuse Risk Score: {composite_score:.2f}% — {final_summary}"
|
82 |
+
return results[0], results[1], results[2], composite_result
|
83 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
iface = gr.Interface(
|
86 |
+
fn=analyze_composite,
|
87 |
inputs=[
|
88 |
+
gr.Textbox(label="Message 1"),
|
89 |
+
gr.Textbox(label="Message 2"),
|
90 |
+
gr.Textbox(label="Message 3")
|
91 |
+
],
|
92 |
+
outputs=[
|
93 |
+
gr.Textbox(label="Message 1 Result"),
|
94 |
+
gr.Textbox(label="Message 2 Result"),
|
95 |
+
gr.Textbox(label="Message 3 Result"),
|
96 |
+
gr.Textbox(label="Composite Score Summary")
|
97 |
],
|
98 |
+
title="Abuse Pattern Detector (Multi-Message)",
|
99 |
+
live=False,
|
100 |
+
allow_flagging="manual"
|
|
|
101 |
)
|
102 |
|
103 |
if __name__ == "__main__":
|
104 |
+
iface.launch()
|