oreillyp commited on
Commit
f872c8a
·
1 Parent(s): 40bed59

initial commit

Browse files
.gitignore ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ **/..git/*
2
+ __pycache__/
3
+ .DS_Store
4
+ ._.DS_Store
5
+ .ipynb_checkpoints/
6
+ .vscode/
7
+ *.egg-info/
8
+ .pytest_cache
9
+ *.ipynb_checkpoints/
app.py ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import uuid
4
+ from pathlib import Path
5
+ from contextlib import contextmanager
6
+
7
+ import numpy as np
8
+ import torch
9
+ import matplotlib.pyplot as plt
10
+ import gradio as gr
11
+ from scipy.io.wavfile import write as wavwrite
12
+
13
+ from audiotools import AudioSignal
14
+ from audioseal import AudioSeal
15
+
16
+ # allow local imports of your encodec folder
17
+ @contextmanager
18
+ def chdir(path: str):
19
+ origin = Path().absolute()
20
+ try:
21
+ os.chdir(path)
22
+ yield
23
+ finally:
24
+ os.chdir(origin)
25
+
26
+
27
+ _path = Path(__file__).parent
28
+ sys.path.insert(0, str(_path))
29
+ with chdir(_path):
30
+ from encodec import Encodec
31
+
32
+
33
+ OUT_DIR = _path / "gradio-outputs"
34
+ OUT_DIR.mkdir(exist_ok=True)
35
+
36
+ LOUDNESS_DB = -16.
37
+ SAMPLE_RATE = 48_000
38
+ ENCODEC_SAMPLE_RATE = 16_000
39
+ AUDIOSEAL_SAMPLE_RATE = 16_000
40
+
41
+ # load codec
42
+ config = {
43
+ "sample_rate": 16_000,
44
+ "target_bandwidths": [2.2],
45
+ "channels": 1,
46
+ "causal": False,
47
+ "codebook_size": 2048,
48
+ "n_filters": 64,
49
+ "model_norm": "weight_norm",
50
+ "audio_normalize": False,
51
+ "true_skip": True,
52
+ "ratios": [8, 5, 4, 2],
53
+ "encoder_kwargs": {"pad_mode": "constant"},
54
+ "decoder_kwargs": {"pad_mode": "constant"},
55
+ }
56
+ codec = Encodec(**config)
57
+ codec.load_state_dict(torch.load("ckpt/encodec_voicecraft.pt", map_location="cpu"))
58
+ codec.eval()
59
+ for p in codec.parameters(): p.requires_grad_(False)
60
+ codec.set_target_bandwidth(2.2)
61
+
62
+ # watermark models
63
+ embedder = AudioSeal.load_generator("audioseal_wm_16bits")
64
+ detector = AudioSeal.load_detector("audioseal_detector_16bits")
65
+
66
+
67
+ @torch.no_grad()
68
+ def encode(signal: AudioSignal, codec: torch.nn.Module):
69
+ n_b, n_ch, n_s = signal.shape
70
+ sr = signal.sample_rate
71
+ loud_db = signal.loudness()
72
+ x = signal.clone().resample(ENCODEC_SAMPLE_RATE).audio_data
73
+ x = x.reshape(n_b * n_ch, 1, -1)
74
+ codes, *_ = codec.encode(x)
75
+ return codes, n_b, n_ch, n_s, sr, loud_db
76
+
77
+ @torch.no_grad()
78
+ def decode(codes, n_b, n_ch, n_s, sr, loud_db, codec):
79
+ x = codec.decode(codes).reshape(n_b, n_ch, -1)
80
+ sig = AudioSignal(x, sample_rate=ENCODEC_SAMPLE_RATE)
81
+ sig = sig.resample(sr)
82
+ sig.audio_data = sig.audio_data[..., :n_s]
83
+ sig.audio_data = torch.nn.functional.pad(
84
+ sig.audio_data, (0, max(0, n_s - sig.signal_length))
85
+ )
86
+ return sig.normalize(loud_db)
87
+
88
+ @torch.no_grad()
89
+ def split_bands(signal: AudioSignal, sample_rate: float = ENCODEC_SAMPLE_RATE):
90
+ nyq = sample_rate // 2
91
+ high = signal.clone().high_pass(cutoffs=int(nyq * 0.95), zeros=51)
92
+ low = signal.clone().low_pass(cutoffs=int(nyq * 1.05), zeros=51)
93
+ loud_db = low.loudness()
94
+ low = low.resample(sample_rate)
95
+ return low, high, loud_db
96
+
97
+ @torch.no_grad()
98
+ def merge_bands(low, high, loud_db):
99
+ low = low.clone().to(high.device).resample(high.sample_rate)
100
+ low.audio_data = low.audio_data[..., :high.signal_length]
101
+ low.audio_data = torch.nn.functional.pad(
102
+ low.audio_data, (0, max(0, high.signal_length - low.signal_length))
103
+ )
104
+ return low.normalize(loud_db) + high
105
+
106
+ @torch.no_grad()
107
+ def attack(signal: AudioSignal, codec, split_rate_hz=AUDIOSEAL_SAMPLE_RATE):
108
+ if split_rate_hz:
109
+ low, high, loud_db = split_bands(signal, split_rate_hz)
110
+ low = decode(*encode(low, codec), codec)
111
+ return merge_bands(low, high, loud_db)
112
+ else:
113
+ return decode(*encode(signal, codec), codec)
114
+
115
+ @torch.no_grad()
116
+ def embed(signal: AudioSignal, embedder: torch.nn.Module):
117
+ orig_ch, orig_sr = signal.num_channels, signal.sample_rate
118
+ sig = signal.clone().resample(SAMPLE_RATE)
119
+ if orig_ch > 1:
120
+ b, c, n = sig.audio_data.shape
121
+ sig.audio_data = sig.audio_data.reshape(b * c, 1, n)
122
+ low, high, loud = split_bands(sig.clone(), AUDIOSEAL_SAMPLE_RATE)
123
+ wm = embedder.get_watermark(low.audio_data, AUDIOSEAL_SAMPLE_RATE)
124
+ low.audio_data = low.audio_data + wm
125
+ merged = merge_bands(low, high, loud)
126
+ if orig_ch > 1:
127
+ b2, c2, n2 = merged.audio_data.shape
128
+ merged.audio_data = merged.audio_data.reshape(-1, orig_ch * c2, n2)
129
+ return merged.resample(orig_sr)
130
+
131
+ @torch.no_grad()
132
+ def detect(signal: AudioSignal, detector: torch.nn.Module):
133
+ sig = signal.clone().to_mono().resample(AUDIOSEAL_SAMPLE_RATE)
134
+ result, _ = detector.forward(sig.audio_data, sample_rate=AUDIOSEAL_SAMPLE_RATE)
135
+ return result[0, 1, :].detach().cpu().numpy()
136
+
137
+ def pipeline(audio_tuple):
138
+
139
+ sr, audio_np = audio_tuple
140
+
141
+ print("GOT SR", sr)
142
+ print("GOT AUDIO", audio_np.shape)
143
+
144
+ if audio_np.ndim == 1:
145
+ audio_np = audio_np[None, None, :]
146
+ else:
147
+ audio_np = np.transpose(audio_np, (1, 0))[None, ...]
148
+
149
+ print("FORMATTED AUDIO", audio_np.shape)
150
+
151
+ sig = AudioSignal(torch.from_numpy(audio_np).float(), sample_rate=sr)
152
+ orig_loud = sig.loudness()
153
+ sig = sig.to_mono().resample(SAMPLE_RATE).normalize(LOUDNESS_DB).ensure_max_of_audio()
154
+
155
+
156
+ print("REFORMATTED AUDIO")
157
+ print(sig)
158
+
159
+ # Detect
160
+ scores = detect(sig, detector)
161
+
162
+ # Embed + detect without attack
163
+ wm_sig = embed(sig.clone(), embedder).normalize(LOUDNESS_DB).ensure_max_of_audio()
164
+ scores_clean = detect(wm_sig, detector)
165
+
166
+ print(np.mean(scores_clean))
167
+
168
+ # Attack + detect
169
+ att_sig = attack(wm_sig.clone(), codec).normalize(LOUDNESS_DB).ensure_max_of_audio()
170
+ scores_att = detect(att_sig, detector)
171
+
172
+ print(np.mean(scores_att))
173
+
174
+ # Match loudness priot to writing
175
+ wm_sig.normalize(orig_loud).ensure_max_of_audio()
176
+ att_sig.normalize(orig_loud).ensure_max_of_audio()
177
+
178
+ # Write audio files to disk
179
+ uid = uuid.uuid4().hex
180
+ wm_path = OUT_DIR / f"watermarked_{uid}.wav"
181
+ att_path = OUT_DIR / f"attacked_{uid}.wav"
182
+
183
+ wm_arr = wm_sig.audio_data.squeeze().numpy()
184
+ att_arr = att_sig.audio_data.squeeze().numpy()
185
+ wavwrite(str(wm_path), SAMPLE_RATE, wm_arr)
186
+ wavwrite(str(att_path), SAMPLE_RATE, att_arr)
187
+
188
+ # Plot scores with waveform background
189
+ # Plot: waveform on top, detection scores on bottom
190
+ sig_bg = sig.clone().to_mono().resample(AUDIOSEAL_SAMPLE_RATE)
191
+ wav = sig_bg.audio_data.squeeze().numpy()
192
+ N = len(scores)
193
+ if wav.shape[0] < N:
194
+ wav = np.pad(wav, (0, N - wav.shape[0]), mode="constant")
195
+ else:
196
+ wav = wav[:N]
197
+
198
+ fig, (ax_wav, ax_score) = plt.subplots(2, 1, sharex=True, figsize=(8, 6))
199
+ # Top: waveform (no labels)
200
+ ax_wav.plot(wav, alpha=0.3)
201
+ ax_wav.axis("off")
202
+
203
+ # Bottom: detection scores
204
+ ax_score.plot(scores, label="No watermark", color="blue")
205
+ ax_score.plot(scores_clean, label="Watermark (no attack)", color="green")
206
+ ax_score.plot(scores_att, label="Watermark (codec attack)", color="red")
207
+ ax_score.set_xlabel("Frame Index")
208
+ ax_score.set_ylabel("Detection Score")
209
+ ax_score.set_ylim(-0.05, 1.05)
210
+ ax_score.set_yticks([0.0, 0.2, 0.4, 0.6, 0.8, 1.0])
211
+ ax_score.legend()
212
+
213
+ plt.tight_layout()
214
+ plot_path = OUT_DIR / f"detection_plot_{uid}.png"
215
+ fig.savefig(str(plot_path), format="png")
216
+ plt.close(fig)
217
+
218
+ return str(wm_path), str(att_path), str(plot_path)
219
+
220
+ demo = gr.Interface(
221
+ fn=pipeline,
222
+ inputs= gr.Audio(sources=["upload"], type="numpy", label="Upload Input Audio"),
223
+ outputs=[
224
+ gr.Audio(type="filepath", label="Watermarked Audio"),
225
+ gr.Audio(type="filepath", label="Attacked Audio"),
226
+ gr.Image(type="filepath", label="Detection Scores Plot"),
227
+ ],
228
+ title="Watermark Stress Test",
229
+ description="""
230
+
231
+ This is an educational demonstration of state-of-the-art audio watermark performance under codec processing. Upload any (speech) audio file to test watermark performance before and after processing with a low-bitrate neural codec [1].
232
+
233
+ For this demo, we use the AudioSeal [2] watermark, which is well documented, open source, and provides state-of-the-art localized detection performance. Both the watermark and codec operate at 16kHz, meaning all frequencies above 8kHz are left unaltered. To ensure consistent watermark performance, we normalize audio to -16db LUFS and downmix to mono prior to embedding.
234
+
235
+ [1] https://github.com/jasonppy/VoiceCraft
236
+ [2] https://github.com/facebookresearch/audioseal
237
+ """,
238
+ article="""
239
+ The citation info for our corresponding paper is:
240
+
241
+ ```
242
+ @inproceedings{deepwatermarksareshallow,
243
+ author ={Patrick O'Reilly and Zeyu Jin and Jiaqi Su and Bryan Pardo},
244
+ title = {Deep Audio Watermarks are Shallow: Limitations of Post-Hoc Watermarking Techniques for Speech},
245
+ booktitle = {ICLR Workshop on GenAI Watermarking},
246
+ year = {2025}
247
+ }
248
+ ```
249
+
250
+ For the VoiceCraft codec:
251
+
252
+ ```
253
+ @article{voicecraft,
254
+ author={Puyuan Peng and Po-Yao Huang and Daniel Li and Abdelrahman Mohamed and David Harwath},
255
+ year={2024},
256
+ title={VoiceCraft: Zero-Shot Speech Editing and Text-to-Speech in the Wild},
257
+ journal={arXiv preprint arXiv:2403.16973v1},
258
+ }
259
+
260
+ ```
261
+
262
+ And for the AudioSeal watermark:
263
+
264
+ ```
265
+ @article{audioseal,
266
+ title={Proactive Detection of Voice Cloning with Localized Watermarking},
267
+ author={San Roman, Robin and Fernandez, Pierre and Elsahar, Hady and D´efossez, Alexandre and Furon, Teddy and Tran, Tuan},
268
+ journal={International Conference on Machine Learning (ICML)},
269
+ year={2024}
270
+ }
271
+ ```
272
+
273
+ """,
274
+ allow_flagging=False,
275
+ )
276
+
277
+ if __name__ == "__main__":
278
+ demo.launch(share=True)
ckpt/encodec_voicecraft.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:42b224ba5b193a8fb66eb692fe377831bb14b1dcf556638db7afc5d108099bfb
3
+ size 235735922
encodec/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) Meta Platforms, Inc. and affiliates.
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
encodec/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ from .encodec import Encodec
encodec/distrib.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ """Torch distributed utilities."""
7
+ import typing as tp
8
+
9
+ import torch
10
+
11
+
12
+ def rank():
13
+ if torch.distributed.is_initialized():
14
+ return torch.distributed.get_rank()
15
+ else:
16
+ return 0
17
+
18
+
19
+ def world_size():
20
+ if torch.distributed.is_initialized():
21
+ return torch.distributed.get_world_size()
22
+ else:
23
+ return 1
24
+
25
+
26
+ def is_distributed():
27
+ return world_size() > 1
28
+
29
+
30
+ def all_reduce(tensor: torch.Tensor, op=torch.distributed.ReduceOp.SUM):
31
+ if is_distributed():
32
+ return torch.distributed.all_reduce(tensor, op)
33
+
34
+
35
+ def _is_complex_or_float(tensor):
36
+ return torch.is_floating_point(tensor) or torch.is_complex(tensor)
37
+
38
+
39
+ def _check_number_of_params(params: tp.List[torch.Tensor]):
40
+ # utility function to check that the number of params in all workers is the same,
41
+ # and thus avoid a deadlock with distributed all reduce.
42
+ if not is_distributed() or not params:
43
+ return
44
+ tensor = torch.tensor([len(params)], device=params[0].device, dtype=torch.long)
45
+ all_reduce(tensor)
46
+ if tensor.item() != len(params) * world_size():
47
+ # If not all the workers have the same number, for at least one of them,
48
+ # this inequality will be verified.
49
+ raise RuntimeError(
50
+ f"Mismatch in number of params: ours is {len(params)}, "
51
+ "at least one worker has a different one."
52
+ )
53
+
54
+
55
+ def broadcast_tensors(tensors: tp.Iterable[torch.Tensor], src: int = 0):
56
+ """Broadcast the tensors from the given parameters to all workers.
57
+ This can be used to ensure that all workers have the same model to start with.
58
+ """
59
+ if not is_distributed():
60
+ return
61
+ tensors = [tensor for tensor in tensors if _is_complex_or_float(tensor)]
62
+ _check_number_of_params(tensors)
63
+ handles = []
64
+ for tensor in tensors:
65
+ handle = torch.distributed.broadcast(tensor.data, src=src, async_op=True)
66
+ handles.append(handle)
67
+ for handle in handles:
68
+ handle.wait()
69
+
70
+
71
+ def sync_buffer(buffers, average=True):
72
+ """
73
+ Sync grad for buffers. If average is False, broadcast instead of averaging.
74
+ """
75
+ if not is_distributed():
76
+ return
77
+ handles = []
78
+ for buffer in buffers:
79
+ if torch.is_floating_point(buffer.data):
80
+ if average:
81
+ handle = torch.distributed.all_reduce(
82
+ buffer.data, op=torch.distributed.ReduceOp.SUM, async_op=True
83
+ )
84
+ else:
85
+ handle = torch.distributed.broadcast(buffer.data, src=0, async_op=True)
86
+ handles.append((buffer, handle))
87
+ for buffer, handle in handles:
88
+ handle.wait()
89
+ if average:
90
+ buffer.data /= world_size
91
+
92
+
93
+ def sync_grad(params):
94
+ """
95
+ Simpler alternative to DistributedDataParallel, that doesn't rely
96
+ on any black magic. For simple models it can also be as fast.
97
+ Just call this on your model parameters after the call to backward!
98
+ """
99
+ if not is_distributed():
100
+ return
101
+ handles = []
102
+ for p in params:
103
+ if p.grad is not None:
104
+ handle = torch.distributed.all_reduce(
105
+ p.grad.data, op=torch.distributed.ReduceOp.SUM, async_op=True
106
+ )
107
+ handles.append((p, handle))
108
+ for p, handle in handles:
109
+ handle.wait()
110
+ p.grad.data /= world_size()
111
+
112
+
113
+ def average_metrics(metrics: tp.Dict[str, float], count=1.0):
114
+ """Average a dictionary of metrics across all workers, using the optional
115
+ `count` as unnormalized weight.
116
+ """
117
+ if not is_distributed():
118
+ return metrics
119
+ keys, values = zip(*metrics.items())
120
+ device = "cuda" if torch.cuda.is_available() else "cpu"
121
+ tensor = torch.tensor(list(values) + [1], device=device, dtype=torch.float32)
122
+ tensor *= count
123
+ all_reduce(tensor)
124
+ averaged = (tensor[:-1] / tensor[-1]).cpu().tolist()
125
+ return dict(zip(keys, averaged))
encodec/encodec.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ import math
7
+ import typing as tp
8
+
9
+ import torch
10
+
11
+ from .modules import SEANetDecoder
12
+ from .modules import SEANetEncoder
13
+ from .quantization import ResidualVectorQuantizer
14
+
15
+ ################################################################################
16
+ # Encodec neural audio codec
17
+ ################################################################################
18
+
19
+
20
+ class Encodec(torch.nn.Module):
21
+ """
22
+ Encodec neural audio codec proposed in "High Fidelity Neural Audio
23
+ Compression" (https://arxiv.org/abs/2210.13438) by Défossez et al.
24
+ """
25
+
26
+ def __init__(
27
+ self,
28
+ sample_rate: int,
29
+ channels: int,
30
+ causal: bool,
31
+ model_norm: str,
32
+ target_bandwidths: tp.Sequence[float],
33
+ audio_normalize: bool,
34
+ ratios: tp.List[int] = (8, 5, 4, 2),
35
+ codebook_size: int = 1024,
36
+ n_filters: int = 32,
37
+ true_skip: bool = False,
38
+ encoder_kwargs: tp.Dict = None,
39
+ decoder_kwargs: tp.Dict = None,
40
+ ):
41
+ """
42
+ Parameters
43
+ ----------
44
+ sample_rate : int
45
+ Audio sample rate in Hz.
46
+ channels : int
47
+ Number of audio channels expected at input.
48
+ causal : bool
49
+ Whether to use a causal convolution layers in encoder/decoder.
50
+ model_norm : str
51
+ Type of normalization to use in encoder/decoder.
52
+ target_bandwidths : tp.Sequence[float]
53
+ List of target bandwidths in kb/s.
54
+ audio_normalize : bool
55
+ Whether to normalize encoded and decoded audio segments using
56
+ simple scaling factors
57
+ ratios : tp.List[int], optional
58
+ List of downsampling ratios used in encoder/decoder, by default (8, 5, 4, 2)
59
+ codebook_size : int, optional
60
+ Size of residual vector quantizer codebooks, by default 1024
61
+ n_filters : int, optional
62
+ Number of filters used in encoder/decoder, by default 32
63
+ true_skip : bool, optional
64
+ Whether to use true skip connections in encoder/decoder rather than
65
+ convolutional skip connections, by default False
66
+ """
67
+ super().__init__()
68
+
69
+ encoder_kwargs = encoder_kwargs or {}
70
+ decoder_kwargs = decoder_kwargs or {}
71
+
72
+ self.encoder = SEANetEncoder(
73
+ channels=channels,
74
+ causal=causal,
75
+ norm=model_norm,
76
+ ratios=ratios,
77
+ n_filters=n_filters,
78
+ true_skip=true_skip,
79
+ **encoder_kwargs,
80
+ )
81
+ self.decoder = SEANetDecoder(
82
+ channels=channels,
83
+ causal=causal,
84
+ norm=model_norm,
85
+ ratios=ratios,
86
+ n_filters=n_filters,
87
+ true_skip=true_skip,
88
+ **decoder_kwargs,
89
+ )
90
+
91
+ n_q = int(
92
+ 1000
93
+ * target_bandwidths[-1]
94
+ // (math.ceil(sample_rate / self.encoder.hop_length) * 10)
95
+ )
96
+ self.n_q = n_q # Maximum number of quantizers
97
+ self.quantizer = ResidualVectorQuantizer(
98
+ dimension=self.encoder.dimension,
99
+ n_q=n_q,
100
+ bins=codebook_size,
101
+ )
102
+
103
+ self.sample_rate = sample_rate
104
+ self.normalize = audio_normalize
105
+ self.channels = channels
106
+
107
+ self.frame_rate = math.ceil(self.sample_rate / math.prod(self.encoder.ratios))
108
+
109
+ self.target_bandwidths = target_bandwidths
110
+ self.bits_per_codebook = int(math.log2(self.quantizer.bins))
111
+ assert (
112
+ 2**self.bits_per_codebook == self.quantizer.bins
113
+ ), "quantizer bins must be a power of 2."
114
+
115
+ self.bandwidth = self.target_bandwidths[-1]
116
+
117
+ def set_target_bandwidth(self, bandwidth: float):
118
+ """
119
+ Set the target bandwidth for the codec by adjusting the
120
+ number of residual vector quantizers used
121
+ """
122
+ if bandwidth not in self.target_bandwidths:
123
+ raise ValueError(
124
+ f"This model doesn't support the bandwidth {bandwidth}. "
125
+ f"Select one of {self.target_bandwidths}."
126
+ )
127
+ self.bandwidth = bandwidth
128
+
129
+ def encode(self, x: torch.Tensor) -> torch.Tensor:
130
+ """
131
+ Map a given an audio waveform `x` to discrete residual latent codes.
132
+
133
+ Parameters
134
+ ----------
135
+ x : torch.Tensor
136
+ Audio waveform of shape `(n_batch, n_channels, n_samples)`.
137
+
138
+ Returns
139
+ -------
140
+ codes : torch.Tensor
141
+ Tensor of shape `(n_batch, n_codebooks, n_frames)`.
142
+ """
143
+ assert x.dim() == 3
144
+ _, channels, length = x.shape
145
+ assert 0 < channels <= 2
146
+
147
+ z = self.encoder(x)
148
+ codes, z_O, z_o = self.quantizer.encode(z, self.frame_rate, self.bandwidth)
149
+ codes = codes.transpose(0, 1)
150
+
151
+ return codes, z_O, z_o, z
152
+
153
+ def decode(self, codes: torch.Tensor):
154
+ """
155
+ Decode quantized latents to obtain waveform audio.
156
+
157
+ Parameters
158
+ ----------
159
+ codes : torch.Tensor
160
+ Tensor of shape `(n_batch, n_codebooks, n_frames)`.
161
+
162
+ Returns
163
+ -------
164
+ out : torch.Tensor
165
+ Tensor of shape `(n_batch, n_channels, n_samples)`.
166
+ """
167
+ codes = codes.transpose(0, 1)
168
+ emb = self.quantizer.decode(codes)
169
+ out = self.decoder(emb)
170
+
171
+ return out
encodec/modules/__init__.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ """Torch modules."""
7
+ from .conv import NormConv1d
8
+ from .conv import NormConv2d
9
+ from .conv import NormConvTranspose1d
10
+ from .conv import NormConvTranspose2d
11
+ from .conv import pad1d
12
+ from .conv import SConv1d
13
+ from .conv import SConvTranspose1d
14
+ from .conv import unpad1d
15
+ from .lstm import SLSTM
16
+ from .seanet import SEANetDecoder
17
+ from .seanet import SEANetEncoder
encodec/modules/conv.py ADDED
@@ -0,0 +1,342 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ """Convolutional layers wrappers and utilities."""
7
+ import math
8
+ import typing as tp
9
+ import warnings
10
+
11
+ import torch
12
+ from torch import nn
13
+ from torch.nn import functional as F
14
+ from torch.nn.utils import spectral_norm
15
+ from torch.nn.utils import weight_norm
16
+
17
+ from .norm import ConvLayerNorm
18
+
19
+
20
+ CONV_NORMALIZATIONS = frozenset(
21
+ [
22
+ "none",
23
+ "weight_norm",
24
+ "spectral_norm",
25
+ "time_layer_norm",
26
+ "layer_norm",
27
+ "time_group_norm",
28
+ ]
29
+ )
30
+
31
+
32
+ def apply_parametrization_norm(module: nn.Module, norm: str = "none") -> nn.Module:
33
+ assert norm in CONV_NORMALIZATIONS
34
+ if norm == "weight_norm":
35
+ return weight_norm(module)
36
+ elif norm == "spectral_norm":
37
+ return spectral_norm(module)
38
+ else:
39
+ # We already check was in CONV_NORMALIZATION, so any other choice
40
+ # doesn't need reparametrization.
41
+ return module
42
+
43
+
44
+ def get_norm_module(
45
+ module: nn.Module, causal: bool = False, norm: str = "none", **norm_kwargs
46
+ ) -> nn.Module:
47
+ """Return the proper normalization module. If causal is True, this will ensure the returned
48
+ module is causal, or return an error if the normalization doesn't support causal evaluation.
49
+ """
50
+ assert norm in CONV_NORMALIZATIONS
51
+ if norm == "layer_norm":
52
+ assert isinstance(module, nn.modules.conv._ConvNd)
53
+ return ConvLayerNorm(module.out_channels, **norm_kwargs)
54
+ elif norm == "time_group_norm":
55
+ if causal:
56
+ raise ValueError("GroupNorm doesn't support causal evaluation.")
57
+ assert isinstance(module, nn.modules.conv._ConvNd)
58
+ return nn.GroupNorm(1, module.out_channels, **norm_kwargs)
59
+ else:
60
+ return nn.Identity()
61
+
62
+
63
+ def get_extra_padding_for_conv1d(
64
+ x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0
65
+ ) -> int:
66
+ """See `pad_for_conv1d`."""
67
+ length = x.shape[-1]
68
+ n_frames = (length - kernel_size + padding_total) / stride + 1
69
+ ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
70
+ return ideal_length - length
71
+
72
+
73
+ def pad_for_conv1d(
74
+ x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0
75
+ ):
76
+ """Pad for a convolution to make sure that the last window is full.
77
+ Extra padding is added at the end. This is required to ensure that we can rebuild
78
+ an output of the same length, as otherwise, even with padding, some time steps
79
+ might get removed.
80
+ For instance, with total padding = 4, kernel size = 4, stride = 2:
81
+ 0 0 1 2 3 4 5 0 0 # (0s are padding)
82
+ 1 2 3 # (output frames of a convolution, last 0 is never used)
83
+ 0 0 1 2 3 4 5 0 # (output of tr. conv., but pos. 5 is going to get removed as padding)
84
+ 1 2 3 4 # once you removed padding, we are missing one time step !
85
+ """
86
+ extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
87
+ return F.pad(x, (0, extra_padding))
88
+
89
+
90
+ def pad1d(
91
+ x: torch.Tensor,
92
+ paddings: tp.Tuple[int, int],
93
+ mode: str = "zero",
94
+ value: float = 0.0,
95
+ ):
96
+ """Tiny wrapper around F.pad, just to allow for reflect padding on small input.
97
+ If this is the case, we insert extra 0 padding to the right before the reflection happen.
98
+ """
99
+ length = x.shape[-1]
100
+ padding_left, padding_right = paddings
101
+ assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
102
+ if mode == "reflect":
103
+ max_pad = max(padding_left, padding_right)
104
+ extra_pad = 0
105
+ if length <= max_pad:
106
+ extra_pad = max_pad - length + 1
107
+ x = F.pad(x, (0, extra_pad))
108
+ padded = F.pad(x, paddings, mode, value)
109
+ end = padded.shape[-1] - extra_pad
110
+ return padded[..., :end]
111
+ else:
112
+ return F.pad(x, paddings, mode, value)
113
+
114
+
115
+ def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]):
116
+ """Remove padding from x, handling properly zero padding. Only for 1d!"""
117
+ padding_left, padding_right = paddings
118
+ assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
119
+ assert (padding_left + padding_right) <= x.shape[-1]
120
+ end = x.shape[-1] - padding_right
121
+ return x[..., padding_left:end]
122
+
123
+
124
+ class NormConv1d(nn.Module):
125
+ """Wrapper around Conv1d and normalization applied to this conv
126
+ to provide a uniform interface across normalization approaches.
127
+ """
128
+
129
+ def __init__(
130
+ self,
131
+ *args,
132
+ causal: bool = False,
133
+ norm: str = "none",
134
+ norm_kwargs: tp.Dict[str, tp.Any] = {},
135
+ **kwargs,
136
+ ):
137
+ super().__init__()
138
+ self.conv = apply_parametrization_norm(nn.Conv1d(*args, **kwargs), norm)
139
+ self.norm = get_norm_module(self.conv, causal, norm, **norm_kwargs)
140
+ self.norm_type = norm
141
+
142
+ def forward(self, x):
143
+ x = self.conv(x)
144
+ x = self.norm(x)
145
+ return x
146
+
147
+
148
+ class NormConv2d(nn.Module):
149
+ """Wrapper around Conv2d and normalization applied to this conv
150
+ to provide a uniform interface across normalization approaches.
151
+ """
152
+
153
+ def __init__(
154
+ self,
155
+ *args,
156
+ norm: str = "none",
157
+ norm_kwargs: tp.Dict[str, tp.Any] = {},
158
+ **kwargs,
159
+ ):
160
+ super().__init__()
161
+ self.conv = apply_parametrization_norm(nn.Conv2d(*args, **kwargs), norm)
162
+ self.norm = get_norm_module(self.conv, causal=False, norm=norm, **norm_kwargs)
163
+ self.norm_type = norm
164
+
165
+ def forward(self, x):
166
+ x = self.conv(x)
167
+ x = self.norm(x)
168
+ return x
169
+
170
+
171
+ class NormConvTranspose1d(nn.Module):
172
+ """Wrapper around ConvTranspose1d and normalization applied to this conv
173
+ to provide a uniform interface across normalization approaches.
174
+ """
175
+
176
+ def __init__(
177
+ self,
178
+ *args,
179
+ causal: bool = False,
180
+ norm: str = "none",
181
+ norm_kwargs: tp.Dict[str, tp.Any] = {},
182
+ **kwargs,
183
+ ):
184
+ super().__init__()
185
+ self.convtr = apply_parametrization_norm(
186
+ nn.ConvTranspose1d(*args, **kwargs), norm
187
+ )
188
+ self.norm = get_norm_module(self.convtr, causal, norm, **norm_kwargs)
189
+ self.norm_type = norm
190
+
191
+ def forward(self, x):
192
+ x = self.convtr(x)
193
+ x = self.norm(x)
194
+ return x
195
+
196
+
197
+ class NormConvTranspose2d(nn.Module):
198
+ """Wrapper around ConvTranspose2d and normalization applied to this conv
199
+ to provide a uniform interface across normalization approaches.
200
+ """
201
+
202
+ def __init__(
203
+ self,
204
+ *args,
205
+ norm: str = "none",
206
+ norm_kwargs: tp.Dict[str, tp.Any] = {},
207
+ **kwargs,
208
+ ):
209
+ super().__init__()
210
+ self.convtr = apply_parametrization_norm(
211
+ nn.ConvTranspose2d(*args, **kwargs), norm
212
+ )
213
+ self.norm = get_norm_module(self.convtr, causal=False, norm=norm, **norm_kwargs)
214
+
215
+ def forward(self, x):
216
+ x = self.convtr(x)
217
+ x = self.norm(x)
218
+ return x
219
+
220
+
221
+ class SConv1d(nn.Module):
222
+ """Conv1d with some builtin handling of asymmetric or causal padding
223
+ and normalization.
224
+ """
225
+
226
+ def __init__(
227
+ self,
228
+ in_channels: int,
229
+ out_channels: int,
230
+ kernel_size: int,
231
+ stride: int = 1,
232
+ dilation: int = 1,
233
+ groups: int = 1,
234
+ bias: bool = True,
235
+ causal: bool = False,
236
+ norm: str = "none",
237
+ norm_kwargs: tp.Dict[str, tp.Any] = {},
238
+ pad_mode: str = "reflect",
239
+ ):
240
+ super().__init__()
241
+ # warn user on unusual setup between dilation and stride
242
+ if stride > 1 and dilation > 1:
243
+ warnings.warn(
244
+ "SConv1d has been initialized with stride > 1 and dilation > 1"
245
+ f" (kernel_size={kernel_size} stride={stride}, dilation={dilation})."
246
+ )
247
+ self.conv = NormConv1d(
248
+ in_channels,
249
+ out_channels,
250
+ kernel_size,
251
+ stride,
252
+ dilation=dilation,
253
+ groups=groups,
254
+ bias=bias,
255
+ causal=causal,
256
+ norm=norm,
257
+ norm_kwargs=norm_kwargs,
258
+ )
259
+ self.causal = causal
260
+ self.pad_mode = pad_mode
261
+
262
+ def forward(self, x):
263
+ B, C, T = x.shape
264
+ kernel_size = self.conv.conv.kernel_size[0]
265
+ stride = self.conv.conv.stride[0]
266
+ dilation = self.conv.conv.dilation[0]
267
+ kernel_size = (
268
+ kernel_size - 1
269
+ ) * dilation + 1 # effective kernel size with dilations
270
+ padding_total = kernel_size - stride
271
+ extra_padding = get_extra_padding_for_conv1d(
272
+ x, kernel_size, stride, padding_total
273
+ )
274
+ if self.causal:
275
+ # Left padding for causal
276
+ x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode)
277
+ else:
278
+ # Asymmetric padding required for odd strides
279
+ padding_right = padding_total // 2
280
+ padding_left = padding_total - padding_right
281
+ x = pad1d(
282
+ x, (padding_left, padding_right + extra_padding), mode=self.pad_mode
283
+ )
284
+ return self.conv(x)
285
+
286
+
287
+ class SConvTranspose1d(nn.Module):
288
+ """ConvTranspose1d with some builtin handling of asymmetric or causal padding
289
+ and normalization.
290
+ """
291
+
292
+ def __init__(
293
+ self,
294
+ in_channels: int,
295
+ out_channels: int,
296
+ kernel_size: int,
297
+ stride: int = 1,
298
+ causal: bool = False,
299
+ norm: str = "none",
300
+ trim_right_ratio: float = 1.0,
301
+ norm_kwargs: tp.Dict[str, tp.Any] = {},
302
+ ):
303
+ super().__init__()
304
+ self.convtr = NormConvTranspose1d(
305
+ in_channels,
306
+ out_channels,
307
+ kernel_size,
308
+ stride,
309
+ causal=causal,
310
+ norm=norm,
311
+ norm_kwargs=norm_kwargs,
312
+ )
313
+ self.causal = causal
314
+ self.trim_right_ratio = trim_right_ratio
315
+ assert (
316
+ self.causal or self.trim_right_ratio == 1.0
317
+ ), "`trim_right_ratio` != 1.0 only makes sense for causal convolutions"
318
+ assert self.trim_right_ratio >= 0.0 and self.trim_right_ratio <= 1.0
319
+
320
+ def forward(self, x):
321
+ kernel_size = self.convtr.convtr.kernel_size[0]
322
+ stride = self.convtr.convtr.stride[0]
323
+ padding_total = kernel_size - stride
324
+
325
+ y = self.convtr(x)
326
+
327
+ # We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be
328
+ # removed at the very end, when keeping only the right length for the output,
329
+ # as removing it here would require also passing the length at the matching layer
330
+ # in the encoder.
331
+ if self.causal:
332
+ # Trim the padding on the right according to the specified ratio
333
+ # if trim_right_ratio = 1.0, trim everything from right
334
+ padding_right = math.ceil(padding_total * self.trim_right_ratio)
335
+ padding_left = padding_total - padding_right
336
+ y = unpad1d(y, (padding_left, padding_right))
337
+ else:
338
+ # Asymmetric padding required for odd strides
339
+ padding_right = padding_total // 2
340
+ padding_left = padding_total - padding_right
341
+ y = unpad1d(y, (padding_left, padding_right))
342
+ return y
encodec/modules/lstm.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ """LSTM layers module."""
7
+ from torch import nn
8
+
9
+
10
+ class SLSTM(nn.Module):
11
+ """
12
+ LSTM without worrying about the hidden state, nor the layout of the data.
13
+ Expects input as convolutional layout.
14
+ """
15
+
16
+ def __init__(self, dimension: int, num_layers: int = 2, skip: bool = True):
17
+ super().__init__()
18
+ self.skip = skip
19
+ self.lstm = nn.LSTM(dimension, dimension, num_layers)
20
+
21
+ def forward(self, x):
22
+ x = x.permute(2, 0, 1)
23
+ y, _ = self.lstm(x)
24
+ if self.skip:
25
+ y = y + x
26
+ y = y.permute(1, 2, 0)
27
+ return y
encodec/modules/norm.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ """Normalization modules."""
7
+ import typing as tp
8
+
9
+ import torch
10
+ from torch import nn
11
+
12
+
13
+ class ConvLayerNorm(nn.LayerNorm):
14
+ """
15
+ Convolution-friendly LayerNorm that moves channels to last dimensions
16
+ before running the normalization and moves them back to original position right after.
17
+ """
18
+
19
+ def __init__(
20
+ self, normalized_shape: tp.Union[int, tp.List[int], torch.Size], **kwargs
21
+ ):
22
+ super().__init__(normalized_shape, **kwargs)
23
+
24
+ def forward(self, x):
25
+
26
+ assert x.ndim == 3 # (n_batch, n_channels, n_samples)
27
+
28
+ x = x.transpose(1, 2)
29
+ x = super().forward(x)
30
+ x = x.transpose(1, 2)
31
+
32
+ return x
encodec/modules/seanet.py ADDED
@@ -0,0 +1,352 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ """Encodec SEANet-based encoder and decoder implementation."""
7
+ import typing as tp
8
+
9
+ import numpy as np
10
+ import torch.nn as nn
11
+
12
+ from . import SConv1d
13
+ from . import SConvTranspose1d
14
+ from . import SLSTM
15
+
16
+
17
+ class SEANetResnetBlock(nn.Module):
18
+ """Residual block from SEANet model.
19
+ Args:
20
+ dim (int): Dimension of the input/output
21
+ kernel_sizes (list): List of kernel sizes for the convolutions.
22
+ dilations (list): List of dilations for the convolutions.
23
+ activation (str): Activation function.
24
+ activation_params (dict): Parameters to provide to the activation function
25
+ norm (str): Normalization method.
26
+ norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
27
+ causal (bool): Whether to use fully causal convolution.
28
+ pad_mode (str): Padding mode for the convolutions.
29
+ compress (int): Reduced dimensionality in residual branches (from Demucs v3)
30
+ true_skip (bool): Whether to use true skip connection or a simple convolution as the skip connection.
31
+ """
32
+
33
+ def __init__(
34
+ self,
35
+ dim: int,
36
+ kernel_sizes: tp.List[int] = [3, 1],
37
+ dilations: tp.List[int] = [1, 1],
38
+ activation: str = "ELU",
39
+ activation_params: dict = {"alpha": 1.0},
40
+ norm: str = "weight_norm",
41
+ norm_params: tp.Dict[str, tp.Any] = {},
42
+ causal: bool = False,
43
+ pad_mode: str = "reflect",
44
+ compress: int = 2,
45
+ true_skip: bool = True,
46
+ ):
47
+ super().__init__()
48
+ assert len(kernel_sizes) == len(
49
+ dilations
50
+ ), "Number of kernel sizes should match number of dilations"
51
+ act = getattr(nn, activation)
52
+ hidden = dim // compress
53
+ block = []
54
+ for i, (kernel_size, dilation) in enumerate(zip(kernel_sizes, dilations)):
55
+ in_chs = dim if i == 0 else hidden
56
+ out_chs = dim if i == len(kernel_sizes) - 1 else hidden
57
+ block += [
58
+ act(**activation_params),
59
+ SConv1d(
60
+ in_chs,
61
+ out_chs,
62
+ kernel_size=kernel_size,
63
+ dilation=dilation,
64
+ norm=norm,
65
+ norm_kwargs=norm_params,
66
+ causal=causal,
67
+ pad_mode=pad_mode,
68
+ ),
69
+ ]
70
+ self.block = nn.Sequential(*block)
71
+ self.shortcut: nn.Module
72
+ if true_skip:
73
+ self.shortcut = nn.Identity()
74
+ else:
75
+ self.shortcut = SConv1d(
76
+ dim,
77
+ dim,
78
+ kernel_size=1,
79
+ norm=norm,
80
+ norm_kwargs=norm_params,
81
+ causal=causal,
82
+ pad_mode=pad_mode,
83
+ )
84
+
85
+ def forward(self, x):
86
+ return self.shortcut(x) + self.block(x)
87
+
88
+
89
+ class SEANetEncoder(nn.Module):
90
+ """SEANet encoder.
91
+ Args:
92
+ channels (int): Audio channels.
93
+ dimension (int): Intermediate representation dimension.
94
+ n_filters (int): Base width for the model.
95
+ n_residual_layers (int): nb of residual layers.
96
+ ratios (Sequence[int]): kernel size and stride ratios. The encoder uses downsampling ratios instead of
97
+ upsampling ratios, hence it will use the ratios in the reverse order to the ones specified here
98
+ that must match the decoder order
99
+ activation (str): Activation function.
100
+ activation_params (dict): Parameters to provide to the activation function
101
+ norm (str): Normalization method.
102
+ norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
103
+ kernel_size (int): Kernel size for the initial convolution.
104
+ last_kernel_size (int): Kernel size for the initial convolution.
105
+ residual_kernel_size (int): Kernel size for the residual layers.
106
+ dilation_base (int): How much to increase the dilation with each layer.
107
+ causal (bool): Whether to use fully causal convolution.
108
+ pad_mode (str): Padding mode for the convolutions.
109
+ true_skip (bool): Whether to use true skip connection or a simple
110
+ (streamable) convolution as the skip connection in the residual network blocks.
111
+ compress (int): Reduced dimensionality in residual branches (from Demucs v3).
112
+ lstm (int): Number of LSTM layers at the end of the encoder.
113
+ """
114
+
115
+ def __init__(
116
+ self,
117
+ channels: int = 1,
118
+ dimension: int = 128,
119
+ n_filters: int = 32,
120
+ n_residual_layers: int = 1,
121
+ ratios: tp.List[int] = [8, 5, 4, 2],
122
+ activation: str = "ELU",
123
+ activation_params: dict = {"alpha": 1.0},
124
+ norm: str = "weight_norm",
125
+ norm_params: tp.Dict[str, tp.Any] = {},
126
+ kernel_size: int = 7,
127
+ last_kernel_size: int = 7,
128
+ residual_kernel_size: int = 3,
129
+ dilation_base: int = 2,
130
+ causal: bool = False,
131
+ pad_mode: str = "reflect",
132
+ true_skip: bool = False,
133
+ compress: int = 2,
134
+ lstm: int = 2,
135
+ ):
136
+ super().__init__()
137
+ self.channels = channels
138
+ self.dimension = dimension
139
+ self.n_filters = n_filters
140
+ self.ratios = list(reversed(ratios))
141
+ del ratios
142
+ self.n_residual_layers = n_residual_layers
143
+ self.hop_length = np.prod(self.ratios)
144
+
145
+ act = getattr(nn, activation)
146
+ mult = 1
147
+ model: tp.List[nn.Module] = [
148
+ SConv1d(
149
+ channels,
150
+ mult * n_filters,
151
+ kernel_size,
152
+ norm=norm,
153
+ norm_kwargs=norm_params,
154
+ causal=causal,
155
+ pad_mode=pad_mode,
156
+ )
157
+ ]
158
+ # Downsample to raw audio scale
159
+ for i, ratio in enumerate(self.ratios):
160
+ # Add residual layers
161
+ for j in range(n_residual_layers):
162
+ model += [
163
+ SEANetResnetBlock(
164
+ mult * n_filters,
165
+ kernel_sizes=[residual_kernel_size, 1],
166
+ dilations=[dilation_base**j, 1],
167
+ norm=norm,
168
+ norm_params=norm_params,
169
+ activation=activation,
170
+ activation_params=activation_params,
171
+ causal=causal,
172
+ pad_mode=pad_mode,
173
+ compress=compress,
174
+ true_skip=true_skip,
175
+ )
176
+ ]
177
+
178
+ # Add downsampling layers
179
+ model += [
180
+ act(**activation_params),
181
+ SConv1d(
182
+ mult * n_filters,
183
+ mult * n_filters * 2,
184
+ kernel_size=ratio * 2,
185
+ stride=ratio,
186
+ norm=norm,
187
+ norm_kwargs=norm_params,
188
+ causal=causal,
189
+ pad_mode=pad_mode,
190
+ ),
191
+ ]
192
+ mult *= 2
193
+
194
+ if lstm:
195
+ model += [SLSTM(mult * n_filters, num_layers=lstm)]
196
+
197
+ model += [
198
+ act(**activation_params),
199
+ SConv1d(
200
+ mult * n_filters,
201
+ dimension,
202
+ last_kernel_size,
203
+ norm=norm,
204
+ norm_kwargs=norm_params,
205
+ causal=causal,
206
+ pad_mode=pad_mode,
207
+ ),
208
+ ]
209
+
210
+ self.model = nn.Sequential(*model)
211
+
212
+ def forward(self, x):
213
+ return self.model(x)
214
+
215
+
216
+ class SEANetDecoder(nn.Module):
217
+ """SEANet decoder.
218
+ Args:
219
+ channels (int): Audio channels.
220
+ dimension (int): Intermediate representation dimension.
221
+ n_filters (int): Base width for the model.
222
+ n_residual_layers (int): nb of residual layers.
223
+ ratios (Sequence[int]): kernel size and stride ratios
224
+ activation (str): Activation function.
225
+ activation_params (dict): Parameters to provide to the activation function
226
+ final_activation (str): Final activation function after all convolutions.
227
+ final_activation_params (dict): Parameters to provide to the activation function
228
+ norm (str): Normalization method.
229
+ norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
230
+ kernel_size (int): Kernel size for the initial convolution.
231
+ last_kernel_size (int): Kernel size for the initial convolution.
232
+ residual_kernel_size (int): Kernel size for the residual layers.
233
+ dilation_base (int): How much to increase the dilation with each layer.
234
+ causal (bool): Whether to use fully causal convolution.
235
+ pad_mode (str): Padding mode for the convolutions.
236
+ true_skip (bool): Whether to use true skip connection or a simple
237
+ (streamable) convolution as the skip connection in the residual network blocks.
238
+ compress (int): Reduced dimensionality in residual branches (from Demucs v3).
239
+ lstm (int): Number of LSTM layers at the end of the encoder.
240
+ trim_right_ratio (float): Ratio for trimming at the right of the transposed convolution under the causal setup.
241
+ If equal to 1.0, it means that all the trimming is done at the right.
242
+ """
243
+
244
+ def __init__(
245
+ self,
246
+ channels: int = 1,
247
+ dimension: int = 128,
248
+ n_filters: int = 32,
249
+ n_residual_layers: int = 1,
250
+ ratios: tp.List[int] = [8, 5, 4, 2],
251
+ activation: str = "ELU",
252
+ activation_params: dict = {"alpha": 1.0},
253
+ final_activation: tp.Optional[str] = None,
254
+ final_activation_params: tp.Optional[dict] = None,
255
+ norm: str = "weight_norm",
256
+ norm_params: tp.Dict[str, tp.Any] = {},
257
+ kernel_size: int = 7,
258
+ last_kernel_size: int = 7,
259
+ residual_kernel_size: int = 3,
260
+ dilation_base: int = 2,
261
+ causal: bool = False,
262
+ pad_mode: str = "reflect",
263
+ true_skip: bool = False,
264
+ compress: int = 2,
265
+ lstm: int = 2,
266
+ trim_right_ratio: float = 1.0,
267
+ ):
268
+ super().__init__()
269
+ self.dimension = dimension
270
+ self.channels = channels
271
+ self.n_filters = n_filters
272
+ self.ratios = ratios
273
+ del ratios
274
+ self.n_residual_layers = n_residual_layers
275
+ self.hop_length = np.prod(self.ratios)
276
+
277
+ act = getattr(nn, activation)
278
+ mult = int(2 ** len(self.ratios))
279
+ model: tp.List[nn.Module] = [
280
+ SConv1d(
281
+ dimension,
282
+ mult * n_filters,
283
+ kernel_size,
284
+ norm=norm,
285
+ norm_kwargs=norm_params,
286
+ causal=causal,
287
+ pad_mode=pad_mode,
288
+ )
289
+ ]
290
+
291
+ if lstm:
292
+ model += [SLSTM(mult * n_filters, num_layers=lstm)]
293
+
294
+ # Upsample to raw audio scale
295
+ for i, ratio in enumerate(self.ratios):
296
+ # Add upsampling layers
297
+ model += [
298
+ act(**activation_params),
299
+ SConvTranspose1d(
300
+ mult * n_filters,
301
+ mult * n_filters // 2,
302
+ kernel_size=ratio * 2,
303
+ stride=ratio,
304
+ norm=norm,
305
+ norm_kwargs=norm_params,
306
+ causal=causal,
307
+ trim_right_ratio=trim_right_ratio,
308
+ ),
309
+ ]
310
+ # Add residual layers
311
+ for j in range(n_residual_layers):
312
+ model += [
313
+ SEANetResnetBlock(
314
+ mult * n_filters // 2,
315
+ kernel_sizes=[residual_kernel_size, 1],
316
+ dilations=[dilation_base**j, 1],
317
+ activation=activation,
318
+ activation_params=activation_params,
319
+ norm=norm,
320
+ norm_params=norm_params,
321
+ causal=causal,
322
+ pad_mode=pad_mode,
323
+ compress=compress,
324
+ true_skip=true_skip,
325
+ )
326
+ ]
327
+
328
+ mult //= 2
329
+
330
+ # Add final layers
331
+ model += [
332
+ act(**activation_params),
333
+ SConv1d(
334
+ n_filters,
335
+ channels,
336
+ last_kernel_size,
337
+ norm=norm,
338
+ norm_kwargs=norm_params,
339
+ causal=causal,
340
+ pad_mode=pad_mode,
341
+ ),
342
+ ]
343
+ # Add optional final activation to decoder (eg. tanh)
344
+ if final_activation is not None:
345
+ final_act = getattr(nn, final_activation)
346
+ final_activation_params = final_activation_params or {}
347
+ model += [final_act(**final_activation_params)]
348
+ self.model = nn.Sequential(*model)
349
+
350
+ def forward(self, z):
351
+ y = self.model(z)
352
+ return y
encodec/quantization/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # flake8: noqa
7
+ from .vq import ResidualVectorQuantizer
encodec/quantization/core_vq.py ADDED
@@ -0,0 +1,370 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ import typing as tp
7
+ import warnings
8
+
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from torch import nn
12
+
13
+ from .. import distrib
14
+
15
+ ################################################################################
16
+ # Core vector quantization implementation
17
+ ################################################################################
18
+
19
+
20
+ def default(val: tp.Any, d: tp.Any) -> tp.Any:
21
+ return val if val is not None else d
22
+
23
+
24
+ def ema_inplace(moving_avg, new, decay: float):
25
+ moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
26
+
27
+
28
+ def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-5):
29
+ return (x + epsilon) / (x.sum() + n_categories * epsilon)
30
+
31
+
32
+ def uniform_init(*shape: int):
33
+ t = torch.empty(shape)
34
+ nn.init.kaiming_uniform_(t)
35
+ return t
36
+
37
+
38
+ def sample_vectors(samples, num: int):
39
+ num_samples, device = samples.shape[0], samples.device
40
+
41
+ if num_samples >= num:
42
+ indices = torch.randperm(num_samples, device=device)[:num]
43
+ else:
44
+ indices = torch.randint(0, num_samples, (num,), device=device)
45
+
46
+ return samples[indices]
47
+
48
+
49
+ def kmeans(samples, num_clusters: int, num_iters: int = 10):
50
+ dim, dtype = samples.shape[-1], samples.dtype
51
+
52
+ means = sample_vectors(samples, num_clusters)
53
+
54
+ for _ in range(num_iters):
55
+ diffs = samples.unsqueeze(1) - means.unsqueeze(0)
56
+ dists = -(diffs**2).sum(dim=-1)
57
+
58
+ buckets = dists.max(dim=-1).indices
59
+ bins = torch.bincount(buckets, minlength=num_clusters)
60
+ zero_mask = bins == 0
61
+ bins_min_clamped = bins.masked_fill(zero_mask, 1)
62
+
63
+ new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype)
64
+
65
+ new_means.scatter_add_(0, buckets.unsqueeze(-1).expand(-1, dim), samples)
66
+ new_means = new_means / bins_min_clamped[..., None]
67
+
68
+ means = torch.where(zero_mask[..., None], means, new_means)
69
+
70
+ return means, bins
71
+
72
+
73
+ class EuclideanCodebook(nn.Module):
74
+ """Codebook with Euclidean distance.
75
+ Args:
76
+ dim (int): Dimension.
77
+ codebook_size (int): Codebook size.
78
+ kmeans_init (bool): Whether to use k-means to initialize the codebooks.
79
+ If set to true, run the k-means algorithm on the first training batch and use
80
+ the learned centroids as initialization.
81
+ kmeans_iters (int): Number of iterations used for k-means algorithm at initialization.
82
+ decay (float): Decay for exponential moving average over the codebooks.
83
+ epsilon (float): Epsilon value for numerical stability.
84
+ threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
85
+ that have an exponential moving average cluster size less than the specified threshold with
86
+ randomly selected vector from the current batch.
87
+ """
88
+
89
+ def __init__(
90
+ self,
91
+ dim: int,
92
+ codebook_size: int,
93
+ kmeans_init: int = False,
94
+ kmeans_iters: int = 10,
95
+ decay: float = 0.99,
96
+ epsilon: float = 1e-5,
97
+ threshold_ema_dead_code: int = 2,
98
+ ):
99
+ super().__init__()
100
+ self.decay = decay
101
+ init_fn: tp.Union[tp.Callable[..., torch.Tensor], tp.Any] = (
102
+ uniform_init if not kmeans_init else torch.zeros
103
+ )
104
+ embed = init_fn(codebook_size, dim)
105
+
106
+ self.codebook_size = codebook_size
107
+
108
+ self.kmeans_iters = kmeans_iters
109
+ self.epsilon = epsilon
110
+ self.threshold_ema_dead_code = threshold_ema_dead_code
111
+
112
+ self.register_buffer("inited", torch.Tensor([not kmeans_init]))
113
+ self.register_buffer("cluster_size", torch.zeros(codebook_size))
114
+ self.register_buffer("embed", embed)
115
+ self.register_buffer("embed_avg", embed.clone())
116
+
117
+ @torch.jit.ignore
118
+ def init_embed_(self, data):
119
+ if self.inited:
120
+ return
121
+
122
+ embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters)
123
+ self.embed.data.copy_(embed)
124
+ self.embed_avg.data.copy_(embed.clone())
125
+ self.cluster_size.data.copy_(cluster_size)
126
+ self.inited.data.copy_(torch.Tensor([True]))
127
+ # Make sure all buffers across workers are in sync after initialization
128
+ distrib.broadcast_tensors(self.buffers())
129
+
130
+ def replace_(self, samples, mask):
131
+ modified_codebook = torch.where(
132
+ mask[..., None], sample_vectors(samples, self.codebook_size), self.embed
133
+ )
134
+ self.embed.data.copy_(modified_codebook)
135
+
136
+ def expire_codes_(self, batch_samples):
137
+ if self.threshold_ema_dead_code == 0:
138
+ return
139
+
140
+ expired_codes = self.cluster_size < self.threshold_ema_dead_code
141
+ if not torch.any(expired_codes):
142
+ return
143
+
144
+ batch_samples = batch_samples.view(-1, batch_samples.shape[-1])
145
+ self.replace_(batch_samples, mask=expired_codes)
146
+ distrib.broadcast_tensors(self.buffers())
147
+
148
+ def preprocess(self, x):
149
+ x = x.view(-1, x.shape[-1])
150
+ return x
151
+
152
+ def quantize(self, x):
153
+ embed = self.embed.t()
154
+ dist = -(
155
+ x.pow(2).sum(1, keepdim=True)
156
+ - 2 * x @ embed
157
+ + embed.pow(2).sum(0, keepdim=True)
158
+ )
159
+ embed_ind = dist.max(dim=-1).indices
160
+ return embed_ind
161
+
162
+ def postprocess_emb(self, embed_ind, shape):
163
+ return embed_ind.view(*shape[:-1])
164
+
165
+ def dequantize(self, embed_ind):
166
+ quantize = F.embedding(embed_ind, self.embed)
167
+ return quantize
168
+
169
+ def encode(self, x):
170
+ shape = x.shape
171
+ # pre-process
172
+ x = self.preprocess(x)
173
+ # quantize
174
+ embed_ind = self.quantize(x)
175
+ # post-process
176
+ embed_ind = self.postprocess_emb(embed_ind, shape)
177
+ return embed_ind
178
+
179
+ def decode(self, embed_ind):
180
+ quantize = self.dequantize(embed_ind)
181
+ return quantize
182
+
183
+ def forward(self, x):
184
+ shape, dtype = x.shape, x.dtype
185
+ x = self.preprocess(x)
186
+
187
+ self.init_embed_(x)
188
+
189
+ embed_ind = self.quantize(x)
190
+ embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype)
191
+ embed_ind = self.postprocess_emb(embed_ind, shape)
192
+ quantize = self.dequantize(embed_ind)
193
+
194
+ if self.training:
195
+ # We do the expiry of code at that point as buffers are in sync
196
+ # and all the workers will take the same decision.
197
+ self.expire_codes_(x)
198
+ ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay)
199
+ embed_sum = x.t() @ embed_onehot
200
+ ema_inplace(self.embed_avg, embed_sum.t(), self.decay)
201
+ cluster_size = (
202
+ laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon)
203
+ * self.cluster_size.sum()
204
+ )
205
+ embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
206
+ self.embed.data.copy_(embed_normalized)
207
+
208
+ return quantize, embed_ind
209
+
210
+
211
+ class VectorQuantization(nn.Module):
212
+ """Vector quantization implementation.
213
+ Currently supports only euclidean distance.
214
+ Args:
215
+ dim (int): Dimension
216
+ codebook_size (int): Codebook size
217
+ codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim.
218
+ decay (float): Decay for exponential moving average over the codebooks.
219
+ epsilon (float): Epsilon value for numerical stability.
220
+ kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
221
+ kmeans_iters (int): Number of iterations used for kmeans initialization.
222
+ threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
223
+ that have an exponential moving average cluster size less than the specified threshold with
224
+ randomly selected vector from the current batch.
225
+ commitment_weight (float): Weight for commitment loss.
226
+ """
227
+
228
+ def __init__(
229
+ self,
230
+ dim: int,
231
+ codebook_size: int,
232
+ codebook_dim: tp.Optional[int] = None,
233
+ decay: float = 0.99,
234
+ epsilon: float = 1e-5,
235
+ kmeans_init: bool = True,
236
+ kmeans_iters: int = 50,
237
+ threshold_ema_dead_code: int = 2,
238
+ commitment_weight: float = 1.0,
239
+ ):
240
+ super().__init__()
241
+ _codebook_dim: int = default(codebook_dim, dim)
242
+
243
+ requires_projection = _codebook_dim != dim
244
+ self.project_in = (
245
+ nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity()
246
+ )
247
+ self.project_out = (
248
+ nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity()
249
+ )
250
+
251
+ self.epsilon = epsilon
252
+ self.commitment_weight = commitment_weight
253
+
254
+ self._codebook = EuclideanCodebook(
255
+ dim=_codebook_dim,
256
+ codebook_size=codebook_size,
257
+ kmeans_init=kmeans_init,
258
+ kmeans_iters=kmeans_iters,
259
+ decay=decay,
260
+ epsilon=epsilon,
261
+ threshold_ema_dead_code=threshold_ema_dead_code,
262
+ )
263
+ self.codebook_size = codebook_size
264
+
265
+ @property
266
+ def codebook(self):
267
+ return self._codebook.embed
268
+
269
+ def encode(self, x):
270
+ x = x.transpose(1, 2).contiguous()
271
+ x = self.project_in(x)
272
+ embed_in = self._codebook.encode(x)
273
+ return embed_in
274
+
275
+ def decode(self, embed_ind):
276
+ quantize = self._codebook.decode(embed_ind)
277
+ quantize = self.project_out(quantize)
278
+ quantize = quantize.transpose(1, 2).contiguous()
279
+
280
+ return quantize
281
+
282
+ def forward(self, x):
283
+ device = x.device
284
+ x = x.transpose(1, 2).contiguous()
285
+ x = self.project_in(x)
286
+
287
+ quantize, embed_ind = self._codebook(x)
288
+
289
+ if self.training:
290
+ quantize = x + (quantize - x).detach()
291
+
292
+ loss = torch.tensor([0.0], device=device, requires_grad=self.training)
293
+
294
+ if self.training:
295
+ warnings.warn(
296
+ "When using RVQ in training model, first check "
297
+ "https://github.com/facebookresearch/encodec/issues/25 . "
298
+ "The bug wasn't fixed here for reproducibility."
299
+ )
300
+ if self.commitment_weight > 0:
301
+ commit_loss = F.mse_loss(quantize.detach(), x)
302
+ loss = loss + commit_loss * self.commitment_weight
303
+
304
+ quantize = self.project_out(quantize)
305
+ quantize = quantize.transpose(1, 2).contiguous()
306
+ return quantize, embed_ind, loss
307
+
308
+
309
+ class ResidualVectorQuantization(nn.Module):
310
+ """Residual vector quantization implementation.
311
+ Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf
312
+ """
313
+
314
+ def __init__(self, *, num_quantizers, **kwargs):
315
+ super().__init__()
316
+ self.layers = nn.ModuleList(
317
+ [VectorQuantization(**kwargs) for _ in range(num_quantizers)]
318
+ )
319
+
320
+ def forward(self, x, n_q: tp.Optional[int] = None):
321
+ quantized_out = 0.0
322
+ residual = x
323
+
324
+ all_losses = []
325
+ all_indices = []
326
+
327
+ n_q = n_q or len(self.layers)
328
+
329
+ for layer in self.layers[:n_q]:
330
+ quantized, indices, loss = layer(residual)
331
+ residual = residual - quantized
332
+ quantized_out = quantized_out + quantized
333
+
334
+ all_indices.append(indices)
335
+ all_losses.append(loss)
336
+
337
+ out_losses, out_indices = map(torch.stack, (all_losses, all_indices))
338
+ return quantized_out, out_indices, out_losses
339
+
340
+ def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None) -> torch.Tensor:
341
+ residual = x
342
+
343
+ # Return quantized latents, both summed and at each quantizer level
344
+ z_O = 0.0 # Summed quantized latents
345
+ z_o = [] # Quantized latents at each quantizer level
346
+
347
+ all_indices = []
348
+ n_q = n_q or len(self.layers)
349
+ for layer in self.layers[:n_q]:
350
+ indices = layer.encode(residual)
351
+ quantized = layer.decode(indices)
352
+
353
+ z_o += [quantized]
354
+ z_O = z_O + quantized
355
+
356
+ residual = residual - quantized
357
+ all_indices.append(indices)
358
+
359
+ out_indices = torch.stack(all_indices)
360
+ z_o = torch.stack(z_o, dim=1)
361
+
362
+ return out_indices, z_O, z_o
363
+
364
+ def decode(self, q_indices: torch.Tensor) -> torch.Tensor:
365
+ quantized_out = torch.tensor(0.0, device=q_indices.device)
366
+ for i, indices in enumerate(q_indices):
367
+ layer = self.layers[i]
368
+ quantized = layer.decode(indices)
369
+ quantized_out = quantized_out + quantized
370
+ return quantized_out
encodec/quantization/vq.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ import math
7
+ import typing as tp
8
+
9
+ import torch
10
+ from torch import nn
11
+
12
+ from .core_vq import ResidualVectorQuantization
13
+
14
+ ################################################################################
15
+ # Residual quantization module
16
+ ################################################################################
17
+
18
+
19
+ class ResidualVectorQuantizer(nn.Module):
20
+ """Residual Vector Quantizer.
21
+ Args:
22
+ dimension (int): Dimension of the codebooks.
23
+ n_q (int): Number of residual vector quantizers used.
24
+ bins (int): Codebook size.
25
+ decay (float): Decay for exponential moving average over the codebooks.
26
+ kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
27
+ kmeans_iters (int): Number of iterations used for kmeans initialization.
28
+ threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
29
+ that have an exponential moving average cluster size less than the specified threshold with
30
+ randomly selected vector from the current batch.
31
+ """
32
+
33
+ def __init__(
34
+ self,
35
+ dimension: int = 256,
36
+ n_q: int = 8,
37
+ bins: int = 1024,
38
+ decay: float = 0.99,
39
+ kmeans_init: bool = True,
40
+ kmeans_iters: int = 50,
41
+ threshold_ema_dead_code: int = 2,
42
+ ):
43
+ super().__init__()
44
+ self.n_q = n_q
45
+ self.dimension = dimension
46
+ self.bins = bins
47
+ self.decay = decay
48
+ self.kmeans_init = kmeans_init
49
+ self.kmeans_iters = kmeans_iters
50
+ self.threshold_ema_dead_code = threshold_ema_dead_code
51
+ self.vq = ResidualVectorQuantization(
52
+ dim=self.dimension,
53
+ codebook_size=self.bins,
54
+ num_quantizers=self.n_q,
55
+ decay=self.decay,
56
+ kmeans_init=self.kmeans_init,
57
+ kmeans_iters=self.kmeans_iters,
58
+ threshold_ema_dead_code=self.threshold_ema_dead_code,
59
+ )
60
+
61
+ def get_num_quantizers_for_bandwidth(
62
+ self, frame_rate: int, bandwidth: tp.Optional[float] = None
63
+ ) -> int:
64
+ """Return n_q based on specified target bandwidth."""
65
+ bw_per_q = self.get_bandwidth_per_quantizer(frame_rate)
66
+ n_q = self.n_q
67
+ if bandwidth and bandwidth > 0.0:
68
+ # bandwidth is represented as a thousandth of what it is, e.g. 6kbps bandwidth is represented as
69
+ # bandwidth == 6.0
70
+ n_q = int(max(1, math.floor(bandwidth * 1000 / bw_per_q)))
71
+ return n_q
72
+
73
+ def get_bandwidth_per_quantizer(self, frame_rate: int):
74
+ """Return bandwidth per quantizer for a given input frame rate.
75
+ Each quantizer encodes a frame with lg(bins) bits.
76
+ """
77
+ return math.log2(self.bins) * frame_rate
78
+
79
+ def encode(
80
+ self, x: torch.Tensor, frame_rate: int, bandwidth: tp.Optional[float] = None
81
+ ) -> torch.Tensor:
82
+ """Encode a given input tensor with the specified frame rate at the given bandwidth.
83
+ The RVQ encode method sets the appropriate number of quantizers to use
84
+ and returns indices for each quantizer.
85
+ """
86
+ n_q = self.get_num_quantizers_for_bandwidth(frame_rate, bandwidth)
87
+ codes, z_O, z_o = self.vq.encode(x, n_q=n_q)
88
+ return codes, z_O, z_o
89
+
90
+ def decode(self, codes: torch.Tensor) -> torch.Tensor:
91
+ """
92
+ Decode the given codes to the quantized representation.
93
+ """
94
+ quantized = self.vq.decode(codes)
95
+ return quantized
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ gradio
2
+ torch
3
+ audioseal
4
+ git+https://github.com/descriptinc/audiotools
5
+ pydantic==2.10.6