File size: 2,434 Bytes
9712d04
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
780954b
9712d04
 
 
 
 
 
 
780954b
9712d04
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import os
from typing import Optional

import librosa
import numpy as np
from transformers.pipelines import pipeline

from .base import AbstractASRModel
from .registry import register_asr_model

hf_token = os.getenv("HF_TOKEN")


@register_asr_model("openai/whisper")
class WhisperASR(AbstractASRModel):
    def __init__(
        self, model_id: str, device: str = "auto", cache_dir: str = "cache", **kwargs
    ):
        super().__init__(model_id, device, cache_dir, **kwargs)
        model_kwargs = kwargs.setdefault("model_kwargs", {})
        model_kwargs["cache_dir"] = cache_dir
        self.pipe = pipeline(
            "automatic-speech-recognition",
            model=model_id,
            device_map=device,
            token=hf_token,
            **kwargs,
        )

    def transcribe(
        self,
        audio: np.ndarray,
        audio_sample_rate: int,
        language: Optional[str] = None,
        **kwargs,
    ) -> str:
        """
        Transcribe audio using Whisper model

        Args:
            audio: Audio numpy array
            audio_sample_rate: Sample rate of the audio
            language: Language hint (optional)

        Returns:
            Transcribed text as string
        """
        try:
            # Resample to 16kHz if needed
            if audio_sample_rate != 16000:
                audio = librosa.resample(
                    audio, orig_sr=audio_sample_rate, target_sr=16000
                )

            # Generate transcription
            generate_kwargs = {}
            if language:
                generate_kwargs["language"] = language

            result = self.pipe(
                audio,
                generate_kwargs=generate_kwargs,
                return_timestamps=False,
                **kwargs,
            )

            # Extract text from result
            if isinstance(result, dict) and "text" in result:
                return result["text"]
            elif isinstance(result, list) and len(result) > 0:
                # Handle list of results
                first_result = result[0]
                if isinstance(first_result, dict):
                    return first_result.get("text", str(first_result))
                else:
                    return str(first_result)
            else:
                return str(result)

        except Exception as e:
            print(f"Error during Whisper transcription: {e}")
            return ""