oracle-demo / transcribe.py
sasan's picture
Upload folder using huggingface_hub
617df14 verified
raw
history blame
4.16 kB
import logging
import numpy as np
from scipy import signal
import torch
from transformers.models.auto.modeling_auto import AutoModelForSpeechSeq2Seq
from transformers.models.auto.processing_auto import AutoProcessor
from transformers.utils.import_utils import is_flash_attn_2_available
from pydub import AudioSegment
from transcription import resample_audio
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
logger = logging.getLogger(__name__)
def get_device():
if torch.cuda.is_available():
return "cuda"
return "cpu"
def transcribe_audio(audio_path: str) -> str:
"""
Transcribe an audio file using Whisper.
Args:
audio_path: Path to the audio file
Returns:
Transcribed text
"""
# Setup device and model
device = get_device()
torch_dtype = torch.float16 if device == "cuda" else torch.float32
logger.info(f"Using device: {device}, torch_dtype: {torch_dtype}")
attention = "flash_attention_2" if is_flash_attn_2_available() else "sdpa"
# Load model and processor
stt_model_name = "openai/whisper-large-v2"
try:
model = AutoModelForSpeechSeq2Seq.from_pretrained(
stt_model_name,
torch_dtype=torch_dtype,
low_cpu_mem_usage=True,
use_safetensors=True,
attn_implementation=attention
)
model.to(device)
except Exception as e:
logger.error(f"Error loading ASR model: {e}")
raise
processor = AutoProcessor.from_pretrained(stt_model_name)
# Add custom vocabulary
# custom_vocab = [
# "LuxDev", "luxdev", "LUXDEV",
# "Sasan", "sasan", "SASAN",
# "Jafarnejad", "jafarnejad", "JAFARNEJAD",
# "Sasan Jafarnejad", "sasan jafarnejad", "SASAN JAFARNEJAD"
# ]
# processor.tokenizer.add_tokens(custom_vocab)
# model.resize_token_embeddings(len(processor.tokenizer))
# Create a custom prompt to guide the model
initial_prompt = "LuxDev, Sasan, Jafarnejad, LUXDEV"
prompt_ids = processor.get_prompt_ids(initial_prompt, return_tensors="pt").to(device)
# Load and process audio
try:
# Read audio file using pydub
audio = AudioSegment.from_file(audio_path)
# Convert to mono if stereo
if audio.channels > 1:
audio = audio.set_channels(1)
# Convert to numpy array
audio_array = np.array(audio.get_array_of_samples())
sample_rate = audio.frame_rate
# Convert to float32 and normalize
audio_array = audio_array.astype(np.float32) / (2**15 if audio.sample_width == 2 else 2**31)
# Resample to 16kHz if needed
audio_array, sample_rate = resample_audio(audio_array, sample_rate)
# Process audio input
input_features = processor(
audio_array,
sampling_rate=sample_rate,
return_tensors="pt"
).input_features
input_features = input_features.to(device=device, dtype=torch_dtype)
# Generate transcription with custom prompt
predicted_ids = model.generate(
input_features,
# task="transcribe",
# language="english",
max_length=448,
num_beams=5,
temperature=0.0,
no_repeat_ngram_size=3,
length_penalty=1.0,
repetition_penalty=1.0,
# Use the prompt tokens directly
prompt_ids=prompt_ids
)
# Decode the transcription
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0].strip()
return transcription
except Exception as e:
logger.error(f"Error processing audio: {e}")
raise
if __name__ == "__main__":
# Example usage
audio_path = "sample.m4a"
try:
transcription = transcribe_audio(audio_path)
print(f"Transcription: {transcription}")
except Exception as e:
print(f"Error: {e}")