|
import os
|
|
import sys
|
|
import gzip
|
|
import zlib
|
|
import tqdm
|
|
import torch
|
|
import base64
|
|
import string
|
|
import logging
|
|
import tiktoken
|
|
import itertools
|
|
|
|
import numba as nb
|
|
import numpy as np
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
from contextlib import contextmanager
|
|
from torch.distributions import Categorical
|
|
from functools import cached_property, lru_cache
|
|
from dataclasses import dataclass, replace
|
|
from torch.nn.functional import scaled_dot_product_attention
|
|
|
|
sys.path.append(os.getcwd())
|
|
|
|
from main.library.utils import load_audio
|
|
|
|
LANGUAGES = {"en": "english", "zh": "chinese", "de": "german", "es": "spanish", "ru": "russian", "ko": "korean", "fr": "french", "ja": "japanese", "pt": "portuguese", "tr": "turkish", "pl": "polish", "ca": "catalan", "nl": "dutch", "ar": "arabic", "sv": "swedish", "it": "italian", "id": "indonesian", "hi": "hindi", "fi": "finnish", "vi": "vietnamese", "he": "hebrew", "uk": "ukrainian", "el": "greek", "ms": "malay", "cs": "czech", "ro": "romanian", "da": "danish", "hu": "hungarian", "ta": "tamil", "no": "norwegian", "th": "thai", "ur": "urdu", "hr": "croatian", "bg": "bulgarian", "lt": "lithuanian", "la": "latin", "mi": "maori", "ml": "malayalam", "cy": "welsh", "sk": "slovak", "te": "telugu", "fa": "persian", "lv": "latvian", "bn": "bengali", "sr": "serbian", "az": "azerbaijani", "sl": "slovenian", "kn": "kannada", "et": "estonian", "mk": "macedonian", "br": "breton", "eu": "basque", "is": "icelandic", "hy": "armenian", "ne": "nepali", "mn": "mongolian", "bs": "bosnian", "kk": "kazakh", "sq": "albanian", "sw": "swahili", "gl": "galician", "mr": "marathi", "pa": "punjabi", "si": "sinhala", "km": "khmer", "sn": "shona", "yo": "yoruba", "so": "somali", "af": "afrikaans", "oc": "occitan", "ka": "georgian", "be": "belarusian", "tg": "tajik", "sd": "sindhi", "gu": "gujarati", "am": "amharic", "yi": "yiddish", "lo": "lao", "uz": "uzbek", "fo": "faroese", "ht": "haitian creole", "ps": "pashto", "tk": "turkmen", "nn": "nynorsk", "mt": "maltese", "sa": "sanskrit", "lb": "luxembourgish", "my": "myanmar", "bo": "tibetan", "tl": "tagalog", "mg": "malagasy", "as": "assamese", "tt": "tatar", "haw": "hawaiian", "ln": "lingala", "ha": "hausa", "ba": "bashkir", "jw": "javanese", "su": "sundanese", "yue": "cantonese"}
|
|
TO_LANGUAGE_CODE = {**{language: code for code, language in LANGUAGES.items()}, "burmese": "my", "valencian": "ca", "flemish": "nl", "haitian": "ht", "letzeburgesch": "lb", "pushto": "ps", "panjabi": "pa", "moldavian": "ro", "moldovan": "ro", "sinhalese": "si", "castilian": "es", "mandarin": "zh"}
|
|
_ALIGNMENT_HEADS = {"tiny": b"ABzY8bu8Lr0{>%RKn9Fp%m@SkK7Kt=7ytkO", "base": b"ABzY8KQ!870{>%RzyTQH3`Q^yNP!>##QT-<FaQ7m", "small": b"ABzY8DmU6=0{>%Rpa?J`kvJ6qF(V^F86#Xh7JUGMK}P<N0000", "medium": b"ABzY8B0Jh+0{>%R7}kK1fFL7w6%<-Pf*t^=N)Qr&0RR9", "large-v1": b"ABzY8r9j$a0{>%R7#4sLmoOs{s)o3~84-RPdcFk!JR<kSfC2yj", "large-v2": b"ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj", "large-v3": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00", "large-v3-turbo": b"ABzY8j^C+e0{>%RARaKHP%t(lGR*)0g!tONPyhe`"}
|
|
|
|
SAMPLE_RATE, N_FFT, HOP_LENGTH, CHUNK_LENGTH = 16000, 400, 160, 30
|
|
|
|
N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE
|
|
N_SAMPLES_PER_TOKEN = HOP_LENGTH * 2
|
|
|
|
def exact_div(x, y):
|
|
assert x % y == 0
|
|
return x // y
|
|
|
|
N_FRAMES = exact_div(N_SAMPLES, HOP_LENGTH)
|
|
|
|
FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH)
|
|
TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN)
|
|
|
|
|
|
def load_model(name = "base", device = "cpu"):
|
|
checkpoint_file = os.path.join("assets", "models", "speaker_diarization", "models", name + ".pt")
|
|
alignment_heads = _ALIGNMENT_HEADS[name]
|
|
|
|
with open(checkpoint_file, "rb") as fp:
|
|
checkpoint = torch.load(fp, map_location=device)
|
|
|
|
del checkpoint_file
|
|
|
|
model = Whisper(ModelDimensions(**checkpoint["dims"]))
|
|
model.load_state_dict(checkpoint["model_state_dict"])
|
|
model.set_alignment_heads(alignment_heads)
|
|
|
|
return model.to(device)
|
|
|
|
def merge_punctuations(alignment, prepended, appended):
|
|
i = len(alignment) - 2
|
|
j = len(alignment) - 1
|
|
|
|
while i >= 0:
|
|
previous = alignment[i]
|
|
following = alignment[j]
|
|
|
|
if previous.word.startswith(" ") and previous.word.strip() in prepended:
|
|
following.word = previous.word + following.word
|
|
following.tokens = previous.tokens + following.tokens
|
|
|
|
previous.word = ""
|
|
previous.tokens = []
|
|
else: j = i
|
|
|
|
i -= 1
|
|
|
|
i = 0
|
|
j = 1
|
|
|
|
while j < len(alignment):
|
|
previous = alignment[i]
|
|
following = alignment[j]
|
|
|
|
if not previous.word.endswith(" ") and following.word in appended:
|
|
previous.word = previous.word + following.word
|
|
previous.tokens = previous.tokens + following.tokens
|
|
|
|
following.word = ""
|
|
following.tokens = []
|
|
else: i = j
|
|
|
|
j += 1
|
|
|
|
class WordTiming:
|
|
def __init__(self, word, tokens, start, end, probability):
|
|
self.word = word
|
|
self.tokens = tokens
|
|
self.start = start
|
|
self.end = end
|
|
self.probability = probability
|
|
|
|
@contextmanager
|
|
def disable_sdpa():
|
|
prev_state = MultiHeadAttention.use_sdpa
|
|
try:
|
|
MultiHeadAttention.use_sdpa = False
|
|
yield
|
|
finally:
|
|
MultiHeadAttention.use_sdpa = prev_state
|
|
|
|
def median_filter(x, filter_width):
|
|
pad_width = filter_width // 2
|
|
|
|
if x.shape[-1] <= pad_width: return x
|
|
if (ndim := x.ndim) <= 2: x = x[None, None, :]
|
|
|
|
assert (filter_width > 0 and filter_width % 2 == 1)
|
|
|
|
result = None
|
|
x = F.pad(x, (filter_width // 2, filter_width // 2, 0, 0), mode="reflect")
|
|
|
|
if result is None: result = x.unfold(-1, filter_width, 1).sort()[0][..., filter_width // 2]
|
|
if ndim <= 2: result = result[0, 0]
|
|
|
|
return result
|
|
|
|
@nb.jit(nopython=True)
|
|
def backtrace(trace):
|
|
i = trace.shape[0] - 1
|
|
j = trace.shape[1] - 1
|
|
|
|
trace[0, :] = 2
|
|
trace[:, 0] = 1
|
|
|
|
result = []
|
|
while i > 0 or j > 0:
|
|
result.append((i - 1, j - 1))
|
|
|
|
if trace[i, j] == 0:
|
|
i -= 1
|
|
j -= 1
|
|
elif trace[i, j] == 1: i -= 1
|
|
elif trace[i, j] == 2: j -= 1
|
|
else: raise ValueError
|
|
|
|
return np.array(result)[::-1, :].T
|
|
|
|
|
|
@nb.jit(nopython=True, parallel=True)
|
|
def dtw_cpu(x):
|
|
N, M = x.shape
|
|
|
|
cost = np.ones((N + 1, M + 1), dtype=np.float32) * np.inf
|
|
trace = -np.ones((N + 1, M + 1), dtype=np.float32)
|
|
cost[0, 0] = 0
|
|
|
|
for j in range(1, M + 1):
|
|
for i in range(1, N + 1):
|
|
c0 = cost[i - 1, j - 1]
|
|
c1 = cost[i - 1, j]
|
|
c2 = cost[i, j - 1]
|
|
|
|
if c0 < c1 and c0 < c2: c, t = c0, 0
|
|
elif c1 < c0 and c1 < c2: c, t = c1, 1
|
|
else: c, t = c2, 2
|
|
|
|
cost[i, j] = x[i - 1, j - 1] + c
|
|
trace[i, j] = t
|
|
|
|
return backtrace(trace)
|
|
|
|
def dtw(x):
|
|
return dtw_cpu(x.double().cpu().numpy())
|
|
|
|
def find_alignment(model, tokenizer, text_tokens, mel, num_frames, *, medfilt_width = 7, qk_scale = 1.0):
|
|
if len(text_tokens) == 0: return []
|
|
|
|
tokens = torch.tensor([*tokenizer.sot_sequence, tokenizer.no_timestamps, *text_tokens, tokenizer.eot]).to(model.device)
|
|
|
|
QKs = [None] * model.dims.n_text_layer
|
|
hooks = [block.cross_attn.register_forward_hook(lambda _, ins, outs, index=i: QKs.__setitem__(index, outs[-1][0])) for i, block in enumerate(model.decoder.blocks)]
|
|
|
|
with torch.no_grad(), disable_sdpa():
|
|
token_probs = model(mel.unsqueeze(0), tokens.unsqueeze(0))[0][len(tokenizer.sot_sequence) :, : tokenizer.eot].softmax(dim=-1)
|
|
text_token_probs = token_probs[np.arange(len(text_tokens)), text_tokens].tolist()
|
|
|
|
for hook in hooks:
|
|
hook.remove()
|
|
|
|
weights = (torch.stack([QKs[_l][_h] for _l, _h in model.alignment_heads.indices().T])[:, :, : num_frames // 2] * qk_scale).softmax(dim=-1)
|
|
std, mean = torch.std_mean(weights, dim=-2, keepdim=True, unbiased=False)
|
|
weights = median_filter((weights - mean) / std, medfilt_width)
|
|
|
|
text_indices, time_indices = dtw(-weights.mean(axis=0)[len(tokenizer.sot_sequence) : -1])
|
|
|
|
words, word_tokens = tokenizer.split_to_word_tokens(text_tokens + [tokenizer.eot])
|
|
if len(word_tokens) <= 1: return []
|
|
|
|
word_boundaries = np.pad(np.cumsum([len(t) for t in word_tokens[:-1]]), (1, 0))
|
|
jump_times = time_indices[np.pad(np.diff(text_indices), (1, 0), constant_values=1).astype(bool)] / TOKENS_PER_SECOND
|
|
|
|
return [WordTiming(word, tokens, start, end, probability) for word, tokens, start, end, probability in zip(words, word_tokens, jump_times[word_boundaries[:-1]], jump_times[word_boundaries[1:]], [np.mean(text_token_probs[i:j]) for i, j in zip(word_boundaries[:-1], word_boundaries[1:])])]
|
|
|
|
def add_word_timestamps(*, segments, model, tokenizer, mel, num_frames, prepend_punctuations = "\"'“¿([{-", append_punctuations = "\"'.。,,!!??::”)]}、", last_speech_timestamp, **kwargs):
|
|
if len(segments) == 0: return
|
|
|
|
text_tokens_per_segment = [[token for token in segment["tokens"] if token < tokenizer.eot] for segment in segments]
|
|
|
|
text_tokens = list(itertools.chain.from_iterable(text_tokens_per_segment))
|
|
alignment = find_alignment(model, tokenizer, text_tokens, mel, num_frames, **kwargs)
|
|
|
|
word_durations = np.array([t.end - t.start for t in alignment])
|
|
word_durations = word_durations[word_durations.nonzero()]
|
|
|
|
median_duration = min(0.7, float(np.median(word_durations) if len(word_durations) > 0 else 0.0))
|
|
max_duration = median_duration * 2
|
|
|
|
if len(word_durations) > 0:
|
|
sentence_end_marks = ".。!!??"
|
|
for i in range(1, len(alignment)):
|
|
if alignment[i].end - alignment[i].start > max_duration:
|
|
if alignment[i].word in sentence_end_marks: alignment[i].end = alignment[i].start + max_duration
|
|
elif alignment[i - 1].word in sentence_end_marks: alignment[i].start = alignment[i].end - max_duration
|
|
|
|
merge_punctuations(alignment, prepend_punctuations, append_punctuations)
|
|
|
|
time_offset = segments[0]["seek"] * HOP_LENGTH / SAMPLE_RATE
|
|
word_index = 0
|
|
|
|
for segment, text_tokens in zip(segments, text_tokens_per_segment):
|
|
saved_tokens = 0
|
|
words = []
|
|
|
|
while word_index < len(alignment) and saved_tokens < len(text_tokens):
|
|
timing = alignment[word_index]
|
|
|
|
if timing.word: words.append(dict(word=timing.word, start=round(time_offset + timing.start, 2), end=round(time_offset + timing.end, 2), probability=timing.probability))
|
|
|
|
saved_tokens += len(timing.tokens)
|
|
word_index += 1
|
|
|
|
if len(words) > 0:
|
|
if words[0]["end"] - last_speech_timestamp > median_duration * 4 and (words[0]["end"] - words[0]["start"] > max_duration or (len(words) > 1 and words[1]["end"] - words[0]["start"] > max_duration * 2)):
|
|
if (len(words) > 1 and words[1]["end"] - words[1]["start"] > max_duration): words[0]["end"] = words[1]["start"] = max(words[1]["end"] / 2, words[1]["end"] - max_duration)
|
|
words[0]["start"] = max(0, words[0]["end"] - max_duration)
|
|
|
|
if (segment["start"] < words[0]["end"] and segment["start"] - 0.5 > words[0]["start"]): words[0]["start"] = max(0, min(words[0]["end"] - median_duration, segment["start"]))
|
|
else: segment["start"] = words[0]["start"]
|
|
|
|
if (segment["end"] > words[-1]["start"] and segment["end"] + 0.5 < words[-1]["end"]): words[-1]["end"] = max(words[-1]["start"] + median_duration, segment["end"])
|
|
else: segment["end"] = words[-1]["end"]
|
|
|
|
last_speech_timestamp = segment["end"]
|
|
|
|
segment["words"] = words
|
|
|
|
@lru_cache(maxsize=None)
|
|
def mel_filters(device, n_mels):
|
|
assert n_mels in {80, 128}
|
|
|
|
with np.load(os.path.join("assets", "models", "speaker_diarization", "assets", "mel_filters.npz"), allow_pickle=False) as f:
|
|
return torch.from_numpy(f[f"mel_{n_mels}"]).to(device)
|
|
|
|
def log_mel_spectrogram(audio, n_mels = 80, padding = 0, device = None):
|
|
if not torch.is_tensor(audio):
|
|
if isinstance(audio, str): audio = load_audio(logging.getLogger(__name__), audio, sample_rate=SAMPLE_RATE).astype(np.float32)
|
|
audio = torch.from_numpy(audio)
|
|
|
|
if device is not None: audio = audio.to(device)
|
|
if padding > 0: audio = F.pad(audio, (0, padding))
|
|
|
|
log_spec = torch.clamp(mel_filters(audio.device, n_mels) @ torch.stft(audio, N_FFT, HOP_LENGTH, window=torch.hann_window(N_FFT).to(audio.device), return_complex=True)[..., :-1].abs() ** 2, min=1e-10).log10()
|
|
return (torch.maximum(log_spec, log_spec.max() - 8.0) + 4.0) / 4.0
|
|
|
|
def pad_or_trim(array, length = N_SAMPLES, *, axis = -1):
|
|
if torch.is_tensor(array):
|
|
if array.shape[axis] > length: array = array.index_select(dim=axis, index=torch.arange(length, device=array.device))
|
|
|
|
if array.shape[axis] < length:
|
|
pad_widths = [(0, 0)] * array.ndim
|
|
pad_widths[axis] = (0, length - array.shape[axis])
|
|
array = F.pad(array, [pad for sizes in pad_widths[::-1] for pad in sizes])
|
|
else:
|
|
if array.shape[axis] > length: array = array.take(indices=range(length), axis=axis)
|
|
|
|
if array.shape[axis] < length:
|
|
pad_widths = [(0, 0)] * array.ndim
|
|
pad_widths[axis] = (0, length - array.shape[axis])
|
|
array = np.pad(array, pad_widths)
|
|
|
|
return array
|
|
|
|
def get_end(segments):
|
|
return next((w["end"] for s in reversed(segments) for w in reversed(s["words"])), segments[-1]["end"] if segments else None)
|
|
|
|
def transcribe_function(model, audio, *, verbose = None, temperature = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0), compression_ratio_threshold = 2.4, logprob_threshold = -1.0, no_speech_threshold = 0.6, condition_on_previous_text = True, initial_prompt = None, carry_initial_prompt = False, word_timestamps = False, prepend_punctuations = "\"'“¿([{-", append_punctuations = "\"'.。,,!!??::”)]}、", clip_timestamps = "0", hallucination_silence_threshold = None, fp16 = False, **decode_options):
|
|
dtype = torch.float32
|
|
decode_options["fp16"] = fp16
|
|
|
|
mel = log_mel_spectrogram(audio, model.dims.n_mels, padding=N_SAMPLES)
|
|
content_frames = mel.shape[-1] - N_FRAMES
|
|
content_duration = float(content_frames * HOP_LENGTH / SAMPLE_RATE)
|
|
|
|
if decode_options.get("language", None) is None:
|
|
if not model.is_multilingual: decode_options["language"] = "vi"
|
|
else:
|
|
mel_segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype)
|
|
_, probs = model.detect_language(mel_segment)
|
|
decode_options["language"] = max(probs, key=probs.get)
|
|
|
|
if verbose is not None: print(f"{LANGUAGES[decode_options['language']].title()}")
|
|
|
|
language = decode_options["language"]
|
|
task = decode_options.get("task", "transcribe")
|
|
tokenizer = get_tokenizer(model.is_multilingual, num_languages=model.num_languages, language=language, task=task)
|
|
|
|
if isinstance(clip_timestamps, str): clip_timestamps = [float(ts) for ts in (clip_timestamps.split(",") if clip_timestamps else [])]
|
|
seek_points = [round(ts * FRAMES_PER_SECOND) for ts in clip_timestamps]
|
|
|
|
if len(seek_points) == 0: seek_points.append(0)
|
|
if len(seek_points) % 2 == 1: seek_points.append(content_frames)
|
|
|
|
seek_clips = list(zip(seek_points[::2], seek_points[1::2]))
|
|
punctuation = "\"'“¿([{-\"'.。,,!!??::”)]}、"
|
|
|
|
def decode_with_fallback(segment):
|
|
temperatures = ([temperature] if isinstance(temperature, (int, float)) else temperature)
|
|
decode_result = None
|
|
|
|
for t in temperatures:
|
|
kwargs = {**decode_options}
|
|
|
|
if t > 0:
|
|
kwargs.pop("beam_size", None)
|
|
kwargs.pop("patience", None)
|
|
else: kwargs.pop("best_of", None)
|
|
|
|
decode_result = model.decode(segment, DecodingOptions(**kwargs, temperature=t))
|
|
needs_fallback = False
|
|
|
|
if (compression_ratio_threshold is not None and decode_result.compression_ratio > compression_ratio_threshold): needs_fallback = True
|
|
if (logprob_threshold is not None and decode_result.avg_logprob < logprob_threshold): needs_fallback = True
|
|
if (no_speech_threshold is not None and decode_result.no_speech_prob > no_speech_threshold and logprob_threshold is not None and decode_result.avg_logprob < logprob_threshold): needs_fallback = False
|
|
if not needs_fallback: break
|
|
|
|
return decode_result
|
|
|
|
clip_idx = 0
|
|
seek = seek_clips[clip_idx][0]
|
|
|
|
input_stride = exact_div(N_FRAMES, model.dims.n_audio_ctx)
|
|
time_precision = (input_stride * HOP_LENGTH / SAMPLE_RATE)
|
|
|
|
all_tokens, all_segments = [], []
|
|
prompt_reset_since = 0
|
|
|
|
remaining_prompt_length = model.dims.n_text_ctx // 2 - 1
|
|
|
|
if initial_prompt is not None:
|
|
initial_prompt_tokens = tokenizer.encode(" " + initial_prompt.strip())
|
|
all_tokens.extend(initial_prompt_tokens)
|
|
remaining_prompt_length -= len(initial_prompt_tokens)
|
|
else: initial_prompt_tokens = []
|
|
|
|
def new_segment(*, start, end, tokens, result):
|
|
tokens = tokens.tolist()
|
|
return {"seek": seek, "start": start, "end": end, "text": tokenizer.decode([token for token in tokens if token < tokenizer.eot]), "tokens": tokens, "temperature": result.temperature, "avg_logprob": result.avg_logprob, "compression_ratio": result.compression_ratio, "no_speech_prob": result.no_speech_prob}
|
|
|
|
with tqdm.tqdm(total=content_frames, unit="frames", disable=verbose is not False) as pbar:
|
|
last_speech_timestamp = 0.0
|
|
while clip_idx < len(seek_clips):
|
|
seek_clip_start, seek_clip_end = seek_clips[clip_idx]
|
|
if seek < seek_clip_start: seek = seek_clip_start
|
|
|
|
if seek >= seek_clip_end:
|
|
clip_idx += 1
|
|
if clip_idx < len(seek_clips): seek = seek_clips[clip_idx][0]
|
|
continue
|
|
|
|
time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
|
|
window_end_time = float((seek + N_FRAMES) * HOP_LENGTH / SAMPLE_RATE)
|
|
|
|
segment_size = min(N_FRAMES, content_frames - seek, seek_clip_end - seek)
|
|
mel_segment = mel[:, seek : seek + segment_size]
|
|
|
|
segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE
|
|
mel_segment = pad_or_trim(mel_segment, N_FRAMES).to(model.device).to(dtype)
|
|
|
|
if carry_initial_prompt: decode_options["prompt"] = initial_prompt_tokens + all_tokens[max(len(initial_prompt_tokens), prompt_reset_since):][-remaining_prompt_length:]
|
|
else: decode_options["prompt"] = all_tokens[prompt_reset_since:]
|
|
|
|
result = decode_with_fallback(mel_segment)
|
|
tokens = torch.tensor(result.tokens)
|
|
|
|
if no_speech_threshold is not None:
|
|
should_skip = result.no_speech_prob > no_speech_threshold
|
|
if (logprob_threshold is not None and result.avg_logprob > logprob_threshold):
|
|
should_skip = False
|
|
|
|
if should_skip:
|
|
seek += segment_size
|
|
continue
|
|
|
|
previous_seek = seek
|
|
current_segments = []
|
|
|
|
def word_anomaly_score(word):
|
|
probability = word.get("probability", 0.0)
|
|
duration = word["end"] - word["start"]
|
|
score = 0.0
|
|
|
|
if probability < 0.15: score += 1.0
|
|
if duration < 0.133: score += (0.133 - duration) * 15
|
|
if duration > 2.0: score += duration - 2.0
|
|
|
|
return score
|
|
|
|
def is_segment_anomaly(segment):
|
|
if segment is None or not segment["words"]: return False
|
|
|
|
words = [w for w in segment["words"] if w["word"] not in punctuation]
|
|
words = words[:8]
|
|
|
|
score = sum(word_anomaly_score(w) for w in words)
|
|
|
|
return score >= 3 or score + 0.01 >= len(words)
|
|
|
|
def next_words_segment(segments):
|
|
return next((s for s in segments if s["words"]), None)
|
|
|
|
timestamp_tokens = tokens.ge(tokenizer.timestamp_begin)
|
|
single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True]
|
|
|
|
consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0]
|
|
consecutive.add_(1)
|
|
|
|
if len(consecutive) > 0:
|
|
slices = consecutive.tolist()
|
|
if single_timestamp_ending:
|
|
slices.append(len(tokens))
|
|
|
|
last_slice = 0
|
|
for current_slice in slices:
|
|
sliced_tokens = tokens[last_slice:current_slice]
|
|
current_segments.append(new_segment(start=time_offset + (sliced_tokens[0].item() - tokenizer.timestamp_begin) * time_precision, end=time_offset + (sliced_tokens[-1].item() - tokenizer.timestamp_begin) * time_precision, tokens=sliced_tokens, result=result))
|
|
last_slice = current_slice
|
|
|
|
if single_timestamp_ending: seek += segment_size
|
|
else: seek += (tokens[last_slice - 1].item() - tokenizer.timestamp_begin) * input_stride
|
|
else:
|
|
duration = segment_duration
|
|
|
|
timestamps = tokens[timestamp_tokens.nonzero().flatten()]
|
|
if (len(timestamps) > 0 and timestamps[-1].item() != tokenizer.timestamp_begin): duration = (timestamps[-1].item() - tokenizer.timestamp_begin) * time_precision
|
|
|
|
current_segments.append(new_segment(start=time_offset, end=time_offset + duration, tokens=tokens, result=result))
|
|
seek += segment_size
|
|
|
|
if word_timestamps:
|
|
add_word_timestamps(segments=current_segments, model=model, tokenizer=tokenizer, mel=mel_segment, num_frames=segment_size, prepend_punctuations=prepend_punctuations, append_punctuations=append_punctuations, last_speech_timestamp=last_speech_timestamp)
|
|
|
|
if not single_timestamp_ending:
|
|
last_word_end = get_end(current_segments)
|
|
if last_word_end is not None and last_word_end > time_offset: seek = round(last_word_end * FRAMES_PER_SECOND)
|
|
|
|
if hallucination_silence_threshold is not None:
|
|
threshold = hallucination_silence_threshold
|
|
|
|
if not single_timestamp_ending:
|
|
last_word_end = get_end(current_segments)
|
|
if last_word_end is not None and last_word_end > time_offset: seek = round(last_word_end * FRAMES_PER_SECOND) if (window_end_time - last_word_end) > threshold else (previous_seek + segment_size)
|
|
|
|
first_segment = next_words_segment(current_segments)
|
|
|
|
if first_segment is not None and is_segment_anomaly(first_segment):
|
|
gap = first_segment["start"] - time_offset
|
|
|
|
if gap > threshold:
|
|
seek = previous_seek + round(gap * FRAMES_PER_SECOND)
|
|
continue
|
|
|
|
hal_last_end = last_speech_timestamp
|
|
|
|
for si in range(len(current_segments)):
|
|
segment = current_segments[si]
|
|
if not segment["words"]: continue
|
|
|
|
if is_segment_anomaly(segment):
|
|
next_segment = next_words_segment(current_segments[si + 1 :])
|
|
hal_next_start = next_segment["words"][0]["start"] if next_segment is not None else (time_offset + segment_duration)
|
|
|
|
if (segment["start"] - hal_last_end > threshold or segment["start"] < threshold or segment["start"] - time_offset < 2.0) and (hal_next_start - segment["end"] > threshold or is_segment_anomaly(next_segment) or window_end_time - segment["end"] < 2.0):
|
|
seek = round(max(time_offset + 1, segment["start"]) * FRAMES_PER_SECOND)
|
|
if content_duration - segment["end"] < threshold: seek = content_frames
|
|
|
|
current_segments[si:] = []
|
|
break
|
|
|
|
hal_last_end = segment["end"]
|
|
|
|
last_word_end = get_end(current_segments)
|
|
if last_word_end is not None: last_speech_timestamp = last_word_end
|
|
|
|
for _, segment in enumerate(current_segments):
|
|
if segment["start"] == segment["end"] or segment["text"].strip() == "":
|
|
segment["text"] = ""
|
|
segment["tokens"] = []
|
|
segment["words"] = []
|
|
|
|
all_segments.extend([{"id": i, **segment} for i, segment in enumerate(current_segments, start=len(all_segments))])
|
|
all_tokens.extend([token for segment in current_segments for token in segment["tokens"]])
|
|
|
|
if not condition_on_previous_text or result.temperature > 0.5: prompt_reset_since = len(all_tokens)
|
|
pbar.update(min(content_frames, seek) - previous_seek)
|
|
|
|
return dict(text=tokenizer.decode(all_tokens[len(initial_prompt_tokens) :]), segments=all_segments, language=language)
|
|
|
|
def compression_ratio(text):
|
|
text_bytes = text.encode("utf-8")
|
|
return len(text_bytes) / len(zlib.compress(text_bytes))
|
|
|
|
def sinusoids(length, channels, max_timescale=10000):
|
|
assert channels % 2 == 0
|
|
|
|
scaled_time = torch.arange(length)[:, np.newaxis] * torch.exp(-(np.log(max_timescale) / (channels // 2 - 1)) * torch.arange(channels // 2))[np.newaxis, :]
|
|
return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
|
|
|
|
@torch.no_grad()
|
|
def detect_language_function(model, mel, tokenizer = None):
|
|
if tokenizer is None: tokenizer = get_tokenizer(model.is_multilingual, num_languages=model.num_languages)
|
|
if (tokenizer.language is None or tokenizer.language_token not in tokenizer.sot_sequence): raise ValueError
|
|
|
|
single = mel.ndim == 2
|
|
|
|
if single: mel = mel.unsqueeze(0)
|
|
if mel.shape[-2:] != (model.dims.n_audio_ctx, model.dims.n_audio_state): mel = model.encoder(mel)
|
|
|
|
n_audio = mel.shape[0]
|
|
logits = model.logits(torch.tensor([[tokenizer.sot]] * n_audio).to(mel.device), mel)[:, 0]
|
|
|
|
mask = torch.ones(logits.shape[-1], dtype=torch.bool)
|
|
mask[list(tokenizer.all_language_tokens)] = False
|
|
|
|
logits[:, mask] = -np.inf
|
|
|
|
language_tokens = logits.argmax(dim=-1)
|
|
language_probs = [{c: logits.softmax(dim=-1).cpu()[i, j].item() for j, c in zip(tokenizer.all_language_tokens, tokenizer.all_language_codes)} for i in range(n_audio)]
|
|
|
|
if single:
|
|
language_tokens = language_tokens[0]
|
|
language_probs = language_probs[0]
|
|
|
|
return language_tokens, language_probs
|
|
|
|
@lru_cache(maxsize=None)
|
|
def get_tokenizer(multilingual, *, num_languages = 99, language = None, task = None):
|
|
if language is not None:
|
|
language = language.lower()
|
|
if language not in LANGUAGES:
|
|
if language in TO_LANGUAGE_CODE: language = TO_LANGUAGE_CODE[language]
|
|
else: raise ValueError
|
|
|
|
if multilingual:
|
|
encoding_name = "multilingual"
|
|
language = language or "en"
|
|
task = task or "transcribe"
|
|
else:
|
|
encoding_name = "gpt2"
|
|
language = None
|
|
task = None
|
|
|
|
return Tokenizer(encoding_name=encoding_name, num_languages=num_languages, language=language, task=task)
|
|
|
|
@lru_cache(maxsize=None)
|
|
def get_encoding(name = "gpt2", num_languages = 99):
|
|
vocab_path = os.path.join("assets", "models", "speaker_diarization", "assets", f"{name}.tiktoken")
|
|
ranks = {base64.b64decode(token): int(rank) for token, rank in (line.split() for line in open(vocab_path) if line)}
|
|
|
|
n_vocab = len(ranks)
|
|
special_tokens = {}
|
|
|
|
specials = ["<|endoftext|>", "<|startoftranscript|>", *[f"<|{lang}|>" for lang in list(LANGUAGES.keys())[:num_languages]], "<|translate|>", "<|transcribe|>", "<|startoflm|>", "<|startofprev|>", "<|nospeech|>", "<|notimestamps|>", *[f"<|{i * 0.02:.2f}|>" for i in range(1501)]]
|
|
|
|
for token in specials:
|
|
special_tokens[token] = n_vocab
|
|
n_vocab += 1
|
|
|
|
return tiktoken.Encoding(name=os.path.basename(vocab_path), explicit_n_vocab=n_vocab, pat_str=r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""", mergeable_ranks=ranks, special_tokens=special_tokens)
|
|
|
|
class DecodingOptions:
|
|
def __init__(self, task = "transcribe", language = None, temperature = 0.0, sample_len = None, best_of = None, beam_size = None, patience = None, length_penalty = None, prompt = None, prefix = None, suppress_tokens = "-1", suppress_blank = True, without_timestamps = False, max_initial_timestamp = 1.0, fp16 = False):
|
|
self.task = task
|
|
self.language = language
|
|
self.temperature = temperature
|
|
self.sample_len = sample_len
|
|
self.best_of = best_of
|
|
self.beam_size = beam_size
|
|
self.patience = patience
|
|
self.length_penalty = length_penalty
|
|
self.prompt = prompt
|
|
self.prefix = prefix
|
|
self.suppress_tokens = suppress_tokens
|
|
self.suppress_blank = suppress_blank
|
|
self.without_timestamps = without_timestamps
|
|
self.max_initial_timestamp = max_initial_timestamp
|
|
self.fp16 = fp16
|
|
|
|
@torch.no_grad()
|
|
def decode_function(model, mel, options = DecodingOptions(), **kwargs):
|
|
if single := mel.ndim == 2: mel = mel.unsqueeze(0)
|
|
if kwargs: options = replace(options, **kwargs)
|
|
|
|
result = DecodingTask(model, options).run(mel)
|
|
return result[0] if single else result
|
|
|
|
@dataclass
|
|
class ModelDimensions:
|
|
n_mels: int
|
|
n_audio_ctx: int
|
|
n_audio_state: int
|
|
n_audio_head: int
|
|
n_audio_layer: int
|
|
n_vocab: int
|
|
n_text_ctx: int
|
|
n_text_state: int
|
|
n_text_head: int
|
|
n_text_layer: int
|
|
|
|
class LayerNorm(nn.LayerNorm):
|
|
def forward(self, x):
|
|
return super().forward(x.float()).type(x.dtype)
|
|
|
|
class Linear(nn.Linear):
|
|
def forward(self, x):
|
|
return F.linear(x, self.weight.to(x.dtype), None if self.bias is None else self.bias.to(x.dtype))
|
|
|
|
class Conv1d(nn.Conv1d):
|
|
def _conv_forward(self, x, weight, bias):
|
|
return super()._conv_forward(x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype))
|
|
|
|
class TextDecoder(nn.Module):
|
|
def __init__(self, n_vocab, n_ctx, n_state, n_head, n_layer):
|
|
super().__init__()
|
|
|
|
self.token_embedding = nn.Embedding(n_vocab, n_state)
|
|
self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state))
|
|
|
|
self.blocks = nn.ModuleList([ResidualAttentionBlock(n_state, n_head, cross_attention=True) for _ in range(n_layer)])
|
|
self.ln = LayerNorm(n_state)
|
|
self.register_buffer("mask", torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1), persistent=False)
|
|
|
|
def forward(self, x, xa, kv_cache = None):
|
|
offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
|
|
x = (self.token_embedding(x) + self.positional_embedding[offset : offset + x.shape[-1]]).to(xa.dtype)
|
|
|
|
for block in self.blocks:
|
|
x = block(x, xa, mask=self.mask, kv_cache=kv_cache)
|
|
|
|
x = self.ln(x)
|
|
return (x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)).float()
|
|
|
|
class AudioEncoder(nn.Module):
|
|
def __init__(self, n_mels, n_ctx, n_state, n_head, n_layer):
|
|
super().__init__()
|
|
self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1)
|
|
self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
|
|
self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state))
|
|
|
|
self.blocks = nn.ModuleList([ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)])
|
|
self.ln_post = LayerNorm(n_state)
|
|
|
|
def forward(self, x):
|
|
x = F.gelu(self.conv2(F.gelu(self.conv1(x)))).permute(0, 2, 1)
|
|
|
|
assert x.shape[1:] == self.positional_embedding.shape
|
|
x = (x + self.positional_embedding).to(x.dtype)
|
|
|
|
for block in self.blocks:
|
|
x = block(x)
|
|
|
|
return self.ln_post(x)
|
|
|
|
class Whisper(nn.Module):
|
|
def __init__(self, dims):
|
|
super().__init__()
|
|
self.dims = dims
|
|
self.encoder = AudioEncoder(self.dims.n_mels, self.dims.n_audio_ctx, self.dims.n_audio_state, self.dims.n_audio_head, self.dims.n_audio_layer)
|
|
self.decoder = TextDecoder(self.dims.n_vocab, self.dims.n_text_ctx, self.dims.n_text_state, self.dims.n_text_head, self.dims.n_text_layer)
|
|
|
|
all_heads = torch.zeros(self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool)
|
|
all_heads[self.dims.n_text_layer // 2 :] = True
|
|
self.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False)
|
|
|
|
def set_alignment_heads(self, dump):
|
|
self.register_buffer("alignment_heads", torch.from_numpy(np.frombuffer(gzip.decompress(base64.b85decode(dump)), dtype=bool).copy()).reshape(self.dims.n_text_layer, self.dims.n_text_head).to_sparse(), persistent=False)
|
|
|
|
def embed_audio(self, mel):
|
|
return self.encoder(mel)
|
|
|
|
def logits(self, tokens, audio_features):
|
|
return self.decoder(tokens, audio_features)
|
|
|
|
def forward(self, mel, tokens):
|
|
return self.decoder(tokens, self.encoder(mel))
|
|
|
|
@property
|
|
def device(self):
|
|
return next(self.parameters()).device
|
|
|
|
@property
|
|
def is_multilingual(self):
|
|
return self.dims.n_vocab >= 51865
|
|
|
|
@property
|
|
def num_languages(self):
|
|
return self.dims.n_vocab - 51765 - int(self.is_multilingual)
|
|
|
|
def install_kv_cache_hooks(self, cache = None):
|
|
cache = {**cache} if cache is not None else {}
|
|
hooks = []
|
|
|
|
def save_to_cache(module, _, output):
|
|
cache[module] = output if module not in cache or output.shape[1] > self.dims.n_text_ctx else torch.cat([cache[module], output], dim=1).detach()
|
|
return cache[module]
|
|
|
|
def install_hooks(layer: nn.Module):
|
|
if isinstance(layer, MultiHeadAttention):
|
|
hooks.append(layer.key.register_forward_hook(save_to_cache))
|
|
hooks.append(layer.value.register_forward_hook(save_to_cache))
|
|
|
|
self.decoder.apply(install_hooks)
|
|
return cache, hooks
|
|
|
|
detect_language = detect_language_function
|
|
transcribe = transcribe_function
|
|
decode = decode_function
|
|
|
|
class ResidualAttentionBlock(nn.Module):
|
|
def __init__(self, n_state: int, n_head: int, cross_attention: bool = False):
|
|
super().__init__()
|
|
|
|
self.attn = MultiHeadAttention(n_state, n_head)
|
|
self.attn_ln = LayerNorm(n_state)
|
|
|
|
self.cross_attn = (MultiHeadAttention(n_state, n_head) if cross_attention else None)
|
|
self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None
|
|
|
|
n_mlp = n_state * 4
|
|
self.mlp = nn.Sequential(Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state))
|
|
self.mlp_ln = LayerNorm(n_state)
|
|
|
|
def forward(self, x, xa = None, mask = None, kv_cache = None):
|
|
x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)[0]
|
|
if self.cross_attn: x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)[0]
|
|
|
|
return x + self.mlp(self.mlp_ln(x))
|
|
|
|
class MultiHeadAttention(nn.Module):
|
|
def __init__(self, n_state, n_head):
|
|
super().__init__()
|
|
self.n_head = n_head
|
|
self.query = Linear(n_state, n_state)
|
|
self.key = Linear(n_state, n_state, bias=False)
|
|
self.value = Linear(n_state, n_state)
|
|
self.out = Linear(n_state, n_state)
|
|
|
|
def forward(self, x, xa = None, mask = None, kv_cache = None):
|
|
k, v = (self.key(x if xa is None else xa), self.value(x if xa is None else xa)) if kv_cache is None or xa is None or self.key not in kv_cache else (kv_cache[self.key], kv_cache[self.value])
|
|
wv, qk = self.qkv_attention(self.query(x), k, v, mask)
|
|
|
|
return self.out(wv), qk
|
|
|
|
def qkv_attention(self, q, k, v, mask = None):
|
|
_, n_ctx, _ = q.shape
|
|
|
|
q, k, v = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3), k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 1, 3), v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
|
|
return scaled_dot_product_attention(q, k, v, is_causal=mask is not None and n_ctx > 1).permute(0, 2, 1, 3).flatten(start_dim=2), None
|
|
|
|
class LogitFilter:
|
|
def apply(self, logits, tokens):
|
|
pass
|
|
|
|
class SuppressBlank(LogitFilter):
|
|
def __init__(self, tokenizer, sample_begin):
|
|
self.tokenizer = tokenizer
|
|
self.sample_begin = sample_begin
|
|
|
|
def apply(self, logits, tokens):
|
|
if tokens.shape[1] == self.sample_begin: logits[:, self.tokenizer.encode(" ") + [self.tokenizer.eot]] = -np.inf
|
|
|
|
class SuppressTokens(LogitFilter):
|
|
def __init__(self, suppress_tokens):
|
|
self.suppress_tokens = list(suppress_tokens)
|
|
|
|
def apply(self, logits, tokens):
|
|
logits[:, self.suppress_tokens] = -np.inf
|
|
|
|
class Inference:
|
|
def logits(self, tokens, audio_features):
|
|
pass
|
|
|
|
def rearrange_kv_cache(self, source_indices):
|
|
pass
|
|
|
|
def cleanup_caching(self):
|
|
pass
|
|
|
|
class PyTorchInference(Inference):
|
|
def __init__(self, model, initial_token_length):
|
|
self.model = model
|
|
self.initial_token_length = initial_token_length
|
|
self.kv_cache = {}
|
|
self.hooks = []
|
|
|
|
self.kv_modules = [block.attn.key for block in self.model.decoder.blocks] + [block.attn.value for block in self.model.decoder.blocks]
|
|
|
|
def logits(self, tokens, audio_features):
|
|
if not self.kv_cache: self.kv_cache, self.hooks = self.model.install_kv_cache_hooks()
|
|
if tokens.shape[-1] > self.initial_token_length: tokens = tokens[:, -1:]
|
|
|
|
return self.model.decoder(tokens, audio_features, kv_cache=self.kv_cache)
|
|
|
|
def cleanup_caching(self):
|
|
for hook in self.hooks:
|
|
hook.remove()
|
|
|
|
self.kv_cache = {}
|
|
self.hooks = []
|
|
|
|
def rearrange_kv_cache(self, source_indices):
|
|
if source_indices != list(range(len(source_indices))):
|
|
for module in self.kv_modules:
|
|
self.kv_cache[module] = self.kv_cache[module][source_indices].detach()
|
|
|
|
class SequenceRanker:
|
|
def rank(self, tokens, sum_logprobs):
|
|
pass
|
|
|
|
class MaximumLikelihoodRanker(SequenceRanker):
|
|
def __init__(self, length_penalty):
|
|
self.length_penalty = length_penalty
|
|
|
|
def rank(self, tokens, sum_logprobs):
|
|
def scores(logprobs, lengths):
|
|
result = []
|
|
for logprob, length in zip(logprobs, lengths):
|
|
result.append(logprob / (length if self.length_penalty is None else ((5 + length) / 6) ** self.length_penalty))
|
|
return result
|
|
|
|
return [np.argmax(scores(p, l)) for p, l in zip(sum_logprobs, [[len(t) for t in s] for s in tokens])]
|
|
|
|
class TokenDecoder:
|
|
def reset(self):
|
|
pass
|
|
|
|
def update(self, tokens, logits, sum_logprobs):
|
|
pass
|
|
|
|
def finalize(self, tokens, sum_logprobs):
|
|
pass
|
|
|
|
|
|
class GreedyDecoder(TokenDecoder):
|
|
def __init__(self, temperature, eot):
|
|
self.temperature = temperature
|
|
self.eot = eot
|
|
|
|
def update(self, tokens, logits, sum_logprobs):
|
|
next_tokens = logits.argmax(dim=-1) if self.temperature == 0 else Categorical(logits=logits / self.temperature).sample()
|
|
|
|
logprobs = F.log_softmax(logits.float(), dim=-1)
|
|
sum_logprobs += logprobs[torch.arange(logprobs.shape[0]), next_tokens] * (tokens[:, -1] != self.eot)
|
|
|
|
next_tokens[tokens[:, -1] == self.eot] = self.eot
|
|
tokens = torch.cat([tokens, next_tokens[:, None]], dim=-1)
|
|
|
|
return tokens, (tokens[:, -1] == self.eot).all()
|
|
|
|
def finalize(self, tokens, sum_logprobs):
|
|
return F.pad(tokens, (0, 1), value=self.eot), sum_logprobs.tolist()
|
|
|
|
class BeamSearchDecoder(TokenDecoder):
|
|
def __init__(self, beam_size, eot, inference, patience = None):
|
|
self.beam_size = beam_size
|
|
self.eot = eot
|
|
self.inference = inference
|
|
self.patience = patience or 1.0
|
|
self.max_candidates = round(beam_size * self.patience)
|
|
self.finished_sequences = None
|
|
|
|
assert (self.max_candidates > 0)
|
|
|
|
def reset(self):
|
|
self.finished_sequences = None
|
|
|
|
def update(self, tokens, logits, sum_logprobs):
|
|
if tokens.shape[0] % self.beam_size != 0: raise ValueError(f"{tokens.shape}[0] % {self.beam_size} != 0")
|
|
|
|
n_audio = tokens.shape[0] // self.beam_size
|
|
if self.finished_sequences is None: self.finished_sequences = [{} for _ in range(n_audio)]
|
|
|
|
logprobs = F.log_softmax(logits.float(), dim=-1)
|
|
next_tokens, source_indices, finished_sequences = [], [], []
|
|
|
|
for i in range(n_audio):
|
|
scores, sources, finished = {}, {}, {}
|
|
|
|
for j in range(self.beam_size):
|
|
idx = i * self.beam_size + j
|
|
prefix = tokens[idx].tolist()
|
|
for logprob, token in zip(*logprobs[idx].topk(self.beam_size + 1)):
|
|
sequence = tuple(prefix + [token.item()])
|
|
scores[sequence] = (sum_logprobs[idx] + logprob).item()
|
|
sources[sequence] = idx
|
|
|
|
saved = 0
|
|
|
|
for sequence in sorted(scores, key=scores.get, reverse=True):
|
|
if sequence[-1] == self.eot: finished[sequence] = scores[sequence]
|
|
else:
|
|
sum_logprobs[len(next_tokens)] = scores[sequence]
|
|
next_tokens.append(sequence)
|
|
source_indices.append(sources[sequence])
|
|
|
|
saved += 1
|
|
if saved == self.beam_size: break
|
|
|
|
finished_sequences.append(finished)
|
|
|
|
self.inference.rearrange_kv_cache(source_indices)
|
|
assert len(self.finished_sequences) == len(finished_sequences)
|
|
|
|
for previously_finished, newly_finished in zip(self.finished_sequences, finished_sequences):
|
|
for seq in sorted(newly_finished, key=newly_finished.get, reverse=True):
|
|
if len(previously_finished) >= self.max_candidates: break
|
|
previously_finished[seq] = newly_finished[seq]
|
|
|
|
return torch.tensor(next_tokens, device=tokens.device), all(len(sequences) >= self.max_candidates for sequences in self.finished_sequences)
|
|
|
|
def finalize(self, preceding_tokens, sum_logprobs):
|
|
sum_logprobs = sum_logprobs.cpu()
|
|
|
|
for i, sequences in enumerate(self.finished_sequences):
|
|
if (len(sequences) < self.beam_size):
|
|
for j in list(np.argsort(sum_logprobs[i]))[::-1]:
|
|
sequence = preceding_tokens[i, j].tolist() + [self.eot]
|
|
sequences[tuple(sequence)] = sum_logprobs[i][j].item()
|
|
if len(sequences) >= self.beam_size: break
|
|
|
|
return [[torch.tensor(seq) for seq in sequences.keys()] for sequences in self.finished_sequences], [list(sequences.values()) for sequences in self.finished_sequences]
|
|
|
|
class ApplyTimestampRules(LogitFilter):
|
|
def __init__(self, tokenizer, sample_begin, max_initial_timestamp_index):
|
|
self.tokenizer = tokenizer
|
|
self.sample_begin = sample_begin
|
|
self.max_initial_timestamp_index = max_initial_timestamp_index
|
|
|
|
def apply(self, logits, tokens):
|
|
if self.tokenizer.no_timestamps is not None: logits[:, self.tokenizer.no_timestamps] = -np.inf
|
|
|
|
for k in range(tokens.shape[0]):
|
|
sampled_tokens = tokens[k, self.sample_begin :]
|
|
seq = [t for t in sampled_tokens.tolist()]
|
|
|
|
last_was_timestamp = (len(seq) >= 1 and seq[-1] >= self.tokenizer.timestamp_begin)
|
|
penultimate_was_timestamp = (len(seq) < 2 or seq[-2] >= self.tokenizer.timestamp_begin)
|
|
|
|
if last_was_timestamp:
|
|
if penultimate_was_timestamp: logits[k, self.tokenizer.timestamp_begin :] = -np.inf
|
|
else: logits[k, : self.tokenizer.eot] = -np.inf
|
|
|
|
timestamps = sampled_tokens[sampled_tokens.ge(self.tokenizer.timestamp_begin)]
|
|
|
|
if timestamps.numel() > 0: logits[k, self.tokenizer.timestamp_begin : timestamps[-1] if last_was_timestamp and not penultimate_was_timestamp else (timestamps[-1] + 1)] = -np.inf
|
|
|
|
if tokens.shape[1] == self.sample_begin:
|
|
logits[:, : self.tokenizer.timestamp_begin] = -np.inf
|
|
|
|
if self.max_initial_timestamp_index is not None:
|
|
last_allowed = (self.tokenizer.timestamp_begin + self.max_initial_timestamp_index)
|
|
logits[:, last_allowed + 1 :] = -np.inf
|
|
|
|
logprobs = F.log_softmax(logits.float(), dim=-1)
|
|
for k in range(tokens.shape[0]):
|
|
if logprobs[k, self.tokenizer.timestamp_begin :].logsumexp(dim=-1) > logprobs[k, : self.tokenizer.timestamp_begin].max(): logits[k, : self.tokenizer.timestamp_begin] = -np.inf
|
|
|
|
class DecodingTask:
|
|
def __init__(self, model, options):
|
|
self.model = model
|
|
|
|
language = options.language or "en"
|
|
tokenizer = get_tokenizer(model.is_multilingual, num_languages=model.num_languages, language=language, task=options.task)
|
|
|
|
self.tokenizer = tokenizer
|
|
self.options = self._verify_options(options)
|
|
|
|
self.n_group = options.beam_size or options.best_of or 1
|
|
self.n_ctx = model.dims.n_text_ctx
|
|
self.sample_len = options.sample_len or model.dims.n_text_ctx // 2
|
|
|
|
self.sot_sequence = tokenizer.sot_sequence
|
|
if self.options.without_timestamps: self.sot_sequence = tokenizer.sot_sequence_including_notimestamps
|
|
|
|
self.initial_tokens = self._get_initial_tokens()
|
|
self.sample_begin = len(self.initial_tokens)
|
|
self.sot_index = self.initial_tokens.index(tokenizer.sot)
|
|
self.inference = PyTorchInference(model, len(self.initial_tokens))
|
|
self.sequence_ranker = MaximumLikelihoodRanker(options.length_penalty)
|
|
self.decoder = BeamSearchDecoder(options.beam_size, tokenizer.eot, self.inference, options.patience) if options.beam_size is not None else GreedyDecoder(options.temperature, tokenizer.eot)
|
|
|
|
self.logit_filters = []
|
|
|
|
if self.options.suppress_blank: self.logit_filters.append(SuppressBlank(self.tokenizer, self.sample_begin))
|
|
if self.options.suppress_tokens: self.logit_filters.append(SuppressTokens(self._get_suppress_tokens()))
|
|
|
|
if not options.without_timestamps:
|
|
max_initial_timestamp_index = None
|
|
if options.max_initial_timestamp: max_initial_timestamp_index = round(self.options.max_initial_timestamp / (CHUNK_LENGTH / model.dims.n_audio_ctx))
|
|
self.logit_filters.append(ApplyTimestampRules(tokenizer, self.sample_begin, max_initial_timestamp_index))
|
|
|
|
def _verify_options(self, options):
|
|
if options.beam_size is not None and options.best_of is not None: raise ValueError
|
|
if options.temperature == 0 and options.best_of is not None: raise ValueError
|
|
if options.patience is not None and options.beam_size is None: raise ValueError
|
|
if options.length_penalty is not None and not (0 <= options.length_penalty <= 1): raise ValueError
|
|
|
|
return options
|
|
|
|
def _get_initial_tokens(self):
|
|
tokens = list(self.sot_sequence)
|
|
|
|
if prefix := self.options.prefix:
|
|
prefix_tokens = (self.tokenizer.encode(" " + prefix.strip()) if isinstance(prefix, str) else prefix)
|
|
if self.sample_len is not None: prefix_tokens = prefix_tokens[-(self.n_ctx // 2 - self.sample_len):]
|
|
tokens = tokens + prefix_tokens
|
|
|
|
if prompt := self.options.prompt: tokens = ([self.tokenizer.sot_prev] + (self.tokenizer.encode(" " + prompt.strip()) if isinstance(prompt, str) else prompt)[-(self.n_ctx // 2 - 1) :] + tokens)
|
|
|
|
return tuple(tokens)
|
|
|
|
def _get_suppress_tokens(self):
|
|
suppress_tokens = self.options.suppress_tokens
|
|
if isinstance(suppress_tokens, str): suppress_tokens = [int(t) for t in suppress_tokens.split(",")]
|
|
|
|
if -1 in suppress_tokens:
|
|
suppress_tokens = [t for t in suppress_tokens if t >= 0]
|
|
suppress_tokens.extend(self.tokenizer.non_speech_tokens)
|
|
elif suppress_tokens is None or len(suppress_tokens) == 0: suppress_tokens = []
|
|
else: assert isinstance(suppress_tokens, list)
|
|
|
|
suppress_tokens.extend([self.tokenizer.transcribe, self.tokenizer.translate, self.tokenizer.sot, self.tokenizer.sot_prev, self.tokenizer.sot_lm])
|
|
|
|
if self.tokenizer.no_speech is not None: suppress_tokens.append(self.tokenizer.no_speech)
|
|
return tuple(sorted(set(suppress_tokens)))
|
|
|
|
def _get_audio_features(self, mel):
|
|
if self.options.fp16: mel = mel.half()
|
|
|
|
audio_features = mel if mel.shape[-2:] == (self.model.dims.n_audio_ctx, self.model.dims.n_audio_state) else self.model.encoder(mel)
|
|
if audio_features.dtype != (torch.float16 if self.options.fp16 else torch.float32): return TypeError(f"audio_features has an incorrect dtype: {audio_features.dtype}")
|
|
|
|
return audio_features
|
|
|
|
def _detect_language(self, audio_features, tokens):
|
|
languages = [self.options.language] * audio_features.shape[0]
|
|
lang_probs = None
|
|
|
|
if self.options.language is None or self.options.task == "lang_id":
|
|
lang_tokens, lang_probs = self.model.detect_language(audio_features, self.tokenizer)
|
|
languages = [max(probs, key=probs.get) for probs in lang_probs]
|
|
|
|
if self.options.language is None: tokens[:, self.sot_index + 1] = lang_tokens
|
|
|
|
return languages, lang_probs
|
|
|
|
def _main_loop(self, audio_features, tokens):
|
|
n_batch = tokens.shape[0]
|
|
sum_logprobs = torch.zeros(n_batch, device=audio_features.device)
|
|
no_speech_probs = [np.nan] * n_batch
|
|
|
|
try:
|
|
for i in range(self.sample_len):
|
|
logits = self.inference.logits(tokens, audio_features)
|
|
|
|
if (i == 0 and self.tokenizer.no_speech is not None):
|
|
probs_at_sot = logits[:, self.sot_index].float().softmax(dim=-1)
|
|
no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech].tolist()
|
|
|
|
logits = logits[:, -1]
|
|
for logit_filter in self.logit_filters:
|
|
logit_filter.apply(logits, tokens)
|
|
|
|
tokens, completed = self.decoder.update(tokens, logits, sum_logprobs)
|
|
if completed or tokens.shape[-1] > self.n_ctx: break
|
|
finally:
|
|
self.inference.cleanup_caching()
|
|
|
|
return tokens, sum_logprobs, no_speech_probs
|
|
|
|
@torch.no_grad()
|
|
def run(self, mel):
|
|
self.decoder.reset()
|
|
tokenizer = self.tokenizer
|
|
n_audio = mel.shape[0]
|
|
|
|
audio_features = self._get_audio_features(mel)
|
|
tokens = torch.tensor([self.initial_tokens]).repeat(n_audio, 1)
|
|
|
|
languages, language_probs = self._detect_language(audio_features, tokens)
|
|
if self.options.task == "lang_id": return [DecodingResult(audio_features=features, language=language, language_probs=probs) for features, language, probs in zip(audio_features, languages, language_probs)]
|
|
|
|
tokens = tokens.repeat_interleave(self.n_group, dim=0).to(audio_features.device)
|
|
tokens, sum_logprobs, no_speech_probs = self._main_loop(audio_features, tokens)
|
|
|
|
audio_features = audio_features[:: self.n_group]
|
|
no_speech_probs = no_speech_probs[:: self.n_group]
|
|
|
|
assert audio_features.shape[0] == len(no_speech_probs) == n_audio
|
|
|
|
tokens = tokens.reshape(n_audio, self.n_group, -1)
|
|
sum_logprobs = sum_logprobs.reshape(n_audio, self.n_group)
|
|
|
|
tokens, sum_logprobs = self.decoder.finalize(tokens, sum_logprobs)
|
|
tokens = [[t[self.sample_begin : (t == tokenizer.eot).nonzero()[0, 0]] for t in s] for s in tokens]
|
|
|
|
selected = self.sequence_ranker.rank(tokens, sum_logprobs)
|
|
tokens = [t[i].tolist() for i, t in zip(selected, tokens)]
|
|
|
|
fields = ([tokenizer.decode(t).strip() for t in tokens], languages, tokens, audio_features, [lp / (len(t) + 1) for t, lp in zip(tokens, [lp[i] for i, lp in zip(selected, sum_logprobs)])], no_speech_probs)
|
|
if len(set(map(len, fields))) != 1: raise RuntimeError
|
|
|
|
return [DecodingResult(audio_features=features, language=language, tokens=tokens, text=text, avg_logprob=avg_logprob, no_speech_prob=no_speech_prob, temperature=self.options.temperature, compression_ratio=compression_ratio(text)) for text, language, tokens, features, avg_logprob, no_speech_prob in zip(*fields)]
|
|
|
|
class DecodingResult:
|
|
def __init__(self, audio_features, language, language_probs = None, tokens = None, text = "", avg_logprob = np.nan, no_speech_prob = np.nan, temperature = np.nan, compression_ratio = np.nan):
|
|
self.audio_features = audio_features
|
|
self.language = language
|
|
self.language_probs = language_probs if language_probs is not None else {}
|
|
self.tokens = tokens if tokens is not None else []
|
|
self.text = text
|
|
self.avg_logprob = avg_logprob
|
|
self.no_speech_prob = no_speech_prob
|
|
self.temperature = temperature
|
|
self.compression_ratio = compression_ratio
|
|
|
|
class Tokenizer:
|
|
def __init__(self, encoding_name, num_languages = 2, language = None, task = None, sot_sequence = ()):
|
|
self.encoding = get_encoding(name=encoding_name, num_languages=num_languages)
|
|
self.num_languages = num_languages
|
|
self.language = language
|
|
self.task = task
|
|
self.sot_sequence = sot_sequence
|
|
self.special_tokens = {}
|
|
|
|
for special in self.encoding.special_tokens_set:
|
|
special_token = self.encoding.encode_single_token(special)
|
|
self.special_tokens[special] = special_token
|
|
|
|
sot = self.special_tokens["<|startoftranscript|>"]
|
|
langs = tuple(LANGUAGES.keys())[: self.num_languages]
|
|
sot_sequence = [sot]
|
|
|
|
if self.language is not None: sot_sequence.append(sot + 1 + langs.index(self.language))
|
|
if self.task is not None: sot_sequence.append(self.special_tokens["<|transcribe|>"] if self.task == "transcribe" else self.special_tokens["<|translate|>"])
|
|
|
|
self.sot_sequence = tuple(sot_sequence)
|
|
|
|
def encode(self, text, **kwargs):
|
|
return self.encoding.encode(text, **kwargs)
|
|
|
|
def decode(self, token_ids, **kwargs):
|
|
return self.encoding.decode([t for t in token_ids if t < self.timestamp_begin], **kwargs)
|
|
|
|
def decode_with_timestamps(self, token_ids, **kwargs):
|
|
return self.encoding.decode(token_ids, **kwargs)
|
|
|
|
@cached_property
|
|
def eot(self):
|
|
return self.encoding.eot_token
|
|
|
|
@cached_property
|
|
def transcribe(self):
|
|
return self.special_tokens["<|transcribe|>"]
|
|
|
|
@cached_property
|
|
def translate(self):
|
|
return self.special_tokens["<|translate|>"]
|
|
|
|
@cached_property
|
|
def sot(self):
|
|
return self.special_tokens["<|startoftranscript|>"]
|
|
|
|
@cached_property
|
|
def sot_lm(self):
|
|
return self.special_tokens["<|startoflm|>"]
|
|
|
|
@cached_property
|
|
def sot_prev(self):
|
|
return self.special_tokens["<|startofprev|>"]
|
|
|
|
@cached_property
|
|
def no_speech(self):
|
|
return self.special_tokens["<|nospeech|>"]
|
|
|
|
@cached_property
|
|
def no_timestamps(self):
|
|
return self.special_tokens["<|notimestamps|>"]
|
|
|
|
@cached_property
|
|
def timestamp_begin(self):
|
|
return self.special_tokens["<|0.00|>"]
|
|
|
|
@cached_property
|
|
def language_token(self):
|
|
if self.language is None: raise ValueError
|
|
return self.to_language_token(self.language)
|
|
|
|
def to_language_token(self, language):
|
|
if token := self.special_tokens.get(f"<|{language}|>", None): return token
|
|
raise KeyError
|
|
|
|
@cached_property
|
|
def all_language_tokens(self):
|
|
result = []
|
|
for token, token_id in self.special_tokens.items():
|
|
if token.strip("<|>") in LANGUAGES: result.append(token_id)
|
|
|
|
return tuple(result)[: self.num_languages]
|
|
|
|
@cached_property
|
|
def all_language_codes(self):
|
|
return tuple(self.decode([_l]).strip("<|>") for _l in self.all_language_tokens)
|
|
|
|
@cached_property
|
|
def sot_sequence_including_notimestamps(self):
|
|
return tuple(list(self.sot_sequence) + [self.no_timestamps])
|
|
|
|
@cached_property
|
|
def non_speech_tokens(self):
|
|
symbols = list('"#()*+/:;<=>@[\\]^_`{|}~「」『』')
|
|
symbols += ("<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} ♪♪ ♪♪♪".split())
|
|
|
|
miscellaneous = set("♩♪♫♬♭♮♯")
|
|
assert all(0x2640 <= ord(c) <= 0x267F for c in miscellaneous)
|
|
|
|
result = {self.encoding.encode(" -")[0], self.encoding.encode(" '")[0]}
|
|
for symbol in symbols + list(miscellaneous):
|
|
for tokens in [self.encoding.encode(symbol), self.encoding.encode(" " + symbol)]:
|
|
if len(tokens) == 1 or symbol in miscellaneous: result.add(tokens[0])
|
|
|
|
return tuple(sorted(result))
|
|
|
|
def split_to_word_tokens(self, tokens):
|
|
if self.language in {"zh", "ja", "th", "lo", "my", "yue"}: return self.split_tokens_on_unicode(tokens)
|
|
return self.split_tokens_on_spaces(tokens)
|
|
|
|
def split_tokens_on_unicode(self, tokens):
|
|
replacement_char = "\ufffd"
|
|
|
|
words, word_tokens, current_tokens = [], [], []
|
|
unicode_offset = 0
|
|
|
|
for token in tokens:
|
|
current_tokens.append(token)
|
|
decoded = self.decode_with_timestamps(current_tokens)
|
|
|
|
if (replacement_char not in decoded or self.decode_with_timestamps(tokens)[unicode_offset + decoded.index(replacement_char)] == replacement_char):
|
|
words.append(decoded)
|
|
word_tokens.append(current_tokens)
|
|
current_tokens = []
|
|
unicode_offset += len(decoded)
|
|
|
|
return words, word_tokens
|
|
|
|
def split_tokens_on_spaces(self, tokens):
|
|
subwords, subword_tokens_list = self.split_tokens_on_unicode(tokens)
|
|
words, word_tokens = [], []
|
|
|
|
for subword, subword_tokens in zip(subwords, subword_tokens_list):
|
|
if (subword_tokens[0] >= self.eot) or (subword.startswith(" ")) or (subword.strip() in string.punctuation) or len(words) == 0:
|
|
words.append(subword)
|
|
word_tokens.append(subword_tokens)
|
|
else:
|
|
words[-1] = words[-1] + subword
|
|
word_tokens[-1].extend(subword_tokens)
|
|
|
|
return words, word_tokens |