File size: 7,265 Bytes
d6e219c
f1948f2
 
e46fbeb
a0d733c
 
1ef0cee
 
 
 
f81397d
 
be511bb
 
 
 
 
cc9d8f5
f81397d
 
 
 
 
 
1ef0cee
e46fbeb
5888ec6
 
 
 
 
 
e46fbeb
 
5888ec6
d270769
 
e46fbeb
5888ec6
 
cc9d8f5
e46fbeb
 
 
 
 
 
 
 
 
d80ec7b
6cdcb5b
 
e1ba4f8
6cdcb5b
 
e46fbeb
cc9d8f5
 
 
 
 
 
 
b405143
cc9d8f5
b405143
e46fbeb
 
44220fe
977855e
cdb869a
c127ba6
 
 
 
 
cdb869a
c127ba6
 
8cbdd57
90d7e35
cdb869a
 
44220fe
e46fbeb
 
 
96f5bc2
a0d733c
cc9d8f5
 
d4713b6
cc9d8f5
d80ec7b
cc9d8f5
e46fbeb
 
 
 
 
 
 
 
b40c9cf
a0d733c
e46fbeb
 
 
 
 
 
 
 
 
 
2376828
 
b40c9cf
e46fbeb
cc9d8f5
577f266
cc9d8f5
 
 
e46fbeb
 
 
 
cc9d8f5
 
e46fbeb
 
 
 
b40c9cf
e46fbeb
 
 
 
 
 
 
b40c9cf
 
e46fbeb
 
c4bd45f
 
 
2376828
ae867ed
 
 
cc9d8f5
92611e1
 
cc9d8f5
e46fbeb
 
b40c9cf
cc9d8f5
 
f76b87b
cc9d8f5
e46fbeb
 
cc9d8f5
2376828
e46fbeb
 
 
 
2376828
 
e46fbeb
 
2376828
96f5bc2
2376828
 
e46fbeb
 
 
 
 
 
2376828
 
96f5bc2
2376828
6f5f8ad
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
import gradio as gr
import torch
import numpy as np
from transformers import pipeline, RobertaForSequenceClassification, RobertaTokenizer
from motif_tagging import detect_motifs
import re
import matplotlib.pyplot as plt
import io
from PIL import Image
from datetime import datetime
from transformers import pipeline as hf_pipeline  # prevent name collision with gradio pipeline

def get_emotion_profile(text):
    emotions = emotion_pipeline(text)
    if isinstance(emotions, list) and isinstance(emotions[0], list):
        emotions = emotions[0]
    return {e['label'].lower(): round(e['score'], 3) for e in emotions}

emotion_pipeline = hf_pipeline(
    "text-classification",
    model="j-hartmann/emotion-english-distilroberta-base",
    top_k=None,
    truncation=True
)

def generate_abuse_score_chart(dates, scores, labels):
    if all(re.match(r"\d{4}-\d{2}-\d{2}", d) for d in dates):
        parsed_x = [datetime.strptime(d, "%Y-%m-%d") for d in dates]
        x_labels = [d.strftime("%Y-%m-%d") for d in parsed_x]
    else:
        parsed_x = list(range(1, len(dates) + 1))
        x_labels = [f"Message {i+1}" for i in range(len(dates))]

    fig, ax = plt.subplots(figsize=(8, 3))
    ax.plot(parsed_x, scores, marker='o', linestyle='-', color='darkred', linewidth=2)
    for x, y in zip(parsed_x, scores):
        ax.text(x, y + 2, f"{int(y)}%", ha='center', fontsize=8, color='black')

    ax.set_xticks(parsed_x)
    ax.set_xticklabels(x_labels)
    ax.set_xlabel("")
    ax.set_ylabel("Abuse Score (%)")
    ax.set_ylim(0, 105)
    ax.grid(True)
    plt.tight_layout()

    buf = io.BytesIO()
    plt.savefig(buf, format='png')
    buf.seek(0)
    return Image.open(buf)

from transformers import AutoModelForSequenceClassification, AutoTokenizer

model_name = "SamanthaStorm/tether-multilabel-v3"
model      = AutoModelForSequenceClassification.from_pretrained(model_name)
tokenizer  = AutoTokenizer.from_pretrained(model_name, use_fast=False)

LABELS = [...]
THRESHOLDS = {...}
PATTERN_WEIGHTS = {...}
RISK_STAGE_LABELS = {...}
ESCALATION_QUESTIONS = [...]
DARVO_PATTERNS = {...}
DARVO_MOTIFS = [...]

# (Leave the rest of your helper functions unchanged)

