Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -4,15 +4,14 @@ import numpy as np
|
|
4 |
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
5 |
from transformers import RobertaForSequenceClassification, RobertaTokenizer
|
6 |
|
7 |
-
# Load fine-tuned sentiment model
|
8 |
-
|
9 |
-
|
10 |
-
sentiment_tokenizer = AutoTokenizer.from_pretrained(sentiment_model_name)
|
11 |
|
12 |
-
# Load abuse pattern model
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
|
17 |
LABELS = [
|
18 |
"gaslighting", "mockery", "dismissiveness", "control", "guilt_tripping", "apology_baiting", "blame_shifting", "projection",
|
@@ -80,11 +79,11 @@ def analyze_messages(input_text, risk_flags):
|
|
80 |
sentiment_label = sentiment['label']
|
81 |
sentiment_score = sentiment['score']
|
82 |
|
83 |
-
adjusted_thresholds = {k: v * 0.8 for k, v in THRESHOLDS.items()} if sentiment_label == "
|
84 |
|
85 |
-
inputs =
|
86 |
with torch.no_grad():
|
87 |
-
outputs =
|
88 |
scores = torch.sigmoid(outputs.logits.squeeze(0)).numpy()
|
89 |
|
90 |
pattern_count = sum(score > adjusted_thresholds[label] for label, score in zip(PATTERN_LABELS, scores[:15]))
|
@@ -101,7 +100,7 @@ def analyze_messages(input_text, risk_flags):
|
|
101 |
if non_abusive_score > adjusted_thresholds['non_abusive']:
|
102 |
return "This message is classified as non-abusive."
|
103 |
|
104 |
-
abuse_level = calculate_abuse_level(scores,
|
105 |
abuse_description = interpret_abuse_level(abuse_level)
|
106 |
|
107 |
if danger_flag_count >= 2:
|
@@ -121,8 +120,8 @@ def analyze_messages(input_text, risk_flags):
|
|
121 |
f"Abuse Risk Score: {abuse_level}% – {abuse_description}\n\n"
|
122 |
f"Most Likely Patterns:\n{top_pattern_explanations}\n\n"
|
123 |
f"⚠️ Critical Danger Flags Detected: {danger_flag_count} of 3\n"
|
124 |
-
|
125 |
-
f"Sentiment: {sentiment_label} (Confidence: {sentiment_score*100:.2f}%)"
|
126 |
)
|
127 |
|
128 |
if contextual_flags:
|
|
|
4 |
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
5 |
from transformers import RobertaForSequenceClassification, RobertaTokenizer
|
6 |
|
7 |
+
# Load custom fine-tuned sentiment model
|
8 |
+
sentiment_model = AutoModelForSequenceClassification.from_pretrained("SamanthaStorm/tether-sentiment")
|
9 |
+
sentiment_tokenizer = AutoTokenizer.from_pretrained("SamanthaStorm/tether-sentiment")
|
|
|
10 |
|
11 |
+
# Load abuse pattern model
|
12 |
+
model_name = "SamanthaStorm/abuse-pattern-detector-v2"
|
13 |
+
model = RobertaForSequenceClassification.from_pretrained(model_name, trust_remote_code=True)
|
14 |
+
tokenizer = RobertaTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
15 |
|
16 |
LABELS = [
|
17 |
"gaslighting", "mockery", "dismissiveness", "control", "guilt_tripping", "apology_baiting", "blame_shifting", "projection",
|
|
|
79 |
sentiment_label = sentiment['label']
|
80 |
sentiment_score = sentiment['score']
|
81 |
|
82 |
+
adjusted_thresholds = {k: v * 0.8 for k, v in THRESHOLDS.items()} if sentiment_label == "undermining" else THRESHOLDS.copy()
|
83 |
|
84 |
+
inputs = tokenizer(input_text, return_tensors="pt", truncation=True, padding=True)
|
85 |
with torch.no_grad():
|
86 |
+
outputs = model(**inputs)
|
87 |
scores = torch.sigmoid(outputs.logits.squeeze(0)).numpy()
|
88 |
|
89 |
pattern_count = sum(score > adjusted_thresholds[label] for label, score in zip(PATTERN_LABELS, scores[:15]))
|
|
|
100 |
if non_abusive_score > adjusted_thresholds['non_abusive']:
|
101 |
return "This message is classified as non-abusive."
|
102 |
|
103 |
+
abuse_level = calculate_abuse_level(scores, adjusted_thresholds)
|
104 |
abuse_description = interpret_abuse_level(abuse_level)
|
105 |
|
106 |
if danger_flag_count >= 2:
|
|
|
120 |
f"Abuse Risk Score: {abuse_level}% – {abuse_description}\n\n"
|
121 |
f"Most Likely Patterns:\n{top_pattern_explanations}\n\n"
|
122 |
f"⚠️ Critical Danger Flags Detected: {danger_flag_count} of 3\n"
|
123 |
+
"Resources: " + resources + "\n\n"
|
124 |
+
f"Sentiment: {sentiment_label.title()} (Confidence: {sentiment_score*100:.2f}%)"
|
125 |
)
|
126 |
|
127 |
if contextual_flags:
|