oracle-demo / transcription.py
sasan's picture
Upload folder using huggingface_hub
617df14 verified
raw
history blame
1.31 kB
import numpy as np
import logging
from scipy import signal
logger = logging.getLogger(__name__)
def warmup_model(processor, model, device, np_dtype, torch_dtype, logger):
# Warm up the model with empty audio
logger.info("Warming up Whisper model with dummy input")
warmup_audio = np.zeros((16000,), dtype=np_dtype) # 1s of silence
input_features = processor(warmup_audio, sampling_rate=16000, return_tensors="pt").input_features
input_features = input_features.to(device=device, dtype=torch_dtype) # Convert to correct dtype
model.generate(input_features)
logger.info("Model warmup complete")
def resample_audio(audio, sample_rate):
# Check if audio is of type np.int16 and convert to np.float32 in range -1 to 1
if audio.dtype == np.int16:
logger.info("Converting audio from np.int16 to np.float32 and normalizing to range -1 to 1")
audio = audio.squeeze()
audio = audio.astype(np.float32) / 32768.0 # 32768 is the max value for int16
if sample_rate != 16000:
logger.info(f"Resampling audio from {sample_rate}Hz to 16000Hz len:{len(audio)}")
number_of_samples = round(len(audio) * float(16000) / sample_rate)
audio = signal.resample(audio, number_of_samples)
sample_rate = 16000
return audio, sample_rate