Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,190 Bytes
d6e219c f1948f2 a9d4250 f1948f2 a9d4250 f1948f2 a9d4250 79936aa a9d4250 f1948f2 79936aa f1948f2 79936aa f1948f2 79936aa f1948f2 744dcb8 79936aa 2d95c7b 79936aa 2d95c7b 79936aa f1948f2 2d95c7b f1948f2 2d95c7b f1948f2 9ee4a74 f1948f2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 |
import gradio as gr
import torch
from transformers import RobertaForSequenceClassification, RobertaTokenizer
import numpy as np
import tempfile
# Load model and tokenizer
model_name = "SamanthaStorm/abuse-pattern-detector-v2"
model = RobertaForSequenceClassification.from_pretrained(model_name)
tokenizer = RobertaTokenizer.from_pretrained(model_name)
# Define the final label order your model used
LABELS = [
"gaslighting", "mockery", "dismissiveness", "control",
"guilt_tripping", "apology_baiting", "blame_shifting", "projection",
"contradictory_statements", "manipulation", "deflection", "insults",
"obscure_formal", "recovery_phase", "suicidal_threat", "physical_threat",
"extreme_control"
]
TOTAL_LABELS = 17
# Our model outputs 17 labels:
# - First 14 are abuse pattern categories
# - Last 3 are Danger Assessment cues
TOTAL_LABELS = 17
def analyze_messages(input_text):
input_text = input_text.strip()
if not input_text:
return "Please enter a message for analysis."
# Tokenize
inputs = tokenizer(input_text, return_tensors="pt", truncation=True, padding=True)
with torch.no_grad():
outputs = model(**inputs)
# Squeeze out batch dimension: shape should be [17]
logits = outputs.logits.squeeze(0)
# Convert logits to probabilities
scores = torch.sigmoid(logits).numpy()
print("Scores:", scores)
print("Danger Scores:", scores[14:]) # suicidal, physical, extreme
# Debug printing (remove once you're confident everything works)
print("Scores:", scores)
# First 14 = pattern scores
pattern_scores = scores[:14]
pattern_count = int(np.sum(pattern_scores > 0.15))
# Last 3 = danger cues
danger_scores = scores[14:]
danger_flag_count = int(np.sum(danger_scores > 0.20))
# (Optional) Print label-by-label for debugging
for i, s in enumerate(scores):
print(LABELS[i], "=", round(s, 3))
danger_assessment = (
"High" if danger_flag_count >= 2 else
"Moderate" if danger_flag_count == 1 else
"Low"
)
# Customize resource links based on Danger Assessment Score (with additional niche support)
if danger_assessment == "High":
resources = (
"**Immediate Help:** If you are in immediate danger, please call 911.\n\n"
"**Crisis Support:** National DV Hotline β Safety Planning: [thehotline.org/plan-for-safety](https://www.thehotline.org/plan-for-safety/)\n"
"**Legal Assistance:** WomensLaw β Legal Help for Survivors: [womenslaw.org](https://www.womenslaw.org/)\n"
"**Specialized Support:** For LGBTQ+, immigrants, and neurodivergent survivors, please consult local specialized services or visit RAINN: [rainn.org](https://www.rainn.org/)"
)
elif danger_assessment == "Moderate":
resources = (
"**Safety Planning:** The Hotline β What Is Emotional Abuse?: [thehotline.org/resources](https://www.thehotline.org/resources/what-is-emotional-abuse/)\n"
"**Relationship Health:** One Love Foundation β Digital Relationship Health: [joinonelove.org](https://www.joinonelove.org/)\n"
"**Support Chat:** National Domestic Violence Hotline Chat: [thehotline.org](https://www.thehotline.org/)\n"
"**Specialized Groups:** Look for support groups tailored for LGBTQ+, immigrant, and neurodivergent communities."
)
else: # Low risk
resources = (
"**Educational Resources:** Love Is Respect β Healthy Relationships: [loveisrespect.org](https://www.loveisrespect.org/)\n"
"**Therapy Finder:** Psychology Today β Find a Therapist: [psychologytoday.com](https://www.psychologytoday.com/us/therapists)\n"
"**Relationship Tools:** Relate β Relationship Health Tools: [relate.org.uk](https://www.relate.org.uk/)\n"
"**Community Support:** Consider community-based and online support groups, especially those focused on LGBTQ+, immigrant, and neurodivergent survivors."
)
# Prepare the output result with just pattern count and dynamic resources
result_md = (
f"**Abuse Pattern Count:** {pattern_count}\n\n"
f"**Support Resources:**\n{resources}"
)
# Save the result to a temporary text file for download
with tempfile.NamedTemporaryFile(delete=False, suffix=".txt", mode="w") as f:
f.write(result_md)
report_path = f.name
return result_md, report_path
# Build the Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# Abuse Pattern Detector - Risk Analysis")
gr.Markdown("Enter one or more messages (separated by newlines) for analysis.")
text_input = gr.Textbox(label="Input Messages", lines=10, placeholder="Type your message(s) here...")
result_output = gr.Markdown(label="Analysis Result")
download_output = gr.File(label="Download Report (.txt)")
text_input.submit(analyze_messages, inputs=text_input, outputs=[result_output, download_output])
analyze_btn = gr.Button("Analyze")
analyze_btn.click(analyze_messages, inputs=text_input, outputs=[result_output, download_output])
if __name__ == "__main__":
demo.launch()
|