Spaces:
Runtime error
Runtime error
| import torch | |
| class WhisperMixin: | |
| is_initialized = False | |
| def setup_whisper( | |
| self, | |
| pretrained_model_name_or_path: str = "openai/whisper-base.en", | |
| device: str = torch.device("cuda" if torch.cuda.is_available() else "cpu"), | |
| ): | |
| from transformers import WhisperForConditionalGeneration | |
| from transformers import WhisperProcessor | |
| self.whisper_device = device | |
| self.whisper_processor = WhisperProcessor.from_pretrained( | |
| pretrained_model_name_or_path | |
| ) | |
| self.whisper_model = WhisperForConditionalGeneration.from_pretrained( | |
| pretrained_model_name_or_path | |
| ).to(self.whisper_device) | |
| self.is_initialized = True | |
| def get_whisper_features(self) -> torch.Tensor: | |
| """Preprocess audio signal as per the whisper model's training config. | |
| Returns | |
| ------- | |
| torch.Tensor | |
| The prepinput features of the audio signal. Shape: (1, channels, seq_len) | |
| """ | |
| import torch | |
| if not self.is_initialized: | |
| self.setup_whisper() | |
| signal = self.to(self.device) | |
| raw_speech = list( | |
| ( | |
| signal.clone() | |
| .resample(self.whisper_processor.feature_extractor.sampling_rate) | |
| .audio_data[:, 0, :] | |
| .numpy() | |
| ) | |
| ) | |
| with torch.inference_mode(): | |
| input_features = self.whisper_processor( | |
| raw_speech, | |
| sampling_rate=self.whisper_processor.feature_extractor.sampling_rate, | |
| return_tensors="pt", | |
| ).input_features | |
| return input_features | |
| def get_whisper_transcript(self) -> str: | |
| """Get the transcript of the audio signal using the whisper model. | |
| Returns | |
| ------- | |
| str | |
| The transcript of the audio signal, including special tokens such as <|startoftranscript|> and <|endoftext|>. | |
| """ | |
| if not self.is_initialized: | |
| self.setup_whisper() | |
| input_features = self.get_whisper_features() | |
| with torch.inference_mode(): | |
| input_features = input_features.to(self.whisper_device) | |
| generated_ids = self.whisper_model.generate(inputs=input_features) | |
| transcription = self.whisper_processor.batch_decode(generated_ids) | |
| return transcription[0] | |
| def get_whisper_embeddings(self) -> torch.Tensor: | |
| """Get the last hidden state embeddings of the audio signal using the whisper model. | |
| Returns | |
| ------- | |
| torch.Tensor | |
| The Whisper embeddings of the audio signal. Shape: (1, seq_len, hidden_size) | |
| """ | |
| import torch | |
| if not self.is_initialized: | |
| self.setup_whisper() | |
| input_features = self.get_whisper_features() | |
| encoder = self.whisper_model.get_encoder() | |
| with torch.inference_mode(): | |
| input_features = input_features.to(self.whisper_device) | |
| embeddings = encoder(input_features) | |
| return embeddings.last_hidden_state | |