Tether / app.py
SamanthaStorm's picture
Update app.py
b11fbe8 verified
raw
history blame
5.13 kB
import gradio as gr
import torch
from transformers import RobertaForSequenceClassification, RobertaTokenizer
import numpy as np
import tempfile
# Load model and tokenizer
model_name = "SamanthaStorm/abuse-pattern-detector-v2"
model = RobertaForSequenceClassification.from_pretrained(model_name)
tokenizer = RobertaTokenizer.from_pretrained(model_name)
# Define labels
LABELS = [
"gaslighting", "mockery", "dismissiveness", "control",
"guilt_tripping", "apology_baiting", "blame_shifting", "projection",
"contradictory_statements", "manipulation", "deflection", "insults",
"obscure_formal", "recovery_phase", "suicidal_threat", "physical_threat",
"extreme_control"
]
# Custom thresholds per label (feel free to adjust)
THRESHOLDS = {
"gaslighting": 0.15,
"mockery": 0.15,
"dismissiveness": 0.25,
"control": 0.03,
"guilt_tripping": 0.15,
"apology_baiting": 0.15,
"blame_shifting": 0.15,
"projection": 0.20,
"contradictory_statements": 0.15,
"manipulation": 0.15,
"deflection": 0.15,
"insults": 0.25,
"obscure_formal": 0.20,
"recovery_phase": 0.15,
"suicidal_threat": 0.08,
"physical_threat": 0.045,
"extreme_control": 0.30,
}
# Label categories
PATTERN_LABELS = LABELS[:14]
DANGER_LABELS = LABELS[14:]
def calculate_abuse_level(scores, thresholds):
triggered_scores = [score for label, score in zip(LABELS, scores) if score > thresholds[label]]
if not triggered_scores:
return 0.0
return round(np.mean(triggered_scores) * 100, 2)
def interpret_abuse_level(score):
if score > 80:
return "Extreme / High Risk"
elif score > 60:
return "Severe / Harmful Pattern Present"
elif score > 40:
return "Likely Abuse"
elif score > 20:
return "Mild Concern"
else:
return "Very Low / Likely Safe"
def analyze_messages(input_text):
input_text = input_text.strip()
if not input_text:
return "Please enter a message for analysis.", None
# Tokenize and predict
inputs = tokenizer(input_text, return_tensors="pt", truncation=True, padding=True)
with torch.no_grad():
outputs = model(**inputs)
scores = torch.sigmoid(outputs.logits.squeeze(0)).numpy()
# Count triggered labels
pattern_count = sum(score > THRESHOLDS[label] for label, score in zip(PATTERN_LABELS, scores[:14]))
danger_flag_count = sum(score > THRESHOLDS[label] for label, score in zip(DANGER_LABELS, scores[14:]))
# Abuse level calculation
abuse_level = calculate_abuse_level(scores, THRESHOLDS)
abuse_description = interpret_abuse_level(abuse_level)
# Resource logic
if danger_flag_count >= 2:
resources = (
"**Immediate Help:** Call 911 if in danger.\n\n"
"**Crisis Support:** National DV Hotline – [thehotline.org/plan-for-safety](https://www.thehotline.org/plan-for-safety/)\n"
"**Legal Support:** WomensLaw – [womenslaw.org](https://www.womenslaw.org/)\n"
"**Specialized Services:** RAINN, StrongHearts, LGBTQ+, immigrant, neurodivergent resources"
)
elif danger_flag_count == 1:
resources = (
"**Emotional Abuse Info:** [thehotline.org/resources](https://www.thehotline.org/resources/what-is-emotional-abuse/)\n"
"**Relationship Education:** [joinonelove.org](https://www.joinonelove.org/)\n"
"**Support Chat:** [thehotline.org](https://www.thehotline.org/)\n"
"**Community Groups:** LGBTQ+, immigrant, and neurodivergent spaces"
)
else:
resources = (
"**Healthy Relationships:** [loveisrespect.org](https://www.loveisrespect.org/)\n"
"**Find a Therapist:** [psychologytoday.com](https://www.psychologytoday.com/us/therapists)\n"
"**Relationship Tools:** [relate.org.uk](https://www.relate.org.uk/)\n"
"**Online Peer Support:** (including identity-focused groups)"
)
result_md = (
f"### πŸ“‹ Analysis Summary\n\n"
f"**Abuse Pattern Count:** {pattern_count}\n"
f"**Danger Cues Detected:** {danger_flag_count}\n"
f"**Abuse Level:** {abuse_level}% ({abuse_description})\n\n"
f"### πŸ›Ÿ Suggested Support Resources\n{resources}"
)
# Save to .txt
with tempfile.NamedTemporaryFile(delete=False, suffix=".txt", mode="w") as f:
f.write(result_md)
report_path = f.name
return result_md, report_path
# Interface
with gr.Blocks() as demo:
gr.Markdown("# πŸ” Abuse Pattern Detector")
gr.Markdown("Paste one or more messages for analysis (multi-line supported).")
text_input = gr.Textbox(label="Text Message(s)", lines=10, placeholder="Paste messages here...")
result_output = gr.Markdown()
file_output = gr.File(label="πŸ“₯ Download Analysis (.txt)")
text_input.submit(analyze_messages, inputs=text_input, outputs=[result_output, file_output])
analyze_btn = gr.Button("Analyze")
analyze_btn.click(analyze_messages, inputs=text_input, outputs=[result_output, file_output])
if __name__ == "__main__":
demo.launch()