Spaces:
Sleeping
Sleeping
import logging | |
import numpy as np | |
from scipy import signal | |
import torch | |
from transformers.models.auto.modeling_auto import AutoModelForSpeechSeq2Seq | |
from transformers.models.auto.processing_auto import AutoProcessor | |
from transformers.utils.import_utils import is_flash_attn_2_available | |
from pydub import AudioSegment | |
from transcription import resample_audio | |
# Configure logging | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(levelname)s - %(message)s', | |
datefmt='%Y-%m-%d %H:%M:%S' | |
) | |
logger = logging.getLogger(__name__) | |
def get_device(): | |
if torch.cuda.is_available(): | |
return "cuda" | |
return "cpu" | |
def transcribe_audio(audio_path: str) -> str: | |
""" | |
Transcribe an audio file using Whisper. | |
Args: | |
audio_path: Path to the audio file | |
Returns: | |
Transcribed text | |
""" | |
# Setup device and model | |
device = get_device() | |
torch_dtype = torch.float16 if device == "cuda" else torch.float32 | |
logger.info(f"Using device: {device}, torch_dtype: {torch_dtype}") | |
attention = "flash_attention_2" if is_flash_attn_2_available() else "sdpa" | |
# Load model and processor | |
stt_model_name = "openai/whisper-large-v2" | |
try: | |
model = AutoModelForSpeechSeq2Seq.from_pretrained( | |
stt_model_name, | |
torch_dtype=torch_dtype, | |
low_cpu_mem_usage=True, | |
use_safetensors=True, | |
attn_implementation=attention | |
) | |
model.to(device) | |
except Exception as e: | |
logger.error(f"Error loading ASR model: {e}") | |
raise | |
processor = AutoProcessor.from_pretrained(stt_model_name) | |
# Add custom vocabulary | |
# custom_vocab = [ | |
# "LuxDev", "luxdev", "LUXDEV", | |
# "Sasan", "sasan", "SASAN", | |
# "Jafarnejad", "jafarnejad", "JAFARNEJAD", | |
# "Sasan Jafarnejad", "sasan jafarnejad", "SASAN JAFARNEJAD" | |
# ] | |
# processor.tokenizer.add_tokens(custom_vocab) | |
# model.resize_token_embeddings(len(processor.tokenizer)) | |
# Create a custom prompt to guide the model | |
initial_prompt = "LuxDev, Sasan, Jafarnejad, LUXDEV" | |
prompt_ids = processor.get_prompt_ids(initial_prompt, return_tensors="pt").to(device) | |
# Load and process audio | |
try: | |
# Read audio file using pydub | |
audio = AudioSegment.from_file(audio_path) | |
# Convert to mono if stereo | |
if audio.channels > 1: | |
audio = audio.set_channels(1) | |
# Convert to numpy array | |
audio_array = np.array(audio.get_array_of_samples()) | |
sample_rate = audio.frame_rate | |
# Convert to float32 and normalize | |
audio_array = audio_array.astype(np.float32) / (2**15 if audio.sample_width == 2 else 2**31) | |
# Resample to 16kHz if needed | |
audio_array, sample_rate = resample_audio(audio_array, sample_rate) | |
# Process audio input | |
input_features = processor( | |
audio_array, | |
sampling_rate=sample_rate, | |
return_tensors="pt" | |
).input_features | |
input_features = input_features.to(device=device, dtype=torch_dtype) | |
# Generate transcription with custom prompt | |
predicted_ids = model.generate( | |
input_features, | |
# task="transcribe", | |
# language="english", | |
max_length=448, | |
num_beams=5, | |
temperature=0.0, | |
no_repeat_ngram_size=3, | |
length_penalty=1.0, | |
repetition_penalty=1.0, | |
# Use the prompt tokens directly | |
prompt_ids=prompt_ids | |
) | |
# Decode the transcription | |
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0].strip() | |
return transcription | |
except Exception as e: | |
logger.error(f"Error processing audio: {e}") | |
raise | |
if __name__ == "__main__": | |
# Example usage | |
audio_path = "sample.m4a" | |
try: | |
transcription = transcribe_audio(audio_path) | |
print(f"Transcription: {transcription}") | |
except Exception as e: | |
print(f"Error: {e}") |