SamanthaStorm commited on
Commit
909b775
·
verified ·
1 Parent(s): 92b1518

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -24
app.py CHANGED
@@ -1,20 +1,13 @@
1
  import gradio as gr
2
  import torch
3
  import numpy as np
4
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
5
  from transformers import RobertaForSequenceClassification, RobertaTokenizer
6
  from motif_tagging import detect_motifs
7
  import re
8
 
9
- # --- Sentiment Model ---
10
- sentiment_tokenizer = AutoTokenizer.from_pretrained("mrm8488/t5-base-finetuned-emotion")
11
- sentiment_model = AutoModelForSeq2SeqLM.from_pretrained("mrm8488/t5-base-finetuned-emotion")
12
-
13
- EMOTION_TO_SENTIMENT = {
14
- "joy": "supportive", "love": "supportive", "surprise": "supportive", "neutral": "supportive",
15
- "sadness": "undermining", "anger": "undermining", "fear": "undermining",
16
- "disgust": "undermining", "shame": "undermining", "guilt": "undermining"
17
- }
18
 
19
  # --- Abuse Model ---
20
  model_name = "SamanthaStorm/autotrain-jlpi4-mllvp"
@@ -125,27 +118,19 @@ def generate_risk_snippet(abuse_score, top_label):
125
  def analyze_single_message(text, thresholds, motif_flags):
126
  motif_hits, matched_phrases = detect_motifs(text)
127
 
128
- # Sentiment Analysis
129
- input_ids = sentiment_tokenizer(f"emotion: {text}", return_tensors="pt").input_ids
130
- with torch.no_grad():
131
- outputs = sentiment_model.generate(input_ids)
132
- emotion = sentiment_tokenizer.decode(outputs[0], skip_special_tokens=True).strip().lower()
133
- sentiment = EMOTION_TO_SENTIMENT.get(emotion, "undermining")
134
- sentiment_score = 0.5 if sentiment == "undermining" else 0.0
135
 
136
- # Raise thresholds slightly if the sentiment is supportive
137
  adjusted_thresholds = {
138
  k: v + 0.05 if sentiment == "supportive" else v
139
  for k, v in thresholds.items()
140
  }
141
 
142
- # Contradiction Check
143
  contradiction_flag = detect_contradiction(text)
144
-
145
- # Motifs
146
  motifs = [phrase for _, phrase in matched_phrases]
147
 
148
- # Model Prediction
149
  inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
150
  with torch.no_grad():
151
  outputs = model(**inputs)
@@ -169,7 +154,7 @@ def analyze_single_message(text, thresholds, motif_flags):
169
  threshold_labels,
170
  top_patterns,
171
  darvo_score,
172
- {"label": sentiment, "emotion": emotion}
173
  )
174
 
175
  def analyze_composite(msg1, msg2, msg3, *answers_and_none):
@@ -211,4 +196,4 @@ iface = gr.Interface(
211
  )
212
 
213
  if __name__ == "__main__":
214
- iface.launch()
 
1
  import gradio as gr
2
  import torch
3
  import numpy as np
4
+ from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
5
  from transformers import RobertaForSequenceClassification, RobertaTokenizer
6
  from motif_tagging import detect_motifs
7
  import re
8
 
9
+ # --- SST Sentiment Model ---
10
+ sst_pipeline = pipeline("sentiment-analysis", model="distilbert-base-uncased-finetuned-sst-2-english")
 
 
 
 
 
 
 
11
 
12
  # --- Abuse Model ---
13
  model_name = "SamanthaStorm/autotrain-jlpi4-mllvp"
 
118
  def analyze_single_message(text, thresholds, motif_flags):
119
  motif_hits, matched_phrases = detect_motifs(text)
120
 
121
+ # SST Sentiment
122
+ result = sst_pipeline(text)[0]
123
+ sentiment = "supportive" if result['label'] == "POSITIVE" else "undermining"
124
+ sentiment_score = result['score'] if sentiment == "undermining" else 0.0
 
 
 
125
 
 
126
  adjusted_thresholds = {
127
  k: v + 0.05 if sentiment == "supportive" else v
128
  for k, v in thresholds.items()
129
  }
130
 
 
131
  contradiction_flag = detect_contradiction(text)
 
 
132
  motifs = [phrase for _, phrase in matched_phrases]
133
 
 
134
  inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
135
  with torch.no_grad():
136
  outputs = model(**inputs)
 
154
  threshold_labels,
155
  top_patterns,
156
  darvo_score,
157
+ {"label": sentiment, "raw_label": result['label'], "score": result['score']}
158
  )
159
 
160
  def analyze_composite(msg1, msg2, msg3, *answers_and_none):
 
196
  )
197
 
198
  if __name__ == "__main__":
199
+ iface.launch()