Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,732 Bytes
d6e219c f1948f2 a9d4250 f1948f2 a9d4250 f1948f2 a9d4250 79936aa a9d4250 f1948f2 c303ab8 79936aa f1948f2 79936aa f1948f2 79936aa f1948f2 744dcb8 79936aa c303ab8 79936aa 070ff27 c303ab8 79936aa f1948f2 2d95c7b f1948f2 2d95c7b 5a8477a 83c1ff8 2d95c7b f1948f2 5a8477a f1948f2 ea30a69 f1948f2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
import gradio as gr
import torch
from transformers import RobertaForSequenceClassification, RobertaTokenizer
import numpy as np
import tempfile
# Load model and tokenizer
model_name = "SamanthaStorm/abuse-pattern-detector-v2"
model = RobertaForSequenceClassification.from_pretrained(model_name)
tokenizer = RobertaTokenizer.from_pretrained(model_name)
# Define the final label order your model used
LABELS = [
"gaslighting", "mockery", "dismissiveness", "control",
"guilt_tripping", "apology_baiting", "blame_shifting", "projection",
"contradictory_statements", "manipulation", "deflection", "insults",
"obscure_formal", "recovery_phase", "suicidal_threat", "physical_threat",
"extreme_control"
]
TOTAL_LABELS = 17
# Our model outputs 17 labels:
# - First 14 are abuse pattern categories
# - Last 3 are Danger Assessment cues
TOTAL_LABELS = 17
# Individual thresholds for each of the 17 labels
THRESHOLDS = {
"gaslighting": 0.15,
"mockery": 0.15,
"dismissiveness": 0.15,
"control": 0.15,
"guilt_tripping": 0.15,
"apology_baiting": 0.15,
"blame_shifting": 0.15,
"projection": 0.15,
"contradictory_statements": 0.15,
"manipulation": 0.15,
"deflection": 0.15,
"insults": 0.15,
"obscure_formal": 0.15,
"recovery_phase": 0.15,
"suicidal_threat": 0.10,
"physical_threat": 0.10,
"extreme_control": 0.10
}
def analyze_messages(input_text):
input_text = input_text.strip()
if not input_text:
return "Please enter a message for analysis."
# Tokenize
inputs = tokenizer(input_text, return_tensors="pt", truncation=True, padding=True)
with torch.no_grad():
outputs = model(**inputs)
# Squeeze out batch dimension: shape should be [17]
logits = outputs.logits.squeeze(0)
# Convert logits to probabilities
scores = torch.sigmoid(logits).numpy()
print("Scores:", scores)
print("Danger Scores:", scores[14:]) # suicidal, physical, extreme
# Debug printing (remove once you're confident everything works)
print("Scores:", scores)
pattern_count = 0
danger_flag_count = 0
for i, (label, score) in enumerate(zip(LABELS, scores)):
if score > THRESHOLDS[label]:
if i < 14:
pattern_count += 1
else:
danger_flag_count += 1
# (Optional) Print label-by-label for debugging
for i, s in enumerate(scores):
print(LABELS[i], "=", round(s, 3))
danger_assessment = (
"High" if danger_flag_count >= 2 else
"Moderate" if danger_flag_count == 1 else
"Low"
)
# Customize resource links based on Danger Assessment Score (with additional niche support)
if danger_assessment == "High":
resources = (
"**Immediate Help:** If you are in immediate danger, please call 911.\n\n"
"**Crisis Support:** National DV Hotline β Safety Planning: [thehotline.org/plan-for-safety](https://www.thehotline.org/plan-for-safety/)\n"
"**Legal Assistance:** WomensLaw β Legal Help for Survivors: [womenslaw.org](https://www.womenslaw.org/)\n"
"**Specialized Support:** For LGBTQ+, immigrants, and neurodivergent survivors, please consult local specialized services or visit RAINN: [rainn.org](https://www.rainn.org/)"
)
elif danger_assessment == "Moderate":
resources = (
"**Safety Planning:** The Hotline β What Is Emotional Abuse?: [thehotline.org/resources](https://www.thehotline.org/resources/what-is-emotional-abuse/)\n"
"**Relationship Health:** One Love Foundation β Digital Relationship Health: [joinonelove.org](https://www.joinonelove.org/)\n"
"**Support Chat:** National Domestic Violence Hotline Chat: [thehotline.org](https://www.thehotline.org/)\n"
"**Specialized Groups:** Look for support groups tailored for LGBTQ+, immigrant, and neurodivergent communities."
)
else: # Low risk
resources = (
"**Educational Resources:** Love Is Respect β Healthy Relationships: [loveisrespect.org](https://www.loveisrespect.org/)\n"
"**Therapy Finder:** Psychology Today β Find a Therapist: [psychologytoday.com](https://www.psychologytoday.com/us/therapists)\n"
"**Relationship Tools:** Relate β Relationship Health Tools: [relate.org.uk](https://www.relate.org.uk/)\n"
"**Community Support:** Consider community-based and online support groups, especially those focused on LGBTQ+, immigrant, and neurodivergent survivors."
)
# Prepare the output result with just pattern count and dynamic resources
result_md = (
f"**Abuse Pattern Count:** {pattern_count}\n\n"
f"**Support Resources:**\n{resources}"
)
# Save the result to a temporary text file for download
with tempfile.NamedTemporaryFile(delete=False, suffix=".txt", mode="w") as f:
f.write(result_md)
report_path = f.name
return result_md, report_path
# Build the Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# Abuse Pattern Detector - Risk Analysis")
gr.Markdown("Enter one or more messages (separated by newlines) for analysis.")
text_input = gr.Textbox(label="Input Messages", lines=10, placeholder="Type your message(s) here...")
result_output = gr.Markdown(label="Analysis Result")
download_output = gr.File(label="Download Report (.txt)")
text_input.submit(analyze_messages, inputs=text_input, outputs=[result_output, download_output])
analyze_btn = gr.Button("Analyze")
analyze_btn.click(analyze_messages, inputs=text_input, outputs=[result_output, download_output])
if __name__ == "__main__":
demo.launch()
|