jhansss commited on
Commit
9712d04
·
1 Parent(s): bb840e1

Refactor ASR module; add Paraformer support

Browse files
config/interface/options.yaml CHANGED
@@ -7,10 +7,8 @@ asr_models:
7
  name: Whisper medium
8
  - id: openai/whisper-small
9
  name: Whisper small
10
- - id: sanchit-gandhi/whisper-small-dv
11
- name: Whisper small-dv
12
- - id: facebook/wav2vec2-base-960h
13
- name: Wav2Vec2-Base-960h
14
 
15
  llm_models:
16
  - id: gemini-2.5-flash
 
7
  name: Whisper medium
8
  - id: openai/whisper-small
9
  name: Whisper small
10
+ - id: funasr/paraformer-zh
11
+ name: Paraformer-zh
 
 
12
 
13
  llm_models:
14
  - id: gemini-2.5-flash
modules/asr/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .base import AbstractASRModel
2
+ from .registry import ASR_MODEL_REGISTRY, get_asr_model, register_asr_model
3
+ from .whisper import WhisperASR
4
+ from .paraformer import ParaformerASR
5
+
6
+ __all__ = [
7
+ "AbstractASRModel",
8
+ "get_asr_model",
9
+ "register_asr_model",
10
+ "ASR_MODEL_REGISTRY",
11
+ ]
modules/asr/base.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from typing import Optional
3
+
4
+ import numpy as np
5
+
6
+
7
+ class AbstractASRModel(ABC):
8
+ def __init__(
9
+ self, model_id: str, device: str = "cpu", cache_dir: str = "cache", **kwargs
10
+ ):
11
+ print(f"Loading ASR model {model_id}...")
12
+ self.model_id = model_id
13
+ self.device = device
14
+ self.cache_dir = cache_dir
15
+
16
+ @abstractmethod
17
+ def transcribe(
18
+ self,
19
+ audio: np.ndarray,
20
+ audio_sample_rate: int,
21
+ language: Optional[str] = None,
22
+ **kwargs,
23
+ ) -> str:
24
+ pass
modules/asr/paraformer.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import tempfile
3
+ from typing import Optional
4
+
5
+ import numpy as np
6
+ import soundfile as sf
7
+
8
+ try:
9
+ from funasr import AutoModel
10
+ except ImportError:
11
+ AutoModel = None
12
+
13
+ from .base import AbstractASRModel
14
+ from .registry import register_asr_model
15
+
16
+
17
+ @register_asr_model("funasr/paraformer-zh")
18
+ class ParaformerASR(AbstractASRModel):
19
+ def __init__(
20
+ self, model_id: str, device: str = "cpu", cache_dir: str = "cache", **kwargs
21
+ ):
22
+ super().__init__(model_id, device, cache_dir, **kwargs)
23
+
24
+ if AutoModel is None:
25
+ raise ImportError(
26
+ "funasr is not installed. Please install it with: pip3 install -U funasr"
27
+ )
28
+
29
+ model_name = model_id.replace("funasr/", "")
30
+ language = model_name.split("-")[1]
31
+ if language == "zh":
32
+ self.language = "mandarin"
33
+ elif language == "en":
34
+ self.language = "english"
35
+ else:
36
+ raise ValueError(
37
+ f"Language cannot be determined. {model_id} is not supported"
38
+ )
39
+
40
+ try:
41
+ original_cache_dir = os.getenv("MODELSCOPE_CACHE")
42
+ os.makedirs(cache_dir, exist_ok=True)
43
+ os.environ["MODELSCOPE_CACHE"] = cache_dir
44
+ self.model = AutoModel(
45
+ model=model_name,
46
+ model_revision="v2.0.4",
47
+ vad_model="fsmn-vad",
48
+ vad_model_revision="v2.0.4",
49
+ punc_model="ct-punc-c",
50
+ punc_model_revision="v2.0.4",
51
+ device=device,
52
+ )
53
+ if original_cache_dir:
54
+ os.environ["MODELSCOPE_CACHE"] = original_cache_dir
55
+ else:
56
+ del os.environ["MODELSCOPE_CACHE"]
57
+
58
+ except Exception as e:
59
+ raise ValueError(f"Error loading Paraformer model: {e}")
60
+
61
+ def transcribe(
62
+ self,
63
+ audio: np.ndarray,
64
+ audio_sample_rate: int,
65
+ language: Optional[str] = None,
66
+ **kwargs,
67
+ ) -> str:
68
+ if language and language != self.language:
69
+ raise ValueError(
70
+ f"Paraformer model {self.model_id} only supports {self.language} language, but {language} was requested"
71
+ )
72
+
73
+ try:
74
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
75
+ sf.write(f.name, audio, audio_sample_rate)
76
+ temp_file = f.name
77
+
78
+ result = self.model.generate(input=temp_file, batch_size_s=300, **kwargs)
79
+
80
+ os.unlink(temp_file)
81
+
82
+ print(f"Transcription result: {result}, type: {type(result)}")
83
+
84
+ return result[0]["text"]
85
+ except Exception as e:
86
+ raise ValueError(f"Error during transcription: {e}")
modules/asr/registry.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .base import AbstractASRModel
2
+
3
+ ASR_MODEL_REGISTRY = {}
4
+
5
+
6
+ def register_asr_model(prefix: str):
7
+ def wrapper(cls):
8
+ assert issubclass(cls, AbstractASRModel), f"{cls} must inherit AbstractASRModel"
9
+ ASR_MODEL_REGISTRY[prefix] = cls
10
+ return cls
11
+
12
+ return wrapper
13
+
14
+
15
+ def get_asr_model(model_id: str, device="cpu", **kwargs) -> AbstractASRModel:
16
+ for prefix, cls in ASR_MODEL_REGISTRY.items():
17
+ if model_id.startswith(prefix):
18
+ return cls(model_id, device=device, **kwargs)
19
+ raise ValueError(f"No ASR wrapper found for model: {model_id}")
modules/asr/whisper.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Optional
3
+
4
+ import librosa
5
+ import numpy as np
6
+ from transformers.pipelines import pipeline
7
+
8
+ from .base import AbstractASRModel
9
+ from .registry import register_asr_model
10
+
11
+ hf_token = os.getenv("HF_TOKEN")
12
+
13
+
14
+ @register_asr_model("openai/whisper")
15
+ class WhisperASR(AbstractASRModel):
16
+ def __init__(
17
+ self, model_id: str, device: str = "cpu", cache_dir: str = "cache", **kwargs
18
+ ):
19
+ super().__init__(model_id, device, cache_dir, **kwargs)
20
+ model_kwargs = kwargs.setdefault("model_kwargs", {})
21
+ model_kwargs["cache_dir"] = cache_dir
22
+ self.pipe = pipeline(
23
+ "automatic-speech-recognition",
24
+ model=model_id,
25
+ device=0 if device == "cuda" else -1,
26
+ token=hf_token,
27
+ **kwargs,
28
+ )
29
+
30
+ def transcribe(
31
+ self,
32
+ audio: np.ndarray,
33
+ audio_sample_rate: int,
34
+ language: Optional[str] = None,
35
+ **kwargs,
36
+ ) -> str:
37
+ """
38
+ Transcribe audio using Whisper model
39
+
40
+ Args:
41
+ audio: Audio numpy array
42
+ audio_sample_rate: Sample rate of the audio
43
+ language: Language hint (optional)
44
+
45
+ Returns:
46
+ Transcribed text as string
47
+ """
48
+ try:
49
+ # Resample to 16kHz if needed
50
+ if audio_sample_rate != 16000:
51
+ audio = librosa.resample(
52
+ audio, orig_sr=audio_sample_rate, target_sr=16000
53
+ )
54
+
55
+ # Generate transcription
56
+ generate_kwargs = {}
57
+ if language:
58
+ generate_kwargs["language"] = language
59
+
60
+ result = self.pipe(
61
+ audio,
62
+ generate_kwargs=generate_kwargs,
63
+ return_timestamps=False,
64
+ **kwargs,
65
+ )
66
+
67
+ # Extract text from result
68
+ if isinstance(result, dict) and "text" in result:
69
+ return result["text"]
70
+ elif isinstance(result, list) and len(result) > 0:
71
+ # Handle list of results
72
+ first_result = result[0]
73
+ if isinstance(first_result, dict):
74
+ return first_result.get("text", str(first_result))
75
+ else:
76
+ return str(first_result)
77
+ else:
78
+ return str(result)
79
+
80
+ except Exception as e:
81
+ print(f"Error during Whisper transcription: {e}")
82
+ return ""
tests/test_asr_infer.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from modules.asr import get_asr_model
2
+ import librosa
3
+
4
+ if __name__ == "__main__":
5
+ supported_asrs = [
6
+ "funasr/paraformer-zh",
7
+ "openai/whisper-large-v3-turbo",
8
+ ]
9
+ for model_id in supported_asrs:
10
+ try:
11
+ print(f"Loading model: {model_id}")
12
+ asr_model = get_asr_model(model_id, device="cpu", cache_dir=".cache")
13
+ audio, sample_rate = librosa.load("tests/audio/hello.wav", sr=None)
14
+ result = asr_model.transcribe(audio, sample_rate, language="mandarin")
15
+ print(result)
16
+ except Exception as e:
17
+ print(f"Failed to load model {model_id}: {e}")
18
+ breakpoint()
19
+ continue