File size: 2,735 Bytes
71192d1
ce2ce69
73a6a7e
 
 
ce2ce69
73a6a7e
71192d1
 
 
73a6a7e
 
 
 
 
 
 
 
 
71192d1
73a6a7e
71192d1
73a6a7e
71192d1
ce2ce69
 
73a6a7e
ce2ce69
73a6a7e
 
ce2ce69
 
 
73a6a7e
ce2ce69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73a6a7e
 
ce2ce69
73a6a7e
 
ce2ce69
71192d1
73a6a7e
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import logging
import torch
from app.services.base import load_hf_pipeline
from app.core.config import APP_NAME, settings
from app.core.exceptions import ServiceError, ModelNotDownloadedError

logger = logging.getLogger(f"{APP_NAME}.services.tone_classification")

class ToneClassifier:
    def __init__(self):
        self._classifier = None

    def _get_classifier(self):
        if self._classifier is None:
            self._classifier = load_hf_pipeline(
                model_id=settings.TONE_MODEL_ID,
                task="text-classification",
                feature_name="Tone Classification",
                top_k=None
            )
        return self._classifier

    async def classify(self, text: str) -> dict:
        try:
            text = text.strip()
            if not text:
                raise ServiceError(status_code=400, detail="Input text is empty for tone classification.")

            classifier = self._get_classifier()
            raw_results = classifier(text)

            if not (isinstance(raw_results, list) and raw_results and isinstance(raw_results[0], list)):
                logger.error(f"Unexpected raw_results format from pipeline: {raw_results}")
                raise ServiceError(status_code=500, detail="Unexpected model output format for tone classification.")

            scores_for_text = raw_results[0]
            sorted_emotions = sorted(scores_for_text, key=lambda x: x['score'], reverse=True)

            logger.debug(f"Input Text: '{text}'")
            logger.debug("--- Emotion Scores (Label: Score) ---")
            for emotion in sorted_emotions:
                logger.debug(f"  {emotion['label']}: {emotion['score']:.4f}")
            logger.debug("-------------------------------------")

            top_emotion = sorted_emotions[0]
            predicted_label = top_emotion.get("label", "Unknown")
            predicted_score = top_emotion.get("score", 0.0)

            if predicted_score >= settings.TONE_CONFIDENCE_THRESHOLD:
                logger.info(f"Final prediction for '{text[:50]}...': '{predicted_label}' (Score: {predicted_score:.4f}, Above Threshold: {settings.TONE_CONFIDENCE_THRESHOLD:.2f})")
                return {"tone": predicted_label}
            else:
                logger.info(f"Final prediction for '{text[:50]}...': 'neutral' (Top Score: {predicted_score:.4f}, Below Threshold: {settings.TONE_CONFIDENCE_THRESHOLD:.2f}).")
                return {"tone": "neutral"}

        except Exception as e:
            logger.error(f"Tone classification unexpected error for text '{text[:50]}...': {e}", exc_info=True)
            raise ServiceError(status_code=500, detail="An internal error occurred during tone classification.") from e