Spaces:
Runtime error
Runtime error
tonic
commited on
Commit
Β·
33d9042
1
Parent(s):
89d01e6
Laion WhisperSpeech Demo
Browse files- README.md +3 -3
- app.py +61 -0
- requirements.txt +3 -1
- whisperspeech/__init__.py +1 -0
- whisperspeech/_modidx.py +615 -0
- whisperspeech/a2wav.py +45 -0
- whisperspeech/extract_acoustic.py +56 -0
- whisperspeech/fetch_models.py +17 -0
- whisperspeech/languages.py +131 -0
- whisperspeech/modules.py +331 -0
- whisperspeech/pipeline.py +93 -0
- whisperspeech/prepare_s2a_dataset.py +112 -0
- whisperspeech/prepare_t2s_dataset.py +111 -0
- whisperspeech/s2a_delar_mup_wds.py +688 -0
- whisperspeech/s2a_delar_mup_wds_mlang.py +564 -0
- whisperspeech/t2s_up_wds.py +442 -0
- whisperspeech/t2s_up_wds_mlang_enclm.py +519 -0
- whisperspeech/train.py +271 -0
- whisperspeech/train_multi.py +263 -0
- whisperspeech/utils.py +159 -0
- whisperspeech/vad.py +71 -0
- whisperspeech/vq_stoks.py +493 -0
- whisperspeech/wer_metrics.py +77 -0
- whisperspeech/wh_transcribe.py +146 -0
README.md
CHANGED
|
@@ -1,12 +1,12 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
-
emoji:
|
| 4 |
colorFrom: pink
|
| 5 |
colorTo: gray
|
| 6 |
sdk: gradio
|
| 7 |
sdk_version: 4.15.0
|
| 8 |
app_file: app.py
|
| 9 |
-
pinned:
|
| 10 |
license: mit
|
| 11 |
---
|
| 12 |
|
|
|
|
| 1 |
---
|
| 2 |
+
title: WhisperSpeech
|
| 3 |
+
emoji: π¬οΈπ¬π
|
| 4 |
colorFrom: pink
|
| 5 |
colorTo: gray
|
| 6 |
sdk: gradio
|
| 7 |
sdk_version: 4.15.0
|
| 8 |
app_file: app.py
|
| 9 |
+
pinned: True
|
| 10 |
license: mit
|
| 11 |
---
|
| 12 |
|
app.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import spaces
|
| 2 |
+
import gradio as gr
|
| 3 |
+
import os
|
| 4 |
+
from whisperspeech.pipeline import Pipeline
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
from whisperspeech.languages import LANGUAGES
|
| 8 |
+
from whisperspeech.pipeline import Pipeline
|
| 9 |
+
import tempfil
|
| 10 |
+
|
| 11 |
+
title = """#ππ»ββοΈ Welcome toπTonic'sπ¬οΈπ¬πWhisperSpeech
|
| 12 |
+
You can use this ZeroGPU Space to test out the current model [π¬οΈπ¬πcollabora/whisperspeech](https://huggingface.co/collabora/whisperspeech). π¬οΈπ¬πcollabora/whisperspeech is An Open Source text-to-speech system built by inverting Whisper. Previously known as spear-tts-pytorch. It's like Stable Diffusion but for speech β both powerful and easily customizable.
|
| 13 |
+
You can also use π¬οΈπ¬πWhisperSpeech by cloning this space. π§¬π¬π Simply click here: <a style="display:inline-block" href="https://huggingface.co/spaces/Tonic/laion-whisper?duplicate=true"><img src="https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=&logoWidth=14" alt="Duplicate Space"></a></h3>
|
| 14 |
+
Join us : πTeamTonicπ is always making cool demos! Join our active builder'sπ οΈcommunity π» [](https://discord.gg/GWpVpekp) On π€Huggingface: [TeamTonic](https://huggingface.co/TeamTonic) & [MultiTransformer](https://huggingface.co/MultiTransformer) On πGithub: [Polytonic](https://github.com/tonic-ai) & contribute to π [Poly](https://github.com/tonic-ai/poly) π€Big thanks to Yuvi Sharma and all the folks at huggingface for the community grant π€
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
@spaces.GPU
|
| 18 |
+
|
| 19 |
+
def whisper_speech_demo(text, lang, speaker_audio=None, mix_lang=None, mix_text=None):
|
| 20 |
+
pipe = Pipeline(s2a_ref='collabora/whisperspeech:s2a-q4-tiny-en+pl.model')
|
| 21 |
+
|
| 22 |
+
# Use uploaded speaker audio if provided
|
| 23 |
+
speaker_url = None
|
| 24 |
+
if speaker_audio is not None:
|
| 25 |
+
speaker_url = speaker_audio.name
|
| 26 |
+
|
| 27 |
+
if mix_lang and mix_text:
|
| 28 |
+
mixed_langs = lang.split(',') + mix_lang.split(',')
|
| 29 |
+
mixed_texts = [text] + mix_text.split(',')
|
| 30 |
+
stoks = pipe.t2s.generate(mixed_texts, lang=mixed_langs)
|
| 31 |
+
audio_data = pipe.generate(stoks, speaker_url, lang=mixed_langs[0])
|
| 32 |
+
else:
|
| 33 |
+
audio_data = pipe.generate(text, speaker_url, lang)
|
| 34 |
+
|
| 35 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as tmp_file:
|
| 36 |
+
tmp_file_name = tmp_file.name
|
| 37 |
+
with open(tmp_file_name, 'wb') as file:
|
| 38 |
+
file.write(audio_data)
|
| 39 |
+
|
| 40 |
+
return tmp_file_name
|
| 41 |
+
|
| 42 |
+
with gr.Blocks() as demo:
|
| 43 |
+
gr.Markdown(title)
|
| 44 |
+
with gr.Row():
|
| 45 |
+
text_input = gr.Textbox(label="Enter text")
|
| 46 |
+
lang_input = gr.Dropdown(choices=list(LANGUAGES.keys()), label="Language")
|
| 47 |
+
speaker_input = gr.File(label="Upload Speaker Audio (optional)", type="file", accepts=["audio/*"])
|
| 48 |
+
with gr.Row():
|
| 49 |
+
mix_lang_input = gr.Textbox(label="Mixed Languages (optional, comma-separated)", placeholder="e.g., en,pl")
|
| 50 |
+
mix_text_input = gr.Textbox(label="Mixed Texts (optional, for mixed languages)", placeholder="e.g., Hello, CzeΕΔ")
|
| 51 |
+
with gr.Row():
|
| 52 |
+
submit_button = gr.Button("Generate Speech")
|
| 53 |
+
output_audio = gr.Audio(label="Generated Speech")
|
| 54 |
+
|
| 55 |
+
submit_button.click(
|
| 56 |
+
whisper_speech_demo,
|
| 57 |
+
inputs=[text_input, lang_input, speaker_input, mix_lang_input, mix_text_input],
|
| 58 |
+
outputs=output_audio
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
demo.launch()
|
requirements.txt
CHANGED
|
@@ -1 +1,3 @@
|
|
| 1 |
-
|
|
|
|
|
|
|
|
|
| 1 |
+
torch
|
| 2 |
+
transformers
|
| 3 |
+
accelerate
|
whisperspeech/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
__version__ = "0.5.6"
|
whisperspeech/_modidx.py
ADDED
|
@@ -0,0 +1,615 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Autogenerated by nbdev
|
| 2 |
+
|
| 3 |
+
d = { 'settings': { 'branch': 'master',
|
| 4 |
+
'doc_baseurl': '/WhisperSpeech',
|
| 5 |
+
'doc_host': 'https://collabora.github.io',
|
| 6 |
+
'git_url': 'https://github.com/collabora/WhisperSpeech',
|
| 7 |
+
'lib_path': 'whisperspeech'},
|
| 8 |
+
'syms': { 'whisperspeech.a2wav': { 'whisperspeech.a2wav.Vocoder': ('6. quality-boosting vocoder.html#vocoder', 'whisperspeech/a2wav.py'),
|
| 9 |
+
'whisperspeech.a2wav.Vocoder.__init__': ( '6. quality-boosting vocoder.html#vocoder.__init__',
|
| 10 |
+
'whisperspeech/a2wav.py'),
|
| 11 |
+
'whisperspeech.a2wav.Vocoder.decode': ( '6. quality-boosting vocoder.html#vocoder.decode',
|
| 12 |
+
'whisperspeech/a2wav.py'),
|
| 13 |
+
'whisperspeech.a2wav.Vocoder.decode_to_file': ( '6. quality-boosting '
|
| 14 |
+
'vocoder.html#vocoder.decode_to_file',
|
| 15 |
+
'whisperspeech/a2wav.py'),
|
| 16 |
+
'whisperspeech.a2wav.Vocoder.decode_to_notebook': ( '6. quality-boosting '
|
| 17 |
+
'vocoder.html#vocoder.decode_to_notebook',
|
| 18 |
+
'whisperspeech/a2wav.py')},
|
| 19 |
+
'whisperspeech.extract_acoustic': { 'whisperspeech.extract_acoustic.extract_Atoks': ( '1. acoustic token '
|
| 20 |
+
'extraction.html#extract_atoks',
|
| 21 |
+
'whisperspeech/extract_acoustic.py'),
|
| 22 |
+
'whisperspeech.extract_acoustic.extract_acoustic': ( '1. acoustic token '
|
| 23 |
+
'extraction.html#extract_acoustic',
|
| 24 |
+
'whisperspeech/extract_acoustic.py'),
|
| 25 |
+
'whisperspeech.extract_acoustic.load': ( '1. acoustic token extraction.html#load',
|
| 26 |
+
'whisperspeech/extract_acoustic.py'),
|
| 27 |
+
'whisperspeech.extract_acoustic.load_model': ( '1. acoustic token '
|
| 28 |
+
'extraction.html#load_model',
|
| 29 |
+
'whisperspeech/extract_acoustic.py')},
|
| 30 |
+
'whisperspeech.extract_semb': { 'whisperspeech.extract_semb.encode_semantic': ( '2c. whisper semantic embedding '
|
| 31 |
+
'extraction.html#encode_semantic',
|
| 32 |
+
'whisperspeech/extract_semb.py'),
|
| 33 |
+
'whisperspeech.extract_semb.extract_semantic': ( '2c. whisper semantic embedding '
|
| 34 |
+
'extraction.html#extract_semantic',
|
| 35 |
+
'whisperspeech/extract_semb.py'),
|
| 36 |
+
'whisperspeech.extract_semb.load_model': ( '2c. whisper semantic embedding '
|
| 37 |
+
'extraction.html#load_model',
|
| 38 |
+
'whisperspeech/extract_semb.py')},
|
| 39 |
+
'whisperspeech.fetch_models': { 'whisperspeech.fetch_models.main': ( '0. download models.html#main',
|
| 40 |
+
'whisperspeech/fetch_models.py')},
|
| 41 |
+
'whisperspeech.modules': { 'whisperspeech.modules.Decoder': ('a. neural modules.html#decoder', 'whisperspeech/modules.py'),
|
| 42 |
+
'whisperspeech.modules.Decoder.__init__': ( 'a. neural modules.html#decoder.__init__',
|
| 43 |
+
'whisperspeech/modules.py'),
|
| 44 |
+
'whisperspeech.modules.Decoder.forward': ( 'a. neural modules.html#decoder.forward',
|
| 45 |
+
'whisperspeech/modules.py'),
|
| 46 |
+
'whisperspeech.modules.Encoder': ('a. neural modules.html#encoder', 'whisperspeech/modules.py'),
|
| 47 |
+
'whisperspeech.modules.Encoder.__init__': ( 'a. neural modules.html#encoder.__init__',
|
| 48 |
+
'whisperspeech/modules.py'),
|
| 49 |
+
'whisperspeech.modules.Encoder.forward': ( 'a. neural modules.html#encoder.forward',
|
| 50 |
+
'whisperspeech/modules.py'),
|
| 51 |
+
'whisperspeech.modules.LayerNorm': ('a. neural modules.html#layernorm', 'whisperspeech/modules.py'),
|
| 52 |
+
'whisperspeech.modules.LayerNorm.forward': ( 'a. neural modules.html#layernorm.forward',
|
| 53 |
+
'whisperspeech/modules.py'),
|
| 54 |
+
'whisperspeech.modules.LinearHead': ( 'a. neural modules.html#linearhead',
|
| 55 |
+
'whisperspeech/modules.py'),
|
| 56 |
+
'whisperspeech.modules.MultiHeadAttention': ( 'a. neural modules.html#multiheadattention',
|
| 57 |
+
'whisperspeech/modules.py'),
|
| 58 |
+
'whisperspeech.modules.MultiHeadAttention.__init__': ( 'a. neural '
|
| 59 |
+
'modules.html#multiheadattention.__init__',
|
| 60 |
+
'whisperspeech/modules.py'),
|
| 61 |
+
'whisperspeech.modules.MultiHeadAttention.forward': ( 'a. neural '
|
| 62 |
+
'modules.html#multiheadattention.forward',
|
| 63 |
+
'whisperspeech/modules.py'),
|
| 64 |
+
'whisperspeech.modules.MultiHeadAttention.qkv_attention_pth20': ( 'a. neural '
|
| 65 |
+
'modules.html#multiheadattention.qkv_attention_pth20',
|
| 66 |
+
'whisperspeech/modules.py'),
|
| 67 |
+
'whisperspeech.modules.MultiHeadAttention.qkv_attention_vanilla': ( 'a. neural '
|
| 68 |
+
'modules.html#multiheadattention.qkv_attention_vanilla',
|
| 69 |
+
'whisperspeech/modules.py'),
|
| 70 |
+
'whisperspeech.modules.MultiHeadAttention.qkv_attention_xformers': ( 'a. neural '
|
| 71 |
+
'modules.html#multiheadattention.qkv_attention_xformers',
|
| 72 |
+
'whisperspeech/modules.py'),
|
| 73 |
+
'whisperspeech.modules.QueryHead': ('a. neural modules.html#queryhead', 'whisperspeech/modules.py'),
|
| 74 |
+
'whisperspeech.modules.ResidualAttentionBlock': ( 'a. neural modules.html#residualattentionblock',
|
| 75 |
+
'whisperspeech/modules.py'),
|
| 76 |
+
'whisperspeech.modules.ResidualAttentionBlock.__init__': ( 'a. neural '
|
| 77 |
+
'modules.html#residualattentionblock.__init__',
|
| 78 |
+
'whisperspeech/modules.py'),
|
| 79 |
+
'whisperspeech.modules.ResidualAttentionBlock.forward': ( 'a. neural '
|
| 80 |
+
'modules.html#residualattentionblock.forward',
|
| 81 |
+
'whisperspeech/modules.py'),
|
| 82 |
+
'whisperspeech.modules.Rotary': ('a. neural modules.html#rotary', 'whisperspeech/modules.py'),
|
| 83 |
+
'whisperspeech.modules.Rotary.__init__': ( 'a. neural modules.html#rotary.__init__',
|
| 84 |
+
'whisperspeech/modules.py'),
|
| 85 |
+
'whisperspeech.modules.Rotary.forward': ( 'a. neural modules.html#rotary.forward',
|
| 86 |
+
'whisperspeech/modules.py'),
|
| 87 |
+
'whisperspeech.modules.SumDecoder': ( 'a. neural modules.html#sumdecoder',
|
| 88 |
+
'whisperspeech/modules.py'),
|
| 89 |
+
'whisperspeech.modules.SumDecoder.__init__': ( 'a. neural modules.html#sumdecoder.__init__',
|
| 90 |
+
'whisperspeech/modules.py'),
|
| 91 |
+
'whisperspeech.modules.SumDecoder.forward': ( 'a. neural modules.html#sumdecoder.forward',
|
| 92 |
+
'whisperspeech/modules.py'),
|
| 93 |
+
'whisperspeech.modules.apply_rotary_pos_emb': ( 'a. neural modules.html#apply_rotary_pos_emb',
|
| 94 |
+
'whisperspeech/modules.py'),
|
| 95 |
+
'whisperspeech.modules.init_transformer': ( 'a. neural modules.html#init_transformer',
|
| 96 |
+
'whisperspeech/modules.py'),
|
| 97 |
+
'whisperspeech.modules.rotate_half': ( 'a. neural modules.html#rotate_half',
|
| 98 |
+
'whisperspeech/modules.py'),
|
| 99 |
+
'whisperspeech.modules.sinusoids': ('a. neural modules.html#sinusoids', 'whisperspeech/modules.py')},
|
| 100 |
+
'whisperspeech.pipeline': { 'whisperspeech.pipeline.Pipeline': ('7. pipeline.html#pipeline', 'whisperspeech/pipeline.py'),
|
| 101 |
+
'whisperspeech.pipeline.Pipeline.__init__': ( '7. pipeline.html#pipeline.__init__',
|
| 102 |
+
'whisperspeech/pipeline.py'),
|
| 103 |
+
'whisperspeech.pipeline.Pipeline.generate': ( '7. pipeline.html#pipeline.generate',
|
| 104 |
+
'whisperspeech/pipeline.py'),
|
| 105 |
+
'whisperspeech.pipeline.Pipeline.generate_atoks': ( '7. pipeline.html#pipeline.generate_atoks',
|
| 106 |
+
'whisperspeech/pipeline.py'),
|
| 107 |
+
'whisperspeech.pipeline.Pipeline.generate_to_file': ( '7. pipeline.html#pipeline.generate_to_file',
|
| 108 |
+
'whisperspeech/pipeline.py'),
|
| 109 |
+
'whisperspeech.pipeline.Pipeline.generate_to_notebook': ( '7. '
|
| 110 |
+
'pipeline.html#pipeline.generate_to_notebook',
|
| 111 |
+
'whisperspeech/pipeline.py')},
|
| 112 |
+
'whisperspeech.prepare_s2a_dataset': { 'whisperspeech.prepare_s2a_dataset.flac_to_s2a_name': ( '4a. s2a dataset '
|
| 113 |
+
'preparation.html#flac_to_s2a_name',
|
| 114 |
+
'whisperspeech/prepare_s2a_dataset.py'),
|
| 115 |
+
'whisperspeech.prepare_s2a_dataset.prepare_s2a': ( '4a. s2a dataset '
|
| 116 |
+
'preparation.html#prepare_s2a',
|
| 117 |
+
'whisperspeech/prepare_s2a_dataset.py'),
|
| 118 |
+
'whisperspeech.prepare_s2a_dataset.resampler': ( '4a. s2a dataset '
|
| 119 |
+
'preparation.html#resampler',
|
| 120 |
+
'whisperspeech/prepare_s2a_dataset.py')},
|
| 121 |
+
'whisperspeech.prepare_t2s_dataset': { 'whisperspeech.prepare_t2s_dataset.Transcriber': ( '5a. t2s dataset '
|
| 122 |
+
'preparation.html#transcriber',
|
| 123 |
+
'whisperspeech/prepare_t2s_dataset.py'),
|
| 124 |
+
'whisperspeech.prepare_t2s_dataset.Transcriber.__init__': ( '5a. t2s dataset '
|
| 125 |
+
'preparation.html#transcriber.__init__',
|
| 126 |
+
'whisperspeech/prepare_t2s_dataset.py'),
|
| 127 |
+
'whisperspeech.prepare_t2s_dataset.Transcriber.transcribe': ( '5a. t2s dataset '
|
| 128 |
+
'preparation.html#transcriber.transcribe',
|
| 129 |
+
'whisperspeech/prepare_t2s_dataset.py'),
|
| 130 |
+
'whisperspeech.prepare_t2s_dataset.flac_to_t2s_name': ( '5a. t2s dataset '
|
| 131 |
+
'preparation.html#flac_to_t2s_name',
|
| 132 |
+
'whisperspeech/prepare_t2s_dataset.py'),
|
| 133 |
+
'whisperspeech.prepare_t2s_dataset.prepare_t2s': ( '5a. t2s dataset '
|
| 134 |
+
'preparation.html#prepare_t2s',
|
| 135 |
+
'whisperspeech/prepare_t2s_dataset.py')},
|
| 136 |
+
'whisperspeech.s2a_delar_mup_wds': { 'whisperspeech.s2a_delar_mup_wds.CMLMVisual': ( '4b. semantic to acoustic token '
|
| 137 |
+
'modeling.html#cmlmvisual',
|
| 138 |
+
'whisperspeech/s2a_delar_mup_wds.py'),
|
| 139 |
+
'whisperspeech.s2a_delar_mup_wds.CMLMVisual.__init__': ( '4b. semantic to acoustic token '
|
| 140 |
+
'modeling.html#cmlmvisual.__init__',
|
| 141 |
+
'whisperspeech/s2a_delar_mup_wds.py'),
|
| 142 |
+
'whisperspeech.s2a_delar_mup_wds.CMLMVisual.add_data': ( '4b. semantic to acoustic token '
|
| 143 |
+
'modeling.html#cmlmvisual.add_data',
|
| 144 |
+
'whisperspeech/s2a_delar_mup_wds.py'),
|
| 145 |
+
'whisperspeech.s2a_delar_mup_wds.CMLMVisual.add_table_row': ( '4b. semantic to acoustic '
|
| 146 |
+
'token '
|
| 147 |
+
'modeling.html#cmlmvisual.add_table_row',
|
| 148 |
+
'whisperspeech/s2a_delar_mup_wds.py'),
|
| 149 |
+
'whisperspeech.s2a_delar_mup_wds.CMLMVisual.hide': ( '4b. semantic to acoustic token '
|
| 150 |
+
'modeling.html#cmlmvisual.hide',
|
| 151 |
+
'whisperspeech/s2a_delar_mup_wds.py'),
|
| 152 |
+
'whisperspeech.s2a_delar_mup_wds.CMLMVisual.on_iter': ( '4b. semantic to acoustic token '
|
| 153 |
+
'modeling.html#cmlmvisual.on_iter',
|
| 154 |
+
'whisperspeech/s2a_delar_mup_wds.py'),
|
| 155 |
+
'whisperspeech.s2a_delar_mup_wds.CMLMVisual.plot': ( '4b. semantic to acoustic token '
|
| 156 |
+
'modeling.html#cmlmvisual.plot',
|
| 157 |
+
'whisperspeech/s2a_delar_mup_wds.py'),
|
| 158 |
+
'whisperspeech.s2a_delar_mup_wds.CMLMVisual.show': ( '4b. semantic to acoustic token '
|
| 159 |
+
'modeling.html#cmlmvisual.show',
|
| 160 |
+
'whisperspeech/s2a_delar_mup_wds.py'),
|
| 161 |
+
'whisperspeech.s2a_delar_mup_wds.DelSumDecoder': ( '4b. semantic to acoustic token '
|
| 162 |
+
'modeling.html#delsumdecoder',
|
| 163 |
+
'whisperspeech/s2a_delar_mup_wds.py'),
|
| 164 |
+
'whisperspeech.s2a_delar_mup_wds.DelSumDecoder.__init__': ( '4b. semantic to acoustic '
|
| 165 |
+
'token '
|
| 166 |
+
'modeling.html#delsumdecoder.__init__',
|
| 167 |
+
'whisperspeech/s2a_delar_mup_wds.py'),
|
| 168 |
+
'whisperspeech.s2a_delar_mup_wds.DelSumDecoder.forward': ( '4b. semantic to acoustic '
|
| 169 |
+
'token '
|
| 170 |
+
'modeling.html#delsumdecoder.forward',
|
| 171 |
+
'whisperspeech/s2a_delar_mup_wds.py'),
|
| 172 |
+
'whisperspeech.s2a_delar_mup_wds.EmbeddingProjector': ( '4b. semantic to acoustic token '
|
| 173 |
+
'modeling.html#embeddingprojector',
|
| 174 |
+
'whisperspeech/s2a_delar_mup_wds.py'),
|
| 175 |
+
'whisperspeech.s2a_delar_mup_wds.MultiHeadAttention': ( '4b. semantic to acoustic token '
|
| 176 |
+
'modeling.html#multiheadattention',
|
| 177 |
+
'whisperspeech/s2a_delar_mup_wds.py'),
|
| 178 |
+
'whisperspeech.s2a_delar_mup_wds.MultiHeadAttention.__init__': ( '4b. semantic to '
|
| 179 |
+
'acoustic token '
|
| 180 |
+
'modeling.html#multiheadattention.__init__',
|
| 181 |
+
'whisperspeech/s2a_delar_mup_wds.py'),
|
| 182 |
+
'whisperspeech.s2a_delar_mup_wds.MultiHeadAttention.forward': ( '4b. semantic to acoustic '
|
| 183 |
+
'token '
|
| 184 |
+
'modeling.html#multiheadattention.forward',
|
| 185 |
+
'whisperspeech/s2a_delar_mup_wds.py'),
|
| 186 |
+
'whisperspeech.s2a_delar_mup_wds.MultiHeadAttention.qkv_attention_pth20': ( '4b. semantic '
|
| 187 |
+
'to acoustic '
|
| 188 |
+
'token '
|
| 189 |
+
'modeling.html#multiheadattention.qkv_attention_pth20',
|
| 190 |
+
'whisperspeech/s2a_delar_mup_wds.py'),
|
| 191 |
+
'whisperspeech.s2a_delar_mup_wds.MultiHeadAttention.qkv_attention_xformers': ( '4b. '
|
| 192 |
+
'semantic '
|
| 193 |
+
'to '
|
| 194 |
+
'acoustic '
|
| 195 |
+
'token '
|
| 196 |
+
'modeling.html#multiheadattention.qkv_attention_xformers',
|
| 197 |
+
'whisperspeech/s2a_delar_mup_wds.py'),
|
| 198 |
+
'whisperspeech.s2a_delar_mup_wds.ResidualAttentionBlock': ( '4b. semantic to acoustic '
|
| 199 |
+
'token '
|
| 200 |
+
'modeling.html#residualattentionblock',
|
| 201 |
+
'whisperspeech/s2a_delar_mup_wds.py'),
|
| 202 |
+
'whisperspeech.s2a_delar_mup_wds.ResidualAttentionBlock.__init__': ( '4b. semantic to '
|
| 203 |
+
'acoustic token '
|
| 204 |
+
'modeling.html#residualattentionblock.__init__',
|
| 205 |
+
'whisperspeech/s2a_delar_mup_wds.py'),
|
| 206 |
+
'whisperspeech.s2a_delar_mup_wds.ResidualAttentionBlock.forward': ( '4b. semantic to '
|
| 207 |
+
'acoustic token '
|
| 208 |
+
'modeling.html#residualattentionblock.forward',
|
| 209 |
+
'whisperspeech/s2a_delar_mup_wds.py'),
|
| 210 |
+
'whisperspeech.s2a_delar_mup_wds.Rotary': ( '4b. semantic to acoustic token '
|
| 211 |
+
'modeling.html#rotary',
|
| 212 |
+
'whisperspeech/s2a_delar_mup_wds.py'),
|
| 213 |
+
'whisperspeech.s2a_delar_mup_wds.Rotary.__init__': ( '4b. semantic to acoustic token '
|
| 214 |
+
'modeling.html#rotary.__init__',
|
| 215 |
+
'whisperspeech/s2a_delar_mup_wds.py'),
|
| 216 |
+
'whisperspeech.s2a_delar_mup_wds.Rotary.forward': ( '4b. semantic to acoustic token '
|
| 217 |
+
'modeling.html#rotary.forward',
|
| 218 |
+
'whisperspeech/s2a_delar_mup_wds.py'),
|
| 219 |
+
'whisperspeech.s2a_delar_mup_wds.SADelARTransformer': ( '4b. semantic to acoustic token '
|
| 220 |
+
'modeling.html#sadelartransformer',
|
| 221 |
+
'whisperspeech/s2a_delar_mup_wds.py'),
|
| 222 |
+
'whisperspeech.s2a_delar_mup_wds.SADelARTransformer.__init__': ( '4b. semantic to '
|
| 223 |
+
'acoustic token '
|
| 224 |
+
'modeling.html#sadelartransformer.__init__',
|
| 225 |
+
'whisperspeech/s2a_delar_mup_wds.py'),
|
| 226 |
+
'whisperspeech.s2a_delar_mup_wds.SADelARTransformer.device': ( '4b. semantic to acoustic '
|
| 227 |
+
'token '
|
| 228 |
+
'modeling.html#sadelartransformer.device',
|
| 229 |
+
'whisperspeech/s2a_delar_mup_wds.py'),
|
| 230 |
+
'whisperspeech.s2a_delar_mup_wds.SADelARTransformer.embed_stoks': ( '4b. semantic to '
|
| 231 |
+
'acoustic token '
|
| 232 |
+
'modeling.html#sadelartransformer.embed_stoks',
|
| 233 |
+
'whisperspeech/s2a_delar_mup_wds.py'),
|
| 234 |
+
'whisperspeech.s2a_delar_mup_wds.SADelARTransformer.forward': ( '4b. semantic to acoustic '
|
| 235 |
+
'token '
|
| 236 |
+
'modeling.html#sadelartransformer.forward',
|
| 237 |
+
'whisperspeech/s2a_delar_mup_wds.py'),
|
| 238 |
+
'whisperspeech.s2a_delar_mup_wds.SADelARTransformer.generate': ( '4b. semantic to '
|
| 239 |
+
'acoustic token '
|
| 240 |
+
'modeling.html#sadelartransformer.generate',
|
| 241 |
+
'whisperspeech/s2a_delar_mup_wds.py'),
|
| 242 |
+
'whisperspeech.s2a_delar_mup_wds.SADelARTransformer.get_extra_state': ( '4b. semantic to '
|
| 243 |
+
'acoustic token '
|
| 244 |
+
'modeling.html#sadelartransformer.get_extra_state',
|
| 245 |
+
'whisperspeech/s2a_delar_mup_wds.py'),
|
| 246 |
+
'whisperspeech.s2a_delar_mup_wds.SADelARTransformer.get_metrics': ( '4b. semantic to '
|
| 247 |
+
'acoustic token '
|
| 248 |
+
'modeling.html#sadelartransformer.get_metrics',
|
| 249 |
+
'whisperspeech/s2a_delar_mup_wds.py'),
|
| 250 |
+
'whisperspeech.s2a_delar_mup_wds.SADelARTransformer.init_transformer': ( '4b. semantic to '
|
| 251 |
+
'acoustic token '
|
| 252 |
+
'modeling.html#sadelartransformer.init_transformer',
|
| 253 |
+
'whisperspeech/s2a_delar_mup_wds.py'),
|
| 254 |
+
'whisperspeech.s2a_delar_mup_wds.SADelARTransformer.load_checkpoint': ( '4b. semantic to '
|
| 255 |
+
'acoustic token '
|
| 256 |
+
'modeling.html#sadelartransformer.load_checkpoint',
|
| 257 |
+
'whisperspeech/s2a_delar_mup_wds.py'),
|
| 258 |
+
'whisperspeech.s2a_delar_mup_wds.SADelARTransformer.load_frozen_semantic_embeddings': ( '4b. '
|
| 259 |
+
'semantic '
|
| 260 |
+
'to '
|
| 261 |
+
'acoustic '
|
| 262 |
+
'token '
|
| 263 |
+
'modeling.html#sadelartransformer.load_frozen_semantic_embeddings',
|
| 264 |
+
'whisperspeech/s2a_delar_mup_wds.py'),
|
| 265 |
+
'whisperspeech.s2a_delar_mup_wds.SADelARTransformer.load_model': ( '4b. semantic to '
|
| 266 |
+
'acoustic token '
|
| 267 |
+
'modeling.html#sadelartransformer.load_model',
|
| 268 |
+
'whisperspeech/s2a_delar_mup_wds.py'),
|
| 269 |
+
'whisperspeech.s2a_delar_mup_wds.SADelARTransformer.save_model': ( '4b. semantic to '
|
| 270 |
+
'acoustic token '
|
| 271 |
+
'modeling.html#sadelartransformer.save_model',
|
| 272 |
+
'whisperspeech/s2a_delar_mup_wds.py'),
|
| 273 |
+
'whisperspeech.s2a_delar_mup_wds.SADelARTransformer.set_extra_state': ( '4b. semantic to '
|
| 274 |
+
'acoustic token '
|
| 275 |
+
'modeling.html#sadelartransformer.set_extra_state',
|
| 276 |
+
'whisperspeech/s2a_delar_mup_wds.py'),
|
| 277 |
+
'whisperspeech.s2a_delar_mup_wds.SADelARTransformer.setup': ( '4b. semantic to acoustic '
|
| 278 |
+
'token '
|
| 279 |
+
'modeling.html#sadelartransformer.setup',
|
| 280 |
+
'whisperspeech/s2a_delar_mup_wds.py'),
|
| 281 |
+
'whisperspeech.s2a_delar_mup_wds.Tunables': ( '4b. semantic to acoustic token '
|
| 282 |
+
'modeling.html#tunables',
|
| 283 |
+
'whisperspeech/s2a_delar_mup_wds.py'),
|
| 284 |
+
'whisperspeech.s2a_delar_mup_wds.Tunables.__post_init__': ( '4b. semantic to acoustic '
|
| 285 |
+
'token '
|
| 286 |
+
'modeling.html#tunables.__post_init__',
|
| 287 |
+
'whisperspeech/s2a_delar_mup_wds.py'),
|
| 288 |
+
'whisperspeech.s2a_delar_mup_wds.Tunables.upgrade': ( '4b. semantic to acoustic token '
|
| 289 |
+
'modeling.html#tunables.upgrade',
|
| 290 |
+
'whisperspeech/s2a_delar_mup_wds.py'),
|
| 291 |
+
'whisperspeech.s2a_delar_mup_wds._make_model': ( '4b. semantic to acoustic token '
|
| 292 |
+
'modeling.html#_make_model',
|
| 293 |
+
'whisperspeech/s2a_delar_mup_wds.py'),
|
| 294 |
+
'whisperspeech.s2a_delar_mup_wds.apply_rotary_pos_emb': ( '4b. semantic to acoustic token '
|
| 295 |
+
'modeling.html#apply_rotary_pos_emb',
|
| 296 |
+
'whisperspeech/s2a_delar_mup_wds.py'),
|
| 297 |
+
'whisperspeech.s2a_delar_mup_wds.load_datasets': ( '4b. semantic to acoustic token '
|
| 298 |
+
'modeling.html#load_datasets',
|
| 299 |
+
'whisperspeech/s2a_delar_mup_wds.py'),
|
| 300 |
+
'whisperspeech.s2a_delar_mup_wds.make_model': ( '4b. semantic to acoustic token '
|
| 301 |
+
'modeling.html#make_model',
|
| 302 |
+
'whisperspeech/s2a_delar_mup_wds.py'),
|
| 303 |
+
'whisperspeech.s2a_delar_mup_wds.pad_samples': ( '4b. semantic to acoustic token '
|
| 304 |
+
'modeling.html#pad_samples',
|
| 305 |
+
'whisperspeech/s2a_delar_mup_wds.py'),
|
| 306 |
+
'whisperspeech.s2a_delar_mup_wds.rand': ( '4b. semantic to acoustic token '
|
| 307 |
+
'modeling.html#rand',
|
| 308 |
+
'whisperspeech/s2a_delar_mup_wds.py'),
|
| 309 |
+
'whisperspeech.s2a_delar_mup_wds.random_trunc': ( '4b. semantic to acoustic token '
|
| 310 |
+
'modeling.html#random_trunc',
|
| 311 |
+
'whisperspeech/s2a_delar_mup_wds.py'),
|
| 312 |
+
'whisperspeech.s2a_delar_mup_wds.rotate_half': ( '4b. semantic to acoustic token '
|
| 313 |
+
'modeling.html#rotate_half',
|
| 314 |
+
'whisperspeech/s2a_delar_mup_wds.py'),
|
| 315 |
+
'whisperspeech.s2a_delar_mup_wds.speaker_id_extractor': ( '4b. semantic to acoustic token '
|
| 316 |
+
'modeling.html#speaker_id_extractor',
|
| 317 |
+
'whisperspeech/s2a_delar_mup_wds.py')},
|
| 318 |
+
'whisperspeech.t2s_up_wds': { 'whisperspeech.t2s_up_wds.CharTokenizer': ( '5b. text to semantic token '
|
| 319 |
+
'modeling.html#chartokenizer',
|
| 320 |
+
'whisperspeech/t2s_up_wds.py'),
|
| 321 |
+
'whisperspeech.t2s_up_wds.CharTokenizer.decode': ( '5b. text to semantic token '
|
| 322 |
+
'modeling.html#chartokenizer.decode',
|
| 323 |
+
'whisperspeech/t2s_up_wds.py'),
|
| 324 |
+
'whisperspeech.t2s_up_wds.CharTokenizer.encode': ( '5b. text to semantic token '
|
| 325 |
+
'modeling.html#chartokenizer.encode',
|
| 326 |
+
'whisperspeech/t2s_up_wds.py'),
|
| 327 |
+
'whisperspeech.t2s_up_wds.Decoder': ( '5b. text to semantic token modeling.html#decoder',
|
| 328 |
+
'whisperspeech/t2s_up_wds.py'),
|
| 329 |
+
'whisperspeech.t2s_up_wds.Decoder.__init__': ( '5b. text to semantic token '
|
| 330 |
+
'modeling.html#decoder.__init__',
|
| 331 |
+
'whisperspeech/t2s_up_wds.py'),
|
| 332 |
+
'whisperspeech.t2s_up_wds.Decoder.forward': ( '5b. text to semantic token '
|
| 333 |
+
'modeling.html#decoder.forward',
|
| 334 |
+
'whisperspeech/t2s_up_wds.py'),
|
| 335 |
+
'whisperspeech.t2s_up_wds.EmbeddingProjector': ( '5b. text to semantic token '
|
| 336 |
+
'modeling.html#embeddingprojector',
|
| 337 |
+
'whisperspeech/t2s_up_wds.py'),
|
| 338 |
+
'whisperspeech.t2s_up_wds.Encoder': ( '5b. text to semantic token modeling.html#encoder',
|
| 339 |
+
'whisperspeech/t2s_up_wds.py'),
|
| 340 |
+
'whisperspeech.t2s_up_wds.Encoder.__init__': ( '5b. text to semantic token '
|
| 341 |
+
'modeling.html#encoder.__init__',
|
| 342 |
+
'whisperspeech/t2s_up_wds.py'),
|
| 343 |
+
'whisperspeech.t2s_up_wds.Encoder.forward': ( '5b. text to semantic token '
|
| 344 |
+
'modeling.html#encoder.forward',
|
| 345 |
+
'whisperspeech/t2s_up_wds.py'),
|
| 346 |
+
'whisperspeech.t2s_up_wds.TSARTransformer': ( '5b. text to semantic token '
|
| 347 |
+
'modeling.html#tsartransformer',
|
| 348 |
+
'whisperspeech/t2s_up_wds.py'),
|
| 349 |
+
'whisperspeech.t2s_up_wds.TSARTransformer.__init__': ( '5b. text to semantic token '
|
| 350 |
+
'modeling.html#tsartransformer.__init__',
|
| 351 |
+
'whisperspeech/t2s_up_wds.py'),
|
| 352 |
+
'whisperspeech.t2s_up_wds.TSARTransformer.device': ( '5b. text to semantic token '
|
| 353 |
+
'modeling.html#tsartransformer.device',
|
| 354 |
+
'whisperspeech/t2s_up_wds.py'),
|
| 355 |
+
'whisperspeech.t2s_up_wds.TSARTransformer.ensure_tokenizer': ( '5b. text to semantic token '
|
| 356 |
+
'modeling.html#tsartransformer.ensure_tokenizer',
|
| 357 |
+
'whisperspeech/t2s_up_wds.py'),
|
| 358 |
+
'whisperspeech.t2s_up_wds.TSARTransformer.forward': ( '5b. text to semantic token '
|
| 359 |
+
'modeling.html#tsartransformer.forward',
|
| 360 |
+
'whisperspeech/t2s_up_wds.py'),
|
| 361 |
+
'whisperspeech.t2s_up_wds.TSARTransformer.generate': ( '5b. text to semantic token '
|
| 362 |
+
'modeling.html#tsartransformer.generate',
|
| 363 |
+
'whisperspeech/t2s_up_wds.py'),
|
| 364 |
+
'whisperspeech.t2s_up_wds.TSARTransformer.generate_batch': ( '5b. text to semantic token '
|
| 365 |
+
'modeling.html#tsartransformer.generate_batch',
|
| 366 |
+
'whisperspeech/t2s_up_wds.py'),
|
| 367 |
+
'whisperspeech.t2s_up_wds.TSARTransformer.init_transformer': ( '5b. text to semantic token '
|
| 368 |
+
'modeling.html#tsartransformer.init_transformer',
|
| 369 |
+
'whisperspeech/t2s_up_wds.py'),
|
| 370 |
+
'whisperspeech.t2s_up_wds.TSARTransformer.load_checkpoint': ( '5b. text to semantic token '
|
| 371 |
+
'modeling.html#tsartransformer.load_checkpoint',
|
| 372 |
+
'whisperspeech/t2s_up_wds.py'),
|
| 373 |
+
'whisperspeech.t2s_up_wds.TSARTransformer.load_frozen_semantic_embeddings': ( '5b. text to '
|
| 374 |
+
'semantic token '
|
| 375 |
+
'modeling.html#tsartransformer.load_frozen_semantic_embeddings',
|
| 376 |
+
'whisperspeech/t2s_up_wds.py'),
|
| 377 |
+
'whisperspeech.t2s_up_wds.TSARTransformer.load_model': ( '5b. text to semantic token '
|
| 378 |
+
'modeling.html#tsartransformer.load_model',
|
| 379 |
+
'whisperspeech/t2s_up_wds.py'),
|
| 380 |
+
'whisperspeech.t2s_up_wds.TSARTransformer.save_model': ( '5b. text to semantic token '
|
| 381 |
+
'modeling.html#tsartransformer.save_model',
|
| 382 |
+
'whisperspeech/t2s_up_wds.py'),
|
| 383 |
+
'whisperspeech.t2s_up_wds.TSARTransformer.setup': ( '5b. text to semantic token '
|
| 384 |
+
'modeling.html#tsartransformer.setup',
|
| 385 |
+
'whisperspeech/t2s_up_wds.py'),
|
| 386 |
+
'whisperspeech.t2s_up_wds.Tunables': ( '5b. text to semantic token modeling.html#tunables',
|
| 387 |
+
'whisperspeech/t2s_up_wds.py'),
|
| 388 |
+
'whisperspeech.t2s_up_wds.Tunables.__post_init__': ( '5b. text to semantic token '
|
| 389 |
+
'modeling.html#tunables.__post_init__',
|
| 390 |
+
'whisperspeech/t2s_up_wds.py'),
|
| 391 |
+
'whisperspeech.t2s_up_wds._make_model': ( '5b. text to semantic token modeling.html#_make_model',
|
| 392 |
+
'whisperspeech/t2s_up_wds.py'),
|
| 393 |
+
'whisperspeech.t2s_up_wds.ar_padder': ( '5b. text to semantic token modeling.html#ar_padder',
|
| 394 |
+
'whisperspeech/t2s_up_wds.py'),
|
| 395 |
+
'whisperspeech.t2s_up_wds.build_speaker_map': ( '5b. text to semantic token '
|
| 396 |
+
'modeling.html#build_speaker_map',
|
| 397 |
+
'whisperspeech/t2s_up_wds.py'),
|
| 398 |
+
'whisperspeech.t2s_up_wds.char_per_seconder': ( '5b. text to semantic token '
|
| 399 |
+
'modeling.html#char_per_seconder',
|
| 400 |
+
'whisperspeech/t2s_up_wds.py'),
|
| 401 |
+
'whisperspeech.t2s_up_wds.load_datasets': ( '5b. text to semantic token '
|
| 402 |
+
'modeling.html#load_datasets',
|
| 403 |
+
'whisperspeech/t2s_up_wds.py'),
|
| 404 |
+
'whisperspeech.t2s_up_wds.make_model': ( '5b. text to semantic token modeling.html#make_model',
|
| 405 |
+
'whisperspeech/t2s_up_wds.py'),
|
| 406 |
+
'whisperspeech.t2s_up_wds.rand': ( '5b. text to semantic token modeling.html#rand',
|
| 407 |
+
'whisperspeech/t2s_up_wds.py'),
|
| 408 |
+
'whisperspeech.t2s_up_wds.speaker_id_extractor': ( '5b. text to semantic token '
|
| 409 |
+
'modeling.html#speaker_id_extractor',
|
| 410 |
+
'whisperspeech/t2s_up_wds.py'),
|
| 411 |
+
'whisperspeech.t2s_up_wds.tokenizer': ( '5b. text to semantic token modeling.html#tokenizer',
|
| 412 |
+
'whisperspeech/t2s_up_wds.py')},
|
| 413 |
+
'whisperspeech.train': { 'whisperspeech.train.SimpleVisual': ('b1. training.html#simplevisual', 'whisperspeech/train.py'),
|
| 414 |
+
'whisperspeech.train.SimpleVisual.__init__': ( 'b1. training.html#simplevisual.__init__',
|
| 415 |
+
'whisperspeech/train.py'),
|
| 416 |
+
'whisperspeech.train.SimpleVisual.add_data': ( 'b1. training.html#simplevisual.add_data',
|
| 417 |
+
'whisperspeech/train.py'),
|
| 418 |
+
'whisperspeech.train.SimpleVisual.add_table_row': ( 'b1. training.html#simplevisual.add_table_row',
|
| 419 |
+
'whisperspeech/train.py'),
|
| 420 |
+
'whisperspeech.train.SimpleVisual.hide': ( 'b1. training.html#simplevisual.hide',
|
| 421 |
+
'whisperspeech/train.py'),
|
| 422 |
+
'whisperspeech.train.SimpleVisual.on_iter': ( 'b1. training.html#simplevisual.on_iter',
|
| 423 |
+
'whisperspeech/train.py'),
|
| 424 |
+
'whisperspeech.train.SimpleVisual.plot': ( 'b1. training.html#simplevisual.plot',
|
| 425 |
+
'whisperspeech/train.py'),
|
| 426 |
+
'whisperspeech.train.SimpleVisual.show': ( 'b1. training.html#simplevisual.show',
|
| 427 |
+
'whisperspeech/train.py'),
|
| 428 |
+
'whisperspeech.train.train': ('b1. training.html#train', 'whisperspeech/train.py'),
|
| 429 |
+
'whisperspeech.train.validate': ('b1. training.html#validate', 'whisperspeech/train.py')},
|
| 430 |
+
'whisperspeech.train_multi': { 'whisperspeech.train_multi.TrainingTask': ( 'b2. training (lightning).html#trainingtask',
|
| 431 |
+
'whisperspeech/train_multi.py'),
|
| 432 |
+
'whisperspeech.train_multi.TrainingTask.__init__': ( 'b2. training '
|
| 433 |
+
'(lightning).html#trainingtask.__init__',
|
| 434 |
+
'whisperspeech/train_multi.py'),
|
| 435 |
+
'whisperspeech.train_multi.TrainingTask.configure_optimizers': ( 'b2. training '
|
| 436 |
+
'(lightning).html#trainingtask.configure_optimizers',
|
| 437 |
+
'whisperspeech/train_multi.py'),
|
| 438 |
+
'whisperspeech.train_multi.TrainingTask.on_fit_start': ( 'b2. training '
|
| 439 |
+
'(lightning).html#trainingtask.on_fit_start',
|
| 440 |
+
'whisperspeech/train_multi.py'),
|
| 441 |
+
'whisperspeech.train_multi.TrainingTask.on_validation_epoch_end': ( 'b2. training '
|
| 442 |
+
'(lightning).html#trainingtask.on_validation_epoch_end',
|
| 443 |
+
'whisperspeech/train_multi.py'),
|
| 444 |
+
'whisperspeech.train_multi.TrainingTask.test_step': ( 'b2. training '
|
| 445 |
+
'(lightning).html#trainingtask.test_step',
|
| 446 |
+
'whisperspeech/train_multi.py'),
|
| 447 |
+
'whisperspeech.train_multi.TrainingTask.training_step': ( 'b2. training '
|
| 448 |
+
'(lightning).html#trainingtask.training_step',
|
| 449 |
+
'whisperspeech/train_multi.py'),
|
| 450 |
+
'whisperspeech.train_multi.TrainingTask.validation_step': ( 'b2. training '
|
| 451 |
+
'(lightning).html#trainingtask.validation_step',
|
| 452 |
+
'whisperspeech/train_multi.py'),
|
| 453 |
+
'whisperspeech.train_multi.parse_and_call': ( 'b2. training (lightning).html#parse_and_call',
|
| 454 |
+
'whisperspeech/train_multi.py')},
|
| 455 |
+
'whisperspeech.vad': { 'whisperspeech.vad.extract_segments': ( '1b. voice activity detection.html#extract_segments',
|
| 456 |
+
'whisperspeech/vad.py'),
|
| 457 |
+
'whisperspeech.vad.fix_dots_in_names': ( '1b. voice activity detection.html#fix_dots_in_names',
|
| 458 |
+
'whisperspeech/vad.py'),
|
| 459 |
+
'whisperspeech.vad.flac_to_vad_name': ( '1b. voice activity detection.html#flac_to_vad_name',
|
| 460 |
+
'whisperspeech/vad.py'),
|
| 461 |
+
'whisperspeech.vad.load_dataset': ( '1b. voice activity detection.html#load_dataset',
|
| 462 |
+
'whisperspeech/vad.py'),
|
| 463 |
+
'whisperspeech.vad.process_shard': ( '1b. voice activity detection.html#process_shard',
|
| 464 |
+
'whisperspeech/vad.py'),
|
| 465 |
+
'whisperspeech.vad.segment_audio': ( '1b. voice activity detection.html#segment_audio',
|
| 466 |
+
'whisperspeech/vad.py')},
|
| 467 |
+
'whisperspeech.verify_wds': { 'whisperspeech.verify_wds.process_shard': ( '0. verify webdataset archives.html#process_shard',
|
| 468 |
+
'whisperspeech/verify_wds.py')},
|
| 469 |
+
'whisperspeech.vq_stoks': { 'whisperspeech.vq_stoks.RQBottleneckTransformer': ( '2b. whisper quantization (semantic token) '
|
| 470 |
+
'model.html#rqbottlenecktransformer',
|
| 471 |
+
'whisperspeech/vq_stoks.py'),
|
| 472 |
+
'whisperspeech.vq_stoks.RQBottleneckTransformer.__init__': ( '2b. whisper quantization (semantic '
|
| 473 |
+
'token) '
|
| 474 |
+
'model.html#rqbottlenecktransformer.__init__',
|
| 475 |
+
'whisperspeech/vq_stoks.py'),
|
| 476 |
+
'whisperspeech.vq_stoks.RQBottleneckTransformer.decode_text': ( '2b. whisper quantization '
|
| 477 |
+
'(semantic token) '
|
| 478 |
+
'model.html#rqbottlenecktransformer.decode_text',
|
| 479 |
+
'whisperspeech/vq_stoks.py'),
|
| 480 |
+
'whisperspeech.vq_stoks.RQBottleneckTransformer.dequantize': ( '2b. whisper quantization (semantic '
|
| 481 |
+
'token) '
|
| 482 |
+
'model.html#rqbottlenecktransformer.dequantize',
|
| 483 |
+
'whisperspeech/vq_stoks.py'),
|
| 484 |
+
'whisperspeech.vq_stoks.RQBottleneckTransformer.device': ( '2b. whisper quantization (semantic '
|
| 485 |
+
'token) '
|
| 486 |
+
'model.html#rqbottlenecktransformer.device',
|
| 487 |
+
'whisperspeech/vq_stoks.py'),
|
| 488 |
+
'whisperspeech.vq_stoks.RQBottleneckTransformer.downsample_embeddings': ( '2b. whisper '
|
| 489 |
+
'quantization (semantic '
|
| 490 |
+
'token) '
|
| 491 |
+
'model.html#rqbottlenecktransformer.downsample_embeddings',
|
| 492 |
+
'whisperspeech/vq_stoks.py'),
|
| 493 |
+
'whisperspeech.vq_stoks.RQBottleneckTransformer.encode_audio': ( '2b. whisper quantization '
|
| 494 |
+
'(semantic token) '
|
| 495 |
+
'model.html#rqbottlenecktransformer.encode_audio',
|
| 496 |
+
'whisperspeech/vq_stoks.py'),
|
| 497 |
+
'whisperspeech.vq_stoks.RQBottleneckTransformer.encode_mel': ( '2b. whisper quantization (semantic '
|
| 498 |
+
'token) '
|
| 499 |
+
'model.html#rqbottlenecktransformer.encode_mel',
|
| 500 |
+
'whisperspeech/vq_stoks.py'),
|
| 501 |
+
'whisperspeech.vq_stoks.RQBottleneckTransformer.ensure_whisper': ( '2b. whisper quantization '
|
| 502 |
+
'(semantic token) '
|
| 503 |
+
'model.html#rqbottlenecktransformer.ensure_whisper',
|
| 504 |
+
'whisperspeech/vq_stoks.py'),
|
| 505 |
+
'whisperspeech.vq_stoks.RQBottleneckTransformer.extract_teacher': ( '2b. whisper quantization '
|
| 506 |
+
'(semantic token) '
|
| 507 |
+
'model.html#rqbottlenecktransformer.extract_teacher',
|
| 508 |
+
'whisperspeech/vq_stoks.py'),
|
| 509 |
+
'whisperspeech.vq_stoks.RQBottleneckTransformer.forward': ( '2b. whisper quantization (semantic '
|
| 510 |
+
'token) '
|
| 511 |
+
'model.html#rqbottlenecktransformer.forward',
|
| 512 |
+
'whisperspeech/vq_stoks.py'),
|
| 513 |
+
'whisperspeech.vq_stoks.RQBottleneckTransformer.get_metrics': ( '2b. whisper quantization '
|
| 514 |
+
'(semantic token) '
|
| 515 |
+
'model.html#rqbottlenecktransformer.get_metrics',
|
| 516 |
+
'whisperspeech/vq_stoks.py'),
|
| 517 |
+
'whisperspeech.vq_stoks.RQBottleneckTransformer.init_transformer': ( '2b. whisper quantization '
|
| 518 |
+
'(semantic token) '
|
| 519 |
+
'model.html#rqbottlenecktransformer.init_transformer',
|
| 520 |
+
'whisperspeech/vq_stoks.py'),
|
| 521 |
+
'whisperspeech.vq_stoks.RQBottleneckTransformer.load_checkpoint': ( '2b. whisper quantization '
|
| 522 |
+
'(semantic token) '
|
| 523 |
+
'model.html#rqbottlenecktransformer.load_checkpoint',
|
| 524 |
+
'whisperspeech/vq_stoks.py'),
|
| 525 |
+
'whisperspeech.vq_stoks.RQBottleneckTransformer.load_model': ( '2b. whisper quantization (semantic '
|
| 526 |
+
'token) '
|
| 527 |
+
'model.html#rqbottlenecktransformer.load_model',
|
| 528 |
+
'whisperspeech/vq_stoks.py'),
|
| 529 |
+
'whisperspeech.vq_stoks.RQBottleneckTransformer.quantize': ( '2b. whisper quantization (semantic '
|
| 530 |
+
'token) '
|
| 531 |
+
'model.html#rqbottlenecktransformer.quantize',
|
| 532 |
+
'whisperspeech/vq_stoks.py'),
|
| 533 |
+
'whisperspeech.vq_stoks.RQBottleneckTransformer.save_model': ( '2b. whisper quantization (semantic '
|
| 534 |
+
'token) '
|
| 535 |
+
'model.html#rqbottlenecktransformer.save_model',
|
| 536 |
+
'whisperspeech/vq_stoks.py'),
|
| 537 |
+
'whisperspeech.vq_stoks.RQBottleneckTransformer.setup': ( '2b. whisper quantization (semantic '
|
| 538 |
+
'token) '
|
| 539 |
+
'model.html#rqbottlenecktransformer.setup',
|
| 540 |
+
'whisperspeech/vq_stoks.py'),
|
| 541 |
+
'whisperspeech.vq_stoks.Tunables': ( '2b. whisper quantization (semantic token) '
|
| 542 |
+
'model.html#tunables',
|
| 543 |
+
'whisperspeech/vq_stoks.py'),
|
| 544 |
+
'whisperspeech.vq_stoks.Tunables.__post_init__': ( '2b. whisper quantization (semantic token) '
|
| 545 |
+
'model.html#tunables.__post_init__',
|
| 546 |
+
'whisperspeech/vq_stoks.py'),
|
| 547 |
+
'whisperspeech.vq_stoks.Tunables.upgrade': ( '2b. whisper quantization (semantic token) '
|
| 548 |
+
'model.html#tunables.upgrade',
|
| 549 |
+
'whisperspeech/vq_stoks.py'),
|
| 550 |
+
'whisperspeech.vq_stoks.add_masks': ( '2b. whisper quantization (semantic token) '
|
| 551 |
+
'model.html#add_masks',
|
| 552 |
+
'whisperspeech/vq_stoks.py'),
|
| 553 |
+
'whisperspeech.vq_stoks.derived_dataset': ( '2b. whisper quantization (semantic token) '
|
| 554 |
+
'model.html#derived_dataset',
|
| 555 |
+
'whisperspeech/vq_stoks.py'),
|
| 556 |
+
'whisperspeech.vq_stoks.load_datasets': ( '2b. whisper quantization (semantic token) '
|
| 557 |
+
'model.html#load_datasets',
|
| 558 |
+
'whisperspeech/vq_stoks.py'),
|
| 559 |
+
'whisperspeech.vq_stoks.logrand': ( '2b. whisper quantization (semantic token) model.html#logrand',
|
| 560 |
+
'whisperspeech/vq_stoks.py'),
|
| 561 |
+
'whisperspeech.vq_stoks.make_model': ( '2b. whisper quantization (semantic token) '
|
| 562 |
+
'model.html#make_model',
|
| 563 |
+
'whisperspeech/vq_stoks.py'),
|
| 564 |
+
'whisperspeech.vq_stoks.merge_in': ( '2b. whisper quantization (semantic token) '
|
| 565 |
+
'model.html#merge_in',
|
| 566 |
+
'whisperspeech/vq_stoks.py'),
|
| 567 |
+
'whisperspeech.vq_stoks.rand': ( '2b. whisper quantization (semantic token) model.html#rand',
|
| 568 |
+
'whisperspeech/vq_stoks.py'),
|
| 569 |
+
'whisperspeech.vq_stoks.tokenize_text': ( '2b. whisper quantization (semantic token) '
|
| 570 |
+
'model.html#tokenize_text',
|
| 571 |
+
'whisperspeech/vq_stoks.py')},
|
| 572 |
+
'whisperspeech.wer_metrics': { 'whisperspeech.wer_metrics.DfBuilder': ( 'c. word error rate metrics.html#dfbuilder',
|
| 573 |
+
'whisperspeech/wer_metrics.py'),
|
| 574 |
+
'whisperspeech.wer_metrics.DfBuilder.__init__': ( 'c. word error rate '
|
| 575 |
+
'metrics.html#dfbuilder.__init__',
|
| 576 |
+
'whisperspeech/wer_metrics.py'),
|
| 577 |
+
'whisperspeech.wer_metrics.DfBuilder.df': ( 'c. word error rate metrics.html#dfbuilder.df',
|
| 578 |
+
'whisperspeech/wer_metrics.py'),
|
| 579 |
+
'whisperspeech.wer_metrics.DfBuilder.push': ( 'c. word error rate metrics.html#dfbuilder.push',
|
| 580 |
+
'whisperspeech/wer_metrics.py'),
|
| 581 |
+
'whisperspeech.wer_metrics.WERStats': ( 'c. word error rate metrics.html#werstats',
|
| 582 |
+
'whisperspeech/wer_metrics.py'),
|
| 583 |
+
'whisperspeech.wer_metrics.WERStats.__init__': ( 'c. word error rate '
|
| 584 |
+
'metrics.html#werstats.__init__',
|
| 585 |
+
'whisperspeech/wer_metrics.py'),
|
| 586 |
+
'whisperspeech.wer_metrics.WERStats.push_sample': ( 'c. word error rate '
|
| 587 |
+
'metrics.html#werstats.push_sample',
|
| 588 |
+
'whisperspeech/wer_metrics.py'),
|
| 589 |
+
'whisperspeech.wer_metrics.librispeech_data': ( 'c. word error rate '
|
| 590 |
+
'metrics.html#librispeech_data',
|
| 591 |
+
'whisperspeech/wer_metrics.py'),
|
| 592 |
+
'whisperspeech.wer_metrics.whisper_normalize': ( 'c. word error rate '
|
| 593 |
+
'metrics.html#whisper_normalize',
|
| 594 |
+
'whisperspeech/wer_metrics.py')},
|
| 595 |
+
'whisperspeech.wh_transcribe': { 'whisperspeech.wh_transcribe.chunk_merger': ( '2a. whisper quantization dataset '
|
| 596 |
+
'preparation.html#chunk_merger',
|
| 597 |
+
'whisperspeech/wh_transcribe.py'),
|
| 598 |
+
'whisperspeech.wh_transcribe.flac_to_txt_name': ( '2a. whisper quantization dataset '
|
| 599 |
+
'preparation.html#flac_to_txt_name',
|
| 600 |
+
'whisperspeech/wh_transcribe.py'),
|
| 601 |
+
'whisperspeech.wh_transcribe.merge_in': ( '2a. whisper quantization dataset '
|
| 602 |
+
'preparation.html#merge_in',
|
| 603 |
+
'whisperspeech/wh_transcribe.py'),
|
| 604 |
+
'whisperspeech.wh_transcribe.process_shard': ( '2a. whisper quantization dataset '
|
| 605 |
+
'preparation.html#process_shard',
|
| 606 |
+
'whisperspeech/wh_transcribe.py'),
|
| 607 |
+
'whisperspeech.wh_transcribe.random_cutter': ( '2a. whisper quantization dataset '
|
| 608 |
+
'preparation.html#random_cutter',
|
| 609 |
+
'whisperspeech/wh_transcribe.py'),
|
| 610 |
+
'whisperspeech.wh_transcribe.split_to_chunks': ( '2a. whisper quantization dataset '
|
| 611 |
+
'preparation.html#split_to_chunks',
|
| 612 |
+
'whisperspeech/wh_transcribe.py'),
|
| 613 |
+
'whisperspeech.wh_transcribe.wds_compose': ( '2a. whisper quantization dataset '
|
| 614 |
+
'preparation.html#wds_compose',
|
| 615 |
+
'whisperspeech/wh_transcribe.py')}}}
|
whisperspeech/a2wav.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/6. Quality-boosting vocoder.ipynb.
|
| 2 |
+
|
| 3 |
+
# %% auto 0
|
| 4 |
+
__all__ = ['Vocoder']
|
| 5 |
+
|
| 6 |
+
# %% ../nbs/6. Quality-boosting vocoder.ipynb 1
|
| 7 |
+
from vocos import Vocos
|
| 8 |
+
import torch
|
| 9 |
+
import torchaudio
|
| 10 |
+
|
| 11 |
+
# %% ../nbs/6. Quality-boosting vocoder.ipynb 2
|
| 12 |
+
class Vocoder:
|
| 13 |
+
def __init__(self, repo_id="charactr/vocos-encodec-24khz"):
|
| 14 |
+
self.vocos = Vocos.from_pretrained(repo_id).cuda()
|
| 15 |
+
|
| 16 |
+
def is_notebook(self):
|
| 17 |
+
try:
|
| 18 |
+
return get_ipython().__class__.__name__ == "ZMQInteractiveShell"
|
| 19 |
+
except:
|
| 20 |
+
return False
|
| 21 |
+
|
| 22 |
+
@torch.no_grad()
|
| 23 |
+
def decode(self, atoks):
|
| 24 |
+
if len(atoks.shape) == 3:
|
| 25 |
+
b,q,t = atoks.shape
|
| 26 |
+
atoks = atoks.permute(1,0,2)
|
| 27 |
+
else:
|
| 28 |
+
q,t = atoks.shape
|
| 29 |
+
|
| 30 |
+
features = self.vocos.codes_to_features(atoks)
|
| 31 |
+
bandwidth_id = torch.tensor({2:0,4:1,8:2}[q]).cuda()
|
| 32 |
+
return self.vocos.decode(features, bandwidth_id=bandwidth_id)
|
| 33 |
+
|
| 34 |
+
def decode_to_file(self, fname, atoks):
|
| 35 |
+
audio = self.decode(atoks)
|
| 36 |
+
torchaudio.save(fname, audio.cpu(), 24000)
|
| 37 |
+
if self.is_notebook():
|
| 38 |
+
from IPython.display import display, HTML, Audio
|
| 39 |
+
display(HTML(f'<a href="{fname}" target="_blank">Listen to {fname}</a>'))
|
| 40 |
+
|
| 41 |
+
def decode_to_notebook(self, atoks):
|
| 42 |
+
from IPython.display import display, HTML, Audio
|
| 43 |
+
|
| 44 |
+
audio = self.decode(atoks)
|
| 45 |
+
display(Audio(audio.cpu().numpy(), rate=24000))
|
whisperspeech/extract_acoustic.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/1. Acoustic token extraction.ipynb.
|
| 2 |
+
|
| 3 |
+
# %% auto 0
|
| 4 |
+
__all__ = ['load', 'load_model', 'extract_Atoks', 'extract_acoustic']
|
| 5 |
+
|
| 6 |
+
# %% ../nbs/1. Acoustic token extraction.ipynb 2
|
| 7 |
+
import torch
|
| 8 |
+
import torchaudio
|
| 9 |
+
import gc
|
| 10 |
+
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
from fastcore.script import *
|
| 13 |
+
from fastprogress import progress_bar, master_bar
|
| 14 |
+
|
| 15 |
+
# %% ../nbs/1. Acoustic token extraction.ipynb 5
|
| 16 |
+
def load(fname, newsr=24000):
|
| 17 |
+
"""Load an audio file to the GPU and resample to `newsr`."""
|
| 18 |
+
x, sr = torchaudio.load(fname)
|
| 19 |
+
_tform = torchaudio.transforms.Resample(sr, newsr)
|
| 20 |
+
return _tform(x).cuda().unsqueeze(0)
|
| 21 |
+
|
| 22 |
+
# %% ../nbs/1. Acoustic token extraction.ipynb 6
|
| 23 |
+
def load_model():
|
| 24 |
+
"Load the pretrained EnCodec model"
|
| 25 |
+
from encodec.model import EncodecModel
|
| 26 |
+
model = EncodecModel.encodec_model_24khz()
|
| 27 |
+
model.set_target_bandwidth(1.5)
|
| 28 |
+
model.cuda().eval();
|
| 29 |
+
return model
|
| 30 |
+
|
| 31 |
+
# %% ../nbs/1. Acoustic token extraction.ipynb 7
|
| 32 |
+
def extract_Atoks(model, audio):
|
| 33 |
+
"""Extract EnCodec tokens for the given `audio` tensor (or file path)
|
| 34 |
+
using the given `model` (see `load_model`)."""
|
| 35 |
+
if isinstance(audio, (Path, str)):
|
| 36 |
+
audio = load(audio)
|
| 37 |
+
with torch.no_grad():
|
| 38 |
+
frames = torch.cat([model.encode(segment)[0][0]
|
| 39 |
+
for segment in torch.split(audio, 320*20000, dim=-1)], dim=-1)
|
| 40 |
+
return frames
|
| 41 |
+
|
| 42 |
+
# %% ../nbs/1. Acoustic token extraction.ipynb 8
|
| 43 |
+
@call_parse
|
| 44 |
+
def extract_acoustic(
|
| 45 |
+
srcdir:Path, # source dir, should contain *.flac files
|
| 46 |
+
outdir:Path, # output dir, will get the *.encodec files
|
| 47 |
+
):
|
| 48 |
+
"Convert audio files to .encodec files with tensors of tokens"
|
| 49 |
+
model = load_model()
|
| 50 |
+
outdir.mkdir(exist_ok=True, parents=True)
|
| 51 |
+
for name in progress_bar(list(srcdir.rglob('*.flac'))):
|
| 52 |
+
outname = outdir/name.with_suffix('.encodec').name
|
| 53 |
+
tokens = extract_Atoks(model, name)
|
| 54 |
+
torch.save(tokens, outname)
|
| 55 |
+
del tokens
|
| 56 |
+
gc.collect()
|
whisperspeech/fetch_models.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/0. Download models.ipynb.
|
| 2 |
+
|
| 3 |
+
# %% auto 0
|
| 4 |
+
__all__ = []
|
| 5 |
+
|
| 6 |
+
# %% ../nbs/0. Download models.ipynb 1
|
| 7 |
+
from fastcore.script import call_parse
|
| 8 |
+
import whisperx
|
| 9 |
+
import whisper
|
| 10 |
+
|
| 11 |
+
# %% ../nbs/0. Download models.ipynb 3
|
| 12 |
+
@call_parse
|
| 13 |
+
def main():
|
| 14 |
+
whisper.load_model('base.en')
|
| 15 |
+
whisper.load_model('small.en')
|
| 16 |
+
whisperx.vad.load_vad_model('cpu')
|
| 17 |
+
whisperx.asr.load_model('medium.en', "cpu", compute_type="float16", language='en')
|
whisperspeech/languages.py
ADDED
|
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/B. Languages.ipynb.
|
| 2 |
+
|
| 3 |
+
# %% auto 0
|
| 4 |
+
__all__ = ['to_id']
|
| 5 |
+
|
| 6 |
+
# %% ../nbs/B. Languages.ipynb 3
|
| 7 |
+
LANGUAGES = {
|
| 8 |
+
"en": "english",
|
| 9 |
+
"zh": "chinese",
|
| 10 |
+
"de": "german",
|
| 11 |
+
"es": "spanish",
|
| 12 |
+
"ru": "russian",
|
| 13 |
+
"ko": "korean",
|
| 14 |
+
"fr": "french",
|
| 15 |
+
"ja": "japanese",
|
| 16 |
+
"pt": "portuguese",
|
| 17 |
+
"tr": "turkish",
|
| 18 |
+
"pl": "polish",
|
| 19 |
+
"ca": "catalan",
|
| 20 |
+
"nl": "dutch",
|
| 21 |
+
"ar": "arabic",
|
| 22 |
+
"sv": "swedish",
|
| 23 |
+
"it": "italian",
|
| 24 |
+
"id": "indonesian",
|
| 25 |
+
"hi": "hindi",
|
| 26 |
+
"fi": "finnish",
|
| 27 |
+
"vi": "vietnamese",
|
| 28 |
+
"he": "hebrew",
|
| 29 |
+
"uk": "ukrainian",
|
| 30 |
+
"el": "greek",
|
| 31 |
+
"ms": "malay",
|
| 32 |
+
"cs": "czech",
|
| 33 |
+
"ro": "romanian",
|
| 34 |
+
"da": "danish",
|
| 35 |
+
"hu": "hungarian",
|
| 36 |
+
"ta": "tamil",
|
| 37 |
+
"no": "norwegian",
|
| 38 |
+
"th": "thai",
|
| 39 |
+
"ur": "urdu",
|
| 40 |
+
"hr": "croatian",
|
| 41 |
+
"bg": "bulgarian",
|
| 42 |
+
"lt": "lithuanian",
|
| 43 |
+
"la": "latin",
|
| 44 |
+
"mi": "maori",
|
| 45 |
+
"ml": "malayalam",
|
| 46 |
+
"cy": "welsh",
|
| 47 |
+
"sk": "slovak",
|
| 48 |
+
"te": "telugu",
|
| 49 |
+
"fa": "persian",
|
| 50 |
+
"lv": "latvian",
|
| 51 |
+
"bn": "bengali",
|
| 52 |
+
"sr": "serbian",
|
| 53 |
+
"az": "azerbaijani",
|
| 54 |
+
"sl": "slovenian",
|
| 55 |
+
"kn": "kannada",
|
| 56 |
+
"et": "estonian",
|
| 57 |
+
"mk": "macedonian",
|
| 58 |
+
"br": "breton",
|
| 59 |
+
"eu": "basque",
|
| 60 |
+
"is": "icelandic",
|
| 61 |
+
"hy": "armenian",
|
| 62 |
+
"ne": "nepali",
|
| 63 |
+
"mn": "mongolian",
|
| 64 |
+
"bs": "bosnian",
|
| 65 |
+
"kk": "kazakh",
|
| 66 |
+
"sq": "albanian",
|
| 67 |
+
"sw": "swahili",
|
| 68 |
+
"gl": "galician",
|
| 69 |
+
"mr": "marathi",
|
| 70 |
+
"pa": "punjabi",
|
| 71 |
+
"si": "sinhala",
|
| 72 |
+
"km": "khmer",
|
| 73 |
+
"sn": "shona",
|
| 74 |
+
"yo": "yoruba",
|
| 75 |
+
"so": "somali",
|
| 76 |
+
"af": "afrikaans",
|
| 77 |
+
"oc": "occitan",
|
| 78 |
+
"ka": "georgian",
|
| 79 |
+
"be": "belarusian",
|
| 80 |
+
"tg": "tajik",
|
| 81 |
+
"sd": "sindhi",
|
| 82 |
+
"gu": "gujarati",
|
| 83 |
+
"am": "amharic",
|
| 84 |
+
"yi": "yiddish",
|
| 85 |
+
"lo": "lao",
|
| 86 |
+
"uz": "uzbek",
|
| 87 |
+
"fo": "faroese",
|
| 88 |
+
"ht": "haitian creole",
|
| 89 |
+
"ps": "pashto",
|
| 90 |
+
"tk": "turkmen",
|
| 91 |
+
"nn": "nynorsk",
|
| 92 |
+
"mt": "maltese",
|
| 93 |
+
"sa": "sanskrit",
|
| 94 |
+
"lb": "luxembourgish",
|
| 95 |
+
"my": "myanmar",
|
| 96 |
+
"bo": "tibetan",
|
| 97 |
+
"tl": "tagalog",
|
| 98 |
+
"mg": "malagasy",
|
| 99 |
+
"as": "assamese",
|
| 100 |
+
"tt": "tatar",
|
| 101 |
+
"haw": "hawaiian",
|
| 102 |
+
"ln": "lingala",
|
| 103 |
+
"ha": "hausa",
|
| 104 |
+
"ba": "bashkir",
|
| 105 |
+
"jw": "javanese",
|
| 106 |
+
"su": "sundanese",
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
# %% ../nbs/B. Languages.ipynb 4
|
| 110 |
+
# language code lookup by name, with a few language aliases
|
| 111 |
+
TO_LANGUAGE_CODE = {
|
| 112 |
+
**{language: code for code, language in LANGUAGES.items()},
|
| 113 |
+
"burmese": "my",
|
| 114 |
+
"valencian": "ca",
|
| 115 |
+
"flemish": "nl",
|
| 116 |
+
"haitian": "ht",
|
| 117 |
+
"letzeburgesch": "lb",
|
| 118 |
+
"pushto": "ps",
|
| 119 |
+
"panjabi": "pa",
|
| 120 |
+
"moldavian": "ro",
|
| 121 |
+
"moldovan": "ro",
|
| 122 |
+
"sinhalese": "si",
|
| 123 |
+
"castilian": "es",
|
| 124 |
+
}
|
| 125 |
+
|
| 126 |
+
# %% ../nbs/B. Languages.ipynb 5
|
| 127 |
+
languages = tuple(LANGUAGES.keys())
|
| 128 |
+
|
| 129 |
+
# %% ../nbs/B. Languages.ipynb 6
|
| 130 |
+
def to_id(lang):
|
| 131 |
+
return languages.index(TO_LANGUAGE_CODE.get(lang, lang))
|
whisperspeech/modules.py
ADDED
|
@@ -0,0 +1,331 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/A. Neural modules.ipynb.
|
| 2 |
+
|
| 3 |
+
# %% auto 0
|
| 4 |
+
__all__ = ['LayerNorm', 'LinearHead', 'QueryHead', 'init_transformer', 'sinusoids', 'MultiHeadAttention',
|
| 5 |
+
'ResidualAttentionBlock', 'BaseDecoder', 'EmbeddingProjector', 'FlexEmbeddings']
|
| 6 |
+
|
| 7 |
+
# %% ../nbs/A. Neural modules.ipynb 2
|
| 8 |
+
import torch
|
| 9 |
+
import numpy as np
|
| 10 |
+
import math
|
| 11 |
+
|
| 12 |
+
from torch import Tensor, nn
|
| 13 |
+
import torch.nn.functional as F
|
| 14 |
+
from typing import Dict, Iterable, Optional
|
| 15 |
+
|
| 16 |
+
# import xformers.ops as xops
|
| 17 |
+
|
| 18 |
+
# %% ../nbs/A. Neural modules.ipynb 3
|
| 19 |
+
# Code in this file is mostly borrowed from
|
| 20 |
+
# https://github.com/openai/whisper/blob/main/whisper/model.py
|
| 21 |
+
# and is under the MIT License
|
| 22 |
+
|
| 23 |
+
class LayerNorm(nn.LayerNorm):
|
| 24 |
+
def forward(self, x):
|
| 25 |
+
return super().forward(x.float()).type(x.dtype)
|
| 26 |
+
|
| 27 |
+
# Used in ΞΌP to initialize the weights and configure the optimizer
|
| 28 |
+
# These two layers map the transformer width into a fixed dimension
|
| 29 |
+
class LinearHead(nn.Linear):
|
| 30 |
+
pass
|
| 31 |
+
|
| 32 |
+
class QueryHead(nn.Linear):
|
| 33 |
+
pass
|
| 34 |
+
|
| 35 |
+
# based on https://github.com/karpathy/minGPT/blob/master/mingpt/model.py#L163
|
| 36 |
+
def init_transformer(m):
|
| 37 |
+
if isinstance(m, (nn.Linear, nn.Embedding)):
|
| 38 |
+
torch.nn.init.trunc_normal_(m.weight, std=.02)
|
| 39 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 40 |
+
torch.nn.init.constant_(m.bias, 0)
|
| 41 |
+
elif isinstance(m, nn.LayerNorm):
|
| 42 |
+
torch.nn.init.constant_(m.bias, 0)
|
| 43 |
+
torch.nn.init.constant_(m.weight, 1.0)
|
| 44 |
+
|
| 45 |
+
# %% ../nbs/A. Neural modules.ipynb 4
|
| 46 |
+
def sinusoids(length, channels, max_timescale=10000):
|
| 47 |
+
"""Returns sinusoids for positional embedding"""
|
| 48 |
+
assert channels % 2 == 0
|
| 49 |
+
log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
|
| 50 |
+
inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
|
| 51 |
+
scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
|
| 52 |
+
return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
|
| 53 |
+
|
| 54 |
+
# %% ../nbs/A. Neural modules.ipynb 5
|
| 55 |
+
class MultiHeadAttention(nn.Module):
|
| 56 |
+
def __init__(self, n_state: int, n_head: int, qk_scale: float = 1, rope: bool = False, cross=False):
|
| 57 |
+
super().__init__()
|
| 58 |
+
self.n_state = n_state
|
| 59 |
+
self.n_head = n_head
|
| 60 |
+
self.sqrt_qk_scale = math.sqrt(qk_scale)
|
| 61 |
+
self.query = QueryHead(n_state, n_state)
|
| 62 |
+
self.key = nn.Linear(n_state, n_state, bias=False)
|
| 63 |
+
self.value = nn.Linear(n_state, n_state)
|
| 64 |
+
self.out = nn.Linear(n_state, n_state)
|
| 65 |
+
self.cross = cross
|
| 66 |
+
self.query_subsampling = 1
|
| 67 |
+
self.key_subsampling = 1
|
| 68 |
+
|
| 69 |
+
self.cached_kvx = None
|
| 70 |
+
self.register_buffer('k_cache', None)
|
| 71 |
+
self.register_buffer('v_cache', None)
|
| 72 |
+
|
| 73 |
+
self.rotary = None
|
| 74 |
+
if rope:
|
| 75 |
+
self.rotary = Rotary(n_state // n_head)
|
| 76 |
+
self.qkv = None
|
| 77 |
+
self.kv = None
|
| 78 |
+
|
| 79 |
+
def setup_kv_cache(self, max_batch_size, max_seq_len, dtype=torch.float32):
|
| 80 |
+
cache_shape = (max_batch_size, self.n_head, max_seq_len, self.n_state//self.n_head)
|
| 81 |
+
self.k_cache = torch.zeros(cache_shape, dtype=dtype, device=self.key.weight.device)
|
| 82 |
+
self.v_cache = torch.zeros(cache_shape, dtype=dtype, device=self.value.weight.device)
|
| 83 |
+
|
| 84 |
+
def merge_linears(self, layers, mults):
|
| 85 |
+
bias = [x.bias for x in layers if x.bias is not None][0]
|
| 86 |
+
din, dout = layers[0].weight.shape
|
| 87 |
+
new = nn.Linear(din, len(layers) * dout).to(layers[0].weight.device)
|
| 88 |
+
with torch.no_grad():
|
| 89 |
+
new.weight[:] = torch.cat([x.weight * m for x,m in zip(layers, mults)])
|
| 90 |
+
new.bias[:] = torch.cat([torch.zeros_like(bias) if x.bias is None else x.bias * m for x, m in zip(layers, mults)])
|
| 91 |
+
return new
|
| 92 |
+
|
| 93 |
+
def convert_for_eval(self):
|
| 94 |
+
if self.qkv or self.kv: raise AttributeError("already converted")
|
| 95 |
+
|
| 96 |
+
self.odim = self.key.weight.shape[1]
|
| 97 |
+
if self.cross:
|
| 98 |
+
self.q = self.merge_linears([self.query], [self.sqrt_qk_scale])
|
| 99 |
+
self.kv = self.merge_linears([self.key, self.value],
|
| 100 |
+
[self.sqrt_qk_scale, 1])
|
| 101 |
+
else:
|
| 102 |
+
self.qkv = self.merge_linears([self.query, self.key, self.value],
|
| 103 |
+
[self.sqrt_qk_scale, self.sqrt_qk_scale, 1])
|
| 104 |
+
|
| 105 |
+
def split_heads(self, x, x_positions, rope=False, subsampling=1):
|
| 106 |
+
x = x.view(*x.shape[:2], self.n_head, -1)
|
| 107 |
+
if rope:
|
| 108 |
+
x = rope_rotate(x, x_positions * subsampling, *self.rotary(x))
|
| 109 |
+
return x.permute(0, 2, 1, 3)
|
| 110 |
+
|
| 111 |
+
def forward(
|
| 112 |
+
self,
|
| 113 |
+
qx,
|
| 114 |
+
q_positions,
|
| 115 |
+
kvx,
|
| 116 |
+
kv_positions,
|
| 117 |
+
causal = False,
|
| 118 |
+
mask=None,
|
| 119 |
+
):
|
| 120 |
+
if self.qkv:
|
| 121 |
+
q,k,v = self.qkv(qx).split(self.odim, dim=-1)
|
| 122 |
+
elif self.kv:
|
| 123 |
+
q = self.q(qx)
|
| 124 |
+
k,v = self.kv(kvx).split(self.odim, dim=-1)
|
| 125 |
+
else:
|
| 126 |
+
q,k,v = None,None,None
|
| 127 |
+
|
| 128 |
+
if q is None: q = self.query(qx) * self.sqrt_qk_scale
|
| 129 |
+
q = self.split_heads(q, q_positions, rope = self.rotary, subsampling = self.query_subsampling)
|
| 130 |
+
|
| 131 |
+
if kvx is not self.cached_kvx:
|
| 132 |
+
if k is None: k = self.key(kvx) * self.sqrt_qk_scale
|
| 133 |
+
k = self.split_heads(k, kv_positions, rope = self.rotary, subsampling = self.key_subsampling)
|
| 134 |
+
if v is None: v = self.value(kvx)
|
| 135 |
+
v = self.split_heads(v, kv_positions)
|
| 136 |
+
if self.k_cache is not None:
|
| 137 |
+
self.k_cache[:,:,kv_positions] = k
|
| 138 |
+
self.v_cache[:,:,kv_positions] = v
|
| 139 |
+
|
| 140 |
+
if self.k_cache is not None:
|
| 141 |
+
k, v = self.k_cache, self.v_cache
|
| 142 |
+
|
| 143 |
+
if mask is not None:
|
| 144 |
+
mask = mask[q_positions]
|
| 145 |
+
|
| 146 |
+
wv = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0, is_causal=causal)
|
| 147 |
+
|
| 148 |
+
return self.out(wv.permute(0, 2, 1, 3).flatten(start_dim=2))
|
| 149 |
+
|
| 150 |
+
# %% ../nbs/A. Neural modules.ipynb 6
|
| 151 |
+
# modified from https://blog.eleuther.ai/rotary-embeddings/
|
| 152 |
+
|
| 153 |
+
import torch
|
| 154 |
+
|
| 155 |
+
class Rotary(torch.nn.Module):
|
| 156 |
+
def __init__(self, dim, base=10000):
|
| 157 |
+
super().__init__()
|
| 158 |
+
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
|
| 159 |
+
self.register_buffer("inv_freq", inv_freq)
|
| 160 |
+
self.seq_len_cached = None
|
| 161 |
+
self.cos_cached = None
|
| 162 |
+
self.sin_cached = None
|
| 163 |
+
|
| 164 |
+
def forward(self, x, seq_dim=1):
|
| 165 |
+
seq_len = x.shape[seq_dim]
|
| 166 |
+
if not self.seq_len_cached or seq_len > self.seq_len_cached:
|
| 167 |
+
self.seq_len_cached = 2500
|
| 168 |
+
# self.seq_len_cached = seq_len
|
| 169 |
+
|
| 170 |
+
t = torch.arange(self.seq_len_cached, device=x.device).type_as(self.inv_freq)
|
| 171 |
+
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
| 172 |
+
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
|
| 173 |
+
self.cos_cached = emb.cos()[None, :, None, :]
|
| 174 |
+
self.sin_cached = emb.sin()[None, :, None, :]
|
| 175 |
+
return self.cos_cached, self.sin_cached
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
# rotary pos emb helpers:
|
| 179 |
+
def rotate_half(x):
|
| 180 |
+
x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
|
| 181 |
+
return torch.cat(
|
| 182 |
+
(-x2, x1), dim=len(x.shape)-1
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
def rope_rotate(x, positions, cos, sin):
|
| 186 |
+
return x * cos[:,positions] + rotate_half(x) * sin[:,positions]
|
| 187 |
+
|
| 188 |
+
# %% ../nbs/A. Neural modules.ipynb 7
|
| 189 |
+
class ResidualAttentionBlock(nn.Module):
|
| 190 |
+
def __init__(self, n_state: int, n_head: int, cross_attention: bool = False, rope: bool = False,
|
| 191 |
+
qk_scale: float = 1, ffn_mult: int = 4):
|
| 192 |
+
super().__init__()
|
| 193 |
+
self.attn = MultiHeadAttention(n_state, n_head, qk_scale=qk_scale, rope=rope)
|
| 194 |
+
self.attn_ln = LayerNorm(n_state)
|
| 195 |
+
|
| 196 |
+
self.cross_attn = (
|
| 197 |
+
MultiHeadAttention(n_state, n_head, qk_scale=qk_scale, rope=rope, cross=True) if cross_attention else None
|
| 198 |
+
)
|
| 199 |
+
self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None
|
| 200 |
+
|
| 201 |
+
n_mlp = n_state * ffn_mult
|
| 202 |
+
self.mlp = nn.Sequential(
|
| 203 |
+
nn.Linear(n_state, n_mlp), nn.GELU(), nn.Linear(n_mlp, n_state)
|
| 204 |
+
)
|
| 205 |
+
self.mlp_ln = LayerNorm(n_state)
|
| 206 |
+
|
| 207 |
+
def setup_kv_cache(self, max_batch_size, max_seq_len, max_cross_seq_len=None):
|
| 208 |
+
self.attn.setup_kv_cache(max_batch_size, max_seq_len)
|
| 209 |
+
if self.cross_attn:
|
| 210 |
+
self.cross_attn.setup_kv_cache(max_batch_size, max_cross_seq_len)
|
| 211 |
+
|
| 212 |
+
def forward(
|
| 213 |
+
self,
|
| 214 |
+
x: Tensor,
|
| 215 |
+
x_positions: Tensor = None,
|
| 216 |
+
xa: Optional[Tensor] = None,
|
| 217 |
+
xa_positions: Optional[Tensor] = None,
|
| 218 |
+
causal = False,
|
| 219 |
+
mask=None,
|
| 220 |
+
):
|
| 221 |
+
lnx = self.attn_ln(x)
|
| 222 |
+
x = x + self.attn(lnx, x_positions, lnx, x_positions, causal=causal, mask=mask)
|
| 223 |
+
if self.cross_attn:
|
| 224 |
+
lnx = self.cross_attn_ln(x)
|
| 225 |
+
x = x + self.cross_attn(lnx, x_positions, xa, xa_positions)
|
| 226 |
+
x = x + self.mlp(self.mlp_ln(x))
|
| 227 |
+
return x
|
| 228 |
+
|
| 229 |
+
# %% ../nbs/A. Neural modules.ipynb 8
|
| 230 |
+
class BaseDecoder(nn.Module):
|
| 231 |
+
def __init__(self, depth=6, n_head=6, width=384, qk_scale=1, ffn_mult=4, length=2250, rope=False):
|
| 232 |
+
super().__init__()
|
| 233 |
+
self.length = length
|
| 234 |
+
self.width = width
|
| 235 |
+
self.layers = nn.ModuleList([
|
| 236 |
+
ResidualAttentionBlock(
|
| 237 |
+
self.width, n_head, qk_scale=qk_scale, ffn_mult=ffn_mult, cross_attention=True, rope=rope
|
| 238 |
+
) for _ in range(math.floor(depth))
|
| 239 |
+
])
|
| 240 |
+
|
| 241 |
+
self.ln_post = LayerNorm(width)
|
| 242 |
+
|
| 243 |
+
mask = torch.empty(length, length).fill_(-torch.inf).triu_(1)
|
| 244 |
+
self.register_buffer("mask", mask, persistent=False)
|
| 245 |
+
|
| 246 |
+
def forward(self, x, x_positions, xenc, xenc_positions):
|
| 247 |
+
for i,l in enumerate(self.layers):
|
| 248 |
+
x = l(x, x_positions, xenc, xenc_positions, causal=False, mask=self.mask)
|
| 249 |
+
|
| 250 |
+
x = self.ln_post(x)
|
| 251 |
+
|
| 252 |
+
return x
|
| 253 |
+
|
| 254 |
+
# %% ../nbs/A. Neural modules.ipynb 9
|
| 255 |
+
class EmbeddingProjector(nn.Linear):
|
| 256 |
+
pass
|
| 257 |
+
|
| 258 |
+
class FlexEmbeddings(nn.Module):
|
| 259 |
+
def __init__(self, codes, width, special_codes=None, frozen_width=None, special_embedding=None, unembed=True):
|
| 260 |
+
super().__init__()
|
| 261 |
+
self.codes = codes
|
| 262 |
+
self.special_codes = special_codes
|
| 263 |
+
if frozen_width is None: frozen_width = width
|
| 264 |
+
|
| 265 |
+
self.main = nn.Embedding(codes, frozen_width or width)
|
| 266 |
+
self.emb_to_hidden = EmbeddingProjector(frozen_width, width) if frozen_width != width else None
|
| 267 |
+
self.hidden_to_emb = EmbeddingProjector(width, frozen_width) if unembed and frozen_width != width else None
|
| 268 |
+
if special_codes:
|
| 269 |
+
self.special = special_embedding or nn.Embedding(special_codes, width)
|
| 270 |
+
|
| 271 |
+
self.register_buffer('merged_in', None)
|
| 272 |
+
self.register_buffer('merged_out', None)
|
| 273 |
+
self.register_buffer('bias_out', None)
|
| 274 |
+
|
| 275 |
+
def set_frozen_embeddings(self, values):
|
| 276 |
+
with torch.no_grad():
|
| 277 |
+
self.main.weight[:] = values
|
| 278 |
+
self.main.lr_scale = 0
|
| 279 |
+
|
| 280 |
+
@torch.no_grad()
|
| 281 |
+
def convert_for_eval(self):
|
| 282 |
+
if not self.special_codes: return
|
| 283 |
+
# in
|
| 284 |
+
main_w = self.main.weight
|
| 285 |
+
if self.emb_to_hidden is not None: main_w = self.emb_to_hidden(main_w)
|
| 286 |
+
weight = torch.cat([main_w, self.special.weight], dim=0)
|
| 287 |
+
self.merged_in = nn.Embedding(*weight.shape, _weight=weight)
|
| 288 |
+
|
| 289 |
+
# out
|
| 290 |
+
weight = self.main.weight
|
| 291 |
+
if self.hidden_to_emb: weight = weight @ self.hidden_to_emb.weight
|
| 292 |
+
self.merged_out = torch.cat([weight.T, self.special.weight.T], dim=1).T.contiguous() # T is for F.linear
|
| 293 |
+
if self.hidden_to_emb:
|
| 294 |
+
self.bias_out = torch.cat([
|
| 295 |
+
self.hidden_to_emb.bias @ self.main.weight.T,
|
| 296 |
+
torch.zeros(self.special.weight.shape[0], device=weight.device, dtype=weight.dtype)
|
| 297 |
+
], dim=0)
|
| 298 |
+
else:
|
| 299 |
+
self.bias_out = None
|
| 300 |
+
|
| 301 |
+
def forward(self, toks):
|
| 302 |
+
if not self.training and self.merged_in is not None:
|
| 303 |
+
return self.merged_in(toks)
|
| 304 |
+
|
| 305 |
+
if self.special_codes:
|
| 306 |
+
special_mask = toks >= self.codes
|
| 307 |
+
embs = self.main(torch.where(special_mask, 0, toks))
|
| 308 |
+
else:
|
| 309 |
+
embs = self.main(toks)
|
| 310 |
+
|
| 311 |
+
if self.emb_to_hidden: embs = self.emb_to_hidden(embs)
|
| 312 |
+
|
| 313 |
+
if self.special_codes:
|
| 314 |
+
embs[special_mask] = self.special(toks[special_mask] - self.codes).to(embs.dtype)
|
| 315 |
+
|
| 316 |
+
return embs
|
| 317 |
+
|
| 318 |
+
def unembed(self, embs):
|
| 319 |
+
if not self.training and self.merged_out is not None:
|
| 320 |
+
return F.linear(embs, self.merged_out, self.bias_out) # embs @ self.merged_out + self.bias_out
|
| 321 |
+
|
| 322 |
+
orig_embs = embs
|
| 323 |
+
if self.hidden_to_emb: embs = self.hidden_to_emb(embs)
|
| 324 |
+
|
| 325 |
+
main_logits = (embs @ self.main.weight.to(embs.dtype).T).float()
|
| 326 |
+
|
| 327 |
+
if not self.special_codes:
|
| 328 |
+
return main_logits
|
| 329 |
+
|
| 330 |
+
special_logits = (orig_embs @ self.special.weight.to(orig_embs.dtype).T).float()
|
| 331 |
+
return torch.cat([main_logits, special_logits], dim=-1)
|
whisperspeech/pipeline.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/7. Pipeline.ipynb.
|
| 2 |
+
|
| 3 |
+
# %% auto 0
|
| 4 |
+
__all__ = ['Pipeline']
|
| 5 |
+
|
| 6 |
+
# %% ../nbs/7. Pipeline.ipynb 1
|
| 7 |
+
import torch
|
| 8 |
+
from whisperspeech.t2s_up_wds_mlang_enclm import TSARTransformer
|
| 9 |
+
from whisperspeech.s2a_delar_mup_wds_mlang import SADelARTransformer
|
| 10 |
+
from whisperspeech.a2wav import Vocoder
|
| 11 |
+
import traceback
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
|
| 14 |
+
# %% ../nbs/7. Pipeline.ipynb 2
|
| 15 |
+
class Pipeline:
|
| 16 |
+
default_speaker = torch.tensor(
|
| 17 |
+
[-0.2929, -0.4503, 0.4155, -0.1417, 0.0473, -0.1624, -0.2322, 0.7071,
|
| 18 |
+
0.4800, 0.5496, 0.0410, 0.6236, 0.4729, 0.0587, 0.2194, -0.0466,
|
| 19 |
+
-0.3036, 0.0497, 0.5028, -0.1703, 0.5039, -0.6464, 0.3857, -0.7350,
|
| 20 |
+
-0.1605, 0.4808, 0.5397, -0.4851, 0.1774, -0.8712, 0.5789, 0.1785,
|
| 21 |
+
-0.1417, 0.3039, 0.4232, -0.0186, 0.2685, 0.6153, -0.3103, -0.5706,
|
| 22 |
+
-0.4494, 0.3394, -0.6184, -0.3617, 1.1041, -0.1178, -0.1885, 0.1997,
|
| 23 |
+
0.5571, -0.2906, -0.0477, -0.4048, -0.1062, 1.4779, 0.1639, -0.3712,
|
| 24 |
+
-0.1776, -0.0568, -0.6162, 0.0110, -0.0207, -0.1319, -0.3854, 0.7248,
|
| 25 |
+
0.0343, 0.5724, 0.0670, 0.0486, -0.3813, 0.1738, 0.3017, 1.0502,
|
| 26 |
+
0.1550, 0.5708, 0.0366, 0.5093, 0.0294, -0.7091, -0.8220, -0.1583,
|
| 27 |
+
-0.2343, 0.1366, 0.7372, -0.0631, 0.1505, 0.4600, -0.1252, -0.5245,
|
| 28 |
+
0.7523, -0.0386, -0.2587, 1.0066, -0.2037, 0.1617, -0.3800, 0.2790,
|
| 29 |
+
0.0184, -0.5111, -0.7291, 0.1627, 0.2367, -0.0192, 0.4822, -0.4458,
|
| 30 |
+
0.1457, -0.5884, 0.1909, 0.2563, -0.2035, -0.0377, 0.7771, 0.2139,
|
| 31 |
+
0.3801, 0.6047, -0.6043, -0.2563, -0.0726, 0.3856, 0.3217, 0.0823,
|
| 32 |
+
-0.1302, 0.3287, 0.5693, 0.2453, 0.8231, 0.0072, 1.0327, 0.6065,
|
| 33 |
+
-0.0620, -0.5572, 0.5220, 0.2485, 0.1520, 0.0222, -0.2179, -0.7392,
|
| 34 |
+
-0.3855, 0.1822, 0.1042, 0.7133, 0.3583, 0.0606, -0.0424, -0.9189,
|
| 35 |
+
-0.4882, -0.5480, -0.5719, -0.1660, -0.3439, -0.5814, -0.2542, 0.0197,
|
| 36 |
+
0.4942, 0.0915, -0.0420, -0.0035, 0.5578, 0.1051, -0.0891, 0.2348,
|
| 37 |
+
0.6876, -0.6685, 0.8215, -0.3692, -0.3150, -0.0462, -0.6806, -0.2661,
|
| 38 |
+
-0.0308, -0.0050, 0.6756, -0.1647, 1.0734, 0.0049, 0.4969, 0.0259,
|
| 39 |
+
-0.8949, 0.0731, 0.0886, 0.3442, -0.1433, -0.6804, 0.2204, 0.1859,
|
| 40 |
+
0.2702, 0.1699, -0.1443, -0.9614, 0.3261, 0.1718, 0.3545, -0.0686]
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
def __init__(self, t2s_ref=None, s2a_ref=None, optimize=True, torch_compile=False):
|
| 44 |
+
args = dict()
|
| 45 |
+
try:
|
| 46 |
+
if t2s_ref:
|
| 47 |
+
args["ref"] = t2s_ref
|
| 48 |
+
self.t2s = TSARTransformer.load_model(**args).cuda()
|
| 49 |
+
if optimize: self.t2s.optimize(torch_compile=torch_compile)
|
| 50 |
+
except:
|
| 51 |
+
print("Failed to load the T2S model:")
|
| 52 |
+
print(traceback.format_exc())
|
| 53 |
+
try:
|
| 54 |
+
if s2a_ref:
|
| 55 |
+
args["ref"] = s2a_ref
|
| 56 |
+
self.s2a = SADelARTransformer.load_model(**args).cuda()
|
| 57 |
+
if optimize: self.s2a.optimize(torch_compile=torch_compile)
|
| 58 |
+
except:
|
| 59 |
+
print("Failed to load the S2A model:")
|
| 60 |
+
print(traceback.format_exc())
|
| 61 |
+
self.vocoder = Vocoder()
|
| 62 |
+
self.encoder = None
|
| 63 |
+
|
| 64 |
+
def extract_spk_emb(self, fname):
|
| 65 |
+
"""Extracts a speaker embedding from the first 30 seconds of the give audio file.
|
| 66 |
+
"""
|
| 67 |
+
import torchaudio
|
| 68 |
+
if self.encoder is None:
|
| 69 |
+
from speechbrain.pretrained import EncoderClassifier
|
| 70 |
+
self.encoder = EncoderClassifier.from_hparams("speechbrain/spkrec-ecapa-voxceleb",
|
| 71 |
+
savedir="~/.cache/speechbrain/",
|
| 72 |
+
run_opts={"device": "cuda"})
|
| 73 |
+
samples, sr = torchaudio.load(fname)
|
| 74 |
+
samples = self.encoder.audio_normalizer(samples[0,:30*sr], sr)
|
| 75 |
+
spk_emb = self.encoder.encode_batch(samples)
|
| 76 |
+
return spk_emb[0,0]
|
| 77 |
+
|
| 78 |
+
def generate_atoks(self, text, speaker=None, lang='en', cps=15, step_callback=None):
|
| 79 |
+
if speaker is None: speaker = self.default_speaker
|
| 80 |
+
elif isinstance(speaker, (str, Path)): speaker = self.extract_spk_emb(speaker)
|
| 81 |
+
text = text.replace("\n", " ")
|
| 82 |
+
stoks = self.t2s.generate(text, cps=cps, lang=lang, step=step_callback)
|
| 83 |
+
atoks = self.s2a.generate(stoks, speaker.unsqueeze(0), step=step_callback)
|
| 84 |
+
return atoks
|
| 85 |
+
|
| 86 |
+
def generate(self, text, speaker=None, lang='en', cps=15, step_callback=None):
|
| 87 |
+
return self.vocoder.decode(self.generate_atoks(text, speaker, lang=lang, cps=cps, step_callback=step_callback))
|
| 88 |
+
|
| 89 |
+
def generate_to_file(self, fname, text, speaker=None, lang='en', cps=15, step_callback=None):
|
| 90 |
+
self.vocoder.decode_to_file(fname, self.generate_atoks(text, speaker, lang=lang, cps=cps, step_callback=None))
|
| 91 |
+
|
| 92 |
+
def generate_to_notebook(self, text, speaker=None, lang='en', cps=15, step_callback=None):
|
| 93 |
+
self.vocoder.decode_to_notebook(self.generate_atoks(text, speaker, lang=lang, cps=cps, step_callback=None))
|
whisperspeech/prepare_s2a_dataset.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/4A. S2A dataset preparation.ipynb.
|
| 2 |
+
|
| 3 |
+
# %% auto 0
|
| 4 |
+
__all__ = ['flac_to_s2a_name']
|
| 5 |
+
|
| 6 |
+
# %% ../nbs/4A. S2A dataset preparation.ipynb 2
|
| 7 |
+
import sys
|
| 8 |
+
import os
|
| 9 |
+
import itertools
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
|
| 12 |
+
import numpy as np
|
| 13 |
+
import torch
|
| 14 |
+
import torchaudio
|
| 15 |
+
import torch.nn.functional as F
|
| 16 |
+
from torch.profiler import profile, record_function, ProfilerActivity
|
| 17 |
+
|
| 18 |
+
from fastprogress import progress_bar
|
| 19 |
+
from fastcore.script import *
|
| 20 |
+
|
| 21 |
+
import whisper
|
| 22 |
+
from . import vad, wh_transcribe, vq_stoks, extract_acoustic
|
| 23 |
+
import webdataset as wds
|
| 24 |
+
|
| 25 |
+
# %% ../nbs/4A. S2A dataset preparation.ipynb 4
|
| 26 |
+
def flac_to_s2a_name(input):
|
| 27 |
+
if '-flac-' in input:
|
| 28 |
+
return input.rsplit("/", 1)[1].replace('flac', 's2a') + ".gz"
|
| 29 |
+
else:
|
| 30 |
+
return input.rsplit("/", 1)[1].replace('raw', 's2a') + ".gz"
|
| 31 |
+
|
| 32 |
+
# %% ../nbs/4A. S2A dataset preparation.ipynb 6
|
| 33 |
+
def resampler(newsr = 24000, key = 'samples_24k'):
|
| 34 |
+
_last_sr = None
|
| 35 |
+
tform = None
|
| 36 |
+
|
| 37 |
+
def _resample(samples):
|
| 38 |
+
for s in samples:
|
| 39 |
+
sr = s['sample_rate']
|
| 40 |
+
if sr != newsr:
|
| 41 |
+
if sr != _last_sr: tform = torchaudio.transforms.Resample(sr, newsr)
|
| 42 |
+
s[key] = tform(s['samples'])
|
| 43 |
+
else:
|
| 44 |
+
s[key] = s['samples']
|
| 45 |
+
yield s
|
| 46 |
+
|
| 47 |
+
return _resample
|
| 48 |
+
|
| 49 |
+
# %% ../nbs/4A. S2A dataset preparation.ipynb 9
|
| 50 |
+
@call_parse
|
| 51 |
+
def prepare_s2a(
|
| 52 |
+
input:str, # FLAC webdataset file path (or - to read the names from stdin)
|
| 53 |
+
proc_dataset_path:Path, # processed VAD files path
|
| 54 |
+
output:str=None, # output file name
|
| 55 |
+
vq_model:str="collabora/spear-tts-pytorch:whisper-vq-stoks.model", # the model path (use repo_id:filename to download it from hugginface)
|
| 56 |
+
n_samples:int=None, # process a limited amount of samples
|
| 57 |
+
batch_size:int=1, # process several segments at once
|
| 58 |
+
fix_dots:bool=False, # fix dots in file names
|
| 59 |
+
):
|
| 60 |
+
if ":" in vq_model:
|
| 61 |
+
repo, fname = vq_model.split(":", 1)
|
| 62 |
+
vq_model = vq_stoks.RQBottleneckTransformer.load_model(repo, fname).cuda()
|
| 63 |
+
else:
|
| 64 |
+
vq_model = vq_stoks.RQBottleneckTransformer.load_model(local_filename=vq_model).cuda()
|
| 65 |
+
amodel = extract_acoustic.load_model()
|
| 66 |
+
amodel.set_target_bandwidth(3)
|
| 67 |
+
|
| 68 |
+
if input == "-":
|
| 69 |
+
input = [f.strip() for f in sys.stdin.readlines()]
|
| 70 |
+
assert output, "please provide the output shard name"
|
| 71 |
+
else:
|
| 72 |
+
if output is None: output = flac_to_s2a_name(input)
|
| 73 |
+
input = [input]
|
| 74 |
+
|
| 75 |
+
total = n_samples//batch_size if n_samples else 'noinfer'
|
| 76 |
+
|
| 77 |
+
ds = wds.WebDataset(input, shardshuffle=True, rename_files=vad.fix_dots_in_names if fix_dots else None).compose(
|
| 78 |
+
wds.decode(wds.torch_audio),
|
| 79 |
+
wds.select(lambda x: 'wav' in x or 'flac' in x),
|
| 80 |
+
vq_stoks.merge_in(vq_stoks.derived_dataset(proc_dataset_path, 'vad')),
|
| 81 |
+
wds.map_dict(**{"vad.npy":wh_transcribe.chunk_merger}),
|
| 82 |
+
lambda x: wh_transcribe.split_to_chunks(x),
|
| 83 |
+
resampler(),
|
| 84 |
+
resampler(16000, 'samples_16k'),
|
| 85 |
+
wds.to_tuple('__key__', 'rpad_s', 'samples_16k', 'samples_24k'),
|
| 86 |
+
wds.batched(64),
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
dl = wds.WebLoader(ds, num_workers=4, batch_size=None).unbatched().shuffle(2000).batched(batch_size)
|
| 90 |
+
|
| 91 |
+
speakers = set()
|
| 92 |
+
tmp = output+".tmp"
|
| 93 |
+
with wds.TarWriter(tmp) as sink:
|
| 94 |
+
for keys, rpad_ss, samples, samples24k in progress_bar(dl, total=total):
|
| 95 |
+
with record_function('to_cuda'):
|
| 96 |
+
samples, samples24k = samples.cuda(), samples24k.unsqueeze(1).cuda()
|
| 97 |
+
with record_function('encodec'):
|
| 98 |
+
atoks = amodel.encode(samples24k)[0][0]
|
| 99 |
+
with record_function('vq_stoks'):
|
| 100 |
+
stoks = vq_model.encode_audio(samples)
|
| 101 |
+
with record_function('from_cuda'):
|
| 102 |
+
atoks, stoks = atoks.cpu().numpy().astype(np.int16), stoks.cpu().numpy().astype(np.int16)
|
| 103 |
+
for key, rpad_s, _atoks, _stoks in zip(keys, rpad_ss, atoks, stoks):
|
| 104 |
+
speakers.add(key.split('/')[1])
|
| 105 |
+
sink.write({
|
| 106 |
+
"__key__": key,
|
| 107 |
+
"atoks.npy": _atoks[:,:int(-rpad_s * 75)],
|
| 108 |
+
"stoks.npy": _stoks[:int(-rpad_s * 25)],
|
| 109 |
+
})
|
| 110 |
+
with open(output+".speakers.txt", "w") as f: f.write("\n".join(speakers))
|
| 111 |
+
if not n_samples:
|
| 112 |
+
os.rename(tmp, output)
|
whisperspeech/prepare_t2s_dataset.py
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/5A. T2S dataset preparation.ipynb.
|
| 2 |
+
|
| 3 |
+
# %% auto 0
|
| 4 |
+
__all__ = []
|
| 5 |
+
|
| 6 |
+
# %% ../nbs/5A. T2S dataset preparation.ipynb 2
|
| 7 |
+
import sys
|
| 8 |
+
import os
|
| 9 |
+
import itertools
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
|
| 12 |
+
import numpy as np
|
| 13 |
+
import torch
|
| 14 |
+
import torchaudio
|
| 15 |
+
import torch.nn.functional as F
|
| 16 |
+
from torch.profiler import profile, record_function, ProfilerActivity
|
| 17 |
+
|
| 18 |
+
from fastprogress import progress_bar
|
| 19 |
+
from fastcore.script import *
|
| 20 |
+
|
| 21 |
+
import whisper, whisperx
|
| 22 |
+
from . import vad, wh_transcribe, vq_stoks, extract_acoustic
|
| 23 |
+
import webdataset as wds
|
| 24 |
+
|
| 25 |
+
# %% ../nbs/5A. T2S dataset preparation.ipynb 4
|
| 26 |
+
def flac_to_t2s_name(input):
|
| 27 |
+
return input.rsplit("/", 1)[1].replace('flac', 't2s') + ".gz"
|
| 28 |
+
|
| 29 |
+
# %% ../nbs/5A. T2S dataset preparation.ipynb 6
|
| 30 |
+
class Transcriber:
|
| 31 |
+
"""
|
| 32 |
+
A helper class to transcribe a batch of 30 second audio chunks.
|
| 33 |
+
"""
|
| 34 |
+
def __init__(self, model_size, lang=False):
|
| 35 |
+
self.model = whisperx.asr.load_model(model_size, "cuda", compute_type="float16", language=lang)
|
| 36 |
+
# without calling vad_model at least once the rest segfaults for some reason...
|
| 37 |
+
self.model.vad_model({"waveform": torch.zeros(1, 16000), "sample_rate": 16000})
|
| 38 |
+
|
| 39 |
+
def transcribe(self, batch):
|
| 40 |
+
batch = whisper.log_mel_spectrogram(batch)
|
| 41 |
+
embs = self.model.model.encode(batch.cpu().numpy())
|
| 42 |
+
return self.model.tokenizer.tokenizer.decode_batch([x.sequences_ids[0] for x in
|
| 43 |
+
self.model.model.model.generate(
|
| 44 |
+
embs,
|
| 45 |
+
[self.model.model.get_prompt(self.model.tokenizer, [], without_timestamps=True)]*len(batch),
|
| 46 |
+
)])
|
| 47 |
+
|
| 48 |
+
# %% ../nbs/5A. T2S dataset preparation.ipynb 7
|
| 49 |
+
@call_parse
|
| 50 |
+
def prepare_t2s(
|
| 51 |
+
input:str, # FLAC webdataset file path (or - to read the names from stdin)
|
| 52 |
+
proc_dataset_path:Path, # processed VAD files path
|
| 53 |
+
output:str=None, # output file name
|
| 54 |
+
vq_model:str="collabora/spear-tts-pytorch:whisper-vq-stoks.model", # the model path (use repo_id:filename to download it from hugginface)
|
| 55 |
+
n_samples:int=None, # process a limited amount of samples
|
| 56 |
+
batch_size:int=1, # process several segments at once
|
| 57 |
+
transcription_model:str="small.en",
|
| 58 |
+
):
|
| 59 |
+
if ":" in vq_model:
|
| 60 |
+
repo, fname = vq_model.split(":", 1)
|
| 61 |
+
vq_model = vq_stoks.RQBottleneckTransformer.load_model(repo, fname).cuda()
|
| 62 |
+
else:
|
| 63 |
+
vq_model = vq_stoks.RQBottleneckTransformer.load_model(local_filename=vq_model).cuda()
|
| 64 |
+
transcriber = Transcriber(transcription_model)
|
| 65 |
+
|
| 66 |
+
if input == "-":
|
| 67 |
+
input = [f.strip() for f in sys.stdin.readlines()]
|
| 68 |
+
assert output, "please provide the output shard name"
|
| 69 |
+
else:
|
| 70 |
+
if output is None: output = flac_to_t2s_name(input)
|
| 71 |
+
input = [input]
|
| 72 |
+
|
| 73 |
+
total = n_samples//batch_size if n_samples else 'noinfer'
|
| 74 |
+
if n_samples: print(f"Benchmarking run of {n_samples} samples ({total} batches)")
|
| 75 |
+
|
| 76 |
+
ds = wds.WebDataset(input, shardshuffle=True, rename_files=vad.fix_dots_in_names).compose(
|
| 77 |
+
wds.decode(wds.torch_audio),
|
| 78 |
+
vq_stoks.merge_in(vq_stoks.derived_dataset(proc_dataset_path, 'vad')),
|
| 79 |
+
wds.map_dict(**{"vad.npy": lambda s: wh_transcribe.chunk_merger(s, wh_transcribe.random_cutter)}),
|
| 80 |
+
lambda x: wh_transcribe.split_to_chunks(x),
|
| 81 |
+
# drop the first and last segment because they tend to be inaccurate
|
| 82 |
+
# (the transcriptions don't have the "LibriVox" header and "end of chapter" suffix)
|
| 83 |
+
wds.select(lambda x: x['i'] != 0 and x['i'] != x['imax']),
|
| 84 |
+
wds.to_tuple('__key__', 'rpad', 'samples'),
|
| 85 |
+
wds.batched(64),
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
dl = wds.WebLoader(ds, num_workers=4, batch_size=None).unbatched().shuffle(2000).batched(batch_size)
|
| 89 |
+
|
| 90 |
+
speakers = set()
|
| 91 |
+
tmp = output+".tmp"
|
| 92 |
+
with wds.TarWriter(tmp) as sink:
|
| 93 |
+
for keys, rpads, samples in progress_bar(dl, total=total):
|
| 94 |
+
with record_function('to_cuda'):
|
| 95 |
+
csamples = samples.cuda()
|
| 96 |
+
with record_function('transcribe'):
|
| 97 |
+
txts = transcriber.transcribe(csamples)
|
| 98 |
+
with record_function('vq_stoks'):
|
| 99 |
+
stoks = vq_model.encode_audio(csamples)
|
| 100 |
+
with record_function('from_cuda'):
|
| 101 |
+
stoks = stoks.cpu().numpy().astype(np.int16)
|
| 102 |
+
for key, rpad, txt, _stoks in zip(keys, rpads, txts, stoks):
|
| 103 |
+
speakers.add(key.split('/')[1])
|
| 104 |
+
sink.write({
|
| 105 |
+
"__key__": key,
|
| 106 |
+
"txt": txt,
|
| 107 |
+
"stoks.npy": _stoks[:int(-rpad/16000 * 25)],
|
| 108 |
+
})
|
| 109 |
+
with open(output+".speakers.txt", "w") as f: f.write("\n".join(speakers))
|
| 110 |
+
if not n_samples:
|
| 111 |
+
os.rename(tmp, output)
|
whisperspeech/s2a_delar_mup_wds.py
ADDED
|
@@ -0,0 +1,688 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/4B. Semantic to acoustic token modeling.ipynb.
|
| 2 |
+
|
| 3 |
+
# %% auto 0
|
| 4 |
+
__all__ = ['load_datasets', 'CMLMVisual', 'Rotary', 'rotate_half', 'apply_rotary_pos_emb', 'ResidualAttentionBlock',
|
| 5 |
+
'MultiHeadAttention', 'DelSumDecoder', 'EmbeddingProjector', 'rand', 'Tunables', 'SADelARTransformer']
|
| 6 |
+
|
| 7 |
+
# %% ../nbs/4B. Semantic to acoustic token modeling.ipynb 1
|
| 8 |
+
import io
|
| 9 |
+
import time
|
| 10 |
+
import math
|
| 11 |
+
import random
|
| 12 |
+
import dataclasses
|
| 13 |
+
|
| 14 |
+
# %% ../nbs/4B. Semantic to acoustic token modeling.ipynb 2
|
| 15 |
+
import torch
|
| 16 |
+
import torch.nn as nn
|
| 17 |
+
import torch.nn.functional as F
|
| 18 |
+
from torch.profiler import profile, record_function, ProfilerActivity, schedule
|
| 19 |
+
from fastcore.basics import store_attr
|
| 20 |
+
from huggingface_hub import hf_hub_download
|
| 21 |
+
|
| 22 |
+
# %% ../nbs/4B. Semantic to acoustic token modeling.ipynb 3
|
| 23 |
+
from pathlib import Path
|
| 24 |
+
import json
|
| 25 |
+
from fastprogress import progress_bar, master_bar
|
| 26 |
+
import webdataset as wds
|
| 27 |
+
|
| 28 |
+
# %% ../nbs/4B. Semantic to acoustic token modeling.ipynb 4
|
| 29 |
+
from .train import *
|
| 30 |
+
from .modules import *
|
| 31 |
+
from . import vq_stoks
|
| 32 |
+
|
| 33 |
+
# %% ../nbs/4B. Semantic to acoustic token modeling.ipynb 8
|
| 34 |
+
def rand(start, end):
|
| 35 |
+
return random.random() * (end - start) + start
|
| 36 |
+
|
| 37 |
+
# %% ../nbs/4B. Semantic to acoustic token modeling.ipynb 9
|
| 38 |
+
def random_trunc(random_trunc_p, atoks_len = 2250, stoks_len = 750):
|
| 39 |
+
atoks_per_second = atoks_len / 30
|
| 40 |
+
def _trunc(samples):
|
| 41 |
+
for s in samples:
|
| 42 |
+
if random.random() < random_trunc_p:
|
| 43 |
+
seconds = rand(0.3, 30)
|
| 44 |
+
s['atoks.npy'] = s['atoks.npy'][:,:math.ceil(seconds * atoks_per_second)]
|
| 45 |
+
s['stoks.npy'] = s['stoks.npy'][:math.ceil(s['atoks.npy'].shape[-1]/atoks_len*stoks_len)]
|
| 46 |
+
yield s
|
| 47 |
+
return _trunc
|
| 48 |
+
|
| 49 |
+
def pad_samples(atoks_len = 2250, stoks_len = 750, stoks_pad_token = 4096):
|
| 50 |
+
def _pad(samples):
|
| 51 |
+
for s in samples:
|
| 52 |
+
s['stoks.npy'] = F.pad(torch.tensor(s['stoks.npy']), (0, stoks_len - s['stoks.npy'].shape[-1]), value=stoks_pad_token)
|
| 53 |
+
s['atoks.npy'] = F.pad(torch.tensor(s['atoks.npy']), (0, atoks_len - s['atoks.npy'].shape[-1]), value=-100)
|
| 54 |
+
yield s
|
| 55 |
+
return _pad
|
| 56 |
+
|
| 57 |
+
# %% ../nbs/4B. Semantic to acoustic token modeling.ipynb 10
|
| 58 |
+
def speaker_id_extractor(speaker_map):
|
| 59 |
+
def _extractor(samples):
|
| 60 |
+
for s in samples:
|
| 61 |
+
s['speaker'] = torch.tensor(speaker_map[s['__key__'].split("/")[1]])
|
| 62 |
+
yield s
|
| 63 |
+
return _extractor
|
| 64 |
+
|
| 65 |
+
# %% ../nbs/4B. Semantic to acoustic token modeling.ipynb 14
|
| 66 |
+
def load_datasets(
|
| 67 |
+
input:str, # webdataset folder
|
| 68 |
+
samples:int, # samples per epoch
|
| 69 |
+
subsample:float=1, # use a fraction of the files
|
| 70 |
+
val_samples:int=512,
|
| 71 |
+
random_trunc_p:float=0,# probability of truncating the input to less than 30 seconds
|
| 72 |
+
stoks_pad_token=4096,
|
| 73 |
+
):
|
| 74 |
+
|
| 75 |
+
if isinstance(input, (Path, str)):
|
| 76 |
+
path = Path(input)
|
| 77 |
+
if path.is_dir():
|
| 78 |
+
glob = '*-s2a-*.tar.gz'
|
| 79 |
+
else:
|
| 80 |
+
glob = path.name
|
| 81 |
+
path = path.parent
|
| 82 |
+
input = Path(path).glob(glob)
|
| 83 |
+
elif isinstance(input, list):
|
| 84 |
+
pass
|
| 85 |
+
else:
|
| 86 |
+
raise ArgumentError("input should be either a list or a path with an optional glob specifier")
|
| 87 |
+
shards = [str(x) for x in input]
|
| 88 |
+
|
| 89 |
+
speakers = set()
|
| 90 |
+
for shard in shards:
|
| 91 |
+
with open(shard+'.speakers.txt') as f: speakers = speakers.union(set(x.strip() for x in f.readlines()))
|
| 92 |
+
speakers = {id:i for i,id in enumerate(sorted(speakers))}
|
| 93 |
+
|
| 94 |
+
def ds(shards, length):
|
| 95 |
+
ds = wds.WebDataset(wds.ResampledShards(shards)).compose(
|
| 96 |
+
wds.decode(),
|
| 97 |
+
speaker_id_extractor(speakers),
|
| 98 |
+
random_trunc(random_trunc_p) if random_trunc_p > 0 else lambda x: x,
|
| 99 |
+
pad_samples(stoks_pad_token=stoks_pad_token),
|
| 100 |
+
wds.to_tuple('stoks.npy', 'atoks.npy', 'speaker'),
|
| 101 |
+
wds.batched(64),
|
| 102 |
+
)
|
| 103 |
+
ds.speakers = speakers
|
| 104 |
+
ds.total_samples = length
|
| 105 |
+
return ds.compose(wds.slice(length // 64)).with_epoch(length // 64).with_length(length // 64)
|
| 106 |
+
|
| 107 |
+
return (
|
| 108 |
+
ds(shards[1:], samples),
|
| 109 |
+
ds(shards[:1], val_samples),
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
# %% ../nbs/4B. Semantic to acoustic token modeling.ipynb 33
|
| 113 |
+
import pylab as plt
|
| 114 |
+
import fastprogress
|
| 115 |
+
import IPython
|
| 116 |
+
import numpy as np
|
| 117 |
+
|
| 118 |
+
class CMLMVisual:
|
| 119 |
+
"""Visualize training progress"""
|
| 120 |
+
def __init__ (self, model, masterbar, total_steps):
|
| 121 |
+
self.model = model
|
| 122 |
+
self.masterbar = masterbar
|
| 123 |
+
self.total_steps = total_steps
|
| 124 |
+
self.epochs = total_steps // masterbar.main_bar.total
|
| 125 |
+
|
| 126 |
+
gs = plt.GridSpec(3, 1, height_ratios=[2,2,1])
|
| 127 |
+
graph_fig = plt.figure(figsize=(10,6))
|
| 128 |
+
self.graph_fig = graph_fig
|
| 129 |
+
self.loss_p = graph_fig.add_subplot(gs[0])
|
| 130 |
+
self.acc_p = graph_fig.add_subplot(gs[1], sharex=self.loss_p)
|
| 131 |
+
self.acc_p.tick_params('x', labelbottom=False)
|
| 132 |
+
self.lr_p = graph_fig.add_subplot(gs[2], sharex=self.loss_p)
|
| 133 |
+
self.lr_p.tick_params('x', labelbottom=False)
|
| 134 |
+
self.graph_out = None
|
| 135 |
+
|
| 136 |
+
self.its = []
|
| 137 |
+
self.train_losses = []
|
| 138 |
+
self.val_losses = []
|
| 139 |
+
self.lr_history = []
|
| 140 |
+
self.acc = np.nan
|
| 141 |
+
self.acc_history = []
|
| 142 |
+
self.pacc_history = []
|
| 143 |
+
|
| 144 |
+
def show(self):
|
| 145 |
+
self.start_t = time.time()
|
| 146 |
+
self.masterbar.write(["samples", "train", "val", "time"], table=True)
|
| 147 |
+
self.graph_out = display(self.graph_fig, display_id=True)
|
| 148 |
+
self.acc_out = display(IPython.display.HTML(''), display_id=True)
|
| 149 |
+
|
| 150 |
+
def hide(self):
|
| 151 |
+
if self.graph_out is not None:
|
| 152 |
+
self.graph_out.update(IPython.display.HTML(''))
|
| 153 |
+
|
| 154 |
+
def plot(self):
|
| 155 |
+
loss_p, acc_p, lr_p = self.loss_p, self.acc_p, self.lr_p
|
| 156 |
+
loss_p.clear()
|
| 157 |
+
loss_p.plot(self.its, self.train_losses)
|
| 158 |
+
loss_p.plot(self.its, self.val_losses)
|
| 159 |
+
loss_p.set_xlim(0, self.total_steps)
|
| 160 |
+
loss_p.set_yscale('log')
|
| 161 |
+
acc_p.clear()
|
| 162 |
+
for k in self.acc_history[-1].keys():
|
| 163 |
+
acc_p.plot(self.its, [x[k] for x in self.acc_history], ':')
|
| 164 |
+
# acc_p.plot(self.its, np.stack(self.pacc_history), label=range(len(self.pacc_history[0])))
|
| 165 |
+
lr_p.clear()
|
| 166 |
+
lrs = np.array(self.lr_history)
|
| 167 |
+
lr_p.plot(self.its, lrs)
|
| 168 |
+
self.graph_out.update(self.graph_fig)
|
| 169 |
+
|
| 170 |
+
def add_data(self, it, lr, train_loss, val_los):
|
| 171 |
+
self.its.append(it)
|
| 172 |
+
self.train_losses.append(train_loss)
|
| 173 |
+
self.val_losses.append(val_los)
|
| 174 |
+
self.lr_history.append(lr)
|
| 175 |
+
metrics = self.model.get_metrics()
|
| 176 |
+
self.acc_history.append(metrics)
|
| 177 |
+
# self.acc_out.update(f"Accuracy: {self.entropy_history[-1]:.2f}")
|
| 178 |
+
# self.pacc_history.append((self.model.pval_true / self.model.pval_total).cpu().numpy())
|
| 179 |
+
# if self.acc_history:
|
| 180 |
+
html = "<h5>Accuracies:</h5><table>"
|
| 181 |
+
html += "<thead>"+(''.join([f"<td>{k}<td>" for k,x in metrics.items()]))+"</thead>"
|
| 182 |
+
html += "<tr>"+(''.join([f"<td>{x*100:.1f}%<td>" for k,x in metrics.items()]))+"</tr>"
|
| 183 |
+
html += "</table>"
|
| 184 |
+
self.acc_out.update(IPython.display.HTML(html))
|
| 185 |
+
self.plot()
|
| 186 |
+
|
| 187 |
+
def add_table_row(self, it, avg_train_loss, val_loss):
|
| 188 |
+
elapsed_t = time.time() - self.start_t
|
| 189 |
+
self.masterbar.write([it, f"{avg_train_loss:.5f}", f"{val_loss:.5f}", fastprogress.core.format_time(elapsed_t)], table=True)
|
| 190 |
+
|
| 191 |
+
def on_iter(self, bar, it, avg_train_loss, val_loss):
|
| 192 |
+
epoch = math.ceil(it / self.total_steps * self.epochs)
|
| 193 |
+
bar.comment = f"#{epoch}/{self.epochs} loss: {avg_train_loss:.3f} / {val_loss:.3f}"
|
| 194 |
+
|
| 195 |
+
# %% ../nbs/4B. Semantic to acoustic token modeling.ipynb 34
|
| 196 |
+
# modified from https://blog.eleuther.ai/rotary-embeddings/
|
| 197 |
+
import torch
|
| 198 |
+
|
| 199 |
+
class Rotary(torch.nn.Module):
|
| 200 |
+
def __init__(self, dim, base=10000):
|
| 201 |
+
super().__init__()
|
| 202 |
+
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
|
| 203 |
+
self.register_buffer("inv_freq", inv_freq)
|
| 204 |
+
self.seq_len_cached = None
|
| 205 |
+
self.cos_cached = None
|
| 206 |
+
self.sin_cached = None
|
| 207 |
+
|
| 208 |
+
def forward(self, x, seq_dim=1):
|
| 209 |
+
seq_len = x.shape[seq_dim]
|
| 210 |
+
if seq_len != self.seq_len_cached:
|
| 211 |
+
self.seq_len_cached = seq_len
|
| 212 |
+
t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq)
|
| 213 |
+
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
| 214 |
+
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
|
| 215 |
+
self.cos_cached = emb.cos()[None, :, None, :]
|
| 216 |
+
self.sin_cached = emb.sin()[None, :, None, :]
|
| 217 |
+
return self.cos_cached, self.sin_cached
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
# rotary pos emb helpers:
|
| 221 |
+
def rotate_half(x):
|
| 222 |
+
x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
|
| 223 |
+
return torch.cat(
|
| 224 |
+
(-x2, x1), dim=-1
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
#@torch.jit.script
|
| 228 |
+
def apply_rotary_pos_emb(q, k, cos, sin):
|
| 229 |
+
return (q * cos[:,:q.shape[1]]) + (rotate_half(q) * sin[:,:q.shape[1]]), (k * cos) + (rotate_half(k) * sin)
|
| 230 |
+
|
| 231 |
+
# %% ../nbs/4B. Semantic to acoustic token modeling.ipynb 35
|
| 232 |
+
from torch import Tensor, nn
|
| 233 |
+
import torch.nn.functional as F
|
| 234 |
+
from typing import Dict, Iterable, Optional
|
| 235 |
+
|
| 236 |
+
class ResidualAttentionBlock(nn.Module):
|
| 237 |
+
def __init__(self, n_state: int, n_head: int, cross_attention: bool = False, rope: bool = False,
|
| 238 |
+
qk_scale: float = 1, ffn_mult: int = 4):
|
| 239 |
+
super().__init__()
|
| 240 |
+
|
| 241 |
+
self.attn = MultiHeadAttention(n_state, n_head, qk_scale=qk_scale, rope=rope)
|
| 242 |
+
self.attn_ln = LayerNorm(n_state)
|
| 243 |
+
|
| 244 |
+
self.cross_attn = (
|
| 245 |
+
MultiHeadAttention(n_state, n_head, qk_scale=qk_scale, rope=rope) if cross_attention else None
|
| 246 |
+
)
|
| 247 |
+
self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None
|
| 248 |
+
|
| 249 |
+
n_mlp = n_state * ffn_mult
|
| 250 |
+
self.mlp = nn.Sequential(
|
| 251 |
+
nn.Linear(n_state, n_mlp), nn.GELU(), nn.Linear(n_mlp, n_state)
|
| 252 |
+
)
|
| 253 |
+
self.mlp_ln = LayerNorm(n_state)
|
| 254 |
+
|
| 255 |
+
def forward(
|
| 256 |
+
self,
|
| 257 |
+
x: Tensor,
|
| 258 |
+
xa: Optional[Tensor] = None,
|
| 259 |
+
causal = False,
|
| 260 |
+
kv_cache: Optional[dict] = None,
|
| 261 |
+
):
|
| 262 |
+
x = x + self.attn(self.attn_ln(x), causal=causal, kv_cache=kv_cache)[0]
|
| 263 |
+
if self.cross_attn:
|
| 264 |
+
x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)[0]
|
| 265 |
+
x = x + self.mlp(self.mlp_ln(x))
|
| 266 |
+
return x
|
| 267 |
+
|
| 268 |
+
class MultiHeadAttention(nn.Module):
|
| 269 |
+
def __init__(self, n_state: int, n_head: int, qk_scale: float = 1, rope: bool = False):
|
| 270 |
+
super().__init__()
|
| 271 |
+
self.n_head = n_head
|
| 272 |
+
self.sqrt_qk_scale = math.sqrt(qk_scale)
|
| 273 |
+
self.query = QueryHead(n_state, n_state)
|
| 274 |
+
self.key = nn.Linear(n_state, n_state, bias=False)
|
| 275 |
+
self.value = nn.Linear(n_state, n_state)
|
| 276 |
+
self.out = nn.Linear(n_state, n_state)
|
| 277 |
+
|
| 278 |
+
self.rotary = None
|
| 279 |
+
if rope:
|
| 280 |
+
self.rotary = Rotary(n_state // n_head)
|
| 281 |
+
|
| 282 |
+
def forward(
|
| 283 |
+
self,
|
| 284 |
+
x: Tensor,
|
| 285 |
+
xa: Optional[Tensor] = None,
|
| 286 |
+
causal = False,
|
| 287 |
+
kv_cache: Optional[dict] = None,
|
| 288 |
+
):
|
| 289 |
+
q = self.query(x)
|
| 290 |
+
|
| 291 |
+
if kv_cache is None or xa is None or self.key not in kv_cache:
|
| 292 |
+
# hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors;
|
| 293 |
+
# otherwise, perform key/value projections for self- or cross-attention as usual.
|
| 294 |
+
k = self.key(x if xa is None else xa)
|
| 295 |
+
v = self.value(x if xa is None else xa)
|
| 296 |
+
else:
|
| 297 |
+
# for cross-attention, calculate keys and values once and reuse in subsequent calls.
|
| 298 |
+
k = kv_cache[self.key]
|
| 299 |
+
v = kv_cache[self.value]
|
| 300 |
+
|
| 301 |
+
if self.sqrt_qk_scale != 1:
|
| 302 |
+
q *= self.sqrt_qk_scale
|
| 303 |
+
k *= self.sqrt_qk_scale
|
| 304 |
+
|
| 305 |
+
wv, qk = self.qkv_attention_pth20(q, k, v, causal)
|
| 306 |
+
# wv, qk = self.qkv_attention_xformers(q, k, v, causal)
|
| 307 |
+
|
| 308 |
+
return self.out(wv), qk
|
| 309 |
+
|
| 310 |
+
def qkv_attention_pth20(
|
| 311 |
+
self, q: Tensor, k: Tensor, v: Tensor, causal = False
|
| 312 |
+
):
|
| 313 |
+
n_batch, n_ctx, n_state = q.shape
|
| 314 |
+
q = q.view(*q.shape[:2], self.n_head, -1)
|
| 315 |
+
k = k.view(*k.shape[:2], self.n_head, -1)
|
| 316 |
+
v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
|
| 317 |
+
|
| 318 |
+
#print('before rot:', q.shape, k.shape)
|
| 319 |
+
if self.rotary:
|
| 320 |
+
q, k = apply_rotary_pos_emb(q, k, *self.rotary(k))
|
| 321 |
+
#print(' after rot:', q.shape, k.shape)
|
| 322 |
+
|
| 323 |
+
k = k.permute(0, 2, 1, 3)
|
| 324 |
+
q = q.permute(0, 2, 1, 3)
|
| 325 |
+
# modified for better performance under PyTorch 2.0
|
| 326 |
+
wv = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0, is_causal=causal)
|
| 327 |
+
|
| 328 |
+
# previously we've returned q@k which we don't have now
|
| 329 |
+
# since it's not actually used anywhere else, let's just keep two return values for compatibility
|
| 330 |
+
return wv.permute(0, 2, 1, 3).flatten(start_dim=2), None
|
| 331 |
+
|
| 332 |
+
def qkv_attention_xformers(
|
| 333 |
+
self, q: Tensor, k: Tensor, v: Tensor, causal = False
|
| 334 |
+
):
|
| 335 |
+
n_batch, n_ctx, n_state = q.shape
|
| 336 |
+
q = q.view(*q.shape[:2], self.n_head, -1)
|
| 337 |
+
k = k.view(*k.shape[:2], self.n_head, -1)
|
| 338 |
+
v = v.view(*v.shape[:2], self.n_head, -1)
|
| 339 |
+
|
| 340 |
+
if self.rotary:
|
| 341 |
+
q, k = apply_rotary_pos_emb(q, k, *self.rotary(k))
|
| 342 |
+
|
| 343 |
+
bias = xops.LowerTriangularMask() if causal else None
|
| 344 |
+
wv = xops.memory_efficient_attention(q,k,v, attn_bias=bias)
|
| 345 |
+
|
| 346 |
+
# previously we've returned q@k which we don't have now
|
| 347 |
+
# since it's not actually used anywhere else, let's just keep two return values for compatibility
|
| 348 |
+
return wv.flatten(start_dim=2), None
|
| 349 |
+
|
| 350 |
+
# %% ../nbs/4B. Semantic to acoustic token modeling.ipynb 36
|
| 351 |
+
class DelSumDecoder(nn.Module):
|
| 352 |
+
def __init__(self, depth=6, n_head=6, head_width=64, qk_scale=1, ffn_mult=4, length=2250, codes=1024, quantizers=8, linear_heads=True, rope=False, pos_embs=None):
|
| 353 |
+
super().__init__()
|
| 354 |
+
self.length = length
|
| 355 |
+
width = n_head * head_width
|
| 356 |
+
self.width = width
|
| 357 |
+
self.codes = codes
|
| 358 |
+
self.quantizers = quantizers
|
| 359 |
+
self.linear_heads = linear_heads
|
| 360 |
+
|
| 361 |
+
self.embeddings = nn.ModuleList([nn.Embedding(codes+1, width) for _ in range(quantizers)])
|
| 362 |
+
if pos_embs is not None:
|
| 363 |
+
self.register_buffer("positional_embedding", pos_embs)
|
| 364 |
+
|
| 365 |
+
self.layers = nn.ModuleList([
|
| 366 |
+
ResidualAttentionBlock(width, n_head, qk_scale=qk_scale, ffn_mult=ffn_mult, cross_attention=True, rope=rope) for _ in range(math.floor(depth))
|
| 367 |
+
])
|
| 368 |
+
|
| 369 |
+
self.ln_post = LayerNorm(width)
|
| 370 |
+
|
| 371 |
+
if self.linear_heads:
|
| 372 |
+
self.heads = LinearHead(width, (codes+1) * quantizers, bias=False)
|
| 373 |
+
else:
|
| 374 |
+
self.splitter = nn.Sequential(
|
| 375 |
+
nn.Linear(width, width * quantizers),
|
| 376 |
+
nn.GELU(),
|
| 377 |
+
)
|
| 378 |
+
self.heads = nn.ModuleList([
|
| 379 |
+
LinearHead(width, codes+1, bias=True) for _ in range(quantizers)
|
| 380 |
+
])
|
| 381 |
+
|
| 382 |
+
def forward(self, toks, xenc):
|
| 383 |
+
b,_,n = toks.shape
|
| 384 |
+
newn = min(n+1, self.length)
|
| 385 |
+
embs = torch.zeros((b,newn,self.width), dtype=xenc.dtype, device=xenc.device)
|
| 386 |
+
for i in range(self.quantizers):
|
| 387 |
+
embs[:,:i+1] += self.embeddings[i](torch.tensor([self.codes], device=xenc.device))
|
| 388 |
+
if i < n:
|
| 389 |
+
embs[:,i+1:] += self.embeddings[i](toks[:,i,:newn-i-1])
|
| 390 |
+
|
| 391 |
+
x = embs.to(xenc.dtype)
|
| 392 |
+
|
| 393 |
+
for l in self.layers:
|
| 394 |
+
x = l(x, xenc, causal=True)
|
| 395 |
+
x = self.ln_post(x)
|
| 396 |
+
|
| 397 |
+
if self.linear_heads:
|
| 398 |
+
logits = self.heads(x).view(b,newn,self.quantizers,self.codes+1).permute(0,2,1,3)
|
| 399 |
+
else:
|
| 400 |
+
split = self.splitter(x).view(b,newn,self.quantizers,self.width)
|
| 401 |
+
logits = torch.stack([self.heads[q](split[:,:,q]) for q in range(self.quantizers)], dim=1)
|
| 402 |
+
|
| 403 |
+
return logits
|
| 404 |
+
|
| 405 |
+
class EmbeddingProjector(nn.Linear):
|
| 406 |
+
pass
|
| 407 |
+
|
| 408 |
+
def rand(start, end):
|
| 409 |
+
return random.random() * (end - start) + start
|
| 410 |
+
|
| 411 |
+
@dataclasses.dataclass
|
| 412 |
+
class Tunables:
|
| 413 |
+
init_std :float = 9
|
| 414 |
+
embeddings_std :float = 0.2
|
| 415 |
+
embeddings_lr_scale: float = 10
|
| 416 |
+
output_mult :float = 5.6
|
| 417 |
+
# FIXME: try separate mults for self and cross attention
|
| 418 |
+
query_mult :float = .3
|
| 419 |
+
encoder_depth_ratio :float = 0.25
|
| 420 |
+
linear_heads :bool = False
|
| 421 |
+
rope :bool = True
|
| 422 |
+
|
| 423 |
+
lr0 :float = 3e-3
|
| 424 |
+
clip_gradient_norm :float = 2
|
| 425 |
+
weight_decay :float = 1e-3
|
| 426 |
+
warmup_steps :float = 2000
|
| 427 |
+
|
| 428 |
+
random :bool = False
|
| 429 |
+
|
| 430 |
+
def __post_init__(self):
|
| 431 |
+
# randomize the hyperparams if requested
|
| 432 |
+
if self.random:
|
| 433 |
+
self.init_std = 2*10**rand(0,1)
|
| 434 |
+
self.embeddings_std = 10**rand(-1.7,-0.22)
|
| 435 |
+
self.embeddings_lr_scale = 2**rand(2,4)
|
| 436 |
+
self.output_mult = 2**rand(1.5,3)
|
| 437 |
+
self.query_mult = 2**rand(-3,-1.3)
|
| 438 |
+
self.encoder_depth_ratio = random.choice([0.25,0.5])
|
| 439 |
+
self.linear_heads = False
|
| 440 |
+
self.rope = True
|
| 441 |
+
|
| 442 |
+
self.lr0 = 3e-3
|
| 443 |
+
self.clip_gradient_norm = 10**rand(-1,1)
|
| 444 |
+
self.warmup_steps = 100*(10**rand(1.18,1.3))
|
| 445 |
+
|
| 446 |
+
@staticmethod
|
| 447 |
+
def upgrade(args):
|
| 448 |
+
args = {k:v for k,v in args.items()}
|
| 449 |
+
def old_default(name, value):
|
| 450 |
+
if name not in args: args[name] = value
|
| 451 |
+
old_default('rope', False)
|
| 452 |
+
old_default('linear_heads', True)
|
| 453 |
+
return args
|
| 454 |
+
|
| 455 |
+
class SADelARTransformer(nn.Module):
|
| 456 |
+
def __init__(self, depth=3, ctx_n=2250, stoks_len=750, stoks_codes=4097, stoks_width=None, spk_width=None, n_head=3, head_width=64, ffn_mult=4,
|
| 457 |
+
quantizers=8, speaker_map={"1":0}, tunables=Tunables()):
|
| 458 |
+
super().__init__()
|
| 459 |
+
self.quantizers = quantizers
|
| 460 |
+
width = n_head * head_width
|
| 461 |
+
store_attr("depth,ctx_n,stoks_len,stoks_codes,stoks_width,spk_width,n_head,head_width,ffn_mult,quantizers,speaker_map")
|
| 462 |
+
self.width = width
|
| 463 |
+
self.base_width = 3 * head_width
|
| 464 |
+
self.tunables = tunables
|
| 465 |
+
|
| 466 |
+
if stoks_width is None: stoks_width = width
|
| 467 |
+
if spk_width is None: spk_width = width
|
| 468 |
+
self.emb_factor = width != stoks_width
|
| 469 |
+
self.spk_factor = width != spk_width
|
| 470 |
+
|
| 471 |
+
if tunables.rope:
|
| 472 |
+
self.positional_embeddings = None
|
| 473 |
+
else:
|
| 474 |
+
self.register_buffer('positional_embeddings', sinusoids(ctx_n, width))
|
| 475 |
+
|
| 476 |
+
self.speaker_embedding = nn.Embedding(len(speaker_map), width)
|
| 477 |
+
self.semantic_embedding = nn.Embedding(stoks_codes, stoks_width)
|
| 478 |
+
if self.emb_factor:
|
| 479 |
+
self.emb_to_hidden = nn.Linear(stoks_width, width)
|
| 480 |
+
|
| 481 |
+
if self.spk_factor:
|
| 482 |
+
self.spk_to_hidden = EmbeddingProjector(spk_width, width)
|
| 483 |
+
|
| 484 |
+
qk_scale = self.tunables.query_mult * 8 / math.sqrt(head_width)
|
| 485 |
+
|
| 486 |
+
encoder_depth = int(depth * 2 * tunables.encoder_depth_ratio)
|
| 487 |
+
decoder_depth = depth * 2 - encoder_depth
|
| 488 |
+
self.encoder = nn.Sequential(*[
|
| 489 |
+
ResidualAttentionBlock(width, n_head, qk_scale=qk_scale, ffn_mult=ffn_mult, rope=tunables.rope) for _ in range(encoder_depth)
|
| 490 |
+
])
|
| 491 |
+
self.ln_post = LayerNorm(width)
|
| 492 |
+
|
| 493 |
+
self.decoder = DelSumDecoder(pos_embs=self.positional_embeddings, qk_scale=qk_scale,
|
| 494 |
+
length=ctx_n, n_head=n_head, head_width=head_width, ffn_mult=ffn_mult,
|
| 495 |
+
depth=decoder_depth, quantizers=quantizers,
|
| 496 |
+
linear_heads=tunables.linear_heads, rope=tunables.rope)
|
| 497 |
+
|
| 498 |
+
self.register_buffer('val_true', torch.zeros(self.quantizers).cuda())
|
| 499 |
+
self.register_buffer('val_total', torch.zeros(self.quantizers).cuda())
|
| 500 |
+
self.apply(self.init_transformer)
|
| 501 |
+
|
| 502 |
+
def setup(self, device):
|
| 503 |
+
pass
|
| 504 |
+
|
| 505 |
+
def load_frozen_semantic_embeddings(self, vqmodel):
|
| 506 |
+
with torch.no_grad():
|
| 507 |
+
self.semantic_embedding.weight[:] = vqmodel.rq.layers[0]._codebook.embed[0]
|
| 508 |
+
self.semantic_embedding.lr_scale = 0
|
| 509 |
+
|
| 510 |
+
def init_transformer(self, m):
|
| 511 |
+
if isinstance(m, LinearHead):
|
| 512 |
+
m.no_weight_decay = True
|
| 513 |
+
torch.nn.init.constant_(m.weight, 0)
|
| 514 |
+
elif isinstance(m, QueryHead):
|
| 515 |
+
m.lr_scale = 1/(m.weight.shape[1] / self.base_width)
|
| 516 |
+
torch.nn.init.constant_(m.weight, 0)
|
| 517 |
+
elif isinstance(m, nn.Embedding):
|
| 518 |
+
m.no_weight_decay = True
|
| 519 |
+
m.lr_scale = self.tunables.embeddings_lr_scale
|
| 520 |
+
std = self.tunables.embeddings_std
|
| 521 |
+
torch.nn.init.trunc_normal_(m.weight, std=std, a=-3*std, b=3*std)
|
| 522 |
+
elif isinstance(m, EmbeddingProjector):
|
| 523 |
+
m.lr_scale = self.tunables.embeddings_lr_scale/2
|
| 524 |
+
std = self.tunables.init_std
|
| 525 |
+
torch.nn.init.trunc_normal_(m.weight, std=std, a=-3*std, b=3*std)
|
| 526 |
+
elif isinstance(m, nn.Linear):
|
| 527 |
+
m.lr_scale = 1/(m.weight.shape[1] / self.base_width)
|
| 528 |
+
std = self.tunables.init_std / m.weight.shape[1]
|
| 529 |
+
torch.nn.init.trunc_normal_(m.weight, std=std, a=-3*std, b=3*std)
|
| 530 |
+
if m.bias is not None:
|
| 531 |
+
torch.nn.init.trunc_normal_(m.bias, std=std, a=-3*std, b=3*std)
|
| 532 |
+
elif isinstance(m, nn.LayerNorm):
|
| 533 |
+
m.no_weight_decay = True
|
| 534 |
+
torch.nn.init.constant_(m.bias, 0)
|
| 535 |
+
torch.nn.init.constant_(m.weight, 1)
|
| 536 |
+
|
| 537 |
+
def embed_stoks(self, Stoks):
|
| 538 |
+
b,n = Stoks.shape
|
| 539 |
+
if self.stoks_len == 1500:
|
| 540 |
+
# converts 50 toks/s to 75 toks/s by adding padding between every two tokens
|
| 541 |
+
x = Stoks.reshape(b,n//2,2)
|
| 542 |
+
x = x.repeat_interleave(2, -1)[:,:,:3]
|
| 543 |
+
x[:,:,1] = 1024
|
| 544 |
+
x = x.reshape(b,n//2*3)
|
| 545 |
+
else:
|
| 546 |
+
# it's a lot easier with 25 toks/s
|
| 547 |
+
x = Stoks.repeat_interleave(3, -1)
|
| 548 |
+
# embed semantic tokens
|
| 549 |
+
Sembs = self.semantic_embedding(x.to(torch.long))
|
| 550 |
+
if self.emb_factor:
|
| 551 |
+
Sembs = self.emb_to_hidden(Sembs)
|
| 552 |
+
return Sembs
|
| 553 |
+
|
| 554 |
+
def forward(self, Stoks, Atoks, speakers, noloss=False):
|
| 555 |
+
Atoks = Atoks.to(torch.long)
|
| 556 |
+
semb = self.embed_stoks(Stoks)
|
| 557 |
+
with record_function("encoder"):
|
| 558 |
+
if self.positional_embeddings is not None: semb = semb + self.positional_embeddings
|
| 559 |
+
xenc = self.ln_post(self.encoder(semb))
|
| 560 |
+
# xenc = torch.zeros_like(xenc)
|
| 561 |
+
with record_function("decoder"):
|
| 562 |
+
Atoks_gt = Atoks.clone()
|
| 563 |
+
Atoks_gt[Atoks == -100] = 1024
|
| 564 |
+
# we can randomize speaker ids during validation to measure
|
| 565 |
+
# the importance of the speaker embedding vs. just the acoustic prompt/prefix
|
| 566 |
+
# if not self.training: speakers = speakers[torch.randperm(speakers.nelement())]
|
| 567 |
+
spk_embs = self.speaker_embedding(speakers)
|
| 568 |
+
if self.spk_factor: spk_embs = self.spk_to_hidden(spk_embs)
|
| 569 |
+
logits = self.decoder(Atoks_gt, xenc + spk_embs.unsqueeze(1))
|
| 570 |
+
logits *= self.tunables.output_mult / (self.width / self.base_width)
|
| 571 |
+
|
| 572 |
+
if noloss:
|
| 573 |
+
return logits
|
| 574 |
+
|
| 575 |
+
with record_function("loss"):
|
| 576 |
+
N = Atoks.shape[-1]
|
| 577 |
+
loss = 0
|
| 578 |
+
for i in range(self.quantizers):
|
| 579 |
+
loss += F.cross_entropy(logits[:,i,i:].reshape(-1,logits.shape[-1]), Atoks[:,i,:N-i].reshape(-1))
|
| 580 |
+
loss /= self.quantizers
|
| 581 |
+
|
| 582 |
+
if not self.training:
|
| 583 |
+
for i in range(self.quantizers):
|
| 584 |
+
Atoks_i = Atoks[:,i,:N-i]
|
| 585 |
+
valid_Atoks = Atoks_i != -100
|
| 586 |
+
self.val_true[i] += (logits[:,i,i:].argmax(-1)[valid_Atoks] == Atoks_i[valid_Atoks]).float().sum()
|
| 587 |
+
self.val_total[i] += valid_Atoks.float().sum()
|
| 588 |
+
|
| 589 |
+
return logits, loss
|
| 590 |
+
|
| 591 |
+
def get_metrics(self):
|
| 592 |
+
metrics = {
|
| 593 |
+
f'acc_{i}':x.item() for i,x in enumerate(self.val_true / self.val_total)
|
| 594 |
+
}
|
| 595 |
+
self.val_true[:] = 0
|
| 596 |
+
self.val_total[:] = 0
|
| 597 |
+
return metrics
|
| 598 |
+
|
| 599 |
+
#
|
| 600 |
+
# inference
|
| 601 |
+
#
|
| 602 |
+
@classmethod
|
| 603 |
+
def load_model(cls, repo_id="collabora/whisperspeech", filename="s2a_up_wds.model", local_filename=None):
|
| 604 |
+
if not local_filename:
|
| 605 |
+
local_filename = hf_hub_download(repo_id=repo_id, filename=filename)
|
| 606 |
+
spec = torch.load(local_filename)
|
| 607 |
+
if '_extra_state' not in spec['state_dict']: spec['state_dict']['_extra_state'] = { 'speaker_map': spec['config']['speaker_map'] }
|
| 608 |
+
model = cls(**spec['config'], tunables=Tunables(**Tunables.upgrade(spec['tunables'])))
|
| 609 |
+
model.load_state_dict(spec['state_dict'])
|
| 610 |
+
model.eval()
|
| 611 |
+
return model
|
| 612 |
+
|
| 613 |
+
def get_extra_state(self):
|
| 614 |
+
return { 'speaker_map': self.speaker_map }
|
| 615 |
+
|
| 616 |
+
def set_extra_state(self, st):
|
| 617 |
+
self.speaker_map = st['speaker_map']
|
| 618 |
+
|
| 619 |
+
def load_checkpoint(self, local_filename):
|
| 620 |
+
spec = torch.load(local_filename, map_location='cpu')
|
| 621 |
+
assert 'pytorch-lightning_version' in spec, 'not a valid PyTorch Lightning checkpoint'
|
| 622 |
+
state_dict = {k.replace('model.', ''):v
|
| 623 |
+
for k,v in spec['state_dict'].items()}
|
| 624 |
+
self.load_state_dict(state_dict)
|
| 625 |
+
return self
|
| 626 |
+
|
| 627 |
+
def save_model(self, fname):
|
| 628 |
+
torch.save(dict(config = self.__stored_args__,
|
| 629 |
+
tunables = dataclasses.asdict(self.tunables),
|
| 630 |
+
state_dict = self.state_dict()), fname)
|
| 631 |
+
|
| 632 |
+
@property
|
| 633 |
+
def device(self):
|
| 634 |
+
return next(self.parameters()).device
|
| 635 |
+
|
| 636 |
+
@torch.no_grad()
|
| 637 |
+
def generate(self, stoks, speakers, N=None, T=0.7, top_k=None, show_progress_bar=True):
|
| 638 |
+
dev = self.device
|
| 639 |
+
if self.stoks_len == 1500:
|
| 640 |
+
N = N or len(stoks) * 3 // 2
|
| 641 |
+
else:
|
| 642 |
+
N = N or len(stoks) * 3
|
| 643 |
+
stoks = F.pad(stoks.to(dev), (0, self.stoks_len - len(stoks)), value=self.stoks_codes-1).unsqueeze(0)
|
| 644 |
+
speakers = torch.tensor([self.speaker_map[spk] for spk in speakers], device=dev)
|
| 645 |
+
toks = torch.zeros((1,self.quantizers,N), dtype=torch.long, device=dev)
|
| 646 |
+
it = range(0,N)
|
| 647 |
+
if show_progress_bar: it = progress_bar(it)
|
| 648 |
+
for i in it:
|
| 649 |
+
p = self(stoks, toks[:,:,:i], speakers, noloss=True)
|
| 650 |
+
last_p = p[0,:,-1]
|
| 651 |
+
if top_k:
|
| 652 |
+
last_p[last_p < torch.topk(last_p, top_k).values[:,-1,None]] = -torch.inf
|
| 653 |
+
for j,tok in enumerate(torch.multinomial((last_p / float(T)).softmax(-1), 1)):
|
| 654 |
+
toks[0,j,max(0,i-j)] = tok
|
| 655 |
+
if toks[0,0,i] == 1024: return toks[0,:,:i]
|
| 656 |
+
return toks[0]
|
| 657 |
+
|
| 658 |
+
# %% ../nbs/4B. Semantic to acoustic token modeling.ipynb 37
|
| 659 |
+
def _make_model(size:str, quantizers:int=4, tunables:Tunables=Tunables(), dataset:torch.utils.data.Dataset=None, **kwargs):
|
| 660 |
+
assert(dataset is not None)
|
| 661 |
+
kwargs = dict(speaker_map=dataset.speakers, quantizers=quantizers, tunables=tunables, **kwargs)
|
| 662 |
+
if size == 'micro':
|
| 663 |
+
return SADelARTransformer(depth=4, n_head=3, ffn_mult=2, **kwargs)
|
| 664 |
+
if size == 'tiny-narrow':
|
| 665 |
+
return SADelARTransformer(depth=4, n_head=6, ffn_mult=1, **kwargs)
|
| 666 |
+
if size == 'tiny':
|
| 667 |
+
return SADelARTransformer(depth=4, n_head=6, **kwargs)
|
| 668 |
+
if size == 'base':
|
| 669 |
+
return SADelARTransformer(depth=6, n_head=8, **kwargs)
|
| 670 |
+
if size == 'base-deep':
|
| 671 |
+
return SADelARTransformer(depth=9, n_head=8, **kwargs)
|
| 672 |
+
if size == 'base-wide':
|
| 673 |
+
return SADelARTransformer(depth=6, n_head=12, **kwargs)
|
| 674 |
+
if size == 'small/2':
|
| 675 |
+
return SADelARTransformer(depth=9, n_head=12, **kwargs)
|
| 676 |
+
if size == 'small':
|
| 677 |
+
return SADelARTransformer(depth=12, n_head=12, **kwargs)
|
| 678 |
+
if size == 'medium':
|
| 679 |
+
return SADelARTransformer(depth=24, n_head=16, **kwargs)
|
| 680 |
+
|
| 681 |
+
def make_model(size:str, quantizers:int=4, frozen_embeddings_model:str=None, tunables:Tunables=Tunables(), dataset:torch.utils.data.Dataset=None):
|
| 682 |
+
if frozen_embeddings_model:
|
| 683 |
+
vqmodel = vq_stoks.RQBottleneckTransformer.load_model(frozen_embeddings_model)
|
| 684 |
+
model = _make_model(size, quantizers, tunables, dataset, stoks_codes=vqmodel.vq_codes+1, stoks_width=vqmodel.rq.layers[0]._codebook.embed[0].shape[-1])
|
| 685 |
+
model.load_frozen_semantic_embeddings(vqmodel)
|
| 686 |
+
else:
|
| 687 |
+
model = _make_model(size, quantizers, tunables, dataset)
|
| 688 |
+
return model
|
whisperspeech/s2a_delar_mup_wds_mlang.py
ADDED
|
@@ -0,0 +1,564 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/4B. Multi-language semantic to acoustic token modeling.ipynb.
|
| 2 |
+
|
| 3 |
+
# %% auto 0
|
| 4 |
+
__all__ = ['load_dataset', 'DelSumEmbedding', 'DelSumHead', 'rand', 'Tunables', 'SADelARTransformer']
|
| 5 |
+
|
| 6 |
+
# %% ../nbs/4B. Multi-language semantic to acoustic token modeling.ipynb 1
|
| 7 |
+
import io
|
| 8 |
+
import time
|
| 9 |
+
import math
|
| 10 |
+
import random
|
| 11 |
+
import dataclasses
|
| 12 |
+
|
| 13 |
+
# %% ../nbs/4B. Multi-language semantic to acoustic token modeling.ipynb 2
|
| 14 |
+
import torch
|
| 15 |
+
import torch.nn as nn
|
| 16 |
+
import torch.nn.functional as F
|
| 17 |
+
import numpy as np
|
| 18 |
+
from torch.profiler import profile, record_function, ProfilerActivity, schedule
|
| 19 |
+
from fastcore.basics import store_attr
|
| 20 |
+
from huggingface_hub import hf_hub_download
|
| 21 |
+
|
| 22 |
+
# %% ../nbs/4B. Multi-language semantic to acoustic token modeling.ipynb 3
|
| 23 |
+
from pathlib import Path
|
| 24 |
+
import json
|
| 25 |
+
from fastprogress import progress_bar, master_bar
|
| 26 |
+
|
| 27 |
+
# %% ../nbs/4B. Multi-language semantic to acoustic token modeling.ipynb 4
|
| 28 |
+
from .modules import *
|
| 29 |
+
|
| 30 |
+
# %% ../nbs/4B. Multi-language semantic to acoustic token modeling.ipynb 8
|
| 31 |
+
def rand(start, end):
|
| 32 |
+
return random.random() * (end - start) + start
|
| 33 |
+
|
| 34 |
+
# %% ../nbs/4B. Multi-language semantic to acoustic token modeling.ipynb 9
|
| 35 |
+
def random_trunc(random_trunc_p, atoks_len = 2250, stoks_len = 750):
|
| 36 |
+
atoks_per_second = atoks_len / 30
|
| 37 |
+
def _trunc(samples):
|
| 38 |
+
for s in samples:
|
| 39 |
+
if random.random() < random_trunc_p:
|
| 40 |
+
seconds = rand(0.3, 30)
|
| 41 |
+
s['atoks.npy'] = s['atoks.npy'][:,:math.ceil(seconds * atoks_per_second)]
|
| 42 |
+
s['stoks.npy'] = s['stoks.npy'][:math.ceil(s['atoks.npy'].shape[-1]/atoks_len*stoks_len)]
|
| 43 |
+
yield s
|
| 44 |
+
return _trunc
|
| 45 |
+
|
| 46 |
+
def pad_samples(atoks_len = 2250, stoks_len = 750, stoks_pad_token = 4096):
|
| 47 |
+
def _pad(samples):
|
| 48 |
+
for s in samples:
|
| 49 |
+
s['stoks.npy'] = F.pad(torch.tensor(s['stoks.npy']), (1, stoks_len - s['stoks.npy'].shape[-1]-1), value=stoks_pad_token)
|
| 50 |
+
s['out_stoks'] = F.pad(torch.tensor(s['stoks.npy']), (0, stoks_len - s['stoks.npy'].shape[-1]), value=stoks_pad_token)
|
| 51 |
+
s['atoks.npy'] = F.pad(torch.tensor(s['atoks.npy']), (0, atoks_len - s['atoks.npy'].shape[-1]), value=-100)
|
| 52 |
+
yield s
|
| 53 |
+
return _pad
|
| 54 |
+
|
| 55 |
+
# %% ../nbs/4B. Multi-language semantic to acoustic token modeling.ipynb 10
|
| 56 |
+
def make_speaker_map(shards):
|
| 57 |
+
speakers = set()
|
| 58 |
+
for shard in shards:
|
| 59 |
+
with open(shard+'.speakers.txt') as f: speakers = speakers.union(set(x.strip() for x in f.readlines()))
|
| 60 |
+
return {id:i for i,id in enumerate(sorted(speakers))}
|
| 61 |
+
|
| 62 |
+
def speaker_id_extractor(speaker_map):
|
| 63 |
+
def _extractor(samples):
|
| 64 |
+
for s in samples:
|
| 65 |
+
s['speaker'] = torch.tensor(speaker_map[s['__key__'].split("/")[1]])
|
| 66 |
+
yield s
|
| 67 |
+
return _extractor
|
| 68 |
+
|
| 69 |
+
# %% ../nbs/4B. Multi-language semantic to acoustic token modeling.ipynb 27
|
| 70 |
+
def load_dataset(
|
| 71 |
+
atoks_shard_spec:str, # webdataset folder
|
| 72 |
+
stoks_shard_dir:str, # stoks webdataset base dir
|
| 73 |
+
samples:int, # samples per epoch
|
| 74 |
+
random_trunc_p:float=0,# probability of truncating the input to less than 30 seconds
|
| 75 |
+
vq_codes:int=4096,
|
| 76 |
+
language:str='en',
|
| 77 |
+
weight:float=1,
|
| 78 |
+
validation:bool=False,
|
| 79 |
+
exclude_files:str=None,
|
| 80 |
+
randomize_speakers:bool=False,
|
| 81 |
+
):
|
| 82 |
+
import webdataset as wds
|
| 83 |
+
from whisperspeech import utils
|
| 84 |
+
|
| 85 |
+
shards = utils.shard_glob(atoks_shard_spec)
|
| 86 |
+
excludes = {x for file in exclude_files.split() for x in utils.readlines(file)} if exclude_files else set()
|
| 87 |
+
|
| 88 |
+
def check_for_nan(s):
|
| 89 |
+
if torch.tensor(s['spk_emb.npy']).isnan().any(): print("found NaN:", s['__key__'])
|
| 90 |
+
return s
|
| 91 |
+
|
| 92 |
+
def set_language(x):
|
| 93 |
+
x['language'] = language
|
| 94 |
+
return x
|
| 95 |
+
|
| 96 |
+
same_on_all_nodes = lambda urls: urls # will only be used for validation
|
| 97 |
+
ds = wds.WebDataset(shards, resampled=not validation, nodesplitter=same_on_all_nodes).compose(
|
| 98 |
+
wds.decode(),
|
| 99 |
+
utils.merge_in(utils.derived_dataset('maxvad-stoks', base='atoks-3kbps', suffix='', dir=stoks_shard_dir)),
|
| 100 |
+
wds.map(check_for_nan),
|
| 101 |
+
wds.select(lambda s: s['__key__'] not in excludes),
|
| 102 |
+
wds.map_dict(**{'spk_emb.npy':np.nan_to_num}), # remove nans from the speaker embedding model
|
| 103 |
+
random_trunc(random_trunc_p) if random_trunc_p > 0 else lambda x: x,
|
| 104 |
+
pad_samples(stoks_pad_token=vq_codes-1),
|
| 105 |
+
wds.map(set_language),
|
| 106 |
+
wds.to_tuple('stoks.npy', 'atoks.npy', 'spk_emb.npy', 'language', 'out_stoks'),
|
| 107 |
+
wds.shuffle(20000, initial=20000),
|
| 108 |
+
wds.batched(64),
|
| 109 |
+
)
|
| 110 |
+
if randomize_speakers:
|
| 111 |
+
rng = np.random.default_rng()
|
| 112 |
+
ds = ds.compose(
|
| 113 |
+
wds.map_tuple(None, None, lambda x: rng.permutation(x), None),
|
| 114 |
+
)
|
| 115 |
+
if validation:
|
| 116 |
+
ds = ds.slice(samples // 64)
|
| 117 |
+
ds.total_samples = samples
|
| 118 |
+
ds.weight = weight
|
| 119 |
+
|
| 120 |
+
return ds
|
| 121 |
+
|
| 122 |
+
# %% ../nbs/4B. Multi-language semantic to acoustic token modeling.ipynb 37
|
| 123 |
+
class DelSumEmbedding(nn.Module):
|
| 124 |
+
def __init__(self, n_head=6, head_width=64, atoks_width=None, length=2250, codes=1024, quantizers=8, pos_embs=None):
|
| 125 |
+
super().__init__()
|
| 126 |
+
self.length = length
|
| 127 |
+
width = n_head * head_width
|
| 128 |
+
if atoks_width is None: atoks_width = width
|
| 129 |
+
self.width = width
|
| 130 |
+
self.quantizers = quantizers
|
| 131 |
+
|
| 132 |
+
emb = None
|
| 133 |
+
embs = []
|
| 134 |
+
for _ in range(quantizers):
|
| 135 |
+
emb = FlexEmbeddings(codes, width, special_codes=2, frozen_width=atoks_width,
|
| 136 |
+
special_embedding=emb and emb.special)
|
| 137 |
+
embs.append(emb)
|
| 138 |
+
self.embeddings = nn.ModuleList(embs)
|
| 139 |
+
if pos_embs is not None:
|
| 140 |
+
self.register_buffer("positional_embedding", pos_embs)
|
| 141 |
+
|
| 142 |
+
def forward(self, toks, xenc):
|
| 143 |
+
with record_function("embeddings"):
|
| 144 |
+
b,_,n = toks.shape
|
| 145 |
+
newn = min(n, self.length)
|
| 146 |
+
|
| 147 |
+
embs = torch.zeros((b,newn,self.width), dtype=xenc.dtype, device=xenc.device)
|
| 148 |
+
for i in range(self.quantizers):
|
| 149 |
+
embs[:, :] += self.embeddings[i](toks[:,i,:])
|
| 150 |
+
|
| 151 |
+
x = embs.to(xenc.dtype)
|
| 152 |
+
return x
|
| 153 |
+
|
| 154 |
+
# %% ../nbs/4B. Multi-language semantic to acoustic token modeling.ipynb 38
|
| 155 |
+
class DelSumHead(nn.Module):
|
| 156 |
+
def __init__(self, quantizers=8, n_head=6, head_width=64):
|
| 157 |
+
super().__init__()
|
| 158 |
+
self.width = n_head * head_width
|
| 159 |
+
self.quantizers = quantizers
|
| 160 |
+
self.splitter = nn.Sequential(
|
| 161 |
+
nn.Linear(self.width, self.width * quantizers),
|
| 162 |
+
nn.GELU(),
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
def forward(self, x, embeddings=None):
|
| 166 |
+
b, newn, _ = x.shape
|
| 167 |
+
with record_function("splitter"):
|
| 168 |
+
split = self.splitter(x).view(b,newn,self.quantizers,self.width)
|
| 169 |
+
with record_function("unembed"):
|
| 170 |
+
logits = torch.stack([embeddings[q].unembed(split[:,:,q]) for q in range(self.quantizers)], dim=1)
|
| 171 |
+
return logits
|
| 172 |
+
|
| 173 |
+
def rand(start, end):
|
| 174 |
+
return random.random() * (end - start) + start
|
| 175 |
+
|
| 176 |
+
@dataclasses.dataclass
|
| 177 |
+
class Tunables:
|
| 178 |
+
init_std :float = 9
|
| 179 |
+
embeddings_std :float = 0.2
|
| 180 |
+
embeddings_lr_scale: float = 10
|
| 181 |
+
output_mult :float = 5.6
|
| 182 |
+
# FIXME: try separate mults for self and cross attention
|
| 183 |
+
query_mult :float = .3
|
| 184 |
+
encoder_depth_ratio :float = 0.25
|
| 185 |
+
linear_heads :bool = False
|
| 186 |
+
rope :bool = True
|
| 187 |
+
|
| 188 |
+
lr0 :float = 3e-3
|
| 189 |
+
clip_gradient_norm :float = 2
|
| 190 |
+
weight_decay :float = 1e-3
|
| 191 |
+
warmup_steps :float = 2000
|
| 192 |
+
|
| 193 |
+
random :bool = False
|
| 194 |
+
|
| 195 |
+
def __post_init__(self):
|
| 196 |
+
# randomize the hyperparams if requested
|
| 197 |
+
if self.random:
|
| 198 |
+
self.init_std = 2*10**rand(0,1)
|
| 199 |
+
self.embeddings_std = 10**rand(-1.7,-0.22)
|
| 200 |
+
self.embeddings_lr_scale = 2**rand(2,4)
|
| 201 |
+
self.output_mult = 2**rand(1.5,3)
|
| 202 |
+
self.query_mult = 2**rand(-3,-1.3)
|
| 203 |
+
self.encoder_depth_ratio = random.choice([0.25,0.5])
|
| 204 |
+
self.linear_heads = False
|
| 205 |
+
self.rope = True
|
| 206 |
+
|
| 207 |
+
self.lr0 = 3e-3
|
| 208 |
+
self.clip_gradient_norm = 10**rand(-1,1)
|
| 209 |
+
self.warmup_steps = 100*(10**rand(1.18,1.3))
|
| 210 |
+
|
| 211 |
+
@staticmethod
|
| 212 |
+
def upgrade(args):
|
| 213 |
+
args = {k:v for k,v in args.items()}
|
| 214 |
+
def old_default(name, value):
|
| 215 |
+
if name not in args: args[name] = value
|
| 216 |
+
old_default('rope', False)
|
| 217 |
+
old_default('linear_heads', True)
|
| 218 |
+
return args
|
| 219 |
+
|
| 220 |
+
class SADelARTransformer(nn.Module):
|
| 221 |
+
def __init__(self, depth=3, ctx_n=2250,
|
| 222 |
+
stoks_len=750, stoks_codes=4097, stoks_width=None,
|
| 223 |
+
spk_width=None,
|
| 224 |
+
atoks_width=None,
|
| 225 |
+
n_head=3, head_width=64, ffn_mult=4,
|
| 226 |
+
quantizers=8, speaker_map={"1":0}, tunables=Tunables()):
|
| 227 |
+
super().__init__()
|
| 228 |
+
self.quantizers = quantizers
|
| 229 |
+
self.codes = 1024
|
| 230 |
+
width = n_head * head_width
|
| 231 |
+
store_attr("depth,ctx_n,stoks_len,stoks_codes,stoks_width,spk_width,atoks_width,n_head,head_width,ffn_mult,quantizers,speaker_map")
|
| 232 |
+
self.width = width
|
| 233 |
+
self.base_width = 3 * head_width
|
| 234 |
+
self.tunables = tunables
|
| 235 |
+
|
| 236 |
+
if stoks_width is None: stoks_width = width
|
| 237 |
+
if spk_width is None: spk_width = width
|
| 238 |
+
self.emb_factor = width != stoks_width
|
| 239 |
+
self.spk_factor = width != spk_width
|
| 240 |
+
|
| 241 |
+
if tunables.rope:
|
| 242 |
+
self.positional_embeddings = None
|
| 243 |
+
else:
|
| 244 |
+
self.register_buffer('positional_embeddings', sinusoids(ctx_n, width))
|
| 245 |
+
|
| 246 |
+
# self.speaker_embedding = nn.Embedding(len(speaker_map), spk_width)
|
| 247 |
+
self.semantic_embedding = nn.Embedding(stoks_codes, stoks_width)
|
| 248 |
+
if self.emb_factor:
|
| 249 |
+
self.emb_to_hidden = nn.Linear(stoks_width, width)
|
| 250 |
+
self.hidden_to_emb = nn.Linear(width, stoks_width)
|
| 251 |
+
|
| 252 |
+
if self.spk_factor:
|
| 253 |
+
self.spk_to_hidden = nn.Linear(spk_width, width)
|
| 254 |
+
|
| 255 |
+
qk_scale = self.tunables.query_mult * 8 / math.sqrt(head_width)
|
| 256 |
+
|
| 257 |
+
encoder_depth = int(depth * 2 * tunables.encoder_depth_ratio)
|
| 258 |
+
decoder_depth = depth * 2 - encoder_depth
|
| 259 |
+
self.encoder = nn.Sequential(*[
|
| 260 |
+
ResidualAttentionBlock(width, n_head, qk_scale=qk_scale, ffn_mult=ffn_mult, rope=tunables.rope) for _ in range(encoder_depth)
|
| 261 |
+
]) # FIXME: enclm requires causal attention here
|
| 262 |
+
self.ln_post = LayerNorm(width)
|
| 263 |
+
|
| 264 |
+
self.embds = DelSumEmbedding(
|
| 265 |
+
pos_embs=self.positional_embeddings, length=ctx_n,
|
| 266 |
+
n_head=n_head, head_width=head_width, atoks_width=atoks_width,
|
| 267 |
+
quantizers=quantizers,
|
| 268 |
+
)
|
| 269 |
+
self.decoder = BaseDecoder(qk_scale=qk_scale, length=ctx_n,
|
| 270 |
+
n_head=n_head, width=n_head * head_width,
|
| 271 |
+
ffn_mult=ffn_mult, depth=decoder_depth,
|
| 272 |
+
rope=tunables.rope)
|
| 273 |
+
self.head = DelSumHead(n_head=n_head, head_width=head_width, quantizers=quantizers)
|
| 274 |
+
for l in self.decoder.layers:
|
| 275 |
+
l.cross_attn.key_subsampling = 3
|
| 276 |
+
# for l in self.encoder:
|
| 277 |
+
# l.attn.key_subsampling = 3
|
| 278 |
+
# l.attn.query_subsampling = 3
|
| 279 |
+
|
| 280 |
+
self.register_buffer('val_true', torch.zeros(self.quantizers).cuda())
|
| 281 |
+
self.register_buffer('val_total', torch.zeros(self.quantizers).cuda())
|
| 282 |
+
self.apply(self.init_transformer)
|
| 283 |
+
|
| 284 |
+
def setup(self, device):
|
| 285 |
+
pass
|
| 286 |
+
|
| 287 |
+
def load_frozen_semantic_embeddings(self, vqmodel):
|
| 288 |
+
with torch.no_grad():
|
| 289 |
+
self.semantic_embedding.weight[:] = vqmodel.rq.layers[0]._codebook.embed[0]
|
| 290 |
+
self.semantic_embedding.lr_scale = 0
|
| 291 |
+
|
| 292 |
+
def load_frozen_acoustic_embeddings(self, amodel):
|
| 293 |
+
for i in range(self.quantizers):
|
| 294 |
+
self.decoder.embeddings[i].set_frozen_embeddings(amodel.quantizer.vq.layers[i].codebook)
|
| 295 |
+
|
| 296 |
+
def init_transformer(self, m):
|
| 297 |
+
if isinstance(m, LinearHead):
|
| 298 |
+
m.no_weight_decay = True
|
| 299 |
+
torch.nn.init.constant_(m.weight, 0)
|
| 300 |
+
elif isinstance(m, QueryHead):
|
| 301 |
+
m.lr_scale = 1/(m.weight.shape[1] / self.base_width)
|
| 302 |
+
torch.nn.init.constant_(m.weight, 0)
|
| 303 |
+
elif isinstance(m, nn.Embedding):
|
| 304 |
+
m.no_weight_decay = True
|
| 305 |
+
m.lr_scale = self.tunables.embeddings_lr_scale
|
| 306 |
+
std = self.tunables.embeddings_std
|
| 307 |
+
torch.nn.init.trunc_normal_(m.weight, std=std, a=-3*std, b=3*std)
|
| 308 |
+
# elif isinstance(m, EmbeddingProjector):
|
| 309 |
+
# m.lr_scale = self.tunables.embeddings_lr_scale #1/(m.weight.shape[1] / self.base_width)
|
| 310 |
+
# m.lr_scale = 2/(m.weight.shape[1] / self.base_width)
|
| 311 |
+
# std = self.tunables.init_std / m.weight.shape[1]
|
| 312 |
+
# torch.nn.init.trunc_normal_(m.weight, std=std, a=-3*std, b=3*std)
|
| 313 |
+
elif isinstance(m, nn.Linear):
|
| 314 |
+
m.lr_scale = 1/(m.weight.shape[1] / self.base_width)
|
| 315 |
+
std = self.tunables.init_std / m.weight.shape[1]
|
| 316 |
+
torch.nn.init.trunc_normal_(m.weight, std=std, a=-3*std, b=3*std)
|
| 317 |
+
if m.bias is not None:
|
| 318 |
+
torch.nn.init.trunc_normal_(m.bias, std=std, a=-3*std, b=3*std)
|
| 319 |
+
elif isinstance(m, nn.LayerNorm):
|
| 320 |
+
m.no_weight_decay = True
|
| 321 |
+
torch.nn.init.constant_(m.bias, 0)
|
| 322 |
+
torch.nn.init.constant_(m.weight, 1)
|
| 323 |
+
|
| 324 |
+
def embed_stoks(self, Stoks):
|
| 325 |
+
b,n = Stoks.shape
|
| 326 |
+
if self.stoks_len == 1500:
|
| 327 |
+
# converts 50 toks/s to 75 toks/s by adding padding between every two tokens
|
| 328 |
+
x = Stoks.reshape(b,n//2,2)
|
| 329 |
+
x = x.repeat_interleave(2, -1)[:,:,:3]
|
| 330 |
+
x[:,:,1] = 1024
|
| 331 |
+
x = x.reshape(b,n//2*3)
|
| 332 |
+
else:
|
| 333 |
+
# it's a lot easier with 25 toks/s
|
| 334 |
+
# x = Stoks.repeat_interleave(3, -1)
|
| 335 |
+
x = Stoks
|
| 336 |
+
# embed semantic tokens
|
| 337 |
+
Sembs = self.semantic_embedding(x.to(torch.long))
|
| 338 |
+
if self.emb_factor:
|
| 339 |
+
Sembs = self.emb_to_hidden(Sembs)
|
| 340 |
+
return Sembs
|
| 341 |
+
|
| 342 |
+
def _encoder(self, semb, positions):
|
| 343 |
+
x = semb
|
| 344 |
+
for l in self.encoder: x = l(x, positions)
|
| 345 |
+
return self.ln_post(x)
|
| 346 |
+
|
| 347 |
+
def run_encoder(self, Stoks, speakers):
|
| 348 |
+
semb = self.embed_stoks(Stoks)
|
| 349 |
+
with record_function("encoder"):
|
| 350 |
+
if self.positional_embeddings is not None: semb = semb + self.positional_embeddings
|
| 351 |
+
positions = torch.arange(0, semb.shape[1], device=semb.device)
|
| 352 |
+
xenc = self._encoder(semb, positions)
|
| 353 |
+
if self.training:
|
| 354 |
+
enc_logits = (self.hidden_to_emb(xenc) @ self.semantic_embedding.weight.to(xenc.dtype).T).float()
|
| 355 |
+
enc_logits = enc_logits * self.tunables.output_mult / (self.width / self.base_width)
|
| 356 |
+
else:
|
| 357 |
+
enc_logits = None
|
| 358 |
+
# print(xenc.shape, speakers.shape)
|
| 359 |
+
spk_embs = F.normalize(speakers, dim=-1) # use extracted embeddings
|
| 360 |
+
if self.spk_factor: spk_embs = self.spk_to_hidden(spk_embs)
|
| 361 |
+
return xenc + spk_embs.unsqueeze(1), positions, enc_logits
|
| 362 |
+
|
| 363 |
+
def forward(self, Stoks, Atoks, speakers, langs=None, out_stoks=None, noloss=False, xenc=None, xenc_positions=None, atoks_positions=None):
|
| 364 |
+
if xenc is None:
|
| 365 |
+
Atoks = Atoks.to(torch.long)
|
| 366 |
+
out_stoks = out_stoks.to(torch.long)
|
| 367 |
+
Atoks_gt = Atoks.clone()
|
| 368 |
+
Atoks_gt[Atoks == -100] = 1024
|
| 369 |
+
xenc, enc_logits = self.run_encoder(Stoks, speakers)
|
| 370 |
+
else:
|
| 371 |
+
Atoks_gt = Atoks
|
| 372 |
+
with record_function("decoder"):
|
| 373 |
+
embs = self.embds(Atoks, xenc)
|
| 374 |
+
if atoks_positions is None: atoks_positions = torch.arange(0, embs.shape[1], device=embs.device)
|
| 375 |
+
x = self.decoder(embs, atoks_positions, xenc, xenc_positions)
|
| 376 |
+
logits = self.head(x, embeddings=self.embds.embeddings)
|
| 377 |
+
logits *= self.tunables.output_mult / (self.width / self.base_width)
|
| 378 |
+
|
| 379 |
+
if noloss:
|
| 380 |
+
return logits
|
| 381 |
+
|
| 382 |
+
with record_function("loss"):
|
| 383 |
+
N = Atoks.shape[-1]
|
| 384 |
+
loss = 0
|
| 385 |
+
for i in range(self.quantizers):
|
| 386 |
+
loss += F.cross_entropy(logits[:,i,i:].reshape(-1,logits.shape[-1]), Atoks[:,i,:N-i].reshape(-1))
|
| 387 |
+
if self.training and i == 0:
|
| 388 |
+
loss *= 5
|
| 389 |
+
loss /= self.quantizers
|
| 390 |
+
if self.training:
|
| 391 |
+
loss += 0.1 * F.cross_entropy(enc_logits.transpose(-1,-2), out_stoks)
|
| 392 |
+
|
| 393 |
+
if not self.training:
|
| 394 |
+
for i in range(self.quantizers):
|
| 395 |
+
Atoks_i = Atoks[:,i,:N-i]
|
| 396 |
+
valid_Atoks = Atoks_i != -100
|
| 397 |
+
self.val_true[i] += (logits[:,i,i:].argmax(-1)[valid_Atoks] == Atoks_i[valid_Atoks]).float().sum()
|
| 398 |
+
self.val_total[i] += valid_Atoks.float().sum()
|
| 399 |
+
|
| 400 |
+
return logits, loss
|
| 401 |
+
|
| 402 |
+
def get_metrics(self):
|
| 403 |
+
metrics = {
|
| 404 |
+
f'acc_{i}':x.item() for i,x in enumerate(self.val_true / self.val_total)
|
| 405 |
+
}
|
| 406 |
+
self.val_true[:] = 0
|
| 407 |
+
self.val_total[:] = 0
|
| 408 |
+
return metrics
|
| 409 |
+
|
| 410 |
+
#
|
| 411 |
+
# inference
|
| 412 |
+
#
|
| 413 |
+
@classmethod
|
| 414 |
+
def load_model(cls, ref="collabora/whisperspeech:s2a-q4-small-en+pl.model",
|
| 415 |
+
repo_id=None, filename=None, local_filename=None):
|
| 416 |
+
if repo_id is None and filename is None and local_filename is None:
|
| 417 |
+
if ":" in ref:
|
| 418 |
+
repo_id, filename = ref.split(":", 1)
|
| 419 |
+
else:
|
| 420 |
+
local_filename = ref
|
| 421 |
+
if not local_filename:
|
| 422 |
+
local_filename = hf_hub_download(repo_id=repo_id, filename=filename)
|
| 423 |
+
spec = torch.load(local_filename)
|
| 424 |
+
if '_extra_state' not in spec['state_dict']: spec['state_dict']['_extra_state'] = { 'speaker_map': spec['config']['speaker_map'] }
|
| 425 |
+
model = cls(**spec['config'], tunables=Tunables(**Tunables.upgrade(spec['tunables'])))
|
| 426 |
+
model.load_state_dict(spec['state_dict'])
|
| 427 |
+
model.eval()
|
| 428 |
+
return model
|
| 429 |
+
|
| 430 |
+
def get_extra_state(self):
|
| 431 |
+
return { 'speaker_map': self.speaker_map }
|
| 432 |
+
|
| 433 |
+
def set_extra_state(self, st):
|
| 434 |
+
self.speaker_map = st['speaker_map']
|
| 435 |
+
|
| 436 |
+
def load_checkpoint(self, local_filename):
|
| 437 |
+
spec = torch.load(local_filename, map_location='cpu')
|
| 438 |
+
assert 'pytorch-lightning_version' in spec, 'not a valid PyTorch Lightning checkpoint'
|
| 439 |
+
state_dict = {k.replace('model.', ''):v
|
| 440 |
+
for k,v in spec['state_dict'].items()}
|
| 441 |
+
self.load_state_dict(state_dict)
|
| 442 |
+
return self
|
| 443 |
+
|
| 444 |
+
def save_model(self, fname):
|
| 445 |
+
torch.save(dict(config = self.__stored_args__,
|
| 446 |
+
tunables = dataclasses.asdict(self.tunables),
|
| 447 |
+
state_dict = self.state_dict()), fname)
|
| 448 |
+
|
| 449 |
+
def switch_dtypes(self, dtype=torch.float16):
|
| 450 |
+
self.dtype = dtype
|
| 451 |
+
for n,m in self.named_modules():
|
| 452 |
+
# convert every leaf layer apart from the LayerNorms
|
| 453 |
+
if isinstance(m, (nn.Linear, nn.Embedding)):
|
| 454 |
+
m.to(dtype)
|
| 455 |
+
# take care of buffers ([kv]_cache, masks) that are not in the leaf layers
|
| 456 |
+
for bn,b in m.named_buffers(recurse=False):
|
| 457 |
+
setattr(m,bn,b.to(dtype))
|
| 458 |
+
|
| 459 |
+
def optimize(self, max_batch_size=1, dtype=torch.float16, torch_compile=True):
|
| 460 |
+
for emb in self.embds.embeddings:
|
| 461 |
+
emb.convert_for_eval()
|
| 462 |
+
for l in self.encoder:
|
| 463 |
+
l.attn.convert_for_eval()
|
| 464 |
+
for l in self.decoder.layers:
|
| 465 |
+
l.attn.convert_for_eval()
|
| 466 |
+
l.cross_attn.convert_for_eval()
|
| 467 |
+
l.setup_kv_cache(max_batch_size, self.ctx_n, self.stoks_len)
|
| 468 |
+
self.switch_dtypes(dtype)
|
| 469 |
+
if torch_compile:
|
| 470 |
+
self.generate_next = torch.compile(self.generate_next, mode="reduce-overhead", fullgraph=True)
|
| 471 |
+
|
| 472 |
+
@property
|
| 473 |
+
def device(self):
|
| 474 |
+
return next(self.parameters()).device
|
| 475 |
+
|
| 476 |
+
# from https://github.com/pytorch-labs/gpt-fast/blob/main/generate.py
|
| 477 |
+
def multinomial_sample_one_no_sync(self, probs_sort): # Does multinomial sampling without a cuda synchronization
|
| 478 |
+
q = torch.empty_like(probs_sort).exponential_(1)
|
| 479 |
+
return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
|
| 480 |
+
|
| 481 |
+
def logits_to_probs(self, logits, T=1.0, top_k=None):
|
| 482 |
+
logits = logits / max(T, 1e-5)
|
| 483 |
+
|
| 484 |
+
if top_k is not None:
|
| 485 |
+
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
| 486 |
+
pivot = v.select(-1, -1).unsqueeze(-1)
|
| 487 |
+
logits = torch.where(logits < pivot, -float("Inf"), logits)
|
| 488 |
+
probs = torch.nn.functional.softmax(logits, dim=-1)
|
| 489 |
+
return probs
|
| 490 |
+
|
| 491 |
+
def sample(self, logits, T=1.0, top_k=None):
|
| 492 |
+
probs = self.logits_to_probs(logits[0,:,-1], T, top_k)
|
| 493 |
+
idx_next = self.multinomial_sample_one_no_sync(probs)
|
| 494 |
+
return idx_next
|
| 495 |
+
|
| 496 |
+
def generate_one(self, toks, positions, langs, xenc, xenc_positions, T, top_k):
|
| 497 |
+
probs = self(None, toks, None, langs, noloss=True, xenc=xenc, xenc_positions=xenc_positions, atoks_positions=positions)
|
| 498 |
+
return self.sample(probs, T, top_k)
|
| 499 |
+
|
| 500 |
+
def generate_next(self, *args, **kwargs):
|
| 501 |
+
return self.generate_one(*args, **kwargs)
|
| 502 |
+
|
| 503 |
+
@torch.no_grad()
|
| 504 |
+
def generate(self, stoks, speakers, langs=None, N=None, T=0.7, top_k=None, show_progress_bar=True, step=None, subsample_enc=False):
|
| 505 |
+
dev = self.device
|
| 506 |
+
N = N or len(stoks) * 3
|
| 507 |
+
stoks = F.pad(stoks.to(dev), (1, self.stoks_len - len(stoks)-1), value=self.stoks_codes-1).unsqueeze(0)
|
| 508 |
+
speakers = speakers.to(device=dev, dtype=self.dtype)
|
| 509 |
+
toks = torch.full((1,self.quantizers,2250), self.codes+1, dtype=torch.long, device=dev)
|
| 510 |
+
it = range(1,min(N,2250-1))
|
| 511 |
+
if show_progress_bar: it = progress_bar(it)
|
| 512 |
+
with record_function("encode"):
|
| 513 |
+
xenc, xenc_positions, _ = self.run_encoder(stoks, speakers)
|
| 514 |
+
toks_positions = torch.arange(N, device=dev)
|
| 515 |
+
with record_function("prefill"):
|
| 516 |
+
toks[0,0,1] = self.generate_one(toks[:,:,:1], toks_positions[:1], langs, xenc, xenc_positions, T, top_k)[0,0]
|
| 517 |
+
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True):
|
| 518 |
+
for i in it:
|
| 519 |
+
with record_function("generate_one"):
|
| 520 |
+
toks[0,:i+1,i+1] = self.generate_next(toks[:,:,i:i+1], toks_positions[i:i+1], langs, xenc, xenc_positions, T, top_k)[:i+1,0]
|
| 521 |
+
|
| 522 |
+
# for profiling, debugging or early exit
|
| 523 |
+
if step is not None: step()
|
| 524 |
+
# shift tokens
|
| 525 |
+
toks = toks[:,:,1:N]
|
| 526 |
+
for j in range(self.quantizers):
|
| 527 |
+
toks[0, j] = torch.roll(toks[0, j], -j)
|
| 528 |
+
return toks[0]
|
| 529 |
+
|
| 530 |
+
# %% ../nbs/4B. Multi-language semantic to acoustic token modeling.ipynb 39
|
| 531 |
+
def _make_model(size:str, quantizers:int=4, tunables:Tunables=Tunables(), **kwargs):
|
| 532 |
+
kwargs = dict(quantizers=quantizers, tunables=tunables, **kwargs)
|
| 533 |
+
if size == 'micro':
|
| 534 |
+
return SADelARTransformer(depth=4, n_head=3, ffn_mult=2, **kwargs)
|
| 535 |
+
if size == 'tiny-narrow':
|
| 536 |
+
return SADelARTransformer(depth=4, n_head=6, ffn_mult=1, **kwargs)
|
| 537 |
+
if size == 'tiny':
|
| 538 |
+
return SADelARTransformer(depth=4, n_head=6, **kwargs)
|
| 539 |
+
if size == 'base':
|
| 540 |
+
return SADelARTransformer(depth=6, n_head=8, **kwargs)
|
| 541 |
+
if size == 'base-deep':
|
| 542 |
+
return SADelARTransformer(depth=9, n_head=8, **kwargs)
|
| 543 |
+
if size == 'base-wide':
|
| 544 |
+
return SADelARTransformer(depth=6, n_head=12, **kwargs)
|
| 545 |
+
if size == 'small/2':
|
| 546 |
+
return SADelARTransformer(depth=9, n_head=12, **kwargs)
|
| 547 |
+
if size == 'small':
|
| 548 |
+
return SADelARTransformer(depth=12, n_head=12, **kwargs)
|
| 549 |
+
if size == 'medium':
|
| 550 |
+
return SADelARTransformer(depth=24, n_head=16, **kwargs)
|
| 551 |
+
|
| 552 |
+
def make_model(size:str, quantizers:int=4, frozen_embeddings_model:str=None, frozen_acoustic_embeddings:bool=False, spk_width:int=None, tunables:Tunables=Tunables(), dataset=None):
|
| 553 |
+
from encodec.model import EncodecModel
|
| 554 |
+
from whisperspeech import vq_stoks
|
| 555 |
+
|
| 556 |
+
amodel = EncodecModel.encodec_model_24khz() if frozen_acoustic_embeddings else None
|
| 557 |
+
vqmodel = vq_stoks.RQBottleneckTransformer.load_model(frozen_embeddings_model) if frozen_embeddings_model else None
|
| 558 |
+
model = _make_model(size, quantizers, tunables,
|
| 559 |
+
spk_width=spk_width,
|
| 560 |
+
atoks_width=amodel and amodel.quantizer.vq.layers[0]._codebook.embed.shape[-1],
|
| 561 |
+
stoks_codes=vqmodel.vq_codes+1, stoks_width=vqmodel.rq.layers[0]._codebook.embed[0].shape[-1])
|
| 562 |
+
if vqmodel: model.load_frozen_semantic_embeddings(vqmodel)
|
| 563 |
+
if amodel: model.load_frozen_acoustic_embeddings(amodel)
|
| 564 |
+
return model
|
whisperspeech/t2s_up_wds.py
ADDED
|
@@ -0,0 +1,442 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/5B. Text to semantic token modeling.ipynb.
|
| 2 |
+
|
| 3 |
+
# %% auto 0
|
| 4 |
+
__all__ = ['load_datasets', 'rand', 'Tunables', 'Encoder', 'Decoder', 'TSARTransformer', 'make_model']
|
| 5 |
+
|
| 6 |
+
# %% ../nbs/5B. Text to semantic token modeling.ipynb 1
|
| 7 |
+
import dataclasses
|
| 8 |
+
import random
|
| 9 |
+
import math
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
from torch.profiler import record_function
|
| 14 |
+
|
| 15 |
+
from huggingface_hub import hf_hub_download
|
| 16 |
+
from fastcore.basics import store_attr
|
| 17 |
+
from fastprogress import progress_bar
|
| 18 |
+
|
| 19 |
+
import webdataset as wds
|
| 20 |
+
|
| 21 |
+
# %% ../nbs/5B. Text to semantic token modeling.ipynb 2
|
| 22 |
+
from pathlib import Path
|
| 23 |
+
import pylab as plt
|
| 24 |
+
import pandas as pd
|
| 25 |
+
import numpy as np
|
| 26 |
+
|
| 27 |
+
# %% ../nbs/5B. Text to semantic token modeling.ipynb 3
|
| 28 |
+
import whisper
|
| 29 |
+
from whisperspeech.train import *
|
| 30 |
+
from whisperspeech.modules import *
|
| 31 |
+
from whisperspeech import vq_stoks
|
| 32 |
+
|
| 33 |
+
# %% ../nbs/5B. Text to semantic token modeling.ipynb 8
|
| 34 |
+
import re
|
| 35 |
+
|
| 36 |
+
class CharTokenizer:
|
| 37 |
+
"""Trivial tokenizer β just use UTF-8 bytes"""
|
| 38 |
+
eot = 0
|
| 39 |
+
|
| 40 |
+
def encode(self, txt):
|
| 41 |
+
return list(bytes(txt.strip(), 'utf-8'))
|
| 42 |
+
|
| 43 |
+
def decode(self, tokens):
|
| 44 |
+
return bytes(tokens).decode('utf-8')
|
| 45 |
+
|
| 46 |
+
def tokenizer(ikey, okey, length):
|
| 47 |
+
"""Tokenizes a transcript"""
|
| 48 |
+
tok = CharTokenizer()
|
| 49 |
+
def _tokenizer(samples):
|
| 50 |
+
for s in samples:
|
| 51 |
+
toks = torch.tensor(tok.encode(s[ikey]))
|
| 52 |
+
s[okey] = F.pad(toks, (0, length - toks.shape[-1]), value=tok.eot)
|
| 53 |
+
yield s
|
| 54 |
+
return _tokenizer
|
| 55 |
+
|
| 56 |
+
def ar_padder(ikey, okey, length, pad_token):
|
| 57 |
+
"""Pads the tokens for autoregresive training"""
|
| 58 |
+
def _ar_padder(samples):
|
| 59 |
+
for s in samples:
|
| 60 |
+
toks = s[ikey]
|
| 61 |
+
if isinstance(toks, (list, np.ndarray)): toks = torch.tensor(toks)
|
| 62 |
+
toks = toks.to(torch.long)
|
| 63 |
+
s['in_' +okey] = F.pad(toks, (1, length - toks.shape[-1] - 1), value=pad_token)
|
| 64 |
+
s['out_'+okey] = F.pad(toks, (0, length - toks.shape[-1]), value=pad_token)
|
| 65 |
+
yield s
|
| 66 |
+
return _ar_padder
|
| 67 |
+
|
| 68 |
+
def char_per_seconder(txt_key, stoks_key, cps_key, stoks_per_second=25):
|
| 69 |
+
"""Adds the characters per second metric to the input data"""
|
| 70 |
+
def _char_per_seconder(samples):
|
| 71 |
+
for s in samples:
|
| 72 |
+
secs = s[stoks_key].shape[-1] / stoks_per_second
|
| 73 |
+
s[cps_key] = len(s[txt_key]) / secs
|
| 74 |
+
yield s
|
| 75 |
+
return _char_per_seconder
|
| 76 |
+
|
| 77 |
+
# %% ../nbs/5B. Text to semantic token modeling.ipynb 9
|
| 78 |
+
def build_speaker_map(shards):
|
| 79 |
+
speakers = set()
|
| 80 |
+
for shard in shards:
|
| 81 |
+
with open(shard+'.speakers.txt') as f: speakers = speakers.union(set(x.strip() for x in f.readlines()))
|
| 82 |
+
return {id:i for i,id in enumerate(speakers)}
|
| 83 |
+
|
| 84 |
+
def speaker_id_extractor(speaker_map):
|
| 85 |
+
def _extractor(samples):
|
| 86 |
+
for s in samples:
|
| 87 |
+
s['speaker'] = torch.tensor(speaker_map[s['__key__'].split("/")[1]])
|
| 88 |
+
yield s
|
| 89 |
+
return _extractor
|
| 90 |
+
|
| 91 |
+
# %% ../nbs/5B. Text to semantic token modeling.ipynb 10
|
| 92 |
+
def load_datasets(
|
| 93 |
+
input:str, # webdataset folder or shard list
|
| 94 |
+
samples:int, # samples per epoch
|
| 95 |
+
subsample:float=1, # use a fraction of the files
|
| 96 |
+
val_samples:int=512,
|
| 97 |
+
vq_codes:int=4096,
|
| 98 |
+
):
|
| 99 |
+
if isinstance(input, (Path, str)):
|
| 100 |
+
path = Path(input)
|
| 101 |
+
if path.is_dir():
|
| 102 |
+
glob = '*-t2s-*.tar.gz'
|
| 103 |
+
else:
|
| 104 |
+
glob = path.name
|
| 105 |
+
path = path.parent
|
| 106 |
+
input = Path(path).glob(glob)
|
| 107 |
+
elif isinstance(input, list):
|
| 108 |
+
pass
|
| 109 |
+
else:
|
| 110 |
+
raise ArgumentError("input should be either a list of a path with an optional glob specifier")
|
| 111 |
+
shards = [str(x) for x in input]
|
| 112 |
+
|
| 113 |
+
speaker_map = build_speaker_map(shards)
|
| 114 |
+
|
| 115 |
+
def ds(shards, length):
|
| 116 |
+
ds = wds.WebDataset(wds.ResampledShards(shards)).compose(
|
| 117 |
+
wds.decode(),
|
| 118 |
+
speaker_id_extractor(speaker_map),
|
| 119 |
+
wds.select(lambda s: s['stoks.npy'].shape[-1] > 12), # select samples > .5s
|
| 120 |
+
tokenizer('txt', 'ttoks', length=550),
|
| 121 |
+
ar_padder('stoks.npy', 'stoks', length=750, pad_token=vq_codes-1),
|
| 122 |
+
char_per_seconder('txt', 'stoks.npy', 'cps', stoks_per_second=25),
|
| 123 |
+
wds.to_tuple('ttoks', 'speaker', 'cps', 'in_stoks', 'out_stoks'),
|
| 124 |
+
wds.batched(64)
|
| 125 |
+
)
|
| 126 |
+
ds.speakers = speaker_map
|
| 127 |
+
ds.total_samples = length
|
| 128 |
+
ds.stoks_len = 750
|
| 129 |
+
ds.stoks_codes = vq_codes
|
| 130 |
+
ds.ttoks_len = 550
|
| 131 |
+
return ds.compose(wds.slice(length // 64)).with_epoch(length // 64).with_length(length // 64)
|
| 132 |
+
|
| 133 |
+
return (
|
| 134 |
+
ds(shards[1:], samples),
|
| 135 |
+
ds(shards[:1], val_samples),
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
# %% ../nbs/5B. Text to semantic token modeling.ipynb 14
|
| 139 |
+
def rand(start, end):
|
| 140 |
+
return random.random() * (end - start) + start
|
| 141 |
+
|
| 142 |
+
@dataclasses.dataclass
|
| 143 |
+
class Tunables:
|
| 144 |
+
init_std :float = 1
|
| 145 |
+
embeddings_std :float = .01
|
| 146 |
+
embeddings_lr_scale: float = 5
|
| 147 |
+
embedding_projector_lr_scale: float = 2.5
|
| 148 |
+
output_mult :float = .35
|
| 149 |
+
query_mult :float = 1
|
| 150 |
+
encoder_depth_ratio :float = 0.25
|
| 151 |
+
eot_dropout_p :float = .5
|
| 152 |
+
cps_input: bool = True
|
| 153 |
+
cps_bins: int = 32
|
| 154 |
+
|
| 155 |
+
lr0 :float = 1.5e-3
|
| 156 |
+
clip_gradient_norm :float = .2
|
| 157 |
+
weight_decay :float = 1e-1
|
| 158 |
+
warmup_steps :float = 4000
|
| 159 |
+
|
| 160 |
+
random :bool = False
|
| 161 |
+
|
| 162 |
+
def __post_init__(self):
|
| 163 |
+
# randomize the hyperparams if requested
|
| 164 |
+
if self.random:
|
| 165 |
+
self.init_std = 10**rand(-1,1)
|
| 166 |
+
self.embeddings_std = 10**rand(-3,-.7)
|
| 167 |
+
self.embeddings_lr_scale = rand(2,6)
|
| 168 |
+
self.output_mult = rand(0.25,0.65)
|
| 169 |
+
self.query_mult = 2**rand(-2,3)
|
| 170 |
+
self.encoder_depth_ratio = 0.25
|
| 171 |
+
|
| 172 |
+
self.lr0 = rand(1,5)*1e-3
|
| 173 |
+
self.clip_gradient_norm = 10**rand(-3,0)
|
| 174 |
+
self.warmup_steps = 100*(10**rand(1,1.85))
|
| 175 |
+
|
| 176 |
+
# %% ../nbs/5B. Text to semantic token modeling.ipynb 15
|
| 177 |
+
class EmbeddingProjector(nn.Linear):
|
| 178 |
+
pass
|
| 179 |
+
|
| 180 |
+
# %% ../nbs/5B. Text to semantic token modeling.ipynb 16
|
| 181 |
+
class Encoder(nn.Module):
|
| 182 |
+
def __init__(self, depth=6, width=384, n_head=6, length=1500, codes=1024, emb_width=384, ffn_mult=4, pos_embs=None, tunables=Tunables()):
|
| 183 |
+
super().__init__()
|
| 184 |
+
self.emb_width = emb_width
|
| 185 |
+
|
| 186 |
+
self.emb_factor = width != emb_width
|
| 187 |
+
|
| 188 |
+
self.embedding = nn.Embedding(codes, emb_width)
|
| 189 |
+
if self.emb_factor:
|
| 190 |
+
self.emb_to_hidden = EmbeddingProjector(emb_width, width)
|
| 191 |
+
|
| 192 |
+
if pos_embs is None: pos_embs = sinusoids(length, width)
|
| 193 |
+
self.register_buffer("positional_embedding", pos_embs)
|
| 194 |
+
|
| 195 |
+
self.layers = nn.Sequential(*[
|
| 196 |
+
ResidualAttentionBlock(width, n_head,
|
| 197 |
+
qk_scale=tunables.query_mult*8/math.sqrt(width/n_head), ffn_mult=ffn_mult) for _ in range(depth)
|
| 198 |
+
])
|
| 199 |
+
|
| 200 |
+
self.ln_post = LayerNorm(width)
|
| 201 |
+
|
| 202 |
+
def forward(self, Stoks):
|
| 203 |
+
xin = self.embedding(Stoks)
|
| 204 |
+
if self.emb_factor:
|
| 205 |
+
xin = self.emb_to_hidden(xin)
|
| 206 |
+
|
| 207 |
+
assert xin.shape[1:] == self.positional_embedding.shape, "incorrect semantic token shape"
|
| 208 |
+
xin = (xin + self.positional_embedding).to(xin.dtype)
|
| 209 |
+
|
| 210 |
+
return self.ln_post(self.layers(xin))
|
| 211 |
+
|
| 212 |
+
# %% ../nbs/5B. Text to semantic token modeling.ipynb 17
|
| 213 |
+
class Decoder(nn.Module):
|
| 214 |
+
def __init__(self, depth=6, stoks_width=384, width=384, n_head=6, length=1500, codes=1024, ffn_mult=4, pos_embs=None, tunables=Tunables()):
|
| 215 |
+
super().__init__()
|
| 216 |
+
self.length = length
|
| 217 |
+
self.codes = codes
|
| 218 |
+
self.width = width
|
| 219 |
+
self.stoks_width = stoks_width
|
| 220 |
+
|
| 221 |
+
self.emb_factor = width != stoks_width
|
| 222 |
+
|
| 223 |
+
# embed semantic tokens
|
| 224 |
+
self.embedding = nn.Embedding(codes, stoks_width)
|
| 225 |
+
if self.emb_factor:
|
| 226 |
+
self.emb_to_hidden = EmbeddingProjector(stoks_width, width)
|
| 227 |
+
self.hidden_to_emb = EmbeddingProjector(width, stoks_width)
|
| 228 |
+
|
| 229 |
+
if pos_embs is None: pos_embs = sinusoids(length, width)
|
| 230 |
+
self.register_buffer("positional_embedding", pos_embs)
|
| 231 |
+
|
| 232 |
+
self.layers = nn.ModuleList([
|
| 233 |
+
ResidualAttentionBlock(width, n_head, cross_attention=True,
|
| 234 |
+
qk_scale=tunables.query_mult*8/math.sqrt(width/n_head), ffn_mult=ffn_mult) for _ in range(depth)
|
| 235 |
+
])
|
| 236 |
+
self.ln_post = LayerNorm(width)
|
| 237 |
+
|
| 238 |
+
def forward(self, Stoks, xenc, cps=None):
|
| 239 |
+
Sembs = self.embedding(Stoks)
|
| 240 |
+
|
| 241 |
+
if self.emb_factor:
|
| 242 |
+
Sembs = self.emb_to_hidden(Sembs)
|
| 243 |
+
|
| 244 |
+
xin = (Sembs + self.positional_embedding[:Sembs.shape[1]]).to(xenc.dtype)
|
| 245 |
+
if cps is not None: xin = xin + cps
|
| 246 |
+
|
| 247 |
+
x = xin
|
| 248 |
+
for l in self.layers: x = l(x, xenc, causal=True)
|
| 249 |
+
|
| 250 |
+
x = self.ln_post(x)
|
| 251 |
+
|
| 252 |
+
if self.emb_factor:
|
| 253 |
+
x = self.hidden_to_emb(x)
|
| 254 |
+
|
| 255 |
+
logits = (x @ self.embedding.weight.to(x.dtype).T).float()
|
| 256 |
+
return logits
|
| 257 |
+
|
| 258 |
+
# %% ../nbs/5B. Text to semantic token modeling.ipynb 18
|
| 259 |
+
class TSARTransformer(nn.Module):
|
| 260 |
+
def __init__(self, depth=6, n_head=6, head_width=64, ffn_mult=4, language='en',
|
| 261 |
+
ttoks_len=200, ttoks_codes=50364, ttoks_width=None,
|
| 262 |
+
stoks_len=1500, stoks_codes=1024, stoks_width=None,
|
| 263 |
+
tunables=Tunables()):
|
| 264 |
+
assert language == 'en', "only english is supported right now"
|
| 265 |
+
super().__init__()
|
| 266 |
+
store_attr("depth,n_head,head_width,ffn_mult,stoks_width,ttoks_width,ttoks_len,stoks_len,ttoks_codes,stoks_codes,language")
|
| 267 |
+
|
| 268 |
+
width = n_head * head_width
|
| 269 |
+
self.width = width
|
| 270 |
+
self.base_width = 3 * head_width
|
| 271 |
+
self.tunables = tunables
|
| 272 |
+
if self.stoks_width is None: self.stoks_width = self.width
|
| 273 |
+
if self.ttoks_width is None: self.ttoks_width = self.width
|
| 274 |
+
|
| 275 |
+
if tunables.cps_input:
|
| 276 |
+
self.cps_embeddings = nn.Embedding(tunables.cps_bins, self.width)
|
| 277 |
+
else:
|
| 278 |
+
self.cps_embeddings = None
|
| 279 |
+
|
| 280 |
+
encoder_depth = int(depth * 2 * tunables.encoder_depth_ratio)
|
| 281 |
+
decoder_depth = depth * 2 - encoder_depth
|
| 282 |
+
tformer_args = dict(width=width, n_head=n_head, ffn_mult=ffn_mult, tunables=tunables)
|
| 283 |
+
self.encoder = Encoder(length=ttoks_len, codes=ttoks_codes, emb_width=self.ttoks_width, depth=encoder_depth, **tformer_args)
|
| 284 |
+
self.decoder = Decoder(length=stoks_len, codes=stoks_codes, stoks_width=self.stoks_width, depth=decoder_depth, **tformer_args)
|
| 285 |
+
|
| 286 |
+
self.tokenizer = None
|
| 287 |
+
|
| 288 |
+
self.apply(self.init_transformer)
|
| 289 |
+
|
| 290 |
+
def load_frozen_semantic_embeddings(self, vqmodel):
|
| 291 |
+
with torch.no_grad():
|
| 292 |
+
self.decoder.embedding.weight[:] = vqmodel.rq.layers[0]._codebook.embed[0]
|
| 293 |
+
self.decoder.embedding.lr_scale = 0
|
| 294 |
+
|
| 295 |
+
def setup(self, device):
|
| 296 |
+
pass
|
| 297 |
+
|
| 298 |
+
def init_transformer(self, m):
|
| 299 |
+
if isinstance(m, LinearHead):
|
| 300 |
+
m.no_weight_decay = True
|
| 301 |
+
torch.nn.init.constant_(m.weight, 0)
|
| 302 |
+
elif isinstance(m, QueryHead):
|
| 303 |
+
m.lr_scale = 1/(m.weight.shape[1] / self.base_width)
|
| 304 |
+
torch.nn.init.constant_(m.weight, 0)
|
| 305 |
+
elif isinstance(m, nn.Embedding):
|
| 306 |
+
m.no_weight_decay = True
|
| 307 |
+
m.lr_scale = self.tunables.embeddings_lr_scale
|
| 308 |
+
std = self.tunables.embeddings_std
|
| 309 |
+
torch.nn.init.trunc_normal_(m.weight, std=std, a=-3*std, b=3*std)
|
| 310 |
+
elif isinstance(m, EmbeddingProjector):
|
| 311 |
+
m.lr_scale = self.tunables.embedding_projector_lr_scale
|
| 312 |
+
std = self.tunables.init_std
|
| 313 |
+
torch.nn.init.trunc_normal_(m.weight, std=std, a=-3*std, b=3*std)
|
| 314 |
+
elif isinstance(m, nn.Linear):
|
| 315 |
+
m.lr_scale = 1/(m.weight.shape[1] / self.base_width)
|
| 316 |
+
std = self.tunables.init_std / m.weight.shape[1]
|
| 317 |
+
torch.nn.init.trunc_normal_(m.weight, std=std, a=-3*std, b=3*std)
|
| 318 |
+
if m.bias is not None:
|
| 319 |
+
torch.nn.init.trunc_normal_(m.bias, std=std, a=-3*std, b=3*std)
|
| 320 |
+
elif isinstance(m, nn.LayerNorm):
|
| 321 |
+
m.no_weight_decay = True
|
| 322 |
+
torch.nn.init.constant_(m.bias, 0)
|
| 323 |
+
torch.nn.init.constant_(m.weight, 1)
|
| 324 |
+
|
| 325 |
+
def forward(self, Ttoks, speakers, cpss, in_stoks, out_stoks=None, loss=True):
|
| 326 |
+
with record_function("encoder"):
|
| 327 |
+
xenc = self.encoder(Ttoks.to(torch.long))
|
| 328 |
+
with record_function("decoder"):
|
| 329 |
+
if self.cps_embeddings:
|
| 330 |
+
cps_bin = (cpss / 20 * self.tunables.cps_bins).to(torch.long)
|
| 331 |
+
cps_bin[cps_bin >= self.tunables.cps_bins] = self.tunables.cps_bins-1
|
| 332 |
+
cps_embs = self.cps_embeddings(cps_bin).unsqueeze(1)
|
| 333 |
+
else:
|
| 334 |
+
cps_embs = None
|
| 335 |
+
logits = self.decoder(in_stoks, xenc, cps=cps_embs) * self.tunables.output_mult / (self.width / self.base_width)
|
| 336 |
+
if loss is not None:
|
| 337 |
+
with record_function("loss"):
|
| 338 |
+
loss = F.cross_entropy(logits.transpose(-1,-2), out_stoks)#, reduction='none')
|
| 339 |
+
return logits, loss
|
| 340 |
+
|
| 341 |
+
#
|
| 342 |
+
# inference
|
| 343 |
+
#
|
| 344 |
+
@classmethod
|
| 345 |
+
def load_model(cls, repo_id="collabora/whisperspeech", filename="t2s_up_wds.model", local_filename=None):
|
| 346 |
+
if not local_filename:
|
| 347 |
+
local_filename = hf_hub_download(repo_id=repo_id, filename=filename)
|
| 348 |
+
spec = torch.load(local_filename)
|
| 349 |
+
model = cls(**spec['config'], tunables=Tunables(**spec['tunables']))
|
| 350 |
+
model.load_state_dict(spec['state_dict'])
|
| 351 |
+
model.eval()
|
| 352 |
+
return model
|
| 353 |
+
|
| 354 |
+
def load_checkpoint(self, local_filename):
|
| 355 |
+
spec = torch.load(local_filename, map_location='cpu')
|
| 356 |
+
assert 'pytorch-lightning_version' in spec, 'not a valid PyTorch Lightning checkpoint'
|
| 357 |
+
state_dict = {k.replace('model.', ''):v
|
| 358 |
+
for k,v in spec['state_dict'].items()}
|
| 359 |
+
self.load_state_dict(state_dict)
|
| 360 |
+
return self
|
| 361 |
+
|
| 362 |
+
def save_model(self, fname):
|
| 363 |
+
torch.save(dict(config = self.__stored_args__,
|
| 364 |
+
tunables = dataclasses.asdict(self.tunables),
|
| 365 |
+
state_dict = self.state_dict()), fname)
|
| 366 |
+
|
| 367 |
+
def ensure_tokenizer(self):
|
| 368 |
+
assert not self.training
|
| 369 |
+
if self.tokenizer is None: self.tokenizer = CharTokenizer()
|
| 370 |
+
#whisper.tokenizer.get_tokenizer(multilingual=True)
|
| 371 |
+
|
| 372 |
+
@property
|
| 373 |
+
def device(self):
|
| 374 |
+
return next(self.parameters()).device
|
| 375 |
+
|
| 376 |
+
@torch.no_grad()
|
| 377 |
+
def generate(self, txt, cps=15, N=None, T=0.7, top_k=None, show_progress_bar=True):
|
| 378 |
+
self.ensure_tokenizer()
|
| 379 |
+
N = N or self.stoks_len
|
| 380 |
+
dev = self.device
|
| 381 |
+
ttoks = torch.tensor(self.tokenizer.encode(txt), device=dev)
|
| 382 |
+
ttoks = F.pad(ttoks, (0, self.ttoks_len - len(ttoks)), value=self.tokenizer.eot).unsqueeze(0)
|
| 383 |
+
cpss = torch.tensor([cps], device=dev)
|
| 384 |
+
toks = torch.zeros((1,N), dtype=torch.long, device=dev)
|
| 385 |
+
toks[0,0] = self.stoks_codes-1
|
| 386 |
+
it = range(1,N)
|
| 387 |
+
if show_progress_bar: it = progress_bar(it)
|
| 388 |
+
for i in it:
|
| 389 |
+
p, _ = self(ttoks, None, cpss, toks[:,:i], loss=None)
|
| 390 |
+
last_p = p[0,-1]
|
| 391 |
+
if top_k:
|
| 392 |
+
last_p[last_p < torch.topk(last_p, top_k).values[-1,None]] = -torch.inf
|
| 393 |
+
tok = torch.multinomial((last_p / float(T)).softmax(-1), 1)
|
| 394 |
+
toks[0,i] = tok
|
| 395 |
+
if toks[0,i] == self.stoks_codes-1: return toks[0,1:i]
|
| 396 |
+
return toks[0,1:]
|
| 397 |
+
|
| 398 |
+
@torch.no_grad()
|
| 399 |
+
def generate_batch(self, txts, N=None, T=1.1, top_k=7, show_progress_bar=True):
|
| 400 |
+
self.ensure_tokenizer()
|
| 401 |
+
N = self.stoks_len
|
| 402 |
+
dev = self.device
|
| 403 |
+
ttoks = []
|
| 404 |
+
for txt in txts:
|
| 405 |
+
ttoks_ = torch.tensor(self.tokenizer.encode(txt), device=dev)
|
| 406 |
+
ttoks_ = F.pad(ttoks_, (0, self.ttoks_len - len(ttoks_)), value=self.tokenizer.eot).unsqueeze(0)
|
| 407 |
+
ttoks.append(ttoks_)
|
| 408 |
+
ttoks = torch.cat(ttoks, dim=0)
|
| 409 |
+
toks = torch.zeros((len(ttoks),N), dtype=torch.long, device=dev)
|
| 410 |
+
it = range(N)
|
| 411 |
+
if show_progress_bar: it = progress_bar(it)
|
| 412 |
+
for i in it:
|
| 413 |
+
p, _ = self(ttoks, toks[:,:i], loss=None)
|
| 414 |
+
last_p = p[:,-1]
|
| 415 |
+
if top_k:
|
| 416 |
+
last_p[last_p < torch.topk(last_p, top_k).values[:,-1,None]] = -torch.inf
|
| 417 |
+
tok = torch.multinomial((last_p / float(T)).softmax(-1), 1)
|
| 418 |
+
toks[:,i] = tok[:,0]
|
| 419 |
+
if (toks[:,i] == self.stoks_codes-1).all(): return toks[:,:i]
|
| 420 |
+
return toks
|
| 421 |
+
|
| 422 |
+
# %% ../nbs/5B. Text to semantic token modeling.ipynb 19
|
| 423 |
+
def _make_model(size:str, tunables:Tunables=Tunables(), dataset=None, **kwargs):
|
| 424 |
+
kwargs = dict(stoks_len = dataset.stoks_len, ttoks_len = dataset.ttoks_len, tunables=tunables, **kwargs)
|
| 425 |
+
if 'stoks_codes' not in kwargs: kwargs['stoks_codes'] = dataset.stoks_codes
|
| 426 |
+
if size == 'micro':
|
| 427 |
+
return TSARTransformer(depth=2, n_head=3, ffn_mult=1, **kwargs)
|
| 428 |
+
if size == 'tiny':
|
| 429 |
+
return TSARTransformer(depth=4, n_head=6, **kwargs)
|
| 430 |
+
if size == 'base':
|
| 431 |
+
return TSARTransformer(depth=6, n_head=8, **kwargs)
|
| 432 |
+
if size == 'small':
|
| 433 |
+
return TSARTransformer(depth=12, n_head=16, **kwargs)
|
| 434 |
+
|
| 435 |
+
def make_model(size:str, frozen_embeddings_model:str=None, tunables:Tunables=Tunables(), dataset:torch.utils.data.Dataset=None):
|
| 436 |
+
if frozen_embeddings_model:
|
| 437 |
+
vqmodel = vq_stoks.RQBottleneckTransformer.load_model(frozen_embeddings_model)
|
| 438 |
+
model = _make_model(size, tunables, dataset, stoks_codes=vqmodel.vq_codes+1, stoks_width=vqmodel.rq.layers[0]._codebook.embed[0].shape[-1])
|
| 439 |
+
model.load_frozen_semantic_embeddings(vqmodel)
|
| 440 |
+
else:
|
| 441 |
+
model = _make_model(size, quantizers, tunables, dataset)
|
| 442 |
+
return model
|
whisperspeech/t2s_up_wds_mlang_enclm.py
ADDED
|
@@ -0,0 +1,519 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/5B. Multi-lang text to semantic token modeling.ipynb.
|
| 2 |
+
|
| 3 |
+
# %% auto 0
|
| 4 |
+
__all__ = ['load_dataset', 'rand', 'Tunables', 'T2SEmbedding', 'Encoder', 'TSARTransformer', 'make_model']
|
| 5 |
+
|
| 6 |
+
# %% ../nbs/5B. Multi-lang text to semantic token modeling.ipynb 1
|
| 7 |
+
import dataclasses
|
| 8 |
+
import random
|
| 9 |
+
import math
|
| 10 |
+
import itertools
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
import torch.nn.functional as F
|
| 14 |
+
from torch.profiler import record_function
|
| 15 |
+
|
| 16 |
+
from huggingface_hub import hf_hub_download
|
| 17 |
+
from fastcore.basics import store_attr
|
| 18 |
+
from fastprogress import progress_bar
|
| 19 |
+
|
| 20 |
+
from pathlib import Path
|
| 21 |
+
|
| 22 |
+
# %% ../nbs/5B. Multi-lang text to semantic token modeling.ipynb 2
|
| 23 |
+
from whisperspeech.modules import *
|
| 24 |
+
from whisperspeech import languages
|
| 25 |
+
|
| 26 |
+
# %% ../nbs/5B. Multi-lang text to semantic token modeling.ipynb 6
|
| 27 |
+
import re
|
| 28 |
+
|
| 29 |
+
class CharTokenizer:
|
| 30 |
+
"""Trivial tokenizer β just use UTF-8 bytes"""
|
| 31 |
+
eot = 0
|
| 32 |
+
|
| 33 |
+
def encode(self, txt):
|
| 34 |
+
return list(bytes(txt.strip(), 'utf-8'))
|
| 35 |
+
|
| 36 |
+
def decode(self, tokens):
|
| 37 |
+
return bytes(tokens).decode('utf-8')
|
| 38 |
+
|
| 39 |
+
def tokenizer(ikey, okey, length):
|
| 40 |
+
"""Tokenizes a transcript"""
|
| 41 |
+
tok = CharTokenizer()
|
| 42 |
+
def _tokenizer(samples):
|
| 43 |
+
for s in samples:
|
| 44 |
+
toks = torch.tensor(tok.encode(s[ikey]))
|
| 45 |
+
s[okey] = F.pad(toks, (0, length - toks.shape[-1]), value=tok.eot)
|
| 46 |
+
yield s
|
| 47 |
+
return _tokenizer
|
| 48 |
+
|
| 49 |
+
def ar_padder(ikey, okey, length, pad_token):
|
| 50 |
+
"""Pads the tokens for autoregresive training"""
|
| 51 |
+
import numpy as np
|
| 52 |
+
|
| 53 |
+
def _ar_padder(samples):
|
| 54 |
+
for s in samples:
|
| 55 |
+
toks = s[ikey]
|
| 56 |
+
if isinstance(toks, (list, np.ndarray)): toks = torch.tensor(toks)
|
| 57 |
+
toks = toks.to(torch.long)
|
| 58 |
+
s['in_' +okey] = F.pad(toks, (1, length - toks.shape[-1] - 1), value=pad_token)
|
| 59 |
+
s['out_'+okey] = F.pad(toks, (0, length - toks.shape[-1]), value=pad_token)
|
| 60 |
+
yield s
|
| 61 |
+
return _ar_padder
|
| 62 |
+
|
| 63 |
+
def char_per_seconder(txt_key, stoks_key, cps_key, stoks_per_second=25):
|
| 64 |
+
"""Adds the characters per second metric to the input data"""
|
| 65 |
+
def _char_per_seconder(samples):
|
| 66 |
+
for s in samples:
|
| 67 |
+
secs = s[stoks_key].shape[-1] / stoks_per_second
|
| 68 |
+
s[cps_key] = len(s[txt_key]) / secs
|
| 69 |
+
yield s
|
| 70 |
+
return _char_per_seconder
|
| 71 |
+
|
| 72 |
+
# %% ../nbs/5B. Multi-lang text to semantic token modeling.ipynb 7
|
| 73 |
+
def load_dataset(
|
| 74 |
+
txt_shard_spec:str, # transcription webdataset shards
|
| 75 |
+
stoks_shard_dir:str, # stoks webdataset base dir
|
| 76 |
+
samples:int, # samples per epoch
|
| 77 |
+
txt_kind:str='small.en-txt',
|
| 78 |
+
vq_codes:int=4096,
|
| 79 |
+
language:str='en',
|
| 80 |
+
weight:float=1,
|
| 81 |
+
validation:bool=False,
|
| 82 |
+
exclude_files:str=None,
|
| 83 |
+
):
|
| 84 |
+
import webdataset as wds
|
| 85 |
+
from whisperspeech import utils
|
| 86 |
+
|
| 87 |
+
shards = utils.shard_glob(txt_shard_spec)
|
| 88 |
+
excludes = {x for file in exclude_files.split() for x in utils.readlines(file)} if exclude_files else set()
|
| 89 |
+
|
| 90 |
+
language = languages.to_id(language)
|
| 91 |
+
|
| 92 |
+
def set_language(x):
|
| 93 |
+
x['language'] = language
|
| 94 |
+
return x
|
| 95 |
+
|
| 96 |
+
same_on_all_nodes = lambda urls: urls # will only be used for validation
|
| 97 |
+
ds = wds.WebDataset(shards, resampled=not validation, nodesplitter=same_on_all_nodes).compose(
|
| 98 |
+
wds.decode(),
|
| 99 |
+
utils.merge_in(utils.derived_dataset('eqvad-stoks', base=txt_kind, suffix='', dir=stoks_shard_dir)),
|
| 100 |
+
# discard validation samples, select samples > .5s
|
| 101 |
+
wds.select(lambda s: s['__key__'] not in excludes and s['stoks.npy'].shape[-1] > 12),
|
| 102 |
+
tokenizer('txt', 'ttoks', length=550),
|
| 103 |
+
ar_padder('stoks.npy', 'stoks', length=750, pad_token=vq_codes-1),
|
| 104 |
+
ar_padder('ttoks', 'ttoks', length=550, pad_token=CharTokenizer.eot),
|
| 105 |
+
char_per_seconder('txt', 'stoks.npy', 'cps', stoks_per_second=25),
|
| 106 |
+
wds.map(set_language),
|
| 107 |
+
wds.to_tuple('in_ttoks', 'out_ttoks', 'language', 'cps', 'in_stoks', 'out_stoks'),
|
| 108 |
+
wds.shuffle(20000, initial=20000),
|
| 109 |
+
wds.batched(64)
|
| 110 |
+
)
|
| 111 |
+
if validation:
|
| 112 |
+
ds = ds.slice(samples // 64)
|
| 113 |
+
ds.total_samples = samples
|
| 114 |
+
ds.stoks_len = 750
|
| 115 |
+
ds.stoks_codes = vq_codes
|
| 116 |
+
ds.ttoks_len = 550
|
| 117 |
+
ds.weight = weight
|
| 118 |
+
|
| 119 |
+
return ds
|
| 120 |
+
|
| 121 |
+
# %% ../nbs/5B. Multi-lang text to semantic token modeling.ipynb 14
|
| 122 |
+
def rand(start, end):
|
| 123 |
+
return random.random() * (end - start) + start
|
| 124 |
+
|
| 125 |
+
@dataclasses.dataclass
|
| 126 |
+
class Tunables:
|
| 127 |
+
init_std :float = 1
|
| 128 |
+
embeddings_std :float = .01
|
| 129 |
+
embeddings_lr_scale: float = 5
|
| 130 |
+
embedding_projector_lr_scale: float = 2.5
|
| 131 |
+
output_mult :float = .35
|
| 132 |
+
query_mult :float = 1
|
| 133 |
+
encoder_depth_ratio :float = 0.25
|
| 134 |
+
eot_dropout_p :float = .5
|
| 135 |
+
cps_input: bool = True
|
| 136 |
+
cps_bins: int = 32
|
| 137 |
+
|
| 138 |
+
lr0 :float = 1.5e-3
|
| 139 |
+
clip_gradient_norm :float = .2
|
| 140 |
+
weight_decay :float = 1e-1
|
| 141 |
+
warmup_steps :float = 4000
|
| 142 |
+
|
| 143 |
+
random :bool = False
|
| 144 |
+
|
| 145 |
+
def __post_init__(self):
|
| 146 |
+
# randomize the hyperparams if requested
|
| 147 |
+
if self.random:
|
| 148 |
+
self.init_std = 10**rand(-1,1)
|
| 149 |
+
self.embeddings_std = 10**rand(-3,-.7)
|
| 150 |
+
self.embeddings_lr_scale = rand(2,6)
|
| 151 |
+
self.output_mult = rand(0.25,0.65)
|
| 152 |
+
self.query_mult = 2**rand(-2,3)
|
| 153 |
+
self.encoder_depth_ratio = 0.25
|
| 154 |
+
|
| 155 |
+
self.lr0 = rand(1,5)*1e-3
|
| 156 |
+
self.clip_gradient_norm = 10**rand(-3,0)
|
| 157 |
+
self.warmup_steps = 100*(10**rand(1,1.85))
|
| 158 |
+
|
| 159 |
+
# %% ../nbs/5B. Multi-lang text to semantic token modeling.ipynb 15
|
| 160 |
+
class T2SEmbedding(nn.Module):
|
| 161 |
+
def __init__(self, length=1500, codes=1024, width=384, pos_embs=None, stoks_width=384):
|
| 162 |
+
super().__init__()
|
| 163 |
+
self.embedding = FlexEmbeddings(codes, width, special_codes=1, frozen_width=stoks_width)
|
| 164 |
+
if pos_embs is None: pos_embs = sinusoids(length, width)
|
| 165 |
+
self.register_buffer("positional_embedding", pos_embs)
|
| 166 |
+
|
| 167 |
+
def forward(self, Stoks, xenc, cps=None, offset=0):
|
| 168 |
+
Sembs = self.embedding(Stoks)
|
| 169 |
+
xin = (Sembs + self.positional_embedding[offset : offset + Sembs.shape[1]]).to(xenc.dtype)
|
| 170 |
+
if cps is not None: xin = xin + cps
|
| 171 |
+
return xin, offset
|
| 172 |
+
|
| 173 |
+
# %% ../nbs/5B. Multi-lang text to semantic token modeling.ipynb 16
|
| 174 |
+
class Encoder(nn.Module):
|
| 175 |
+
def __init__(self, depth=6, width=384, n_head=6, length=1500, codes=1024, emb_width=384, ffn_mult=4, pos_embs=None, tunables=Tunables()):
|
| 176 |
+
super().__init__()
|
| 177 |
+
self.emb_width = emb_width
|
| 178 |
+
|
| 179 |
+
self.embedding = FlexEmbeddings(codes, width, frozen_width=emb_width)
|
| 180 |
+
|
| 181 |
+
if pos_embs is None: pos_embs = sinusoids(length, width)
|
| 182 |
+
self.register_buffer("positional_embedding", pos_embs)
|
| 183 |
+
|
| 184 |
+
self.layers = nn.ModuleList([
|
| 185 |
+
ResidualAttentionBlock(width, n_head,
|
| 186 |
+
qk_scale=tunables.query_mult*8/math.sqrt(width/n_head), ffn_mult=ffn_mult) for _ in range(depth)
|
| 187 |
+
])
|
| 188 |
+
|
| 189 |
+
self.ln_post = LayerNorm(width)
|
| 190 |
+
|
| 191 |
+
mask = torch.empty(length, length).fill_(-torch.inf).triu_(1)
|
| 192 |
+
self.register_buffer("mask", mask, persistent=False)
|
| 193 |
+
|
| 194 |
+
def forward(self, Stoks, positions, lang_emb=None):
|
| 195 |
+
xin = self.embedding(Stoks)
|
| 196 |
+
|
| 197 |
+
if lang_emb is not None: xin += lang_emb
|
| 198 |
+
|
| 199 |
+
# assert xin.shape[1:] == self.positional_embedding.shape, "incorrect semantic token shape"
|
| 200 |
+
x = (xin +
|
| 201 |
+
self.positional_embedding[positions]).to(xin.dtype)
|
| 202 |
+
|
| 203 |
+
for l in self.layers: x = l(x, positions, causal=False, mask=self.mask)
|
| 204 |
+
|
| 205 |
+
return self.ln_post(x)
|
| 206 |
+
|
| 207 |
+
# %% ../nbs/5B. Multi-lang text to semantic token modeling.ipynb 17
|
| 208 |
+
class TSARTransformer(nn.Module):
|
| 209 |
+
def __init__(self, depth=6, n_head=6, head_width=64, ffn_mult=4,
|
| 210 |
+
ttoks_len=200, ttoks_codes=256, ttoks_width=None,
|
| 211 |
+
stoks_len=1500, stoks_codes=1024, stoks_width=None,
|
| 212 |
+
tunables=Tunables()):
|
| 213 |
+
super().__init__()
|
| 214 |
+
store_attr("depth,n_head,head_width,ffn_mult,stoks_width,ttoks_width,ttoks_len,stoks_len,ttoks_codes,stoks_codes")
|
| 215 |
+
|
| 216 |
+
width = n_head * head_width
|
| 217 |
+
self.width = width
|
| 218 |
+
self.base_width = 3 * head_width
|
| 219 |
+
self.tunables = tunables
|
| 220 |
+
if self.stoks_width is None: self.stoks_width = self.width
|
| 221 |
+
if self.ttoks_width is None: self.ttoks_width = self.width
|
| 222 |
+
|
| 223 |
+
self.lang_embeddings = nn.Embedding(len(languages.languages), width)
|
| 224 |
+
if tunables.cps_input:
|
| 225 |
+
self.cps_embeddings = nn.Embedding(tunables.cps_bins, self.width)
|
| 226 |
+
else:
|
| 227 |
+
self.cps_embeddings = None
|
| 228 |
+
|
| 229 |
+
encoder_depth = int(depth * 2 * tunables.encoder_depth_ratio)
|
| 230 |
+
decoder_depth = depth * 2 - encoder_depth
|
| 231 |
+
tformer_args = dict(width=width, n_head=n_head, ffn_mult=ffn_mult, tunables=tunables)
|
| 232 |
+
self.encoder = Encoder(length=ttoks_len, codes=ttoks_codes, emb_width=self.ttoks_width, depth=encoder_depth, **tformer_args)
|
| 233 |
+
self.embeddings = T2SEmbedding(length=stoks_len, codes=stoks_codes, width=width, stoks_width=self.stoks_width)
|
| 234 |
+
|
| 235 |
+
self.decoder = BaseDecoder(
|
| 236 |
+
length=stoks_len,
|
| 237 |
+
depth=decoder_depth,
|
| 238 |
+
qk_scale=tunables.query_mult*8/math.sqrt(width/n_head),
|
| 239 |
+
width=width, n_head=n_head, ffn_mult=ffn_mult,
|
| 240 |
+
)
|
| 241 |
+
self.tokenizer = None
|
| 242 |
+
|
| 243 |
+
self.apply(self.init_transformer)
|
| 244 |
+
|
| 245 |
+
def load_frozen_semantic_embeddings(self, vqmodel):
|
| 246 |
+
self.embeddings.embedding.set_frozen_embeddings(vqmodel.rq.layers[0]._codebook.embed[0])
|
| 247 |
+
|
| 248 |
+
def setup(self, device):
|
| 249 |
+
pass
|
| 250 |
+
|
| 251 |
+
def init_transformer(self, m):
|
| 252 |
+
if isinstance(m, LinearHead):
|
| 253 |
+
m.no_weight_decay = True
|
| 254 |
+
torch.nn.init.constant_(m.weight, 0)
|
| 255 |
+
elif isinstance(m, QueryHead):
|
| 256 |
+
m.lr_scale = 1/(m.weight.shape[1] / self.base_width)
|
| 257 |
+
torch.nn.init.constant_(m.weight, 0)
|
| 258 |
+
elif isinstance(m, nn.Embedding):
|
| 259 |
+
m.no_weight_decay = True
|
| 260 |
+
m.lr_scale = self.tunables.embeddings_lr_scale
|
| 261 |
+
std = self.tunables.embeddings_std
|
| 262 |
+
torch.nn.init.trunc_normal_(m.weight, std=std, a=-3*std, b=3*std)
|
| 263 |
+
elif isinstance(m, EmbeddingProjector):
|
| 264 |
+
m.lr_scale = self.tunables.embedding_projector_lr_scale
|
| 265 |
+
std = self.tunables.init_std
|
| 266 |
+
torch.nn.init.trunc_normal_(m.weight, std=std, a=-3*std, b=3*std)
|
| 267 |
+
elif isinstance(m, nn.Linear):
|
| 268 |
+
m.lr_scale = 1/(m.weight.shape[1] / self.base_width)
|
| 269 |
+
std = self.tunables.init_std / m.weight.shape[1]
|
| 270 |
+
torch.nn.init.trunc_normal_(m.weight, std=std, a=-3*std, b=3*std)
|
| 271 |
+
if m.bias is not None:
|
| 272 |
+
torch.nn.init.trunc_normal_(m.bias, std=std, a=-3*std, b=3*std)
|
| 273 |
+
elif isinstance(m, nn.LayerNorm):
|
| 274 |
+
m.no_weight_decay = True
|
| 275 |
+
torch.nn.init.constant_(m.bias, 0)
|
| 276 |
+
torch.nn.init.constant_(m.weight, 1)
|
| 277 |
+
|
| 278 |
+
def _embed_cps(self, cpss):
|
| 279 |
+
if self.cps_embeddings is None: return None
|
| 280 |
+
|
| 281 |
+
cps_bin = (cpss / 20 * self.tunables.cps_bins).to(torch.long)
|
| 282 |
+
cps_bin[cps_bin >= self.tunables.cps_bins] = self.tunables.cps_bins-1
|
| 283 |
+
return self.cps_embeddings(cps_bin).unsqueeze(1)
|
| 284 |
+
|
| 285 |
+
def run_encoder(self, in_ttoks, languages, cpss):
|
| 286 |
+
if len(languages.shape) != 3: lang_embs = self.lang_embeddings(languages)
|
| 287 |
+
else: lang_embs = languages
|
| 288 |
+
if len(lang_embs.shape) == 2: lang_embs = lang_embs.unsqueeze(1)
|
| 289 |
+
|
| 290 |
+
cps_emb = self._embed_cps(cpss)
|
| 291 |
+
|
| 292 |
+
with record_function("encoder"):
|
| 293 |
+
positions = torch.arange(0, in_ttoks.shape[1], device=in_ttoks.device)
|
| 294 |
+
xenc = self.encoder(in_ttoks.to(torch.long), positions, lang_emb=lang_embs)
|
| 295 |
+
|
| 296 |
+
return xenc, positions, cps_emb
|
| 297 |
+
|
| 298 |
+
def forward(self, in_ttoks, out_ttoks, languages, cpss, in_stoks, in_stoks_positions, out_stoks=None, loss=True, offset=None, xenc=None, xenc_positions=None, cps_emb=None):
|
| 299 |
+
if xenc is None:
|
| 300 |
+
xenc, cps_emb = self.run_encoder(in_ttoks, languages, cpss)
|
| 301 |
+
|
| 302 |
+
with record_function("decoder"):
|
| 303 |
+
x = (self.embeddings.embedding(in_stoks) +
|
| 304 |
+
self.embeddings.positional_embedding[in_stoks_positions] +
|
| 305 |
+
cps_emb).to(xenc[0].dtype)
|
| 306 |
+
x = self.decoder(x, in_stoks_positions, xenc, xenc_positions)
|
| 307 |
+
logits = self.embeddings.embedding.unembed(x)
|
| 308 |
+
logits = logits * self.tunables.output_mult / (self.width / self.base_width)
|
| 309 |
+
|
| 310 |
+
if loss is not None:
|
| 311 |
+
enc_logits = self.encoder.embedding.unembed(xenc[0])
|
| 312 |
+
enc_logits = enc_logits * self.tunables.output_mult / (self.width / self.base_width)
|
| 313 |
+
with record_function("loss"):
|
| 314 |
+
loss = F.cross_entropy(logits.transpose(-1,-2), out_stoks)
|
| 315 |
+
if self.training:
|
| 316 |
+
loss += 0.1 * F.cross_entropy(enc_logits.transpose(-1,-2), out_ttoks)
|
| 317 |
+
|
| 318 |
+
return logits, loss
|
| 319 |
+
|
| 320 |
+
#
|
| 321 |
+
# inference
|
| 322 |
+
#
|
| 323 |
+
@classmethod
|
| 324 |
+
def load_model(cls, ref="collabora/whisperspeech:t2s-small-en+pl.model",
|
| 325 |
+
repo_id=None, filename=None, local_filename=None):
|
| 326 |
+
if repo_id is None and filename is None and local_filename is None:
|
| 327 |
+
if ":" in ref:
|
| 328 |
+
repo_id, filename = ref.split(":", 1)
|
| 329 |
+
else:
|
| 330 |
+
local_filename = ref
|
| 331 |
+
if not local_filename:
|
| 332 |
+
local_filename = hf_hub_download(repo_id=repo_id, filename=filename)
|
| 333 |
+
spec = torch.load(local_filename)
|
| 334 |
+
model = cls(**spec['config'], tunables=Tunables(**spec['tunables']))
|
| 335 |
+
model.load_state_dict(spec['state_dict'])
|
| 336 |
+
model.eval()
|
| 337 |
+
return model
|
| 338 |
+
|
| 339 |
+
def load_checkpoint(self, local_filename):
|
| 340 |
+
spec = torch.load(local_filename, map_location='cpu')
|
| 341 |
+
assert 'pytorch-lightning_version' in spec, 'not a valid PyTorch Lightning checkpoint'
|
| 342 |
+
state_dict = {k.replace('model.', ''):v
|
| 343 |
+
for k,v in spec['state_dict'].items()}
|
| 344 |
+
self.load_state_dict(state_dict)
|
| 345 |
+
return self
|
| 346 |
+
|
| 347 |
+
def save_model(self, fname):
|
| 348 |
+
torch.save(dict(config = self.__stored_args__,
|
| 349 |
+
tunables = dataclasses.asdict(self.tunables),
|
| 350 |
+
state_dict = self.state_dict()), fname)
|
| 351 |
+
|
| 352 |
+
def ensure_tokenizer(self):
|
| 353 |
+
assert not self.training
|
| 354 |
+
if self.tokenizer is None: self.tokenizer = CharTokenizer()
|
| 355 |
+
|
| 356 |
+
def switch_dtypes(self, dtype=torch.float16):
|
| 357 |
+
self.dtype = dtype
|
| 358 |
+
for n,m in self.named_modules():
|
| 359 |
+
# convert every leaf layer apart from the LayerNorms
|
| 360 |
+
if isinstance(m, (nn.Linear, nn.Embedding)):
|
| 361 |
+
m.to(dtype)
|
| 362 |
+
# take care of buffers ([kv]_cache, masks) that are not in the leaf layers
|
| 363 |
+
for bn,b in m.named_buffers(recurse=False):
|
| 364 |
+
setattr(m,bn,b.to(dtype))
|
| 365 |
+
|
| 366 |
+
def optimize(self, max_batch_size=1, dtype=torch.float16, torch_compile=True):
|
| 367 |
+
for emb in [self.embeddings.embedding, self.embeddings.embedding]:
|
| 368 |
+
emb.convert_for_eval()
|
| 369 |
+
for l in self.encoder.layers:
|
| 370 |
+
l.attn.convert_for_eval()
|
| 371 |
+
for l in self.decoder.layers:
|
| 372 |
+
l.attn.convert_for_eval()
|
| 373 |
+
l.cross_attn.convert_for_eval()
|
| 374 |
+
l.setup_kv_cache(max_batch_size, self.stoks_len, self.ttoks_len)
|
| 375 |
+
self.switch_dtypes(dtype)
|
| 376 |
+
if torch_compile:
|
| 377 |
+
self.generate_next = torch.compile(self.generate_next, mode="reduce-overhead", fullgraph=True)
|
| 378 |
+
|
| 379 |
+
@property
|
| 380 |
+
def device(self):
|
| 381 |
+
return next(self.parameters()).device
|
| 382 |
+
|
| 383 |
+
# from https://github.com/pytorch-labs/gpt-fast/blob/main/generate.py
|
| 384 |
+
def multinomial_sample_one_no_sync(self, probs_sort): # Does multinomial sampling without a cuda synchronization
|
| 385 |
+
q = torch.empty_like(probs_sort).exponential_(1)
|
| 386 |
+
return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
|
| 387 |
+
|
| 388 |
+
def logits_to_probs(self, logits, T=1.0, top_k=None):
|
| 389 |
+
logits = logits / max(T, 1e-5)
|
| 390 |
+
|
| 391 |
+
logits[self.embeddings.embedding.codes:] = -torch.inf
|
| 392 |
+
if top_k is not None:
|
| 393 |
+
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
| 394 |
+
pivot = v.select(-1, -1).unsqueeze(-1)
|
| 395 |
+
logits = torch.where(logits < pivot, -float("Inf"), logits)
|
| 396 |
+
|
| 397 |
+
probs = torch.nn.functional.softmax(logits, dim=-1)
|
| 398 |
+
return probs
|
| 399 |
+
|
| 400 |
+
def sample(self, logits, T=1.0, top_k=None):
|
| 401 |
+
probs = self.logits_to_probs(logits[0,-1], T, top_k)
|
| 402 |
+
idx_next = self.multinomial_sample_one_no_sync(probs)
|
| 403 |
+
return idx_next
|
| 404 |
+
|
| 405 |
+
def generate_one(self, toks, toks_positions, cps_emb, xenc, xenc_positions, T, top_k):
|
| 406 |
+
probs, _ = self(None, None, None, None, toks, toks_positions, loss=None, xenc=xenc, xenc_positions=xenc_positions, cps_emb=cps_emb)
|
| 407 |
+
return self.sample(probs, T, top_k)
|
| 408 |
+
|
| 409 |
+
def generate_next(self, *args, **kwargs):
|
| 410 |
+
return self.generate_one(*args, **kwargs)
|
| 411 |
+
|
| 412 |
+
@torch.no_grad()
|
| 413 |
+
def prep(self, txt, cps=15, lang="en"):
|
| 414 |
+
dev = self.device
|
| 415 |
+
ttoks = torch.tensor(self.tokenizer.encode(txt), device=dev)
|
| 416 |
+
ttoks = F.pad(ttoks, (0, self.ttoks_len - len(ttoks)), value=self.tokenizer.eot).unsqueeze(0)
|
| 417 |
+
cpss = torch.tensor([cps], device=dev)
|
| 418 |
+
langs = torch.tensor([languages.to_id(lang)], device=dev)
|
| 419 |
+
return ttoks, cpss, langs
|
| 420 |
+
|
| 421 |
+
@torch.no_grad()
|
| 422 |
+
def generate(self, txt, cps=15, lang="en", N=None, T=0.7, top_k=None, step=None, show_progress_bar=True):
|
| 423 |
+
self.ensure_tokenizer()
|
| 424 |
+
N = N or self.stoks_len
|
| 425 |
+
dev = self.device
|
| 426 |
+
ttoks = []
|
| 427 |
+
langs = []
|
| 428 |
+
if isinstance(lang, list):
|
| 429 |
+
lang0 = lang[0]
|
| 430 |
+
assert isinstance(txt, list), "lang and txt have to be both lists or strings"
|
| 431 |
+
for txt, lang in zip(txt, lang):
|
| 432 |
+
tt = self.tokenizer.encode(txt)
|
| 433 |
+
ttoks += tt
|
| 434 |
+
langs += [languages.to_id(lang)] * len(tt)
|
| 435 |
+
elif isinstance(lang, torch.Tensor):
|
| 436 |
+
langs = lang
|
| 437 |
+
ttoks = self.tokenizer.encode(txt)
|
| 438 |
+
else:
|
| 439 |
+
lang0 = lang
|
| 440 |
+
ttoks = self.tokenizer.encode(txt)
|
| 441 |
+
langs = torch.tensor([languages.to_id(lang)], device=dev).unsqueeze(0)
|
| 442 |
+
ttoks = torch.tensor(ttoks, device=dev)
|
| 443 |
+
ttoks = F.pad(ttoks, (1, self.ttoks_len - len(ttoks) - 1), value=self.tokenizer.eot).unsqueeze(0)
|
| 444 |
+
cpss = torch.tensor([cps], device=dev)
|
| 445 |
+
if not isinstance(langs, torch.Tensor):
|
| 446 |
+
langs = torch.tensor(langs, device=dev)
|
| 447 |
+
langs = F.pad(langs, (1, self.ttoks_len - len(langs) - 1), value=languages.to_id(lang0)).unsqueeze(0)
|
| 448 |
+
it = range(0,N-1)
|
| 449 |
+
if show_progress_bar: it = progress_bar(it)
|
| 450 |
+
|
| 451 |
+
toks = torch.zeros((1,N), dtype=torch.long, device=dev)
|
| 452 |
+
toks[:,0] = self.stoks_codes-1
|
| 453 |
+
toks_positions = torch.arange(N, device=dev)
|
| 454 |
+
with record_function("encode"):
|
| 455 |
+
xenc, xenc_positions, cps_emb = self.run_encoder(ttoks, langs, cpss)
|
| 456 |
+
toks_positions = torch.arange(N+1, device=dev)
|
| 457 |
+
# contrary to S2A this model works without prefill and is actually a tiny bit faster
|
| 458 |
+
# with record_function("prefill"):
|
| 459 |
+
# toks[0,1] = self.generate_one(toks[:,:1], toks_positions[:1], cps_emb, xenc, xenc_positions, T, top_k)
|
| 460 |
+
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True):
|
| 461 |
+
for i in it:
|
| 462 |
+
toks[0,i+1] = self.generate_next(toks[:,i:i+1], toks_positions[i:i+1], cps_emb, xenc, xenc_positions, T, top_k)
|
| 463 |
+
if i % 25 == 0 and toks[0,i+1] == self.stoks_codes-1: return toks[0,:i+1]
|
| 464 |
+
|
| 465 |
+
# for profiling, debugging or early exit
|
| 466 |
+
if step is not None: step()
|
| 467 |
+
return toks[0,:]
|
| 468 |
+
|
| 469 |
+
@torch.no_grad()
|
| 470 |
+
def generate_batch(self, txts, N=None, T=1.1, top_k=7, show_progress_bar=True):
|
| 471 |
+
self.ensure_tokenizer()
|
| 472 |
+
N = self.stoks_len
|
| 473 |
+
dev = self.device
|
| 474 |
+
ttoks = []
|
| 475 |
+
for txt in txts:
|
| 476 |
+
ttoks_ = torch.tensor(self.tokenizer.encode(txt), device=dev)
|
| 477 |
+
ttoks_ = F.pad(ttoks_, (0, self.ttoks_len - len(ttoks_)), value=self.tokenizer.eot).unsqueeze(0)
|
| 478 |
+
ttoks.append(ttoks_)
|
| 479 |
+
ttoks = torch.cat(ttoks, dim=0)
|
| 480 |
+
toks = torch.zeros((len(ttoks),N), dtype=torch.long, device=dev)
|
| 481 |
+
it = range(N)
|
| 482 |
+
if show_progress_bar: it = progress_bar(it)
|
| 483 |
+
for i in it:
|
| 484 |
+
p, _ = self(ttoks, toks[:,:i], loss=None)
|
| 485 |
+
last_p = p[:,-1]
|
| 486 |
+
if top_k:
|
| 487 |
+
last_p[last_p < torch.topk(last_p, top_k).values[:,-1,None]] = -torch.inf
|
| 488 |
+
tok = torch.multinomial((last_p / float(T)).softmax(-1), 1)
|
| 489 |
+
toks[:,i] = tok[:,0]
|
| 490 |
+
if (toks[:,i] == self.stoks_codes-1).all(): return toks[:,:i]
|
| 491 |
+
return toks
|
| 492 |
+
|
| 493 |
+
# %% ../nbs/5B. Multi-lang text to semantic token modeling.ipynb 18
|
| 494 |
+
def _make_model(size:str, tunables:Tunables=Tunables(), dataset=None, **kwargs):
|
| 495 |
+
kwargs = dict(stoks_len = dataset.stoks_len, ttoks_len = dataset.ttoks_len, tunables=tunables, **kwargs)
|
| 496 |
+
if 'stoks_codes' not in kwargs: kwargs['stoks_codes'] = dataset.stoks_codes
|
| 497 |
+
if size == 'micro':
|
| 498 |
+
return TSARTransformer(depth=2, n_head=3, ffn_mult=1, **kwargs)
|
| 499 |
+
if size == 'tiny':
|
| 500 |
+
return TSARTransformer(depth=4, n_head=6, **kwargs)
|
| 501 |
+
if size == 'base':
|
| 502 |
+
return TSARTransformer(depth=6, n_head=8, **kwargs)
|
| 503 |
+
if size == 'small':
|
| 504 |
+
return TSARTransformer(depth=12, n_head=12, **kwargs)
|
| 505 |
+
if size == 'small+':
|
| 506 |
+
return TSARTransformer(depth=12, n_head=16, **kwargs)
|
| 507 |
+
if size == 'medium':
|
| 508 |
+
return TSARTransformer(depth=24, n_head=16, **kwargs)
|
| 509 |
+
|
| 510 |
+
def make_model(size:str, frozen_embeddings_model:str=None, tunables:Tunables=Tunables(), dataset:torch.utils.data.Dataset=None):
|
| 511 |
+
from whisperspeech import vq_stoks
|
| 512 |
+
|
| 513 |
+
if frozen_embeddings_model:
|
| 514 |
+
vqmodel = vq_stoks.RQBottleneckTransformer.load_model(frozen_embeddings_model)
|
| 515 |
+
model = _make_model(size, tunables, dataset, stoks_codes=vqmodel.vq_codes+1, stoks_width=vqmodel.rq.layers[0]._codebook.embed[0].shape[-1])
|
| 516 |
+
model.load_frozen_semantic_embeddings(vqmodel)
|
| 517 |
+
else:
|
| 518 |
+
model = _make_model(size, tunables, dataset, mode=mode)
|
| 519 |
+
return model
|
whisperspeech/train.py
ADDED
|
@@ -0,0 +1,271 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/B1. Training.ipynb.
|
| 2 |
+
|
| 3 |
+
# %% auto 0
|
| 4 |
+
__all__ = ['SimpleVisual', 'validate', 'train']
|
| 5 |
+
|
| 6 |
+
# %% ../nbs/B1. Training.ipynb 2
|
| 7 |
+
import io
|
| 8 |
+
import time
|
| 9 |
+
import random
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
|
| 12 |
+
from fastprogress import progress_bar, master_bar
|
| 13 |
+
import fastprogress
|
| 14 |
+
|
| 15 |
+
import numpy as np
|
| 16 |
+
import pylab as plt
|
| 17 |
+
import math
|
| 18 |
+
|
| 19 |
+
import IPython
|
| 20 |
+
|
| 21 |
+
import torch
|
| 22 |
+
import torch.nn as nn
|
| 23 |
+
from torch.utils.data.dataloader import DataLoader
|
| 24 |
+
from torch.profiler import record_function
|
| 25 |
+
|
| 26 |
+
import webdataset as wds
|
| 27 |
+
|
| 28 |
+
torch.backends.cudnn.benchmark = True
|
| 29 |
+
torch.backends.cudnn.enabled = True
|
| 30 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 31 |
+
torch.set_float32_matmul_precision('medium')
|
| 32 |
+
|
| 33 |
+
# %% ../nbs/B1. Training.ipynb 3
|
| 34 |
+
class SimpleVisual:
|
| 35 |
+
def __init__ (self, model, masterbar, total_steps):
|
| 36 |
+
self.model = model
|
| 37 |
+
self.masterbar = masterbar
|
| 38 |
+
self.total_steps = total_steps
|
| 39 |
+
self.epochs = total_steps // masterbar.main_bar.total
|
| 40 |
+
|
| 41 |
+
gs = plt.GridSpec(2, 1, height_ratios=[3,1])
|
| 42 |
+
graph_fig = plt.figure(figsize=(10,6))
|
| 43 |
+
self.graph_fig = graph_fig
|
| 44 |
+
self.loss_p = graph_fig.add_subplot(gs[0])
|
| 45 |
+
self.lr_p = graph_fig.add_subplot(gs[1], sharex=self.loss_p)
|
| 46 |
+
self.lr_p.tick_params('x', labelbottom=False)
|
| 47 |
+
self.graph_out = None
|
| 48 |
+
|
| 49 |
+
self.its = []
|
| 50 |
+
self.train_losses = []
|
| 51 |
+
self.val_losses = []
|
| 52 |
+
self.lr_history = []
|
| 53 |
+
|
| 54 |
+
def show(self):
|
| 55 |
+
self.start_t = time.time()
|
| 56 |
+
self.masterbar.write(["samples", "train", "val", "time"], table=True)
|
| 57 |
+
self.graph_out = display(self.graph_fig, display_id=True, clear=True)
|
| 58 |
+
|
| 59 |
+
def hide(self):
|
| 60 |
+
if self.graph_out is not None:
|
| 61 |
+
self.graph_out.update(IPython.display.HTML(''))
|
| 62 |
+
|
| 63 |
+
def plot(self):
|
| 64 |
+
loss_p, lr_p = self.loss_p, self.lr_p
|
| 65 |
+
loss_p.clear()
|
| 66 |
+
loss_p.plot(self.its, self.train_losses)
|
| 67 |
+
loss_p.plot(self.its, self.val_losses)
|
| 68 |
+
loss_p.set_xlim(0, self.total_steps)
|
| 69 |
+
loss_p.set_yscale('log')
|
| 70 |
+
lr_p.clear()
|
| 71 |
+
lrs = np.array(self.lr_history)
|
| 72 |
+
lr_p.plot(self.its, lrs)
|
| 73 |
+
self.graph_out.update(self.graph_fig)
|
| 74 |
+
|
| 75 |
+
def add_data(self, it, lr, train_loss, val_los):
|
| 76 |
+
self.its.append(it)
|
| 77 |
+
self.train_losses.append(train_loss)
|
| 78 |
+
self.val_losses.append(val_los)
|
| 79 |
+
self.lr_history.append(lr)
|
| 80 |
+
self.plot()
|
| 81 |
+
|
| 82 |
+
def add_table_row(self, it, avg_train_loss, val_loss):
|
| 83 |
+
elapsed_t = time.time() - self.start_t
|
| 84 |
+
self.masterbar.write([it, f"{avg_train_loss:.5f}", f"{val_loss:.5f}", fastprogress.core.format_time(elapsed_t)], table=True)
|
| 85 |
+
|
| 86 |
+
def on_iter(self, bar, it, avg_train_loss, val_loss):
|
| 87 |
+
epoch = math.ceil(it / self.total_steps * self.epochs)
|
| 88 |
+
bar.comment = f"#{epoch}/{self.epochs} loss: {avg_train_loss:.3f} / {val_loss:.3f}"
|
| 89 |
+
|
| 90 |
+
# %% ../nbs/B1. Training.ipynb 4
|
| 91 |
+
# FIXME: we need to keep this synchronised with the validation code below...
|
| 92 |
+
def validate(model, val, half=True, bs=16, drop_last=False, dl_workers=8, device="cuda"):
|
| 93 |
+
if isinstance(val, torch.utils.data.IterableDataset):
|
| 94 |
+
val_loader = wds.WebLoader(val, batch_size=None, num_workers=dl_workers, drop_last=drop_last) \
|
| 95 |
+
.unbatched().shuffle(1024).batched(bs)
|
| 96 |
+
else:
|
| 97 |
+
val_loader = DataLoader(val, batch_size=bs, num_workers=dl_workers, pin_memory=True, drop_last=drop_last)
|
| 98 |
+
|
| 99 |
+
with torch.no_grad():
|
| 100 |
+
val_loss = 0
|
| 101 |
+
val_samples = 0
|
| 102 |
+
for args in val_loader:
|
| 103 |
+
args = [x.to(device, non_blocking=True) for x in args]
|
| 104 |
+
with torch.autocast(device_type=device, dtype=torch.float16 if half else torch.float32, enabled=device!='cpu'):
|
| 105 |
+
ps, loss = model(*args)
|
| 106 |
+
N = args[0].shape[0]
|
| 107 |
+
val_loss += loss.mean().item() * N
|
| 108 |
+
val_samples += N
|
| 109 |
+
val_loss = val_loss / val_samples
|
| 110 |
+
|
| 111 |
+
return val_loss
|
| 112 |
+
|
| 113 |
+
# %% ../nbs/B1. Training.ipynb 5
|
| 114 |
+
def train(checkpoint_path, model, train, val, half=True, bs=16, lr=1e-4, drop_last=False,
|
| 115 |
+
weight_decay=0.1, warmup_steps=10000, epochs=10, clip_gradient_norm=None,
|
| 116 |
+
dl_workers=8, visual_class = SimpleVisual, profiler=None,
|
| 117 |
+
run_valid_every_iters=8000, table_row_every_iters=80000, chkpt_every_iters=None,
|
| 118 |
+
device="cuda", trainable_params=None):
|
| 119 |
+
if chkpt_every_iters is None:
|
| 120 |
+
chkpt_every_iters = table_row_every_iters
|
| 121 |
+
|
| 122 |
+
mb = master_bar(range(epochs))
|
| 123 |
+
if isinstance(train, torch.utils.data.IterableDataset):
|
| 124 |
+
pct_start = min(0.3, warmup_steps / (epochs * (train.total_samples//bs)))
|
| 125 |
+
visual = visual_class(model, mb, epochs * train.total_samples)
|
| 126 |
+
# pct_start = min(0.3, warmup_steps / (epochs * len(train)))
|
| 127 |
+
# visual = visual_class(model, mb, epochs*len(train)*bs)
|
| 128 |
+
else:
|
| 129 |
+
pct_start = min(0.3, warmup_steps / (epochs * len(train) / bs))
|
| 130 |
+
visual = visual_class(model, mb, epochs*len(train))
|
| 131 |
+
model.visual = visual
|
| 132 |
+
|
| 133 |
+
Path(checkpoint_path).mkdir(exist_ok=True)
|
| 134 |
+
|
| 135 |
+
if isinstance(train, torch.utils.data.IterableDataset):
|
| 136 |
+
# train_loader = DataLoader(train, batch_size=None, num_workers=dl_workers, pin_memory=True, drop_last=False, shuffle=False)
|
| 137 |
+
# val_loader = DataLoader(val, batch_size=None, num_workers=dl_workers, pin_memory=True, drop_last=False)
|
| 138 |
+
train_loader = wds.WebLoader(train, batch_size=None, num_workers=dl_workers, drop_last=drop_last) \
|
| 139 |
+
.unbatched().shuffle(1024).batched(bs, partial=False)
|
| 140 |
+
val_loader = wds.WebLoader(val, batch_size=None, num_workers=dl_workers, drop_last=drop_last) \
|
| 141 |
+
.unbatched().shuffle(1024).batched(bs)
|
| 142 |
+
else:
|
| 143 |
+
train_loader = DataLoader(train, batch_size=bs, num_workers=dl_workers, pin_memory=True, drop_last=drop_last, shuffle=True)
|
| 144 |
+
val_loader = DataLoader(val, batch_size=bs, num_workers=dl_workers, pin_memory=True, drop_last=drop_last)
|
| 145 |
+
|
| 146 |
+
val_loss = torch.nan
|
| 147 |
+
avg_train_loss = torch.nan
|
| 148 |
+
|
| 149 |
+
if hasattr(model, 'setup'):
|
| 150 |
+
model.setup(device)
|
| 151 |
+
|
| 152 |
+
try:
|
| 153 |
+
scheduler = None
|
| 154 |
+
|
| 155 |
+
if trainable_params is None: trainable_params = model.parameters()
|
| 156 |
+
all_params = set(trainable_params)
|
| 157 |
+
customized_params = set()
|
| 158 |
+
groups = []
|
| 159 |
+
group_map = {}
|
| 160 |
+
for name,m in model.named_modules():
|
| 161 |
+
if hasattr(m, 'no_weight_decay') or hasattr(m, 'lr_scale'):
|
| 162 |
+
m_trainable = [x for x in m.parameters() if x in all_params]
|
| 163 |
+
if not m_trainable: continue
|
| 164 |
+
customized_params |= set(m_trainable)
|
| 165 |
+
m_wd = 0 if hasattr(m, 'no_weight_decay') else weight_decay
|
| 166 |
+
m_lr = lr * getattr(m, 'lr_scale', 1)
|
| 167 |
+
group = group_map.get((m_wd, m_lr), None)
|
| 168 |
+
if not group:
|
| 169 |
+
group = {"params": [], "names": [], "weight_decay": m_wd, "lr": m_lr}
|
| 170 |
+
groups.append(group)
|
| 171 |
+
group_map[(m_wd, m_lr)] = group
|
| 172 |
+
group['params'] += m_trainable
|
| 173 |
+
group['names'].append(name)
|
| 174 |
+
|
| 175 |
+
other_params = all_params - customized_params
|
| 176 |
+
|
| 177 |
+
if other_params:
|
| 178 |
+
groups = groups + [
|
| 179 |
+
{"names": ["other"], "params": list(other_params), "weight_decay": weight_decay },
|
| 180 |
+
]
|
| 181 |
+
|
| 182 |
+
optimizer = torch.optim.AdamW(lr=lr, betas=(0.9, 0.95), fused=device!='cpu', params=groups)
|
| 183 |
+
model._optimizer = optimizer
|
| 184 |
+
scaler = torch.cuda.amp.GradScaler(enabled=half)
|
| 185 |
+
scheduler = torch.optim.lr_scheduler.OneCycleLR(
|
| 186 |
+
optimizer, pct_start=pct_start, steps_per_epoch=math.ceil(train.total_samples/bs), epochs=epochs,
|
| 187 |
+
max_lr=[pg.get('lr', lr) for pg in groups],
|
| 188 |
+
final_div_factor=25)
|
| 189 |
+
|
| 190 |
+
it = 0
|
| 191 |
+
next_val_it = it + 50
|
| 192 |
+
next_chkpt_it = chkpt_every_iters
|
| 193 |
+
next_table_it = table_row_every_iters
|
| 194 |
+
|
| 195 |
+
visual.show()
|
| 196 |
+
|
| 197 |
+
running_loss = [0]
|
| 198 |
+
|
| 199 |
+
for epoch in mb:
|
| 200 |
+
bar = progress_bar(train_loader, total=train.total_samples//bs, parent=mb)
|
| 201 |
+
for args in bar:
|
| 202 |
+
with record_function("forward"):
|
| 203 |
+
args = [x.to(device, non_blocking=True) for x in args]
|
| 204 |
+
|
| 205 |
+
# zero the parameter gradients
|
| 206 |
+
optimizer.zero_grad(set_to_none=True)
|
| 207 |
+
|
| 208 |
+
with torch.autocast(device_type=device, dtype=torch.float16 if half else torch.float32, enabled=device!='cpu'):
|
| 209 |
+
ps, loss = model(*args)
|
| 210 |
+
loss = loss.mean()
|
| 211 |
+
|
| 212 |
+
with record_function("backward"):
|
| 213 |
+
scaler.scale(loss).backward()
|
| 214 |
+
|
| 215 |
+
if clip_gradient_norm:
|
| 216 |
+
scaler.unscale_(optimizer)
|
| 217 |
+
# Since the gradients of optimizer's assigned params are unscaled, clips as usual:
|
| 218 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), clip_gradient_norm)
|
| 219 |
+
|
| 220 |
+
scaler.step(optimizer)
|
| 221 |
+
scaler.update()
|
| 222 |
+
|
| 223 |
+
scheduler.step()
|
| 224 |
+
|
| 225 |
+
if profiler is not None: profiler.step()
|
| 226 |
+
|
| 227 |
+
with record_function("running_loss"):
|
| 228 |
+
running_loss.append(loss.item())
|
| 229 |
+
running_loss = running_loss[-5:]
|
| 230 |
+
avg_train_loss = sum(running_loss)/len(running_loss)
|
| 231 |
+
|
| 232 |
+
if it >= next_chkpt_it:
|
| 233 |
+
with record_function("checkpoint"):
|
| 234 |
+
next_chkpt_it += chkpt_every_iters
|
| 235 |
+
torch.save(model.state_dict(), f'{checkpoint_path}/{it:08d}.pt')
|
| 236 |
+
|
| 237 |
+
if it >= next_val_it:
|
| 238 |
+
next_val_it += run_valid_every_iters
|
| 239 |
+
with record_function("validation"):
|
| 240 |
+
with record_function("model.eval"):
|
| 241 |
+
model.eval()
|
| 242 |
+
with torch.no_grad():
|
| 243 |
+
val_loss = 0
|
| 244 |
+
val_samples = 0
|
| 245 |
+
for args in val_loader:
|
| 246 |
+
args = [x.to(device, non_blocking=True) for x in args]
|
| 247 |
+
with torch.autocast(device_type=device, dtype=torch.float16 if half else torch.float32, enabled=device!='cpu'):
|
| 248 |
+
ps, loss = model(*args)
|
| 249 |
+
N = args[0].shape[0]
|
| 250 |
+
val_loss += loss.mean().item() * N
|
| 251 |
+
val_samples += N
|
| 252 |
+
val_loss = val_loss / val_samples
|
| 253 |
+
with record_function("model.train"):
|
| 254 |
+
model.train()
|
| 255 |
+
with record_function("plotting"):
|
| 256 |
+
visual.add_data(it, scheduler.get_last_lr(), avg_train_loss, val_loss)
|
| 257 |
+
|
| 258 |
+
if it >= next_table_it:
|
| 259 |
+
visual.add_table_row(it, avg_train_loss, val_loss)
|
| 260 |
+
next_table_it += table_row_every_iters
|
| 261 |
+
|
| 262 |
+
it += bs
|
| 263 |
+
visual.on_iter(bar, it, avg_train_loss, val_loss)
|
| 264 |
+
except KeyboardInterrupt:
|
| 265 |
+
mb.write(f"interrupted")
|
| 266 |
+
mb.show()
|
| 267 |
+
pass
|
| 268 |
+
finally:
|
| 269 |
+
visual.add_table_row(it, avg_train_loss, val_loss)
|
| 270 |
+
mb.show()
|
| 271 |
+
visual.hide()
|
whisperspeech/train_multi.py
ADDED
|
@@ -0,0 +1,263 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/B2. Training (Lightning).ipynb.
|
| 2 |
+
|
| 3 |
+
# %% auto 0
|
| 4 |
+
__all__ = []
|
| 5 |
+
|
| 6 |
+
# %% ../nbs/B2. Training (Lightning).ipynb 2
|
| 7 |
+
import io
|
| 8 |
+
import time
|
| 9 |
+
import random
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
|
| 12 |
+
from fastprogress import progress_bar, master_bar
|
| 13 |
+
import fastprogress
|
| 14 |
+
import wandb
|
| 15 |
+
|
| 16 |
+
import numpy as np
|
| 17 |
+
import pylab as plt
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
import torch.nn as nn
|
| 21 |
+
from torch.utils.data.dataloader import DataLoader
|
| 22 |
+
from torch.profiler import record_function
|
| 23 |
+
|
| 24 |
+
# %% ../nbs/B2. Training (Lightning).ipynb 3
|
| 25 |
+
import lightning.pytorch as pl
|
| 26 |
+
import math
|
| 27 |
+
|
| 28 |
+
class TrainingTask(pl.LightningModule):
|
| 29 |
+
def __init__(self, model, model_hparams=None):
|
| 30 |
+
super().__init__()
|
| 31 |
+
self.model = model
|
| 32 |
+
self.model_hparams = model_hparams
|
| 33 |
+
|
| 34 |
+
def on_fit_start(self):
|
| 35 |
+
if getattr(self.model, 'setup'):
|
| 36 |
+
self.model.setup(self.device)
|
| 37 |
+
|
| 38 |
+
def configure_optimizers(self):
|
| 39 |
+
""" Initialize AdamW optimizer"""
|
| 40 |
+
lr = self.model_hparams['lr0']
|
| 41 |
+
weight_decay = self.model_hparams['weight_decay']
|
| 42 |
+
|
| 43 |
+
all_params = set(model.parameters())
|
| 44 |
+
customized_params = set()
|
| 45 |
+
groups = []
|
| 46 |
+
group_map = {}
|
| 47 |
+
for name,m in model.named_modules():
|
| 48 |
+
if hasattr(m, 'no_weight_decay') or hasattr(m, 'lr_scale'):
|
| 49 |
+
customized_params |= set(m.parameters())
|
| 50 |
+
m_wd = 0 if hasattr(m, 'no_weight_decay') else weight_decay
|
| 51 |
+
m_lr = lr * getattr(m, 'lr_scale', 1)
|
| 52 |
+
group = group_map.get((m_wd, m_lr), None)
|
| 53 |
+
if not group:
|
| 54 |
+
group = {"params": [], "names": [], "weight_decay": m_wd, "lr": m_lr}
|
| 55 |
+
groups.append(group)
|
| 56 |
+
group_map[(m_wd, m_lr)] = group
|
| 57 |
+
group['params'] += m.parameters()
|
| 58 |
+
group['names'].append(name)
|
| 59 |
+
|
| 60 |
+
other_params = all_params - customized_params
|
| 61 |
+
|
| 62 |
+
param_groups = groups + [
|
| 63 |
+
{"names": ["other"], "params": list(other_params), "weight_decay": weight_decay },
|
| 64 |
+
]
|
| 65 |
+
|
| 66 |
+
optimizer = torch.optim.AdamW(lr=lr, betas=(0.9, 0.95), params=param_groups)
|
| 67 |
+
|
| 68 |
+
# modified from https://github.com/Lightning-AI/lightning/issues/5449#issuecomment-1501597319
|
| 69 |
+
def num_steps_per_epoch() -> int:
|
| 70 |
+
"""Get number of steps"""
|
| 71 |
+
# Accessing _data_source is flaky and might break
|
| 72 |
+
dataset = self.trainer.fit_loop._data_source.dataloader()
|
| 73 |
+
dataset_size = len(dataset)
|
| 74 |
+
# math.ceil so always overestimate (underestimating throws exceptions)
|
| 75 |
+
num_steps = math.ceil(dataset_size / self.trainer.accumulate_grad_batches)
|
| 76 |
+
return num_steps
|
| 77 |
+
|
| 78 |
+
total_steps = self.model_hparams['epochs'] * num_steps_per_epoch()
|
| 79 |
+
self.model_hparams['pct_start'] = min(0.3, self.model_hparams['warmup_steps'] / total_steps)
|
| 80 |
+
|
| 81 |
+
print(f"{self.model_hparams['epochs']=} epochs x {num_steps_per_epoch()=} steps")
|
| 82 |
+
|
| 83 |
+
lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(
|
| 84 |
+
optimizer,
|
| 85 |
+
pct_start=self.model_hparams['pct_start'],
|
| 86 |
+
max_lr=[pg.get('lr', lr) for pg in param_groups],
|
| 87 |
+
steps_per_epoch=num_steps_per_epoch(),
|
| 88 |
+
epochs=int(self.model_hparams['epochs']),
|
| 89 |
+
final_div_factor=25
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
return [optimizer], [{'scheduler': lr_scheduler, 'interval': 'step'}]
|
| 93 |
+
|
| 94 |
+
def training_step(self, train_batch, batch_idx):
|
| 95 |
+
train_logits, train_loss = self.model.forward(*train_batch)
|
| 96 |
+
|
| 97 |
+
self.log("train_loss", train_loss, sync_dist=True)
|
| 98 |
+
return train_loss
|
| 99 |
+
|
| 100 |
+
def validation_step(self, val_batch, batch_idx):
|
| 101 |
+
val_logits, val_loss = self.model.forward(*val_batch)
|
| 102 |
+
|
| 103 |
+
self.log("val_loss", val_loss, sync_dist=True)
|
| 104 |
+
return val_loss
|
| 105 |
+
|
| 106 |
+
def on_validation_epoch_end(self):
|
| 107 |
+
if hasattr(self.model, 'get_metrics'):
|
| 108 |
+
self.log_dict({'metrics/'+k:v for k,v in self.model.get_metrics().items()}, sync_dist=True)
|
| 109 |
+
|
| 110 |
+
def test_step(self, val_batch, batch_idx):
|
| 111 |
+
test_logits, test_loss = self.model.forward(*val_batch)
|
| 112 |
+
|
| 113 |
+
self.log("test_loss", test_loss, sync_dist=True)
|
| 114 |
+
return test_loss
|
| 115 |
+
|
| 116 |
+
# %% ../nbs/B2. Training (Lightning).ipynb 4
|
| 117 |
+
from fastcore.script import anno_parser
|
| 118 |
+
import shlex
|
| 119 |
+
|
| 120 |
+
# watch out: we can only pass Python values as keyword arguments (not positional)
|
| 121 |
+
# everything else has to be a string
|
| 122 |
+
def parse_and_call(name, fun, args, kwargs={}, log_to_wandb=True):
|
| 123 |
+
p = anno_parser(fun)
|
| 124 |
+
args = p.parse_args(args).__dict__
|
| 125 |
+
args.pop('xtra'); args.pop('pdb')
|
| 126 |
+
args.update({k:v for k, v in kwargs.items()})
|
| 127 |
+
if log_to_wandb and type(wandb_logger.experiment.config) == wandb.sdk.wandb_config.Config:
|
| 128 |
+
wandb_logger.experiment.config[name] = {k:v for k,v in args.items() if k not in ['dataset', 'tunables']}
|
| 129 |
+
return fun(**args)
|
| 130 |
+
|
| 131 |
+
# %% ../nbs/B2. Training (Lightning).ipynb 8
|
| 132 |
+
import argparse
|
| 133 |
+
|
| 134 |
+
parser = argparse.ArgumentParser()
|
| 135 |
+
parser.add_argument('--task', type=str, help='Task to train')
|
| 136 |
+
parser.add_argument('--seed', type=int, default=0, help='Global training seed')
|
| 137 |
+
parser.add_argument('--batch-size', type=int, default=16, help='total batch size for all GPUs')
|
| 138 |
+
parser.add_argument('--workers', type=int, default=8, help='max dataloader workers (per RANK in DDP mode)')
|
| 139 |
+
parser.add_argument('--input-dir', type=str, default='', help='input data path') # fixed in the model for now
|
| 140 |
+
parser.add_argument("--checkpoint-dir", type=str, default="./checkpoints/", help="directory to save the checkpoints")
|
| 141 |
+
parser.add_argument('--epochs', type=int, default=10, help='total training epochs')
|
| 142 |
+
parser.add_argument('--validate-every-n-steps', type=int, default=500, help='how training steps to run between validations')
|
| 143 |
+
parser.add_argument('--weight-decay', type=float, default=1e-2, help='optimizer weight decay')
|
| 144 |
+
parser.add_argument('--lr0', type=float, default=1e-4, help='optimizer initial learning rate')
|
| 145 |
+
parser.add_argument('--clip-gradient-norm', type=float, default=None, help='enable gradient norm clipping')
|
| 146 |
+
parser.add_argument('--accumulate-grad-batches', type=int, default=1, help='perform the optimizer step only after going through several batches of samples')
|
| 147 |
+
parser.add_argument('--precision', type=str, default="16-mixed", help="floating point precision")
|
| 148 |
+
parser.add_argument('--warmup-steps', type=int, default=10000, help='total number steps during which the learning rate rises (defaults to 10k updates)')
|
| 149 |
+
parser.add_argument('--tunables', type=str, default="", help='tunable hyperparameters')
|
| 150 |
+
parser.add_argument('--resume-from', type=Path, default=None, help='resume training from the given checkpoint')
|
| 151 |
+
parser.add_argument('--strategy', type=str, default='ddp', help='distributed training strategy')
|
| 152 |
+
parser.add_argument('--wandb-suffix', type=str, default=None, help='W&B project name suffix')
|
| 153 |
+
parser.add_argument('--wandb-task-name', type=str, default=None, help='Task name for the W&B project name')
|
| 154 |
+
|
| 155 |
+
args = parser.parse_args().__dict__
|
| 156 |
+
|
| 157 |
+
task_args: list = shlex.split(args.pop("task"))
|
| 158 |
+
task_name, task_args = task_args[0], task_args[1:]
|
| 159 |
+
input_args: list = shlex.split(args.pop("input_dir"))
|
| 160 |
+
checkpoint_dir: str = args.pop("checkpoint_dir")
|
| 161 |
+
num_workers: int = args.pop("workers")
|
| 162 |
+
batch_size: int = args.pop("batch_size")
|
| 163 |
+
epochs: int = args.pop("epochs")
|
| 164 |
+
tunables_args: list = shlex.split(args.pop("tunables"))
|
| 165 |
+
|
| 166 |
+
hyp_params = {}
|
| 167 |
+
hyp_params['batch_size'] = batch_size
|
| 168 |
+
hyp_params['warmup_steps'] = args['warmup_steps']
|
| 169 |
+
hyp_params['weight_decay'] = args['weight_decay']
|
| 170 |
+
hyp_params['clip_gradient_norm'] = args['clip_gradient_norm']
|
| 171 |
+
hyp_params['accumulate_grad_batches'] = args['accumulate_grad_batches']
|
| 172 |
+
hyp_params['precision'] = args['precision']
|
| 173 |
+
hyp_params['lr0'] = args['lr0']
|
| 174 |
+
hyp_params['epochs'] = epochs
|
| 175 |
+
hyp_params['strategy'] = args['strategy']
|
| 176 |
+
|
| 177 |
+
# %% ../nbs/B2. Training (Lightning).ipynb 9
|
| 178 |
+
from lightning.pytorch.loggers import WandbLogger
|
| 179 |
+
from lightning.pytorch.callbacks import LearningRateMonitor
|
| 180 |
+
import datetime
|
| 181 |
+
import webdataset as wds
|
| 182 |
+
import importlib
|
| 183 |
+
|
| 184 |
+
torch.set_float32_matmul_precision('medium')
|
| 185 |
+
|
| 186 |
+
project = f"WhisperSpeech-{args['wandb_task_name'] or task_name}"
|
| 187 |
+
if args['wandb_suffix']:
|
| 188 |
+
project += "-"+args['wandb_suffix']
|
| 189 |
+
|
| 190 |
+
wandb_logger = WandbLogger(project=project)
|
| 191 |
+
|
| 192 |
+
ckpt_callback = pl.callbacks.ModelCheckpoint(
|
| 193 |
+
dirpath=f'{task_name}-{epochs}e',
|
| 194 |
+
filename=task_name+"-{epoch}-{step}-{val_loss:.2f}",
|
| 195 |
+
monitor="val_loss",
|
| 196 |
+
save_top_k=4,
|
| 197 |
+
train_time_interval=datetime.timedelta(minutes=5),
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
lr_monitor_callback = LearningRateMonitor(logging_interval='step')
|
| 201 |
+
|
| 202 |
+
from torch.utils.data import DataLoader
|
| 203 |
+
|
| 204 |
+
task = importlib.import_module("whisperspeech."+task_name)
|
| 205 |
+
|
| 206 |
+
train_ds, val_ds = parse_and_call('dataset', task.load_datasets, input_args)
|
| 207 |
+
|
| 208 |
+
tunables = None
|
| 209 |
+
if hasattr(task, "Tunables"):
|
| 210 |
+
import dataclasses
|
| 211 |
+
tunables = parse_and_call('tunables', task.Tunables, tunables_args, log_to_wandb=False)
|
| 212 |
+
if type(wandb_logger.experiment.config) == wandb.sdk.wandb_config.Config:
|
| 213 |
+
wandb_logger.experiment.config['tunables'] = dataclasses.asdict(tunables)
|
| 214 |
+
|
| 215 |
+
for name in ["lr0", "clip_gradient_norm", "weight_decay", "warmup_steps"]:
|
| 216 |
+
val = getattr(tunables, name, None)
|
| 217 |
+
if val is not None: hyp_params[name] = val
|
| 218 |
+
|
| 219 |
+
if isinstance(train_ds, torch.utils.data.IterableDataset):
|
| 220 |
+
dl_batch_size, dl_shuffle = None, False
|
| 221 |
+
pin_memory = False
|
| 222 |
+
else:
|
| 223 |
+
dl_batch_size, dl_shuffle = batch_size, True
|
| 224 |
+
pin_memory = True
|
| 225 |
+
|
| 226 |
+
val_loader = wds.WebLoader(val_ds,
|
| 227 |
+
batch_size=dl_batch_size,
|
| 228 |
+
num_workers=num_workers,
|
| 229 |
+
drop_last=False,
|
| 230 |
+
pin_memory=pin_memory).unbatched().shuffle(1024).batched(batch_size).with_length(val_ds.total_samples // batch_size)
|
| 231 |
+
|
| 232 |
+
train_loader = wds.WebLoader(train_ds,
|
| 233 |
+
batch_size=dl_batch_size,
|
| 234 |
+
num_workers=num_workers,
|
| 235 |
+
drop_last=False,
|
| 236 |
+
shuffle=dl_shuffle,
|
| 237 |
+
pin_memory=pin_memory).unbatched().shuffle(1024).batched(batch_size).with_length(train_ds.total_samples // batch_size)
|
| 238 |
+
|
| 239 |
+
model_kwargs = dict(dataset=train_ds)
|
| 240 |
+
if tunables is not None: model_kwargs['tunables'] = tunables
|
| 241 |
+
model = parse_and_call('model', task.make_model, task_args, model_kwargs)
|
| 242 |
+
|
| 243 |
+
task = TrainingTask(model, model_hparams=hyp_params)
|
| 244 |
+
|
| 245 |
+
trainer = pl.Trainer(strategy=hyp_params['strategy'],
|
| 246 |
+
max_epochs=hyp_params['epochs'],
|
| 247 |
+
accelerator="gpu",
|
| 248 |
+
profiler="simple",
|
| 249 |
+
precision=hyp_params['precision'],
|
| 250 |
+
gradient_clip_val=hyp_params['clip_gradient_norm'],
|
| 251 |
+
accumulate_grad_batches=hyp_params['accumulate_grad_batches'],
|
| 252 |
+
val_check_interval=args.pop("validate_every_n_steps"),
|
| 253 |
+
enable_checkpointing=True,
|
| 254 |
+
logger=wandb_logger,
|
| 255 |
+
callbacks=[ckpt_callback, lr_monitor_callback])
|
| 256 |
+
|
| 257 |
+
if type(wandb_logger.experiment.config) == wandb.sdk.wandb_config.Config:
|
| 258 |
+
wandb_logger.experiment.config.update(hyp_params)
|
| 259 |
+
|
| 260 |
+
kwargs = {}
|
| 261 |
+
if 'resume_from' in args:
|
| 262 |
+
kwargs['ckpt_path'] = args['resume_from']
|
| 263 |
+
trainer.fit(model=task, train_dataloaders=train_loader, val_dataloaders=val_loader, **kwargs)
|
whisperspeech/utils.py
ADDED
|
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/D. Common dataset utilities.ipynb.
|
| 2 |
+
|
| 3 |
+
# %% auto 0
|
| 4 |
+
__all__ = ['shard_glob', 'join_datasets', 'resampler', 'derived_name', 'derived_dataset', 'merge_in', 'AtomicTarWriter',
|
| 5 |
+
'readlines']
|
| 6 |
+
|
| 7 |
+
# %% ../nbs/D. Common dataset utilities.ipynb 1
|
| 8 |
+
import os
|
| 9 |
+
import torch
|
| 10 |
+
import torchaudio
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
import webdataset as wds
|
| 13 |
+
from contextlib import contextmanager
|
| 14 |
+
|
| 15 |
+
import torch.nn.functional as F
|
| 16 |
+
|
| 17 |
+
# %% ../nbs/D. Common dataset utilities.ipynb 2
|
| 18 |
+
def shard_glob(input):
|
| 19 |
+
if '{' in input:
|
| 20 |
+
return wds.shardlists.expand_urls(input)
|
| 21 |
+
if isinstance(input, (Path, str)):
|
| 22 |
+
path = Path(input)
|
| 23 |
+
if path.is_dir():
|
| 24 |
+
glob = '*.tar.gz'
|
| 25 |
+
else:
|
| 26 |
+
glob = path.name
|
| 27 |
+
path = path.parent
|
| 28 |
+
input = Path(path).glob(glob)
|
| 29 |
+
else:
|
| 30 |
+
raise ArgumentError("input should be either a list or a path with an optional glob specifier")
|
| 31 |
+
return [str(x) for x in input]
|
| 32 |
+
|
| 33 |
+
# %% ../nbs/D. Common dataset utilities.ipynb 3
|
| 34 |
+
class join_datasets(torch.utils.data.IterableDataset):
|
| 35 |
+
def __init__(self, datasets):
|
| 36 |
+
self.datasets = datasets
|
| 37 |
+
|
| 38 |
+
def __iter__(self):
|
| 39 |
+
probs = torch.tensor([getattr(ds, 'weight', 1) for ds in self.datasets], dtype=torch.float)
|
| 40 |
+
its = [iter(ds) for ds in self.datasets]
|
| 41 |
+
while True:
|
| 42 |
+
try:
|
| 43 |
+
yield next(its[torch.multinomial(probs, 1)])
|
| 44 |
+
except StopIteration:
|
| 45 |
+
return
|
| 46 |
+
|
| 47 |
+
def __len__(self):
|
| 48 |
+
return sum([ds.total_samples for ds in self.datasets])
|
| 49 |
+
|
| 50 |
+
# %% ../nbs/D. Common dataset utilities.ipynb 5
|
| 51 |
+
def resampler(newsr = 24000, key = 'samples_24k'):
|
| 52 |
+
_last_sr = None
|
| 53 |
+
tform = None
|
| 54 |
+
|
| 55 |
+
def _resample(samples):
|
| 56 |
+
for s in samples:
|
| 57 |
+
sr = s['sample_rate']
|
| 58 |
+
if sr != newsr:
|
| 59 |
+
if sr != _last_sr: tform = torchaudio.transforms.Resample(sr, newsr)
|
| 60 |
+
s[key] = tform(s['samples'])
|
| 61 |
+
else:
|
| 62 |
+
s[key] = s['samples']
|
| 63 |
+
yield s
|
| 64 |
+
|
| 65 |
+
return _resample
|
| 66 |
+
|
| 67 |
+
# %% ../nbs/D. Common dataset utilities.ipynb 6
|
| 68 |
+
def derived_name(input, kind, base="audio", suffix=".gz", dir=None):
|
| 69 |
+
dir = Path(dir) if dir else Path(input).parent
|
| 70 |
+
return str(dir/(Path(input).name.replace(f"-{base}-", f"-{kind}-") + suffix))
|
| 71 |
+
|
| 72 |
+
# %% ../nbs/D. Common dataset utilities.ipynb 7
|
| 73 |
+
def derived_dataset(kind, base='audio', suffix=".gz", decoders=[], dir=None):
|
| 74 |
+
def deriver(url):
|
| 75 |
+
url = str(derived_name(url, kind, base=base, suffix=suffix, dir=dir))
|
| 76 |
+
return wds.WebDataset(
|
| 77 |
+
wds.SimpleShardList([url])
|
| 78 |
+
).decode(*decoders)
|
| 79 |
+
return deriver
|
| 80 |
+
|
| 81 |
+
# %% ../nbs/D. Common dataset utilities.ipynb 8
|
| 82 |
+
def merge_in(dataset_fun):
|
| 83 |
+
"""Merge a dataset into the current one returning samples with the union of keys. Pass in a function
|
| 84 |
+
that takes a URL of a sample and returns a dataset for it (called everytime the URL changes).
|
| 85 |
+
|
| 86 |
+
It requires (and validates) that both datasets have the same ordering of keys so you have
|
| 87 |
+
to use it before any sample shuffling. Shard shuffling is ok.
|
| 88 |
+
"""
|
| 89 |
+
def merge_loop(main_samples):
|
| 90 |
+
#print("new merge loop:", dataset_fun)
|
| 91 |
+
merged_samples = None
|
| 92 |
+
cur_url = None
|
| 93 |
+
i = None
|
| 94 |
+
for s in main_samples:
|
| 95 |
+
url = s['__url__']
|
| 96 |
+
if url != cur_url:
|
| 97 |
+
# this will open a new file when we get the first sample with a new __url__
|
| 98 |
+
merged_samples = iter(dataset_fun(url))
|
| 99 |
+
cur_url = url
|
| 100 |
+
try:
|
| 101 |
+
merge_s = next(merged_samples)
|
| 102 |
+
except StopIteration:
|
| 103 |
+
# if the original shard got repeated we won't observe a __url__ change
|
| 104 |
+
# in this case restart the dataset from the beginning
|
| 105 |
+
merged_samples = iter(dataset_fun(url))
|
| 106 |
+
merge_s = next(merged_samples)
|
| 107 |
+
assert merge_s['__key__'] == s['__key__'], f"sample keys don't match: {merge_s['__key__']}, {s['__key__']} in file {s['__url__']}"
|
| 108 |
+
news = {}
|
| 109 |
+
news.update(merge_s)
|
| 110 |
+
news.update(s)
|
| 111 |
+
yield news
|
| 112 |
+
return merge_loop
|
| 113 |
+
|
| 114 |
+
# %% ../nbs/D. Common dataset utilities.ipynb 9
|
| 115 |
+
def split_to_chunks(stream, ikey='vad.npy', metakeys=[], pad_to_seconds=30, random_shift=False):
|
| 116 |
+
for s in stream:
|
| 117 |
+
audio, sr = s['audio']
|
| 118 |
+
imax = len(s[ikey]) - 1
|
| 119 |
+
for i,(ts,te) in enumerate(s[ikey]):
|
| 120 |
+
samples = audio[0,int(ts*sr):int(te*sr)]
|
| 121 |
+
if pad_to_seconds is not None:
|
| 122 |
+
padding = pad_to_seconds*sr-samples.shape[-1]
|
| 123 |
+
lpad = random.randint(0, padding) if random_shift else 0
|
| 124 |
+
samples = F.pad(samples, (lpad, padding-lpad))
|
| 125 |
+
subs = {"__key__": s['__key__'] + f"_{i:03d}",
|
| 126 |
+
"src_key": s['__key__'],
|
| 127 |
+
"__url__": s['__url__'],
|
| 128 |
+
"i": i, "imax": imax,
|
| 129 |
+
"tstart": ts, "tend": te, "total_seconds": audio.shape[-1]/sr,
|
| 130 |
+
"lpad": lpad, "rpad": padding-lpad,
|
| 131 |
+
"lpad_s": lpad/sr, "rpad_s": (padding-lpad)/sr,
|
| 132 |
+
"samples": samples, "sample_rate": sr}
|
| 133 |
+
for k in metakeys:
|
| 134 |
+
subs[k] = s[k][i]
|
| 135 |
+
yield subs
|
| 136 |
+
|
| 137 |
+
# %% ../nbs/D. Common dataset utilities.ipynb 10
|
| 138 |
+
def vad_dataset(shards, ikey='vad.npy', kind='vad'):
|
| 139 |
+
return wds.WebDataset(shards).compose(
|
| 140 |
+
wds.decode(wds.torch_audio),
|
| 141 |
+
merge_in(derived_dataset(kind)),
|
| 142 |
+
wds.select(lambda x: 'wav' in x or 'flac' in x or 'mp3' in x or 'ogg' in x), # skip samples without audio
|
| 143 |
+
wds.rename(audio="flac;mp3;wav;ogg"),
|
| 144 |
+
lambda x: split_to_chunks(x, ikey=ikey),
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
# %% ../nbs/D. Common dataset utilities.ipynb 11
|
| 148 |
+
@contextmanager
|
| 149 |
+
def AtomicTarWriter(name, throwaway=False):
|
| 150 |
+
tmp = name+".tmp"
|
| 151 |
+
with wds.TarWriter(tmp, compress=name.endswith('gz')) as sink:
|
| 152 |
+
yield sink
|
| 153 |
+
if not throwaway:
|
| 154 |
+
os.rename(tmp, name)
|
| 155 |
+
|
| 156 |
+
# %% ../nbs/D. Common dataset utilities.ipynb 12
|
| 157 |
+
def readlines(fname):
|
| 158 |
+
with open(fname) as file:
|
| 159 |
+
return [line.rstrip() for line in file]
|
whisperspeech/vad.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/1B. Voice activity detection.ipynb.
|
| 2 |
+
|
| 3 |
+
# %% auto 0
|
| 4 |
+
__all__ = []
|
| 5 |
+
|
| 6 |
+
# %% ../nbs/1B. Voice activity detection.ipynb 3
|
| 7 |
+
import os
|
| 8 |
+
import torch
|
| 9 |
+
import torchaudio
|
| 10 |
+
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
from fastprogress import progress_bar
|
| 13 |
+
from fastcore.script import call_parse
|
| 14 |
+
|
| 15 |
+
import whisperx
|
| 16 |
+
import random
|
| 17 |
+
import numpy as np
|
| 18 |
+
import webdataset as wds
|
| 19 |
+
|
| 20 |
+
# %% ../nbs/1B. Voice activity detection.ipynb 5
|
| 21 |
+
# some of the original file names have a dot in their name
|
| 22 |
+
# webdataset does not like it so let's patch it
|
| 23 |
+
def fix_dots_in_names(name):
|
| 24 |
+
name, ext = name.rsplit('.', 1)
|
| 25 |
+
return ".".join((name.replace('.', '_'), ext))
|
| 26 |
+
|
| 27 |
+
def load_dataset(url, decode=True, rename_files=None):
|
| 28 |
+
ds = wds.WebDataset(url, rename_files=rename_files)
|
| 29 |
+
if not decode: return ds
|
| 30 |
+
return ds.decode(wds.torch_audio)
|
| 31 |
+
|
| 32 |
+
# %% ../nbs/1B. Voice activity detection.ipynb 7
|
| 33 |
+
def extract_segments(vad_result, max_duration):
|
| 34 |
+
binarize = whisperx.vad.Binarize(max_duration=max_duration)
|
| 35 |
+
segments = binarize(vad_result)
|
| 36 |
+
return [(x.start, x.end) for x in segments.get_timeline()]
|
| 37 |
+
|
| 38 |
+
def segment_audio(vad_model, audio, sr=16000):
|
| 39 |
+
vad_result = vad_model({"waveform": audio, "sample_rate": sr})
|
| 40 |
+
return extract_segments(vad_result, 30)
|
| 41 |
+
|
| 42 |
+
# %% ../nbs/1B. Voice activity detection.ipynb 13
|
| 43 |
+
def flac_to_vad_name(input):
|
| 44 |
+
if '-flac-' in input:
|
| 45 |
+
return input.rsplit("/", 1)[1].replace('flac', 'vad') + ".gz"
|
| 46 |
+
else:
|
| 47 |
+
return input.rsplit("/", 1)[1].replace('raw', 'vad') + ".gz"
|
| 48 |
+
|
| 49 |
+
@call_parse
|
| 50 |
+
def process_shard(
|
| 51 |
+
input:str, # input shard URL/path
|
| 52 |
+
output:str=None, # output shard URL/path
|
| 53 |
+
fix_dots:bool=False, # fix dots in LibriLight filenames
|
| 54 |
+
):
|
| 55 |
+
if output is None: output = flac_to_vad_name(input)
|
| 56 |
+
|
| 57 |
+
ds = torch.utils.data.DataLoader(load_dataset(input, rename_files=fix_dots_in_names if fix_dots else None), num_workers=2, batch_size=None)
|
| 58 |
+
vad_model = whisperx.vad.load_vad_model('cuda')
|
| 59 |
+
|
| 60 |
+
tmp = output+".tmp"
|
| 61 |
+
with wds.TarWriter(tmp) as sink:
|
| 62 |
+
for s in progress_bar(ds, total='noinfer'):
|
| 63 |
+
audio, sr = s.get('flac', s.get('wav', (None, None)))
|
| 64 |
+
if audio is None:
|
| 65 |
+
print(f"warning: '{s['__key__']}' does not contain an audio file")
|
| 66 |
+
continue
|
| 67 |
+
sink.write({
|
| 68 |
+
"__key__": s['__key__'],
|
| 69 |
+
"vad.npy": np.array(segment_audio(vad_model, audio, sr=sr), dtype=np.float16)
|
| 70 |
+
})
|
| 71 |
+
os.rename(tmp, output)
|
whisperspeech/vq_stoks.py
ADDED
|
@@ -0,0 +1,493 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/2B. Whisper quantization (semantic token) model.ipynb.
|
| 2 |
+
|
| 3 |
+
# %% auto 0
|
| 4 |
+
__all__ = ['RQBottleneckTransformer', 'make_model']
|
| 5 |
+
|
| 6 |
+
# %% ../nbs/2B. Whisper quantization (semantic token) model.ipynb 2
|
| 7 |
+
import io
|
| 8 |
+
import sys
|
| 9 |
+
import time
|
| 10 |
+
import torch
|
| 11 |
+
import torchaudio
|
| 12 |
+
|
| 13 |
+
# %% ../nbs/2B. Whisper quantization (semantic token) model.ipynb 3
|
| 14 |
+
from pathlib import Path
|
| 15 |
+
import json
|
| 16 |
+
from fastprogress import progress_bar, master_bar
|
| 17 |
+
import fastprogress
|
| 18 |
+
import numpy as np
|
| 19 |
+
import pylab as plt
|
| 20 |
+
import pandas as pd
|
| 21 |
+
import random
|
| 22 |
+
|
| 23 |
+
import whisper
|
| 24 |
+
from huggingface_hub import hf_hub_download
|
| 25 |
+
from fastcore.basics import store_attr
|
| 26 |
+
|
| 27 |
+
from torch import nn
|
| 28 |
+
import torch.optim as optim
|
| 29 |
+
import torch.nn.functional as F
|
| 30 |
+
from torch.utils.data.dataloader import DataLoader
|
| 31 |
+
import webdataset as wds
|
| 32 |
+
from . import utils
|
| 33 |
+
|
| 34 |
+
from vector_quantize_pytorch import ResidualVQ
|
| 35 |
+
|
| 36 |
+
from fastcore.script import *
|
| 37 |
+
|
| 38 |
+
# %% ../nbs/2B. Whisper quantization (semantic token) model.ipynb 9
|
| 39 |
+
def merge_in(dataset_fun):
|
| 40 |
+
"""Merge a dataset into the current one returning samples with the union of keys. Pass in a function
|
| 41 |
+
that takes a URL of a sample and returns a dataset for it (called everytime the URL changes).
|
| 42 |
+
|
| 43 |
+
It requires (and validates) that both datasets have the same ordering of keys so you have
|
| 44 |
+
to use it before any sample shuffling. Shard shuffling is ok.
|
| 45 |
+
"""
|
| 46 |
+
def merge_loop(main_samples):
|
| 47 |
+
#print("new merge loop:", dataset_fun)
|
| 48 |
+
merged_samples = None
|
| 49 |
+
cur_url = None
|
| 50 |
+
i = None
|
| 51 |
+
for s in main_samples:
|
| 52 |
+
url = s['__url__']
|
| 53 |
+
if url != cur_url:
|
| 54 |
+
# this will open a new file when we get the first sample with a new __url__
|
| 55 |
+
merged_samples = iter(dataset_fun(url))
|
| 56 |
+
cur_url = url
|
| 57 |
+
try:
|
| 58 |
+
merge_s = next(merged_samples)
|
| 59 |
+
except StopIteration:
|
| 60 |
+
# if the original shard got repeated we won't observe a __url__ change
|
| 61 |
+
# in this case restart the dataset from the beginning
|
| 62 |
+
merged_samples = iter(dataset_fun(url))
|
| 63 |
+
merge_s = next(merged_samples)
|
| 64 |
+
assert merge_s['__key__'] == s['__key__'], f"sample keys don't match: {merge_s['__key__']}, {s['__key__']} in file {s['__url__']}"
|
| 65 |
+
news = {}
|
| 66 |
+
news.update(merge_s)
|
| 67 |
+
news.update(s)
|
| 68 |
+
yield news
|
| 69 |
+
return merge_loop
|
| 70 |
+
|
| 71 |
+
# %% ../nbs/2B. Whisper quantization (semantic token) model.ipynb 10
|
| 72 |
+
def derived_dataset(kind, key='audio'):
|
| 73 |
+
def deriver(url):
|
| 74 |
+
url = str(Path(url).parent/(Path(url).name.replace(key, kind) + ".gz"))
|
| 75 |
+
return wds.WebDataset(
|
| 76 |
+
wds.SimpleShardList([url])
|
| 77 |
+
).decode()
|
| 78 |
+
return deriver
|
| 79 |
+
|
| 80 |
+
# %% ../nbs/2B. Whisper quantization (semantic token) model.ipynb 17
|
| 81 |
+
def add_masks(samples):
|
| 82 |
+
for s in samples:
|
| 83 |
+
seconds = s['tend'] - s['tstart']
|
| 84 |
+
# a mask (downsampled to the Whisper encoder token rate of 50/s) is used
|
| 85 |
+
# to teach the model the concept of padding
|
| 86 |
+
# this let's us decode shorter sequences later
|
| 87 |
+
mask = torch.zeros(30*16000//320, dtype=torch.bool)
|
| 88 |
+
mask[:int(seconds * 16000) // 320] = 1
|
| 89 |
+
s['mask'] = mask
|
| 90 |
+
yield s
|
| 91 |
+
|
| 92 |
+
def tokenize_text(samples, ttoks_size=200, model="base.en", language="en"):
|
| 93 |
+
multilingual = not model.endswith(".en")
|
| 94 |
+
tokenizer = whisper.tokenizer.get_tokenizer(multilingual, language=language, task="transcribe")
|
| 95 |
+
for s in samples:
|
| 96 |
+
ttoks = tokenizer.encode(s['txt'])
|
| 97 |
+
tokens = list(tokenizer.sot_sequence) + ttoks
|
| 98 |
+
rpad = ttoks_size - len(tokens)
|
| 99 |
+
s['in_ttoks'] = F.pad(torch.tensor(tokens), (0, rpad), value=tokenizer.eot)
|
| 100 |
+
s['out_ttoks'] = F.pad(torch.tensor(tokens[1:] + [tokenizer.eot]), (0, rpad), value=-100)
|
| 101 |
+
yield s
|
| 102 |
+
|
| 103 |
+
# %% ../nbs/2B. Whisper quantization (semantic token) model.ipynb 22
|
| 104 |
+
def load_dataset(
|
| 105 |
+
shard_spec:str,
|
| 106 |
+
proc_dataset_path:Path, # processed VAD and txt files
|
| 107 |
+
samples:int, # set the per-GPU sample count
|
| 108 |
+
txt_label:str="base.en-txt", # the label of the files containing transcriptions
|
| 109 |
+
model:str="base.en",
|
| 110 |
+
key:str="flac",
|
| 111 |
+
language:str=None,
|
| 112 |
+
validation:bool=False,
|
| 113 |
+
):
|
| 114 |
+
from . import wh_transcribe
|
| 115 |
+
shards = utils.shard_glob(shard_spec)
|
| 116 |
+
|
| 117 |
+
if not language and model.endswith('en'): language = 'en'
|
| 118 |
+
assert language, "please provide the dataset language for multilang models"
|
| 119 |
+
|
| 120 |
+
same_on_all_nodes = lambda urls: urls # will only be used for validation
|
| 121 |
+
ds = wds.WebDataset(shards, resampled=not validation, nodesplitter=same_on_all_nodes).compose(
|
| 122 |
+
wds.decode(wds.torch_audio),
|
| 123 |
+
wds.select(lambda x: 'wav' in x or 'flac' in x or 'mp3' in x or 'ogg' in x), # skip samples without audio
|
| 124 |
+
wds.rename(audio="flac;mp3;wav;ogg"),
|
| 125 |
+
merge_in(derived_dataset(proc_dataset_path, 'vad', key=key)),
|
| 126 |
+
wds.map_dict(**{"vad.npy":wh_transcribe.chunk_merger}),
|
| 127 |
+
wh_transcribe.split_to_chunks,
|
| 128 |
+
utils.resampler(16000, 'samples_16k'),
|
| 129 |
+
merge_in(derived_dataset(proc_dataset_path, txt_label, key=key)),
|
| 130 |
+
)
|
| 131 |
+
if 'librilight' in shards[0]:
|
| 132 |
+
ds = ds.compose(
|
| 133 |
+
# drop the first and last segment because they tend to be inaccurate
|
| 134 |
+
# (the transcriptions don't have the "LibriVox" headers and "end of chapter" suffixes)
|
| 135 |
+
wds.select(lambda x: x['i'] != 0 and x['i'] != x['imax']),
|
| 136 |
+
)
|
| 137 |
+
ds = ds.compose(
|
| 138 |
+
add_masks,
|
| 139 |
+
lambda x: tokenize_text(x, model=model, language=language),
|
| 140 |
+
wds.to_tuple('samples_16k', 'mask', 'in_ttoks', 'out_ttoks'),
|
| 141 |
+
wds.batched(32),
|
| 142 |
+
)
|
| 143 |
+
ds.total_samples = samples
|
| 144 |
+
|
| 145 |
+
return ds
|
| 146 |
+
|
| 147 |
+
# %% ../nbs/2B. Whisper quantization (semantic token) model.ipynb 28
|
| 148 |
+
from whisperspeech.train import *
|
| 149 |
+
from whisperspeech.modules import *
|
| 150 |
+
|
| 151 |
+
# %% ../nbs/2B. Whisper quantization (semantic token) model.ipynb 29
|
| 152 |
+
import dataclasses
|
| 153 |
+
|
| 154 |
+
def rand(start, end):
|
| 155 |
+
return random.random() * (end - start) + start
|
| 156 |
+
|
| 157 |
+
def logrand(start, end):
|
| 158 |
+
return 10**rand(math.log10(start), math.log10(end))
|
| 159 |
+
|
| 160 |
+
@dataclasses.dataclass
|
| 161 |
+
class Tunables:
|
| 162 |
+
init_std :float = 1.5
|
| 163 |
+
embeddings_std :float = 4.5e-2
|
| 164 |
+
embeddings_lr_scale: float = 1
|
| 165 |
+
output_mult :float = 1
|
| 166 |
+
query_mult :float = 2
|
| 167 |
+
rope :bool = True
|
| 168 |
+
mask_embs :bool = True # force embeddings corresponding to the input audio padding to a constant value
|
| 169 |
+
downsample_conv: bool = False
|
| 170 |
+
downsample_mean: bool = True
|
| 171 |
+
|
| 172 |
+
codebook_dim: int = 32
|
| 173 |
+
codebook_decay: float = 0.9
|
| 174 |
+
|
| 175 |
+
lr0 :float = .9e-3
|
| 176 |
+
clip_gradient_norm :float = 2
|
| 177 |
+
weight_decay :float = 1e-3
|
| 178 |
+
warmup_steps :float = 850
|
| 179 |
+
|
| 180 |
+
random :bool = False
|
| 181 |
+
|
| 182 |
+
def __post_init__(self):
|
| 183 |
+
# randomize the hyperparams if requested
|
| 184 |
+
if self.random:
|
| 185 |
+
self.init_std = logrand(1, 2)
|
| 186 |
+
self.embeddings_std = logrand(3e-2,6e-2)
|
| 187 |
+
self.embeddings_lr_scale = 2**rand(0,3)
|
| 188 |
+
self.output_mult = 2**rand(-3,3)
|
| 189 |
+
self.query_mult = logrand(1,8)
|
| 190 |
+
self.codebook_dim = int(logrand(30,50))
|
| 191 |
+
self.codebook_decay = logrand(0.86,0.95)
|
| 192 |
+
self.rope = True
|
| 193 |
+
self.mask_embs = True
|
| 194 |
+
self.downsample_mean = True
|
| 195 |
+
|
| 196 |
+
self.lr0 = logrand(.8e-3,1e-3)
|
| 197 |
+
self.clip_gradient_norm = 10**rand(-1,1)
|
| 198 |
+
self.warmup_steps = logrand(700,1000)
|
| 199 |
+
|
| 200 |
+
@staticmethod
|
| 201 |
+
def upgrade(args):
|
| 202 |
+
args = {k:v for k,v in args.items()}
|
| 203 |
+
def old_default(name, value):
|
| 204 |
+
if name not in args: args[name] = value
|
| 205 |
+
old_default('output_mult', 1)
|
| 206 |
+
old_default('query_mult', 1)
|
| 207 |
+
old_default('rope', False)
|
| 208 |
+
old_default('mask_embs', False)
|
| 209 |
+
old_default('downsample_conv', False)
|
| 210 |
+
old_default('downsample_mean', False)
|
| 211 |
+
if 'encoder_depth_ratio' in args: del args['encoder_depth_ratio']
|
| 212 |
+
if 'vq_codes' in args: del args['vq_codes']
|
| 213 |
+
return args
|
| 214 |
+
|
| 215 |
+
# %% ../nbs/2B. Whisper quantization (semantic token) model.ipynb 30
|
| 216 |
+
import math
|
| 217 |
+
|
| 218 |
+
# %% ../nbs/2B. Whisper quantization (semantic token) model.ipynb 31
|
| 219 |
+
class RQBottleneckTransformer(nn.Module):
|
| 220 |
+
def __init__(self, vq_codes=512, q_depth=12, depth=1, n_head=2, head_width=64, ffn_mult=4,
|
| 221 |
+
codebook_dim=2, threshold_ema_dead_code=2, use_cosine_sim = False, kl_loss_mul=1,
|
| 222 |
+
downsample=1,
|
| 223 |
+
whisper_model_name='tiny.en', tunables=Tunables()):
|
| 224 |
+
super().__init__()
|
| 225 |
+
width = n_head * head_width
|
| 226 |
+
store_attr("codebook_dim,vq_codes,q_depth,n_head,head_width,ffn_mult,depth,use_cosine_sim,downsample,whisper_model_name")
|
| 227 |
+
self.width = width
|
| 228 |
+
self.base_width = 3 * head_width
|
| 229 |
+
self.vq_codes = vq_codes
|
| 230 |
+
self.tunables = tunables
|
| 231 |
+
self.stoks_len = 1500//downsample
|
| 232 |
+
self.stoks_per_sec = self.stoks_len//30
|
| 233 |
+
|
| 234 |
+
qk_scale = self.tunables.query_mult * 8 / math.sqrt(head_width)
|
| 235 |
+
|
| 236 |
+
self.kl_loss_mul = kl_loss_mul
|
| 237 |
+
|
| 238 |
+
n_mlp = width * ffn_mult
|
| 239 |
+
self.mlp = nn.Sequential(
|
| 240 |
+
nn.Linear(width, n_mlp), nn.GELU(), nn.Linear(n_mlp, width)
|
| 241 |
+
)
|
| 242 |
+
self.mlp_ln = LayerNorm(width)
|
| 243 |
+
|
| 244 |
+
if tunables.downsample_conv:
|
| 245 |
+
self.downsample_conv = nn.Conv1d(width, width, kernel_size=3, stride=downsample, padding=1)
|
| 246 |
+
else:
|
| 247 |
+
self.downsample_conv = None
|
| 248 |
+
|
| 249 |
+
if tunables.mask_embs: vq_codes = vq_codes + 1
|
| 250 |
+
self.rq = ResidualVQ(
|
| 251 |
+
dim = width,
|
| 252 |
+
codebook_size = vq_codes, # codebook size
|
| 253 |
+
decay = tunables.codebook_decay, # the exponential moving average decay, lower means the dictionary will change faster
|
| 254 |
+
commitment_weight = 1., # the weight on the commitment loss
|
| 255 |
+
threshold_ema_dead_code = threshold_ema_dead_code,
|
| 256 |
+
use_cosine_sim = use_cosine_sim,
|
| 257 |
+
codebook_dim = codebook_dim,
|
| 258 |
+
num_quantizers= 1,
|
| 259 |
+
)
|
| 260 |
+
|
| 261 |
+
self.ce_lossf = nn.CrossEntropyLoss(ignore_index=-100)
|
| 262 |
+
self.kl_lossf = nn.KLDivLoss(reduction='batchmean')
|
| 263 |
+
|
| 264 |
+
self.positional_embedding = nn.Embedding(1500, width) # FIXME: should be self.stoks_len
|
| 265 |
+
|
| 266 |
+
self.out_blocks = nn.Sequential(*[
|
| 267 |
+
ResidualAttentionBlock(width, n_head, qk_scale=qk_scale, ffn_mult=ffn_mult, rope=tunables.rope) for _ in range(depth)
|
| 268 |
+
])
|
| 269 |
+
self.ln_post = LayerNorm(width)
|
| 270 |
+
|
| 271 |
+
self.whmodel = None
|
| 272 |
+
|
| 273 |
+
self.apply(self.init_transformer)
|
| 274 |
+
self.register_buffer('val_true', torch.zeros(1).cuda())
|
| 275 |
+
self.register_buffer('val_total', torch.zeros(1).cuda())
|
| 276 |
+
|
| 277 |
+
def setup(self, device):
|
| 278 |
+
self.ensure_whisper(device)
|
| 279 |
+
|
| 280 |
+
def init_transformer(self, m):
|
| 281 |
+
if isinstance(m, LinearHead):
|
| 282 |
+
m.no_weight_decay = True
|
| 283 |
+
torch.nn.init.constant_(m.weight, 0)
|
| 284 |
+
elif isinstance(m, QueryHead):
|
| 285 |
+
m.lr_scale = 1/(m.weight.shape[1] / self.base_width)
|
| 286 |
+
torch.nn.init.constant_(m.weight, 0)
|
| 287 |
+
elif isinstance(m, nn.Embedding):
|
| 288 |
+
m.no_weight_decay = True
|
| 289 |
+
m.lr_scale = self.tunables.embeddings_lr_scale
|
| 290 |
+
std = self.tunables.embeddings_std
|
| 291 |
+
torch.nn.init.trunc_normal_(m.weight, std=std, a=-3*std, b=3*std)
|
| 292 |
+
elif isinstance(m, nn.Linear):
|
| 293 |
+
m.lr_scale = 1/(m.weight.shape[1] / self.base_width)
|
| 294 |
+
std = self.tunables.init_std / m.weight.shape[1]
|
| 295 |
+
torch.nn.init.trunc_normal_(m.weight, std=std, a=-3*std, b=3*std)
|
| 296 |
+
if m.bias is not None:
|
| 297 |
+
torch.nn.init.trunc_normal_(m.bias, std=std, a=-3*std, b=3*std)
|
| 298 |
+
elif isinstance(m, nn.LayerNorm):
|
| 299 |
+
m.no_weight_decay = True
|
| 300 |
+
torch.nn.init.constant_(m.bias, 0)
|
| 301 |
+
torch.nn.init.constant_(m.weight, 1)
|
| 302 |
+
|
| 303 |
+
@property
|
| 304 |
+
def device(self):
|
| 305 |
+
return next(self.parameters()).device
|
| 306 |
+
|
| 307 |
+
#
|
| 308 |
+
# training
|
| 309 |
+
#
|
| 310 |
+
@torch.no_grad()
|
| 311 |
+
def extract_teacher(self, samples, input_toks, output_toks):
|
| 312 |
+
embs = self.whmodel[0].encoder(whisper.log_mel_spectrogram(samples))
|
| 313 |
+
teacher_logits = self.whmodel[0].decoder(input_toks, embs)
|
| 314 |
+
# set teacher logits to 0 for padding positions so KLDivLoss ignores them
|
| 315 |
+
teacher_logits[output_toks == -100] = 0
|
| 316 |
+
return embs, teacher_logits
|
| 317 |
+
|
| 318 |
+
def downsample_embeddings(self, x):
|
| 319 |
+
if self.downsample_conv is not None:
|
| 320 |
+
return x[:,::self.downsample] + self.downsample_conv(x.transpose(-1,-2)).transpose(-2,-1)
|
| 321 |
+
elif self.tunables.downsample_mean:
|
| 322 |
+
bs,slen,depth = x.shape
|
| 323 |
+
return x.reshape(bs,slen//self.downsample,self.downsample,depth).mean(-2)
|
| 324 |
+
else:
|
| 325 |
+
return x[:,::self.downsample]
|
| 326 |
+
|
| 327 |
+
def forward(self, samples, mask, input_toks, output_toks):
|
| 328 |
+
embs, teacher_logits = self.extract_teacher(samples, input_toks, output_toks)
|
| 329 |
+
|
| 330 |
+
x = self.downsample_embeddings(embs)
|
| 331 |
+
x = x + self.mlp(self.mlp_ln(x))
|
| 332 |
+
# VQ bottleneck
|
| 333 |
+
quantized, self.indices, self.commit_loss = self.rq(x)
|
| 334 |
+
self.commit_loss = self.commit_loss.mean()
|
| 335 |
+
|
| 336 |
+
x = quantized.repeat_interleave(self.downsample, -2)
|
| 337 |
+
project_out = getattr(self.rq, 'project_out', None) or self.rq.layers[0].project_out
|
| 338 |
+
if self.tunables.mask_embs: x[~mask] = project_out(self.rq.layers[0]._codebook.embed[0,self.vq_codes])
|
| 339 |
+
positions = torch.arange(0, x.shape[-2], dtype=torch.long, device=x.device)
|
| 340 |
+
x = x + self.positional_embedding(positions)
|
| 341 |
+
x = self.ln_post(self.out_blocks(x))
|
| 342 |
+
|
| 343 |
+
logits = self.whmodel[0].decoder(input_toks, x)
|
| 344 |
+
self.ce_loss = self.ce_lossf(logits.view(-1,logits.shape[-1]), output_toks.view(-1))
|
| 345 |
+
self.kl_loss = self.kl_lossf(F.log_softmax(logits, dim=-1), F.softmax(teacher_logits, dim=-1))
|
| 346 |
+
loss = self.ce_loss + self.kl_loss_mul * self.kl_loss + self.commit_loss
|
| 347 |
+
|
| 348 |
+
if not self.training:
|
| 349 |
+
valid_toks = output_toks != -100
|
| 350 |
+
self.val_true += (logits.argmax(-1)[valid_toks] == output_toks[valid_toks]).float().sum()
|
| 351 |
+
self.val_total += valid_toks.float().sum()
|
| 352 |
+
|
| 353 |
+
return x, loss
|
| 354 |
+
|
| 355 |
+
def get_metrics(self):
|
| 356 |
+
metrics = {
|
| 357 |
+
'acc_0': (self.val_true / self.val_total).item(),
|
| 358 |
+
}
|
| 359 |
+
self.val_true[:] = 0
|
| 360 |
+
self.val_total[:] = 0
|
| 361 |
+
return metrics
|
| 362 |
+
|
| 363 |
+
#
|
| 364 |
+
# inference
|
| 365 |
+
#
|
| 366 |
+
@classmethod
|
| 367 |
+
def load_model(cls, ref="collabora/spear-tts-pytorch:whisper-vq-stoks-medium-en+pl.model",
|
| 368 |
+
repo_id=None, filename=None, local_filename=None):
|
| 369 |
+
if repo_id is None and filename is None and local_filename is None:
|
| 370 |
+
if ":" in ref:
|
| 371 |
+
repo_id, filename = ref.split(":", 1)
|
| 372 |
+
else:
|
| 373 |
+
local_filename = ref
|
| 374 |
+
if not local_filename:
|
| 375 |
+
local_filename = hf_hub_download(repo_id=repo_id, filename=filename)
|
| 376 |
+
spec = torch.load(local_filename)
|
| 377 |
+
vqmodel = cls(**spec['config'], tunables=Tunables(**Tunables.upgrade(spec.get('tunables', {}))))
|
| 378 |
+
vqmodel.load_state_dict(spec['state_dict'])
|
| 379 |
+
vqmodel.eval()
|
| 380 |
+
return vqmodel
|
| 381 |
+
|
| 382 |
+
def load_checkpoint(self, local_filename):
|
| 383 |
+
spec = torch.load(local_filename, map_location='cpu')
|
| 384 |
+
assert 'pytorch-lightning_version' in spec, 'not a valid PyTorch Lightning checkpoint'
|
| 385 |
+
state_dict = {k.replace('model.', ''):v
|
| 386 |
+
for k,v in spec['state_dict'].items()}
|
| 387 |
+
self.load_state_dict(state_dict)
|
| 388 |
+
return self
|
| 389 |
+
|
| 390 |
+
def save_model(self, fname, store_parameters=True):
|
| 391 |
+
torch.save(dict(config = self.__stored_args__,
|
| 392 |
+
tunables = dataclasses.asdict(self.tunables),
|
| 393 |
+
state_dict = self.state_dict() if store_parameters else None), fname)
|
| 394 |
+
|
| 395 |
+
def ensure_whisper(self, device):
|
| 396 |
+
# the list wrapper is a hack to make sure the whole of Whisper is not sucked into self.parameters()
|
| 397 |
+
if self.whmodel is None: self.whmodel = [whisper.load_model(self.whisper_model_name, device=device)]
|
| 398 |
+
self.decoding_options = whisper.DecodingOptions()
|
| 399 |
+
multilingual = not self.whisper_model_name.endswith('.en')
|
| 400 |
+
self.tokenizer = whisper.tokenizer.get_tokenizer(multilingual)
|
| 401 |
+
|
| 402 |
+
def quantize(self, embs):
|
| 403 |
+
x = self.downsample_embeddings(embs)
|
| 404 |
+
x = x + self.mlp(self.mlp_ln(x))
|
| 405 |
+
_, stoks, _ = self.rq(x)
|
| 406 |
+
if self.q_depth == 1:
|
| 407 |
+
stoks = stoks.squeeze(-1)
|
| 408 |
+
return stoks
|
| 409 |
+
|
| 410 |
+
def dequantize(self, stoks):
|
| 411 |
+
assert self.q_depth == 1
|
| 412 |
+
assert len(stoks.shape) == 1, "batch processing is not supported"
|
| 413 |
+
if isinstance(stoks, np.ndarray): stoks = torch.tensor(stoks)
|
| 414 |
+
# remove padding
|
| 415 |
+
padding = torch.nonzero(stoks == self.vq_codes)
|
| 416 |
+
if padding.any(): stoks = stoks[:padding[0,0]]
|
| 417 |
+
stoks = F.pad(stoks, (0,self.stoks_len - stoks.shape[-1]), value=self.vq_codes if self.tunables.mask_embs else 0)
|
| 418 |
+
x = self.rq.layers[0]._codebook.embed[0,stoks.to(torch.long).view(-1)]
|
| 419 |
+
x = x.repeat_interleave(self.downsample, -2)
|
| 420 |
+
project_out = getattr(self.rq, 'project_out', None) or self.rq.layers[0].project_out
|
| 421 |
+
x = project_out(x).unsqueeze(0)
|
| 422 |
+
positions = torch.arange(0, x.shape[-2], dtype=torch.long, device=x.device)
|
| 423 |
+
x = x + self.positional_embedding(positions)
|
| 424 |
+
return self.ln_post(self.out_blocks(x))
|
| 425 |
+
|
| 426 |
+
def encode_audio(self, audio):
|
| 427 |
+
if isinstance(audio, str):
|
| 428 |
+
x, sr = torchaudio.load(audio)
|
| 429 |
+
x = torchaudio.transforms.Resample(sr, 16000)(x)[0]
|
| 430 |
+
audio = x.unsqueeze(0)
|
| 431 |
+
return self.encode_mel(whisper.log_mel_spectrogram(audio).to(self.device))
|
| 432 |
+
|
| 433 |
+
def encode_mel(self, mel):
|
| 434 |
+
assert len(mel.shape) == 3, "invalid mel spectrogram shape, expect (batch,chn,time)"
|
| 435 |
+
self.ensure_whisper(self.device)
|
| 436 |
+
n = mel.shape[-1]
|
| 437 |
+
if n > whisper.audio.N_FRAMES:
|
| 438 |
+
padding = 0
|
| 439 |
+
padded = mel[:,:,:whisper.audio.N_FRAMES]
|
| 440 |
+
else:
|
| 441 |
+
padding = -n % whisper.audio.N_FRAMES
|
| 442 |
+
padded = F.pad(mel, (0, padding), value=-1.5)
|
| 443 |
+
embs = self.whmodel[0].encoder(padded)#.to(self.whmodel[0].device))#[:,:n//2]
|
| 444 |
+
stoks = self.quantize(embs)
|
| 445 |
+
if self.tunables.mask_embs:
|
| 446 |
+
return stoks[:,:n//2//self.downsample]
|
| 447 |
+
else:
|
| 448 |
+
return stoks
|
| 449 |
+
|
| 450 |
+
def decode_text(self, stoks, decoding_options=None):
|
| 451 |
+
self.ensure_whisper(self.device)
|
| 452 |
+
if decoding_options is None: decoding_options = self.decoding_options
|
| 453 |
+
embs = self.dequantize(stoks).to(self.whmodel[0].device)
|
| 454 |
+
return self.whmodel[0].decode(embs, decoding_options)
|
| 455 |
+
|
| 456 |
+
# %% ../nbs/2B. Whisper quantization (semantic token) model.ipynb 33
|
| 457 |
+
def make_model(size:str, tunables:Tunables=Tunables(), dataset:torch.utils.data.Dataset=None):
|
| 458 |
+
if size == 'base.en-2d-4096c':
|
| 459 |
+
model = RQBottleneckTransformer(codebook_dim=32, vq_codes=4096, q_depth=1, n_head=8, depth=1,
|
| 460 |
+
downsample=2, threshold_ema_dead_code=0, use_cosine_sim=True,
|
| 461 |
+
whisper_model_name=size.split("-")[0], tunables=tunables)
|
| 462 |
+
return model
|
| 463 |
+
if size == 'base.en-2d-512c':
|
| 464 |
+
model = RQBottleneckTransformer(codebook_dim=32, vq_codes=512, q_depth=1, n_head=8, depth=1,
|
| 465 |
+
downsample=2, threshold_ema_dead_code=0, use_cosine_sim=True,
|
| 466 |
+
whisper_model_name=size.split("-")[0], tunables=tunables)
|
| 467 |
+
return model
|
| 468 |
+
if size == 'base.en-2d-512c-dim64':
|
| 469 |
+
model = RQBottleneckTransformer(codebook_dim=64, vq_codes=512, q_depth=1, n_head=8, depth=1,
|
| 470 |
+
downsample=2, threshold_ema_dead_code=0, use_cosine_sim=True,
|
| 471 |
+
whisper_model_name=size.split("-")[0], tunables=tunables)
|
| 472 |
+
return model
|
| 473 |
+
if size == 'base-2d-512c-dim64':
|
| 474 |
+
model = RQBottleneckTransformer(codebook_dim=64, vq_codes=512, q_depth=1, n_head=8, depth=1,
|
| 475 |
+
downsample=2, threshold_ema_dead_code=0, use_cosine_sim=True,
|
| 476 |
+
whisper_model_name=size.split("-")[0], tunables=tunables)
|
| 477 |
+
return model
|
| 478 |
+
if size == 'base-2d-1024c-dim64':
|
| 479 |
+
model = RQBottleneckTransformer(codebook_dim=64, vq_codes=1024, q_depth=1, n_head=8, depth=1,
|
| 480 |
+
downsample=2, threshold_ema_dead_code=0, use_cosine_sim=True,
|
| 481 |
+
whisper_model_name=size.split("-")[0], tunables=tunables)
|
| 482 |
+
return model
|
| 483 |
+
if size == 'medium-2d-512c-dim64':
|
| 484 |
+
model = RQBottleneckTransformer(codebook_dim=64, vq_codes=512, q_depth=1, n_head=16, depth=1,
|
| 485 |
+
downsample=2, threshold_ema_dead_code=0, use_cosine_sim=True,
|
| 486 |
+
whisper_model_name=size.split("-")[0], tunables=tunables)
|
| 487 |
+
return model
|
| 488 |
+
if size == 'medium-2d-1024c-dim64':
|
| 489 |
+
model = RQBottleneckTransformer(codebook_dim=64, vq_codes=1024, q_depth=1, n_head=16, depth=1,
|
| 490 |
+
downsample=2, threshold_ema_dead_code=0, use_cosine_sim=True,
|
| 491 |
+
whisper_model_name=size.split("-")[0], tunables=tunables)
|
| 492 |
+
return model
|
| 493 |
+
raise ArgumentError(f"invalid model size: {size}")
|
whisperspeech/wer_metrics.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/C. Word error rate metrics.ipynb.
|
| 2 |
+
|
| 3 |
+
# %% auto 0
|
| 4 |
+
__all__ = ['librispeech_data', 'DfBuilder', 'WERStats']
|
| 5 |
+
|
| 6 |
+
# %% ../nbs/C. Word error rate metrics.ipynb 2
|
| 7 |
+
import jiwer
|
| 8 |
+
from whisper_normalizer.english import EnglishTextNormalizer
|
| 9 |
+
|
| 10 |
+
import torchaudio
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
import pandas as pd
|
| 13 |
+
|
| 14 |
+
# %% ../nbs/C. Word error rate metrics.ipynb 3
|
| 15 |
+
engnorm = EnglishTextNormalizer()
|
| 16 |
+
def whisper_normalize(x):
|
| 17 |
+
if type(x) == list:
|
| 18 |
+
return [engnorm(y) for y in x]
|
| 19 |
+
else:
|
| 20 |
+
return engnorm(x)
|
| 21 |
+
|
| 22 |
+
default_transform = jiwer.transforms.Compose([
|
| 23 |
+
jiwer.transforms.ToLowerCase(),
|
| 24 |
+
jiwer.transforms.ExpandCommonEnglishContractions(),
|
| 25 |
+
whisper_normalize,
|
| 26 |
+
jiwer.transforms.RemoveMultipleSpaces(),
|
| 27 |
+
jiwer.transforms.Strip(),
|
| 28 |
+
jiwer.transforms.RemovePunctuation(),
|
| 29 |
+
jiwer.transforms.ReduceToListOfListOfWords(),
|
| 30 |
+
])
|
| 31 |
+
|
| 32 |
+
# %% ../nbs/C. Word error rate metrics.ipynb 5
|
| 33 |
+
def librispeech_data(datadir, sample_rate=16000):
|
| 34 |
+
for file in Path(datadir).rglob('*.txt'):
|
| 35 |
+
for line in file.read_text().split('\n'):
|
| 36 |
+
if not line: continue
|
| 37 |
+
idx, text = line.split(" ", 1)
|
| 38 |
+
x, sr = torchaudio.load((file.parent/idx).with_suffix('.flac'))
|
| 39 |
+
if sr != sample_rate:
|
| 40 |
+
x = torchaudio.transforms.Resample(sr, self.sample_rate)(x)
|
| 41 |
+
yield x, text
|
| 42 |
+
|
| 43 |
+
# %% ../nbs/C. Word error rate metrics.ipynb 6
|
| 44 |
+
class DfBuilder:
|
| 45 |
+
def __init__(self):
|
| 46 |
+
self.data = {}
|
| 47 |
+
|
| 48 |
+
def push(self, **kwargs):
|
| 49 |
+
for k,v in kwargs.items():
|
| 50 |
+
if k not in self.data:
|
| 51 |
+
self.data[k] = [v]
|
| 52 |
+
else:
|
| 53 |
+
self.data[k].append(v)
|
| 54 |
+
|
| 55 |
+
def df(self):
|
| 56 |
+
return pd.DataFrame(self.data)
|
| 57 |
+
|
| 58 |
+
# %% ../nbs/C. Word error rate metrics.ipynb 7
|
| 59 |
+
class WERStats(DfBuilder):
|
| 60 |
+
def __init__(self, transform=default_transform):
|
| 61 |
+
super().__init__()
|
| 62 |
+
self.reference_transform = transform
|
| 63 |
+
self.hypothesis_transform = transform
|
| 64 |
+
|
| 65 |
+
def push_sample(self, snd, gt_text, text, idx=None):
|
| 66 |
+
if snd is not None: self.push(secs = snd.shape[-1]/16000)
|
| 67 |
+
diff = jiwer.process_words(gt_text, text, reference_transform=self.reference_transform, hypothesis_transform=self.hypothesis_transform)
|
| 68 |
+
self.push(
|
| 69 |
+
idx = idx,
|
| 70 |
+
gt_text = gt_text,
|
| 71 |
+
text = text,
|
| 72 |
+
wer = diff.wer,
|
| 73 |
+
mer = diff.mer,
|
| 74 |
+
wil = diff.wil,
|
| 75 |
+
wip = diff.wip,
|
| 76 |
+
)
|
| 77 |
+
return diff
|
whisperspeech/wh_transcribe.py
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/2A. Whisper quantization dataset preparation.ipynb.
|
| 2 |
+
|
| 3 |
+
# %% auto 0
|
| 4 |
+
__all__ = []
|
| 5 |
+
|
| 6 |
+
# %% ../nbs/2A. Whisper quantization dataset preparation.ipynb 3
|
| 7 |
+
import os
|
| 8 |
+
import io
|
| 9 |
+
import time
|
| 10 |
+
import torch
|
| 11 |
+
import torchaudio
|
| 12 |
+
|
| 13 |
+
# %% ../nbs/2A. Whisper quantization dataset preparation.ipynb 4
|
| 14 |
+
from pathlib import Path
|
| 15 |
+
import json
|
| 16 |
+
from fastprogress import progress_bar, master_bar
|
| 17 |
+
import numpy as np
|
| 18 |
+
import random
|
| 19 |
+
|
| 20 |
+
import whisper
|
| 21 |
+
|
| 22 |
+
from torch import nn
|
| 23 |
+
import torch.nn.functional as F
|
| 24 |
+
from torch.utils.data.dataloader import DataLoader
|
| 25 |
+
|
| 26 |
+
from fastcore.script import *
|
| 27 |
+
|
| 28 |
+
from . import vad
|
| 29 |
+
import webdataset as wds
|
| 30 |
+
|
| 31 |
+
# %% ../nbs/2A. Whisper quantization dataset preparation.ipynb 9
|
| 32 |
+
# let's make it a bit more conservative
|
| 33 |
+
# with full 30 second chunks it sometimes misses a small part of the transcript
|
| 34 |
+
def random_cutter(dur):
|
| 35 |
+
if random.random() < 0.5:
|
| 36 |
+
return dur > 28 * (random.random()*0.95+0.05)
|
| 37 |
+
else:
|
| 38 |
+
return dur > 28
|
| 39 |
+
|
| 40 |
+
def chunk_merger(segments, should_cut=lambda x: x > 28):
|
| 41 |
+
if len(segments) == 0: return segments
|
| 42 |
+
curr_start = segments[0][0]
|
| 43 |
+
curr_end = 0
|
| 44 |
+
merged = []
|
| 45 |
+
|
| 46 |
+
for ts,te in segments:
|
| 47 |
+
if should_cut(te - curr_start) and curr_end - curr_start > 0:
|
| 48 |
+
merged.append((curr_start, curr_end))
|
| 49 |
+
curr_start = ts
|
| 50 |
+
curr_end = te
|
| 51 |
+
merged.append((curr_start, curr_end))
|
| 52 |
+
return merged
|
| 53 |
+
|
| 54 |
+
# %% ../nbs/2A. Whisper quantization dataset preparation.ipynb 18
|
| 55 |
+
def merge_in(*datasets):
|
| 56 |
+
"""Merge multiple datasets into the current one returning samples with the union of keys.
|
| 57 |
+
|
| 58 |
+
It requires (and validates) all datasets to have the same ordering of keys so you have
|
| 59 |
+
to use it before any sample shuffling. Shard shuffling is ok.
|
| 60 |
+
"""
|
| 61 |
+
def merge_loop(main_samples):
|
| 62 |
+
for samples in zip(*[main_samples]+[iter(x) for x in datasets]):
|
| 63 |
+
key = samples[0]['__key__']
|
| 64 |
+
news = {}
|
| 65 |
+
for s in samples:
|
| 66 |
+
assert s['__key__'] == key
|
| 67 |
+
news.update(s)
|
| 68 |
+
yield news
|
| 69 |
+
return merge_loop
|
| 70 |
+
|
| 71 |
+
# %% ../nbs/2A. Whisper quantization dataset preparation.ipynb 19
|
| 72 |
+
import copy
|
| 73 |
+
|
| 74 |
+
# %% ../nbs/2A. Whisper quantization dataset preparation.ipynb 20
|
| 75 |
+
# a workaround for https://github.com/webdataset/webdataset/issues/297
|
| 76 |
+
# should be possible to use ds.compose here
|
| 77 |
+
def wds_compose(ds, *args):
|
| 78 |
+
ds = copy.copy(ds)
|
| 79 |
+
ds.pipeline = copy.copy(ds.pipeline)
|
| 80 |
+
for f in args:
|
| 81 |
+
ds.append(f)
|
| 82 |
+
return ds
|
| 83 |
+
|
| 84 |
+
# %% ../nbs/2A. Whisper quantization dataset preparation.ipynb 24
|
| 85 |
+
def split_to_chunks(stream, pad_to_seconds=30, random_shift=False):
|
| 86 |
+
for s in stream:
|
| 87 |
+
audio, sr = s.get('flac', s.get('wav', (None, None)))
|
| 88 |
+
if audio is None:
|
| 89 |
+
print(f"warning: '{s['__key__']}' does not contain an audio file")
|
| 90 |
+
continue
|
| 91 |
+
imax = len(s['vad.npy']) - 1
|
| 92 |
+
for i,(ts,te) in enumerate(s['vad.npy']):
|
| 93 |
+
samples = audio[0,int(ts*sr):int(te*sr)]
|
| 94 |
+
if pad_to_seconds is not None:
|
| 95 |
+
padding = pad_to_seconds*sr-samples.shape[-1]
|
| 96 |
+
lpad = random.randint(0, padding) if random_shift else 0
|
| 97 |
+
samples = F.pad(samples, (lpad, padding-lpad))
|
| 98 |
+
yield {"__key__": s['__key__'] + f"_{i:03d}",
|
| 99 |
+
"__url__": s['__url__'],
|
| 100 |
+
"i": i, "imax": imax,
|
| 101 |
+
"tstart": ts, "tend": te, "total_seconds": audio.shape[-1]/sr,
|
| 102 |
+
"lpad": lpad, "rpad": padding-lpad,
|
| 103 |
+
"lpad_s": lpad/sr, "rpad_s": (padding-lpad)/sr,
|
| 104 |
+
"samples": samples, "sample_rate": sr}
|
| 105 |
+
|
| 106 |
+
# %% ../nbs/2A. Whisper quantization dataset preparation.ipynb 38
|
| 107 |
+
def flac_to_txt_name(input, model_size):
|
| 108 |
+
return input.rsplit("/", 1)[1].replace('flac', f'{model_size}-txt') + ".gz"
|
| 109 |
+
|
| 110 |
+
@call_parse
|
| 111 |
+
def process_shard(
|
| 112 |
+
input:str, # input shard URL/path
|
| 113 |
+
output:str=None, # output shard URL/path
|
| 114 |
+
bs:int=None, # batch size (16 uses around 11GB of VRAM)
|
| 115 |
+
n_samples:int=None, # limit the number of samples (useful for quick benchmarking)
|
| 116 |
+
whisper_model:str="base.en" # Whisper model size
|
| 117 |
+
):
|
| 118 |
+
if output is None: output = flac_to_txt_name(input, whisper_model)
|
| 119 |
+
if bs is None: bs = 16
|
| 120 |
+
if n_samples is None: n_samples = 'noinfer'
|
| 121 |
+
else: n_samples = n_samples // bs
|
| 122 |
+
|
| 123 |
+
ds = wds_compose(vad.load_dataset(input),
|
| 124 |
+
merge_in(wds.WebDataset(vad.flac_to_vad_name(input)).decode()),
|
| 125 |
+
wds.map_dict(**{"vad.npy":chunk_merger}),
|
| 126 |
+
split_to_chunks,
|
| 127 |
+
wds.to_tuple('__key__', 'samples'),
|
| 128 |
+
wds.batched(bs),
|
| 129 |
+
)
|
| 130 |
+
dl = DataLoader(ds, num_workers=2, batch_size=None)
|
| 131 |
+
|
| 132 |
+
whmodel = whisper.load_model(whisper_model)
|
| 133 |
+
decoding_options = whisper.DecodingOptions(language='en')
|
| 134 |
+
|
| 135 |
+
tmp = output+".tmp"
|
| 136 |
+
with wds.TarWriter(tmp) as sink:
|
| 137 |
+
for keys, samples in progress_bar(dl, total=n_samples):
|
| 138 |
+
with torch.no_grad():
|
| 139 |
+
embs = whmodel.encoder(whisper.log_mel_spectrogram(samples).cuda())
|
| 140 |
+
decs = whmodel.decode(embs, decoding_options)
|
| 141 |
+
for key, dec in zip(keys, decs):
|
| 142 |
+
sink.write({
|
| 143 |
+
"__key__": key,
|
| 144 |
+
"txt": dec.text,
|
| 145 |
+
})
|
| 146 |
+
os.rename(tmp, output)
|