SamanthaStorm commited on
Commit
8dc9e3c
·
verified ·
1 Parent(s): 90bf0d2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -52
app.py CHANGED
@@ -2,14 +2,40 @@ import gradio as gr
2
  import spaces
3
  import torch
4
  import numpy as np
5
- from transformers import pipeline, RobertaForSequenceClassification, RobertaTokenizer
6
- from motif_tagging import detect_motifs
7
  import re
8
  import matplotlib.pyplot as plt
9
  import io
10
  from PIL import Image
11
  from datetime import datetime
12
- from transformers import pipeline as hf_pipeline # prevent name collision with gradio pipeline
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  def get_emotion_profile(text):
15
  emotions = emotion_pipeline(text)
@@ -87,9 +113,9 @@ THRESHOLDS = {
87
  }
88
 
89
  PATTERN_WEIGHTS = {
90
- "recovery": 0.7,
91
  "control": 1.4,
92
- "gaslighting": 1.50,
93
  "guilt tripping": 0.9,
94
  "dismissiveness": 0.9,
95
  "blame shifting": 0.8,
@@ -382,7 +408,6 @@ THREAT_MOTIFS = [
382
  ]
383
 
384
 
385
- @spaces.GPU
386
  @spaces.GPU
387
  def compute_abuse_score(matched_scores, sentiment):
388
  """
@@ -429,17 +454,15 @@ def compute_abuse_score(matched_scores, sentiment):
429
  base_score *= 1.05 # Reduced
430
 
431
  # Sentiment modifier (more nuanced)
432
- if sentiment == "supportive":
433
- manipulative_patterns = {'guilt tripping', 'gaslighting', 'blame shifting', 'love bombing'}
434
- if any(label in manipulative_patterns for label, score, _ in matched_scores if score > 0.6): # Higher threshold
435
- base_score *= 0.95 # Smaller reduction for strongly manipulative "support"
436
- elif any(label in manipulative_patterns for label, score, _ in matched_scores if score > 0.4): # Moderate threshold
437
- base_score *= 0.9 # Moderate reduction for manipulative "support"
438
- else:
439
- base_score *= 0.8 # Larger reduction for genuine support
440
-
441
- elif sentiment == "undermining":
442
- base_score *= 1.15
443
 
444
  # Reduce minimum score and threshold for activation
445
  if any(score > 0.9 for _, score, _ in matched_scores): # Higher threshold
@@ -449,6 +472,7 @@ def compute_abuse_score(matched_scores, sentiment):
449
 
450
  return min(round(base_score, 1), 100.0)
451
 
 
452
  def analyze_single_message(text, thresholds):
453
  print("⚡ ENTERED analyze_single_message")
454
  stage = 1
@@ -459,19 +483,21 @@ def analyze_single_message(text, thresholds):
459
  sentiment_score = emotion_profile.get("anger", 0) + emotion_profile.get("disgust", 0)
460
 
461
  # Get model scores
462
- inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
463
  with torch.no_grad():
464
  outputs = model(**inputs)
465
  scores = torch.sigmoid(outputs.logits.squeeze(0)).numpy()
466
 
467
  # Sentiment override
468
  if emotion_profile.get("neutral", 0) > 0.85 and any(
469
- scores[LABELS.index(l)] > thresholds[l]
470
- for l in ["control", "blame shifting"]
471
  ):
 
 
472
  sentiment = "undermining"
473
  else:
474
- sentiment = "undermining" if sentiment_score > 0.25 else "supportive"
475
 
476
  weapon_flag = detect_weapon_language(text)
477
 
@@ -486,10 +512,10 @@ def analyze_single_message(text, thresholds):
486
  label for label, score in zip(LABELS, scores)
487
  if score > adjusted_thresholds[label]
488
  ]
 
489
 
490
- # Early exit if nothing passed
491
  if not threshold_labels:
492
- return 0.0, [], [], {"label": sentiment}, 1, 0.0, "supportive"
493
 
494
  top_patterns = sorted(
495
  [(label, score) for label, score in zip(LABELS, scores)],
@@ -497,34 +523,25 @@ def analyze_single_message(text, thresholds):
497
  reverse=True
498
  )[:2]
499
 
500
- matched_scores = [
501
- (label, score, PATTERN_WEIGHTS.get(label, 1.0))
502
- for label, score in zip(LABELS, scores)
503
- if score > adjusted_thresholds[label]
504
- ]
505
-
506
-
507
- # Cap subtle insults to avoid excessive abuse score
508
- if (
509
- len(threshold_labels) == 1 and "insults" in threshold_labels
510
- and emotion_profile.get("neutral", 0) > 0.85
511
- ):
512
- abuse_score_raw = min(abuse_score_raw, 40)
513
-
514
  # Abuse score
515
- abuse_score_raw = compute_abuse_score(matched_scores, sentiment)
516
 
517
- # Weapon adjustment
518
  if weapon_flag:
519
- abuse_score_raw = min(abuse_score_raw + 25, 100)
520
  if stage < 2:
521
  stage = 2
522
 
523
- abuse_score = min(abuse_score_raw, 100 if "control" in threshold_labels else 95)
524
 
525
- # Tone tag
526
  tone_tag = get_emotional_tone_tag(emotion_profile, sentiment, threshold_labels, abuse_score)
527
 
 
 
 
 
 
 
 
528
  # Remove recovery tag if tone is fake
529
  if "recovery" in threshold_labels and tone_tag == "forced accountability flip":
530
  threshold_labels.remove("recovery")
@@ -563,7 +580,7 @@ def analyze_single_message(text, thresholds):
563
 
564
  return abuse_score, threshold_labels, top_patterns, {"label": sentiment}, stage, darvo_score, tone_tag
565
 
566
- import spaces
567
 
568
  @spaces.GPU
569
  def analyze_composite(msg1, msg2, msg3, *answers_and_none):
@@ -612,7 +629,7 @@ def analyze_composite(msg1, msg2, msg3, *answers_and_none):
612
  immediate_threats = [detect_threat_motifs(m, THREAT_MOTIFS) for m, _ in active]
613
  flat_threats = [t for sublist in immediate_threats for t in sublist]
614
  threat_risk = "Yes" if flat_threats else "No"
615
- results = [(analyze_single_message(m, THRESHOLDS.copy()), d) for m, d in active]
616
 
617
  abuse_scores = [r[0][0] for r in results]
618
  stages = [r[0][4] for r in results]
@@ -713,14 +730,18 @@ def analyze_composite(msg1, msg2, msg3, *answers_and_none):
713
  f"• Checklist Risk: {checklist_escalation_risk}\n"
714
  f"• Escalation Bump: +{escalation_bump} (from DARVO, tone, intensity, etc.)"
715
  )
716
- # Composite Abuse Score
 
717
  composite_abuse_scores = []
718
- for result, _ in results:
719
- _, _, top_patterns, sentiment, _, _, _ = result
720
- matched_scores = [(label, score, PATTERN_WEIGHTS.get(label, 1.0)) for label, score in top_patterns]
721
- final_score = compute_abuse_score(matched_scores, sentiment["label"])
722
- composite_abuse_scores.append(final_score)
723
- composite_abuse = int(round(sum(composite_abuse_scores) / len(composite_abuse_scores)))
 
 
 
724
 
725
  most_common_stage = max(set(stages), key=stages.count)
726
  stage_text = RISK_STAGE_LABELS[most_common_stage]
@@ -763,7 +784,7 @@ def analyze_composite(msg1, msg2, msg3, *answers_and_none):
763
  pats[0][0] if (pats := r[0][2]) else "none"
764
  for r in results
765
  ]
766
- timeline_image = generate_abuse_score_chart(dates_used, abuse_scores, top_labels)
767
  out += "\n\n" + escalation_text
768
  return out, timeline_image
769
 
 
2
  import spaces
3
  import torch
4
  import numpy as np
 
 
5
  import re
6
  import matplotlib.pyplot as plt
7
  import io
8
  from PIL import Image
9
  from datetime import datetime
10
+ from transformers import pipeline, AutoModelForSequenceClassification, AutoTokenizer
11
+ from motif_tagging import detect_motifs
12
+ from functools import lru_cache
13
+ from torch.nn.functional import sigmoid
14
+
15
+ # ----- Models -----
16
+
17
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18
+
19
+ # Emotion model (CPU for stability)
20
+ emotion_pipeline = pipeline(
21
+ "text-classification",
22
+ model="j-hartmann/emotion-english-distilroberta-base",
23
+ top_k=6,
24
+ truncation=True,
25
+ device=-1 # Force CPU usage
26
+ )
27
+
28
+ # Abuse Model
29
+ model_name = "SamanthaStorm/tether-multilabel-v4" # Or your HF Hub path
30
+ model = AutoModelForSequenceClassification.from_pretrained(model_name)
31
+ tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
32
+ model.to(device)
33
+
34
+ # DARVO Model
35
+ darvo_model = AutoModelForSequenceClassification.from_pretrained("SamanthaStorm/tether-darvo-regressor-v1")
36
+ darvo_tokenizer = AutoTokenizer.from_pretrained("SamanthaStorm/tether-darvo-regressor-v1", use_fast=False)
37
+ darvo_model.eval()
38
+ darvo_model.to(device)
39
 
40
  def get_emotion_profile(text):
41
  emotions = emotion_pipeline(text)
 
113
  }
114
 
115
  PATTERN_WEIGHTS = {
116
+ "recovery": 0.5,
117
  "control": 1.4,
118
+ "gaslighting": 1.0,
119
  "guilt tripping": 0.9,
120
  "dismissiveness": 0.9,
121
  "blame shifting": 0.8,
 
408
  ]
409
 
410
 
 
411
  @spaces.GPU
412
  def compute_abuse_score(matched_scores, sentiment):
413
  """
 
454
  base_score *= 1.05 # Reduced
455
 
456
  # Sentiment modifier (more nuanced)
457
+ if emotion_profile.get("neutral", 0) > 0.85 and any(
458
+ scores[LABELS.index(l)] > thresholds[l] * 0.8 # Scale down thresholds for neutral sentiment
459
+ for l in ["control", "blame shifting", "insults", "guilt tripping"] # Consider more labels
460
+ ):
461
+ sentiment = "undermining" # Only override if multiple patterns are present with moderate confidence
462
+ elif sentiment_score > 0.35: # Increased threshold
463
+ sentiment = "undermining"
464
+ else:
465
+ sentiment = "supportive"
 
 
466
 
467
  # Reduce minimum score and threshold for activation
468
  if any(score > 0.9 for _, score, _ in matched_scores): # Higher threshold
 
472
 
473
  return min(round(base_score, 1), 100.0)
474
 
475
+ @lru_cache(maxsize=1024) # Cache results for performance
476
  def analyze_single_message(text, thresholds):
477
  print("⚡ ENTERED analyze_single_message")
478
  stage = 1
 
483
  sentiment_score = emotion_profile.get("anger", 0) + emotion_profile.get("disgust", 0)
484
 
485
  # Get model scores
486
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True).to(device)
487
  with torch.no_grad():
