SamanthaStorm commited on
Commit
79936aa
·
verified ·
1 Parent(s): a0ddd67

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -15
app.py CHANGED
@@ -9,6 +9,14 @@ model_name = "SamanthaStorm/abuse-pattern-detector-v2"
9
  model = RobertaForSequenceClassification.from_pretrained(model_name)
10
  tokenizer = RobertaTokenizer.from_pretrained(model_name)
11
 
 
 
 
 
 
 
 
 
12
  TOTAL_LABELS = 17
13
 
14
  # Our model outputs 17 labels:
@@ -16,27 +24,36 @@ TOTAL_LABELS = 17
16
  # - Last 3 are Danger Assessment cues
17
  TOTAL_LABELS = 17
18
 
19
- def analyze_messages(text):
20
- input_text = text.strip()
21
  if not input_text:
22
- return "Please enter a message for analysis.", None
23
-
24
- # Tokenize input text
25
  inputs = tokenizer(input_text, return_tensors="pt", truncation=True, padding=True)
26
  with torch.no_grad():
27
  outputs = model(**inputs)
28
-
29
- # Assume model logits shape is [17] (for a single example)
30
- logits = outputs.logits.squeeze() # shape: [17]
 
 
31
  scores = torch.sigmoid(logits).numpy()
32
-
33
- # For the first 14 labels (abuse patterns), count how many exceed threshold 0.5
34
- abuse_pattern_scores = scores[:14]
35
- concerning_pattern_count = int(np.sum(abuse_pattern_scores > 0.5))
36
-
37
- # For the last 3 labels (Danger Assessment cues), count how many exceed threshold 0.5
38
- danger_scores = scores[14:17]
 
 
 
39
  danger_flag_count = int(np.sum(danger_scores > 0.5))
 
 
 
 
40
 
41
  # Map danger flag count to Danger Assessment Score
42
  if danger_flag_count >= 2:
 
9
  model = RobertaForSequenceClassification.from_pretrained(model_name)
10
  tokenizer = RobertaTokenizer.from_pretrained(model_name)
11
 
12
+ # Define the final label order your model used
13
+ LABELS = [
14
+ "gaslighting", "mockery", "dismissiveness", "control",
15
+ "guilt_tripping", "apology_baiting", "blame_shifting", "projection",
16
+ "contradictory_statements", "manipulation", "deflection", "insults",
17
+ "obscure_formal", "recovery_phase", "suicidal_threat", "physical_threat",
18
+ "extreme_control"
19
+ ]
20
  TOTAL_LABELS = 17
21
 
22
  # Our model outputs 17 labels:
 
24
  # - Last 3 are Danger Assessment cues
25
  TOTAL_LABELS = 17
26
 
27
+ def analyze_messages(input_text):
28
+ input_text = input_text.strip()
29
  if not input_text:
30
+ return "Please enter a message for analysis."
31
+
32
+ # Tokenize
33
  inputs = tokenizer(input_text, return_tensors="pt", truncation=True, padding=True)
34
  with torch.no_grad():
35
  outputs = model(**inputs)
36
+
37
+ # Squeeze out batch dimension: shape should be [17]
38
+ logits = outputs.logits.squeeze(0)
39
+
40
+ # Convert logits to probabilities
41
  scores = torch.sigmoid(logits).numpy()
42
+
43
+ # Debug printing (remove once you're confident everything works)
44
+ print("Scores:", scores)
45
+
46
+ # First 14 = pattern scores
47
+ pattern_scores = scores[:14]
48
+ pattern_count = int(np.sum(pattern_scores > 0.5))
49
+
50
+ # Last 3 = danger cues
51
+ danger_scores = scores[14:]
52
  danger_flag_count = int(np.sum(danger_scores > 0.5))
53
+
54
+ # (Optional) Print label-by-label for debugging
55
+ for i, s in enumerate(scores):
56
+ print(LABELS[i], "=", round(s, 3))
57
 
58
  # Map danger flag count to Danger Assessment Score
59
  if danger_flag_count >= 2: