SamanthaStorm commited on
Commit
80e6ac9
·
verified ·
1 Parent(s): a9b6112

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -17
app.py CHANGED
@@ -10,6 +10,7 @@ import io
10
  from PIL import Image
11
  from datetime import datetime
12
  from transformers import pipeline as hf_pipeline # prevent name collision with gradio pipeline
 
13
 
14
  def get_emotion_profile(text):
15
  emotions = emotion_pipeline(text)
@@ -382,7 +383,6 @@ THREAT_MOTIFS = [
382
  ]
383
 
384
 
385
- @spaces.GPU
386
  @spaces.GPU
387
  def compute_abuse_score(matched_scores, sentiment):
388
  """
@@ -448,7 +448,7 @@ def compute_abuse_score(matched_scores, sentiment):
448
  base_score = max(base_score, 60.0) # Reduced
449
 
450
  return min(round(base_score, 1), 100.0)
451
-
452
  def analyze_single_message(text, thresholds):
453
  print("⚡ ENTERED analyze_single_message")
454
  stage = 1
@@ -459,9 +459,9 @@ def analyze_single_message(text, thresholds):
459
  sentiment_score = emotion_profile.get("anger", 0) + emotion_profile.get("disgust", 0)
460
 
461
  # Get model scores
462
- inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
463
- with torch.no_grad():
464
- outputs = model(**inputs)
465
  scores = torch.sigmoid(outputs.logits.squeeze(0)).numpy()
466
 
467
  # Sentiment override
@@ -475,21 +475,17 @@ def analyze_single_message(text, thresholds):
475
 
476
  weapon_flag = detect_weapon_language(text)
477
 
478
- adjusted_thresholds = {
479
- k: v + 0.05 if sentiment == "supportive" else v
480
- for k, v in thresholds.items()
481
- }
482
-
483
  darvo_score = predict_darvo_score(text)
484
 
485
- threshold_labels = [
486
- label for label, score in zip(LABELS, scores)
487
- if score > adjusted_thresholds[label]
488
- ]
489
 
490
- # Early exit if nothing passed
491
  if not threshold_labels:
492
- return 0.0, [], [], {"label": sentiment}, 1, 0.0, "supportive"
493
 
494
  top_patterns = sorted(
495
  [(label, score) for label, score in zip(LABELS, scores)],
@@ -544,7 +540,7 @@ def analyze_single_message(text, thresholds):
544
  top_patterns = [("insults", insult_score)] + top_patterns
545
  if "insults" not in threshold_labels:
546
  threshold_labels.append("insults")
547
-
548
  # Debug
549
  print(f"Emotional Tone Tag: {tone_tag}")
550
  print("Emotion Profile:")
 
10
  from PIL import Image
11
  from datetime import datetime
12
  from transformers import pipeline as hf_pipeline # prevent name collision with gradio pipeline
13
+ from functools import lru_cache # Import lru_cache
14
 
15
  def get_emotion_profile(text):
16
  emotions = emotion_pipeline(text)
 
383
  ]
384
 
385
 
 
386
  @spaces.GPU
387
  def compute_abuse_score(matched_scores, sentiment):
388
  """
 
448
  base_score = max(base_score, 60.0) # Reduced
449
 
450
  return min(round(base_score, 1), 100.0)
451
+ @lru_cache(maxsize=1024)
452
  def analyze_single_message(text, thresholds):
453
  print("⚡ ENTERED analyze_single_message")
454
  stage = 1
 
459
  sentiment_score = emotion_profile.get("anger", 0) + emotion_profile.get("disgust", 0)
460
 
461
  # Get model scores
462
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True).to(device) # Move to device
463
+ with torch.no_grad():
464
+ outputs = model(**inputs)
465
  scores = torch.sigmoid(outputs.logits.squeeze(0)).numpy()
466
 
467
  # Sentiment override
 
475
 
476
  weapon_flag = detect_weapon_language(text)
477
 
478
+ adjusted_thresholds = {k: v + 0.05 if sentiment == "supportive" else v for k, v in thresholds.items()}
 
 
 
 
479
  darvo_score = predict_darvo_score(text)
480
 
481
+ threshold_labels = [label for label, score in zip(LABELS, scores) if score > adjusted_thresholds[label]]
482
+
483
+ # Calculate matched scores *before* early exit
484
+ matched_scores = [(label, score, PATTERN_WEIGHTS.get(label, 1.0)) for label, score in zip(LABELS, scores) if score > adjusted_thresholds[label]]
485
 
486
+ # Early exit if nothing passed, but return tone_tag as None
487
  if not threshold_labels:
488
+ return 0.0, [], [], {"label": sentiment}, 1, 0.0, None # Return None for tone_tag
489
 
490
  top_patterns = sorted(
491
  [(label, score) for label, score in zip(LABELS, scores)],
 
540
  top_patterns = [("insults", insult_score)] + top_patterns
541
  if "insults" not in threshold_labels:
542
  threshold_labels.append("insults")
543
+ return abuse_score, threshold_labels, top_patterns, {"label": sentiment}, stage, darvo_score, tone_tag
544
  # Debug
545
  print(f"Emotional Tone Tag: {tone_tag}")
546
  print("Emotion Profile:")