488
  outputs = model(**inputs)
489
  scores = torch.sigmoid(outputs.logits.squeeze(0)).numpy()
490
 
491
  # Sentiment override
492
  if emotion_profile.get("neutral", 0) > 0.85 and any(
493
+ scores[LABELS.index(l)] > thresholds[l] * 0.8 # Scale down thresholds for neutral sentiment
494
+ for l in ["control", "blame shifting", "insults", "guilt tripping"] # Consider more labels
495
  ):
496
+ sentiment = "undermining" # Only override if multiple patterns are present with moderate confidence
497
+ elif sentiment_score > 0.35: # Increased threshold
498
  sentiment = "undermining"
499
  else:
500
+ sentiment = "supportive"
501
 
502
  weapon_flag = detect_weapon_language(text)
503
 
 
512
  label for label, score in zip(LABELS, scores)
513
  if score > adjusted_thresholds[label]
514
  ]
515
+ matched_scores = [(label, score, PATTERN_WEIGHTS.get(label, 1.0)) for label, score in zip(LABELS, scores) if score > adjusted_thresholds[label]]
516
 
 
517
  if not threshold_labels:
518
+ return 0.0, [], [], {"label": sentiment}, 1, 0.0, None
519
 
520
  top_patterns = sorted(
521
  [(label, score) for label, score in zip(LABELS, scores)],
 
523
  reverse=True
524
  )[:2]
