import gradio as gr import torch import numpy as np from transformers import pipeline, RobertaForSequenceClassification, RobertaTokenizer from motif_tagging import detect_motifs import re import matplotlib.pyplot as plt import io from PIL import Image from datetime import datetime from transformers import pipeline as hf_pipeline # prevent name collision with gradio pipeline def get_emotion_profile(text): emotions = emotion_pipeline(text) if isinstance(emotions, list) and isinstance(emotions[0], list): emotions = emotions[0] return {e['label'].lower(): round(e['score'], 3) for e in emotions} emotion_pipeline = hf_pipeline( "text-classification", model="j-hartmann/emotion-english-distilroberta-base", top_k=None, truncation=True ) def generate_abuse_score_chart(dates, scores, labels): if all(re.match(r"\d{4}-\d{2}-\d{2}", d) for d in dates): parsed_x = [datetime.strptime(d, "%Y-%m-%d") for d in dates] x_labels = [d.strftime("%Y-%m-%d") for d in parsed_x] else: parsed_x = list(range(1, len(dates) + 1)) x_labels = [f"Message {i+1}" for i in range(len(dates))] fig, ax = plt.subplots(figsize=(8, 3)) ax.plot(parsed_x, scores, marker='o', linestyle='-', color='darkred', linewidth=2) for x, y in zip(parsed_x, scores): ax.text(x, y + 2, f"{int(y)}%", ha='center', fontsize=8, color='black') ax.set_xticks(parsed_x) ax.set_xticklabels(x_labels) ax.set_xlabel("") ax.set_ylabel("Abuse Score (%)") ax.set_ylim(0, 105) ax.grid(True) plt.tight_layout() buf = io.BytesIO() plt.savefig(buf, format='png') buf.seek(0) return Image.open(buf) from transformers import AutoModelForSequenceClassification, AutoTokenizer model_name = "SamanthaStorm/tether-multilabel-v3" model = AutoModelForSequenceClassification.from_pretrained(model_name) tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False) LABELS = [...] THRESHOLDS = {...} PATTERN_WEIGHTS = {...} RISK_STAGE_LABELS = {...} ESCALATION_QUESTIONS = [...] DARVO_PATTERNS = {...} DARVO_MOTIFS = [...] # (Leave the rest of your helper functions unchanged) def analyze_single_message(text, thresholds): motif_hits, matched_phrases = detect_motifs(text) emotion_profile = get_emotion_profile(text) sentiment_score = emotion_profile.get("anger", 0) + emotion_profile.get("disgust", 0) inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True) with torch.no_grad(): outputs = model(**inputs) scores = torch.sigmoid(outputs.logits.squeeze(0)).numpy() if emotion_profile.get("neutral", 0) > 0.85 and any( scores[label_idx] > thresholds[LABELS[label_idx]] for label_idx in [LABELS.index(l) for l in ["control", "threat", "blame shifting"]] ): sentiment = "undermining" else: sentiment = "undermining" if sentiment_score > 0.25 else "supportive" adjusted_thresholds = { k: v + 0.05 if sentiment == "supportive" else v for k, v in thresholds.items() } passed = { label: score for label, score in zip(LABELS, scores) if score > adjusted_thresholds[label] } # (Continue unchanged) def analyze_composite(msg1, date1, msg2, date2, msg3, date3, *answers_and_none): none_selected_checked = answers_and_none[-1] responses_checked = any(answers_and_none[:-1]) none_selected = not responses_checked and none_selected_checked if none_selected: escalation_score = None risk_level = "unknown" else: escalation_score = sum(w for (_, w), a in zip(ESCALATION_QUESTIONS, answers_and_none[:-1]) if a) risk_level = ( "High" if escalation_score >= 16 else "Moderate" if escalation_score >= 8 else "Low" ) messages = [msg1, msg2, msg3] dates = [date1, date2, date3] active = [(m, d) for m, d in zip(messages, dates) if m.strip()] if not active: return "Please enter at least one message." results = [(analyze_single_message(m, THRESHOLDS.copy()), d) for m, d in active] for result, date in results: assert len(result) == 7, "Unexpected output from analyze_single_message" top_labels = [r[0][6] for r in results] top_scores = [r[0][2][0][1] for r in results] sentiments = [r[0][3]['label'] for r in results] stages = [r[0][4] for r in results] darvo_scores = [r[0][5] for r in results] dates_used = [r[1] or "Undated" for r in results] abuse_scores = [r[0][0] for r in results] composite_abuse = int(round(sum(abuse_scores) / len(abuse_scores))) top_label = f"{top_labels[0]} โ€“ {int(round(top_scores[0] * 100))}%" most_common_stage = max(set(stages), key=stages.count) stage_text = RISK_STAGE_LABELS[most_common_stage] avg_darvo = round(sum(darvo_scores) / len(darvo_scores), 3) darvo_blurb = "" if avg_darvo > 0.25: level = "moderate" if avg_darvo < 0.65 else "high" darvo_blurb = f"\n\n๐ŸŽญ **DARVO Score: {avg_darvo}** โ†’ This indicates a **{level} likelihood** of narrative reversal (DARVO), where the speaker may be denying, attacking, or reversing blame." out = f"Abuse Intensity: {composite_abuse}%\n" out += "๐Ÿ“Š This reflects the strength and severity of detected abuse patterns in the message(s).\n\n" if escalation_score is None: escalation_text = "๐Ÿ“‰ Escalation Potential: Unknown (Checklist not completed)\n" escalation_text += "โš ๏ธ *This section was not completed. Escalation potential is unknown.*\n" else: escalation_text = f"๐Ÿงจ **Escalation Potential: {risk_level} ({escalation_score}/{sum(w for _, w in ESCALATION_QUESTIONS)})**\n" escalation_text += "This score comes directly from the safety checklist and functions as a standalone escalation risk score.\n" escalation_text += "It indicates how many serious risk factors are present based on your answers to the safety checklist.\n" if top_label is None: top_label = "Unknown โ€“ 0%" out += generate_risk_snippet(composite_abuse, top_label, escalation_score if escalation_score is not None else 0, most_common_stage) out += f"\n\n{stage_text}" out += darvo_blurb out += "\n\n" + escalation_text print(f"DEBUG: avg_darvo = {avg_darvo}") pattern_labels = [r[0][2][0][0] for r in results] timeline_image = generate_abuse_score_chart(dates_used, abuse_scores, pattern_labels) return out, timeline_image message_date_pairs = [ ( gr.Textbox(label=f"Message {i+1}"), gr.Textbox(label=f"Date {i+1} (optional)", placeholder="YYYY-MM-DD") ) for i in range(3) ] textbox_inputs = [item for pair in message_date_pairs for item in pair] quiz_boxes = [gr.Checkbox(label=q) for q, _ in ESCALATION_QUESTIONS] none_box = gr.Checkbox(label="None of the above") iface = gr.Interface( fn=analyze_composite, inputs=textbox_inputs + quiz_boxes + [none_box], outputs=[ gr.Textbox(label="Results"), gr.Image(label="Risk Stage Timeline", type="pil") ], title="Abuse Pattern Detector + Escalation Quiz", allow_flagging="manual" ) if __name__ == "__main__": iface.launch()