liuhuadai commited on
Commit
052cf68
·
1 Parent(s): 70bc476

support cot

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. {think_sound → ThinkSound}/__init__.py +0 -0
  2. {think_sound/configs/model_configs/autoencoders → ThinkSound/configs/model_configs}/stable_audio_2_0_vae.json +0 -0
  3. think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_kernel_size3.json → ThinkSound/configs/model_configs/thinksound.json +1 -1
  4. ThinkSound/configs/multimodal_dataset_demo.json +53 -0
  5. {data_utils → ThinkSound/data}/__init__.py +0 -0
  6. {think_sound → ThinkSound}/data/datamodule.py +4 -2
  7. {think_sound → ThinkSound}/data/dataset.py +6 -8
  8. {think_sound → ThinkSound}/data/utils.py +0 -0
  9. {think_sound/data → ThinkSound/inference}/__init__.py +0 -0
  10. {think_sound → ThinkSound}/inference/generation.py +0 -0
  11. {think_sound → ThinkSound}/inference/sampling.py +0 -0
  12. {think_sound → ThinkSound}/inference/utils.py +0 -0
  13. {think_sound → ThinkSound}/models/__init__.py +0 -0
  14. {think_sound → ThinkSound}/models/autoencoders.py +0 -0
  15. {think_sound → ThinkSound}/models/blocks.py +92 -1
  16. {think_sound → ThinkSound}/models/bottleneck.py +0 -0
  17. {think_sound → ThinkSound}/models/codebook_patterns.py +0 -0
  18. {think_sound → ThinkSound}/models/conditioners.py +0 -1
  19. {think_sound → ThinkSound}/models/diffusion.py +1 -3
  20. {think_sound → ThinkSound}/models/dit.py +0 -0
  21. {think_sound/models/mmmodules/model → ThinkSound/models}/embeddings.py +36 -0
  22. {think_sound → ThinkSound}/models/factory.py +0 -0
  23. {think_sound → ThinkSound}/models/local_attention.py +0 -0
  24. {think_sound → ThinkSound}/models/mmdit.py +56 -9
  25. {think_sound → ThinkSound}/models/pretrained.py +0 -0
  26. {think_sound → ThinkSound}/models/pretransforms.py +0 -0
  27. {think_sound → ThinkSound}/models/transformer.py +0 -0
  28. {think_sound/models/mmmodules/model → ThinkSound/models}/transformer_layers.py +2 -2
  29. {think_sound → ThinkSound}/models/utils.py +0 -0
  30. {think_sound → ThinkSound}/training/__init__.py +0 -0
  31. {think_sound → ThinkSound}/training/autoencoders.py +0 -1
  32. {think_sound → ThinkSound}/training/diffusion.py +1 -948
  33. {think_sound → ThinkSound}/training/factory.py +0 -0
  34. {think_sound → ThinkSound}/training/losses/__init__.py +0 -0
  35. {think_sound → ThinkSound}/training/losses/auraloss.py +0 -0
  36. {think_sound → ThinkSound}/training/losses/losses.py +0 -0
  37. {think_sound → ThinkSound}/training/utils.py +0 -0
  38. app.py +50 -59
  39. cot_vgg_demo_caption.txt +1 -0
  40. data_utils/__pycache__/__init__.cpython-310.pyc +0 -0
  41. data_utils/__pycache__/utils.cpython-310.pyc +0 -0
  42. data_utils/__pycache__/utils.cpython-39.pyc +0 -0
  43. data_utils/ext/synchformer/__pycache__/__init__.cpython-310.pyc +0 -0
  44. data_utils/ext/synchformer/__pycache__/__init__.cpython-39.pyc +0 -0
  45. data_utils/ext/synchformer/__pycache__/motionformer.cpython-310.pyc +0 -0
  46. data_utils/ext/synchformer/__pycache__/motionformer.cpython-39.pyc +0 -0
  47. data_utils/ext/synchformer/__pycache__/synchformer.cpython-310.pyc +0 -0
  48. data_utils/ext/synchformer/__pycache__/synchformer.cpython-39.pyc +0 -0
  49. data_utils/ext/synchformer/__pycache__/utils.cpython-310.pyc +0 -0
  50. data_utils/ext/synchformer/__pycache__/utils.cpython-39.pyc +0 -0
{think_sound → ThinkSound}/__init__.py RENAMED
File without changes
{think_sound/configs/model_configs/autoencoders → ThinkSound/configs/model_configs}/stable_audio_2_0_vae.json RENAMED
File without changes
think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_kernel_size3.json → ThinkSound/configs/model_configs/thinksound.json RENAMED
@@ -85,7 +85,7 @@
85
  "clip_dim":1024,
86
  "sync_dim":768,
87
  "text_dim":2048,
88
- "hidden_dim":1024 ,
89
  "depth":21,
90
  "fused_depth":14,
91
  "num_heads":16,
 
85
  "clip_dim":1024,
86
  "sync_dim":768,
87
  "text_dim":2048,
88
+ "hidden_dim":1024,
89
  "depth":21,
90
  "fused_depth":14,
91
  "num_heads":16,
ThinkSound/configs/multimodal_dataset_demo.json ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "dataset_type": "multimodal_dir",
3
+ "video_datasets": [
4
+ {
5
+ "id": "vggsound",
6
+ "path": "dataset/vggsound/video_latents_t5_clip_npz/train",
7
+ "split_path": "dataset/vggsound/split_txt/train_cot.txt"
8
+ }
9
+ ],
10
+ "audio_datasets": [
11
+ {
12
+ "id": "audiostock",
13
+ "path": "dataset/Laion-Audio-630k/audiostock_latents_npz",
14
+ "split_path": "dataset/Laion-Audio-630k/split_txt/cot_audiostock_1.txt"
15
+ },
16
+ {
17
+ "id": "freesound_no_overlap",
18
+ "path": "dataset/Laion-Audio-630k/freesound_no_overlap_latents_npz",
19
+ "split_path": "dataset/Laion-Audio-630k/split_txt/cot_freesound.txt"
20
+ },
21
+ {
22
+ "id": "audioset_sl",
23
+ "path": "dataset/wavcaps/audioset_sl_latents_npz",
24
+ "split_path": "dataset/wavcaps/split_txt/cot_audio_sl_1.txt"
25
+ },
26
+ {
27
+ "id": "audiocaps",
28
+ "path": "dataset/1_audiocaps/audiocaps_latents_npz",
29
+ "split_path": "dataset/1_audiocaps/split_txt/train_cot.txt"
30
+ },
31
+ {
32
+ "id": "bbc",
33
+ "path": "dataset/Laion-Audio-630k/bbc_latents_npz",
34
+ "split_path": "dataset/Laion-Audio-630k/split_txt/cot_bbc_1.txt"
35
+ }
36
+ ],
37
+ "val_datasets": [
38
+ {
39
+ "id": "vggsound",
40
+ "path": "dataset/vggsound/video_latents_t5_clip_npz/test",
41
+ "split_path": "dataset/vggsound/split_txt/test_cot.txt"
42
+ }
43
+ ],
44
+ "test_datasets": [
45
+ {
46
+ "id": "vggsound",
47
+ "path": "cot_coarse",
48
+ "split_path": "cot_vgg_demo_caption.txt"
49
+ }
50
+ ],
51
+ "random_crop": true,
52
+ "input_type": "prompt"
53
+ }
{data_utils → ThinkSound/data}/__init__.py RENAMED
File without changes
{think_sound → ThinkSound}/data/datamodule.py RENAMED
@@ -33,13 +33,14 @@ def get_configs(audio_configs):
33
  return configs
34
 
35
  class DataModule(L.LightningDataModule):
36
- def __init__(self, dataset_config, batch_size, test_batch_size, sample_size, sample_rate, audio_channels=2, num_workers=4,repeat_num=5):
37
  super().__init__()
38
  dataset_type = dataset_config.get("dataset_type", None)
39
  self.batch_size = batch_size
40
  self.num_workers = num_workers
41
  self.test_batch_size = test_batch_size
42
  self.repeat_num = repeat_num
 
43
  assert dataset_type is not None, "Dataset type must be specified in dataset config"
44
 
45
  if audio_channels == 1:
@@ -140,7 +141,8 @@ class DataModule(L.LightningDataModule):
140
  random_crop=random_crop,
141
  input_type=self.input_type,
142
  fps=self.input_type,
143
- force_channels=self.force_channels
 
144
  )
145
 
146
  if stage == 'fit':
 
33
  return configs
34
 
35
  class DataModule(L.LightningDataModule):
36
+ def __init__(self, dataset_config, batch_size, test_batch_size, sample_size, sample_rate, audio_channels=2, num_workers=4,repeat_num=5,latent_length=194):
37
  super().__init__()
38
  dataset_type = dataset_config.get("dataset_type", None)
39
  self.batch_size = batch_size
40
  self.num_workers = num_workers
41
  self.test_batch_size = test_batch_size
42
  self.repeat_num = repeat_num
43
+ self.latent_length = latent_length
44
  assert dataset_type is not None, "Dataset type must be specified in dataset config"
45
 
46
  if audio_channels == 1:
 
141
  random_crop=random_crop,
142
  input_type=self.input_type,
143
  fps=self.input_type,
144
+ force_channels=self.force_channels,
145
+ latent_length=self.latent_length
146
  )
147
 
148
  if stage == 'fit':
{think_sound → ThinkSound}/data/dataset.py RENAMED
@@ -342,8 +342,7 @@ class LatentDataset(torch.utils.data.Dataset):
342
  info = {}
343
  audio, video = self.load_file(audio_filename, info)
344
  info["path"] = audio_filename
345
- assert audio.shape == (64,194), f'{audio.shape} input error, id: {id}'
346
- assert video.shape == (72,1024), f'{video.shape} input error, id: {id}'
347
  info['id'] = Path(audio_filename).stem
348
  for root_path in self.root_paths:
349
  if root_path in audio_filename:
@@ -434,8 +433,7 @@ class AudioDataset(torch.utils.data.Dataset):
434
  info = {}
435
  audio, video = self.load_file(audio_filename, info)
436
  info["path"] = audio_filename
437
- assert audio.shape == (64,194), f'{audio.shape} input error, id: {id}'
438
- assert video.shape == (72,1024), f'{video.shape} input error, id: {id}'
439
  info['id'] = Path(audio_filename).stem
440
  for root_path in self.root_paths:
441
  if root_path in audio_filename:
@@ -454,8 +452,9 @@ class VideoDataset(torch.utils.data.Dataset):
454
  input_type="prompt",
455
  fps=4,
456
  force_channels="stereo",
 
457
  ):
458
-
459
  super().__init__()
460
  self.filenames = []
461
  print(f'configs: {configs[0]}')
@@ -523,7 +522,7 @@ class VideoDataset(torch.utils.data.Dataset):
523
  if 'latent' in data.keys():
524
  audio = data['latent']
525
  else:
526
- audio = torch.zeros(64,194)
527
  info['video_exist'] = self.video_exist
528
  # except:
529
  # print(f'error load file: {filename}')
@@ -540,8 +539,7 @@ class VideoDataset(torch.utils.data.Dataset):
540
  info = {}
541
  audio, video = self.load_file(audio_filename, info)
542
  info["path"] = audio_filename
543
- assert audio is None or audio.shape == (64,194), f'{audio.shape} input error, id: {id}'
544
- assert video.shape == (72,1024), f'{video.shape} input error, id: {id}'
545
  info['id'] = Path(audio_filename).stem
