File size: 2,844 Bytes
1c02c6e
 
47f397e
c2f577f
e7e4d1a
c2f577f
5cb8034
c2f577f
47f397e
 
 
 
 
 
a49896f
c2f577f
7b89327
 
 
 
 
 
 
 
 
 
 
 
 
47f397e
7b89327
 
 
47f397e
cc1edab
47f397e
7b89327
 
17b66cc
47f397e
a49896f
 
cc1edab
 
 
a49896f
 
cc1edab
47f397e
7b89327
 
 
 
 
47f397e
7b89327
 
17b66cc
7b89327
 
 
47f397e
7b89327
 
47f397e
c2f577f
7b89327
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import gradio as gr
from transformers import pipeline
from smolagents import Tool

class SentimentAnalysisTool(Tool):
    name = "sentiment_analysis"
    description = "This tool analyses the sentiment of a given text."

    inputs = {
        "text": {
            "type": "string",
            "description": "The text to analyze for sentiment"
        }
    }
    output_type = "json"
    
    # Available sentiment analysis models
    models = {
        "multilingual": "nlptown/bert-base-multilingual-uncased-sentiment",
        "deberta": "microsoft/deberta-xlarge-mnli",
        "distilbert": "distilbert-base-uncased-finetuned-sst-2-english",
        "mobilebert": "lordtt13/emo-mobilebert",
        "reviews": "juliensimon/reviews-sentiment-analysis",
        "sbc": "sbcBI/sentiment_analysis_model",
        "german": "oliverguhr/german-sentiment-bert"
    }
    
    def __init__(self, default_model="distilbert"):
        """Initialize with a default model."""
        super().__init__()
        self.default_model = default_model
        # Pre-load the default model to speed up first inference
        self._classifiers = {}
        self._get_classifier(self.models[default_model])
    
    def forward(self, text: str): 
        """Process input text and return sentiment predictions."""
        return self.predict(text)
        
    def _parse_output(self, output_json):  
        """Parse model output into a dictionary of scores by label."""
        result = {}
        for i in range(len(output_json[0])):
            label = output_json[0][i]['label']
            score = output_json[0][i]['score']
            result[label] = score
        return result
    
    def _get_classifier(self, model_id):
        """Get or create a classifier for the given model ID."""
        if model_id not in self._classifiers:
            self._classifiers[model_id] = pipeline(
                "text-classification", 
                model=model_id, 
                top_k=None  # This replaces return_all_scores=True
            )
        return self._classifiers[model_id]
    
    def predict(self, text, model_key=None):
        """Make predictions using the specified or default model."""
        model_id = self.models[model_key] if model_key in self.models else self.models[self.default_model]
        classifier = self._get_classifier(model_id)
        
        prediction = classifier(text)
        return self._parse_output(prediction)

# For standalone testing
if __name__ == "__main__":
    # Create an instance of the SentimentAnalysisTool class
    sentiment_analysis_tool = SentimentAnalysisTool()
    
    # Test with a sample text
    test_text = "I really enjoyed this product. It exceeded my expectations!"
    result = sentiment_analysis_tool(test_text)
    print(f"Input: {test_text}")
    print(f"Result: {result}")