File size: 4,163 Bytes
617df14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import logging
import numpy as np
from scipy import signal
import torch
from transformers.models.auto.modeling_auto import AutoModelForSpeechSeq2Seq
from transformers.models.auto.processing_auto import AutoProcessor
from transformers.utils.import_utils import is_flash_attn_2_available
from pydub import AudioSegment
from transcription import resample_audio
# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    datefmt='%Y-%m-%d %H:%M:%S'
)
logger = logging.getLogger(__name__)

def get_device():
    if torch.cuda.is_available():
        return "cuda"
    return "cpu"

def transcribe_audio(audio_path: str) -> str:
    """
    Transcribe an audio file using Whisper.
    
    Args:
        audio_path: Path to the audio file
        
    Returns:
        Transcribed text
    """
    # Setup device and model
    device = get_device()
    torch_dtype = torch.float16 if device == "cuda" else torch.float32
    logger.info(f"Using device: {device}, torch_dtype: {torch_dtype}")
    attention = "flash_attention_2" if is_flash_attn_2_available() else "sdpa"
    
    # Load model and processor
    stt_model_name = "openai/whisper-large-v2"
    try:
        model = AutoModelForSpeechSeq2Seq.from_pretrained(
            stt_model_name, 
            torch_dtype=torch_dtype, 
            low_cpu_mem_usage=True, 
            use_safetensors=True,
            attn_implementation=attention
        )
        model.to(device)
    except Exception as e:
        logger.error(f"Error loading ASR model: {e}")
        raise

    processor = AutoProcessor.from_pretrained(stt_model_name)
    
    # Add custom vocabulary
    # custom_vocab = [
    #     "LuxDev", "luxdev", "LUXDEV",
    #     "Sasan", "sasan", "SASAN",
    #     "Jafarnejad", "jafarnejad", "JAFARNEJAD",
    #     "Sasan Jafarnejad", "sasan jafarnejad", "SASAN JAFARNEJAD"
    # ]
    # processor.tokenizer.add_tokens(custom_vocab)
    # model.resize_token_embeddings(len(processor.tokenizer))
    
    # Create a custom prompt to guide the model
    initial_prompt = "LuxDev, Sasan, Jafarnejad, LUXDEV"
    prompt_ids = processor.get_prompt_ids(initial_prompt, return_tensors="pt").to(device)
    
    # Load and process audio
    try:
        # Read audio file using pydub
        audio = AudioSegment.from_file(audio_path)
        
        # Convert to mono if stereo
        if audio.channels > 1:
            audio = audio.set_channels(1)
        
        # Convert to numpy array
        audio_array = np.array(audio.get_array_of_samples())
        sample_rate = audio.frame_rate
        
        # Convert to float32 and normalize
        audio_array = audio_array.astype(np.float32) / (2**15 if audio.sample_width == 2 else 2**31)
        
        # Resample to 16kHz if needed
        audio_array, sample_rate = resample_audio(audio_array, sample_rate)
        
        # Process audio input
        input_features = processor(
            audio_array, 
            sampling_rate=sample_rate, 
            return_tensors="pt"
        ).input_features
        input_features = input_features.to(device=device, dtype=torch_dtype)
        
        # Generate transcription with custom prompt
        predicted_ids = model.generate(
            input_features,
            # task="transcribe",
            # language="english",
            max_length=448,
            num_beams=5,
            temperature=0.0,
            no_repeat_ngram_size=3,
            length_penalty=1.0,
            repetition_penalty=1.0,
            # Use the prompt tokens directly
            prompt_ids=prompt_ids
        )
        
        # Decode the transcription
        transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0].strip()
        return transcription
        
    except Exception as e:
        logger.error(f"Error processing audio: {e}")
        raise

if __name__ == "__main__":
    # Example usage
    audio_path = "sample.m4a"
    try:
        transcription = transcribe_audio(audio_path)
        print(f"Transcription: {transcription}")
    except Exception as e:
        print(f"Error: {e}")