SamanthaStorm commited on
Commit
c303ab8
·
verified ·
1 Parent(s): 5a8477a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -7
app.py CHANGED
@@ -23,7 +23,26 @@ TOTAL_LABELS = 17
23
  # - First 14 are abuse pattern categories
24
  # - Last 3 are Danger Assessment cues
25
  TOTAL_LABELS = 17
26
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  def analyze_messages(input_text):
28
  input_text = input_text.strip()
29
  if not input_text:
@@ -45,13 +64,15 @@ def analyze_messages(input_text):
45
  # Debug printing (remove once you're confident everything works)
46
  print("Scores:", scores)
47
 
48
- # First 14 = pattern scores
49
- pattern_scores = scores[:14]
50
- pattern_count = int(np.sum(pattern_scores > 0.15))
51
 
52
- # Last 3 = danger cues
53
- danger_scores = scores[14:]
54
- danger_flag_count = int(np.sum(danger_scores > 0.20))
 
 
 
55
  # (Optional) Print label-by-label for debugging
56
  for i, s in enumerate(scores):
57
  print(LABELS[i], "=", round(s, 3))
 
23
  # - First 14 are abuse pattern categories
24
  # - Last 3 are Danger Assessment cues
25
  TOTAL_LABELS = 17
26
+ # Individual thresholds for each of the 17 labels
27
+ THRESHOLDS = {
28
+ "gaslighting": 0.15,
29
+ "mockery": 0.15,
30
+ "dismissiveness": 0.15,
31
+ "control": 0.15,
32
+ "guilt_tripping": 0.15,
33
+ "apology_baiting": 0.15,
34
+ "blame_shifting": 0.15,
35
+ "projection": 0.15,
36
+ "contradictory_statements": 0.15,
37
+ "manipulation": 0.15,
38
+ "deflection": 0.15,
39
+ "insults": 0.15,
40
+ "obscure_formal": 0.15,
41
+ "recovery_phase": 0.15,
42
+ "suicidal_threat": 0.10,
43
+ "physical_threat": 0.10,
44
+ "extreme_control": 0.10
45
+ }
46
  def analyze_messages(input_text):
47
  input_text = input_text.strip()
48
  if not input_text:
 
64
  # Debug printing (remove once you're confident everything works)
65
  print("Scores:", scores)
66
 
67
+ pattern_count = 0
68
+ danger_flag_count = 0
 
69
 
70
+ for i, (label, score) in enumerate(zip(LABELS, scores)):
71
+ if score > THRESHOLDS[label]:
72
+ if i < 14:
73
+ pattern_count += 1
74
+ else:
75
+ danger_flag_count += 1
76
  # (Optional) Print label-by-label for debugging
77
  for i, s in enumerate(scores):
78
  print(LABELS[i], "=", round(s, 3))