Spaces:
Running
on
Zero
Running
on
Zero
File size: 6,830 Bytes
d6e219c f1948f2 67874ff 1ef0cee 9ea207c 67874ff 795371c a9d4250 9ea207c e185e86 1dbc865 e185e86 9ea207c e185e86 f32b7e3 9ea207c f32b7e3 9ea207c f32b7e3 9ea207c f32b7e3 e4d31a4 9ea207c e4d31a4 9ea207c e4d31a4 9ea207c 83bf881 9ea207c 83bf881 9ea207c 88862a6 9ea207c be824f1 9ea207c 88862a6 9ea207c a20692f 9ea207c 1bd01b0 5a21664 9ea207c 5a21664 9ea207c 1bd01b0 9ea207c 1bd01b0 5a21664 9ea207c 5a21664 9ea207c 5d031d6 9387da0 ea50b1d 57b9f91 eb8fc53 57b9f91 5a21664 9387da0 5a21664 9ea207c 5a21664 91451af c6f9529 e8a72bc c6f9529 e8a72bc 91451af 623e3b6 eab911b dee6092 21421a3 9ea207c 1ef0cee 9ea207c e5c8c09 3a65ac7 e5c8c09 2b8d8a7 1ef0cee 9ea207c 2b8d8a7 9ea207c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 |
import gradio as gr
import torch
import numpy as np
from transformers import pipeline, AutoModelForSequenceClassification, AutoTokenizer
import matplotlib.pyplot as plt
import io
from PIL import Image
from datetime import datetime
# --- Load models ---
model_name = "SamanthaStorm/tether-multilabel-v2" # UPDATE if needed
model = AutoModelForSequenceClassification.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
healthy_detector = pipeline("text-classification", model="distilbert-base-uncased-finetuned-sst-2-english")
sst_pipeline = pipeline("sentiment-analysis", model="distilbert-base-uncased-finetuned-sst-2-english")
LABELS = [
"blame shifting", "contradictory statements", "control", "dismissiveness",
"gaslighting", "guilt tripping", "insults", "obscure language",
"projection", "recovery phase", "threat"
]
THRESHOLDS = {
"blame shifting": 0.3, "contradictory statements": 0.3, "control": 0.35, "dismissiveness": 0.4,
"gaslighting": 0.3, "guilt tripping": 0.3, "insults": 0.3, "obscure language": 0.4,
"projection": 0.4, "recovery phase": 0.35, "threat": 0.3
}
ESCALATION_QUESTIONS = [
("Partner has access to firearms or weapons", 4),
("Partner threatened to kill you", 3),
("Partner threatened you with a weapon", 3),
("Partner has ever choked you", 4),
("Partner injured or threatened your pet(s)", 3),
("Partner has broken your things, punched walls, or thrown objects", 2),
("Partner forced you into unwanted sexual acts", 3),
("Partner threatened to take away your children", 2),
("Violence has increased in frequency or severity", 3),
("Partner monitors your calls, GPS, or social media", 2)
]
# --- Functions ---
def is_healthy_message(text, threshold=0.9):
result = healthy_detector(text)[0]
return result['label'] == "POSITIVE" and result['score'] > threshold
def generate_abuse_score_chart(dates, scores, labels):
try:
parsed_dates = [datetime.strptime(d, "%Y-%m-%d") for d in dates]
except Exception:
parsed_dates = list(range(len(dates)))
fig, ax = plt.subplots(figsize=(8, 3))
ax.plot(parsed_dates, scores, marker='o', linestyle='-', color='darkred', linewidth=2)
for i, (x, y) in enumerate(zip(parsed_dates, scores)):
label = labels[i]
ax.text(x, y + 2, f"{label}\n{int(y)}%", ha='center', fontsize=8, color='black')
ax.set_title("Abuse Intensity Over Time")
ax.set_xlabel("Date")
ax.set_ylabel("Abuse Score (%)")
ax.set_ylim(0, 105)
ax.grid(True)
plt.tight_layout()
buf = io.BytesIO()
plt.savefig(buf, format='png')
buf.seek(0)
return Image.open(buf)
def analyze_single_message(text):
if is_healthy_message(text):
return {
"abuse_score": 0,
"labels": [],
"sentiment": "supportive",
"stage": 4,
"top_patterns": [],
}
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits.squeeze(0)
probs = torch.sigmoid(logits).numpy()
detected_labels = [
label for label, prob in zip(LABELS, probs)
if prob > THRESHOLDS.get(label, 0.3)
]
abuse_score = (sum(probs[i] for i, label in enumerate(LABELS) if label in detected_labels) / len(LABELS)) * 100
sentiment_result = sst_pipeline(text)[0]
sentiment = "supportive" if sentiment_result['label'] == "POSITIVE" else "undermining"
if "threat" in detected_labels or "insults" in detected_labels:
stage = 2 # Escalation
elif "control" in detected_labels or "guilt tripping" in detected_labels:
stage = 1 # Tension building
elif "recovery phase" in detected_labels:
stage = 3 # Reconciliation
else:
stage = 1
top_patterns = sorted(
[(label, prob) for label, prob in zip(LABELS, probs)],
key=lambda x: x[1],
reverse=True
)[:2]
return {
"abuse_score": int(abuse_score),
"labels": detected_labels,
"sentiment": sentiment,
"stage": stage,
"top_patterns": top_patterns,
}
def analyze_composite(msg1, date1, msg2, date2, msg3, date3, *answers_and_none):
none_selected_checked = answers_and_none[-1]
responses_checked = any(answers_and_none[:-1])
none_selected = not responses_checked and none_selected_checked
if none_selected:
escalation_score = None
risk_level = "unknown"
else:
escalation_score = sum(w for (_, w), a in zip(ESCALATION_QUESTIONS, answers_and_none[:-1]) if a)
risk_level = (
"High" if escalation_score >= 16 else
"Moderate" if escalation_score >= 8 else
"Low"
)
messages = [msg1, msg2, msg3]
dates = [date1, date2, date3]
active = [(m, d) for m, d in zip(messages, dates) if m.strip()]
if not active:
return "Please enter at least one message."
results = [(analyze_single_message(m), d) for m, d in active]
abuse_scores = [r[0]["abuse_score"] for r in results]
top_labels = [r[0]["top_patterns"][0][0] if r[0]["top_patterns"] else "None" for r in results]
dates_used = [r[1] or "Undated" for r in results]
composite_abuse = int(round(sum(abuse_scores) / len(abuse_scores)))
- most_common_stage = max(
- set(r[0]["stage"] for r in results),
- key==lambda x: [r[0]["stage"] for r in results].count
- )
+ most_common_stage = max(
+ set(r[0]["stage"] for r in results),
+ key=lambda x: [r[0]["stage"] for r in results].count(x)
+ )
out = f"Abuse Intensity: {composite_abuse}%\n"
if escalation_score is None:
out += "Escalation Potential: Unknown (Checklist not completed)\n"
else:
out += f"Escalation Potential: {risk_level} ({escalation_score}/{sum(w for _, w in ESCALATION_QUESTIONS)})\n"
timeline_image = generate_abuse_score_chart(dates_used, abuse_scores, top_labels)
return out, timeline_image
# --- Gradio Interface ---
message_date_pairs = [
(
gr.Textbox(label=f"Message {i+1}"),
gr.Textbox(label=f"Date {i+1} (optional)", placeholder="YYYY-MM-DD")
)
for i in range(3)
]
textbox_inputs = [item for pair in message_date_pairs for item in pair]
quiz_boxes = [gr.Checkbox(label=q) for q, _ in ESCALATION_QUESTIONS]
none_box = gr.Checkbox(label="None of the above")
iface = gr.Interface(
fn=analyze_composite,
inputs=textbox_inputs + quiz_boxes + [none_box],
outputs=[
gr.Textbox(label="Results"),
gr.Image(label="Risk Stage Timeline", type="pil")
],
title="Tether Abuse Pattern Detector v2",
allow_flagging="manual"
)
if __name__ == "__main__":
iface.launch() |