File size: 5,732 Bytes
d6e219c
f1948f2
a9d4250
f1948f2
 
 
a9d4250
f1948f2
a9d4250
 
 
79936aa
 
 
 
 
 
 
 
a9d4250
f1948f2
 
 
 
 
c303ab8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79936aa
 
f1948f2
79936aa
 
 
f1948f2
 
 
79936aa
 
 
 
 
f1948f2
744dcb8
 
79936aa
 
 
 
c303ab8
 
79936aa
070ff27
c303ab8
 
 
 
 
79936aa
 
 
f1948f2
2d95c7b
 
 
 
 
 
f1948f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2d95c7b
 
5a8477a
83c1ff8
 
 
2d95c7b
f1948f2
5a8477a
f1948f2
 
 
ea30a69
f1948f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import gradio as gr
import torch
from transformers import RobertaForSequenceClassification, RobertaTokenizer
import numpy as np
import tempfile

# Load model and tokenizer
model_name = "SamanthaStorm/abuse-pattern-detector-v2"
model = RobertaForSequenceClassification.from_pretrained(model_name)
tokenizer = RobertaTokenizer.from_pretrained(model_name)

# Define the final label order your model used
LABELS = [
    "gaslighting", "mockery", "dismissiveness", "control",
    "guilt_tripping", "apology_baiting", "blame_shifting", "projection",
    "contradictory_statements", "manipulation", "deflection", "insults",
    "obscure_formal", "recovery_phase", "suicidal_threat", "physical_threat",
    "extreme_control"
]
TOTAL_LABELS = 17

# Our model outputs 17 labels:
# - First 14 are abuse pattern categories
# - Last 3 are Danger Assessment cues
TOTAL_LABELS = 17
# Individual thresholds for each of the 17 labels
THRESHOLDS = {
    "gaslighting": 0.15,
    "mockery": 0.15,
    "dismissiveness": 0.15,
    "control": 0.15,
    "guilt_tripping": 0.15,
    "apology_baiting": 0.15,
    "blame_shifting": 0.15,
    "projection": 0.15,
    "contradictory_statements": 0.15,
    "manipulation": 0.15,
    "deflection": 0.15,
    "insults": 0.15,
    "obscure_formal": 0.15,
    "recovery_phase": 0.15,
    "suicidal_threat": 0.10,
    "physical_threat": 0.10,
    "extreme_control": 0.10
}
def analyze_messages(input_text):
    input_text = input_text.strip()
    if not input_text:
        return "Please enter a message for analysis."

    # Tokenize
    inputs = tokenizer(input_text, return_tensors="pt", truncation=True, padding=True)
    with torch.no_grad():
        outputs = model(**inputs)

    # Squeeze out batch dimension: shape should be [17]
    logits = outputs.logits.squeeze(0)

    # Convert logits to probabilities
    scores = torch.sigmoid(logits).numpy()
    print("Scores:", scores)
    print("Danger Scores:", scores[14:])  # suicidal, physical, extreme

    # Debug printing (remove once you're confident everything works)
    print("Scores:", scores)

    pattern_count = 0
    danger_flag_count = 0

for i, (label, score) in enumerate(zip(LABELS, scores)):
    if score > THRESHOLDS[label]:
        if i < 14:
            pattern_count += 1
        else:
            danger_flag_count += 1
    # (Optional) Print label-by-label for debugging
    for i, s in enumerate(scores):
        print(LABELS[i], "=", round(s, 3))
    
   
    danger_assessment = (
    "High" if danger_flag_count >= 2 else
    "Moderate" if danger_flag_count == 1 else
    "Low"
)
    # Customize resource links based on Danger Assessment Score (with additional niche support)
    if danger_assessment == "High":
        resources = (
            "**Immediate Help:** If you are in immediate danger, please call 911.\n\n"
            "**Crisis Support:** National DV Hotline – Safety Planning: [thehotline.org/plan-for-safety](https://www.thehotline.org/plan-for-safety/)\n"
            "**Legal Assistance:** WomensLaw – Legal Help for Survivors: [womenslaw.org](https://www.womenslaw.org/)\n"
            "**Specialized Support:** For LGBTQ+, immigrants, and neurodivergent survivors, please consult local specialized services or visit RAINN: [rainn.org](https://www.rainn.org/)"
        )
    elif danger_assessment == "Moderate":
        resources = (
            "**Safety Planning:** The Hotline – What Is Emotional Abuse?: [thehotline.org/resources](https://www.thehotline.org/resources/what-is-emotional-abuse/)\n"
            "**Relationship Health:** One Love Foundation – Digital Relationship Health: [joinonelove.org](https://www.joinonelove.org/)\n"
            "**Support Chat:** National Domestic Violence Hotline Chat: [thehotline.org](https://www.thehotline.org/)\n"
            "**Specialized Groups:** Look for support groups tailored for LGBTQ+, immigrant, and neurodivergent communities."
        )
    else:  # Low risk
        resources = (
            "**Educational Resources:** Love Is Respect – Healthy Relationships: [loveisrespect.org](https://www.loveisrespect.org/)\n"
            "**Therapy Finder:** Psychology Today – Find a Therapist: [psychologytoday.com](https://www.psychologytoday.com/us/therapists)\n"
            "**Relationship Tools:** Relate – Relationship Health Tools: [relate.org.uk](https://www.relate.org.uk/)\n"
            "**Community Support:** Consider community-based and online support groups, especially those focused on LGBTQ+, immigrant, and neurodivergent survivors."
        )
    
   
    # Prepare the output result with just pattern count and dynamic resources
    result_md = (
        f"**Abuse Pattern Count:** {pattern_count}\n\n"
        f"**Support Resources:**\n{resources}"
    )
 
    # Save the result to a temporary text file for download
    with tempfile.NamedTemporaryFile(delete=False, suffix=".txt", mode="w") as f:
        f.write(result_md)
        report_path = f.name
    
    return result_md, report_path

# Build the Gradio interface
with gr.Blocks() as demo:
    gr.Markdown("# Abuse Pattern Detector - Risk Analysis")
    gr.Markdown("Enter one or more messages (separated by newlines) for analysis.")
    
    text_input = gr.Textbox(label="Input Messages", lines=10, placeholder="Type your message(s) here...")
    result_output = gr.Markdown(label="Analysis Result")
    download_output = gr.File(label="Download Report (.txt)")
    
    text_input.submit(analyze_messages, inputs=text_input, outputs=[result_output, download_output])
    analyze_btn = gr.Button("Analyze")
    analyze_btn.click(analyze_messages, inputs=text_input, outputs=[result_output, download_output])
    
if __name__ == "__main__":
    demo.launch()