wozwize's picture
updating logging
1360e33
import logging
from typing import Dict, Any, List, Optional
from transformers import pipeline, AutoTokenizer
import numpy as np
import nltk
from nltk.tokenize import sent_tokenize
logger = logging.getLogger(__name__)
class HeadlineAnalyzer:
def __init__(self, use_ai: bool = True, model_registry: Optional[Any] = None):
"""
Initialize the analyzers for headline analysis.
Args:
use_ai: Boolean indicating whether to use AI-powered analysis (True) or traditional analysis (False)
model_registry: Optional shared model registry for better performance
"""
self.use_ai = use_ai
self.llm_available = False
self.model_registry = model_registry
if use_ai:
try:
if model_registry and model_registry.is_available:
# Use shared models
self.nli_pipeline = model_registry.nli
self.zero_shot = model_registry.zero_shot
self.tokenizer = AutoTokenizer.from_pretrained("roberta-large-mnli")
self.llm_available = True
logger.info("Using shared model pipelines for headline analysis")
else:
# Initialize own pipelines
self.nli_pipeline = pipeline(
"text-classification",
model="roberta-large-mnli",
batch_size=16
)
self.zero_shot = pipeline(
"zero-shot-classification",
model="facebook/bart-large-mnli",
device=-1,
batch_size=8
)
self.tokenizer = AutoTokenizer.from_pretrained("roberta-large-mnli")
self.llm_available = True
logger.info("Initialized dedicated model pipelines for headline analysis")
self.max_length = 512
except Exception as e:
logger.warning(f"Failed to initialize LLM pipelines: {str(e)}")
self.llm_available = False
else:
logger.info("Initializing headline analyzer in traditional mode")
def _split_content(self, headline: str, content: str) -> List[str]:
"""Split content into sections that fit within token limit."""
content_words = content.split()
sections = []
current_section = []
# Account for headline and [SEP] token in the max length
headline_tokens = len(self.tokenizer.encode(headline))
sep_tokens = len(self.tokenizer.encode("[SEP]")) - 2
max_content_tokens = self.max_length - headline_tokens - sep_tokens
# Process words into sections with 4000 character chunks
current_text = ""
for word in content_words:
if len(current_text) + len(word) + 1 <= 4000:
current_text += " " + word
else:
sections.append(current_text.strip())
current_text = word
if current_text:
sections.append(current_text.strip())
return sections
def _analyze_section(self, headline: str, section: str) -> Dict[str, Any]:
"""Analyze a single section for headline accuracy and sensationalism."""
try:
logger.info("\n" + "-"*30)
logger.info("ANALYZING SECTION")
logger.info("-"*30)
logger.info(f"Headline: {headline}")
logger.info(f"Section length: {len(section)} characters")
# Download NLTK data if needed
try:
nltk.data.find('tokenizers/punkt')
except LookupError:
nltk.download('punkt')
sentences = sent_tokenize(section)
logger.info(f"Found {len(sentences)} sentences in section")
if not sentences:
logger.warning("No sentences found in section")
return {
"accuracy_score": 50.0,
"flagged_phrases": [],
"detailed_scores": {
"nli": {"ENTAILMENT": 0.0, "CONTRADICTION": 0.0, "NEUTRAL": 1.0},
"sensationalism": {"factual reporting": 0.5, "accurate headline": 0.5}
}
}
# Categories for sensationalism check
sensationalism_categories = [
"clickbait",
"sensationalized",
"misleading",
"factual reporting",
"accurate headline"
]
logger.info("Checking headline for sensationalism...")
sensationalism_result = self.zero_shot(
headline,
sensationalism_categories,
multi_label=True
)
sensationalism_scores = {
label: score
for label, score in zip(sensationalism_result['labels'], sensationalism_result['scores'])
}
logger.info(f"Sensationalism scores: {sensationalism_scores}")
# Filter relevant sentences (longer than 20 chars)
relevant_sentences = [s.strip() for s in sentences if len(s.strip()) > 20]
logger.info(f"Found {len(relevant_sentences)} relevant sentences after filtering")
if not relevant_sentences:
logger.warning("No relevant sentences found in section")
return {
"accuracy_score": 50.0,
"flagged_phrases": [],
"detailed_scores": {
"nli": {"ENTAILMENT": 0.0, "CONTRADICTION": 0.0, "NEUTRAL": 1.0},
"sensationalism": sensationalism_scores
}
}
# Process sentences in batches for contradiction/support
nli_scores = []
flagged_phrases = []
batch_size = 8
logger.info("Processing sentences for contradictions...")
for i in range(0, len(relevant_sentences), batch_size):
batch = relevant_sentences[i:i+batch_size]
batch_inputs = [f"{headline} [SEP] {sentence}" for sentence in batch]
try:
# Get NLI scores for batch
batch_results = self.nli_pipeline(batch_inputs, top_k=None)
if not isinstance(batch_results, list):
batch_results = [batch_results]
for sentence, result in zip(batch, batch_results):
scores = {item['label']: item['score'] for item in result}
nli_scores.append(scores)
# Flag contradictory content with lower threshold
if scores.get('CONTRADICTION', 0) > 0.3: # Lowered threshold
logger.info(f"Found contradictory sentence (score: {scores['CONTRADICTION']:.2f}): {sentence}")
flagged_phrases.append({
'text': sentence,
'type': 'Contradiction',
'score': scores['CONTRADICTION'],
'highlight': f"[CONTRADICTION] (Score: {round(scores['CONTRADICTION'] * 100, 1)}%) \"{sentence}\""
})
# Flag highly sensationalized content
if sensationalism_scores.get('sensationalized', 0) > 0.6 or sensationalism_scores.get('clickbait', 0) > 0.6:
logger.info(f"Found sensationalized content: {sentence}")
flagged_phrases.append({
'text': sentence,
'type': 'Sensationalized',
'score': max(sensationalism_scores.get('sensationalized', 0), sensationalism_scores.get('clickbait', 0)),
'highlight': f"[SENSATIONALIZED] \"{sentence}\""
})
except Exception as batch_error:
logger.warning(f"Batch processing error: {str(batch_error)}")
continue
# Calculate aggregate scores with validation
if not nli_scores:
logger.warning("No NLI scores available")
avg_scores = {"ENTAILMENT": 0.0, "CONTRADICTION": 0.0, "NEUTRAL": 1.0}
else:
try:
avg_scores = {
label: float(np.mean([
score.get(label, 0.0)
for score in nli_scores
]))
for label in ['ENTAILMENT', 'CONTRADICTION', 'NEUTRAL']
}
logger.info(f"Average NLI scores: {avg_scores}")
except Exception as agg_error:
logger.error(f"Error aggregating NLI scores: {str(agg_error)}")
avg_scores = {"ENTAILMENT": 0.0, "CONTRADICTION": 0.0, "NEUTRAL": 1.0}
# Calculate headline accuracy score with validation
try:
accuracy_components = {
'entailment': avg_scores.get('ENTAILMENT', 0.0) * 0.4,
'non_contradiction': (1 - avg_scores.get('CONTRADICTION', 0.0)) * 0.3,
'non_sensational': (
sensationalism_scores.get('factual reporting', 0.0) +
sensationalism_scores.get('accurate headline', 0.0)
) * 0.15,
'non_clickbait': (
1 - sensationalism_scores.get('clickbait', 0.0) -
sensationalism_scores.get('sensationalized', 0.0)
) * 0.15
}
logger.info(f"Accuracy components: {accuracy_components}")
accuracy_score = sum(accuracy_components.values()) * 100
# Validate final score
if np.isnan(accuracy_score) or not np.isfinite(accuracy_score):
logger.warning("Invalid accuracy score calculated, using default")
accuracy_score = 50.0
else:
accuracy_score = float(accuracy_score)
logger.info(f"Final accuracy score: {accuracy_score:.1f}")
except Exception as score_error:
logger.error(f"Error calculating accuracy score: {str(score_error)}")
accuracy_score = 50.0
# Sort and limit flagged phrases
sorted_phrases = sorted(
flagged_phrases,
key=lambda x: x['score'],
reverse=True
)
unique_phrases = []
seen = set()
for phrase in sorted_phrases:
if phrase['text'] not in seen:
unique_phrases.append(phrase)
seen.add(phrase['text'])
if len(unique_phrases) >= 5:
break
logger.info(f"Final number of flagged phrases: {len(unique_phrases)}")
return {
"accuracy_score": accuracy_score,
"flagged_phrases": unique_phrases,
"detailed_scores": {
"nli": avg_scores,
"sensationalism": sensationalism_scores
}
}
except Exception as e:
logger.error(f"Section analysis failed: {str(e)}")
return {
"accuracy_score": 50.0,
"flagged_phrases": [],
"detailed_scores": {
"nli": {"ENTAILMENT": 0.0, "CONTRADICTION": 0.0, "NEUTRAL": 1.0},
"sensationalism": {}
}
}
def _analyze_traditional(self, headline: str, content: str) -> Dict[str, Any]:
"""Traditional headline analysis method."""
try:
# Download NLTK data if needed
try:
nltk.data.find('tokenizers/punkt')
except LookupError:
nltk.download('punkt')
# Basic metrics
headline_words = set(headline.lower().split())
content_words = set(content.lower().split())
# Calculate word overlap
overlap_words = headline_words.intersection(content_words)
overlap_score = len(overlap_words) / len(headline_words) if headline_words else 0
# Check for clickbait patterns
clickbait_patterns = [
"you won't believe",
"shocking",
"mind blowing",
"amazing",
"incredible",
"unbelievable",
"must see",
"click here",
"find out",
"what happens next"
]
clickbait_count = sum(1 for pattern in clickbait_patterns if pattern in headline.lower())
clickbait_penalty = clickbait_count * 10 # 10% penalty per clickbait phrase
# Calculate final score (0-100)
base_score = overlap_score * 100
final_score = max(0, min(100, base_score - clickbait_penalty))
# Find potentially misleading phrases
flagged_phrases = []
sentences = sent_tokenize(content)
for sentence in sentences:
# Flag sentences that directly contradict headline words
sentence_words = set(sentence.lower().split())
if len(headline_words.intersection(sentence_words)) > 2:
flagged_phrases.append(sentence.strip())
# Flag sentences with clickbait patterns
if any(pattern in sentence.lower() for pattern in clickbait_patterns):
flagged_phrases.append(sentence.strip())
return {
"headline_vs_content_score": round(final_score, 1),
"flagged_phrases": list(set(flagged_phrases))[:5] # Limit to top 5 unique phrases
}
except Exception as e:
logger.error(f"Traditional analysis failed: {str(e)}")
return {
"headline_vs_content_score": 0,
"flagged_phrases": []
}
def analyze(self, headline: str, content: str) -> Dict[str, Any]:
"""Analyze how well the headline matches the content."""
try:
logger.info("\n" + "="*50)
logger.info("HEADLINE ANALYSIS STARTED")
logger.info("="*50)
if not headline.strip() or not content.strip():
logger.warning("Empty headline or content provided")
return {
"headline_vs_content_score": 0,
"flagged_phrases": []
}
# Use LLM analysis if available and enabled
if self.use_ai and self.llm_available:
logger.info("Using LLM analysis for headline")
# Split content if needed
sections = self._split_content(headline, content)
section_results = []
# Analyze each section
for section in sections:
result = self._analyze_section(headline, section)
section_results.append(result)
# Aggregate results across sections
accuracy_scores = [r['accuracy_score'] for r in section_results]
final_score = np.mean(accuracy_scores)
# Combine and deduplicate flagged phrases
all_phrases = []
for result in section_results:
if 'flagged_phrases' in result:
all_phrases.extend(result['flagged_phrases'])
# Sort by score and get unique phrases
sorted_phrases = sorted(all_phrases, key=lambda x: x['score'], reverse=True)
unique_phrases = []
seen = set()
for phrase in sorted_phrases:
if phrase['text'] not in seen:
unique_phrases.append(phrase)
seen.add(phrase['text'])
if len(unique_phrases) >= 5:
break
return {
"headline_vs_content_score": round(final_score, 1),
"flagged_phrases": unique_phrases
}
else:
# Use traditional analysis
logger.info("Using traditional headline analysis")
return self._analyze_traditional(headline, content)
except Exception as e:
logger.error(f"Headline analysis failed: {str(e)}")
return {
"headline_vs_content_score": 0,
"flagged_phrases": []
}