Spaces:
Runtime error
Runtime error
File size: 2,735 Bytes
71192d1 ce2ce69 73a6a7e ce2ce69 73a6a7e 71192d1 73a6a7e 71192d1 73a6a7e 71192d1 73a6a7e 71192d1 ce2ce69 73a6a7e ce2ce69 73a6a7e ce2ce69 73a6a7e ce2ce69 73a6a7e ce2ce69 73a6a7e ce2ce69 71192d1 73a6a7e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 |
import logging
import torch
from app.services.base import load_hf_pipeline
from app.core.config import APP_NAME, settings
from app.core.exceptions import ServiceError, ModelNotDownloadedError
logger = logging.getLogger(f"{APP_NAME}.services.tone_classification")
class ToneClassifier:
def __init__(self):
self._classifier = None
def _get_classifier(self):
if self._classifier is None:
self._classifier = load_hf_pipeline(
model_id=settings.TONE_MODEL_ID,
task="text-classification",
feature_name="Tone Classification",
top_k=None
)
return self._classifier
async def classify(self, text: str) -> dict:
try:
text = text.strip()
if not text:
raise ServiceError(status_code=400, detail="Input text is empty for tone classification.")
classifier = self._get_classifier()
raw_results = classifier(text)
if not (isinstance(raw_results, list) and raw_results and isinstance(raw_results[0], list)):
logger.error(f"Unexpected raw_results format from pipeline: {raw_results}")
raise ServiceError(status_code=500, detail="Unexpected model output format for tone classification.")
scores_for_text = raw_results[0]
sorted_emotions = sorted(scores_for_text, key=lambda x: x['score'], reverse=True)
logger.debug(f"Input Text: '{text}'")
logger.debug("--- Emotion Scores (Label: Score) ---")
for emotion in sorted_emotions:
logger.debug(f" {emotion['label']}: {emotion['score']:.4f}")
logger.debug("-------------------------------------")
top_emotion = sorted_emotions[0]
predicted_label = top_emotion.get("label", "Unknown")
predicted_score = top_emotion.get("score", 0.0)
if predicted_score >= settings.TONE_CONFIDENCE_THRESHOLD:
logger.info(f"Final prediction for '{text[:50]}...': '{predicted_label}' (Score: {predicted_score:.4f}, Above Threshold: {settings.TONE_CONFIDENCE_THRESHOLD:.2f})")
return {"tone": predicted_label}
else:
logger.info(f"Final prediction for '{text[:50]}...': 'neutral' (Top Score: {predicted_score:.4f}, Below Threshold: {settings.TONE_CONFIDENCE_THRESHOLD:.2f}).")
return {"tone": "neutral"}
except Exception as e:
logger.error(f"Tone classification unexpected error for text '{text[:50]}...': {e}", exc_info=True)
raise ServiceError(status_code=500, detail="An internal error occurred during tone classification.") from e
|