wellsaid / app /services /tone_classification.py
iamspruce
fixed the api
73a6a7e
import logging
import torch
from app.services.base import load_hf_pipeline
from app.core.config import APP_NAME, settings
from app.core.exceptions import ServiceError, ModelNotDownloadedError
logger = logging.getLogger(f"{APP_NAME}.services.tone_classification")
class ToneClassifier:
def __init__(self):
self._classifier = None
def _get_classifier(self):
if self._classifier is None:
self._classifier = load_hf_pipeline(
model_id=settings.TONE_MODEL_ID,
task="text-classification",
feature_name="Tone Classification",
top_k=None
)
return self._classifier
async def classify(self, text: str) -> dict:
try:
text = text.strip()
if not text:
raise ServiceError(status_code=400, detail="Input text is empty for tone classification.")
classifier = self._get_classifier()
raw_results = classifier(text)
if not (isinstance(raw_results, list) and raw_results and isinstance(raw_results[0], list)):
logger.error(f"Unexpected raw_results format from pipeline: {raw_results}")
raise ServiceError(status_code=500, detail="Unexpected model output format for tone classification.")
scores_for_text = raw_results[0]
sorted_emotions = sorted(scores_for_text, key=lambda x: x['score'], reverse=True)
logger.debug(f"Input Text: '{text}'")
logger.debug("--- Emotion Scores (Label: Score) ---")
for emotion in sorted_emotions:
logger.debug(f" {emotion['label']}: {emotion['score']:.4f}")
logger.debug("-------------------------------------")
top_emotion = sorted_emotions[0]
predicted_label = top_emotion.get("label", "Unknown")
predicted_score = top_emotion.get("score", 0.0)
if predicted_score >= settings.TONE_CONFIDENCE_THRESHOLD:
logger.info(f"Final prediction for '{text[:50]}...': '{predicted_label}' (Score: {predicted_score:.4f}, Above Threshold: {settings.TONE_CONFIDENCE_THRESHOLD:.2f})")
return {"tone": predicted_label}
else:
logger.info(f"Final prediction for '{text[:50]}...': 'neutral' (Top Score: {predicted_score:.4f}, Below Threshold: {settings.TONE_CONFIDENCE_THRESHOLD:.2f}).")
return {"tone": "neutral"}
except Exception as e:
logger.error(f"Tone classification unexpected error for text '{text[:50]}...': {e}", exc_info=True)
raise ServiceError(status_code=500, detail="An internal error occurred during tone classification.") from e