525
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
526
  # Abuse score
527
+ abuse_score = compute_abuse_score(matched_scores, sentiment) # Calculate before adjustments
528
 
 
529
  if weapon_flag:
530
+ abuse_score = min(abuse_score + 25, 100) # Apply weapon adjustment directly to abuse_score
531
  if stage < 2:
532
  stage = 2
533
 
534
+ abuse_score = min(abuse_score, 100 if "control" in threshold_labels else 95) # Apply cap after weapon adjustment
535
 
 
536
  tone_tag = get_emotional_tone_tag(emotion_profile, sentiment, threshold_labels, abuse_score)
537
 
538
+
539
+ threshold_labels = [label for label, score in zip(LABELS, scores) if score > adjusted_thresholds[label]]
540
+ matched_scores = [(label, score, PATTERN_WEIGHTS.get(label, 1.0)) for label, score in zip(LABELS, scores) if score > adjusted_thresholds[label]]
541
+
542
+ if not threshold_labels:
543
+ return 0.0, [], [], {"label": sentiment}, 1, 0.0, None
544
+
545
  # Remove recovery tag if tone is fake
546
  if "recovery" in threshold_labels and tone_tag == "forced accountability flip":
547
  threshold_labels.remove("recovery")
 
580
 
581
  return abuse_score, threshold_labels, top_patterns, {"label": sentiment}, stage, darvo_score, tone_tag
