Spaces:
Runtime error
Runtime error
import logging | |
import torch | |
from app.services.base import load_hf_pipeline | |
from app.core.config import APP_NAME, settings | |
from app.core.exceptions import ServiceError, ModelNotDownloadedError | |
logger = logging.getLogger(f"{APP_NAME}.services.tone_classification") | |
class ToneClassifier: | |
def __init__(self): | |
self._classifier = None | |
def _get_classifier(self): | |
if self._classifier is None: | |
self._classifier = load_hf_pipeline( | |
model_id=settings.TONE_MODEL_ID, | |
task="text-classification", | |
feature_name="Tone Classification", | |
top_k=None | |
) | |
return self._classifier | |
async def classify(self, text: str) -> dict: | |
try: | |
text = text.strip() | |
if not text: | |
raise ServiceError(status_code=400, detail="Input text is empty for tone classification.") | |
classifier = self._get_classifier() | |
raw_results = classifier(text) | |
if not (isinstance(raw_results, list) and raw_results and isinstance(raw_results[0], list)): | |
logger.error(f"Unexpected raw_results format from pipeline: {raw_results}") | |
raise ServiceError(status_code=500, detail="Unexpected model output format for tone classification.") | |
scores_for_text = raw_results[0] | |
sorted_emotions = sorted(scores_for_text, key=lambda x: x['score'], reverse=True) | |
logger.debug(f"Input Text: '{text}'") | |
logger.debug("--- Emotion Scores (Label: Score) ---") | |
for emotion in sorted_emotions: | |
logger.debug(f" {emotion['label']}: {emotion['score']:.4f}") | |
logger.debug("-------------------------------------") | |
top_emotion = sorted_emotions[0] | |
predicted_label = top_emotion.get("label", "Unknown") | |
predicted_score = top_emotion.get("score", 0.0) | |
if predicted_score >= settings.TONE_CONFIDENCE_THRESHOLD: | |
logger.info(f"Final prediction for '{text[:50]}...': '{predicted_label}' (Score: {predicted_score:.4f}, Above Threshold: {settings.TONE_CONFIDENCE_THRESHOLD:.2f})") | |
return {"tone": predicted_label} | |
else: | |
logger.info(f"Final prediction for '{text[:50]}...': 'neutral' (Top Score: {predicted_score:.4f}, Below Threshold: {settings.TONE_CONFIDENCE_THRESHOLD:.2f}).") | |
return {"tone": "neutral"} | |
except Exception as e: | |
logger.error(f"Tone classification unexpected error for text '{text[:50]}...': {e}", exc_info=True) | |
raise ServiceError(status_code=500, detail="An internal error occurred during tone classification.") from e | |