SamanthaStorm commited on
Commit
4472a1d
·
verified ·
1 Parent(s): 32be36e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -22
app.py CHANGED
@@ -1,20 +1,18 @@
1
  import gradio as gr
2
  import torch
3
  import numpy as np
4
- from transformers import pipeline
5
- from transformers import RobertaForSequenceClassification, RobertaTokenizer
6
-
7
- # Load fine-tuned sentiment model from Hugging Face
8
- sentiment_analyzer = pipeline(
9
- "sentiment-analysis",
10
- model="SamanthaStorm/Tether",
11
- tokenizer="SamanthaStorm/Tether"
12
- )
13
 
14
- # Load abuse pattern model
15
- model_name = "SamanthaStorm/abuse-pattern-detector-v2"
16
- model = RobertaForSequenceClassification.from_pretrained(model_name, trust_remote_code=True)
17
- tokenizer = RobertaTokenizer.from_pretrained(model_name, trust_remote_code=True)
 
 
 
18
 
19
  LABELS = [
20
  "gaslighting", "mockery", "dismissiveness", "control", "guilt_tripping", "apology_baiting", "blame_shifting", "projection",
@@ -52,6 +50,16 @@ EXPLANATIONS = {
52
  "obscure_formal": "Obscure/formal language manipulates through confusion or superiority."
53
  }
54
 
 
 
 
 
 
 
 
 
 
 
55
  def calculate_abuse_level(scores, thresholds):
56
  triggered_scores = [score for label, score in zip(LABELS, scores) if score > thresholds[label]]
57
  return round(np.mean(triggered_scores) * 100, 2) if triggered_scores else 0.0
@@ -68,15 +76,15 @@ def analyze_messages(input_text, risk_flags):
68
  if not input_text:
69
  return "Please enter a message for analysis."
70
 
71
- sentiment = sentiment_analyzer(input_text)[0]
72
  sentiment_label = sentiment['label']
73
  sentiment_score = sentiment['score']
74
 
75
  adjusted_thresholds = {k: v * 0.8 for k, v in THRESHOLDS.items()} if sentiment_label == "NEGATIVE" else THRESHOLDS.copy()
76
 
77
- inputs = tokenizer(input_text, return_tensors="pt", truncation=True, padding=True)
78
  with torch.no_grad():
79
- outputs = model(**inputs)
80
  scores = torch.sigmoid(outputs.logits.squeeze(0)).numpy()
81
 
82
  pattern_count = sum(score > adjusted_thresholds[label] for label, score in zip(PATTERN_LABELS, scores[:15]))
@@ -96,16 +104,16 @@ def analyze_messages(input_text, risk_flags):
96
  abuse_level = calculate_abuse_level(scores, THRESHOLDS)
97
  abuse_description = interpret_abuse_level(abuse_level)
98
 
99
- if danger_flag_count >= 2:
100
- resources = "Immediate assistance recommended. Please seek professional help or contact emergency services."
101
- else:
102
- resources = "For more information on abuse patterns, consider reaching out to support groups or professional counselors."
 
103
 
104
  scored_patterns = [(label, score) for label, score in zip(PATTERN_LABELS, scores[:15])]
105
  top_patterns = sorted(scored_patterns, key=lambda x: x[1], reverse=True)[:2]
106
-
107
  top_pattern_explanations = "\n".join([
108
- f" {label.replace('_', ' ').title()}: {EXPLANATIONS.get(label, 'No explanation available.')}"
109
  for label, _ in top_patterns
110
  ])
111
 
 
1
  import gradio as gr
2
  import torch
3
  import numpy as np
4
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
5
+
6
+ # Load both sentiment and abuse models from the same Hugging Face repo
7
+ repo_id = "SamanthaStorm/Tether"
 
 
 
 
 
8
 
9
+ # Load fine-tuned sentiment model
10
+ sentiment_model = AutoModelForSequenceClassification.from_pretrained(repo_id, subfolder="sentiment")
11
+ sentiment_tokenizer = AutoTokenizer.from_pretrained(repo_id, subfolder="sentiment")
12
+
13
+ # Load abuse detection model
14
+ abuse_model = AutoModelForSequenceClassification.from_pretrained(repo_id, subfolder="abuse")
15
+ abuse_tokenizer = AutoTokenizer.from_pretrained(repo_id, subfolder="abuse")
16
 
17
  LABELS = [
18
  "gaslighting", "mockery", "dismissiveness", "control", "guilt_tripping", "apology_baiting", "blame_shifting", "projection",
 
50
  "obscure_formal": "Obscure/formal language manipulates through confusion or superiority."
51
  }
52
 
53
+ def custom_sentiment(text):
54
+ inputs = sentiment_tokenizer(text, return_tensors="pt", truncation=True, padding=True)
55
+ with torch.no_grad():
56
+ outputs = sentiment_model(**inputs)
57
+ probs = torch.nn.functional.softmax(outputs.logits, dim=1)
58
+ label_idx = torch.argmax(probs).item()
59
+ label = sentiment_model.config.id2label[label_idx]
60
+ score = probs[0][label_idx].item()
61
+ return {"label": label, "score": score}
62
+
63
  def calculate_abuse_level(scores, thresholds):
64
  triggered_scores = [score for label, score in zip(LABELS, scores) if score > thresholds[label]]
65
  return round(np.mean(triggered_scores) * 100, 2) if triggered_scores else 0.0
 
76
  if not input_text:
77
  return "Please enter a message for analysis."
78
 
79
+ sentiment = custom_sentiment(input_text)
80
  sentiment_label = sentiment['label']
81
  sentiment_score = sentiment['score']
82
 
83
  adjusted_thresholds = {k: v * 0.8 for k, v in THRESHOLDS.items()} if sentiment_label == "NEGATIVE" else THRESHOLDS.copy()
84
 
85
+ inputs = abuse_tokenizer(input_text, return_tensors="pt", truncation=True, padding=True)
86
  with torch.no_grad():
87
+ outputs = abuse_model(**inputs)
88
  scores = torch.sigmoid(outputs.logits.squeeze(0)).numpy()
89
 
90
  pattern_count = sum(score > adjusted_thresholds[label] for label, score in zip(PATTERN_LABELS, scores[:15]))
 
104
  abuse_level = calculate_abuse_level(scores, THRESHOLDS)
105
  abuse_description = interpret_abuse_level(abuse_level)
106
 
107
+ resources = (
108
+ "Immediate assistance recommended. Please seek professional help or contact emergency services."
109
+ if danger_flag_count >= 2 else
110
+ "For more information on abuse patterns, consider reaching out to support groups or professional counselors."
111
+ )
112
 
113
  scored_patterns = [(label, score) for label, score in zip(PATTERN_LABELS, scores[:15])]
114
  top_patterns = sorted(scored_patterns, key=lambda x: x[1], reverse=True)[:2]
 
115
  top_pattern_explanations = "\n".join([
116
+ f"\u2022 {label.replace('_', ' ').title()}: {EXPLANATIONS.get(label, 'No explanation available.')}"
117
  for label, _ in top_patterns
118
  ])
119