546
  for root_path in self.root_paths:
547
  if root_path in audio_filename:
 
342
  info = {}
343
  audio, video = self.load_file(audio_filename, info)
344
  info["path"] = audio_filename
345
+
 
346
  info['id'] = Path(audio_filename).stem
347
  for root_path in self.root_paths:
348
  if root_path in audio_filename:
 
433
  info = {}
434
  audio, video = self.load_file(audio_filename, info)
435
  info["path"] = audio_filename
436
+
 
437
  info['id'] = Path(audio_filename).stem
438
  for root_path in self.root_paths:
439
  if root_path in audio_filename:
 
452
  input_type="prompt",
453
  fps=4,
454
  force_channels="stereo",
455
+ latent_length=194, # default latent length for video dataset
456
  ):
457
+ self.latent_length = latent_length
458
  super().__init__()
459
  self.filenames = []
460
  print(f'configs: {configs[0]}')
 
522
  if 'latent' in data.keys():
523
  audio = data['latent']
524
  else:
525
+ audio = torch.zeros(64,self.latent_length)
526
  info['video_exist'] = self.video_exist
527
  # except:
528
  # print(f'error load file: {filename}')
 
539
  info = {}
540
  audio, video = self.load_file(audio_filename, info)
541
  info["path"] = audio_filename
542
+
 
543
  info['id'] = Path(audio_filename).stem
544
  for root_path in self.root_paths:
545
  if root_path in audio_filename:
{think_sound → ThinkSound}/data/utils.py RENAMED
File without changes
{think_sound/data → ThinkSound/inference}/__init__.py RENAMED
File without changes
{think_sound → ThinkSound}/inference/generation.py RENAMED
File without changes
{think_sound → ThinkSound}/inference/sampling.py RENAMED
File without changes
{think_sound → ThinkSound}/inference/utils.py RENAMED
File without changes
{think_sound → ThinkSound}/models/__init__.py RENAMED
File without changes
{think_sound → ThinkSound}/models/autoencoders.py RENAMED
File without changes
{think_sound → ThinkSound}/models/blocks.py RENAMED
@@ -336,4 +336,95 @@ class SnakeBeta(nn.Module):
336
  beta = torch.exp(beta)
337
  x = snake_beta(x, alpha, beta)
338
 
339
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
336
  beta = torch.exp(beta)
337
  x = snake_beta(x, alpha, beta)
338
 
339
+ return x
340
+
341
+ class ChannelLastConv1d(nn.Conv1d):
342
+
343
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
344
+ x = x.permute(0, 2, 1)
345
+ x = super().forward(x)
346
+ x = x.permute(0, 2, 1)
347
+ return x
348
+
349
+
350
+ # https://github.com/Stability-AI/sd3-ref
351
+ class MLP(nn.Module):
352
+
353
+ def __init__(
354
+ self,
355
+ dim: int,
356
+ hidden_dim: int,
357
+ multiple_of: int = 256,
358
+ ):
359
+ """
360
+ Initialize the FeedForward module.
361
+
362
+ Args:
363
+ dim (int): Input dimension.
364
+ hidden_dim (int): Hidden dimension of the feedforward layer.
365
+ multiple_of (int): Value to ensure hidden dimension is a multiple of this value.
366
+
367
+ Attributes:
368
+ w1 (ColumnParallelLinear): Linear transformation for the first layer.
369
+ w2 (RowParallelLinear): Linear transformation for the second layer.
370
+ w3 (ColumnParallelLinear): Linear transformation for the third layer.
371
+
372
+ """
373
+ super().__init__()
374
+ hidden_dim = int(2 * hidden_dim / 3)
375
+ hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
376
+
377
+ self.w1 = nn.Linear(dim, hidden_dim, bias=False)
378
+ self.w2 = nn.Linear(hidden_dim, dim, bias=False)
379
+ self.w3 = nn.Linear(dim, hidden_dim, bias=False)
380
+
381
+ def forward(self, x):
382
+ return self.w2(F.silu(self.w1(x)) * self.w3(x))
383
+
384
+
385
+ class ConvMLP(nn.Module):
386
+
387
+ def __init__(
388
+ self,
389
+ dim: int,
390
+ hidden_dim: int,
391
+ multiple_of: int = 256,
392
+ kernel_size: int = 3,
393
+ padding: int = 1,
394
+ ):
395
+ """
396
+ Initialize the FeedForward module.
397
+
398
+ Args:
399
+ dim (int): Input dimension.
400
+ hidden_dim (int): Hidden dimension of the feedforward layer.
401
+ multiple_of (int): Value to ensure hidden dimension is a multiple of this value.
402
+
403
+ Attributes:
404
+ w1 (ColumnParallelLinear): Linear transformation for the first layer.
405
+ w2 (RowParallelLinear): Linear transformation for the second layer.
406
+ w3 (ColumnParallelLinear): Linear transformation for the third layer.
407
+
408
+ """
409
+ super().__init__()
410
+ hidden_dim = int(2 * hidden_dim / 3)
411
+ hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
412
+
413
+ self.w1 = ChannelLastConv1d(dim,
414
+ hidden_dim,
415
+ bias=False,
416
+ kernel_size=kernel_size,
417
+ padding=padding)
418
+ self.w2 = ChannelLastConv1d(hidden_dim,
419
+ dim,
420
+ bias=False,
421
+ kernel_size=kernel_size,
422
+ padding=padding)
423
+ self.w3 = ChannelLastConv1d(dim,
424
+ hidden_dim,
425
+ bias=False,
426
+ kernel_size=kernel_size,
427
+ padding=padding)
428
+
429
+ def forward(self, x):
430
+ return self.w2(F.silu(self.w1(x)) * self.w3(x))
{think_sound → ThinkSound}/models/bottleneck.py RENAMED
File without changes
{think_sound → ThinkSound}/models/codebook_patterns.py RENAMED
File without changes
{think_sound → ThinkSound}/models/conditioners.py RENAMED
@@ -7,7 +7,6 @@ import typing as tp
7
  import gc
8
  from typing import Literal, Optional
9
  import os
10
- from .adp import NumberEmbedder
11
  from ..inference.utils import set_audio_channels
12
  from .factory import create_pretransform_from_config
13
  from .pretransforms import Pretransform
 
7
  import gc
8
  from typing import Literal, Optional
9
  import os
 
10
  from ..inference.utils import set_audio_channels
11
  from .factory import create_pretransform_from_config
12
  from .pretransforms import Pretransform
{think_sound → ThinkSound}/models/diffusion.py RENAMED
@@ -7,14 +7,12 @@ import typing as tp
7
 
8
  from .blocks import ResConvBlock, FourierFeatures, Upsample1d, Upsample1d_2, Downsample1d, Downsample1d_2, SelfAttention1d, SkipBlock, expand_to_planes
9
  from .conditioners import MultiConditioner, create_multi_conditioner_from_conditioning_config
10
- from .dit import DiffusionTransformer
11
  from .mmdit import MMAudio
12
  from .factory import create_pretransform_from_config
13
  from .pretransforms import Pretransform
14
  from ..inference.generation import generate_diffusion_cond
15
 
16
- from .adp import UNetCFG1d, UNet1d
17
-
18
  from time import time
19
 
20
  class Profiler:
 
7
 
8
  from .blocks import ResConvBlock, FourierFeatures, Upsample1d, Upsample1d_2, Downsample1d, Downsample1d_2, SelfAttention1d, SkipBlock, expand_to_planes
9
  from .conditioners import MultiConditioner, create_multi_conditioner_from_conditioning_config
10
+ # from .dit import DiffusionTransformer
11
  from .mmdit import MMAudio
12
  from .factory import create_pretransform_from_config
13
  from .pretransforms import Pretransform
14
  from ..inference.generation import generate_diffusion_cond
15
 
 
 
16
  from time import time
17
 
18
  class Profiler:
{think_sound → ThinkSound}/models/dit.py RENAMED
File without changes
{think_sound/models/mmmodules/model → ThinkSound/models}/embeddings.py RENAMED
@@ -3,6 +3,42 @@ import torch.nn as nn
3
 
4
  # https://github.com/facebookresearch/DiT
5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
  class TimestepEmbedder(nn.Module):
8
  """
 
3
 
4
  # https://github.com/facebookresearch/DiT
5
 
6
+ from typing import Union
7
+
8
+ import torch
9
+ from einops import rearrange
10
+ from torch import Tensor
11
+
12
+ # Ref: https://github.com/black-forest-labs/flux/blob/main/src/flux/math.py
13
+ # Ref: https://github.com/lucidrains/rotary-embedding-torch
14
+
15
+
16
+ def compute_rope_rotations(length: int,
17
+ dim: int,
18
+ theta: int,
19
+ *,
20
+ freq_scaling: float = 1.0,
21
+ device: Union[torch.device, str] = 'cpu') -> Tensor:
22
+ assert dim % 2 == 0
23
+
24
+ with torch.amp.autocast(device_type='cuda', enabled=False):
25
+ pos = torch.arange(length, dtype=torch.float32, device=device)
26
+ freqs = 1.0 / (theta**(torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim))
27
+ freqs *= freq_scaling
28
+
29
+ rot = torch.einsum('..., f -> ... f', pos, freqs)
30
+ rot = torch.stack([torch.cos(rot), -torch.sin(rot), torch.sin(rot), torch.cos(rot)], dim=-1)
31
+ rot = rearrange(rot, 'n d (i j) -> 1 n d i j', i=2, j=2)
32
+ return rot
33
+
34
+
35
+ def apply_rope(x: Tensor, rot: Tensor) -> tuple[Tensor, Tensor]:
36
+ with torch.amp.autocast(device_type='cuda', enabled=False):
37
+ _x = x.float()
38
+ _x = _x.view(*_x.shape[:-1], -1, 1, 2)
39
+ x_out = rot[..., 0] * _x[..., 0] + rot[..., 1] * _x[..., 1]
40
+ return x_out.reshape(*x.shape).to(dtype=x.dtype)
41
+
42
 
43
  class TimestepEmbedder(nn.Module):
44
  """
{think_sound → ThinkSound}/models/factory.py RENAMED
File without changes
{think_sound → ThinkSound}/models/local_attention.py RENAMED
File without changes
{think_sound → ThinkSound}/models/mmdit.py RENAMED
@@ -6,10 +6,10 @@ import torch
6
  import torch.nn as nn
7
  import torch.nn.functional as F
8
  import sys
9
- from .mmmodules.ext.rotary_embeddings import compute_rope_rotations
10
- from .mmmodules.model.embeddings import TimestepEmbedder
11
- from .mmmodules.model.low_level import MLP, ChannelLastConv1d, ConvMLP
12
- from .mmmodules.model.transformer_layers import (FinalBlock, JointBlock, MMDitSingleBlock)
13
  from .utils import resample
14
 
15
  log = logging.getLogger()
@@ -24,7 +24,6 @@ class PreprocessedConditions:
24
  text_f_c: torch.Tensor
25
 
26
 
27
- # Partially from https://github.com/facebookresearch/DiT
28
  class MMAudio(nn.Module):
29
 
30
  def __init__(self,
@@ -94,7 +93,6 @@ class MMAudio(nn.Module):
94
  nn.Linear(hidden_dim * 4, hidden_dim, bias=False),
95
  nn.Sigmoid()
96
  )
97
- # 初始化最后一层权重为零,促进初始均匀融合
98
  nn.init.zeros_(self.gated_mlp_v[3].weight)
