File size: 4,997 Bytes
d6e219c
f1948f2
a9d4250
f1948f2
 
5dfb1ca
f1948f2
5dfb1ca
 
a9d4250
5dfb1ca
79936aa
 
 
 
 
 
 
f1948f2
5dfb1ca
c303ab8
44dafc4
293a004
44dafc4
 
99463df
44dafc4
99463df
44dafc4
 
 
 
 
 
 
 
 
 
99463df
c303ab8
4292d1b
5dfb1ca
e178791
612e2a1
4292d1b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79936aa
 
f1948f2
4292d1b
79936aa
5dfb1ca
f1948f2
 
 
4292d1b
79936aa
5dfb1ca
e178791
612e2a1
68ecdb1
99463df
68ecdb1
 
 
 
79936aa
68ecdb1
 
 
79936aa
5dfb1ca
68ecdb1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5dfb1ca
b390ecc
5d7c4ba
5dfb1ca
ab8c96f
 
5dfb1ca
 
 
 
ab8c96f
 
4292d1b
5dfb1ca
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import gradio as gr
import torch
from transformers import RobertaForSequenceClassification, RobertaTokenizer
import numpy as np

# Load model and tokenizer with trust_remote_code in case it's needed
model_name = "SamanthaStorm/abuse-pattern-detector-v2"
model = RobertaForSequenceClassification.from_pretrained(model_name, trust_remote_code=True)
tokenizer = RobertaTokenizer.from_pretrained(model_name, trust_remote_code=True)

# Define labels (17 total)
LABELS = [
    "gaslighting", "mockery", "dismissiveness", "control",
    "guilt_tripping", "apology_baiting", "blame_shifting", "projection",
    "contradictory_statements", "manipulation", "deflection", "insults",
    "obscure_formal", "recovery_phase", "suicidal_threat", "physical_threat",
    "extreme_control"
]

# Custom thresholds for each label (make sure these match your original settings)
THRESHOLDS = {
    "gaslighting": 0.25,
    "mockery": 0.15,
    "dismissiveness": 0.30,  # original value, not 0.30
    "control": 0.43,
    "guilt_tripping": 0.19,
    "apology_baiting": 0.45,
    "blame_shifting": 0.23,
    "projection": 0.50,
    "contradictory_statements": 0.25,
    "manipulation": 0.25,
    "deflection": 0.30,
    "insults": 0.34,
    "obscure_formal": 0.25,
    "recovery_phase": 0.25,
    "suicidal_threat": 0.45,
    "physical_threat": 0.31,
    "extreme_control": 0.36,
    "non_abusive": 0.40
}

# Define label groups using slicing (first 14: abuse patterns, last 3: danger cues)
PATTERN_LABELS = LABELS[:14]
DANGER_LABELS = LABELS[14:17]

def calculate_abuse_level(scores, thresholds):
    triggered_scores = [score for label, score in zip(LABELS, scores) if score > thresholds[label]]
    if not triggered_scores:
        return 0.0
    return round(np.mean(triggered_scores) * 100, 2)

def interpret_abuse_level(score):
    if score > 80:
        return "Extreme / High Risk"
    elif score > 60:
        return "Severe / Harmful Pattern Present"
    elif score > 40:
        return "Likely Abuse"
    elif score > 20:
        return "Mild Concern"
    else:
        return "Very Low / Likely Safe"

def analyze_messages(input_text):
    input_text = input_text.strip()
    if not input_text:
        return "Please enter a message for analysis.", None

    # Tokenize input and generate model predictions
    inputs = tokenizer(input_text, return_tensors="pt", truncation=True, padding=True)
    with torch.no_grad():
        outputs = model(**inputs)
    scores = torch.sigmoid(outputs.logits.squeeze(0)).numpy()

    # Count the number of triggered abuse pattern and danger flags based on thresholds
    pattern_count = sum(score > THRESHOLDS[label] for label, score in zip(PATTERN_LABELS, scores[:14]))
    danger_flag_count = sum(score > THRESHOLDS[label] for label, score in zip(DANGER_LABELS, scores[14:17]))

    # Build formatted raw score display
    score_lines = [
        f"{label:25}: {score:.3f}" for label, score in zip(PATTERN_LABELS + DANGER_LABELS, scores)
    ]
    raw_score_output = "\n".join(score_lines)

    # Calculate overall abuse level and interpret it
    abuse_level = calculate_abuse_level(scores, THRESHOLDS)
    abuse_description = interpret_abuse_level(abuse_level)

    # Resource logic based on the number of danger cues
    if danger_flag_count >= 2:
        resources = "Immediate assistance recommended. Please seek professional help or contact emergency services."
    else:
        resources = "For more information on abuse patterns, consider reaching out to support groups or professional counselors."

    # Get top 2 highest scoring abuse patterns (excluding 'non_abusive')
    scored_patterns = [(label, score) for label, score in zip(PATTERN_LABELS, scores[:14])]
    top_patterns = sorted(scored_patterns, key=lambda x: x[1], reverse=True)[:2]
    top_patterns_str = "\n".join([f"• {label.replace('_', ' ').title()}" for label, _ in top_patterns])

    # Format final result
    result = (
        f"Abuse Risk Score: {abuse_level}% – {abuse_description}\n"
        "This message contains signs of emotionally harmful communication that may indicate abusive patterns.\n\n"
        f"Most Likely Patterns:\n{top_patterns_str}\n\n"
        f"⚠️ Critical Danger Flags Detected: {danger_flag_count} of 3\n"
        "The Danger Assessment is a validated tool that helps identify serious risk in intimate partner violence. "
        "It flags communication patterns associated with increased risk of severe harm. "
        "For more info, consider reaching out to support groups or professionals.\n\n"
        f"Resources: {resources}"
    )

    # Return both a text summary and a JSON-like dict of scores per label
    return result

# Updated Gradio Interface using new component syntax
iface = gr.Interface(
    fn=analyze_messages,
    inputs=gr.Textbox(lines=10, placeholder="Enter message here..."),
    outputs=[
        gr.Textbox(label="Analysis Result"),
    ],
    title="Abuse Pattern Detector"
)

if __name__ == "__main__":
    iface.launch()