SamanthaStorm commited on
Commit
0ff864f
·
verified ·
1 Parent(s): 94e76c4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -3
app.py CHANGED
@@ -2,6 +2,10 @@ import gradio as gr
2
  import torch
3
  from transformers import RobertaForSequenceClassification, RobertaTokenizer
4
  import numpy as np
 
 
 
 
5
 
6
  # Load model and tokenizer with trust_remote_code in case it's needed
7
  model_name = "SamanthaStorm/abuse-pattern-detector-v2"
@@ -83,7 +87,18 @@ def analyze_messages(input_text):
83
  input_text = input_text.strip()
84
  if not input_text:
85
  return "Please enter a message for analysis.", None
86
-
 
 
 
 
 
 
 
 
 
 
 
87
  # Tokenize input and generate model predictions
88
  inputs = tokenizer(input_text, return_tensors="pt", truncation=True, padding=True)
89
  with torch.no_grad():
@@ -91,8 +106,14 @@ def analyze_messages(input_text):
91
  scores = torch.sigmoid(outputs.logits.squeeze(0)).numpy()
92
 
93
  # Count the number of triggered abuse pattern and danger flags based on thresholds
94
- pattern_count = sum(score > THRESHOLDS[label] for label, score in zip(PATTERN_LABELS, scores[:14]))
95
- danger_flag_count = sum(score > THRESHOLDS[label] for label, score in zip(DANGER_LABELS, scores[14:17]))
 
 
 
 
 
 
96
 
97
  # Build formatted raw score display
98
  score_lines = [
@@ -127,6 +148,7 @@ def analyze_messages(input_text):
127
  "It flags communication patterns associated with increased risk of severe harm. "
128
  "For more info, consider reaching out to support groups or professionals.\n\n"
129
  f"Resources: {resources}"
 
130
  )
131
 
132
  # Return both a text summary and a JSON-like dict of scores per label
 
2
  import torch
3
  from transformers import RobertaForSequenceClassification, RobertaTokenizer
4
  import numpy as np
5
+ from transformers import pipeline
6
+
7
+ # Load sentiment analysis model
8
+ sentiment_analyzer = pipeline("sentiment-analysis")
9
 
10
  # Load model and tokenizer with trust_remote_code in case it's needed
11
  model_name = "SamanthaStorm/abuse-pattern-detector-v2"
 
87
  input_text = input_text.strip()
88
  if not input_text:
89
  return "Please enter a message for analysis.", None
90
+
91
+ # Sentiment analysis
92
+ sentiment = sentiment_analyzer(input_text)[0] # Sentiment result
93
+ sentiment_label = sentiment['label']
94
+ sentiment_score = sentiment['score']
95
+
96
+ # Adjust thresholds based on sentiment
97
+ adjusted_thresholds = THRESHOLDS.copy()
98
+ if sentiment_label == "NEGATIVE":
99
+ # Lower thresholds for negative sentiment
100
+ adjusted_thresholds = {key: val * 0.8 for key, val in THRESHOLDS.items()} # Example adjustment
101
+
102
  # Tokenize input and generate model predictions
103
  inputs = tokenizer(input_text, return_tensors="pt", truncation=True, padding=True)
104
  with torch.no_grad():
 
106
  scores = torch.sigmoid(outputs.logits.squeeze(0)).numpy()
107
 
108
  # Count the number of triggered abuse pattern and danger flags based on thresholds
109
+ pattern_count = sum(score > adjusted_thresholds[label] for label, score in zip(PATTERN_LABELS, scores[:14]))
110
+ danger_flag_count = sum(score > adjusted_thresholds[label] for label, score in zip(DANGER_LABELS, scores[14:17]))
111
+
112
+ # Check if 'non_abusive' label is triggered
113
+ non_abusive_score = scores[LABELS.index('non_abusive')]
114
+ if non_abusive_score > adjusted_thresholds['non_abusive']:
115
+ # If non-abusive threshold is met, return a non-abusive classification
116
+ return "This message is classified as non-abusive."
117
 
118
  # Build formatted raw score display
119
  score_lines = [
 
148
  "It flags communication patterns associated with increased risk of severe harm. "
149
  "For more info, consider reaching out to support groups or professionals.\n\n"
150
  f"Resources: {resources}"
151
+ f"Sentiment: {sentiment_label} (Confidence: {sentiment_score*100:.2f}%)"
152
  )
153
 
154
  # Return both a text summary and a JSON-like dict of scores per label