|
""" |
|
Core sentiment analysis engine for MCP server. |
|
|
|
This module provides sentiment analysis functionality using both TextBlob |
|
for simplicity and Transformers for accuracy, with confidence scoring |
|
and comprehensive error handling. |
|
""" |
|
|
|
import logging |
|
from typing import Dict, Any, Optional, Tuple |
|
from enum import Enum |
|
import asyncio |
|
from concurrent.futures import ThreadPoolExecutor |
|
|
|
try: |
|
from textblob import TextBlob |
|
TEXTBLOB_AVAILABLE = True |
|
except ImportError: |
|
TEXTBLOB_AVAILABLE = False |
|
logging.warning("TextBlob not available. Install with: pip install textblob") |
|
|
|
try: |
|
from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification |
|
import torch |
|
TRANSFORMERS_AVAILABLE = True |
|
except ImportError: |
|
TRANSFORMERS_AVAILABLE = False |
|
logging.warning("Transformers not available. Install with: pip install transformers torch") |
|
|
|
|
|
class SentimentLabel(Enum): |
|
"""Sentiment classification labels.""" |
|
POSITIVE = "positive" |
|
NEGATIVE = "negative" |
|
NEUTRAL = "neutral" |
|
|
|
|
|
class SentimentResult: |
|
"""Container for sentiment analysis results.""" |
|
|
|
def __init__(self, label: SentimentLabel, confidence: float, raw_scores: Optional[Dict[str, float]] = None): |
|
self.label = label |
|
self.confidence = confidence |
|
self.raw_scores = raw_scores or {} |
|
|
|
def to_dict(self) -> Dict[str, Any]: |
|
"""Convert result to dictionary format.""" |
|
return { |
|
"label": self.label.value, |
|
"confidence": round(self.confidence, 4), |
|
"raw_scores": self.raw_scores |
|
} |
|
|
|
|
|
class SentimentAnalyzer: |
|
""" |
|
Advanced sentiment analysis engine supporting multiple backends. |
|
|
|
Supports both TextBlob (simple) and Transformers (accurate) for sentiment analysis |
|
with confidence scoring and async processing capabilities. |
|
""" |
|
|
|
def __init__(self, backend: str = "auto", model_name: str = "cardiffnlp/twitter-roberta-base-sentiment-latest"): |
|
""" |
|
Initialize sentiment analyzer. |
|
|
|
Args: |
|
backend: Analysis backend ("textblob", "transformers", or "auto") |
|
model_name: Hugging Face model name for transformers backend |
|
""" |
|
self.backend = backend |
|
self.model_name = model_name |
|
self.logger = logging.getLogger(__name__) |
|
self.executor = ThreadPoolExecutor(max_workers=2) |
|
|
|
|
|
self._transformer_pipeline = None |
|
self._model_loaded = False |
|
|
|
|
|
self._initialize_backend() |
|
|
|
def _initialize_backend(self) -> None: |
|
"""Initialize the selected backend.""" |
|
if self.backend == "auto": |
|
if TRANSFORMERS_AVAILABLE: |
|
self.backend = "transformers" |
|
self.logger.info("Auto-selected Transformers backend") |
|
elif TEXTBLOB_AVAILABLE: |
|
self.backend = "textblob" |
|
self.logger.info("Auto-selected TextBlob backend") |
|
else: |
|
raise RuntimeError("No sentiment analysis backend available. Install textblob or transformers.") |
|
|
|
if self.backend == "transformers" and not TRANSFORMERS_AVAILABLE: |
|
raise RuntimeError("Transformers backend requested but not available") |
|
|
|
if self.backend == "textblob" and not TEXTBLOB_AVAILABLE: |
|
raise RuntimeError("TextBlob backend requested but not available") |
|
|
|
async def _load_transformer_model(self) -> None: |
|
"""Load transformer model asynchronously.""" |
|
if self._model_loaded: |
|
return |
|
|
|
try: |
|
self.logger.info(f"Loading transformer model: {self.model_name}") |
|
|
|
|
|
loop = asyncio.get_event_loop() |
|
self._transformer_pipeline = await loop.run_in_executor( |
|
self.executor, |
|
lambda: pipeline( |
|
"sentiment-analysis", |
|
model=self.model_name, |
|
tokenizer=self.model_name, |
|
device=0 if torch.cuda.is_available() else -1, |
|
return_all_scores=True |
|
) |
|
) |
|
|
|
self._model_loaded = True |
|
self.logger.info("Transformer model loaded successfully") |
|
|
|
except Exception as e: |
|
self.logger.error(f"Failed to load transformer model: {e}") |
|
raise RuntimeError(f"Model loading failed: {e}") |
|
|
|
def _validate_input(self, text: str) -> str: |
|
""" |
|
Validate and sanitize input text. |
|
|
|
Args: |
|
text: Input text to validate |
|
|
|
Returns: |
|
Sanitized text |
|
|
|
Raises: |
|
ValueError: If text is invalid |
|
""" |
|
if not isinstance(text, str): |
|
raise ValueError("Input must be a string") |
|
|
|
text = text.strip() |
|
|
|
if not text: |
|
raise ValueError("Input text cannot be empty") |
|
|
|
if len(text) > 10000: |
|
raise ValueError("Input text too long (max 10,000 characters)") |
|
|
|
|
|
text = text.replace('\x00', '') |
|
|
|
return text |
|
|
|
def _analyze_with_textblob(self, text: str) -> SentimentResult: |
|
""" |
|
Analyze sentiment using TextBlob. |
|
|
|
Args: |
|
text: Text to analyze |
|
|
|
Returns: |
|
SentimentResult with classification and confidence |
|
""" |
|
try: |
|
blob = TextBlob(text) |
|
polarity = blob.sentiment.polarity |
|
|
|
|
|
if polarity > 0.1: |
|
label = SentimentLabel.POSITIVE |
|
confidence = min(polarity, 1.0) |
|
elif polarity < -0.1: |
|
label = SentimentLabel.NEGATIVE |
|
confidence = min(abs(polarity), 1.0) |
|
else: |
|
label = SentimentLabel.NEUTRAL |
|
confidence = 1.0 - abs(polarity) |
|
|
|
raw_scores = { |
|
"polarity": polarity, |
|
"subjectivity": blob.sentiment.subjectivity |
|
} |
|
|
|
return SentimentResult(label, confidence, raw_scores) |
|
|
|
except Exception as e: |
|
self.logger.error(f"TextBlob analysis failed: {e}") |
|
raise RuntimeError(f"Sentiment analysis failed: {e}") |
|
|
|
async def _analyze_with_transformers(self, text: str) -> SentimentResult: |
|
""" |
|
Analyze sentiment using Transformers. |
|
|
|
Args: |
|
text: Text to analyze |
|
|
|
Returns: |
|
SentimentResult with classification and confidence |
|
""" |
|
try: |
|
await self._load_transformer_model() |
|
|
|
|
|
loop = asyncio.get_event_loop() |
|
results = await loop.run_in_executor( |
|
self.executor, |
|
lambda: self._transformer_pipeline(text) |
|
) |
|
|
|
|
|
scores = {result['label'].lower(): result['score'] for result in results[0]} |
|
|
|
|
|
label_mapping = { |
|
'positive': SentimentLabel.POSITIVE, |
|
'negative': SentimentLabel.NEGATIVE, |
|
'neutral': SentimentLabel.NEUTRAL, |
|
'label_0': SentimentLabel.NEGATIVE, |
|
'label_1': SentimentLabel.NEUTRAL, |
|
'label_2': SentimentLabel.POSITIVE |
|
} |
|
|
|
|
|
best_score = 0 |
|
best_label = SentimentLabel.NEUTRAL |
|
|
|
for model_label, score in scores.items(): |
|
if model_label in label_mapping and score > best_score: |
|
best_score = score |
|
best_label = label_mapping[model_label] |
|
|
|
return SentimentResult(best_label, best_score, scores) |
|
|
|
except Exception as e: |
|
self.logger.error(f"Transformers analysis failed: {e}") |
|
raise RuntimeError(f"Sentiment analysis failed: {e}") |
|
|
|
async def analyze(self, text: str) -> SentimentResult: |
|
""" |
|
Analyze sentiment of input text. |
|
|
|
Args: |
|
text: Text to analyze |
|
|
|
Returns: |
|
SentimentResult with label, confidence, and raw scores |
|
|
|
Raises: |
|
ValueError: If input is invalid |
|
RuntimeError: If analysis fails |
|
""" |
|
|
|
text = self._validate_input(text) |
|
|
|
try: |
|
if self.backend == "transformers": |
|
return await self._analyze_with_transformers(text) |
|
elif self.backend == "textblob": |
|
|
|
loop = asyncio.get_event_loop() |
|
return await loop.run_in_executor( |
|
self.executor, |
|
self._analyze_with_textblob, |
|
text |
|
) |
|
else: |
|
raise RuntimeError(f"Unknown backend: {self.backend}") |
|
|
|
except Exception as e: |
|
self.logger.error(f"Sentiment analysis failed for text: {text[:100]}... Error: {e}") |
|
raise |
|
|
|
async def analyze_batch(self, texts: list[str]) -> list[SentimentResult]: |
|
""" |
|
Analyze sentiment for multiple texts concurrently. |
|
|
|
Args: |
|
texts: List of texts to analyze |
|
|
|
Returns: |
|
List of SentimentResult objects |
|
""" |
|
if not texts: |
|
return [] |
|
|
|
|
|
tasks = [self.analyze(text) for text in texts] |
|
results = await asyncio.gather(*tasks, return_exceptions=True) |
|
|
|
|
|
processed_results = [] |
|
for i, result in enumerate(results): |
|
if isinstance(result, Exception): |
|
self.logger.error(f"Failed to analyze text {i}: {result}") |
|
|
|
processed_results.append( |
|
SentimentResult(SentimentLabel.NEUTRAL, 0.0, {"error": str(result)}) |
|
) |
|
else: |
|
processed_results.append(result) |
|
|
|
return processed_results |
|
|
|
def get_info(self) -> Dict[str, Any]: |
|
"""Get information about the analyzer configuration.""" |
|
return { |
|
"backend": self.backend, |
|
"model_name": self.model_name if self.backend == "transformers" else None, |
|
"model_loaded": self._model_loaded, |
|
"textblob_available": TEXTBLOB_AVAILABLE, |
|
"transformers_available": TRANSFORMERS_AVAILABLE, |
|
"cuda_available": torch.cuda.is_available() if TRANSFORMERS_AVAILABLE else False |
|
} |
|
|
|
async def cleanup(self) -> None: |
|
"""Clean up resources.""" |
|
self.executor.shutdown(wait=True) |
|
self.logger.info("Sentiment analyzer cleaned up") |
|
|
|
|
|
|
|
_global_analyzer: Optional[SentimentAnalyzer] = None |
|
|
|
|
|
async def get_analyzer(backend: str = "auto") -> SentimentAnalyzer: |
|
""" |
|
Get or create global sentiment analyzer instance. |
|
|
|
Args: |
|
backend: Analysis backend to use |
|
|
|
Returns: |
|
SentimentAnalyzer instance |
|
""" |
|
global _global_analyzer |
|
|
|
if _global_analyzer is None: |
|
_global_analyzer = SentimentAnalyzer(backend=backend) |
|
|
|
return _global_analyzer |
|
|
|
|
|
async def analyze_sentiment(text: str, backend: str = "auto") -> Dict[str, Any]: |
|
""" |
|
Convenience function for sentiment analysis. |
|
|
|
Args: |
|
text: Text to analyze |
|
backend: Analysis backend to use |
|
|
|
Returns: |
|
Dictionary with sentiment analysis results |
|
""" |
|
analyzer = await get_analyzer(backend) |
|
result = await analyzer.analyze(text) |
|
return result.to_dict() |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
async def main(): |
|
analyzer = SentimentAnalyzer(backend="textblob") |
|
|
|
test_texts = [ |
|
"I love this product! It's amazing!", |
|
"This is terrible and I hate it.", |
|
"It's okay, nothing special.", |
|
"The weather is nice today." |
|
] |
|
|
|
for text in test_texts: |
|
result = await analyzer.analyze(text) |
|
print(f"Text: {text}") |
|
print(f"Result: {result.to_dict()}") |
|
print("-" * 50) |
|
|
|
await analyzer.cleanup() |
|
|
|
asyncio.run(main()) |