Spaces:
Sleeping
Sleeping
| """ | |
| Provides a robust, asynchronous SentimentAnalyzer class. | |
| This module communicates with the Hugging Face Inference API to perform sentiment | |
| analysis without requiring local model downloads. It's designed for use within | |
| an asynchronous application like FastAPI. | |
| """ | |
| import asyncio | |
| import logging | |
| import os | |
| from typing import TypedDict | |
| import httpx | |
| # --- Configuration & Models --- | |
| # Configure logging for this module | |
| logger = logging.getLogger(__name__) | |
| # Define the expected structure of a result payload for type hinting | |
| class SentimentResult(TypedDict): | |
| id: int | |
| text: str | |
| result: dict[str, str | float] | |
| # --- Main Class: SentimentAnalyzer --- | |
| class SentimentAnalyzer: | |
| """ | |
| Manages sentiment analysis requests to the Hugging Face Inference API. | |
| This class handles asynchronous API communication, manages a result queue for | |
| Server-Sent Events (SSE), and encapsulates all related state and logic. | |
| """ | |
| HF_API_URL = "https://api-inference.huggingface.co/models/distilbert-base-uncased-finetuned-sst-2-english" | |
| def __init__(self, client: httpx.AsyncClient, api_token: str | None = None): | |
| """ | |
| Initializes the SentimentAnalyzer. | |
| Args: | |
| client: An instance of httpx.AsyncClient for making API calls. | |
| api_token: The Hugging Face API token. | |
| """ | |
| self.client = client | |
| self.api_token = api_token or os.getenv("HF_API_TOKEN") | |
| if not self.api_token: | |
| raise ValueError("Hugging Face API token is not set. Please set the HF_API_TOKEN environment variable.") | |
| self.headers = {"Authorization": f"Bearer {self.api_token}"} | |
| # A queue is the ideal structure for a producer-consumer pattern, | |
| # where the API endpoint is the producer and SSE streamers are consumers. | |
| self.result_queue: asyncio.Queue[SentimentResult] = asyncio.Queue() | |
| async def compute_and_publish(self, text: str, request_id: int) -> None: | |
| """ | |
| Performs sentiment analysis via an external API and places the result | |
| into a queue for consumption by SSE streams. | |
| Args: | |
| text: The input text to analyze. | |
| request_id: A unique identifier for this request. | |
| """ | |
| analysis_result = {"label": "ERROR", "score": 0.0, "error": "Unknown failure"} | |
| try: | |
| response = await self.client.post( | |
| self.HF_API_URL, | |
| headers=self.headers, | |
| json={"inputs": text, "options": {"wait_for_model": True}}, | |
| timeout=20.0 | |
| ) | |
| response.raise_for_status() | |
| data = response.json() | |
| # Validate the expected response structure from the Inference API | |
| if isinstance(data, list) and data and isinstance(data[0], list) and data[0]: | |
| # The model returns a list containing a list of results | |
| res = data[0][0] | |
| analysis_result = {"label": res.get("label"), "score": round(res.get("score", 0.0), 4)} | |
| logger.info("β Sentiment computed for request #%d", request_id) | |
| else: | |
| raise ValueError(f"Unexpected API response format: {data}") | |
| except httpx.HTTPStatusError as e: | |
| error_msg = f"API returned status {e.response.status_code}" | |
| logger.error("β Sentiment API error for request #%d: %s", request_id, error_msg) | |
| analysis_result["error"] = error_msg | |
| except httpx.RequestError as e: | |
| error_msg = f"Network request failed: {e}" | |
| logger.error("β Sentiment network error for request #%d: %s", request_id, error_msg) | |
| analysis_result["error"] = error_msg | |
| except (ValueError, KeyError) as e: | |
| error_msg = f"Failed to parse API response: {e}" | |
| logger.error("β Sentiment parsing error for request #%d: %s", request_id, error_msg) | |
| analysis_result["error"] = error_msg | |
| # Always publish a result to the queue, even if it's an error state | |
| payload: SentimentResult = { | |
| "id": request_id, | |
| "text": text, | |
| "result": analysis_result | |
| } | |
| await self.result_queue.put(payload) | |
| async def stream_results(self) -> SentimentResult: | |
| """ | |
| An async generator that yields new results as they become available. | |
| This is the consumer part of the pattern. | |
| """ | |
| while True: | |
| try: | |
| # This efficiently waits until an item is available in the queue | |
| result = await self.result_queue.get() | |
| yield result | |
| self.result_queue.task_done() | |
| except asyncio.CancelledError: | |
| logger.info("Result stream has been cancelled.") | |
| break |