Spaces:
Running
on
Zero
Running
on
Zero
liuhuadai
commited on
Commit
·
052cf68
1
Parent(s):
70bc476
support cot
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- {think_sound → ThinkSound}/__init__.py +0 -0
- {think_sound/configs/model_configs/autoencoders → ThinkSound/configs/model_configs}/stable_audio_2_0_vae.json +0 -0
- think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_kernel_size3.json → ThinkSound/configs/model_configs/thinksound.json +1 -1
- ThinkSound/configs/multimodal_dataset_demo.json +53 -0
- {data_utils → ThinkSound/data}/__init__.py +0 -0
- {think_sound → ThinkSound}/data/datamodule.py +4 -2
- {think_sound → ThinkSound}/data/dataset.py +6 -8
- {think_sound → ThinkSound}/data/utils.py +0 -0
- {think_sound/data → ThinkSound/inference}/__init__.py +0 -0
- {think_sound → ThinkSound}/inference/generation.py +0 -0
- {think_sound → ThinkSound}/inference/sampling.py +0 -0
- {think_sound → ThinkSound}/inference/utils.py +0 -0
- {think_sound → ThinkSound}/models/__init__.py +0 -0
- {think_sound → ThinkSound}/models/autoencoders.py +0 -0
- {think_sound → ThinkSound}/models/blocks.py +92 -1
- {think_sound → ThinkSound}/models/bottleneck.py +0 -0
- {think_sound → ThinkSound}/models/codebook_patterns.py +0 -0
- {think_sound → ThinkSound}/models/conditioners.py +0 -1
- {think_sound → ThinkSound}/models/diffusion.py +1 -3
- {think_sound → ThinkSound}/models/dit.py +0 -0
- {think_sound/models/mmmodules/model → ThinkSound/models}/embeddings.py +36 -0
- {think_sound → ThinkSound}/models/factory.py +0 -0
- {think_sound → ThinkSound}/models/local_attention.py +0 -0
- {think_sound → ThinkSound}/models/mmdit.py +56 -9
- {think_sound → ThinkSound}/models/pretrained.py +0 -0
- {think_sound → ThinkSound}/models/pretransforms.py +0 -0
- {think_sound → ThinkSound}/models/transformer.py +0 -0
- {think_sound/models/mmmodules/model → ThinkSound/models}/transformer_layers.py +2 -2
- {think_sound → ThinkSound}/models/utils.py +0 -0
- {think_sound → ThinkSound}/training/__init__.py +0 -0
- {think_sound → ThinkSound}/training/autoencoders.py +0 -1
- {think_sound → ThinkSound}/training/diffusion.py +1 -948
- {think_sound → ThinkSound}/training/factory.py +0 -0
- {think_sound → ThinkSound}/training/losses/__init__.py +0 -0
- {think_sound → ThinkSound}/training/losses/auraloss.py +0 -0
- {think_sound → ThinkSound}/training/losses/losses.py +0 -0
- {think_sound → ThinkSound}/training/utils.py +0 -0
- app.py +50 -59
- cot_vgg_demo_caption.txt +1 -0
- data_utils/__pycache__/__init__.cpython-310.pyc +0 -0
- data_utils/__pycache__/utils.cpython-310.pyc +0 -0
- data_utils/__pycache__/utils.cpython-39.pyc +0 -0
- data_utils/ext/synchformer/__pycache__/__init__.cpython-310.pyc +0 -0
- data_utils/ext/synchformer/__pycache__/__init__.cpython-39.pyc +0 -0
- data_utils/ext/synchformer/__pycache__/motionformer.cpython-310.pyc +0 -0
- data_utils/ext/synchformer/__pycache__/motionformer.cpython-39.pyc +0 -0
- data_utils/ext/synchformer/__pycache__/synchformer.cpython-310.pyc +0 -0
- data_utils/ext/synchformer/__pycache__/synchformer.cpython-39.pyc +0 -0
- data_utils/ext/synchformer/__pycache__/utils.cpython-310.pyc +0 -0
- data_utils/ext/synchformer/__pycache__/utils.cpython-39.pyc +0 -0
{think_sound → ThinkSound}/__init__.py
RENAMED
File without changes
|
{think_sound/configs/model_configs/autoencoders → ThinkSound/configs/model_configs}/stable_audio_2_0_vae.json
RENAMED
File without changes
|
think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_kernel_size3.json → ThinkSound/configs/model_configs/thinksound.json
RENAMED
@@ -85,7 +85,7 @@
|
|
85 |
"clip_dim":1024,
|
86 |
"sync_dim":768,
|
87 |
"text_dim":2048,
|
88 |
-
"hidden_dim":1024
|
89 |
"depth":21,
|
90 |
"fused_depth":14,
|
91 |
"num_heads":16,
|
|
|
85 |
"clip_dim":1024,
|
86 |
"sync_dim":768,
|
87 |
"text_dim":2048,
|
88 |
+
"hidden_dim":1024,
|
89 |
"depth":21,
|
90 |
"fused_depth":14,
|
91 |
"num_heads":16,
|
ThinkSound/configs/multimodal_dataset_demo.json
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"dataset_type": "multimodal_dir",
|
3 |
+
"video_datasets": [
|
4 |
+
{
|
5 |
+
"id": "vggsound",
|
6 |
+
"path": "dataset/vggsound/video_latents_t5_clip_npz/train",
|
7 |
+
"split_path": "dataset/vggsound/split_txt/train_cot.txt"
|
8 |
+
}
|
9 |
+
],
|
10 |
+
"audio_datasets": [
|
11 |
+
{
|
12 |
+
"id": "audiostock",
|
13 |
+
"path": "dataset/Laion-Audio-630k/audiostock_latents_npz",
|
14 |
+
"split_path": "dataset/Laion-Audio-630k/split_txt/cot_audiostock_1.txt"
|
15 |
+
},
|
16 |
+
{
|
17 |
+
"id": "freesound_no_overlap",
|
18 |
+
"path": "dataset/Laion-Audio-630k/freesound_no_overlap_latents_npz",
|
19 |
+
"split_path": "dataset/Laion-Audio-630k/split_txt/cot_freesound.txt"
|
20 |
+
},
|
21 |
+
{
|
22 |
+
"id": "audioset_sl",
|
23 |
+
"path": "dataset/wavcaps/audioset_sl_latents_npz",
|
24 |
+
"split_path": "dataset/wavcaps/split_txt/cot_audio_sl_1.txt"
|
25 |
+
},
|
26 |
+
{
|
27 |
+
"id": "audiocaps",
|
28 |
+
"path": "dataset/1_audiocaps/audiocaps_latents_npz",
|
29 |
+
"split_path": "dataset/1_audiocaps/split_txt/train_cot.txt"
|
30 |
+
},
|
31 |
+
{
|
32 |
+
"id": "bbc",
|
33 |
+
"path": "dataset/Laion-Audio-630k/bbc_latents_npz",
|
34 |
+
"split_path": "dataset/Laion-Audio-630k/split_txt/cot_bbc_1.txt"
|
35 |
+
}
|
36 |
+
],
|
37 |
+
"val_datasets": [
|
38 |
+
{
|
39 |
+
"id": "vggsound",
|
40 |
+
"path": "dataset/vggsound/video_latents_t5_clip_npz/test",
|
41 |
+
"split_path": "dataset/vggsound/split_txt/test_cot.txt"
|
42 |
+
}
|
43 |
+
],
|
44 |
+
"test_datasets": [
|
45 |
+
{
|
46 |
+
"id": "vggsound",
|
47 |
+
"path": "cot_coarse",
|
48 |
+
"split_path": "cot_vgg_demo_caption.txt"
|
49 |
+
}
|
50 |
+
],
|
51 |
+
"random_crop": true,
|
52 |
+
"input_type": "prompt"
|
53 |
+
}
|
{data_utils → ThinkSound/data}/__init__.py
RENAMED
File without changes
|
{think_sound → ThinkSound}/data/datamodule.py
RENAMED
@@ -33,13 +33,14 @@ def get_configs(audio_configs):
|
|
33 |
return configs
|
34 |
|
35 |
class DataModule(L.LightningDataModule):
|
36 |
-
def __init__(self, dataset_config, batch_size, test_batch_size, sample_size, sample_rate, audio_channels=2, num_workers=4,repeat_num=5):
|
37 |
super().__init__()
|
38 |
dataset_type = dataset_config.get("dataset_type", None)
|
39 |
self.batch_size = batch_size
|
40 |
self.num_workers = num_workers
|
41 |
self.test_batch_size = test_batch_size
|
42 |
self.repeat_num = repeat_num
|
|
|
43 |
assert dataset_type is not None, "Dataset type must be specified in dataset config"
|
44 |
|
45 |
if audio_channels == 1:
|
@@ -140,7 +141,8 @@ class DataModule(L.LightningDataModule):
|
|
140 |
random_crop=random_crop,
|
141 |
input_type=self.input_type,
|
142 |
fps=self.input_type,
|
143 |
-
force_channels=self.force_channels
|
|
|
144 |
)
|
145 |
|
146 |
if stage == 'fit':
|
|
|
33 |
return configs
|
34 |
|
35 |
class DataModule(L.LightningDataModule):
|
36 |
+
def __init__(self, dataset_config, batch_size, test_batch_size, sample_size, sample_rate, audio_channels=2, num_workers=4,repeat_num=5,latent_length=194):
|
37 |
super().__init__()
|
38 |
dataset_type = dataset_config.get("dataset_type", None)
|
39 |
self.batch_size = batch_size
|
40 |
self.num_workers = num_workers
|
41 |
self.test_batch_size = test_batch_size
|
42 |
self.repeat_num = repeat_num
|
43 |
+
self.latent_length = latent_length
|
44 |
assert dataset_type is not None, "Dataset type must be specified in dataset config"
|
45 |
|
46 |
if audio_channels == 1:
|
|
|
141 |
random_crop=random_crop,
|
142 |
input_type=self.input_type,
|
143 |
fps=self.input_type,
|
144 |
+
force_channels=self.force_channels,
|
145 |
+
latent_length=self.latent_length
|
146 |
)
|
147 |
|
148 |
if stage == 'fit':
|
{think_sound → ThinkSound}/data/dataset.py
RENAMED
@@ -342,8 +342,7 @@ class LatentDataset(torch.utils.data.Dataset):
|
|
342 |
info = {}
|
343 |
audio, video = self.load_file(audio_filename, info)
|
344 |
info["path"] = audio_filename
|
345 |
-
|
346 |
-
assert video.shape == (72,1024), f'{video.shape} input error, id: {id}'
|
347 |
info['id'] = Path(audio_filename).stem
|
348 |
for root_path in self.root_paths:
|
349 |
if root_path in audio_filename:
|
@@ -434,8 +433,7 @@ class AudioDataset(torch.utils.data.Dataset):
|
|
434 |
info = {}
|
435 |
audio, video = self.load_file(audio_filename, info)
|
436 |
info["path"] = audio_filename
|
437 |
-
|
438 |
-
assert video.shape == (72,1024), f'{video.shape} input error, id: {id}'
|
439 |
info['id'] = Path(audio_filename).stem
|
440 |
for root_path in self.root_paths:
|
441 |
if root_path in audio_filename:
|
@@ -454,8 +452,9 @@ class VideoDataset(torch.utils.data.Dataset):
|
|
454 |
input_type="prompt",
|
455 |
fps=4,
|
456 |
force_channels="stereo",
|
|
|
457 |
):
|
458 |
-
|
459 |
super().__init__()
|
460 |
self.filenames = []
|
461 |
print(f'configs: {configs[0]}')
|
@@ -523,7 +522,7 @@ class VideoDataset(torch.utils.data.Dataset):
|
|
523 |
if 'latent' in data.keys():
|
524 |
audio = data['latent']
|
525 |
else:
|
526 |
-
audio = torch.zeros(64,
|
527 |
info['video_exist'] = self.video_exist
|
528 |
# except:
|
529 |
# print(f'error load file: {filename}')
|
@@ -540,8 +539,7 @@ class VideoDataset(torch.utils.data.Dataset):
|
|
540 |
info = {}
|
541 |
audio, video = self.load_file(audio_filename, info)
|
542 |
info["path"] = audio_filename
|
543 |
-
|
544 |
-
assert video.shape == (72,1024), f'{video.shape} input error, id: {id}'
|
545 |
info['id'] = Path(audio_filename).stem
|
546 |
for root_path in self.root_paths:
|
547 |
if root_path in audio_filename:
|
|
|
342 |
info = {}
|
343 |
audio, video = self.load_file(audio_filename, info)
|
344 |
info["path"] = audio_filename
|
345 |
+
|
|
|
346 |
info['id'] = Path(audio_filename).stem
|
347 |
for root_path in self.root_paths:
|
348 |
if root_path in audio_filename:
|
|
|
433 |
info = {}
|
434 |
audio, video = self.load_file(audio_filename, info)
|
435 |
info["path"] = audio_filename
|
436 |
+
|
|
|
437 |
info['id'] = Path(audio_filename).stem
|
438 |
for root_path in self.root_paths:
|
439 |
if root_path in audio_filename:
|
|
|
452 |
input_type="prompt",
|
453 |
fps=4,
|
454 |
force_channels="stereo",
|
455 |
+
latent_length=194, # default latent length for video dataset
|
456 |
):
|
457 |
+
self.latent_length = latent_length
|
458 |
super().__init__()
|
459 |
self.filenames = []
|
460 |
print(f'configs: {configs[0]}')
|
|
|
522 |
if 'latent' in data.keys():
|
523 |
audio = data['latent']
|
524 |
else:
|
525 |
+
audio = torch.zeros(64,self.latent_length)
|
526 |
info['video_exist'] = self.video_exist
|
527 |
# except:
|
528 |
# print(f'error load file: {filename}')
|
|
|
539 |
info = {}
|
540 |
audio, video = self.load_file(audio_filename, info)
|
541 |
info["path"] = audio_filename
|
542 |
+
|
|
|
543 |
info['id'] = Path(audio_filename).stem
|
544 |
for root_path in self.root_paths:
|
545 |
if root_path in audio_filename:
|
{think_sound → ThinkSound}/data/utils.py
RENAMED
File without changes
|
{think_sound/data → ThinkSound/inference}/__init__.py
RENAMED
File without changes
|
{think_sound → ThinkSound}/inference/generation.py
RENAMED
File without changes
|
{think_sound → ThinkSound}/inference/sampling.py
RENAMED
File without changes
|
{think_sound → ThinkSound}/inference/utils.py
RENAMED
File without changes
|
{think_sound → ThinkSound}/models/__init__.py
RENAMED
File without changes
|
{think_sound → ThinkSound}/models/autoencoders.py
RENAMED
File without changes
|
{think_sound → ThinkSound}/models/blocks.py
RENAMED
@@ -336,4 +336,95 @@ class SnakeBeta(nn.Module):
|
|
336 |
beta = torch.exp(beta)
|
337 |
x = snake_beta(x, alpha, beta)
|
338 |
|
339 |
-
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
336 |
beta = torch.exp(beta)
|
337 |
x = snake_beta(x, alpha, beta)
|
338 |
|
339 |
+
return x
|
340 |
+
|
341 |
+
class ChannelLastConv1d(nn.Conv1d):
|
342 |
+
|
343 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
344 |
+
x = x.permute(0, 2, 1)
|
345 |
+
x = super().forward(x)
|
346 |
+
x = x.permute(0, 2, 1)
|
347 |
+
return x
|
348 |
+
|
349 |
+
|
350 |
+
# https://github.com/Stability-AI/sd3-ref
|
351 |
+
class MLP(nn.Module):
|
352 |
+
|
353 |
+
def __init__(
|
354 |
+
self,
|
355 |
+
dim: int,
|
356 |
+
hidden_dim: int,
|
357 |
+
multiple_of: int = 256,
|
358 |
+
):
|
359 |
+
"""
|
360 |
+
Initialize the FeedForward module.
|
361 |
+
|
362 |
+
Args:
|
363 |
+
dim (int): Input dimension.
|
364 |
+
hidden_dim (int): Hidden dimension of the feedforward layer.
|
365 |
+
multiple_of (int): Value to ensure hidden dimension is a multiple of this value.
|
366 |
+
|
367 |
+
Attributes:
|
368 |
+
w1 (ColumnParallelLinear): Linear transformation for the first layer.
|
369 |
+
w2 (RowParallelLinear): Linear transformation for the second layer.
|
370 |
+
w3 (ColumnParallelLinear): Linear transformation for the third layer.
|
371 |
+
|
372 |
+
"""
|
373 |
+
super().__init__()
|
374 |
+
hidden_dim = int(2 * hidden_dim / 3)
|
375 |
+
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
|
376 |
+
|
377 |
+
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
|
378 |
+
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
|
379 |
+
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
|
380 |
+
|
381 |
+
def forward(self, x):
|
382 |
+
return self.w2(F.silu(self.w1(x)) * self.w3(x))
|
383 |
+
|
384 |
+
|
385 |
+
class ConvMLP(nn.Module):
|
386 |
+
|
387 |
+
def __init__(
|
388 |
+
self,
|
389 |
+
dim: int,
|
390 |
+
hidden_dim: int,
|
391 |
+
multiple_of: int = 256,
|
392 |
+
kernel_size: int = 3,
|
393 |
+
padding: int = 1,
|
394 |
+
):
|
395 |
+
"""
|
396 |
+
Initialize the FeedForward module.
|
397 |
+
|
398 |
+
Args:
|
399 |
+
dim (int): Input dimension.
|
400 |
+
hidden_dim (int): Hidden dimension of the feedforward layer.
|
401 |
+
multiple_of (int): Value to ensure hidden dimension is a multiple of this value.
|
402 |
+
|
403 |
+
Attributes:
|
404 |
+
w1 (ColumnParallelLinear): Linear transformation for the first layer.
|
405 |
+
w2 (RowParallelLinear): Linear transformation for the second layer.
|
406 |
+
w3 (ColumnParallelLinear): Linear transformation for the third layer.
|
407 |
+
|
408 |
+
"""
|
409 |
+
super().__init__()
|
410 |
+
hidden_dim = int(2 * hidden_dim / 3)
|
411 |
+
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
|
412 |
+
|
413 |
+
self.w1 = ChannelLastConv1d(dim,
|
414 |
+
hidden_dim,
|
415 |
+
bias=False,
|
416 |
+
kernel_size=kernel_size,
|
417 |
+
padding=padding)
|
418 |
+
self.w2 = ChannelLastConv1d(hidden_dim,
|
419 |
+
dim,
|
420 |
+
bias=False,
|
421 |
+
kernel_size=kernel_size,
|
422 |
+
padding=padding)
|
423 |
+
self.w3 = ChannelLastConv1d(dim,
|
424 |
+
hidden_dim,
|
425 |
+
bias=False,
|
426 |
+
kernel_size=kernel_size,
|
427 |
+
padding=padding)
|
428 |
+
|
429 |
+
def forward(self, x):
|
430 |
+
return self.w2(F.silu(self.w1(x)) * self.w3(x))
|
{think_sound → ThinkSound}/models/bottleneck.py
RENAMED
File without changes
|
{think_sound → ThinkSound}/models/codebook_patterns.py
RENAMED
File without changes
|
{think_sound → ThinkSound}/models/conditioners.py
RENAMED
@@ -7,7 +7,6 @@ import typing as tp
|
|
7 |
import gc
|
8 |
from typing import Literal, Optional
|
9 |
import os
|
10 |
-
from .adp import NumberEmbedder
|
11 |
from ..inference.utils import set_audio_channels
|
12 |
from .factory import create_pretransform_from_config
|
13 |
from .pretransforms import Pretransform
|
|
|
7 |
import gc
|
8 |
from typing import Literal, Optional
|
9 |
import os
|
|
|
10 |
from ..inference.utils import set_audio_channels
|
11 |
from .factory import create_pretransform_from_config
|
12 |
from .pretransforms import Pretransform
|
{think_sound → ThinkSound}/models/diffusion.py
RENAMED
@@ -7,14 +7,12 @@ import typing as tp
|
|
7 |
|
8 |
from .blocks import ResConvBlock, FourierFeatures, Upsample1d, Upsample1d_2, Downsample1d, Downsample1d_2, SelfAttention1d, SkipBlock, expand_to_planes
|
9 |
from .conditioners import MultiConditioner, create_multi_conditioner_from_conditioning_config
|
10 |
-
from .dit import DiffusionTransformer
|
11 |
from .mmdit import MMAudio
|
12 |
from .factory import create_pretransform_from_config
|
13 |
from .pretransforms import Pretransform
|
14 |
from ..inference.generation import generate_diffusion_cond
|
15 |
|
16 |
-
from .adp import UNetCFG1d, UNet1d
|
17 |
-
|
18 |
from time import time
|
19 |
|
20 |
class Profiler:
|
|
|
7 |
|
8 |
from .blocks import ResConvBlock, FourierFeatures, Upsample1d, Upsample1d_2, Downsample1d, Downsample1d_2, SelfAttention1d, SkipBlock, expand_to_planes
|
9 |
from .conditioners import MultiConditioner, create_multi_conditioner_from_conditioning_config
|
10 |
+
# from .dit import DiffusionTransformer
|
11 |
from .mmdit import MMAudio
|
12 |
from .factory import create_pretransform_from_config
|
13 |
from .pretransforms import Pretransform
|
14 |
from ..inference.generation import generate_diffusion_cond
|
15 |
|
|
|
|
|
16 |
from time import time
|
17 |
|
18 |
class Profiler:
|
{think_sound → ThinkSound}/models/dit.py
RENAMED
File without changes
|
{think_sound/models/mmmodules/model → ThinkSound/models}/embeddings.py
RENAMED
@@ -3,6 +3,42 @@ import torch.nn as nn
|
|
3 |
|
4 |
# https://github.com/facebookresearch/DiT
|
5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
|
7 |
class TimestepEmbedder(nn.Module):
|
8 |
"""
|
|
|
3 |
|
4 |
# https://github.com/facebookresearch/DiT
|
5 |
|
6 |
+
from typing import Union
|
7 |
+
|
8 |
+
import torch
|
9 |
+
from einops import rearrange
|
10 |
+
from torch import Tensor
|
11 |
+
|
12 |
+
# Ref: https://github.com/black-forest-labs/flux/blob/main/src/flux/math.py
|
13 |
+
# Ref: https://github.com/lucidrains/rotary-embedding-torch
|
14 |
+
|
15 |
+
|
16 |
+
def compute_rope_rotations(length: int,
|
17 |
+
dim: int,
|
18 |
+
theta: int,
|
19 |
+
*,
|
20 |
+
freq_scaling: float = 1.0,
|
21 |
+
device: Union[torch.device, str] = 'cpu') -> Tensor:
|
22 |
+
assert dim % 2 == 0
|
23 |
+
|
24 |
+
with torch.amp.autocast(device_type='cuda', enabled=False):
|
25 |
+
pos = torch.arange(length, dtype=torch.float32, device=device)
|
26 |
+
freqs = 1.0 / (theta**(torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim))
|
27 |
+
freqs *= freq_scaling
|
28 |
+
|
29 |
+
rot = torch.einsum('..., f -> ... f', pos, freqs)
|
30 |
+
rot = torch.stack([torch.cos(rot), -torch.sin(rot), torch.sin(rot), torch.cos(rot)], dim=-1)
|
31 |
+
rot = rearrange(rot, 'n d (i j) -> 1 n d i j', i=2, j=2)
|
32 |
+
return rot
|
33 |
+
|
34 |
+
|
35 |
+
def apply_rope(x: Tensor, rot: Tensor) -> tuple[Tensor, Tensor]:
|
36 |
+
with torch.amp.autocast(device_type='cuda', enabled=False):
|
37 |
+
_x = x.float()
|
38 |
+
_x = _x.view(*_x.shape[:-1], -1, 1, 2)
|
39 |
+
x_out = rot[..., 0] * _x[..., 0] + rot[..., 1] * _x[..., 1]
|
40 |
+
return x_out.reshape(*x.shape).to(dtype=x.dtype)
|
41 |
+
|
42 |
|
43 |
class TimestepEmbedder(nn.Module):
|
44 |
"""
|
{think_sound → ThinkSound}/models/factory.py
RENAMED
File without changes
|
{think_sound → ThinkSound}/models/local_attention.py
RENAMED
File without changes
|
{think_sound → ThinkSound}/models/mmdit.py
RENAMED
@@ -6,10 +6,10 @@ import torch
|
|
6 |
import torch.nn as nn
|
7 |
import torch.nn.functional as F
|
8 |
import sys
|
9 |
-
from .
|
10 |
-
from .
|
11 |
-
from .
|
12 |
-
from .
|
13 |
from .utils import resample
|
14 |
|
15 |
log = logging.getLogger()
|
@@ -24,7 +24,6 @@ class PreprocessedConditions:
|
|
24 |
text_f_c: torch.Tensor
|
25 |
|
26 |
|
27 |
-
# Partially from https://github.com/facebookresearch/DiT
|
28 |
class MMAudio(nn.Module):
|
29 |
|
30 |
def __init__(self,
|
@@ -94,7 +93,6 @@ class MMAudio(nn.Module):
|
|
94 |
nn.Linear(hidden_dim * 4, hidden_dim, bias=False),
|
95 |
nn.Sigmoid()
|
96 |
)
|
97 |
-
# 初始化最后一层权重为零,促进初始均匀融合
|
98 |
nn.init.zeros_(self.gated_mlp_v[3].weight)
|
99 |
nn.init.zeros_(self.gated_mlp_t[3].weight)
|
100 |
if v2:
|
@@ -441,9 +439,9 @@ class MMAudio(nn.Module):
|
|
441 |
# clip_f = torch.cat([clip_f,empty_clip_f], dim=0)
|
442 |
# sync_f = torch.cat([sync_f,empty_sync_f], dim=0)
|
443 |
# text_f = torch.cat([text_f,empty_text_f], dim=0)
|
444 |
-
clip_f =
|
445 |
-
sync_f =
|
446 |
-
text_f =
|
447 |
if t5_features is not None:
|
448 |
empty_t5_features = torch.zeros_like(t5_features, device=latent.device)
|
449 |
# t5_features = torch.cat([t5_features,empty_t5_features], dim=0)
|
@@ -529,3 +527,52 @@ class MMAudio(nn.Module):
|
|
529 |
def sync_seq_len(self) -> int:
|
530 |
return self._sync_seq_len
|
531 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
import torch.nn as nn
|
7 |
import torch.nn.functional as F
|
8 |
import sys
|
9 |
+
from .embeddings import compute_rope_rotations
|
10 |
+
from .embeddings import TimestepEmbedder
|
11 |
+
from .blocks import MLP, ChannelLastConv1d, ConvMLP
|
12 |
+
from .transformer_layers import (FinalBlock, JointBlock, MMDitSingleBlock)
|
13 |
from .utils import resample
|
14 |
|
15 |
log = logging.getLogger()
|
|
|
24 |
text_f_c: torch.Tensor
|
25 |
|
26 |
|
|
|
27 |
class MMAudio(nn.Module):
|
28 |
|
29 |
def __init__(self,
|
|
|
93 |
nn.Linear(hidden_dim * 4, hidden_dim, bias=False),
|
94 |
nn.Sigmoid()
|
95 |
)
|
|
|
96 |
nn.init.zeros_(self.gated_mlp_v[3].weight)
|
97 |
nn.init.zeros_(self.gated_mlp_t[3].weight)
|
98 |
if v2:
|
|
|
439 |
# clip_f = torch.cat([clip_f,empty_clip_f], dim=0)
|
440 |
# sync_f = torch.cat([sync_f,empty_sync_f], dim=0)
|
441 |
# text_f = torch.cat([text_f,empty_text_f], dim=0)
|
442 |
+
clip_f = safe_cat(clip_f,self.get_empty_clip_sequence(bsz), dim=0, match_dim=1)
|
443 |
+
sync_f = safe_cat(sync_f,self.get_empty_sync_sequence(bsz), dim=0, match_dim=1)
|
444 |
+
text_f = safe_cat(text_f,self.get_empty_string_sequence(bsz), dim=0, match_dim=1)
|
445 |
if t5_features is not None:
|
446 |
empty_t5_features = torch.zeros_like(t5_features, device=latent.device)
|
447 |
# t5_features = torch.cat([t5_features,empty_t5_features], dim=0)
|
|
|
527 |
def sync_seq_len(self) -> int:
|
528 |
return self._sync_seq_len
|
529 |
|
530 |
+
|
531 |
+
|
532 |
+
|
533 |
+
|
534 |
+
|
535 |
+
|
536 |
+
|
537 |
+
|
538 |
+
|
539 |
+
|
540 |
+
|
541 |
+
|
542 |
+
|
543 |
+
|
544 |
+
|
545 |
+
|
546 |
+
def truncate_to_target(tensor, target_size, dim=1):
|
547 |
+
current_size = tensor.size(dim)
|
548 |
+
if current_size > target_size:
|
549 |
+
slices = [slice(None)] * tensor.dim()
|
550 |
+
slices[dim] = slice(0, target_size)
|
551 |
+
return tensor[slices]
|
552 |
+
return tensor
|
553 |
+
|
554 |
+
def pad_to_target(tensor, target_size, dim=1, pad_value=0):
|
555 |
+
current_size = tensor.size(dim)
|
556 |
+
if current_size < target_size:
|
557 |
+
pad_size = target_size - current_size
|
558 |
+
|
559 |
+
pad_config = [0, 0] * tensor.dim()
|
560 |
+
pad_index = 2 * (tensor.dim() - dim - 1) + 1
|
561 |
+
pad_config[pad_index] = pad_size
|
562 |
+
|
563 |
+
return torch.nn.functional.pad(tensor, pad_config, value=pad_value)
|
564 |
+
return tensor
|
565 |
+
|
566 |
+
|
567 |
+
def safe_cat(tensor1, tensor2, dim=0, match_dim=1):
|
568 |
+
|
569 |
+
target_size = tensor2.size(match_dim)
|
570 |
+
|
571 |
+
if tensor1.size(match_dim) > target_size:
|
572 |
+
tensor1 = truncate_to_target(tensor1, target_size, match_dim)
|
573 |
+
|
574 |
+
else:
|
575 |
+
tensor1 = pad_to_target(tensor1, target_size, match_dim)
|
576 |
+
|
577 |
+
return torch.cat([tensor1, tensor2], dim=dim)
|
578 |
+
|
{think_sound → ThinkSound}/models/pretrained.py
RENAMED
File without changes
|
{think_sound → ThinkSound}/models/pretransforms.py
RENAMED
File without changes
|
{think_sound → ThinkSound}/models/transformer.py
RENAMED
File without changes
|
{think_sound/models/mmmodules/model → ThinkSound/models}/transformer_layers.py
RENAMED
@@ -6,8 +6,8 @@ import torch.nn.functional as F
|
|
6 |
from einops import rearrange
|
7 |
from einops.layers.torch import Rearrange
|
8 |
|
9 |
-
from
|
10 |
-
from
|
11 |
try:
|
12 |
from flash_attn import flash_attn_func, flash_attn_kvpacked_func
|
13 |
print('flash_attn installed, using Flash Attention')
|
|
|
6 |
from einops import rearrange
|
7 |
from einops.layers.torch import Rearrange
|
8 |
|
9 |
+
from .embeddings import apply_rope
|
10 |
+
from .blocks import MLP, ChannelLastConv1d, ConvMLP
|
11 |
try:
|
12 |
from flash_attn import flash_attn_func, flash_attn_kvpacked_func
|
13 |
print('flash_attn installed, using Flash Attention')
|
{think_sound → ThinkSound}/models/utils.py
RENAMED
File without changes
|
{think_sound → ThinkSound}/training/__init__.py
RENAMED
File without changes
|
{think_sound → ThinkSound}/training/autoencoders.py
RENAMED
@@ -9,7 +9,6 @@ from .losses.auraloss import SumAndDifferenceSTFTLoss, MultiResolutionSTFTLoss,
|
|
9 |
import lightning as L
|
10 |
from lightning.pytorch.callbacks import Callback
|
11 |
from ..models.autoencoders import AudioAutoencoder
|
12 |
-
from ..models.discriminators import EncodecDiscriminator, OobleckDiscriminator, DACGANLoss
|
13 |
from ..models.bottleneck import VAEBottleneck, RVQBottleneck, DACRVQBottleneck, DACRVQVAEBottleneck, RVQVAEBottleneck, WassersteinBottleneck
|
14 |
from .losses import MultiLoss, AuralossLoss, ValueLoss, L1Loss
|
15 |
from .utils import create_optimizer_from_config, create_scheduler_from_config
|
|
|
9 |
import lightning as L
|
10 |
from lightning.pytorch.callbacks import Callback
|
11 |
from ..models.autoencoders import AudioAutoencoder
|
|
|
12 |
from ..models.bottleneck import VAEBottleneck, RVQBottleneck, DACRVQBottleneck, DACRVQVAEBottleneck, RVQVAEBottleneck, WassersteinBottleneck
|
13 |
from .losses import MultiLoss, AuralossLoss, ValueLoss, L1Loss
|
14 |
from .utils import create_optimizer_from_config, create_scheduler_from_config
|
{think_sound → ThinkSound}/training/diffusion.py
RENAMED
@@ -20,7 +20,6 @@ from pytorch_lightning.utilities.rank_zero import rank_zero_only
|
|
20 |
from ..inference.sampling import get_alphas_sigmas, sample, sample_discrete_euler
|
21 |
from ..models.diffusion import DiffusionModelWrapper, ConditionedDiffusionModelWrapper
|
22 |
from ..models.autoencoders import DiffusionAutoencoder
|
23 |
-
from ..models.diffusion_prior import PriorType
|
24 |
from .autoencoders import create_loss_modules_from_bottleneck
|
25 |
from .losses import AuralossLoss, MSELoss, MultiLoss
|
26 |
from .utils import create_optimizer_from_config, create_scheduler_from_config, mask_from_frac_lengths, generate_mask, generate_channel_mask
|
@@ -846,10 +845,9 @@ class DiffusionCondTrainingWrapper(L.LightningModule):
|
|
846 |
|
847 |
def predict_step(self, batch, batch_idx):
|
848 |
reals, metadata = batch
|
849 |
-
# import ipdb
|
850 |
-
# ipdb.set_trace()
|
851 |
ids = [item['id'] for item in metadata]
|
852 |
batch_size, length = reals.shape[0], reals.shape[2]
|
|
|
853 |
with torch.amp.autocast('cuda'):
|
854 |
conditioning = self.diffusion.conditioner(metadata, self.device)
|
855 |
|
@@ -878,7 +876,6 @@ class DiffusionCondTrainingWrapper(L.LightningModule):
|
|
878 |
end_time = time.time()
|
879 |
execution_time = end_time - start_time
|
880 |
print(f"执行时间: {execution_time:.2f} 秒")
|
881 |
-
breakpoint()
|
882 |
if self.diffusion.pretransform is not None:
|
883 |
fakes = self.diffusion.pretransform.decode(fakes)
|
884 |
|
@@ -1077,947 +1074,3 @@ class DiffusionCondDemoCallback(Callback):
|
|
1077 |
gc.collect()
|
1078 |
torch.cuda.empty_cache()
|
1079 |
module.train()
|
1080 |
-
|
1081 |
-
class DiffusionCondInpaintTrainingWrapper(L.LightningModule):
|
1082 |
-
'''
|
1083 |
-
Wrapper for training a conditional audio diffusion model.
|
1084 |
-
'''
|
1085 |
-
def __init__(
|
1086 |
-
self,
|
1087 |
-
model: ConditionedDiffusionModelWrapper,
|
1088 |
-
lr: float = 1e-4,
|
1089 |
-
max_mask_segments = 10,
|
1090 |
-
log_loss_info: bool = False,
|
1091 |
-
optimizer_configs: dict = None,
|
1092 |
-
use_ema: bool = True,
|
1093 |
-
pre_encoded: bool = False,
|
1094 |
-
cfg_dropout_prob = 0.1,
|
1095 |
-
timestep_sampler: tp.Literal["uniform", "logit_normal"] = "uniform",
|
1096 |
-
):
|
1097 |
-
super().__init__()
|
1098 |
-
|
1099 |
-
self.diffusion = model
|
1100 |
-
|
1101 |
-
self.use_ema = use_ema
|
1102 |
-
|
1103 |
-
if self.use_ema:
|
1104 |
-
self.diffusion_ema = EMA(
|
1105 |
-
self.diffusion.model,
|
1106 |
-
beta=0.9999,
|
1107 |
-
power=3/4,
|
1108 |
-
update_every=1,
|
1109 |
-
update_after_step=1,
|
1110 |
-
include_online_model=False
|
1111 |
-
)
|
1112 |
-
else:
|
1113 |
-
self.diffusion_ema = None
|
1114 |
-
|
1115 |
-
self.cfg_dropout_prob = cfg_dropout_prob
|
1116 |
-
|
1117 |
-
self.lr = lr
|
1118 |
-
self.max_mask_segments = max_mask_segments
|
1119 |
-
|
1120 |
-
self.rng = torch.quasirandom.SobolEngine(1, scramble=True)
|
1121 |
-
|
1122 |
-
self.timestep_sampler = timestep_sampler
|
1123 |
-
|
1124 |
-
self.diffusion_objective = model.diffusion_objective
|
1125 |
-
|
1126 |
-
self.loss_modules = [
|
1127 |
-
MSELoss("output",
|
1128 |
-
"targets",
|
1129 |
-
weight=1.0,
|
1130 |
-
name="mse_loss"
|
1131 |
-
)
|
1132 |
-
]
|
1133 |
-
|
1134 |
-
self.losses = MultiLoss(self.loss_modules)
|
1135 |
-
|
1136 |
-
self.log_loss_info = log_loss_info
|
1137 |
-
|
1138 |
-
assert lr is not None or optimizer_configs is not None, "Must specify either lr or optimizer_configs in training config"
|
1139 |
-
|
1140 |
-
if optimizer_configs is None:
|
1141 |
-
optimizer_configs = {
|
1142 |
-
"diffusion": {
|
1143 |
-
"optimizer": {
|
1144 |
-
"type": "Adam",
|
1145 |
-
"config": {
|
1146 |
-
"lr": lr
|
1147 |
-
}
|
1148 |
-
}
|
1149 |
-
}
|
1150 |
-
}
|
1151 |
-
else:
|
1152 |
-
if lr is not None:
|
1153 |
-
print(f"WARNING: learning_rate and optimizer_configs both specified in config. Ignoring learning_rate and using optimizer_configs.")
|
1154 |
-
|
1155 |
-
self.optimizer_configs = optimizer_configs
|
1156 |
-
|
1157 |
-
self.pre_encoded = pre_encoded
|
1158 |
-
|
1159 |
-
def configure_optimizers(self):
|
1160 |
-
diffusion_opt_config = self.optimizer_configs['diffusion']
|
1161 |
-
opt_diff = create_optimizer_from_config(diffusion_opt_config['optimizer'], self.diffusion.parameters())
|
1162 |
-
|
1163 |
-
if "scheduler" in diffusion_opt_config:
|
1164 |
-
sched_diff = create_scheduler_from_config(diffusion_opt_config['scheduler'], opt_diff)
|
1165 |
-
sched_diff_config = {
|
1166 |
-
"scheduler": sched_diff,
|
1167 |
-
"interval": "step"
|
1168 |
-
}
|
1169 |
-
return [opt_diff], [sched_diff_config]
|
1170 |
-
|
1171 |
-
return [opt_diff]
|
1172 |
-
|
1173 |
-
def random_mask(self, sequence, max_mask_length):
|
1174 |
-
b, _, sequence_length = sequence.size()
|
1175 |
-
|
1176 |
-
# Create a mask tensor for each batch element
|
1177 |
-
masks = []
|
1178 |
-
|
1179 |
-
for i in range(b):
|
1180 |
-
mask_type = random.randint(0, 2)
|
1181 |
-
|
1182 |
-
if mask_type == 0: # Random mask with multiple segments
|
1183 |
-
num_segments = random.randint(1, self.max_mask_segments)
|
1184 |
-
max_segment_length = max_mask_length // num_segments
|
1185 |
-
|
1186 |
-
segment_lengths = random.sample(range(1, max_segment_length + 1), num_segments)
|
1187 |
-
|
1188 |
-
mask = torch.ones((1, 1, sequence_length))
|
1189 |
-
for length in segment_lengths:
|
1190 |
-
mask_start = random.randint(0, sequence_length - length)
|
1191 |
-
mask[:, :, mask_start:mask_start + length] = 0
|
1192 |
-
|
1193 |
-
elif mask_type == 1: # Full mask
|
1194 |
-
mask = torch.zeros((1, 1, sequence_length))
|
1195 |
-
|
1196 |
-
elif mask_type == 2: # Causal mask
|
1197 |
-
mask = torch.ones((1, 1, sequence_length))
|
1198 |
-
mask_length = random.randint(1, max_mask_length)
|
1199 |
-
mask[:, :, -mask_length:] = 0
|
1200 |
-
|
1201 |
-
mask = mask.to(sequence.device)
|
1202 |
-
masks.append(mask)
|
1203 |
-
|
1204 |
-
# Concatenate the mask tensors into a single tensor
|
1205 |
-
mask = torch.cat(masks, dim=0).to(sequence.device)
|
1206 |
-
|
1207 |
-
# Apply the mask to the sequence tensor for each batch element
|
1208 |
-
masked_sequence = sequence * mask
|
1209 |
-
|
1210 |
-
return masked_sequence, mask
|
1211 |
-
|
1212 |
-
def training_step(self, batch, batch_idx):
|
1213 |
-
reals, metadata = batch
|
1214 |
-
|
1215 |
-
p = Profiler()
|
1216 |
-
|
1217 |
-
if reals.ndim == 4 and reals.shape[0] == 1:
|
1218 |
-
reals = reals[0]
|
1219 |
-
|
1220 |
-
loss_info = {}
|
1221 |
-
|
1222 |
-
diffusion_input = reals
|
1223 |
-
|
1224 |
-
if not self.pre_encoded:
|
1225 |
-
loss_info["audio_reals"] = diffusion_input
|
1226 |
-
|
1227 |
-
p.tick("setup")
|
1228 |
-
|
1229 |
-
with torch.amp.autocast('cuda'):
|
1230 |
-
conditioning = self.diffusion.conditioner(metadata, self.device)
|
1231 |
-
|
1232 |
-
p.tick("conditioning")
|
1233 |
-
|
1234 |
-
if self.diffusion.pretransform is not None:
|
1235 |
-
self.diffusion.pretransform.to(self.device)
|
1236 |
-
|
1237 |
-
if not self.pre_encoded:
|
1238 |
-
with torch.amp.autocast('cuda') and torch.set_grad_enabled(self.diffusion.pretransform.enable_grad):
|
1239 |
-
diffusion_input = self.diffusion.pretransform.encode(diffusion_input)
|
1240 |
-
p.tick("pretransform")
|
1241 |
-
|
1242 |
-
# If mask_padding is on, interpolate the padding masks to the size of the pretransformed input
|
1243 |
-
# if use_padding_mask:
|
1244 |
-
# padding_masks = F.interpolate(padding_masks.unsqueeze(1).float(), size=diffusion_input.shape[2], mode="nearest").squeeze(1).bool()
|
1245 |
-
else:
|
1246 |
-
# Apply scale to pre-encoded latents if needed, as the pretransform encode function will not be run
|
1247 |
-
if hasattr(self.diffusion.pretransform, "scale") and self.diffusion.pretransform.scale != 1.0:
|
1248 |
-
diffusion_input = diffusion_input / self.diffusion.pretransform.scale
|
1249 |
-
|
1250 |
-
# Max mask size is the full sequence length
|
1251 |
-
max_mask_length = diffusion_input.shape[2]
|
1252 |
-
|
1253 |
-
# Create a mask of random length for a random slice of the input
|
1254 |
-
masked_input, mask = self.random_mask(diffusion_input, max_mask_length)
|
1255 |
-
|
1256 |
-
# conditioning['inpaint_mask'] = [mask]
|
1257 |
-
conditioning['inpaint_masked_input'] = [masked_input]
|
1258 |
-
|
1259 |
-
if self.timestep_sampler == "uniform":
|
1260 |
-
# Draw uniformly distributed continuous timesteps
|
1261 |
-
t = self.rng.draw(reals.shape[0])[:, 0].to(self.device)
|
1262 |
-
elif self.timestep_sampler == "logit_normal":
|
1263 |
-
t = torch.sigmoid(torch.randn(reals.shape[0], device=self.device))
|
1264 |
-
|
1265 |
-
# Calculate the noise schedule parameters for those timesteps
|
1266 |
-
if self.diffusion_objective == "v":
|
1267 |
-
alphas, sigmas = get_alphas_sigmas(t)
|
1268 |
-
elif self.diffusion_objective == "rectified_flow":
|
1269 |
-
alphas, sigmas = 1-t, t
|
1270 |
-
|
1271 |
-
# Combine the ground truth data and the noise
|
1272 |
-
alphas = alphas[:, None, None]
|
1273 |
-
sigmas = sigmas[:, None, None]
|
1274 |
-
noise = torch.randn_like(diffusion_input)
|
1275 |
-
noised_inputs = diffusion_input * alphas + noise * sigmas
|
1276 |
-
|
1277 |
-
if self.diffusion_objective == "v":
|
1278 |
-
targets = noise * alphas - diffusion_input * sigmas
|
1279 |
-
elif self.diffusion_objective == "rectified_flow":
|
1280 |
-
targets = noise - diffusion_input
|
1281 |
-
|
1282 |
-
p.tick("noise")
|
1283 |
-
|
1284 |
-
extra_args = {}
|
1285 |
-
|
1286 |
-
with torch.amp.autocast('cuda'):
|
1287 |
-
p.tick("amp")
|
1288 |
-
output = self.diffusion(noised_inputs, t, cond=conditioning, cfg_dropout_prob = self.cfg_dropout_prob, **extra_args)
|
1289 |
-
p.tick("diffusion")
|
1290 |
-
|
1291 |
-
loss_info.update({
|
1292 |
-
"output": output,
|
1293 |
-
"targets": targets,
|
1294 |
-
})
|
1295 |
-
|
1296 |
-
loss, losses = self.losses(loss_info)
|
1297 |
-
|
1298 |
-
if self.log_loss_info:
|
1299 |
-
# Loss debugging logs
|
1300 |
-
num_loss_buckets = 10
|
1301 |
-
bucket_size = 1 / num_loss_buckets
|
1302 |
-
loss_all = F.mse_loss(output, targets, reduction="none")
|
1303 |
-
|
1304 |
-
sigmas = rearrange(self.all_gather(sigmas), "w b c n -> (w b) c n").squeeze()
|
1305 |
-
|
1306 |
-
# gather loss_all across all GPUs
|
1307 |
-
loss_all = rearrange(self.all_gather(loss_all), "w b c n -> (w b) c n")
|
1308 |
-
|
1309 |
-
# Bucket loss values based on corresponding sigma values, bucketing sigma values by bucket_size
|
1310 |
-
loss_all = torch.stack([loss_all[(sigmas >= i) & (sigmas < i + bucket_size)].mean() for i in torch.arange(0, 1, bucket_size).to(self.device)])
|
1311 |
-
|
1312 |
-
# Log bucketed losses with corresponding sigma bucket values, if it's not NaN
|
1313 |
-
debug_log_dict = {
|
1314 |
-
f"model/loss_all_{i/num_loss_buckets:.1f}": loss_all[i].detach() for i in range(num_loss_buckets) if not torch.isnan(loss_all[i])
|
1315 |
-
}
|
1316 |
-
|
1317 |
-
self.log_dict(debug_log_dict)
|
1318 |
-
|
1319 |
-
log_dict = {
|
1320 |
-
'train/loss': loss.detach(),
|
1321 |
-
'train/std_data': diffusion_input.std(),
|
1322 |
-
'train/lr': self.trainer.optimizers[0].param_groups[0]['lr']
|
1323 |
-
}
|
1324 |
-
|
1325 |
-
for loss_name, loss_value in losses.items():
|
1326 |
-
log_dict[f"train/{loss_name}"] = loss_value.detach()
|
1327 |
-
|
1328 |
-
self.log_dict(log_dict, prog_bar=True, on_step=True)
|
1329 |
-
p.tick("log")
|
1330 |
-
#print(f"Profiler: {p}")
|
1331 |
-
return loss
|
1332 |
-
|
1333 |
-
def on_before_zero_grad(self, *args, **kwargs):
|
1334 |
-
if self.diffusion_ema is not None:
|
1335 |
-
self.diffusion_ema.update()
|
1336 |
-
|
1337 |
-
def export_model(self, path, use_safetensors=False):
|
1338 |
-
if self.diffusion_ema is not None:
|
1339 |
-
self.diffusion.model = self.diffusion_ema.ema_model
|
1340 |
-
|
1341 |
-
if use_safetensors:
|
1342 |
-
save_file(self.diffusion.state_dict(), path)
|
1343 |
-
else:
|
1344 |
-
torch.save({"state_dict": self.diffusion.state_dict()}, path)
|
1345 |
-
|
1346 |
-
class DiffusionCondInpaintDemoCallback(Callback):
|
1347 |
-
def __init__(
|
1348 |
-
self,
|
1349 |
-
demo_dl,
|
1350 |
-
demo_every=2000,
|
1351 |
-
demo_steps=250,
|
1352 |
-
sample_size=65536,
|
1353 |
-
sample_rate=48000,
|
1354 |
-
demo_cfg_scales: tp.Optional[tp.List[int]] = [3, 5, 7]
|
1355 |
-
):
|
1356 |
-
super().__init__()
|
1357 |
-
self.demo_every = demo_every
|
1358 |
-
self.demo_steps = demo_steps
|
1359 |
-
self.demo_samples = sample_size
|
1360 |
-
self.demo_dl = iter(demo_dl)
|
1361 |
-
self.sample_rate = sample_rate
|
1362 |
-
self.demo_cfg_scales = demo_cfg_scales
|
1363 |
-
self.last_demo_step = -1
|
1364 |
-
|
1365 |
-
@rank_zero_only
|
1366 |
-
@torch.no_grad()
|
1367 |
-
def on_train_batch_end(self, trainer, module: DiffusionCondTrainingWrapper, outputs, batch, batch_idx):
|
1368 |
-
if (trainer.global_step - 1) % self.demo_every != 0 or self.last_demo_step == trainer.global_step:
|
1369 |
-
return
|
1370 |
-
|
1371 |
-
self.last_demo_step = trainer.global_step
|
1372 |
-
|
1373 |
-
try:
|
1374 |
-
log_dict = {}
|
1375 |
-
|
1376 |
-
demo_reals, metadata = next(self.demo_dl)
|
1377 |
-
|
1378 |
-
# Remove extra dimension added by WebDataset
|
1379 |
-
if demo_reals.ndim == 4 and demo_reals.shape[0] == 1:
|
1380 |
-
demo_reals = demo_reals[0]
|
1381 |
-
|
1382 |
-
demo_reals = demo_reals.to(module.device)
|
1383 |
-
|
1384 |
-
if not module.pre_encoded:
|
1385 |
-
# Log the real audio
|
1386 |
-
log_dict[f'demo_reals_melspec_left'] = wandb.Image(audio_spectrogram_image(rearrange(demo_reals, "b d n -> d (b n)").mul(32767).to(torch.int16).cpu()))
|
1387 |
-
# log_dict[f'demo_reals'] = wandb.Audio(rearrange(demo_reals, "b d n -> d (b n)").mul(32767).to(torch.int16).cpu(), sample_rate=self.sample_rate, caption="demo reals")
|
1388 |
-
|
1389 |
-
if module.diffusion.pretransform is not None:
|
1390 |
-
module.diffusion.pretransform.to(module.device)
|
1391 |
-
with torch.amp.autocast('cuda'):
|
1392 |
-
demo_reals = module.diffusion.pretransform.encode(demo_reals)
|
1393 |
-
|
1394 |
-
demo_samples = demo_reals.shape[2]
|
1395 |
-
|
1396 |
-
# Get conditioning
|
1397 |
-
conditioning = module.diffusion.conditioner(metadata, module.device)
|
1398 |
-
|
1399 |
-
masked_input, mask = module.random_mask(demo_reals, demo_reals.shape[2])
|
1400 |
-
|
1401 |
-
conditioning['inpaint_mask'] = [mask]
|
1402 |
-
conditioning['inpaint_masked_input'] = [masked_input]
|
1403 |
-
|
1404 |
-
if module.diffusion.pretransform is not None:
|
1405 |
-
log_dict[f'demo_masked_input'] = wandb.Image(tokens_spectrogram_image(masked_input.cpu()))
|
1406 |
-
else:
|
1407 |
-
log_dict[f'demo_masked_input'] = wandb.Image(audio_spectrogram_image(rearrange(masked_input, "b c t -> c (b t)").mul(32767).to(torch.int16).cpu()))
|
1408 |
-
|
1409 |
-
cond_inputs = module.diffusion.get_conditioning_inputs(conditioning)
|
1410 |
-
|
1411 |
-
noise = torch.randn([demo_reals.shape[0], module.diffusion.io_channels, demo_samples]).to(module.device)
|
1412 |
-
|
1413 |
-
trainer.logger.experiment.log(log_dict)
|
1414 |
-
|
1415 |
-
for cfg_scale in self.demo_cfg_scales:
|
1416 |
-
model = module.diffusion_ema.model if module.diffusion_ema is not None else module.diffusion.model
|
1417 |
-
print(f"Generating demo for cfg scale {cfg_scale}")
|
1418 |
-
|
1419 |
-
if module.diffusion_objective == "v":
|
1420 |
-
fakes = sample(model, noise, self.demo_steps, 0, **cond_inputs, cfg_scale=cfg_scale, batch_cfg=True)
|
1421 |
-
elif module.diffusion_objective == "rectified_flow":
|
1422 |
-
fakes = sample_discrete_euler(model, noise, self.demo_steps, **cond_inputs, cfg_scale=cfg_scale, batch_cfg=True)
|
1423 |
-
|
1424 |
-
if module.diffusion.pretransform is not None:
|
1425 |
-
with torch.amp.autocast('cuda'):
|
1426 |
-
fakes = module.diffusion.pretransform.decode(fakes)
|
1427 |
-
|
1428 |
-
# Put the demos together
|
1429 |
-
fakes = rearrange(fakes, 'b d n -> d (b n)')
|
1430 |
-
|
1431 |
-
log_dict = {}
|
1432 |
-
|
1433 |
-
filename = f'demo_cfg_{cfg_scale}_{trainer.global_step:08}.wav'
|
1434 |
-
fakes = fakes.to(torch.float32).div(torch.max(torch.abs(fakes))).mul(32767).to(torch.int16).cpu()
|
1435 |
-
torchaudio.save(filename, fakes, self.sample_rate)
|
1436 |
-
|
1437 |
-
log_dict[f'demo_cfg_{cfg_scale}'] = wandb.Audio(filename,
|
1438 |
-
sample_rate=self.sample_rate,
|
1439 |
-
caption=f'Reconstructed')
|
1440 |
-
|
1441 |
-
log_dict[f'demo_melspec_left_cfg_{cfg_scale}'] = wandb.Image(audio_spectrogram_image(fakes))
|
1442 |
-
|
1443 |
-
trainer.logger.experiment.log(log_dict)
|
1444 |
-
except Exception as e:
|
1445 |
-
print(f'{type(e).__name__}: {e}')
|
1446 |
-
raise e
|
1447 |
-
|
1448 |
-
class DiffusionAutoencoderTrainingWrapper(L.LightningModule):
|
1449 |
-
'''
|
1450 |
-
Wrapper for training a diffusion autoencoder
|
1451 |
-
'''
|
1452 |
-
def __init__(
|
1453 |
-
self,
|
1454 |
-
model: DiffusionAutoencoder,
|
1455 |
-
lr: float = 1e-4,
|
1456 |
-
ema_copy = None,
|
1457 |
-
use_reconstruction_loss: bool = False
|
1458 |
-
):
|
1459 |
-
super().__init__()
|
1460 |
-
|
1461 |
-
self.diffae = model
|
1462 |
-
|
1463 |
-
self.diffae_ema = EMA(
|
1464 |
-
self.diffae,
|
1465 |
-
ema_model=ema_copy,
|
1466 |
-
beta=0.9999,
|
1467 |
-
power=3/4,
|
1468 |
-
update_every=1,
|
1469 |
-
update_after_step=1,
|
1470 |
-
include_online_model=False
|
1471 |
-
)
|
1472 |
-
|
1473 |
-
self.lr = lr
|
1474 |
-
|
1475 |
-
self.rng = torch.quasirandom.SobolEngine(1, scramble=True)
|
1476 |
-
|
1477 |
-
loss_modules = [
|
1478 |
-
MSELoss("v",
|
1479 |
-
"targets",
|
1480 |
-
weight=1.0,
|
1481 |
-
name="mse_loss"
|
1482 |
-
)
|
1483 |
-
]
|
1484 |
-
|
1485 |
-
if model.bottleneck is not None:
|
1486 |
-
# TODO: Use loss config for configurable bottleneck weights and reconstruction losses
|
1487 |
-
loss_modules += create_loss_modules_from_bottleneck(model.bottleneck, {})
|
1488 |
-
|
1489 |
-
self.use_reconstruction_loss = use_reconstruction_loss
|
1490 |
-
|
1491 |
-
if use_reconstruction_loss:
|
1492 |
-
scales = [2048, 1024, 512, 256, 128, 64, 32]
|
1493 |
-
hop_sizes = []
|
1494 |
-
win_lengths = []
|
1495 |
-
overlap = 0.75
|
1496 |
-
for s in scales:
|
1497 |
-
hop_sizes.append(int(s * (1 - overlap)))
|
1498 |
-
win_lengths.append(s)
|
1499 |
-
|
1500 |
-
sample_rate = model.sample_rate
|
1501 |
-
|
1502 |
-
stft_loss_args = {
|
1503 |
-
"fft_sizes": scales,
|
1504 |
-
"hop_sizes": hop_sizes,
|
1505 |
-
"win_lengths": win_lengths,
|
1506 |
-
"perceptual_weighting": True
|
1507 |
-
}
|
1508 |
-
|
1509 |
-
out_channels = model.out_channels
|
1510 |
-
|
1511 |
-
if model.pretransform is not None:
|
1512 |
-
out_channels = model.pretransform.io_channels
|
1513 |
-
|
1514 |
-
if out_channels == 2:
|
1515 |
-
self.sdstft = auraloss.freq.SumAndDifferenceSTFTLoss(sample_rate=sample_rate, **stft_loss_args)
|
1516 |
-
else:
|
1517 |
-
self.sdstft = auraloss.freq.MultiResolutionSTFTLoss(sample_rate=sample_rate, **stft_loss_args)
|
1518 |
-
|
1519 |
-
loss_modules.append(
|
1520 |
-
AuralossLoss(self.sdstft, 'audio_reals', 'audio_pred', name='mrstft_loss', weight=0.1), # Reconstruction loss
|
1521 |
-
)
|
1522 |
-
|
1523 |
-
self.losses = MultiLoss(loss_modules)
|
1524 |
-
|
1525 |
-
def configure_optimizers(self):
|
1526 |
-
return optim.Adam([*self.diffae.parameters()], lr=self.lr)
|
1527 |
-
|
1528 |
-
def training_step(self, batch, batch_idx):
|
1529 |
-
reals = batch[0]
|
1530 |
-
|
1531 |
-
if reals.ndim == 4 and reals.shape[0] == 1:
|
1532 |
-
reals = reals[0]
|
1533 |
-
|
1534 |
-
loss_info = {}
|
1535 |
-
|
1536 |
-
loss_info["audio_reals"] = reals
|
1537 |
-
|
1538 |
-
if self.diffae.pretransform is not None:
|
1539 |
-
with torch.no_grad():
|
1540 |
-
reals = self.diffae.pretransform.encode(reals)
|
1541 |
-
|
1542 |
-
loss_info["reals"] = reals
|
1543 |
-
|
1544 |
-
#Encode reals, skipping the pretransform since it was already applied
|
1545 |
-
latents, encoder_info = self.diffae.encode(reals, return_info=True, skip_pretransform=True)
|
1546 |
-
|
1547 |
-
loss_info["latents"] = latents
|
1548 |
-
loss_info.update(encoder_info)
|
1549 |
-
|
1550 |
-
if self.diffae.decoder is not None:
|
1551 |
-
latents = self.diffae.decoder(latents)
|
1552 |
-
|
1553 |
-
# Upsample latents to match diffusion length
|
1554 |
-
if latents.shape[2] != reals.shape[2]:
|
1555 |
-
latents = F.interpolate(latents, size=reals.shape[2], mode='nearest')
|
1556 |
-
|
1557 |
-
loss_info["latents_upsampled"] = latents
|
1558 |
-
|
1559 |
-
# Draw uniformly distributed continuous timesteps
|
1560 |
-
t = self.rng.draw(reals.shape[0])[:, 0].to(self.device)
|
1561 |
-
|
1562 |
-
# Calculate the noise schedule parameters for those timesteps
|
1563 |
-
alphas, sigmas = get_alphas_sigmas(t)
|
1564 |
-
|
1565 |
-
# Combine the ground truth data and the noise
|
1566 |
-
alphas = alphas[:, None, None]
|
1567 |
-
sigmas = sigmas[:, None, None]
|
1568 |
-
noise = torch.randn_like(reals)
|
1569 |
-
noised_reals = reals * alphas + noise * sigmas
|
1570 |
-
targets = noise * alphas - reals * sigmas
|
1571 |
-
|
1572 |
-
with torch.amp.autocast('cuda'):
|
1573 |
-
v = self.diffae.diffusion(noised_reals, t, input_concat_cond=latents)
|
1574 |
-
|
1575 |
-
loss_info.update({
|
1576 |
-
"v": v,
|
1577 |
-
"targets": targets
|
1578 |
-
})
|
1579 |
-
|
1580 |
-
if self.use_reconstruction_loss:
|
1581 |
-
pred = noised_reals * alphas - v * sigmas
|
1582 |
-
|
1583 |
-
loss_info["pred"] = pred
|
1584 |
-
|
1585 |
-
if self.diffae.pretransform is not None:
|
1586 |
-
pred = self.diffae.pretransform.decode(pred)
|
1587 |
-
loss_info["audio_pred"] = pred
|
1588 |
-
|
1589 |
-
loss, losses = self.losses(loss_info)
|
1590 |
-
|
1591 |
-
log_dict = {
|
1592 |
-
'train/loss': loss.detach(),
|
1593 |
-
'train/std_data': reals.std(),
|
1594 |
-
'train/latent_std': latents.std(),
|
1595 |
-
}
|
1596 |
-
|
1597 |
-
for loss_name, loss_value in losses.items():
|
1598 |
-
log_dict[f"train/{loss_name}"] = loss_value.detach()
|
1599 |
-
|
1600 |
-
self.log_dict(log_dict, prog_bar=True, on_step=True)
|
1601 |
-
return loss
|
1602 |
-
|
1603 |
-
def on_before_zero_grad(self, *args, **kwargs):
|
1604 |
-
self.diffae_ema.update()
|
1605 |
-
|
1606 |
-
def export_model(self, path, use_safetensors=False):
|
1607 |
-
|
1608 |
-
model = self.diffae_ema.ema_model
|
1609 |
-
|
1610 |
-
if use_safetensors:
|
1611 |
-
save_file(model.state_dict(), path)
|
1612 |
-
else:
|
1613 |
-
torch.save({"state_dict": model.state_dict()}, path)
|
1614 |
-
|
1615 |
-
class DiffusionAutoencoderDemoCallback(Callback):
|
1616 |
-
def __init__(
|
1617 |
-
self,
|
1618 |
-
demo_dl,
|
1619 |
-
demo_every=2000,
|
1620 |
-
demo_steps=250,
|
1621 |
-
sample_size=65536,
|
1622 |
-
sample_rate=48000
|
1623 |
-
):
|
1624 |
-
super().__init__()
|
1625 |
-
self.demo_every = demo_every
|
1626 |
-
self.demo_steps = demo_steps
|
1627 |
-
self.demo_samples = sample_size
|
1628 |
-
self.demo_dl = iter(demo_dl)
|
1629 |
-
self.sample_rate = sample_rate
|
1630 |
-
self.last_demo_step = -1
|
1631 |
-
|
1632 |
-
@rank_zero_only
|
1633 |
-
@torch.no_grad()
|
1634 |
-
def on_train_batch_end(self, trainer, module: DiffusionAutoencoderTrainingWrapper, outputs, batch, batch_idx):
|
1635 |
-
if (trainer.global_step - 1) % self.demo_every != 0 or self.last_demo_step == trainer.global_step:
|
1636 |
-
return
|
1637 |
-
|
1638 |
-
self.last_demo_step = trainer.global_step
|
1639 |
-
|
1640 |
-
demo_reals, _ = next(self.demo_dl)
|
1641 |
-
|
1642 |
-
# Remove extra dimension added by WebDataset
|
1643 |
-
if demo_reals.ndim == 4 and demo_reals.shape[0] == 1:
|
1644 |
-
demo_reals = demo_reals[0]
|
1645 |
-
|
1646 |
-
encoder_input = demo_reals
|
1647 |
-
|
1648 |
-
encoder_input = encoder_input.to(module.device)
|
1649 |
-
|
1650 |
-
demo_reals = demo_reals.to(module.device)
|
1651 |
-
|
1652 |
-
with torch.no_grad() and torch.amp.autocast('cuda'):
|
1653 |
-
latents = module.diffae_ema.ema_model.encode(encoder_input).float()
|
1654 |
-
fakes = module.diffae_ema.ema_model.decode(latents, steps=self.demo_steps)
|
1655 |
-
|
1656 |
-
#Interleave reals and fakes
|
1657 |
-
reals_fakes = rearrange([demo_reals, fakes], 'i b d n -> (b i) d n')
|
1658 |
-
|
1659 |
-
# Put the demos together
|
1660 |
-
reals_fakes = rearrange(reals_fakes, 'b d n -> d (b n)')
|
1661 |
-
|
1662 |
-
log_dict = {}
|
1663 |
-
|
1664 |
-
filename = f'recon_{trainer.global_step:08}.wav'
|
1665 |
-
reals_fakes = reals_fakes.to(torch.float32).div(torch.max(torch.abs(reals_fakes))).mul(32767).to(torch.int16).cpu()
|
1666 |
-
torchaudio.save(filename, reals_fakes, self.sample_rate)
|
1667 |
-
|
1668 |
-
log_dict[f'recon'] = wandb.Audio(filename,
|
1669 |
-
sample_rate=self.sample_rate,
|
1670 |
-
caption=f'Reconstructed')
|
1671 |
-
|
1672 |
-
log_dict[f'embeddings_3dpca'] = pca_point_cloud(latents)
|
1673 |
-
log_dict[f'embeddings_spec'] = wandb.Image(tokens_spectrogram_image(latents))
|
1674 |
-
|
1675 |
-
log_dict[f'recon_melspec_left'] = wandb.Image(audio_spectrogram_image(reals_fakes))
|
1676 |
-
|
1677 |
-
if module.diffae_ema.ema_model.pretransform is not None:
|
1678 |
-
with torch.no_grad() and torch.amp.autocast('cuda'):
|
1679 |
-
initial_latents = module.diffae_ema.ema_model.pretransform.encode(encoder_input)
|
1680 |
-
first_stage_fakes = module.diffae_ema.ema_model.pretransform.decode(initial_latents)
|
1681 |
-
first_stage_fakes = rearrange(first_stage_fakes, 'b d n -> d (b n)')
|
1682 |
-
first_stage_fakes = first_stage_fakes.to(torch.float32).mul(32767).to(torch.int16).cpu()
|
1683 |
-
first_stage_filename = f'first_stage_{trainer.global_step:08}.wav'
|
1684 |
-
torchaudio.save(first_stage_filename, first_stage_fakes, self.sample_rate)
|
1685 |
-
|
1686 |
-
log_dict[f'first_stage_latents'] = wandb.Image(tokens_spectrogram_image(initial_latents))
|
1687 |
-
|
1688 |
-
log_dict[f'first_stage'] = wandb.Audio(first_stage_filename,
|
1689 |
-
sample_rate=self.sample_rate,
|
1690 |
-
caption=f'First Stage Reconstructed')
|
1691 |
-
|
1692 |
-
log_dict[f'first_stage_melspec_left'] = wandb.Image(audio_spectrogram_image(first_stage_fakes))
|
1693 |
-
|
1694 |
-
|
1695 |
-
trainer.logger.experiment.log(log_dict)
|
1696 |
-
|
1697 |
-
def create_source_mixture(reals, num_sources=2):
|
1698 |
-
# Create a fake mixture source by mixing elements from the training batch together with random offsets
|
1699 |
-
source = torch.zeros_like(reals)
|
1700 |
-
for i in range(reals.shape[0]):
|
1701 |
-
sources_added = 0
|
1702 |
-
|
1703 |
-
js = list(range(reals.shape[0]))
|
1704 |
-
random.shuffle(js)
|
1705 |
-
for j in js:
|
1706 |
-
if i == j or (i != j and sources_added < num_sources):
|
1707 |
-
# Randomly offset the mixed element between 0 and the length of the source
|
1708 |
-
seq_len = reals.shape[2]
|
1709 |
-
offset = random.randint(0, seq_len-1)
|
1710 |
-
source[i, :, offset:] += reals[j, :, :-offset]
|
1711 |
-
if i == j:
|
1712 |
-
# If this is the real one, shift the reals as well to ensure alignment
|
1713 |
-
new_reals = torch.zeros_like(reals[i])
|
1714 |
-
new_reals[:, offset:] = reals[i, :, :-offset]
|
1715 |
-
reals[i] = new_reals
|
1716 |
-
sources_added += 1
|
1717 |
-
|
1718 |
-
return source
|
1719 |
-
|
1720 |
-
class DiffusionPriorTrainingWrapper(L.LightningModule):
|
1721 |
-
'''
|
1722 |
-
Wrapper for training a diffusion prior for inverse problems
|
1723 |
-
Prior types:
|
1724 |
-
mono_stereo: The prior is conditioned on a mono version of the audio to generate a stereo version
|
1725 |
-
'''
|
1726 |
-
def __init__(
|
1727 |
-
self,
|
1728 |
-
model: ConditionedDiffusionModelWrapper,
|
1729 |
-
lr: float = 1e-4,
|
1730 |
-
ema_copy = None,
|
1731 |
-
prior_type: PriorType = PriorType.MonoToStereo,
|
1732 |
-
use_reconstruction_loss: bool = False,
|
1733 |
-
log_loss_info: bool = False,
|
1734 |
-
):
|
1735 |
-
super().__init__()
|
1736 |
-
|
1737 |
-
self.diffusion = model
|
1738 |
-
|
1739 |
-
self.diffusion_ema = EMA(
|
1740 |
-
self.diffusion,
|
1741 |
-
ema_model=ema_copy,
|
1742 |
-
beta=0.9999,
|
1743 |
-
power=3/4,
|
1744 |
-
update_every=1,
|
1745 |
-
update_after_step=1,
|
1746 |
-
include_online_model=False
|
1747 |
-
)
|
1748 |
-
|
1749 |
-
self.lr = lr
|
1750 |
-
|
1751 |
-
self.rng = torch.quasirandom.SobolEngine(1, scramble=True)
|
1752 |
-
|
1753 |
-
self.log_loss_info = log_loss_info
|
1754 |
-
|
1755 |
-
loss_modules = [
|
1756 |
-
MSELoss("v",
|
1757 |
-
"targets",
|
1758 |
-
weight=1.0,
|
1759 |
-
name="mse_loss"
|
1760 |
-
)
|
1761 |
-
]
|
1762 |
-
|
1763 |
-
self.use_reconstruction_loss = use_reconstruction_loss
|
1764 |
-
|
1765 |
-
if use_reconstruction_loss:
|
1766 |
-
scales = [2048, 1024, 512, 256, 128, 64, 32]
|
1767 |
-
hop_sizes = []
|
1768 |
-
win_lengths = []
|
1769 |
-
overlap = 0.75
|
1770 |
-
for s in scales:
|
1771 |
-
hop_sizes.append(int(s * (1 - overlap)))
|
1772 |
-
win_lengths.append(s)
|
1773 |
-
|
1774 |
-
sample_rate = model.sample_rate
|
1775 |
-
|
1776 |
-
stft_loss_args = {
|
1777 |
-
"fft_sizes": scales,
|
1778 |
-
"hop_sizes": hop_sizes,
|
1779 |
-
"win_lengths": win_lengths,
|
1780 |
-
"perceptual_weighting": True
|
1781 |
-
}
|
1782 |
-
|
1783 |
-
out_channels = model.io_channels
|
1784 |
-
|
1785 |
-
|
1786 |
-
if model.pretransform is not None:
|
1787 |
-
out_channels = model.pretransform.io_channels
|
1788 |
-
self.audio_out_channels = out_channels
|
1789 |
-
|
1790 |
-
if self.audio_out_channels == 2:
|
1791 |
-
self.sdstft = auraloss.freq.SumAndDifferenceSTFTLoss(sample_rate=sample_rate, **stft_loss_args)
|
1792 |
-
self.lrstft = auraloss.freq.MultiResolutionSTFTLoss(sample_rate=sample_rate, **stft_loss_args)
|
1793 |
-
|
1794 |
-
# Add left and right channel reconstruction losses in addition to the sum and difference
|
1795 |
-
loss_modules += [
|
1796 |
-
AuralossLoss(self.lrstft, 'audio_reals_left', 'pred_left', name='stft_loss_left', weight=0.05),
|
1797 |
-
AuralossLoss(self.lrstft, 'audio_reals_right', 'pred_right', name='stft_loss_right', weight=0.05),
|
1798 |
-
]
|
1799 |
-
|
1800 |
-
else:
|
1801 |
-
self.sdstft = auraloss.freq.MultiResolutionSTFTLoss(sample_rate=sample_rate, **stft_loss_args)
|
1802 |
-
|
1803 |
-
loss_modules.append(
|
1804 |
-
AuralossLoss(self.sdstft, 'audio_reals', 'audio_pred', name='mrstft_loss', weight=0.1), # Reconstruction loss
|
1805 |
-
)
|
1806 |
-
|
1807 |
-
self.losses = MultiLoss(loss_modules)
|
1808 |
-
|
1809 |
-
self.prior_type = prior_type
|
1810 |
-
|
1811 |
-
def configure_optimizers(self):
|
1812 |
-
return optim.Adam([*self.diffusion.parameters()], lr=self.lr)
|
1813 |
-
|
1814 |
-
def training_step(self, batch, batch_idx):
|
1815 |
-
reals, metadata = batch
|
1816 |
-
|
1817 |
-
if reals.ndim == 4 and reals.shape[0] == 1:
|
1818 |
-
reals = reals[0]
|
1819 |
-
|
1820 |
-
loss_info = {}
|
1821 |
-
|
1822 |
-
loss_info["audio_reals"] = reals
|
1823 |
-
|
1824 |
-
if self.prior_type == PriorType.MonoToStereo:
|
1825 |
-
source = reals.mean(dim=1, keepdim=True).repeat(1, reals.shape[1], 1).to(self.device)
|
1826 |
-
loss_info["audio_reals_mono"] = source
|
1827 |
-
else:
|
1828 |
-
raise ValueError(f"Unknown prior type {self.prior_type}")
|
1829 |
-
|
1830 |
-
if self.diffusion.pretransform is not None:
|
1831 |
-
with torch.no_grad():
|
1832 |
-
reals = self.diffusion.pretransform.encode(reals)
|
1833 |
-
|
1834 |
-
if self.prior_type in [PriorType.MonoToStereo]:
|
1835 |
-
source = self.diffusion.pretransform.encode(source)
|
1836 |
-
|
1837 |
-
if self.diffusion.conditioner is not None:
|
1838 |
-
with torch.amp.autocast('cuda'):
|
1839 |
-
conditioning = self.diffusion.conditioner(metadata, self.device)
|
1840 |
-
else:
|
1841 |
-
conditioning = {}
|
1842 |
-
|
1843 |
-
loss_info["reals"] = reals
|
1844 |
-
|
1845 |
-
# Draw uniformly distributed continuous timesteps
|
1846 |
-
t = self.rng.draw(reals.shape[0])[:, 0].to(self.device)
|
1847 |
-
|
1848 |
-
# Calculate the noise schedule parameters for those timesteps
|
1849 |
-
alphas, sigmas = get_alphas_sigmas(t)
|
1850 |
-
|
1851 |
-
# Combine the ground truth data and the noise
|
1852 |
-
alphas = alphas[:, None, None]
|
1853 |
-
sigmas = sigmas[:, None, None]
|
1854 |
-
noise = torch.randn_like(reals)
|
1855 |
-
noised_reals = reals * alphas + noise * sigmas
|
1856 |
-
targets = noise * alphas - reals * sigmas
|
1857 |
-
|
1858 |
-
with torch.amp.autocast('cuda'):
|
1859 |
-
|
1860 |
-
conditioning['source'] = [source]
|
1861 |
-
|
1862 |
-
v = self.diffusion(noised_reals, t, cond=conditioning, cfg_dropout_prob = 0.1)
|
1863 |
-
|
1864 |
-
loss_info.update({
|
1865 |
-
"v": v,
|
1866 |
-
"targets": targets
|
1867 |
-
})
|
1868 |
-
|
1869 |
-
if self.use_reconstruction_loss:
|
1870 |
-
pred = noised_reals * alphas - v * sigmas
|
1871 |
-
|
1872 |
-
loss_info["pred"] = pred
|
1873 |
-
|
1874 |
-
if self.diffusion.pretransform is not None:
|
1875 |
-
pred = self.diffusion.pretransform.decode(pred)
|
1876 |
-
loss_info["audio_pred"] = pred
|
1877 |
-
|
1878 |
-
if self.audio_out_channels == 2:
|
1879 |
-
loss_info["pred_left"] = pred[:, 0:1, :]
|
1880 |
-
loss_info["pred_right"] = pred[:, 1:2, :]
|
1881 |
-
loss_info["audio_reals_left"] = loss_info["audio_reals"][:, 0:1, :]
|
1882 |
-
loss_info["audio_reals_right"] = loss_info["audio_reals"][:, 1:2, :]
|
1883 |
-
|
1884 |
-
loss, losses = self.losses(loss_info)
|
1885 |
-
|
1886 |
-
if self.log_loss_info:
|
1887 |
-
# Loss debugging logs
|
1888 |
-
num_loss_buckets = 10
|
1889 |
-
bucket_size = 1 / num_loss_buckets
|
1890 |
-
loss_all = F.mse_loss(v, targets, reduction="none")
|
1891 |
-
|
1892 |
-
sigmas = rearrange(self.all_gather(sigmas), "w b c n -> (w b) c n").squeeze()
|
1893 |
-
|
1894 |
-
# gather loss_all across all GPUs
|
1895 |
-
loss_all = rearrange(self.all_gather(loss_all), "w b c n -> (w b) c n")
|
1896 |
-
|
1897 |
-
# Bucket loss values based on corresponding sigma values, bucketing sigma values by bucket_size
|
1898 |
-
loss_all = torch.stack([loss_all[(sigmas >= i) & (sigmas < i + bucket_size)].mean() for i in torch.arange(0, 1, bucket_size).to(self.device)])
|
1899 |
-
|
1900 |
-
# Log bucketed losses with corresponding sigma bucket values, if it's not NaN
|
1901 |
-
debug_log_dict = {
|
1902 |
-
f"model/loss_all_{i/num_loss_buckets:.1f}": loss_all[i].detach() for i in range(num_loss_buckets) if not torch.isnan(loss_all[i])
|
1903 |
-
}
|
1904 |
-
|
1905 |
-
self.log_dict(debug_log_dict)
|
1906 |
-
|
1907 |
-
log_dict = {
|
1908 |
-
'train/loss': loss.detach(),
|
1909 |
-
'train/std_data': reals.std()
|
1910 |
-
}
|
1911 |
-
|
1912 |
-
for loss_name, loss_value in losses.items():
|
1913 |
-
log_dict[f"train/{loss_name}"] = loss_value.detach()
|
1914 |
-
|
1915 |
-
self.log_dict(log_dict, prog_bar=True, on_step=True)
|
1916 |
-
return loss
|
1917 |
-
|
1918 |
-
def on_before_zero_grad(self, *args, **kwargs):
|
1919 |
-
self.diffusion_ema.update()
|
1920 |
-
|
1921 |
-
def export_model(self, path, use_safetensors=False):
|
1922 |
-
|
1923 |
-
#model = self.diffusion_ema.ema_model
|
1924 |
-
model = self.diffusion
|
1925 |
-
|
1926 |
-
if use_safetensors:
|
1927 |
-
save_file(model.state_dict(), path)
|
1928 |
-
else:
|
1929 |
-
torch.save({"state_dict": model.state_dict()}, path)
|
1930 |
-
|
1931 |
-
class DiffusionPriorDemoCallback(Callback):
|
1932 |
-
def __init__(
|
1933 |
-
self,
|
1934 |
-
demo_dl,
|
1935 |
-
demo_every=2000,
|
1936 |
-
demo_steps=250,
|
1937 |
-
sample_size=65536,
|
1938 |
-
sample_rate=48000
|
1939 |
-
):
|
1940 |
-
super().__init__()
|
1941 |
-
|
1942 |
-
self.demo_every = demo_every
|
1943 |
-
self.demo_steps = demo_steps
|
1944 |
-
self.demo_samples = sample_size
|
1945 |
-
self.demo_dl = iter(demo_dl)
|
1946 |
-
self.sample_rate = sample_rate
|
1947 |
-
self.last_demo_step = -1
|
1948 |
-
|
1949 |
-
@rank_zero_only
|
1950 |
-
@torch.no_grad()
|
1951 |
-
def on_train_batch_end(self, trainer, module: DiffusionAutoencoderTrainingWrapper, outputs, batch, batch_idx):
|
1952 |
-
if (trainer.global_step - 1) % self.demo_every != 0 or self.last_demo_step == trainer.global_step:
|
1953 |
-
return
|
1954 |
-
|
1955 |
-
self.last_demo_step = trainer.global_step
|
1956 |
-
|
1957 |
-
demo_reals, metadata = next(self.demo_dl)
|
1958 |
-
# import ipdb
|
1959 |
-
# ipdb.set_trace()
|
1960 |
-
# Remove extra dimension added by WebDataset
|
1961 |
-
if demo_reals.ndim == 4 and demo_reals.shape[0] == 1:
|
1962 |
-
demo_reals = demo_reals[0]
|
1963 |
-
|
1964 |
-
demo_reals = demo_reals.to(module.device)
|
1965 |
-
|
1966 |
-
encoder_input = demo_reals
|
1967 |
-
|
1968 |
-
if module.diffusion.conditioner is not None:
|
1969 |
-
with torch.amp.autocast('cuda'):
|
1970 |
-
conditioning_tensors = module.diffusion.conditioner(metadata, module.device)
|
1971 |
-
|
1972 |
-
else:
|
1973 |
-
conditioning_tensors = {}
|
1974 |
-
|
1975 |
-
|
1976 |
-
with torch.no_grad() and torch.amp.autocast('cuda'):
|
1977 |
-
if module.prior_type == PriorType.MonoToStereo and encoder_input.shape[1] > 1:
|
1978 |
-
source = encoder_input.mean(dim=1, keepdim=True).repeat(1, encoder_input.shape[1], 1).to(module.device)
|
1979 |
-
|
1980 |
-
if module.diffusion.pretransform is not None:
|
1981 |
-
encoder_input = module.diffusion.pretransform.encode(encoder_input)
|
1982 |
-
source_input = module.diffusion.pretransform.encode(source)
|
1983 |
-
else:
|
1984 |
-
source_input = source
|
1985 |
-
|
1986 |
-
conditioning_tensors['source'] = [source_input]
|
1987 |
-
|
1988 |
-
fakes = sample(module.diffusion_ema.model, torch.randn_like(encoder_input), self.demo_steps, 0, cond=conditioning_tensors)
|
1989 |
-
|
1990 |
-
if module.diffusion.pretransform is not None:
|
1991 |
-
fakes = module.diffusion.pretransform.decode(fakes)
|
1992 |
-
|
1993 |
-
#Interleave reals and fakes
|
1994 |
-
reals_fakes = rearrange([demo_reals, fakes], 'i b d n -> (b i) d n')
|
1995 |
-
|
1996 |
-
# Put the demos together
|
1997 |
-
reals_fakes = rearrange(reals_fakes, 'b d n -> d (b n)')
|
1998 |
-
|
1999 |
-
log_dict = {}
|
2000 |
-
|
2001 |
-
filename = f'recon_mono_{trainer.global_step:08}.wav'
|
2002 |
-
reals_fakes = reals_fakes.to(torch.float32).div(torch.max(torch.abs(reals_fakes))).mul(32767).to(torch.int16).cpu()
|
2003 |
-
torchaudio.save(filename, reals_fakes, self.sample_rate)
|
2004 |
-
|
2005 |
-
log_dict[f'recon'] = wandb.Audio(filename,
|
2006 |
-
sample_rate=self.sample_rate,
|
2007 |
-
caption=f'Reconstructed')
|
2008 |
-
|
2009 |
-
log_dict[f'recon_melspec_left'] = wandb.Image(audio_spectrogram_image(reals_fakes))
|
2010 |
-
|
2011 |
-
#Log the source
|
2012 |
-
filename = f'source_{trainer.global_step:08}.wav'
|
2013 |
-
source = rearrange(source, 'b d n -> d (b n)')
|
2014 |
-
source = source.to(torch.float32).mul(32767).to(torch.int16).cpu()
|
2015 |
-
torchaudio.save(filename, source, self.sample_rate)
|
2016 |
-
|
2017 |
-
log_dict[f'source'] = wandb.Audio(filename,
|
2018 |
-
sample_rate=self.sample_rate,
|
2019 |
-
caption=f'Source')
|
2020 |
-
|
2021 |
-
log_dict[f'source_melspec_left'] = wandb.Image(audio_spectrogram_image(source))
|
2022 |
-
|
2023 |
-
trainer.logger.experiment.log(log_dict)
|
|
|
20 |
from ..inference.sampling import get_alphas_sigmas, sample, sample_discrete_euler
|
21 |
from ..models.diffusion import DiffusionModelWrapper, ConditionedDiffusionModelWrapper
|
22 |
from ..models.autoencoders import DiffusionAutoencoder
|
|
|
23 |
from .autoencoders import create_loss_modules_from_bottleneck
|
24 |
from .losses import AuralossLoss, MSELoss, MultiLoss
|
25 |
from .utils import create_optimizer_from_config, create_scheduler_from_config, mask_from_frac_lengths, generate_mask, generate_channel_mask
|
|
|
845 |
|
846 |
def predict_step(self, batch, batch_idx):
|
847 |
reals, metadata = batch
|
|
|
|
|
848 |
ids = [item['id'] for item in metadata]
|
849 |
batch_size, length = reals.shape[0], reals.shape[2]
|
850 |
+
print(f"Predicting {batch_size} samples with length {length} for ids: {ids}")
|
851 |
with torch.amp.autocast('cuda'):
|
852 |
conditioning = self.diffusion.conditioner(metadata, self.device)
|
853 |
|
|
|
876 |
end_time = time.time()
|
877 |
execution_time = end_time - start_time
|
878 |
print(f"执行时间: {execution_time:.2f} 秒")
|
|
|
879 |
if self.diffusion.pretransform is not None:
|
880 |
fakes = self.diffusion.pretransform.decode(fakes)
|
881 |
|
|
|
1074 |
gc.collect()
|
1075 |
torch.cuda.empty_cache()
|
1076 |
module.train()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
{think_sound → ThinkSound}/training/factory.py
RENAMED
File without changes
|
{think_sound → ThinkSound}/training/losses/__init__.py
RENAMED
File without changes
|
{think_sound → ThinkSound}/training/losses/auraloss.py
RENAMED
File without changes
|
{think_sound → ThinkSound}/training/losses/losses.py
RENAMED
File without changes
|
{think_sound → ThinkSound}/training/utils.py
RENAMED
File without changes
|
app.py
CHANGED
@@ -14,13 +14,12 @@ from lightning.pytorch.tuner import Tuner
|
|
14 |
from lightning.pytorch import seed_everything
|
15 |
import random
|
16 |
from datetime import datetime
|
17 |
-
|
18 |
-
from
|
19 |
-
from
|
20 |
-
from
|
21 |
-
from
|
22 |
-
from
|
23 |
-
from think_sound.inference.sampling import get_alphas_sigmas, sample, sample_discrete_euler
|
24 |
from data_utils.v2a_utils.feature_utils_224 import FeaturesUtils
|
25 |
from torch.utils.data import Dataset
|
26 |
from typing import Optional, Union
|
@@ -34,7 +33,7 @@ import tempfile
|
|
34 |
import subprocess
|
35 |
from huggingface_hub import hf_hub_download
|
36 |
from moviepy.editor import VideoFileClip
|
37 |
-
os.system("conda install -c conda-forge 'ffmpeg<7'")
|
38 |
|
39 |
_CLIP_SIZE = 224
|
40 |
_CLIP_FPS = 8.0
|
@@ -101,7 +100,7 @@ class VGGSound(Dataset):
|
|
101 |
|
102 |
self.resampler = {}
|
103 |
|
104 |
-
def sample(self, video_path,label):
|
105 |
video_id = video_path
|
106 |
|
107 |
reader = StreamingMediaDecoder(video_path)
|
@@ -156,7 +155,7 @@ class VGGSound(Dataset):
|
|
156 |
# padding using the last frame, but no more than 2
|
157 |
current_length = sync_chunk.shape[0]
|
158 |
last_frame = sync_chunk[-1]
|
159 |
-
|
160 |
padding = last_frame.repeat(self.sync_expected_length - current_length, 1, 1, 1)
|
161 |
assert self.sync_expected_length - current_length < 12, f'sync can pad no more than 2 while {self.sync_expected_length - current_length}'
|
162 |
sync_chunk = torch.cat((sync_chunk, padding), dim=0)
|
@@ -170,6 +169,7 @@ class VGGSound(Dataset):
|
|
170 |
data = {
|
171 |
'id': video_id,
|
172 |
'caption': label,
|
|
|
173 |
# 'audio': audio_chunk,
|
174 |
'clip_video': clip_chunk,
|
175 |
'sync_video': sync_chunk,
|
@@ -187,17 +187,16 @@ else:
|
|
187 |
|
188 |
print(f"load in device {device}")
|
189 |
|
190 |
-
vae_ckpt = hf_hub_download(repo_id="
|
191 |
-
synchformer_ckpt = hf_hub_download(repo_id="
|
|
|
192 |
feature_extractor = FeaturesUtils(
|
193 |
-
vae_ckpt=
|
194 |
-
vae_config='
|
195 |
enable_conditions=True,
|
196 |
synchformer_ckpt=synchformer_ckpt
|
197 |
).eval().to(extra_device)
|
198 |
|
199 |
-
|
200 |
-
|
201 |
args = get_all_args()
|
202 |
|
203 |
seed = 10086
|
@@ -206,7 +205,7 @@ seed_everything(seed, workers=True)
|
|
206 |
|
207 |
|
208 |
#Get JSON config from args.model_config
|
209 |
-
with open("
|
210 |
model_config = json.load(f)
|
211 |
|
212 |
model = create_model_from_config(model_config)
|
@@ -229,7 +228,7 @@ model.pretransform.load_state_dict(load_vae_state)
|
|
229 |
# Remove weight_norm from the pretransform if specified
|
230 |
if args.remove_pretransform_weight_norm == "post_load":
|
231 |
remove_weight_norm_from_model(model.pretransform)
|
232 |
-
ckpt_path = hf_hub_download(repo_id="
|
233 |
training_wrapper = create_training_wrapper_from_config(model_config, model)
|
234 |
# 加载模型权重时根据设备选择map_location
|
235 |
training_wrapper.load_state_dict(torch.load(ckpt_path)['state_dict'])
|
@@ -243,16 +242,17 @@ def get_video_duration(video_path):
|
|
243 |
@spaces.GPU(duration=60)
|
244 |
@torch.inference_mode()
|
245 |
@torch.no_grad()
|
246 |
-
def get_audio(video_path, caption):
|
247 |
-
# 允许caption为空
|
248 |
if caption is None:
|
249 |
caption = ''
|
|
|
|
|
250 |
timer = Timer(duration="00:15:00:00")
|
251 |
#get video duration
|
252 |
duration_sec = get_video_duration(video_path)
|
253 |
print(duration_sec)
|
254 |
preprocesser = VGGSound(duration_sec=duration_sec)
|
255 |
-
data = preprocesser.sample(video_path, caption)
|
256 |
|
257 |
|
258 |
|
@@ -261,7 +261,7 @@ def get_audio(video_path, caption):
|
|
261 |
preprocessed_data['metaclip_global_text_features'] = metaclip_global_text_features.detach().cpu().squeeze(0)
|
262 |
preprocessed_data['metaclip_text_features'] = metaclip_text_features.detach().cpu().squeeze(0)
|
263 |
|
264 |
-
t5_features = feature_extractor.encode_t5_text(data['
|
265 |
preprocessed_data['t5_features'] = t5_features.detach().cpu().squeeze(0)
|
266 |
|
267 |
clip_features = feature_extractor.encode_video_with_clip(data['clip_video'].unsqueeze(0).to(extra_device))
|
@@ -305,56 +305,47 @@ def get_audio(video_path, caption):
|
|
305 |
fakes = training_wrapper.diffusion.pretransform.decode(fakes)
|
306 |
|
307 |
audios = fakes.to(torch.float32).div(torch.max(torch.abs(fakes))).clamp(-1, 1).mul(32767).to(torch.int16).cpu()
|
308 |
-
# 保存临时音频文件
|
309 |
with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp_audio:
|
310 |
torchaudio.save(tmp_audio.name, audios[0], 44100)
|
311 |
audio_path = tmp_audio.name
|
|
|
312 |
return audio_path
|
313 |
|
314 |
-
def synthesize_video_with_audio(video_file, caption):
|
315 |
-
|
316 |
-
if caption is None:
|
317 |
-
caption = ''
|
318 |
-
audio_path = get_audio(video_file, caption)
|
319 |
with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as tmp_video:
|
320 |
output_video_path = tmp_video.name
|
321 |
-
|
322 |
cmd = [
|
323 |
'ffmpeg', '-y', '-i', video_file, '-i', audio_path,
|
324 |
'-c:v', 'copy', '-map', '0:v:0', '-map', '1:a:0',
|
325 |
'-shortest', output_video_path
|
326 |
]
|
327 |
subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
|
|
328 |
return output_video_path
|
329 |
|
330 |
-
|
331 |
-
|
332 |
-
|
333 |
-
""
|
334 |
-
|
335 |
-
|
336 |
-
|
337 |
-
|
338 |
-
|
339 |
-
|
340 |
-
|
341 |
-
with
|
342 |
-
|
343 |
-
|
344 |
-
|
345 |
-
|
346 |
-
|
347 |
-
|
348 |
-
|
349 |
-
|
350 |
-
|
351 |
-
|
352 |
-
|
353 |
-
["./examples/4_mute.mp4", "train passing by", "./examples/4.mp4"],
|
354 |
-
["./examples/5_mute.mp4", "Lighting Firecrackers", "./examples/5.mp4"]
|
355 |
-
],
|
356 |
-
inputs=[video_input, caption_input,output_video],
|
357 |
-
)
|
358 |
-
|
359 |
-
demo.launch(share=True)
|
360 |
|
|
|
14 |
from lightning.pytorch import seed_everything
|
15 |
import random
|
16 |
from datetime import datetime
|
17 |
+
from ThinkSound.data.datamodule import DataModule
|
18 |
+
from ThinkSound.models import create_model_from_config
|
19 |
+
from ThinkSound.models.utils import load_ckpt_state_dict, remove_weight_norm_from_model
|
20 |
+
from ThinkSound.training import create_training_wrapper_from_config, create_demo_callback_from_config
|
21 |
+
from ThinkSound.training.utils import copy_state_dict
|
22 |
+
from ThinkSound.inference.sampling import get_alphas_sigmas, sample, sample_discrete_euler
|
|
|
23 |
from data_utils.v2a_utils.feature_utils_224 import FeaturesUtils
|
24 |
from torch.utils.data import Dataset
|
25 |
from typing import Optional, Union
|
|
|
33 |
import subprocess
|
34 |
from huggingface_hub import hf_hub_download
|
35 |
from moviepy.editor import VideoFileClip
|
36 |
+
# os.system("conda install -c conda-forge 'ffmpeg<7'")
|
37 |
|
38 |
_CLIP_SIZE = 224
|
39 |
_CLIP_FPS = 8.0
|
|
|
100 |
|
101 |
self.resampler = {}
|
102 |
|
103 |
+
def sample(self, video_path,label,cot):
|
104 |
video_id = video_path
|
105 |
|
106 |
reader = StreamingMediaDecoder(video_path)
|
|
|
155 |
# padding using the last frame, but no more than 2
|
156 |
current_length = sync_chunk.shape[0]
|
157 |
last_frame = sync_chunk[-1]
|
158 |
+
|
159 |
padding = last_frame.repeat(self.sync_expected_length - current_length, 1, 1, 1)
|
160 |
assert self.sync_expected_length - current_length < 12, f'sync can pad no more than 2 while {self.sync_expected_length - current_length}'
|
161 |
sync_chunk = torch.cat((sync_chunk, padding), dim=0)
|
|
|
169 |
data = {
|
170 |
'id': video_id,
|
171 |
'caption': label,
|
172 |
+
'caption_cot': cot,
|
173 |
# 'audio': audio_chunk,
|
174 |
'clip_video': clip_chunk,
|
175 |
'sync_video': sync_chunk,
|
|
|
187 |
|
188 |
print(f"load in device {device}")
|
189 |
|
190 |
+
vae_ckpt = hf_hub_download(repo_id="FunAudioLLM/ThinkSound", filename="vae.ckpt",repo_type="model")
|
191 |
+
synchformer_ckpt = hf_hub_download(repo_id="FunAudioLLM/ThinkSound", filename="synchformer_state_dict.pth",repo_type="model")
|
192 |
+
|
193 |
feature_extractor = FeaturesUtils(
|
194 |
+
vae_ckpt=None,
|
195 |
+
vae_config='ThinkSound/configs/model_configs/stable_audio_2_0_vae.json',
|
196 |
enable_conditions=True,
|
197 |
synchformer_ckpt=synchformer_ckpt
|
198 |
).eval().to(extra_device)
|
199 |
|
|
|
|
|
200 |
args = get_all_args()
|
201 |
|
202 |
seed = 10086
|
|
|
205 |
|
206 |
|
207 |
#Get JSON config from args.model_config
|
208 |
+
with open("ThinkSound/configs/model_configs/thinksound.json") as f:
|
209 |
model_config = json.load(f)
|
210 |
|
211 |
model = create_model_from_config(model_config)
|
|
|
228 |
# Remove weight_norm from the pretransform if specified
|
229 |
if args.remove_pretransform_weight_norm == "post_load":
|
230 |
remove_weight_norm_from_model(model.pretransform)
|
231 |
+
ckpt_path = hf_hub_download(repo_id="FunAudioLLM/ThinkSound", filename="thinksound.ckpt",repo_type="model")
|
232 |
training_wrapper = create_training_wrapper_from_config(model_config, model)
|
233 |
# 加载模型权重时根据设备选择map_location
|
234 |
training_wrapper.load_state_dict(torch.load(ckpt_path)['state_dict'])
|
|
|
242 |
@spaces.GPU(duration=60)
|
243 |
@torch.inference_mode()
|
244 |
@torch.no_grad()
|
245 |
+
def get_audio(video_path, caption, cot):
|
|
|
246 |
if caption is None:
|
247 |
caption = ''
|
248 |
+
if cot is None:
|
249 |
+
cot = caption
|
250 |
timer = Timer(duration="00:15:00:00")
|
251 |
#get video duration
|
252 |
duration_sec = get_video_duration(video_path)
|
253 |
print(duration_sec)
|
254 |
preprocesser = VGGSound(duration_sec=duration_sec)
|
255 |
+
data = preprocesser.sample(video_path, caption, cot)
|
256 |
|
257 |
|
258 |
|
|
|
261 |
preprocessed_data['metaclip_global_text_features'] = metaclip_global_text_features.detach().cpu().squeeze(0)
|
262 |
preprocessed_data['metaclip_text_features'] = metaclip_text_features.detach().cpu().squeeze(0)
|
263 |
|
264 |
+
t5_features = feature_extractor.encode_t5_text(data['caption_cot'])
|
265 |
preprocessed_data['t5_features'] = t5_features.detach().cpu().squeeze(0)
|
266 |
|
267 |
clip_features = feature_extractor.encode_video_with_clip(data['clip_video'].unsqueeze(0).to(extra_device))
|
|
|
305 |
fakes = training_wrapper.diffusion.pretransform.decode(fakes)
|
306 |
|
307 |
audios = fakes.to(torch.float32).div(torch.max(torch.abs(fakes))).clamp(-1, 1).mul(32767).to(torch.int16).cpu()
|
|
|
308 |
with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp_audio:
|
309 |
torchaudio.save(tmp_audio.name, audios[0], 44100)
|
310 |
audio_path = tmp_audio.name
|
311 |
+
|
312 |
return audio_path
|
313 |
|
314 |
+
def synthesize_video_with_audio(video_file, caption, cot):
|
315 |
+
audio_path = get_audio(video_file, caption, cot)
|
|
|
|
|
|
|
316 |
with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as tmp_video:
|
317 |
output_video_path = tmp_video.name
|
318 |
+
|
319 |
cmd = [
|
320 |
'ffmpeg', '-y', '-i', video_file, '-i', audio_path,
|
321 |
'-c:v', 'copy', '-map', '0:v:0', '-map', '1:a:0',
|
322 |
'-shortest', output_video_path
|
323 |
]
|
324 |
subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
325 |
+
|
326 |
return output_video_path
|
327 |
|
328 |
+
demo = gr.Interface(
|
329 |
+
fn=synthesize_video_with_audio,
|
330 |
+
inputs=[
|
331 |
+
gr.Video(label="Upload Video"),
|
332 |
+
gr.Textbox(label="Caption (optional)", placeholder="can be empty",),
|
333 |
+
gr.Textbox(label="CoT Description (optional)", lines=6, placeholder="can be empty",),
|
334 |
+
],
|
335 |
+
outputs=[
|
336 |
+
gr.Video(label="Result"),
|
337 |
+
],
|
338 |
+
title="ThinkSound Demo",
|
339 |
+
description="Upload a video, caption, or CoT to generate audio. For an enhanced experience, we automatically merge the generated audio with your original silent video. (Note: Flexible audio generation lengths are supported.:)",
|
340 |
+
examples=[
|
341 |
+
["examples/3_mute.mp4", "Gentle Sucking Sounds From the Pacifier", "Begin by creating a soft, steady background of light pacifier suckling. Add subtle, breathy rhythms to mimic a newborn's gentle mouth movements. Keep the sound smooth, natural, and soothing."],
|
342 |
+
["examples/2_mute.mp4", "Printer Printing", "Generate a continuous printer printing sound with periodic beeps and paper movement, plus a cat pawing at the machine. Add subtle ambient room noise for authenticity, keeping the focus on printing, beeps, and the cat's interaction."],
|
343 |
+
["examples/4_mute.mp4", "Plastic Debris Handling", "Begin with the sound of hands scooping up loose plastic debris, followed by the subtle cascading noise as the pieces fall and scatter back down. Include soft crinkling and rustling to emphasize the texture of the plastic. Add ambient factory background noise with distant machinery to create an industrial atmosphere."],
|
344 |
+
["examples/5_mute.mp4", "Lighting Firecrackers", "Generate the sound of firecrackers lighting and exploding repeatedly on the ground, followed by fireworks bursting in the sky. Incorporate occasional subtle echoes to mimic an outdoor night ambiance, with no human voices present."]
|
345 |
+
],
|
346 |
+
cache_examples=True
|
347 |
+
)
|
348 |
+
|
349 |
+
if __name__ == "__main__":
|
350 |
+
demo.launch(share=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
351 |
|
cot_vgg_demo_caption.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
demo.npz
|
data_utils/__pycache__/__init__.cpython-310.pyc
DELETED
Binary file (149 Bytes)
|
|
data_utils/__pycache__/utils.cpython-310.pyc
DELETED
Binary file (4.56 kB)
|
|
data_utils/__pycache__/utils.cpython-39.pyc
DELETED
Binary file (4.56 kB)
|
|
data_utils/ext/synchformer/__pycache__/__init__.cpython-310.pyc
DELETED
Binary file (243 Bytes)
|
|
data_utils/ext/synchformer/__pycache__/__init__.cpython-39.pyc
DELETED
Binary file (241 Bytes)
|
|
data_utils/ext/synchformer/__pycache__/motionformer.cpython-310.pyc
DELETED
Binary file (12.7 kB)
|
|
data_utils/ext/synchformer/__pycache__/motionformer.cpython-39.pyc
DELETED
Binary file (12.7 kB)
|
|
data_utils/ext/synchformer/__pycache__/synchformer.cpython-310.pyc
DELETED
Binary file (1.91 kB)
|
|
data_utils/ext/synchformer/__pycache__/synchformer.cpython-39.pyc
DELETED
Binary file (1.9 kB)
|
|
data_utils/ext/synchformer/__pycache__/utils.cpython-310.pyc
DELETED
Binary file (3.97 kB)
|
|
data_utils/ext/synchformer/__pycache__/utils.cpython-39.pyc
DELETED
Binary file (3.78 kB)
|
|