def analyze_single_message(text, thresholds):
    motif_hits, matched_phrases = detect_motifs(text)
    emotion_profile = get_emotion_profile(text)
    sentiment_score = emotion_profile.get("anger", 0) + emotion_profile.get("disgust", 0)

    inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
    with torch.no_grad():
        outputs = model(**inputs)
    scores = torch.sigmoid(outputs.logits.squeeze(0)).numpy()

    if emotion_profile.get("neutral", 0) > 0.85 and any(
        scores[label_idx] > thresholds[LABELS[label_idx]]
        for label_idx in [LABELS.index(l) for l in ["control", "threat", "blame shifting"]]
    ):
        sentiment = "undermining"
    else:
        sentiment = "undermining" if sentiment_score > 0.25 else "supportive"

    adjusted_thresholds = {
        k: v + 0.05 if sentiment == "supportive" else v
        for k, v in thresholds.items()
    }

    passed = {
        label: score for label, score in zip(LABELS, scores)
        if score > adjusted_thresholds[label]
    }

    # (Continue unchanged)

def analyze_composite(msg1, date1, msg2, date2, msg3, date3, *answers_and_none):
    none_selected_checked = answers_and_none[-1]
    responses_checked = any(answers_and_none[:-1])
    none_selected = not responses_checked and none_selected_checked

    if none_selected:
        escalation_score = None
        risk_level = "unknown"
    else:
        escalation_score = sum(w for (_, w), a in zip(ESCALATION_QUESTIONS, answers_and_none[:-1]) if a)
        risk_level = (
            "High" if escalation_score >= 16 else
            "Moderate" if escalation_score >= 8 else
            "Low"
        )

    messages = [msg1, msg2, msg3]
    dates = [date1, date2, date3]
    active = [(m, d) for m, d in zip(messages, dates) if m.strip()]
    if not active:
        return "Please enter at least one message."

    results = [(analyze_single_message(m, THRESHOLDS.copy()), d) for m, d in active]

    for result, date in results:
        assert len(result) == 7, "Unexpected output from analyze_single_message"

    top_labels = [r[0][6] for r in results]
    top_scores = [r[0][2][0][1] for r in results]
    sentiments = [r[0][3]['label'] for r in results]
    stages = [r[0][4] for r in results]
    darvo_scores = [r[0][5] for r in results]
    dates_used = [r[1] or "Undated" for r in results]
    abuse_scores = [r[0][0] for r in results]

    composite_abuse = int(round(sum(abuse_scores) / len(abuse_scores)))
    top_label = f"{top_labels[0]} – {int(round(top_scores[0] * 100))}%"

    most_common_stage = max(set(stages), key=stages.count)
    stage_text = RISK_STAGE_LABELS[most_common_stage]

    avg_darvo = round(sum(darvo_scores) / len(darvo_scores), 3)
    darvo_blurb = ""
    if avg_darvo > 0.25:
        level = "moderate" if avg_darvo < 0.65 else "high"
        darvo_blurb = f"\n\n🎭 **DARVO Score: {avg_darvo}** β†’ This indicates a **{level} likelihood** of narrative reversal (DARVO), where the speaker may be denying, attacking, or reversing blame."

    out = f"Abuse Intensity: {composite_abuse}%\n"
    out += "πŸ“Š This reflects the strength and severity of detected abuse patterns in the message(s).\n\n"

    if escalation_score is None:
        escalation_text = "πŸ“‰ Escalation Potential: Unknown (Checklist not completed)\n"
        escalation_text += "⚠️ *This section was not completed. Escalation potential is unknown.*\n"
    else:
        escalation_text = f"🧨 **Escalation Potential: {risk_level} ({escalation_score}/{sum(w for _, w in ESCALATION_QUESTIONS)})**\n"
        escalation_text += "This score comes directly from the safety checklist and functions as a standalone escalation risk score.\n"
        escalation_text += "It indicates how many serious risk factors are present based on your answers to the safety checklist.\n"

    if top_label is None:
        top_label = "Unknown – 0%"

    out += generate_risk_snippet(composite_abuse, top_label, escalation_score if escalation_score is not None else 0, most_common_stage)
    out += f"\n\n{stage_text}"
    out += darvo_blurb
    out += "\n\n" + escalation_text

    print(f"DEBUG: avg_darvo = {avg_darvo}")
    pattern_labels = [r[0][2][0][0] for r in results]
    timeline_image = generate_abuse_score_chart(dates_used, abuse_scores, pattern_labels)
    return out, timeline_image

message_date_pairs = [
    (
        gr.Textbox(label=f"Message {i+1}"),
        gr.Textbox(label=f"Date {i+1} (optional)", placeholder="YYYY-MM-DD")
    )
    for i in range(3)
]
textbox_inputs = [item for pair in message_date_pairs for item in pair]
quiz_boxes = [gr.Checkbox(label=q) for q, _ in ESCALATION_QUESTIONS]
none_box = gr.Checkbox(label="None of the above")

iface = gr.Interface(
    fn=analyze_composite,
    inputs=textbox_inputs + quiz_boxes + [none_box],
    outputs=[
        gr.Textbox(label="Results"),
        gr.Image(label="Risk Stage Timeline", type="pil")
    ],
    title="Abuse Pattern Detector + Escalation Quiz",
    allow_flagging="manual"
)

if __name__ == "__main__":
    iface.launch()