|
import json |
|
import re |
|
import math |
|
from collections import OrderedDict |
|
from pathlib import Path |
|
from typing import Union |
|
|
|
import torch |
|
import numpy as np |
|
import torch.nn.functional as F |
|
from whisper.tokenizer import get_tokenizer |
|
from whisper_live.transcriber.tensorrt_utils import ( |
|
mel_filters, |
|
load_audio_wav_format, |
|
pad_or_trim, |
|
load_audio |
|
) |
|
|
|
import tensorrt_llm |
|
import tensorrt_llm.logger as logger |
|
from tensorrt_llm._utils import (str_dtype_to_torch, str_dtype_to_trt, |
|
trt_dtype_to_torch) |
|
from tensorrt_llm.bindings import GptJsonConfig, KVCacheType |
|
from tensorrt_llm.runtime import PYTHON_BINDINGS, ModelConfig, SamplingConfig |
|
from tensorrt_llm.runtime.session import Session, TensorInfo |
|
if PYTHON_BINDINGS: |
|
from tensorrt_llm.runtime import ModelRunnerCpp |
|
|
|
SAMPLE_RATE = 16000 |
|
N_FFT = 400 |
|
HOP_LENGTH = 160 |
|
CHUNK_LENGTH = 30 |
|
N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE |
|
|
|
def read_config(component, engine_dir): |
|
config_path = engine_dir / component / 'config.json' |
|
with open(config_path, 'r') as f: |
|
config = json.load(f) |
|
model_config = OrderedDict() |
|
model_config.update(config['pretrained_config']) |
|
model_config.update(config['build_config']) |
|
return model_config |
|
|
|
|
|
def remove_tensor_padding(input_tensor, |
|
input_tensor_lengths=None, |
|
pad_value=None): |
|
if pad_value: |
|
assert input_tensor_lengths is None, "input_tensor_lengths should be None when pad_value is provided" |
|
|
|
assert torch.all( |
|
input_tensor[:, 0] != pad_value |
|
), "First token in each sequence should not be pad_value" |
|
assert input_tensor_lengths is None |
|
|
|
|
|
mask = input_tensor != pad_value |
|
|
|
|
|
output_tensor = input_tensor[mask].view(1, -1) |
|
|
|
else: |
|
|
|
|
|
assert input_tensor_lengths is not None, "input_tensor_lengths must be provided for 3D input_tensor" |
|
|
|
|
|
valid_sequences = [] |
|
|
|
for i in range(input_tensor.shape[0]): |
|
valid_length = input_tensor_lengths[i] |
|
valid_sequences.append(input_tensor[i, :valid_length]) |
|
|
|
|
|
output_tensor = torch.cat(valid_sequences, dim=0) |
|
return output_tensor |
|
|
|
|
|
class WhisperEncoding: |
|
|
|
def __init__(self, engine_dir): |
|
self.session = self.get_session(engine_dir) |
|
config = read_config('encoder', engine_dir) |
|
self.n_mels = config['n_mels'] |
|
self.dtype = config['dtype'] |
|
self.num_languages = config['num_languages'] |
|
self.encoder_config = config |
|
|
|
def get_session(self, engine_dir): |
|
serialize_path = engine_dir / 'encoder' / 'rank0.engine' |
|
with open(serialize_path, 'rb') as f: |
|
session = Session.from_serialized_engine(f.read()) |
|
return session |
|
|
|
def get_audio_features(self, |
|
mel, |
|
mel_input_lengths, |
|
encoder_downsampling_factor=2): |
|
if isinstance(mel, list): |
|
longest_mel = max([f.shape[-1] for f in mel]) |
|
mel = [ |
|
torch.nn.functional.pad(f, (0, longest_mel - f.shape[-1]), |
|
mode='constant') for f in mel |
|
] |
|
mel = torch.cat(mel, dim=0).type( |
|
str_dtype_to_torch("float16")).contiguous() |
|
bsz, seq_len = mel.shape[0], mel.shape[2] |
|
position_ids = torch.arange( |
|
math.ceil(seq_len / encoder_downsampling_factor), |
|
dtype=torch.int32, |
|
device=mel.device).expand(bsz, -1).contiguous() |
|
if self.encoder_config['plugin_config']['remove_input_padding']: |
|
|
|
mel = mel.transpose(1, 2) |
|
mel = remove_tensor_padding(mel, mel_input_lengths) |
|
position_ids = remove_tensor_padding( |
|
position_ids, mel_input_lengths // encoder_downsampling_factor) |
|
inputs = OrderedDict() |
|
inputs['input_features'] = mel |
|
inputs['input_lengths'] = mel_input_lengths |
|
inputs['position_ids'] = position_ids |
|
|
|
output_list = [ |
|
TensorInfo('input_features', str_dtype_to_trt(self.dtype), |
|
mel.shape), |
|
TensorInfo('input_lengths', str_dtype_to_trt('int32'), |
|
mel_input_lengths.shape), |
|
TensorInfo('position_ids', str_dtype_to_trt('int32'), |
|
inputs['position_ids'].shape) |
|
] |
|
|
|
output_info = (self.session).infer_shapes(output_list) |
|
|
|
logger.debug(f'output info {output_info}') |
|
outputs = { |
|
t.name: torch.empty(tuple(t.shape), |
|
dtype=trt_dtype_to_torch(t.dtype), |
|
device='cuda') |
|
for t in output_info |
|
} |
|
stream = torch.cuda.current_stream() |
|
ok = self.session.run(inputs=inputs, |
|
outputs=outputs, |
|
stream=stream.cuda_stream) |
|
assert ok, 'Engine execution failed' |
|
stream.synchronize() |
|
encoder_output = outputs['encoder_output'] |
|
encoder_output_lengths = mel_input_lengths // encoder_downsampling_factor |
|
return encoder_output, encoder_output_lengths |
|
|
|
|
|
class WhisperDecoding: |
|
|
|
def __init__(self, engine_dir, runtime_mapping, debug_mode=False): |
|
|
|
self.decoder_config = read_config('decoder', engine_dir) |
|
self.decoder_generation_session = self.get_session( |
|
engine_dir, runtime_mapping, debug_mode) |
|
|
|
def get_session(self, engine_dir, runtime_mapping, debug_mode=False): |
|
serialize_path = engine_dir / 'decoder' / 'rank0.engine' |
|
with open(serialize_path, "rb") as f: |
|
decoder_engine_buffer = f.read() |
|
|
|
decoder_model_config = ModelConfig( |
|
max_batch_size=self.decoder_config['max_batch_size'], |
|
max_beam_width=self.decoder_config['max_beam_width'], |
|
num_heads=self.decoder_config['num_attention_heads'], |
|
num_kv_heads=self.decoder_config['num_attention_heads'], |
|
hidden_size=self.decoder_config['hidden_size'], |
|
vocab_size=self.decoder_config['vocab_size'], |
|
cross_attention=True, |
|
num_layers=self.decoder_config['num_hidden_layers'], |
|
gpt_attention_plugin=self.decoder_config['plugin_config'] |
|
['gpt_attention_plugin'], |
|
remove_input_padding=self.decoder_config['plugin_config'] |
|
['remove_input_padding'], |
|
kv_cache_type=KVCacheType.PAGED |
|
if self.decoder_config['plugin_config']['paged_kv_cache'] == True |
|
else KVCacheType.CONTINUOUS, |
|
has_position_embedding=self. |
|
decoder_config['has_position_embedding'], |
|
dtype=self.decoder_config['dtype'], |
|
has_token_type_embedding=False, |
|
) |
|
decoder_generation_session = tensorrt_llm.runtime.GenerationSession( |
|
decoder_model_config, |
|
decoder_engine_buffer, |
|
runtime_mapping, |
|
debug_mode=debug_mode) |
|
|
|
return decoder_generation_session |
|
|
|
def generate(self, |
|
decoder_input_ids, |
|
encoder_outputs, |
|
encoder_max_input_length, |
|
encoder_input_lengths, |
|
eot_id, |
|
max_new_tokens=40, |
|
num_beams=1): |
|
batch_size = decoder_input_ids.shape[0] |
|
decoder_input_lengths = torch.tensor([ |
|
decoder_input_ids.shape[-1] |
|
for _ in range(decoder_input_ids.shape[0]) |
|
], |
|
dtype=torch.int32, |
|
device='cuda') |
|
decoder_max_input_length = torch.max(decoder_input_lengths).item() |
|
|
|
cross_attention_mask = torch.ones([ |
|
batch_size, decoder_max_input_length + max_new_tokens, |
|
encoder_max_input_length |
|
]).int().cuda() |
|
|
|
sampling_config = SamplingConfig(end_id=eot_id, |
|
pad_id=eot_id, |
|
num_beams=num_beams) |
|
self.decoder_generation_session.setup( |
|
decoder_input_lengths.size(0), |
|
decoder_max_input_length, |
|
max_new_tokens, |
|
beam_width=num_beams, |
|
encoder_max_input_length=encoder_max_input_length) |
|
|
|
torch.cuda.synchronize() |
|
|
|
decoder_input_ids = decoder_input_ids.type(torch.int32).cuda() |
|
if self.decoder_config['plugin_config']['remove_input_padding']: |
|
|
|
WHISPER_PAD_TOKEN_ID = 50256 |
|
decoder_input_ids = remove_tensor_padding( |
|
decoder_input_ids, pad_value=WHISPER_PAD_TOKEN_ID) |
|
if encoder_outputs.dim() == 3: |
|
encoder_output_lens = torch.full((encoder_outputs.shape[0], ), |
|
encoder_outputs.shape[1], |
|
dtype=torch.int32, |
|
device='cuda') |
|
|
|
encoder_outputs = remove_tensor_padding(encoder_outputs, |
|
encoder_output_lens) |
|
output_ids = self.decoder_generation_session.decode( |
|
decoder_input_ids, |
|
decoder_input_lengths, |
|
sampling_config, |
|
encoder_output=encoder_outputs, |
|
encoder_input_lengths=encoder_input_lengths, |
|
cross_attention_mask=cross_attention_mask, |
|
) |
|
torch.cuda.synchronize() |
|
|
|
|
|
output_ids = output_ids.cpu().numpy().tolist() |
|
return output_ids |
|
|
|
|
|
class WhisperTRTLLM(object): |
|
|
|
def __init__(self, |
|
engine_dir, |
|
assets_dir=None, |
|
device=None, |
|
is_multilingual=False, |
|
language="en", |
|
task="transcribe", |
|
use_py_session=False, |
|
num_beams=1, |
|
debug_mode=False, |
|
max_output_len=96): |
|
world_size = 1 |
|
runtime_rank = tensorrt_llm.mpi_rank() |
|
runtime_mapping = tensorrt_llm.Mapping(world_size, runtime_rank) |
|
torch.cuda.set_device(runtime_rank % runtime_mapping.gpus_per_node) |
|
engine_dir = Path(engine_dir) |
|
encoder_config = read_config('encoder', engine_dir) |
|
decoder_config = read_config('decoder', engine_dir) |
|
self.n_mels = encoder_config['n_mels'] |
|
self.num_languages = encoder_config['num_languages'] |
|
is_multilingual = (decoder_config['vocab_size'] >= 51865) |
|
|
|
self.device = device |
|
self.tokenizer = get_tokenizer( |
|
is_multilingual, |
|
num_languages=self.num_languages, |
|
language=language, |
|
task=task, |
|
) |
|
|
|
if use_py_session: |
|
self.encoder = WhisperEncoding(engine_dir) |
|
self.decoder = WhisperDecoding(engine_dir, |
|
runtime_mapping, |
|
debug_mode=False) |
|
else: |
|
json_config = GptJsonConfig.parse_file(engine_dir / 'decoder' / |
|
'config.json') |
|
assert json_config.model_config.supports_inflight_batching |
|
runner_kwargs = dict(engine_dir=engine_dir, |
|
is_enc_dec=True, |
|
max_batch_size=1, |
|
max_input_len=3000, |
|
max_output_len=max_output_len, |
|
max_beam_width=num_beams, |
|
debug_mode=debug_mode, |
|
kv_cache_free_gpu_memory_fraction=0.9, |
|
cross_kv_cache_fraction=0.5) |
|
self.model_runner_cpp = ModelRunnerCpp.from_dir(**runner_kwargs) |
|
self.filters = mel_filters(self.device, self.n_mels, assets_dir) |
|
self.use_py_session = use_py_session |
|
|
|
def log_mel_spectrogram( |
|
self, |
|
audio: Union[str, np.ndarray, torch.Tensor], |
|
padding: int = 0, |
|
return_duration=True |
|
): |
|
""" |
|
Compute the log-Mel spectrogram of |
|
|
|
Parameters |
|
---------- |
|
audio: Union[str, np.ndarray, torch.Tensor], shape = (*) |
|
The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz |
|
|
|
n_mels: int |
|
The number of Mel-frequency filters, only 80 and 128 are supported |
|
|
|
padding: int |
|
Number of zero samples to pad to the right |
|
|
|
device: Optional[Union[str, torch.device]] |
|
If given, the audio tensor is moved to this device before STFT |
|
|
|
Returns |
|
------- |
|
torch.Tensor, shape = (80 or 128, n_frames) |
|
A Tensor that contains the Mel spectrogram |
|
""" |
|
if not torch.is_tensor(audio): |
|
if isinstance(audio, str): |
|
if audio.endswith('.wav'): |
|
audio, _ = load_audio_wav_format(audio) |
|
else: |
|
audio = load_audio(audio) |
|
assert isinstance(audio, np.ndarray), f"Unsupported audio type: {type(audio)}" |
|
duration = audio.shape[-1] / SAMPLE_RATE |
|
audio = pad_or_trim(audio, N_SAMPLES) |
|
audio = audio.astype(np.float32) |
|
audio = torch.from_numpy(audio) |
|
|
|
if self.device is not None: |
|
audio = audio.to(self.device) |
|
if padding > 0: |
|
audio = F.pad(audio, (0, padding)) |
|
window = torch.hann_window(N_FFT).to(audio.device) |
|
stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True) |
|
magnitudes = stft[..., :-1].abs()**2 |
|
|
|
mel_spec = self.filters @ magnitudes |
|
|
|
log_spec = torch.clamp(mel_spec, min=1e-10).log10() |
|
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0) |
|
log_spec = (log_spec + 4.0) / 4.0 |
|
if return_duration: |
|
return log_spec, duration |
|
else: |
|
return log_spec |
|
|
|
def process_batch( |
|
self, |
|
mel, |
|
mel_input_lengths, |
|
text_prefix="<|startoftranscript|><|en|><|transcribe|><|notimestamps|>", |
|
num_beams=1, |
|
max_new_tokens=96): |
|
prompt_id = self.tokenizer.encode( |
|
text_prefix, allowed_special=set(self.tokenizer.special_tokens.keys())) |
|
|
|
prompt_id = torch.tensor(prompt_id) |
|
batch_size = mel.shape[0] |
|
decoder_input_ids = prompt_id.repeat(batch_size, 1) |
|
if self.use_py_session: |
|
encoder_output, encoder_output_lengths = self.encoder.get_audio_features(mel, mel_input_lengths) |
|
encoder_max_input_length = torch.max(encoder_output_lengths).item() |
|
output_ids = self.decoder.generate(decoder_input_ids, |
|
encoder_output, |
|
encoder_max_input_length, |
|
encoder_output_lengths, |
|
self.tokenizer.eot, |
|
max_new_tokens=max_new_tokens, |
|
num_beams=num_beams) |
|
else: |
|
with torch.no_grad(): |
|
if isinstance(mel, list): |
|
mel = [ |
|
m.transpose(1, 2).type( |
|
str_dtype_to_torch("float16")).squeeze(0) |
|
for m in mel |
|
] |
|
else: |
|
mel = mel.transpose(1, 2) |
|
outputs = self.model_runner_cpp.generate( |
|
batch_input_ids=decoder_input_ids, |
|
encoder_input_features=mel, |
|
encoder_output_lengths=mel_input_lengths // 2, |
|
max_new_tokens=max_new_tokens, |
|
end_id=self.tokenizer.eot, |
|
pad_id=self.tokenizer.eot, |
|
num_beams=num_beams, |
|
output_sequence_lengths=True, |
|
return_dict=True) |
|
torch.cuda.synchronize() |
|
output_ids = outputs['output_ids'].cpu().numpy().tolist() |
|
texts = [] |
|
for i in range(len(output_ids)): |
|
text = self.tokenizer.decode(output_ids[i][0]).strip() |
|
texts.append(text) |
|
return texts |
|
|
|
def transcribe( |
|
self, |
|
mel, |
|
text_prefix="<|startoftranscript|><|en|><|transcribe|><|notimestamps|>", |
|
dtype='float16', |
|
batch_size=1, |
|
num_beams=1, |
|
padding_strategy="max", |
|
max_new_tokens=96, |
|
): |
|
mel = mel.type(str_dtype_to_torch(dtype)) |
|
mel = mel.unsqueeze(0) |
|
|
|
mel = mel.repeat(batch_size, 1, 1) |
|
if padding_strategy == "longest": |
|
pass |
|
else: |
|
mel = torch.nn.functional.pad(mel, (0, 3000 - mel.shape[2])) |
|
features_input_lengths = torch.full((mel.shape[0], ), |
|
mel.shape[2], |
|
dtype=torch.int32, |
|
device=mel.device) |
|
|
|
predictions = self.process_batch( |
|
mel, |
|
features_input_lengths, |
|
text_prefix, |
|
num_beams, |
|
max_new_tokens=max_new_tokens |
|
) |
|
prediction = predictions[0] |
|
|
|
|
|
prediction = re.sub(r'<\|.*?\|>', '', prediction) |
|
return prediction.strip() |
|
|
|
|
|
def decode_wav_file( |
|
model, |
|
mel, |
|
text_prefix="<|startoftranscript|><|en|><|transcribe|><|notimestamps|>", |
|
dtype='float16', |
|
batch_size=1, |
|
num_beams=1, |
|
normalizer=None, |
|
mel_filters_dir=None): |
|
|
|
mel = mel.type(str_dtype_to_torch(dtype)) |
|
mel = mel.unsqueeze(0) |
|
|
|
mel = mel.repeat(batch_size, 1, 1) |
|
predictions = model.process_batch(mel, text_prefix, num_beams) |
|
prediction = predictions[0] |
|
|
|
|
|
prediction = re.sub(r'<\|.*?\|>', '', prediction) |
|
if normalizer: |
|
prediction = normalizer(prediction) |
|
|
|
return prediction.strip() |
|
|