File size: 5,529 Bytes
d6e219c
f1948f2
a9d4250
f1948f2
 
 
a9d4250
f1948f2
a9d4250
 
 
79936aa
 
 
 
 
 
 
 
a9d4250
f1948f2
 
 
 
 
c303ab8
 
293a004
 
 
 
 
 
 
7dec571
293a004
 
 
7dec571
293a004
c303ab8
293a004
 
 
c303ab8
79936aa
 
f1948f2
79936aa
 
 
f1948f2
 
 
79936aa
5d7c4ba
79936aa
 
5d7c4ba
f1948f2
744dcb8
 
79936aa
c303ab8
 
79936aa
5d7c4ba
 
 
 
 
 
 
 
79936aa
 
5d7c4ba
2d95c7b
5d7c4ba
 
 
 
479c580
 
d23b4c6
 
5d7c4ba
f1948f2
 
 
5d7c4ba
 
 
f1948f2
 
 
5d7c4ba
 
 
f1948f2
 
5d7c4ba
f1948f2
5d7c4ba
 
 
f1948f2
 
5d7c4ba
 
5a8477a
83c1ff8
 
 
5d7c4ba
5a8477a
f1948f2
 
5d7c4ba
ea30a69
f1948f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import gradio as gr
import torch
from transformers import RobertaForSequenceClassification, RobertaTokenizer
import numpy as np
import tempfile

# Load model and tokenizer
model_name = "SamanthaStorm/abuse-pattern-detector-v2"
model = RobertaForSequenceClassification.from_pretrained(model_name)
tokenizer = RobertaTokenizer.from_pretrained(model_name)

# Define the final label order your model used
LABELS = [
    "gaslighting", "mockery", "dismissiveness", "control",
    "guilt_tripping", "apology_baiting", "blame_shifting", "projection",
    "contradictory_statements", "manipulation", "deflection", "insults",
    "obscure_formal", "recovery_phase", "suicidal_threat", "physical_threat",
    "extreme_control"
]
TOTAL_LABELS = 17

# Our model outputs 17 labels:
# - First 14 are abuse pattern categories
# - Last 3 are Danger Assessment cues
TOTAL_LABELS = 17
# Individual thresholds for each of the 17 labels
THRESHOLDS = {
    "gaslighting": 0.15,
    "mockery": 0.15,
    "dismissiveness": 0.30,
    "control": 0.13,
    "guilt_tripping": 0.15,
    "apology_baiting": 0.15,
    "blame_shifting": 0.15,
    "projection": 0.20,
    "contradictory_statements": 0.15,
    "manipulation": 0.15,
    "deflection": 0.15,
    "insults": 0.20,
    "obscure_formal": 0.20,
    "recovery_phase": 0.15,
    "suicidal_threat": 0.09,
    "physical_threat": 0.50,
    "extreme_control": 0.30
}
def analyze_messages(input_text):
    input_text = input_text.strip()
    if not input_text:
        return "Please enter a message for analysis."

    # Tokenize
    inputs = tokenizer(input_text, return_tensors="pt", truncation=True, padding=True)
    with torch.no_grad():
        outputs = model(**inputs)

    # Squeeze out batch dimension
    logits = outputs.logits.squeeze(0)

    # Convert to probabilities
    scores = torch.sigmoid(logits).numpy()
    print("Scores:", scores)
    print("Danger Scores:", scores[14:])  # suicidal, physical, extreme

    pattern_count = 0
    danger_flag_count = 0

    for i, (label, score) in enumerate(zip(LABELS, scores)):
        if score > THRESHOLDS[label]:
            if i < 14:
                pattern_count += 1
            else:
                danger_flag_count += 1

    # Optional debug print
    for i, s in enumerate(scores):
        print(LABELS[i], "=", round(s, 3))

    danger_assessment = (
        "High" if danger_flag_count >= 2 else
        "Moderate" if danger_flag_count == 1 else
        "Low"
    )
    # Treat high-scoring danger cues as abuse patterns as well
    for danger_label in ["suicidal_threat", "physical_threat", "extreme_control"]:
        if scores[LABELS.index(danger_label)] > THRESHOLDS[danger_label]:
            pattern_count += 1
    # Set resources
    if danger_assessment == "High":
        resources = (
            "**Immediate Help:** If you are in immediate danger, please call 911.\n\n"
            "**Crisis Support:** National DV Hotline – Safety Planning: [thehotline.org/plan-for-safety](https://www.thehotline.org/plan-for-safety)\n"
            "**Legal Assistance:** WomensLaw – Legal Help for Survivors: [womenslaw.org](https://www.womenslaw.org)\n"
            "**Specialized Support:** For LGBTQ+, immigrants, and neurodivergent survivors, please consult local services."
        )
    elif danger_assessment == "Moderate":
        resources = (
            "**Safety Planning:** The Hotline – What Is Emotional Abuse?: [thehotline.org/resources](https://www.thehotline.org/resources)\n"
            "**Relationship Health:** One Love Foundation – Digital Relationship Health: [joinonelove.org](https://www.joinonelove.org)\n"
            "**Support Chat:** National Domestic Violence Hotline Chat: [thehotline.org](https://www.thehotline.org)\n"
            "**Specialized Groups:** Look for support groups tailored for LGBTQ+, immigrant, and neurodivergent communities."
        )
    else:
        resources = (
            "**Educational Resources:** Love Is Respect – Healthy Relationships: [loveisrespect.org](https://www.loveisrespect.org)\n"
            "**Therapy Finder:** Psychology Today – Find a Therapist: [psychologytoday.com](https://www.psychologytoday.com)\n"
            "**Relationship Tools:** Relate – Relationship Health Tools: [relate.org.uk](https://www.relate.org.uk)\n"
            "**Community Support:** Consider community-based and online support groups, especially those focused on LGBTQ+, immigrant, and neurodivergent survivors."
        )

    # Output
    result_md = (
        f"**Abuse Pattern Count:** {pattern_count}\n\n"
        f"**Support Resources:**\n{resources}"
    )

    with tempfile.NamedTemporaryFile(delete=False, suffix=".txt", mode="w") as f:
        f.write(result_md)
        report_path = f.name

    return result_md, report_path

# Build the Gradio interface
with gr.Blocks() as demo:
    gr.Markdown("# Abuse Pattern Detector - Risk Analysis")
    gr.Markdown("Enter one or more messages (separated by newlines) for analysis.")
    
    text_input = gr.Textbox(label="Input Messages", lines=10, placeholder="Type your message(s) here...")
    result_output = gr.Markdown(label="Analysis Result")
    download_output = gr.File(label="Download Report (.txt)")
    
    text_input.submit(analyze_messages, inputs=text_input, outputs=[result_output, download_output])
    analyze_btn = gr.Button("Analyze")
    analyze_btn.click(analyze_messages, inputs=text_input, outputs=[result_output, download_output])
    
if __name__ == "__main__":
    demo.launch()