File size: 3,766 Bytes
d6e219c
f1948f2
a9d4250
f1948f2
 
 
a9d4250
f1948f2
a9d4250
 
 
ab8c96f
79936aa
 
 
 
 
 
 
f1948f2
ab8c96f
c303ab8
293a004
 
ab8c96f
 
293a004
 
 
b11fbe8
293a004
 
 
ab8c96f
b11fbe8
c303ab8
b11fbe8
 
 
c303ab8
4292d1b
ab8c96f
4292d1b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79936aa
 
f1948f2
4292d1b
79936aa
4292d1b
f1948f2
 
 
4292d1b
79936aa
ab8c96f
4292d1b
 
79936aa
ab8c96f
4292d1b
 
79936aa
ab8c96f
4292d1b
ab8c96f
5d7c4ba
ab8c96f
 
 
 
 
 
 
 
83c1ff8
ab8c96f
5d7c4ba
ab8c96f
 
 
 
 
 
4292d1b
ab8c96f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import gradio as gr
import torch
from transformers import RobertaForSequenceClassification, RobertaTokenizer
import numpy as np
import tempfile

# Load model and tokenizer
model_name = "SamanthaStorm/abuse-pattern-detector-v2"
model = RobertaForSequenceClassification.from_pretrained(model_name)
tokenizer = RobertaTokenizer.from_pretrained(model_name)

# Define labels (total 17 labels)
LABELS = [
    "gaslighting", "mockery", "dismissiveness", "control",
    "guilt_tripping", "apology_baiting", "blame_shifting", "projection",
    "contradictory_statements", "manipulation", "deflection", "insults",
    "obscure_formal", "recovery_phase", "suicidal_threat", "physical_threat",
    "extreme_control"
]

# Custom thresholds per label (make sure these are exactly as in the original)
THRESHOLDS = {
    "gaslighting": 0.15,
    "mockery": 0.15,
    "dismissiveness": 0.25,  # Keep this as 0.25 (not 0.30)
    "control": 0.13,
    "guilt_tripping": 0.15,
    "apology_baiting": 0.15,
    "blame_shifting": 0.15,
    "projection": 0.20,
    "contradictory_statements": 0.15,
    "manipulation": 0.15,
    "deflection": 0.15,
    "insults": 0.20,
    "obscure_formal": 0.20,
    "recovery_phase": 0.15,
    "suicidal_threat": 0.08,
    "physical_threat": 0.045,
    "extreme_control": 0.30,
}

# Define label groups using slicing (first 14 are abuse patterns, last 3 are danger cues)
PATTERN_LABELS = LABELS[:14]
DANGER_LABELS = LABELS[14:]

def calculate_abuse_level(scores, thresholds):
    triggered_scores = [score for label, score in zip(LABELS, scores) if score > thresholds[label]]
    if not triggered_scores:
        return 0.0
    return round(np.mean(triggered_scores) * 100, 2)

def interpret_abuse_level(score):
    if score > 80:
        return "Extreme / High Risk"
    elif score > 60:
        return "Severe / Harmful Pattern Present"
    elif score > 40:
        return "Likely Abuse"
    elif score > 20:
        return "Mild Concern"
    else:
        return "Very Low / Likely Safe"

def analyze_messages(input_text):
    input_text = input_text.strip()
    if not input_text:
        return "Please enter a message for analysis.", None

    # Tokenize and predict
    inputs = tokenizer(input_text, return_tensors="pt", truncation=True, padding=True)
    with torch.no_grad():
        outputs = model(**inputs)
    scores = torch.sigmoid(outputs.logits.squeeze(0)).numpy()

    # Count triggered labels using the correct slices
    pattern_count = sum(score > THRESHOLDS[label] for label, score in zip(PATTERN_LABELS, scores[:14]))
    danger_flag_count = sum(score > THRESHOLDS[label] for label, score in zip(DANGER_LABELS, scores[14:]))

    # Abuse level calculation and severity interpretation
    abuse_level = calculate_abuse_level(scores, THRESHOLDS)
    abuse_description = interpret_abuse_level(abuse_level)

    # Resource logic (example logic; adjust as needed)
    if danger_flag_count >= 2:
        resources = "Immediate assistance recommended. Please seek professional help or contact emergency services."
    else:
        resources = "For more information on abuse patterns, consider reaching out to support groups or professional counselors."

    # Output combining counts, severity, and resource suggestion
    result = (
        f"Abuse Patterns Detected: {pattern_count} out of {len(PATTERN_LABELS)}\n"
        f"Danger Flags Detected: {danger_flag_count} out of {len(DANGER_LABELS)}\n"
        f"Abuse Level: {abuse_level}% - {abuse_description}\n"
        f"Resources: {resources}"
    )
    return result, scores

iface = gr.Interface(
    fn=analyze_messages,
    inputs=gr.inputs.Textbox(lines=10, placeholder="Enter message here..."),
    outputs=["text", "json"],
    title="Abuse Pattern Detector"
)

iface.launch()