99
  nn.init.zeros_(self.gated_mlp_t[3].weight)
100
  if v2:
@@ -441,9 +439,9 @@ class MMAudio(nn.Module):
441
  # clip_f = torch.cat([clip_f,empty_clip_f], dim=0)
442
  # sync_f = torch.cat([sync_f,empty_sync_f], dim=0)
443
  # text_f = torch.cat([text_f,empty_text_f], dim=0)
444
- clip_f = torch.cat([clip_f,self.get_empty_clip_sequence(bsz)], dim=0)
445
- sync_f = torch.cat([sync_f,self.get_empty_sync_sequence(bsz)], dim=0)
446
- text_f = torch.cat([text_f,self.get_empty_string_sequence(bsz)], dim=0)
447
  if t5_features is not None:
448
  empty_t5_features = torch.zeros_like(t5_features, device=latent.device)
449
  # t5_features = torch.cat([t5_features,empty_t5_features], dim=0)
@@ -529,3 +527,52 @@ class MMAudio(nn.Module):
529
  def sync_seq_len(self) -> int:
530
  return self._sync_seq_len
531
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  import torch.nn as nn
7
  import torch.nn.functional as F
8
  import sys
9
+ from .embeddings import compute_rope_rotations
10
+ from .embeddings import TimestepEmbedder
11
+ from .blocks import MLP, ChannelLastConv1d, ConvMLP
12
+ from .transformer_layers import (FinalBlock, JointBlock, MMDitSingleBlock)
13
  from .utils import resample
14
 
15
  log = logging.getLogger()
 
24
  text_f_c: torch.Tensor
25
 
26
 
 
27
  class MMAudio(nn.Module):
28
 
29
  def __init__(self,
 
93
  nn.Linear(hidden_dim * 4, hidden_dim, bias=False),
94
  nn.Sigmoid()
95
  )
 
96
  nn.init.zeros_(self.gated_mlp_v[3].weight)
97
  nn.init.zeros_(self.gated_mlp_t[3].weight)
98
  if v2:
 
439
  # clip_f = torch.cat([clip_f,empty_clip_f], dim=0)
440
  # sync_f = torch.cat([sync_f,empty_sync_f], dim=0)
441
  # text_f = torch.cat([text_f,empty_text_f], dim=0)
442
+ clip_f = safe_cat(clip_f,self.get_empty_clip_sequence(bsz), dim=0, match_dim=1)
443
+ sync_f = safe_cat(sync_f,self.get_empty_sync_sequence(bsz), dim=0, match_dim=1)
444
+ text_f = safe_cat(text_f,self.get_empty_string_sequence(bsz), dim=0, match_dim=1)
445
  if t5_features is not None:
446
  empty_t5_features = torch.zeros_like(t5_features, device=latent.device)
447
  # t5_features = torch.cat([t5_features,empty_t5_features], dim=0)
 
527
  def sync_seq_len(self) -> int:
528
  return self._sync_seq_len
529
 
530
+
531
+
532
+
533
+
534
+
535
+
536
+
537
+
538
+
539
+
540
+
541
+
542
+
543
+
544
+
545
+
546
+ def truncate_to_target(tensor, target_size, dim=1):
547
+ current_size = tensor.size(dim)
548
+ if current_size > target_size:
549
+ slices = [slice(None)] * tensor.dim()
550
+ slices[dim] = slice(0, target_size)
551
+ return tensor[slices]
552
+ return tensor
553
+
554
+ def pad_to_target(tensor, target_size, dim=1, pad_value=0):
555
+ current_size = tensor.size(dim)
556
+ if current_size < target_size:
557
+ pad_size = target_size - current_size
558
+
559
+ pad_config = [0, 0] * tensor.dim()
560
+ pad_index = 2 * (tensor.dim() - dim - 1) + 1
561
+ pad_config[pad_index] = pad_size
562
+
563
+ return torch.nn.functional.pad(tensor, pad_config, value=pad_value)
564
+ return tensor
565
+
566
+
567
+ def safe_cat(tensor1, tensor2, dim=0, match_dim=1):
568
+
569
+ target_size = tensor2.size(match_dim)
570
+
571
+ if tensor1.size(match_dim) > target_size:
572
+ tensor1 = truncate_to_target(tensor1, target_size, match_dim)
573
+
574
+ else:
575
+ tensor1 = pad_to_target(tensor1, target_size, match_dim)
576
+
577
+ return torch.cat([tensor1, tensor2], dim=dim)
578
+
{think_sound → ThinkSound}/models/pretrained.py RENAMED
File without changes
{think_sound → ThinkSound}/models/pretransforms.py RENAMED
File without changes
{think_sound → ThinkSound}/models/transformer.py RENAMED
File without changes
{think_sound/models/mmmodules/model → ThinkSound/models}/transformer_layers.py RENAMED
@@ -6,8 +6,8 @@ import torch.nn.functional as F
6
  from einops import rearrange
7
  from einops.layers.torch import Rearrange
8
 
9
- from ..ext.rotary_embeddings import apply_rope
10
- from ..model.low_level import MLP, ChannelLastConv1d, ConvMLP
11
  try:
12
  from flash_attn import flash_attn_func, flash_attn_kvpacked_func
13
  print('flash_attn installed, using Flash Attention')
 
6
  from einops import rearrange
7
  from einops.layers.torch import Rearrange
8
 
9
+ from .embeddings import apply_rope
10
+ from .blocks import MLP, ChannelLastConv1d, ConvMLP
11
  try:
12
  from flash_attn import flash_attn_func, flash_attn_kvpacked_func
13
  print('flash_attn installed, using Flash Attention')
{think_sound → ThinkSound}/models/utils.py RENAMED
File without changes
{think_sound → ThinkSound}/training/__init__.py RENAMED
File without changes
{think_sound → ThinkSound}/training/autoencoders.py RENAMED
@@ -9,7 +9,6 @@ from .losses.auraloss import SumAndDifferenceSTFTLoss, MultiResolutionSTFTLoss,
9
  import lightning as L
10
  from lightning.pytorch.callbacks import Callback
11
  from ..models.autoencoders import AudioAutoencoder
12
- from ..models.discriminators import EncodecDiscriminator, OobleckDiscriminator, DACGANLoss
13
  from ..models.bottleneck import VAEBottleneck, RVQBottleneck, DACRVQBottleneck, DACRVQVAEBottleneck, RVQVAEBottleneck, WassersteinBottleneck
14
  from .losses import MultiLoss, AuralossLoss, ValueLoss, L1Loss
15
  from .utils import create_optimizer_from_config, create_scheduler_from_config
 
9
  import lightning as L
10
  from lightning.pytorch.callbacks import Callback
11
  from ..models.autoencoders import AudioAutoencoder
 
12
  from ..models.bottleneck import VAEBottleneck, RVQBottleneck, DACRVQBottleneck, DACRVQVAEBottleneck, RVQVAEBottleneck, WassersteinBottleneck
13
  from .losses import MultiLoss, AuralossLoss, ValueLoss, L1Loss
14
  from .utils import create_optimizer_from_config, create_scheduler_from_config
{think_sound → ThinkSound}/training/diffusion.py RENAMED
@@ -20,7 +20,6 @@ from pytorch_lightning.utilities.rank_zero import rank_zero_only
20
  from ..inference.sampling import get_alphas_sigmas, sample, sample_discrete_euler
21
  from ..models.diffusion import DiffusionModelWrapper, ConditionedDiffusionModelWrapper
22
  from ..models.autoencoders import DiffusionAutoencoder
23
- from ..models.diffusion_prior import PriorType
24
  from .autoencoders import create_loss_modules_from_bottleneck
25
  from .losses import AuralossLoss, MSELoss, MultiLoss
26
  from .utils import create_optimizer_from_config, create_scheduler_from_config, mask_from_frac_lengths, generate_mask, generate_channel_mask
@@ -846,10 +845,9 @@ class DiffusionCondTrainingWrapper(L.LightningModule):
846
 
847
  def predict_step(self, batch, batch_idx):
848
  reals, metadata = batch
849
- # import ipdb
850
- # ipdb.set_trace()
851
  ids = [item['id'] for item in metadata]
852
  batch_size, length = reals.shape[0], reals.shape[2]
 
853
  with torch.amp.autocast('cuda'):
854
  conditioning = self.diffusion.conditioner(metadata, self.device)
855
 
@@ -878,7 +876,6 @@ class DiffusionCondTrainingWrapper(L.LightningModule):
878
  end_time = time.time()
879
  execution_time = end_time - start_time
880
  print(f"执行时间: {execution_time:.2f} 秒")
881
- breakpoint()
882
  if self.diffusion.pretransform is not None:
883
  fakes = self.diffusion.pretransform.decode(fakes)
884
 
@@ -1077,947 +1074,3 @@ class DiffusionCondDemoCallback(Callback):
1077
  gc.collect()
1078
  torch.cuda.empty_cache()
1079
  module.train()
