Tether / app.py
SamanthaStorm's picture
Update app.py
cc9d8f5 verified
raw
history blame
7.27 kB
import gradio as gr
import torch
import numpy as np
from transformers import pipeline, RobertaForSequenceClassification, RobertaTokenizer
from motif_tagging import detect_motifs
import re
import matplotlib.pyplot as plt
import io
from PIL import Image
from datetime import datetime
from transformers import pipeline as hf_pipeline # prevent name collision with gradio pipeline
def get_emotion_profile(text):
emotions = emotion_pipeline(text)
if isinstance(emotions, list) and isinstance(emotions[0], list):
emotions = emotions[0]
return {e['label'].lower(): round(e['score'], 3) for e in emotions}
emotion_pipeline = hf_pipeline(
"text-classification",
model="j-hartmann/emotion-english-distilroberta-base",
top_k=None,
truncation=True
)
def generate_abuse_score_chart(dates, scores, labels):
if all(re.match(r"\d{4}-\d{2}-\d{2}", d) for d in dates):
parsed_x = [datetime.strptime(d, "%Y-%m-%d") for d in dates]
x_labels = [d.strftime("%Y-%m-%d") for d in parsed_x]
else:
parsed_x = list(range(1, len(dates) + 1))
x_labels = [f"Message {i+1}" for i in range(len(dates))]
fig, ax = plt.subplots(figsize=(8, 3))
ax.plot(parsed_x, scores, marker='o', linestyle='-', color='darkred', linewidth=2)
for x, y in zip(parsed_x, scores):
ax.text(x, y + 2, f"{int(y)}%", ha='center', fontsize=8, color='black')
ax.set_xticks(parsed_x)
ax.set_xticklabels(x_labels)
ax.set_xlabel("")
ax.set_ylabel("Abuse Score (%)")
ax.set_ylim(0, 105)
ax.grid(True)
plt.tight_layout()
buf = io.BytesIO()
plt.savefig(buf, format='png')
buf.seek(0)
return Image.open(buf)
from transformers import AutoModelForSequenceClassification, AutoTokenizer
model_name = "SamanthaStorm/tether-multilabel-v3"
model = AutoModelForSequenceClassification.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
LABELS = [...]
THRESHOLDS = {...}
PATTERN_WEIGHTS = {...}
RISK_STAGE_LABELS = {...}
ESCALATION_QUESTIONS = [...]
DARVO_PATTERNS = {...}
DARVO_MOTIFS = [...]
# (Leave the rest of your helper functions unchanged)
def analyze_single_message(text, thresholds):
motif_hits, matched_phrases = detect_motifs(text)
emotion_profile = get_emotion_profile(text)
sentiment_score = emotion_profile.get("anger", 0) + emotion_profile.get("disgust", 0)
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
with torch.no_grad():
outputs = model(**inputs)
scores = torch.sigmoid(outputs.logits.squeeze(0)).numpy()
if emotion_profile.get("neutral", 0) > 0.85 and any(
scores[label_idx] > thresholds[LABELS[label_idx]]
for label_idx in [LABELS.index(l) for l in ["control", "threat", "blame shifting"]]
):
sentiment = "undermining"
else:
sentiment = "undermining" if sentiment_score > 0.25 else "supportive"
adjusted_thresholds = {
k: v + 0.05 if sentiment == "supportive" else v
for k, v in thresholds.items()
}
passed = {
label: score for label, score in zip(LABELS, scores)
if score > adjusted_thresholds[label]
}
# (Continue unchanged)
def analyze_composite(msg1, date1, msg2, date2, msg3, date3, *answers_and_none):
none_selected_checked = answers_and_none[-1]
responses_checked = any(answers_and_none[:-1])
none_selected = not responses_checked and none_selected_checked
if none_selected:
escalation_score = None
risk_level = "unknown"
else:
escalation_score = sum(w for (_, w), a in zip(ESCALATION_QUESTIONS, answers_and_none[:-1]) if a)
risk_level = (
"High" if escalation_score >= 16 else
"Moderate" if escalation_score >= 8 else
"Low"
)
messages = [msg1, msg2, msg3]
dates = [date1, date2, date3]
active = [(m, d) for m, d in zip(messages, dates) if m.strip()]
if not active:
return "Please enter at least one message."
results = [(analyze_single_message(m, THRESHOLDS.copy()), d) for m, d in active]
for result, date in results:
assert len(result) == 7, "Unexpected output from analyze_single_message"
top_labels = [r[0][6] for r in results]
top_scores = [r[0][2][0][1] for r in results]
sentiments = [r[0][3]['label'] for r in results]
stages = [r[0][4] for r in results]
darvo_scores = [r[0][5] for r in results]
dates_used = [r[1] or "Undated" for r in results]
abuse_scores = [r[0][0] for r in results]
composite_abuse = int(round(sum(abuse_scores) / len(abuse_scores)))
top_label = f"{top_labels[0]} – {int(round(top_scores[0] * 100))}%"
most_common_stage = max(set(stages), key=stages.count)
stage_text = RISK_STAGE_LABELS[most_common_stage]
avg_darvo = round(sum(darvo_scores) / len(darvo_scores), 3)
darvo_blurb = ""
if avg_darvo > 0.25:
level = "moderate" if avg_darvo < 0.65 else "high"
darvo_blurb = f"\n\n🎭 **DARVO Score: {avg_darvo}** β†’ This indicates a **{level} likelihood** of narrative reversal (DARVO), where the speaker may be denying, attacking, or reversing blame."
out = f"Abuse Intensity: {composite_abuse}%\n"
out += "πŸ“Š This reflects the strength and severity of detected abuse patterns in the message(s).\n\n"
if escalation_score is None:
escalation_text = "πŸ“‰ Escalation Potential: Unknown (Checklist not completed)\n"
escalation_text += "⚠️ *This section was not completed. Escalation potential is unknown.*\n"
else:
escalation_text = f"🧨 **Escalation Potential: {risk_level} ({escalation_score}/{sum(w for _, w in ESCALATION_QUESTIONS)})**\n"
escalation_text += "This score comes directly from the safety checklist and functions as a standalone escalation risk score.\n"
escalation_text += "It indicates how many serious risk factors are present based on your answers to the safety checklist.\n"
if top_label is None:
top_label = "Unknown – 0%"
out += generate_risk_snippet(composite_abuse, top_label, escalation_score if escalation_score is not None else 0, most_common_stage)
out += f"\n\n{stage_text}"
out += darvo_blurb
out += "\n\n" + escalation_text
print(f"DEBUG: avg_darvo = {avg_darvo}")
pattern_labels = [r[0][2][0][0] for r in results]
timeline_image = generate_abuse_score_chart(dates_used, abuse_scores, pattern_labels)
return out, timeline_image
message_date_pairs = [
(
gr.Textbox(label=f"Message {i+1}"),
gr.Textbox(label=f"Date {i+1} (optional)", placeholder="YYYY-MM-DD")
)
for i in range(3)
]
textbox_inputs = [item for pair in message_date_pairs for item in pair]
quiz_boxes = [gr.Checkbox(label=q) for q, _ in ESCALATION_QUESTIONS]
none_box = gr.Checkbox(label="None of the above")
iface = gr.Interface(
fn=analyze_composite,
inputs=textbox_inputs + quiz_boxes + [none_box],
outputs=[
gr.Textbox(label="Results"),
gr.Image(label="Risk Stage Timeline", type="pil")
],
title="Abuse Pattern Detector + Escalation Quiz",
allow_flagging="manual"
)
if __name__ == "__main__":
iface.launch()