import logging import numpy as np from scipy import signal import torch from transformers.models.auto.modeling_auto import AutoModelForSpeechSeq2Seq from transformers.models.auto.processing_auto import AutoProcessor from transformers.utils.import_utils import is_flash_attn_2_available from pydub import AudioSegment from transcription import resample_audio # Configure logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S' ) logger = logging.getLogger(__name__) def get_device(): if torch.cuda.is_available(): return "cuda" return "cpu" def transcribe_audio(audio_path: str) -> str: """ Transcribe an audio file using Whisper. Args: audio_path: Path to the audio file Returns: Transcribed text """ # Setup device and model device = get_device() torch_dtype = torch.float16 if device == "cuda" else torch.float32 logger.info(f"Using device: {device}, torch_dtype: {torch_dtype}") attention = "flash_attention_2" if is_flash_attn_2_available() else "sdpa" # Load model and processor stt_model_name = "openai/whisper-large-v2" try: model = AutoModelForSpeechSeq2Seq.from_pretrained( stt_model_name, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True, attn_implementation=attention ) model.to(device) except Exception as e: logger.error(f"Error loading ASR model: {e}") raise processor = AutoProcessor.from_pretrained(stt_model_name) # Add custom vocabulary # custom_vocab = [ # "LuxDev", "luxdev", "LUXDEV", # "Sasan", "sasan", "SASAN", # "Jafarnejad", "jafarnejad", "JAFARNEJAD", # "Sasan Jafarnejad", "sasan jafarnejad", "SASAN JAFARNEJAD" # ] # processor.tokenizer.add_tokens(custom_vocab) # model.resize_token_embeddings(len(processor.tokenizer)) # Create a custom prompt to guide the model initial_prompt = "LuxDev, Sasan, Jafarnejad, LUXDEV" prompt_ids = processor.get_prompt_ids(initial_prompt, return_tensors="pt").to(device) # Load and process audio try: # Read audio file using pydub audio = AudioSegment.from_file(audio_path) # Convert to mono if stereo if audio.channels > 1: audio = audio.set_channels(1) # Convert to numpy array audio_array = np.array(audio.get_array_of_samples()) sample_rate = audio.frame_rate # Convert to float32 and normalize audio_array = audio_array.astype(np.float32) / (2**15 if audio.sample_width == 2 else 2**31) # Resample to 16kHz if needed audio_array, sample_rate = resample_audio(audio_array, sample_rate) # Process audio input input_features = processor( audio_array, sampling_rate=sample_rate, return_tensors="pt" ).input_features input_features = input_features.to(device=device, dtype=torch_dtype) # Generate transcription with custom prompt predicted_ids = model.generate( input_features, # task="transcribe", # language="english", max_length=448, num_beams=5, temperature=0.0, no_repeat_ngram_size=3, length_penalty=1.0, repetition_penalty=1.0, # Use the prompt tokens directly prompt_ids=prompt_ids ) # Decode the transcription transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0].strip() return transcription except Exception as e: logger.error(f"Error processing audio: {e}") raise if __name__ == "__main__": # Example usage audio_path = "sample.m4a" try: transcription = transcribe_audio(audio_path) print(f"Transcription: {transcription}") except Exception as e: print(f"Error: {e}")