1080
-
1081
- class DiffusionCondInpaintTrainingWrapper(L.LightningModule):
1082
- '''
1083
- Wrapper for training a conditional audio diffusion model.
1084
- '''
1085
- def __init__(
1086
- self,
1087
- model: ConditionedDiffusionModelWrapper,
1088
- lr: float = 1e-4,
1089
- max_mask_segments = 10,
1090
- log_loss_info: bool = False,
1091
- optimizer_configs: dict = None,
1092
- use_ema: bool = True,
1093
- pre_encoded: bool = False,
1094
- cfg_dropout_prob = 0.1,
1095
- timestep_sampler: tp.Literal["uniform", "logit_normal"] = "uniform",
1096
- ):
1097
- super().__init__()
1098
-
1099
- self.diffusion = model
1100
-
1101
- self.use_ema = use_ema
1102
-
1103
- if self.use_ema:
1104
- self.diffusion_ema = EMA(
1105
- self.diffusion.model,
1106
- beta=0.9999,
1107
- power=3/4,
1108
- update_every=1,
1109
- update_after_step=1,
1110
- include_online_model=False
1111
- )
1112
- else:
1113
- self.diffusion_ema = None
1114
-
1115
- self.cfg_dropout_prob = cfg_dropout_prob
1116
-
1117
- self.lr = lr
1118
- self.max_mask_segments = max_mask_segments
1119
-
1120
- self.rng = torch.quasirandom.SobolEngine(1, scramble=True)
1121
-
1122
- self.timestep_sampler = timestep_sampler
1123
-
1124
- self.diffusion_objective = model.diffusion_objective
1125
-
1126
- self.loss_modules = [
1127
- MSELoss("output",
1128
- "targets",
1129
- weight=1.0,
1130
- name="mse_loss"
1131
- )
1132
- ]
1133
-
1134
- self.losses = MultiLoss(self.loss_modules)
1135
-
1136
- self.log_loss_info = log_loss_info
1137
-
1138
- assert lr is not None or optimizer_configs is not None, "Must specify either lr or optimizer_configs in training config"
1139
-
1140
- if optimizer_configs is None:
1141
- optimizer_configs = {
1142
- "diffusion": {
1143
- "optimizer": {
1144
- "type": "Adam",
1145
- "config": {
1146
- "lr": lr
1147
- }
1148
- }
1149
- }
1150
- }
1151
- else:
1152
- if lr is not None:
1153
- print(f"WARNING: learning_rate and optimizer_configs both specified in config. Ignoring learning_rate and using optimizer_configs.")
1154
-
1155
- self.optimizer_configs = optimizer_configs
1156
-
1157
- self.pre_encoded = pre_encoded
1158
-
1159
- def configure_optimizers(self):
1160
- diffusion_opt_config = self.optimizer_configs['diffusion']
1161
- opt_diff = create_optimizer_from_config(diffusion_opt_config['optimizer'], self.diffusion.parameters())
1162
-
1163
- if "scheduler" in diffusion_opt_config:
1164
- sched_diff = create_scheduler_from_config(diffusion_opt_config['scheduler'], opt_diff)
1165
- sched_diff_config = {
1166
- "scheduler": sched_diff,
1167
- "interval": "step"
1168
- }
1169
- return [opt_diff], [sched_diff_config]
1170
-
1171
- return [opt_diff]
1172
-
1173
- def random_mask(self, sequence, max_mask_length):
1174
- b, _, sequence_length = sequence.size()
1175
-
1176
- # Create a mask tensor for each batch element
1177
- masks = []
1178
-
1179
- for i in range(b):
1180
- mask_type = random.randint(0, 2)
1181
-
1182
- if mask_type == 0: # Random mask with multiple segments
1183
- num_segments = random.randint(1, self.max_mask_segments)
1184
- max_segment_length = max_mask_length // num_segments
1185
-
1186
- segment_lengths = random.sample(range(1, max_segment_length + 1), num_segments)
1187
-
1188
- mask = torch.ones((1, 1, sequence_length))
1189
- for length in segment_lengths:
1190
- mask_start = random.randint(0, sequence_length - length)
1191
- mask[:, :, mask_start:mask_start + length] = 0
1192
-
1193
- elif mask_type == 1: # Full mask
1194
- mask = torch.zeros((1, 1, sequence_length))
1195
-
1196
- elif mask_type == 2: # Causal mask
1197
- mask = torch.ones((1, 1, sequence_length))
1198
- mask_length = random.randint(1, max_mask_length)
1199
- mask[:, :, -mask_length:] = 0
1200
-
1201
- mask = mask.to(sequence.device)
1202
- masks.append(mask)
1203
-
1204
- # Concatenate the mask tensors into a single tensor
1205
- mask = torch.cat(masks, dim=0).to(sequence.device)
1206
-
1207
- # Apply the mask to the sequence tensor for each batch element
1208
- masked_sequence = sequence * mask
1209
-
1210
- return masked_sequence, mask
1211
-
1212
- def training_step(self, batch, batch_idx):
1213
- reals, metadata = batch
1214
-
1215
- p = Profiler()
1216
-
1217
- if reals.ndim == 4 and reals.shape[0] == 1:
1218
- reals = reals[0]
1219
-
1220
- loss_info = {}
1221
-
1222
- diffusion_input = reals
1223
-
1224
- if not self.pre_encoded:
1225
- loss_info["audio_reals"] = diffusion_input
1226
-
1227
- p.tick("setup")
1228
-
1229
- with torch.amp.autocast('cuda'):
1230
- conditioning = self.diffusion.conditioner(metadata, self.device)
1231
-
1232
- p.tick("conditioning")
1233
-
1234
- if self.diffusion.pretransform is not None:
1235
- self.diffusion.pretransform.to(self.device)
1236
-
1237
- if not self.pre_encoded:
1238
- with torch.amp.autocast('cuda') and torch.set_grad_enabled(self.diffusion.pretransform.enable_grad):
1239
- diffusion_input = self.diffusion.pretransform.encode(diffusion_input)
1240
- p.tick("pretransform")
1241
-
1242
- # If mask_padding is on, interpolate the padding masks to the size of the pretransformed input
1243
- # if use_padding_mask:
1244
- # padding_masks = F.interpolate(padding_masks.unsqueeze(1).float(), size=diffusion_input.shape[2], mode="nearest").squeeze(1).bool()
1245
- else:
1246
- # Apply scale to pre-encoded latents if needed, as the pretransform encode function will not be run
1247
- if hasattr(self.diffusion.pretransform, "scale") and self.diffusion.pretransform.scale != 1.0:
1248
- diffusion_input = diffusion_input / self.diffusion.pretransform.scale
1249
-
1250
- # Max mask size is the full sequence length
1251
- max_mask_length = diffusion_input.shape[2]
1252
-
1253
- # Create a mask of random length for a random slice of the input
1254
- masked_input, mask = self.random_mask(diffusion_input, max_mask_length)
1255
-
1256
- # conditioning['inpaint_mask'] = [mask]
1257
- conditioning['inpaint_masked_input'] = [masked_input]
1258
-
1259
- if self.timestep_sampler == "uniform":
1260
- # Draw uniformly distributed continuous timesteps
1261
- t = self.rng.draw(reals.shape[0])[:, 0].to(self.device)
1262
- elif self.timestep_sampler == "logit_normal":
1263
- t = torch.sigmoid(torch.randn(reals.shape[0], device=self.device))
1264
-
1265
- # Calculate the noise schedule parameters for those timesteps
1266
- if self.diffusion_objective == "v":
1267
- alphas, sigmas = get_alphas_sigmas(t)
1268
- elif self.diffusion_objective == "rectified_flow":
1269
- alphas, sigmas = 1-t, t
1270
-
1271
- # Combine the ground truth data and the noise
1272
- alphas = alphas[:, None, None]
1273
- sigmas = sigmas[:, None, None]
1274
- noise = torch.randn_like(diffusion_input)
1275
- noised_inputs = diffusion_input * alphas + noise * sigmas
1276
-
1277
- if self.diffusion_objective == "v":
1278
- targets = noise * alphas - diffusion_input * sigmas
1279
- elif self.diffusion_objective == "rectified_flow":
1280
- targets = noise - diffusion_input
1281
-
1282
- p.tick("noise")
1283
-
1284
- extra_args = {}
1285
-
1286
- with torch.amp.autocast('cuda'):
1287
- p.tick("amp")
1288
- output = self.diffusion(noised_inputs, t, cond=conditioning, cfg_dropout_prob = self.cfg_dropout_prob, **extra_args)
1289
- p.tick("diffusion")
1290
-
1291
- loss_info.update({
1292
- "output": output,
1293
- "targets": targets,
1294
- })
1295
-
1296
- loss, losses = self.losses(loss_info)
1297
-
1298
- if self.log_loss_info:
1299
- # Loss debugging logs
1300
- num_loss_buckets = 10
1301
- bucket_size = 1 / num_loss_buckets
1302
- loss_all = F.mse_loss(output, targets, reduction="none")
1303
-
1304
- sigmas = rearrange(self.all_gather(sigmas), "w b c n -> (w b) c n").squeeze()
1305
-
1306
- # gather loss_all across all GPUs
1307
- loss_all = rearrange(self.all_gather(loss_all), "w b c n -> (w b) c n")
1308
-
1309
- # Bucket loss values based on corresponding sigma values, bucketing sigma values by bucket_size
1310
- loss_all = torch.stack([loss_all[(sigmas >= i) & (sigmas < i + bucket_size)].mean() for i in torch.arange(0, 1, bucket_size).to(self.device)])
1311
-
1312
- # Log bucketed losses with corresponding sigma bucket values, if it's not NaN
1313
- debug_log_dict = {
1314
- f"model/loss_all_{i/num_loss_buckets:.1f}": loss_all[i].detach() for i in range(num_loss_buckets) if not torch.isnan(loss_all[i])
1315
- }
1316
-
1317
- self.log_dict(debug_log_dict)
1318
-
1319
- log_dict = {
1320
- 'train/loss': loss.detach(),
1321
- 'train/std_data': diffusion_input.std(),
1322
- 'train/lr': self.trainer.optimizers[0].param_groups[0]['lr']
1323
- }
1324
-
1325
- for loss_name, loss_value in losses.items():
1326
- log_dict[f"train/{loss_name}"] = loss_value.detach()
1327
-
1328
- self.log_dict(log_dict, prog_bar=True, on_step=True)
1329
- p.tick("log")
1330
- #print(f"Profiler: {p}")
1331
- return loss
1332
-
1333
- def on_before_zero_grad(self, *args, **kwargs):
1334
- if self.diffusion_ema is not None:
1335
- self.diffusion_ema.update()
1336
-
1337
- def export_model(self, path, use_safetensors=False):
1338
- if self.diffusion_ema is not None:
1339
- self.diffusion.model = self.diffusion_ema.ema_model
1340
-
1341
- if use_safetensors:
1342
- save_file(self.diffusion.state_dict(), path)
1343
- else:
1344
- torch.save({"state_dict": self.diffusion.state_dict()}, path)
1345
-
1346
- class DiffusionCondInpaintDemoCallback(Callback):
1347
- def __init__(
1348
- self,
1349
- demo_dl,
1350
- demo_every=2000,
1351
- demo_steps=250,
1352
- sample_size=65536,
1353
- sample_rate=48000,
1354
- demo_cfg_scales: tp.Optional[tp.List[int]] = [3, 5, 7]
1355
- ):
1356
- super().__init__()
1357
- self.demo_every = demo_every
1358
- self.demo_steps = demo_steps
1359
- self.demo_samples = sample_size
1360
- self.demo_dl = iter(demo_dl)
1361
- self.sample_rate = sample_rate
1362
- self.demo_cfg_scales = demo_cfg_scales
1363
- self.last_demo_step = -1
1364
-
1365
- @rank_zero_only
1366
- @torch.no_grad()
1367
- def on_train_batch_end(self, trainer, module: DiffusionCondTrainingWrapper, outputs, batch, batch_idx):
1368
- if (trainer.global_step - 1) % self.demo_every != 0 or self.last_demo_step == trainer.global_step:
1369
- return
1370
-
1371
- self.last_demo_step = trainer.global_step
1372
-
1373
- try:
1374
- log_dict = {}
1375
-
1376
- demo_reals, metadata = next(self.demo_dl)
1377
-
1378
- # Remove extra dimension added by WebDataset
1379
- if demo_reals.ndim == 4 and demo_reals.shape[0] == 1:
1380
- demo_reals = demo_reals[0]
1381
-
1382
- demo_reals = demo_reals.to(module.device)
1383
-
1384
- if not module.pre_encoded:
1385
- # Log the real audio
1386
- log_dict[f'demo_reals_melspec_left'] = wandb.Image(audio_spectrogram_image(rearrange(demo_reals, "b d n -> d (b n)").mul(32767).to(torch.int16).cpu()))
1387
- # log_dict[f'demo_reals'] = wandb.Audio(rearrange(demo_reals, "b d n -> d (b n)").mul(32767).to(torch.int16).cpu(), sample_rate=self.sample_rate, caption="demo reals")
1388
-
1389
- if module.diffusion.pretransform is not None:
1390
- module.diffusion.pretransform.to(module.device)
1391
- with torch.amp.autocast('cuda'):
1392
- demo_reals = module.diffusion.pretransform.encode(demo_reals)
1393
-
1394
- demo_samples = demo_reals.shape[2]
1395
-
1396
- # Get conditioning
1397
- conditioning = module.diffusion.conditioner(metadata, module.device)
1398
-
1399
- masked_input, mask = module.random_mask(demo_reals, demo_reals.shape[2])
1400
-
1401
- conditioning['inpaint_mask'] = [mask]
1402
- conditioning['inpaint_masked_input'] = [masked_input]
1403
-
1404
- if module.diffusion.pretransform is not None:
1405
- log_dict[f'demo_masked_input'] = wandb.Image(tokens_spectrogram_image(masked_input.cpu()))
1406
- else:
1407
- log_dict[f'demo_masked_input'] = wandb.Image(audio_spectrogram_image(rearrange(masked_input, "b c t -> c (b t)").mul(32767).to(torch.int16).cpu()))
1408
-
1409
- cond_inputs = module.diffusion.get_conditioning_inputs(conditioning)
1410
-
1411
- noise = torch.randn([demo_reals.shape[0], module.diffusion.io_channels, demo_samples]).to(module.device)
1412
-
1413
- trainer.logger.experiment.log(log_dict)
1414
-
1415
- for cfg_scale in self.demo_cfg_scales:
1416
- model = module.diffusion_ema.model if module.diffusion_ema is not None else module.diffusion.model
1417
- print(f"Generating demo for cfg scale {cfg_scale}")
1418
-
1419
- if module.diffusion_objective == "v":
1420
- fakes = sample(model, noise, self.demo_steps, 0, **cond_inputs, cfg_scale=cfg_scale, batch_cfg=True)
1421
- elif module.diffusion_objective == "rectified_flow":
1422
- fakes = sample_discrete_euler(model, noise, self.demo_steps, **cond_inputs, cfg_scale=cfg_scale, batch_cfg=True)
1423
-
1424
- if module.diffusion.pretransform is not None:
1425
- with torch.amp.autocast('cuda'):
1426
- fakes = module.diffusion.pretransform.decode(fakes)
1427
-
1428
- # Put the demos together
1429
- fakes = rearrange(fakes, 'b d n -> d (b n)')
1430
-
1431
- log_dict = {}
1432
-
1433
- filename = f'demo_cfg_{cfg_scale}_{trainer.global_step:08}.wav'
1434
- fakes = fakes.to(torch.float32).div(torch.max(torch.abs(fakes))).mul(32767).to(torch.int16).cpu()
1435
- torchaudio.save(filename, fakes, self.sample_rate)
1436
-
1437
- log_dict[f'demo_cfg_{cfg_scale}'] = wandb.Audio(filename,
1438
- sample_rate=self.sample_rate,
1439
- caption=f'Reconstructed')
1440
-
1441
- log_dict[f'demo_melspec_left_cfg_{cfg_scale}'] = wandb.Image(audio_spectrogram_image(fakes))
1442
-
1443
- trainer.logger.experiment.log(log_dict)
1444
- except Exception as e:
1445
- print(f'{type(e).__name__}: {e}')
1446
- raise e
1447
-
1448
- class DiffusionAutoencoderTrainingWrapper(L.LightningModule):
1449
- '''
1450
- Wrapper for training a diffusion autoencoder
1451
- '''
1452
- def __init__(
1453
- self,
1454
- model: DiffusionAutoencoder,
1455
- lr: float = 1e-4,
1456
- ema_copy = None,
1457
- use_reconstruction_loss: bool = False
1458
- ):
1459
- super().__init__()
1460
-
1461
- self.diffae = model
1462
-
1463
- self.diffae_ema = EMA(
1464
- self.diffae,
1465
- ema_model=ema_copy,
1466
- beta=0.9999,
1467
- power=3/4,
1468
- update_every=1,
1469
- update_after_step=1,
1470
- include_online_model=False
1471
- )
1472
-
1473
- self.lr = lr
1474
-
1475
- self.rng = torch.quasirandom.SobolEngine(1, scramble=True)
1476
-
1477
- loss_modules = [
1478
- MSELoss("v",
1479
- "targets",
1480
- weight=1.0,
1481
- name="mse_loss"
1482
- )
1483
- ]
1484
-
1485
- if model.bottleneck is not None:
1486
- # TODO: Use loss config for configurable bottleneck weights and reconstruction losses
1487
- loss_modules += create_loss_modules_from_bottleneck(model.bottleneck, {})
1488
-
1489
- self.use_reconstruction_loss = use_reconstruction_loss
1490
-
1491
- if use_reconstruction_loss:
1492
- scales = [2048, 1024, 512, 256, 128, 64, 32]
1493
- hop_sizes = []
1494
- win_lengths = []
1495
- overlap = 0.75
1496
- for s in scales:
1497
- hop_sizes.append(int(s * (1 - overlap)))
1498
- win_lengths.append(s)
1499
-
1500
- sample_rate = model.sample_rate
1501
-
1502
- stft_loss_args = {
1503
- "fft_sizes": scales,
1504
- "hop_sizes": hop_sizes,
1505
- "win_lengths": win_lengths,
1506
- "perceptual_weighting": True
1507
- }
1508
-
1509
- out_channels = model.out_channels
1510
-
1511
- if model.pretransform is not None:
1512
- out_channels = model.pretransform.io_channels
1513
-
1514
- if out_channels == 2:
1515
- self.sdstft = auraloss.freq.SumAndDifferenceSTFTLoss(sample_rate=sample_rate, **stft_loss_args)
1516
- else:
1517
- self.sdstft = auraloss.freq.MultiResolutionSTFTLoss(sample_rate=sample_rate, **stft_loss_args)
1518
-
1519
- loss_modules.append(
1520
- AuralossLoss(self.sdstft, 'audio_reals', 'audio_pred', name='mrstft_loss', weight=0.1), # Reconstruction loss
1521
- )
1522
-
1523
- self.losses = MultiLoss(loss_modules)
1524
-
1525
- def configure_optimizers(self):
1526
- return optim.Adam([*self.diffae.parameters()], lr=self.lr)
1527
-
1528
- def training_step(self, batch, batch_idx):
1529
- reals = batch[0]
1530
-
1531
- if reals.ndim == 4 and reals.shape[0] == 1:
1532
- reals = reals[0]
1533
-
1534
- loss_info = {}
1535
-
1536
- loss_info["audio_reals"] = reals
1537
-
1538
- if self.diffae.pretransform is not None:
1539
- with torch.no_grad():
1540
- reals = self.diffae.pretransform.encode(reals)
1541
-
1542
- loss_info["reals"] = reals
1543
-
1544
- #Encode reals, skipping the pretransform since it was already applied
1545
- latents, encoder_info = self.diffae.encode(reals, return_info=True, skip_pretransform=True)
1546
-
1547
- loss_info["latents"] = latents
1548
- loss_info.update(encoder_info)
1549
-
1550
- if self.diffae.decoder is not None:
1551
- latents = self.diffae.decoder(latents)
1552
-
1553
- # Upsample latents to match diffusion length
1554
- if latents.shape[2] != reals.shape[2]:
1555
- latents = F.interpolate(latents, size=reals.shape[2], mode='nearest')
1556
-
1557
- loss_info["latents_upsampled"] = latents
1558
-
1559
- # Draw uniformly distributed continuous timesteps
1560
- t = self.rng.draw(reals.shape[0])[:, 0].to(self.device)
1561
-
1562
- # Calculate the noise schedule parameters for those timesteps
1563
- alphas, sigmas = get_alphas_sigmas(t)
1564
-
1565
- # Combine the ground truth data and the noise
1566
- alphas = alphas[:, None, None]
1567
- sigmas = sigmas[:, None, None]
1568
- noise = torch.randn_like(reals)
1569
- noised_reals = reals * alphas + noise * sigmas
1570
- targets = noise * alphas - reals * sigmas
1571
-
1572
- with torch.amp.autocast('cuda'):
1573
- v = self.diffae.diffusion(noised_reals, t, input_concat_cond=latents)
1574
-
1575
- loss_info.update({
1576
- "v": v,
1577
- "targets": targets
1578
- })
1579
-
1580
- if self.use_reconstruction_loss:
1581
- pred = noised_reals * alphas - v * sigmas
1582
-
1583
- loss_info["pred"] = pred
1584
-
1585
- if self.diffae.pretransform is not None:
1586
- pred = self.diffae.pretransform.decode(pred)
1587
- loss_info["audio_pred"] = pred
1588
-
1589
- loss, losses = self.losses(loss_info)
1590
-
1591
- log_dict = {
1592
- 'train/loss': loss.detach(),
1593
- 'train/std_data': reals.std(),
1594
- 'train/latent_std': latents.std(),
1595
- }
1596
-
1597
- for loss_name, loss_value in losses.items():
1598
- log_dict[f"train/{loss_name}"] = loss_value.detach()
1599
-
1600
- self.log_dict(log_dict, prog_bar=True, on_step=True)
1601
- return loss
1602
-
1603
- def on_before_zero_grad(self, *args, **kwargs):
1604
- self.diffae_ema.update()
1605
-
1606
- def export_model(self, path, use_safetensors=False):
1607
-
1608
- model = self.diffae_ema.ema_model
1609
-
1610
- if use_safetensors:
1611
- save_file(model.state_dict(), path)
1612
- else:
1613
- torch.save({"state_dict": model.state_dict()}, path)
1614
-
1615
- class DiffusionAutoencoderDemoCallback(Callback):
1616
- def __init__(
1617
- self,
1618
- demo_dl,
1619
- demo_every=2000,
1620
- demo_steps=250,
1621
- sample_size=65536,
1622
- sample_rate=48000
1623
- ):
1624
- super().__init__()
1625
- self.demo_every = demo_every
1626
- self.demo_steps = demo_steps
1627
- self.demo_samples = sample_size
1628
- self.demo_dl = iter(demo_dl)
1629
- self.sample_rate = sample_rate
1630
- self.last_demo_step = -1
1631
-
1632
- @rank_zero_only
1633
- @torch.no_grad()
1634
- def on_train_batch_end(self, trainer, module: DiffusionAutoencoderTrainingWrapper, outputs, batch, batch_idx):
1635
- if (trainer.global_step - 1) % self.demo_every != 0 or self.last_demo_step == trainer.global_step:
1636
- return
1637
-
1638
- self.last_demo_step = trainer.global_step
1639
-
1640
- demo_reals, _ = next(self.demo_dl)
1641
-
1642
- # Remove extra dimension added by WebDataset
1643
- if demo_reals.ndim == 4 and demo_reals.shape[0] == 1:
1644
- demo_reals = demo_reals[0]
1645
-
1646
- encoder_input = demo_reals
1647
-
1648
- encoder_input = encoder_input.to(module.device)
1649
-
1650
- demo_reals = demo_reals.to(module.device)
1651
-
1652
- with torch.no_grad() and torch.amp.autocast('cuda'):
1653
- latents = module.diffae_ema.ema_model.encode(encoder_input).float()
1654
- fakes = module.diffae_ema.ema_model.decode(latents, steps=self.demo_steps)
1655
-
1656
- #Interleave reals and fakes
1657
- reals_fakes = rearrange([demo_reals, fakes], 'i b d n -> (b i) d n')
1658
-
1659
- # Put the demos together
1660
- reals_fakes = rearrange(reals_fakes, 'b d n -> d (b n)')
1661
-
1662
- log_dict = {}
1663
-
1664
- filename = f'recon_{trainer.global_step:08}.wav'
1665
- reals_fakes = reals_fakes.to(torch.float32).div(torch.max(torch.abs(reals_fakes))).mul(32767).to(torch.int16).cpu()
1666
- torchaudio.save(filename, reals_fakes, self.sample_rate)
1667
-
1668
- log_dict[f'recon'] = wandb.Audio(filename,
1669
- sample_rate=self.sample_rate,
1670
- caption=f'Reconstructed')
1671
-
1672
- log_dict[f'embeddings_3dpca'] = pca_point_cloud(latents)
1673
- log_dict[f'embeddings_spec'] = wandb.Image(tokens_spectrogram_image(latents))
1674
-
1675
- log_dict[f'recon_melspec_left'] = wandb.Image(audio_spectrogram_image(reals_fakes))
1676
-
1677
- if module.diffae_ema.ema_model.pretransform is not None:
1678
- with torch.no_grad() and torch.amp.autocast('cuda'):
1679
- initial_latents = module.diffae_ema.ema_model.pretransform.encode(encoder_input)
1680
- first_stage_fakes = module.diffae_ema.ema_model.pretransform.decode(initial_latents)
1681
- first_stage_fakes = rearrange(first_stage_fakes, 'b d n -> d (b n)')
1682
- first_stage_fakes = first_stage_fakes.to(torch.float32).mul(32767).to(torch.int16).cpu()
1683
- first_stage_filename = f'first_stage_{trainer.global_step:08}.wav'
1684
- torchaudio.save(first_stage_filename, first_stage_fakes, self.sample_rate)
1685
-
1686
- log_dict[f'first_stage_latents'] = wandb.Image(tokens_spectrogram_image(initial_latents))
1687
-
1688
- log_dict[f'first_stage'] = wandb.Audio(first_stage_filename,
1689
- sample_rate=self.sample_rate,
1690
- caption=f'First Stage Reconstructed')
1691
-
1692
- log_dict[f'first_stage_melspec_left'] = wandb.Image(audio_spectrogram_image(first_stage_fakes))
1693
-
1694
-
1695
- trainer.logger.experiment.log(log_dict)
1696
-
1697
- def create_source_mixture(reals, num_sources=2):
1698
- # Create a fake mixture source by mixing elements from the training batch together with random offsets
1699
- source = torch.zeros_like(reals)
1700
- for i in range(reals.shape[0]):
1701
- sources_added = 0
1702
-
1703
- js = list(range(reals.shape[0]))
1704
- random.shuffle(js)
1705
- for j in js:
1706
- if i == j or (i != j and sources_added < num_sources):
1707
- # Randomly offset the mixed element between 0 and the length of the source
1708
- seq_len = reals.shape[2]
1709
- offset = random.randint(0, seq_len-1)
1710
- source[i, :, offset:] += reals[j, :, :-offset]
1711
- if i == j:
1712
- # If this is the real one, shift the reals as well to ensure alignment
1713
- new_reals = torch.zeros_like(reals[i])
1714
- new_reals[:, offset:] = reals[i, :, :-offset]
1715
- reals[i] = new_reals
1716
- sources_added += 1
1717
-
1718
- return source
1719
-
1720
- class DiffusionPriorTrainingWrapper(L.LightningModule):
1721
- '''
1722
- Wrapper for training a diffusion prior for inverse problems
1723
- Prior types:
1724
- mono_stereo: The prior is conditioned on a mono version of the audio to generate a stereo version
1725
- '''
1726
- def __init__(
1727
- self,
1728
- model: ConditionedDiffusionModelWrapper,
1729
- lr: float = 1e-4,
1730
- ema_copy = None,
1731
- prior_type: PriorType = PriorType.MonoToStereo,
1732
- use_reconstruction_loss: bool = False,
1733
- log_loss_info: bool = False,
1734
- ):
1735
- super().__init__()
1736
-
1737
- self.diffusion = model
1738
-
1739
- self.diffusion_ema = EMA(
1740
- self.diffusion,
1741
- ema_model=ema_copy,
1742
- beta=0.9999,
1743
- power=3/4,
1744
- update_every=1,
1745
- update_after_step=1,
1746
- include_online_model=False
1747
- )
1748
-
1749
- self.lr = lr
1750
-
1751
- self.rng = torch.quasirandom.SobolEngine(1, scramble=True)
1752
-
1753
- self.log_loss_info = log_loss_info
1754
-
1755
- loss_modules = [
1756
- MSELoss("v",
1757
- "targets",
1758
- weight=1.0,
1759
- name="mse_loss"
1760
- )
1761
- ]
1762
-
1763
- self.use_reconstruction_loss = use_reconstruction_loss
1764
-
1765
- if use_reconstruction_loss:
1766
- scales = [2048, 1024, 512, 256, 128, 64, 32]
1767
- hop_sizes = []
1768
- win_lengths = []
1769
- overlap = 0.75
1770
- for s in scales:
1771
- hop_sizes.append(int(s * (1 - overlap)))
1772
- win_lengths.append(s)
1773
-
1774
- sample_rate = model.sample_rate
1775
-
1776
- stft_loss_args = {
1777
- "fft_sizes": scales,
1778
- "hop_sizes": hop_sizes,
1779
- "win_lengths": win_lengths,
1780
- "perceptual_weighting": True
1781
- }
1782
-
1783
- out_channels = model.io_channels
1784
-
1785
-
1786
- if model.pretransform is not None:
1787
- out_channels = model.pretransform.io_channels
1788
- self.audio_out_channels = out_channels
1789
-
1790
- if self.audio_out_channels == 2:
1791
- self.sdstft = auraloss.freq.SumAndDifferenceSTFTLoss(sample_rate=sample_rate, **stft_loss_args)
1792
- self.lrstft = auraloss.freq.MultiResolutionSTFTLoss(sample_rate=sample_rate, **stft_loss_args)
1793
-
1794
- # Add left and right channel reconstruction losses in addition to the sum and difference
1795
- loss_modules += [
1796
- AuralossLoss(self.lrstft, 'audio_reals_left', 'pred_left', name='stft_loss_left', weight=0.05),
1797
- AuralossLoss(self.lrstft, 'audio_reals_right', 'pred_right', name='stft_loss_right', weight=0.05),
1798
- ]
1799
-
1800
- else:
1801
- self.sdstft = auraloss.freq.MultiResolutionSTFTLoss(sample_rate=sample_rate, **stft_loss_args)
1802
-
1803
- loss_modules.append(
1804
- AuralossLoss(self.sdstft, 'audio_reals', 'audio_pred', name='mrstft_loss', weight=0.1), # Reconstruction loss
1805
- )
1806
-
1807
- self.losses = MultiLoss(loss_modules)
1808
-
1809
- self.prior_type = prior_type
1810
-
1811
- def configure_optimizers(self):
1812
- return optim.Adam([*self.diffusion.parameters()], lr=self.lr)
1813
-
1814
- def training_step(self, batch, batch_idx):
1815
- reals, metadata = batch
1816
-
1817
- if reals.ndim == 4 and reals.shape[0] == 1:
1818
- reals = reals[0]
1819
-
1820
- loss_info = {}
1821
-
1822
- loss_info["audio_reals"] = reals
1823
-
1824
- if self.prior_type == PriorType.MonoToStereo:
1825
- source = reals.mean(dim=1, keepdim=True).repeat(1, reals.shape[1], 1).to(self.device)
1826
- loss_info["audio_reals_mono"] = source
1827
- else:
1828
- raise ValueError(f"Unknown prior type {self.prior_type}")
1829
-
1830
- if self.diffusion.pretransform is not None:
1831
- with torch.no_grad():
1832
- reals = self.diffusion.pretransform.encode(reals)
1833
-
1834
- if self.prior_type in [PriorType.MonoToStereo]:
1835
- source = self.diffusion.pretransform.encode(source)
1836
-
1837
- if self.diffusion.conditioner is not None:
1838
- with torch.amp.autocast('cuda'):
1839
- conditioning = self.diffusion.conditioner(metadata, self.device)
1840
- else:
1841
- conditioning = {}
1842
-
1843
- loss_info["reals"] = reals
1844
-
1845
- # Draw uniformly distributed continuous timesteps
1846
- t = self.rng.draw(reals.shape[0])[:, 0].to(self.device)
1847
-
1848
- # Calculate the noise schedule parameters for those timesteps
1849
- alphas, sigmas = get_alphas_sigmas(t)
1850
-
1851
- # Combine the ground truth data and the noise
1852
- alphas = alphas[:, None, None]
1853
- sigmas = sigmas[:, None, None]
1854
- noise = torch.randn_like(reals)
1855
- noised_reals = reals * alphas + noise * sigmas
1856
- targets = noise * alphas - reals * sigmas
1857
-
1858
- with torch.amp.autocast('cuda'):
1859
-
1860
- conditioning['source'] = [source]
1861
-
1862
- v = self.diffusion(noised_reals, t, cond=conditioning, cfg_dropout_prob = 0.1)
1863
-
1864
- loss_info.update({
1865
- "v": v,
1866
- "targets": targets
1867
- })
1868
-
1869
- if self.use_reconstruction_loss:
1870
- pred = noised_reals * alphas - v * sigmas
1871
-
1872
- loss_info["pred"] = pred
1873
-
1874
- if self.diffusion.pretransform is not None:
1875
- pred = self.diffusion.pretransform.decode(pred)
1876
- loss_info["audio_pred"] = pred
1877
-
1878
- if self.audio_out_channels == 2:
1879
- loss_info["pred_left"] = pred[:, 0:1, :]
1880
- loss_info["pred_right"] = pred[:, 1:2, :]
1881
- loss_info["audio_reals_left"] = loss_info["audio_reals"][:, 0:1, :]
1882
- loss_info["audio_reals_right"] = loss_info["audio_reals"][:, 1:2, :]
1883
-
1884
- loss, losses = self.losses(loss_info)
1885
-
1886
- if self.log_loss_info:
1887
- # Loss debugging logs
1888
- num_loss_buckets = 10
1889
- bucket_size = 1 / num_loss_buckets
1890
- loss_all = F.mse_loss(v, targets, reduction="none")
1891
-
1892
- sigmas = rearrange(self.all_gather(sigmas), "w b c n -> (w b) c n").squeeze()
1893
-
1894
- # gather loss_all across all GPUs
1895
- loss_all = rearrange(self.all_gather(loss_all), "w b c n -> (w b) c n")
1896
-
1897
- # Bucket loss values based on corresponding sigma values, bucketing sigma values by bucket_size
1898
- loss_all = torch.stack([loss_all[(sigmas >= i) & (sigmas < i + bucket_size)].mean() for i in torch.arange(0, 1, bucket_size).to(self.device)])
1899
-
1900
- # Log bucketed losses with corresponding sigma bucket values, if it's not NaN
1901
- debug_log_dict = {
1902
- f"model/loss_all_{i/num_loss_buckets:.1f}": loss_all[i].detach() for i in range(num_loss_buckets) if not torch.isnan(loss_all[i])
1903
- }
1904
-
1905
- self.log_dict(debug_log_dict)
1906
-
1907
- log_dict = {
1908
- 'train/loss': loss.detach(),
1909
- 'train/std_data': reals.std()
1910
- }
1911
-
1912
- for loss_name, loss_value in losses.items():
1913
- log_dict[f"train/{loss_name}"] = loss_value.detach()
1914
-
1915
- self.log_dict(log_dict, prog_bar=True, on_step=True)
1916
- return loss
1917
-
1918
- def on_before_zero_grad(self, *args, **kwargs):
1919
- self.diffusion_ema.update()
1920
-
1921
- def export_model(self, path, use_safetensors=False):
1922
-
1923
- #model = self.diffusion_ema.ema_model
1924
- model = self.diffusion
1925
-
1926
- if use_safetensors:
1927
- save_file(model.state_dict(), path)
1928
- else:
1929
- torch.save({"state_dict": model.state_dict()}, path)
1930
-
1931
- class DiffusionPriorDemoCallback(Callback):
1932
- def __init__(
1933
- self,
1934
- demo_dl,
1935
- demo_every=2000,
1936
- demo_steps=250,
1937
- sample_size=65536,
1938
- sample_rate=48000
1939
- ):
1940
- super().__init__()
1941
-
1942
- self.demo_every = demo_every
1943
- self.demo_steps = demo_steps
1944
- self.demo_samples = sample_size
1945
- self.demo_dl = iter(demo_dl)
1946
- self.sample_rate = sample_rate
1947
- self.last_demo_step = -1
1948
-
1949
- @rank_zero_only
1950
- @torch.no_grad()
1951
- def on_train_batch_end(self, trainer, module: DiffusionAutoencoderTrainingWrapper, outputs, batch, batch_idx):
1952
- if (trainer.global_step - 1) % self.demo_every != 0 or self.last_demo_step == trainer.global_step:
1953
- return
1954
-
1955
- self.last_demo_step = trainer.global_step
1956
-
1957
- demo_reals, metadata = next(self.demo_dl)
1958
- # import ipdb
1959
- # ipdb.set_trace()
1960
- # Remove extra dimension added by WebDataset
1961
- if demo_reals.ndim == 4 and demo_reals.shape[0] == 1:
1962
- demo_reals = demo_reals[0]
1963
-
1964
- demo_reals = demo_reals.to(module.device)
1965
-
1966
- encoder_input = demo_reals
1967
-
1968
- if module.diffusion.conditioner is not None:
1969
- with torch.amp.autocast('cuda'):
1970
- conditioning_tensors = module.diffusion.conditioner(metadata, module.device)
1971
-
1972
- else:
1973
- conditioning_tensors = {}
1974
-
1975
-
1976
- with torch.no_grad() and torch.amp.autocast('cuda'):
1977
- if module.prior_type == PriorType.MonoToStereo and encoder_input.shape[1] > 1:
1978
- source = encoder_input.mean(dim=1, keepdim=True).repeat(1, encoder_input.shape[1], 1).to(module.device)
1979
-
1980
- if module.diffusion.pretransform is not None:
1981
- encoder_input = module.diffusion.pretransform.encode(encoder_input)
1982
- source_input = module.diffusion.pretransform.encode(source)
1983
- else:
1984
- source_input = source
1985
-
1986
- conditioning_tensors['source'] = [source_input]
1987
-
1988
- fakes = sample(module.diffusion_ema.model, torch.randn_like(encoder_input), self.demo_steps, 0, cond=conditioning_tensors)
1989
-
1990
- if module.diffusion.pretransform is not None:
1991
- fakes = module.diffusion.pretransform.decode(fakes)
1992
-
1993
- #Interleave reals and fakes
1994
- reals_fakes = rearrange([demo_reals, fakes], 'i b d n -> (b i) d n')
1995
-
1996
- # Put the demos together
1997
- reals_fakes = rearrange(reals_fakes, 'b d n -> d (b n)')
1998
-
1999
- log_dict = {}
2000
-
2001
- filename = f'recon_mono_{trainer.global_step:08}.wav'
2002
- reals_fakes = reals_fakes.to(torch.float32).div(torch.max(torch.abs(reals_fakes))).mul(32767).to(torch.int16).cpu()
2003
- torchaudio.save(filename, reals_fakes, self.sample_rate)
2004
-
2005
- log_dict[f'recon'] = wandb.Audio(filename,
2006
- sample_rate=self.sample_rate,
2007
- caption=f'Reconstructed')
2008
-
2009
- log_dict[f'recon_melspec_left'] = wandb.Image(audio_spectrogram_image(reals_fakes))
2010
-
2011
- #Log the source
2012
- filename = f'source_{trainer.global_step:08}.wav'
2013
- source = rearrange(source, 'b d n -> d (b n)')
2014
- source = source.to(torch.float32).mul(32767).to(torch.int16).cpu()
2015
- torchaudio.save(filename, source, self.sample_rate)
2016
-
2017
- log_dict[f'source'] = wandb.Audio(filename,
2018
- sample_rate=self.sample_rate,
2019
- caption=f'Source')
2020
-
2021
- log_dict[f'source_melspec_left'] = wandb.Image(audio_spectrogram_image(source))
2022
-
2023
- trainer.logger.experiment.log(log_dict)
 
