stock_sentiment_analysisv1 / src /sentiment_analyzer.py
S6six's picture
Initial commit of stock sentiment analysis project
9719f08
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import pandas as pd
import numpy as np
# Load the FinBERT model and tokenizer
# This might download the model files the first time it's run
tokenizer = AutoTokenizer.from_pretrained("ProsusAI/finbert")
model = AutoModelForSequenceClassification.from_pretrained("ProsusAI/finbert")
def analyze_sentiment(text):
"""
Analyzes the sentiment of a given text using the FinBERT model.
Args:
text (str): The input text (e.g., news headline or description).
Returns:
tuple: A tuple containing:
- sentiment_label (str): 'positive', 'negative', or 'neutral'.
- sentiment_score (float): The probability score of the predicted sentiment.
- scores (dict): Dictionary containing probabilities for all labels ('positive', 'negative', 'neutral').
Returns (None, None, None) if the input is invalid or an error occurs.
"""
if not isinstance(text, str) or not text.strip():
return None, None, None # Return None for empty or invalid input
try:
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512, padding=True)
with torch.no_grad(): # Disable gradient calculation for inference
outputs = model(**inputs)
# Get probabilities using softmax
probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
scores = probabilities[0].numpy() # Get scores for the first (and only) input
# Get the predicted sentiment label index
predicted_class_id = np.argmax(scores)
# Map index to label based on model config
sentiment_label = model.config.id2label[predicted_class_id]
sentiment_score = scores[predicted_class_id]
all_scores = {model.config.id2label[i]: scores[i] for i in range(len(scores))}
return sentiment_label, float(sentiment_score), {k: float(v) for k, v in all_scores.items()}
except Exception as e:
print(f"Error during sentiment analysis for text: '{text[:50]}...': {e}")
return None, None, None
# Example usage (for testing the module directly)
if __name__ == '__main__':
test_texts = [
"Stocks rallied on positive economic news.",
"The company reported a significant drop in profits.",
"Market remains flat amid uncertainty.",
"", # Empty string test
None # None test
]
print("--- Testing Sentiment Analysis ---")
for t in test_texts:
label, score, all_scores_dict = analyze_sentiment(t)
if label:
print(f"Text: '{t}'")
print(f" Sentiment: {label} (Score: {score:.4f})")
print(f" All Scores: {all_scores_dict}")
else:
print(f"Text: '{t}' -> Invalid input or error during analysis.")