SamanthaStorm commited on
Commit
1ef0cee
·
verified ·
1 Parent(s): 9387da0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -3
app.py CHANGED
@@ -4,7 +4,38 @@ import numpy as np
4
  from transformers import pipeline, RobertaForSequenceClassification, RobertaTokenizer
5
  from motif_tagging import detect_motifs
6
  import re
7
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  # --- SST Sentiment Model ---
9
  sst_pipeline = pipeline("sentiment-analysis", model="distilbert-base-uncased-finetuned-sst-2-english")
10
 
@@ -300,7 +331,8 @@ def analyze_composite(msg1, date1, msg2, date2, msg3, date3, *answers_and_none):
300
  out += f"\n\n{stage_text}"
301
  out += darvo_blurb
302
 
303
- return out
 
304
 
305
  message_date_pairs = [(gr.Textbox(label=f"Message {i+1}"), gr.Textbox(label=f"Date {i+1} (optional)", placeholder="e.g. 2025-04-22")) for i in range(3)]
306
  textbox_inputs = [item for pair in message_date_pairs for item in pair] # Flatten for Gradio input
@@ -310,7 +342,10 @@ none_box = gr.Checkbox(label="None of the above")
310
  iface = gr.Interface(
311
  fn=analyze_composite,
312
  inputs=textbox_inputs + quiz_boxes + [none_box],
313
- outputs=gr.Textbox(label="Results"),
 
 
 
314
  title="Abuse Pattern Detector + Escalation Quiz",
315
  allow_flagging="manual"
316
  )
 
4
  from transformers import pipeline, RobertaForSequenceClassification, RobertaTokenizer
5
  from motif_tagging import detect_motifs
6
  import re
7
+ import matplotlib.pyplot as plt
8
+ import io
9
+ from PIL import Image
10
+ from datetime import datetime
11
+
12
+ # --- Timeline Visualization Function ---
13
+ def generate_risk_stage_timeline(dates, stages):
14
+ stage_labels = {
15
+ 1: "Tension",
16
+ 2: "Escalation",
17
+ 3: "Reconciliation",
18
+ 4: "Calm"
19
+ }
20
+ try:
21
+ parsed_dates = [datetime.strptime(d, "%Y-%m-%d") for d in dates]
22
+ except Exception:
23
+ parsed_dates = list(range(len(dates)))
24
+
25
+ fig, ax = plt.subplots(figsize=(6, 2.5))
26
+ ax.step(parsed_dates, stages, where='mid', color='purple', linewidth=2, marker='o')
27
+ ax.set_yticks(list(stage_labels.keys()))
28
+ ax.set_yticklabels([stage_labels[s] for s in stage_labels])
29
+ ax.set_title("Risk Stage Timeline")
30
+ ax.set_xlabel("Date")
31
+ ax.set_ylabel("Risk Stage")
32
+ ax.grid(True)
33
+ plt.tight_layout()
34
+
35
+ buf = io.BytesIO()
36
+ plt.savefig(buf, format='png')
37
+ buf.seek(0)
38
+ return Image.open(buf)
39
  # --- SST Sentiment Model ---
40
  sst_pipeline = pipeline("sentiment-analysis", model="distilbert-base-uncased-finetuned-sst-2-english")
41
 
 
331
  out += f"\n\n{stage_text}"
332
  out += darvo_blurb
333
 
334
+ timeline_image = generate_risk_stage_timeline(dates_used, stages)
335
+ return out, timeline_image
336
 
337
  message_date_pairs = [(gr.Textbox(label=f"Message {i+1}"), gr.Textbox(label=f"Date {i+1} (optional)", placeholder="e.g. 2025-04-22")) for i in range(3)]
338
  textbox_inputs = [item for pair in message_date_pairs for item in pair] # Flatten for Gradio input
 
342
  iface = gr.Interface(
343
  fn=analyze_composite,
344
  inputs=textbox_inputs + quiz_boxes + [none_box],
345
+ outputs=[
346
+ gr.Textbox(label="Results"),
347
+ gr.Image(label="Risk Stage Timeline", type="pil")
348
+ ],
349
  title="Abuse Pattern Detector + Escalation Quiz",
350
  allow_flagging="manual"
351
  )