20
  from ..inference.sampling import get_alphas_sigmas, sample, sample_discrete_euler
21
  from ..models.diffusion import DiffusionModelWrapper, ConditionedDiffusionModelWrapper
22
  from ..models.autoencoders import DiffusionAutoencoder
 
23
  from .autoencoders import create_loss_modules_from_bottleneck
24
  from .losses import AuralossLoss, MSELoss, MultiLoss
25
  from .utils import create_optimizer_from_config, create_scheduler_from_config, mask_from_frac_lengths, generate_mask, generate_channel_mask
 
845
 
846
  def predict_step(self, batch, batch_idx):
847
  reals, metadata = batch
 
 
848
  ids = [item['id'] for item in metadata]
849
  batch_size, length = reals.shape[0], reals.shape[2]
850
+ print(f"Predicting {batch_size} samples with length {length} for ids: {ids}")
851
  with torch.amp.autocast('cuda'):
852
  conditioning = self.diffusion.conditioner(metadata, self.device)
853
 
 
876
  end_time = time.time()
877
  execution_time = end_time - start_time
878
  print(f"执行时间: {execution_time:.2f} 秒")
 
879
  if self.diffusion.pretransform is not None:
880
  fakes = self.diffusion.pretransform.decode(fakes)
881
 
 
1074
  gc.collect()
