Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import torch | |
import numpy as np | |
from transformers import pipeline, AutoModelForSequenceClassification, AutoTokenizer | |
import matplotlib.pyplot as plt | |
import io | |
from PIL import Image | |
from datetime import datetime | |
# --- Load models --- | |
model_name = "SamanthaStorm/tether-multilabel-v2" # UPDATE if needed | |
model = AutoModelForSequenceClassification.from_pretrained(model_name) | |
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False) | |
healthy_detector = pipeline("text-classification", model="distilbert-base-uncased-finetuned-sst-2-english") | |
sst_pipeline = pipeline("sentiment-analysis", model="distilbert-base-uncased-finetuned-sst-2-english") | |
LABELS = [ | |
"blame shifting", "contradictory statements", "control", "dismissiveness", | |
"gaslighting", "guilt tripping", "insults", "obscure language", | |
"projection", "recovery phase", "threat" | |
] | |
THRESHOLDS = { | |
"blame shifting": 0.3, "contradictory statements": 0.3, "control": 0.35, "dismissiveness": 0.4, | |
"gaslighting": 0.3, "guilt tripping": 0.3, "insults": 0.3, "obscure language": 0.4, | |
"projection": 0.4, "recovery phase": 0.35, "threat": 0.3 | |
} | |
ESCALATION_QUESTIONS = [ | |
("Partner has access to firearms or weapons", 4), | |
("Partner threatened to kill you", 3), | |
("Partner threatened you with a weapon", 3), | |
("Partner has ever choked you", 4), | |
("Partner injured or threatened your pet(s)", 3), | |
("Partner has broken your things, punched walls, or thrown objects", 2), | |
("Partner forced you into unwanted sexual acts", 3), | |
("Partner threatened to take away your children", 2), | |
("Violence has increased in frequency or severity", 3), | |
("Partner monitors your calls, GPS, or social media", 2) | |
] | |
# --- Functions --- | |
def is_healthy_message(text, threshold=0.9): | |
result = healthy_detector(text)[0] | |
return result['label'] == "POSITIVE" and result['score'] > threshold | |
def generate_abuse_score_chart(dates, scores, labels): | |
try: | |
parsed_dates = [datetime.strptime(d, "%Y-%m-%d") for d in dates] | |
except Exception: | |
parsed_dates = list(range(len(dates))) | |
fig, ax = plt.subplots(figsize=(8, 3)) | |
ax.plot(parsed_dates, scores, marker='o', linestyle='-', color='darkred', linewidth=2) | |
for i, (x, y) in enumerate(zip(parsed_dates, scores)): | |
label = labels[i] | |
ax.text(x, y + 2, f"{label}\n{int(y)}%", ha='center', fontsize=8, color='black') | |
ax.set_title("Abuse Intensity Over Time") | |
ax.set_xlabel("Date") | |
ax.set_ylabel("Abuse Score (%)") | |
ax.set_ylim(0, 105) | |
ax.grid(True) | |
plt.tight_layout() | |
buf = io.BytesIO() | |
plt.savefig(buf, format='png') | |
buf.seek(0) | |
return Image.open(buf) | |
def analyze_single_message(text): | |
if is_healthy_message(text): | |
return { | |
"abuse_score": 0, | |
"labels": [], | |
"sentiment": "supportive", | |
"stage": 4, | |
"top_patterns": [], | |
} | |
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True) | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
logits = outputs.logits.squeeze(0) | |
probs = torch.sigmoid(logits).numpy() | |
detected_labels = [ | |
label for label, prob in zip(LABELS, probs) | |
if prob > THRESHOLDS.get(label, 0.3) | |
] | |
abuse_score = (sum(probs[i] for i, label in enumerate(LABELS) if label in detected_labels) / len(LABELS)) * 100 | |
sentiment_result = sst_pipeline(text)[0] | |
sentiment = "supportive" if sentiment_result['label'] == "POSITIVE" else "undermining" | |
if "threat" in detected_labels or "insults" in detected_labels: | |
stage = 2 # Escalation | |
elif "control" in detected_labels or "guilt tripping" in detected_labels: | |
stage = 1 # Tension building | |
elif "recovery phase" in detected_labels: | |
stage = 3 # Reconciliation | |
else: | |
stage = 1 | |
top_patterns = sorted( | |
[(label, prob) for label, prob in zip(LABELS, probs)], | |
key=lambda x: x[1], | |
reverse=True | |
)[:2] | |
return { | |
"abuse_score": int(abuse_score), | |
"labels": detected_labels, | |
"sentiment": sentiment, | |
"stage": stage, | |
"top_patterns": top_patterns, | |
} | |
def analyze_composite(msg1, date1, msg2, date2, msg3, date3, *answers_and_none): | |
none_selected_checked = answers_and_none[-1] | |
responses_checked = any(answers_and_none[:-1]) | |
none_selected = not responses_checked and none_selected_checked | |
if none_selected: | |
escalation_score = None | |
risk_level = "unknown" | |
else: | |
escalation_score = sum(w for (_, w), a in zip(ESCALATION_QUESTIONS, answers_and_none[:-1]) if a) | |
risk_level = ( | |
"High" if escalation_score >= 16 else | |
"Moderate" if escalation_score >= 8 else | |
"Low" | |
) | |
messages = [msg1, msg2, msg3] | |
dates = [date1, date2, date3] | |
active = [(m, d) for m, d in zip(messages, dates) if m.strip()] | |
if not active: | |
return "Please enter at least one message." | |
results = [(analyze_single_message(m), d) for m, d in active] | |
abuse_scores = [r[0]["abuse_score"] for r in results] | |
top_labels = [r[0]["top_patterns"][0][0] if r[0]["top_patterns"] else "None" for r in results] | |
dates_used = [r[1] or "Undated" for r in results] | |
composite_abuse = int(round(sum(abuse_scores) / len(abuse_scores))) | |
- most_common_stage = max( | |
- set(r[0]["stage"] for r in results), | |
- key==lambda x: [r[0]["stage"] for r in results].count | |
- ) | |
+ most_common_stage = max( | |
+ set(r[0]["stage"] for r in results), | |
+ key=lambda x: [r[0]["stage"] for r in results].count(x) | |
+ ) | |
out = f"Abuse Intensity: {composite_abuse}%\n" | |
if escalation_score is None: | |
out += "Escalation Potential: Unknown (Checklist not completed)\n" | |
else: | |
out += f"Escalation Potential: {risk_level} ({escalation_score}/{sum(w for _, w in ESCALATION_QUESTIONS)})\n" | |
timeline_image = generate_abuse_score_chart(dates_used, abuse_scores, top_labels) | |
return out, timeline_image | |
# --- Gradio Interface --- | |
message_date_pairs = [ | |
( | |
gr.Textbox(label=f"Message {i+1}"), | |
gr.Textbox(label=f"Date {i+1} (optional)", placeholder="YYYY-MM-DD") | |
) | |
for i in range(3) | |
] | |
textbox_inputs = [item for pair in message_date_pairs for item in pair] | |
quiz_boxes = [gr.Checkbox(label=q) for q, _ in ESCALATION_QUESTIONS] | |
none_box = gr.Checkbox(label="None of the above") | |
iface = gr.Interface( | |
fn=analyze_composite, | |
inputs=textbox_inputs + quiz_boxes + [none_box], | |
outputs=[ | |
gr.Textbox(label="Results"), | |
gr.Image(label="Risk Stage Timeline", type="pil") | |
], | |
title="Tether Abuse Pattern Detector v2", | |
allow_flagging="manual" | |
) | |
if __name__ == "__main__": | |
iface.launch() |