582
 
583
+
584
 
585
  @spaces.GPU
586
  def analyze_composite(msg1, msg2, msg3, *answers_and_none):
 
629
  immediate_threats = [detect_threat_motifs(m, THREAT_MOTIFS) for m, _ in active]
630
  flat_threats = [t for sublist in immediate_threats for t in sublist]
631
  threat_risk = "Yes" if flat_threats else "No"
632
+ results = [(analyze_single_message(m.lower(), THRESHOLDS.copy()), d) for m, d in active]
633
 
634
  abuse_scores = [r[0][0] for r in results]
635
  stages = [r[0][4] for r in results]
 
730
  f"• Checklist Risk: {checklist_escalation_risk}\n"
731
  f"• Escalation Bump: +{escalation_bump} (from DARVO, tone, intensity, etc.)"
732
  )
733
+
734
+ # Composite Abuse Score (weighted average based on message length)
735
  composite_abuse_scores = []
736
+ message_lengths = [len(m.split()) for m, _ in active]
737
+ total_length = sum(message_lengths)
738
+
739
+ for result, length in zip(results, message_lengths):
740
+ abuse_score = result[0][0]
741
+ weight = length / total_length if total_length > 0 else 1 / len(results) if len(results) > 0 else 1
742
+ composite_abuse_scores.append(abuse_score * weight)
743
+ composite_abuse = int(round(sum(composite_abuse_scores)))
744
+
745
 
746
  most_common_stage = max(set(stages), key=stages.count)
747
  stage_text = RISK_STAGE_LABELS[most_common_stage]
 
784
  pats[0][0] if (pats := r[0][2]) else "none"
785
  for r in results
786
  ]
787
+ timeline_image = generate_abuse_score_chart(dates_used, abuse_scores, top_labels)
788
  out += "\n\n" + escalation_text
789
  return out, timeline_image
790