Spaces:
Running
Running
File size: 5,529 Bytes
d6e219c f1948f2 a9d4250 f1948f2 a9d4250 f1948f2 a9d4250 79936aa a9d4250 f1948f2 c303ab8 293a004 7dec571 293a004 7dec571 293a004 c303ab8 293a004 c303ab8 79936aa f1948f2 79936aa f1948f2 79936aa 5d7c4ba 79936aa 5d7c4ba f1948f2 744dcb8 79936aa c303ab8 79936aa 5d7c4ba 79936aa 5d7c4ba 2d95c7b 5d7c4ba 479c580 d23b4c6 5d7c4ba f1948f2 5d7c4ba f1948f2 5d7c4ba f1948f2 5d7c4ba f1948f2 5d7c4ba f1948f2 5d7c4ba 5a8477a 83c1ff8 5d7c4ba 5a8477a f1948f2 5d7c4ba ea30a69 f1948f2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
import gradio as gr
import torch
from transformers import RobertaForSequenceClassification, RobertaTokenizer
import numpy as np
import tempfile
# Load model and tokenizer
model_name = "SamanthaStorm/abuse-pattern-detector-v2"
model = RobertaForSequenceClassification.from_pretrained(model_name)
tokenizer = RobertaTokenizer.from_pretrained(model_name)
# Define the final label order your model used
LABELS = [
"gaslighting", "mockery", "dismissiveness", "control",
"guilt_tripping", "apology_baiting", "blame_shifting", "projection",
"contradictory_statements", "manipulation", "deflection", "insults",
"obscure_formal", "recovery_phase", "suicidal_threat", "physical_threat",
"extreme_control"
]
TOTAL_LABELS = 17
# Our model outputs 17 labels:
# - First 14 are abuse pattern categories
# - Last 3 are Danger Assessment cues
TOTAL_LABELS = 17
# Individual thresholds for each of the 17 labels
THRESHOLDS = {
"gaslighting": 0.15,
"mockery": 0.15,
"dismissiveness": 0.30,
"control": 0.13,
"guilt_tripping": 0.15,
"apology_baiting": 0.15,
"blame_shifting": 0.15,
"projection": 0.20,
"contradictory_statements": 0.15,
"manipulation": 0.15,
"deflection": 0.15,
"insults": 0.20,
"obscure_formal": 0.20,
"recovery_phase": 0.15,
"suicidal_threat": 0.09,
"physical_threat": 0.50,
"extreme_control": 0.30
}
def analyze_messages(input_text):
input_text = input_text.strip()
if not input_text:
return "Please enter a message for analysis."
# Tokenize
inputs = tokenizer(input_text, return_tensors="pt", truncation=True, padding=True)
with torch.no_grad():
outputs = model(**inputs)
# Squeeze out batch dimension
logits = outputs.logits.squeeze(0)
# Convert to probabilities
scores = torch.sigmoid(logits).numpy()
print("Scores:", scores)
print("Danger Scores:", scores[14:]) # suicidal, physical, extreme
pattern_count = 0
danger_flag_count = 0
for i, (label, score) in enumerate(zip(LABELS, scores)):
if score > THRESHOLDS[label]:
if i < 14:
pattern_count += 1
else:
danger_flag_count += 1
# Optional debug print
for i, s in enumerate(scores):
print(LABELS[i], "=", round(s, 3))
danger_assessment = (
"High" if danger_flag_count >= 2 else
"Moderate" if danger_flag_count == 1 else
"Low"
)
# Treat high-scoring danger cues as abuse patterns as well
for danger_label in ["suicidal_threat", "physical_threat", "extreme_control"]:
if scores[LABELS.index(danger_label)] > THRESHOLDS[danger_label]:
pattern_count += 1
# Set resources
if danger_assessment == "High":
resources = (
"**Immediate Help:** If you are in immediate danger, please call 911.\n\n"
"**Crisis Support:** National DV Hotline β Safety Planning: [thehotline.org/plan-for-safety](https://www.thehotline.org/plan-for-safety)\n"
"**Legal Assistance:** WomensLaw β Legal Help for Survivors: [womenslaw.org](https://www.womenslaw.org)\n"
"**Specialized Support:** For LGBTQ+, immigrants, and neurodivergent survivors, please consult local services."
)
elif danger_assessment == "Moderate":
resources = (
"**Safety Planning:** The Hotline β What Is Emotional Abuse?: [thehotline.org/resources](https://www.thehotline.org/resources)\n"
"**Relationship Health:** One Love Foundation β Digital Relationship Health: [joinonelove.org](https://www.joinonelove.org)\n"
"**Support Chat:** National Domestic Violence Hotline Chat: [thehotline.org](https://www.thehotline.org)\n"
"**Specialized Groups:** Look for support groups tailored for LGBTQ+, immigrant, and neurodivergent communities."
)
else:
resources = (
"**Educational Resources:** Love Is Respect β Healthy Relationships: [loveisrespect.org](https://www.loveisrespect.org)\n"
"**Therapy Finder:** Psychology Today β Find a Therapist: [psychologytoday.com](https://www.psychologytoday.com)\n"
"**Relationship Tools:** Relate β Relationship Health Tools: [relate.org.uk](https://www.relate.org.uk)\n"
"**Community Support:** Consider community-based and online support groups, especially those focused on LGBTQ+, immigrant, and neurodivergent survivors."
)
# Output
result_md = (
f"**Abuse Pattern Count:** {pattern_count}\n\n"
f"**Support Resources:**\n{resources}"
)
with tempfile.NamedTemporaryFile(delete=False, suffix=".txt", mode="w") as f:
f.write(result_md)
report_path = f.name
return result_md, report_path
# Build the Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# Abuse Pattern Detector - Risk Analysis")
gr.Markdown("Enter one or more messages (separated by newlines) for analysis.")
text_input = gr.Textbox(label="Input Messages", lines=10, placeholder="Type your message(s) here...")
result_output = gr.Markdown(label="Analysis Result")
download_output = gr.File(label="Download Report (.txt)")
text_input.submit(analyze_messages, inputs=text_input, outputs=[result_output, download_output])
analyze_btn = gr.Button("Analyze")
analyze_btn.click(analyze_messages, inputs=text_input, outputs=[result_output, download_output])
if __name__ == "__main__":
demo.launch()
|