File size: 5,655 Bytes
fb17746
c402b7d
 
 
 
 
fb17746
c402b7d
f611cc3
c402b7d
2b79b9f
95e25f2
2b79b9f
95e25f2
 
2b79b9f
 
c402b7d
061fd19
f611cc3
c402b7d
 
 
 
 
 
 
 
 
95e25f2
2b79b9f
c402b7d
 
 
 
 
 
f611cc3
c402b7d
 
 
f611cc3
c402b7d
f611cc3
95e25f2
 
 
 
 
 
c402b7d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2b79b9f
061fd19
c402b7d
 
 
 
 
 
061fd19
 
c402b7d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f611cc3
c402b7d
f611cc3
c402b7d
 
95e25f2
c402b7d
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
"""
Provides a robust, asynchronous SentimentAnalyzer class.

This module communicates with the Hugging Face Inference API to perform sentiment
analysis without requiring local model downloads. It's designed for use within
an asynchronous application like FastAPI.
"""
import asyncio
import logging
import os
# ====================================================================
#                       FINAL FIX APPLIED HERE
# ====================================================================
# Import Optional and Union for Python 3.9 compatibility.
from typing import TypedDict, Union, Optional
# ====================================================================


import httpx

# --- Configuration & Models ---

# Configure logging for this module
logger = logging.getLogger(__name__)

# Define the expected structure of a result payload for type hinting
class SentimentResult(TypedDict):
    id: int
    text: str
    # Using Union for Python 3.9 compatibility
    result: dict[str, Union[str, float]]


# --- Main Class: SentimentAnalyzer ---
class SentimentAnalyzer:
    """
    Manages sentiment analysis requests to the Hugging Face Inference API.

    This class handles asynchronous API communication, manages a result queue for
    Server-Sent Events (SSE), and encapsulates all related state and logic.
    """

    HF_API_URL = "https://api-inference.huggingface.co/models/distilbert-base-uncased-finetuned-sst-2-english"

    # ====================================================================
    #                       FINAL FIX APPLIED HERE
    # ====================================================================
    # Changed `str | None` to `Optional[str]` for Python 3.9 compatibility.
    def __init__(self, client: httpx.AsyncClient, api_token: Optional[str] = None):
    # ====================================================================
        """
        Initializes the SentimentAnalyzer.

        Args:
            client: An instance of httpx.AsyncClient for making API calls.
            api_token: The Hugging Face API token.
        """
        self.client = client
        self.api_token = api_token or os.getenv("HF_API_TOKEN")
        
        if not self.api_token:
            raise ValueError("Hugging Face API token is not set. Please set the HF_API_TOKEN environment variable.")
            
        self.headers = {"Authorization": f"Bearer {self.api_token}"}
        
        # A queue is the ideal structure for a producer-consumer pattern,
        # where the API endpoint is the producer and SSE streamers are consumers.
        self.result_queue: asyncio.Queue[SentimentResult] = asyncio.Queue()

    async def compute_and_publish(self, text: str, request_id: int) -> None:
        """
        Performs sentiment analysis via an external API and places the result
        into a queue for consumption by SSE streams.

        Args:
            text: The input text to analyze.
            request_id: A unique identifier for this request.
        """
        analysis_result: dict[str, Union[str, float]] = {"label": "ERROR", "score": 0.0, "error": "Unknown failure"}
        try:
            response = await self.client.post(
                self.HF_API_URL,
                headers=self.headers,
                json={"inputs": text, "options": {"wait_for_model": True}},
                timeout=20.0
            )
            response.raise_for_status()
            data = response.json()

            # Validate the expected response structure from the Inference API
            if isinstance(data, list) and data and isinstance(data[0], list) and data[0]:
                # The model returns a list containing a list of results
                res = data[0][0]
                analysis_result = {"label": res.get("label"), "score": round(res.get("score", 0.0), 4)}
                logger.info("βœ… Sentiment computed for request #%d", request_id)
            else:
                raise ValueError(f"Unexpected API response format: {data}")

        except httpx.HTTPStatusError as e:
            error_msg = f"API returned status {e.response.status_code}"
            logger.error("❌ Sentiment API error for request #%d: %s", request_id, error_msg)
            analysis_result["error"] = error_msg
        except httpx.RequestError as e:
            error_msg = f"Network request failed: {e}"
            logger.error("❌ Sentiment network error for request #%d: %s", request_id, error_msg)
            analysis_result["error"] = error_msg
        except (ValueError, KeyError) as e:
            error_msg = f"Failed to parse API response: {e}"
            logger.error("❌ Sentiment parsing error for request #%d: %s", request_id, error_msg)
            analysis_result["error"] = error_msg

        # Always publish a result to the queue, even if it's an error state
        payload: SentimentResult = {
            "id": request_id,
            "text": text,
            "result": analysis_result
        }
        await self.result_queue.put(payload)

    async def stream_results(self): # Type hint removed for simplicity, was -> SentimentResult
        """
        An async generator that yields new results as they become available.
        This is the consumer part of the pattern.
        """
        while True:
            try:
                # This efficiently waits until an item is available in the queue
                result = await self.result_queue.get()
                yield result
                self.result_queue.task_done()
            except asyncio.CancelledError:
                logger.info("Result stream has been cancelled.")
                break