Spaces:
Sleeping
Sleeping
initial commit
Browse files- .gitignore +9 -0
- app.py +278 -0
- ckpt/encodec_voicecraft.pt +3 -0
- encodec/LICENSE +21 -0
- encodec/__init__.py +6 -0
- encodec/distrib.py +125 -0
- encodec/encodec.py +171 -0
- encodec/modules/__init__.py +17 -0
- encodec/modules/conv.py +342 -0
- encodec/modules/lstm.py +27 -0
- encodec/modules/norm.py +32 -0
- encodec/modules/seanet.py +352 -0
- encodec/quantization/__init__.py +7 -0
- encodec/quantization/core_vq.py +370 -0
- encodec/quantization/vq.py +95 -0
- requirements.txt +5 -0
.gitignore
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
**/..git/*
|
2 |
+
__pycache__/
|
3 |
+
.DS_Store
|
4 |
+
._.DS_Store
|
5 |
+
.ipynb_checkpoints/
|
6 |
+
.vscode/
|
7 |
+
*.egg-info/
|
8 |
+
.pytest_cache
|
9 |
+
*.ipynb_checkpoints/
|
app.py
ADDED
@@ -0,0 +1,278 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import uuid
|
4 |
+
from pathlib import Path
|
5 |
+
from contextlib import contextmanager
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
import matplotlib.pyplot as plt
|
10 |
+
import gradio as gr
|
11 |
+
from scipy.io.wavfile import write as wavwrite
|
12 |
+
|
13 |
+
from audiotools import AudioSignal
|
14 |
+
from audioseal import AudioSeal
|
15 |
+
|
16 |
+
# allow local imports of your encodec folder
|
17 |
+
@contextmanager
|
18 |
+
def chdir(path: str):
|
19 |
+
origin = Path().absolute()
|
20 |
+
try:
|
21 |
+
os.chdir(path)
|
22 |
+
yield
|
23 |
+
finally:
|
24 |
+
os.chdir(origin)
|
25 |
+
|
26 |
+
|
27 |
+
_path = Path(__file__).parent
|
28 |
+
sys.path.insert(0, str(_path))
|
29 |
+
with chdir(_path):
|
30 |
+
from encodec import Encodec
|
31 |
+
|
32 |
+
|
33 |
+
OUT_DIR = _path / "gradio-outputs"
|
34 |
+
OUT_DIR.mkdir(exist_ok=True)
|
35 |
+
|
36 |
+
LOUDNESS_DB = -16.
|
37 |
+
SAMPLE_RATE = 48_000
|
38 |
+
ENCODEC_SAMPLE_RATE = 16_000
|
39 |
+
AUDIOSEAL_SAMPLE_RATE = 16_000
|
40 |
+
|
41 |
+
# load codec
|
42 |
+
config = {
|
43 |
+
"sample_rate": 16_000,
|
44 |
+
"target_bandwidths": [2.2],
|
45 |
+
"channels": 1,
|
46 |
+
"causal": False,
|
47 |
+
"codebook_size": 2048,
|
48 |
+
"n_filters": 64,
|
49 |
+
"model_norm": "weight_norm",
|
50 |
+
"audio_normalize": False,
|
51 |
+
"true_skip": True,
|
52 |
+
"ratios": [8, 5, 4, 2],
|
53 |
+
"encoder_kwargs": {"pad_mode": "constant"},
|
54 |
+
"decoder_kwargs": {"pad_mode": "constant"},
|
55 |
+
}
|
56 |
+
codec = Encodec(**config)
|
57 |
+
codec.load_state_dict(torch.load("ckpt/encodec_voicecraft.pt", map_location="cpu"))
|
58 |
+
codec.eval()
|
59 |
+
for p in codec.parameters(): p.requires_grad_(False)
|
60 |
+
codec.set_target_bandwidth(2.2)
|
61 |
+
|
62 |
+
# watermark models
|
63 |
+
embedder = AudioSeal.load_generator("audioseal_wm_16bits")
|
64 |
+
detector = AudioSeal.load_detector("audioseal_detector_16bits")
|
65 |
+
|
66 |
+
|
67 |
+
@torch.no_grad()
|
68 |
+
def encode(signal: AudioSignal, codec: torch.nn.Module):
|
69 |
+
n_b, n_ch, n_s = signal.shape
|
70 |
+
sr = signal.sample_rate
|
71 |
+
loud_db = signal.loudness()
|
72 |
+
x = signal.clone().resample(ENCODEC_SAMPLE_RATE).audio_data
|
73 |
+
x = x.reshape(n_b * n_ch, 1, -1)
|
74 |
+
codes, *_ = codec.encode(x)
|
75 |
+
return codes, n_b, n_ch, n_s, sr, loud_db
|
76 |
+
|
77 |
+
@torch.no_grad()
|
78 |
+
def decode(codes, n_b, n_ch, n_s, sr, loud_db, codec):
|
79 |
+
x = codec.decode(codes).reshape(n_b, n_ch, -1)
|
80 |
+
sig = AudioSignal(x, sample_rate=ENCODEC_SAMPLE_RATE)
|
81 |
+
sig = sig.resample(sr)
|
82 |
+
sig.audio_data = sig.audio_data[..., :n_s]
|
83 |
+
sig.audio_data = torch.nn.functional.pad(
|
84 |
+
sig.audio_data, (0, max(0, n_s - sig.signal_length))
|
85 |
+
)
|
86 |
+
return sig.normalize(loud_db)
|
87 |
+
|
88 |
+
@torch.no_grad()
|
89 |
+
def split_bands(signal: AudioSignal, sample_rate: float = ENCODEC_SAMPLE_RATE):
|
90 |
+
nyq = sample_rate // 2
|
91 |
+
high = signal.clone().high_pass(cutoffs=int(nyq * 0.95), zeros=51)
|
92 |
+
low = signal.clone().low_pass(cutoffs=int(nyq * 1.05), zeros=51)
|
93 |
+
loud_db = low.loudness()
|
94 |
+
low = low.resample(sample_rate)
|
95 |
+
return low, high, loud_db
|
96 |
+
|
97 |
+
@torch.no_grad()
|
98 |
+
def merge_bands(low, high, loud_db):
|
99 |
+
low = low.clone().to(high.device).resample(high.sample_rate)
|
100 |
+
low.audio_data = low.audio_data[..., :high.signal_length]
|
101 |
+
low.audio_data = torch.nn.functional.pad(
|
102 |
+
low.audio_data, (0, max(0, high.signal_length - low.signal_length))
|
103 |
+
)
|
104 |
+
return low.normalize(loud_db) + high
|
105 |
+
|
106 |
+
@torch.no_grad()
|
107 |
+
def attack(signal: AudioSignal, codec, split_rate_hz=AUDIOSEAL_SAMPLE_RATE):
|
108 |
+
if split_rate_hz:
|
109 |
+
low, high, loud_db = split_bands(signal, split_rate_hz)
|
110 |
+
low = decode(*encode(low, codec), codec)
|
111 |
+
return merge_bands(low, high, loud_db)
|
112 |
+
else:
|
113 |
+
return decode(*encode(signal, codec), codec)
|
114 |
+
|
115 |
+
@torch.no_grad()
|
116 |
+
def embed(signal: AudioSignal, embedder: torch.nn.Module):
|
117 |
+
orig_ch, orig_sr = signal.num_channels, signal.sample_rate
|
118 |
+
sig = signal.clone().resample(SAMPLE_RATE)
|
119 |
+
if orig_ch > 1:
|
120 |
+
b, c, n = sig.audio_data.shape
|
121 |
+
sig.audio_data = sig.audio_data.reshape(b * c, 1, n)
|
122 |
+
low, high, loud = split_bands(sig.clone(), AUDIOSEAL_SAMPLE_RATE)
|
123 |
+
wm = embedder.get_watermark(low.audio_data, AUDIOSEAL_SAMPLE_RATE)
|
124 |
+
low.audio_data = low.audio_data + wm
|
125 |
+
merged = merge_bands(low, high, loud)
|
126 |
+
if orig_ch > 1:
|
127 |
+
b2, c2, n2 = merged.audio_data.shape
|
128 |
+
merged.audio_data = merged.audio_data.reshape(-1, orig_ch * c2, n2)
|
129 |
+
return merged.resample(orig_sr)
|
130 |
+
|
131 |
+
@torch.no_grad()
|
132 |
+
def detect(signal: AudioSignal, detector: torch.nn.Module):
|
133 |
+
sig = signal.clone().to_mono().resample(AUDIOSEAL_SAMPLE_RATE)
|
134 |
+
result, _ = detector.forward(sig.audio_data, sample_rate=AUDIOSEAL_SAMPLE_RATE)
|
135 |
+
return result[0, 1, :].detach().cpu().numpy()
|
136 |
+
|
137 |
+
def pipeline(audio_tuple):
|
138 |
+
|
139 |
+
sr, audio_np = audio_tuple
|
140 |
+
|
141 |
+
print("GOT SR", sr)
|
142 |
+
print("GOT AUDIO", audio_np.shape)
|
143 |
+
|
144 |
+
if audio_np.ndim == 1:
|
145 |
+
audio_np = audio_np[None, None, :]
|
146 |
+
else:
|
147 |
+
audio_np = np.transpose(audio_np, (1, 0))[None, ...]
|
148 |
+
|
149 |
+
print("FORMATTED AUDIO", audio_np.shape)
|
150 |
+
|
151 |
+
sig = AudioSignal(torch.from_numpy(audio_np).float(), sample_rate=sr)
|
152 |
+
orig_loud = sig.loudness()
|
153 |
+
sig = sig.to_mono().resample(SAMPLE_RATE).normalize(LOUDNESS_DB).ensure_max_of_audio()
|
154 |
+
|
155 |
+
|
156 |
+
print("REFORMATTED AUDIO")
|
157 |
+
print(sig)
|
158 |
+
|
159 |
+
# Detect
|
160 |
+
scores = detect(sig, detector)
|
161 |
+
|
162 |
+
# Embed + detect without attack
|
163 |
+
wm_sig = embed(sig.clone(), embedder).normalize(LOUDNESS_DB).ensure_max_of_audio()
|
164 |
+
scores_clean = detect(wm_sig, detector)
|
165 |
+
|
166 |
+
print(np.mean(scores_clean))
|
167 |
+
|
168 |
+
# Attack + detect
|
169 |
+
att_sig = attack(wm_sig.clone(), codec).normalize(LOUDNESS_DB).ensure_max_of_audio()
|
170 |
+
scores_att = detect(att_sig, detector)
|
171 |
+
|
172 |
+
print(np.mean(scores_att))
|
173 |
+
|
174 |
+
# Match loudness priot to writing
|
175 |
+
wm_sig.normalize(orig_loud).ensure_max_of_audio()
|
176 |
+
att_sig.normalize(orig_loud).ensure_max_of_audio()
|
177 |
+
|
178 |
+
# Write audio files to disk
|
179 |
+
uid = uuid.uuid4().hex
|
180 |
+
wm_path = OUT_DIR / f"watermarked_{uid}.wav"
|
181 |
+
att_path = OUT_DIR / f"attacked_{uid}.wav"
|
182 |
+
|
183 |
+
wm_arr = wm_sig.audio_data.squeeze().numpy()
|
184 |
+
att_arr = att_sig.audio_data.squeeze().numpy()
|
185 |
+
wavwrite(str(wm_path), SAMPLE_RATE, wm_arr)
|
186 |
+
wavwrite(str(att_path), SAMPLE_RATE, att_arr)
|
187 |
+
|
188 |
+
# Plot scores with waveform background
|
189 |
+
# Plot: waveform on top, detection scores on bottom
|
190 |
+
sig_bg = sig.clone().to_mono().resample(AUDIOSEAL_SAMPLE_RATE)
|
191 |
+
wav = sig_bg.audio_data.squeeze().numpy()
|
192 |
+
N = len(scores)
|
193 |
+
if wav.shape[0] < N:
|
194 |
+
wav = np.pad(wav, (0, N - wav.shape[0]), mode="constant")
|
195 |
+
else:
|
196 |
+
wav = wav[:N]
|
197 |
+
|
198 |
+
fig, (ax_wav, ax_score) = plt.subplots(2, 1, sharex=True, figsize=(8, 6))
|
199 |
+
# Top: waveform (no labels)
|
200 |
+
ax_wav.plot(wav, alpha=0.3)
|
201 |
+
ax_wav.axis("off")
|
202 |
+
|
203 |
+
# Bottom: detection scores
|
204 |
+
ax_score.plot(scores, label="No watermark", color="blue")
|
205 |
+
ax_score.plot(scores_clean, label="Watermark (no attack)", color="green")
|
206 |
+
ax_score.plot(scores_att, label="Watermark (codec attack)", color="red")
|
207 |
+
ax_score.set_xlabel("Frame Index")
|
208 |
+
ax_score.set_ylabel("Detection Score")
|
209 |
+
ax_score.set_ylim(-0.05, 1.05)
|
210 |
+
ax_score.set_yticks([0.0, 0.2, 0.4, 0.6, 0.8, 1.0])
|
211 |
+
ax_score.legend()
|
212 |
+
|
213 |
+
plt.tight_layout()
|
214 |
+
plot_path = OUT_DIR / f"detection_plot_{uid}.png"
|
215 |
+
fig.savefig(str(plot_path), format="png")
|
216 |
+
plt.close(fig)
|
217 |
+
|
218 |
+
return str(wm_path), str(att_path), str(plot_path)
|
219 |
+
|
220 |
+
demo = gr.Interface(
|
221 |
+
fn=pipeline,
|
222 |
+
inputs= gr.Audio(sources=["upload"], type="numpy", label="Upload Input Audio"),
|
223 |
+
outputs=[
|
224 |
+
gr.Audio(type="filepath", label="Watermarked Audio"),
|
225 |
+
gr.Audio(type="filepath", label="Attacked Audio"),
|
226 |
+
gr.Image(type="filepath", label="Detection Scores Plot"),
|
227 |
+
],
|
228 |
+
title="Watermark Stress Test",
|
229 |
+
description="""
|
230 |
+
|
231 |
+
This is an educational demonstration of state-of-the-art audio watermark performance under codec processing. Upload any (speech) audio file to test watermark performance before and after processing with a low-bitrate neural codec [1].
|
232 |
+
|
233 |
+
For this demo, we use the AudioSeal [2] watermark, which is well documented, open source, and provides state-of-the-art localized detection performance. Both the watermark and codec operate at 16kHz, meaning all frequencies above 8kHz are left unaltered. To ensure consistent watermark performance, we normalize audio to -16db LUFS and downmix to mono prior to embedding.
|
234 |
+
|
235 |
+
[1] https://github.com/jasonppy/VoiceCraft
|
236 |
+
[2] https://github.com/facebookresearch/audioseal
|
237 |
+
""",
|
238 |
+
article="""
|
239 |
+
The citation info for our corresponding paper is:
|
240 |
+
|
241 |
+
```
|
242 |
+
@inproceedings{deepwatermarksareshallow,
|
243 |
+
author ={Patrick O'Reilly and Zeyu Jin and Jiaqi Su and Bryan Pardo},
|
244 |
+
title = {Deep Audio Watermarks are Shallow: Limitations of Post-Hoc Watermarking Techniques for Speech},
|
245 |
+
booktitle = {ICLR Workshop on GenAI Watermarking},
|
246 |
+
year = {2025}
|
247 |
+
}
|
248 |
+
```
|
249 |
+
|
250 |
+
For the VoiceCraft codec:
|
251 |
+
|
252 |
+
```
|
253 |
+
@article{voicecraft,
|
254 |
+
author={Puyuan Peng and Po-Yao Huang and Daniel Li and Abdelrahman Mohamed and David Harwath},
|
255 |
+
year={2024},
|
256 |
+
title={VoiceCraft: Zero-Shot Speech Editing and Text-to-Speech in the Wild},
|
257 |
+
journal={arXiv preprint arXiv:2403.16973v1},
|
258 |
+
}
|
259 |
+
|
260 |
+
```
|
261 |
+
|
262 |
+
And for the AudioSeal watermark:
|
263 |
+
|
264 |
+
```
|
265 |
+
@article{audioseal,
|
266 |
+
title={Proactive Detection of Voice Cloning with Localized Watermarking},
|
267 |
+
author={San Roman, Robin and Fernandez, Pierre and Elsahar, Hady and D´efossez, Alexandre and Furon, Teddy and Tran, Tuan},
|
268 |
+
journal={International Conference on Machine Learning (ICML)},
|
269 |
+
year={2024}
|
270 |
+
}
|
271 |
+
```
|
272 |
+
|
273 |
+
""",
|
274 |
+
allow_flagging=False,
|
275 |
+
)
|
276 |
+
|
277 |
+
if __name__ == "__main__":
|
278 |
+
demo.launch(share=True)
|
ckpt/encodec_voicecraft.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:42b224ba5b193a8fb66eb692fe377831bb14b1dcf556638db7afc5d108099bfb
|
3 |
+
size 235735922
|
encodec/LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) Meta Platforms, Inc. and affiliates.
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
encodec/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
from .encodec import Encodec
|
encodec/distrib.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
"""Torch distributed utilities."""
|
7 |
+
import typing as tp
|
8 |
+
|
9 |
+
import torch
|
10 |
+
|
11 |
+
|
12 |
+
def rank():
|
13 |
+
if torch.distributed.is_initialized():
|
14 |
+
return torch.distributed.get_rank()
|
15 |
+
else:
|
16 |
+
return 0
|
17 |
+
|
18 |
+
|
19 |
+
def world_size():
|
20 |
+
if torch.distributed.is_initialized():
|
21 |
+
return torch.distributed.get_world_size()
|
22 |
+
else:
|
23 |
+
return 1
|
24 |
+
|
25 |
+
|
26 |
+
def is_distributed():
|
27 |
+
return world_size() > 1
|
28 |
+
|
29 |
+
|
30 |
+
def all_reduce(tensor: torch.Tensor, op=torch.distributed.ReduceOp.SUM):
|
31 |
+
if is_distributed():
|
32 |
+
return torch.distributed.all_reduce(tensor, op)
|
33 |
+
|
34 |
+
|
35 |
+
def _is_complex_or_float(tensor):
|
36 |
+
return torch.is_floating_point(tensor) or torch.is_complex(tensor)
|
37 |
+
|
38 |
+
|
39 |
+
def _check_number_of_params(params: tp.List[torch.Tensor]):
|
40 |
+
# utility function to check that the number of params in all workers is the same,
|
41 |
+
# and thus avoid a deadlock with distributed all reduce.
|
42 |
+
if not is_distributed() or not params:
|
43 |
+
return
|
44 |
+
tensor = torch.tensor([len(params)], device=params[0].device, dtype=torch.long)
|
45 |
+
all_reduce(tensor)
|
46 |
+
if tensor.item() != len(params) * world_size():
|
47 |
+
# If not all the workers have the same number, for at least one of them,
|
48 |
+
# this inequality will be verified.
|
49 |
+
raise RuntimeError(
|
50 |
+
f"Mismatch in number of params: ours is {len(params)}, "
|
51 |
+
"at least one worker has a different one."
|
52 |
+
)
|
53 |
+
|
54 |
+
|
55 |
+
def broadcast_tensors(tensors: tp.Iterable[torch.Tensor], src: int = 0):
|
56 |
+
"""Broadcast the tensors from the given parameters to all workers.
|
57 |
+
This can be used to ensure that all workers have the same model to start with.
|
58 |
+
"""
|
59 |
+
if not is_distributed():
|
60 |
+
return
|
61 |
+
tensors = [tensor for tensor in tensors if _is_complex_or_float(tensor)]
|
62 |
+
_check_number_of_params(tensors)
|
63 |
+
handles = []
|
64 |
+
for tensor in tensors:
|
65 |
+
handle = torch.distributed.broadcast(tensor.data, src=src, async_op=True)
|
66 |
+
handles.append(handle)
|
67 |
+
for handle in handles:
|
68 |
+
handle.wait()
|
69 |
+
|
70 |
+
|
71 |
+
def sync_buffer(buffers, average=True):
|
72 |
+
"""
|
73 |
+
Sync grad for buffers. If average is False, broadcast instead of averaging.
|
74 |
+
"""
|
75 |
+
if not is_distributed():
|
76 |
+
return
|
77 |
+
handles = []
|
78 |
+
for buffer in buffers:
|
79 |
+
if torch.is_floating_point(buffer.data):
|
80 |
+
if average:
|
81 |
+
handle = torch.distributed.all_reduce(
|
82 |
+
buffer.data, op=torch.distributed.ReduceOp.SUM, async_op=True
|
83 |
+
)
|
84 |
+
else:
|
85 |
+
handle = torch.distributed.broadcast(buffer.data, src=0, async_op=True)
|
86 |
+
handles.append((buffer, handle))
|
87 |
+
for buffer, handle in handles:
|
88 |
+
handle.wait()
|
89 |
+
if average:
|
90 |
+
buffer.data /= world_size
|
91 |
+
|
92 |
+
|
93 |
+
def sync_grad(params):
|
94 |
+
"""
|
95 |
+
Simpler alternative to DistributedDataParallel, that doesn't rely
|
96 |
+
on any black magic. For simple models it can also be as fast.
|
97 |
+
Just call this on your model parameters after the call to backward!
|
98 |
+
"""
|
99 |
+
if not is_distributed():
|
100 |
+
return
|
101 |
+
handles = []
|
102 |
+
for p in params:
|
103 |
+
if p.grad is not None:
|
104 |
+
handle = torch.distributed.all_reduce(
|
105 |
+
p.grad.data, op=torch.distributed.ReduceOp.SUM, async_op=True
|
106 |
+
)
|
107 |
+
handles.append((p, handle))
|
108 |
+
for p, handle in handles:
|
109 |
+
handle.wait()
|
110 |
+
p.grad.data /= world_size()
|
111 |
+
|
112 |
+
|
113 |
+
def average_metrics(metrics: tp.Dict[str, float], count=1.0):
|
114 |
+
"""Average a dictionary of metrics across all workers, using the optional
|
115 |
+
`count` as unnormalized weight.
|
116 |
+
"""
|
117 |
+
if not is_distributed():
|
118 |
+
return metrics
|
119 |
+
keys, values = zip(*metrics.items())
|
120 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
121 |
+
tensor = torch.tensor(list(values) + [1], device=device, dtype=torch.float32)
|
122 |
+
tensor *= count
|
123 |
+
all_reduce(tensor)
|
124 |
+
averaged = (tensor[:-1] / tensor[-1]).cpu().tolist()
|
125 |
+
return dict(zip(keys, averaged))
|
encodec/encodec.py
ADDED
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
import math
|
7 |
+
import typing as tp
|
8 |
+
|
9 |
+
import torch
|
10 |
+
|
11 |
+
from .modules import SEANetDecoder
|
12 |
+
from .modules import SEANetEncoder
|
13 |
+
from .quantization import ResidualVectorQuantizer
|
14 |
+
|
15 |
+
################################################################################
|
16 |
+
# Encodec neural audio codec
|
17 |
+
################################################################################
|
18 |
+
|
19 |
+
|
20 |
+
class Encodec(torch.nn.Module):
|
21 |
+
"""
|
22 |
+
Encodec neural audio codec proposed in "High Fidelity Neural Audio
|
23 |
+
Compression" (https://arxiv.org/abs/2210.13438) by Défossez et al.
|
24 |
+
"""
|
25 |
+
|
26 |
+
def __init__(
|
27 |
+
self,
|
28 |
+
sample_rate: int,
|
29 |
+
channels: int,
|
30 |
+
causal: bool,
|
31 |
+
model_norm: str,
|
32 |
+
target_bandwidths: tp.Sequence[float],
|
33 |
+
audio_normalize: bool,
|
34 |
+
ratios: tp.List[int] = (8, 5, 4, 2),
|
35 |
+
codebook_size: int = 1024,
|
36 |
+
n_filters: int = 32,
|
37 |
+
true_skip: bool = False,
|
38 |
+
encoder_kwargs: tp.Dict = None,
|
39 |
+
decoder_kwargs: tp.Dict = None,
|
40 |
+
):
|
41 |
+
"""
|
42 |
+
Parameters
|
43 |
+
----------
|
44 |
+
sample_rate : int
|
45 |
+
Audio sample rate in Hz.
|
46 |
+
channels : int
|
47 |
+
Number of audio channels expected at input.
|
48 |
+
causal : bool
|
49 |
+
Whether to use a causal convolution layers in encoder/decoder.
|
50 |
+
model_norm : str
|
51 |
+
Type of normalization to use in encoder/decoder.
|
52 |
+
target_bandwidths : tp.Sequence[float]
|
53 |
+
List of target bandwidths in kb/s.
|
54 |
+
audio_normalize : bool
|
55 |
+
Whether to normalize encoded and decoded audio segments using
|
56 |
+
simple scaling factors
|
57 |
+
ratios : tp.List[int], optional
|
58 |
+
List of downsampling ratios used in encoder/decoder, by default (8, 5, 4, 2)
|
59 |
+
codebook_size : int, optional
|
60 |
+
Size of residual vector quantizer codebooks, by default 1024
|
61 |
+
n_filters : int, optional
|
62 |
+
Number of filters used in encoder/decoder, by default 32
|
63 |
+
true_skip : bool, optional
|
64 |
+
Whether to use true skip connections in encoder/decoder rather than
|
65 |
+
convolutional skip connections, by default False
|
66 |
+
"""
|
67 |
+
super().__init__()
|
68 |
+
|
69 |
+
encoder_kwargs = encoder_kwargs or {}
|
70 |
+
decoder_kwargs = decoder_kwargs or {}
|
71 |
+
|
72 |
+
self.encoder = SEANetEncoder(
|
73 |
+
channels=channels,
|
74 |
+
causal=causal,
|
75 |
+
norm=model_norm,
|
76 |
+
ratios=ratios,
|
77 |
+
n_filters=n_filters,
|
78 |
+
true_skip=true_skip,
|
79 |
+
**encoder_kwargs,
|
80 |
+
)
|
81 |
+
self.decoder = SEANetDecoder(
|
82 |
+
channels=channels,
|
83 |
+
causal=causal,
|
84 |
+
norm=model_norm,
|
85 |
+
ratios=ratios,
|
86 |
+
n_filters=n_filters,
|
87 |
+
true_skip=true_skip,
|
88 |
+
**decoder_kwargs,
|
89 |
+
)
|
90 |
+
|
91 |
+
n_q = int(
|
92 |
+
1000
|
93 |
+
* target_bandwidths[-1]
|
94 |
+
// (math.ceil(sample_rate / self.encoder.hop_length) * 10)
|
95 |
+
)
|
96 |
+
self.n_q = n_q # Maximum number of quantizers
|
97 |
+
self.quantizer = ResidualVectorQuantizer(
|
98 |
+
dimension=self.encoder.dimension,
|
99 |
+
n_q=n_q,
|
100 |
+
bins=codebook_size,
|
101 |
+
)
|
102 |
+
|
103 |
+
self.sample_rate = sample_rate
|
104 |
+
self.normalize = audio_normalize
|
105 |
+
self.channels = channels
|
106 |
+
|
107 |
+
self.frame_rate = math.ceil(self.sample_rate / math.prod(self.encoder.ratios))
|
108 |
+
|
109 |
+
self.target_bandwidths = target_bandwidths
|
110 |
+
self.bits_per_codebook = int(math.log2(self.quantizer.bins))
|
111 |
+
assert (
|
112 |
+
2**self.bits_per_codebook == self.quantizer.bins
|
113 |
+
), "quantizer bins must be a power of 2."
|
114 |
+
|
115 |
+
self.bandwidth = self.target_bandwidths[-1]
|
116 |
+
|
117 |
+
def set_target_bandwidth(self, bandwidth: float):
|
118 |
+
"""
|
119 |
+
Set the target bandwidth for the codec by adjusting the
|
120 |
+
number of residual vector quantizers used
|
121 |
+
"""
|
122 |
+
if bandwidth not in self.target_bandwidths:
|
123 |
+
raise ValueError(
|
124 |
+
f"This model doesn't support the bandwidth {bandwidth}. "
|
125 |
+
f"Select one of {self.target_bandwidths}."
|
126 |
+
)
|
127 |
+
self.bandwidth = bandwidth
|
128 |
+
|
129 |
+
def encode(self, x: torch.Tensor) -> torch.Tensor:
|
130 |
+
"""
|
131 |
+
Map a given an audio waveform `x` to discrete residual latent codes.
|
132 |
+
|
133 |
+
Parameters
|
134 |
+
----------
|
135 |
+
x : torch.Tensor
|
136 |
+
Audio waveform of shape `(n_batch, n_channels, n_samples)`.
|
137 |
+
|
138 |
+
Returns
|
139 |
+
-------
|
140 |
+
codes : torch.Tensor
|
141 |
+
Tensor of shape `(n_batch, n_codebooks, n_frames)`.
|
142 |
+
"""
|
143 |
+
assert x.dim() == 3
|
144 |
+
_, channels, length = x.shape
|
145 |
+
assert 0 < channels <= 2
|
146 |
+
|
147 |
+
z = self.encoder(x)
|
148 |
+
codes, z_O, z_o = self.quantizer.encode(z, self.frame_rate, self.bandwidth)
|
149 |
+
codes = codes.transpose(0, 1)
|
150 |
+
|
151 |
+
return codes, z_O, z_o, z
|
152 |
+
|
153 |
+
def decode(self, codes: torch.Tensor):
|
154 |
+
"""
|
155 |
+
Decode quantized latents to obtain waveform audio.
|
156 |
+
|
157 |
+
Parameters
|
158 |
+
----------
|
159 |
+
codes : torch.Tensor
|
160 |
+
Tensor of shape `(n_batch, n_codebooks, n_frames)`.
|
161 |
+
|
162 |
+
Returns
|
163 |
+
-------
|
164 |
+
out : torch.Tensor
|
165 |
+
Tensor of shape `(n_batch, n_channels, n_samples)`.
|
166 |
+
"""
|
167 |
+
codes = codes.transpose(0, 1)
|
168 |
+
emb = self.quantizer.decode(codes)
|
169 |
+
out = self.decoder(emb)
|
170 |
+
|
171 |
+
return out
|
encodec/modules/__init__.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
"""Torch modules."""
|
7 |
+
from .conv import NormConv1d
|
8 |
+
from .conv import NormConv2d
|
9 |
+
from .conv import NormConvTranspose1d
|
10 |
+
from .conv import NormConvTranspose2d
|
11 |
+
from .conv import pad1d
|
12 |
+
from .conv import SConv1d
|
13 |
+
from .conv import SConvTranspose1d
|
14 |
+
from .conv import unpad1d
|
15 |
+
from .lstm import SLSTM
|
16 |
+
from .seanet import SEANetDecoder
|
17 |
+
from .seanet import SEANetEncoder
|
encodec/modules/conv.py
ADDED
@@ -0,0 +1,342 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
"""Convolutional layers wrappers and utilities."""
|
7 |
+
import math
|
8 |
+
import typing as tp
|
9 |
+
import warnings
|
10 |
+
|
11 |
+
import torch
|
12 |
+
from torch import nn
|
13 |
+
from torch.nn import functional as F
|
14 |
+
from torch.nn.utils import spectral_norm
|
15 |
+
from torch.nn.utils import weight_norm
|
16 |
+
|
17 |
+
from .norm import ConvLayerNorm
|
18 |
+
|
19 |
+
|
20 |
+
CONV_NORMALIZATIONS = frozenset(
|
21 |
+
[
|
22 |
+
"none",
|
23 |
+
"weight_norm",
|
24 |
+
"spectral_norm",
|
25 |
+
"time_layer_norm",
|
26 |
+
"layer_norm",
|
27 |
+
"time_group_norm",
|
28 |
+
]
|
29 |
+
)
|
30 |
+
|
31 |
+
|
32 |
+
def apply_parametrization_norm(module: nn.Module, norm: str = "none") -> nn.Module:
|
33 |
+
assert norm in CONV_NORMALIZATIONS
|
34 |
+
if norm == "weight_norm":
|
35 |
+
return weight_norm(module)
|
36 |
+
elif norm == "spectral_norm":
|
37 |
+
return spectral_norm(module)
|
38 |
+
else:
|
39 |
+
# We already check was in CONV_NORMALIZATION, so any other choice
|
40 |
+
# doesn't need reparametrization.
|
41 |
+
return module
|
42 |
+
|
43 |
+
|
44 |
+
def get_norm_module(
|
45 |
+
module: nn.Module, causal: bool = False, norm: str = "none", **norm_kwargs
|
46 |
+
) -> nn.Module:
|
47 |
+
"""Return the proper normalization module. If causal is True, this will ensure the returned
|
48 |
+
module is causal, or return an error if the normalization doesn't support causal evaluation.
|
49 |
+
"""
|
50 |
+
assert norm in CONV_NORMALIZATIONS
|
51 |
+
if norm == "layer_norm":
|
52 |
+
assert isinstance(module, nn.modules.conv._ConvNd)
|
53 |
+
return ConvLayerNorm(module.out_channels, **norm_kwargs)
|
54 |
+
elif norm == "time_group_norm":
|
55 |
+
if causal:
|
56 |
+
raise ValueError("GroupNorm doesn't support causal evaluation.")
|
57 |
+
assert isinstance(module, nn.modules.conv._ConvNd)
|
58 |
+
return nn.GroupNorm(1, module.out_channels, **norm_kwargs)
|
59 |
+
else:
|
60 |
+
return nn.Identity()
|
61 |
+
|
62 |
+
|
63 |
+
def get_extra_padding_for_conv1d(
|
64 |
+
x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0
|
65 |
+
) -> int:
|
66 |
+
"""See `pad_for_conv1d`."""
|
67 |
+
length = x.shape[-1]
|
68 |
+
n_frames = (length - kernel_size + padding_total) / stride + 1
|
69 |
+
ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
|
70 |
+
return ideal_length - length
|
71 |
+
|
72 |
+
|
73 |
+
def pad_for_conv1d(
|
74 |
+
x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0
|
75 |
+
):
|
76 |
+
"""Pad for a convolution to make sure that the last window is full.
|
77 |
+
Extra padding is added at the end. This is required to ensure that we can rebuild
|
78 |
+
an output of the same length, as otherwise, even with padding, some time steps
|
79 |
+
might get removed.
|
80 |
+
For instance, with total padding = 4, kernel size = 4, stride = 2:
|
81 |
+
0 0 1 2 3 4 5 0 0 # (0s are padding)
|
82 |
+
1 2 3 # (output frames of a convolution, last 0 is never used)
|
83 |
+
0 0 1 2 3 4 5 0 # (output of tr. conv., but pos. 5 is going to get removed as padding)
|
84 |
+
1 2 3 4 # once you removed padding, we are missing one time step !
|
85 |
+
"""
|
86 |
+
extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
|
87 |
+
return F.pad(x, (0, extra_padding))
|
88 |
+
|
89 |
+
|
90 |
+
def pad1d(
|
91 |
+
x: torch.Tensor,
|
92 |
+
paddings: tp.Tuple[int, int],
|
93 |
+
mode: str = "zero",
|
94 |
+
value: float = 0.0,
|
95 |
+
):
|
96 |
+
"""Tiny wrapper around F.pad, just to allow for reflect padding on small input.
|
97 |
+
If this is the case, we insert extra 0 padding to the right before the reflection happen.
|
98 |
+
"""
|
99 |
+
length = x.shape[-1]
|
100 |
+
padding_left, padding_right = paddings
|
101 |
+
assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
|
102 |
+
if mode == "reflect":
|
103 |
+
max_pad = max(padding_left, padding_right)
|
104 |
+
extra_pad = 0
|
105 |
+
if length <= max_pad:
|
106 |
+
extra_pad = max_pad - length + 1
|
107 |
+
x = F.pad(x, (0, extra_pad))
|
108 |
+
padded = F.pad(x, paddings, mode, value)
|
109 |
+
end = padded.shape[-1] - extra_pad
|
110 |
+
return padded[..., :end]
|
111 |
+
else:
|
112 |
+
return F.pad(x, paddings, mode, value)
|
113 |
+
|
114 |
+
|
115 |
+
def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]):
|
116 |
+
"""Remove padding from x, handling properly zero padding. Only for 1d!"""
|
117 |
+
padding_left, padding_right = paddings
|
118 |
+
assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
|
119 |
+
assert (padding_left + padding_right) <= x.shape[-1]
|
120 |
+
end = x.shape[-1] - padding_right
|
121 |
+
return x[..., padding_left:end]
|
122 |
+
|
123 |
+
|
124 |
+
class NormConv1d(nn.Module):
|
125 |
+
"""Wrapper around Conv1d and normalization applied to this conv
|
126 |
+
to provide a uniform interface across normalization approaches.
|
127 |
+
"""
|
128 |
+
|
129 |
+
def __init__(
|
130 |
+
self,
|
131 |
+
*args,
|
132 |
+
causal: bool = False,
|
133 |
+
norm: str = "none",
|
134 |
+
norm_kwargs: tp.Dict[str, tp.Any] = {},
|
135 |
+
**kwargs,
|
136 |
+
):
|
137 |
+
super().__init__()
|
138 |
+
self.conv = apply_parametrization_norm(nn.Conv1d(*args, **kwargs), norm)
|
139 |
+
self.norm = get_norm_module(self.conv, causal, norm, **norm_kwargs)
|
140 |
+
self.norm_type = norm
|
141 |
+
|
142 |
+
def forward(self, x):
|
143 |
+
x = self.conv(x)
|
144 |
+
x = self.norm(x)
|
145 |
+
return x
|
146 |
+
|
147 |
+
|
148 |
+
class NormConv2d(nn.Module):
|
149 |
+
"""Wrapper around Conv2d and normalization applied to this conv
|
150 |
+
to provide a uniform interface across normalization approaches.
|
151 |
+
"""
|
152 |
+
|
153 |
+
def __init__(
|
154 |
+
self,
|
155 |
+
*args,
|
156 |
+
norm: str = "none",
|
157 |
+
norm_kwargs: tp.Dict[str, tp.Any] = {},
|
158 |
+
**kwargs,
|
159 |
+
):
|
160 |
+
super().__init__()
|
161 |
+
self.conv = apply_parametrization_norm(nn.Conv2d(*args, **kwargs), norm)
|
162 |
+
self.norm = get_norm_module(self.conv, causal=False, norm=norm, **norm_kwargs)
|
163 |
+
self.norm_type = norm
|
164 |
+
|
165 |
+
def forward(self, x):
|
166 |
+
x = self.conv(x)
|
167 |
+
x = self.norm(x)
|
168 |
+
return x
|
169 |
+
|
170 |
+
|
171 |
+
class NormConvTranspose1d(nn.Module):
|
172 |
+
"""Wrapper around ConvTranspose1d and normalization applied to this conv
|
173 |
+
to provide a uniform interface across normalization approaches.
|
174 |
+
"""
|
175 |
+
|
176 |
+
def __init__(
|
177 |
+
self,
|
178 |
+
*args,
|
179 |
+
causal: bool = False,
|
180 |
+
norm: str = "none",
|
181 |
+
norm_kwargs: tp.Dict[str, tp.Any] = {},
|
182 |
+
**kwargs,
|
183 |
+
):
|
184 |
+
super().__init__()
|
185 |
+
self.convtr = apply_parametrization_norm(
|
186 |
+
nn.ConvTranspose1d(*args, **kwargs), norm
|
187 |
+
)
|
188 |
+
self.norm = get_norm_module(self.convtr, causal, norm, **norm_kwargs)
|
189 |
+
self.norm_type = norm
|
190 |
+
|
191 |
+
def forward(self, x):
|
192 |
+
x = self.convtr(x)
|
193 |
+
x = self.norm(x)
|
194 |
+
return x
|
195 |
+
|
196 |
+
|
197 |
+
class NormConvTranspose2d(nn.Module):
|
198 |
+
"""Wrapper around ConvTranspose2d and normalization applied to this conv
|
199 |
+
to provide a uniform interface across normalization approaches.
|
200 |
+
"""
|
201 |
+
|
202 |
+
def __init__(
|
203 |
+
self,
|
204 |
+
*args,
|
205 |
+
norm: str = "none",
|
206 |
+
norm_kwargs: tp.Dict[str, tp.Any] = {},
|
207 |
+
**kwargs,
|
208 |
+
):
|
209 |
+
super().__init__()
|
210 |
+
self.convtr = apply_parametrization_norm(
|
211 |
+
nn.ConvTranspose2d(*args, **kwargs), norm
|
212 |
+
)
|
213 |
+
self.norm = get_norm_module(self.convtr, causal=False, norm=norm, **norm_kwargs)
|
214 |
+
|
215 |
+
def forward(self, x):
|
216 |
+
x = self.convtr(x)
|
217 |
+
x = self.norm(x)
|
218 |
+
return x
|
219 |
+
|
220 |
+
|
221 |
+
class SConv1d(nn.Module):
|
222 |
+
"""Conv1d with some builtin handling of asymmetric or causal padding
|
223 |
+
and normalization.
|
224 |
+
"""
|
225 |
+
|
226 |
+
def __init__(
|
227 |
+
self,
|
228 |
+
in_channels: int,
|
229 |
+
out_channels: int,
|
230 |
+
kernel_size: int,
|
231 |
+
stride: int = 1,
|
232 |
+
dilation: int = 1,
|
233 |
+
groups: int = 1,
|
234 |
+
bias: bool = True,
|
235 |
+
causal: bool = False,
|
236 |
+
norm: str = "none",
|
237 |
+
norm_kwargs: tp.Dict[str, tp.Any] = {},
|
238 |
+
pad_mode: str = "reflect",
|
239 |
+
):
|
240 |
+
super().__init__()
|
241 |
+
# warn user on unusual setup between dilation and stride
|
242 |
+
if stride > 1 and dilation > 1:
|
243 |
+
warnings.warn(
|
244 |
+
"SConv1d has been initialized with stride > 1 and dilation > 1"
|
245 |
+
f" (kernel_size={kernel_size} stride={stride}, dilation={dilation})."
|
246 |
+
)
|
247 |
+
self.conv = NormConv1d(
|
248 |
+
in_channels,
|
249 |
+
out_channels,
|
250 |
+
kernel_size,
|
251 |
+
stride,
|
252 |
+
dilation=dilation,
|
253 |
+
groups=groups,
|
254 |
+
bias=bias,
|
255 |
+
causal=causal,
|
256 |
+
norm=norm,
|
257 |
+
norm_kwargs=norm_kwargs,
|
258 |
+
)
|
259 |
+
self.causal = causal
|
260 |
+
self.pad_mode = pad_mode
|
261 |
+
|
262 |
+
def forward(self, x):
|
263 |
+
B, C, T = x.shape
|
264 |
+
kernel_size = self.conv.conv.kernel_size[0]
|
265 |
+
stride = self.conv.conv.stride[0]
|
266 |
+
dilation = self.conv.conv.dilation[0]
|
267 |
+
kernel_size = (
|
268 |
+
kernel_size - 1
|
269 |
+
) * dilation + 1 # effective kernel size with dilations
|
270 |
+
padding_total = kernel_size - stride
|
271 |
+
extra_padding = get_extra_padding_for_conv1d(
|
272 |
+
x, kernel_size, stride, padding_total
|
273 |
+
)
|
274 |
+
if self.causal:
|
275 |
+
# Left padding for causal
|
276 |
+
x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode)
|
277 |
+
else:
|
278 |
+
# Asymmetric padding required for odd strides
|
279 |
+
padding_right = padding_total // 2
|
280 |
+
padding_left = padding_total - padding_right
|
281 |
+
x = pad1d(
|
282 |
+
x, (padding_left, padding_right + extra_padding), mode=self.pad_mode
|
283 |
+
)
|
284 |
+
return self.conv(x)
|
285 |
+
|
286 |
+
|
287 |
+
class SConvTranspose1d(nn.Module):
|
288 |
+
"""ConvTranspose1d with some builtin handling of asymmetric or causal padding
|
289 |
+
and normalization.
|
290 |
+
"""
|
291 |
+
|
292 |
+
def __init__(
|
293 |
+
self,
|
294 |
+
in_channels: int,
|
295 |
+
out_channels: int,
|
296 |
+
kernel_size: int,
|
297 |
+
stride: int = 1,
|
298 |
+
causal: bool = False,
|
299 |
+
norm: str = "none",
|
300 |
+
trim_right_ratio: float = 1.0,
|
301 |
+
norm_kwargs: tp.Dict[str, tp.Any] = {},
|
302 |
+
):
|
303 |
+
super().__init__()
|
304 |
+
self.convtr = NormConvTranspose1d(
|
305 |
+
in_channels,
|
306 |
+
out_channels,
|
307 |
+
kernel_size,
|
308 |
+
stride,
|
309 |
+
causal=causal,
|
310 |
+
norm=norm,
|
311 |
+
norm_kwargs=norm_kwargs,
|
312 |
+
)
|
313 |
+
self.causal = causal
|
314 |
+
self.trim_right_ratio = trim_right_ratio
|
315 |
+
assert (
|
316 |
+
self.causal or self.trim_right_ratio == 1.0
|
317 |
+
), "`trim_right_ratio` != 1.0 only makes sense for causal convolutions"
|
318 |
+
assert self.trim_right_ratio >= 0.0 and self.trim_right_ratio <= 1.0
|
319 |
+
|
320 |
+
def forward(self, x):
|
321 |
+
kernel_size = self.convtr.convtr.kernel_size[0]
|
322 |
+
stride = self.convtr.convtr.stride[0]
|
323 |
+
padding_total = kernel_size - stride
|
324 |
+
|
325 |
+
y = self.convtr(x)
|
326 |
+
|
327 |
+
# We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be
|
328 |
+
# removed at the very end, when keeping only the right length for the output,
|
329 |
+
# as removing it here would require also passing the length at the matching layer
|
330 |
+
# in the encoder.
|
331 |
+
if self.causal:
|
332 |
+
# Trim the padding on the right according to the specified ratio
|
333 |
+
# if trim_right_ratio = 1.0, trim everything from right
|
334 |
+
padding_right = math.ceil(padding_total * self.trim_right_ratio)
|
335 |
+
padding_left = padding_total - padding_right
|
336 |
+
y = unpad1d(y, (padding_left, padding_right))
|
337 |
+
else:
|
338 |
+
# Asymmetric padding required for odd strides
|
339 |
+
padding_right = padding_total // 2
|
340 |
+
padding_left = padding_total - padding_right
|
341 |
+
y = unpad1d(y, (padding_left, padding_right))
|
342 |
+
return y
|
encodec/modules/lstm.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
"""LSTM layers module."""
|
7 |
+
from torch import nn
|
8 |
+
|
9 |
+
|
10 |
+
class SLSTM(nn.Module):
|
11 |
+
"""
|
12 |
+
LSTM without worrying about the hidden state, nor the layout of the data.
|
13 |
+
Expects input as convolutional layout.
|
14 |
+
"""
|
15 |
+
|
16 |
+
def __init__(self, dimension: int, num_layers: int = 2, skip: bool = True):
|
17 |
+
super().__init__()
|
18 |
+
self.skip = skip
|
19 |
+
self.lstm = nn.LSTM(dimension, dimension, num_layers)
|
20 |
+
|
21 |
+
def forward(self, x):
|
22 |
+
x = x.permute(2, 0, 1)
|
23 |
+
y, _ = self.lstm(x)
|
24 |
+
if self.skip:
|
25 |
+
y = y + x
|
26 |
+
y = y.permute(1, 2, 0)
|
27 |
+
return y
|
encodec/modules/norm.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
"""Normalization modules."""
|
7 |
+
import typing as tp
|
8 |
+
|
9 |
+
import torch
|
10 |
+
from torch import nn
|
11 |
+
|
12 |
+
|
13 |
+
class ConvLayerNorm(nn.LayerNorm):
|
14 |
+
"""
|
15 |
+
Convolution-friendly LayerNorm that moves channels to last dimensions
|
16 |
+
before running the normalization and moves them back to original position right after.
|
17 |
+
"""
|
18 |
+
|
19 |
+
def __init__(
|
20 |
+
self, normalized_shape: tp.Union[int, tp.List[int], torch.Size], **kwargs
|
21 |
+
):
|
22 |
+
super().__init__(normalized_shape, **kwargs)
|
23 |
+
|
24 |
+
def forward(self, x):
|
25 |
+
|
26 |
+
assert x.ndim == 3 # (n_batch, n_channels, n_samples)
|
27 |
+
|
28 |
+
x = x.transpose(1, 2)
|
29 |
+
x = super().forward(x)
|
30 |
+
x = x.transpose(1, 2)
|
31 |
+
|
32 |
+
return x
|
encodec/modules/seanet.py
ADDED
@@ -0,0 +1,352 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
"""Encodec SEANet-based encoder and decoder implementation."""
|
7 |
+
import typing as tp
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
import torch.nn as nn
|
11 |
+
|
12 |
+
from . import SConv1d
|
13 |
+
from . import SConvTranspose1d
|
14 |
+
from . import SLSTM
|
15 |
+
|
16 |
+
|
17 |
+
class SEANetResnetBlock(nn.Module):
|
18 |
+
"""Residual block from SEANet model.
|
19 |
+
Args:
|
20 |
+
dim (int): Dimension of the input/output
|
21 |
+
kernel_sizes (list): List of kernel sizes for the convolutions.
|
22 |
+
dilations (list): List of dilations for the convolutions.
|
23 |
+
activation (str): Activation function.
|
24 |
+
activation_params (dict): Parameters to provide to the activation function
|
25 |
+
norm (str): Normalization method.
|
26 |
+
norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
|
27 |
+
causal (bool): Whether to use fully causal convolution.
|
28 |
+
pad_mode (str): Padding mode for the convolutions.
|
29 |
+
compress (int): Reduced dimensionality in residual branches (from Demucs v3)
|
30 |
+
true_skip (bool): Whether to use true skip connection or a simple convolution as the skip connection.
|
31 |
+
"""
|
32 |
+
|
33 |
+
def __init__(
|
34 |
+
self,
|
35 |
+
dim: int,
|
36 |
+
kernel_sizes: tp.List[int] = [3, 1],
|
37 |
+
dilations: tp.List[int] = [1, 1],
|
38 |
+
activation: str = "ELU",
|
39 |
+
activation_params: dict = {"alpha": 1.0},
|
40 |
+
norm: str = "weight_norm",
|
41 |
+
norm_params: tp.Dict[str, tp.Any] = {},
|
42 |
+
causal: bool = False,
|
43 |
+
pad_mode: str = "reflect",
|
44 |
+
compress: int = 2,
|
45 |
+
true_skip: bool = True,
|
46 |
+
):
|
47 |
+
super().__init__()
|
48 |
+
assert len(kernel_sizes) == len(
|
49 |
+
dilations
|
50 |
+
), "Number of kernel sizes should match number of dilations"
|
51 |
+
act = getattr(nn, activation)
|
52 |
+
hidden = dim // compress
|
53 |
+
block = []
|
54 |
+
for i, (kernel_size, dilation) in enumerate(zip(kernel_sizes, dilations)):
|
55 |
+
in_chs = dim if i == 0 else hidden
|
56 |
+
out_chs = dim if i == len(kernel_sizes) - 1 else hidden
|
57 |
+
block += [
|
58 |
+
act(**activation_params),
|
59 |
+
SConv1d(
|
60 |
+
in_chs,
|
61 |
+
out_chs,
|
62 |
+
kernel_size=kernel_size,
|
63 |
+
dilation=dilation,
|
64 |
+
norm=norm,
|
65 |
+
norm_kwargs=norm_params,
|
66 |
+
causal=causal,
|
67 |
+
pad_mode=pad_mode,
|
68 |
+
),
|
69 |
+
]
|
70 |
+
self.block = nn.Sequential(*block)
|
71 |
+
self.shortcut: nn.Module
|
72 |
+
if true_skip:
|
73 |
+
self.shortcut = nn.Identity()
|
74 |
+
else:
|
75 |
+
self.shortcut = SConv1d(
|
76 |
+
dim,
|
77 |
+
dim,
|
78 |
+
kernel_size=1,
|
79 |
+
norm=norm,
|
80 |
+
norm_kwargs=norm_params,
|
81 |
+
causal=causal,
|
82 |
+
pad_mode=pad_mode,
|
83 |
+
)
|
84 |
+
|
85 |
+
def forward(self, x):
|
86 |
+
return self.shortcut(x) + self.block(x)
|
87 |
+
|
88 |
+
|
89 |
+
class SEANetEncoder(nn.Module):
|
90 |
+
"""SEANet encoder.
|
91 |
+
Args:
|
92 |
+
channels (int): Audio channels.
|
93 |
+
dimension (int): Intermediate representation dimension.
|
94 |
+
n_filters (int): Base width for the model.
|
95 |
+
n_residual_layers (int): nb of residual layers.
|
96 |
+
ratios (Sequence[int]): kernel size and stride ratios. The encoder uses downsampling ratios instead of
|
97 |
+
upsampling ratios, hence it will use the ratios in the reverse order to the ones specified here
|
98 |
+
that must match the decoder order
|
99 |
+
activation (str): Activation function.
|
100 |
+
activation_params (dict): Parameters to provide to the activation function
|
101 |
+
norm (str): Normalization method.
|
102 |
+
norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
|
103 |
+
kernel_size (int): Kernel size for the initial convolution.
|
104 |
+
last_kernel_size (int): Kernel size for the initial convolution.
|
105 |
+
residual_kernel_size (int): Kernel size for the residual layers.
|
106 |
+
dilation_base (int): How much to increase the dilation with each layer.
|
107 |
+
causal (bool): Whether to use fully causal convolution.
|
108 |
+
pad_mode (str): Padding mode for the convolutions.
|
109 |
+
true_skip (bool): Whether to use true skip connection or a simple
|
110 |
+
(streamable) convolution as the skip connection in the residual network blocks.
|
111 |
+
compress (int): Reduced dimensionality in residual branches (from Demucs v3).
|
112 |
+
lstm (int): Number of LSTM layers at the end of the encoder.
|
113 |
+
"""
|
114 |
+
|
115 |
+
def __init__(
|
116 |
+
self,
|
117 |
+
channels: int = 1,
|
118 |
+
dimension: int = 128,
|
119 |
+
n_filters: int = 32,
|
120 |
+
n_residual_layers: int = 1,
|
121 |
+
ratios: tp.List[int] = [8, 5, 4, 2],
|
122 |
+
activation: str = "ELU",
|
123 |
+
activation_params: dict = {"alpha": 1.0},
|
124 |
+
norm: str = "weight_norm",
|
125 |
+
norm_params: tp.Dict[str, tp.Any] = {},
|
126 |
+
kernel_size: int = 7,
|
127 |
+
last_kernel_size: int = 7,
|
128 |
+
residual_kernel_size: int = 3,
|
129 |
+
dilation_base: int = 2,
|
130 |
+
causal: bool = False,
|
131 |
+
pad_mode: str = "reflect",
|
132 |
+
true_skip: bool = False,
|
133 |
+
compress: int = 2,
|
134 |
+
lstm: int = 2,
|
135 |
+
):
|
136 |
+
super().__init__()
|
137 |
+
self.channels = channels
|
138 |
+
self.dimension = dimension
|
139 |
+
self.n_filters = n_filters
|
140 |
+
self.ratios = list(reversed(ratios))
|
141 |
+
del ratios
|
142 |
+
self.n_residual_layers = n_residual_layers
|
143 |
+
self.hop_length = np.prod(self.ratios)
|
144 |
+
|
145 |
+
act = getattr(nn, activation)
|
146 |
+
mult = 1
|
147 |
+
model: tp.List[nn.Module] = [
|
148 |
+
SConv1d(
|
149 |
+
channels,
|
150 |
+
mult * n_filters,
|
151 |
+
kernel_size,
|
152 |
+
norm=norm,
|
153 |
+
norm_kwargs=norm_params,
|
154 |
+
causal=causal,
|
155 |
+
pad_mode=pad_mode,
|
156 |
+
)
|
157 |
+
]
|
158 |
+
# Downsample to raw audio scale
|
159 |
+
for i, ratio in enumerate(self.ratios):
|
160 |
+
# Add residual layers
|
161 |
+
for j in range(n_residual_layers):
|
162 |
+
model += [
|
163 |
+
SEANetResnetBlock(
|
164 |
+
mult * n_filters,
|
165 |
+
kernel_sizes=[residual_kernel_size, 1],
|
166 |
+
dilations=[dilation_base**j, 1],
|
167 |
+
norm=norm,
|
168 |
+
norm_params=norm_params,
|
169 |
+
activation=activation,
|
170 |
+
activation_params=activation_params,
|
171 |
+
causal=causal,
|
172 |
+
pad_mode=pad_mode,
|
173 |
+
compress=compress,
|
174 |
+
true_skip=true_skip,
|
175 |
+
)
|
176 |
+
]
|
177 |
+
|
178 |
+
# Add downsampling layers
|
179 |
+
model += [
|
180 |
+
act(**activation_params),
|
181 |
+
SConv1d(
|
182 |
+
mult * n_filters,
|
183 |
+
mult * n_filters * 2,
|
184 |
+
kernel_size=ratio * 2,
|
185 |
+
stride=ratio,
|
186 |
+
norm=norm,
|
187 |
+
norm_kwargs=norm_params,
|
188 |
+
causal=causal,
|
189 |
+
pad_mode=pad_mode,
|
190 |
+
),
|
191 |
+
]
|
192 |
+
mult *= 2
|
193 |
+
|
194 |
+
if lstm:
|
195 |
+
model += [SLSTM(mult * n_filters, num_layers=lstm)]
|
196 |
+
|
197 |
+
model += [
|
198 |
+
act(**activation_params),
|
199 |
+
SConv1d(
|
200 |
+
mult * n_filters,
|
201 |
+
dimension,
|
202 |
+
last_kernel_size,
|
203 |
+
norm=norm,
|
204 |
+
norm_kwargs=norm_params,
|
205 |
+
causal=causal,
|
206 |
+
pad_mode=pad_mode,
|
207 |
+
),
|
208 |
+
]
|
209 |
+
|
210 |
+
self.model = nn.Sequential(*model)
|
211 |
+
|
212 |
+
def forward(self, x):
|
213 |
+
return self.model(x)
|
214 |
+
|
215 |
+
|
216 |
+
class SEANetDecoder(nn.Module):
|
217 |
+
"""SEANet decoder.
|
218 |
+
Args:
|
219 |
+
channels (int): Audio channels.
|
220 |
+
dimension (int): Intermediate representation dimension.
|
221 |
+
n_filters (int): Base width for the model.
|
222 |
+
n_residual_layers (int): nb of residual layers.
|
223 |
+
ratios (Sequence[int]): kernel size and stride ratios
|
224 |
+
activation (str): Activation function.
|
225 |
+
activation_params (dict): Parameters to provide to the activation function
|
226 |
+
final_activation (str): Final activation function after all convolutions.
|
227 |
+
final_activation_params (dict): Parameters to provide to the activation function
|
228 |
+
norm (str): Normalization method.
|
229 |
+
norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
|
230 |
+
kernel_size (int): Kernel size for the initial convolution.
|
231 |
+
last_kernel_size (int): Kernel size for the initial convolution.
|
232 |
+
residual_kernel_size (int): Kernel size for the residual layers.
|
233 |
+
dilation_base (int): How much to increase the dilation with each layer.
|
234 |
+
causal (bool): Whether to use fully causal convolution.
|
235 |
+
pad_mode (str): Padding mode for the convolutions.
|
236 |
+
true_skip (bool): Whether to use true skip connection or a simple
|
237 |
+
(streamable) convolution as the skip connection in the residual network blocks.
|
238 |
+
compress (int): Reduced dimensionality in residual branches (from Demucs v3).
|
239 |
+
lstm (int): Number of LSTM layers at the end of the encoder.
|
240 |
+
trim_right_ratio (float): Ratio for trimming at the right of the transposed convolution under the causal setup.
|
241 |
+
If equal to 1.0, it means that all the trimming is done at the right.
|
242 |
+
"""
|
243 |
+
|
244 |
+
def __init__(
|
245 |
+
self,
|
246 |
+
channels: int = 1,
|
247 |
+
dimension: int = 128,
|
248 |
+
n_filters: int = 32,
|
249 |
+
n_residual_layers: int = 1,
|
250 |
+
ratios: tp.List[int] = [8, 5, 4, 2],
|
251 |
+
activation: str = "ELU",
|
252 |
+
activation_params: dict = {"alpha": 1.0},
|
253 |
+
final_activation: tp.Optional[str] = None,
|
254 |
+
final_activation_params: tp.Optional[dict] = None,
|
255 |
+
norm: str = "weight_norm",
|
256 |
+
norm_params: tp.Dict[str, tp.Any] = {},
|
257 |
+
kernel_size: int = 7,
|
258 |
+
last_kernel_size: int = 7,
|
259 |
+
residual_kernel_size: int = 3,
|
260 |
+
dilation_base: int = 2,
|
261 |
+
causal: bool = False,
|
262 |
+
pad_mode: str = "reflect",
|
263 |
+
true_skip: bool = False,
|
264 |
+
compress: int = 2,
|
265 |
+
lstm: int = 2,
|
266 |
+
trim_right_ratio: float = 1.0,
|
267 |
+
):
|
268 |
+
super().__init__()
|
269 |
+
self.dimension = dimension
|
270 |
+
self.channels = channels
|
271 |
+
self.n_filters = n_filters
|
272 |
+
self.ratios = ratios
|
273 |
+
del ratios
|
274 |
+
self.n_residual_layers = n_residual_layers
|
275 |
+
self.hop_length = np.prod(self.ratios)
|
276 |
+
|
277 |
+
act = getattr(nn, activation)
|
278 |
+
mult = int(2 ** len(self.ratios))
|
279 |
+
model: tp.List[nn.Module] = [
|
280 |
+
SConv1d(
|
281 |
+
dimension,
|
282 |
+
mult * n_filters,
|
283 |
+
kernel_size,
|
284 |
+
norm=norm,
|
285 |
+
norm_kwargs=norm_params,
|
286 |
+
causal=causal,
|
287 |
+
pad_mode=pad_mode,
|
288 |
+
)
|
289 |
+
]
|
290 |
+
|
291 |
+
if lstm:
|
292 |
+
model += [SLSTM(mult * n_filters, num_layers=lstm)]
|
293 |
+
|
294 |
+
# Upsample to raw audio scale
|
295 |
+
for i, ratio in enumerate(self.ratios):
|
296 |
+
# Add upsampling layers
|
297 |
+
model += [
|
298 |
+
act(**activation_params),
|
299 |
+
SConvTranspose1d(
|
300 |
+
mult * n_filters,
|
301 |
+
mult * n_filters // 2,
|
302 |
+
kernel_size=ratio * 2,
|
303 |
+
stride=ratio,
|
304 |
+
norm=norm,
|
305 |
+
norm_kwargs=norm_params,
|
306 |
+
causal=causal,
|
307 |
+
trim_right_ratio=trim_right_ratio,
|
308 |
+
),
|
309 |
+
]
|
310 |
+
# Add residual layers
|
311 |
+
for j in range(n_residual_layers):
|
312 |
+
model += [
|
313 |
+
SEANetResnetBlock(
|
314 |
+
mult * n_filters // 2,
|
315 |
+
kernel_sizes=[residual_kernel_size, 1],
|
316 |
+
dilations=[dilation_base**j, 1],
|
317 |
+
activation=activation,
|
318 |
+
activation_params=activation_params,
|
319 |
+
norm=norm,
|
320 |
+
norm_params=norm_params,
|
321 |
+
causal=causal,
|
322 |
+
pad_mode=pad_mode,
|
323 |
+
compress=compress,
|
324 |
+
true_skip=true_skip,
|
325 |
+
)
|
326 |
+
]
|
327 |
+
|
328 |
+
mult //= 2
|
329 |
+
|
330 |
+
# Add final layers
|
331 |
+
model += [
|
332 |
+
act(**activation_params),
|
333 |
+
SConv1d(
|
334 |
+
n_filters,
|
335 |
+
channels,
|
336 |
+
last_kernel_size,
|
337 |
+
norm=norm,
|
338 |
+
norm_kwargs=norm_params,
|
339 |
+
causal=causal,
|
340 |
+
pad_mode=pad_mode,
|
341 |
+
),
|
342 |
+
]
|
343 |
+
# Add optional final activation to decoder (eg. tanh)
|
344 |
+
if final_activation is not None:
|
345 |
+
final_act = getattr(nn, final_activation)
|
346 |
+
final_activation_params = final_activation_params or {}
|
347 |
+
model += [final_act(**final_activation_params)]
|
348 |
+
self.model = nn.Sequential(*model)
|
349 |
+
|
350 |
+
def forward(self, z):
|
351 |
+
y = self.model(z)
|
352 |
+
return y
|
encodec/quantization/__init__.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
# flake8: noqa
|
7 |
+
from .vq import ResidualVectorQuantizer
|
encodec/quantization/core_vq.py
ADDED
@@ -0,0 +1,370 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
import typing as tp
|
7 |
+
import warnings
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn.functional as F
|
11 |
+
from torch import nn
|
12 |
+
|
13 |
+
from .. import distrib
|
14 |
+
|
15 |
+
################################################################################
|
16 |
+
# Core vector quantization implementation
|
17 |
+
################################################################################
|
18 |
+
|
19 |
+
|
20 |
+
def default(val: tp.Any, d: tp.Any) -> tp.Any:
|
21 |
+
return val if val is not None else d
|
22 |
+
|
23 |
+
|
24 |
+
def ema_inplace(moving_avg, new, decay: float):
|
25 |
+
moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
|
26 |
+
|
27 |
+
|
28 |
+
def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-5):
|
29 |
+
return (x + epsilon) / (x.sum() + n_categories * epsilon)
|
30 |
+
|
31 |
+
|
32 |
+
def uniform_init(*shape: int):
|
33 |
+
t = torch.empty(shape)
|
34 |
+
nn.init.kaiming_uniform_(t)
|
35 |
+
return t
|
36 |
+
|
37 |
+
|
38 |
+
def sample_vectors(samples, num: int):
|
39 |
+
num_samples, device = samples.shape[0], samples.device
|
40 |
+
|
41 |
+
if num_samples >= num:
|
42 |
+
indices = torch.randperm(num_samples, device=device)[:num]
|
43 |
+
else:
|
44 |
+
indices = torch.randint(0, num_samples, (num,), device=device)
|
45 |
+
|
46 |
+
return samples[indices]
|
47 |
+
|
48 |
+
|
49 |
+
def kmeans(samples, num_clusters: int, num_iters: int = 10):
|
50 |
+
dim, dtype = samples.shape[-1], samples.dtype
|
51 |
+
|
52 |
+
means = sample_vectors(samples, num_clusters)
|
53 |
+
|
54 |
+
for _ in range(num_iters):
|
55 |
+
diffs = samples.unsqueeze(1) - means.unsqueeze(0)
|
56 |
+
dists = -(diffs**2).sum(dim=-1)
|
57 |
+
|
58 |
+
buckets = dists.max(dim=-1).indices
|
59 |
+
bins = torch.bincount(buckets, minlength=num_clusters)
|
60 |
+
zero_mask = bins == 0
|
61 |
+
bins_min_clamped = bins.masked_fill(zero_mask, 1)
|
62 |
+
|
63 |
+
new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype)
|
64 |
+
|
65 |
+
new_means.scatter_add_(0, buckets.unsqueeze(-1).expand(-1, dim), samples)
|
66 |
+
new_means = new_means / bins_min_clamped[..., None]
|
67 |
+
|
68 |
+
means = torch.where(zero_mask[..., None], means, new_means)
|
69 |
+
|
70 |
+
return means, bins
|
71 |
+
|
72 |
+
|
73 |
+
class EuclideanCodebook(nn.Module):
|
74 |
+
"""Codebook with Euclidean distance.
|
75 |
+
Args:
|
76 |
+
dim (int): Dimension.
|
77 |
+
codebook_size (int): Codebook size.
|
78 |
+
kmeans_init (bool): Whether to use k-means to initialize the codebooks.
|
79 |
+
If set to true, run the k-means algorithm on the first training batch and use
|
80 |
+
the learned centroids as initialization.
|
81 |
+
kmeans_iters (int): Number of iterations used for k-means algorithm at initialization.
|
82 |
+
decay (float): Decay for exponential moving average over the codebooks.
|
83 |
+
epsilon (float): Epsilon value for numerical stability.
|
84 |
+
threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
|
85 |
+
that have an exponential moving average cluster size less than the specified threshold with
|
86 |
+
randomly selected vector from the current batch.
|
87 |
+
"""
|
88 |
+
|
89 |
+
def __init__(
|
90 |
+
self,
|
91 |
+
dim: int,
|
92 |
+
codebook_size: int,
|
93 |
+
kmeans_init: int = False,
|
94 |
+
kmeans_iters: int = 10,
|
95 |
+
decay: float = 0.99,
|
96 |
+
epsilon: float = 1e-5,
|
97 |
+
threshold_ema_dead_code: int = 2,
|
98 |
+
):
|
99 |
+
super().__init__()
|
100 |
+
self.decay = decay
|
101 |
+
init_fn: tp.Union[tp.Callable[..., torch.Tensor], tp.Any] = (
|
102 |
+
uniform_init if not kmeans_init else torch.zeros
|
103 |
+
)
|
104 |
+
embed = init_fn(codebook_size, dim)
|
105 |
+
|
106 |
+
self.codebook_size = codebook_size
|
107 |
+
|
108 |
+
self.kmeans_iters = kmeans_iters
|
109 |
+
self.epsilon = epsilon
|
110 |
+
self.threshold_ema_dead_code = threshold_ema_dead_code
|
111 |
+
|
112 |
+
self.register_buffer("inited", torch.Tensor([not kmeans_init]))
|
113 |
+
self.register_buffer("cluster_size", torch.zeros(codebook_size))
|
114 |
+
self.register_buffer("embed", embed)
|
115 |
+
self.register_buffer("embed_avg", embed.clone())
|
116 |
+
|
117 |
+
@torch.jit.ignore
|
118 |
+
def init_embed_(self, data):
|
119 |
+
if self.inited:
|
120 |
+
return
|
121 |
+
|
122 |
+
embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters)
|
123 |
+
self.embed.data.copy_(embed)
|
124 |
+
self.embed_avg.data.copy_(embed.clone())
|
125 |
+
self.cluster_size.data.copy_(cluster_size)
|
126 |
+
self.inited.data.copy_(torch.Tensor([True]))
|
127 |
+
# Make sure all buffers across workers are in sync after initialization
|
128 |
+
distrib.broadcast_tensors(self.buffers())
|
129 |
+
|
130 |
+
def replace_(self, samples, mask):
|
131 |
+
modified_codebook = torch.where(
|
132 |
+
mask[..., None], sample_vectors(samples, self.codebook_size), self.embed
|
133 |
+
)
|
134 |
+
self.embed.data.copy_(modified_codebook)
|
135 |
+
|
136 |
+
def expire_codes_(self, batch_samples):
|
137 |
+
if self.threshold_ema_dead_code == 0:
|
138 |
+
return
|
139 |
+
|
140 |
+
expired_codes = self.cluster_size < self.threshold_ema_dead_code
|
141 |
+
if not torch.any(expired_codes):
|
142 |
+
return
|
143 |
+
|
144 |
+
batch_samples = batch_samples.view(-1, batch_samples.shape[-1])
|
145 |
+
self.replace_(batch_samples, mask=expired_codes)
|
146 |
+
distrib.broadcast_tensors(self.buffers())
|
147 |
+
|
148 |
+
def preprocess(self, x):
|
149 |
+
x = x.view(-1, x.shape[-1])
|
150 |
+
return x
|
151 |
+
|
152 |
+
def quantize(self, x):
|
153 |
+
embed = self.embed.t()
|
154 |
+
dist = -(
|
155 |
+
x.pow(2).sum(1, keepdim=True)
|
156 |
+
- 2 * x @ embed
|
157 |
+
+ embed.pow(2).sum(0, keepdim=True)
|
158 |
+
)
|
159 |
+
embed_ind = dist.max(dim=-1).indices
|
160 |
+
return embed_ind
|
161 |
+
|
162 |
+
def postprocess_emb(self, embed_ind, shape):
|
163 |
+
return embed_ind.view(*shape[:-1])
|
164 |
+
|
165 |
+
def dequantize(self, embed_ind):
|
166 |
+
quantize = F.embedding(embed_ind, self.embed)
|
167 |
+
return quantize
|
168 |
+
|
169 |
+
def encode(self, x):
|
170 |
+
shape = x.shape
|
171 |
+
# pre-process
|
172 |
+
x = self.preprocess(x)
|
173 |
+
# quantize
|
174 |
+
embed_ind = self.quantize(x)
|
175 |
+
# post-process
|
176 |
+
embed_ind = self.postprocess_emb(embed_ind, shape)
|
177 |
+
return embed_ind
|
178 |
+
|
179 |
+
def decode(self, embed_ind):
|
180 |
+
quantize = self.dequantize(embed_ind)
|
181 |
+
return quantize
|
182 |
+
|
183 |
+
def forward(self, x):
|
184 |
+
shape, dtype = x.shape, x.dtype
|
185 |
+
x = self.preprocess(x)
|
186 |
+
|
187 |
+
self.init_embed_(x)
|
188 |
+
|
189 |
+
embed_ind = self.quantize(x)
|
190 |
+
embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype)
|
191 |
+
embed_ind = self.postprocess_emb(embed_ind, shape)
|
192 |
+
quantize = self.dequantize(embed_ind)
|
193 |
+
|
194 |
+
if self.training:
|
195 |
+
# We do the expiry of code at that point as buffers are in sync
|
196 |
+
# and all the workers will take the same decision.
|
197 |
+
self.expire_codes_(x)
|
198 |
+
ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay)
|
199 |
+
embed_sum = x.t() @ embed_onehot
|
200 |
+
ema_inplace(self.embed_avg, embed_sum.t(), self.decay)
|
201 |
+
cluster_size = (
|
202 |
+
laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon)
|
203 |
+
* self.cluster_size.sum()
|
204 |
+
)
|
205 |
+
embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
|
206 |
+
self.embed.data.copy_(embed_normalized)
|
207 |
+
|
208 |
+
return quantize, embed_ind
|
209 |
+
|
210 |
+
|
211 |
+
class VectorQuantization(nn.Module):
|
212 |
+
"""Vector quantization implementation.
|
213 |
+
Currently supports only euclidean distance.
|
214 |
+
Args:
|
215 |
+
dim (int): Dimension
|
216 |
+
codebook_size (int): Codebook size
|
217 |
+
codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim.
|
218 |
+
decay (float): Decay for exponential moving average over the codebooks.
|
219 |
+
epsilon (float): Epsilon value for numerical stability.
|
220 |
+
kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
|
221 |
+
kmeans_iters (int): Number of iterations used for kmeans initialization.
|
222 |
+
threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
|
223 |
+
that have an exponential moving average cluster size less than the specified threshold with
|
224 |
+
randomly selected vector from the current batch.
|
225 |
+
commitment_weight (float): Weight for commitment loss.
|
226 |
+
"""
|
227 |
+
|
228 |
+
def __init__(
|
229 |
+
self,
|
230 |
+
dim: int,
|
231 |
+
codebook_size: int,
|
232 |
+
codebook_dim: tp.Optional[int] = None,
|
233 |
+
decay: float = 0.99,
|
234 |
+
epsilon: float = 1e-5,
|
235 |
+
kmeans_init: bool = True,
|
236 |
+
kmeans_iters: int = 50,
|
237 |
+
threshold_ema_dead_code: int = 2,
|
238 |
+
commitment_weight: float = 1.0,
|
239 |
+
):
|
240 |
+
super().__init__()
|
241 |
+
_codebook_dim: int = default(codebook_dim, dim)
|
242 |
+
|
243 |
+
requires_projection = _codebook_dim != dim
|
244 |
+
self.project_in = (
|
245 |
+
nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity()
|
246 |
+
)
|
247 |
+
self.project_out = (
|
248 |
+
nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity()
|
249 |
+
)
|
250 |
+
|
251 |
+
self.epsilon = epsilon
|
252 |
+
self.commitment_weight = commitment_weight
|
253 |
+
|
254 |
+
self._codebook = EuclideanCodebook(
|
255 |
+
dim=_codebook_dim,
|
256 |
+
codebook_size=codebook_size,
|
257 |
+
kmeans_init=kmeans_init,
|
258 |
+
kmeans_iters=kmeans_iters,
|
259 |
+
decay=decay,
|
260 |
+
epsilon=epsilon,
|
261 |
+
threshold_ema_dead_code=threshold_ema_dead_code,
|
262 |
+
)
|
263 |
+
self.codebook_size = codebook_size
|
264 |
+
|
265 |
+
@property
|
266 |
+
def codebook(self):
|
267 |
+
return self._codebook.embed
|
268 |
+
|
269 |
+
def encode(self, x):
|
270 |
+
x = x.transpose(1, 2).contiguous()
|
271 |
+
x = self.project_in(x)
|
272 |
+
embed_in = self._codebook.encode(x)
|
273 |
+
return embed_in
|
274 |
+
|
275 |
+
def decode(self, embed_ind):
|
276 |
+
quantize = self._codebook.decode(embed_ind)
|
277 |
+
quantize = self.project_out(quantize)
|
278 |
+
quantize = quantize.transpose(1, 2).contiguous()
|
279 |
+
|
280 |
+
return quantize
|
281 |
+
|
282 |
+
def forward(self, x):
|
283 |
+
device = x.device
|
284 |
+
x = x.transpose(1, 2).contiguous()
|
285 |
+
x = self.project_in(x)
|
286 |
+
|
287 |
+
quantize, embed_ind = self._codebook(x)
|
288 |
+
|
289 |
+
if self.training:
|
290 |
+
quantize = x + (quantize - x).detach()
|
291 |
+
|
292 |
+
loss = torch.tensor([0.0], device=device, requires_grad=self.training)
|
293 |
+
|
294 |
+
if self.training:
|
295 |
+
warnings.warn(
|
296 |
+
"When using RVQ in training model, first check "
|
297 |
+
"https://github.com/facebookresearch/encodec/issues/25 . "
|
298 |
+
"The bug wasn't fixed here for reproducibility."
|
299 |
+
)
|
300 |
+
if self.commitment_weight > 0:
|
301 |
+
commit_loss = F.mse_loss(quantize.detach(), x)
|
302 |
+
loss = loss + commit_loss * self.commitment_weight
|
303 |
+
|
304 |
+
quantize = self.project_out(quantize)
|
305 |
+
quantize = quantize.transpose(1, 2).contiguous()
|
306 |
+
return quantize, embed_ind, loss
|
307 |
+
|
308 |
+
|
309 |
+
class ResidualVectorQuantization(nn.Module):
|
310 |
+
"""Residual vector quantization implementation.
|
311 |
+
Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf
|
312 |
+
"""
|
313 |
+
|
314 |
+
def __init__(self, *, num_quantizers, **kwargs):
|
315 |
+
super().__init__()
|
316 |
+
self.layers = nn.ModuleList(
|
317 |
+
[VectorQuantization(**kwargs) for _ in range(num_quantizers)]
|
318 |
+
)
|
319 |
+
|
320 |
+
def forward(self, x, n_q: tp.Optional[int] = None):
|
321 |
+
quantized_out = 0.0
|
322 |
+
residual = x
|
323 |
+
|
324 |
+
all_losses = []
|
325 |
+
all_indices = []
|
326 |
+
|
327 |
+
n_q = n_q or len(self.layers)
|
328 |
+
|
329 |
+
for layer in self.layers[:n_q]:
|
330 |
+
quantized, indices, loss = layer(residual)
|
331 |
+
residual = residual - quantized
|
332 |
+
quantized_out = quantized_out + quantized
|
333 |
+
|
334 |
+
all_indices.append(indices)
|
335 |
+
all_losses.append(loss)
|
336 |
+
|
337 |
+
out_losses, out_indices = map(torch.stack, (all_losses, all_indices))
|
338 |
+
return quantized_out, out_indices, out_losses
|
339 |
+
|
340 |
+
def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None) -> torch.Tensor:
|
341 |
+
residual = x
|
342 |
+
|
343 |
+
# Return quantized latents, both summed and at each quantizer level
|
344 |
+
z_O = 0.0 # Summed quantized latents
|
345 |
+
z_o = [] # Quantized latents at each quantizer level
|
346 |
+
|
347 |
+
all_indices = []
|
348 |
+
n_q = n_q or len(self.layers)
|
349 |
+
for layer in self.layers[:n_q]:
|
350 |
+
indices = layer.encode(residual)
|
351 |
+
quantized = layer.decode(indices)
|
352 |
+
|
353 |
+
z_o += [quantized]
|
354 |
+
z_O = z_O + quantized
|
355 |
+
|
356 |
+
residual = residual - quantized
|
357 |
+
all_indices.append(indices)
|
358 |
+
|
359 |
+
out_indices = torch.stack(all_indices)
|
360 |
+
z_o = torch.stack(z_o, dim=1)
|
361 |
+
|
362 |
+
return out_indices, z_O, z_o
|
363 |
+
|
364 |
+
def decode(self, q_indices: torch.Tensor) -> torch.Tensor:
|
365 |
+
quantized_out = torch.tensor(0.0, device=q_indices.device)
|
366 |
+
for i, indices in enumerate(q_indices):
|
367 |
+
layer = self.layers[i]
|
368 |
+
quantized = layer.decode(indices)
|
369 |
+
quantized_out = quantized_out + quantized
|
370 |
+
return quantized_out
|
encodec/quantization/vq.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
import math
|
7 |
+
import typing as tp
|
8 |
+
|
9 |
+
import torch
|
10 |
+
from torch import nn
|
11 |
+
|
12 |
+
from .core_vq import ResidualVectorQuantization
|
13 |
+
|
14 |
+
################################################################################
|
15 |
+
# Residual quantization module
|
16 |
+
################################################################################
|
17 |
+
|
18 |
+
|
19 |
+
class ResidualVectorQuantizer(nn.Module):
|
20 |
+
"""Residual Vector Quantizer.
|
21 |
+
Args:
|
22 |
+
dimension (int): Dimension of the codebooks.
|
23 |
+
n_q (int): Number of residual vector quantizers used.
|
24 |
+
bins (int): Codebook size.
|
25 |
+
decay (float): Decay for exponential moving average over the codebooks.
|
26 |
+
kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
|
27 |
+
kmeans_iters (int): Number of iterations used for kmeans initialization.
|
28 |
+
threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
|
29 |
+
that have an exponential moving average cluster size less than the specified threshold with
|
30 |
+
randomly selected vector from the current batch.
|
31 |
+
"""
|
32 |
+
|
33 |
+
def __init__(
|
34 |
+
self,
|
35 |
+
dimension: int = 256,
|
36 |
+
n_q: int = 8,
|
37 |
+
bins: int = 1024,
|
38 |
+
decay: float = 0.99,
|
39 |
+
kmeans_init: bool = True,
|
40 |
+
kmeans_iters: int = 50,
|
41 |
+
threshold_ema_dead_code: int = 2,
|
42 |
+
):
|
43 |
+
super().__init__()
|
44 |
+
self.n_q = n_q
|
45 |
+
self.dimension = dimension
|
46 |
+
self.bins = bins
|
47 |
+
self.decay = decay
|
48 |
+
self.kmeans_init = kmeans_init
|
49 |
+
self.kmeans_iters = kmeans_iters
|
50 |
+
self.threshold_ema_dead_code = threshold_ema_dead_code
|
51 |
+
self.vq = ResidualVectorQuantization(
|
52 |
+
dim=self.dimension,
|
53 |
+
codebook_size=self.bins,
|
54 |
+
num_quantizers=self.n_q,
|
55 |
+
decay=self.decay,
|
56 |
+
kmeans_init=self.kmeans_init,
|
57 |
+
kmeans_iters=self.kmeans_iters,
|
58 |
+
threshold_ema_dead_code=self.threshold_ema_dead_code,
|
59 |
+
)
|
60 |
+
|
61 |
+
def get_num_quantizers_for_bandwidth(
|
62 |
+
self, frame_rate: int, bandwidth: tp.Optional[float] = None
|
63 |
+
) -> int:
|
64 |
+
"""Return n_q based on specified target bandwidth."""
|
65 |
+
bw_per_q = self.get_bandwidth_per_quantizer(frame_rate)
|
66 |
+
n_q = self.n_q
|
67 |
+
if bandwidth and bandwidth > 0.0:
|
68 |
+
# bandwidth is represented as a thousandth of what it is, e.g. 6kbps bandwidth is represented as
|
69 |
+
# bandwidth == 6.0
|
70 |
+
n_q = int(max(1, math.floor(bandwidth * 1000 / bw_per_q)))
|
71 |
+
return n_q
|
72 |
+
|
73 |
+
def get_bandwidth_per_quantizer(self, frame_rate: int):
|
74 |
+
"""Return bandwidth per quantizer for a given input frame rate.
|
75 |
+
Each quantizer encodes a frame with lg(bins) bits.
|
76 |
+
"""
|
77 |
+
return math.log2(self.bins) * frame_rate
|
78 |
+
|
79 |
+
def encode(
|
80 |
+
self, x: torch.Tensor, frame_rate: int, bandwidth: tp.Optional[float] = None
|
81 |
+
) -> torch.Tensor:
|
82 |
+
"""Encode a given input tensor with the specified frame rate at the given bandwidth.
|
83 |
+
The RVQ encode method sets the appropriate number of quantizers to use
|
84 |
+
and returns indices for each quantizer.
|
85 |
+
"""
|
86 |
+
n_q = self.get_num_quantizers_for_bandwidth(frame_rate, bandwidth)
|
87 |
+
codes, z_O, z_o = self.vq.encode(x, n_q=n_q)
|
88 |
+
return codes, z_O, z_o
|
89 |
+
|
90 |
+
def decode(self, codes: torch.Tensor) -> torch.Tensor:
|
91 |
+
"""
|
92 |
+
Decode the given codes to the quantized representation.
|
93 |
+
"""
|
94 |
+
quantized = self.vq.decode(codes)
|
95 |
+
return quantized
|
requirements.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gradio
|
2 |
+
torch
|
3 |
+
audioseal
|
4 |
+
git+https://github.com/descriptinc/audiotools
|
5 |
+
pydantic==2.10.6
|