SamanthaStorm commited on
Commit
5dfb1ca
·
verified ·
1 Parent(s): 38fd495

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -17
app.py CHANGED
@@ -2,14 +2,13 @@ import gradio as gr
2
  import torch
3
  from transformers import RobertaForSequenceClassification, RobertaTokenizer
4
  import numpy as np
5
- import tempfile
6
 
7
- # Load model and tokenizer
8
  model_name = "SamanthaStorm/abuse-pattern-detector-v2"
9
- model = RobertaForSequenceClassification.from_pretrained(model_name)
10
- tokenizer = RobertaTokenizer.from_pretrained(model_name)
11
 
12
- # Define labels (total 17 labels)
13
  LABELS = [
14
  "gaslighting", "mockery", "dismissiveness", "control",
15
  "guilt_tripping", "apology_baiting", "blame_shifting", "projection",
@@ -18,11 +17,11 @@ LABELS = [
18
  "extreme_control"
19
  ]
20
 
21
- # Custom thresholds per label (make sure these are exactly as in the original)
22
  THRESHOLDS = {
23
  "gaslighting": 0.15,
24
  "mockery": 0.15,
25
- "dismissiveness": 0.25, # Keep this as 0.25 (not 0.30)
26
  "control": 0.13,
27
  "guilt_tripping": 0.15,
28
  "apology_baiting": 0.15,
@@ -39,7 +38,7 @@ THRESHOLDS = {
39
  "extreme_control": 0.30,
40
  }
41
 
42
- # Define label groups using slicing (first 14 are abuse patterns, last 3 are danger cues)
43
  PATTERN_LABELS = LABELS[:14]
44
  DANGER_LABELS = LABELS[14:]
45
 
@@ -66,40 +65,47 @@ def analyze_messages(input_text):
66
  if not input_text:
67
  return "Please enter a message for analysis.", None
68
 
69
- # Tokenize and predict
70
  inputs = tokenizer(input_text, return_tensors="pt", truncation=True, padding=True)
71
  with torch.no_grad():
72
  outputs = model(**inputs)
73
  scores = torch.sigmoid(outputs.logits.squeeze(0)).numpy()
74
 
75
- # Count triggered labels using the correct slices
76
  pattern_count = sum(score > THRESHOLDS[label] for label, score in zip(PATTERN_LABELS, scores[:14]))
77
  danger_flag_count = sum(score > THRESHOLDS[label] for label, score in zip(DANGER_LABELS, scores[14:]))
78
 
79
- # Abuse level calculation and severity interpretation
80
  abuse_level = calculate_abuse_level(scores, THRESHOLDS)
81
  abuse_description = interpret_abuse_level(abuse_level)
82
 
83
- # Resource logic (example logic; adjust as needed)
84
  if danger_flag_count >= 2:
85
  resources = "Immediate assistance recommended. Please seek professional help or contact emergency services."
86
  else:
87
  resources = "For more information on abuse patterns, consider reaching out to support groups or professional counselors."
88
 
89
- # Output combining counts, severity, and resource suggestion
90
  result = (
91
  f"Abuse Patterns Detected: {pattern_count} out of {len(PATTERN_LABELS)}\n"
92
  f"Danger Flags Detected: {danger_flag_count} out of {len(DANGER_LABELS)}\n"
93
  f"Abuse Level: {abuse_level}% - {abuse_description}\n"
94
  f"Resources: {resources}"
95
  )
96
- return result, scores
 
 
97
 
 
98
  iface = gr.Interface(
99
  fn=analyze_messages,
100
- inputs=gr.inputs.Textbox(lines=10, placeholder="Enter message here..."),
101
- outputs=[gr.Textbox(), gr.JSON()],
 
 
 
102
  title="Abuse Pattern Detector"
103
  )
104
 
105
- iface.launch()
 
 
2
  import torch
3
  from transformers import RobertaForSequenceClassification, RobertaTokenizer
4
  import numpy as np
 
5
 
6
+ # Load model and tokenizer with trust_remote_code in case it's needed
7
  model_name = "SamanthaStorm/abuse-pattern-detector-v2"
8
+ model = RobertaForSequenceClassification.from_pretrained(model_name, trust_remote_code=True)
9
+ tokenizer = RobertaTokenizer.from_pretrained(model_name, trust_remote_code=True)
10
 
11
+ # Define labels (17 total)
12
  LABELS = [
13
  "gaslighting", "mockery", "dismissiveness", "control",
14
  "guilt_tripping", "apology_baiting", "blame_shifting", "projection",
 
17
  "extreme_control"
18
  ]
19
 
20
+ # Custom thresholds for each label (make sure these match your original settings)
21
  THRESHOLDS = {
22
  "gaslighting": 0.15,
23
  "mockery": 0.15,
24
+ "dismissiveness": 0.25, # original value, not 0.30
25
  "control": 0.13,
26
  "guilt_tripping": 0.15,
27
  "apology_baiting": 0.15,
 
38
  "extreme_control": 0.30,
39
  }
40
 
41
+ # Define label groups using slicing (first 14: abuse patterns, last 3: danger cues)
42
  PATTERN_LABELS = LABELS[:14]
43
  DANGER_LABELS = LABELS[14:]
44
 
 
65
  if not input_text:
66
  return "Please enter a message for analysis.", None
67
 
68
+ # Tokenize input and generate model predictions
69
  inputs = tokenizer(input_text, return_tensors="pt", truncation=True, padding=True)
70
  with torch.no_grad():
71
  outputs = model(**inputs)
72
  scores = torch.sigmoid(outputs.logits.squeeze(0)).numpy()
73
 
74
+ # Count the number of triggered abuse pattern and danger flags based on thresholds
75
  pattern_count = sum(score > THRESHOLDS[label] for label, score in zip(PATTERN_LABELS, scores[:14]))
76
  danger_flag_count = sum(score > THRESHOLDS[label] for label, score in zip(DANGER_LABELS, scores[14:]))
77
 
78
+ # Calculate overall abuse level and interpret it
79
  abuse_level = calculate_abuse_level(scores, THRESHOLDS)
80
  abuse_description = interpret_abuse_level(abuse_level)
81
 
82
+ # Resource logic based on the number of danger cues
83
  if danger_flag_count >= 2:
84
  resources = "Immediate assistance recommended. Please seek professional help or contact emergency services."
85
  else:
86
  resources = "For more information on abuse patterns, consider reaching out to support groups or professional counselors."
87
 
88
+ # Prepare the result summary and detailed scores
89
  result = (
90
  f"Abuse Patterns Detected: {pattern_count} out of {len(PATTERN_LABELS)}\n"
91
  f"Danger Flags Detected: {danger_flag_count} out of {len(DANGER_LABELS)}\n"
92
  f"Abuse Level: {abuse_level}% - {abuse_description}\n"
93
  f"Resources: {resources}"
94
  )
95
+
96
+ # Return both a text summary and a JSON-like dict of scores per label
97
+ return result, {"scores": dict(zip(LABELS, scores))}
98
 
99
+ # Updated Gradio Interface using new component syntax
100
  iface = gr.Interface(
101
  fn=analyze_messages,
102
+ inputs=gr.Textbox(lines=10, placeholder="Enter message here..."),
103
+ outputs=[
104
+ gr.Textbox(label="Analysis Result"),
105
+ gr.JSON(label="Scores")
106
+ ],
107
  title="Abuse Pattern Detector"
108
  )
109
 
110
+ if __name__ == "__main__":
111
+ iface.launch()