1075
  torch.cuda.empty_cache()
1076
  module.train()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
{think_sound → ThinkSound}/training/factory.py RENAMED
File without changes
{think_sound → ThinkSound}/training/losses/__init__.py RENAMED
File without changes
{think_sound → ThinkSound}/training/losses/auraloss.py RENAMED
File without changes
{think_sound → ThinkSound}/training/losses/losses.py RENAMED
File without changes
{think_sound → ThinkSound}/training/utils.py RENAMED
File without changes
app.py CHANGED
@@ -14,13 +14,12 @@ from lightning.pytorch.tuner import Tuner
14
  from lightning.pytorch import seed_everything
15
  import random
16
  from datetime import datetime
17
- # from think_sound.data.dataset import create_dataloader_from_config
18
- from think_sound.data.datamodule import DataModule
19
- from think_sound.models import create_model_from_config
20
- from think_sound.models.utils import load_ckpt_state_dict, remove_weight_norm_from_model
21
- from think_sound.training import create_training_wrapper_from_config, create_demo_callback_from_config
22
- from think_sound.training.utils import copy_state_dict
23
- from think_sound.inference.sampling import get_alphas_sigmas, sample, sample_discrete_euler
24
  from data_utils.v2a_utils.feature_utils_224 import FeaturesUtils
