Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -9,6 +9,14 @@ model_name = "SamanthaStorm/abuse-pattern-detector-v2"
|
|
9 |
model = RobertaForSequenceClassification.from_pretrained(model_name)
|
10 |
tokenizer = RobertaTokenizer.from_pretrained(model_name)
|
11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
TOTAL_LABELS = 17
|
13 |
|
14 |
# Our model outputs 17 labels:
|
@@ -16,27 +24,36 @@ TOTAL_LABELS = 17
|
|
16 |
# - Last 3 are Danger Assessment cues
|
17 |
TOTAL_LABELS = 17
|
18 |
|
19 |
-
def analyze_messages(
|
20 |
-
input_text =
|
21 |
if not input_text:
|
22 |
-
return "Please enter a message for analysis."
|
23 |
-
|
24 |
-
# Tokenize
|
25 |
inputs = tokenizer(input_text, return_tensors="pt", truncation=True, padding=True)
|
26 |
with torch.no_grad():
|
27 |
outputs = model(**inputs)
|
28 |
-
|
29 |
-
#
|
30 |
-
logits = outputs.logits.squeeze()
|
|
|
|
|
31 |
scores = torch.sigmoid(logits).numpy()
|
32 |
-
|
33 |
-
#
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
|
|
|
|
|
|
39 |
danger_flag_count = int(np.sum(danger_scores > 0.5))
|
|
|
|
|
|
|
|
|
40 |
|
41 |
# Map danger flag count to Danger Assessment Score
|
42 |
if danger_flag_count >= 2:
|
|
|
9 |
model = RobertaForSequenceClassification.from_pretrained(model_name)
|
10 |
tokenizer = RobertaTokenizer.from_pretrained(model_name)
|
11 |
|
12 |
+
# Define the final label order your model used
|
13 |
+
LABELS = [
|
14 |
+
"gaslighting", "mockery", "dismissiveness", "control",
|
15 |
+
"guilt_tripping", "apology_baiting", "blame_shifting", "projection",
|
16 |
+
"contradictory_statements", "manipulation", "deflection", "insults",
|
17 |
+
"obscure_formal", "recovery_phase", "suicidal_threat", "physical_threat",
|
18 |
+
"extreme_control"
|
19 |
+
]
|
20 |
TOTAL_LABELS = 17
|
21 |
|
22 |
# Our model outputs 17 labels:
|
|
|
24 |
# - Last 3 are Danger Assessment cues
|
25 |
TOTAL_LABELS = 17
|
26 |
|
27 |
+
def analyze_messages(input_text):
|
28 |
+
input_text = input_text.strip()
|
29 |
if not input_text:
|
30 |
+
return "Please enter a message for analysis."
|
31 |
+
|
32 |
+
# Tokenize
|
33 |
inputs = tokenizer(input_text, return_tensors="pt", truncation=True, padding=True)
|
34 |
with torch.no_grad():
|
35 |
outputs = model(**inputs)
|
36 |
+
|
37 |
+
# Squeeze out batch dimension: shape should be [17]
|
38 |
+
logits = outputs.logits.squeeze(0)
|
39 |
+
|
40 |
+
# Convert logits to probabilities
|
41 |
scores = torch.sigmoid(logits).numpy()
|
42 |
+
|
43 |
+
# Debug printing (remove once you're confident everything works)
|
44 |
+
print("Scores:", scores)
|
45 |
+
|
46 |
+
# First 14 = pattern scores
|
47 |
+
pattern_scores = scores[:14]
|
48 |
+
pattern_count = int(np.sum(pattern_scores > 0.5))
|
49 |
+
|
50 |
+
# Last 3 = danger cues
|
51 |
+
danger_scores = scores[14:]
|
52 |
danger_flag_count = int(np.sum(danger_scores > 0.5))
|
53 |
+
|
54 |
+
# (Optional) Print label-by-label for debugging
|
55 |
+
for i, s in enumerate(scores):
|
56 |
+
print(LABELS[i], "=", round(s, 3))
|
57 |
|
58 |
# Map danger flag count to Danger Assessment Score
|
59 |
if danger_flag_count >= 2:
|