Tether / app.py
SamanthaStorm's picture
Update app.py
f550b78 verified
raw
history blame
6.58 kB
import gradio as gr
import torch
import numpy as np
from transformers import pipeline, AutoModelForSequenceClassification, AutoTokenizer
import matplotlib.pyplot as plt
import io
from PIL import Image
from datetime import datetime
# --- Load models ---
model_name = "SamanthaStorm/tether-multilabel-v2" # UPDATE if needed
model = AutoModelForSequenceClassification.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
healthy_detector = pipeline("text-classification", model="distilbert-base-uncased-finetuned-sst-2-english")
sst_pipeline = pipeline("sentiment-analysis", model="distilbert-base-uncased-finetuned-sst-2-english")
LABELS = [
"blame shifting", "contradictory statements", "control", "dismissiveness",
"gaslighting", "guilt tripping", "insults", "obscure language",
"projection", "recovery phase", "threat"
]
THRESHOLDS = {
"blame shifting": 0.3, "contradictory statements": 0.3, "control": 0.35, "dismissiveness": 0.4,
"gaslighting": 0.3, "guilt tripping": 0.3, "insults": 0.3, "obscure language": 0.4,
"projection": 0.4, "recovery phase": 0.35, "threat": 0.3
}
ESCALATION_QUESTIONS = [
("Partner has access to firearms or weapons", 4),
("Partner threatened to kill you", 3),
("Partner threatened you with a weapon", 3),
("Partner has ever choked you", 4),
("Partner injured or threatened your pet(s)", 3),
("Partner has broken your things, punched walls, or thrown objects", 2),
("Partner forced you into unwanted sexual acts", 3),
("Partner threatened to take away your children", 2),
("Violence has increased in frequency or severity", 3),
("Partner monitors your calls, GPS, or social media", 2)
]
# --- Functions ---
def is_healthy_message(text, threshold=0.9):
result = healthy_detector(text)[0]
return result['label'] == "POSITIVE" and result['score'] > threshold
def generate_abuse_score_chart(dates, scores, labels):
try:
parsed_dates = [datetime.strptime(d, "%Y-%m-%d") for d in dates]
except Exception:
parsed_dates = list(range(len(dates)))
fig, ax = plt.subplots(figsize=(8, 3))
ax.plot(parsed_dates, scores, marker='o', linestyle='-', color='darkred', linewidth=2)
for i, (x, y) in enumerate(zip(parsed_dates, scores)):
label = labels[i]
ax.text(x, y + 2, f"{label}\n{int(y)}%", ha='center', fontsize=8, color='black')
ax.set_title("Abuse Intensity Over Time")
ax.set_xlabel("Date")
ax.set_ylabel("Abuse Score (%)")
ax.set_ylim(0, 105)
ax.grid(True)
plt.tight_layout()
buf = io.BytesIO()
plt.savefig(buf, format='png')
buf.seek(0)
return Image.open(buf)
def analyze_single_message(text):
if is_healthy_message(text):
return {
"abuse_score": 0,
"labels": [],
"sentiment": "supportive",
"stage": 4,
"top_patterns": [],
}
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits.squeeze(0)
probs = torch.sigmoid(logits).numpy()
detected_labels = [
label for label, prob in zip(LABELS, probs)
if prob > THRESHOLDS.get(label, 0.3)
]
abuse_score = (sum(probs[i] for i, label in enumerate(LABELS) if label in detected_labels) / len(LABELS)) * 100
sentiment_result = sst_pipeline(text)[0]
sentiment = "supportive" if sentiment_result['label'] == "POSITIVE" else "undermining"
if "threat" in detected_labels or "insults" in detected_labels:
stage = 2 # Escalation
elif "control" in detected_labels or "guilt tripping" in detected_labels:
stage = 1 # Tension building
elif "recovery phase" in detected_labels:
stage = 3 # Reconciliation
else:
stage = 1
top_patterns = sorted(
[(label, prob) for label, prob in zip(LABELS, probs)],
key=lambda x: x[1],
reverse=True
)[:2]
return {
"abuse_score": int(abuse_score),
"labels": detected_labels,
"sentiment": sentiment,
"stage": stage,
"top_patterns": top_patterns,
}
def analyze_composite(msg1, date1, msg2, date2, msg3, date3, *answers_and_none):
none_selected_checked = answers_and_none[-1]
responses_checked = any(answers_and_none[:-1])
none_selected = not responses_checked and none_selected_checked
if none_selected:
escalation_score = None
risk_level = "unknown"
else:
escalation_score = sum(w for (_, w), a in zip(ESCALATION_QUESTIONS, answers_and_none[:-1]) if a)
risk_level = (
"High" if escalation_score >= 16 else
"Moderate" if escalation_score >= 8 else
"Low"
)
messages = [msg1, msg2, msg3]
dates = [date1, date2, date3]
active = [(m, d) for m, d in zip(messages, dates) if m.strip()]
if not active:
return "Please enter at least one message."
results = [(analyze_single_message(m), d) for m, d in active]
abuse_scores = [r[0]["abuse_score"] for r in results]
top_labels = [r[0]["top_patterns"][0][0] if r[0]["top_patterns"] else "None" for r in results]
dates_used = [r[1] or "Undated" for r in results]
stage_list = [r[0]["stage"] for r in results]
most_common_stage = max(set(stage_list), key=stage_list.count)
out = f"Abuse Intensity: {composite_abuse}%\n"
if escalation_score is None:
out += "Escalation Potential: Unknown (Checklist not completed)\n"
else:
out += f"Escalation Potential: {risk_level} ({escalation_score}/{sum(w for _, w in ESCALATION_QUESTIONS)})\n"
timeline_image = generate_abuse_score_chart(dates_used, abuse_scores, top_labels)
return out, timeline_image
# --- Gradio Interface ---
message_date_pairs = [
(
gr.Textbox(label=f"Message {i+1}"),
gr.Textbox(label=f"Date {i+1} (optional)", placeholder="YYYY-MM-DD")
)
for i in range(3)
]
textbox_inputs = [item for pair in message_date_pairs for item in pair]
quiz_boxes = [gr.Checkbox(label=q) for q, _ in ESCALATION_QUESTIONS]
none_box = gr.Checkbox(label="None of the above")
iface = gr.Interface(
fn=analyze_composite,
inputs=textbox_inputs + quiz_boxes + [none_box],
outputs=[
gr.Textbox(label="Results"),
gr.Image(label="Risk Stage Timeline", type="pil")
],
title="Tether Abuse Pattern Detector v2",
allow_flagging="manual"
)
if __name__ == "__main__":
iface.launch()