Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import torch | |
import numpy as np | |
from transformers import pipeline, RobertaForSequenceClassification, RobertaTokenizer | |
from motif_tagging import detect_motifs | |
import re | |
import matplotlib.pyplot as plt | |
import io | |
from PIL import Image | |
from datetime import datetime | |
from transformers import pipeline as hf_pipeline # prevent name collision with gradio pipeline | |
def get_emotion_profile(text): | |
emotions = emotion_pipeline(text) | |
if isinstance(emotions, list) and isinstance(emotions[0], list): | |
emotions = emotions[0] | |
return {e['label'].lower(): round(e['score'], 3) for e in emotions} | |
emotion_pipeline = hf_pipeline( | |
"text-classification", | |
model="j-hartmann/emotion-english-distilroberta-base", | |
top_k=None, | |
truncation=True | |
) | |
def generate_abuse_score_chart(dates, scores, labels): | |
if all(re.match(r"\d{4}-\d{2}-\d{2}", d) for d in dates): | |
parsed_x = [datetime.strptime(d, "%Y-%m-%d") for d in dates] | |
x_labels = [d.strftime("%Y-%m-%d") for d in parsed_x] | |
else: | |
parsed_x = list(range(1, len(dates) + 1)) | |
x_labels = [f"Message {i+1}" for i in range(len(dates))] | |
fig, ax = plt.subplots(figsize=(8, 3)) | |
ax.plot(parsed_x, scores, marker='o', linestyle='-', color='darkred', linewidth=2) | |
for x, y in zip(parsed_x, scores): | |
ax.text(x, y + 2, f"{int(y)}%", ha='center', fontsize=8, color='black') | |
ax.set_xticks(parsed_x) | |
ax.set_xticklabels(x_labels) | |
ax.set_xlabel("") | |
ax.set_ylabel("Abuse Score (%)") | |
ax.set_ylim(0, 105) | |
ax.grid(True) | |
plt.tight_layout() | |
buf = io.BytesIO() | |
plt.savefig(buf, format='png') | |
buf.seek(0) | |
return Image.open(buf) | |
from transformers import AutoModelForSequenceClassification, AutoTokenizer | |
model_name = "SamanthaStorm/tether-multilabel-v3" | |
model = AutoModelForSequenceClassification.from_pretrained(model_name) | |
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False) | |
LABELS = [...] | |
THRESHOLDS = {...} | |
PATTERN_WEIGHTS = {...} | |
RISK_STAGE_LABELS = {...} | |
ESCALATION_QUESTIONS = [...] | |
DARVO_PATTERNS = {...} | |
DARVO_MOTIFS = [...] | |
# (Leave the rest of your helper functions unchanged) | |
def analyze_single_message(text, thresholds): | |
motif_hits, matched_phrases = detect_motifs(text) | |
emotion_profile = get_emotion_profile(text) | |
sentiment_score = emotion_profile.get("anger", 0) + emotion_profile.get("disgust", 0) | |
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True) | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
scores = torch.sigmoid(outputs.logits.squeeze(0)).numpy() | |
if emotion_profile.get("neutral", 0) > 0.85 and any( | |
scores[label_idx] > thresholds[LABELS[label_idx]] | |
for label_idx in [LABELS.index(l) for l in ["control", "threat", "blame shifting"]] | |
): | |
sentiment = "undermining" | |
else: | |
sentiment = "undermining" if sentiment_score > 0.25 else "supportive" | |
adjusted_thresholds = { | |
k: v + 0.05 if sentiment == "supportive" else v | |
for k, v in thresholds.items() | |
} | |
passed = { | |
label: score for label, score in zip(LABELS, scores) | |
if score > adjusted_thresholds[label] | |
} | |
# (Continue unchanged) | |
def analyze_composite(msg1, date1, msg2, date2, msg3, date3, *answers_and_none): | |
none_selected_checked = answers_and_none[-1] | |
responses_checked = any(answers_and_none[:-1]) | |
none_selected = not responses_checked and none_selected_checked | |
if none_selected: | |
escalation_score = None | |
risk_level = "unknown" | |
else: | |
escalation_score = sum(w for (_, w), a in zip(ESCALATION_QUESTIONS, answers_and_none[:-1]) if a) | |
risk_level = ( | |
"High" if escalation_score >= 16 else | |
"Moderate" if escalation_score >= 8 else | |
"Low" | |
) | |
messages = [msg1, msg2, msg3] | |
dates = [date1, date2, date3] | |
active = [(m, d) for m, d in zip(messages, dates) if m.strip()] | |
if not active: | |
return "Please enter at least one message." | |
results = [(analyze_single_message(m, THRESHOLDS.copy()), d) for m, d in active] | |
for result, date in results: | |
assert len(result) == 7, "Unexpected output from analyze_single_message" | |
top_labels = [r[0][6] for r in results] | |
top_scores = [r[0][2][0][1] for r in results] | |
sentiments = [r[0][3]['label'] for r in results] | |
stages = [r[0][4] for r in results] | |
darvo_scores = [r[0][5] for r in results] | |
dates_used = [r[1] or "Undated" for r in results] | |
abuse_scores = [r[0][0] for r in results] | |
composite_abuse = int(round(sum(abuse_scores) / len(abuse_scores))) | |
top_label = f"{top_labels[0]} β {int(round(top_scores[0] * 100))}%" | |
most_common_stage = max(set(stages), key=stages.count) | |
stage_text = RISK_STAGE_LABELS[most_common_stage] | |
avg_darvo = round(sum(darvo_scores) / len(darvo_scores), 3) | |
darvo_blurb = "" | |
if avg_darvo > 0.25: | |
level = "moderate" if avg_darvo < 0.65 else "high" | |
darvo_blurb = f"\n\nπ **DARVO Score: {avg_darvo}** β This indicates a **{level} likelihood** of narrative reversal (DARVO), where the speaker may be denying, attacking, or reversing blame." | |
out = f"Abuse Intensity: {composite_abuse}%\n" | |
out += "π This reflects the strength and severity of detected abuse patterns in the message(s).\n\n" | |
if escalation_score is None: | |
escalation_text = "π Escalation Potential: Unknown (Checklist not completed)\n" | |
escalation_text += "β οΈ *This section was not completed. Escalation potential is unknown.*\n" | |
else: | |
escalation_text = f"𧨠**Escalation Potential: {risk_level} ({escalation_score}/{sum(w for _, w in ESCALATION_QUESTIONS)})**\n" | |
escalation_text += "This score comes directly from the safety checklist and functions as a standalone escalation risk score.\n" | |
escalation_text += "It indicates how many serious risk factors are present based on your answers to the safety checklist.\n" | |
if top_label is None: | |
top_label = "Unknown β 0%" | |
out += generate_risk_snippet(composite_abuse, top_label, escalation_score if escalation_score is not None else 0, most_common_stage) | |
out += f"\n\n{stage_text}" | |
out += darvo_blurb | |
out += "\n\n" + escalation_text | |
print(f"DEBUG: avg_darvo = {avg_darvo}") | |
pattern_labels = [r[0][2][0][0] for r in results] | |
timeline_image = generate_abuse_score_chart(dates_used, abuse_scores, pattern_labels) | |
return out, timeline_image | |
message_date_pairs = [ | |
( | |
gr.Textbox(label=f"Message {i+1}"), | |
gr.Textbox(label=f"Date {i+1} (optional)", placeholder="YYYY-MM-DD") | |
) | |
for i in range(3) | |
] | |
textbox_inputs = [item for pair in message_date_pairs for item in pair] | |
quiz_boxes = [gr.Checkbox(label=q) for q, _ in ESCALATION_QUESTIONS] | |
none_box = gr.Checkbox(label="None of the above") | |
iface = gr.Interface( | |
fn=analyze_composite, | |
inputs=textbox_inputs + quiz_boxes + [none_box], | |
outputs=[ | |
gr.Textbox(label="Results"), | |
gr.Image(label="Risk Stage Timeline", type="pil") | |
], | |
title="Abuse Pattern Detector + Escalation Quiz", | |
allow_flagging="manual" | |
) | |
if __name__ == "__main__": | |
iface.launch() |