25
  from torch.utils.data import Dataset
26
  from typing import Optional, Union
@@ -34,7 +33,7 @@ import tempfile
34
  import subprocess
35
  from huggingface_hub import hf_hub_download
36
  from moviepy.editor import VideoFileClip
37
- os.system("conda install -c conda-forge 'ffmpeg<7'")
38
 
39
  _CLIP_SIZE = 224
40
  _CLIP_FPS = 8.0
@@ -101,7 +100,7 @@ class VGGSound(Dataset):
101
 
102
  self.resampler = {}
103
 
104
- def sample(self, video_path,label):
105
  video_id = video_path
106
 
107
  reader = StreamingMediaDecoder(video_path)
@@ -156,7 +155,7 @@ class VGGSound(Dataset):
156
  # padding using the last frame, but no more than 2
157
  current_length = sync_chunk.shape[0]
158
  last_frame = sync_chunk[-1]
159
- # 重复最后一帧以进行填充
160
  padding = last_frame.repeat(self.sync_expected_length - current_length, 1, 1, 1)
161
  assert self.sync_expected_length - current_length < 12, f'sync can pad no more than 2 while {self.sync_expected_length - current_length}'
162
  sync_chunk = torch.cat((sync_chunk, padding), dim=0)
@@ -170,6 +169,7 @@ class VGGSound(Dataset):
170
  data = {
171
  'id': video_id,
172
  'caption': label,
 
173
  # 'audio': audio_chunk,
174
  'clip_video': clip_chunk,
175
  'sync_video': sync_chunk,
@@ -187,17 +187,16 @@ else:
187
 
188
  print(f"load in device {device}")
189
 
190
- vae_ckpt = hf_hub_download(repo_id="liuhuadai/ThinkSound", filename="vae.ckpt",repo_type="model")
191
- synchformer_ckpt = hf_hub_download(repo_id="liuhuadai/ThinkSound", filename="synchformer_state_dict.pth",repo_type="model")
 
192
  feature_extractor = FeaturesUtils(
193
- vae_ckpt=vae_ckpt,
194
- vae_config='think_sound/configs/model_configs/autoencoders/stable_audio_2_0_vae.json',
195
  enable_conditions=True,
196
  synchformer_ckpt=synchformer_ckpt
197
  ).eval().to(extra_device)
198
 
199
-
200
-
201
  args = get_all_args()
202
 
203
  seed = 10086
@@ -206,7 +205,7 @@ seed_everything(seed, workers=True)
206
 
207
 
208
  #Get JSON config from args.model_config
209
- with open("think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_kernel_size3.json") as f:
210
  model_config = json.load(f)
211
 
212
  model = create_model_from_config(model_config)
@@ -229,7 +228,7 @@ model.pretransform.load_state_dict(load_vae_state)
229
  # Remove weight_norm from the pretransform if specified
230
  if args.remove_pretransform_weight_norm == "post_load":
231
  remove_weight_norm_from_model(model.pretransform)
232
- ckpt_path = hf_hub_download(repo_id="liuhuadai/ThinkSound", filename="thinksound.ckpt",repo_type="model")
233
  training_wrapper = create_training_wrapper_from_config(model_config, model)
234
  # 加载模型权重时根据设备选择map_location
235
  training_wrapper.load_state_dict(torch.load(ckpt_path)['state_dict'])
@@ -243,16 +242,17 @@ def get_video_duration(video_path):
243
  @spaces.GPU(duration=60)
244
  @torch.inference_mode()
245
  @torch.no_grad()
246
- def get_audio(video_path, caption):
247
- # 允许caption为空
248
  if caption is None:
249
  caption = ''
 
 
250
  timer = Timer(duration="00:15:00:00")
251
  #get video duration
252
  duration_sec = get_video_duration(video_path)
253
  print(duration_sec)
254
  preprocesser = VGGSound(duration_sec=duration_sec)
255
- data = preprocesser.sample(video_path, caption)
256
 
257
 
258
 
@@ -261,7 +261,7 @@ def get_audio(video_path, caption):
261
  preprocessed_data['metaclip_global_text_features'] = metaclip_global_text_features.detach().cpu().squeeze(0)
262
  preprocessed_data['metaclip_text_features'] = metaclip_text_features.detach().cpu().squeeze(0)
263
 
264
- t5_features = feature_extractor.encode_t5_text(data['caption'])
265
  preprocessed_data['t5_features'] = t5_features.detach().cpu().squeeze(0)
266
 
267
  clip_features = feature_extractor.encode_video_with_clip(data['clip_video'].unsqueeze(0).to(extra_device))
@@ -305,56 +305,47 @@ def get_audio(video_path, caption):
305
  fakes = training_wrapper.diffusion.pretransform.decode(fakes)
306
 
307
  audios = fakes.to(torch.float32).div(torch.max(torch.abs(fakes))).clamp(-1, 1).mul(32767).to(torch.int16).cpu()
308
- # 保存临时音频文件
309
  with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp_audio:
310
  torchaudio.save(tmp_audio.name, audios[0], 44100)
311
  audio_path = tmp_audio.name
 
312
  return audio_path
313
 
314
- def synthesize_video_with_audio(video_file, caption):
315
- # 允许caption为空
316
- if caption is None:
317
- caption = ''
318
- audio_path = get_audio(video_file, caption)
319
  with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as tmp_video:
320
  output_video_path = tmp_video.name
321
- # ffmpeg命令:用新音频替换原视频音轨
322
  cmd = [
323
  'ffmpeg', '-y', '-i', video_file, '-i', audio_path,
324
  '-c:v', 'copy', '-map', '0:v:0', '-map', '1:a:0',
325
  '-shortest', output_video_path
326
  ]
327
  subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
 
328
  return output_video_path
329
 
330
- # Gradio界面
331
- with gr.Blocks() as demo:
332
- gr.Markdown(
333
- """
334
- # ThinkSound\n
335
- ThinkSound is a unified Any2Audio generation framework with flow matching guided by Chain-of-Thought (CoT) reasoning.
336
-
337
- Upload video and caption (optional), and get video with audio!
338
-
339
- """
340
- )
341
- with gr.Row():
342
- video_input = gr.Video(label="upload video")
343
- caption_input = gr.Textbox(label="caption(optional)", placeholder="can be empty", lines=1)
344
- output_video = gr.Video(label="output video")
345
- btn = gr.Button("start synthesize")
346
- btn.click(fn=synthesize_video_with_audio, inputs=[video_input, caption_input], outputs=output_video)
347
-
348
- gr.Examples(
349
- examples=[
350
- ["./examples/1_mute.mp4", "Playing Trumpet", "./examples/1.mp4"],
351
- ["./examples/2_mute.mp4", "Axe striking", "./examples/2.mp4"],
352
- ["./examples/3_mute.mp4", "Gentle Sucking Sounds From the Pacifier", "./examples/3.mp4"],
353
- ["./examples/4_mute.mp4", "train passing by", "./examples/4.mp4"],
354
- ["./examples/5_mute.mp4", "Lighting Firecrackers", "./examples/5.mp4"]
355
- ],
356
- inputs=[video_input, caption_input,output_video],
357
- )
358
-
359
- demo.launch(share=True)
360
 
 
14
  from lightning.pytorch import seed_everything
15
  import random
16
  from datetime import datetime
17
+ from ThinkSound.data.datamodule import DataModule
18
+ from ThinkSound.models import create_model_from_config
19
+ from ThinkSound.models.utils import load_ckpt_state_dict, remove_weight_norm_from_model
20
+ from ThinkSound.training import create_training_wrapper_from_config, create_demo_callback_from_config
21
+ from ThinkSound.training.utils import copy_state_dict
22
+ from ThinkSound.inference.sampling import get_alphas_sigmas, sample, sample_discrete_euler
 
23
  from data_utils.v2a_utils.feature_utils_224 import FeaturesUtils
24
  from torch.utils.data import Dataset
25
  from typing import Optional, Union
 
33
  import subprocess
34
  from huggingface_hub import hf_hub_download
35
  from moviepy.editor import VideoFileClip
36
+ # os.system("conda install -c conda-forge 'ffmpeg<7'")
37
 
38
  _CLIP_SIZE = 224
39
  _CLIP_FPS = 8.0
 
100
 
101
  self.resampler = {}
102
 
103
+ def sample(self, video_path,label,cot):
104
  video_id = video_path
105
 
106
  reader = StreamingMediaDecoder(video_path)
 
155
  # padding using the last frame, but no more than 2
156
  current_length = sync_chunk.shape[0]
157
  last_frame = sync_chunk[-1]
158
+
159
  padding = last_frame.repeat(self.sync_expected_length - current_length, 1, 1, 1)
160
  assert self.sync_expected_length - current_length < 12, f'sync can pad no more than 2 while {self.sync_expected_length - current_length}'
161
  sync_chunk = torch.cat((sync_chunk, padding), dim=0)
 
169
  data = {
170
  'id': video_id,
171
  'caption': label,
172
+ 'caption_cot': cot,
173
  # 'audio': audio_chunk,
174
  'clip_video': clip_chunk,
175
  'sync_video': sync_chunk,
 
187
 
188
  print(f"load in device {device}")
189
 
190
+ vae_ckpt = hf_hub_download(repo_id="FunAudioLLM/ThinkSound", filename="vae.ckpt",repo_type="model")
191
+ synchformer_ckpt = hf_hub_download(repo_id="FunAudioLLM/ThinkSound", filename="synchformer_state_dict.pth",repo_type="model")
192
+
193
  feature_extractor = FeaturesUtils(
194
+ vae_ckpt=None,
195
+ vae_config='ThinkSound/configs/model_configs/stable_audio_2_0_vae.json',
196
  enable_conditions=True,
197
  synchformer_ckpt=synchformer_ckpt
198
  ).eval().to(extra_device)
199
 
 
 
200
  args = get_all_args()
201
 
202
  seed = 10086
 
205
 
206
 
207
  #Get JSON config from args.model_config
208
+ with open("ThinkSound/configs/model_configs/thinksound.json") as f:
209
  model_config = json.load(f)
210
 
211
  model = create_model_from_config(model_config)
 
228
  # Remove weight_norm from the pretransform if specified
229
  if args.remove_pretransform_weight_norm == "post_load":
230
  remove_weight_norm_from_model(model.pretransform)
231
+ ckpt_path = hf_hub_download(repo_id="FunAudioLLM/ThinkSound", filename="thinksound.ckpt",repo_type="model")
232
  training_wrapper = create_training_wrapper_from_config(model_config, model)
233
  # 加载模型权重时根据设备选择map_location
234
  training_wrapper.load_state_dict(torch.load(ckpt_path)['state_dict'])
 
242
  @spaces.GPU(duration=60)
243
  @torch.inference_mode()
244
  @torch.no_grad()
245
+ def get_audio(video_path, caption, cot):
 
246
  if caption is None:
247
  caption = ''
248
+ if cot is None:
249
+ cot = caption
250
  timer = Timer(duration="00:15:00:00")
251
  #get video duration
252
  duration_sec = get_video_duration(video_path)
253
  print(duration_sec)
254
  preprocesser = VGGSound(duration_sec=duration_sec)
255
+ data = preprocesser.sample(video_path, caption, cot)
256
 
257
 
258
 
 
261
  preprocessed_data['metaclip_global_text_features'] = metaclip_global_text_features.detach().cpu().squeeze(0)
262
  preprocessed_data['metaclip_text_features'] = metaclip_text_features.detach().cpu().squeeze(0)
263
 
264
+ t5_features = feature_extractor.encode_t5_text(data['caption_cot'])
265
  preprocessed_data['t5_features'] = t5_features.detach().cpu().squeeze(0)
266
 
267
  clip_features = feature_extractor.encode_video_with_clip(data['clip_video'].unsqueeze(0).to(extra_device))
 
305
  fakes = training_wrapper.diffusion.pretransform.decode(fakes)
306
 
307
  audios = fakes.to(torch.float32).div(torch.max(torch.abs(fakes))).clamp(-1, 1).mul(32767).to(torch.int16).cpu()
 
308
  with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp_audio:
309
  torchaudio.save(tmp_audio.name, audios[0], 44100)
310
  audio_path = tmp_audio.name
311
+
312
  return audio_path
313
 
314
+ def synthesize_video_with_audio(video_file, caption, cot):
315
+ audio_path = get_audio(video_file, caption, cot)
 
 
 
316
  with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as tmp_video:
317
  output_video_path = tmp_video.name
318
+
319
  cmd = [
320
  'ffmpeg', '-y', '-i', video_file, '-i', audio_path,
321
  '-c:v', 'copy', '-map', '0:v:0', '-map', '1:a:0',
322
  '-shortest', output_video_path
323
  ]
324
  subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
325
+
326
  return output_video_path
327
 
328
+ demo = gr.Interface(
329
+ fn=synthesize_video_with_audio,
330
+ inputs=[
331
+ gr.Video(label="Upload Video"),
332
+ gr.Textbox(label="Caption (optional)", placeholder="can be empty",),
333
+ gr.Textbox(label="CoT Description (optional)", lines=6, placeholder="can be empty",),
334
+ ],
335
+ outputs=[
336
+ gr.Video(label="Result"),
337
+ ],
338
+ title="ThinkSound Demo",
339
+ description="Upload a video, caption, or CoT to generate audio. For an enhanced experience, we automatically merge the generated audio with your original silent video. (Note: Flexible audio generation lengths are supported.:)",
340
+ examples=[
341
+ ["examples/3_mute.mp4", "Gentle Sucking Sounds From the Pacifier", "Begin by creating a soft, steady background of light pacifier suckling. Add subtle, breathy rhythms to mimic a newborn's gentle mouth movements. Keep the sound smooth, natural, and soothing."],
342
+ ["examples/2_mute.mp4", "Printer Printing", "Generate a continuous printer printing sound with periodic beeps and paper movement, plus a cat pawing at the machine. Add subtle ambient room noise for authenticity, keeping the focus on printing, beeps, and the cat's interaction."],
343
+ ["examples/4_mute.mp4", "Plastic Debris Handling", "Begin with the sound of hands scooping up loose plastic debris, followed by the subtle cascading noise as the pieces fall and scatter back down. Include soft crinkling and rustling to emphasize the texture of the plastic. Add ambient factory background noise with distant machinery to create an industrial atmosphere."],
344
+ ["examples/5_mute.mp4", "Lighting Firecrackers", "Generate the sound of firecrackers lighting and exploding repeatedly on the ground, followed by fireworks bursting in the sky. Incorporate occasional subtle echoes to mimic an outdoor night ambiance, with no human voices present."]
345
+ ],
346
+ cache_examples=True
347
+ )
348
+
349
+ if __name__ == "__main__":
350
+ demo.launch(share=True)
 
 
 
 
 
 
 
351
 
cot_vgg_demo_caption.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ demo.npz
data_utils/__pycache__/__init__.cpython-310.pyc DELETED
Binary file (149 Bytes)
 
data_utils/__pycache__/utils.cpython-310.pyc DELETED
Binary file (4.56 kB)
 
data_utils/__pycache__/utils.cpython-39.pyc DELETED
Binary file (4.56 kB)
 
data_utils/ext/synchformer/__pycache__/__init__.cpython-310.pyc DELETED
Binary file (243 Bytes)
 
data_utils/ext/synchformer/__pycache__/__init__.cpython-39.pyc DELETED
Binary file (241 Bytes)
 
data_utils/ext/synchformer/__pycache__/motionformer.cpython-310.pyc DELETED
Binary file (12.7 kB)
 
data_utils/ext/synchformer/__pycache__/motionformer.cpython-39.pyc DELETED
Binary file (12.7 kB)
 
data_utils/ext/synchformer/__pycache__/synchformer.cpython-310.pyc DELETED
Binary file (1.91 kB)
 
data_utils/ext/synchformer/__pycache__/synchformer.cpython-39.pyc DELETED
Binary file (1.9 kB)
 
data_utils/ext/synchformer/__pycache__/utils.cpython-310.pyc DELETED
Binary file (3.97 kB)
 
data_utils/ext/synchformer/__pycache__/utils.cpython-39.pyc DELETED
Binary file (3.78 kB)