diff --git a/think_sound/__init__.py b/ThinkSound/__init__.py similarity index 100% rename from think_sound/__init__.py rename to ThinkSound/__init__.py diff --git a/think_sound/configs/model_configs/autoencoders/stable_audio_2_0_vae.json b/ThinkSound/configs/model_configs/stable_audio_2_0_vae.json similarity index 100% rename from think_sound/configs/model_configs/autoencoders/stable_audio_2_0_vae.json rename to ThinkSound/configs/model_configs/stable_audio_2_0_vae.json diff --git a/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_kernel_size3.json b/ThinkSound/configs/model_configs/thinksound.json similarity index 99% rename from think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_kernel_size3.json rename to ThinkSound/configs/model_configs/thinksound.json index 5d466bef5a76c21149b14c8c998c55e737a4845a..1458b0d43350e928b9c68cc8619242b8ff8f87c1 100644 --- a/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_kernel_size3.json +++ b/ThinkSound/configs/model_configs/thinksound.json @@ -85,7 +85,7 @@ "clip_dim":1024, "sync_dim":768, "text_dim":2048, - "hidden_dim":1024 , + "hidden_dim":1024, "depth":21, "fused_depth":14, "num_heads":16, diff --git a/ThinkSound/configs/multimodal_dataset_demo.json b/ThinkSound/configs/multimodal_dataset_demo.json new file mode 100644 index 0000000000000000000000000000000000000000..54e059db04525ce2334593ffce5809295ab6ffba --- /dev/null +++ b/ThinkSound/configs/multimodal_dataset_demo.json @@ -0,0 +1,53 @@ +{ + "dataset_type": "multimodal_dir", + "video_datasets": [ + { + "id": "vggsound", + "path": "dataset/vggsound/video_latents_t5_clip_npz/train", + "split_path": "dataset/vggsound/split_txt/train_cot.txt" + } + ], + "audio_datasets": [ + { + "id": "audiostock", + "path": "dataset/Laion-Audio-630k/audiostock_latents_npz", + "split_path": "dataset/Laion-Audio-630k/split_txt/cot_audiostock_1.txt" + }, + { + "id": "freesound_no_overlap", + "path": "dataset/Laion-Audio-630k/freesound_no_overlap_latents_npz", + "split_path": "dataset/Laion-Audio-630k/split_txt/cot_freesound.txt" + }, + { + "id": "audioset_sl", + "path": "dataset/wavcaps/audioset_sl_latents_npz", + "split_path": "dataset/wavcaps/split_txt/cot_audio_sl_1.txt" + }, + { + "id": "audiocaps", + "path": "dataset/1_audiocaps/audiocaps_latents_npz", + "split_path": "dataset/1_audiocaps/split_txt/train_cot.txt" + }, + { + "id": "bbc", + "path": "dataset/Laion-Audio-630k/bbc_latents_npz", + "split_path": "dataset/Laion-Audio-630k/split_txt/cot_bbc_1.txt" + } + ], + "val_datasets": [ + { + "id": "vggsound", + "path": "dataset/vggsound/video_latents_t5_clip_npz/test", + "split_path": "dataset/vggsound/split_txt/test_cot.txt" + } + ], + "test_datasets": [ + { + "id": "vggsound", + "path": "cot_coarse", + "split_path": "cot_vgg_demo_caption.txt" + } + ], + "random_crop": true, + "input_type": "prompt" +} \ No newline at end of file diff --git a/data_utils/__init__.py b/ThinkSound/data/__init__.py similarity index 100% rename from data_utils/__init__.py rename to ThinkSound/data/__init__.py diff --git a/think_sound/data/datamodule.py b/ThinkSound/data/datamodule.py similarity index 98% rename from think_sound/data/datamodule.py rename to ThinkSound/data/datamodule.py index 8a4733d3a8ffceb8ddd8e769a231a9f3a0019a20..330789ad4518fbd5c7d5eb29cf3e87d7840aea78 100644 --- a/think_sound/data/datamodule.py +++ b/ThinkSound/data/datamodule.py @@ -33,13 +33,14 @@ def get_configs(audio_configs): return configs class DataModule(L.LightningDataModule): - def __init__(self, dataset_config, batch_size, test_batch_size, sample_size, sample_rate, audio_channels=2, num_workers=4,repeat_num=5): + def __init__(self, dataset_config, batch_size, test_batch_size, sample_size, sample_rate, audio_channels=2, num_workers=4,repeat_num=5,latent_length=194): super().__init__() dataset_type = dataset_config.get("dataset_type", None) self.batch_size = batch_size self.num_workers = num_workers self.test_batch_size = test_batch_size self.repeat_num = repeat_num + self.latent_length = latent_length assert dataset_type is not None, "Dataset type must be specified in dataset config" if audio_channels == 1: @@ -140,7 +141,8 @@ class DataModule(L.LightningDataModule): random_crop=random_crop, input_type=self.input_type, fps=self.input_type, - force_channels=self.force_channels + force_channels=self.force_channels, + latent_length=self.latent_length ) if stage == 'fit': diff --git a/think_sound/data/dataset.py b/ThinkSound/data/dataset.py similarity index 98% rename from think_sound/data/dataset.py rename to ThinkSound/data/dataset.py index f4e4fadcf9ed0951928f29109214ac2bde5c773a..dccb339a9b783b740509d58648e2e68395782576 100644 --- a/think_sound/data/dataset.py +++ b/ThinkSound/data/dataset.py @@ -342,8 +342,7 @@ class LatentDataset(torch.utils.data.Dataset): info = {} audio, video = self.load_file(audio_filename, info) info["path"] = audio_filename - assert audio.shape == (64,194), f'{audio.shape} input error, id: {id}' - assert video.shape == (72,1024), f'{video.shape} input error, id: {id}' + info['id'] = Path(audio_filename).stem for root_path in self.root_paths: if root_path in audio_filename: @@ -434,8 +433,7 @@ class AudioDataset(torch.utils.data.Dataset): info = {} audio, video = self.load_file(audio_filename, info) info["path"] = audio_filename - assert audio.shape == (64,194), f'{audio.shape} input error, id: {id}' - assert video.shape == (72,1024), f'{video.shape} input error, id: {id}' + info['id'] = Path(audio_filename).stem for root_path in self.root_paths: if root_path in audio_filename: @@ -454,8 +452,9 @@ class VideoDataset(torch.utils.data.Dataset): input_type="prompt", fps=4, force_channels="stereo", + latent_length=194, # default latent length for video dataset ): - + self.latent_length = latent_length super().__init__() self.filenames = [] print(f'configs: {configs[0]}') @@ -523,7 +522,7 @@ class VideoDataset(torch.utils.data.Dataset): if 'latent' in data.keys(): audio = data['latent'] else: - audio = torch.zeros(64,194) + audio = torch.zeros(64,self.latent_length) info['video_exist'] = self.video_exist # except: # print(f'error load file: {filename}') @@ -540,8 +539,7 @@ class VideoDataset(torch.utils.data.Dataset): info = {} audio, video = self.load_file(audio_filename, info) info["path"] = audio_filename - assert audio is None or audio.shape == (64,194), f'{audio.shape} input error, id: {id}' - assert video.shape == (72,1024), f'{video.shape} input error, id: {id}' + info['id'] = Path(audio_filename).stem for root_path in self.root_paths: if root_path in audio_filename: diff --git a/think_sound/data/utils.py b/ThinkSound/data/utils.py similarity index 100% rename from think_sound/data/utils.py rename to ThinkSound/data/utils.py diff --git a/think_sound/data/__init__.py b/ThinkSound/inference/__init__.py similarity index 100% rename from think_sound/data/__init__.py rename to ThinkSound/inference/__init__.py diff --git a/think_sound/inference/generation.py b/ThinkSound/inference/generation.py similarity index 100% rename from think_sound/inference/generation.py rename to ThinkSound/inference/generation.py diff --git a/think_sound/inference/sampling.py b/ThinkSound/inference/sampling.py similarity index 100% rename from think_sound/inference/sampling.py rename to ThinkSound/inference/sampling.py diff --git a/think_sound/inference/utils.py b/ThinkSound/inference/utils.py similarity index 100% rename from think_sound/inference/utils.py rename to ThinkSound/inference/utils.py diff --git a/think_sound/models/__init__.py b/ThinkSound/models/__init__.py similarity index 100% rename from think_sound/models/__init__.py rename to ThinkSound/models/__init__.py diff --git a/think_sound/models/autoencoders.py b/ThinkSound/models/autoencoders.py similarity index 100% rename from think_sound/models/autoencoders.py rename to ThinkSound/models/autoencoders.py diff --git a/think_sound/models/blocks.py b/ThinkSound/models/blocks.py similarity index 78% rename from think_sound/models/blocks.py rename to ThinkSound/models/blocks.py index 3c827fd2441e643717d123847236d3d6c003ef4f..7743694eb319b59cfc17c169a089cf5ae652f9bd 100644 --- a/think_sound/models/blocks.py +++ b/ThinkSound/models/blocks.py @@ -336,4 +336,95 @@ class SnakeBeta(nn.Module): beta = torch.exp(beta) x = snake_beta(x, alpha, beta) - return x \ No newline at end of file + return x + +class ChannelLastConv1d(nn.Conv1d): + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x.permute(0, 2, 1) + x = super().forward(x) + x = x.permute(0, 2, 1) + return x + + +# https://github.com/Stability-AI/sd3-ref +class MLP(nn.Module): + + def __init__( + self, + dim: int, + hidden_dim: int, + multiple_of: int = 256, + ): + """ + Initialize the FeedForward module. + + Args: + dim (int): Input dimension. + hidden_dim (int): Hidden dimension of the feedforward layer. + multiple_of (int): Value to ensure hidden dimension is a multiple of this value. + + Attributes: + w1 (ColumnParallelLinear): Linear transformation for the first layer. + w2 (RowParallelLinear): Linear transformation for the second layer. + w3 (ColumnParallelLinear): Linear transformation for the third layer. + + """ + super().__init__() + hidden_dim = int(2 * hidden_dim / 3) + hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) + + self.w1 = nn.Linear(dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, dim, bias=False) + self.w3 = nn.Linear(dim, hidden_dim, bias=False) + + def forward(self, x): + return self.w2(F.silu(self.w1(x)) * self.w3(x)) + + +class ConvMLP(nn.Module): + + def __init__( + self, + dim: int, + hidden_dim: int, + multiple_of: int = 256, + kernel_size: int = 3, + padding: int = 1, + ): + """ + Initialize the FeedForward module. + + Args: + dim (int): Input dimension. + hidden_dim (int): Hidden dimension of the feedforward layer. + multiple_of (int): Value to ensure hidden dimension is a multiple of this value. + + Attributes: + w1 (ColumnParallelLinear): Linear transformation for the first layer. + w2 (RowParallelLinear): Linear transformation for the second layer. + w3 (ColumnParallelLinear): Linear transformation for the third layer. + + """ + super().__init__() + hidden_dim = int(2 * hidden_dim / 3) + hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) + + self.w1 = ChannelLastConv1d(dim, + hidden_dim, + bias=False, + kernel_size=kernel_size, + padding=padding) + self.w2 = ChannelLastConv1d(hidden_dim, + dim, + bias=False, + kernel_size=kernel_size, + padding=padding) + self.w3 = ChannelLastConv1d(dim, + hidden_dim, + bias=False, + kernel_size=kernel_size, + padding=padding) + + def forward(self, x): + return self.w2(F.silu(self.w1(x)) * self.w3(x)) diff --git a/think_sound/models/bottleneck.py b/ThinkSound/models/bottleneck.py similarity index 100% rename from think_sound/models/bottleneck.py rename to ThinkSound/models/bottleneck.py diff --git a/think_sound/models/codebook_patterns.py b/ThinkSound/models/codebook_patterns.py similarity index 100% rename from think_sound/models/codebook_patterns.py rename to ThinkSound/models/codebook_patterns.py diff --git a/think_sound/models/conditioners.py b/ThinkSound/models/conditioners.py similarity index 99% rename from think_sound/models/conditioners.py rename to ThinkSound/models/conditioners.py index 915c6fd753d69ed16e0366b9b8051df231d7e5ef..985b74cf6b31174c73972e29671714017aa4cc7d 100644 --- a/think_sound/models/conditioners.py +++ b/ThinkSound/models/conditioners.py @@ -7,7 +7,6 @@ import typing as tp import gc from typing import Literal, Optional import os -from .adp import NumberEmbedder from ..inference.utils import set_audio_channels from .factory import create_pretransform_from_config from .pretransforms import Pretransform diff --git a/think_sound/models/diffusion.py b/ThinkSound/models/diffusion.py similarity index 99% rename from think_sound/models/diffusion.py rename to ThinkSound/models/diffusion.py index 7364e87040ac0fe59a14486b442d194abf5e0f71..42863139e7dbd585fb2cc5ac99ae5c51cdafafb4 100644 --- a/think_sound/models/diffusion.py +++ b/ThinkSound/models/diffusion.py @@ -7,14 +7,12 @@ import typing as tp from .blocks import ResConvBlock, FourierFeatures, Upsample1d, Upsample1d_2, Downsample1d, Downsample1d_2, SelfAttention1d, SkipBlock, expand_to_planes from .conditioners import MultiConditioner, create_multi_conditioner_from_conditioning_config -from .dit import DiffusionTransformer +# from .dit import DiffusionTransformer from .mmdit import MMAudio from .factory import create_pretransform_from_config from .pretransforms import Pretransform from ..inference.generation import generate_diffusion_cond -from .adp import UNetCFG1d, UNet1d - from time import time class Profiler: diff --git a/think_sound/models/dit.py b/ThinkSound/models/dit.py similarity index 100% rename from think_sound/models/dit.py rename to ThinkSound/models/dit.py diff --git a/think_sound/models/mmmodules/model/embeddings.py b/ThinkSound/models/embeddings.py similarity index 55% rename from think_sound/models/mmmodules/model/embeddings.py rename to ThinkSound/models/embeddings.py index d0f5f0d76fbe6d52766c8f24c89d14c755f985f0..81a2d4b3dc3510827f74425668d4ffd28134623c 100644 --- a/think_sound/models/mmmodules/model/embeddings.py +++ b/ThinkSound/models/embeddings.py @@ -3,6 +3,42 @@ import torch.nn as nn # https://github.com/facebookresearch/DiT +from typing import Union + +import torch +from einops import rearrange +from torch import Tensor + +# Ref: https://github.com/black-forest-labs/flux/blob/main/src/flux/math.py +# Ref: https://github.com/lucidrains/rotary-embedding-torch + + +def compute_rope_rotations(length: int, + dim: int, + theta: int, + *, + freq_scaling: float = 1.0, + device: Union[torch.device, str] = 'cpu') -> Tensor: + assert dim % 2 == 0 + + with torch.amp.autocast(device_type='cuda', enabled=False): + pos = torch.arange(length, dtype=torch.float32, device=device) + freqs = 1.0 / (theta**(torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)) + freqs *= freq_scaling + + rot = torch.einsum('..., f -> ... f', pos, freqs) + rot = torch.stack([torch.cos(rot), -torch.sin(rot), torch.sin(rot), torch.cos(rot)], dim=-1) + rot = rearrange(rot, 'n d (i j) -> 1 n d i j', i=2, j=2) + return rot + + +def apply_rope(x: Tensor, rot: Tensor) -> tuple[Tensor, Tensor]: + with torch.amp.autocast(device_type='cuda', enabled=False): + _x = x.float() + _x = _x.view(*_x.shape[:-1], -1, 1, 2) + x_out = rot[..., 0] * _x[..., 0] + rot[..., 1] * _x[..., 1] + return x_out.reshape(*x.shape).to(dtype=x.dtype) + class TimestepEmbedder(nn.Module): """ diff --git a/think_sound/models/factory.py b/ThinkSound/models/factory.py similarity index 100% rename from think_sound/models/factory.py rename to ThinkSound/models/factory.py diff --git a/think_sound/models/local_attention.py b/ThinkSound/models/local_attention.py similarity index 100% rename from think_sound/models/local_attention.py rename to ThinkSound/models/local_attention.py diff --git a/think_sound/models/mmdit.py b/ThinkSound/models/mmdit.py similarity index 94% rename from think_sound/models/mmdit.py rename to ThinkSound/models/mmdit.py index 4ec310ab510079e2bbd0fa89d89f0bf55ea7526e..1c004d76073dfdb4a2ea1cbcbd24b764b357731a 100644 --- a/think_sound/models/mmdit.py +++ b/ThinkSound/models/mmdit.py @@ -6,10 +6,10 @@ import torch import torch.nn as nn import torch.nn.functional as F import sys -from .mmmodules.ext.rotary_embeddings import compute_rope_rotations -from .mmmodules.model.embeddings import TimestepEmbedder -from .mmmodules.model.low_level import MLP, ChannelLastConv1d, ConvMLP -from .mmmodules.model.transformer_layers import (FinalBlock, JointBlock, MMDitSingleBlock) +from .embeddings import compute_rope_rotations +from .embeddings import TimestepEmbedder +from .blocks import MLP, ChannelLastConv1d, ConvMLP +from .transformer_layers import (FinalBlock, JointBlock, MMDitSingleBlock) from .utils import resample log = logging.getLogger() @@ -24,7 +24,6 @@ class PreprocessedConditions: text_f_c: torch.Tensor -# Partially from https://github.com/facebookresearch/DiT class MMAudio(nn.Module): def __init__(self, @@ -94,7 +93,6 @@ class MMAudio(nn.Module): nn.Linear(hidden_dim * 4, hidden_dim, bias=False), nn.Sigmoid() ) - # 初始化最后一层权重为零,促进初始均匀融合 nn.init.zeros_(self.gated_mlp_v[3].weight) nn.init.zeros_(self.gated_mlp_t[3].weight) if v2: @@ -441,9 +439,9 @@ class MMAudio(nn.Module): # clip_f = torch.cat([clip_f,empty_clip_f], dim=0) # sync_f = torch.cat([sync_f,empty_sync_f], dim=0) # text_f = torch.cat([text_f,empty_text_f], dim=0) - clip_f = torch.cat([clip_f,self.get_empty_clip_sequence(bsz)], dim=0) - sync_f = torch.cat([sync_f,self.get_empty_sync_sequence(bsz)], dim=0) - text_f = torch.cat([text_f,self.get_empty_string_sequence(bsz)], dim=0) + clip_f = safe_cat(clip_f,self.get_empty_clip_sequence(bsz), dim=0, match_dim=1) + sync_f = safe_cat(sync_f,self.get_empty_sync_sequence(bsz), dim=0, match_dim=1) + text_f = safe_cat(text_f,self.get_empty_string_sequence(bsz), dim=0, match_dim=1) if t5_features is not None: empty_t5_features = torch.zeros_like(t5_features, device=latent.device) # t5_features = torch.cat([t5_features,empty_t5_features], dim=0) @@ -529,3 +527,52 @@ class MMAudio(nn.Module): def sync_seq_len(self) -> int: return self._sync_seq_len + + + + + + + + + + + + + + + + +def truncate_to_target(tensor, target_size, dim=1): + current_size = tensor.size(dim) + if current_size > target_size: + slices = [slice(None)] * tensor.dim() + slices[dim] = slice(0, target_size) + return tensor[slices] + return tensor + +def pad_to_target(tensor, target_size, dim=1, pad_value=0): + current_size = tensor.size(dim) + if current_size < target_size: + pad_size = target_size - current_size + + pad_config = [0, 0] * tensor.dim() + pad_index = 2 * (tensor.dim() - dim - 1) + 1 + pad_config[pad_index] = pad_size + + return torch.nn.functional.pad(tensor, pad_config, value=pad_value) + return tensor + + +def safe_cat(tensor1, tensor2, dim=0, match_dim=1): + + target_size = tensor2.size(match_dim) + + if tensor1.size(match_dim) > target_size: + tensor1 = truncate_to_target(tensor1, target_size, match_dim) + + else: + tensor1 = pad_to_target(tensor1, target_size, match_dim) + + return torch.cat([tensor1, tensor2], dim=dim) + diff --git a/think_sound/models/pretrained.py b/ThinkSound/models/pretrained.py similarity index 100% rename from think_sound/models/pretrained.py rename to ThinkSound/models/pretrained.py diff --git a/think_sound/models/pretransforms.py b/ThinkSound/models/pretransforms.py similarity index 100% rename from think_sound/models/pretransforms.py rename to ThinkSound/models/pretransforms.py diff --git a/think_sound/models/transformer.py b/ThinkSound/models/transformer.py similarity index 100% rename from think_sound/models/transformer.py rename to ThinkSound/models/transformer.py diff --git a/think_sound/models/mmmodules/model/transformer_layers.py b/ThinkSound/models/transformer_layers.py similarity index 98% rename from think_sound/models/mmmodules/model/transformer_layers.py rename to ThinkSound/models/transformer_layers.py index 6b06bc9850543f87ca9eb4217899674609d1620b..7174b07c4a21d75abde1a00485b2670fcc5bc77a 100644 --- a/think_sound/models/mmmodules/model/transformer_layers.py +++ b/ThinkSound/models/transformer_layers.py @@ -6,8 +6,8 @@ import torch.nn.functional as F from einops import rearrange from einops.layers.torch import Rearrange -from ..ext.rotary_embeddings import apply_rope -from ..model.low_level import MLP, ChannelLastConv1d, ConvMLP +from .embeddings import apply_rope +from .blocks import MLP, ChannelLastConv1d, ConvMLP try: from flash_attn import flash_attn_func, flash_attn_kvpacked_func print('flash_attn installed, using Flash Attention') diff --git a/think_sound/models/utils.py b/ThinkSound/models/utils.py similarity index 100% rename from think_sound/models/utils.py rename to ThinkSound/models/utils.py diff --git a/think_sound/training/__init__.py b/ThinkSound/training/__init__.py similarity index 100% rename from think_sound/training/__init__.py rename to ThinkSound/training/__init__.py diff --git a/think_sound/training/autoencoders.py b/ThinkSound/training/autoencoders.py similarity index 99% rename from think_sound/training/autoencoders.py rename to ThinkSound/training/autoencoders.py index f215393a12d446a64a8ebb84b0fabe49ce52258b..e97943e7fe11790cc4404d36ca90787420cb80bd 100644 --- a/think_sound/training/autoencoders.py +++ b/ThinkSound/training/autoencoders.py @@ -9,7 +9,6 @@ from .losses.auraloss import SumAndDifferenceSTFTLoss, MultiResolutionSTFTLoss, import lightning as L from lightning.pytorch.callbacks import Callback from ..models.autoencoders import AudioAutoencoder -from ..models.discriminators import EncodecDiscriminator, OobleckDiscriminator, DACGANLoss from ..models.bottleneck import VAEBottleneck, RVQBottleneck, DACRVQBottleneck, DACRVQVAEBottleneck, RVQVAEBottleneck, WassersteinBottleneck from .losses import MultiLoss, AuralossLoss, ValueLoss, L1Loss from .utils import create_optimizer_from_config, create_scheduler_from_config diff --git a/think_sound/training/diffusion.py b/ThinkSound/training/diffusion.py similarity index 54% rename from think_sound/training/diffusion.py rename to ThinkSound/training/diffusion.py index 067a1630591fdeff36188abedfd42c6cac13e6ee..d99c4772ff67d4453e212fef7b67186e6e5927a4 100644 --- a/think_sound/training/diffusion.py +++ b/ThinkSound/training/diffusion.py @@ -20,7 +20,6 @@ from pytorch_lightning.utilities.rank_zero import rank_zero_only from ..inference.sampling import get_alphas_sigmas, sample, sample_discrete_euler from ..models.diffusion import DiffusionModelWrapper, ConditionedDiffusionModelWrapper from ..models.autoencoders import DiffusionAutoencoder -from ..models.diffusion_prior import PriorType from .autoencoders import create_loss_modules_from_bottleneck from .losses import AuralossLoss, MSELoss, MultiLoss from .utils import create_optimizer_from_config, create_scheduler_from_config, mask_from_frac_lengths, generate_mask, generate_channel_mask @@ -846,10 +845,9 @@ class DiffusionCondTrainingWrapper(L.LightningModule): def predict_step(self, batch, batch_idx): reals, metadata = batch - # import ipdb - # ipdb.set_trace() ids = [item['id'] for item in metadata] batch_size, length = reals.shape[0], reals.shape[2] + print(f"Predicting {batch_size} samples with length {length} for ids: {ids}") with torch.amp.autocast('cuda'): conditioning = self.diffusion.conditioner(metadata, self.device) @@ -878,7 +876,6 @@ class DiffusionCondTrainingWrapper(L.LightningModule): end_time = time.time() execution_time = end_time - start_time print(f"执行时间: {execution_time:.2f} 秒") - breakpoint() if self.diffusion.pretransform is not None: fakes = self.diffusion.pretransform.decode(fakes) @@ -1077,947 +1074,3 @@ class DiffusionCondDemoCallback(Callback): gc.collect() torch.cuda.empty_cache() module.train() - -class DiffusionCondInpaintTrainingWrapper(L.LightningModule): - ''' - Wrapper for training a conditional audio diffusion model. - ''' - def __init__( - self, - model: ConditionedDiffusionModelWrapper, - lr: float = 1e-4, - max_mask_segments = 10, - log_loss_info: bool = False, - optimizer_configs: dict = None, - use_ema: bool = True, - pre_encoded: bool = False, - cfg_dropout_prob = 0.1, - timestep_sampler: tp.Literal["uniform", "logit_normal"] = "uniform", - ): - super().__init__() - - self.diffusion = model - - self.use_ema = use_ema - - if self.use_ema: - self.diffusion_ema = EMA( - self.diffusion.model, - beta=0.9999, - power=3/4, - update_every=1, - update_after_step=1, - include_online_model=False - ) - else: - self.diffusion_ema = None - - self.cfg_dropout_prob = cfg_dropout_prob - - self.lr = lr - self.max_mask_segments = max_mask_segments - - self.rng = torch.quasirandom.SobolEngine(1, scramble=True) - - self.timestep_sampler = timestep_sampler - - self.diffusion_objective = model.diffusion_objective - - self.loss_modules = [ - MSELoss("output", - "targets", - weight=1.0, - name="mse_loss" - ) - ] - - self.losses = MultiLoss(self.loss_modules) - - self.log_loss_info = log_loss_info - - assert lr is not None or optimizer_configs is not None, "Must specify either lr or optimizer_configs in training config" - - if optimizer_configs is None: - optimizer_configs = { - "diffusion": { - "optimizer": { - "type": "Adam", - "config": { - "lr": lr - } - } - } - } - else: - if lr is not None: - print(f"WARNING: learning_rate and optimizer_configs both specified in config. Ignoring learning_rate and using optimizer_configs.") - - self.optimizer_configs = optimizer_configs - - self.pre_encoded = pre_encoded - - def configure_optimizers(self): - diffusion_opt_config = self.optimizer_configs['diffusion'] - opt_diff = create_optimizer_from_config(diffusion_opt_config['optimizer'], self.diffusion.parameters()) - - if "scheduler" in diffusion_opt_config: - sched_diff = create_scheduler_from_config(diffusion_opt_config['scheduler'], opt_diff) - sched_diff_config = { - "scheduler": sched_diff, - "interval": "step" - } - return [opt_diff], [sched_diff_config] - - return [opt_diff] - - def random_mask(self, sequence, max_mask_length): - b, _, sequence_length = sequence.size() - - # Create a mask tensor for each batch element - masks = [] - - for i in range(b): - mask_type = random.randint(0, 2) - - if mask_type == 0: # Random mask with multiple segments - num_segments = random.randint(1, self.max_mask_segments) - max_segment_length = max_mask_length // num_segments - - segment_lengths = random.sample(range(1, max_segment_length + 1), num_segments) - - mask = torch.ones((1, 1, sequence_length)) - for length in segment_lengths: - mask_start = random.randint(0, sequence_length - length) - mask[:, :, mask_start:mask_start + length] = 0 - - elif mask_type == 1: # Full mask - mask = torch.zeros((1, 1, sequence_length)) - - elif mask_type == 2: # Causal mask - mask = torch.ones((1, 1, sequence_length)) - mask_length = random.randint(1, max_mask_length) - mask[:, :, -mask_length:] = 0 - - mask = mask.to(sequence.device) - masks.append(mask) - - # Concatenate the mask tensors into a single tensor - mask = torch.cat(masks, dim=0).to(sequence.device) - - # Apply the mask to the sequence tensor for each batch element - masked_sequence = sequence * mask - - return masked_sequence, mask - - def training_step(self, batch, batch_idx): - reals, metadata = batch - - p = Profiler() - - if reals.ndim == 4 and reals.shape[0] == 1: - reals = reals[0] - - loss_info = {} - - diffusion_input = reals - - if not self.pre_encoded: - loss_info["audio_reals"] = diffusion_input - - p.tick("setup") - - with torch.amp.autocast('cuda'): - conditioning = self.diffusion.conditioner(metadata, self.device) - - p.tick("conditioning") - - if self.diffusion.pretransform is not None: - self.diffusion.pretransform.to(self.device) - - if not self.pre_encoded: - with torch.amp.autocast('cuda') and torch.set_grad_enabled(self.diffusion.pretransform.enable_grad): - diffusion_input = self.diffusion.pretransform.encode(diffusion_input) - p.tick("pretransform") - - # If mask_padding is on, interpolate the padding masks to the size of the pretransformed input - # if use_padding_mask: - # padding_masks = F.interpolate(padding_masks.unsqueeze(1).float(), size=diffusion_input.shape[2], mode="nearest").squeeze(1).bool() - else: - # Apply scale to pre-encoded latents if needed, as the pretransform encode function will not be run - if hasattr(self.diffusion.pretransform, "scale") and self.diffusion.pretransform.scale != 1.0: - diffusion_input = diffusion_input / self.diffusion.pretransform.scale - - # Max mask size is the full sequence length - max_mask_length = diffusion_input.shape[2] - - # Create a mask of random length for a random slice of the input - masked_input, mask = self.random_mask(diffusion_input, max_mask_length) - - # conditioning['inpaint_mask'] = [mask] - conditioning['inpaint_masked_input'] = [masked_input] - - if self.timestep_sampler == "uniform": - # Draw uniformly distributed continuous timesteps - t = self.rng.draw(reals.shape[0])[:, 0].to(self.device) - elif self.timestep_sampler == "logit_normal": - t = torch.sigmoid(torch.randn(reals.shape[0], device=self.device)) - - # Calculate the noise schedule parameters for those timesteps - if self.diffusion_objective == "v": - alphas, sigmas = get_alphas_sigmas(t) - elif self.diffusion_objective == "rectified_flow": - alphas, sigmas = 1-t, t - - # Combine the ground truth data and the noise - alphas = alphas[:, None, None] - sigmas = sigmas[:, None, None] - noise = torch.randn_like(diffusion_input) - noised_inputs = diffusion_input * alphas + noise * sigmas - - if self.diffusion_objective == "v": - targets = noise * alphas - diffusion_input * sigmas - elif self.diffusion_objective == "rectified_flow": - targets = noise - diffusion_input - - p.tick("noise") - - extra_args = {} - - with torch.amp.autocast('cuda'): - p.tick("amp") - output = self.diffusion(noised_inputs, t, cond=conditioning, cfg_dropout_prob = self.cfg_dropout_prob, **extra_args) - p.tick("diffusion") - - loss_info.update({ - "output": output, - "targets": targets, - }) - - loss, losses = self.losses(loss_info) - - if self.log_loss_info: - # Loss debugging logs - num_loss_buckets = 10 - bucket_size = 1 / num_loss_buckets - loss_all = F.mse_loss(output, targets, reduction="none") - - sigmas = rearrange(self.all_gather(sigmas), "w b c n -> (w b) c n").squeeze() - - # gather loss_all across all GPUs - loss_all = rearrange(self.all_gather(loss_all), "w b c n -> (w b) c n") - - # Bucket loss values based on corresponding sigma values, bucketing sigma values by bucket_size - loss_all = torch.stack([loss_all[(sigmas >= i) & (sigmas < i + bucket_size)].mean() for i in torch.arange(0, 1, bucket_size).to(self.device)]) - - # Log bucketed losses with corresponding sigma bucket values, if it's not NaN - debug_log_dict = { - f"model/loss_all_{i/num_loss_buckets:.1f}": loss_all[i].detach() for i in range(num_loss_buckets) if not torch.isnan(loss_all[i]) - } - - self.log_dict(debug_log_dict) - - log_dict = { - 'train/loss': loss.detach(), - 'train/std_data': diffusion_input.std(), - 'train/lr': self.trainer.optimizers[0].param_groups[0]['lr'] - } - - for loss_name, loss_value in losses.items(): - log_dict[f"train/{loss_name}"] = loss_value.detach() - - self.log_dict(log_dict, prog_bar=True, on_step=True) - p.tick("log") - #print(f"Profiler: {p}") - return loss - - def on_before_zero_grad(self, *args, **kwargs): - if self.diffusion_ema is not None: - self.diffusion_ema.update() - - def export_model(self, path, use_safetensors=False): - if self.diffusion_ema is not None: - self.diffusion.model = self.diffusion_ema.ema_model - - if use_safetensors: - save_file(self.diffusion.state_dict(), path) - else: - torch.save({"state_dict": self.diffusion.state_dict()}, path) - -class DiffusionCondInpaintDemoCallback(Callback): - def __init__( - self, - demo_dl, - demo_every=2000, - demo_steps=250, - sample_size=65536, - sample_rate=48000, - demo_cfg_scales: tp.Optional[tp.List[int]] = [3, 5, 7] - ): - super().__init__() - self.demo_every = demo_every - self.demo_steps = demo_steps - self.demo_samples = sample_size - self.demo_dl = iter(demo_dl) - self.sample_rate = sample_rate - self.demo_cfg_scales = demo_cfg_scales - self.last_demo_step = -1 - - @rank_zero_only - @torch.no_grad() - def on_train_batch_end(self, trainer, module: DiffusionCondTrainingWrapper, outputs, batch, batch_idx): - if (trainer.global_step - 1) % self.demo_every != 0 or self.last_demo_step == trainer.global_step: - return - - self.last_demo_step = trainer.global_step - - try: - log_dict = {} - - demo_reals, metadata = next(self.demo_dl) - - # Remove extra dimension added by WebDataset - if demo_reals.ndim == 4 and demo_reals.shape[0] == 1: - demo_reals = demo_reals[0] - - demo_reals = demo_reals.to(module.device) - - if not module.pre_encoded: - # Log the real audio - log_dict[f'demo_reals_melspec_left'] = wandb.Image(audio_spectrogram_image(rearrange(demo_reals, "b d n -> d (b n)").mul(32767).to(torch.int16).cpu())) - # log_dict[f'demo_reals'] = wandb.Audio(rearrange(demo_reals, "b d n -> d (b n)").mul(32767).to(torch.int16).cpu(), sample_rate=self.sample_rate, caption="demo reals") - - if module.diffusion.pretransform is not None: - module.diffusion.pretransform.to(module.device) - with torch.amp.autocast('cuda'): - demo_reals = module.diffusion.pretransform.encode(demo_reals) - - demo_samples = demo_reals.shape[2] - - # Get conditioning - conditioning = module.diffusion.conditioner(metadata, module.device) - - masked_input, mask = module.random_mask(demo_reals, demo_reals.shape[2]) - - conditioning['inpaint_mask'] = [mask] - conditioning['inpaint_masked_input'] = [masked_input] - - if module.diffusion.pretransform is not None: - log_dict[f'demo_masked_input'] = wandb.Image(tokens_spectrogram_image(masked_input.cpu())) - else: - log_dict[f'demo_masked_input'] = wandb.Image(audio_spectrogram_image(rearrange(masked_input, "b c t -> c (b t)").mul(32767).to(torch.int16).cpu())) - - cond_inputs = module.diffusion.get_conditioning_inputs(conditioning) - - noise = torch.randn([demo_reals.shape[0], module.diffusion.io_channels, demo_samples]).to(module.device) - - trainer.logger.experiment.log(log_dict) - - for cfg_scale in self.demo_cfg_scales: - model = module.diffusion_ema.model if module.diffusion_ema is not None else module.diffusion.model - print(f"Generating demo for cfg scale {cfg_scale}") - - if module.diffusion_objective == "v": - fakes = sample(model, noise, self.demo_steps, 0, **cond_inputs, cfg_scale=cfg_scale, batch_cfg=True) - elif module.diffusion_objective == "rectified_flow": - fakes = sample_discrete_euler(model, noise, self.demo_steps, **cond_inputs, cfg_scale=cfg_scale, batch_cfg=True) - - if module.diffusion.pretransform is not None: - with torch.amp.autocast('cuda'): - fakes = module.diffusion.pretransform.decode(fakes) - - # Put the demos together - fakes = rearrange(fakes, 'b d n -> d (b n)') - - log_dict = {} - - filename = f'demo_cfg_{cfg_scale}_{trainer.global_step:08}.wav' - fakes = fakes.to(torch.float32).div(torch.max(torch.abs(fakes))).mul(32767).to(torch.int16).cpu() - torchaudio.save(filename, fakes, self.sample_rate) - - log_dict[f'demo_cfg_{cfg_scale}'] = wandb.Audio(filename, - sample_rate=self.sample_rate, - caption=f'Reconstructed') - - log_dict[f'demo_melspec_left_cfg_{cfg_scale}'] = wandb.Image(audio_spectrogram_image(fakes)) - - trainer.logger.experiment.log(log_dict) - except Exception as e: - print(f'{type(e).__name__}: {e}') - raise e - -class DiffusionAutoencoderTrainingWrapper(L.LightningModule): - ''' - Wrapper for training a diffusion autoencoder - ''' - def __init__( - self, - model: DiffusionAutoencoder, - lr: float = 1e-4, - ema_copy = None, - use_reconstruction_loss: bool = False - ): - super().__init__() - - self.diffae = model - - self.diffae_ema = EMA( - self.diffae, - ema_model=ema_copy, - beta=0.9999, - power=3/4, - update_every=1, - update_after_step=1, - include_online_model=False - ) - - self.lr = lr - - self.rng = torch.quasirandom.SobolEngine(1, scramble=True) - - loss_modules = [ - MSELoss("v", - "targets", - weight=1.0, - name="mse_loss" - ) - ] - - if model.bottleneck is not None: - # TODO: Use loss config for configurable bottleneck weights and reconstruction losses - loss_modules += create_loss_modules_from_bottleneck(model.bottleneck, {}) - - self.use_reconstruction_loss = use_reconstruction_loss - - if use_reconstruction_loss: - scales = [2048, 1024, 512, 256, 128, 64, 32] - hop_sizes = [] - win_lengths = [] - overlap = 0.75 - for s in scales: - hop_sizes.append(int(s * (1 - overlap))) - win_lengths.append(s) - - sample_rate = model.sample_rate - - stft_loss_args = { - "fft_sizes": scales, - "hop_sizes": hop_sizes, - "win_lengths": win_lengths, - "perceptual_weighting": True - } - - out_channels = model.out_channels - - if model.pretransform is not None: - out_channels = model.pretransform.io_channels - - if out_channels == 2: - self.sdstft = auraloss.freq.SumAndDifferenceSTFTLoss(sample_rate=sample_rate, **stft_loss_args) - else: - self.sdstft = auraloss.freq.MultiResolutionSTFTLoss(sample_rate=sample_rate, **stft_loss_args) - - loss_modules.append( - AuralossLoss(self.sdstft, 'audio_reals', 'audio_pred', name='mrstft_loss', weight=0.1), # Reconstruction loss - ) - - self.losses = MultiLoss(loss_modules) - - def configure_optimizers(self): - return optim.Adam([*self.diffae.parameters()], lr=self.lr) - - def training_step(self, batch, batch_idx): - reals = batch[0] - - if reals.ndim == 4 and reals.shape[0] == 1: - reals = reals[0] - - loss_info = {} - - loss_info["audio_reals"] = reals - - if self.diffae.pretransform is not None: - with torch.no_grad(): - reals = self.diffae.pretransform.encode(reals) - - loss_info["reals"] = reals - - #Encode reals, skipping the pretransform since it was already applied - latents, encoder_info = self.diffae.encode(reals, return_info=True, skip_pretransform=True) - - loss_info["latents"] = latents - loss_info.update(encoder_info) - - if self.diffae.decoder is not None: - latents = self.diffae.decoder(latents) - - # Upsample latents to match diffusion length - if latents.shape[2] != reals.shape[2]: - latents = F.interpolate(latents, size=reals.shape[2], mode='nearest') - - loss_info["latents_upsampled"] = latents - - # Draw uniformly distributed continuous timesteps - t = self.rng.draw(reals.shape[0])[:, 0].to(self.device) - - # Calculate the noise schedule parameters for those timesteps - alphas, sigmas = get_alphas_sigmas(t) - - # Combine the ground truth data and the noise - alphas = alphas[:, None, None] - sigmas = sigmas[:, None, None] - noise = torch.randn_like(reals) - noised_reals = reals * alphas + noise * sigmas - targets = noise * alphas - reals * sigmas - - with torch.amp.autocast('cuda'): - v = self.diffae.diffusion(noised_reals, t, input_concat_cond=latents) - - loss_info.update({ - "v": v, - "targets": targets - }) - - if self.use_reconstruction_loss: - pred = noised_reals * alphas - v * sigmas - - loss_info["pred"] = pred - - if self.diffae.pretransform is not None: - pred = self.diffae.pretransform.decode(pred) - loss_info["audio_pred"] = pred - - loss, losses = self.losses(loss_info) - - log_dict = { - 'train/loss': loss.detach(), - 'train/std_data': reals.std(), - 'train/latent_std': latents.std(), - } - - for loss_name, loss_value in losses.items(): - log_dict[f"train/{loss_name}"] = loss_value.detach() - - self.log_dict(log_dict, prog_bar=True, on_step=True) - return loss - - def on_before_zero_grad(self, *args, **kwargs): - self.diffae_ema.update() - - def export_model(self, path, use_safetensors=False): - - model = self.diffae_ema.ema_model - - if use_safetensors: - save_file(model.state_dict(), path) - else: - torch.save({"state_dict": model.state_dict()}, path) - -class DiffusionAutoencoderDemoCallback(Callback): - def __init__( - self, - demo_dl, - demo_every=2000, - demo_steps=250, - sample_size=65536, - sample_rate=48000 - ): - super().__init__() - self.demo_every = demo_every - self.demo_steps = demo_steps - self.demo_samples = sample_size - self.demo_dl = iter(demo_dl) - self.sample_rate = sample_rate - self.last_demo_step = -1 - - @rank_zero_only - @torch.no_grad() - def on_train_batch_end(self, trainer, module: DiffusionAutoencoderTrainingWrapper, outputs, batch, batch_idx): - if (trainer.global_step - 1) % self.demo_every != 0 or self.last_demo_step == trainer.global_step: - return - - self.last_demo_step = trainer.global_step - - demo_reals, _ = next(self.demo_dl) - - # Remove extra dimension added by WebDataset - if demo_reals.ndim == 4 and demo_reals.shape[0] == 1: - demo_reals = demo_reals[0] - - encoder_input = demo_reals - - encoder_input = encoder_input.to(module.device) - - demo_reals = demo_reals.to(module.device) - - with torch.no_grad() and torch.amp.autocast('cuda'): - latents = module.diffae_ema.ema_model.encode(encoder_input).float() - fakes = module.diffae_ema.ema_model.decode(latents, steps=self.demo_steps) - - #Interleave reals and fakes - reals_fakes = rearrange([demo_reals, fakes], 'i b d n -> (b i) d n') - - # Put the demos together - reals_fakes = rearrange(reals_fakes, 'b d n -> d (b n)') - - log_dict = {} - - filename = f'recon_{trainer.global_step:08}.wav' - reals_fakes = reals_fakes.to(torch.float32).div(torch.max(torch.abs(reals_fakes))).mul(32767).to(torch.int16).cpu() - torchaudio.save(filename, reals_fakes, self.sample_rate) - - log_dict[f'recon'] = wandb.Audio(filename, - sample_rate=self.sample_rate, - caption=f'Reconstructed') - - log_dict[f'embeddings_3dpca'] = pca_point_cloud(latents) - log_dict[f'embeddings_spec'] = wandb.Image(tokens_spectrogram_image(latents)) - - log_dict[f'recon_melspec_left'] = wandb.Image(audio_spectrogram_image(reals_fakes)) - - if module.diffae_ema.ema_model.pretransform is not None: - with torch.no_grad() and torch.amp.autocast('cuda'): - initial_latents = module.diffae_ema.ema_model.pretransform.encode(encoder_input) - first_stage_fakes = module.diffae_ema.ema_model.pretransform.decode(initial_latents) - first_stage_fakes = rearrange(first_stage_fakes, 'b d n -> d (b n)') - first_stage_fakes = first_stage_fakes.to(torch.float32).mul(32767).to(torch.int16).cpu() - first_stage_filename = f'first_stage_{trainer.global_step:08}.wav' - torchaudio.save(first_stage_filename, first_stage_fakes, self.sample_rate) - - log_dict[f'first_stage_latents'] = wandb.Image(tokens_spectrogram_image(initial_latents)) - - log_dict[f'first_stage'] = wandb.Audio(first_stage_filename, - sample_rate=self.sample_rate, - caption=f'First Stage Reconstructed') - - log_dict[f'first_stage_melspec_left'] = wandb.Image(audio_spectrogram_image(first_stage_fakes)) - - - trainer.logger.experiment.log(log_dict) - -def create_source_mixture(reals, num_sources=2): - # Create a fake mixture source by mixing elements from the training batch together with random offsets - source = torch.zeros_like(reals) - for i in range(reals.shape[0]): - sources_added = 0 - - js = list(range(reals.shape[0])) - random.shuffle(js) - for j in js: - if i == j or (i != j and sources_added < num_sources): - # Randomly offset the mixed element between 0 and the length of the source - seq_len = reals.shape[2] - offset = random.randint(0, seq_len-1) - source[i, :, offset:] += reals[j, :, :-offset] - if i == j: - # If this is the real one, shift the reals as well to ensure alignment - new_reals = torch.zeros_like(reals[i]) - new_reals[:, offset:] = reals[i, :, :-offset] - reals[i] = new_reals - sources_added += 1 - - return source - -class DiffusionPriorTrainingWrapper(L.LightningModule): - ''' - Wrapper for training a diffusion prior for inverse problems - Prior types: - mono_stereo: The prior is conditioned on a mono version of the audio to generate a stereo version - ''' - def __init__( - self, - model: ConditionedDiffusionModelWrapper, - lr: float = 1e-4, - ema_copy = None, - prior_type: PriorType = PriorType.MonoToStereo, - use_reconstruction_loss: bool = False, - log_loss_info: bool = False, - ): - super().__init__() - - self.diffusion = model - - self.diffusion_ema = EMA( - self.diffusion, - ema_model=ema_copy, - beta=0.9999, - power=3/4, - update_every=1, - update_after_step=1, - include_online_model=False - ) - - self.lr = lr - - self.rng = torch.quasirandom.SobolEngine(1, scramble=True) - - self.log_loss_info = log_loss_info - - loss_modules = [ - MSELoss("v", - "targets", - weight=1.0, - name="mse_loss" - ) - ] - - self.use_reconstruction_loss = use_reconstruction_loss - - if use_reconstruction_loss: - scales = [2048, 1024, 512, 256, 128, 64, 32] - hop_sizes = [] - win_lengths = [] - overlap = 0.75 - for s in scales: - hop_sizes.append(int(s * (1 - overlap))) - win_lengths.append(s) - - sample_rate = model.sample_rate - - stft_loss_args = { - "fft_sizes": scales, - "hop_sizes": hop_sizes, - "win_lengths": win_lengths, - "perceptual_weighting": True - } - - out_channels = model.io_channels - - - if model.pretransform is not None: - out_channels = model.pretransform.io_channels - self.audio_out_channels = out_channels - - if self.audio_out_channels == 2: - self.sdstft = auraloss.freq.SumAndDifferenceSTFTLoss(sample_rate=sample_rate, **stft_loss_args) - self.lrstft = auraloss.freq.MultiResolutionSTFTLoss(sample_rate=sample_rate, **stft_loss_args) - - # Add left and right channel reconstruction losses in addition to the sum and difference - loss_modules += [ - AuralossLoss(self.lrstft, 'audio_reals_left', 'pred_left', name='stft_loss_left', weight=0.05), - AuralossLoss(self.lrstft, 'audio_reals_right', 'pred_right', name='stft_loss_right', weight=0.05), - ] - - else: - self.sdstft = auraloss.freq.MultiResolutionSTFTLoss(sample_rate=sample_rate, **stft_loss_args) - - loss_modules.append( - AuralossLoss(self.sdstft, 'audio_reals', 'audio_pred', name='mrstft_loss', weight=0.1), # Reconstruction loss - ) - - self.losses = MultiLoss(loss_modules) - - self.prior_type = prior_type - - def configure_optimizers(self): - return optim.Adam([*self.diffusion.parameters()], lr=self.lr) - - def training_step(self, batch, batch_idx): - reals, metadata = batch - - if reals.ndim == 4 and reals.shape[0] == 1: - reals = reals[0] - - loss_info = {} - - loss_info["audio_reals"] = reals - - if self.prior_type == PriorType.MonoToStereo: - source = reals.mean(dim=1, keepdim=True).repeat(1, reals.shape[1], 1).to(self.device) - loss_info["audio_reals_mono"] = source - else: - raise ValueError(f"Unknown prior type {self.prior_type}") - - if self.diffusion.pretransform is not None: - with torch.no_grad(): - reals = self.diffusion.pretransform.encode(reals) - - if self.prior_type in [PriorType.MonoToStereo]: - source = self.diffusion.pretransform.encode(source) - - if self.diffusion.conditioner is not None: - with torch.amp.autocast('cuda'): - conditioning = self.diffusion.conditioner(metadata, self.device) - else: - conditioning = {} - - loss_info["reals"] = reals - - # Draw uniformly distributed continuous timesteps - t = self.rng.draw(reals.shape[0])[:, 0].to(self.device) - - # Calculate the noise schedule parameters for those timesteps - alphas, sigmas = get_alphas_sigmas(t) - - # Combine the ground truth data and the noise - alphas = alphas[:, None, None] - sigmas = sigmas[:, None, None] - noise = torch.randn_like(reals) - noised_reals = reals * alphas + noise * sigmas - targets = noise * alphas - reals * sigmas - - with torch.amp.autocast('cuda'): - - conditioning['source'] = [source] - - v = self.diffusion(noised_reals, t, cond=conditioning, cfg_dropout_prob = 0.1) - - loss_info.update({ - "v": v, - "targets": targets - }) - - if self.use_reconstruction_loss: - pred = noised_reals * alphas - v * sigmas - - loss_info["pred"] = pred - - if self.diffusion.pretransform is not None: - pred = self.diffusion.pretransform.decode(pred) - loss_info["audio_pred"] = pred - - if self.audio_out_channels == 2: - loss_info["pred_left"] = pred[:, 0:1, :] - loss_info["pred_right"] = pred[:, 1:2, :] - loss_info["audio_reals_left"] = loss_info["audio_reals"][:, 0:1, :] - loss_info["audio_reals_right"] = loss_info["audio_reals"][:, 1:2, :] - - loss, losses = self.losses(loss_info) - - if self.log_loss_info: - # Loss debugging logs - num_loss_buckets = 10 - bucket_size = 1 / num_loss_buckets - loss_all = F.mse_loss(v, targets, reduction="none") - - sigmas = rearrange(self.all_gather(sigmas), "w b c n -> (w b) c n").squeeze() - - # gather loss_all across all GPUs - loss_all = rearrange(self.all_gather(loss_all), "w b c n -> (w b) c n") - - # Bucket loss values based on corresponding sigma values, bucketing sigma values by bucket_size - loss_all = torch.stack([loss_all[(sigmas >= i) & (sigmas < i + bucket_size)].mean() for i in torch.arange(0, 1, bucket_size).to(self.device)]) - - # Log bucketed losses with corresponding sigma bucket values, if it's not NaN - debug_log_dict = { - f"model/loss_all_{i/num_loss_buckets:.1f}": loss_all[i].detach() for i in range(num_loss_buckets) if not torch.isnan(loss_all[i]) - } - - self.log_dict(debug_log_dict) - - log_dict = { - 'train/loss': loss.detach(), - 'train/std_data': reals.std() - } - - for loss_name, loss_value in losses.items(): - log_dict[f"train/{loss_name}"] = loss_value.detach() - - self.log_dict(log_dict, prog_bar=True, on_step=True) - return loss - - def on_before_zero_grad(self, *args, **kwargs): - self.diffusion_ema.update() - - def export_model(self, path, use_safetensors=False): - - #model = self.diffusion_ema.ema_model - model = self.diffusion - - if use_safetensors: - save_file(model.state_dict(), path) - else: - torch.save({"state_dict": model.state_dict()}, path) - -class DiffusionPriorDemoCallback(Callback): - def __init__( - self, - demo_dl, - demo_every=2000, - demo_steps=250, - sample_size=65536, - sample_rate=48000 - ): - super().__init__() - - self.demo_every = demo_every - self.demo_steps = demo_steps - self.demo_samples = sample_size - self.demo_dl = iter(demo_dl) - self.sample_rate = sample_rate - self.last_demo_step = -1 - - @rank_zero_only - @torch.no_grad() - def on_train_batch_end(self, trainer, module: DiffusionAutoencoderTrainingWrapper, outputs, batch, batch_idx): - if (trainer.global_step - 1) % self.demo_every != 0 or self.last_demo_step == trainer.global_step: - return - - self.last_demo_step = trainer.global_step - - demo_reals, metadata = next(self.demo_dl) - # import ipdb - # ipdb.set_trace() - # Remove extra dimension added by WebDataset - if demo_reals.ndim == 4 and demo_reals.shape[0] == 1: - demo_reals = demo_reals[0] - - demo_reals = demo_reals.to(module.device) - - encoder_input = demo_reals - - if module.diffusion.conditioner is not None: - with torch.amp.autocast('cuda'): - conditioning_tensors = module.diffusion.conditioner(metadata, module.device) - - else: - conditioning_tensors = {} - - - with torch.no_grad() and torch.amp.autocast('cuda'): - if module.prior_type == PriorType.MonoToStereo and encoder_input.shape[1] > 1: - source = encoder_input.mean(dim=1, keepdim=True).repeat(1, encoder_input.shape[1], 1).to(module.device) - - if module.diffusion.pretransform is not None: - encoder_input = module.diffusion.pretransform.encode(encoder_input) - source_input = module.diffusion.pretransform.encode(source) - else: - source_input = source - - conditioning_tensors['source'] = [source_input] - - fakes = sample(module.diffusion_ema.model, torch.randn_like(encoder_input), self.demo_steps, 0, cond=conditioning_tensors) - - if module.diffusion.pretransform is not None: - fakes = module.diffusion.pretransform.decode(fakes) - - #Interleave reals and fakes - reals_fakes = rearrange([demo_reals, fakes], 'i b d n -> (b i) d n') - - # Put the demos together - reals_fakes = rearrange(reals_fakes, 'b d n -> d (b n)') - - log_dict = {} - - filename = f'recon_mono_{trainer.global_step:08}.wav' - reals_fakes = reals_fakes.to(torch.float32).div(torch.max(torch.abs(reals_fakes))).mul(32767).to(torch.int16).cpu() - torchaudio.save(filename, reals_fakes, self.sample_rate) - - log_dict[f'recon'] = wandb.Audio(filename, - sample_rate=self.sample_rate, - caption=f'Reconstructed') - - log_dict[f'recon_melspec_left'] = wandb.Image(audio_spectrogram_image(reals_fakes)) - - #Log the source - filename = f'source_{trainer.global_step:08}.wav' - source = rearrange(source, 'b d n -> d (b n)') - source = source.to(torch.float32).mul(32767).to(torch.int16).cpu() - torchaudio.save(filename, source, self.sample_rate) - - log_dict[f'source'] = wandb.Audio(filename, - sample_rate=self.sample_rate, - caption=f'Source') - - log_dict[f'source_melspec_left'] = wandb.Image(audio_spectrogram_image(source)) - - trainer.logger.experiment.log(log_dict) \ No newline at end of file diff --git a/think_sound/training/factory.py b/ThinkSound/training/factory.py similarity index 100% rename from think_sound/training/factory.py rename to ThinkSound/training/factory.py diff --git a/think_sound/training/losses/__init__.py b/ThinkSound/training/losses/__init__.py similarity index 100% rename from think_sound/training/losses/__init__.py rename to ThinkSound/training/losses/__init__.py diff --git a/think_sound/training/losses/auraloss.py b/ThinkSound/training/losses/auraloss.py similarity index 100% rename from think_sound/training/losses/auraloss.py rename to ThinkSound/training/losses/auraloss.py diff --git a/think_sound/training/losses/losses.py b/ThinkSound/training/losses/losses.py similarity index 100% rename from think_sound/training/losses/losses.py rename to ThinkSound/training/losses/losses.py diff --git a/think_sound/training/utils.py b/ThinkSound/training/utils.py similarity index 100% rename from think_sound/training/utils.py rename to ThinkSound/training/utils.py diff --git a/app.py b/app.py index ef2be28d734f721d4d9b087ab27cdc1e36722d9a..dc3519b9eecfa1ff4bfc85cfa27ec33e30e00186 100644 --- a/app.py +++ b/app.py @@ -14,13 +14,12 @@ from lightning.pytorch.tuner import Tuner from lightning.pytorch import seed_everything import random from datetime import datetime -# from think_sound.data.dataset import create_dataloader_from_config -from think_sound.data.datamodule import DataModule -from think_sound.models import create_model_from_config -from think_sound.models.utils import load_ckpt_state_dict, remove_weight_norm_from_model -from think_sound.training import create_training_wrapper_from_config, create_demo_callback_from_config -from think_sound.training.utils import copy_state_dict -from think_sound.inference.sampling import get_alphas_sigmas, sample, sample_discrete_euler +from ThinkSound.data.datamodule import DataModule +from ThinkSound.models import create_model_from_config +from ThinkSound.models.utils import load_ckpt_state_dict, remove_weight_norm_from_model +from ThinkSound.training import create_training_wrapper_from_config, create_demo_callback_from_config +from ThinkSound.training.utils import copy_state_dict +from ThinkSound.inference.sampling import get_alphas_sigmas, sample, sample_discrete_euler from data_utils.v2a_utils.feature_utils_224 import FeaturesUtils from torch.utils.data import Dataset from typing import Optional, Union @@ -34,7 +33,7 @@ import tempfile import subprocess from huggingface_hub import hf_hub_download from moviepy.editor import VideoFileClip -os.system("conda install -c conda-forge 'ffmpeg<7'") +# os.system("conda install -c conda-forge 'ffmpeg<7'") _CLIP_SIZE = 224 _CLIP_FPS = 8.0 @@ -101,7 +100,7 @@ class VGGSound(Dataset): self.resampler = {} - def sample(self, video_path,label): + def sample(self, video_path,label,cot): video_id = video_path reader = StreamingMediaDecoder(video_path) @@ -156,7 +155,7 @@ class VGGSound(Dataset): # padding using the last frame, but no more than 2 current_length = sync_chunk.shape[0] last_frame = sync_chunk[-1] - # 重复最后一帧以进行填充 + padding = last_frame.repeat(self.sync_expected_length - current_length, 1, 1, 1) assert self.sync_expected_length - current_length < 12, f'sync can pad no more than 2 while {self.sync_expected_length - current_length}' sync_chunk = torch.cat((sync_chunk, padding), dim=0) @@ -170,6 +169,7 @@ class VGGSound(Dataset): data = { 'id': video_id, 'caption': label, + 'caption_cot': cot, # 'audio': audio_chunk, 'clip_video': clip_chunk, 'sync_video': sync_chunk, @@ -187,17 +187,16 @@ else: print(f"load in device {device}") -vae_ckpt = hf_hub_download(repo_id="liuhuadai/ThinkSound", filename="vae.ckpt",repo_type="model") -synchformer_ckpt = hf_hub_download(repo_id="liuhuadai/ThinkSound", filename="synchformer_state_dict.pth",repo_type="model") +vae_ckpt = hf_hub_download(repo_id="FunAudioLLM/ThinkSound", filename="vae.ckpt",repo_type="model") +synchformer_ckpt = hf_hub_download(repo_id="FunAudioLLM/ThinkSound", filename="synchformer_state_dict.pth",repo_type="model") + feature_extractor = FeaturesUtils( - vae_ckpt=vae_ckpt, - vae_config='think_sound/configs/model_configs/autoencoders/stable_audio_2_0_vae.json', + vae_ckpt=None, + vae_config='ThinkSound/configs/model_configs/stable_audio_2_0_vae.json', enable_conditions=True, synchformer_ckpt=synchformer_ckpt ).eval().to(extra_device) - - args = get_all_args() seed = 10086 @@ -206,7 +205,7 @@ seed_everything(seed, workers=True) #Get JSON config from args.model_config -with open("think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_kernel_size3.json") as f: +with open("ThinkSound/configs/model_configs/thinksound.json") as f: model_config = json.load(f) model = create_model_from_config(model_config) @@ -229,7 +228,7 @@ model.pretransform.load_state_dict(load_vae_state) # Remove weight_norm from the pretransform if specified if args.remove_pretransform_weight_norm == "post_load": remove_weight_norm_from_model(model.pretransform) -ckpt_path = hf_hub_download(repo_id="liuhuadai/ThinkSound", filename="thinksound.ckpt",repo_type="model") +ckpt_path = hf_hub_download(repo_id="FunAudioLLM/ThinkSound", filename="thinksound.ckpt",repo_type="model") training_wrapper = create_training_wrapper_from_config(model_config, model) # 加载模型权重时根据设备选择map_location training_wrapper.load_state_dict(torch.load(ckpt_path)['state_dict']) @@ -243,16 +242,17 @@ def get_video_duration(video_path): @spaces.GPU(duration=60) @torch.inference_mode() @torch.no_grad() -def get_audio(video_path, caption): - # 允许caption为空 +def get_audio(video_path, caption, cot): if caption is None: caption = '' + if cot is None: + cot = caption timer = Timer(duration="00:15:00:00") #get video duration duration_sec = get_video_duration(video_path) print(duration_sec) preprocesser = VGGSound(duration_sec=duration_sec) - data = preprocesser.sample(video_path, caption) + data = preprocesser.sample(video_path, caption, cot) @@ -261,7 +261,7 @@ def get_audio(video_path, caption): preprocessed_data['metaclip_global_text_features'] = metaclip_global_text_features.detach().cpu().squeeze(0) preprocessed_data['metaclip_text_features'] = metaclip_text_features.detach().cpu().squeeze(0) - t5_features = feature_extractor.encode_t5_text(data['caption']) + t5_features = feature_extractor.encode_t5_text(data['caption_cot']) preprocessed_data['t5_features'] = t5_features.detach().cpu().squeeze(0) clip_features = feature_extractor.encode_video_with_clip(data['clip_video'].unsqueeze(0).to(extra_device)) @@ -305,56 +305,47 @@ def get_audio(video_path, caption): fakes = training_wrapper.diffusion.pretransform.decode(fakes) audios = fakes.to(torch.float32).div(torch.max(torch.abs(fakes))).clamp(-1, 1).mul(32767).to(torch.int16).cpu() - # 保存临时音频文件 with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp_audio: torchaudio.save(tmp_audio.name, audios[0], 44100) audio_path = tmp_audio.name + return audio_path -def synthesize_video_with_audio(video_file, caption): - # 允许caption为空 - if caption is None: - caption = '' - audio_path = get_audio(video_file, caption) +def synthesize_video_with_audio(video_file, caption, cot): + audio_path = get_audio(video_file, caption, cot) with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as tmp_video: output_video_path = tmp_video.name - # ffmpeg命令:用新音频替换原视频音轨 + cmd = [ 'ffmpeg', '-y', '-i', video_file, '-i', audio_path, '-c:v', 'copy', '-map', '0:v:0', '-map', '1:a:0', '-shortest', output_video_path ] subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + return output_video_path -# Gradio界面 -with gr.Blocks() as demo: - gr.Markdown( - """ -# ThinkSound\n -ThinkSound is a unified Any2Audio generation framework with flow matching guided by Chain-of-Thought (CoT) reasoning. - -Upload video and caption (optional), and get video with audio! - -""" - ) - with gr.Row(): - video_input = gr.Video(label="upload video") - caption_input = gr.Textbox(label="caption(optional)", placeholder="can be empty", lines=1) - output_video = gr.Video(label="output video") - btn = gr.Button("start synthesize") - btn.click(fn=synthesize_video_with_audio, inputs=[video_input, caption_input], outputs=output_video) - - gr.Examples( - examples=[ - ["./examples/1_mute.mp4", "Playing Trumpet", "./examples/1.mp4"], - ["./examples/2_mute.mp4", "Axe striking", "./examples/2.mp4"], - ["./examples/3_mute.mp4", "Gentle Sucking Sounds From the Pacifier", "./examples/3.mp4"], - ["./examples/4_mute.mp4", "train passing by", "./examples/4.mp4"], - ["./examples/5_mute.mp4", "Lighting Firecrackers", "./examples/5.mp4"] - ], - inputs=[video_input, caption_input,output_video], - ) - -demo.launch(share=True) +demo = gr.Interface( + fn=synthesize_video_with_audio, + inputs=[ + gr.Video(label="Upload Video"), + gr.Textbox(label="Caption (optional)", placeholder="can be empty",), + gr.Textbox(label="CoT Description (optional)", lines=6, placeholder="can be empty",), + ], + outputs=[ + gr.Video(label="Result"), + ], + title="ThinkSound Demo", + description="Upload a video, caption, or CoT to generate audio. For an enhanced experience, we automatically merge the generated audio with your original silent video. (Note: Flexible audio generation lengths are supported.:)", + examples=[ + ["examples/3_mute.mp4", "Gentle Sucking Sounds From the Pacifier", "Begin by creating a soft, steady background of light pacifier suckling. Add subtle, breathy rhythms to mimic a newborn's gentle mouth movements. Keep the sound smooth, natural, and soothing."], + ["examples/2_mute.mp4", "Printer Printing", "Generate a continuous printer printing sound with periodic beeps and paper movement, plus a cat pawing at the machine. Add subtle ambient room noise for authenticity, keeping the focus on printing, beeps, and the cat's interaction."], + ["examples/4_mute.mp4", "Plastic Debris Handling", "Begin with the sound of hands scooping up loose plastic debris, followed by the subtle cascading noise as the pieces fall and scatter back down. Include soft crinkling and rustling to emphasize the texture of the plastic. Add ambient factory background noise with distant machinery to create an industrial atmosphere."], + ["examples/5_mute.mp4", "Lighting Firecrackers", "Generate the sound of firecrackers lighting and exploding repeatedly on the ground, followed by fireworks bursting in the sky. Incorporate occasional subtle echoes to mimic an outdoor night ambiance, with no human voices present."] + ], + cache_examples=True +) + +if __name__ == "__main__": + demo.launch(share=True) diff --git a/cot_vgg_demo_caption.txt b/cot_vgg_demo_caption.txt new file mode 100644 index 0000000000000000000000000000000000000000..545693370faeb841cceb4f4fba1fa7fc6b91a496 --- /dev/null +++ b/cot_vgg_demo_caption.txt @@ -0,0 +1 @@ +demo.npz \ No newline at end of file diff --git a/data_utils/__pycache__/__init__.cpython-310.pyc b/data_utils/__pycache__/__init__.cpython-310.pyc deleted file mode 100644 index 6649c3a8b1e3891099f9dcc6544cfe78b2fd0819..0000000000000000000000000000000000000000 Binary files a/data_utils/__pycache__/__init__.cpython-310.pyc and /dev/null differ diff --git a/data_utils/__pycache__/utils.cpython-310.pyc b/data_utils/__pycache__/utils.cpython-310.pyc deleted file mode 100644 index ce286899e4b3e6d8ef183749a5428ec34618bb19..0000000000000000000000000000000000000000 Binary files a/data_utils/__pycache__/utils.cpython-310.pyc and /dev/null differ diff --git a/data_utils/__pycache__/utils.cpython-39.pyc b/data_utils/__pycache__/utils.cpython-39.pyc deleted file mode 100644 index 46823f149a2bbd0d09481eff2ee6d1ad86fe151b..0000000000000000000000000000000000000000 Binary files a/data_utils/__pycache__/utils.cpython-39.pyc and /dev/null differ diff --git a/data_utils/ext/synchformer/__pycache__/__init__.cpython-310.pyc b/data_utils/ext/synchformer/__pycache__/__init__.cpython-310.pyc deleted file mode 100644 index 0ce33b394f89c7e27b9d3ce6128905759ae97585..0000000000000000000000000000000000000000 Binary files a/data_utils/ext/synchformer/__pycache__/__init__.cpython-310.pyc and /dev/null differ diff --git a/data_utils/ext/synchformer/__pycache__/__init__.cpython-39.pyc b/data_utils/ext/synchformer/__pycache__/__init__.cpython-39.pyc deleted file mode 100644 index 43e5f377eba172a01f2d255bd1892d94fe4eaf1c..0000000000000000000000000000000000000000 Binary files a/data_utils/ext/synchformer/__pycache__/__init__.cpython-39.pyc and /dev/null differ diff --git a/data_utils/ext/synchformer/__pycache__/motionformer.cpython-310.pyc b/data_utils/ext/synchformer/__pycache__/motionformer.cpython-310.pyc deleted file mode 100644 index 410b25263432859b32e10180b22403ba2ed6b511..0000000000000000000000000000000000000000 Binary files a/data_utils/ext/synchformer/__pycache__/motionformer.cpython-310.pyc and /dev/null differ diff --git a/data_utils/ext/synchformer/__pycache__/motionformer.cpython-39.pyc b/data_utils/ext/synchformer/__pycache__/motionformer.cpython-39.pyc deleted file mode 100644 index af8c9377841c9c77b52249865cbc42d98c14278f..0000000000000000000000000000000000000000 Binary files a/data_utils/ext/synchformer/__pycache__/motionformer.cpython-39.pyc and /dev/null differ diff --git a/data_utils/ext/synchformer/__pycache__/synchformer.cpython-310.pyc b/data_utils/ext/synchformer/__pycache__/synchformer.cpython-310.pyc deleted file mode 100644 index a0df0dd022851746361ae2410b5d9508764e1bbd..0000000000000000000000000000000000000000 Binary files a/data_utils/ext/synchformer/__pycache__/synchformer.cpython-310.pyc and /dev/null differ diff --git a/data_utils/ext/synchformer/__pycache__/synchformer.cpython-39.pyc b/data_utils/ext/synchformer/__pycache__/synchformer.cpython-39.pyc deleted file mode 100644 index cca707887580daa686ca15539a1fe0dd93b73884..0000000000000000000000000000000000000000 Binary files a/data_utils/ext/synchformer/__pycache__/synchformer.cpython-39.pyc and /dev/null differ diff --git a/data_utils/ext/synchformer/__pycache__/utils.cpython-310.pyc b/data_utils/ext/synchformer/__pycache__/utils.cpython-310.pyc deleted file mode 100644 index 5880bf322ab6106b6e97b8d3b5c5c835e114b424..0000000000000000000000000000000000000000 Binary files a/data_utils/ext/synchformer/__pycache__/utils.cpython-310.pyc and /dev/null differ diff --git a/data_utils/ext/synchformer/__pycache__/utils.cpython-39.pyc b/data_utils/ext/synchformer/__pycache__/utils.cpython-39.pyc deleted file mode 100644 index eae02f00ddcb30ca46ebf32b5debb946d6618d81..0000000000000000000000000000000000000000 Binary files a/data_utils/ext/synchformer/__pycache__/utils.cpython-39.pyc and /dev/null differ diff --git a/data_utils/ext/synchformer/__pycache__/video_model_builder.cpython-310.pyc b/data_utils/ext/synchformer/__pycache__/video_model_builder.cpython-310.pyc deleted file mode 100644 index fa6ddadb705a39423b4768966a3a0ea41e76b648..0000000000000000000000000000000000000000 Binary files a/data_utils/ext/synchformer/__pycache__/video_model_builder.cpython-310.pyc and /dev/null differ diff --git a/data_utils/ext/synchformer/__pycache__/video_model_builder.cpython-39.pyc b/data_utils/ext/synchformer/__pycache__/video_model_builder.cpython-39.pyc deleted file mode 100644 index 10ce47b92644955ec4a037d3703ba997e7162429..0000000000000000000000000000000000000000 Binary files a/data_utils/ext/synchformer/__pycache__/video_model_builder.cpython-39.pyc and /dev/null differ diff --git a/data_utils/ext/synchformer/__pycache__/vit_helper.cpython-310.pyc b/data_utils/ext/synchformer/__pycache__/vit_helper.cpython-310.pyc deleted file mode 100644 index 15ebefdf8dbdb0c826516b654134f0db46f00282..0000000000000000000000000000000000000000 Binary files a/data_utils/ext/synchformer/__pycache__/vit_helper.cpython-310.pyc and /dev/null differ diff --git a/data_utils/ext/synchformer/__pycache__/vit_helper.cpython-39.pyc b/data_utils/ext/synchformer/__pycache__/vit_helper.cpython-39.pyc deleted file mode 100644 index 62566399384b6ab2d9d619a9e7ff0364821a0a51..0000000000000000000000000000000000000000 Binary files a/data_utils/ext/synchformer/__pycache__/vit_helper.cpython-39.pyc and /dev/null differ diff --git a/data_utils/utils.py b/data_utils/utils.py deleted file mode 100644 index 68cb806d9a44d9f6fa891f715b4d64f6534c3bd1..0000000000000000000000000000000000000000 --- a/data_utils/utils.py +++ /dev/null @@ -1,115 +0,0 @@ -"""Utility functions.""" -import contextlib -import csv -import json -import os -import pathlib -import warnings - -import numpy as np - - -def save_args(filename, args): - """Save the command-line arguments.""" - args_dict = {} - for key, value in vars(args).items(): - if isinstance(value, pathlib.Path): - args_dict[key] = str(value) - else: - args_dict[key] = value - save_json(filename, args_dict) - - -def inverse_dict(d): - """Return the inverse dictionary.""" - return {v: k for k, v in d.items()} - - -def save_txt(filename, data): - """Save a list to a TXT file.""" - with open(filename, "w", encoding="utf8") as f: - for item in data: - f.write(f"{item}\n") - - -def load_txt(filename): - """Load a TXT file as a list.""" - with open(filename, encoding="utf8") as f: - return [line.strip() for line in f] - - -def save_json(filename, data): - """Save data as a JSON file.""" - with open(filename, "w", encoding="utf8") as f: - json.dump(data, f) - - -def load_json(filename): - """Load data from a JSON file.""" - with open(filename, encoding="utf8") as f: - return json.load(f) - - -def save_csv(filename, data, header=""): - """Save data as a CSV file.""" - np.savetxt( - filename, data, fmt="%d", delimiter=",", header=header, comments="" - ) - - -def load_csv(filename, skiprows=1): - """Load data from a CSV file.""" - return np.loadtxt(filename, dtype=int, delimiter=",", skiprows=skiprows) - - -def load_csv_text(filename, headerless=True): - """Read a CSV file into a list of dictionaries or lists.""" - with open(filename) as f: - if headerless: - return [row for row in csv.reader(f)] - reader = csv.DictReader(f) - return [ - {field: row[field] for field in reader.fieldnames} - for row in reader - ] - - -def ignore_exceptions(func): - """Decorator that ignores all errors and warnings.""" - - def inner(*args, **kwargs): - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - try: - return func(*args, **kwargs) - except Exception: - return None - - return inner - - -def suppress_outputs(func): - """Decorator that suppresses writing to stdout and stderr.""" - - def inner(*args, **kwargs): - devnull = open(os.devnull, "w") - with contextlib.redirect_stdout(devnull): - with contextlib.redirect_stderr(devnull): - return func(*args, **kwargs) - - return inner - - -def resolve_paths(func): - """Decorator that resolves all paths.""" - - def inner(*args, **kwargs): - parsed = func(*args, **kwargs) - for key in vars(parsed).keys(): - if isinstance(getattr(parsed, key), pathlib.Path): - setattr( - parsed, key, getattr(parsed, key).expanduser().resolve() - ) - return parsed - - return inner diff --git a/data_utils/v2a_utils/__pycache__/__init__.cpython-310.pyc b/data_utils/v2a_utils/__pycache__/__init__.cpython-310.pyc deleted file mode 100644 index 10d232e6b7d73a07d1d4b7f08ac3a581b9110a89..0000000000000000000000000000000000000000 Binary files a/data_utils/v2a_utils/__pycache__/__init__.cpython-310.pyc and /dev/null differ diff --git a/data_utils/v2a_utils/__pycache__/audio_text_dataset.cpython-310.pyc b/data_utils/v2a_utils/__pycache__/audio_text_dataset.cpython-310.pyc deleted file mode 100644 index be016caa49d423f4b0a52f66ac3b210b52423ce1..0000000000000000000000000000000000000000 Binary files a/data_utils/v2a_utils/__pycache__/audio_text_dataset.cpython-310.pyc and /dev/null differ diff --git a/data_utils/v2a_utils/__pycache__/audio_text_dataset.cpython-38.pyc b/data_utils/v2a_utils/__pycache__/audio_text_dataset.cpython-38.pyc deleted file mode 100644 index c38b83aea9638b59c45ec959b4c1cc5281f0f89e..0000000000000000000000000000000000000000 Binary files a/data_utils/v2a_utils/__pycache__/audio_text_dataset.cpython-38.pyc and /dev/null differ diff --git a/data_utils/v2a_utils/__pycache__/audio_text_dataset.cpython-39.pyc b/data_utils/v2a_utils/__pycache__/audio_text_dataset.cpython-39.pyc deleted file mode 100644 index b80d86ad166b7715cf5e7e0b8aef626c262b78f4..0000000000000000000000000000000000000000 Binary files a/data_utils/v2a_utils/__pycache__/audio_text_dataset.cpython-39.pyc and /dev/null differ diff --git a/data_utils/v2a_utils/__pycache__/audioset_224.cpython-39.pyc b/data_utils/v2a_utils/__pycache__/audioset_224.cpython-39.pyc deleted file mode 100644 index 1996d974e0dd8489b9ee54b64093d893a8bb8c6f..0000000000000000000000000000000000000000 Binary files a/data_utils/v2a_utils/__pycache__/audioset_224.cpython-39.pyc and /dev/null differ diff --git a/data_utils/v2a_utils/__pycache__/audioset_video_224.cpython-39.pyc b/data_utils/v2a_utils/__pycache__/audioset_video_224.cpython-39.pyc deleted file mode 100644 index fa209c10ef36183642053047c60abdfa9b715796..0000000000000000000000000000000000000000 Binary files a/data_utils/v2a_utils/__pycache__/audioset_video_224.cpython-39.pyc and /dev/null differ diff --git a/data_utils/v2a_utils/__pycache__/feature_utils.cpython-310.pyc b/data_utils/v2a_utils/__pycache__/feature_utils.cpython-310.pyc deleted file mode 100644 index 471714af83c05d9612ddda5bb09dfab70813ec34..0000000000000000000000000000000000000000 Binary files a/data_utils/v2a_utils/__pycache__/feature_utils.cpython-310.pyc and /dev/null differ diff --git a/data_utils/v2a_utils/__pycache__/feature_utils.cpython-39.pyc b/data_utils/v2a_utils/__pycache__/feature_utils.cpython-39.pyc deleted file mode 100644 index 2e133a860468cbc37baa9acea75ec1a2b7c5c456..0000000000000000000000000000000000000000 Binary files a/data_utils/v2a_utils/__pycache__/feature_utils.cpython-39.pyc and /dev/null differ diff --git a/data_utils/v2a_utils/__pycache__/feature_utils_224.cpython-310.pyc b/data_utils/v2a_utils/__pycache__/feature_utils_224.cpython-310.pyc deleted file mode 100644 index c33953aafdaa73f59beeced690bee714dcee756a..0000000000000000000000000000000000000000 Binary files a/data_utils/v2a_utils/__pycache__/feature_utils_224.cpython-310.pyc and /dev/null differ diff --git a/data_utils/v2a_utils/__pycache__/feature_utils_224.cpython-39.pyc b/data_utils/v2a_utils/__pycache__/feature_utils_224.cpython-39.pyc deleted file mode 100644 index 6577c2bb22d4d84bbbc2013c74a74e30ab89b71d..0000000000000000000000000000000000000000 Binary files a/data_utils/v2a_utils/__pycache__/feature_utils_224.cpython-39.pyc and /dev/null differ diff --git a/data_utils/v2a_utils/__pycache__/feature_utils_224_audio.cpython-310.pyc b/data_utils/v2a_utils/__pycache__/feature_utils_224_audio.cpython-310.pyc deleted file mode 100644 index b6565541958c14e9294afcfc29f9ce764fe04e83..0000000000000000000000000000000000000000 Binary files a/data_utils/v2a_utils/__pycache__/feature_utils_224_audio.cpython-310.pyc and /dev/null differ diff --git a/data_utils/v2a_utils/__pycache__/feature_utils_224_audio.cpython-38.pyc b/data_utils/v2a_utils/__pycache__/feature_utils_224_audio.cpython-38.pyc deleted file mode 100644 index fcc14fe670bd8fe46989ecba1e5bdc9a4e47f937..0000000000000000000000000000000000000000 Binary files a/data_utils/v2a_utils/__pycache__/feature_utils_224_audio.cpython-38.pyc and /dev/null differ diff --git a/data_utils/v2a_utils/__pycache__/feature_utils_224_audio.cpython-39.pyc b/data_utils/v2a_utils/__pycache__/feature_utils_224_audio.cpython-39.pyc deleted file mode 100644 index e62be420b89e4fcf0d698b6f0ff6aac91cb8889e..0000000000000000000000000000000000000000 Binary files a/data_utils/v2a_utils/__pycache__/feature_utils_224_audio.cpython-39.pyc and /dev/null differ diff --git a/data_utils/v2a_utils/__pycache__/feature_utils_224_no_sync.cpython-39.pyc b/data_utils/v2a_utils/__pycache__/feature_utils_224_no_sync.cpython-39.pyc deleted file mode 100644 index 34cd3500ab00860c8727a17891e7ccbfd084fa76..0000000000000000000000000000000000000000 Binary files a/data_utils/v2a_utils/__pycache__/feature_utils_224_no_sync.cpython-39.pyc and /dev/null differ diff --git a/data_utils/v2a_utils/__pycache__/vggsound.cpython-310.pyc b/data_utils/v2a_utils/__pycache__/vggsound.cpython-310.pyc deleted file mode 100644 index 174ca8efd7276497218351d82f28f9a139fa5d16..0000000000000000000000000000000000000000 Binary files a/data_utils/v2a_utils/__pycache__/vggsound.cpython-310.pyc and /dev/null differ diff --git a/data_utils/v2a_utils/__pycache__/vggsound.cpython-39.pyc b/data_utils/v2a_utils/__pycache__/vggsound.cpython-39.pyc deleted file mode 100644 index 510649119739eb6dab8d71fdecec9429910c06b7..0000000000000000000000000000000000000000 Binary files a/data_utils/v2a_utils/__pycache__/vggsound.cpython-39.pyc and /dev/null differ diff --git a/data_utils/v2a_utils/__pycache__/vggsound_224.cpython-310.pyc b/data_utils/v2a_utils/__pycache__/vggsound_224.cpython-310.pyc deleted file mode 100644 index 2f414d5b164c9b3eb63b322f1d46061bdcfd94c9..0000000000000000000000000000000000000000 Binary files a/data_utils/v2a_utils/__pycache__/vggsound_224.cpython-310.pyc and /dev/null differ diff --git a/data_utils/v2a_utils/__pycache__/vggsound_224.cpython-39.pyc b/data_utils/v2a_utils/__pycache__/vggsound_224.cpython-39.pyc deleted file mode 100644 index b6f8d690d16e6cc85cc4b9d946ca2a7d7e02d54f..0000000000000000000000000000000000000000 Binary files a/data_utils/v2a_utils/__pycache__/vggsound_224.cpython-39.pyc and /dev/null differ diff --git a/data_utils/v2a_utils/__pycache__/vggsound_224_no_audio.cpython-310.pyc b/data_utils/v2a_utils/__pycache__/vggsound_224_no_audio.cpython-310.pyc deleted file mode 100644 index d15160ffbd28c168fd85ba1cc0107ac9a05cd830..0000000000000000000000000000000000000000 Binary files a/data_utils/v2a_utils/__pycache__/vggsound_224_no_audio.cpython-310.pyc and /dev/null differ diff --git a/data_utils/v2a_utils/__pycache__/vggsound_224_no_sync.cpython-39.pyc b/data_utils/v2a_utils/__pycache__/vggsound_224_no_sync.cpython-39.pyc deleted file mode 100644 index e68e23e6f0e5f2487d985741a09cfab54638f231..0000000000000000000000000000000000000000 Binary files a/data_utils/v2a_utils/__pycache__/vggsound_224_no_sync.cpython-39.pyc and /dev/null differ diff --git a/data_utils/v2a_utils/__pycache__/vggsound_text.cpython-39.pyc b/data_utils/v2a_utils/__pycache__/vggsound_text.cpython-39.pyc deleted file mode 100644 index 9ac3fd63b2d590d77c06801489338005bc10c953..0000000000000000000000000000000000000000 Binary files a/data_utils/v2a_utils/__pycache__/vggsound_text.cpython-39.pyc and /dev/null differ diff --git a/data_utils/v2a_utils/audio_text_dataset.py b/data_utils/v2a_utils/audio_text_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..5a692380ef635f1ea44902b8f13a9eac5a7b7de4 --- /dev/null +++ b/data_utils/v2a_utils/audio_text_dataset.py @@ -0,0 +1,173 @@ +import os +from pathlib import Path +from typing import Optional, Union +from PIL import Image + +import pandas as pd +import torch +import torchaudio +from torch.utils.data.dataset import Dataset +from torchvision.transforms import v2 +from torio.io import StreamingMediaDecoder +from torchvision.utils import save_image +from transformers import AutoProcessor +import torch.nn.functional as F +import numpy as np + +import logging +log = logging.getLogger() + +_CLIP_SIZE = 224 +_CLIP_FPS = 8.0 + +_SYNC_SIZE = 224 +_SYNC_FPS = 25.0 + + +class Audio_Text(Dataset): + + def __init__( + self, + root: Union[str, Path], + *, + tsv_path: Union[str, Path] = 'dataset/vggsound/split_txt/train_caption.csv', + sample_rate: int = 44_100, + duration_sec: float = 9.0, + audio_samples: Optional[int] = 397312, + normalize_audio: bool = False, + start_row: Optional[int] = None, + end_row: Optional[int] = None, + save_dir: str = 'data/vggsound/video_latents_text/train' + ): + self.root = Path(root) + self.normalize_audio = normalize_audio + if audio_samples is None: + self.audio_samples = int(sample_rate * duration_sec) + else: + self.audio_samples = audio_samples + effective_duration = audio_samples / sample_rate + # make sure the duration is close enough, within 15ms + assert abs(effective_duration - duration_sec) < 0.015, \ + f'audio_samples {audio_samples} does not match duration_sec {duration_sec}' + + # videos = sorted(os.listdir(self.root)) + # videos = set([Path(v).stem for v in videos]) # remove extensions + videos = [] + self.labels = [] + self.videos = [] + self.cots = [] + missing_videos = [] + # read the tsv for subset information + df_list = pd.read_csv(tsv_path, sep=',', dtype={'id': str}).to_dict('records') + + # 控制处理的行范围 + if start_row is not None and end_row is not None: + df_list = df_list[start_row:end_row] + for record in df_list: + id = record['id'] + if os.path.exists(f'{save_dir}/{id}.pth'): continue + label = record['caption'] + # if id in videos: + self.labels.append(label) + # print(label,'debug1!!!!!!!!!') + self.cots.append(record['caption_cot']) + # self.labels[id] = label + self.videos.append(id) + # else: + # missing_videos.append(id) + + log.info(f'{len(videos)} videos found in {root}') + log.info(f'{len(self.videos)} videos found in {tsv_path}') + log.info(f'{len(missing_videos)} videos missing in {root}') + + self.sample_rate = sample_rate + self.duration_sec = duration_sec + + self.expected_audio_length = self.audio_samples + self.resampler = {} + + def sample(self, idx: int): + video_id = self.videos[idx] + label = self.labels[idx] + cot = self.cots[idx] + audio_path = os.path.join(self.root, f'{video_id}.wav') + if not os.path.exists(audio_path): + audio_path = os.path.join(self.root, f'{video_id}.flac') + if not os.path.exists(audio_path): + raise RuntimeError(f'Audio is not exist {audio_path}') + audio_chunk, sample_rate = torchaudio.load(audio_path) + if len(audio_chunk.shape) != 2: + raise RuntimeError(f'error audio shape {video_id}') + + abs_max = audio_chunk[0].abs().max() + + if abs_max <= 1e-6: + if audio_chunk.shape[0] > 1 and audio_chunk[1].abs().max() > 1e-6: + audio_chunk = audio_chunk[1:2] + else: + raise RuntimeError(f'Audio is silent {video_id}') + + # ensure the stereo audio + if audio_chunk.shape[0] < 2: + audio_chunk = audio_chunk.repeat(2, 1) + elif audio_chunk.shape[0] > 2: + audio_chunk = audio_chunk[:2] + + # resample + if sample_rate == self.sample_rate: + audio_chunk = audio_chunk + else: + if sample_rate not in self.resampler: + # https://pytorch.org/audio/stable/tutorials/audio_resampling_tutorial.html#kaiser-best + self.resampler[sample_rate] = torchaudio.transforms.Resample( + sample_rate, + self.sample_rate, + lowpass_filter_width=64, + rolloff=0.9475937167399596, + resampling_method='sinc_interp_kaiser', + beta=14.769656459379492, + ) + audio_chunk = self.resampler[sample_rate](audio_chunk) + + if audio_chunk.shape[1] < self.expected_audio_length: + # zero-padding audio + padding_length = self.expected_audio_length - audio_chunk.shape[1] + # 创建 padding 张量,大小为 [batch_size, padding_length],值为0 + padding = torch.zeros(audio_chunk.shape[0], padding_length) + # 将原始音频和 padding 沿第 1 维度拼接在一起 + audio_chunk = torch.cat((audio_chunk, padding), dim=1) + # raise RuntimeError(f'Audio too short {video_id}') + audio_chunk = audio_chunk[:,:self.expected_audio_length] + assert audio_chunk.shape == (2, 397312), f'error shape:{video_id},{audio_chunk.shape}' + # print(label,'debug2!!!!!!!!!') + data = { + 'id': video_id, + 'caption': label, + 'caption_cot': cot, + 'audio': audio_chunk, + } + + return data + + def __getitem__(self, idx: int): + try: + return self.sample(idx) + except Exception as e: + log.error(f'Error loading video {self.videos[idx]}: {e}') + return None + + def __len__(self): + return len(self.labels) + + +# dataset = VGGSound( +# root="data/vggsound/video/train", +# tsv_path="data/vggsound/split_txt/temp.csv", +# sample_rate=44100, +# duration_sec=9.0, +# audio_samples=397312, +# start_row=0, +# end_row=None, +# save_dir="data/vggsound/video_224_latents_text/train" +# ) +# dataset[0] \ No newline at end of file diff --git a/data_utils/v2a_utils/audioset_224.py b/data_utils/v2a_utils/audioset_224.py new file mode 100644 index 0000000000000000000000000000000000000000..f363d157ebc72d3b929937f6de031eec63315cf4 --- /dev/null +++ b/data_utils/v2a_utils/audioset_224.py @@ -0,0 +1,315 @@ +import os +from pathlib import Path +from typing import Optional, Union +from PIL import Image + +import pandas as pd +import torch +import torchaudio +from torch.utils.data.dataset import Dataset +from torchvision.transforms import v2 +from torio.io import StreamingMediaDecoder +from torchvision.utils import save_image +from transformers import AutoProcessor +import torch.nn.functional as F +import numpy as np + +import logging +log = logging.getLogger() + +_CLIP_SIZE = 224 +_CLIP_FPS = 8.0 + +_SYNC_SIZE = 224 +_SYNC_FPS = 25.0 + +def save_tensor_as_image(tensor, save_path): + """ + 将形状为 (1, 3, H, W) 的 RGB 图像数组保存为图片文件。 + + :param tensor: 输入的 NumPy 数组 (1, 3, H, W)。 + :param save_path: 图片保存路径。 + """ + # # 移除批次维度,变成 (3, H, W) + # tensor = tensor.squeeze(0) + + # 交换轴顺序,变为 (H, W, 3) + image_array = np.transpose(tensor, (1, 2, 0)) + + # 检查数组是否为合适的数据类型 + if image_array.dtype != np.uint8: + # 如果不是 uint8,首先标准化,然后转换 + image_array = (image_array - image_array.min()) / (image_array.max() - image_array.min()) * 255 + image_array = image_array.astype(np.uint8) + + # 创建图像对象 + image = Image.fromarray(image_array) + + # 保存图片 + image.save(save_path) + print(f"Image saved to {save_path}") + +def pad_to_square(video_tensor): + # 验证输入的形状 + if len(video_tensor.shape) != 4: + raise ValueError("Input tensor must have shape (l, c, h, w)") + + l, c, h, w = video_tensor.shape + max_side = max(h, w) + + # 计算每一维度需要的填充量:(left, right, top, bottom) + pad_h = max_side - h + pad_w = max_side - w + + # 创建padding tuple (left, right, top, bottom) + # 因为图像的填充是作用在最后两个维度 h 和 w 上,所以我们需要指定这两个维度的填充 + padding = (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2) + + # 使用F.pad对视频张量进行填充操作 + # 填充参数为 (left, right, top, bottom) + video_padded = F.pad(video_tensor, pad=padding, mode='constant', value=0) + + return video_padded + +class Audioset(Dataset): + + def __init__( + self, + root: Union[str, Path], + *, + tsv_path: Union[str, Path] = 'dataset/vggsound/split_txt/train_caption.csv', + sample_rate: int = 44_100, + duration_sec: float = 9.0, + audio_samples: Optional[int] = 397312, + normalize_audio: bool = False, + start_row: Optional[int] = None, + end_row: Optional[int] = None, + save_dir: str = 'data/vggsound/video_latents_text/train' + ): + self.root = Path(root) + self.normalize_audio = normalize_audio + if audio_samples is None: + self.audio_samples = int(sample_rate * duration_sec) + else: + self.audio_samples = audio_samples + effective_duration = audio_samples / sample_rate + # make sure the duration is close enough, within 15ms + assert abs(effective_duration - duration_sec) < 0.015, \ + f'audio_samples {audio_samples} does not match duration_sec {duration_sec}' + + # videos = sorted(os.listdir(self.root)) + # videos = set([Path(v).stem for v in videos]) # remove extensions + videos = [] + self.labels = [] + self.videos = [] + self.caption_t5s = [] + missing_videos = [] + # read the tsv for subset information + df_list = pd.read_csv(tsv_path, sep=',', dtype={'id': str}).to_dict('records') + + # 控制处理的行范围 + if start_row is not None and end_row is not None: + df_list = df_list[start_row:end_row] + + for record in df_list: + id = record['id'] + if os.path.exists(f'{save_dir}/{id}.pth'): continue + label = record['label'] + caption_t5 = record['caption_t5'] + # if id in videos: + self.labels.append(label) + # self.labels[id] = label + self.videos.append(id) + self.caption_t5s.append(caption_t5) + # else: + # missing_videos.append(id) + + log.info(f'{len(videos)} videos found in {root}') + log.info(f'{len(self.videos)} videos found in {tsv_path}') + log.info(f'{len(missing_videos)} videos missing in {root}') + + self.sample_rate = sample_rate + self.duration_sec = duration_sec + + self.expected_audio_length = self.audio_samples + self.clip_expected_length = int(_CLIP_FPS * self.duration_sec) + self.sync_expected_length = int(_SYNC_FPS * self.duration_sec) + + self.clip_transform = v2.Compose([ + v2.Lambda(pad_to_square), # 先填充为正方形 + v2.Resize((_CLIP_SIZE, _CLIP_SIZE), interpolation=v2.InterpolationMode.BICUBIC), + v2.ToImage(), + v2.ToDtype(torch.float32, scale=True), + ]) + self.clip_processor = AutoProcessor.from_pretrained("useful_ckpts/metaclip-huge") + self.sync_transform = v2.Compose([ + v2.Resize(_SYNC_SIZE, interpolation=v2.InterpolationMode.BICUBIC), + v2.CenterCrop(_SYNC_SIZE), + v2.ToImage(), + v2.ToDtype(torch.float32, scale=True), + v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + ]) + + self.resampler = {} + + def sample(self, idx: int) -> dict[str, torch.Tensor]: + video_id = self.videos[idx] + label = self.labels[idx] + caption_t5 = self.caption_t5s[idx] + + reader = StreamingMediaDecoder(self.root / (video_id + '.mp4')) + reader.add_basic_video_stream( + frames_per_chunk=int(_CLIP_FPS * self.duration_sec), + frame_rate=_CLIP_FPS, + format='rgb24', + ) + reader.add_basic_video_stream( + frames_per_chunk=int(_SYNC_FPS * self.duration_sec), + frame_rate=_SYNC_FPS, + format='rgb24', + ) + # reader.add_basic_audio_stream(frames_per_chunk=2**30,) + + reader.fill_buffer() + data_chunk = reader.pop_chunks() + + clip_chunk = data_chunk[0] + sync_chunk = data_chunk[1] + audio_path = os.path.join("dataset/3_Audioset/audios/sound",video_id+'.wav') + assert os.path.exists(audio_path), f'{audio_path} not exists' + audio_chunk, sr = torchaudio.load(audio_path) + # audio_chunk = data_chunk[2] + if len(audio_chunk.shape) != 2: + raise RuntimeError(f'error audio shape {video_id}') + if clip_chunk is None: + raise RuntimeError(f'CLIP video returned None {video_id}') + + if sync_chunk is None: + raise RuntimeError(f'Sync video returned None {video_id}') + sample_rate = int(sr) + # audio_chunk = audio_chunk.transpose(0, 1) + abs_max = audio_chunk[0].abs().max() + # audio_chunk = audio_chunk.mean(dim=0) # mono + # if self.normalize_audio: + # abs_max = audio_chunk.abs().max() + # audio_chunk = audio_chunk / abs_max * 0.95 + if abs_max <= 1e-6: + if audio_chunk.shape[0] > 1 and audio_chunk[1].abs().max() > 1e-6: + audio_chunk = audio_chunk[1:2] + else: + raise RuntimeError(f'Audio is silent {video_id}') + + # ensure the stereo audio + if audio_chunk.shape[0] < 2: + audio_chunk = audio_chunk.repeat(2, 1) + + # resample + if sample_rate == self.sample_rate: + audio_chunk = audio_chunk + else: + if sample_rate not in self.resampler: + # https://pytorch.org/audio/stable/tutorials/audio_resampling_tutorial.html#kaiser-best + self.resampler[sample_rate] = torchaudio.transforms.Resample( + sample_rate, + self.sample_rate, + lowpass_filter_width=64, + rolloff=0.9475937167399596, + resampling_method='sinc_interp_kaiser', + beta=14.769656459379492, + ) + audio_chunk = self.resampler[sample_rate](audio_chunk) + + if audio_chunk.shape[1] < self.expected_audio_length: + # zero-padding audio + padding_length = self.expected_audio_length - audio_chunk.shape[1] + # 创建 padding 张量,大小为 [batch_size, padding_length],值为0 + padding = torch.zeros(audio_chunk.shape[0], padding_length) + # 将原始音频和 padding 沿第 1 维度拼接在一起 + audio_chunk = torch.cat((audio_chunk, padding), dim=1) + # raise RuntimeError(f'Audio too short {video_id}') + audio_chunk = audio_chunk[:,:self.expected_audio_length] + # truncate the video + clip_chunk = clip_chunk[:self.clip_expected_length] + # import ipdb + # ipdb.set_trace() + if clip_chunk.shape[0] != self.clip_expected_length: + current_length = clip_chunk.shape[0] + padding_needed = self.clip_expected_length - current_length + + # Check that padding needed is no more than 2 + assert padding_needed < 4, f'Padding no more than 2 frames allowed, but {padding_needed} needed' + + # If assertion passes, proceed with padding + if padding_needed > 0: + last_frame = clip_chunk[-1] + log.info(last_frame.shape) + # Repeat the last frame to reach the expected length + padding = last_frame.repeat(padding_needed, 1, 1, 1) + clip_chunk = torch.cat((clip_chunk, padding), dim=0) + # raise RuntimeError(f'CLIP video wrong length {video_id}, ' + # f'expected {self.clip_expected_length}, ' + # f'got {clip_chunk.shape[0]}') + + # save_image(clip_chunk[0] / 255.0,'ori.png') + clip_chunk = pad_to_square(clip_chunk) + # save_image(clip_chunk[0] / 255.0,'square.png') + # clip_chunk = self.clip_transform(clip_chunk) + # import ipdb + # ipdb.set_trace() + clip_chunk = self.clip_processor(images=clip_chunk, return_tensors="pt")["pixel_values"] + # log.info(clip_chunk.shape) + # save_tensor_as_image(clip_chunk[0].numpy(),'scale.png') + # log.info(clip_chunk[0]) + # clip_chunk = outputs + # text_ids = outputs["input_ids"] + # temp_img = clip_chunk[0].permute(1, 2, 0) * 255 + # save_image(clip_chunk[0],'scale.png') + sync_chunk = sync_chunk[:self.sync_expected_length] + if sync_chunk.shape[0] != self.sync_expected_length: + # padding using the last frame, but no more than 2 + current_length = sync_chunk.shape[0] + last_frame = sync_chunk[-1] + # 重复最后一帧以进行填充 + padding = last_frame.repeat(self.sync_expected_length - current_length, 1, 1, 1) + assert self.sync_expected_length - current_length < 12, f'sync can pad no more than 2 while {self.sync_expected_length - current_length}' + sync_chunk = torch.cat((sync_chunk, padding), dim=0) + # raise RuntimeError(f'Sync video wrong length {video_id}, ' + # f'expected {self.sync_expected_length}, ' + # f'got {sync_chunk.shape[0]}') + + sync_chunk = self.sync_transform(sync_chunk) + assert audio_chunk.shape[1] == self.expected_audio_length and clip_chunk.shape[0] == self.clip_expected_length \ + and sync_chunk.shape[0] == self.sync_expected_length, 'error processed data shape' + data = { + 'id': video_id, + 'caption': label, + 'caption_t5': caption_t5, + 'audio': audio_chunk, + 'clip_video': clip_chunk, + 'sync_video': sync_chunk, + } + + return data + + def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: + try: + return self.sample(idx) + except Exception as e: + log.error(f'Error loading video {self.videos[idx]}: {e}') + return None + + def __len__(self): + return len(self.labels) + + +# dataset = Audioset( +# root="dataset/3_Audioset/video/sound", +# tsv_path="dataset/3_Audioset/split_txt/unbalanced_sound_filtered_aligned_novgg_noout.csv", +# sample_rate=44100, +# duration_sec=9.0, +# audio_samples=397312, +# start_row=0, +# end_row=None, +# save_dir="dataset/3_Audioset/video_text_latents/" +# ) +# dataset[0] \ No newline at end of file diff --git a/data_utils/v2a_utils/audioset_video_224.py b/data_utils/v2a_utils/audioset_video_224.py new file mode 100644 index 0000000000000000000000000000000000000000..ecc9f94332c74b5bd03790954f539e6ab570deb9 --- /dev/null +++ b/data_utils/v2a_utils/audioset_video_224.py @@ -0,0 +1,268 @@ +import os +from pathlib import Path +from typing import Optional, Union +from PIL import Image + +import pandas as pd +import torch +import torchaudio +from torch.utils.data.dataset import Dataset +from torchvision.transforms import v2 +from torio.io import StreamingMediaDecoder +from torchvision.utils import save_image +from transformers import AutoProcessor +import torch.nn.functional as F +import numpy as np + +import logging +log = logging.getLogger() + +_CLIP_SIZE = 224 +_CLIP_FPS = 8.0 + +_SYNC_SIZE = 224 +_SYNC_FPS = 25.0 + +def save_tensor_as_image(tensor, save_path): + """ + 将形状为 (1, 3, H, W) 的 RGB 图像数组保存为图片文件。 + + :param tensor: 输入的 NumPy 数组 (1, 3, H, W)。 + :param save_path: 图片保存路径。 + """ + # # 移除批次维度,变成 (3, H, W) + # tensor = tensor.squeeze(0) + + # 交换轴顺序,变为 (H, W, 3) + image_array = np.transpose(tensor, (1, 2, 0)) + + # 检查数组是否为合适的数据类型 + if image_array.dtype != np.uint8: + # 如果不是 uint8,首先标准化,然后转换 + image_array = (image_array - image_array.min()) / (image_array.max() - image_array.min()) * 255 + image_array = image_array.astype(np.uint8) + + # 创建图像对象 + image = Image.fromarray(image_array) + + # 保存图片 + image.save(save_path) + print(f"Image saved to {save_path}") + +def pad_to_square(video_tensor): + # 验证输入的形状 + if len(video_tensor.shape) != 4: + raise ValueError("Input tensor must have shape (l, c, h, w)") + + l, c, h, w = video_tensor.shape + max_side = max(h, w) + + # 计算每一维度需要的填充量:(left, right, top, bottom) + pad_h = max_side - h + pad_w = max_side - w + + # 创建padding tuple (left, right, top, bottom) + # 因为图像的填充是作用在最后两个维度 h 和 w 上,所以我们需要指定这两个维度的填充 + padding = (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2) + + # 使用F.pad对视频张量进行填充操作 + # 填充参数为 (left, right, top, bottom) + video_padded = F.pad(video_tensor, pad=padding, mode='constant', value=0) + + return video_padded + +class Audioset(Dataset): + + def __init__( + self, + root: Union[str, Path], + *, + tsv_path: Union[str, Path] = 'dataset/vggsound/split_txt/train_caption.csv', + duration_sec: float = 10.0, + start_row: Optional[int] = None, + end_row: Optional[int] = None, + save_dir: str = 'data/vggsound/video_latents_text/train' + ): + self.root = Path(root) + + # videos = sorted(os.listdir(self.root)) + # videos = set([Path(v).stem for v in videos]) # remove extensions + videos = [] + self.captions = [] + self.videos = [] + self.caption_t5s = [] + missing_videos = [] + # read the tsv for subset information + df_list = pd.read_csv(tsv_path, sep=',', dtype={'id': str}).to_dict('records') + + # 控制处理的行范围 + if start_row is not None and end_row is not None: + df_list = df_list[start_row:end_row] + with open(tsv_path.replace('.csv','.txt')) as file: + paths = file.readlines() + for record, path in zip(df_list,paths): + id = Path(record['id']).stem + # if os.path.exists(f'{save_dir}/{id}.pth'): continue + caption = record['caption'] + caption_t5 = record['caption_t5'] + path = path.strip() + part = Path(path).parent + video_id = Path(path).stem[1:] + video_path = os.path.join('dataset/3_Audioset/video',part,f'{video_id}.mp4') + assert os.path.exists(video_path), 'video must exist' + # if id in videos: + self.captions.append(caption) + self.caption_t5s.append(caption_t5) + # self.labels[id] = label + self.videos.append(video_path) + # else: + # missing_videos.append(id) + assert len(self.captions) == len(self.caption_t5s) and len(self.captions) == len(self.videos), 'error length' + log.info(f'{len(videos)} videos found in {root}') + log.info(f'{len(self.videos)} videos found in {tsv_path}') + log.info(f'{len(missing_videos)} videos missing in {root}') + + self.duration_sec = duration_sec + + self.clip_expected_length = int(_CLIP_FPS * self.duration_sec) + self.sync_expected_length = int(_SYNC_FPS * self.duration_sec) + + self.clip_transform = v2.Compose([ + v2.Lambda(pad_to_square), # 先填充为正方形 + v2.Resize((_CLIP_SIZE, _CLIP_SIZE), interpolation=v2.InterpolationMode.BICUBIC), + v2.ToImage(), + v2.ToDtype(torch.float32, scale=True), + ]) + self.clip_processor = AutoProcessor.from_pretrained("useful_ckpts/metaclip-huge") + self.sync_transform = v2.Compose([ + v2.Resize(_SYNC_SIZE, interpolation=v2.InterpolationMode.BICUBIC), + v2.CenterCrop(_SYNC_SIZE), + v2.ToImage(), + v2.ToDtype(torch.float32, scale=True), + v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + ]) + + self.resampler = {} + + def sample(self, idx: int) -> dict[str, torch.Tensor]: + video_path = self.videos[idx] + video_id = 'Y'+str(Path(video_path).stem) + caption = self.captions[idx] + caption_t5 = self.caption_t5s[idx] + + reader = StreamingMediaDecoder(video_path) + reader.add_basic_video_stream( + frames_per_chunk=int(_CLIP_FPS * self.duration_sec), + frame_rate=_CLIP_FPS, + format='rgb24', + ) + reader.add_basic_video_stream( + frames_per_chunk=int(_SYNC_FPS * self.duration_sec), + frame_rate=_SYNC_FPS, + format='rgb24', + ) + + reader.fill_buffer() + data_chunk = reader.pop_chunks() + + clip_chunk = data_chunk[0] + sync_chunk = data_chunk[1] + + if clip_chunk is None: + raise RuntimeError(f'CLIP video returned None {video_id}') + # if clip_chunk.shape[0] < self.clip_expected_length: + # raise RuntimeError( + # f'CLIP video too short {video_id}, expected {self.clip_expected_length}, got {clip_chunk.shape[0]}' + # ) + + if sync_chunk is None: + raise RuntimeError(f'Sync video returned None {video_id}') + # if sync_chunk.shape[0] < self.sync_expected_length: + # raise RuntimeError( + # f'Sync video too short {video_id}, expected {self.sync_expected_length}, got {sync_chunk.shape[0]}' + # ) + + + # truncate the video + clip_chunk = clip_chunk[:self.clip_expected_length] + # import ipdb + # ipdb.set_trace() + if clip_chunk.shape[0] != self.clip_expected_length: + current_length = clip_chunk.shape[0] + padding_needed = self.clip_expected_length - current_length + + # Check that padding needed is no more than 2 + assert padding_needed < 4, f'Padding no more than 2 frames allowed, but {padding_needed} needed' + + # If assertion passes, proceed with padding + if padding_needed > 0: + last_frame = clip_chunk[-1] + log.info(clip_chunk.shape) + # Repeat the last frame to reach the expected length + padding = last_frame.repeat(padding_needed, 1, 1, 1) + clip_chunk = torch.cat((clip_chunk, padding), dim=0) + # raise RuntimeError(f'CLIP video wrong length {video_id}, ' + # f'expected {self.clip_expected_length}, ' + # f'got {clip_chunk.shape[0]}') + + # save_image(clip_chunk[0] / 255.0,'ori.png') + clip_chunk = pad_to_square(clip_chunk) + # save_image(clip_chunk[0] / 255.0,'square.png') + # clip_chunk = self.clip_transform(clip_chunk) + # import ipdb + # ipdb.set_trace() + clip_chunk = self.clip_processor(images=clip_chunk, return_tensors="pt")["pixel_values"] + # log.info(clip_chunk.shape) + # save_tensor_as_image(clip_chunk[0].numpy(),'scale.png') + # log.info(clip_chunk[0]) + # clip_chunk = outputs + # text_ids = outputs["input_ids"] + # temp_img = clip_chunk[0].permute(1, 2, 0) * 255 + # save_image(clip_chunk[0],'scale.png') + sync_chunk = sync_chunk[:self.sync_expected_length] + if sync_chunk.shape[0] != self.sync_expected_length: + # padding using the last frame, but no more than 2 + current_length = sync_chunk.shape[0] + last_frame = sync_chunk[-1] + # 重复最后一帧以进行填充 + padding = last_frame.repeat(self.sync_expected_length - current_length, 1, 1, 1) + assert self.sync_expected_length - current_length < 12, f'sync can pad no more than 2 while {self.sync_expected_length - current_length}' + sync_chunk = torch.cat((sync_chunk, padding), dim=0) + # raise RuntimeError(f'Sync video wrong length {video_id}, ' + # f'expected {self.sync_expected_length}, ' + # f'got {sync_chunk.shape[0]}') + + sync_chunk = self.sync_transform(sync_chunk) + assert clip_chunk.shape[0] == self.clip_expected_length and sync_chunk.shape[0] == self.sync_expected_length, 'error processed data shape' + data = { + 'id': video_id, + 'caption': caption, + 'caption_t5': caption_t5, + 'clip_video': clip_chunk, + 'sync_video': sync_chunk, + } + + return data + + def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: + try: + return self.sample(idx) + except Exception as e: + log.error(f'Error loading video {self.videos[idx]}: {e}') + return None + + def __len__(self): + return len(self.captions) + + +# dataset = VGGSound( +# root="data/vggsound/video/train", +# tsv_path="data/vggsound/split_txt/temp.csv", +# sample_rate=44100, +# duration_sec=9.0, +# audio_samples=397312, +# start_row=0, +# end_row=None, +# save_dir="data/vggsound/video_224_latents_text/train" +# ) +# dataset[0] \ No newline at end of file diff --git a/data_utils/v2a_utils/feature_utils_224.py b/data_utils/v2a_utils/feature_utils_224.py index 520d51f30ff47813b482742de59a50e004a0420a..c7bf47dc1a2479a23fbdbc14d74eaecbd96393ca 100644 --- a/data_utils/v2a_utils/feature_utils_224.py +++ b/data_utils/v2a_utils/feature_utils_224.py @@ -7,9 +7,9 @@ import torch.nn.functional as F from einops import rearrange from open_clip import create_model_from_pretrained from torchvision.transforms import Normalize -from think_sound.models.factory import create_model_from_config -from think_sound.models.utils import load_ckpt_state_dict -from think_sound.training.utils import copy_state_dict +from ThinkSound.models.factory import create_model_from_config +from ThinkSound.models.utils import load_ckpt_state_dict +from ThinkSound.training.utils import copy_state_dict from transformers import AutoModel from transformers import AutoProcessor from transformers import T5EncoderModel, AutoTokenizer diff --git a/data_utils/v2a_utils/vggsound.py b/data_utils/v2a_utils/vggsound.py new file mode 100644 index 0000000000000000000000000000000000000000..d9ae03f7af3eb263c649f7d8855b9069db988b2b --- /dev/null +++ b/data_utils/v2a_utils/vggsound.py @@ -0,0 +1,259 @@ +import logging +import os +from pathlib import Path +from typing import Optional, Union + +import pandas as pd +import torch +import torchaudio +from torch.utils.data.dataset import Dataset +from torchvision.transforms import v2 +from torio.io import StreamingMediaDecoder +from torchvision.utils import save_image + +log = logging.getLogger() + +_CLIP_SIZE = 384 +_CLIP_FPS = 8.0 + +_SYNC_SIZE = 224 +_SYNC_FPS = 25.0 + + +class VGGSound(Dataset): + + def __init__( + self, + root: Union[str, Path], + *, + tsv_path: Union[str, Path] = 'dataset/vggsound/split_txt/train_caption.csv', + sample_rate: int = 44_100, + duration_sec: float = 9.0, + audio_samples: Optional[int] = 397312, + normalize_audio: bool = False, + start_row: Optional[int] = None, + end_row: Optional[int] = None, + save_dir: str = 'data/vggsound/video_latents_text/train' + ): + self.root = Path(root) + self.normalize_audio = normalize_audio + if audio_samples is None: + self.audio_samples = int(sample_rate * duration_sec) + else: + self.audio_samples = audio_samples + effective_duration = audio_samples / sample_rate + # make sure the duration is close enough, within 15ms + assert abs(effective_duration - duration_sec) < 0.015, \ + f'audio_samples {audio_samples} does not match duration_sec {duration_sec}' + + videos = sorted(os.listdir(self.root)) + videos = set([Path(v).stem for v in videos]) # remove extensions + # videos = [] + self.labels = [] + self.videos = [] + missing_videos = [] + # read the tsv for subset information + df_list = pd.read_csv(tsv_path, sep=',', dtype={'id': str}).to_dict('records') + + # 控制处理的行范围 + if start_row is not None and end_row is not None: + df_list = df_list[start_row:end_row] + + for record in df_list: + id = record['id'] + if os.path.exists(f'{save_dir}/{id}.pth'): continue + label = record['caption'] + if id in videos: + # self.labels.append(label) + self.labels[id] = label + self.videos.append(id) + else: + missing_videos.append(id) + + log.info(f'{len(videos)} videos found in {root}') + log.info(f'{len(self.videos)} videos found in {tsv_path}') + log.info(f'{len(missing_videos)} videos missing in {root}') + + self.sample_rate = sample_rate + self.duration_sec = duration_sec + + self.expected_audio_length = self.audio_samples + self.clip_expected_length = int(_CLIP_FPS * self.duration_sec) + self.sync_expected_length = int(_SYNC_FPS * self.duration_sec) + + self.clip_transform = v2.Compose([ + v2.Resize((_CLIP_SIZE, _CLIP_SIZE), interpolation=v2.InterpolationMode.BICUBIC), + v2.ToImage(), + v2.ToDtype(torch.float32, scale=True), + ]) + + self.sync_transform = v2.Compose([ + v2.Resize(_SYNC_SIZE, interpolation=v2.InterpolationMode.BICUBIC), + v2.CenterCrop(_SYNC_SIZE), + v2.ToImage(), + v2.ToDtype(torch.float32, scale=True), + v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + ]) + + self.resampler = {} + + def sample(self, idx: int) -> dict[str, torch.Tensor]: + video_id = self.videos[idx] + label = self.labels[idx] + + reader = StreamingMediaDecoder(self.root / (video_id + '.mp4')) + reader.add_basic_video_stream( + frames_per_chunk=int(_CLIP_FPS * self.duration_sec), + frame_rate=_CLIP_FPS, + format='rgb24', + ) + reader.add_basic_video_stream( + frames_per_chunk=int(_SYNC_FPS * self.duration_sec), + frame_rate=_SYNC_FPS, + format='rgb24', + ) + reader.add_basic_audio_stream(frames_per_chunk=2**30,) + + reader.fill_buffer() + data_chunk = reader.pop_chunks() + + clip_chunk = data_chunk[0] + sync_chunk = data_chunk[1] + audio_chunk = data_chunk[2] + if len(audio_chunk.shape) != 2: + raise RuntimeError(f'error audio shape {video_id}') + if clip_chunk is None: + raise RuntimeError(f'CLIP video returned None {video_id}') + # if clip_chunk.shape[0] < self.clip_expected_length: + # raise RuntimeError( + # f'CLIP video too short {video_id}, expected {self.clip_expected_length}, got {clip_chunk.shape[0]}' + # ) + + if sync_chunk is None: + raise RuntimeError(f'Sync video returned None {video_id}') + # if sync_chunk.shape[0] < self.sync_expected_length: + # raise RuntimeError( + # f'Sync video too short {video_id}, expected {self.sync_expected_length}, got {sync_chunk.shape[0]}' + # ) + # import ipdb + # ipdb.set_trace() + # process audio + sample_rate = int(reader.get_out_stream_info(2).sample_rate) + audio_chunk = audio_chunk.transpose(0, 1) + abs_max = audio_chunk[0].abs().max() + # audio_chunk = audio_chunk.mean(dim=0) # mono + # if self.normalize_audio: + # abs_max = audio_chunk.abs().max() + # audio_chunk = audio_chunk / abs_max * 0.95 + if abs_max <= 1e-6: + if audio_chunk.shape[0] > 1 and audio_chunk[1].abs().max() > 1e-6: + audio_chunk = audio_chunk[1:2] + else: + raise RuntimeError(f'Audio is silent {video_id}') + + + # if abs_max <= 1e-6: + # raise RuntimeError(f'Audio is silent {video_id}') + + # ensure the stereo audio + if audio_chunk.shape[0] < 2: + audio_chunk = audio_chunk.repeat(2, 1) + + # resample + if sample_rate == self.sample_rate: + audio_chunk = audio_chunk + else: + if sample_rate not in self.resampler: + # https://pytorch.org/audio/stable/tutorials/audio_resampling_tutorial.html#kaiser-best + self.resampler[sample_rate] = torchaudio.transforms.Resample( + sample_rate, + self.sample_rate, + lowpass_filter_width=64, + rolloff=0.9475937167399596, + resampling_method='sinc_interp_kaiser', + beta=14.769656459379492, + ) + audio_chunk = self.resampler[sample_rate](audio_chunk) + + if audio_chunk.shape[1] < self.expected_audio_length: + # zero-padding audio + padding_length = self.expected_audio_length - audio_chunk.shape[1] + # 创建 padding 张量,大小为 [batch_size, padding_length],值为0 + padding = torch.zeros(audio_chunk.shape[0], padding_length) + # 将原始音频和 padding 沿第 1 维度拼接在一起 + audio_chunk = torch.cat((audio_chunk, padding), dim=1) + # raise RuntimeError(f'Audio too short {video_id}') + audio_chunk = audio_chunk[:,:self.expected_audio_length] + # truncate the video + clip_chunk = clip_chunk[:self.clip_expected_length] + # import ipdb + # ipdb.set_trace() + if clip_chunk.shape[0] != self.clip_expected_length: + current_length = clip_chunk.shape[0] + padding_needed = self.clip_expected_length - current_length + + # Check that padding needed is no more than 2 + assert padding_needed < 4, f'Padding no more than 2 frames allowed, but {padding_needed} needed' + + # If assertion passes, proceed with padding + if padding_needed > 0: + last_frame = clip_chunk[-1] + log.info(last_frame.shape) + # Repeat the last frame to reach the expected length + padding = last_frame.repeat(padding_needed, 1, 1, 1) + clip_chunk = torch.cat((clip_chunk, padding), dim=0) + # raise RuntimeError(f'CLIP video wrong length {video_id}, ' + # f'expected {self.clip_expected_length}, ' + # f'got {clip_chunk.shape[0]}') + # save_image(clip_chunk[0] / 255.0,'ori.png') + clip_chunk = self.clip_transform(clip_chunk) + # temp_img = clip_chunk[0].permute(1, 2, 0) * 255 + # save_image(clip_chunk[0],'scale.png') + sync_chunk = sync_chunk[:self.sync_expected_length] + if sync_chunk.shape[0] != self.sync_expected_length: + # padding using the last frame, but no more than 2 + current_length = sync_chunk.shape[0] + last_frame = sync_chunk[-1] + # 重复最后一帧以进行填充 + padding = last_frame.repeat(self.sync_expected_length - current_length, 1, 1, 1) + assert self.sync_expected_length - current_length < 12, f'sync can pad no more than 2 while {self.sync_expected_length - current_length}' + sync_chunk = torch.cat((sync_chunk, padding), dim=0) + # raise RuntimeError(f'Sync video wrong length {video_id}, ' + # f'expected {self.sync_expected_length}, ' + # f'got {sync_chunk.shape[0]}') + + sync_chunk = self.sync_transform(sync_chunk) + assert audio_chunk.shape[1] == self.expected_audio_length and clip_chunk.shape[0] == self.clip_expected_length \ + and sync_chunk.shape[0] == self.sync_expected_length, 'error processed data shape' + data = { + 'id': video_id, + 'caption': label, + 'audio': audio_chunk, + 'clip_video': clip_chunk, + 'sync_video': sync_chunk, + } + + return data + + def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: + try: + return self.sample(idx) + except Exception as e: + log.error(f'Error loading video {self.videos[idx]}: {e}') + return None + + def __len__(self): + return len(self.labels) + + +# dataset = VGGSound( +# root="data/vggsound/video/test", +# tsv_path="data/vggsound/split_txt/temp.csv", +# sample_rate=44100, +# duration_sec=9.0, +# audio_samples=397312, +# start_row=0, +# end_row=None, +# save_dir="data/vggsound/video_latents_text/test" +# ) +# dataset[0] \ No newline at end of file diff --git a/data_utils/v2a_utils/vggsound_224.py b/data_utils/v2a_utils/vggsound_224.py new file mode 100644 index 0000000000000000000000000000000000000000..f076fc3017b42b42710b6dbad7144fd4b93c23d1 --- /dev/null +++ b/data_utils/v2a_utils/vggsound_224.py @@ -0,0 +1,320 @@ +import os +from pathlib import Path +from typing import Optional, Union +from PIL import Image + +import pandas as pd +import torch +import torchaudio +from torch.utils.data.dataset import Dataset +from torchvision.transforms import v2 +from torio.io import StreamingMediaDecoder +from torchvision.utils import save_image +from transformers import AutoProcessor +import torch.nn.functional as F +import numpy as np + +import logging +log = logging.getLogger() + +_CLIP_SIZE = 224 +_CLIP_FPS = 8.0 + +_SYNC_SIZE = 224 +_SYNC_FPS = 25.0 + +def save_tensor_as_image(tensor, save_path): + """ + 将形状为 (1, 3, H, W) 的 RGB 图像数组保存为图片文件。 + + :param tensor: 输入的 NumPy 数组 (1, 3, H, W)。 + :param save_path: 图片保存路径。 + """ + # # 移除批次维度,变成 (3, H, W) + # tensor = tensor.squeeze(0) + + # 交换轴顺序,变为 (H, W, 3) + image_array = np.transpose(tensor, (1, 2, 0)) + + # 检查数组是否为合适的数据类型 + if image_array.dtype != np.uint8: + # 如果不是 uint8,首先标准化,然后转换 + image_array = (image_array - image_array.min()) / (image_array.max() - image_array.min()) * 255 + image_array = image_array.astype(np.uint8) + + # 创建图像对象 + image = Image.fromarray(image_array) + + # 保存图片 + image.save(save_path) + print(f"Image saved to {save_path}") + +def pad_to_square(video_tensor): + # 验证输入的形状 + if len(video_tensor.shape) != 4: + raise ValueError("Input tensor must have shape (l, c, h, w)") + + l, c, h, w = video_tensor.shape + max_side = max(h, w) + + # 计算每一维度需要的填充量:(left, right, top, bottom) + pad_h = max_side - h + pad_w = max_side - w + + # 创建padding tuple (left, right, top, bottom) + # 因为图像的填充是作用在最后两个维度 h 和 w 上,所以我们需要指定这两个维度的填充 + padding = (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2) + + # 使用F.pad对视频张量进行填充操作 + # 填充参数为 (left, right, top, bottom) + video_padded = F.pad(video_tensor, pad=padding, mode='constant', value=0) + + return video_padded + +class VGGSound(Dataset): + + def __init__( + self, + root: Union[str, Path], + *, + tsv_path: Union[str, Path] = 'dataset/vggsound/split_txt/train_caption.csv', + sample_rate: int = 44_100, + duration_sec: float = 9.0, + audio_samples: Optional[int] = 397312, + normalize_audio: bool = False, + start_row: Optional[int] = None, + end_row: Optional[int] = None, + save_dir: str = 'data/vggsound/video_latents_text/train' + ): + self.root = Path(root) + self.normalize_audio = normalize_audio + if audio_samples is None: + self.audio_samples = int(sample_rate * duration_sec) + else: + self.audio_samples = audio_samples + effective_duration = audio_samples / sample_rate + # make sure the duration is close enough, within 15ms + assert abs(effective_duration - duration_sec) < 0.015, \ + f'audio_samples {audio_samples} does not match duration_sec {duration_sec}' + + # videos = sorted(os.listdir(self.root)) + # videos = set([Path(v).stem for v in videos]) # remove extensions + videos = [] + self.labels = [] + self.videos = [] + missing_videos = [] + # read the tsv for subset information + df_list = pd.read_csv(tsv_path, sep=',', dtype={'id': str}).to_dict('records') + + # 控制处理的行范围 + if start_row is not None and end_row is not None: + df_list = df_list[start_row:end_row] + + for record in df_list: + id = record['id'] + if os.path.exists(f'{save_dir}/{id}.pth'): continue + label = record['label'] + # if id in videos: + self.labels.append(label) + # self.labels[id] = label + self.videos.append(id) + # else: + # missing_videos.append(id) + + log.info(f'{len(videos)} videos found in {root}') + log.info(f'{len(self.videos)} videos found in {tsv_path}') + log.info(f'{len(missing_videos)} videos missing in {root}') + + self.sample_rate = sample_rate + self.duration_sec = duration_sec + + self.expected_audio_length = self.audio_samples + self.clip_expected_length = int(_CLIP_FPS * self.duration_sec) + self.sync_expected_length = int(_SYNC_FPS * self.duration_sec) + + self.clip_transform = v2.Compose([ + v2.Lambda(pad_to_square), # 先填充为正方形 + v2.Resize((_CLIP_SIZE, _CLIP_SIZE), interpolation=v2.InterpolationMode.BICUBIC), + v2.ToImage(), + v2.ToDtype(torch.float32, scale=True), + ]) + self.clip_processor = AutoProcessor.from_pretrained("facebook/metaclip-h14-fullcc2.5b") + self.sync_transform = v2.Compose([ + v2.Resize(_SYNC_SIZE, interpolation=v2.InterpolationMode.BICUBIC), + v2.CenterCrop(_SYNC_SIZE), + v2.ToImage(), + v2.ToDtype(torch.float32, scale=True), + v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + ]) + + self.resampler = {} + + def sample(self, idx: int) -> dict[str, torch.Tensor]: + video_id = self.videos[idx] + label = self.labels[idx] + + reader = StreamingMediaDecoder(self.root / (video_id + '.mp4')) + reader.add_basic_video_stream( + frames_per_chunk=int(_CLIP_FPS * self.duration_sec), + frame_rate=_CLIP_FPS, + format='rgb24', + ) + reader.add_basic_video_stream( + frames_per_chunk=int(_SYNC_FPS * self.duration_sec), + frame_rate=_SYNC_FPS, + format='rgb24', + ) + reader.add_basic_audio_stream(frames_per_chunk=2**30,) + + reader.fill_buffer() + data_chunk = reader.pop_chunks() + + clip_chunk = data_chunk[0] + sync_chunk = data_chunk[1] + audio_chunk = data_chunk[2] + if len(audio_chunk.shape) != 2: + raise RuntimeError(f'error audio shape {video_id}') + if clip_chunk is None: + raise RuntimeError(f'CLIP video returned None {video_id}') + # if clip_chunk.shape[0] < self.clip_expected_length: + # raise RuntimeError( + # f'CLIP video too short {video_id}, expected {self.clip_expected_length}, got {clip_chunk.shape[0]}' + # ) + + if sync_chunk is None: + raise RuntimeError(f'Sync video returned None {video_id}') + # if sync_chunk.shape[0] < self.sync_expected_length: + # raise RuntimeError( + # f'Sync video too short {video_id}, expected {self.sync_expected_length}, got {sync_chunk.shape[0]}' + # ) + # import ipdb + # ipdb.set_trace() + # process audio + # import ipdb + # ipdb.set_trace() + sample_rate = int(reader.get_out_stream_info(2).sample_rate) + audio_chunk = audio_chunk.transpose(0, 1) + abs_max = audio_chunk[0].abs().max() + # audio_chunk = audio_chunk.mean(dim=0) # mono + # if self.normalize_audio: + # abs_max = audio_chunk.abs().max() + # audio_chunk = audio_chunk / abs_max * 0.95 + if abs_max <= 1e-6: + if audio_chunk.shape[0] > 1 and audio_chunk[1].abs().max() > 1e-6: + audio_chunk = audio_chunk[1:2] + else: + raise RuntimeError(f'Audio is silent {video_id}') + + # ensure the stereo audio + if audio_chunk.shape[0] < 2: + audio_chunk = audio_chunk.repeat(2, 1) + + # resample + if sample_rate == self.sample_rate: + audio_chunk = audio_chunk + else: + if sample_rate not in self.resampler: + # https://pytorch.org/audio/stable/tutorials/audio_resampling_tutorial.html#kaiser-best + self.resampler[sample_rate] = torchaudio.transforms.Resample( + sample_rate, + self.sample_rate, + lowpass_filter_width=64, + rolloff=0.9475937167399596, + resampling_method='sinc_interp_kaiser', + beta=14.769656459379492, + ) + audio_chunk = self.resampler[sample_rate](audio_chunk) + + if audio_chunk.shape[1] < self.expected_audio_length: + # zero-padding audio + padding_length = self.expected_audio_length - audio_chunk.shape[1] + # 创建 padding 张量,大小为 [batch_size, padding_length],值为0 + padding = torch.zeros(audio_chunk.shape[0], padding_length) + # 将原始音频和 padding 沿第 1 维度拼接在一起 + audio_chunk = torch.cat((audio_chunk, padding), dim=1) + # raise RuntimeError(f'Audio too short {video_id}') + audio_chunk = audio_chunk[:,:self.expected_audio_length] + # truncate the video + clip_chunk = clip_chunk[:self.clip_expected_length] + # import ipdb + # ipdb.set_trace() + if clip_chunk.shape[0] != self.clip_expected_length: + current_length = clip_chunk.shape[0] + padding_needed = self.clip_expected_length - current_length + + # Check that padding needed is no more than 2 + assert padding_needed < 4, f'Padding no more than 2 frames allowed, but {padding_needed} needed' + + # If assertion passes, proceed with padding + if padding_needed > 0: + last_frame = clip_chunk[-1] + log.info(last_frame.shape) + # Repeat the last frame to reach the expected length + padding = last_frame.repeat(padding_needed, 1, 1, 1) + clip_chunk = torch.cat((clip_chunk, padding), dim=0) + # raise RuntimeError(f'CLIP video wrong length {video_id}, ' + # f'expected {self.clip_expected_length}, ' + # f'got {clip_chunk.shape[0]}') + + # save_image(clip_chunk[0] / 255.0,'ori.png') + clip_chunk = pad_to_square(clip_chunk) + # save_image(clip_chunk[0] / 255.0,'square.png') + # clip_chunk = self.clip_transform(clip_chunk) + # import ipdb + # ipdb.set_trace() + clip_chunk = self.clip_processor(images=clip_chunk, return_tensors="pt")["pixel_values"] + # log.info(clip_chunk.shape) + # save_tensor_as_image(clip_chunk[0].numpy(),'scale.png') + # log.info(clip_chunk[0]) + # clip_chunk = outputs + # text_ids = outputs["input_ids"] + # temp_img = clip_chunk[0].permute(1, 2, 0) * 255 + # save_image(clip_chunk[0],'scale.png') + sync_chunk = sync_chunk[:self.sync_expected_length] + if sync_chunk.shape[0] != self.sync_expected_length: + # padding using the last frame, but no more than 2 + current_length = sync_chunk.shape[0] + last_frame = sync_chunk[-1] + # 重复最后一帧以进行填充 + padding = last_frame.repeat(self.sync_expected_length - current_length, 1, 1, 1) + assert self.sync_expected_length - current_length < 12, f'sync can pad no more than 2 while {self.sync_expected_length - current_length}' + sync_chunk = torch.cat((sync_chunk, padding), dim=0) + # raise RuntimeError(f'Sync video wrong length {video_id}, ' + # f'expected {self.sync_expected_length}, ' + # f'got {sync_chunk.shape[0]}') + + sync_chunk = self.sync_transform(sync_chunk) + assert audio_chunk.shape[1] == self.expected_audio_length and clip_chunk.shape[0] == self.clip_expected_length \ + and sync_chunk.shape[0] == self.sync_expected_length, 'error processed data shape' + data = { + 'id': video_id, + 'caption': label, + 'audio': audio_chunk, + 'clip_video': clip_chunk, + 'sync_video': sync_chunk, + } + + return data + + def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: + try: + return self.sample(idx) + except Exception as e: + log.error(f'Error loading video {self.videos[idx]}: {e}') + return None + + def __len__(self): + return len(self.labels) + + +# dataset = VGGSound( +# root="data/vggsound/video/train", +# tsv_path="data/vggsound/split_txt/temp.csv", +# sample_rate=44100, +# duration_sec=9.0, +# audio_samples=397312, +# start_row=0, +# end_row=None, +# save_dir="data/vggsound/video_224_latents_text/train" +# ) +# dataset[0] \ No newline at end of file diff --git a/data_utils/v2a_utils/vggsound_224_no_audio.py b/data_utils/v2a_utils/vggsound_224_no_audio.py new file mode 100644 index 0000000000000000000000000000000000000000..51b6ad69f3afca5d10764cc9eebc0daf848f810a --- /dev/null +++ b/data_utils/v2a_utils/vggsound_224_no_audio.py @@ -0,0 +1,275 @@ +import os +from pathlib import Path +from typing import Optional, Union +from PIL import Image + +import pandas as pd +import torch +import torchaudio +from torch.utils.data.dataset import Dataset +from torchvision.transforms import v2 +from torio.io import StreamingMediaDecoder +from torchvision.utils import save_image +from transformers import AutoProcessor +import torch.nn.functional as F +import numpy as np + +import logging +log = logging.getLogger() + +_CLIP_SIZE = 224 +_CLIP_FPS = 8.0 + +_SYNC_SIZE = 224 +_SYNC_FPS = 25.0 + +def save_tensor_as_image(tensor, save_path): + """ + 将形状为 (1, 3, H, W) 的 RGB 图像数组保存为图片文件。 + + :param tensor: 输入的 NumPy 数组 (1, 3, H, W)。 + :param save_path: 图片保存路径。 + """ + # # 移除批次维度,变成 (3, H, W) + # tensor = tensor.squeeze(0) + + # 交换轴顺序,变为 (H, W, 3) + image_array = np.transpose(tensor, (1, 2, 0)) + + # 检查数组是否为合适的数据类型 + if image_array.dtype != np.uint8: + # 如果不是 uint8,首先标准化,然后转换 + image_array = (image_array - image_array.min()) / (image_array.max() - image_array.min()) * 255 + image_array = image_array.astype(np.uint8) + + # 创建图像对象 + image = Image.fromarray(image_array) + + # 保存图片 + image.save(save_path) + print(f"Image saved to {save_path}") + +def pad_to_square(video_tensor): + # 验证输入的形状 + if len(video_tensor.shape) != 4: + raise ValueError("Input tensor must have shape (l, c, h, w)") + + l, c, h, w = video_tensor.shape + max_side = max(h, w) + + # 计算每一维度需要的填充量:(left, right, top, bottom) + pad_h = max_side - h + pad_w = max_side - w + + # 创建padding tuple (left, right, top, bottom) + # 因为图像的填充是作用在最后两个维度 h 和 w 上,所以我们需要指定这两个维度的填充 + padding = (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2) + + # 使用F.pad对视频张量进行填充操作 + # 填充参数为 (left, right, top, bottom) + video_padded = F.pad(video_tensor, pad=padding, mode='constant', value=0) + + return video_padded + +class VGGSound(Dataset): + + def __init__( + self, + root: Union[str, Path], + *, + tsv_path: Union[str, Path] = 'dataset/vggsound/split_txt/train_caption.csv', + sample_rate: int = 44_100, + duration_sec: float = 9.0, + audio_samples: Optional[int] = 397312, + normalize_audio: bool = False, + start_row: Optional[int] = None, + end_row: Optional[int] = None, + save_dir: str = 'data/vggsound/video_latents_text/train' + ): + self.root = Path(root) + self.normalize_audio = normalize_audio + if audio_samples is None: + self.audio_samples = int(sample_rate * duration_sec) + else: + self.audio_samples = audio_samples + effective_duration = audio_samples / sample_rate + # make sure the duration is close enough, within 15ms + assert abs(effective_duration - duration_sec) < 0.015, \ + f'audio_samples {audio_samples} does not match duration_sec {duration_sec}' + + # videos = sorted(os.listdir(self.root)) + # videos = set([Path(v).stem for v in videos]) # remove extensions + videos = [] + self.labels = [] + self.videos = [] + self.caption_cot = [] + missing_videos = [] + # read the tsv for subset information + df_list = pd.read_csv(tsv_path, sep=',', dtype={'id': str}).to_dict('records') + + # 控制处理的行范围 + if start_row is not None and end_row is not None: + df_list = df_list[start_row:end_row] + + for record in df_list: + id = record['id'] + if os.path.exists(f'{save_dir}/{id}.pth'): continue + label = record['caption'] + caption_cot = record['caption_cot'] + # if id in videos: + self.labels.append(label) + # self.labels[id] = label + self.videos.append(id) + self.caption_cot.append(caption_cot) + # else: + # missing_videos.append(id) + + log.info(f'{len(videos)} videos found in {root}') + log.info(f'{len(self.videos)} videos found in {tsv_path}') + log.info(f'{len(missing_videos)} videos missing in {root}') + + self.sample_rate = sample_rate + self.duration_sec = duration_sec + + self.expected_audio_length = self.audio_samples + self.clip_expected_length = int(_CLIP_FPS * self.duration_sec) + self.sync_expected_length = int(_SYNC_FPS * self.duration_sec) + + self.clip_transform = v2.Compose([ + v2.Lambda(pad_to_square), # 先填充为正方形 + v2.Resize((_CLIP_SIZE, _CLIP_SIZE), interpolation=v2.InterpolationMode.BICUBIC), + v2.ToImage(), + v2.ToDtype(torch.float32, scale=True), + ]) + self.clip_processor = AutoProcessor.from_pretrained("facebook/metaclip-h14-fullcc2.5b") + self.sync_transform = v2.Compose([ + v2.Resize(_SYNC_SIZE, interpolation=v2.InterpolationMode.BICUBIC), + v2.CenterCrop(_SYNC_SIZE), + v2.ToImage(), + v2.ToDtype(torch.float32, scale=True), + v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + ]) + + self.resampler = {} + + def sample(self, idx: int) -> dict[str, torch.Tensor]: + video_id = self.videos[idx] + label = self.labels[idx] + caption_cot = self.caption_cot[idx] + + reader = StreamingMediaDecoder(self.root / (video_id + '.mp4')) + reader.add_basic_video_stream( + frames_per_chunk=int(_CLIP_FPS * self.duration_sec), + frame_rate=_CLIP_FPS, + format='rgb24', + ) + reader.add_basic_video_stream( + frames_per_chunk=int(_SYNC_FPS * self.duration_sec), + frame_rate=_SYNC_FPS, + format='rgb24', + ) + # reader.add_basic_audio_stream(frames_per_chunk=2**30,) + + reader.fill_buffer() + data_chunk = reader.pop_chunks() + + clip_chunk = data_chunk[0] + sync_chunk = data_chunk[1] + # audio_chunk = data_chunk[2] + # if len(audio_chunk.shape) != 2: + # raise RuntimeError(f'error audio shape {video_id}') + if clip_chunk is None: + raise RuntimeError(f'CLIP video returned None {video_id}') + # if clip_chunk.shape[0] < self.clip_expected_length: + # raise RuntimeError( + # f'CLIP video too short {video_id}, expected {self.clip_expected_length}, got {clip_chunk.shape[0]}' + # ) + + if sync_chunk is None: + raise RuntimeError(f'Sync video returned None {video_id}') + + # truncate the video + clip_chunk = clip_chunk[:self.clip_expected_length] + # import ipdb + # ipdb.set_trace() + if clip_chunk.shape[0] != self.clip_expected_length: + current_length = clip_chunk.shape[0] + padding_needed = self.clip_expected_length - current_length + + # Check that padding needed is no more than 2 + # assert padding_needed < 4, f'Padding no more than 2 frames allowed, but {padding_needed} needed' + + # If assertion passes, proceed with padding + if padding_needed > 0: + last_frame = clip_chunk[-1] + log.info(last_frame.shape) + # Repeat the last frame to reach the expected length + padding = last_frame.repeat(padding_needed, 1, 1, 1) + clip_chunk = torch.cat((clip_chunk, padding), dim=0) + # raise RuntimeError(f'CLIP video wrong length {video_id}, ' + # f'expected {self.clip_expected_length}, ' + # f'got {clip_chunk.shape[0]}') + + # save_image(clip_chunk[0] / 255.0,'ori.png') + clip_chunk = pad_to_square(clip_chunk) + # save_image(clip_chunk[0] / 255.0,'square.png') + # clip_chunk = self.clip_transform(clip_chunk) + # import ipdb + # ipdb.set_trace() + clip_chunk = self.clip_processor(images=clip_chunk, return_tensors="pt")["pixel_values"] + # log.info(clip_chunk.shape) + # save_tensor_as_image(clip_chunk[0].numpy(),'scale.png') + # log.info(clip_chunk[0]) + # clip_chunk = outputs + # text_ids = outputs["input_ids"] + # temp_img = clip_chunk[0].permute(1, 2, 0) * 255 + # save_image(clip_chunk[0],'scale.png') + sync_chunk = sync_chunk[:self.sync_expected_length] + if sync_chunk.shape[0] != self.sync_expected_length: + # padding using the last frame, but no more than 2 + current_length = sync_chunk.shape[0] + last_frame = sync_chunk[-1] + # 重复最后一帧以进行填充 + padding = last_frame.repeat(self.sync_expected_length - current_length, 1, 1, 1) + # assert self.sync_expected_length - current_length < 12, f'sync can pad no more than 2 while {self.sync_expected_length - current_length}' + sync_chunk = torch.cat((sync_chunk, padding), dim=0) + # raise RuntimeError(f'Sync video wrong length {video_id}, ' + # f'expected {self.sync_expected_length}, ' + # f'got {sync_chunk.shape[0]}') + + sync_chunk = self.sync_transform(sync_chunk) + # assert audio_chunk.shape[1] == self.expected_audio_length and clip_chunk.shape[0] == self.clip_expected_length \ + # and sync_chunk.shape[0] == self.sync_expected_length, 'error processed data shape' + data = { + 'id': video_id, + 'caption': label, + # 'audio': audio_chunk, + 'clip_video': clip_chunk, + 'sync_video': sync_chunk, + 'caption_cot': caption_cot, + } + + return data + + def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: + try: + return self.sample(idx) + except Exception as e: + log.error(f'Error loading video {self.videos[idx]}: {e}') + return None + + def __len__(self): + return len(self.labels) + + +# dataset = VGGSound( +# root="data/vggsound/video/train", +# tsv_path="data/vggsound/split_txt/temp.csv", +# sample_rate=44100, +# duration_sec=9.0, +# audio_samples=397312, +# start_row=0, +# end_row=None, +# save_dir="data/vggsound/video_224_latents_text/train" +# ) +# dataset[0] \ No newline at end of file diff --git a/data_utils/v2a_utils/vggsound_224_no_sync.py b/data_utils/v2a_utils/vggsound_224_no_sync.py new file mode 100644 index 0000000000000000000000000000000000000000..87b301e8e1f4e09ced92a01614c78b93435cc327 --- /dev/null +++ b/data_utils/v2a_utils/vggsound_224_no_sync.py @@ -0,0 +1,223 @@ +import os +from pathlib import Path +from typing import Optional, Union +from PIL import Image + +import pandas as pd +import torch +import torchaudio +from torch.utils.data.dataset import Dataset +from torchvision.transforms import v2 +from torio.io import StreamingMediaDecoder +from torchvision.utils import save_image +from transformers import AutoProcessor +import torch.nn.functional as F +import numpy as np + +import logging +log = logging.getLogger() + +_CLIP_SIZE = 224 +_CLIP_FPS = 8.0 + +_SYNC_SIZE = 224 +_SYNC_FPS = 25.0 + +def save_tensor_as_image(tensor, save_path): + """ + 将形状为 (1, 3, H, W) 的 RGB 图像数组保存为图片文件。 + + :param tensor: 输入的 NumPy 数组 (1, 3, H, W)。 + :param save_path: 图片保存路径。 + """ + # # 移除批次维度,变成 (3, H, W) + # tensor = tensor.squeeze(0) + + # 交换轴顺序,变为 (H, W, 3) + image_array = np.transpose(tensor, (1, 2, 0)) + + # 检查数组是否为合适的数据类型 + if image_array.dtype != np.uint8: + # 如果不是 uint8,首先标准化,然后转换 + image_array = (image_array - image_array.min()) / (image_array.max() - image_array.min()) * 255 + image_array = image_array.astype(np.uint8) + + # 创建图像对象 + image = Image.fromarray(image_array) + + # 保存图片 + image.save(save_path) + print(f"Image saved to {save_path}") + +def pad_to_square(video_tensor): + # 验证输入的形状 + if len(video_tensor.shape) != 4: + raise ValueError("Input tensor must have shape (l, c, h, w)") + + l, c, h, w = video_tensor.shape + max_side = max(h, w) + + # 计算每一维度需要的填充量:(left, right, top, bottom) + pad_h = max_side - h + pad_w = max_side - w + + # 创建padding tuple (left, right, top, bottom) + # 因为图像的填充是作用在最后两个维度 h 和 w 上,所以我们需要指定这两个维度的填充 + padding = (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2) + + # 使用F.pad对视频张量进行填充操作 + # 填充参数为 (left, right, top, bottom) + video_padded = F.pad(video_tensor, pad=padding, mode='constant', value=0) + + return video_padded + +class VGGSound(Dataset): + + def __init__( + self, + root: Union[str, Path], + *, + tsv_path: Union[str, Path] = 'dataset/vggsound/split_txt/train_caption.csv', + sample_rate: int = 44_100, + duration_sec: float = 9.0, + audio_samples: Optional[int] = 397312, + normalize_audio: bool = False, + start_row: Optional[int] = None, + end_row: Optional[int] = None, + save_dir: str = 'data/vggsound/video_latents_text/train' + ): + self.root = Path(root) + self.normalize_audio = normalize_audio + if audio_samples is None: + self.audio_samples = int(sample_rate * duration_sec) + else: + self.audio_samples = audio_samples + effective_duration = audio_samples / sample_rate + # make sure the duration is close enough, within 15ms + assert abs(effective_duration - duration_sec) < 0.015, \ + f'audio_samples {audio_samples} does not match duration_sec {duration_sec}' + + # videos = sorted(os.listdir(self.root)) + # videos = set([Path(v).stem for v in videos]) # remove extensions + videos = [] + self.labels = [] + self.videos = [] + missing_videos = [] + # read the tsv for subset information + df_list = pd.read_csv(tsv_path, sep=',', dtype={'id': str}).to_dict('records') + + # 控制处理的行范围 + if start_row is not None and end_row is not None: + df_list = df_list[start_row:end_row] + + for record in df_list: + id = record['id'] + if os.path.exists(f'{save_dir}/{id}.pth'): continue + label = record['label'] + # if id in videos: + self.labels.append(label) + # self.labels[id] = label + self.videos.append(id) + # else: + # missing_videos.append(id) + + log.info(f'{len(videos)} videos found in {root}') + log.info(f'{len(self.videos)} videos found in {tsv_path}') + log.info(f'{len(missing_videos)} videos missing in {root}') + + self.sample_rate = sample_rate + self.duration_sec = duration_sec + + self.expected_audio_length = self.audio_samples + self.clip_expected_length = int(_CLIP_FPS * self.duration_sec) + self.sync_expected_length = int(_SYNC_FPS * self.duration_sec) + + self.clip_transform = v2.Compose([ + v2.Lambda(pad_to_square), # 先填充为正方形 + v2.Resize((_CLIP_SIZE, _CLIP_SIZE), interpolation=v2.InterpolationMode.BICUBIC), + v2.ToImage(), + v2.ToDtype(torch.float32, scale=True), + ]) + self.clip_processor = AutoProcessor.from_pretrained("useful_ckpts/metaclip-huge") + + self.resampler = {} + + def sample(self, idx: int) -> dict[str, torch.Tensor]: + video_id = self.videos[idx] + label = self.labels[idx] + + reader = StreamingMediaDecoder(self.root / (video_id + '.mp4')) + reader.add_basic_video_stream( + frames_per_chunk=int(_CLIP_FPS * self.duration_sec), + frame_rate=_CLIP_FPS, + format='rgb24', + ) + + reader.fill_buffer() + data_chunk = reader.pop_chunks() + + clip_chunk = data_chunk[0] + if clip_chunk is None: + raise RuntimeError(f'CLIP video returned None {video_id}') + + + # truncate the video + clip_chunk = clip_chunk[:self.clip_expected_length] + # import ipdb + # ipdb.set_trace() + if clip_chunk.shape[0] != self.clip_expected_length: + current_length = clip_chunk.shape[0] + padding_needed = self.clip_expected_length - current_length + + # Check that padding needed is no more than 2 + assert padding_needed < 4, f'Padding no more than 2 frames allowed, but {padding_needed} needed' + + # If assertion passes, proceed with padding + if padding_needed > 0: + last_frame = clip_chunk[-1] + log.info(last_frame.shape) + # Repeat the last frame to reach the expected length + padding = last_frame.repeat(padding_needed, 1, 1, 1) + clip_chunk = torch.cat((clip_chunk, padding), dim=0) + # raise RuntimeError(f'CLIP video wrong length {video_id}, ' + # f'expected {self.clip_expected_length}, ' + # f'got {clip_chunk.shape[0]}') + + # save_image(clip_chunk[0] / 255.0,'ori.png') + clip_chunk = pad_to_square(clip_chunk) + # save_image(clip_chunk[0] / 255.0,'square.png') + # clip_chunk = self.clip_transform(clip_chunk) + # import ipdb + # ipdb.set_trace() + clip_chunk = self.clip_processor(images=clip_chunk, return_tensors="pt")["pixel_values"] + + data = { + 'id': video_id, + 'caption': label, + 'clip_video': clip_chunk, + } + + return data + + def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: + try: + return self.sample(idx) + except Exception as e: + log.error(f'Error loading video {self.videos[idx]}: {e}') + return None + + def __len__(self): + return len(self.labels) + + +# dataset = VGGSound( +# root="data/vggsound/video/train", +# tsv_path="data/vggsound/split_txt/temp.csv", +# sample_rate=44100, +# duration_sec=9.0, +# audio_samples=397312, +# start_row=0, +# end_row=None, +# save_dir="data/vggsound/video_224_latents_text/train" +# ) +# dataset[0] \ No newline at end of file diff --git a/data_utils/v2a_utils/vggsound_text.py b/data_utils/v2a_utils/vggsound_text.py new file mode 100644 index 0000000000000000000000000000000000000000..8f097f19136b3084df1434d4f2b2c6e5ddec88bf --- /dev/null +++ b/data_utils/v2a_utils/vggsound_text.py @@ -0,0 +1,109 @@ +import logging +import os +from pathlib import Path +from typing import Optional, Union + +import pandas as pd +import torch +import torchaudio +from torch.utils.data.dataset import Dataset +from torchvision.transforms import v2 +from torio.io import StreamingMediaDecoder +from torchvision.utils import save_image + +log = logging.getLogger() + +_CLIP_SIZE = 384 +_CLIP_FPS = 8.0 + +_SYNC_SIZE = 224 +_SYNC_FPS = 25.0 + + +class VGGSound(Dataset): + + def __init__( + self, + root: Union[str, Path], + *, + tsv_path: Union[str, Path] = 'dataset/vggsound/split_txt/train_caption.csv', + start_row: Optional[int] = None, + end_row: Optional[int] = None, + save_dir: str = 'data/vggsound/video_latents_text/train' + ): + self.root = Path(root) + + # videos = sorted(os.listdir(self.root)) + # videos = set([Path(v).stem for v in videos]) # remove extensions + videos = [] + self.labels = [] + self.cots = [] + self.videos = [] + missing_videos = [] + # read the tsv for subset information + df_list = pd.read_csv(tsv_path, sep=',', dtype={'id': str}).to_dict('records') + + # 控制处理的行范围 + if start_row is not None and end_row is not None: + df_list = df_list[start_row:end_row] + + for record in df_list: + id = record['id'] + # if os.path.exists(f'{save_dir}/{id}.pth'): + # continue + # try: + # torch.load(f'{save_dir}/{id}.pth') + # continue + # except: + # print(f'error load file: {save_dir}/{id}.pth') + # os.system(f'rm -f {save_dir}/{id}.pth') + label = record['caption'] + # if id in videos: + self.labels.append(label) + self.cots.append(record['caption_cot']) + # self.labels[id] = label + self.videos.append(id) + # else: + # missing_videos.append(id) + + log.info(f'{len(videos)} videos found in {root}') + log.info(f'{len(self.videos)} videos found in {tsv_path}') + log.info(f'{len(missing_videos)} videos missing in {root}') + + + + + def sample(self, idx: int) -> dict[str, torch.Tensor]: + video_id = self.videos[idx] + label = self.labels[idx] + cot = self.cots[idx] + data = { + 'id': video_id, + 'caption': label, + 'caption_cot': cot + } + + return data + + def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: + try: + return self.sample(idx) + except Exception as e: + log.error(f'Error loading video {self.videos[idx]}: {e}') + return None + + def __len__(self): + return len(self.labels) + + +# dataset = VGGSound( +# root="data/vggsound/video/test", +# tsv_path="data/vggsound/split_txt/temp.csv", +# sample_rate=44100, +# duration_sec=9.0, +# audio_samples=397312, +# start_row=0, +# end_row=None, +# save_dir="data/vggsound/video_latents_text/test" +# ) +# dataset[0] \ No newline at end of file diff --git a/defaults.ini b/defaults.ini index 2b7cf194a68b3bab4d28eb3aade5cf71c01a7065..48bc46c720507ef52033b8e861e0cdce7de001be 100644 --- a/defaults.ini +++ b/defaults.ini @@ -2,14 +2,14 @@ [DEFAULTS] #name of the run -name = think_sound +name = stable_audio_tools # the batch size batch_size = 8 -test_batch_size = 32 +test_batch_size = 1 # predict ckpt directory -ckpt_dir = "" +ckpt_dir = "ckpts" # number of GPUs to use for training num_gpus = 1 @@ -61,4 +61,8 @@ remove_pretransform_weight_norm = '' compile = False -repeat_num = 5 \ No newline at end of file +repeat_num = 5 + +duration_sec = '9' + +results_dir = 'results' \ No newline at end of file diff --git a/demo_test.csv b/demo_test.csv new file mode 100644 index 0000000000000000000000000000000000000000..ffb66451b92e8c48e8bcb49fa9dd707e8246a24e --- /dev/null +++ b/demo_test.csv @@ -0,0 +1,17 @@ +id,caption,caption_cot +W1nb2hIeDKc_000021,striking bowling,"Start with a background of ambient music, then add consistent sounds of bowling balls striking pins to emphasize the action. Include occasional subtle sounds of pins rattling and settling. Keep human voices or other noises minimal or absent for authenticity." +YYRdv32TJnc_000184,plastic bottle crushing,"Start with the sound of crushing plastic bottles, including crinkling and crunching. Add background noise resembling a factory environment, with machinery sounds. Incorporate subtle rustling and paper crinkling to suggest manipulation of plastic items." +Rp39_WnX5Fk_000380,"subway, metro, underground","Generate subway sounds including ambient station noise, train doors opening and closing, engine hum, wheels on tracks, and conductor announcements to produce an accurate underground train environment." +-KqXcm-I2zY_000087,playing tennis,"Generate sounds of tennis hitting a racket, the ball bouncing, and the girl’s grunts, with distant tennis court ambient noise. Avoid unrelated sounds like horses, basketballs, or indoor voices. Focus on clear tennis scene with realistic audio cues." +0W_wPc-zV3I_000101,hedge trimmer running,"Generate the sound of a hedge trimmer running steadily, focusing on consistent motor noise and cutting sounds. Ensure minimal background noise or voices, capturing the primary sound of the trimmer in operation. Avoid including any chainsaw or unrelated sounds for accuracy." +_Betmm6FaWo_000096,writing on blackboard with chalk,"The audio should feature consistent sounds of chalk scratching the blackboard, including occasional voice instructions, encouragement, and children’s chatter, with background music playing softly or fading in/out to match the scene's atmosphere. The sounds of laughter and chatter should be lively but balanced with the primary chalk and voice sounds for clarity. Overall, the audio combines educational sounds with background activity to reflect a classroom or play environment." +xmTfE3F2huE_000854,chopping food,"Generate rhythmic chopping sounds consistent with meat or food being sliced, incorporating occasional rustling noises like a plastic bag. Avoid adding human voices or train sounds to match the correct audio descriptions, ensuring a focused, realistic kitchen chopping scene." +ZaUaqnLdg6k_000030,skateboarding,"Generate the audio featuring skateboarding sounds with wheels rolling on various surfaces, including ramps, rails, and sidewalks, capturing the sound of tricks and landings. Include subtle ambient background noise to suggest an outdoor setting, avoiding any human voices or singing. Focus on realistic skateboarding sounds, emphasizing wheel contact, impacts, and movement." +_ZC6yk5iE1I_000026,playing trumpet,"Generate a continuous trumpet sound with melodic variations, mimicking the sound of a person playing the trumpet idealy in a musical setting, ensuring clarity and realistic tone. Avoid extraneous noise or background sounds to reflect the focus on trumpet playing. The audio should resemble a skilled player producing expressive, melodious trumpet notes." +55L7peYRB_Q_000120,using sewing machines,"Generate ambient sewing room sounds with consistent sewing machine hum, minimal background noise, and no human voices, focusing on characteristic machine noise to match the correct descriptions." +4p8n4Zf-WMM_000190,lighting firecrackers,"Generate the sound of firecrackers lighting and exploding repeatedly, mixed with distant background sounds of crickets chirping. Incorporate occasional subtle echoes to mimic outdoor night ambiance, with no human voices present. End with a series of sharp cracker bursts to create a lively, festive atmosphere." +yLazKv68TeA_000078,people eating crisps,"Create audio with consistent crisp sounds of people eating chips, including crinkling paper and breathing. Include subtle chewing noises to match the activity. Avoid background music or voices for clarity." +_XyxrZDZ36E_000034,hammering nails,"Generate audio with consistent hammering sounds, featuring a rhythmic pattern of nails being driven into a surface, with occasional ambient background sounds like birds chirping and distant traffic. Avoid human voices, focusing on realistic hammer strikes and natural outdoor environment sounds. Ensure the hammering tone is steady and clear, matching the description of continuous nail hammering." +1u1orBeV4xI_000428,ripping paper,"Start with a subtle tearing sound of paper being ripped, emphasizing a continuous, consistent noise. Ensure the sound has slight variations to mimic real tearing. No background or additional noises are needed, focusing solely on the tearing action." +JFG4YvcJ3bo_000228,playing bongo,"Generate a lively percussion track featuring rhythmic djembe beats, with a melodic guitar strumming softly in the background to enhance the musical atmosphere. Ensure no human voice is included, focusing on the percussive and guitar sounds. Maintain a natural, well-balanced stereo mix to highlight the instruments' interplay." +1pViEqMXJH0_000030,printer printing,"Generate a continuous printer printing sound with periodic beeps, resembling typical printer noise, including paper movement and occasional beeps for realism. Add subtle ambient background noise, like faint room sounds, to enhance authenticity. Ensure the primary focus remains on the printing and beeping sounds, consistent with the correct audio descriptions." diff --git a/examples/1.mp4 b/examples/1.mp4 deleted file mode 100644 index 4c08710287c138274986f31555b715529faed0b1..0000000000000000000000000000000000000000 --- a/examples/1.mp4 +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:8884c466292b46510c298a9ee88d8a584c86cb750afb558108c0850413e21e51 -size 634576 diff --git a/examples/1_mute.mp4 b/examples/1_mute.mp4 deleted file mode 100644 index db41cb4f700bddeca0636c1ecb3b7d8af6316c41..0000000000000000000000000000000000000000 --- a/examples/1_mute.mp4 +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:b0ca4223b15583d8099d023bac3e86725bfd5cfbbad771ef67d31c1ad953bdc3 -size 482981 diff --git a/examples/2.mp4 b/examples/2.mp4 deleted file mode 100644 index 84999ec9f13b1df23fe29776b41098787a6b4d30..0000000000000000000000000000000000000000 --- a/examples/2.mp4 +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:f26d23553c926913533c9ad1c3573d8a7dcfb5ecf9ae14ade3e7aea108966a45 -size 1655358 diff --git a/examples/2_mute.mp4 b/examples/2_mute.mp4 deleted file mode 100644 index 798337c4a9b187536095dbf5536b1401f8613b7a..0000000000000000000000000000000000000000 --- a/examples/2_mute.mp4 +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:5bf89328627748411b3b43e61b9e3248af1dea7ace2c0440b5bb94f513640302 -size 1503089 diff --git a/examples/3.mp4 b/examples/3.mp4 deleted file mode 100644 index 6753254714e21981844ebc877689423e090d5046..0000000000000000000000000000000000000000 --- a/examples/3.mp4 +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:9dccddd67a954b12d34d481107c499460c69bebd92913b8f092724fcdf1c5baf -size 1716778 diff --git a/examples/3_mute.mp4 b/examples/3_mute.mp4 deleted file mode 100644 index 188b28129495d6a4d38b9373464f832883f67bdd..0000000000000000000000000000000000000000 --- a/examples/3_mute.mp4 +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:103ed8d5e4fbe8d6954dde463c2a997acd0e21c13895404706dbbaab39e2b086 -size 1564981 diff --git a/examples/4.mp4 b/examples/4.mp4 deleted file mode 100644 index 9582bfd6a76e7011c98c61bd8f3ac97dfd29b7cc..0000000000000000000000000000000000000000 --- a/examples/4.mp4 +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:e8398d2cd7fab1ccbaeb1b97642e634ec642dbfff181e33e6f017225ae91ea57 -size 1382246 diff --git a/examples/4_mute.mp4 b/examples/4_mute.mp4 deleted file mode 100644 index 86097985a2d81ec8bcd02492f2f4c0a15f0c7317..0000000000000000000000000000000000000000 --- a/examples/4_mute.mp4 +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:65756fa3c9b4a40addbd7bd2dde26cad71a5e379d214cc826e4bfa3597963a19 -size 1230521 diff --git a/examples/5.mp4 b/examples/5.mp4 deleted file mode 100644 index 0587cbf725cc24f8b0f1534dfd0e032ace456d89..0000000000000000000000000000000000000000 --- a/examples/5.mp4 +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:189a6cfcac18470a3f18285013f63f46c7b5996f31e0d3ecf617d9f7f91fdfeb -size 738718 diff --git a/examples/5_mute.mp4 b/examples/5_mute.mp4 deleted file mode 100644 index 993fb7db0431f2b64e347653eda17112de4bbfb4..0000000000000000000000000000000000000000 --- a/examples/5_mute.mp4 +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:103c2b1517d9cbecd81b394c93ebe36f166e95f635ceee280c28073233f08173 -size 586982 diff --git a/extract_latents.py b/extract_latents.py new file mode 100644 index 0000000000000000000000000000000000000000..417664f6cff7acc62fc9ce87e2d65069fd187e0a --- /dev/null +++ b/extract_latents.py @@ -0,0 +1,128 @@ +import argparse +import os +import torch +import torch.distributed as dist +from torch.utils.data import DataLoader +from tqdm import tqdm +import logging +from data_utils.v2a_utils.vggsound_224_no_audio import VGGSound +from data_utils.v2a_utils.feature_utils_224 import FeaturesUtils +import torchaudio +from einops import rearrange +from torch.utils.data.dataloader import default_collate +import numpy as np +from huggingface_hub import hf_hub_download +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') + +def setup(rank, world_size): + dist.init_process_group("nccl", rank=rank, world_size=world_size) + torch.cuda.set_device(rank) + +def cleanup(): + dist.destroy_process_group() + +def error_avoidance_collate(batch): + batch = list(filter(lambda x: x is not None, batch)) + return default_collate(batch) + +def main(args): + + print(f"Using root: {args.root}, tsv_path: {args.tsv_path}, save_dir: {args.save_dir}") + dataset = VGGSound( + root=args.root, + tsv_path=args.tsv_path, + sample_rate=args.sample_rate, + duration_sec=args.duration_sec, + audio_samples=args.audio_samples, + start_row=args.start_row, + end_row=args.end_row, + save_dir=args.save_dir + ) + save_dir = args.save_dir + os.makedirs(save_dir, exist_ok=True) + + dataloader = DataLoader(dataset, batch_size=2, num_workers=8, drop_last=False,collate_fn=error_avoidance_collate) + + print(f"Dataset length: {len(dataset)}") + feature_extractor = FeaturesUtils( + vae_ckpt=None, + vae_config=args.vae_config, + enable_conditions=True, + synchformer_ckpt=args.synchformer_ckpt + ).eval().cuda() + + feature_extractor = feature_extractor + + for i, data in enumerate(tqdm(dataloader, desc="Processing", unit="batch")): + ids = data['id'] + with torch.no_grad(): + # audio = data['audio'].cuda(rank, non_blocking=True) + output = { + 'caption': str(data['caption']), + 'caption_cot': str(data['caption_cot']) + } + print(output) + + # latent = feature_extractor.module.encode_audio(audio) + # output['latent'] = latent.detach().cpu() + + clip_video = data['clip_video'].cuda() + clip_features = feature_extractor.encode_video_with_clip(clip_video) + output['metaclip_features'] = clip_features.detach().cpu() + + sync_video = data['sync_video'].cuda() + sync_features = feature_extractor.encode_video_with_sync(sync_video) + output['sync_features'] = sync_features.detach().cpu() + + caption = data['caption'] + metaclip_global_text_features, metaclip_text_features = feature_extractor.encode_text(caption) + output['metaclip_global_text_features'] = metaclip_global_text_features.detach().cpu() + output['metaclip_text_features'] = metaclip_text_features.detach().cpu() + + caption_cot = data['caption_cot'] + t5_features = feature_extractor.encode_t5_text(caption_cot) + output['t5_features'] = t5_features.detach().cpu() + + for j in range(len(ids)): + sample_output = { + 'id': ids[j], + 'caption': output['caption'][j], + 'caption_cot': output['caption_cot'][j], + # 'latent': output['latent'][j], + 'metaclip_features': output['metaclip_features'][j], + 'sync_features': output['sync_features'][j], + 'metaclip_global_text_features': output['metaclip_global_text_features'][j], + 'metaclip_text_features': output['metaclip_text_features'][j], + 't5_features': output['t5_features'][j], + } + # torch.save(sample_output, f'{save_dir}/{ids[j]}.pth') + np.savez(f'{save_dir}/demo.npz', **sample_output) + + ## test the sync between videos and audios + # torchaudio.save(f'input_{i}.wav',data['audio'],sample_rate=44100) + # recon_audio = feature_extractor.decode_audio(latent) + # recon_audio = rearrange(recon_audio, "b d n -> d (b n)") + # id = data['id'] + # torchaudio.save(f'recon_{i}.wav',recon_audio.cpu(),sample_rate=44100) + # os.system(f'ffmpeg -y -i dataset/vggsound/video/train/{id}.mp4 -i recon_{i}.wav -t 9 -map 0:v -map 1:a -c:v copy -c:a aac -strict experimental -shortest out_{i}.mp4') + # os.system(f'ffmpeg -y -i dataset/vggsound/video/train/{id}.mp4 -i input_{i}.wav -t 9 -map 0:v -map 1:a -c:v copy -c:a aac -strict experimental -shortest input_{i}.mp4') + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Extract Video Training Latents') + parser.add_argument('--root', type=str, default='videos', help='Root directory of the video dataset') + parser.add_argument('--tsv_path', type=str, default='cot_coarse/cot.csv', help='Path to the TSV file') + parser.add_argument('--save-dir', type=str, default='results', help='Save Directory') + parser.add_argument('--sample_rate', type=int, default=44100, help='Sample rate of the audio') + parser.add_argument('--duration_sec', type=float, default=9.0, help='Duration of the audio in seconds') + parser.add_argument('--vae_ckpt', type=str, default='ckpts/vae.ckpt', help='Path to the VAE checkpoint') + parser.add_argument('--vae_config', type=str, default='ThinkSound/configs/model_configs/stable_audio_2_0_vae.json', help='Path to the VAE configuration file') + parser.add_argument('--synchformer_ckpt', type=str, default='ckpts/synchformer_state_dict.pth', help='Path to the Synchformer checkpoint') + parser.add_argument('--start-row', type=int, default=0, help='start row') + parser.add_argument('--end-row', type=int, default=None, help='end row') + + args = parser.parse_args() + args.audio_samples = int(args.sample_rate * args.duration_sec) + + main(args=args) + diff --git a/predict.py b/predict.py new file mode 100644 index 0000000000000000000000000000000000000000..f63a0df1f12d734faec61de8b1e453599f7fa76c --- /dev/null +++ b/predict.py @@ -0,0 +1,214 @@ +from prefigure.prefigure import get_all_args, push_wandb_config +import json +import os +import re +import torch +import torchaudio +# import pytorch_lightning as pl +import lightning as L +from lightning.pytorch.callbacks import Timer, ModelCheckpoint, BasePredictionWriter +from lightning.pytorch.callbacks import Callback +from lightning.pytorch.tuner import Tuner +from lightning.pytorch import seed_everything +import random +from datetime import datetime + +from ThinkSound.data.datamodule import DataModule +from ThinkSound.models import create_model_from_config +from ThinkSound.models.utils import load_ckpt_state_dict, remove_weight_norm_from_model +from ThinkSound.training import create_training_wrapper_from_config, create_demo_callback_from_config +from ThinkSound.training.utils import copy_state_dict +from huggingface_hub import hf_hub_download + +class ExceptionCallback(Callback): + def on_exception(self, trainer, module, err): + print(f'{type(err).__name__}: {err}') + +class ModelConfigEmbedderCallback(Callback): + def __init__(self, model_config): + self.model_config = model_config + + def on_save_checkpoint(self, trainer, pl_module, checkpoint): + checkpoint["model_config"] = self.model_config + +class CustomWriter(BasePredictionWriter): + + def __init__(self, output_dir, write_interval='batch', batch_size=32): + super().__init__(write_interval) + self.output_dir = output_dir + self.batch_size = batch_size + + def write_on_batch_end(self, trainer, pl_module, predictions, batch_indices, batch, batch_idx, dataloader_idx): + + audios = predictions + ids = [item['id'] for item in batch[1]] + current_date = datetime.now() + + formatted_date = current_date.strftime('%m%d') + os.makedirs(os.path.join(self.output_dir, f'{formatted_date}_batch_size{self.batch_size}'),exist_ok=True) + for audio, id in zip(audios, ids): + save_path = os.path.join(self.output_dir, f'{formatted_date}_batch_size{self.batch_size}', f'{id}.wav') + torchaudio.save(save_path, audio, 44100) + +def main(): + + args = get_all_args() + + + # args.pretransform_ckpt_path = hf_hub_download( + # repo_id="liuhuadai/ThinkSound", + # filename="vae.ckpt" + # ) + + args.pretransform_ckpt_path = "./ckpts/vae.ckpt" + + + seed = 10086 + + # Set a different seed for each process if using SLURM + if os.environ.get("SLURM_PROCID") is not None: + seed += int(os.environ.get("SLURM_PROCID")) + + # random.seed(seed) + # torch.manual_seed(seed) + seed_everything(seed, workers=True) + + #Get JSON config from args.model_config + with open(args.model_config) as f: + model_config = json.load(f) + + with open(args.dataset_config) as f: + dataset_config = json.load(f) + + for td in dataset_config["test_datasets"]: + td["path"] = args.results_dir + + # train_dl = create_dataloader_from_config( + # dataset_config, + # batch_size=args.batch_size, + # num_workers=args.num_workers, + # sample_rate=model_config["sample_rate"], + # sample_size=model_config["sample_size"], + # audio_channels=model_config.get("audio_channels", 2), + # ) + + + duration=(float)(args.duration_sec) + + dm = DataModule( + dataset_config, + batch_size=args.batch_size, + test_batch_size=args.test_batch_size, + num_workers=args.num_workers, + sample_rate=model_config["sample_rate"], + sample_size=(float)(args.duration_sec) * model_config["sample_rate"], + audio_channels=model_config.get("audio_channels", 2), + latent_length=round(44100/64/32*duration), + ) + + model_config["sample_size"] = duration * model_config["sample_rate"] + model_config["model"]["diffusion"]["config"]["sync_seq_len"] = 24*int(duration) + model_config["model"]["diffusion"]["config"]["clip_seq_len"] = 8*int(duration) + model_config["model"]["diffusion"]["config"]["latent_seq_len"] = round(44100/64/32*duration) + + model = create_model_from_config(model_config) + + ## speed by torch.compile + if args.compile: + model = torch.compile(model) + + if args.pretrained_ckpt_path: + copy_state_dict(model, load_ckpt_state_dict(args.pretrained_ckpt_path,prefix='diffusion.')) # autoencoder. diffusion. + + if args.remove_pretransform_weight_norm == "pre_load": + remove_weight_norm_from_model(model.pretransform) + # import ipdb + # ipdb.set_trace() + if args.pretransform_ckpt_path: + load_vae_state = load_ckpt_state_dict(args.pretransform_ckpt_path, prefix='autoencoder.') + # new_state_dict = {k.replace("autoencoder.", ""): v for k, v in load_vae_state.items() if k.startswith("autoencoder.")} + model.pretransform.load_state_dict(load_vae_state) + + # Remove weight_norm from the pretransform if specified + if args.remove_pretransform_weight_norm == "post_load": + remove_weight_norm_from_model(model.pretransform) + + training_wrapper = create_training_wrapper_from_config(model_config, model) + + # wandb_logger = L.pytorch.loggers.WandbLogger(project=args.name) + # wandb_logger.watch(training_wrapper) + + exc_callback = ExceptionCallback() + + # if args.save_dir and isinstance(wandb_logger.experiment.id, str): + # checkpoint_dir = os.path.join(args.save_dir, wandb_logger.experiment.project, wandb_logger.experiment.id, "checkpoints") + # else: + # checkpoint_dir = None + + # ckpt_callback = ModelCheckpoint(every_n_train_steps=args.checkpoint_every, dirpath=checkpoint_dir, monitor='val_loss', mode='min', save_top_k=10) + save_model_config_callback = ModelConfigEmbedderCallback(model_config) + audio_dir = args.results_dir + pred_writer = CustomWriter(output_dir=audio_dir, write_interval="batch", batch_size=args.test_batch_size) + timer = Timer(duration="00:15:00:00") + demo_callback = create_demo_callback_from_config(model_config, demo_dl=dm) + + #Combine args and config dicts + args_dict = vars(args) + args_dict.update({"model_config": model_config}) + args_dict.update({"dataset_config": dataset_config}) + # push_wandb_config(wandb_logger, args_dict) + + #Set multi-GPU strategy if specified + if args.strategy: + if args.strategy == "deepspeed": + from pytorch_lightning.strategies import DeepSpeedStrategy + strategy = DeepSpeedStrategy(stage=2, + contiguous_gradients=True, + overlap_comm=True, + reduce_scatter=True, + reduce_bucket_size=5e8, + allgather_bucket_size=5e8, + load_full_weights=True + ) + else: + strategy = args.strategy + else: + strategy = 'ddp_find_unused_parameters_true' if args.num_gpus > 1 else "auto" + + trainer = L.Trainer( + devices=args.num_gpus, + accelerator="gpu", + num_nodes = args.num_nodes, + strategy=strategy, + precision=args.precision, + accumulate_grad_batches=args.accum_batches, + callbacks=[demo_callback, exc_callback, save_model_config_callback, timer, pred_writer], + log_every_n_steps=1, + max_epochs=1000, + default_root_dir=args.save_dir, + gradient_clip_val=args.gradient_clip_val, + reload_dataloaders_every_n_epochs = 0, + check_val_every_n_epoch=2, + ) + + + + # ckpt_path = hf_hub_download( + # repo_id="liuhuadai/ThinkSound", + # filename="thinksound.ckpt" + # ) + ckpt_path = 'ckpts/thinksound.ckpt' + + + + current_date = datetime.now() + formatted_date = current_date.strftime('%m%d') + + audio_dir = f'{formatted_date}_step68k_batch_size'+str(args.test_batch_size) + metrics_path = os.path.join(args.ckpt_dir, 'audios',audio_dir,'cache',"output_metrics.json") + # if os.path.exists(metrics_path): continue + + trainer.predict(training_wrapper, dm, return_predictions=False,ckpt_path=ckpt_path) + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000000000000000000000000000000000000..7fd26b970b848120ad4f6ab398e1864ab0111e60 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,3 @@ +[build-system] +requires = ["setuptools"] +build-backend = "setuptools.build_meta" \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index d358f83a37a29610fe532c39d7036f32d454a280..eb27cd60e38a1b078ba7e6d46b9ca39d076087b8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,14 +1,14 @@ +modelscope absl-py==2.2.2 -accelerate==1.7.0 +accelerate==1.6.0 aeiou==0.0.20 aiobotocore==2.22.0 -aiofiles==24.1.0 +aiofiles==23.2.1 aiohappyeyeballs==2.6.1 aiohttp==3.11.18 aioitertools==0.12.0 aiosignal==1.3.2 alias-free-torch==0.0.6 -altair==5.5.0 annotated-types==0.7.0 antlr4-python3-runtime==4.9.3 anyio==4.9.0 @@ -17,6 +17,7 @@ argbind==0.3.9 asttokens==3.0.0 async-timeout==5.0.1 attrs==25.3.0 +audiobox_aesthetics==0.0.2 audioread==3.0.1 auraloss==0.4.0 av==14.4.0 @@ -24,9 +25,10 @@ bleach==6.2.0 bokeh==3.7.3 botocore==1.37.3 braceexpand==0.1.7 +Brotli==1.1.0 certifi==2025.4.26 cffi==1.17.1 -charset-normalizer==3.4.2 +charset-normalizer==3.4.1 clean-fid==0.1.35 click==8.1.8 clip-anytorch==2.6.0 @@ -41,14 +43,13 @@ dctorch==0.1.2 decorator==4.4.2 decord==0.6.0 descript-audio-codec==1.0.0 -descript-audiotools==0.7.2 docker-pycreds==0.4.0 docstring_parser==0.16 einops==0.7.0 einops-exts==0.0.4 ema-pytorch==0.2.3 encodec==0.1.1 -exceptiongroup==1.3.0 +exceptiongroup==1.2.2 executing==2.2.0 fastapi==0.115.12 fastcore==1.8.2 @@ -62,29 +63,33 @@ frozenlist==1.6.0 fsspec==2025.5.0 ftfy==6.3.1 future==1.0.0 +fvcore==0.1.5.post20221221 gin-config==0.5.0 gitdb==4.0.12 GitPython==3.1.44 -gradio==5.31.0 -gradio_client==1.10.1 +gradio==3.50.0 +gradio_client==0.6.1 groovy==0.1.2 grpcio==1.71.0 h11==0.16.0 +hf_xet h5py==3.13.0 +hjson==3.1.0 holoviews==1.20.2 httpcore==1.0.9 httpx==0.28.1 -huggingface-hub==0.31.4 -huggingface_hub[hf_xet] +huggingface-hub==0.30.2 hydra-colorlog==1.2.0 hydra-core==1.3.2 idna==3.10 imageio==2.37.0 +imageio-ffmpeg==0.4.9 importlib-resources==5.12.0 importlib_metadata==8.7.0 +iopath==0.1.10 ipython==8.36.0 jedi==0.19.2 -Jinja2==3.1.6 +Jinja2==3.1.0 jmespath==1.0.1 joblib==1.5.0 jsonmerge==1.9.2 @@ -96,6 +101,7 @@ kiwisolver==1.4.8 kornia==0.8.1 kornia_rs==0.1.9 laion-clap==1.1.4 +latex2mathml==3.77.0 lazy_loader==0.4 librosa==0.9.2 lightning==2.5.1.post0 @@ -106,35 +112,42 @@ local-attention==1.8.6 Markdown==3.8 markdown-it-py==3.0.0 markdown2==2.5.3 -MarkupSafe==3.0.2 +MarkupSafe==2.1.5 matplotlib==3.10.3 matplotlib-inline==0.1.7 mdit-py-plugins==0.4.2 mdurl==0.1.2 +moviepy==1.0.3 mpmath==1.3.0 multidict==6.4.4 +multiprocessing-logging==0.2.4 +mutagen==1.47.0 narwhals==1.40.0 networkx==3.4.2 +ninja==1.11.1.3 nitrous_ema==0.0.1 numba==0.60.0 numpy==1.23.5 omegaconf==2.3.0 open_clip_torch==2.32.0 +openai==1.33.0 opencv-python==4.11.0.86 orjson==3.10.18 -packaging==24.2 +pafy==0.5.3.1 pandas==2.0.2 panel==1.7.0 param==2.2.0 +parameterized==0.9.0 parso==0.8.4 pathtools==0.1.2 pedalboard==0.7.4 pexpect==4.9.0 -pillow==11.2.1 +pillow platformdirs==4.3.8 plotly==6.1.1 pooch==1.8.2 prefigure==0.0.9 +proglog==0.1.10 progressbar==2.5 prompt_toolkit==3.0.51 propcache==0.3.1 @@ -142,6 +155,7 @@ protobuf==3.19.6 psutil==7.0.0 ptyprocess==0.7.0 pure_eval==0.2.3 +py-cpuinfo==9.0.0 pycparser==2.22 pydantic==2.11.5 pydantic_core==2.33.2 @@ -149,12 +163,15 @@ pydub==0.25.1 Pygments==2.19.1 pyloudnorm==0.1.1 pynndescent==0.5.13 +pynvml==12.0.0 pyparsing==3.2.3 pystoi==0.4.1 +pysubs2==1.8.0 python-dateutil==2.9.0.post0 -python-dotenv==1.1.0 +python-dotenv==1.0.1 python-multipart==0.0.20 pytorch-lightning==2.5.1.post0 +pytorchvideo==0.1.5 pytz==2025.2 pyviz_comms==3.0.4 PyWavelets==1.4.1 @@ -170,6 +187,7 @@ ruff==0.11.11 s3fs==2025.5.0 safehttpx==0.1.6 safetensors==0.5.3 +scenedetect==0.6.3 scikit-image==0.24.0 scikit-learn==1.6.1 scipy==1.15.3 @@ -178,16 +196,19 @@ sentencepiece==0.1.99 sentry-sdk==2.29.1 setproctitle==1.3.6 shellingham==1.5.4 +shortuuid==1.0.13 six==1.17.0 smmap==5.0.2 sniffio==1.3.1 SoundFile==0.10.2 +sox==1.3.0 stack-data==0.6.3 starlette==0.46.2 +submitit==1.5.2 +svgwrite==1.4.3 sympy==1.13.1 -tensorboard==2.19.0 +tabulate==0.9.0 tensorboard-data-server==0.7.2 -tensordict==0.8.3 termcolor==3.1.0 threadpoolctl==3.6.0 tifffile==2025.5.10 @@ -203,6 +224,7 @@ torchmetrics==0.11.4 torchsde==0.2.6 torchvision==0.19.0 tornado==6.5.1 +git+https://github.com/patrick-kidger/torchcubicspline.git tqdm==4.67.1 traitlets==5.14.3 trampoline==0.1.2 @@ -210,7 +232,7 @@ transformers==4.43 triton==3.0.0 typer==0.15.4 typing-inspection==0.4.1 -typing_extensions==4.13.2 +typing_extensions==4.12.2 tzdata==2025.2 uc-micro-py==1.0.3 umap-learn==0.5.7 @@ -218,17 +240,15 @@ urllib3==2.4.0 uvicorn==0.34.2 v-diffusion-pytorch==0.0.2 vector-quantize-pytorch==1.9.14 -wandb==0.15.4 wcwidth==0.2.13 webdataset==0.2.48 webencodings==0.5.1 -websockets==15.0.1 Werkzeug==3.1.3 wget==3.2 wrapt==1.17.2 x-transformers==1.26.6 xyzservices==2025.4.0 +yacs==0.1.8 yarl==1.20.0 zipp==3.21.0 -git+https://github.com/patrick-kidger/torchcubicspline.git -moviepy==1.0.3 \ No newline at end of file +altair==5.5.0 diff --git a/scripts/demo.sh b/scripts/demo.sh new file mode 100644 index 0000000000000000000000000000000000000000..98212d00ff0822dce932ef0108783eb91d2085cb --- /dev/null +++ b/scripts/demo.sh @@ -0,0 +1,82 @@ +#!/bin/bash + +# Check number of arguments +if [ "$#" -ne 3 ]; then + echo "Usage: $0 <description>" + exit 1 +fi + +VIDEO_PATH="$1" +TITLE="$2" +DESCRIPTION="$3" + +# Generate unique ID +UNIQUE_ID=$(uuidgen | cut -c 1-8) + +# Create necessary directories +mkdir -p videos cot_coarse results + +# Get video filename and extension +VIDEO_FILE=$(basename "$VIDEO_PATH") +VIDEO_EXT="${VIDEO_FILE##*.}" +VIDEO_ID="${VIDEO_FILE%.*}" +TEMP_VIDEO_PATH="videos/${VIDEO_ID}_${UNIQUE_ID}.mp4" + +# Convert video to MP4 format if needed +if [ "${VIDEO_EXT,,}" != "mp4" ]; then + echo "⏳ Converting video to MP4 format..." + ffmpeg -y -i "$VIDEO_PATH" -c:v libx264 -preset fast -c:a aac -strict experimental "$TEMP_VIDEO_PATH" >/dev/null 2>&1 + if [ $? -ne 0 ]; then + echo "❌ Video conversion failed" + exit 2 + fi +else + cp "$VIDEO_PATH" "$TEMP_VIDEO_PATH" +fi + +# Calculate video duration +DURATION=$(ffprobe -v error -show_entries format=duration -of default=noprint_wrappers=1:nokey=1 "$TEMP_VIDEO_PATH") +DURATION_SEC=${DURATION%.*} +echo "Duration is: $DURATION_SEC" + +# Create cot.csv file +CAPTION_COT=$(echo "$DESCRIPTION" | tr '"' "'") +CSV_PATH="cot_coarse/cot.csv" +echo "id,caption,caption_cot" > "$CSV_PATH" +echo "${VIDEO_ID}_${UNIQUE_ID},$TITLE,\"$CAPTION_COT\"" >> "$CSV_PATH" + +# Run feature extraction +echo "⏳ Extracting features..." +python extract_latents.py --duration_sec "$DURATION_SEC" 2>&1 +if [ $? -ne 0 ]; then + echo "❌ Feature extraction failed" + rm -f "$TEMP_VIDEO_PATH" + exit 3 +fi + +# Run inference +echo "⏳ Running model inference..." +bash scripts/infer.sh --duration-sec "$DURATION_SEC" 2>&1 +if [ $? -ne 0 ]; then + echo "❌ Inference failed" + rm -f "$TEMP_VIDEO_PATH" + exit 4 +fi + +# Get generated audio file +CURRENT_DATE=$(date +"%m%d") +AUDIO_PATH="results/${CURRENT_DATE}_batch_size1/demo.wav" + +# Check if audio file exists +if [ ! -f "$AUDIO_PATH" ]; then + echo "❌ Generated audio file not found" + rm -f "$TEMP_VIDEO_PATH" + exit 5 +fi + +# Clean up temporary video file +rm -f "$TEMP_VIDEO_PATH" + + +echo "✅ Audio generated successfully!" +echo "Audio file path: $AUDIO_PATH" \ No newline at end of file diff --git a/scripts/infer.sh b/scripts/infer.sh new file mode 100644 index 0000000000000000000000000000000000000000..46b1ecf3f095a50da4705aa0a041579bd4dc36db --- /dev/null +++ b/scripts/infer.sh @@ -0,0 +1,76 @@ +#!/bin/bash + +# 变量定义 +ckpt_dir="ckpts/thinksound.ckpt" +test_batch_size=1 +dataset_config="ThinkSound/configs/multimodal_dataset_demo.json" +model_config="ThinkSound/configs/model_configs/thinksound.json" +pretransform_ckpt_path="ckpts/vae.ckpt" +# 默认值 +debug_mode="true" +node_rank=0 + +result_path="results" + +while [[ $# -gt 0 ]]; do + case "$1" in + --duration-sec) + if [[ -n "$2" && "$2" != --* ]]; then + duration_sec="$2" + shift 2 + else + echo "❌ Argument --duration-sec requires a value" + exit 1 + fi + ;; + --result-path) + if [[ -n "$2" && "$2" != --* ]]; then + result_path="$2" + shift 2 + else + echo "❌ Argument --result-path requires a path" + exit 1 + fi + ;; + *) + echo "❌ Unknown argument: $1" + exit 1 + ;; + esac +done + +export NODE_RANK=$node_rank +export RANK=$node_rank + +num_gpus=1 +num_nodes=1 + +export WORLD_SIZE=$((num_gpus * num_nodes)) +# 打印配置信息 +echo "Training Configuration:" +echo "Checkpoint Directory: $ckpt_dir" +echo "Dataset Config: $dataset_config" +echo "Model Config: $model_config" +echo "Pretransform Checkpoint Path: $pretransform_ckpt_path" +echo "Num GPUs: $num_gpus" +echo "Num Nodes: $num_nodes" +echo "Test Batch Size: $test_batch_size" +echo "Num Workers: 20" +echo "Node Rank: $node_rank" +echo "WORLD SIZE: $WORLD_SIZE" + + +python predict.py \ + --dataset-config "$dataset_config" \ + --model-config "$model_config" \ + --ckpt-dir "$ckpt_dir" \ + --pretransform-ckpt-path "$pretransform_ckpt_path" \ + --checkpoint-every 2000 \ + --num-gpus "$num_gpus" \ + --num-nodes "$num_nodes" \ + --batch-size 1 \ + --test-batch-size $test_batch_size \ + --num-workers 32 \ + --duration-sec $duration_sec \ + --results-dir $result_path \ + diff --git a/setup.py b/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..cb8a051fac9ab9407519d0b12bdd66f860bf7a41 --- /dev/null +++ b/setup.py @@ -0,0 +1,44 @@ +from setuptools import setup, find_packages + +setup( + name='thinksound', + version='0.0.16', + url='https://github.com/liuhuadai/thinksound.git', + author='liuhuadai', + description='a unified Any2Audio generation framework guided by Chain-of-Thought (CoT) reasoning', + packages=find_packages(), + install_requires=[ + 'aeiou==0.0.20', + 'alias-free-torch==0.0.6', + 'auraloss==0.4.0', + 'descript-audio-codec==1.0.0', + 'einops==0.7.0', + 'einops-exts==0.0.4', + 'ema-pytorch==0.2.3', + 'encodec==0.1.1', + # 'gradio>=3.42.0', + 'huggingface_hub', + 'importlib-resources==5.12.0', + 'k-diffusion==0.1.1', + 'laion-clap==1.1.4', + 'local-attention==1.8.6', + 'pandas==2.0.2', + 'pedalboard==0.7.4', + 'prefigure==0.0.9', + 'pytorch_lightning==2.1.0', + 'PyWavelets==1.4.1', + 'safetensors', + 'sentencepiece==0.1.99', + 's3fs', + 'torch>=2.0.1', + 'torchaudio>=2.0.2', + 'torchmetrics==0.11.4', + 'tqdm', + 'transformers', + 'v-diffusion-pytorch==0.0.2', + 'vector-quantize-pytorch==1.9.14', + 'wandb==0.15.4', + 'webdataset==0.2.48', + 'x-transformers<1.27.0' + ], +) \ No newline at end of file diff --git a/think_sound/__pycache__/__init__.cpython-310.pyc b/think_sound/__pycache__/__init__.cpython-310.pyc deleted file mode 100644 index 3122a65e35c929cf1509f27a8c66e31d56299281..0000000000000000000000000000000000000000 Binary files a/think_sound/__pycache__/__init__.cpython-310.pyc and /dev/null differ diff --git a/think_sound/__pycache__/__init__.cpython-38.pyc b/think_sound/__pycache__/__init__.cpython-38.pyc deleted file mode 100644 index 0050985c5802bad6bcb67109e9c81e99c54603d0..0000000000000000000000000000000000000000 Binary files a/think_sound/__pycache__/__init__.cpython-38.pyc and /dev/null differ diff --git a/think_sound/__pycache__/__init__.cpython-39.pyc b/think_sound/__pycache__/__init__.cpython-39.pyc deleted file mode 100644 index b045e33239dbb5291b54962bfa6855d0a29d0c2b..0000000000000000000000000000000000000000 Binary files a/think_sound/__pycache__/__init__.cpython-39.pyc and /dev/null differ diff --git a/think_sound/configs/model_configs/audiossl/flow_audio_channel_ssl.json b/think_sound/configs/model_configs/audiossl/flow_audio_channel_ssl.json deleted file mode 100644 index 5e3b3cc543c1da956020a441514594308cf804b0..0000000000000000000000000000000000000000 --- a/think_sound/configs/model_configs/audiossl/flow_audio_channel_ssl.json +++ /dev/null @@ -1,98 +0,0 @@ -{ - "model_type": "diffusion_infill", - "sample_size": 441000, - "sample_rate": 44100, - "audio_channels": 4, - "model": { - "pretransform": { - "type": "autoencoder", - "iterate_batch": true, - "config": { - "encoder": { - "type": "oobleck", - "requires_grad": false, - "config": { - "in_channels": 4, - "channels": 128, - "c_mults": [1, 2, 4, 8, 16], - "strides": [2, 4, 4, 8, 8], - "latent_dim": 128, - "use_snake": true - } - }, - "decoder": { - "type": "oobleck", - "config": { - "out_channels": 4, - "channels": 128, - "c_mults": [1, 2, 4, 8, 16], - "strides": [2, 4, 4, 8, 8], - "latent_dim": 64, - "use_snake": true, - "final_tanh": false - } - }, - "bottleneck": { - "type": "vae" - }, - "latent_dim": 64, - "downsampling_ratio": 2048, - "io_channels": 4 - } - }, - "diffusion": { - "input_concat_ids": ["x_ctx"], - "type": "dit", - "config": { - "io_channels": 64, - "embed_dim": 1536, - "depth": 24, - "num_heads": 24, - "project_cond_tokens": false, - "input_concat_dim": 64, - "transformer_type": "continuous_transformer", - "ctx_drop": 0.1 - } - }, - "io_channels": 64 - }, - "training": { - "use_ema": true, - "log_loss_info": false, - "timestep_sampler": "logit_normal", - "diffusion_objective": "rectified_flow", - "frac_lengths_mask": [0.7, 1.0], - "min_span_len": 10, - "ctx_drop": 0.1, - "r_drop": 0.2, - "optimizer_configs": { - "diffusion": { - "optimizer": { - "type": "AdamW", - "config": { - "lr": 5e-4, - "betas": [0.9, 0.999], - "weight_decay": 1e-3 - } - }, - "scheduler": { - "type": "InverseLR", - "config": { - "inv_gamma": 1000000, - "power": 0.5, - "warmup": 0.99 - } - } - } - }, - "demo": { - "demo_every": 2000, - "demo_steps": 64, - "num_demos": 1, - "demo_cond": [ - {"prompt": "children shouting", "seconds_start": 0, "seconds_total": 10} - ], - "demo_cfg_scales": [7] - } - } -} \ No newline at end of file diff --git a/think_sound/configs/model_configs/audiossl/flow_audio_ssl.json b/think_sound/configs/model_configs/audiossl/flow_audio_ssl.json deleted file mode 100644 index 9a79e2be2e5cf0150505155909ce127277ddefef..0000000000000000000000000000000000000000 --- a/think_sound/configs/model_configs/audiossl/flow_audio_ssl.json +++ /dev/null @@ -1,97 +0,0 @@ -{ - "model_type": "diffusion_infill", - "sample_size": 440320, - "sample_rate": 44100, - "audio_channels": 2, - "model": { - "pretransform": { - "type": "autoencoder", - "iterate_batch": true, - "config": { - "encoder": { - "type": "oobleck", - "requires_grad": false, - "config": { - "in_channels": 2, - "channels": 128, - "c_mults": [1, 2, 4, 8, 16], - "strides": [2, 4, 4, 8, 8], - "latent_dim": 128, - "use_snake": true - } - }, - "decoder": { - "type": "oobleck", - "config": { - "out_channels": 2, - "channels": 128, - "c_mults": [1, 2, 4, 8, 16], - "strides": [2, 4, 4, 8, 8], - "latent_dim": 64, - "use_snake": true, - "final_tanh": false - } - }, - "bottleneck": { - "type": "vae" - }, - "latent_dim": 64, - "downsampling_ratio": 2048, - "io_channels": 2 - } - }, - "diffusion": { - "input_concat_ids": ["x_ctx"], - "type": "dit", - "config": { - "io_channels": 64, - "embed_dim": 1536, - "depth": 24, - "num_heads": 24, - "project_cond_tokens": false, - "input_concat_dim": 64, - "transformer_type": "continuous_transformer", - "ctx_drop": 0.1 - } - }, - "io_channels": 64 - }, - "training": { - "use_ema": true, - "log_loss_info": false, - "timestep_sampler": "logit_normal", - "diffusion_objective": "rectified_flow", - "frac_lengths_mask": [0.7, 1.0], - "min_span_len": 10, - "ctx_drop": 0.1, - "optimizer_configs": { - "diffusion": { - "optimizer": { - "type": "AdamW", - "config": { - "lr": 1e-4, - "betas": [0.9, 0.999], - "weight_decay": 1e-3 - } - }, - "scheduler": { - "type": "InverseLR", - "config": { - "inv_gamma": 1000000, - "power": 0.5, - "warmup": 0.99 - } - } - } - }, - "demo": { - "demo_every": 2000, - "demo_steps": 100, - "num_demos": 1, - "demo_cond": [ - {"prompt": "children shouting", "seconds_start": 0, "seconds_total": 10} - ], - "demo_cfg_scales": [7] - } - } -} \ No newline at end of file diff --git a/think_sound/configs/model_configs/audiossl/flow_audio_token_ssl.json b/think_sound/configs/model_configs/audiossl/flow_audio_token_ssl.json deleted file mode 100644 index f9ea6f38ebbd73b6dab02edfd8259b8efcf999c9..0000000000000000000000000000000000000000 --- a/think_sound/configs/model_configs/audiossl/flow_audio_token_ssl.json +++ /dev/null @@ -1,98 +0,0 @@ -{ - "model_type": "diffusion_infill", - "sample_size": 441000, - "sample_rate": 44100, - "audio_channels": 4, - "model": { - "pretransform": { - "type": "autoencoder", - "iterate_batch": true, - "config": { - "encoder": { - "type": "oobleck", - "requires_grad": false, - "config": { - "in_channels": 4, - "channels": 128, - "c_mults": [1, 2, 4, 8, 16], - "strides": [2, 4, 4, 8, 8], - "latent_dim": 128, - "use_snake": true - } - }, - "decoder": { - "type": "oobleck", - "config": { - "out_channels": 4, - "channels": 128, - "c_mults": [1, 2, 4, 8, 16], - "strides": [2, 4, 4, 8, 8], - "latent_dim": 64, - "use_snake": true, - "final_tanh": false - } - }, - "bottleneck": { - "type": "vae" - }, - "latent_dim": 64, - "downsampling_ratio": 2048, - "io_channels": 4 - } - }, - "diffusion": { - "input_concat_ids": ["x_ctx"], - "type": "dit", - "config": { - "io_channels": 64, - "embed_dim": 1536, - "depth": 24, - "num_heads": 24, - "project_cond_tokens": false, - "input_concat_dim": 64, - "transformer_type": "continuous_transformer", - "ctx_drop": 0.1 - } - }, - "io_channels": 64 - }, - "training": { - "use_ema": true, - "log_loss_info": false, - "timestep_sampler": "logit_normal", - "diffusion_objective": "rectified_flow", - "frac_lengths_mask": [0.7, 1.0], - "min_span_len": 10, - "ctx_drop": 0.1, - "r_drop": 0.0, - "optimizer_configs": { - "diffusion": { - "optimizer": { - "type": "AdamW", - "config": { - "lr": 1e-4, - "betas": [0.9, 0.999], - "weight_decay": 1e-3 - } - }, - "scheduler": { - "type": "InverseLR", - "config": { - "inv_gamma": 1000000, - "power": 0.5, - "warmup": 0.99 - } - } - } - }, - "demo": { - "demo_every": 2000, - "demo_steps": 64, - "num_demos": 1, - "demo_cond": [ - {"prompt": "children shouting", "seconds_start": 0, "seconds_total": 10} - ], - "demo_cfg_scales": [7] - } - } -} \ No newline at end of file diff --git a/think_sound/configs/model_configs/audiossl/flow_audio_token_ssl_drop0.0.json b/think_sound/configs/model_configs/audiossl/flow_audio_token_ssl_drop0.0.json deleted file mode 100644 index f269e1545a8f9c451c460fdb9fbce23c9d6cf7d3..0000000000000000000000000000000000000000 --- a/think_sound/configs/model_configs/audiossl/flow_audio_token_ssl_drop0.0.json +++ /dev/null @@ -1,98 +0,0 @@ -{ - "model_type": "diffusion_infill", - "sample_size": 441000, - "sample_rate": 44100, - "audio_channels": 4, - "model": { - "pretransform": { - "type": "autoencoder", - "iterate_batch": true, - "config": { - "encoder": { - "type": "oobleck", - "requires_grad": false, - "config": { - "in_channels": 4, - "channels": 128, - "c_mults": [1, 2, 4, 8, 16], - "strides": [2, 4, 4, 8, 8], - "latent_dim": 128, - "use_snake": true - } - }, - "decoder": { - "type": "oobleck", - "config": { - "out_channels": 4, - "channels": 128, - "c_mults": [1, 2, 4, 8, 16], - "strides": [2, 4, 4, 8, 8], - "latent_dim": 64, - "use_snake": true, - "final_tanh": false - } - }, - "bottleneck": { - "type": "vae" - }, - "latent_dim": 64, - "downsampling_ratio": 2048, - "io_channels": 4 - } - }, - "diffusion": { - "input_concat_ids": ["x_ctx"], - "type": "dit", - "config": { - "io_channels": 64, - "embed_dim": 1536, - "depth": 24, - "num_heads": 24, - "project_cond_tokens": false, - "input_concat_dim": 64, - "transformer_type": "continuous_transformer", - "ctx_drop": 0.1 - } - }, - "io_channels": 64 - }, - "training": { - "use_ema": true, - "log_loss_info": false, - "timestep_sampler": "logit_normal", - "diffusion_objective": "rectified_flow", - "frac_lengths_mask": [0.7, 1.0], - "min_span_len": 10, - "ctx_drop": 0.0, - "r_drop": 0.0, - "optimizer_configs": { - "diffusion": { - "optimizer": { - "type": "AdamW", - "config": { - "lr": 1e-4, - "betas": [0.9, 0.999], - "weight_decay": 1e-3 - } - }, - "scheduler": { - "type": "InverseLR", - "config": { - "inv_gamma": 1000000, - "power": 0.5, - "warmup": 0.99 - } - } - } - }, - "demo": { - "demo_every": 2000, - "demo_steps": 64, - "num_demos": 1, - "demo_cond": [ - {"prompt": "children shouting", "seconds_start": 0, "seconds_total": 10} - ], - "demo_cfg_scales": [7] - } - } -} \ No newline at end of file diff --git a/think_sound/configs/model_configs/audiossl/flow_audio_token_ssl_drop0.2.json b/think_sound/configs/model_configs/audiossl/flow_audio_token_ssl_drop0.2.json deleted file mode 100644 index 12a7d852975536a2f57ee7db3c42e291e1d140c0..0000000000000000000000000000000000000000 --- a/think_sound/configs/model_configs/audiossl/flow_audio_token_ssl_drop0.2.json +++ /dev/null @@ -1,98 +0,0 @@ -{ - "model_type": "diffusion_infill", - "sample_size": 441000, - "sample_rate": 44100, - "audio_channels": 4, - "model": { - "pretransform": { - "type": "autoencoder", - "iterate_batch": true, - "config": { - "encoder": { - "type": "oobleck", - "requires_grad": false, - "config": { - "in_channels": 4, - "channels": 128, - "c_mults": [1, 2, 4, 8, 16], - "strides": [2, 4, 4, 8, 8], - "latent_dim": 128, - "use_snake": true - } - }, - "decoder": { - "type": "oobleck", - "config": { - "out_channels": 4, - "channels": 128, - "c_mults": [1, 2, 4, 8, 16], - "strides": [2, 4, 4, 8, 8], - "latent_dim": 64, - "use_snake": true, - "final_tanh": false - } - }, - "bottleneck": { - "type": "vae" - }, - "latent_dim": 64, - "downsampling_ratio": 2048, - "io_channels": 4 - } - }, - "diffusion": { - "input_concat_ids": ["x_ctx"], - "type": "dit", - "config": { - "io_channels": 64, - "embed_dim": 1536, - "depth": 24, - "num_heads": 24, - "project_cond_tokens": false, - "input_concat_dim": 64, - "transformer_type": "continuous_transformer", - "ctx_drop": 0.1 - } - }, - "io_channels": 64 - }, - "training": { - "use_ema": true, - "log_loss_info": false, - "timestep_sampler": "logit_normal", - "diffusion_objective": "rectified_flow", - "frac_lengths_mask": [0.7, 1.0], - "min_span_len": 10, - "ctx_drop": 0.2, - "r_drop": 0.0, - "optimizer_configs": { - "diffusion": { - "optimizer": { - "type": "AdamW", - "config": { - "lr": 1e-4, - "betas": [0.9, 0.999], - "weight_decay": 1e-3 - } - }, - "scheduler": { - "type": "InverseLR", - "config": { - "inv_gamma": 1000000, - "power": 0.5, - "warmup": 0.99 - } - } - } - }, - "demo": { - "demo_every": 2000, - "demo_steps": 64, - "num_demos": 1, - "demo_cond": [ - {"prompt": "children shouting", "seconds_start": 0, "seconds_total": 10} - ], - "demo_cfg_scales": [7] - } - } -} \ No newline at end of file diff --git a/think_sound/configs/model_configs/audiossl/flow_audio_token_ssl_lr.json b/think_sound/configs/model_configs/audiossl/flow_audio_token_ssl_lr.json deleted file mode 100644 index abe1e45098f507bfa832c6410e0e35711dc243f1..0000000000000000000000000000000000000000 --- a/think_sound/configs/model_configs/audiossl/flow_audio_token_ssl_lr.json +++ /dev/null @@ -1,98 +0,0 @@ -{ - "model_type": "diffusion_infill", - "sample_size": 441000, - "sample_rate": 44100, - "audio_channels": 4, - "model": { - "pretransform": { - "type": "autoencoder", - "iterate_batch": true, - "config": { - "encoder": { - "type": "oobleck", - "requires_grad": false, - "config": { - "in_channels": 4, - "channels": 128, - "c_mults": [1, 2, 4, 8, 16], - "strides": [2, 4, 4, 8, 8], - "latent_dim": 128, - "use_snake": true - } - }, - "decoder": { - "type": "oobleck", - "config": { - "out_channels": 4, - "channels": 128, - "c_mults": [1, 2, 4, 8, 16], - "strides": [2, 4, 4, 8, 8], - "latent_dim": 64, - "use_snake": true, - "final_tanh": false - } - }, - "bottleneck": { - "type": "vae" - }, - "latent_dim": 64, - "downsampling_ratio": 2048, - "io_channels": 4 - } - }, - "diffusion": { - "input_concat_ids": ["x_ctx"], - "type": "dit", - "config": { - "io_channels": 64, - "embed_dim": 1536, - "depth": 24, - "num_heads": 24, - "project_cond_tokens": false, - "input_concat_dim": 64, - "transformer_type": "continuous_transformer", - "ctx_drop": 0.1 - } - }, - "io_channels": 64 - }, - "training": { - "use_ema": true, - "log_loss_info": false, - "timestep_sampler": "logit_normal", - "diffusion_objective": "rectified_flow", - "frac_lengths_mask": [0.7, 1.0], - "min_span_len": 10, - "ctx_drop": 0.1, - "r_drop": 0.0, - "optimizer_configs": { - "diffusion": { - "optimizer": { - "type": "AdamW", - "config": { - "lr": 1e-4, - "betas": [0.9, 0.999], - "weight_decay": 1e-3 - } - }, - "scheduler": { - "type": "InverseLR", - "config": { - "inv_gamma": 100000, - "power": 0.5, - "warmup": 0.999 - } - } - } - }, - "demo": { - "demo_every": 2000, - "demo_steps": 64, - "num_demos": 1, - "demo_cond": [ - {"prompt": "children shouting", "seconds_start": 0, "seconds_total": 10} - ], - "demo_cfg_scales": [7] - } - } -} \ No newline at end of file diff --git a/think_sound/configs/model_configs/autoencoders/dac_2048_32_vae.json b/think_sound/configs/model_configs/autoencoders/dac_2048_32_vae.json deleted file mode 100644 index 25457472a9d4b0d096abc1f7b197d6f4fb8a7fa7..0000000000000000000000000000000000000000 --- a/think_sound/configs/model_configs/autoencoders/dac_2048_32_vae.json +++ /dev/null @@ -1,71 +0,0 @@ -{ - "model_type": "autoencoder", - "sample_size": 65536, - "sample_rate": 44100, - "audio_channels": 2, - "model": { - "encoder": { - "type": "dac", - "config": { - "latent_dim": 64, - "d_model": 128, - "strides": [4, 8, 8, 8] - } - }, - "decoder": { - "type": "dac", - "config": { - "latent_dim": 32, - "channels": 1536, - "rates": [8, 8, 8, 4] - } - }, - "bottleneck": { - "type": "vae" - }, - "latent_dim": 32, - "downsampling_ratio": 2048, - "io_channels": 1 - }, - "training": { - "learning_rate": 1e-4, - "warmup_steps": 0, - "use_ema": false, - "loss_configs": { - "discriminator": { - "type": "encodec", - "config": { - "filters": 32, - "n_ffts": [2048, 1024, 512, 256, 128, 64, 32], - "hop_lengths": [512, 256, 128, 64, 32, 16, 8], - "win_lengths": [2048, 1024, 512, 256, 128, 64, 32] - }, - "weights": { - "adversarial": 0.1, - "feature_matching": 5.0 - } - }, - "spectral": { - "type": "mrstft", - "config": { - "fft_sizes": [2048, 1024, 512, 256, 128, 64, 32], - "hop_sizes": [512, 256, 128, 64, 32, 16, 8], - "win_lengths": [2048, 1024, 512, 256, 128, 64, 32], - "perceptual_weighting": true - }, - "weights": { - "mrstft": 1.0 - } - }, - "time": { - "type": "l1", - "weights": { - "l1": 0.0 - } - } - }, - "demo": { - "demo_every": 2000 - } - } -} \ No newline at end of file diff --git a/think_sound/configs/model_configs/autoencoders/encodec_musicgen_rvq.json b/think_sound/configs/model_configs/autoencoders/encodec_musicgen_rvq.json deleted file mode 100644 index e76bd3d9a12ae028f3038562ce8082b8eadca116..0000000000000000000000000000000000000000 --- a/think_sound/configs/model_configs/autoencoders/encodec_musicgen_rvq.json +++ /dev/null @@ -1,88 +0,0 @@ -{ - "model_type": "autoencoder", - "sample_size": 32000, - "sample_rate": 32000, - "audio_channels": 1, - "model": { - "encoder": { - "type": "seanet", - "config": { - "channels": 1, - "dimension": 128, - "n_filters": 64, - "ratios": [4, 4, 5, 8], - "n_residual_layers": 1, - "dilation_base": 2, - "lstm": 2, - "norm": "weight_norm" - } - }, - "decoder": { - "type": "seanet", - "config": { - "channels": 1, - "dimension": 128, - "n_filters": 64, - "ratios": [4, 4, 5, 8], - "n_residual_layers": 1, - "dilation_base": 2, - "lstm": 2, - "norm": "weight_norm" - } - }, - "bottleneck": { - "type": "rvq", - "config": { - "num_quantizers": 4, - "codebook_size": 2048, - "dim": 128, - "decay": 0.99, - "threshold_ema_dead_code": 2 - } - }, - "latent_dim": 128, - "downsampling_ratio": 640, - "io_channels": 1 - }, - "training": { - "learning_rate": 1e-4, - "warmup_steps": 0, - "use_ema": true, - "loss_configs": { - "discriminator": { - "type": "encodec", - "config": { - "filters": 32, - "n_ffts": [2048, 1024, 512, 256, 128], - "hop_lengths": [512, 256, 128, 64, 32], - "win_lengths": [2048, 1024, 512, 256, 128] - }, - "weights": { - "adversarial": 0.1, - "feature_matching": 5.0 - } - }, - "spectral": { - "type": "mrstft", - "config": { - "fft_sizes": [2048, 1024, 512, 256, 128, 64, 32], - "hop_sizes": [512, 256, 128, 64, 32, 16, 8], - "win_lengths": [2048, 1024, 512, 256, 128, 64, 32], - "perceptual_weighting": true - }, - "weights": { - "mrstft": 1.0 - } - }, - "time": { - "type": "l1", - "weights": { - "l1": 0.0 - } - } - }, - "demo": { - "demo_every": 2000 - } - } -} \ No newline at end of file diff --git a/think_sound/configs/model_configs/autoencoders/foa_audio_vae.json b/think_sound/configs/model_configs/autoencoders/foa_audio_vae.json deleted file mode 100644 index f1128edad22c618120c7e28c2bf7a68bd1a015e9..0000000000000000000000000000000000000000 --- a/think_sound/configs/model_configs/autoencoders/foa_audio_vae.json +++ /dev/null @@ -1,122 +0,0 @@ -{ - "model_type": "autoencoder", - "sample_size": 65536, - "sample_rate": 44100, - "audio_channels": 4, - "model": { - "encoder": { - "type": "oobleck", - "config": { - "in_channels": 4, - "channels": 128, - "c_mults": [1, 2, 4, 8, 16], - "strides": [2, 4, 4, 8, 8], - "latent_dim": 128, - "use_snake": true - } - }, - "decoder": { - "type": "oobleck", - "config": { - "out_channels": 4, - "channels": 128, - "c_mults": [1, 2, 4, 8, 16], - "strides": [2, 4, 4, 8, 8], - "latent_dim": 64, - "use_snake": true, - "final_tanh": false - } - }, - "bottleneck": { - "type": "vae" - }, - "latent_dim": 64, - "downsampling_ratio": 2048, - "io_channels": 4 - }, - "training": { - "learning_rate": 3e-5, - "warmup_steps": 0, - "use_ema": true, - "optimizer_configs": { - "autoencoder": { - "optimizer": { - "type": "AdamW", - "config": { - "betas": [0.8, 0.99], - "lr": 3e-5, - "weight_decay": 1e-3 - } - }, - "scheduler": { - "type": "InverseLR", - "config": { - "inv_gamma": 200000, - "power": 0.5, - "warmup": 0.999 - } - } - }, - "discriminator": { - "optimizer": { - "type": "AdamW", - "config": { - "betas": [0.8, 0.99], - "lr": 6e-5, - "weight_decay": 1e-3 - } - }, - "scheduler": { - "type": "InverseLR", - "config": { - "inv_gamma": 200000, - "power": 0.5, - "warmup": 0.999 - } - } - } - }, - "loss_configs": { - "discriminator": { - "type": "encodec", - "config": { - "filters": 64, - "n_ffts": [2048, 1024, 512, 256, 128], - "hop_lengths": [512, 256, 128, 64, 32], - "win_lengths": [2048, 1024, 512, 256, 128] - }, - "weights": { - "adversarial": 0.1, - "feature_matching": 5.0 - } - }, - "spectral": { - "type": "mrstft", - "config": { - "fft_sizes": [2048, 1024, 512, 256, 128, 64, 32], - "hop_sizes": [512, 256, 128, 64, 32, 16, 8], - "win_lengths": [2048, 1024, 512, 256, 128, 64, 32], - "perceptual_weighting": true - }, - "weights": { - "mrstft": 1.0 - } - }, - "time": { - "type": "l1", - "weights": { - "l1": 0.0 - } - }, - "bottleneck": { - "type": "kl", - "weights": { - "kl": 1e-4 - } - } - }, - "demo": { - "demo_every": 10000 - } - } -} \ No newline at end of file diff --git a/think_sound/configs/model_configs/autoencoders/foa_audio_vae_256.json b/think_sound/configs/model_configs/autoencoders/foa_audio_vae_256.json deleted file mode 100644 index 5feb255de582b9aae0024daafd2ec5ad77edbb11..0000000000000000000000000000000000000000 --- a/think_sound/configs/model_configs/autoencoders/foa_audio_vae_256.json +++ /dev/null @@ -1,122 +0,0 @@ -{ - "model_type": "autoencoder", - "sample_size": 65536, - "sample_rate": 44100, - "audio_channels": 4, - "model": { - "encoder": { - "type": "oobleck", - "config": { - "in_channels": 4, - "channels": 128, - "c_mults": [1, 2, 4, 8, 16], - "strides": [2, 4, 4, 8, 8], - "latent_dim": 128, - "use_snake": true - } - }, - "decoder": { - "type": "oobleck", - "config": { - "out_channels": 4, - "channels": 128, - "c_mults": [1, 2, 4, 8, 16], - "strides": [2, 4, 4, 8, 8], - "latent_dim": 64, - "use_snake": true, - "final_tanh": false - } - }, - "bottleneck": { - "type": "vae" - }, - "latent_dim": 64, - "downsampling_ratio": 2048, - "io_channels": 4 - }, - "training": { - "learning_rate": 1.5e-4, - "warmup_steps": 0, - "use_ema": true, - "optimizer_configs": { - "autoencoder": { - "optimizer": { - "type": "AdamW", - "config": { - "betas": [0.8, 0.99], - "lr": 1.5e-4, - "weight_decay": 1e-3 - } - }, - "scheduler": { - "type": "InverseLR", - "config": { - "inv_gamma": 200000, - "power": 0.5, - "warmup": 0.999 - } - } - }, - "discriminator": { - "optimizer": { - "type": "AdamW", - "config": { - "betas": [0.8, 0.99], - "lr": 3e-4, - "weight_decay": 1e-3 - } - }, - "scheduler": { - "type": "InverseLR", - "config": { - "inv_gamma": 200000, - "power": 0.5, - "warmup": 0.999 - } - } - } - }, - "loss_configs": { - "discriminator": { - "type": "encodec", - "config": { - "filters": 64, - "n_ffts": [2048, 1024, 512, 256, 128], - "hop_lengths": [512, 256, 128, 64, 32], - "win_lengths": [2048, 1024, 512, 256, 128] - }, - "weights": { - "adversarial": 0.1, - "feature_matching": 5.0 - } - }, - "spectral": { - "type": "mrstft", - "config": { - "fft_sizes": [2048, 1024, 512, 256, 128, 64, 32], - "hop_sizes": [512, 256, 128, 64, 32, 16, 8], - "win_lengths": [2048, 1024, 512, 256, 128, 64, 32], - "perceptual_weighting": true - }, - "weights": { - "mrstft": 1.0 - } - }, - "time": { - "type": "l1", - "weights": { - "l1": 0.0 - } - }, - "bottleneck": { - "type": "kl", - "weights": { - "kl": 1e-4 - } - } - }, - "demo": { - "demo_every": 10000 - } - } -} \ No newline at end of file diff --git a/think_sound/configs/model_configs/autoencoders/foa_audio_vae_256_lr1e-4.json b/think_sound/configs/model_configs/autoencoders/foa_audio_vae_256_lr1e-4.json deleted file mode 100644 index 063d69ff1ce1a371e60dd1fc176434d50acbd788..0000000000000000000000000000000000000000 --- a/think_sound/configs/model_configs/autoencoders/foa_audio_vae_256_lr1e-4.json +++ /dev/null @@ -1,122 +0,0 @@ -{ - "model_type": "autoencoder", - "sample_size": 65536, - "sample_rate": 44100, - "audio_channels": 4, - "model": { - "encoder": { - "type": "oobleck", - "config": { - "in_channels": 4, - "channels": 256, - "c_mults": [1, 2, 4, 8, 16], - "strides": [2, 4, 4, 8, 8], - "latent_dim": 256, - "use_snake": true - } - }, - "decoder": { - "type": "oobleck", - "config": { - "out_channels": 4, - "channels": 256, - "c_mults": [1, 2, 4, 8, 16], - "strides": [2, 4, 4, 8, 8], - "latent_dim": 128, - "use_snake": true, - "final_tanh": false - } - }, - "bottleneck": { - "type": "vae" - }, - "latent_dim": 128, - "downsampling_ratio": 2048, - "io_channels": 4 - }, - "training": { - "learning_rate": 1.5e-4, - "warmup_steps": 0, - "use_ema": true, - "optimizer_configs": { - "autoencoder": { - "optimizer": { - "type": "AdamW", - "config": { - "betas": [0.8, 0.99], - "lr": 1e-4, - "weight_decay": 1e-3 - } - }, - "scheduler": { - "type": "InverseLR", - "config": { - "inv_gamma": 200000, - "power": 0.5, - "warmup": 0.999 - } - } - }, - "discriminator": { - "optimizer": { - "type": "AdamW", - "config": { - "betas": [0.8, 0.99], - "lr": 2e-4, - "weight_decay": 1e-3 - } - }, - "scheduler": { - "type": "InverseLR", - "config": { - "inv_gamma": 200000, - "power": 0.5, - "warmup": 0.999 - } - } - } - }, - "loss_configs": { - "discriminator": { - "type": "encodec", - "config": { - "filters": 64, - "n_ffts": [2048, 1024, 512, 256, 128], - "hop_lengths": [512, 256, 128, 64, 32], - "win_lengths": [2048, 1024, 512, 256, 128] - }, - "weights": { - "adversarial": 0.1, - "feature_matching": 5.0 - } - }, - "spectral": { - "type": "mrstft", - "config": { - "fft_sizes": [2048, 1024, 512, 256, 128, 64, 32], - "hop_sizes": [512, 256, 128, 64, 32, 16, 8], - "win_lengths": [2048, 1024, 512, 256, 128, 64, 32], - "perceptual_weighting": true - }, - "weights": { - "mrstft": 1.0 - } - }, - "time": { - "type": "l1", - "weights": { - "l1": 0.0 - } - }, - "bottleneck": { - "type": "kl", - "weights": { - "kl": 1e-4 - } - } - }, - "demo": { - "demo_every": 10000 - } - } -} \ No newline at end of file diff --git a/think_sound/configs/model_configs/autoencoders/foa_audio_vae_decoder.json b/think_sound/configs/model_configs/autoencoders/foa_audio_vae_decoder.json deleted file mode 100644 index bb604d623300dac618e00f98f73f7eab5fd04228..0000000000000000000000000000000000000000 --- a/think_sound/configs/model_configs/autoencoders/foa_audio_vae_decoder.json +++ /dev/null @@ -1,124 +0,0 @@ -{ - "model_type": "autoencoder", - "sample_size": 65536, - "sample_rate": 44100, - "audio_channels": 4, - "model": { - "encoder": { - "type": "oobleck", - "config": { - "in_channels": 4, - "channels": 128, - "c_mults": [1, 2, 4, 8, 16], - "strides": [2, 4, 4, 8, 8], - "latent_dim": 128, - "use_snake": true - } - }, - "decoder": { - "type": "oobleck", - "config": { - "out_channels": 4, - "channels": 128, - "c_mults": [1, 2, 4, 8, 16], - "strides": [2, 4, 4, 8, 8], - "latent_dim": 64, - "use_snake": true, - "final_tanh": false - } - }, - "bottleneck": { - "type": "vae" - }, - "latent_dim": 64, - "downsampling_ratio": 2048, - "io_channels": 4 - }, - "training": { - "learning_rate": 3e-5, - "warmup_steps": 0, - "latent_mask_ratio": 0.1, - "encoder_freeze_on_warmup": true, - "use_ema": true, - "optimizer_configs": { - "autoencoder": { - "optimizer": { - "type": "AdamW", - "config": { - "betas": [0.8, 0.99], - "lr": 3e-5, - "weight_decay": 1e-3 - } - }, - "scheduler": { - "type": "InverseLR", - "config": { - "inv_gamma": 200000, - "power": 0.5, - "warmup": 0.999 - } - } - }, - "discriminator": { - "optimizer": { - "type": "AdamW", - "config": { - "betas": [0.8, 0.99], - "lr": 6e-5, - "weight_decay": 1e-3 - } - }, - "scheduler": { - "type": "InverseLR", - "config": { - "inv_gamma": 200000, - "power": 0.5, - "warmup": 0.999 - } - } - } - }, - "loss_configs": { - "discriminator": { - "type": "encodec", - "config": { - "filters": 64, - "n_ffts": [2048, 1024, 512, 256, 128], - "hop_lengths": [512, 256, 128, 64, 32], - "win_lengths": [2048, 1024, 512, 256, 128] - }, - "weights": { - "adversarial": 0.1, - "feature_matching": 5.0 - } - }, - "spectral": { - "type": "mrstft", - "config": { - "fft_sizes": [2048, 1024, 512, 256, 128, 64, 32], - "hop_sizes": [512, 256, 128, 64, 32, 16, 8], - "win_lengths": [2048, 1024, 512, 256, 128, 64, 32], - "perceptual_weighting": true - }, - "weights": { - "mrstft": 1.0 - } - }, - "time": { - "type": "l1", - "weights": { - "l1": 0.0 - } - }, - "bottleneck": { - "type": "kl", - "weights": { - "kl": 1e-4 - } - } - }, - "demo": { - "demo_every": 10000 - } - } -} \ No newline at end of file diff --git a/think_sound/configs/model_configs/autoencoders/ori.json b/think_sound/configs/model_configs/autoencoders/ori.json deleted file mode 100644 index 3aa762f2a4bb3ff631fd53401c5ec22e524e9bf2..0000000000000000000000000000000000000000 --- a/think_sound/configs/model_configs/autoencoders/ori.json +++ /dev/null @@ -1,122 +0,0 @@ -{ - "model_type": "autoencoder", - "sample_size": 65536, - "sample_rate": 44100, - "audio_channels": 2, - "model": { - "encoder": { - "type": "oobleck", - "config": { - "in_channels": 2, - "channels": 128, - "c_mults": [1, 2, 4, 8, 16], - "strides": [2, 4, 4, 8, 8], - "latent_dim": 128, - "use_snake": true - } - }, - "decoder": { - "type": "oobleck", - "config": { - "out_channels": 2, - "channels": 128, - "c_mults": [1, 2, 4, 8, 16], - "strides": [2, 4, 4, 8, 8], - "latent_dim": 64, - "use_snake": true, - "final_tanh": false - } - }, - "bottleneck": { - "type": "vae" - }, - "latent_dim": 64, - "downsampling_ratio": 2048, - "io_channels": 2 - }, - "training": { - "learning_rate": 1.5e-4, - "warmup_steps": 0, - "use_ema": true, - "optimizer_configs": { - "autoencoder": { - "optimizer": { - "type": "AdamW", - "config": { - "betas": [0.8, 0.99], - "lr": 1.5e-4, - "weight_decay": 1e-3 - } - }, - "scheduler": { - "type": "InverseLR", - "config": { - "inv_gamma": 200000, - "power": 0.5, - "warmup": 0.999 - } - } - }, - "discriminator": { - "optimizer": { - "type": "AdamW", - "config": { - "betas": [0.8, 0.99], - "lr": 3e-4, - "weight_decay": 1e-3 - } - }, - "scheduler": { - "type": "InverseLR", - "config": { - "inv_gamma": 200000, - "power": 0.5, - "warmup": 0.999 - } - } - } - }, - "loss_configs": { - "discriminator": { - "type": "encodec", - "config": { - "filters": 64, - "n_ffts": [2048, 1024, 512, 256, 128], - "hop_lengths": [512, 256, 128, 64, 32], - "win_lengths": [2048, 1024, 512, 256, 128] - }, - "weights": { - "adversarial": 0.1, - "feature_matching": 5.0 - } - }, - "spectral": { - "type": "mrstft", - "config": { - "fft_sizes": [2048, 1024, 512, 256, 128, 64, 32], - "hop_sizes": [512, 256, 128, 64, 32, 16, 8], - "win_lengths": [2048, 1024, 512, 256, 128, 64, 32], - "perceptual_weighting": true - }, - "weights": { - "mrstft": 1.0 - } - }, - "time": { - "type": "l1", - "weights": { - "l1": 0.0 - } - }, - "bottleneck": { - "type": "kl", - "weights": { - "kl": 1e-4 - } - } - }, - "demo": { - "demo_every": 2000 - } - } -} \ No newline at end of file diff --git a/think_sound/configs/model_configs/autoencoders/speech_vae.json b/think_sound/configs/model_configs/autoencoders/speech_vae.json deleted file mode 100644 index 7d8a9a7b15ca48be660b55d48a87244b3d53f27a..0000000000000000000000000000000000000000 --- a/think_sound/configs/model_configs/autoencoders/speech_vae.json +++ /dev/null @@ -1,122 +0,0 @@ -{ - "model_type": "autoencoder", - "sample_size": 65536, - "sample_rate": 16000, - "audio_channels": 1, - "model": { - "encoder": { - "type": "oobleck", - "config": { - "in_channels": 1, - "channels": 64, - "c_mults": [1, 2, 4, 8, 16], - "strides": [2, 4, 4, 8, 8], - "latent_dim": 128, - "use_snake": true - } - }, - "decoder": { - "type": "oobleck", - "config": { - "out_channels": 1, - "channels": 64, - "c_mults": [1, 2, 4, 8, 16], - "strides": [2, 4, 4, 8, 8], - "latent_dim": 64, - "use_snake": true, - "final_tanh": false - } - }, - "bottleneck": { - "type": "vae" - }, - "latent_dim": 64, - "downsampling_ratio": 2048, - "io_channels": 1 - }, - "training": { - "learning_rate": 1.5e-4, - "warmup_steps": 0, - "use_ema": true, - "optimizer_configs": { - "autoencoder": { - "optimizer": { - "type": "AdamW", - "config": { - "betas": [0.8, 0.99], - "lr": 1.5e-4, - "weight_decay": 1e-3 - } - }, - "scheduler": { - "type": "InverseLR", - "config": { - "inv_gamma": 200000, - "power": 0.5, - "warmup": 0.999 - } - } - }, - "discriminator": { - "optimizer": { - "type": "AdamW", - "config": { - "betas": [0.8, 0.99], - "lr": 3e-4, - "weight_decay": 1e-3 - } - }, - "scheduler": { - "type": "InverseLR", - "config": { - "inv_gamma": 200000, - "power": 0.5, - "warmup": 0.999 - } - } - } - }, - "loss_configs": { - "discriminator": { - "type": "encodec", - "config": { - "filters": 64, - "n_ffts": [2048, 1024, 512, 256, 128], - "hop_lengths": [512, 256, 128, 64, 32], - "win_lengths": [2048, 1024, 512, 256, 128] - }, - "weights": { - "adversarial": 0.1, - "feature_matching": 5.0 - } - }, - "spectral": { - "type": "mrstft", - "config": { - "fft_sizes": [2048, 1024, 512, 256, 128, 64, 32], - "hop_sizes": [512, 256, 128, 64, 32, 16, 8], - "win_lengths": [2048, 1024, 512, 256, 128, 64, 32], - "perceptual_weighting": true - }, - "weights": { - "mrstft": 1.0 - } - }, - "time": { - "type": "l1", - "weights": { - "l1": 0.0 - } - }, - "bottleneck": { - "type": "kl", - "weights": { - "kl": 1e-4 - } - } - }, - "demo": { - "demo_every": 10000 - } - } -} \ No newline at end of file diff --git a/think_sound/configs/model_configs/autoencoders/speech_vae_44k.json b/think_sound/configs/model_configs/autoencoders/speech_vae_44k.json deleted file mode 100644 index 6f77e2e17823517af3d5cede126d40acf8b5f5dc..0000000000000000000000000000000000000000 --- a/think_sound/configs/model_configs/autoencoders/speech_vae_44k.json +++ /dev/null @@ -1,122 +0,0 @@ -{ - "model_type": "autoencoder", - "sample_size": 65536, - "sample_rate": 44100, - "audio_channels": 2, - "model": { - "encoder": { - "type": "oobleck", - "config": { - "in_channels": 2, - "channels": 128, - "c_mults": [1, 2, 4, 8, 16], - "strides": [2, 4, 4, 8, 8], - "latent_dim": 128, - "use_snake": true - } - }, - "decoder": { - "type": "oobleck", - "config": { - "out_channels": 2, - "channels": 128, - "c_mults": [1, 2, 4, 8, 16], - "strides": [2, 4, 4, 8, 8], - "latent_dim": 64, - "use_snake": true, - "final_tanh": false - } - }, - "bottleneck": { - "type": "vae" - }, - "latent_dim": 64, - "downsampling_ratio": 2048, - "io_channels": 2 - }, - "training": { - "learning_rate": 1.5e-4, - "warmup_steps": 0, - "use_ema": true, - "optimizer_configs": { - "autoencoder": { - "optimizer": { - "type": "AdamW", - "config": { - "betas": [0.8, 0.99], - "lr": 1.5e-5, - "weight_decay": 1e-3 - } - }, - "scheduler": { - "type": "InverseLR", - "config": { - "inv_gamma": 200000, - "power": 0.5, - "warmup": 0.999 - } - } - }, - "discriminator": { - "optimizer": { - "type": "AdamW", - "config": { - "betas": [0.8, 0.99], - "lr": 3e-5, - "weight_decay": 1e-3 - } - }, - "scheduler": { - "type": "InverseLR", - "config": { - "inv_gamma": 200000, - "power": 0.5, - "warmup": 0.999 - } - } - } - }, - "loss_configs": { - "discriminator": { - "type": "encodec", - "config": { - "filters": 64, - "n_ffts": [2048, 1024, 512, 256, 128], - "hop_lengths": [512, 256, 128, 64, 32], - "win_lengths": [2048, 1024, 512, 256, 128] - }, - "weights": { - "adversarial": 0.1, - "feature_matching": 5.0 - } - }, - "spectral": { - "type": "mrstft", - "config": { - "fft_sizes": [2048, 1024, 512, 256, 128, 64, 32], - "hop_sizes": [512, 256, 128, 64, 32, 16, 8], - "win_lengths": [2048, 1024, 512, 256, 128, 64, 32], - "perceptual_weighting": true - }, - "weights": { - "mrstft": 1.0 - } - }, - "time": { - "type": "l1", - "weights": { - "l1": 0.0 - } - }, - "bottleneck": { - "type": "kl", - "weights": { - "kl": 1e-4 - } - } - }, - "demo": { - "demo_every": 10000 - } - } -} \ No newline at end of file diff --git a/think_sound/configs/model_configs/autoencoders/stable_audio_1_0_vae.json b/think_sound/configs/model_configs/autoencoders/stable_audio_1_0_vae.json deleted file mode 100644 index 26dcb25f3322e79422c7ab288aace9f23e711768..0000000000000000000000000000000000000000 --- a/think_sound/configs/model_configs/autoencoders/stable_audio_1_0_vae.json +++ /dev/null @@ -1,111 +0,0 @@ -{ - "model_type": "autoencoder", - "sample_size": 65536, - "sample_rate": 44100, - "audio_channels": 2, - "model": { - "encoder": { - "type": "dac", - "config": { - "in_channels": 2, - "latent_dim": 128, - "d_model": 128, - "strides": [4, 4, 8, 8] - } - }, - "decoder": { - "type": "dac", - "config": { - "out_channels": 2, - "latent_dim": 64, - "channels": 1536, - "rates": [8, 8, 4, 4] - } - }, - "bottleneck": { - "type": "vae" - }, - "latent_dim": 64, - "downsampling_ratio": 1024, - "io_channels": 2 - }, - "training": { - "learning_rate": 1e-4, - "warmup_steps": 0, - "use_ema": true, - "optimizer_configs": { - "autoencoder": { - "optimizer": { - "type": "AdamW", - "config": { - "betas": [0.8, 0.99], - "lr": 1e-4 - } - }, - "scheduler": { - "type": "ExponentialLR", - "config": { - "gamma": 0.999996 - } - } - }, - "discriminator": { - "optimizer": { - "type": "AdamW", - "config": { - "betas": [0.8, 0.99], - "lr": 1e-4 - } - }, - "scheduler": { - "type": "ExponentialLR", - "config": { - "gamma": 0.999996 - } - } - } - }, - "loss_configs": { - "discriminator": { - "type": "encodec", - "config": { - "filters": 32, - "n_ffts": [2048, 1024, 512, 256, 128], - "hop_lengths": [512, 256, 128, 64, 32], - "win_lengths": [2048, 1024, 512, 256, 128] - }, - "weights": { - "adversarial": 0.1, - "feature_matching": 5.0 - } - }, - "spectral": { - "type": "mrstft", - "config": { - "fft_sizes": [2048, 1024, 512, 256, 128, 64, 32], - "hop_sizes": [512, 256, 128, 64, 32, 16, 8], - "win_lengths": [2048, 1024, 512, 256, 128, 64, 32], - "perceptual_weighting": true - }, - "weights": { - "mrstft": 1.0 - } - }, - "time": { - "type": "l1", - "weights": { - "l1": 0.0 - } - }, - "bottleneck": { - "type": "kl", - "weights": { - "kl": 1e-6 - } - } - }, - "demo": { - "demo_every": 2000 - } - } -} \ No newline at end of file diff --git a/think_sound/configs/model_configs/dance_diffusion/dance_diffusion_base.json b/think_sound/configs/model_configs/dance_diffusion/dance_diffusion_base.json deleted file mode 100644 index a57f9e4abc99157128f505c6f5e5188101808f9b..0000000000000000000000000000000000000000 --- a/think_sound/configs/model_configs/dance_diffusion/dance_diffusion_base.json +++ /dev/null @@ -1,18 +0,0 @@ -{ - "model_type": "diffusion_uncond", - "sample_size": 65536, - "sample_rate": 48000, - "model": { - "type": "DAU1d", - "config": { - "n_attn_layers": 5 - } - }, - "training": { - "learning_rate": 1e-4, - "demo": { - "demo_every": 2000, - "demo_steps": 250 - } - } -} \ No newline at end of file diff --git a/think_sound/configs/model_configs/dance_diffusion/dance_diffusion_base_16k.json b/think_sound/configs/model_configs/dance_diffusion/dance_diffusion_base_16k.json deleted file mode 100644 index 4319a56731f981d2de1a294c2727e087475d1633..0000000000000000000000000000000000000000 --- a/think_sound/configs/model_configs/dance_diffusion/dance_diffusion_base_16k.json +++ /dev/null @@ -1,18 +0,0 @@ -{ - "model_type": "diffusion_uncond", - "sample_size": 65536, - "sample_rate": 16000, - "model": { - "type": "DAU1d", - "config": { - "n_attn_layers": 5 - } - }, - "training": { - "learning_rate": 1e-4, - "demo": { - "demo_every": 2000, - "demo_steps": 250 - } - } -} \ No newline at end of file diff --git a/think_sound/configs/model_configs/dance_diffusion/dance_diffusion_base_44k.json b/think_sound/configs/model_configs/dance_diffusion/dance_diffusion_base_44k.json deleted file mode 100644 index fedb83fa3c741d7c1d4a7215e909862a81730805..0000000000000000000000000000000000000000 --- a/think_sound/configs/model_configs/dance_diffusion/dance_diffusion_base_44k.json +++ /dev/null @@ -1,18 +0,0 @@ -{ - "model_type": "diffusion_uncond", - "sample_size": 65536, - "sample_rate": 44100, - "model": { - "type": "DAU1d", - "config": { - "n_attn_layers": 5 - } - }, - "training": { - "learning_rate": 4e-5, - "demo": { - "demo_every": 2000, - "demo_steps": 250 - } - } -} \ No newline at end of file diff --git a/think_sound/configs/model_configs/dance_diffusion/dance_diffusion_large.json b/think_sound/configs/model_configs/dance_diffusion/dance_diffusion_large.json deleted file mode 100644 index f9f96a455ad9e40b4ea624bda4b9c209fea4bcca..0000000000000000000000000000000000000000 --- a/think_sound/configs/model_configs/dance_diffusion/dance_diffusion_large.json +++ /dev/null @@ -1,18 +0,0 @@ -{ - "model_type": "diffusion_uncond", - "sample_size": 131072, - "sample_rate": 48000, - "model": { - "type": "DAU1d", - "config": { - "n_attn_layers": 5 - } - }, - "training": { - "learning_rate": 1e-4, - "demo": { - "demo_every": 2000, - "demo_steps": 250 - } - } -} \ No newline at end of file diff --git a/think_sound/configs/model_configs/mono2audio/flow_audio_mono.json b/think_sound/configs/model_configs/mono2audio/flow_audio_mono.json deleted file mode 100644 index bd8ae0d8aff26b0f4cf437c161ed2c6840b94f4f..0000000000000000000000000000000000000000 --- a/think_sound/configs/model_configs/mono2audio/flow_audio_mono.json +++ /dev/null @@ -1,102 +0,0 @@ -{ - "model_type": "diffusion_prior", - "sample_size": 440320, - "sample_rate": 44100, - "audio_channels": 2, - "model": { - "prior_type": "mono_stereo", - "pretransform": { - "type": "autoencoder", - "iterate_batch": true, - "config": { - "encoder": { - "type": "oobleck", - "config": { - "in_channels": 2, - "channels": 128, - "c_mults": [1, 2, 4, 8, 16], - "strides": [2, 4, 4, 8, 8], - "latent_dim": 128, - "use_snake": true - } - }, - "decoder": { - "type": "oobleck", - "config": { - "out_channels": 2, - "channels": 128, - "c_mults": [1, 2, 4, 8, 16], - "strides": [2, 4, 4, 8, 8], - "latent_dim": 64, - "use_snake": true, - "final_tanh": false - } - }, - "bottleneck": { - "type": "vae" - }, - "latent_dim": 64, - "downsampling_ratio": 2048, - "io_channels": 2 - } - }, - "conditioning": { - "configs": [ - { - "id": "video", - "type": "video_linear", - "config": { - "dim": 1024, - "output_dim": 768 - } - } - ], - "cond_dim": 768 - }, - "diffusion": { - "cross_attention_cond_ids": ["video"], - "input_concat_ids": ["source"], - "type": "dit", - "config": { - "io_channels": 64, - "embed_dim": 1536, - "depth": 24, - "num_heads": 24, - "cond_token_dim": 768, - "input_concat_dim": 64, - "project_cond_tokens": false, - "transformer_type": "continuous_transformer" - } - }, - "io_channels": 64 - }, - "training": { - "use_ema": true, - "log_loss_info": false, - "learning_rate": 5e-4, - "optimizer_configs": { - "diffusion": { - "optimizer": { - "type": "AdamW", - "config": { - "lr": 5e-5, - "betas": [0.9, 0.999], - "weight_decay": 1e-3 - } - }, - "scheduler": { - "type": "InverseLR", - "config": { - "inv_gamma": 1000000, - "power": 0.5, - "warmup": 0.99 - } - } - } - }, - "demo": { - "demo_every": 2, - "demo_steps": 100 - } - } -} \ No newline at end of file diff --git a/think_sound/configs/model_configs/mono2audio/flow_audio_mono_aural_loss.json b/think_sound/configs/model_configs/mono2audio/flow_audio_mono_aural_loss.json deleted file mode 100644 index 6b19cf7e750f69766c2a5a0185178ce1964c3787..0000000000000000000000000000000000000000 --- a/think_sound/configs/model_configs/mono2audio/flow_audio_mono_aural_loss.json +++ /dev/null @@ -1,103 +0,0 @@ -{ - "model_type": "diffusion_prior", - "sample_size": 440320, - "sample_rate": 44100, - "audio_channels": 2, - "model": { - "prior_type": "mono_stereo", - "pretransform": { - "type": "autoencoder", - "iterate_batch": true, - "config": { - "encoder": { - "type": "oobleck", - "config": { - "in_channels": 2, - "channels": 128, - "c_mults": [1, 2, 4, 8, 16], - "strides": [2, 4, 4, 8, 8], - "latent_dim": 128, - "use_snake": true - } - }, - "decoder": { - "type": "oobleck", - "config": { - "out_channels": 2, - "channels": 128, - "c_mults": [1, 2, 4, 8, 16], - "strides": [2, 4, 4, 8, 8], - "latent_dim": 64, - "use_snake": true, - "final_tanh": false - } - }, - "bottleneck": { - "type": "vae" - }, - "latent_dim": 64, - "downsampling_ratio": 2048, - "io_channels": 2 - } - }, - "conditioning": { - "configs": [ - { - "id": "video", - "type": "video_linear", - "config": { - "dim": 1024, - "output_dim": 768 - } - } - ], - "cond_dim": 768 - }, - "diffusion": { - "cross_attention_cond_ids": ["video"], - "input_concat_ids": ["source"], - "type": "dit", - "config": { - "io_channels": 64, - "embed_dim": 1536, - "depth": 24, - "num_heads": 24, - "cond_token_dim": 768, - "input_concat_dim": 64, - "project_cond_tokens": false, - "transformer_type": "continuous_transformer" - } - }, - "io_channels": 64 - }, - "training": { - "use_ema": true, - "log_loss_info": false, - "learning_rate": 5e-4, - "use_reconstruction_loss": true, - "optimizer_configs": { - "diffusion": { - "optimizer": { - "type": "AdamW", - "config": { - "lr": 5e-5, - "betas": [0.9, 0.999], - "weight_decay": 1e-3 - } - }, - "scheduler": { - "type": "InverseLR", - "config": { - "inv_gamma": 1000000, - "power": 0.5, - "warmup": 0.99 - } - } - } - }, - "demo": { - "demo_every": 500, - "demo_steps": 100 - } - } -} \ No newline at end of file diff --git a/think_sound/configs/model_configs/mono2audio/flow_audio_mono_lr1e4.json b/think_sound/configs/model_configs/mono2audio/flow_audio_mono_lr1e4.json deleted file mode 100644 index bd8ae0d8aff26b0f4cf437c161ed2c6840b94f4f..0000000000000000000000000000000000000000 --- a/think_sound/configs/model_configs/mono2audio/flow_audio_mono_lr1e4.json +++ /dev/null @@ -1,102 +0,0 @@ -{ - "model_type": "diffusion_prior", - "sample_size": 440320, - "sample_rate": 44100, - "audio_channels": 2, - "model": { - "prior_type": "mono_stereo", - "pretransform": { - "type": "autoencoder", - "iterate_batch": true, - "config": { - "encoder": { - "type": "oobleck", - "config": { - "in_channels": 2, - "channels": 128, - "c_mults": [1, 2, 4, 8, 16], - "strides": [2, 4, 4, 8, 8], - "latent_dim": 128, - "use_snake": true - } - }, - "decoder": { - "type": "oobleck", - "config": { - "out_channels": 2, - "channels": 128, - "c_mults": [1, 2, 4, 8, 16], - "strides": [2, 4, 4, 8, 8], - "latent_dim": 64, - "use_snake": true, - "final_tanh": false - } - }, - "bottleneck": { - "type": "vae" - }, - "latent_dim": 64, - "downsampling_ratio": 2048, - "io_channels": 2 - } - }, - "conditioning": { - "configs": [ - { - "id": "video", - "type": "video_linear", - "config": { - "dim": 1024, - "output_dim": 768 - } - } - ], - "cond_dim": 768 - }, - "diffusion": { - "cross_attention_cond_ids": ["video"], - "input_concat_ids": ["source"], - "type": "dit", - "config": { - "io_channels": 64, - "embed_dim": 1536, - "depth": 24, - "num_heads": 24, - "cond_token_dim": 768, - "input_concat_dim": 64, - "project_cond_tokens": false, - "transformer_type": "continuous_transformer" - } - }, - "io_channels": 64 - }, - "training": { - "use_ema": true, - "log_loss_info": false, - "learning_rate": 5e-4, - "optimizer_configs": { - "diffusion": { - "optimizer": { - "type": "AdamW", - "config": { - "lr": 5e-5, - "betas": [0.9, 0.999], - "weight_decay": 1e-3 - } - }, - "scheduler": { - "type": "InverseLR", - "config": { - "inv_gamma": 1000000, - "power": 0.5, - "warmup": 0.99 - } - } - } - }, - "demo": { - "demo_every": 2, - "demo_steps": 100 - } - } -} \ No newline at end of file diff --git a/think_sound/configs/model_configs/txt2audio/stable_audio_1_0.json b/think_sound/configs/model_configs/txt2audio/stable_audio_1_0.json deleted file mode 100644 index 22db891d8529f894a26a0c7f7d173ef2ae84b744..0000000000000000000000000000000000000000 --- a/think_sound/configs/model_configs/txt2audio/stable_audio_1_0.json +++ /dev/null @@ -1,107 +0,0 @@ -{ - "model_type": "diffusion_cond", - "sample_size": 4194304, - "sample_rate": 44100, - "audio_channels": 2, - "model": { - "pretransform": { - "type": "autoencoder", - "iterate_batch": true, - "config": { - "encoder": { - "type": "dac", - "config": { - "in_channels": 2, - "latent_dim": 128, - "d_model": 128, - "strides": [4, 4, 8, 8] - } - }, - "decoder": { - "type": "dac", - "config": { - "out_channels": 2, - "latent_dim": 64, - "channels": 1536, - "rates": [8, 8, 4, 4] - } - }, - "bottleneck": { - "type": "vae" - }, - "latent_dim": 64, - "downsampling_ratio": 1024, - "io_channels": 2 - } - }, - "conditioning": { - "configs": [ - { - "id": "prompt", - "type": "clap_text", - "config": { - "audio_model_type": "HTSAT-base", - "enable_fusion": true, - "clap_ckpt_path": "/path/to/clap.ckpt", - "use_text_features": true, - "feature_layer_ix": -2 - } - }, - { - "id": "seconds_start", - "type": "int", - "config": { - "min_val": 0, - "max_val": 512 - } - }, - { - "id": "seconds_total", - "type": "int", - "config": { - "min_val": 0, - "max_val": 512 - } - } - ], - "cond_dim": 768 - }, - "diffusion": { - "type": "adp_cfg_1d", - "cross_attention_cond_ids": ["prompt", "seconds_start", "seconds_total"], - "config": { - "in_channels": 64, - "context_embedding_features": 768, - "context_embedding_max_length": 79, - "channels": 256, - "resnet_groups": 16, - "kernel_multiplier_downsample": 2, - "multipliers": [4, 4, 4, 5, 5], - "factors": [1, 2, 2, 4], - "num_blocks": [2, 2, 2, 2], - "attentions": [1, 3, 3, 3, 3], - "attention_heads": 16, - "attention_multiplier": 4, - "use_nearest_upsample": false, - "use_skip_scale": true, - "use_context_time": true - } - }, - "io_channels": 64 - }, - "training": { - "learning_rate": 4e-5, - "demo": { - "demo_every": 2000, - "demo_steps": 250, - "num_demos": 4, - "demo_cond": [ - {"prompt": "A beautiful piano arpeggio", "seconds_start": 0, "seconds_total": 95}, - {"prompt": "A tropical house track with upbeat melodies, a driving bassline, and cheery vibes", "seconds_start": 0, "seconds_total": 90}, - {"prompt": "A cool 80s glam rock song with driving drums and distorted guitars", "seconds_start": 0, "seconds_total": 180}, - {"prompt": "A grand orchestral arrangement with thunderous percussion, epic brass fanfares, and soaring strings, creating a cinematic atmosphere fit for a heroic battle.", "seconds_start": 0, "seconds_total": 60} - ], - "demo_cfg_scales": [3, 6, 9] - } - } -} \ No newline at end of file diff --git a/think_sound/configs/model_configs/txt2audio/stable_audio_2_0.json b/think_sound/configs/model_configs/txt2audio/stable_audio_2_0.json deleted file mode 100644 index 933e24307d322bcc2d5aa9a5d41c308706849791..0000000000000000000000000000000000000000 --- a/think_sound/configs/model_configs/txt2audio/stable_audio_2_0.json +++ /dev/null @@ -1,124 +0,0 @@ -{ - "model_type": "diffusion_cond", - "sample_size": 362496, - "sample_rate": 44100, - "audio_channels": 2, - "model": { - "pretransform": { - "type": "autoencoder", - "iterate_batch": true, - "config": { - "encoder": { - "type": "oobleck", - "config": { - "in_channels": 2, - "channels": 128, - "c_mults": [1, 2, 4, 8, 16], - "strides": [2, 4, 4, 8, 8], - "latent_dim": 128, - "use_snake": true - } - }, - "decoder": { - "type": "oobleck", - "config": { - "out_channels": 2, - "channels": 128, - "c_mults": [1, 2, 4, 8, 16], - "strides": [2, 4, 4, 8, 8], - "latent_dim": 64, - "use_snake": true, - "final_tanh": false - } - }, - "bottleneck": { - "type": "vae" - }, - "latent_dim": 64, - "downsampling_ratio": 2048, - "io_channels": 2 - } - }, - "conditioning": { - "configs": [ - { - "id": "prompt", - "type": "clap_text", - "config": { - "audio_model_type": "HTSAT-base", - "enable_fusion": true, - "clap_ckpt_path": "useful_ckpts/clap-htsat-fused/pytorch_model.bin", - "use_text_features": true, - "feature_layer_ix": -2 - } - }, - { - "id": "seconds_start", - "type": "number", - "config": { - "min_val": 0, - "max_val": 512 - } - }, - { - "id": "seconds_total", - "type": "number", - "config": { - "min_val": 0, - "max_val": 512 - } - } - ], - "cond_dim": 768 - }, - "diffusion": { - "cross_attention_cond_ids": ["prompt", "seconds_start", "seconds_total"], - "global_cond_ids": ["seconds_start", "seconds_total"], - "type": "dit", - "config": { - "io_channels": 64, - "embed_dim": 1536, - "depth": 24, - "num_heads": 24, - "cond_token_dim": 768, - "global_cond_dim": 1536, - "project_cond_tokens": false, - "transformer_type": "continuous_transformer" - } - }, - "io_channels": 64 - }, - "training": { - "use_ema": true, - "log_loss_info": false, - "optimizer_configs": { - "diffusion": { - "optimizer": { - "type": "AdamW", - "config": { - "lr": 5e-5, - "betas": [0.9, 0.999], - "weight_decay": 1e-3 - } - }, - "scheduler": { - "type": "InverseLR", - "config": { - "inv_gamma": 1000000, - "power": 0.5, - "warmup": 0.99 - } - } - } - }, - "demo": { - "demo_every": 1, - "demo_steps": 100, - "num_demos": 1, - "demo_cond": [ - {"prompt": "children shouting", "seconds_start": 0, "seconds_total": 10} - ], - "demo_cfg_scales": [7] - } - } -} \ No newline at end of file diff --git a/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit.json b/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit.json deleted file mode 100644 index 6b18684a37f77e9193165ff24fe63dd9194783ef..0000000000000000000000000000000000000000 --- a/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit.json +++ /dev/null @@ -1,137 +0,0 @@ -{ - "model_type": "mm_diffusion_cond", - "sample_size": 397312, - "sample_rate": 44100, - "audio_channels": 2, - "model": { - "pretransform": { - "type": "autoencoder", - "iterate_batch": true, - "config": { - "encoder": { - "type": "oobleck", - "config": { - "in_channels": 2, - "channels": 128, - "c_mults": [1, 2, 4, 8, 16], - "strides": [2, 4, 4, 8, 8], - "latent_dim": 128, - "use_snake": true - } - }, - "decoder": { - "type": "oobleck", - "config": { - "out_channels": 2, - "channels": 128, - "c_mults": [1, 2, 4, 8, 16], - "strides": [2, 4, 4, 8, 8], - "latent_dim": 64, - "use_snake": true, - "final_tanh": false - } - }, - "bottleneck": { - "type": "vae" - }, - "latent_dim": 64, - "downsampling_ratio": 2048, - "io_channels": 2 - } - }, - "conditioning": { - "configs": [ - { - "id": "metaclip_features", - "type": "mm_unchang", - "config": { - "dim": 1024, - "output_dim": 1024 - } - }, - { - "id": "metaclip_text_features", - "type": "mm_unchang", - "config": { - "dim": 1024, - "output_dim": 1024 - } - }, - { - "id": "sync_features", - "type": "mm_unchang", - "config": { - "dim": 768, - "output_dim": 768 - } - } - ], - "cond_dim": 768 - }, - "diffusion": { - "mm_cond_ids": ["metaclip_features", "sync_features", "metaclip_text_features"], - "type": "mmdit", - "diffusion_objective": "rectified_flow", - "config": { - "latent_dim":64, - "clip_dim":1024, - "sync_dim":768, - "text_dim":1024, - "hidden_dim":1024 , - "depth":21, - "fused_depth":14, - "num_heads":16, - "latent_seq_len":194, - "clip_seq_len":72, - "sync_seq_len":216, - "v2": true - } - }, - "io_channels": 64 - }, - "training": { - "use_ema": true, - "log_loss_info": false, - "cfg_dropout_prob": 0.2, - "pre_encoded": true, - "optimizer_configs": { - "diffusion": { - "optimizer": { - "type": "AdamW", - "config": { - "lr": 5e-5, - "betas": [0.9, 0.95], - "weight_decay": 1e-4, - "eps": 1e-6 - } - }, - "scheduler": { - "type": "InverseLR", - "config": { - "inv_gamma": 1000000, - "power": 0.5, - "warmup": 0.99 - } - } - } - }, - "demo": { - "demo_every": 2000, - "demo_steps": 24, - "num_demos": 10, - "demo_cond": [ - "dataset/vggsound/video_224_pad_latents_text/test/0Cu33yBwAPg_000060.pth", - "dataset/vggsound/video_224_pad_latents_text/test/bmKtI808DsU_000009.pth", - "dataset/vggsound/video_224_pad_latents_text/test/VC0c22cJTbM_000424.pth", - "dataset/vggsound/video_224_pad_latents_text/test/F3gsbUTdc2U_000090.pth", - "dataset/vggsound/video_224_pad_latents_text/test/WatvT8A8iug_000100.pth", - "dataset/vggsound/video_224_pad_latents_text/test/0nvBTp-q7tU_000112.pth", - "dataset/vggsound/video_224_pad_latents_text/test/3-PFuDkTM48_000080.pth", - "dataset/vggsound/video_224_pad_latents_text/test/luSAuu-BoPs_000232.pth", - "dataset/vggsound/video_224_pad_latents_text/test/__8UJxW0aOQ_000002.pth", - "dataset/vggsound/video_224_pad_latents_text/test/_0m_YMpQayA_000168.pth" - ], - "demo_cfg_scales": [3,6,9] - } - } -} \ No newline at end of file diff --git a/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow.json b/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow.json deleted file mode 100644 index 4f9b8704492402b92b4eb2a6c0929755b2c4a8e2..0000000000000000000000000000000000000000 --- a/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow.json +++ /dev/null @@ -1,138 +0,0 @@ -{ - "model_type": "mm_diffusion_cond", - "sample_size": 397312, - "sample_rate": 44100, - "audio_channels": 2, - "model": { - "pretransform": { - "type": "autoencoder", - "iterate_batch": true, - "config": { - "encoder": { - "type": "oobleck", - "config": { - "in_channels": 2, - "channels": 128, - "c_mults": [1, 2, 4, 8, 16], - "strides": [2, 4, 4, 8, 8], - "latent_dim": 128, - "use_snake": true - } - }, - "decoder": { - "type": "oobleck", - "config": { - "out_channels": 2, - "channels": 128, - "c_mults": [1, 2, 4, 8, 16], - "strides": [2, 4, 4, 8, 8], - "latent_dim": 64, - "use_snake": true, - "final_tanh": false - } - }, - "bottleneck": { - "type": "vae" - }, - "latent_dim": 64, - "downsampling_ratio": 2048, - "io_channels": 2 - } - }, - "conditioning": { - "configs": [ - { - "id": "metaclip_features", - "type": "mm_unchang", - "config": { - "dim": 1024, - "output_dim": 1024 - } - }, - { - "id": "metaclip_text_features", - "type": "mm_unchang", - "config": { - "dim": 1024, - "output_dim": 1024 - } - }, - { - "id": "sync_features", - "type": "mm_unchang", - "config": { - "dim": 768, - "output_dim": 768 - } - } - ], - "cond_dim": 768 - }, - "diffusion": { - "mm_cond_ids": ["metaclip_features", "sync_features", "metaclip_text_features"], - "type": "mmdit", - "diffusion_objective": "rectified_flow", - "timestep_sampler": "uniform", - "config": { - "latent_dim":64, - "clip_dim":1024, - "sync_dim":768, - "text_dim":1024, - "hidden_dim":1024 , - "depth":21, - "fused_depth":14, - "num_heads":16, - "latent_seq_len":194, - "clip_seq_len":72, - "sync_seq_len":216, - "v2": false - } - }, - "io_channels": 64 - }, - "training": { - "use_ema": true, - "log_loss_info": false, - "cfg_dropout_prob": 0.2, - "pre_encoded": true, - "optimizer_configs": { - "diffusion": { - "optimizer": { - "type": "AdamW", - "config": { - "lr": 5e-5, - "betas": [0.9, 0.95], - "weight_decay": 1e-3, - "eps": 1e-6 - } - }, - "scheduler": { - "type": "InverseLR", - "config": { - "inv_gamma": 100000, - "power": 0.5, - "warmup": 0.99 - } - } - } - }, - "demo": { - "demo_every": 2000, - "demo_steps": 24, - "num_demos": 10, - "demo_cond": [ - "dataset/vggsound/video_224_pad_latents_text/test/0Cu33yBwAPg_000060.pth", - "dataset/vggsound/video_224_pad_latents_text/test/bmKtI808DsU_000009.pth", - "dataset/vggsound/video_224_pad_latents_text/test/VC0c22cJTbM_000424.pth", - "dataset/vggsound/video_224_pad_latents_text/test/F3gsbUTdc2U_000090.pth", - "dataset/vggsound/video_224_pad_latents_text/test/WatvT8A8iug_000100.pth", - "dataset/vggsound/video_224_pad_latents_text/test/0nvBTp-q7tU_000112.pth", - "dataset/vggsound/video_224_pad_latents_text/test/3-PFuDkTM48_000080.pth", - "dataset/vggsound/video_224_pad_latents_text/test/luSAuu-BoPs_000232.pth", - "dataset/vggsound/video_224_pad_latents_text/test/__8UJxW0aOQ_000002.pth", - "dataset/vggsound/video_224_pad_latents_text/test/_0m_YMpQayA_000168.pth" - ], - "demo_cfg_scales": [3,6,9] - } - } -} \ No newline at end of file diff --git a/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit.json b/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit.json deleted file mode 100644 index e724a8533f25332324a4a87b5715ec0ea7df3a11..0000000000000000000000000000000000000000 --- a/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit.json +++ /dev/null @@ -1,138 +0,0 @@ -{ - "model_type": "mm_diffusion_cond", - "sample_size": 397312, - "sample_rate": 44100, - "audio_channels": 2, - "model": { - "pretransform": { - "type": "autoencoder", - "iterate_batch": true, - "config": { - "encoder": { - "type": "oobleck", - "config": { - "in_channels": 2, - "channels": 128, - "c_mults": [1, 2, 4, 8, 16], - "strides": [2, 4, 4, 8, 8], - "latent_dim": 128, - "use_snake": true - } - }, - "decoder": { - "type": "oobleck", - "config": { - "out_channels": 2, - "channels": 128, - "c_mults": [1, 2, 4, 8, 16], - "strides": [2, 4, 4, 8, 8], - "latent_dim": 64, - "use_snake": true, - "final_tanh": false - } - }, - "bottleneck": { - "type": "vae" - }, - "latent_dim": 64, - "downsampling_ratio": 2048, - "io_channels": 2 - } - }, - "conditioning": { - "configs": [ - { - "id": "metaclip_features", - "type": "mm_unchang", - "config": { - "dim": 1024, - "output_dim": 1024 - } - }, - { - "id": "metaclip_text_features", - "type": "mm_unchang", - "config": { - "dim": 1024, - "output_dim": 1024 - } - }, - { - "id": "sync_features", - "type": "mm_unchang", - "config": { - "dim": 768, - "output_dim": 768 - } - } - ], - "cond_dim": 768 - }, - "diffusion": { - "mm_cond_ids": ["metaclip_features", "sync_features", "metaclip_text_features"], - "type": "mmdit", - "diffusion_objective": "rectified_flow", - "config": { - "latent_dim":64, - "clip_dim":1024, - "sync_dim":768, - "text_dim":1024, - "hidden_dim":1024 , - "depth":21, - "fused_depth":14, - "num_heads":16, - "latent_seq_len":194, - "clip_seq_len":72, - "sync_seq_len":216, - "v2": true - } - }, - "io_channels": 64 - }, - "training": { - "use_ema": true, - "log_loss_info": false, - "cfg_dropout_prob": 0.2, - "pre_encoded": true, - "timestep_sampler": "logit_normal", - "optimizer_configs": { - "diffusion": { - "optimizer": { - "type": "AdamW", - "config": { - "lr": 5e-5, - "betas": [0.9, 0.95], - "weight_decay": 1e-4, - "eps": 1e-6 - } - }, - "scheduler": { - "type": "InverseLR", - "config": { - "inv_gamma": 1000000, - "power": 0.5, - "warmup": 0.99 - } - } - } - }, - "demo": { - "demo_every": 5000, - "demo_steps": 24, - "num_demos": 10, - "demo_cond": [ - "dataset/vggsound/video_latents_t5_clip_npz/test/0Cu33yBwAPg_000060.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/bmKtI808DsU_000009.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/VC0c22cJTbM_000424.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/F3gsbUTdc2U_000090.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/WatvT8A8iug_000100.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/0nvBTp-q7tU_000112.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/3-PFuDkTM48_000080.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/luSAuu-BoPs_000232.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/__8UJxW0aOQ_000002.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/_0m_YMpQayA_000168.npz" - ], - "demo_cfg_scales": [5] - } - } -} \ No newline at end of file diff --git a/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_addvideo.json b/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_addvideo.json deleted file mode 100644 index 34676ef598ac32a83269097c0404fe389aaade52..0000000000000000000000000000000000000000 --- a/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_addvideo.json +++ /dev/null @@ -1,140 +0,0 @@ -{ - "model_type": "mm_diffusion_cond", - "sample_size": 397312, - "sample_rate": 44100, - "audio_channels": 2, - "model": { - "pretransform": { - "type": "autoencoder", - "iterate_batch": true, - "config": { - "encoder": { - "type": "oobleck", - "config": { - "in_channels": 2, - "channels": 128, - "c_mults": [1, 2, 4, 8, 16], - "strides": [2, 4, 4, 8, 8], - "latent_dim": 128, - "use_snake": true - } - }, - "decoder": { - "type": "oobleck", - "config": { - "out_channels": 2, - "channels": 128, - "c_mults": [1, 2, 4, 8, 16], - "strides": [2, 4, 4, 8, 8], - "latent_dim": 64, - "use_snake": true, - "final_tanh": false - } - }, - "bottleneck": { - "type": "vae" - }, - "latent_dim": 64, - "downsampling_ratio": 2048, - "io_channels": 2 - } - }, - "conditioning": { - "configs": [ - { - "id": "metaclip_features", - "type": "mm_unchang", - "config": { - "dim": 1024, - "output_dim": 1024 - } - }, - { - "id": "metaclip_text_features", - "type": "mm_unchang", - "config": { - "dim": 1024, - "output_dim": 1024 - } - }, - { - "id": "sync_features", - "type": "mm_unchang", - "config": { - "dim": 768, - "output_dim": 768 - } - } - ], - "cond_dim": 768 - }, - "diffusion": { - "mm_cond_ids": ["metaclip_features", "sync_features", "metaclip_text_features"], - "type": "mmdit", - "diffusion_objective": "rectified_flow", - "config": { - "latent_dim":64, - "clip_dim":1024, - "sync_dim":768, - "text_dim":1024, - "hidden_dim":1024 , - "depth":21, - "fused_depth":14, - "num_heads":16, - "latent_seq_len":194, - "clip_seq_len":72, - "sync_seq_len":216, - "v2": true, - "cross_attend": false, - "add_video": true - } - }, - "io_channels": 64 - }, - "training": { - "use_ema": true, - "log_loss_info": false, - "cfg_dropout_prob": 0.2, - "pre_encoded": true, - "timestep_sampler": "logit_normal", - "optimizer_configs": { - "diffusion": { - "optimizer": { - "type": "AdamW", - "config": { - "lr": 5e-5, - "betas": [0.9, 0.95], - "weight_decay": 1e-4, - "eps": 1e-6 - } - }, - "scheduler": { - "type": "InverseLR", - "config": { - "inv_gamma": 1000000, - "power": 0.5, - "warmup": 0.99 - } - } - } - }, - "demo": { - "demo_every": 2000, - "demo_steps": 24, - "num_demos": 10, - "demo_cond": [ - "dataset/vggsound/video_224_latents_text_sync_roi_npz/test/0Cu33yBwAPg_000060.npz", - "dataset/vggsound/video_224_latents_text_sync_roi_npz/test/bmKtI808DsU_000009.npz", - "dataset/vggsound/video_224_latents_text_sync_roi_npz/test/VC0c22cJTbM_000424.npz", - "dataset/vggsound/video_224_latents_text_sync_roi_npz/test/F3gsbUTdc2U_000090.npz", - "dataset/vggsound/video_224_latents_text_sync_roi_npz/test/WatvT8A8iug_000100.npz", - "dataset/vggsound/video_224_latents_text_sync_roi_npz/test/0nvBTp-q7tU_000112.npz", - "dataset/vggsound/video_224_latents_text_sync_roi_npz/test/3-PFuDkTM48_000080.npz", - "dataset/vggsound/video_224_latents_text_sync_roi_npz/test/luSAuu-BoPs_000232.npz", - "dataset/vggsound/video_224_latents_text_sync_roi_npz/test/__8UJxW0aOQ_000002.npz", - "dataset/vggsound/video_224_latents_text_sync_roi_npz/test/_0m_YMpQayA_000168.npz" - ], - "demo_cfg_scales": [3,6,9] - } - } -} \ No newline at end of file diff --git a/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_cross.json b/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_cross.json deleted file mode 100644 index a6b208b02695845e892920d29c4c450f0b574dc4..0000000000000000000000000000000000000000 --- a/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_cross.json +++ /dev/null @@ -1,139 +0,0 @@ -{ - "model_type": "mm_diffusion_cond", - "sample_size": 397312, - "sample_rate": 44100, - "audio_channels": 2, - "model": { - "pretransform": { - "type": "autoencoder", - "iterate_batch": true, - "config": { - "encoder": { - "type": "oobleck", - "config": { - "in_channels": 2, - "channels": 128, - "c_mults": [1, 2, 4, 8, 16], - "strides": [2, 4, 4, 8, 8], - "latent_dim": 128, - "use_snake": true - } - }, - "decoder": { - "type": "oobleck", - "config": { - "out_channels": 2, - "channels": 128, - "c_mults": [1, 2, 4, 8, 16], - "strides": [2, 4, 4, 8, 8], - "latent_dim": 64, - "use_snake": true, - "final_tanh": false - } - }, - "bottleneck": { - "type": "vae" - }, - "latent_dim": 64, - "downsampling_ratio": 2048, - "io_channels": 2 - } - }, - "conditioning": { - "configs": [ - { - "id": "metaclip_features", - "type": "mm_unchang", - "config": { - "dim": 1024, - "output_dim": 1024 - } - }, - { - "id": "metaclip_text_features", - "type": "mm_unchang", - "config": { - "dim": 1024, - "output_dim": 1024 - } - }, - { - "id": "sync_features", - "type": "mm_unchang", - "config": { - "dim": 768, - "output_dim": 768 - } - } - ], - "cond_dim": 768 - }, - "diffusion": { - "mm_cond_ids": ["metaclip_features", "sync_features", "metaclip_text_features"], - "type": "mmdit", - "diffusion_objective": "rectified_flow", - "config": { - "latent_dim":64, - "clip_dim":1024, - "sync_dim":768, - "text_dim":1024, - "hidden_dim":1024 , - "depth":21, - "fused_depth":14, - "num_heads":16, - "latent_seq_len":194, - "clip_seq_len":72, - "sync_seq_len":216, - "v2": true, - "cross_attend": true - } - }, - "io_channels": 64 - }, - "training": { - "use_ema": true, - "log_loss_info": false, - "cfg_dropout_prob": 0.2, - "pre_encoded": true, - "timestep_sampler": "logit_normal", - "optimizer_configs": { - "diffusion": { - "optimizer": { - "type": "AdamW", - "config": { - "lr": 5e-5, - "betas": [0.9, 0.95], - "weight_decay": 1e-4, - "eps": 1e-6 - } - }, - "scheduler": { - "type": "InverseLR", - "config": { - "inv_gamma": 1000000, - "power": 0.5, - "warmup": 0.99 - } - } - } - }, - "demo": { - "demo_every": 2000, - "demo_steps": 24, - "num_demos": 10, - "demo_cond": [ - "dataset/vggsound/video_224_latents_text_sync_roi_npz/test/0Cu33yBwAPg_000060.npz", - "dataset/vggsound/video_224_latents_text_sync_roi_npz/test/bmKtI808DsU_000009.npz", - "dataset/vggsound/video_224_latents_text_sync_roi_npz/test/VC0c22cJTbM_000424.npz", - "dataset/vggsound/video_224_latents_text_sync_roi_npz/test/F3gsbUTdc2U_000090.npz", - "dataset/vggsound/video_224_latents_text_sync_roi_npz/test/WatvT8A8iug_000100.npz", - "dataset/vggsound/video_224_latents_text_sync_roi_npz/test/0nvBTp-q7tU_000112.npz", - "dataset/vggsound/video_224_latents_text_sync_roi_npz/test/3-PFuDkTM48_000080.npz", - "dataset/vggsound/video_224_latents_text_sync_roi_npz/test/luSAuu-BoPs_000232.npz", - "dataset/vggsound/video_224_latents_text_sync_roi_npz/test/__8UJxW0aOQ_000002.npz", - "dataset/vggsound/video_224_latents_text_sync_roi_npz/test/_0m_YMpQayA_000168.npz" - ], - "demo_cfg_scales": [3,6,9] - } - } -} \ No newline at end of file diff --git a/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_cross_gated.json b/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_cross_gated.json deleted file mode 100644 index 3f0b545d3a7e906b68491e8e82a8a096e6ecd116..0000000000000000000000000000000000000000 --- a/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_cross_gated.json +++ /dev/null @@ -1,141 +0,0 @@ -{ - "model_type": "mm_diffusion_cond", - "sample_size": 397312, - "sample_rate": 44100, - "audio_channels": 2, - "model": { - "pretransform": { - "type": "autoencoder", - "iterate_batch": true, - "config": { - "encoder": { - "type": "oobleck", - "config": { - "in_channels": 2, - "channels": 128, - "c_mults": [1, 2, 4, 8, 16], - "strides": [2, 4, 4, 8, 8], - "latent_dim": 128, - "use_snake": true - } - }, - "decoder": { - "type": "oobleck", - "config": { - "out_channels": 2, - "channels": 128, - "c_mults": [1, 2, 4, 8, 16], - "strides": [2, 4, 4, 8, 8], - "latent_dim": 64, - "use_snake": true, - "final_tanh": false - } - }, - "bottleneck": { - "type": "vae" - }, - "latent_dim": 64, - "downsampling_ratio": 2048, - "io_channels": 2 - } - }, - "conditioning": { - "configs": [ - { - "id": "metaclip_features", - "type": "mm_unchang", - "config": { - "dim": 1024, - "output_dim": 1024 - } - }, - { - "id": "metaclip_text_features", - "type": "mm_unchang", - "config": { - "dim": 1024, - "output_dim": 1024 - } - }, - { - "id": "sync_features", - "type": "mm_unchang", - "config": { - "dim": 768, - "output_dim": 768 - } - } - ], - "cond_dim": 768 - }, - "diffusion": { - "mm_cond_ids": ["metaclip_features", "sync_features", "metaclip_text_features"], - "type": "mmdit", - "diffusion_objective": "rectified_flow", - "config": { - "latent_dim":64, - "clip_dim":1024, - "sync_dim":768, - "text_dim":1024, - "hidden_dim":1024 , - "depth":21, - "fused_depth":14, - "num_heads":16, - "latent_seq_len":194, - "clip_seq_len":72, - "sync_seq_len":216, - "v2": true, - "cross_attend": true, - "add_video": true, - "gated_video": true - } - }, - "io_channels": 64 - }, - "training": { - "use_ema": true, - "log_loss_info": false, - "cfg_dropout_prob": 0.2, - "pre_encoded": true, - "timestep_sampler": "logit_normal", - "optimizer_configs": { - "diffusion": { - "optimizer": { - "type": "AdamW", - "config": { - "lr": 5e-5, - "betas": [0.9, 0.95], - "weight_decay": 1e-4, - "eps": 1e-6 - } - }, - "scheduler": { - "type": "InverseLR", - "config": { - "inv_gamma": 1000000, - "power": 0.5, - "warmup": 0.99 - } - } - } - }, - "demo": { - "demo_every": 2000, - "demo_steps": 24, - "num_demos": 10, - "demo_cond": [ - "dataset/vggsound/video_text_latents/test/0Cu33yBwAPg_000060.pth", - "dataset/vggsound/video_text_latents/test/bmKtI808DsU_000009.pth", - "dataset/vggsound/video_text_latents/test/VC0c22cJTbM_000424.pth", - "dataset/vggsound/video_text_latents/test/F3gsbUTdc2U_000090.pth", - "dataset/vggsound/video_text_latents/test/WatvT8A8iug_000100.pth", - "dataset/vggsound/video_text_latents/test/0nvBTp-q7tU_000112.pth", - "dataset/vggsound/video_text_latents/test/3-PFuDkTM48_000080.pth", - "dataset/vggsound/video_text_latents/test/luSAuu-BoPs_000232.pth", - "dataset/vggsound/video_text_latents/test/__8UJxW0aOQ_000002.pth", - "dataset/vggsound/video_text_latents/test/_0m_YMpQayA_000168.pth" - ], - "demo_cfg_scales": [3,6,9] - } - } -} \ No newline at end of file diff --git a/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_gatedvideo.json b/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_gatedvideo.json deleted file mode 100644 index 32cae7b5fd5c922af7e35fe0c85a197499654215..0000000000000000000000000000000000000000 --- a/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_gatedvideo.json +++ /dev/null @@ -1,141 +0,0 @@ -{ - "model_type": "mm_diffusion_cond", - "sample_size": 397312, - "sample_rate": 44100, - "audio_channels": 2, - "model": { - "pretransform": { - "type": "autoencoder", - "iterate_batch": true, - "config": { - "encoder": { - "type": "oobleck", - "config": { - "in_channels": 2, - "channels": 128, - "c_mults": [1, 2, 4, 8, 16], - "strides": [2, 4, 4, 8, 8], - "latent_dim": 128, - "use_snake": true - } - }, - "decoder": { - "type": "oobleck", - "config": { - "out_channels": 2, - "channels": 128, - "c_mults": [1, 2, 4, 8, 16], - "strides": [2, 4, 4, 8, 8], - "latent_dim": 64, - "use_snake": true, - "final_tanh": false - } - }, - "bottleneck": { - "type": "vae" - }, - "latent_dim": 64, - "downsampling_ratio": 2048, - "io_channels": 2 - } - }, - "conditioning": { - "configs": [ - { - "id": "metaclip_features", - "type": "mm_unchang", - "config": { - "dim": 1024, - "output_dim": 1024 - } - }, - { - "id": "metaclip_text_features", - "type": "mm_unchang", - "config": { - "dim": 1024, - "output_dim": 1024 - } - }, - { - "id": "sync_features", - "type": "mm_unchang", - "config": { - "dim": 768, - "output_dim": 768 - } - } - ], - "cond_dim": 768 - }, - "diffusion": { - "mm_cond_ids": ["metaclip_features", "sync_features", "metaclip_text_features"], - "type": "mmdit", - "diffusion_objective": "rectified_flow", - "config": { - "latent_dim":64, - "clip_dim":1024, - "sync_dim":768, - "text_dim":1024, - "hidden_dim":1024 , - "depth":21, - "fused_depth":14, - "num_heads":16, - "latent_seq_len":194, - "clip_seq_len":72, - "sync_seq_len":216, - "v2": true, - "cross_attend": false, - "add_video": true, - "gated_video": true - } - }, - "io_channels": 64 - }, - "training": { - "use_ema": true, - "log_loss_info": false, - "cfg_dropout_prob": 0.2, - "pre_encoded": true, - "timestep_sampler": "logit_normal", - "optimizer_configs": { - "diffusion": { - "optimizer": { - "type": "AdamW", - "config": { - "lr": 5e-5, - "betas": [0.9, 0.95], - "weight_decay": 1e-4, - "eps": 1e-6 - } - }, - "scheduler": { - "type": "InverseLR", - "config": { - "inv_gamma": 1000000, - "power": 0.5, - "warmup": 0.99 - } - } - } - }, - "demo": { - "demo_every": 2000, - "demo_steps": 24, - "num_demos": 10, - "demo_cond": [ - "dataset/vggsound/video_224_latents_text_sync_roi_npz/test/0Cu33yBwAPg_000060.npz", - "dataset/vggsound/video_224_latents_text_sync_roi_npz/test/bmKtI808DsU_000009.npz", - "dataset/vggsound/video_224_latents_text_sync_roi_npz/test/VC0c22cJTbM_000424.npz", - "dataset/vggsound/video_224_latents_text_sync_roi_npz/test/F3gsbUTdc2U_000090.npz", - "dataset/vggsound/video_224_latents_text_sync_roi_npz/test/WatvT8A8iug_000100.npz", - "dataset/vggsound/video_224_latents_text_sync_roi_npz/test/0nvBTp-q7tU_000112.npz", - "dataset/vggsound/video_224_latents_text_sync_roi_npz/test/3-PFuDkTM48_000080.npz", - "dataset/vggsound/video_224_latents_text_sync_roi_npz/test/luSAuu-BoPs_000232.npz", - "dataset/vggsound/video_224_latents_text_sync_roi_npz/test/__8UJxW0aOQ_000002.npz", - "dataset/vggsound/video_224_latents_text_sync_roi_npz/test/_0m_YMpQayA_000168.npz" - ], - "demo_cfg_scales": [3,6,9] - } - } -} \ No newline at end of file diff --git a/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_inpaint.json b/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_inpaint.json deleted file mode 100644 index 687836326151c5a23a57b00e2bb7619e383db34c..0000000000000000000000000000000000000000 --- a/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_inpaint.json +++ /dev/null @@ -1,140 +0,0 @@ -{ - "model_type": "mm_diffusion_cond", - "sample_size": 397312, - "sample_rate": 44100, - "audio_channels": 2, - "model": { - "pretransform": { - "type": "autoencoder", - "iterate_batch": true, - "config": { - "encoder": { - "type": "oobleck", - "config": { - "in_channels": 2, - "channels": 128, - "c_mults": [1, 2, 4, 8, 16], - "strides": [2, 4, 4, 8, 8], - "latent_dim": 128, - "use_snake": true - } - }, - "decoder": { - "type": "oobleck", - "config": { - "out_channels": 2, - "channels": 128, - "c_mults": [1, 2, 4, 8, 16], - "strides": [2, 4, 4, 8, 8], - "latent_dim": 64, - "use_snake": true, - "final_tanh": false - } - }, - "bottleneck": { - "type": "vae" - }, - "latent_dim": 64, - "downsampling_ratio": 2048, - "io_channels": 2 - } - }, - "conditioning": { - "configs": [ - { - "id": "metaclip_features", - "type": "mm_unchang", - "config": { - "dim": 1024, - "output_dim": 1024 - } - }, - { - "id": "metaclip_text_features", - "type": "mm_unchang", - "config": { - "dim": 1024, - "output_dim": 1024 - } - }, - { - "id": "sync_features", - "type": "mm_unchang", - "config": { - "dim": 768, - "output_dim": 768 - } - } - ], - "cond_dim": 768 - }, - "diffusion": { - "mm_cond_ids": ["metaclip_features", "sync_features", "metaclip_text_features"], - "type": "mmdit", - "diffusion_objective": "rectified_flow", - "config": { - "latent_dim":64, - "clip_dim":1024, - "sync_dim":768, - "text_dim":1024, - "hidden_dim":1024 , - "depth":21, - "fused_depth":14, - "num_heads":16, - "latent_seq_len":194, - "clip_seq_len":72, - "sync_seq_len":216, - "v2": true, - "use_inpaint": true - } - }, - "io_channels": 64 - }, - "training": { - "use_ema": true, - "log_loss_info": false, - "cfg_dropout_prob": 0.2, - "pre_encoded": true, - "timestep_sampler": "logit_normal", - "max_mask_segments": 10, - "optimizer_configs": { - "diffusion": { - "optimizer": { - "type": "AdamW", - "config": { - "lr": 5e-5, - "betas": [0.9, 0.95], - "weight_decay": 1e-4, - "eps": 1e-6 - } - }, - "scheduler": { - "type": "InverseLR", - "config": { - "inv_gamma": 1000000, - "power": 0.5, - "warmup": 0.99 - } - } - } - }, - "demo": { - "demo_every": 2000, - "demo_steps": 24, - "num_demos": 10, - "demo_cond": [ - "dataset/vggsound/video_text_latents/test/0Cu33yBwAPg_000060.pth", - "dataset/vggsound/video_text_latents/test/bmKtI808DsU_000009.pth", - "dataset/vggsound/video_text_latents/test/VC0c22cJTbM_000424.pth", - "dataset/vggsound/video_text_latents/test/F3gsbUTdc2U_000090.pth", - "dataset/vggsound/video_text_latents/test/WatvT8A8iug_000100.pth", - "dataset/vggsound/video_text_latents/test/0nvBTp-q7tU_000112.pth", - "dataset/vggsound/video_text_latents/test/3-PFuDkTM48_000080.pth", - "dataset/vggsound/video_text_latents/test/luSAuu-BoPs_000232.pth", - "dataset/vggsound/video_text_latents/test/__8UJxW0aOQ_000002.pth", - "dataset/vggsound/video_text_latents/test/_0m_YMpQayA_000168.pth" - ], - "demo_cfg_scales": [3,6,9] - } - } -} \ No newline at end of file diff --git a/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5.json b/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5.json deleted file mode 100644 index 46e5073708d33ee804d27c57aac7ee56e8580c3f..0000000000000000000000000000000000000000 --- a/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5.json +++ /dev/null @@ -1,146 +0,0 @@ -{ - "model_type": "mm_diffusion_cond", - "sample_size": 397312, - "sample_rate": 44100, - "audio_channels": 2, - "model": { - "pretransform": { - "type": "autoencoder", - "iterate_batch": true, - "config": { - "encoder": { - "type": "oobleck", - "config": { - "in_channels": 2, - "channels": 128, - "c_mults": [1, 2, 4, 8, 16], - "strides": [2, 4, 4, 8, 8], - "latent_dim": 128, - "use_snake": true - } - }, - "decoder": { - "type": "oobleck", - "config": { - "out_channels": 2, - "channels": 128, - "c_mults": [1, 2, 4, 8, 16], - "strides": [2, 4, 4, 8, 8], - "latent_dim": 64, - "use_snake": true, - "final_tanh": false - } - }, - "bottleneck": { - "type": "vae" - }, - "latent_dim": 64, - "downsampling_ratio": 2048, - "io_channels": 2 - } - }, - "conditioning": { - "configs": [ - { - "id": "metaclip_features", - "type": "mm_unchang", - "config": { - "dim": 1024, - "output_dim": 1024 - } - }, - { - "id": "metaclip_text_features", - "type": "mm_unchang", - "config": { - "dim": 1024, - "output_dim": 1024 - } - }, - { - "id": "sync_features", - "type": "mm_unchang", - "config": { - "dim": 768, - "output_dim": 768 - } - }, - { - "id": "t5_features", - "type": "mm_unchang", - "config": { - "dim": 2048, - "output_dim": 2048 - } - } - ], - "cond_dim": 768 - }, - "diffusion": { - "mm_cond_ids": ["metaclip_features", "sync_features", "metaclip_text_features","t5_features"], - "type": "mmdit", - "diffusion_objective": "rectified_flow", - "config": { - "latent_dim":64, - "clip_dim":1024, - "sync_dim":768, - "text_dim":2048, - "hidden_dim":1024 , - "depth":21, - "fused_depth":14, - "num_heads":16, - "latent_seq_len":194, - "clip_seq_len":72, - "sync_seq_len":216, - "v2": true - } - }, - "io_channels": 64 - }, - "training": { - "use_ema": true, - "log_loss_info": false, - "cfg_dropout_prob": 0.2, - "pre_encoded": true, - "timestep_sampler": "logit_normal", - "optimizer_configs": { - "diffusion": { - "optimizer": { - "type": "AdamW", - "config": { - "lr": 5e-5, - "betas": [0.9, 0.95], - "weight_decay": 1e-4, - "eps": 1e-6 - } - }, - "scheduler": { - "type": "InverseLR", - "config": { - "inv_gamma": 1000000, - "power": 0.5, - "warmup": 0.99 - } - } - } - }, - "demo": { - "demo_every": 5000, - "demo_steps": 24, - "num_demos": 10, - "demo_cond": [ - "dataset/vggsound/video_latents_t5_clip_npz/test/0Cu33yBwAPg_000060.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/bmKtI808DsU_000009.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/VC0c22cJTbM_000424.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/F3gsbUTdc2U_000090.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/WatvT8A8iug_000100.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/0nvBTp-q7tU_000112.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/3-PFuDkTM48_000080.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/luSAuu-BoPs_000232.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/__8UJxW0aOQ_000002.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/_0m_YMpQayA_000168.npz" - ], - "demo_cfg_scales": [5] - } - } -} \ No newline at end of file diff --git a/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_cross.json b/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_cross.json deleted file mode 100644 index 87d5edb0be880070e3875ae70365c65612b47c55..0000000000000000000000000000000000000000 --- a/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_cross.json +++ /dev/null @@ -1,148 +0,0 @@ -{ - "model_type": "mm_diffusion_cond", - "sample_size": 397312, - "sample_rate": 44100, - "audio_channels": 2, - "model": { - "pretransform": { - "type": "autoencoder", - "iterate_batch": true, - "config": { - "encoder": { - "type": "oobleck", - "config": { - "in_channels": 2, - "channels": 128, - "c_mults": [1, 2, 4, 8, 16], - "strides": [2, 4, 4, 8, 8], - "latent_dim": 128, - "use_snake": true - } - }, - "decoder": { - "type": "oobleck", - "config": { - "out_channels": 2, - "channels": 128, - "c_mults": [1, 2, 4, 8, 16], - "strides": [2, 4, 4, 8, 8], - "latent_dim": 64, - "use_snake": true, - "final_tanh": false - } - }, - "bottleneck": { - "type": "vae" - }, - "latent_dim": 64, - "downsampling_ratio": 2048, - "io_channels": 2 - } - }, - "conditioning": { - "configs": [ - { - "id": "metaclip_features", - "type": "mm_unchang", - "config": { - "dim": 1024, - "output_dim": 1024 - } - }, - { - "id": "metaclip_text_features", - "type": "mm_unchang", - "config": { - "dim": 1024, - "output_dim": 1024 - } - }, - { - "id": "sync_features", - "type": "mm_unchang", - "config": { - "dim": 768, - "output_dim": 768 - } - }, - { - "id": "t5_features", - "type": "mm_unchang", - "config": { - "dim": 2048, - "output_dim": 2048 - } - } - ], - "cond_dim": 768 - }, - "diffusion": { - "mm_cond_ids": ["metaclip_features", "sync_features", "metaclip_text_features","t5_features"], - "type": "mmdit", - "diffusion_objective": "rectified_flow", - "config": { - "latent_dim":64, - "clip_dim":1024, - "sync_dim":768, - "text_dim":2048, - "hidden_dim":1024 , - "depth":21, - "fused_depth":14, - "num_heads":16, - "latent_seq_len":194, - "clip_seq_len":72, - "sync_seq_len":216, - "v2": true, - "add_video": true, - "cross_attend": true - } - }, - "io_channels": 64 - }, - "training": { - "use_ema": true, - "log_loss_info": false, - "cfg_dropout_prob": 0.2, - "pre_encoded": true, - "timestep_sampler": "logit_normal", - "optimizer_configs": { - "diffusion": { - "optimizer": { - "type": "AdamW", - "config": { - "lr": 5e-5, - "betas": [0.9, 0.95], - "weight_decay": 1e-4, - "eps": 1e-6 - } - }, - "scheduler": { - "type": "InverseLR", - "config": { - "inv_gamma": 1000000, - "power": 0.5, - "warmup": 0.99 - } - } - } - }, - "demo": { - "demo_every": 4000, - "demo_steps": 24, - "num_demos": 10, - "demo_cond": [ - "dataset/vggsound/video_latents_t5_clip_npz/test/0Cu33yBwAPg_000060.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/bmKtI808DsU_000009.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/VC0c22cJTbM_000424.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/F3gsbUTdc2U_000090.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/WatvT8A8iug_000100.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/0nvBTp-q7tU_000112.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/3-PFuDkTM48_000080.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/luSAuu-BoPs_000232.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/__8UJxW0aOQ_000002.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/_0m_YMpQayA_000168.npz" - ], - "demo_cfg_scales": [3,6,9] - } - } -} \ No newline at end of file diff --git a/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_gated.json b/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_gated.json deleted file mode 100644 index b40d5a04c5e38ae9ef734ada069512df3ad887f7..0000000000000000000000000000000000000000 --- a/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_gated.json +++ /dev/null @@ -1,148 +0,0 @@ -{ - "model_type": "mm_diffusion_cond", - "sample_size": 397312, - "sample_rate": 44100, - "audio_channels": 2, - "model": { - "pretransform": { - "type": "autoencoder", - "iterate_batch": true, - "config": { - "encoder": { - "type": "oobleck", - "config": { - "in_channels": 2, - "channels": 128, - "c_mults": [1, 2, 4, 8, 16], - "strides": [2, 4, 4, 8, 8], - "latent_dim": 128, - "use_snake": true - } - }, - "decoder": { - "type": "oobleck", - "config": { - "out_channels": 2, - "channels": 128, - "c_mults": [1, 2, 4, 8, 16], - "strides": [2, 4, 4, 8, 8], - "latent_dim": 64, - "use_snake": true, - "final_tanh": false - } - }, - "bottleneck": { - "type": "vae" - }, - "latent_dim": 64, - "downsampling_ratio": 2048, - "io_channels": 2 - } - }, - "conditioning": { - "configs": [ - { - "id": "metaclip_features", - "type": "mm_unchang", - "config": { - "dim": 1024, - "output_dim": 1024 - } - }, - { - "id": "metaclip_text_features", - "type": "mm_unchang", - "config": { - "dim": 1024, - "output_dim": 1024 - } - }, - { - "id": "sync_features", - "type": "mm_unchang", - "config": { - "dim": 768, - "output_dim": 768 - } - }, - { - "id": "t5_features", - "type": "mm_unchang", - "config": { - "dim": 2048, - "output_dim": 2048 - } - } - ], - "cond_dim": 768 - }, - "diffusion": { - "mm_cond_ids": ["metaclip_features", "sync_features", "metaclip_text_features","t5_features"], - "type": "mmdit", - "diffusion_objective": "rectified_flow", - "config": { - "latent_dim":64, - "clip_dim":1024, - "sync_dim":768, - "text_dim":2048, - "hidden_dim":1024 , - "depth":21, - "fused_depth":14, - "num_heads":16, - "latent_seq_len":194, - "clip_seq_len":72, - "sync_seq_len":216, - "v2": true, - "add_video": true, - "gated_video": true - } - }, - "io_channels": 64 - }, - "training": { - "use_ema": true, - "log_loss_info": false, - "cfg_dropout_prob": 0.2, - "pre_encoded": true, - "timestep_sampler": "logit_normal", - "optimizer_configs": { - "diffusion": { - "optimizer": { - "type": "AdamW", - "config": { - "lr": 5e-5, - "betas": [0.9, 0.95], - "weight_decay": 1e-4, - "eps": 1e-6 - } - }, - "scheduler": { - "type": "InverseLR", - "config": { - "inv_gamma": 1000000, - "power": 0.5, - "warmup": 0.99 - } - } - } - }, - "demo": { - "demo_every": 4000, - "demo_steps": 24, - "num_demos": 10, - "demo_cond": [ - "dataset/vggsound/video_latents_t5_clip_npz/test/0Cu33yBwAPg_000060.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/bmKtI808DsU_000009.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/VC0c22cJTbM_000424.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/F3gsbUTdc2U_000090.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/WatvT8A8iug_000100.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/0nvBTp-q7tU_000112.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/3-PFuDkTM48_000080.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/luSAuu-BoPs_000232.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/__8UJxW0aOQ_000002.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/_0m_YMpQayA_000168.npz" - ], - "demo_cfg_scales": [3,6,9] - } - } -} \ No newline at end of file diff --git a/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_gated_cross.json b/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_gated_cross.json deleted file mode 100644 index c1e1c1dd37c57f47b0b07d1ce5d0b38d527d649f..0000000000000000000000000000000000000000 --- a/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_gated_cross.json +++ /dev/null @@ -1,149 +0,0 @@ -{ - "model_type": "mm_diffusion_cond", - "sample_size": 397312, - "sample_rate": 44100, - "audio_channels": 2, - "model": { - "pretransform": { - "type": "autoencoder", - "iterate_batch": true, - "config": { - "encoder": { - "type": "oobleck", - "config": { - "in_channels": 2, - "channels": 128, - "c_mults": [1, 2, 4, 8, 16], - "strides": [2, 4, 4, 8, 8], - "latent_dim": 128, - "use_snake": true - } - }, - "decoder": { - "type": "oobleck", - "config": { - "out_channels": 2, - "channels": 128, - "c_mults": [1, 2, 4, 8, 16], - "strides": [2, 4, 4, 8, 8], - "latent_dim": 64, - "use_snake": true, - "final_tanh": false - } - }, - "bottleneck": { - "type": "vae" - }, - "latent_dim": 64, - "downsampling_ratio": 2048, - "io_channels": 2 - } - }, - "conditioning": { - "configs": [ - { - "id": "metaclip_features", - "type": "mm_unchang", - "config": { - "dim": 1024, - "output_dim": 1024 - } - }, - { - "id": "metaclip_text_features", - "type": "mm_unchang", - "config": { - "dim": 1024, - "output_dim": 1024 - } - }, - { - "id": "sync_features", - "type": "mm_unchang", - "config": { - "dim": 768, - "output_dim": 768 - } - }, - { - "id": "t5_features", - "type": "mm_unchang", - "config": { - "dim": 2048, - "output_dim": 2048 - } - } - ], - "cond_dim": 768 - }, - "diffusion": { - "mm_cond_ids": ["metaclip_features", "sync_features", "metaclip_text_features","t5_features"], - "type": "mmdit", - "diffusion_objective": "rectified_flow", - "config": { - "latent_dim":64, - "clip_dim":1024, - "sync_dim":768, - "text_dim":2048, - "hidden_dim":1024 , - "depth":21, - "fused_depth":14, - "num_heads":16, - "latent_seq_len":194, - "clip_seq_len":72, - "sync_seq_len":216, - "v2": true, - "add_video": true, - "gated_video": true, - "cross_attend": true - } - }, - "io_channels": 64 - }, - "training": { - "use_ema": true, - "log_loss_info": false, - "cfg_dropout_prob": 0.2, - "pre_encoded": true, - "timestep_sampler": "logit_normal", - "optimizer_configs": { - "diffusion": { - "optimizer": { - "type": "AdamW", - "config": { - "lr": 5e-5, - "betas": [0.9, 0.95], - "weight_decay": 1e-4, - "eps": 1e-6 - } - }, - "scheduler": { - "type": "InverseLR", - "config": { - "inv_gamma": 1000000, - "power": 0.5, - "warmup": 0.99 - } - } - } - }, - "demo": { - "demo_every": 100000, - "demo_steps": 24, - "num_demos": 10, - "demo_cond": [ - "dataset/vggsound/video_latents_t5_clip_npz/test/0Cu33yBwAPg_000060.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/bmKtI808DsU_000009.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/VC0c22cJTbM_000424.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/F3gsbUTdc2U_000090.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/WatvT8A8iug_000100.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/0nvBTp-q7tU_000112.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/3-PFuDkTM48_000080.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/luSAuu-BoPs_000232.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/__8UJxW0aOQ_000002.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/_0m_YMpQayA_000168.npz" - ], - "demo_cfg_scales": [5] - } - } -} \ No newline at end of file diff --git a/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_gated_cross_lr1.json b/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_gated_cross_lr1.json deleted file mode 100644 index 3b1e8c61ec9dde8d78998865f84a94a4b71fed39..0000000000000000000000000000000000000000 --- a/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_gated_cross_lr1.json +++ /dev/null @@ -1,150 +0,0 @@ -{ - "model_type": "mm_diffusion_cond", - "sample_size": 397312, - "sample_rate": 44100, - "audio_channels": 2, - "model": { - "pretransform": { - "type": "autoencoder", - "iterate_batch": true, - "config": { - "encoder": { - "type": "oobleck", - "config": { - "in_channels": 2, - "channels": 128, - "c_mults": [1, 2, 4, 8, 16], - "strides": [2, 4, 4, 8, 8], - "latent_dim": 128, - "use_snake": true - } - }, - "decoder": { - "type": "oobleck", - "config": { - "out_channels": 2, - "channels": 128, - "c_mults": [1, 2, 4, 8, 16], - "strides": [2, 4, 4, 8, 8], - "latent_dim": 64, - "use_snake": true, - "final_tanh": false - } - }, - "bottleneck": { - "type": "vae" - }, - "latent_dim": 64, - "downsampling_ratio": 2048, - "io_channels": 2 - } - }, - "conditioning": { - "configs": [ - { - "id": "metaclip_features", - "type": "mm_unchang", - "config": { - "dim": 1024, - "output_dim": 1024 - } - }, - { - "id": "metaclip_text_features", - "type": "mm_unchang", - "config": { - "dim": 1024, - "output_dim": 1024 - } - }, - { - "id": "sync_features", - "type": "mm_unchang", - "config": { - "dim": 768, - "output_dim": 768 - } - }, - { - "id": "t5_features", - "type": "mm_unchang", - "config": { - "dim": 2048, - "output_dim": 2048 - } - } - ], - "cond_dim": 768 - }, - "diffusion": { - "mm_cond_ids": ["metaclip_features", "sync_features", "metaclip_text_features","t5_features"], - "type": "mmdit", - "diffusion_objective": "rectified_flow", - "config": { - "latent_dim":64, - "clip_dim":1024, - "sync_dim":768, - "text_dim":2048, - "hidden_dim":1024 , - "depth":21, - "fused_depth":14, - "num_heads":16, - "latent_seq_len":194, - "clip_seq_len":72, - "sync_seq_len":216, - "v2": true, - "add_video": true, - "gated_video": true, - "cross_attend": true - } - }, - "io_channels": 64 - }, - "training": { - "use_ema": true, - "log_loss_info": false, - "cfg_dropout_prob": 0.2, - "pre_encoded": true, - "timestep_sampler": "logit_normal", - "optimizer_configs": { - "diffusion": { - "optimizer": { - "type": "AdamW", - "config": { - "lr": 1e-4, - "betas": [0.9, 0.95], - "weight_decay": 1e-6, - "fused": true, - "eps": 1e-6 - } - }, - "scheduler": { - "type": "InverseLR", - "config": { - "inv_gamma": 1000000, - "power": 0.5, - "warmup": 0.99 - } - } - } - }, - "demo": { - "demo_every": 5000, - "demo_steps": 24, - "num_demos": 10, - "demo_cond": [ - "dataset/vggsound/video_latents_t5_clip_npz/test/0Cu33yBwAPg_000060.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/bmKtI808DsU_000009.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/VC0c22cJTbM_000424.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/F3gsbUTdc2U_000090.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/WatvT8A8iug_000100.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/0nvBTp-q7tU_000112.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/3-PFuDkTM48_000080.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/luSAuu-BoPs_000232.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/__8UJxW0aOQ_000002.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/_0m_YMpQayA_000168.npz" - ], - "demo_cfg_scales": [5] - } - } -} \ No newline at end of file diff --git a/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_gated_cross_lr2.json b/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_gated_cross_lr2.json deleted file mode 100644 index 3db01645031f702b73d2b8cc2a218bf2142128e9..0000000000000000000000000000000000000000 --- a/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_gated_cross_lr2.json +++ /dev/null @@ -1,150 +0,0 @@ -{ - "model_type": "mm_diffusion_cond", - "sample_size": 397312, - "sample_rate": 44100, - "audio_channels": 2, - "model": { - "pretransform": { - "type": "autoencoder", - "iterate_batch": true, - "config": { - "encoder": { - "type": "oobleck", - "config": { - "in_channels": 2, - "channels": 128, - "c_mults": [1, 2, 4, 8, 16], - "strides": [2, 4, 4, 8, 8], - "latent_dim": 128, - "use_snake": true - } - }, - "decoder": { - "type": "oobleck", - "config": { - "out_channels": 2, - "channels": 128, - "c_mults": [1, 2, 4, 8, 16], - "strides": [2, 4, 4, 8, 8], - "latent_dim": 64, - "use_snake": true, - "final_tanh": false - } - }, - "bottleneck": { - "type": "vae" - }, - "latent_dim": 64, - "downsampling_ratio": 2048, - "io_channels": 2 - } - }, - "conditioning": { - "configs": [ - { - "id": "metaclip_features", - "type": "mm_unchang", - "config": { - "dim": 1024, - "output_dim": 1024 - } - }, - { - "id": "metaclip_text_features", - "type": "mm_unchang", - "config": { - "dim": 1024, - "output_dim": 1024 - } - }, - { - "id": "sync_features", - "type": "mm_unchang", - "config": { - "dim": 768, - "output_dim": 768 - } - }, - { - "id": "t5_features", - "type": "mm_unchang", - "config": { - "dim": 2048, - "output_dim": 2048 - } - } - ], - "cond_dim": 768 - }, - "diffusion": { - "mm_cond_ids": ["metaclip_features", "sync_features", "metaclip_text_features","t5_features"], - "type": "mmdit", - "diffusion_objective": "rectified_flow", - "config": { - "latent_dim":64, - "clip_dim":1024, - "sync_dim":768, - "text_dim":2048, - "hidden_dim":1024 , - "depth":21, - "fused_depth":14, - "num_heads":16, - "latent_seq_len":194, - "clip_seq_len":72, - "sync_seq_len":216, - "v2": true, - "add_video": true, - "gated_video": true, - "cross_attend": true - } - }, - "io_channels": 64 - }, - "training": { - "use_ema": true, - "log_loss_info": false, - "cfg_dropout_prob": 0.2, - "pre_encoded": true, - "timestep_sampler": "logit_normal", - "optimizer_configs": { - "diffusion": { - "optimizer": { - "type": "AdamW", - "config": { - "lr": 1e-4, - "betas": [0.9, 0.95], - "weight_decay": 1e-6, - "fused": true, - "eps": 1e-6 - } - }, - "scheduler": { - "type": "InverseLR", - "config": { - "inv_gamma": 100000, - "power": 0.5, - "warmup": 0.99 - } - } - } - }, - "demo": { - "demo_every": 2000, - "demo_steps": 24, - "num_demos": 10, - "demo_cond": [ - "dataset/vggsound/video_latents_t5_clip_npz/test/0Cu33yBwAPg_000060.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/bmKtI808DsU_000009.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/VC0c22cJTbM_000424.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/F3gsbUTdc2U_000090.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/WatvT8A8iug_000100.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/0nvBTp-q7tU_000112.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/3-PFuDkTM48_000080.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/luSAuu-BoPs_000232.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/__8UJxW0aOQ_000002.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/_0m_YMpQayA_000168.npz" - ], - "demo_cfg_scales": [3,6,9] - } - } -} \ No newline at end of file diff --git a/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_gated_lr1.json b/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_gated_lr1.json deleted file mode 100644 index 86f0c2dd1c52132155801da4d2b7c7a86bccb8cc..0000000000000000000000000000000000000000 --- a/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_gated_lr1.json +++ /dev/null @@ -1,149 +0,0 @@ -{ - "model_type": "mm_diffusion_cond", - "sample_size": 397312, - "sample_rate": 44100, - "audio_channels": 2, - "model": { - "pretransform": { - "type": "autoencoder", - "iterate_batch": true, - "config": { - "encoder": { - "type": "oobleck", - "config": { - "in_channels": 2, - "channels": 128, - "c_mults": [1, 2, 4, 8, 16], - "strides": [2, 4, 4, 8, 8], - "latent_dim": 128, - "use_snake": true - } - }, - "decoder": { - "type": "oobleck", - "config": { - "out_channels": 2, - "channels": 128, - "c_mults": [1, 2, 4, 8, 16], - "strides": [2, 4, 4, 8, 8], - "latent_dim": 64, - "use_snake": true, - "final_tanh": false - } - }, - "bottleneck": { - "type": "vae" - }, - "latent_dim": 64, - "downsampling_ratio": 2048, - "io_channels": 2 - } - }, - "conditioning": { - "configs": [ - { - "id": "metaclip_features", - "type": "mm_unchang", - "config": { - "dim": 1024, - "output_dim": 1024 - } - }, - { - "id": "metaclip_text_features", - "type": "mm_unchang", - "config": { - "dim": 1024, - "output_dim": 1024 - } - }, - { - "id": "sync_features", - "type": "mm_unchang", - "config": { - "dim": 768, - "output_dim": 768 - } - }, - { - "id": "t5_features", - "type": "mm_unchang", - "config": { - "dim": 2048, - "output_dim": 2048 - } - } - ], - "cond_dim": 768 - }, - "diffusion": { - "mm_cond_ids": ["metaclip_features", "sync_features", "metaclip_text_features","t5_features"], - "type": "mmdit", - "diffusion_objective": "rectified_flow", - "config": { - "latent_dim":64, - "clip_dim":1024, - "sync_dim":768, - "text_dim":2048, - "hidden_dim":1024 , - "depth":21, - "fused_depth":14, - "num_heads":16, - "latent_seq_len":194, - "clip_seq_len":72, - "sync_seq_len":216, - "v2": true, - "add_video": true, - "gated_video": true - } - }, - "io_channels": 64 - }, - "training": { - "use_ema": true, - "log_loss_info": false, - "cfg_dropout_prob": 0.2, - "pre_encoded": true, - "timestep_sampler": "logit_normal", - "optimizer_configs": { - "diffusion": { - "optimizer": { - "type": "AdamW", - "config": { - "lr": 1e-4, - "betas": [0.9, 0.95], - "weight_decay": 1e-6, - "fused": true, - "eps": 1e-6 - } - }, - "scheduler": { - "type": "InverseLR", - "config": { - "inv_gamma": 1000000, - "power": 0.5, - "warmup": 0.99 - } - } - } - }, - "demo": { - "demo_every": 5000, - "demo_steps": 24, - "num_demos": 10, - "demo_cond": [ - "dataset/vggsound/video_latents_t5_clip_npz/test/0Cu33yBwAPg_000060.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/bmKtI808DsU_000009.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/VC0c22cJTbM_000424.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/F3gsbUTdc2U_000090.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/WatvT8A8iug_000100.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/0nvBTp-q7tU_000112.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/3-PFuDkTM48_000080.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/luSAuu-BoPs_000232.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/__8UJxW0aOQ_000002.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/_0m_YMpQayA_000168.npz" - ], - "demo_cfg_scales": [5] - } - } -} \ No newline at end of file diff --git a/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_global.json b/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_global.json deleted file mode 100644 index d001485e091baf2bed5222c3e8b81af8646f9b94..0000000000000000000000000000000000000000 --- a/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_global.json +++ /dev/null @@ -1,154 +0,0 @@ -{ - "model_type": "mm_diffusion_cond", - "sample_size": 397312, - "sample_rate": 44100, - "audio_channels": 2, - "model": { - "pretransform": { - "type": "autoencoder", - "iterate_batch": true, - "config": { - "encoder": { - "type": "oobleck", - "config": { - "in_channels": 2, - "channels": 128, - "c_mults": [1, 2, 4, 8, 16], - "strides": [2, 4, 4, 8, 8], - "latent_dim": 128, - "use_snake": true - } - }, - "decoder": { - "type": "oobleck", - "config": { - "out_channels": 2, - "channels": 128, - "c_mults": [1, 2, 4, 8, 16], - "strides": [2, 4, 4, 8, 8], - "latent_dim": 64, - "use_snake": true, - "final_tanh": false - } - }, - "bottleneck": { - "type": "vae" - }, - "latent_dim": 64, - "downsampling_ratio": 2048, - "io_channels": 2 - } - }, - "conditioning": { - "configs": [ - { - "id": "metaclip_features", - "type": "mm_unchang", - "config": { - "dim": 1024, - "output_dim": 1024 - } - }, - { - "id": "metaclip_text_features", - "type": "mm_unchang", - "config": { - "dim": 1024, - "output_dim": 1024 - } - }, - { - "id": "metaclip_global_text_features", - "type": "mm_unchang", - "config": { - "dim": 1024, - "output_dim": 1024 - } - }, - { - "id": "sync_features", - "type": "mm_unchang", - "config": { - "dim": 768, - "output_dim": 768 - } - }, - { - "id": "t5_features", - "type": "mm_unchang", - "config": { - "dim": 2048, - "output_dim": 2048 - } - } - ], - "cond_dim": 768 - }, - "diffusion": { - "mm_cond_ids": ["metaclip_features", "sync_features", "metaclip_text_features", "metaclip_global_text_features", "t5_features"], - "type": "mmdit", - "diffusion_objective": "rectified_flow", - "config": { - "latent_dim":64, - "clip_dim":1024, - "sync_dim":768, - "text_dim":2048, - "hidden_dim":1024 , - "depth":21, - "fused_depth":14, - "num_heads":16, - "latent_seq_len":194, - "clip_seq_len":72, - "sync_seq_len":216, - "v2": true - } - }, - "io_channels": 64 - }, - "training": { - "use_ema": true, - "log_loss_info": false, - "cfg_dropout_prob": 0.2, - "pre_encoded": true, - "timestep_sampler": "logit_normal", - "optimizer_configs": { - "diffusion": { - "optimizer": { - "type": "AdamW", - "config": { - "lr": 5e-5, - "betas": [0.9, 0.95], - "weight_decay": 1e-4, - "eps": 1e-6 - } - }, - "scheduler": { - "type": "InverseLR", - "config": { - "inv_gamma": 1000000, - "power": 0.5, - "warmup": 0.99 - } - } - } - }, - "demo": { - "demo_every": 2000, - "demo_steps": 24, - "num_demos": 10, - "demo_cond": [ - "dataset/vggsound/video_latents_t5_clip/test/0Cu33yBwAPg_000060.pth", - "dataset/vggsound/video_latents_t5_clip/test/bmKtI808DsU_000009.pth", - "dataset/vggsound/video_latents_t5_clip/test/VC0c22cJTbM_000424.pth", - "dataset/vggsound/video_latents_t5_clip/test/F3gsbUTdc2U_000090.pth", - "dataset/vggsound/video_latents_t5_clip/test/WatvT8A8iug_000100.pth", - "dataset/vggsound/video_latents_t5_clip/test/0nvBTp-q7tU_000112.pth", - "dataset/vggsound/video_latents_t5_clip/test/3-PFuDkTM48_000080.pth", - "dataset/vggsound/video_latents_t5_clip/test/luSAuu-BoPs_000232.pth", - "dataset/vggsound/video_latents_t5_clip/test/__8UJxW0aOQ_000002.pth", - "dataset/vggsound/video_latents_t5_clip/test/_0m_YMpQayA_000168.pth" - ], - "demo_cfg_scales": [3,6,9] - } - } -} \ No newline at end of file diff --git a/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_global_gated.json b/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_global_gated.json deleted file mode 100644 index 133793bfac6acf03e467f0f21097a3439f4e2bbd..0000000000000000000000000000000000000000 --- a/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_global_gated.json +++ /dev/null @@ -1,156 +0,0 @@ -{ - "model_type": "mm_diffusion_cond", - "sample_size": 397312, - "sample_rate": 44100, - "audio_channels": 2, - "model": { - "pretransform": { - "type": "autoencoder", - "iterate_batch": true, - "config": { - "encoder": { - "type": "oobleck", - "config": { - "in_channels": 2, - "channels": 128, - "c_mults": [1, 2, 4, 8, 16], - "strides": [2, 4, 4, 8, 8], - "latent_dim": 128, - "use_snake": true - } - }, - "decoder": { - "type": "oobleck", - "config": { - "out_channels": 2, - "channels": 128, - "c_mults": [1, 2, 4, 8, 16], - "strides": [2, 4, 4, 8, 8], - "latent_dim": 64, - "use_snake": true, - "final_tanh": false - } - }, - "bottleneck": { - "type": "vae" - }, - "latent_dim": 64, - "downsampling_ratio": 2048, - "io_channels": 2 - } - }, - "conditioning": { - "configs": [ - { - "id": "metaclip_features", - "type": "mm_unchang", - "config": { - "dim": 1024, - "output_dim": 1024 - } - }, - { - "id": "metaclip_text_features", - "type": "mm_unchang", - "config": { - "dim": 1024, - "output_dim": 1024 - } - }, - { - "id": "metaclip_global_text_features", - "type": "mm_unchang", - "config": { - "dim": 1024, - "output_dim": 1024 - } - }, - { - "id": "sync_features", - "type": "mm_unchang", - "config": { - "dim": 768, - "output_dim": 768 - } - }, - { - "id": "t5_features", - "type": "mm_unchang", - "config": { - "dim": 2048, - "output_dim": 2048 - } - } - ], - "cond_dim": 768 - }, - "diffusion": { - "mm_cond_ids": ["metaclip_features", "sync_features", "metaclip_text_features", "metaclip_global_text_features", "t5_features"], - "type": "mmdit", - "diffusion_objective": "rectified_flow", - "config": { - "latent_dim":64, - "clip_dim":1024, - "sync_dim":768, - "text_dim":2048, - "hidden_dim":1024 , - "depth":21, - "fused_depth":14, - "num_heads":16, - "latent_seq_len":194, - "clip_seq_len":72, - "sync_seq_len":216, - "v2": true, - "add_video": true, - "gated_video": true - } - }, - "io_channels": 64 - }, - "training": { - "use_ema": true, - "log_loss_info": false, - "cfg_dropout_prob": 0.2, - "pre_encoded": true, - "timestep_sampler": "logit_normal", - "optimizer_configs": { - "diffusion": { - "optimizer": { - "type": "AdamW", - "config": { - "lr": 5e-5, - "betas": [0.9, 0.95], - "weight_decay": 1e-4, - "eps": 1e-6 - } - }, - "scheduler": { - "type": "InverseLR", - "config": { - "inv_gamma": 1000000, - "power": 0.5, - "warmup": 0.99 - } - } - } - }, - "demo": { - "demo_every": 2000, - "demo_steps": 24, - "num_demos": 10, - "demo_cond": [ - "dataset/vggsound/video_latents_t5_clip_npz/test/0Cu33yBwAPg_000060.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/bmKtI808DsU_000009.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/VC0c22cJTbM_000424.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/F3gsbUTdc2U_000090.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/WatvT8A8iug_000100.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/0nvBTp-q7tU_000112.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/3-PFuDkTM48_000080.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/luSAuu-BoPs_000232.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/__8UJxW0aOQ_000002.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/_0m_YMpQayA_000168.npz" - ], - "demo_cfg_scales": [3,6,9] - } - } -} \ No newline at end of file diff --git a/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_global_gated_cross.json b/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_global_gated_cross.json deleted file mode 100644 index b713ba4bb782f451df268db799e8bc425beaf503..0000000000000000000000000000000000000000 --- a/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_global_gated_cross.json +++ /dev/null @@ -1,157 +0,0 @@ -{ - "model_type": "mm_diffusion_cond", - "sample_size": 397312, - "sample_rate": 44100, - "audio_channels": 2, - "model": { - "pretransform": { - "type": "autoencoder", - "iterate_batch": true, - "config": { - "encoder": { - "type": "oobleck", - "config": { - "in_channels": 2, - "channels": 128, - "c_mults": [1, 2, 4, 8, 16], - "strides": [2, 4, 4, 8, 8], - "latent_dim": 128, - "use_snake": true - } - }, - "decoder": { - "type": "oobleck", - "config": { - "out_channels": 2, - "channels": 128, - "c_mults": [1, 2, 4, 8, 16], - "strides": [2, 4, 4, 8, 8], - "latent_dim": 64, - "use_snake": true, - "final_tanh": false - } - }, - "bottleneck": { - "type": "vae" - }, - "latent_dim": 64, - "downsampling_ratio": 2048, - "io_channels": 2 - } - }, - "conditioning": { - "configs": [ - { - "id": "metaclip_features", - "type": "mm_unchang", - "config": { - "dim": 1024, - "output_dim": 1024 - } - }, - { - "id": "metaclip_text_features", - "type": "mm_unchang", - "config": { - "dim": 1024, - "output_dim": 1024 - } - }, - { - "id": "metaclip_global_text_features", - "type": "mm_unchang", - "config": { - "dim": 1024, - "output_dim": 1024 - } - }, - { - "id": "sync_features", - "type": "mm_unchang", - "config": { - "dim": 768, - "output_dim": 768 - } - }, - { - "id": "t5_features", - "type": "mm_unchang", - "config": { - "dim": 2048, - "output_dim": 2048 - } - } - ], - "cond_dim": 768 - }, - "diffusion": { - "mm_cond_ids": ["metaclip_features", "sync_features", "metaclip_text_features", "metaclip_global_text_features", "t5_features"], - "type": "mmdit", - "diffusion_objective": "rectified_flow", - "config": { - "latent_dim":64, - "clip_dim":1024, - "sync_dim":768, - "text_dim":2048, - "hidden_dim":1024 , - "depth":21, - "fused_depth":14, - "num_heads":16, - "latent_seq_len":194, - "clip_seq_len":72, - "sync_seq_len":216, - "v2": true, - "add_video": true, - "gated_video": true, - "cross_attend": true - } - }, - "io_channels": 64 - }, - "training": { - "use_ema": true, - "log_loss_info": false, - "cfg_dropout_prob": 0.2, - "pre_encoded": true, - "timestep_sampler": "logit_normal", - "optimizer_configs": { - "diffusion": { - "optimizer": { - "type": "AdamW", - "config": { - "lr": 5e-5, - "betas": [0.9, 0.95], - "weight_decay": 1e-4, - "eps": 1e-6 - } - }, - "scheduler": { - "type": "InverseLR", - "config": { - "inv_gamma": 1000000, - "power": 0.5, - "warmup": 0.99 - } - } - } - }, - "demo": { - "demo_every": 2000, - "demo_steps": 24, - "num_demos": 10, - "demo_cond": [ - "dataset/vggsound/video_latents_t5_clip/test/0Cu33yBwAPg_000060.pth", - "dataset/vggsound/video_latents_t5_clip/test/bmKtI808DsU_000009.pth", - "dataset/vggsound/video_latents_t5_clip/test/VC0c22cJTbM_000424.pth", - "dataset/vggsound/video_latents_t5_clip/test/F3gsbUTdc2U_000090.pth", - "dataset/vggsound/video_latents_t5_clip/test/WatvT8A8iug_000100.pth", - "dataset/vggsound/video_latents_t5_clip/test/0nvBTp-q7tU_000112.pth", - "dataset/vggsound/video_latents_t5_clip/test/3-PFuDkTM48_000080.pth", - "dataset/vggsound/video_latents_t5_clip/test/luSAuu-BoPs_000232.pth", - "dataset/vggsound/video_latents_t5_clip/test/__8UJxW0aOQ_000002.pth", - "dataset/vggsound/video_latents_t5_clip/test/_0m_YMpQayA_000168.pth" - ], - "demo_cfg_scales": [3,6,9] - } - } -} \ No newline at end of file diff --git a/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_inpaint.json b/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_inpaint.json deleted file mode 100644 index 03f4b9e9fdff78c510d41b7d6d8d6b9a530fb620..0000000000000000000000000000000000000000 --- a/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_inpaint.json +++ /dev/null @@ -1,148 +0,0 @@ -{ - "model_type": "mm_diffusion_cond", - "sample_size": 397312, - "sample_rate": 44100, - "audio_channels": 2, - "model": { - "pretransform": { - "type": "autoencoder", - "iterate_batch": true, - "config": { - "encoder": { - "type": "oobleck", - "config": { - "in_channels": 2, - "channels": 128, - "c_mults": [1, 2, 4, 8, 16], - "strides": [2, 4, 4, 8, 8], - "latent_dim": 128, - "use_snake": true - } - }, - "decoder": { - "type": "oobleck", - "config": { - "out_channels": 2, - "channels": 128, - "c_mults": [1, 2, 4, 8, 16], - "strides": [2, 4, 4, 8, 8], - "latent_dim": 64, - "use_snake": true, - "final_tanh": false - } - }, - "bottleneck": { - "type": "vae" - }, - "latent_dim": 64, - "downsampling_ratio": 2048, - "io_channels": 2 - } - }, - "conditioning": { - "configs": [ - { - "id": "metaclip_features", - "type": "mm_unchang", - "config": { - "dim": 1024, - "output_dim": 1024 - } - }, - { - "id": "metaclip_text_features", - "type": "mm_unchang", - "config": { - "dim": 1024, - "output_dim": 1024 - } - }, - { - "id": "sync_features", - "type": "mm_unchang", - "config": { - "dim": 768, - "output_dim": 768 - } - }, - { - "id": "t5_features", - "type": "mm_unchang", - "config": { - "dim": 2048, - "output_dim": 2048 - } - } - ], - "cond_dim": 768 - }, - "diffusion": { - "mm_cond_ids": ["metaclip_features", "sync_features", "metaclip_text_features","t5_features"], - "type": "mmdit", - "diffusion_objective": "rectified_flow", - "config": { - "latent_dim":64, - "clip_dim":1024, - "sync_dim":768, - "text_dim":2048, - "hidden_dim":1024 , - "depth":21, - "fused_depth":14, - "num_heads":16, - "latent_seq_len":194, - "clip_seq_len":72, - "sync_seq_len":216, - "v2": true, - "use_inpaint": true - } - }, - "io_channels": 64 - }, - "training": { - "use_ema": true, - "log_loss_info": false, - "cfg_dropout_prob": 0.2, - "pre_encoded": true, - "timestep_sampler": "logit_normal", - "max_mask_segments": 10, - "optimizer_configs": { - "diffusion": { - "optimizer": { - "type": "AdamW", - "config": { - "lr": 5e-5, - "betas": [0.9, 0.95], - "weight_decay": 1e-4, - "eps": 1e-6 - } - }, - "scheduler": { - "type": "InverseLR", - "config": { - "inv_gamma": 1000000, - "power": 0.5, - "warmup": 0.99 - } - } - } - }, - "demo": { - "demo_every": 5000, - "demo_steps": 24, - "num_demos": 10, - "demo_cond": [ - "dataset/vggsound/video_latents_t5_clip_npz/test/0Cu33yBwAPg_000060.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/bmKtI808DsU_000009.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/VC0c22cJTbM_000424.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/F3gsbUTdc2U_000090.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/WatvT8A8iug_000100.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/0nvBTp-q7tU_000112.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/3-PFuDkTM48_000080.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/luSAuu-BoPs_000232.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/__8UJxW0aOQ_000002.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/_0m_YMpQayA_000168.npz" - ], - "demo_cfg_scales": [5] - } - } -} \ No newline at end of file diff --git a/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_kernel_size1.json b/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_kernel_size1.json deleted file mode 100644 index beb2ee1864048b10f28fa43a20a9dbead84f644e..0000000000000000000000000000000000000000 --- a/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_kernel_size1.json +++ /dev/null @@ -1,147 +0,0 @@ -{ - "model_type": "mm_diffusion_cond", - "sample_size": 397312, - "sample_rate": 44100, - "audio_channels": 2, - "model": { - "pretransform": { - "type": "autoencoder", - "iterate_batch": true, - "config": { - "encoder": { - "type": "oobleck", - "config": { - "in_channels": 2, - "channels": 128, - "c_mults": [1, 2, 4, 8, 16], - "strides": [2, 4, 4, 8, 8], - "latent_dim": 128, - "use_snake": true - } - }, - "decoder": { - "type": "oobleck", - "config": { - "out_channels": 2, - "channels": 128, - "c_mults": [1, 2, 4, 8, 16], - "strides": [2, 4, 4, 8, 8], - "latent_dim": 64, - "use_snake": true, - "final_tanh": false - } - }, - "bottleneck": { - "type": "vae" - }, - "latent_dim": 64, - "downsampling_ratio": 2048, - "io_channels": 2 - } - }, - "conditioning": { - "configs": [ - { - "id": "metaclip_features", - "type": "mm_unchang", - "config": { - "dim": 1024, - "output_dim": 1024 - } - }, - { - "id": "metaclip_text_features", - "type": "mm_unchang", - "config": { - "dim": 1024, - "output_dim": 1024 - } - }, - { - "id": "sync_features", - "type": "mm_unchang", - "config": { - "dim": 768, - "output_dim": 768 - } - }, - { - "id": "t5_features", - "type": "mm_unchang", - "config": { - "dim": 2048, - "output_dim": 2048 - } - } - ], - "cond_dim": 768 - }, - "diffusion": { - "mm_cond_ids": ["metaclip_features", "sync_features", "metaclip_text_features","t5_features"], - "type": "mmdit", - "diffusion_objective": "rectified_flow", - "config": { - "latent_dim":64, - "clip_dim":1024, - "sync_dim":768, - "text_dim":2048, - "hidden_dim":1024 , - "depth":21, - "fused_depth":14, - "num_heads":16, - "latent_seq_len":194, - "clip_seq_len":72, - "sync_seq_len":216, - "v2": true, - "kernel_size": 1 - } - }, - "io_channels": 64 - }, - "training": { - "use_ema": true, - "log_loss_info": false, - "cfg_dropout_prob": 0.2, - "pre_encoded": true, - "timestep_sampler": "logit_normal", - "optimizer_configs": { - "diffusion": { - "optimizer": { - "type": "AdamW", - "config": { - "lr": 5e-5, - "betas": [0.9, 0.95], - "weight_decay": 1e-4, - "eps": 1e-6 - } - }, - "scheduler": { - "type": "InverseLR", - "config": { - "inv_gamma": 1000000, - "power": 0.5, - "warmup": 0.99 - } - } - } - }, - "demo": { - "demo_every": 5000, - "demo_steps": 24, - "num_demos": 10, - "demo_cond": [ - "dataset/vggsound/video_latents_t5_clip_npz/test/0Cu33yBwAPg_000060.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/bmKtI808DsU_000009.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/VC0c22cJTbM_000424.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/F3gsbUTdc2U_000090.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/WatvT8A8iug_000100.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/0nvBTp-q7tU_000112.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/3-PFuDkTM48_000080.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/luSAuu-BoPs_000232.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/__8UJxW0aOQ_000002.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/_0m_YMpQayA_000168.npz" - ], - "demo_cfg_scales": [5] - } - } -} \ No newline at end of file diff --git a/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_kernel_size3_medium.json b/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_kernel_size3_medium.json deleted file mode 100644 index 5de7b392fc926c1b47367536c5f2602a7ad85b4c..0000000000000000000000000000000000000000 --- a/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_kernel_size3_medium.json +++ /dev/null @@ -1,147 +0,0 @@ -{ - "model_type": "mm_diffusion_cond", - "sample_size": 397312, - "sample_rate": 44100, - "audio_channels": 2, - "model": { - "pretransform": { - "type": "autoencoder", - "iterate_batch": true, - "config": { - "encoder": { - "type": "oobleck", - "config": { - "in_channels": 2, - "channels": 128, - "c_mults": [1, 2, 4, 8, 16], - "strides": [2, 4, 4, 8, 8], - "latent_dim": 128, - "use_snake": true - } - }, - "decoder": { - "type": "oobleck", - "config": { - "out_channels": 2, - "channels": 128, - "c_mults": [1, 2, 4, 8, 16], - "strides": [2, 4, 4, 8, 8], - "latent_dim": 64, - "use_snake": true, - "final_tanh": false - } - }, - "bottleneck": { - "type": "vae" - }, - "latent_dim": 64, - "downsampling_ratio": 2048, - "io_channels": 2 - } - }, - "conditioning": { - "configs": [ - { - "id": "metaclip_features", - "type": "mm_unchang", - "config": { - "dim": 1024, - "output_dim": 1024 - } - }, - { - "id": "metaclip_text_features", - "type": "mm_unchang", - "config": { - "dim": 1024, - "output_dim": 1024 - } - }, - { - "id": "sync_features", - "type": "mm_unchang", - "config": { - "dim": 768, - "output_dim": 768 - } - }, - { - "id": "t5_features", - "type": "mm_unchang", - "config": { - "dim": 2048, - "output_dim": 2048 - } - } - ], - "cond_dim": 768 - }, - "diffusion": { - "mm_cond_ids": ["metaclip_features", "sync_features", "metaclip_text_features","t5_features"], - "type": "mmdit", - "diffusion_objective": "rectified_flow", - "config": { - "latent_dim":64, - "clip_dim":1024, - "sync_dim":768, - "text_dim":2048, - "hidden_dim":768 , - "depth":21, - "fused_depth":14, - "num_heads":12, - "latent_seq_len":194, - "clip_seq_len":72, - "sync_seq_len":216, - "v2": true, - "kernel_size": 3 - } - }, - "io_channels": 64 - }, - "training": { - "use_ema": true, - "log_loss_info": false, - "cfg_dropout_prob": 0.2, - "pre_encoded": true, - "timestep_sampler": "logit_normal", - "optimizer_configs": { - "diffusion": { - "optimizer": { - "type": "AdamW", - "config": { - "lr": 5e-5, - "betas": [0.9, 0.95], - "weight_decay": 1e-4, - "eps": 1e-6 - } - }, - "scheduler": { - "type": "InverseLR", - "config": { - "inv_gamma": 1000000, - "power": 0.5, - "warmup": 0.99 - } - } - } - }, - "demo": { - "demo_every": 5000, - "demo_steps": 24, - "num_demos": 10, - "demo_cond": [ - "dataset/vggsound/video_latents_t5_clip_npz/test/0Cu33yBwAPg_000060.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/bmKtI808DsU_000009.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/VC0c22cJTbM_000424.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/F3gsbUTdc2U_000090.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/WatvT8A8iug_000100.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/0nvBTp-q7tU_000112.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/3-PFuDkTM48_000080.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/luSAuu-BoPs_000232.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/__8UJxW0aOQ_000002.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/_0m_YMpQayA_000168.npz" - ], - "demo_cfg_scales": [5] - } - } -} \ No newline at end of file diff --git a/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_kernel_size3_small.json b/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_kernel_size3_small.json deleted file mode 100644 index ef13d4538940a0c37f35e57c897a2dbda4700863..0000000000000000000000000000000000000000 --- a/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_kernel_size3_small.json +++ /dev/null @@ -1,147 +0,0 @@ -{ - "model_type": "mm_diffusion_cond", - "sample_size": 397312, - "sample_rate": 44100, - "audio_channels": 2, - "model": { - "pretransform": { - "type": "autoencoder", - "iterate_batch": true, - "config": { - "encoder": { - "type": "oobleck", - "config": { - "in_channels": 2, - "channels": 128, - "c_mults": [1, 2, 4, 8, 16], - "strides": [2, 4, 4, 8, 8], - "latent_dim": 128, - "use_snake": true - } - }, - "decoder": { - "type": "oobleck", - "config": { - "out_channels": 2, - "channels": 128, - "c_mults": [1, 2, 4, 8, 16], - "strides": [2, 4, 4, 8, 8], - "latent_dim": 64, - "use_snake": true, - "final_tanh": false - } - }, - "bottleneck": { - "type": "vae" - }, - "latent_dim": 64, - "downsampling_ratio": 2048, - "io_channels": 2 - } - }, - "conditioning": { - "configs": [ - { - "id": "metaclip_features", - "type": "mm_unchang", - "config": { - "dim": 1024, - "output_dim": 1024 - } - }, - { - "id": "metaclip_text_features", - "type": "mm_unchang", - "config": { - "dim": 1024, - "output_dim": 1024 - } - }, - { - "id": "sync_features", - "type": "mm_unchang", - "config": { - "dim": 768, - "output_dim": 768 - } - }, - { - "id": "t5_features", - "type": "mm_unchang", - "config": { - "dim": 2048, - "output_dim": 2048 - } - } - ], - "cond_dim": 768 - }, - "diffusion": { - "mm_cond_ids": ["metaclip_features", "sync_features", "metaclip_text_features","t5_features"], - "type": "mmdit", - "diffusion_objective": "rectified_flow", - "config": { - "latent_dim":64, - "clip_dim":1024, - "sync_dim":768, - "text_dim":2048, - "hidden_dim":768 , - "depth":18, - "fused_depth":12, - "num_heads":12, - "latent_seq_len":194, - "clip_seq_len":72, - "sync_seq_len":216, - "v2": true, - "kernel_size": 3 - } - }, - "io_channels": 64 - }, - "training": { - "use_ema": true, - "log_loss_info": false, - "cfg_dropout_prob": 0.2, - "pre_encoded": true, - "timestep_sampler": "logit_normal", - "optimizer_configs": { - "diffusion": { - "optimizer": { - "type": "AdamW", - "config": { - "lr": 5e-5, - "betas": [0.9, 0.95], - "weight_decay": 1e-4, - "eps": 1e-6 - } - }, - "scheduler": { - "type": "InverseLR", - "config": { - "inv_gamma": 1000000, - "power": 0.5, - "warmup": 0.99 - } - } - } - }, - "demo": { - "demo_every": 5000, - "demo_steps": 24, - "num_demos": 10, - "demo_cond": [ - "dataset/vggsound/video_latents_t5_clip_npz/test/0Cu33yBwAPg_000060.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/bmKtI808DsU_000009.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/VC0c22cJTbM_000424.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/F3gsbUTdc2U_000090.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/WatvT8A8iug_000100.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/0nvBTp-q7tU_000112.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/3-PFuDkTM48_000080.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/luSAuu-BoPs_000232.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/__8UJxW0aOQ_000002.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/_0m_YMpQayA_000168.npz" - ], - "demo_cfg_scales": [5] - } - } -} \ No newline at end of file diff --git a/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_kernel_size3_sync_kernel3.json b/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_kernel_size3_sync_kernel3.json deleted file mode 100644 index 7085d6f0e3cef6b7bb95cf166269e02c89d380b2..0000000000000000000000000000000000000000 --- a/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_kernel_size3_sync_kernel3.json +++ /dev/null @@ -1,148 +0,0 @@ -{ - "model_type": "mm_diffusion_cond", - "sample_size": 397312, - "sample_rate": 44100, - "audio_channels": 2, - "model": { - "pretransform": { - "type": "autoencoder", - "iterate_batch": true, - "config": { - "encoder": { - "type": "oobleck", - "config": { - "in_channels": 2, - "channels": 128, - "c_mults": [1, 2, 4, 8, 16], - "strides": [2, 4, 4, 8, 8], - "latent_dim": 128, - "use_snake": true - } - }, - "decoder": { - "type": "oobleck", - "config": { - "out_channels": 2, - "channels": 128, - "c_mults": [1, 2, 4, 8, 16], - "strides": [2, 4, 4, 8, 8], - "latent_dim": 64, - "use_snake": true, - "final_tanh": false - } - }, - "bottleneck": { - "type": "vae" - }, - "latent_dim": 64, - "downsampling_ratio": 2048, - "io_channels": 2 - } - }, - "conditioning": { - "configs": [ - { - "id": "metaclip_features", - "type": "mm_unchang", - "config": { - "dim": 1024, - "output_dim": 1024 - } - }, - { - "id": "metaclip_text_features", - "type": "mm_unchang", - "config": { - "dim": 1024, - "output_dim": 1024 - } - }, - { - "id": "sync_features", - "type": "mm_unchang", - "config": { - "dim": 768, - "output_dim": 768 - } - }, - { - "id": "t5_features", - "type": "mm_unchang", - "config": { - "dim": 2048, - "output_dim": 2048 - } - } - ], - "cond_dim": 768 - }, - "diffusion": { - "mm_cond_ids": ["metaclip_features", "sync_features", "metaclip_text_features","t5_features"], - "type": "mmdit", - "diffusion_objective": "rectified_flow", - "config": { - "latent_dim":64, - "clip_dim":1024, - "sync_dim":768, - "text_dim":2048, - "hidden_dim":1024 , - "depth":21, - "fused_depth":14, - "num_heads":16, - "latent_seq_len":194, - "clip_seq_len":72, - "sync_seq_len":216, - "v2": true, - "kernel_size": 3, - "sync_kernel": 3 - } - }, - "io_channels": 64 - }, - "training": { - "use_ema": true, - "log_loss_info": false, - "cfg_dropout_prob": 0.2, - "pre_encoded": true, - "timestep_sampler": "logit_normal", - "optimizer_configs": { - "diffusion": { - "optimizer": { - "type": "AdamW", - "config": { - "lr": 5e-5, - "betas": [0.9, 0.95], - "weight_decay": 1e-4, - "eps": 1e-6 - } - }, - "scheduler": { - "type": "InverseLR", - "config": { - "inv_gamma": 1000000, - "power": 0.5, - "warmup": 0.99 - } - } - } - }, - "demo": { - "demo_every": 5000, - "demo_steps": 24, - "num_demos": 10, - "demo_cond": [ - "dataset/vggsound/video_latents_t5_clip_npz/test/0Cu33yBwAPg_000060.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/bmKtI808DsU_000009.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/VC0c22cJTbM_000424.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/F3gsbUTdc2U_000090.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/WatvT8A8iug_000100.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/0nvBTp-q7tU_000112.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/3-PFuDkTM48_000080.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/luSAuu-BoPs_000232.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/__8UJxW0aOQ_000002.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/_0m_YMpQayA_000168.npz" - ], - "demo_cfg_scales": [5] - } - } -} \ No newline at end of file diff --git a/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_kernel_size3_sync_kernel5.json b/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_kernel_size3_sync_kernel5.json deleted file mode 100644 index daf6e561b1cd4fc3d55f120afd24858cdb477b7d..0000000000000000000000000000000000000000 --- a/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_kernel_size3_sync_kernel5.json +++ /dev/null @@ -1,148 +0,0 @@ -{ - "model_type": "mm_diffusion_cond", - "sample_size": 397312, - "sample_rate": 44100, - "audio_channels": 2, - "model": { - "pretransform": { - "type": "autoencoder", - "iterate_batch": true, - "config": { - "encoder": { - "type": "oobleck", - "config": { - "in_channels": 2, - "channels": 128, - "c_mults": [1, 2, 4, 8, 16], - "strides": [2, 4, 4, 8, 8], - "latent_dim": 128, - "use_snake": true - } - }, - "decoder": { - "type": "oobleck", - "config": { - "out_channels": 2, - "channels": 128, - "c_mults": [1, 2, 4, 8, 16], - "strides": [2, 4, 4, 8, 8], - "latent_dim": 64, - "use_snake": true, - "final_tanh": false - } - }, - "bottleneck": { - "type": "vae" - }, - "latent_dim": 64, - "downsampling_ratio": 2048, - "io_channels": 2 - } - }, - "conditioning": { - "configs": [ - { - "id": "metaclip_features", - "type": "mm_unchang", - "config": { - "dim": 1024, - "output_dim": 1024 - } - }, - { - "id": "metaclip_text_features", - "type": "mm_unchang", - "config": { - "dim": 1024, - "output_dim": 1024 - } - }, - { - "id": "sync_features", - "type": "mm_unchang", - "config": { - "dim": 768, - "output_dim": 768 - } - }, - { - "id": "t5_features", - "type": "mm_unchang", - "config": { - "dim": 2048, - "output_dim": 2048 - } - } - ], - "cond_dim": 768 - }, - "diffusion": { - "mm_cond_ids": ["metaclip_features", "sync_features", "metaclip_text_features","t5_features"], - "type": "mmdit", - "diffusion_objective": "rectified_flow", - "config": { - "latent_dim":64, - "clip_dim":1024, - "sync_dim":768, - "text_dim":2048, - "hidden_dim":1024 , - "depth":21, - "fused_depth":14, - "num_heads":16, - "latent_seq_len":194, - "clip_seq_len":72, - "sync_seq_len":216, - "v2": true, - "kernel_size": 3, - "sync_kernel": 5 - } - }, - "io_channels": 64 - }, - "training": { - "use_ema": true, - "log_loss_info": false, - "cfg_dropout_prob": 0.2, - "pre_encoded": true, - "timestep_sampler": "logit_normal", - "optimizer_configs": { - "diffusion": { - "optimizer": { - "type": "AdamW", - "config": { - "lr": 5e-5, - "betas": [0.9, 0.95], - "weight_decay": 1e-4, - "eps": 1e-6 - } - }, - "scheduler": { - "type": "InverseLR", - "config": { - "inv_gamma": 1000000, - "power": 0.5, - "warmup": 0.99 - } - } - } - }, - "demo": { - "demo_every": 5000, - "demo_steps": 24, - "num_demos": 10, - "demo_cond": [ - "dataset/vggsound/video_latents_t5_clip_npz/test/0Cu33yBwAPg_000060.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/bmKtI808DsU_000009.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/VC0c22cJTbM_000424.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/F3gsbUTdc2U_000090.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/WatvT8A8iug_000100.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/0nvBTp-q7tU_000112.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/3-PFuDkTM48_000080.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/luSAuu-BoPs_000232.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/__8UJxW0aOQ_000002.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/_0m_YMpQayA_000168.npz" - ], - "demo_cfg_scales": [5] - } - } -} \ No newline at end of file diff --git a/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_kernel_size3_uniform.json b/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_kernel_size3_uniform.json deleted file mode 100644 index cc9027f71a4a3b3fea1ff44ac216f4acfbbb58a4..0000000000000000000000000000000000000000 --- a/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_kernel_size3_uniform.json +++ /dev/null @@ -1,147 +0,0 @@ -{ - "model_type": "mm_diffusion_cond", - "sample_size": 397312, - "sample_rate": 44100, - "audio_channels": 2, - "model": { - "pretransform": { - "type": "autoencoder", - "iterate_batch": true, - "config": { - "encoder": { - "type": "oobleck", - "config": { - "in_channels": 2, - "channels": 128, - "c_mults": [1, 2, 4, 8, 16], - "strides": [2, 4, 4, 8, 8], - "latent_dim": 128, - "use_snake": true - } - }, - "decoder": { - "type": "oobleck", - "config": { - "out_channels": 2, - "channels": 128, - "c_mults": [1, 2, 4, 8, 16], - "strides": [2, 4, 4, 8, 8], - "latent_dim": 64, - "use_snake": true, - "final_tanh": false - } - }, - "bottleneck": { - "type": "vae" - }, - "latent_dim": 64, - "downsampling_ratio": 2048, - "io_channels": 2 - } - }, - "conditioning": { - "configs": [ - { - "id": "metaclip_features", - "type": "mm_unchang", - "config": { - "dim": 1024, - "output_dim": 1024 - } - }, - { - "id": "metaclip_text_features", - "type": "mm_unchang", - "config": { - "dim": 1024, - "output_dim": 1024 - } - }, - { - "id": "sync_features", - "type": "mm_unchang", - "config": { - "dim": 768, - "output_dim": 768 - } - }, - { - "id": "t5_features", - "type": "mm_unchang", - "config": { - "dim": 2048, - "output_dim": 2048 - } - } - ], - "cond_dim": 768 - }, - "diffusion": { - "mm_cond_ids": ["metaclip_features", "sync_features", "metaclip_text_features","t5_features"], - "type": "mmdit", - "diffusion_objective": "rectified_flow", - "config": { - "latent_dim":64, - "clip_dim":1024, - "sync_dim":768, - "text_dim":2048, - "hidden_dim":1024 , - "depth":21, - "fused_depth":14, - "num_heads":16, - "latent_seq_len":194, - "clip_seq_len":72, - "sync_seq_len":216, - "v2": true, - "kernel_size": 3 - } - }, - "io_channels": 64 - }, - "training": { - "use_ema": true, - "log_loss_info": false, - "cfg_dropout_prob": 0.2, - "pre_encoded": true, - "timestep_sampler": "uniform", - "optimizer_configs": { - "diffusion": { - "optimizer": { - "type": "AdamW", - "config": { - "lr": 5e-5, - "betas": [0.9, 0.95], - "weight_decay": 1e-4, - "eps": 1e-6 - } - }, - "scheduler": { - "type": "InverseLR", - "config": { - "inv_gamma": 1000000, - "power": 0.5, - "warmup": 0.99 - } - } - } - }, - "demo": { - "demo_every": 5000, - "demo_steps": 24, - "num_demos": 10, - "demo_cond": [ - "dataset/vggsound/video_latents_t5_clip_npz/test/0Cu33yBwAPg_000060.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/bmKtI808DsU_000009.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/VC0c22cJTbM_000424.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/F3gsbUTdc2U_000090.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/WatvT8A8iug_000100.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/0nvBTp-q7tU_000112.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/3-PFuDkTM48_000080.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/luSAuu-BoPs_000232.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/__8UJxW0aOQ_000002.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/_0m_YMpQayA_000168.npz" - ], - "demo_cfg_scales": [5] - } - } -} \ No newline at end of file diff --git a/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_kernel_size5.json b/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_kernel_size5.json deleted file mode 100644 index 0f862850f543c3fc4c722770d86cec304f53be08..0000000000000000000000000000000000000000 --- a/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_kernel_size5.json +++ /dev/null @@ -1,147 +0,0 @@ -{ - "model_type": "mm_diffusion_cond", - "sample_size": 397312, - "sample_rate": 44100, - "audio_channels": 2, - "model": { - "pretransform": { - "type": "autoencoder", - "iterate_batch": true, - "config": { - "encoder": { - "type": "oobleck", - "config": { - "in_channels": 2, - "channels": 128, - "c_mults": [1, 2, 4, 8, 16], - "strides": [2, 4, 4, 8, 8], - "latent_dim": 128, - "use_snake": true - } - }, - "decoder": { - "type": "oobleck", - "config": { - "out_channels": 2, - "channels": 128, - "c_mults": [1, 2, 4, 8, 16], - "strides": [2, 4, 4, 8, 8], - "latent_dim": 64, - "use_snake": true, - "final_tanh": false - } - }, - "bottleneck": { - "type": "vae" - }, - "latent_dim": 64, - "downsampling_ratio": 2048, - "io_channels": 2 - } - }, - "conditioning": { - "configs": [ - { - "id": "metaclip_features", - "type": "mm_unchang", - "config": { - "dim": 1024, - "output_dim": 1024 - } - }, - { - "id": "metaclip_text_features", - "type": "mm_unchang", - "config": { - "dim": 1024, - "output_dim": 1024 - } - }, - { - "id": "sync_features", - "type": "mm_unchang", - "config": { - "dim": 768, - "output_dim": 768 - } - }, - { - "id": "t5_features", - "type": "mm_unchang", - "config": { - "dim": 2048, - "output_dim": 2048 - } - } - ], - "cond_dim": 768 - }, - "diffusion": { - "mm_cond_ids": ["metaclip_features", "sync_features", "metaclip_text_features","t5_features"], - "type": "mmdit", - "diffusion_objective": "rectified_flow", - "config": { - "latent_dim":64, - "clip_dim":1024, - "sync_dim":768, - "text_dim":2048, - "hidden_dim":1024 , - "depth":21, - "fused_depth":14, - "num_heads":16, - "latent_seq_len":194, - "clip_seq_len":72, - "sync_seq_len":216, - "v2": true, - "kernel_size": 5 - } - }, - "io_channels": 64 - }, - "training": { - "use_ema": true, - "log_loss_info": false, - "cfg_dropout_prob": 0.2, - "pre_encoded": true, - "timestep_sampler": "logit_normal", - "optimizer_configs": { - "diffusion": { - "optimizer": { - "type": "AdamW", - "config": { - "lr": 5e-5, - "betas": [0.9, 0.95], - "weight_decay": 1e-4, - "eps": 1e-6 - } - }, - "scheduler": { - "type": "InverseLR", - "config": { - "inv_gamma": 1000000, - "power": 0.5, - "warmup": 0.99 - } - } - } - }, - "demo": { - "demo_every": 5000, - "demo_steps": 24, - "num_demos": 10, - "demo_cond": [ - "dataset/vggsound/video_latents_t5_clip_npz/test/0Cu33yBwAPg_000060.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/bmKtI808DsU_000009.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/VC0c22cJTbM_000424.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/F3gsbUTdc2U_000090.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/WatvT8A8iug_000100.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/0nvBTp-q7tU_000112.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/3-PFuDkTM48_000080.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/luSAuu-BoPs_000232.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/__8UJxW0aOQ_000002.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/_0m_YMpQayA_000168.npz" - ], - "demo_cfg_scales": [5] - } - } -} \ No newline at end of file diff --git a/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_lr1.json b/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_lr1.json deleted file mode 100644 index 2c5e14bd5bd36ca93e37f105b293a2768c63a424..0000000000000000000000000000000000000000 --- a/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_lr1.json +++ /dev/null @@ -1,147 +0,0 @@ -{ - "model_type": "mm_diffusion_cond", - "sample_size": 397312, - "sample_rate": 44100, - "audio_channels": 2, - "model": { - "pretransform": { - "type": "autoencoder", - "iterate_batch": true, - "config": { - "encoder": { - "type": "oobleck", - "config": { - "in_channels": 2, - "channels": 128, - "c_mults": [1, 2, 4, 8, 16], - "strides": [2, 4, 4, 8, 8], - "latent_dim": 128, - "use_snake": true - } - }, - "decoder": { - "type": "oobleck", - "config": { - "out_channels": 2, - "channels": 128, - "c_mults": [1, 2, 4, 8, 16], - "strides": [2, 4, 4, 8, 8], - "latent_dim": 64, - "use_snake": true, - "final_tanh": false - } - }, - "bottleneck": { - "type": "vae" - }, - "latent_dim": 64, - "downsampling_ratio": 2048, - "io_channels": 2 - } - }, - "conditioning": { - "configs": [ - { - "id": "metaclip_features", - "type": "mm_unchang", - "config": { - "dim": 1024, - "output_dim": 1024 - } - }, - { - "id": "metaclip_text_features", - "type": "mm_unchang", - "config": { - "dim": 1024, - "output_dim": 1024 - } - }, - { - "id": "sync_features", - "type": "mm_unchang", - "config": { - "dim": 768, - "output_dim": 768 - } - }, - { - "id": "t5_features", - "type": "mm_unchang", - "config": { - "dim": 2048, - "output_dim": 2048 - } - } - ], - "cond_dim": 768 - }, - "diffusion": { - "mm_cond_ids": ["metaclip_features", "sync_features", "metaclip_text_features","t5_features"], - "type": "mmdit", - "diffusion_objective": "rectified_flow", - "config": { - "latent_dim":64, - "clip_dim":1024, - "sync_dim":768, - "text_dim":2048, - "hidden_dim":1024 , - "depth":21, - "fused_depth":14, - "num_heads":16, - "latent_seq_len":194, - "clip_seq_len":72, - "sync_seq_len":216, - "v2": true - } - }, - "io_channels": 64 - }, - "training": { - "use_ema": true, - "log_loss_info": false, - "cfg_dropout_prob": 0.2, - "pre_encoded": true, - "timestep_sampler": "logit_normal", - "optimizer_configs": { - "diffusion": { - "optimizer": { - "type": "AdamW", - "config": { - "lr": 1e-4, - "betas": [0.9, 0.95], - "weight_decay": 1e-6, - "fused": true, - "eps": 1e-6 - } - }, - "scheduler": { - "type": "InverseLR", - "config": { - "inv_gamma": 1000000, - "power": 0.5, - "warmup": 0.99 - } - } - } - }, - "demo": { - "demo_every": 5000, - "demo_steps": 24, - "num_demos": 10, - "demo_cond": [ - "dataset/vggsound/video_latents_t5_clip_npz/test/0Cu33yBwAPg_000060.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/bmKtI808DsU_000009.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/VC0c22cJTbM_000424.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/F3gsbUTdc2U_000090.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/WatvT8A8iug_000100.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/0nvBTp-q7tU_000112.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/3-PFuDkTM48_000080.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/luSAuu-BoPs_000232.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/__8UJxW0aOQ_000002.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/_0m_YMpQayA_000168.npz" - ], - "demo_cfg_scales": [5] - } - } -} \ No newline at end of file diff --git a/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_lr1_drop0_1.json b/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_lr1_drop0_1.json deleted file mode 100644 index d464136ceab1c1e36980f6bdd7dd2e3e1417553c..0000000000000000000000000000000000000000 --- a/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_lr1_drop0_1.json +++ /dev/null @@ -1,147 +0,0 @@ -{ - "model_type": "mm_diffusion_cond", - "sample_size": 397312, - "sample_rate": 44100, - "audio_channels": 2, - "model": { - "pretransform": { - "type": "autoencoder", - "iterate_batch": true, - "config": { - "encoder": { - "type": "oobleck", - "config": { - "in_channels": 2, - "channels": 128, - "c_mults": [1, 2, 4, 8, 16], - "strides": [2, 4, 4, 8, 8], - "latent_dim": 128, - "use_snake": true - } - }, - "decoder": { - "type": "oobleck", - "config": { - "out_channels": 2, - "channels": 128, - "c_mults": [1, 2, 4, 8, 16], - "strides": [2, 4, 4, 8, 8], - "latent_dim": 64, - "use_snake": true, - "final_tanh": false - } - }, - "bottleneck": { - "type": "vae" - }, - "latent_dim": 64, - "downsampling_ratio": 2048, - "io_channels": 2 - } - }, - "conditioning": { - "configs": [ - { - "id": "metaclip_features", - "type": "mm_unchang", - "config": { - "dim": 1024, - "output_dim": 1024 - } - }, - { - "id": "metaclip_text_features", - "type": "mm_unchang", - "config": { - "dim": 1024, - "output_dim": 1024 - } - }, - { - "id": "sync_features", - "type": "mm_unchang", - "config": { - "dim": 768, - "output_dim": 768 - } - }, - { - "id": "t5_features", - "type": "mm_unchang", - "config": { - "dim": 2048, - "output_dim": 2048 - } - } - ], - "cond_dim": 768 - }, - "diffusion": { - "mm_cond_ids": ["metaclip_features", "sync_features", "metaclip_text_features","t5_features"], - "type": "mmdit", - "diffusion_objective": "rectified_flow", - "config": { - "latent_dim":64, - "clip_dim":1024, - "sync_dim":768, - "text_dim":2048, - "hidden_dim":1024 , - "depth":21, - "fused_depth":14, - "num_heads":16, - "latent_seq_len":194, - "clip_seq_len":72, - "sync_seq_len":216, - "v2": true - } - }, - "io_channels": 64 - }, - "training": { - "use_ema": true, - "log_loss_info": false, - "cfg_dropout_prob": 0.1, - "pre_encoded": true, - "timestep_sampler": "logit_normal", - "optimizer_configs": { - "diffusion": { - "optimizer": { - "type": "AdamW", - "config": { - "lr": 1e-4, - "betas": [0.9, 0.95], - "weight_decay": 1e-6, - "fused": true, - "eps": 1e-6 - } - }, - "scheduler": { - "type": "InverseLR", - "config": { - "inv_gamma": 1000000, - "power": 0.5, - "warmup": 0.99 - } - } - } - }, - "demo": { - "demo_every": 5000, - "demo_steps": 24, - "num_demos": 10, - "demo_cond": [ - "dataset/vggsound/video_latents_t5_clip_npz/test/0Cu33yBwAPg_000060.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/bmKtI808DsU_000009.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/VC0c22cJTbM_000424.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/F3gsbUTdc2U_000090.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/WatvT8A8iug_000100.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/0nvBTp-q7tU_000112.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/3-PFuDkTM48_000080.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/luSAuu-BoPs_000232.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/__8UJxW0aOQ_000002.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/_0m_YMpQayA_000168.npz" - ], - "demo_cfg_scales": [5] - } - } -} \ No newline at end of file diff --git a/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_lr2.json b/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_lr2.json deleted file mode 100644 index 460ac2ccae0d45e1d45c12e149cea199396a9384..0000000000000000000000000000000000000000 --- a/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_lr2.json +++ /dev/null @@ -1,147 +0,0 @@ -{ - "model_type": "mm_diffusion_cond", - "sample_size": 397312, - "sample_rate": 44100, - "audio_channels": 2, - "model": { - "pretransform": { - "type": "autoencoder", - "iterate_batch": true, - "config": { - "encoder": { - "type": "oobleck", - "config": { - "in_channels": 2, - "channels": 128, - "c_mults": [1, 2, 4, 8, 16], - "strides": [2, 4, 4, 8, 8], - "latent_dim": 128, - "use_snake": true - } - }, - "decoder": { - "type": "oobleck", - "config": { - "out_channels": 2, - "channels": 128, - "c_mults": [1, 2, 4, 8, 16], - "strides": [2, 4, 4, 8, 8], - "latent_dim": 64, - "use_snake": true, - "final_tanh": false - } - }, - "bottleneck": { - "type": "vae" - }, - "latent_dim": 64, - "downsampling_ratio": 2048, - "io_channels": 2 - } - }, - "conditioning": { - "configs": [ - { - "id": "metaclip_features", - "type": "mm_unchang", - "config": { - "dim": 1024, - "output_dim": 1024 - } - }, - { - "id": "metaclip_text_features", - "type": "mm_unchang", - "config": { - "dim": 1024, - "output_dim": 1024 - } - }, - { - "id": "sync_features", - "type": "mm_unchang", - "config": { - "dim": 768, - "output_dim": 768 - } - }, - { - "id": "t5_features", - "type": "mm_unchang", - "config": { - "dim": 2048, - "output_dim": 2048 - } - } - ], - "cond_dim": 768 - }, - "diffusion": { - "mm_cond_ids": ["metaclip_features", "sync_features", "metaclip_text_features","t5_features"], - "type": "mmdit", - "diffusion_objective": "rectified_flow", - "config": { - "latent_dim":64, - "clip_dim":1024, - "sync_dim":768, - "text_dim":2048, - "hidden_dim":1024 , - "depth":21, - "fused_depth":14, - "num_heads":16, - "latent_seq_len":194, - "clip_seq_len":72, - "sync_seq_len":216, - "v2": true - } - }, - "io_channels": 64 - }, - "training": { - "use_ema": true, - "log_loss_info": false, - "cfg_dropout_prob": 0.2, - "pre_encoded": true, - "timestep_sampler": "logit_normal", - "optimizer_configs": { - "diffusion": { - "optimizer": { - "type": "AdamW", - "config": { - "lr": 1e-4, - "betas": [0.9, 0.95], - "weight_decay": 1e-6, - "fused": true, - "eps": 1e-6 - } - }, - "scheduler": { - "type": "InverseLR", - "config": { - "inv_gamma": 100000, - "power": 0.5, - "warmup": 0.99 - } - } - } - }, - "demo": { - "demo_every": 5000, - "demo_steps": 24, - "num_demos": 10, - "demo_cond": [ - "dataset/vggsound/video_latents_t5_clip_npz/test/0Cu33yBwAPg_000060.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/bmKtI808DsU_000009.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/VC0c22cJTbM_000424.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/F3gsbUTdc2U_000090.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/WatvT8A8iug_000100.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/0nvBTp-q7tU_000112.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/3-PFuDkTM48_000080.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/luSAuu-BoPs_000232.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/__8UJxW0aOQ_000002.npz", - "dataset/vggsound/video_latents_t5_clip_npz/test/_0m_YMpQayA_000168.npz" - ], - "demo_cfg_scales": [5] - } - } -} \ No newline at end of file diff --git a/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_mlp.json b/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_mlp.json deleted file mode 100644 index 7ea85ea36d77d13028cb3ddf90a23d6d692a7fa7..0000000000000000000000000000000000000000 --- a/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_mlp.json +++ /dev/null @@ -1,147 +0,0 @@ -{ - "model_type": "mm_diffusion_cond", - "sample_size": 397312, - "sample_rate": 44100, - "audio_channels": 2, - "model": { - "pretransform": { - "type": "autoencoder", - "iterate_batch": true, - "config": { - "encoder": { - "type": "oobleck", - "config": { - "in_channels": 2, - "channels": 128, - "c_mults": [1, 2, 4, 8, 16], - "strides": [2, 4, 4, 8, 8], - "latent_dim": 128, - "use_snake": true - } - }, - "decoder": { - "type": "oobleck", - "config": { - "out_channels": 2, - "channels": 128, - "c_mults": [1, 2, 4, 8, 16], - "strides": [2, 4, 4, 8, 8], - "latent_dim": 64, - "use_snake": true, - "final_tanh": false - } - }, - "bottleneck": { - "type": "vae" - }, - "latent_dim": 64, - "downsampling_ratio": 2048, - "io_channels": 2 - } - }, - "conditioning": { - "configs": [ - { - "id": "metaclip_features", - "type": "mm_unchang", - "config": { - "dim": 1024, - "output_dim": 1024 - } - }, - { - "id": "metaclip_text_features", - "type": "mm_unchang", - "config": { - "dim": 1024, - "output_dim": 1024 - } - }, - { - "id": "sync_features", - "type": "mm_unchang", - "config": { - "dim": 768, - "output_dim": 768 - } - }, - { - "id": "t5_features", - "type": "mm_unchang", - "config": { - "dim": 2048, - "output_dim": 2048 - } - } - ], - "cond_dim": 768 - }, - "diffusion": { - "mm_cond_ids": ["metaclip_features", "sync_features", "metaclip_text_features","t5_features"], - "type": "mmdit", - "diffusion_objective": "rectified_flow", - "config": { - "latent_dim":64, - "clip_dim":1024, - "sync_dim":768, - "text_dim":2048, - "hidden_dim":1024 , - "depth":21, - "fused_depth":14, - "num_heads":16, - "latent_seq_len":194, - "clip_seq_len":72, - "sync_seq_len":216, - "v2": true, - "use_mlp": true - } - }, - "io_channels": 64 - }, - "training": { - "use_ema": true, - "log_loss_info": false, - "cfg_dropout_prob": 0.2, - "pre_encoded": true, - "timestep_sampler": "logit_normal", - "optimizer_configs": { - "diffusion": { - "optimizer": { - "type": "AdamW", - "config": { - "lr": 5e-5, - "betas": [0.9, 0.95], - "weight_decay": 1e-4, - "eps": 1e-6 - } - }, - "scheduler": { - "type": "InverseLR", - "config": { - "inv_gamma": 1000000, - "power": 0.5, - "warmup": 0.99 - } - } - } - }, - "demo": { - "demo_every": 2000, - "demo_steps": 24, - "num_demos": 10, - "demo_cond": [ - "dataset/vggsound/video_latents_t5_clip/test/0Cu33yBwAPg_000060.pth", - "dataset/vggsound/video_latents_t5_clip/test/bmKtI808DsU_000009.pth", - "dataset/vggsound/video_latents_t5_clip/test/VC0c22cJTbM_000424.pth", - "dataset/vggsound/video_latents_t5_clip/test/F3gsbUTdc2U_000090.pth", - "dataset/vggsound/video_latents_t5_clip/test/WatvT8A8iug_000100.pth", - "dataset/vggsound/video_latents_t5_clip/test/0nvBTp-q7tU_000112.pth", - "dataset/vggsound/video_latents_t5_clip/test/3-PFuDkTM48_000080.pth", - "dataset/vggsound/video_latents_t5_clip/test/luSAuu-BoPs_000232.pth", - "dataset/vggsound/video_latents_t5_clip/test/__8UJxW0aOQ_000002.pth", - "dataset/vggsound/video_latents_t5_clip/test/_0m_YMpQayA_000168.pth" - ], - "demo_cfg_scales": [3,6,9] - } - } -} \ No newline at end of file diff --git a/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_triplegated.json b/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_triplegated.json deleted file mode 100644 index ef2c134f7c3c8f985a2d213dc2e477f9e2ea2ac2..0000000000000000000000000000000000000000 --- a/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_triplegated.json +++ /dev/null @@ -1,142 +0,0 @@ -{ - "model_type": "mm_diffusion_cond", - "sample_size": 397312, - "sample_rate": 44100, - "audio_channels": 2, - "model": { - "pretransform": { - "type": "autoencoder", - "iterate_batch": true, - "config": { - "encoder": { - "type": "oobleck", - "config": { - "in_channels": 2, - "channels": 128, - "c_mults": [1, 2, 4, 8, 16], - "strides": [2, 4, 4, 8, 8], - "latent_dim": 128, - "use_snake": true - } - }, - "decoder": { - "type": "oobleck", - "config": { - "out_channels": 2, - "channels": 128, - "c_mults": [1, 2, 4, 8, 16], - "strides": [2, 4, 4, 8, 8], - "latent_dim": 64, - "use_snake": true, - "final_tanh": false - } - }, - "bottleneck": { - "type": "vae" - }, - "latent_dim": 64, - "downsampling_ratio": 2048, - "io_channels": 2 - } - }, - "conditioning": { - "configs": [ - { - "id": "metaclip_features", - "type": "mm_unchang", - "config": { - "dim": 1024, - "output_dim": 1024 - } - }, - { - "id": "metaclip_text_features", - "type": "mm_unchang", - "config": { - "dim": 1024, - "output_dim": 1024 - } - }, - { - "id": "sync_features", - "type": "mm_unchang", - "config": { - "dim": 768, - "output_dim": 768 - } - } - ], - "cond_dim": 768 - }, - "diffusion": { - "mm_cond_ids": ["metaclip_features", "sync_features", "metaclip_text_features"], - "type": "mmdit", - "diffusion_objective": "rectified_flow", - "config": { - "latent_dim":64, - "clip_dim":1024, - "sync_dim":768, - "text_dim":1024, - "hidden_dim":1024 , - "depth":21, - "fused_depth":14, - "num_heads":16, - "latent_seq_len":194, - "clip_seq_len":72, - "sync_seq_len":216, - "v2": true, - "cross_attend": false, - "add_video": true, - "gated_video": false, - "triple_fusion": true - } - }, - "io_channels": 64 - }, - "training": { - "use_ema": true, - "log_loss_info": false, - "cfg_dropout_prob": 0.2, - "pre_encoded": true, - "timestep_sampler": "logit_normal", - "optimizer_configs": { - "diffusion": { - "optimizer": { - "type": "AdamW", - "config": { - "lr": 5e-5, - "betas": [0.9, 0.95], - "weight_decay": 1e-4, - "eps": 1e-6 - } - }, - "scheduler": { - "type": "InverseLR", - "config": { - "inv_gamma": 1000000, - "power": 0.5, - "warmup": 0.99 - } - } - } - }, - "demo": { - "demo_every": 2000, - "demo_steps": 24, - "num_demos": 10, - "demo_cond": [ - "dataset/vggsound/video_text_latents/test/0Cu33yBwAPg_000060.pth", - "dataset/vggsound/video_text_latents/test/bmKtI808DsU_000009.pth", - "dataset/vggsound/video_text_latents/test/VC0c22cJTbM_000424.pth", - "dataset/vggsound/video_text_latents/test/F3gsbUTdc2U_000090.pth", - "dataset/vggsound/video_text_latents/test/WatvT8A8iug_000100.pth", - "dataset/vggsound/video_text_latents/test/0nvBTp-q7tU_000112.pth", - "dataset/vggsound/video_text_latents/test/3-PFuDkTM48_000080.pth", - "dataset/vggsound/video_text_latents/test/luSAuu-BoPs_000232.pth", - "dataset/vggsound/video_text_latents/test/__8UJxW0aOQ_000002.pth", - "dataset/vggsound/video_text_latents/test/_0m_YMpQayA_000168.pth" - ], - "demo_cfg_scales": [3,6,9] - } - } -} \ No newline at end of file diff --git a/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_large.json b/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_large.json deleted file mode 100644 index 2e0ae59455750209382b915acbc1b9d2df2e8914..0000000000000000000000000000000000000000 --- a/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_large.json +++ /dev/null @@ -1,137 +0,0 @@ -{ - "model_type": "mm_diffusion_cond", - "sample_size": 397312, - "sample_rate": 44100, - "audio_channels": 2, - "model": { - "pretransform": { - "type": "autoencoder", - "iterate_batch": true, - "config": { - "encoder": { - "type": "oobleck", - "config": { - "in_channels": 2, - "channels": 128, - "c_mults": [1, 2, 4, 8, 16], - "strides": [2, 4, 4, 8, 8], - "latent_dim": 128, - "use_snake": true - } - }, - "decoder": { - "type": "oobleck", - "config": { - "out_channels": 2, - "channels": 128, - "c_mults": [1, 2, 4, 8, 16], - "strides": [2, 4, 4, 8, 8], - "latent_dim": 64, - "use_snake": true, - "final_tanh": false - } - }, - "bottleneck": { - "type": "vae" - }, - "latent_dim": 64, - "downsampling_ratio": 2048, - "io_channels": 2 - } - }, - "conditioning": { - "configs": [ - { - "id": "metaclip_features", - "type": "mm_unchang", - "config": { - "dim": 1024, - "output_dim": 1024 - } - }, - { - "id": "metaclip_text_features", - "type": "mm_unchang", - "config": { - "dim": 1024, - "output_dim": 1024 - } - }, - { - "id": "sync_features", - "type": "mm_unchang", - "config": { - "dim": 768, - "output_dim": 768 - } - } - ], - "cond_dim": 768 - }, - "diffusion": { - "mm_cond_ids": ["metaclip_features", "sync_features", "metaclip_text_features"], - "type": "mmdit", - "diffusion_objective": "rectified_flow", - "config": { - "latent_dim":64, - "clip_dim":1024, - "sync_dim":768, - "text_dim":1024, - "hidden_dim":1280 , - "depth":24, - "fused_depth":16, - "num_heads":20, - "latent_seq_len":194, - "clip_seq_len":72, - "sync_seq_len":216, - "v2": true - } - }, - "io_channels": 64 - }, - "training": { - "use_ema": true, - "log_loss_info": false, - "cfg_dropout_prob": 0.2, - "pre_encoded": true, - "optimizer_configs": { - "diffusion": { - "optimizer": { - "type": "AdamW", - "config": { - "lr": 5e-5, - "betas": [0.9, 0.95], - "weight_decay": 1e-4, - "eps": 1e-6 - } - }, - "scheduler": { - "type": "InverseLR", - "config": { - "inv_gamma": 1000000, - "power": 0.5, - "warmup": 0.99 - } - } - } - }, - "demo": { - "demo_every": 2000, - "demo_steps": 24, - "num_demos": 10, - "demo_cond": [ - "dataset/vggsound/video_224_pad_latents_text/test/0Cu33yBwAPg_000060.pth", - "dataset/vggsound/video_224_pad_latents_text/test/bmKtI808DsU_000009.pth", - "dataset/vggsound/video_224_pad_latents_text/test/VC0c22cJTbM_000424.pth", - "dataset/vggsound/video_224_pad_latents_text/test/F3gsbUTdc2U_000090.pth", - "dataset/vggsound/video_224_pad_latents_text/test/WatvT8A8iug_000100.pth", - "dataset/vggsound/video_224_pad_latents_text/test/0nvBTp-q7tU_000112.pth", - "dataset/vggsound/video_224_pad_latents_text/test/3-PFuDkTM48_000080.pth", - "dataset/vggsound/video_224_pad_latents_text/test/luSAuu-BoPs_000232.pth", - "dataset/vggsound/video_224_pad_latents_text/test/__8UJxW0aOQ_000002.pth", - "dataset/vggsound/video_224_pad_latents_text/test/_0m_YMpQayA_000168.pth" - ], - "demo_cfg_scales": [3,6,9] - } - } -} \ No newline at end of file diff --git a/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_lr1.json b/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_lr1.json deleted file mode 100644 index 7049af0756c0e46125ac98fb8ce71c6b0a883e41..0000000000000000000000000000000000000000 --- a/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_lr1.json +++ /dev/null @@ -1,137 +0,0 @@ -{ - "model_type": "mm_diffusion_cond", - "sample_size": 397312, - "sample_rate": 44100, - "audio_channels": 2, - "model": { - "pretransform": { - "type": "autoencoder", - "iterate_batch": true, - "config": { - "encoder": { - "type": "oobleck", - "config": { - "in_channels": 2, - "channels": 128, - "c_mults": [1, 2, 4, 8, 16], - "strides": [2, 4, 4, 8, 8], - "latent_dim": 128, - "use_snake": true - } - }, - "decoder": { - "type": "oobleck", - "config": { - "out_channels": 2, - "channels": 128, - "c_mults": [1, 2, 4, 8, 16], - "strides": [2, 4, 4, 8, 8], - "latent_dim": 64, - "use_snake": true, - "final_tanh": false - } - }, - "bottleneck": { - "type": "vae" - }, - "latent_dim": 64, - "downsampling_ratio": 2048, - "io_channels": 2 - } - }, - "conditioning": { - "configs": [ - { - "id": "metaclip_features", - "type": "mm_unchang", - "config": { - "dim": 1024, - "output_dim": 1024 - } - }, - { - "id": "metaclip_text_features", - "type": "mm_unchang", - "config": { - "dim": 1024, - "output_dim": 1024 - } - }, - { - "id": "sync_features", - "type": "mm_unchang", - "config": { - "dim": 768, - "output_dim": 768 - } - } - ], - "cond_dim": 768 - }, - "diffusion": { - "mm_cond_ids": ["metaclip_features", "sync_features", "metaclip_text_features"], - "type": "mmdit", - "diffusion_objective": "rectified_flow", - "config": { - "latent_dim":64, - "clip_dim":1024, - "sync_dim":768, - "text_dim":1024, - "hidden_dim":1024 , - "depth":21, - "fused_depth":14, - "num_heads":16, - "latent_seq_len":194, - "clip_seq_len":72, - "sync_seq_len":216, - "v2": true - } - }, - "io_channels": 64 - }, - "training": { - "use_ema": true, - "log_loss_info": false, - "cfg_dropout_prob": 0.2, - "pre_encoded": true, - "optimizer_configs": { - "diffusion": { - "optimizer": { - "type": "AdamW", - "config": { - "lr": 1e-5, - "betas": [0.9, 0.95], - "weight_decay": 1e-3, - "eps": 1e-6 - } - }, - "scheduler": { - "type": "InverseLR", - "config": { - "inv_gamma": 100000, - "power": 0.5, - "warmup": 0.99 - } - } - } - }, - "demo": { - "demo_every": 2000, - "demo_steps": 24, - "num_demos": 10, - "demo_cond": [ - "dataset/vggsound/video_224_pad_latents_text/test/0Cu33yBwAPg_000060.pth", - "dataset/vggsound/video_224_pad_latents_text/test/bmKtI808DsU_000009.pth", - "dataset/vggsound/video_224_pad_latents_text/test/VC0c22cJTbM_000424.pth", - "dataset/vggsound/video_224_pad_latents_text/test/F3gsbUTdc2U_000090.pth", - "dataset/vggsound/video_224_pad_latents_text/test/WatvT8A8iug_000100.pth", - "dataset/vggsound/video_224_pad_latents_text/test/0nvBTp-q7tU_000112.pth", - "dataset/vggsound/video_224_pad_latents_text/test/3-PFuDkTM48_000080.pth", - "dataset/vggsound/video_224_pad_latents_text/test/luSAuu-BoPs_000232.pth", - "dataset/vggsound/video_224_pad_latents_text/test/__8UJxW0aOQ_000002.pth", - "dataset/vggsound/video_224_pad_latents_text/test/_0m_YMpQayA_000168.pth" - ], - "demo_cfg_scales": [3,6,9] - } - } -} \ No newline at end of file diff --git a/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_lr2.json b/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_lr2.json deleted file mode 100644 index 964e276c547328461bd9e5b6f2d3afcf97406c6f..0000000000000000000000000000000000000000 --- a/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_lr2.json +++ /dev/null @@ -1,137 +0,0 @@ -{ - "model_type": "mm_diffusion_cond", - "sample_size": 397312, - "sample_rate": 44100, - "audio_channels": 2, - "model": { - "pretransform": { - "type": "autoencoder", - "iterate_batch": true, - "config": { - "encoder": { - "type": "oobleck", - "config": { - "in_channels": 2, - "channels": 128, - "c_mults": [1, 2, 4, 8, 16], - "strides": [2, 4, 4, 8, 8], - "latent_dim": 128, - "use_snake": true - } - }, - "decoder": { - "type": "oobleck", - "config": { - "out_channels": 2, - "channels": 128, - "c_mults": [1, 2, 4, 8, 16], - "strides": [2, 4, 4, 8, 8], - "latent_dim": 64, - "use_snake": true, - "final_tanh": false - } - }, - "bottleneck": { - "type": "vae" - }, - "latent_dim": 64, - "downsampling_ratio": 2048, - "io_channels": 2 - } - }, - "conditioning": { - "configs": [ - { - "id": "metaclip_features", - "type": "mm_unchang", - "config": { - "dim": 1024, - "output_dim": 1024 - } - }, - { - "id": "metaclip_text_features", - "type": "mm_unchang", - "config": { - "dim": 1024, - "output_dim": 1024 - } - }, - { - "id": "sync_features", - "type": "mm_unchang", - "config": { - "dim": 768, - "output_dim": 768 - } - } - ], - "cond_dim": 768 - }, - "diffusion": { - "mm_cond_ids": ["metaclip_features", "sync_features", "metaclip_text_features"], - "type": "mmdit", - "diffusion_objective": "rectified_flow", - "config": { - "latent_dim":64, - "clip_dim":1024, - "sync_dim":768, - "text_dim":1024, - "hidden_dim":1024 , - "depth":21, - "fused_depth":14, - "num_heads":16, - "latent_seq_len":194, - "clip_seq_len":72, - "sync_seq_len":216, - "v2": true - } - }, - "io_channels": 64 - }, - "training": { - "use_ema": true, - "log_loss_info": false, - "cfg_dropout_prob": 0.2, - "pre_encoded": true, - "optimizer_configs": { - "diffusion": { - "optimizer": { - "type": "AdamW", - "config": { - "lr": 1e-4, - "betas": [0.9, 0.95], - "weight_decay": 1e-3, - "eps": 1e-6 - } - }, - "scheduler": { - "type": "InverseLR", - "config": { - "inv_gamma": 100000, - "power": 0.5, - "warmup": 0.99 - } - } - } - }, - "demo": { - "demo_every": 2000, - "demo_steps": 24, - "num_demos": 10, - "demo_cond": [ - "dataset/vggsound/video_224_pad_latents_text/test/0Cu33yBwAPg_000060.pth", - "dataset/vggsound/video_224_pad_latents_text/test/bmKtI808DsU_000009.pth", - "dataset/vggsound/video_224_pad_latents_text/test/VC0c22cJTbM_000424.pth", - "dataset/vggsound/video_224_pad_latents_text/test/F3gsbUTdc2U_000090.pth", - "dataset/vggsound/video_224_pad_latents_text/test/WatvT8A8iug_000100.pth", - "dataset/vggsound/video_224_pad_latents_text/test/0nvBTp-q7tU_000112.pth", - "dataset/vggsound/video_224_pad_latents_text/test/3-PFuDkTM48_000080.pth", - "dataset/vggsound/video_224_pad_latents_text/test/luSAuu-BoPs_000232.pth", - "dataset/vggsound/video_224_pad_latents_text/test/__8UJxW0aOQ_000002.pth", - "dataset/vggsound/video_224_pad_latents_text/test/_0m_YMpQayA_000168.pth" - ], - "demo_cfg_scales": [3,6,9] - } - } -} \ No newline at end of file diff --git a/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_lr3.json b/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_lr3.json deleted file mode 100644 index 902ee308129b20354af81f4ea68930d4560e0d4d..0000000000000000000000000000000000000000 --- a/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_lr3.json +++ /dev/null @@ -1,137 +0,0 @@ -{ - "model_type": "mm_diffusion_cond", - "sample_size": 397312, - "sample_rate": 44100, - "audio_channels": 2, - "model": { - "pretransform": { - "type": "autoencoder", - "iterate_batch": true, - "config": { - "encoder": { - "type": "oobleck", - "config": { - "in_channels": 2, - "channels": 128, - "c_mults": [1, 2, 4, 8, 16], - "strides": [2, 4, 4, 8, 8], - "latent_dim": 128, - "use_snake": true - } - }, - "decoder": { - "type": "oobleck", - "config": { - "out_channels": 2, - "channels": 128, - "c_mults": [1, 2, 4, 8, 16], - "strides": [2, 4, 4, 8, 8], - "latent_dim": 64, - "use_snake": true, - "final_tanh": false - } - }, - "bottleneck": { - "type": "vae" - }, - "latent_dim": 64, - "downsampling_ratio": 2048, - "io_channels": 2 - } - }, - "conditioning": { - "configs": [ - { - "id": "metaclip_features", - "type": "mm_unchang", - "config": { - "dim": 1024, - "output_dim": 1024 - } - }, - { - "id": "metaclip_text_features", - "type": "mm_unchang", - "config": { - "dim": 1024, - "output_dim": 1024 - } - }, - { - "id": "sync_features", - "type": "mm_unchang", - "config": { - "dim": 768, - "output_dim": 768 - } - } - ], - "cond_dim": 768 - }, - "diffusion": { - "mm_cond_ids": ["metaclip_features", "sync_features", "metaclip_text_features"], - "type": "mmdit", - "diffusion_objective": "rectified_flow", - "config": { - "latent_dim":64, - "clip_dim":1024, - "sync_dim":768, - "text_dim":1024, - "hidden_dim":1024 , - "depth":21, - "fused_depth":14, - "num_heads":16, - "latent_seq_len":194, - "clip_seq_len":72, - "sync_seq_len":216, - "v2": true - } - }, - "io_channels": 64 - }, - "training": { - "use_ema": true, - "log_loss_info": false, - "cfg_dropout_prob": 0.2, - "pre_encoded": true, - "optimizer_configs": { - "diffusion": { - "optimizer": { - "type": "AdamW", - "config": { - "lr": 1e-4, - "betas": [0.9, 0.95], - "weight_decay": 1e-4, - "eps": 1e-6 - } - }, - "scheduler": { - "type": "InverseLR", - "config": { - "inv_gamma": 1000000, - "power": 0.5, - "warmup": 0.99 - } - } - } - }, - "demo": { - "demo_every": 2000, - "demo_steps": 24, - "num_demos": 10, - "demo_cond": [ - "dataset/vggsound/video_224_pad_latents_text/test/0Cu33yBwAPg_000060.pth", - "dataset/vggsound/video_224_pad_latents_text/test/bmKtI808DsU_000009.pth", - "dataset/vggsound/video_224_pad_latents_text/test/VC0c22cJTbM_000424.pth", - "dataset/vggsound/video_224_pad_latents_text/test/F3gsbUTdc2U_000090.pth", - "dataset/vggsound/video_224_pad_latents_text/test/WatvT8A8iug_000100.pth", - "dataset/vggsound/video_224_pad_latents_text/test/0nvBTp-q7tU_000112.pth", - "dataset/vggsound/video_224_pad_latents_text/test/3-PFuDkTM48_000080.pth", - "dataset/vggsound/video_224_pad_latents_text/test/luSAuu-BoPs_000232.pth", - "dataset/vggsound/video_224_pad_latents_text/test/__8UJxW0aOQ_000002.pth", - "dataset/vggsound/video_224_pad_latents_text/test/_0m_YMpQayA_000168.pth" - ], - "demo_cfg_scales": [3,6,9] - } - } -} \ No newline at end of file diff --git a/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_lr4.json b/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_lr4.json deleted file mode 100644 index a989ec63b012163f887c2d152199d6b54a782837..0000000000000000000000000000000000000000 --- a/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_lr4.json +++ /dev/null @@ -1,138 +0,0 @@ -{ - "model_type": "mm_diffusion_cond", - "sample_size": 397312, - "sample_rate": 44100, - "audio_channels": 2, - "model": { - "pretransform": { - "type": "autoencoder", - "iterate_batch": true, - "config": { - "encoder": { - "type": "oobleck", - "config": { - "in_channels": 2, - "channels": 128, - "c_mults": [1, 2, 4, 8, 16], - "strides": [2, 4, 4, 8, 8], - "latent_dim": 128, - "use_snake": true - } - }, - "decoder": { - "type": "oobleck", - "config": { - "out_channels": 2, - "channels": 128, - "c_mults": [1, 2, 4, 8, 16], - "strides": [2, 4, 4, 8, 8], - "latent_dim": 64, - "use_snake": true, - "final_tanh": false - } - }, - "bottleneck": { - "type": "vae" - }, - "latent_dim": 64, - "downsampling_ratio": 2048, - "io_channels": 2 - } - }, - "conditioning": { - "configs": [ - { - "id": "metaclip_features", - "type": "mm_unchang", - "config": { - "dim": 1024, - "output_dim": 1024 - } - }, - { - "id": "metaclip_text_features", - "type": "mm_unchang", - "config": { - "dim": 1024, - "output_dim": 1024 - } - }, - { - "id": "sync_features", - "type": "mm_unchang", - "config": { - "dim": 768, - "output_dim": 768 - } - } - ], - "cond_dim": 768 - }, - "diffusion": { - "mm_cond_ids": ["metaclip_features", "sync_features", "metaclip_text_features"], - "type": "mmdit", - "diffusion_objective": "rectified_flow", - "config": { - "latent_dim":64, - "clip_dim":1024, - "sync_dim":768, - "text_dim":1024, - "hidden_dim":1024 , - "depth":21, - "fused_depth":14, - "num_heads":16, - "latent_seq_len":194, - "clip_seq_len":72, - "sync_seq_len":216, - "v2": true - } - }, - "io_channels": 64 - }, - "training": { - "use_ema": true, - "log_loss_info": false, - "cfg_dropout_prob": 0.2, - "pre_encoded": true, - "optimizer_configs": { - "diffusion": { - "optimizer": { - "type": "AdamW", - "config": { - "lr": 1e-4, - "betas": [0.9, 0.95], - "weight_decay": 1e-6, - "fused": true, - "eps": 1e-6 - } - }, - "scheduler": { - "type": "InverseLR", - "config": { - "inv_gamma": 1000000, - "power": 0.5, - "warmup": 0.99 - } - } - } - }, - "demo": { - "demo_every": 2000, - "demo_steps": 24, - "num_demos": 10, - "demo_cond": [ - "dataset/vggsound/video_224_pad_latents_text/test/0Cu33yBwAPg_000060.pth", - "dataset/vggsound/video_224_pad_latents_text/test/bmKtI808DsU_000009.pth", - "dataset/vggsound/video_224_pad_latents_text/test/VC0c22cJTbM_000424.pth", - "dataset/vggsound/video_224_pad_latents_text/test/F3gsbUTdc2U_000090.pth", - "dataset/vggsound/video_224_pad_latents_text/test/WatvT8A8iug_000100.pth", - "dataset/vggsound/video_224_pad_latents_text/test/0nvBTp-q7tU_000112.pth", - "dataset/vggsound/video_224_pad_latents_text/test/3-PFuDkTM48_000080.pth", - "dataset/vggsound/video_224_pad_latents_text/test/luSAuu-BoPs_000232.pth", - "dataset/vggsound/video_224_pad_latents_text/test/__8UJxW0aOQ_000002.pth", - "dataset/vggsound/video_224_pad_latents_text/test/_0m_YMpQayA_000168.pth" - ], - "demo_cfg_scales": [3,6,9] - } - } -} \ No newline at end of file diff --git a/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_old.json b/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_old.json deleted file mode 100644 index 3a464c0c08b2dd5814b9ef7cc4b187ad3ad37a9c..0000000000000000000000000000000000000000 --- a/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_old.json +++ /dev/null @@ -1,138 +0,0 @@ -{ - "model_type": "mm_diffusion_cond", - "sample_size": 397312, - "sample_rate": 44100, - "audio_channels": 2, - "model": { - "pretransform": { - "type": "autoencoder", - "iterate_batch": true, - "config": { - "encoder": { - "type": "oobleck", - "config": { - "in_channels": 2, - "channels": 128, - "c_mults": [1, 2, 4, 8, 16], - "strides": [2, 4, 4, 8, 8], - "latent_dim": 128, - "use_snake": true - } - }, - "decoder": { - "type": "oobleck", - "config": { - "out_channels": 2, - "channels": 128, - "c_mults": [1, 2, 4, 8, 16], - "strides": [2, 4, 4, 8, 8], - "latent_dim": 64, - "use_snake": true, - "final_tanh": false - } - }, - "bottleneck": { - "type": "vae" - }, - "latent_dim": 64, - "downsampling_ratio": 2048, - "io_channels": 2 - } - }, - "conditioning": { - "configs": [ - { - "id": "metaclip_features", - "type": "mm_unchang", - "config": { - "dim": 1024, - "output_dim": 1024 - } - }, - { - "id": "metaclip_text_features", - "type": "mm_unchang", - "config": { - "dim": 1024, - "output_dim": 1024 - } - }, - { - "id": "sync_features", - "type": "mm_unchang", - "config": { - "dim": 768, - "output_dim": 768 - } - } - ], - "cond_dim": 768 - }, - "diffusion": { - "mm_cond_ids": ["metaclip_features", "sync_features", "metaclip_text_features"], - "type": "mmdit", - "diffusion_objective": "rectified_flow", - "config": { - "latent_dim":64, - "clip_dim":1024, - "sync_dim":768, - "text_dim":1024, - "hidden_dim":896 , - "depth":21, - "fused_depth":14, - "num_heads":14, - "latent_seq_len":194, - "clip_seq_len":72, - "sync_seq_len":216, - "v2": true - } - }, - "io_channels": 64 - }, - "training": { - "use_ema": true, - "log_loss_info": false, - "cfg_dropout_prob": 0.2, - "pre_encoded": true, - "optimizer_configs": { - "diffusion": { - "optimizer": { - "type": "AdamW", - "config": { - "lr": 1e-4, - "betas": [0.9, 0.95], - "weight_decay": 1e-3, - "fused": true, - "eps": 1e-6 - } - }, - "scheduler": { - "type": "InverseLR", - "config": { - "inv_gamma": 100000, - "power": 0.5, - "warmup": 0.99 - } - } - } - }, - "demo": { - "demo_every": 2000, - "demo_steps": 24, - "num_demos": 10, - "demo_cond": [ - "dataset/vggsound/video_224_pad_latents_text/test/0Cu33yBwAPg_000060.pth", - "dataset/vggsound/video_224_pad_latents_text/test/bmKtI808DsU_000009.pth", - "dataset/vggsound/video_224_pad_latents_text/test/VC0c22cJTbM_000424.pth", - "dataset/vggsound/video_224_pad_latents_text/test/F3gsbUTdc2U_000090.pth", - "dataset/vggsound/video_224_pad_latents_text/test/WatvT8A8iug_000100.pth", - "dataset/vggsound/video_224_pad_latents_text/test/0nvBTp-q7tU_000112.pth", - "dataset/vggsound/video_224_pad_latents_text/test/3-PFuDkTM48_000080.pth", - "dataset/vggsound/video_224_pad_latents_text/test/luSAuu-BoPs_000232.pth", - "dataset/vggsound/video_224_pad_latents_text/test/__8UJxW0aOQ_000002.pth", - "dataset/vggsound/video_224_pad_latents_text/test/_0m_YMpQayA_000168.pth" - ], - "demo_cfg_scales": [3,6,9] - } - } -} \ No newline at end of file diff --git a/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_oldlr.json b/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_oldlr.json deleted file mode 100644 index a989ec63b012163f887c2d152199d6b54a782837..0000000000000000000000000000000000000000 --- a/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_oldlr.json +++ /dev/null @@ -1,138 +0,0 @@ -{ - "model_type": "mm_diffusion_cond", - "sample_size": 397312, - "sample_rate": 44100, - "audio_channels": 2, - "model": { - "pretransform": { - "type": "autoencoder", - "iterate_batch": true, - "config": { - "encoder": { - "type": "oobleck", - "config": { - "in_channels": 2, - "channels": 128, - "c_mults": [1, 2, 4, 8, 16], - "strides": [2, 4, 4, 8, 8], - "latent_dim": 128, - "use_snake": true - } - }, - "decoder": { - "type": "oobleck", - "config": { - "out_channels": 2, - "channels": 128, - "c_mults": [1, 2, 4, 8, 16], - "strides": [2, 4, 4, 8, 8], - "latent_dim": 64, - "use_snake": true, - "final_tanh": false - } - }, - "bottleneck": { - "type": "vae" - }, - "latent_dim": 64, - "downsampling_ratio": 2048, - "io_channels": 2 - } - }, - "conditioning": { - "configs": [ - { - "id": "metaclip_features", - "type": "mm_unchang", - "config": { - "dim": 1024, - "output_dim": 1024 - } - }, - { - "id": "metaclip_text_features", - "type": "mm_unchang", - "config": { - "dim": 1024, - "output_dim": 1024 - } - }, - { - "id": "sync_features", - "type": "mm_unchang", - "config": { - "dim": 768, - "output_dim": 768 - } - } - ], - "cond_dim": 768 - }, - "diffusion": { - "mm_cond_ids": ["metaclip_features", "sync_features", "metaclip_text_features"], - "type": "mmdit", - "diffusion_objective": "rectified_flow", - "config": { - "latent_dim":64, - "clip_dim":1024, - "sync_dim":768, - "text_dim":1024, - "hidden_dim":1024 , - "depth":21, - "fused_depth":14, - "num_heads":16, - "latent_seq_len":194, - "clip_seq_len":72, - "sync_seq_len":216, - "v2": true - } - }, - "io_channels": 64 - }, - "training": { - "use_ema": true, - "log_loss_info": false, - "cfg_dropout_prob": 0.2, - "pre_encoded": true, - "optimizer_configs": { - "diffusion": { - "optimizer": { - "type": "AdamW", - "config": { - "lr": 1e-4, - "betas": [0.9, 0.95], - "weight_decay": 1e-6, - "fused": true, - "eps": 1e-6 - } - }, - "scheduler": { - "type": "InverseLR", - "config": { - "inv_gamma": 1000000, - "power": 0.5, - "warmup": 0.99 - } - } - } - }, - "demo": { - "demo_every": 2000, - "demo_steps": 24, - "num_demos": 10, - "demo_cond": [ - "dataset/vggsound/video_224_pad_latents_text/test/0Cu33yBwAPg_000060.pth", - "dataset/vggsound/video_224_pad_latents_text/test/bmKtI808DsU_000009.pth", - "dataset/vggsound/video_224_pad_latents_text/test/VC0c22cJTbM_000424.pth", - "dataset/vggsound/video_224_pad_latents_text/test/F3gsbUTdc2U_000090.pth", - "dataset/vggsound/video_224_pad_latents_text/test/WatvT8A8iug_000100.pth", - "dataset/vggsound/video_224_pad_latents_text/test/0nvBTp-q7tU_000112.pth", - "dataset/vggsound/video_224_pad_latents_text/test/3-PFuDkTM48_000080.pth", - "dataset/vggsound/video_224_pad_latents_text/test/luSAuu-BoPs_000232.pth", - "dataset/vggsound/video_224_pad_latents_text/test/__8UJxW0aOQ_000002.pth", - "dataset/vggsound/video_224_pad_latents_text/test/_0m_YMpQayA_000168.pth" - ], - "demo_cfg_scales": [3,6,9] - } - } -} \ No newline at end of file diff --git a/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_v.json b/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_v.json deleted file mode 100644 index c5743ed295023a20afe335ffd0748a05ae53fd05..0000000000000000000000000000000000000000 --- a/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_v.json +++ /dev/null @@ -1,137 +0,0 @@ -{ - "model_type": "mm_diffusion_cond", - "sample_size": 397312, - "sample_rate": 44100, - "audio_channels": 2, - "model": { - "pretransform": { - "type": "autoencoder", - "iterate_batch": true, - "config": { - "encoder": { - "type": "oobleck", - "config": { - "in_channels": 2, - "channels": 128, - "c_mults": [1, 2, 4, 8, 16], - "strides": [2, 4, 4, 8, 8], - "latent_dim": 128, - "use_snake": true - } - }, - "decoder": { - "type": "oobleck", - "config": { - "out_channels": 2, - "channels": 128, - "c_mults": [1, 2, 4, 8, 16], - "strides": [2, 4, 4, 8, 8], - "latent_dim": 64, - "use_snake": true, - "final_tanh": false - } - }, - "bottleneck": { - "type": "vae" - }, - "latent_dim": 64, - "downsampling_ratio": 2048, - "io_channels": 2 - } - }, - "conditioning": { - "configs": [ - { - "id": "metaclip_features", - "type": "mm_unchang", - "config": { - "dim": 1024, - "output_dim": 1024 - } - }, - { - "id": "metaclip_text_features", - "type": "mm_unchang", - "config": { - "dim": 1024, - "output_dim": 1024 - } - }, - { - "id": "sync_features", - "type": "mm_unchang", - "config": { - "dim": 768, - "output_dim": 768 - } - } - ], - "cond_dim": 768 - }, - "diffusion": { - "mm_cond_ids": ["metaclip_features", "sync_features", "metaclip_text_features"], - "type": "mmdit", - "diffusion_objective": "v", - "config": { - "latent_dim":64, - "clip_dim":1024, - "sync_dim":768, - "text_dim":1024, - "hidden_dim":1024 , - "depth":21, - "fused_depth":14, - "num_heads":16, - "latent_seq_len":194, - "clip_seq_len":72, - "sync_seq_len":216, - "v2": false - } - }, - "io_channels": 64 - }, - "training": { - "use_ema": true, - "log_loss_info": false, - "cfg_dropout_prob": 0.2, - "pre_encoded": true, - "optimizer_configs": { - "diffusion": { - "optimizer": { - "type": "AdamW", - "config": { - "lr": 5e-5, - "betas": [0.9, 0.95], - "weight_decay": 1e-3, - "eps": 1e-6 - } - }, - "scheduler": { - "type": "InverseLR", - "config": { - "inv_gamma": 100000, - "power": 0.5, - "warmup": 0.99 - } - } - } - }, - "demo": { - "demo_every": 2000, - "demo_steps": 24, - "num_demos": 10, - "demo_cond": [ - "dataset/vggsound/video_224_pad_latents_text/test/0Cu33yBwAPg_000060.pth", - "dataset/vggsound/video_224_pad_latents_text/test/bmKtI808DsU_000009.pth", - "dataset/vggsound/video_224_pad_latents_text/test/VC0c22cJTbM_000424.pth", - "dataset/vggsound/video_224_pad_latents_text/test/F3gsbUTdc2U_000090.pth", - "dataset/vggsound/video_224_pad_latents_text/test/WatvT8A8iug_000100.pth", - "dataset/vggsound/video_224_pad_latents_text/test/0nvBTp-q7tU_000112.pth", - "dataset/vggsound/video_224_pad_latents_text/test/3-PFuDkTM48_000080.pth", - "dataset/vggsound/video_224_pad_latents_text/test/luSAuu-BoPs_000232.pth", - "dataset/vggsound/video_224_pad_latents_text/test/__8UJxW0aOQ_000002.pth", - "dataset/vggsound/video_224_pad_latents_text/test/_0m_YMpQayA_000168.pth" - ], - "demo_cfg_scales": [3,6,9] - } - } -} \ No newline at end of file diff --git a/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_v1.json b/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_v1.json deleted file mode 100644 index a3b5400c23cf62568f84294fba5427f7caf19720..0000000000000000000000000000000000000000 --- a/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_v1.json +++ /dev/null @@ -1,138 +0,0 @@ -{ - "model_type": "mm_diffusion_cond", - "sample_size": 397312, - "sample_rate": 44100, - "audio_channels": 2, - "model": { - "pretransform": { - "type": "autoencoder", - "iterate_batch": true, - "config": { - "encoder": { - "type": "oobleck", - "config": { - "in_channels": 2, - "channels": 128, - "c_mults": [1, 2, 4, 8, 16], - "strides": [2, 4, 4, 8, 8], - "latent_dim": 128, - "use_snake": true - } - }, - "decoder": { - "type": "oobleck", - "config": { - "out_channels": 2, - "channels": 128, - "c_mults": [1, 2, 4, 8, 16], - "strides": [2, 4, 4, 8, 8], - "latent_dim": 64, - "use_snake": true, - "final_tanh": false - } - }, - "bottleneck": { - "type": "vae" - }, - "latent_dim": 64, - "downsampling_ratio": 2048, - "io_channels": 2 - } - }, - "conditioning": { - "configs": [ - { - "id": "metaclip_features", - "type": "mm_unchang", - "config": { - "dim": 1024, - "output_dim": 1024 - } - }, - { - "id": "metaclip_text_features", - "type": "mm_unchang", - "config": { - "dim": 1024, - "output_dim": 1024 - } - }, - { - "id": "sync_features", - "type": "mm_unchang", - "config": { - "dim": 768, - "output_dim": 768 - } - } - ], - "cond_dim": 768 - }, - "diffusion": { - "mm_cond_ids": ["metaclip_features", "sync_features", "metaclip_text_features"], - "type": "mmdit", - "diffusion_objective": "rectified_flow", - "config": { - "latent_dim":64, - "clip_dim":1024, - "sync_dim":768, - "text_dim":1024, - "hidden_dim":1024 , - "depth":21, - "fused_depth":14, - "num_heads":16, - "latent_seq_len":194, - "clip_seq_len":72, - "sync_seq_len":216, - "v2": false - } - }, - "io_channels": 64 - }, - "training": { - "use_ema": true, - "log_loss_info": false, - "cfg_dropout_prob": 0.2, - "pre_encoded": true, - "timestep_sampler": "logit_normal", - "optimizer_configs": { - "diffusion": { - "optimizer": { - "type": "AdamW", - "config": { - "lr": 5e-5, - "betas": [0.9, 0.95], - "weight_decay": 1e-4, - "eps": 1e-6 - } - }, - "scheduler": { - "type": "InverseLR", - "config": { - "inv_gamma": 1000000, - "power": 0.5, - "warmup": 0.99 - } - } - } - }, - "demo": { - "demo_every": 2000, - "demo_steps": 24, - "num_demos": 10, - "demo_cond": [ - "dataset/vggsound/video_224_latents_text_sync_roi_npz/test/0Cu33yBwAPg_000060.npz", - "dataset/vggsound/video_224_latents_text_sync_roi_npz/test/bmKtI808DsU_000009.npz", - "dataset/vggsound/video_224_latents_text_sync_roi_npz/test/VC0c22cJTbM_000424.npz", - "dataset/vggsound/video_224_latents_text_sync_roi_npz/test/F3gsbUTdc2U_000090.npz", - "dataset/vggsound/video_224_latents_text_sync_roi_npz/test/WatvT8A8iug_000100.npz", - "dataset/vggsound/video_224_latents_text_sync_roi_npz/test/0nvBTp-q7tU_000112.npz", - "dataset/vggsound/video_224_latents_text_sync_roi_npz/test/3-PFuDkTM48_000080.npz", - "dataset/vggsound/video_224_latents_text_sync_roi_npz/test/luSAuu-BoPs_000232.npz", - "dataset/vggsound/video_224_latents_text_sync_roi_npz/test/__8UJxW0aOQ_000002.npz", - "dataset/vggsound/video_224_latents_text_sync_roi_npz/test/_0m_YMpQayA_000168.npz" - ], - "demo_cfg_scales": [3,6,9] - } - } -} \ No newline at end of file diff --git a/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_vae.json b/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_vae.json deleted file mode 100644 index cf32673a444143f40f2fe21e890eb317f01ad544..0000000000000000000000000000000000000000 --- a/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_vae.json +++ /dev/null @@ -1,138 +0,0 @@ -{ - "model_type": "mm_diffusion_cond", - "sample_size": 397312, - "sample_rate": 44100, - "audio_channels": 2, - "model": { - "pretransform": { - "type": "autoencoder", - "iterate_batch": true, - "config": { - "encoder": { - "type": "oobleck", - "config": { - "in_channels": 2, - "channels": 128, - "c_mults": [1, 2, 4, 8, 16], - "strides": [2, 4, 4, 8, 8], - "latent_dim": 128, - "use_snake": true - } - }, - "decoder": { - "type": "oobleck", - "config": { - "out_channels": 2, - "channels": 128, - "c_mults": [1, 2, 4, 8, 16], - "strides": [2, 4, 4, 8, 8], - "latent_dim": 64, - "use_snake": true, - "final_tanh": false - } - }, - "bottleneck": { - "type": "vae" - }, - "latent_dim": 64, - "downsampling_ratio": 2048, - "io_channels": 2 - } - }, - "conditioning": { - "configs": [ - { - "id": "metaclip_features", - "type": "mm_unchang", - "config": { - "dim": 1024, - "output_dim": 1024 - } - }, - { - "id": "metaclip_text_features", - "type": "mm_unchang", - "config": { - "dim": 1024, - "output_dim": 1024 - } - }, - { - "id": "sync_features", - "type": "mm_unchang", - "config": { - "dim": 768, - "output_dim": 768 - } - } - ], - "cond_dim": 768 - }, - "diffusion": { - "mm_cond_ids": ["metaclip_features", "sync_features", "metaclip_text_features"], - "type": "mmdit", - "diffusion_objective": "rectified_flow", - "config": { - "latent_dim":64, - "clip_dim":1024, - "sync_dim":768, - "text_dim":1024, - "hidden_dim":1024 , - "depth":21, - "fused_depth":14, - "num_heads":16, - "latent_seq_len":194, - "clip_seq_len":72, - "sync_seq_len":216, - "v2": true - } - }, - "io_channels": 64 - }, - "training": { - "use_ema": true, - "log_loss_info": false, - "cfg_dropout_prob": 0.2, - "pre_encoded": false, - "timestep_sampler": "logit_normal", - "optimizer_configs": { - "diffusion": { - "optimizer": { - "type": "AdamW", - "config": { - "lr": 5e-5, - "betas": [0.9, 0.95], - "weight_decay": 1e-4, - "eps": 1e-6 - } - }, - "scheduler": { - "type": "InverseLR", - "config": { - "inv_gamma": 1000000, - "power": 0.5, - "warmup": 0.99 - } - } - } - }, - "demo": { - "demo_every": 2000, - "demo_steps": 24, - "num_demos": 10, - "demo_cond": [ - "dataset/vggsound/video_224_latents_text_sync_roi_npz/test/0Cu33yBwAPg_000060.npz", - "dataset/vggsound/video_224_latents_text_sync_roi_npz/test/bmKtI808DsU_000009.npz", - "dataset/vggsound/video_224_latents_text_sync_roi_npz/test/VC0c22cJTbM_000424.npz", - "dataset/vggsound/video_224_latents_text_sync_roi_npz/test/F3gsbUTdc2U_000090.npz", - "dataset/vggsound/video_224_latents_text_sync_roi_npz/test/WatvT8A8iug_000100.npz", - "dataset/vggsound/video_224_latents_text_sync_roi_npz/test/0nvBTp-q7tU_000112.npz", - "dataset/vggsound/video_224_latents_text_sync_roi_npz/test/3-PFuDkTM48_000080.npz", - "dataset/vggsound/video_224_latents_text_sync_roi_npz/test/luSAuu-BoPs_000232.npz", - "dataset/vggsound/video_224_latents_text_sync_roi_npz/test/__8UJxW0aOQ_000002.npz", - "dataset/vggsound/video_224_latents_text_sync_roi_npz/test/_0m_YMpQayA_000168.npz" - ], - "demo_cfg_scales": [3,6,9] - } - } -} \ No newline at end of file diff --git a/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_weight.json b/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_weight.json deleted file mode 100644 index 5a3458f86a3cd4bd5103e74bc03233b553416c2a..0000000000000000000000000000000000000000 --- a/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_weight.json +++ /dev/null @@ -1,137 +0,0 @@ -{ - "model_type": "mm_diffusion_cond", - "sample_size": 397312, - "sample_rate": 44100, - "audio_channels": 2, - "model": { - "pretransform": { - "type": "autoencoder", - "iterate_batch": true, - "config": { - "encoder": { - "type": "oobleck", - "config": { - "in_channels": 2, - "channels": 128, - "c_mults": [1, 2, 4, 8, 16], - "strides": [2, 4, 4, 8, 8], - "latent_dim": 128, - "use_snake": true - } - }, - "decoder": { - "type": "oobleck", - "config": { - "out_channels": 2, - "channels": 128, - "c_mults": [1, 2, 4, 8, 16], - "strides": [2, 4, 4, 8, 8], - "latent_dim": 64, - "use_snake": true, - "final_tanh": false - } - }, - "bottleneck": { - "type": "vae" - }, - "latent_dim": 64, - "downsampling_ratio": 2048, - "io_channels": 2 - } - }, - "conditioning": { - "configs": [ - { - "id": "metaclip_features", - "type": "mm_unchang", - "config": { - "dim": 1024, - "output_dim": 1024 - } - }, - { - "id": "metaclip_text_features", - "type": "mm_unchang", - "config": { - "dim": 1024, - "output_dim": 1024 - } - }, - { - "id": "sync_features", - "type": "mm_unchang", - "config": { - "dim": 768, - "output_dim": 768 - } - } - ], - "cond_dim": 768 - }, - "diffusion": { - "mm_cond_ids": ["metaclip_features", "sync_features", "metaclip_text_features"], - "type": "mmdit", - "diffusion_objective": "rectified_flow", - "config": { - "latent_dim":64, - "clip_dim":1024, - "sync_dim":768, - "text_dim":1024, - "hidden_dim":1024 , - "depth":21, - "fused_depth":14, - "num_heads":16, - "latent_seq_len":194, - "clip_seq_len":72, - "sync_seq_len":216, - "v2": true - } - }, - "io_channels": 64 - }, - "training": { - "use_ema": true, - "log_loss_info": false, - "cfg_dropout_prob": 0.2, - "pre_encoded": true, - "optimizer_configs": { - "diffusion": { - "optimizer": { - "type": "AdamW", - "config": { - "lr": 5e-5, - "betas": [0.9, 0.95], - "weight_decay": 1e-3, - "eps": 1e-6 - } - }, - "scheduler": { - "type": "InverseLR", - "config": { - "inv_gamma": 1000000, - "power": 0.5, - "warmup": 0.99 - } - } - } - }, - "demo": { - "demo_every": 2000, - "demo_steps": 24, - "num_demos": 10, - "demo_cond": [ - "dataset/vggsound/video_224_pad_latents_text/test/0Cu33yBwAPg_000060.pth", - "dataset/vggsound/video_224_pad_latents_text/test/bmKtI808DsU_000009.pth", - "dataset/vggsound/video_224_pad_latents_text/test/VC0c22cJTbM_000424.pth", - "dataset/vggsound/video_224_pad_latents_text/test/F3gsbUTdc2U_000090.pth", - "dataset/vggsound/video_224_pad_latents_text/test/WatvT8A8iug_000100.pth", - "dataset/vggsound/video_224_pad_latents_text/test/0nvBTp-q7tU_000112.pth", - "dataset/vggsound/video_224_pad_latents_text/test/3-PFuDkTM48_000080.pth", - "dataset/vggsound/video_224_pad_latents_text/test/luSAuu-BoPs_000232.pth", - "dataset/vggsound/video_224_pad_latents_text/test/__8UJxW0aOQ_000002.pth", - "dataset/vggsound/video_224_pad_latents_text/test/_0m_YMpQayA_000168.pth" - ], - "demo_cfg_scales": [3,6,9] - } - } -} \ No newline at end of file diff --git a/think_sound/configs/model_configs/vt2audio/multimodal_clip.json b/think_sound/configs/model_configs/vt2audio/multimodal_clip.json deleted file mode 100644 index 0e47ef726ebca26a0c31848773e718c515164192..0000000000000000000000000000000000000000 --- a/think_sound/configs/model_configs/vt2audio/multimodal_clip.json +++ /dev/null @@ -1,125 +0,0 @@ -{ - "model_type": "diffusion_cond", - "sample_size": 397312, - "sample_rate": 44100, - "audio_channels": 2, - "model": { - "pretransform": { - "type": "autoencoder", - "iterate_batch": true, - "config": { - "encoder": { - "type": "oobleck", - "config": { - "in_channels": 2, - "channels": 128, - "c_mults": [1, 2, 4, 8, 16], - "strides": [2, 4, 4, 8, 8], - "latent_dim": 128, - "use_snake": true - } - }, - "decoder": { - "type": "oobleck", - "config": { - "out_channels": 2, - "channels": 128, - "c_mults": [1, 2, 4, 8, 16], - "strides": [2, 4, 4, 8, 8], - "latent_dim": 64, - "use_snake": true, - "final_tanh": false - } - }, - "bottleneck": { - "type": "vae" - }, - "latent_dim": 64, - "downsampling_ratio": 2048, - "io_channels": 2 - } - }, - "conditioning": { - "configs": [ - { - "id": "clip_features", - "type": "video_linear", - "config": { - "dim": 1024, - "output_dim": 1536 - } - }, - { - "id": "caption_t5", - "type": "t5", - "config": { - "t5_model_name": "t5-v1_1-xl", - "output_dim": 1536 - } - } - ], - "cond_dim": 1536 - }, - "diffusion": { - "add_cond_ids": ["clip_features"], - "cross_attention_cond_ids": ["caption_t5"], - "type": "dit", - "diffusion_objective": "rectified_flow", - "config": { - "io_channels": 64, - "embed_dim": 1536, - "depth": 24, - "num_heads": 24, - "cond_token_dim": 1536, - "global_cond_dim": 1536, - "project_cond_tokens": false, - "transformer_type": "continuous_transformer" - } - }, - "io_channels": 64 - }, - "training": { - "use_ema": true, - "log_loss_info": false, - "cfg_dropout_prob": 0.2, - "pre_encoded": false, - "optimizer_configs": { - "diffusion": { - "optimizer": { - "type": "AdamW", - "config": { - "lr": 5e-5, - "betas": [0.9, 0.999], - "weight_decay": 1e-3 - } - }, - "scheduler": { - "type": "InverseLR", - "config": { - "inv_gamma": 1000000, - "power": 0.5, - "warmup": 0.99 - } - } - } - }, - "demo": { - "demo_every": 2000, - "demo_steps": 64, - "num_demos": 10, - "demo_cond": [ - "dataset/vggsound/video_latents_text/test/0Cu33yBwAPg_000060.pth", - "dataset/vggsound/video_latents_text/test/bmKtI808DsU_000009.pth", - "dataset/vggsound/video_latents_text/test/VC0c22cJTbM_000424.pth", - "dataset/vggsound/video_latents_text/test/F3gsbUTdc2U_000090.pth", - "dataset/vggsound/video_latents_text/test/WatvT8A8iug_000100.pth", - "dataset/vggsound/video_latents_text/test/0nvBTp-q7tU_000112.pth", - "dataset/vggsound/video_latents_text/test/3-PFuDkTM48_000080.pth", - "dataset/vggsound/video_latents_text/test/luSAuu-BoPs_000232.pth", - "dataset/vggsound/video_latents_text/test/__8UJxW0aOQ_000002.pth", - "dataset/vggsound/video_latents_text/test/_0m_YMpQayA_000168.pth" - ], - "demo_cfg_scales": [3,6,9] - } - } -} \ No newline at end of file diff --git a/think_sound/configs/model_configs/vt2audio/multimodal_clip_text_clip.json b/think_sound/configs/model_configs/vt2audio/multimodal_clip_text_clip.json deleted file mode 100644 index 3502aa0b9145ac5886742d3a6c96727eef69e209..0000000000000000000000000000000000000000 --- a/think_sound/configs/model_configs/vt2audio/multimodal_clip_text_clip.json +++ /dev/null @@ -1,124 +0,0 @@ -{ - "model_type": "diffusion_cond", - "sample_size": 397312, - "sample_rate": 44100, - "audio_channels": 2, - "model": { - "pretransform": { - "type": "autoencoder", - "iterate_batch": true, - "config": { - "encoder": { - "type": "oobleck", - "config": { - "in_channels": 2, - "channels": 128, - "c_mults": [1, 2, 4, 8, 16], - "strides": [2, 4, 4, 8, 8], - "latent_dim": 128, - "use_snake": true - } - }, - "decoder": { - "type": "oobleck", - "config": { - "out_channels": 2, - "channels": 128, - "c_mults": [1, 2, 4, 8, 16], - "strides": [2, 4, 4, 8, 8], - "latent_dim": 64, - "use_snake": true, - "final_tanh": false - } - }, - "bottleneck": { - "type": "vae" - }, - "latent_dim": 64, - "downsampling_ratio": 2048, - "io_channels": 2 - } - }, - "conditioning": { - "configs": [ - { - "id": "clip_features", - "type": "video_linear", - "config": { - "dim": 1024, - "output_dim": 1536 - } - }, - { - "id": "caption", - "type": "clip_text", - "config": { - "output_dim": 768 - } - } - ], - "cond_dim": 768 - }, - "diffusion": { - "add_cond_ids": ["clip_features"], - "cross_attention_cond_ids": ["caption"], - "type": "dit", - "diffusion_objective": "rectified_flow", - "config": { - "io_channels": 64, - "embed_dim": 1536, - "depth": 24, - "num_heads": 24, - "cond_token_dim": 768, - "global_cond_dim": 1536, - "project_cond_tokens": false, - "transformer_type": "continuous_transformer" - } - }, - "io_channels": 64 - }, - "training": { - "use_ema": true, - "log_loss_info": false, - "cfg_dropout_prob": 0.2, - "pre_encoded": false, - "optimizer_configs": { - "diffusion": { - "optimizer": { - "type": "AdamW", - "config": { - "lr": 5e-5, - "betas": [0.9, 0.999], - "weight_decay": 1e-3 - } - }, - "scheduler": { - "type": "InverseLR", - "config": { - "inv_gamma": 1000000, - "power": 0.5, - "warmup": 0.99 - } - } - } - }, - "demo": { - "demo_every": 2000, - "demo_steps": 64, - "num_demos": 10, - "demo_cond": [ - "dataset/vggsound/video_latents_text/test/0Cu33yBwAPg_000060.pth", - "dataset/vggsound/video_latents_text/test/bmKtI808DsU_000009.pth", - "dataset/vggsound/video_latents_text/test/VC0c22cJTbM_000424.pth", - "dataset/vggsound/video_latents_text/test/F3gsbUTdc2U_000090.pth", - "dataset/vggsound/video_latents_text/test/WatvT8A8iug_000100.pth", - "dataset/vggsound/video_latents_text/test/0nvBTp-q7tU_000112.pth", - "dataset/vggsound/video_latents_text/test/3-PFuDkTM48_000080.pth", - "dataset/vggsound/video_latents_text/test/luSAuu-BoPs_000232.pth", - "dataset/vggsound/video_latents_text/test/__8UJxW0aOQ_000002.pth", - "dataset/vggsound/video_latents_text/test/_0m_YMpQayA_000168.pth" - ], - "demo_cfg_scales": [3,6,9] - } - } -} \ No newline at end of file diff --git a/think_sound/configs/model_configs/vt2audio/multimodal_clip_text_metaclip.json b/think_sound/configs/model_configs/vt2audio/multimodal_clip_text_metaclip.json deleted file mode 100644 index d9cdd9376a75d5bbcade31db3aed8f255c7c73da..0000000000000000000000000000000000000000 --- a/think_sound/configs/model_configs/vt2audio/multimodal_clip_text_metaclip.json +++ /dev/null @@ -1,141 +0,0 @@ -{ - "model_type": "diffusion_cond", - "sample_size": 397312, - "sample_rate": 44100, - "audio_channels": 2, - "model": { - "pretransform": { - "type": "autoencoder", - "iterate_batch": true, - "config": { - "encoder": { - "type": "oobleck", - "config": { - "in_channels": 2, - "channels": 128, - "c_mults": [1, 2, 4, 8, 16], - "strides": [2, 4, 4, 8, 8], - "latent_dim": 128, - "use_snake": true - } - }, - "decoder": { - "type": "oobleck", - "config": { - "out_channels": 2, - "channels": 128, - "c_mults": [1, 2, 4, 8, 16], - "strides": [2, 4, 4, 8, 8], - "latent_dim": 64, - "use_snake": true, - "final_tanh": false - } - }, - "bottleneck": { - "type": "vae" - }, - "latent_dim": 64, - "downsampling_ratio": 2048, - "io_channels": 2 - } - }, - "conditioning": { - "configs": [ - { - "id": "clip_features", - "type": "video_linear", - "config": { - "dim": 1024, - "output_dim": 1536 - } - }, - { - "id": "caption", - "type": "metaclip_text", - "config": { - "output_dim": 768 - } - }, - { - "id": "seconds_start", - "type": "number", - "config": { - "min_val": 0, - "max_val": 512 - } - }, - { - "id": "seconds_total", - "type": "number", - "config": { - "min_val": 0, - "max_val": 512 - } - } - ], - "cond_dim": 768 - }, - "diffusion": { - "add_cond_ids": ["clip_features"], - "cross_attention_cond_ids": ["caption", "seconds_start", "seconds_total"], - "global_cond_ids": ["seconds_start", "seconds_total"], - "type": "dit", - "diffusion_objective": "rectified_flow", - "config": { - "io_channels": 64, - "embed_dim": 1536, - "depth": 24, - "num_heads": 24, - "cond_token_dim": 768, - "global_cond_dim": 1536, - "project_cond_tokens": false, - "transformer_type": "continuous_transformer" - } - }, - "io_channels": 64 - }, - "training": { - "use_ema": true, - "log_loss_info": false, - "cfg_dropout_prob": 0.2, - "pre_encoded": false, - "optimizer_configs": { - "diffusion": { - "optimizer": { - "type": "AdamW", - "config": { - "lr": 5e-5, - "betas": [0.9, 0.999], - "weight_decay": 1e-3 - } - }, - "scheduler": { - "type": "InverseLR", - "config": { - "inv_gamma": 1000000, - "power": 0.5, - "warmup": 0.99 - } - } - } - }, - "demo": { - "demo_every": 2000, - "demo_steps": 64, - "num_demos": 10, - "demo_cond": [ - "dataset/vggsound/video_latents_text/test/0Cu33yBwAPg_000060.pth", - "dataset/vggsound/video_latents_text/test/bmKtI808DsU_000009.pth", - "dataset/vggsound/video_latents_text/test/VC0c22cJTbM_000424.pth", - "dataset/vggsound/video_latents_text/test/F3gsbUTdc2U_000090.pth", - "dataset/vggsound/video_latents_text/test/WatvT8A8iug_000100.pth", - "dataset/vggsound/video_latents_text/test/0nvBTp-q7tU_000112.pth", - "dataset/vggsound/video_latents_text/test/3-PFuDkTM48_000080.pth", - "dataset/vggsound/video_latents_text/test/luSAuu-BoPs_000232.pth", - "dataset/vggsound/video_latents_text/test/__8UJxW0aOQ_000002.pth", - "dataset/vggsound/video_latents_text/test/_0m_YMpQayA_000168.pth" - ], - "demo_cfg_scales": [3,6,9] - } - } -} \ No newline at end of file diff --git a/think_sound/configs/model_configs/vt2audio/vt2a_stable_clip_load.json b/think_sound/configs/model_configs/vt2audio/vt2a_stable_clip_load.json deleted file mode 100644 index 2015675c4576bc88310e9f0f8b2a941584e7c8a7..0000000000000000000000000000000000000000 --- a/think_sound/configs/model_configs/vt2audio/vt2a_stable_clip_load.json +++ /dev/null @@ -1,139 +0,0 @@ -{ - "model_type": "diffusion_cond", - "sample_size": 441000, - "sample_rate": 44100, - "audio_channels": 2, - "model": { - "pretransform": { - "type": "autoencoder", - "iterate_batch": true, - "config": { - "encoder": { - "type": "oobleck", - "config": { - "in_channels": 2, - "channels": 128, - "c_mults": [1, 2, 4, 8, 16], - "strides": [2, 4, 4, 8, 8], - "latent_dim": 128, - "use_snake": true - } - }, - "decoder": { - "type": "oobleck", - "config": { - "out_channels": 2, - "channels": 128, - "c_mults": [1, 2, 4, 8, 16], - "strides": [2, 4, 4, 8, 8], - "latent_dim": 64, - "use_snake": true, - "final_tanh": false - } - }, - "bottleneck": { - "type": "vae" - }, - "latent_dim": 64, - "downsampling_ratio": 2048, - "io_channels": 2 - } - }, - "conditioning": { - "configs": [ - { - "id": "video", - "type": "video_linear", - "config": { - "dim": 1024, - "output_dim": 1536 - } - }, - { - "id": "prompt", - "type": "t5", - "config": { - "t5_model_name": "t5-v1_1-xl", - "output_dim": 768 - } - }, - { - "id": "seconds_start", - "type": "number", - "config": { - "min_val": 0, - "max_val": 512 - } - }, - { - "id": "seconds_total", - "type": "number", - "config": { - "min_val": 0, - "max_val": 512 - } - } - ], - "cond_dim": 768 - }, - "diffusion": { - "cross_attention_cond_ids": ["prompt", "seconds_start", "seconds_total"], - "global_cond_ids": ["seconds_start", "seconds_total"], - "add_cond_ids": ["video"], - "type": "dit", - "diffusion_objective": "rectified_flow", - "config": { - "io_channels": 64, - "embed_dim": 1536, - "depth": 24, - "num_heads": 24, - "cond_token_dim": 768, - "global_cond_dim": 1536, - "project_cond_tokens": false, - "transformer_type": "continuous_transformer" - } - }, - "io_channels": 64 - }, - "training": { - "use_ema": true, - "log_loss_info": false, - "cfg_dropout_prob": 0.2, - "optimizer_configs": { - "diffusion": { - "optimizer": { - "type": "AdamW", - "config": { - "lr": 5e-5, - "betas": [0.9, 0.999], - "weight_decay": 1e-3 - } - }, - "scheduler": { - "type": "InverseLR", - "config": { - "inv_gamma": 1000000, - "power": 0.5, - "warmup": 0.99 - } - } - } - }, - "demo": { - "demo_every": 2000, - "demo_steps": 64, - "num_demos": 8, - "demo_cond": [ - {"video": "data/VGGSOUND/MetaClip-Huge/test/0Cu33yBwAPg_000060.npy", "prompt": "church bell ringing", "seconds_start": 0, "seconds_total": 10}, - {"video": "data/VGGSOUND/MetaClip-Huge/test/bmKtI808DsU_000009.npy", "prompt": "lions growling", "seconds_start": 0, "seconds_total": 10}, - {"video": "data/VGGSOUND/MetaClip-Huge/test/VC0c22cJTbM_000424.npy", "prompt": "writing on blackboard with chalk", "seconds_start": 0, "seconds_total": 10}, - {"video": "data/VGGSOUND/MetaClip-Huge/test/F3gsbUTdc2U_000090.npy", "prompt": "wind chime", "seconds_start": 0, "seconds_total": 10}, - {"video": "data/VGGSOUND/MetaClip-Huge/test/WatvT8A8iug_000100.npy", "prompt": "car passing by", "seconds_start": 0, "seconds_total": 10}, - {"video": "data/VGGSOUND/MetaClip-Huge/test/0nvBTp-q7tU_000112.npy", "prompt": "driving snowmobile", "seconds_start": 0, "seconds_total": 10}, - {"video": "data/VGGSOUND/MetaClip-Huge/test/3-PFuDkTM48_000080.npy", "prompt": "playing accordion", "seconds_start": 0, "seconds_total": 10}, - {"video": "data/VGGSOUND/MetaClip-Huge/test/luSAuu-BoPs_000232.npy", "prompt": "squishing water", "seconds_start": 0, "seconds_total": 10} - ], - "demo_cfg_scales": [3,6,9] - } - } -} \ No newline at end of file diff --git a/think_sound/data/__pycache__/__init__.cpython-310.pyc b/think_sound/data/__pycache__/__init__.cpython-310.pyc deleted file mode 100644 index ef50ce6b886a66ebe29fa889b95c859c2155374f..0000000000000000000000000000000000000000 Binary files a/think_sound/data/__pycache__/__init__.cpython-310.pyc and /dev/null differ diff --git a/think_sound/data/__pycache__/datamodule.cpython-310.pyc b/think_sound/data/__pycache__/datamodule.cpython-310.pyc deleted file mode 100644 index dcb3ed109244948edcca0ebfe0c32808ab02e88a..0000000000000000000000000000000000000000 Binary files a/think_sound/data/__pycache__/datamodule.cpython-310.pyc and /dev/null differ diff --git a/think_sound/data/__pycache__/dataset.cpython-310.pyc b/think_sound/data/__pycache__/dataset.cpython-310.pyc deleted file mode 100644 index 070073611ef8a8915f747bd35074d502b5c57fbb..0000000000000000000000000000000000000000 Binary files a/think_sound/data/__pycache__/dataset.cpython-310.pyc and /dev/null differ diff --git a/think_sound/data/__pycache__/utils.cpython-310.pyc b/think_sound/data/__pycache__/utils.cpython-310.pyc deleted file mode 100644 index 037168b106b682650443f9d4399af0ce37889375..0000000000000000000000000000000000000000 Binary files a/think_sound/data/__pycache__/utils.cpython-310.pyc and /dev/null differ diff --git a/think_sound/inference/__init__.py b/think_sound/inference/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/think_sound/inference/__pycache__/__init__.cpython-310.pyc b/think_sound/inference/__pycache__/__init__.cpython-310.pyc deleted file mode 100644 index 5023ae21e83129fdc7f23e439e637f624e7221a9..0000000000000000000000000000000000000000 Binary files a/think_sound/inference/__pycache__/__init__.cpython-310.pyc and /dev/null differ diff --git a/think_sound/inference/__pycache__/__init__.cpython-38.pyc b/think_sound/inference/__pycache__/__init__.cpython-38.pyc deleted file mode 100644 index 269dca61e7b70ba3213ad89ca48564e2e4b54d11..0000000000000000000000000000000000000000 Binary files a/think_sound/inference/__pycache__/__init__.cpython-38.pyc and /dev/null differ diff --git a/think_sound/inference/__pycache__/__init__.cpython-39.pyc b/think_sound/inference/__pycache__/__init__.cpython-39.pyc deleted file mode 100644 index ad93a24373bcdf798ecd55b60e833a270e18fb7f..0000000000000000000000000000000000000000 Binary files a/think_sound/inference/__pycache__/__init__.cpython-39.pyc and /dev/null differ diff --git a/think_sound/inference/__pycache__/generation.cpython-310.pyc b/think_sound/inference/__pycache__/generation.cpython-310.pyc deleted file mode 100644 index 988d166a673c290635ee8dc5d50368e173a94500..0000000000000000000000000000000000000000 Binary files a/think_sound/inference/__pycache__/generation.cpython-310.pyc and /dev/null differ diff --git a/think_sound/inference/__pycache__/generation.cpython-38.pyc b/think_sound/inference/__pycache__/generation.cpython-38.pyc deleted file mode 100644 index 2e8fa0fd739358e1117f01430fa5a77777b9869f..0000000000000000000000000000000000000000 Binary files a/think_sound/inference/__pycache__/generation.cpython-38.pyc and /dev/null differ diff --git a/think_sound/inference/__pycache__/generation.cpython-39.pyc b/think_sound/inference/__pycache__/generation.cpython-39.pyc deleted file mode 100644 index a8aa87aa381c697f459aff50c5f5fdb4a358e579..0000000000000000000000000000000000000000 Binary files a/think_sound/inference/__pycache__/generation.cpython-39.pyc and /dev/null differ diff --git a/think_sound/inference/__pycache__/sampling.cpython-310.pyc b/think_sound/inference/__pycache__/sampling.cpython-310.pyc deleted file mode 100644 index 3118560f043620b53fedaa9780b95c0a171f9924..0000000000000000000000000000000000000000 Binary files a/think_sound/inference/__pycache__/sampling.cpython-310.pyc and /dev/null differ diff --git a/think_sound/inference/__pycache__/sampling.cpython-38.pyc b/think_sound/inference/__pycache__/sampling.cpython-38.pyc deleted file mode 100644 index 5f97d01246f3cd254535df0889cc5baa10512190..0000000000000000000000000000000000000000 Binary files a/think_sound/inference/__pycache__/sampling.cpython-38.pyc and /dev/null differ diff --git a/think_sound/inference/__pycache__/sampling.cpython-39.pyc b/think_sound/inference/__pycache__/sampling.cpython-39.pyc deleted file mode 100644 index 2bf7c849f2e03b8d63fd51f085a4857040d28f6c..0000000000000000000000000000000000000000 Binary files a/think_sound/inference/__pycache__/sampling.cpython-39.pyc and /dev/null differ diff --git a/think_sound/inference/__pycache__/utils.cpython-310.pyc b/think_sound/inference/__pycache__/utils.cpython-310.pyc deleted file mode 100644 index 9aee8d08753eb0ed4929cb1750a212e4d6a0f53c..0000000000000000000000000000000000000000 Binary files a/think_sound/inference/__pycache__/utils.cpython-310.pyc and /dev/null differ diff --git a/think_sound/inference/__pycache__/utils.cpython-38.pyc b/think_sound/inference/__pycache__/utils.cpython-38.pyc deleted file mode 100644 index c7ae3b2b62d814ed92d69bc8e8269bf6a6ee56fe..0000000000000000000000000000000000000000 Binary files a/think_sound/inference/__pycache__/utils.cpython-38.pyc and /dev/null differ diff --git a/think_sound/inference/__pycache__/utils.cpython-39.pyc b/think_sound/inference/__pycache__/utils.cpython-39.pyc deleted file mode 100644 index 3e84bec0b8fb14faae453194b6c6572f0d45a6c5..0000000000000000000000000000000000000000 Binary files a/think_sound/inference/__pycache__/utils.cpython-39.pyc and /dev/null differ diff --git a/think_sound/interface/__init__.py b/think_sound/interface/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/think_sound/interface/__pycache__/__init__.cpython-38.pyc b/think_sound/interface/__pycache__/__init__.cpython-38.pyc deleted file mode 100644 index 9585914d732a148c23ea9b8eed93b5b383238bbe..0000000000000000000000000000000000000000 Binary files a/think_sound/interface/__pycache__/__init__.cpython-38.pyc and /dev/null differ diff --git a/think_sound/interface/__pycache__/__init__.cpython-39.pyc b/think_sound/interface/__pycache__/__init__.cpython-39.pyc deleted file mode 100644 index c791e8cb9797b74d6f44dadc2706615fc07e6d61..0000000000000000000000000000000000000000 Binary files a/think_sound/interface/__pycache__/__init__.cpython-39.pyc and /dev/null differ diff --git a/think_sound/interface/__pycache__/gradio.cpython-38.pyc b/think_sound/interface/__pycache__/gradio.cpython-38.pyc deleted file mode 100644 index eae1bc0b0e4fe71946d0042d73b0e4857089f3b1..0000000000000000000000000000000000000000 Binary files a/think_sound/interface/__pycache__/gradio.cpython-38.pyc and /dev/null differ diff --git a/think_sound/interface/__pycache__/gradio.cpython-39.pyc b/think_sound/interface/__pycache__/gradio.cpython-39.pyc deleted file mode 100644 index 8cb07f162f318c799c027a72e8be8a92f940da2a..0000000000000000000000000000000000000000 Binary files a/think_sound/interface/__pycache__/gradio.cpython-39.pyc and /dev/null differ diff --git a/think_sound/interface/gradio.py b/think_sound/interface/gradio.py deleted file mode 100644 index f38468bc34b88ec6bbe5451a8b11b998430888f8..0000000000000000000000000000000000000000 --- a/think_sound/interface/gradio.py +++ /dev/null @@ -1,700 +0,0 @@ -import gc -import platform - -import numpy as np -import gradio as gr -import json -import torch -import torchaudio - -from aeiou.viz import audio_spectrogram_image -from einops import rearrange -from safetensors.torch import load_file -from torch.nn import functional as F -from torchaudio import transforms as T - -from ..inference.generation import generate_diffusion_cond, generate_diffusion_uncond -from ..models.factory import create_model_from_config -from ..models.pretrained import get_pretrained_model -from ..models.utils import load_ckpt_state_dict -from ..inference.utils import prepare_audio -from ..training.utils import copy_state_dict - -model = None -sample_rate = 32000 -sample_size = 1920000 - -def load_model(model_config=None, model_ckpt_path=None, pretrained_name=None, pretransform_ckpt_path=None, device="cuda", model_half=False): - global model, sample_rate, sample_size - - if pretrained_name is not None: - print(f"Loading pretrained model {pretrained_name}") - model, model_config = get_pretrained_model(pretrained_name) - - elif model_config is not None and model_ckpt_path is not None: - print(f"Creating model from config") - model = create_model_from_config(model_config) - - print(f"Loading model checkpoint from {model_ckpt_path}") - # Load checkpoint - copy_state_dict(model, load_ckpt_state_dict(model_ckpt_path)) - #model.load_state_dict(load_ckpt_state_dict(model_ckpt_path)) - - sample_rate = model_config["sample_rate"] - sample_size = model_config["sample_size"] - - if pretransform_ckpt_path is not None: - print(f"Loading pretransform checkpoint from {pretransform_ckpt_path}") - model.pretransform.load_state_dict(load_ckpt_state_dict(pretransform_ckpt_path), strict=False) - print(f"Done loading pretransform") - - model.to(device).eval().requires_grad_(False) - - if model_half: - model.to(torch.float16) - - print(f"Done loading model") - - return model, model_config - -def generate_cond( - prompt, - negative_prompt=None, - seconds_start=0, - seconds_total=30, - cfg_scale=6.0, - steps=250, - preview_every=None, - seed=-1, - sampler_type="dpmpp-3m-sde", - sigma_min=0.03, - sigma_max=1000, - cfg_rescale=0.0, - use_init=False, - init_audio=None, - init_noise_level=1.0, - mask_cropfrom=None, - mask_pastefrom=None, - mask_pasteto=None, - mask_maskstart=None, - mask_maskend=None, - mask_softnessL=None, - mask_softnessR=None, - mask_marination=None, - batch_size=1 - ): - - if torch.cuda.is_available(): - torch.cuda.empty_cache() - gc.collect() - - print(f"Prompt: {prompt}") - - global preview_images - preview_images = [] - if preview_every == 0: - preview_every = None - - # Return fake stereo audio - conditioning = [{"prompt": prompt, "seconds_start": seconds_start, "seconds_total": seconds_total}] * batch_size - - if negative_prompt: - negative_conditioning = [{"prompt": negative_prompt, "seconds_start": seconds_start, "seconds_total": seconds_total}] * batch_size - else: - negative_conditioning = None - - #Get the device from the model - device = next(model.parameters()).device - - seed = int(seed) - - if not use_init: - init_audio = None - - input_sample_size = sample_size - - if init_audio is not None: - in_sr, init_audio = init_audio - # Turn into torch tensor, converting from int16 to float32 - init_audio = torch.from_numpy(init_audio).float().div(32767) - - if init_audio.dim() == 1: - init_audio = init_audio.unsqueeze(0) # [1, n] - elif init_audio.dim() == 2: - init_audio = init_audio.transpose(0, 1) # [n, 2] -> [2, n] - - if in_sr != sample_rate: - resample_tf = T.Resample(in_sr, sample_rate).to(init_audio.device) - init_audio = resample_tf(init_audio) - - audio_length = init_audio.shape[-1] - - if audio_length > sample_size: - - input_sample_size = audio_length + (model.min_input_length - (audio_length % model.min_input_length)) % model.min_input_length - - init_audio = (sample_rate, init_audio) - - def progress_callback(callback_info): - global preview_images - denoised = callback_info["denoised"] - current_step = callback_info["i"] - sigma = callback_info["sigma"] - - if (current_step - 1) % preview_every == 0: - if model.pretransform is not None: - denoised = model.pretransform.decode(denoised) - denoised = rearrange(denoised, "b d n -> d (b n)") - denoised = denoised.clamp(-1, 1).mul(32767).to(torch.int16).cpu() - audio_spectrogram = audio_spectrogram_image(denoised, sample_rate=sample_rate) - preview_images.append((audio_spectrogram, f"Step {current_step} sigma={sigma:.3f})")) - - # If inpainting, send mask args - # This will definitely change in the future - if mask_cropfrom is not None: - mask_args = { - "cropfrom": mask_cropfrom, - "pastefrom": mask_pastefrom, - "pasteto": mask_pasteto, - "maskstart": mask_maskstart, - "maskend": mask_maskend, - "softnessL": mask_softnessL, - "softnessR": mask_softnessR, - "marination": mask_marination, - } - else: - mask_args = None - - # Do the audio generation - audio = generate_diffusion_cond( - model, - conditioning=conditioning, - negative_conditioning=negative_conditioning, - steps=steps, - cfg_scale=cfg_scale, - batch_size=batch_size, - sample_size=input_sample_size, - sample_rate=sample_rate, - seed=seed, - device=device, - sampler_type=sampler_type, - sigma_min=sigma_min, - sigma_max=sigma_max, - init_audio=init_audio, - init_noise_level=init_noise_level, - mask_args = mask_args, - callback = progress_callback if preview_every is not None else None, - scale_phi = cfg_rescale - ) - - # Convert to WAV file - audio = rearrange(audio, "b d n -> d (b n)") - audio = audio.to(torch.float32).div(torch.max(torch.abs(audio))).clamp(-1, 1).mul(32767).to(torch.int16).cpu() - torchaudio.save("output.wav", audio, sample_rate) - - # Let's look at a nice spectrogram too - audio_spectrogram = audio_spectrogram_image(audio, sample_rate=sample_rate) - - return ("output.wav", [audio_spectrogram, *preview_images]) - -def generate_uncond( - steps=250, - seed=-1, - sampler_type="dpmpp-3m-sde", - sigma_min=0.03, - sigma_max=1000, - use_init=False, - init_audio=None, - init_noise_level=1.0, - batch_size=1, - preview_every=None - ): - - global preview_images - - preview_images = [] - - if torch.cuda.is_available(): - torch.cuda.empty_cache() - gc.collect() - - #Get the device from the model - device = next(model.parameters()).device - - seed = int(seed) - - if not use_init: - init_audio = None - - input_sample_size = sample_size - - if init_audio is not None: - in_sr, init_audio = init_audio - # Turn into torch tensor, converting from int16 to float32 - init_audio = torch.from_numpy(init_audio).float().div(32767) - - if init_audio.dim() == 1: - init_audio = init_audio.unsqueeze(0) # [1, n] - elif init_audio.dim() == 2: - init_audio = init_audio.transpose(0, 1) # [n, 2] -> [2, n] - - if in_sr != sample_rate: - resample_tf = T.Resample(in_sr, sample_rate).to(init_audio.device) - init_audio = resample_tf(init_audio) - - audio_length = init_audio.shape[-1] - - if audio_length > sample_size: - - input_sample_size = audio_length + (model.min_input_length - (audio_length % model.min_input_length)) % model.min_input_length - - init_audio = (sample_rate, init_audio) - - def progress_callback(callback_info): - global preview_images - denoised = callback_info["denoised"] - current_step = callback_info["i"] - sigma = callback_info["sigma"] - - if (current_step - 1) % preview_every == 0: - - if model.pretransform is not None: - denoised = model.pretransform.decode(denoised) - - denoised = rearrange(denoised, "b d n -> d (b n)") - - denoised = denoised.clamp(-1, 1).mul(32767).to(torch.int16).cpu() - - audio_spectrogram = audio_spectrogram_image(denoised, sample_rate=sample_rate) - - preview_images.append((audio_spectrogram, f"Step {current_step} sigma={sigma:.3f})")) - - audio = generate_diffusion_uncond( - model, - steps=steps, - batch_size=batch_size, - sample_size=input_sample_size, - seed=seed, - device=device, - sampler_type=sampler_type, - sigma_min=sigma_min, - sigma_max=sigma_max, - init_audio=init_audio, - init_noise_level=init_noise_level, - callback = progress_callback if preview_every is not None else None - ) - - audio = rearrange(audio, "b d n -> d (b n)") - - audio = audio.to(torch.float32).div(torch.max(torch.abs(audio))).clamp(-1, 1).mul(32767).to(torch.int16).cpu() - - torchaudio.save("output.wav", audio, sample_rate) - - audio_spectrogram = audio_spectrogram_image(audio, sample_rate=sample_rate) - - return ("output.wav", [audio_spectrogram, *preview_images]) - -def generate_lm( - temperature=1.0, - top_p=0.95, - top_k=0, - batch_size=1, - ): - - if torch.cuda.is_available(): - torch.cuda.empty_cache() - gc.collect() - - #Get the device from the model - device = next(model.parameters()).device - - audio = model.generate_audio( - batch_size=batch_size, - max_gen_len = sample_size//model.pretransform.downsampling_ratio, - conditioning=None, - temp=temperature, - top_p=top_p, - top_k=top_k, - use_cache=True - ) - - audio = rearrange(audio, "b d n -> d (b n)") - - audio = audio.to(torch.float32).div(torch.max(torch.abs(audio))).clamp(-1, 1).mul(32767).to(torch.int16).cpu() - - torchaudio.save("output.wav", audio, sample_rate) - - audio_spectrogram = audio_spectrogram_image(audio, sample_rate=sample_rate) - - return ("output.wav", [audio_spectrogram]) - - -def create_uncond_sampling_ui(model_config): - generate_button = gr.Button("Generate", variant='primary', scale=1) - - with gr.Row(equal_height=False): - with gr.Column(): - with gr.Row(): - # Steps slider - steps_slider = gr.Slider(minimum=1, maximum=500, step=1, value=100, label="Steps") - - with gr.Accordion("Sampler params", open=False): - - # Seed - seed_textbox = gr.Textbox(label="Seed (set to -1 for random seed)", value="-1") - - # Sampler params - with gr.Row(): - sampler_type_dropdown = gr.Dropdown(["dpmpp-2m-sde", "dpmpp-3m-sde", "k-heun", "k-lms", "k-dpmpp-2s-ancestral", "k-dpm-2", "k-dpm-fast"], label="Sampler type", value="dpmpp-3m-sde") - sigma_min_slider = gr.Slider(minimum=0.0, maximum=2.0, step=0.01, value=0.03, label="Sigma min") - sigma_max_slider = gr.Slider(minimum=0.0, maximum=1000.0, step=0.1, value=500, label="Sigma max") - - with gr.Accordion("Init audio", open=False): - init_audio_checkbox = gr.Checkbox(label="Use init audio") - init_audio_input = gr.Audio(label="Init audio") - init_noise_level_slider = gr.Slider(minimum=0.0, maximum=100.0, step=0.01, value=0.1, label="Init noise level") - - with gr.Column(): - audio_output = gr.Audio(label="Output audio", interactive=False) - audio_spectrogram_output = gr.Gallery(label="Output spectrogram", show_label=False) - send_to_init_button = gr.Button("Send to init audio", scale=1) - send_to_init_button.click(fn=lambda audio: audio, inputs=[audio_output], outputs=[init_audio_input]) - - generate_button.click(fn=generate_uncond, - inputs=[ - steps_slider, - seed_textbox, - sampler_type_dropdown, - sigma_min_slider, - sigma_max_slider, - init_audio_checkbox, - init_audio_input, - init_noise_level_slider, - ], - outputs=[ - audio_output, - audio_spectrogram_output - ], - api_name="generate") - -def create_sampling_ui(model_config, inpainting=False): - with gr.Row(): - with gr.Column(scale=6): - prompt = gr.Textbox(show_label=False, placeholder="Prompt") - negative_prompt = gr.Textbox(show_label=False, placeholder="Negative prompt") - generate_button = gr.Button("Generate", variant='primary', scale=1) - - model_conditioning_config = model_config["model"].get("conditioning", None) - - has_seconds_start = False - has_seconds_total = False - - if model_conditioning_config is not None: - for conditioning_config in model_conditioning_config["configs"]: - if conditioning_config["id"] == "seconds_start": - has_seconds_start = True - if conditioning_config["id"] == "seconds_total": - has_seconds_total = True - - with gr.Row(equal_height=False): - with gr.Column(): - with gr.Row(visible = has_seconds_start or has_seconds_total): - # Timing controls - seconds_start_slider = gr.Slider(minimum=0, maximum=512, step=1, value=0, label="Seconds start", visible=has_seconds_start) - seconds_total_slider = gr.Slider(minimum=0, maximum=512, step=1, value=sample_size//sample_rate, label="Seconds total", visible=has_seconds_total) - - with gr.Row(): - # Steps slider - steps_slider = gr.Slider(minimum=1, maximum=500, step=1, value=100, label="Steps") - - # Preview Every slider - preview_every_slider = gr.Slider(minimum=0, maximum=100, step=1, value=0, label="Preview Every") - - # CFG scale - cfg_scale_slider = gr.Slider(minimum=0.0, maximum=25.0, step=0.1, value=7.0, label="CFG scale") - - with gr.Accordion("Sampler params", open=False): - - # Seed - seed_textbox = gr.Textbox(label="Seed (set to -1 for random seed)", value="-1") - - # Sampler params - with gr.Row(): - sampler_type_dropdown = gr.Dropdown(["dpmpp-2m-sde", "dpmpp-3m-sde", "k-heun", "k-lms", "k-dpmpp-2s-ancestral", "k-dpm-2", "k-dpm-fast"], label="Sampler type", value="dpmpp-3m-sde") - sigma_min_slider = gr.Slider(minimum=0.0, maximum=2.0, step=0.01, value=0.03, label="Sigma min") - sigma_max_slider = gr.Slider(minimum=0.0, maximum=1000.0, step=0.1, value=500, label="Sigma max") - cfg_rescale_slider = gr.Slider(minimum=0.0, maximum=1, step=0.01, value=0.0, label="CFG rescale amount") - - if inpainting: - # Inpainting Tab - with gr.Accordion("Inpainting", open=False): - sigma_max_slider.maximum=1000 - - init_audio_checkbox = gr.Checkbox(label="Do inpainting") - init_audio_input = gr.Audio(label="Init audio") - init_noise_level_slider = gr.Slider(minimum=0.1, maximum=100.0, step=0.1, value=80, label="Init audio noise level", visible=False) # hide this - - mask_cropfrom_slider = gr.Slider(minimum=0.0, maximum=100.0, step=0.1, value=0, label="Crop From %") - mask_pastefrom_slider = gr.Slider(minimum=0.0, maximum=100.0, step=0.1, value=0, label="Paste From %") - mask_pasteto_slider = gr.Slider(minimum=0.0, maximum=100.0, step=0.1, value=100, label="Paste To %") - - mask_maskstart_slider = gr.Slider(minimum=0.0, maximum=100.0, step=0.1, value=50, label="Mask Start %") - mask_maskend_slider = gr.Slider(minimum=0.0, maximum=100.0, step=0.1, value=100, label="Mask End %") - mask_softnessL_slider = gr.Slider(minimum=0.0, maximum=100.0, step=0.1, value=0, label="Softmask Left Crossfade Length %") - mask_softnessR_slider = gr.Slider(minimum=0.0, maximum=100.0, step=0.1, value=0, label="Softmask Right Crossfade Length %") - mask_marination_slider = gr.Slider(minimum=0.0, maximum=1, step=0.0001, value=0, label="Marination level", visible=False) # still working on the usefulness of this - - inputs = [prompt, - negative_prompt, - seconds_start_slider, - seconds_total_slider, - cfg_scale_slider, - steps_slider, - preview_every_slider, - seed_textbox, - sampler_type_dropdown, - sigma_min_slider, - sigma_max_slider, - cfg_rescale_slider, - init_audio_checkbox, - init_audio_input, - init_noise_level_slider, - mask_cropfrom_slider, - mask_pastefrom_slider, - mask_pasteto_slider, - mask_maskstart_slider, - mask_maskend_slider, - mask_softnessL_slider, - mask_softnessR_slider, - mask_marination_slider - ] - else: - # Default generation tab - with gr.Accordion("Init audio", open=False): - init_audio_checkbox = gr.Checkbox(label="Use init audio") - init_audio_input = gr.Audio(label="Init audio") - init_noise_level_slider = gr.Slider(minimum=0.1, maximum=100.0, step=0.01, value=0.1, label="Init noise level") - - inputs = [prompt, - negative_prompt, - seconds_start_slider, - seconds_total_slider, - cfg_scale_slider, - steps_slider, - preview_every_slider, - seed_textbox, - sampler_type_dropdown, - sigma_min_slider, - sigma_max_slider, - cfg_rescale_slider, - init_audio_checkbox, - init_audio_input, - init_noise_level_slider - ] - - with gr.Column(): - audio_output = gr.Audio(label="Output audio", interactive=False) - audio_spectrogram_output = gr.Gallery(label="Output spectrogram", show_label=False) - send_to_init_button = gr.Button("Send to init audio", scale=1) - send_to_init_button.click(fn=lambda audio: audio, inputs=[audio_output], outputs=[init_audio_input]) - - generate_button.click(fn=generate_cond, - inputs=inputs, - outputs=[ - audio_output, - audio_spectrogram_output - ], - api_name="generate") - - -def create_txt2audio_ui(model_config): - with gr.Blocks() as ui: - with gr.Tab("Generation"): - create_sampling_ui(model_config) - with gr.Tab("Inpainting"): - create_sampling_ui(model_config, inpainting=True) - return ui - -def create_diffusion_uncond_ui(model_config): - with gr.Blocks() as ui: - create_uncond_sampling_ui(model_config) - - return ui - -def autoencoder_process(audio, latent_noise, n_quantizers): - if torch.cuda.is_available(): - torch.cuda.empty_cache() - gc.collect() - - #Get the device from the model - device = next(model.parameters()).device - - in_sr, audio = audio - - audio = torch.from_numpy(audio).float().div(32767).to(device) - - if audio.dim() == 1: - audio = audio.unsqueeze(0) - else: - audio = audio.transpose(0, 1) - - audio = model.preprocess_audio_for_encoder(audio, in_sr) - # Note: If you need to do chunked encoding, to reduce VRAM, - # then add these arguments to encode_audio and decode_audio: chunked=True, overlap=32, chunk_size=128 - # To turn it off, do chunked=False - # Optimal overlap and chunk_size values will depend on the model. - # See encode_audio & decode_audio in autoencoders.py for more info - # Get dtype of model - dtype = next(model.parameters()).dtype - - audio = audio.to(dtype) - - if n_quantizers > 0: - latents = model.encode_audio(audio, chunked=False, n_quantizers=n_quantizers) - else: - latents = model.encode_audio(audio, chunked=False) - - if latent_noise > 0: - latents = latents + torch.randn_like(latents) * latent_noise - - audio = model.decode_audio(latents, chunked=False) - - audio = rearrange(audio, "b d n -> d (b n)") - - audio = audio.to(torch.float32).clamp(-1, 1).mul(32767).to(torch.int16).cpu() - - torchaudio.save("output.wav", audio, sample_rate) - - return "output.wav" - -def create_autoencoder_ui(model_config): - - is_dac_rvq = "model" in model_config and "bottleneck" in model_config["model"] and model_config["model"]["bottleneck"]["type"] in ["dac_rvq","dac_rvq_vae"] - - if is_dac_rvq: - n_quantizers = model_config["model"]["bottleneck"]["config"]["n_codebooks"] - else: - n_quantizers = 0 - - with gr.Blocks() as ui: - input_audio = gr.Audio(label="Input audio") - output_audio = gr.Audio(label="Output audio", interactive=False) - n_quantizers_slider = gr.Slider(minimum=1, maximum=n_quantizers, step=1, value=n_quantizers, label="# quantizers", visible=is_dac_rvq) - latent_noise_slider = gr.Slider(minimum=0.0, maximum=10.0, step=0.001, value=0.0, label="Add latent noise") - process_button = gr.Button("Process", variant='primary', scale=1) - process_button.click(fn=autoencoder_process, inputs=[input_audio, latent_noise_slider, n_quantizers_slider], outputs=output_audio, api_name="process") - - return ui - -def diffusion_prior_process(audio, steps, sampler_type, sigma_min, sigma_max): - - if torch.cuda.is_available(): - torch.cuda.empty_cache() - gc.collect() - - #Get the device from the model - device = next(model.parameters()).device - - in_sr, audio = audio - - audio = torch.from_numpy(audio).float().div(32767).to(device) - - if audio.dim() == 1: - audio = audio.unsqueeze(0) # [1, n] - elif audio.dim() == 2: - audio = audio.transpose(0, 1) # [n, 2] -> [2, n] - - audio = audio.unsqueeze(0) - - audio = model.stereoize(audio, in_sr, steps, sampler_kwargs={"sampler_type": sampler_type, "sigma_min": sigma_min, "sigma_max": sigma_max}) - - audio = rearrange(audio, "b d n -> d (b n)") - - audio = audio.to(torch.float32).div(torch.max(torch.abs(audio))).clamp(-1, 1).mul(32767).to(torch.int16).cpu() - - torchaudio.save("output.wav", audio, sample_rate) - - return "output.wav" - -def create_diffusion_prior_ui(model_config): - with gr.Blocks() as ui: - input_audio = gr.Audio(label="Input audio") - output_audio = gr.Audio(label="Output audio", interactive=False) - # Sampler params - with gr.Row(): - steps_slider = gr.Slider(minimum=1, maximum=500, step=1, value=100, label="Steps") - sampler_type_dropdown = gr.Dropdown(["dpmpp-2m-sde", "dpmpp-3m-sde", "k-heun", "k-lms", "k-dpmpp-2s-ancestral", "k-dpm-2", "k-dpm-fast"], label="Sampler type", value="dpmpp-3m-sde") - sigma_min_slider = gr.Slider(minimum=0.0, maximum=2.0, step=0.01, value=0.03, label="Sigma min") - sigma_max_slider = gr.Slider(minimum=0.0, maximum=1000.0, step=0.1, value=500, label="Sigma max") - process_button = gr.Button("Process", variant='primary', scale=1) - process_button.click(fn=diffusion_prior_process, inputs=[input_audio, steps_slider, sampler_type_dropdown, sigma_min_slider, sigma_max_slider], outputs=output_audio, api_name="process") - - return ui - -def create_lm_ui(model_config): - with gr.Blocks() as ui: - output_audio = gr.Audio(label="Output audio", interactive=False) - audio_spectrogram_output = gr.Gallery(label="Output spectrogram", show_label=False) - - # Sampling params - with gr.Row(): - temperature_slider = gr.Slider(minimum=0, maximum=5, step=0.01, value=1.0, label="Temperature") - top_p_slider = gr.Slider(minimum=0, maximum=1, step=0.01, value=0.95, label="Top p") - top_k_slider = gr.Slider(minimum=0, maximum=100, step=1, value=0, label="Top k") - - generate_button = gr.Button("Generate", variant='primary', scale=1) - generate_button.click( - fn=generate_lm, - inputs=[ - temperature_slider, - top_p_slider, - top_k_slider - ], - outputs=[output_audio, audio_spectrogram_output], - api_name="generate" - ) - - return ui - -def create_ui(model_config_path=None, ckpt_path=None, pretrained_name=None, pretransform_ckpt_path=None, model_half=False): - - assert (pretrained_name is not None) ^ (model_config_path is not None and ckpt_path is not None), "Must specify either pretrained name or provide a model config and checkpoint, but not both" - - if model_config_path is not None: - # Load config from json file - with open(model_config_path) as f: - model_config = json.load(f) - else: - model_config = None - - try: - has_mps = platform.system() == "Darwin" and torch.backends.mps.is_available() - except Exception: - # In case this version of Torch doesn't even have `torch.backends.mps`... - has_mps = False - - if has_mps: - device = torch.device("mps") - elif torch.cuda.is_available(): - device = torch.device("cuda") - else: - device = torch.device("cpu") - - print("Using device:", device) - - _, model_config = load_model(model_config, ckpt_path, pretrained_name=pretrained_name, pretransform_ckpt_path=pretransform_ckpt_path, model_half=model_half, device=device) - - model_type = model_config["model_type"] - - if model_type == "diffusion_cond": - ui = create_txt2audio_ui(model_config) - elif model_type == "diffusion_uncond": - ui = create_diffusion_uncond_ui(model_config) - elif model_type == "autoencoder" or model_type == "diffusion_autoencoder": - ui = create_autoencoder_ui(model_config) - elif model_type == "diffusion_prior": - ui = create_diffusion_prior_ui(model_config) - elif model_type == "lm": - ui = create_lm_ui(model_config) - - return ui diff --git a/think_sound/models/__pycache__/__init__.cpython-310.pyc b/think_sound/models/__pycache__/__init__.cpython-310.pyc deleted file mode 100644 index 616ae52c684375967146708cda52f0a6f0c85f02..0000000000000000000000000000000000000000 Binary files a/think_sound/models/__pycache__/__init__.cpython-310.pyc and /dev/null differ diff --git a/think_sound/models/__pycache__/__init__.cpython-38.pyc b/think_sound/models/__pycache__/__init__.cpython-38.pyc deleted file mode 100644 index 122d8b0c38a5a0e883935b14e41a460b3fca5ca8..0000000000000000000000000000000000000000 Binary files a/think_sound/models/__pycache__/__init__.cpython-38.pyc and /dev/null differ diff --git a/think_sound/models/__pycache__/__init__.cpython-39.pyc b/think_sound/models/__pycache__/__init__.cpython-39.pyc deleted file mode 100644 index 86ac628c21b54b80ad6388a7afc2c2df45622c72..0000000000000000000000000000000000000000 Binary files a/think_sound/models/__pycache__/__init__.cpython-39.pyc and /dev/null differ diff --git a/think_sound/models/__pycache__/adp.cpython-310.pyc b/think_sound/models/__pycache__/adp.cpython-310.pyc deleted file mode 100644 index 02fa73364c0c4c262771e0bf1a0d52e19e3c1907..0000000000000000000000000000000000000000 Binary files a/think_sound/models/__pycache__/adp.cpython-310.pyc and /dev/null differ diff --git a/think_sound/models/__pycache__/adp.cpython-39.pyc b/think_sound/models/__pycache__/adp.cpython-39.pyc deleted file mode 100644 index 47dc1f3adf73733fe498818acc8cd912331c5ce3..0000000000000000000000000000000000000000 Binary files a/think_sound/models/__pycache__/adp.cpython-39.pyc and /dev/null differ diff --git a/think_sound/models/__pycache__/autoencoders.cpython-310.pyc b/think_sound/models/__pycache__/autoencoders.cpython-310.pyc deleted file mode 100644 index 219b7c304c4503e3c009a2f26ba19fdf579b7863..0000000000000000000000000000000000000000 Binary files a/think_sound/models/__pycache__/autoencoders.cpython-310.pyc and /dev/null differ diff --git a/think_sound/models/__pycache__/autoencoders.cpython-39.pyc b/think_sound/models/__pycache__/autoencoders.cpython-39.pyc deleted file mode 100644 index bbffc7fd9bf57cd75207ce4a3ba11c5e7734680a..0000000000000000000000000000000000000000 Binary files a/think_sound/models/__pycache__/autoencoders.cpython-39.pyc and /dev/null differ diff --git a/think_sound/models/__pycache__/blocks.cpython-310.pyc b/think_sound/models/__pycache__/blocks.cpython-310.pyc deleted file mode 100644 index 8d214e7a11f490047824643fb61f4e19db5e9f9c..0000000000000000000000000000000000000000 Binary files a/think_sound/models/__pycache__/blocks.cpython-310.pyc and /dev/null differ diff --git a/think_sound/models/__pycache__/blocks.cpython-39.pyc b/think_sound/models/__pycache__/blocks.cpython-39.pyc deleted file mode 100644 index 89e01add4d2c8e00362289d661ebf222a6f23368..0000000000000000000000000000000000000000 Binary files a/think_sound/models/__pycache__/blocks.cpython-39.pyc and /dev/null differ diff --git a/think_sound/models/__pycache__/bottleneck.cpython-310.pyc b/think_sound/models/__pycache__/bottleneck.cpython-310.pyc deleted file mode 100644 index 7dc1f1046ede83141e434151263840cab1e6cdf7..0000000000000000000000000000000000000000 Binary files a/think_sound/models/__pycache__/bottleneck.cpython-310.pyc and /dev/null differ diff --git a/think_sound/models/__pycache__/bottleneck.cpython-39.pyc b/think_sound/models/__pycache__/bottleneck.cpython-39.pyc deleted file mode 100644 index ef6de8baed483afe3835bcfafacb748b5f19d4d2..0000000000000000000000000000000000000000 Binary files a/think_sound/models/__pycache__/bottleneck.cpython-39.pyc and /dev/null differ diff --git a/think_sound/models/__pycache__/conditioners.cpython-310.pyc b/think_sound/models/__pycache__/conditioners.cpython-310.pyc deleted file mode 100644 index 7de25095434277738e4a7258841f51a9b0d9a589..0000000000000000000000000000000000000000 Binary files a/think_sound/models/__pycache__/conditioners.cpython-310.pyc and /dev/null differ diff --git a/think_sound/models/__pycache__/conditioners.cpython-39.pyc b/think_sound/models/__pycache__/conditioners.cpython-39.pyc deleted file mode 100644 index 9caf5becb3b8e2c68f9f16ee5f1c677bfa902539..0000000000000000000000000000000000000000 Binary files a/think_sound/models/__pycache__/conditioners.cpython-39.pyc and /dev/null differ diff --git a/think_sound/models/__pycache__/diffusion.cpython-310.pyc b/think_sound/models/__pycache__/diffusion.cpython-310.pyc deleted file mode 100644 index 16ed10dfb4658fc2b07dd5f174ccdefa1ce3eb1f..0000000000000000000000000000000000000000 Binary files a/think_sound/models/__pycache__/diffusion.cpython-310.pyc and /dev/null differ diff --git a/think_sound/models/__pycache__/diffusion.cpython-39.pyc b/think_sound/models/__pycache__/diffusion.cpython-39.pyc deleted file mode 100644 index 98a49fdad3e599119289e4c31de61d3241649daf..0000000000000000000000000000000000000000 Binary files a/think_sound/models/__pycache__/diffusion.cpython-39.pyc and /dev/null differ diff --git a/think_sound/models/__pycache__/diffusion_prior.cpython-310.pyc b/think_sound/models/__pycache__/diffusion_prior.cpython-310.pyc deleted file mode 100644 index 350be5114f0b6715caaa7d380f13e6613ca510b0..0000000000000000000000000000000000000000 Binary files a/think_sound/models/__pycache__/diffusion_prior.cpython-310.pyc and /dev/null differ diff --git a/think_sound/models/__pycache__/diffusion_prior.cpython-39.pyc b/think_sound/models/__pycache__/diffusion_prior.cpython-39.pyc deleted file mode 100644 index 77de8bff884189b7ed28a05feab048e9a45a0105..0000000000000000000000000000000000000000 Binary files a/think_sound/models/__pycache__/diffusion_prior.cpython-39.pyc and /dev/null differ diff --git a/think_sound/models/__pycache__/discriminators.cpython-310.pyc b/think_sound/models/__pycache__/discriminators.cpython-310.pyc deleted file mode 100644 index 8d922769aa15e38fc6f549952d0f7d1e2e42ff60..0000000000000000000000000000000000000000 Binary files a/think_sound/models/__pycache__/discriminators.cpython-310.pyc and /dev/null differ diff --git a/think_sound/models/__pycache__/discriminators.cpython-39.pyc b/think_sound/models/__pycache__/discriminators.cpython-39.pyc deleted file mode 100644 index 40ffe9b8e621f52e0ffe87a3a4f374a7ccdc3080..0000000000000000000000000000000000000000 Binary files a/think_sound/models/__pycache__/discriminators.cpython-39.pyc and /dev/null differ diff --git a/think_sound/models/__pycache__/dit.cpython-310.pyc b/think_sound/models/__pycache__/dit.cpython-310.pyc deleted file mode 100644 index 36c7f1f80b7d678b8f188219b14ae5b3d39c2a38..0000000000000000000000000000000000000000 Binary files a/think_sound/models/__pycache__/dit.cpython-310.pyc and /dev/null differ diff --git a/think_sound/models/__pycache__/dit.cpython-39.pyc b/think_sound/models/__pycache__/dit.cpython-39.pyc deleted file mode 100644 index 6af671dd9cb9e4b70c16f37344078e9ef68248f1..0000000000000000000000000000000000000000 Binary files a/think_sound/models/__pycache__/dit.cpython-39.pyc and /dev/null differ diff --git a/think_sound/models/__pycache__/factory.cpython-310.pyc b/think_sound/models/__pycache__/factory.cpython-310.pyc deleted file mode 100644 index ac5e8e82742679dcb4462a9cf5853c34ddbaaf21..0000000000000000000000000000000000000000 Binary files a/think_sound/models/__pycache__/factory.cpython-310.pyc and /dev/null differ diff --git a/think_sound/models/__pycache__/factory.cpython-38.pyc b/think_sound/models/__pycache__/factory.cpython-38.pyc deleted file mode 100644 index 772c0458f56f40df0255cc677b39a5d13a63c1ee..0000000000000000000000000000000000000000 Binary files a/think_sound/models/__pycache__/factory.cpython-38.pyc and /dev/null differ diff --git a/think_sound/models/__pycache__/factory.cpython-39.pyc b/think_sound/models/__pycache__/factory.cpython-39.pyc deleted file mode 100644 index 2ceb64b5a60a53e792ddb4c383e705755b8fb061..0000000000000000000000000000000000000000 Binary files a/think_sound/models/__pycache__/factory.cpython-39.pyc and /dev/null differ diff --git a/think_sound/models/__pycache__/mmdit.cpython-310.pyc b/think_sound/models/__pycache__/mmdit.cpython-310.pyc deleted file mode 100644 index 0c15fe5dfcc45d568faa9c708cb0e6a35884013c..0000000000000000000000000000000000000000 Binary files a/think_sound/models/__pycache__/mmdit.cpython-310.pyc and /dev/null differ diff --git a/think_sound/models/__pycache__/mmdit.cpython-39.pyc b/think_sound/models/__pycache__/mmdit.cpython-39.pyc deleted file mode 100644 index ab736975de43034ba2bdcfec33676d349758693e..0000000000000000000000000000000000000000 Binary files a/think_sound/models/__pycache__/mmdit.cpython-39.pyc and /dev/null differ diff --git a/think_sound/models/__pycache__/pretrained.cpython-310.pyc b/think_sound/models/__pycache__/pretrained.cpython-310.pyc deleted file mode 100644 index 835bcd5ebdd0eabe77b77a2acd2224246d665a03..0000000000000000000000000000000000000000 Binary files a/think_sound/models/__pycache__/pretrained.cpython-310.pyc and /dev/null differ diff --git a/think_sound/models/__pycache__/pretrained.cpython-38.pyc b/think_sound/models/__pycache__/pretrained.cpython-38.pyc deleted file mode 100644 index e660e52b6bac33295abcb24fd59fbed479f6d182..0000000000000000000000000000000000000000 Binary files a/think_sound/models/__pycache__/pretrained.cpython-38.pyc and /dev/null differ diff --git a/think_sound/models/__pycache__/pretrained.cpython-39.pyc b/think_sound/models/__pycache__/pretrained.cpython-39.pyc deleted file mode 100644 index 81ee0fd4e357f8c4c605eb2377912c37ddaa1055..0000000000000000000000000000000000000000 Binary files a/think_sound/models/__pycache__/pretrained.cpython-39.pyc and /dev/null differ diff --git a/think_sound/models/__pycache__/pretransforms.cpython-310.pyc b/think_sound/models/__pycache__/pretransforms.cpython-310.pyc deleted file mode 100644 index a46248ae84f97f324fe3e59c03816189e367621f..0000000000000000000000000000000000000000 Binary files a/think_sound/models/__pycache__/pretransforms.cpython-310.pyc and /dev/null differ diff --git a/think_sound/models/__pycache__/pretransforms.cpython-39.pyc b/think_sound/models/__pycache__/pretransforms.cpython-39.pyc deleted file mode 100644 index 13240d647ee8132e85353b99bb32c3cdc5886558..0000000000000000000000000000000000000000 Binary files a/think_sound/models/__pycache__/pretransforms.cpython-39.pyc and /dev/null differ diff --git a/think_sound/models/__pycache__/transformer.cpython-310.pyc b/think_sound/models/__pycache__/transformer.cpython-310.pyc deleted file mode 100644 index f4661416a8bb59f98a72f0baf10ac14496c1986b..0000000000000000000000000000000000000000 Binary files a/think_sound/models/__pycache__/transformer.cpython-310.pyc and /dev/null differ diff --git a/think_sound/models/__pycache__/transformer.cpython-39.pyc b/think_sound/models/__pycache__/transformer.cpython-39.pyc deleted file mode 100644 index 8957557e758f2d304a0f1ebf04f76e5a2b547c1f..0000000000000000000000000000000000000000 Binary files a/think_sound/models/__pycache__/transformer.cpython-39.pyc and /dev/null differ diff --git a/think_sound/models/__pycache__/utils.cpython-310.pyc b/think_sound/models/__pycache__/utils.cpython-310.pyc deleted file mode 100644 index 1eb7c4445470a6b94a542334f399b147e0a12d66..0000000000000000000000000000000000000000 Binary files a/think_sound/models/__pycache__/utils.cpython-310.pyc and /dev/null differ diff --git a/think_sound/models/__pycache__/utils.cpython-38.pyc b/think_sound/models/__pycache__/utils.cpython-38.pyc deleted file mode 100644 index a2785b42d22f3342618a5251cd7a86a7a2c81d0d..0000000000000000000000000000000000000000 Binary files a/think_sound/models/__pycache__/utils.cpython-38.pyc and /dev/null differ diff --git a/think_sound/models/__pycache__/utils.cpython-39.pyc b/think_sound/models/__pycache__/utils.cpython-39.pyc deleted file mode 100644 index 3bffbdbbc37c5cdf27e9e02467b7b2a4ce215a50..0000000000000000000000000000000000000000 Binary files a/think_sound/models/__pycache__/utils.cpython-39.pyc and /dev/null differ diff --git a/think_sound/models/adp.py b/think_sound/models/adp.py deleted file mode 100644 index 49eb526ab02d16eb4952d346401b1ad2b7e5cb7c..0000000000000000000000000000000000000000 --- a/think_sound/models/adp.py +++ /dev/null @@ -1,1588 +0,0 @@ -# Copied and modified from https://github.com/archinetai/audio-diffusion-pytorch/blob/v0.0.94/audio_diffusion_pytorch/modules.py under MIT License -# License can be found in LICENSES/LICENSE_ADP.txt - -import math -from inspect import isfunction -from math import ceil, floor, log, pi, log2 -from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, TypeVar, Union -from packaging import version - -import torch -import torch.nn as nn -from einops import rearrange, reduce, repeat -from einops.layers.torch import Rearrange -from einops_exts import rearrange_many -from torch import Tensor, einsum -from torch.backends.cuda import sdp_kernel -from torch.nn import functional as F -from dac.nn.layers import Snake1d - -""" -Utils -""" - - -class ConditionedSequential(nn.Module): - def __init__(self, *modules): - super().__init__() - self.module_list = nn.ModuleList(*modules) - - def forward(self, x: Tensor, mapping: Optional[Tensor] = None): - for module in self.module_list: - x = module(x, mapping) - return x - -T = TypeVar("T") - -def default(val: Optional[T], d: Union[Callable[..., T], T]) -> T: - if exists(val): - return val - return d() if isfunction(d) else d - -def exists(val: Optional[T]) -> T: - return val is not None - -def closest_power_2(x: float) -> int: - exponent = log2(x) - distance_fn = lambda z: abs(x - 2 ** z) # noqa - exponent_closest = min((floor(exponent), ceil(exponent)), key=distance_fn) - return 2 ** int(exponent_closest) - -def group_dict_by_prefix(prefix: str, d: Dict) -> Tuple[Dict, Dict]: - return_dicts: Tuple[Dict, Dict] = ({}, {}) - for key in d.keys(): - no_prefix = int(not key.startswith(prefix)) - return_dicts[no_prefix][key] = d[key] - return return_dicts - -def groupby(prefix: str, d: Dict, keep_prefix: bool = False) -> Tuple[Dict, Dict]: - kwargs_with_prefix, kwargs = group_dict_by_prefix(prefix, d) - if keep_prefix: - return kwargs_with_prefix, kwargs - kwargs_no_prefix = {k[len(prefix) :]: v for k, v in kwargs_with_prefix.items()} - return kwargs_no_prefix, kwargs - -""" -Convolutional Blocks -""" -import typing as tp - -# Copied from https://github.com/facebookresearch/audiocraft/blob/main/audiocraft/modules/conv.py under MIT License -# License available in LICENSES/LICENSE_META.txt - -def get_extra_padding_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, - padding_total: int = 0) -> int: - """See `pad_for_conv1d`.""" - length = x.shape[-1] - n_frames = (length - kernel_size + padding_total) / stride + 1 - ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total) - return ideal_length - length - - -def pad_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0): - """Pad for a convolution to make sure that the last window is full. - Extra padding is added at the end. This is required to ensure that we can rebuild - an output of the same length, as otherwise, even with padding, some time steps - might get removed. - For instance, with total padding = 4, kernel size = 4, stride = 2: - 0 0 1 2 3 4 5 0 0 # (0s are padding) - 1 2 3 # (output frames of a convolution, last 0 is never used) - 0 0 1 2 3 4 5 0 # (output of tr. conv., but pos. 5 is going to get removed as padding) - 1 2 3 4 # once you removed padding, we are missing one time step ! - """ - extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total) - return F.pad(x, (0, extra_padding)) - - -def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = 'constant', value: float = 0.): - """Tiny wrapper around F.pad, just to allow for reflect padding on small input. - If this is the case, we insert extra 0 padding to the right before the reflection happen. - """ - length = x.shape[-1] - padding_left, padding_right = paddings - assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right) - if mode == 'reflect': - max_pad = max(padding_left, padding_right) - extra_pad = 0 - if length <= max_pad: - extra_pad = max_pad - length + 1 - x = F.pad(x, (0, extra_pad)) - padded = F.pad(x, paddings, mode, value) - end = padded.shape[-1] - extra_pad - return padded[..., :end] - else: - return F.pad(x, paddings, mode, value) - - -def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]): - """Remove padding from x, handling properly zero padding. Only for 1d!""" - padding_left, padding_right = paddings - assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right) - assert (padding_left + padding_right) <= x.shape[-1] - end = x.shape[-1] - padding_right - return x[..., padding_left: end] - - -class Conv1d(nn.Conv1d): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - def forward(self, x: Tensor, causal=False) -> Tensor: - kernel_size = self.kernel_size[0] - stride = self.stride[0] - dilation = self.dilation[0] - kernel_size = (kernel_size - 1) * dilation + 1 # effective kernel size with dilations - padding_total = kernel_size - stride - extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total) - if causal: - # Left padding for causal - x = pad1d(x, (padding_total, extra_padding)) - else: - # Asymmetric padding required for odd strides - padding_right = padding_total // 2 - padding_left = padding_total - padding_right - x = pad1d(x, (padding_left, padding_right + extra_padding)) - return super().forward(x) - -class ConvTranspose1d(nn.ConvTranspose1d): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - def forward(self, x: Tensor, causal=False) -> Tensor: - kernel_size = self.kernel_size[0] - stride = self.stride[0] - padding_total = kernel_size - stride - - y = super().forward(x) - - # We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be - # removed at the very end, when keeping only the right length for the output, - # as removing it here would require also passing the length at the matching layer - # in the encoder. - if causal: - padding_right = ceil(padding_total) - padding_left = padding_total - padding_right - y = unpad1d(y, (padding_left, padding_right)) - else: - # Asymmetric padding required for odd strides - padding_right = padding_total // 2 - padding_left = padding_total - padding_right - y = unpad1d(y, (padding_left, padding_right)) - return y - - -def Downsample1d( - in_channels: int, out_channels: int, factor: int, kernel_multiplier: int = 2 -) -> nn.Module: - assert kernel_multiplier % 2 == 0, "Kernel multiplier must be even" - - return Conv1d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=factor * kernel_multiplier + 1, - stride=factor - ) - - -def Upsample1d( - in_channels: int, out_channels: int, factor: int, use_nearest: bool = False -) -> nn.Module: - - if factor == 1: - return Conv1d( - in_channels=in_channels, out_channels=out_channels, kernel_size=3 - ) - - if use_nearest: - return nn.Sequential( - nn.Upsample(scale_factor=factor, mode="nearest"), - Conv1d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=3 - ), - ) - else: - return ConvTranspose1d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=factor * 2, - stride=factor - ) - - -class ConvBlock1d(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - *, - kernel_size: int = 3, - stride: int = 1, - dilation: int = 1, - num_groups: int = 8, - use_norm: bool = True, - use_snake: bool = False - ) -> None: - super().__init__() - - self.groupnorm = ( - nn.GroupNorm(num_groups=num_groups, num_channels=in_channels) - if use_norm - else nn.Identity() - ) - - if use_snake: - self.activation = Snake1d(in_channels) - else: - self.activation = nn.SiLU() - - self.project = Conv1d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - stride=stride, - dilation=dilation, - ) - - def forward( - self, x: Tensor, scale_shift: Optional[Tuple[Tensor, Tensor]] = None, causal=False - ) -> Tensor: - x = self.groupnorm(x) - if exists(scale_shift): - scale, shift = scale_shift - x = x * (scale + 1) + shift - x = self.activation(x) - return self.project(x, causal=causal) - - -class MappingToScaleShift(nn.Module): - def __init__( - self, - features: int, - channels: int, - ): - super().__init__() - - self.to_scale_shift = nn.Sequential( - nn.SiLU(), - nn.Linear(in_features=features, out_features=channels * 2), - ) - - def forward(self, mapping: Tensor) -> Tuple[Tensor, Tensor]: - scale_shift = self.to_scale_shift(mapping) - scale_shift = rearrange(scale_shift, "b c -> b c 1") - scale, shift = scale_shift.chunk(2, dim=1) - return scale, shift - - -class ResnetBlock1d(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - *, - kernel_size: int = 3, - stride: int = 1, - dilation: int = 1, - use_norm: bool = True, - use_snake: bool = False, - num_groups: int = 8, - context_mapping_features: Optional[int] = None, - ) -> None: - super().__init__() - - self.use_mapping = exists(context_mapping_features) - - self.block1 = ConvBlock1d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - stride=stride, - dilation=dilation, - use_norm=use_norm, - num_groups=num_groups, - use_snake=use_snake - ) - - if self.use_mapping: - assert exists(context_mapping_features) - self.to_scale_shift = MappingToScaleShift( - features=context_mapping_features, channels=out_channels - ) - - self.block2 = ConvBlock1d( - in_channels=out_channels, - out_channels=out_channels, - use_norm=use_norm, - num_groups=num_groups, - use_snake=use_snake - ) - - self.to_out = ( - Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=1) - if in_channels != out_channels - else nn.Identity() - ) - - def forward(self, x: Tensor, mapping: Optional[Tensor] = None, causal=False) -> Tensor: - assert_message = "context mapping required if context_mapping_features > 0" - assert not (self.use_mapping ^ exists(mapping)), assert_message - - h = self.block1(x, causal=causal) - - scale_shift = None - if self.use_mapping: - scale_shift = self.to_scale_shift(mapping) - - h = self.block2(h, scale_shift=scale_shift, causal=causal) - - return h + self.to_out(x) - - -class Patcher(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - patch_size: int, - context_mapping_features: Optional[int] = None, - use_snake: bool = False, - ): - super().__init__() - assert_message = f"out_channels must be divisible by patch_size ({patch_size})" - assert out_channels % patch_size == 0, assert_message - self.patch_size = patch_size - - self.block = ResnetBlock1d( - in_channels=in_channels, - out_channels=out_channels // patch_size, - num_groups=1, - context_mapping_features=context_mapping_features, - use_snake=use_snake - ) - - def forward(self, x: Tensor, mapping: Optional[Tensor] = None, causal=False) -> Tensor: - x = self.block(x, mapping, causal=causal) - x = rearrange(x, "b c (l p) -> b (c p) l", p=self.patch_size) - return x - - -class Unpatcher(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - patch_size: int, - context_mapping_features: Optional[int] = None, - use_snake: bool = False - ): - super().__init__() - assert_message = f"in_channels must be divisible by patch_size ({patch_size})" - assert in_channels % patch_size == 0, assert_message - self.patch_size = patch_size - - self.block = ResnetBlock1d( - in_channels=in_channels // patch_size, - out_channels=out_channels, - num_groups=1, - context_mapping_features=context_mapping_features, - use_snake=use_snake - ) - - def forward(self, x: Tensor, mapping: Optional[Tensor] = None, causal=False) -> Tensor: - x = rearrange(x, " b (c p) l -> b c (l p) ", p=self.patch_size) - x = self.block(x, mapping, causal=causal) - return x - - -""" -Attention Components -""" -def FeedForward(features: int, multiplier: int) -> nn.Module: - mid_features = features * multiplier - return nn.Sequential( - nn.Linear(in_features=features, out_features=mid_features), - nn.GELU(), - nn.Linear(in_features=mid_features, out_features=features), - ) - -def add_mask(sim: Tensor, mask: Tensor) -> Tensor: - b, ndim = sim.shape[0], mask.ndim - if ndim == 3: - mask = rearrange(mask, "b n m -> b 1 n m") - if ndim == 2: - mask = repeat(mask, "n m -> b 1 n m", b=b) - max_neg_value = -torch.finfo(sim.dtype).max - sim = sim.masked_fill(~mask, max_neg_value) - return sim - -def causal_mask(q: Tensor, k: Tensor) -> Tensor: - b, i, j, device = q.shape[0], q.shape[-2], k.shape[-2], q.device - mask = ~torch.ones((i, j), dtype=torch.bool, device=device).triu(j - i + 1) - mask = repeat(mask, "n m -> b n m", b=b) - return mask - -class AttentionBase(nn.Module): - def __init__( - self, - features: int, - *, - head_features: int, - num_heads: int, - out_features: Optional[int] = None, - ): - super().__init__() - self.scale = head_features**-0.5 - self.num_heads = num_heads - mid_features = head_features * num_heads - out_features = default(out_features, features) - - self.to_out = nn.Linear( - in_features=mid_features, out_features=out_features - ) - - self.use_flash = torch.cuda.is_available() and version.parse(torch.__version__) >= version.parse('2.0.0') - - if not self.use_flash: - return - - device_properties = torch.cuda.get_device_properties(torch.device('cuda')) - - if device_properties.major == 8 and device_properties.minor == 0: - # Use flash attention for A100 GPUs - self.sdp_kernel_config = (True, False, False) - else: - # Don't use flash attention for other GPUs - self.sdp_kernel_config = (False, True, True) - - def forward( - self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None, is_causal: bool = False - ) -> Tensor: - # Split heads - q, k, v = rearrange_many((q, k, v), "b n (h d) -> b h n d", h=self.num_heads) - - if not self.use_flash: - if is_causal and not mask: - # Mask out future tokens for causal attention - mask = causal_mask(q, k) - - # Compute similarity matrix and add eventual mask - sim = einsum("... n d, ... m d -> ... n m", q, k) * self.scale - sim = add_mask(sim, mask) if exists(mask) else sim - - # Get attention matrix with softmax - attn = sim.softmax(dim=-1, dtype=torch.float32) - - # Compute values - out = einsum("... n m, ... m d -> ... n d", attn, v) - else: - with sdp_kernel(*self.sdp_kernel_config): - out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, is_causal=is_causal) - - out = rearrange(out, "b h n d -> b n (h d)") - return self.to_out(out) - -class Attention(nn.Module): - def __init__( - self, - features: int, - *, - head_features: int, - num_heads: int, - out_features: Optional[int] = None, - context_features: Optional[int] = None, - causal: bool = False, - ): - super().__init__() - self.context_features = context_features - self.causal = causal - mid_features = head_features * num_heads - context_features = default(context_features, features) - - self.norm = nn.LayerNorm(features) - self.norm_context = nn.LayerNorm(context_features) - self.to_q = nn.Linear( - in_features=features, out_features=mid_features, bias=False - ) - self.to_kv = nn.Linear( - in_features=context_features, out_features=mid_features * 2, bias=False - ) - self.attention = AttentionBase( - features, - num_heads=num_heads, - head_features=head_features, - out_features=out_features, - ) - - def forward( - self, - x: Tensor, # [b, n, c] - context: Optional[Tensor] = None, # [b, m, d] - context_mask: Optional[Tensor] = None, # [b, m], false is masked, - causal: Optional[bool] = False, - ) -> Tensor: - assert_message = "You must provide a context when using context_features" - assert not self.context_features or exists(context), assert_message - # Use context if provided - context = default(context, x) - # Normalize then compute q from input and k,v from context - x, context = self.norm(x), self.norm_context(context) - - q, k, v = (self.to_q(x), *torch.chunk(self.to_kv(context), chunks=2, dim=-1)) - - if exists(context_mask): - # Mask out cross-attention for padding tokens - mask = repeat(context_mask, "b m -> b m d", d=v.shape[-1]) - k, v = k * mask, v * mask - - # Compute and return attention - return self.attention(q, k, v, is_causal=self.causal or causal) - - -def FeedForward(features: int, multiplier: int) -> nn.Module: - mid_features = features * multiplier - return nn.Sequential( - nn.Linear(in_features=features, out_features=mid_features), - nn.GELU(), - nn.Linear(in_features=mid_features, out_features=features), - ) - -""" -Transformer Blocks -""" - - -class TransformerBlock(nn.Module): - def __init__( - self, - features: int, - num_heads: int, - head_features: int, - multiplier: int, - context_features: Optional[int] = None, - ): - super().__init__() - - self.use_cross_attention = exists(context_features) and context_features > 0 - - self.attention = Attention( - features=features, - num_heads=num_heads, - head_features=head_features - ) - - if self.use_cross_attention: - self.cross_attention = Attention( - features=features, - num_heads=num_heads, - head_features=head_features, - context_features=context_features - ) - - self.feed_forward = FeedForward(features=features, multiplier=multiplier) - - def forward(self, x: Tensor, *, context: Optional[Tensor] = None, context_mask: Optional[Tensor] = None, causal: Optional[bool] = False) -> Tensor: - x = self.attention(x, causal=causal) + x - if self.use_cross_attention: - x = self.cross_attention(x, context=context, context_mask=context_mask) + x - x = self.feed_forward(x) + x - return x - - -""" -Transformers -""" - - -class Transformer1d(nn.Module): - def __init__( - self, - num_layers: int, - channels: int, - num_heads: int, - head_features: int, - multiplier: int, - context_features: Optional[int] = None, - ): - super().__init__() - - self.to_in = nn.Sequential( - nn.GroupNorm(num_groups=32, num_channels=channels, eps=1e-6, affine=True), - Conv1d( - in_channels=channels, - out_channels=channels, - kernel_size=1, - ), - Rearrange("b c t -> b t c"), - ) - - self.blocks = nn.ModuleList( - [ - TransformerBlock( - features=channels, - head_features=head_features, - num_heads=num_heads, - multiplier=multiplier, - context_features=context_features, - ) - for i in range(num_layers) - ] - ) - - self.to_out = nn.Sequential( - Rearrange("b t c -> b c t"), - Conv1d( - in_channels=channels, - out_channels=channels, - kernel_size=1, - ), - ) - - def forward(self, x: Tensor, *, context: Optional[Tensor] = None, context_mask: Optional[Tensor] = None, causal=False) -> Tensor: - x = self.to_in(x) - for block in self.blocks: - x = block(x, context=context, context_mask=context_mask, causal=causal) - x = self.to_out(x) - return x - - -""" -Time Embeddings -""" - - -class SinusoidalEmbedding(nn.Module): - def __init__(self, dim: int): - super().__init__() - self.dim = dim - - def forward(self, x: Tensor) -> Tensor: - device, half_dim = x.device, self.dim // 2 - emb = torch.tensor(log(10000) / (half_dim - 1), device=device) - emb = torch.exp(torch.arange(half_dim, device=device) * -emb) - emb = rearrange(x, "i -> i 1") * rearrange(emb, "j -> 1 j") - return torch.cat((emb.sin(), emb.cos()), dim=-1) - - -class LearnedPositionalEmbedding(nn.Module): - """Used for continuous time""" - - def __init__(self, dim: int): - super().__init__() - assert (dim % 2) == 0 - half_dim = dim // 2 - self.weights = nn.Parameter(torch.randn(half_dim)) - - def forward(self, x: Tensor) -> Tensor: - x = rearrange(x, "b -> b 1") - freqs = x * rearrange(self.weights, "d -> 1 d") * 2 * pi - fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1) - fouriered = torch.cat((x, fouriered), dim=-1) - return fouriered - - -def TimePositionalEmbedding(dim: int, out_features: int) -> nn.Module: - return nn.Sequential( - LearnedPositionalEmbedding(dim), - nn.Linear(in_features=dim + 1, out_features=out_features), - ) - - -""" -Encoder/Decoder Components -""" - - -class DownsampleBlock1d(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - *, - factor: int, - num_groups: int, - num_layers: int, - kernel_multiplier: int = 2, - use_pre_downsample: bool = True, - use_skip: bool = False, - use_snake: bool = False, - extract_channels: int = 0, - context_channels: int = 0, - num_transformer_blocks: int = 0, - attention_heads: Optional[int] = None, - attention_features: Optional[int] = None, - attention_multiplier: Optional[int] = None, - context_mapping_features: Optional[int] = None, - context_embedding_features: Optional[int] = None, - ): - super().__init__() - self.use_pre_downsample = use_pre_downsample - self.use_skip = use_skip - self.use_transformer = num_transformer_blocks > 0 - self.use_extract = extract_channels > 0 - self.use_context = context_channels > 0 - - channels = out_channels if use_pre_downsample else in_channels - - self.downsample = Downsample1d( - in_channels=in_channels, - out_channels=out_channels, - factor=factor, - kernel_multiplier=kernel_multiplier, - ) - - self.blocks = nn.ModuleList( - [ - ResnetBlock1d( - in_channels=channels + context_channels if i == 0 else channels, - out_channels=channels, - num_groups=num_groups, - context_mapping_features=context_mapping_features, - use_snake=use_snake - ) - for i in range(num_layers) - ] - ) - - if self.use_transformer: - assert ( - (exists(attention_heads) or exists(attention_features)) - and exists(attention_multiplier) - ) - - if attention_features is None and attention_heads is not None: - attention_features = channels // attention_heads - - if attention_heads is None and attention_features is not None: - attention_heads = channels // attention_features - - self.transformer = Transformer1d( - num_layers=num_transformer_blocks, - channels=channels, - num_heads=attention_heads, - head_features=attention_features, - multiplier=attention_multiplier, - context_features=context_embedding_features - ) - - if self.use_extract: - num_extract_groups = min(num_groups, extract_channels) - self.to_extracted = ResnetBlock1d( - in_channels=out_channels, - out_channels=extract_channels, - num_groups=num_extract_groups, - use_snake=use_snake - ) - - def forward( - self, - x: Tensor, - *, - mapping: Optional[Tensor] = None, - channels: Optional[Tensor] = None, - embedding: Optional[Tensor] = None, - embedding_mask: Optional[Tensor] = None, - causal: Optional[bool] = False - ) -> Union[Tuple[Tensor, List[Tensor]], Tensor]: - - if self.use_pre_downsample: - x = self.downsample(x) - - if self.use_context and exists(channels): - x = torch.cat([x, channels], dim=1) - - skips = [] - for block in self.blocks: - x = block(x, mapping=mapping, causal=causal) - skips += [x] if self.use_skip else [] - - if self.use_transformer: - x = self.transformer(x, context=embedding, context_mask=embedding_mask, causal=causal) - skips += [x] if self.use_skip else [] - - if not self.use_pre_downsample: - x = self.downsample(x) - - if self.use_extract: - extracted = self.to_extracted(x) - return x, extracted - - return (x, skips) if self.use_skip else x - - -class UpsampleBlock1d(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - *, - factor: int, - num_layers: int, - num_groups: int, - use_nearest: bool = False, - use_pre_upsample: bool = False, - use_skip: bool = False, - use_snake: bool = False, - skip_channels: int = 0, - use_skip_scale: bool = False, - extract_channels: int = 0, - num_transformer_blocks: int = 0, - attention_heads: Optional[int] = None, - attention_features: Optional[int] = None, - attention_multiplier: Optional[int] = None, - context_mapping_features: Optional[int] = None, - context_embedding_features: Optional[int] = None, - ): - super().__init__() - - self.use_extract = extract_channels > 0 - self.use_pre_upsample = use_pre_upsample - self.use_transformer = num_transformer_blocks > 0 - self.use_skip = use_skip - self.skip_scale = 2 ** -0.5 if use_skip_scale else 1.0 - - channels = out_channels if use_pre_upsample else in_channels - - self.blocks = nn.ModuleList( - [ - ResnetBlock1d( - in_channels=channels + skip_channels, - out_channels=channels, - num_groups=num_groups, - context_mapping_features=context_mapping_features, - use_snake=use_snake - ) - for _ in range(num_layers) - ] - ) - - if self.use_transformer: - assert ( - (exists(attention_heads) or exists(attention_features)) - and exists(attention_multiplier) - ) - - if attention_features is None and attention_heads is not None: - attention_features = channels // attention_heads - - if attention_heads is None and attention_features is not None: - attention_heads = channels // attention_features - - self.transformer = Transformer1d( - num_layers=num_transformer_blocks, - channels=channels, - num_heads=attention_heads, - head_features=attention_features, - multiplier=attention_multiplier, - context_features=context_embedding_features, - ) - - self.upsample = Upsample1d( - in_channels=in_channels, - out_channels=out_channels, - factor=factor, - use_nearest=use_nearest, - ) - - if self.use_extract: - num_extract_groups = min(num_groups, extract_channels) - self.to_extracted = ResnetBlock1d( - in_channels=out_channels, - out_channels=extract_channels, - num_groups=num_extract_groups, - use_snake=use_snake - ) - - def add_skip(self, x: Tensor, skip: Tensor) -> Tensor: - return torch.cat([x, skip * self.skip_scale], dim=1) - - def forward( - self, - x: Tensor, - *, - skips: Optional[List[Tensor]] = None, - mapping: Optional[Tensor] = None, - embedding: Optional[Tensor] = None, - embedding_mask: Optional[Tensor] = None, - causal: Optional[bool] = False - ) -> Union[Tuple[Tensor, Tensor], Tensor]: - - if self.use_pre_upsample: - x = self.upsample(x) - - for block in self.blocks: - x = self.add_skip(x, skip=skips.pop()) if exists(skips) else x - x = block(x, mapping=mapping, causal=causal) - - if self.use_transformer: - x = self.transformer(x, context=embedding, context_mask=embedding_mask, causal=causal) - - if not self.use_pre_upsample: - x = self.upsample(x) - - if self.use_extract: - extracted = self.to_extracted(x) - return x, extracted - - return x - - -class BottleneckBlock1d(nn.Module): - def __init__( - self, - channels: int, - *, - num_groups: int, - num_transformer_blocks: int = 0, - attention_heads: Optional[int] = None, - attention_features: Optional[int] = None, - attention_multiplier: Optional[int] = None, - context_mapping_features: Optional[int] = None, - context_embedding_features: Optional[int] = None, - use_snake: bool = False, - ): - super().__init__() - self.use_transformer = num_transformer_blocks > 0 - - self.pre_block = ResnetBlock1d( - in_channels=channels, - out_channels=channels, - num_groups=num_groups, - context_mapping_features=context_mapping_features, - use_snake=use_snake - ) - - if self.use_transformer: - assert ( - (exists(attention_heads) or exists(attention_features)) - and exists(attention_multiplier) - ) - - if attention_features is None and attention_heads is not None: - attention_features = channels // attention_heads - - if attention_heads is None and attention_features is not None: - attention_heads = channels // attention_features - - self.transformer = Transformer1d( - num_layers=num_transformer_blocks, - channels=channels, - num_heads=attention_heads, - head_features=attention_features, - multiplier=attention_multiplier, - context_features=context_embedding_features, - ) - - self.post_block = ResnetBlock1d( - in_channels=channels, - out_channels=channels, - num_groups=num_groups, - context_mapping_features=context_mapping_features, - use_snake=use_snake - ) - - def forward( - self, - x: Tensor, - *, - mapping: Optional[Tensor] = None, - embedding: Optional[Tensor] = None, - embedding_mask: Optional[Tensor] = None, - causal: Optional[bool] = False - ) -> Tensor: - x = self.pre_block(x, mapping=mapping, causal=causal) - if self.use_transformer: - x = self.transformer(x, context=embedding, context_mask=embedding_mask, causal=causal) - x = self.post_block(x, mapping=mapping, causal=causal) - return x - - -""" -UNet -""" - - -class UNet1d(nn.Module): - def __init__( - self, - in_channels: int, - channels: int, - multipliers: Sequence[int], - factors: Sequence[int], - num_blocks: Sequence[int], - attentions: Sequence[int], - patch_size: int = 1, - resnet_groups: int = 8, - use_context_time: bool = True, - kernel_multiplier_downsample: int = 2, - use_nearest_upsample: bool = False, - use_skip_scale: bool = True, - use_snake: bool = False, - use_stft: bool = False, - use_stft_context: bool = False, - out_channels: Optional[int] = None, - context_features: Optional[int] = None, - context_features_multiplier: int = 4, - context_channels: Optional[Sequence[int]] = None, - context_embedding_features: Optional[int] = None, - **kwargs, - ): - super().__init__() - out_channels = default(out_channels, in_channels) - context_channels = list(default(context_channels, [])) - num_layers = len(multipliers) - 1 - use_context_features = exists(context_features) - use_context_channels = len(context_channels) > 0 - context_mapping_features = None - - attention_kwargs, kwargs = groupby("attention_", kwargs, keep_prefix=True) - - self.num_layers = num_layers - self.use_context_time = use_context_time - self.use_context_features = use_context_features - self.use_context_channels = use_context_channels - self.use_stft = use_stft - self.use_stft_context = use_stft_context - - self.context_features = context_features - context_channels_pad_length = num_layers + 1 - len(context_channels) - context_channels = context_channels + [0] * context_channels_pad_length - self.context_channels = context_channels - self.context_embedding_features = context_embedding_features - - if use_context_channels: - has_context = [c > 0 for c in context_channels] - self.has_context = has_context - self.channels_ids = [sum(has_context[:i]) for i in range(len(has_context))] - - assert ( - len(factors) == num_layers - and len(attentions) >= num_layers - and len(num_blocks) == num_layers - ) - - if use_context_time or use_context_features: - context_mapping_features = channels * context_features_multiplier - - self.to_mapping = nn.Sequential( - nn.Linear(context_mapping_features, context_mapping_features), - nn.GELU(), - nn.Linear(context_mapping_features, context_mapping_features), - nn.GELU(), - ) - - if use_context_time: - assert exists(context_mapping_features) - self.to_time = nn.Sequential( - TimePositionalEmbedding( - dim=channels, out_features=context_mapping_features - ), - nn.GELU(), - ) - - if use_context_features: - assert exists(context_features) and exists(context_mapping_features) - self.to_features = nn.Sequential( - nn.Linear( - in_features=context_features, out_features=context_mapping_features - ), - nn.GELU(), - ) - - if use_stft: - stft_kwargs, kwargs = groupby("stft_", kwargs) - assert "num_fft" in stft_kwargs, "stft_num_fft required if use_stft=True" - stft_channels = (stft_kwargs["num_fft"] // 2 + 1) * 2 - in_channels *= stft_channels - out_channels *= stft_channels - context_channels[0] *= stft_channels if use_stft_context else 1 - assert exists(in_channels) and exists(out_channels) - self.stft = STFT(**stft_kwargs) - - assert not kwargs, f"Unknown arguments: {', '.join(list(kwargs.keys()))}" - - self.to_in = Patcher( - in_channels=in_channels + context_channels[0], - out_channels=channels * multipliers[0], - patch_size=patch_size, - context_mapping_features=context_mapping_features, - use_snake=use_snake - ) - - self.downsamples = nn.ModuleList( - [ - DownsampleBlock1d( - in_channels=channels * multipliers[i], - out_channels=channels * multipliers[i + 1], - context_mapping_features=context_mapping_features, - context_channels=context_channels[i + 1], - context_embedding_features=context_embedding_features, - num_layers=num_blocks[i], - factor=factors[i], - kernel_multiplier=kernel_multiplier_downsample, - num_groups=resnet_groups, - use_pre_downsample=True, - use_skip=True, - use_snake=use_snake, - num_transformer_blocks=attentions[i], - **attention_kwargs, - ) - for i in range(num_layers) - ] - ) - - self.bottleneck = BottleneckBlock1d( - channels=channels * multipliers[-1], - context_mapping_features=context_mapping_features, - context_embedding_features=context_embedding_features, - num_groups=resnet_groups, - num_transformer_blocks=attentions[-1], - use_snake=use_snake, - **attention_kwargs, - ) - - self.upsamples = nn.ModuleList( - [ - UpsampleBlock1d( - in_channels=channels * multipliers[i + 1], - out_channels=channels * multipliers[i], - context_mapping_features=context_mapping_features, - context_embedding_features=context_embedding_features, - num_layers=num_blocks[i] + (1 if attentions[i] else 0), - factor=factors[i], - use_nearest=use_nearest_upsample, - num_groups=resnet_groups, - use_skip_scale=use_skip_scale, - use_pre_upsample=False, - use_skip=True, - use_snake=use_snake, - skip_channels=channels * multipliers[i + 1], - num_transformer_blocks=attentions[i], - **attention_kwargs, - ) - for i in reversed(range(num_layers)) - ] - ) - - self.to_out = Unpatcher( - in_channels=channels * multipliers[0], - out_channels=out_channels, - patch_size=patch_size, - context_mapping_features=context_mapping_features, - use_snake=use_snake - ) - - def get_channels( - self, channels_list: Optional[Sequence[Tensor]] = None, layer: int = 0 - ) -> Optional[Tensor]: - """Gets context channels at `layer` and checks that shape is correct""" - use_context_channels = self.use_context_channels and self.has_context[layer] - if not use_context_channels: - return None - assert exists(channels_list), "Missing context" - # Get channels index (skipping zero channel contexts) - channels_id = self.channels_ids[layer] - # Get channels - channels = channels_list[channels_id] - message = f"Missing context for layer {layer} at index {channels_id}" - assert exists(channels), message - # Check channels - num_channels = self.context_channels[layer] - message = f"Expected context with {num_channels} channels at idx {channels_id}" - assert channels.shape[1] == num_channels, message - # STFT channels if requested - channels = self.stft.encode1d(channels) if self.use_stft_context else channels # type: ignore # noqa - return channels - - def get_mapping( - self, time: Optional[Tensor] = None, features: Optional[Tensor] = None - ) -> Optional[Tensor]: - """Combines context time features and features into mapping""" - items, mapping = [], None - # Compute time features - if self.use_context_time: - assert_message = "use_context_time=True but no time features provided" - assert exists(time), assert_message - items += [self.to_time(time)] - # Compute features - if self.use_context_features: - assert_message = "context_features exists but no features provided" - assert exists(features), assert_message - items += [self.to_features(features)] - # Compute joint mapping - if self.use_context_time or self.use_context_features: - mapping = reduce(torch.stack(items), "n b m -> b m", "sum") - mapping = self.to_mapping(mapping) - return mapping - - def forward( - self, - x: Tensor, - time: Optional[Tensor] = None, - *, - features: Optional[Tensor] = None, - channels_list: Optional[Sequence[Tensor]] = None, - embedding: Optional[Tensor] = None, - embedding_mask: Optional[Tensor] = None, - causal: Optional[bool] = False, - ) -> Tensor: - channels = self.get_channels(channels_list, layer=0) - # Apply stft if required - x = self.stft.encode1d(x) if self.use_stft else x # type: ignore - # Concat context channels at layer 0 if provided - x = torch.cat([x, channels], dim=1) if exists(channels) else x - # Compute mapping from time and features - mapping = self.get_mapping(time, features) - x = self.to_in(x, mapping, causal=causal) - skips_list = [x] - - for i, downsample in enumerate(self.downsamples): - channels = self.get_channels(channels_list, layer=i + 1) - x, skips = downsample( - x, mapping=mapping, channels=channels, embedding=embedding, embedding_mask=embedding_mask, causal=causal - ) - skips_list += [skips] - - x = self.bottleneck(x, mapping=mapping, embedding=embedding, embedding_mask=embedding_mask, causal=causal) - - for i, upsample in enumerate(self.upsamples): - skips = skips_list.pop() - x = upsample(x, skips=skips, mapping=mapping, embedding=embedding, embedding_mask=embedding_mask, causal=causal) - - x += skips_list.pop() - x = self.to_out(x, mapping, causal=causal) - x = self.stft.decode1d(x) if self.use_stft else x - - return x - - -""" Conditioning Modules """ - - -class FixedEmbedding(nn.Module): - def __init__(self, max_length: int, features: int): - super().__init__() - self.max_length = max_length - self.embedding = nn.Embedding(max_length, features) - - def forward(self, x: Tensor) -> Tensor: - batch_size, length, device = *x.shape[0:2], x.device - assert_message = "Input sequence length must be <= max_length" - assert length <= self.max_length, assert_message - position = torch.arange(length, device=device) - fixed_embedding = self.embedding(position) - fixed_embedding = repeat(fixed_embedding, "n d -> b n d", b=batch_size) - return fixed_embedding - - -def rand_bool(shape: Any, proba: float, device: Any = None) -> Tensor: - if proba == 1: - return torch.ones(shape, device=device, dtype=torch.bool) - elif proba == 0: - return torch.zeros(shape, device=device, dtype=torch.bool) - else: - return torch.bernoulli(torch.full(shape, proba, device=device)).to(torch.bool) - - -class UNetCFG1d(UNet1d): - - """UNet1d with Classifier-Free Guidance""" - - def __init__( - self, - context_embedding_max_length: int, - context_embedding_features: int, - use_xattn_time: bool = False, - **kwargs, - ): - super().__init__( - context_embedding_features=context_embedding_features, **kwargs - ) - - self.use_xattn_time = use_xattn_time - - if use_xattn_time: - assert exists(context_embedding_features) - self.to_time_embedding = nn.Sequential( - TimePositionalEmbedding( - dim=kwargs["channels"], out_features=context_embedding_features - ), - nn.GELU(), - ) - - context_embedding_max_length += 1 # Add one for time embedding - - self.fixed_embedding = FixedEmbedding( - max_length=context_embedding_max_length, features=context_embedding_features - ) - - def forward( # type: ignore - self, - x: Tensor, - time: Tensor, - *, - embedding: Tensor, - embedding_mask: Optional[Tensor] = None, - embedding_scale: float = 1.0, - embedding_mask_proba: float = 0.0, - batch_cfg: bool = False, - rescale_cfg: bool = False, - scale_phi: float = 0.4, - negative_embedding: Optional[Tensor] = None, - negative_embedding_mask: Optional[Tensor] = None, - **kwargs, - ) -> Tensor: - b, device = embedding.shape[0], embedding.device - - if self.use_xattn_time: - embedding = torch.cat([embedding, self.to_time_embedding(time).unsqueeze(1)], dim=1) - - if embedding_mask is not None: - embedding_mask = torch.cat([embedding_mask, torch.ones((b, 1), device=device)], dim=1) - - fixed_embedding = self.fixed_embedding(embedding) - - if embedding_mask_proba > 0.0: - # Randomly mask embedding - batch_mask = rand_bool( - shape=(b, 1, 1), proba=embedding_mask_proba, device=device - ) - embedding = torch.where(batch_mask, fixed_embedding, embedding) - - if embedding_scale != 1.0: - if batch_cfg: - batch_x = torch.cat([x, x], dim=0) - batch_time = torch.cat([time, time], dim=0) - - if negative_embedding is not None: - if negative_embedding_mask is not None: - negative_embedding_mask = negative_embedding_mask.to(torch.bool).unsqueeze(2) - - negative_embedding = torch.where(negative_embedding_mask, negative_embedding, fixed_embedding) - - batch_embed = torch.cat([embedding, negative_embedding], dim=0) - - else: - batch_embed = torch.cat([embedding, fixed_embedding], dim=0) - - batch_mask = None - if embedding_mask is not None: - batch_mask = torch.cat([embedding_mask, embedding_mask], dim=0) - - batch_features = None - features = kwargs.pop("features", None) - if self.use_context_features: - batch_features = torch.cat([features, features], dim=0) - - batch_channels = None - channels_list = kwargs.pop("channels_list", None) - if self.use_context_channels: - batch_channels = [] - for channels in channels_list: - batch_channels += [torch.cat([channels, channels], dim=0)] - - # Compute both normal and fixed embedding outputs - batch_out = super().forward(batch_x, batch_time, embedding=batch_embed, embedding_mask=batch_mask, features=batch_features, channels_list=batch_channels, **kwargs) - out, out_masked = batch_out.chunk(2, dim=0) - - else: - # Compute both normal and fixed embedding outputs - out = super().forward(x, time, embedding=embedding, embedding_mask=embedding_mask, **kwargs) - out_masked = super().forward(x, time, embedding=fixed_embedding, embedding_mask=embedding_mask, **kwargs) - - out_cfg = out_masked + (out - out_masked) * embedding_scale - - if rescale_cfg: - - out_std = out.std(dim=1, keepdim=True) - out_cfg_std = out_cfg.std(dim=1, keepdim=True) - - return scale_phi * (out_cfg * (out_std/out_cfg_std)) + (1-scale_phi) * out_cfg - - else: - - return out_cfg - - else: - return super().forward(x, time, embedding=embedding, embedding_mask=embedding_mask, **kwargs) - - -class UNetNCCA1d(UNet1d): - - """UNet1d with Noise Channel Conditioning Augmentation""" - - def __init__(self, context_features: int, **kwargs): - super().__init__(context_features=context_features, **kwargs) - self.embedder = NumberEmbedder(features=context_features) - - def expand(self, x: Any, shape: Tuple[int, ...]) -> Tensor: - x = x if torch.is_tensor(x) else torch.tensor(x) - return x.expand(shape) - - def forward( # type: ignore - self, - x: Tensor, - time: Tensor, - *, - channels_list: Sequence[Tensor], - channels_augmentation: Union[ - bool, Sequence[bool], Sequence[Sequence[bool]], Tensor - ] = False, - channels_scale: Union[ - float, Sequence[float], Sequence[Sequence[float]], Tensor - ] = 0, - **kwargs, - ) -> Tensor: - b, n = x.shape[0], len(channels_list) - channels_augmentation = self.expand(channels_augmentation, shape=(b, n)).to(x) - channels_scale = self.expand(channels_scale, shape=(b, n)).to(x) - - # Augmentation (for each channel list item) - for i in range(n): - scale = channels_scale[:, i] * channels_augmentation[:, i] - scale = rearrange(scale, "b -> b 1 1") - item = channels_list[i] - channels_list[i] = torch.randn_like(item) * scale + item * (1 - scale) # type: ignore # noqa - - # Scale embedding (sum reduction if more than one channel list item) - channels_scale_emb = self.embedder(channels_scale) - channels_scale_emb = reduce(channels_scale_emb, "b n d -> b d", "sum") - - return super().forward( - x=x, - time=time, - channels_list=channels_list, - features=channels_scale_emb, - **kwargs, - ) - - -class UNetAll1d(UNetCFG1d, UNetNCCA1d): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - def forward(self, *args, **kwargs): # type: ignore - return UNetCFG1d.forward(self, *args, **kwargs) - - -def XUNet1d(type: str = "base", **kwargs) -> UNet1d: - if type == "base": - return UNet1d(**kwargs) - elif type == "all": - return UNetAll1d(**kwargs) - elif type == "cfg": - return UNetCFG1d(**kwargs) - elif type == "ncca": - return UNetNCCA1d(**kwargs) - else: - raise ValueError(f"Unknown XUNet1d type: {type}") - -class NumberEmbedder(nn.Module): - def __init__( - self, - features: int, - dim: int = 256, - ): - super().__init__() - self.features = features - self.embedding = TimePositionalEmbedding(dim=dim, out_features=features) - - def forward(self, x: Union[List[float], Tensor]) -> Tensor: - if not torch.is_tensor(x): - device = next(self.embedding.parameters()).device - x = torch.tensor(x, device=device) - assert isinstance(x, Tensor) - shape = x.shape - x = rearrange(x, "... -> (...)") - embedding = self.embedding(x) - x = embedding.view(*shape, self.features) - return x # type: ignore - - -""" -Audio Transforms -""" - - -class STFT(nn.Module): - """Helper for torch stft and istft""" - - def __init__( - self, - num_fft: int = 1023, - hop_length: int = 256, - window_length: Optional[int] = None, - length: Optional[int] = None, - use_complex: bool = False, - ): - super().__init__() - self.num_fft = num_fft - self.hop_length = default(hop_length, floor(num_fft // 4)) - self.window_length = default(window_length, num_fft) - self.length = length - self.register_buffer("window", torch.hann_window(self.window_length)) - self.use_complex = use_complex - - def encode(self, wave: Tensor) -> Tuple[Tensor, Tensor]: - b = wave.shape[0] - wave = rearrange(wave, "b c t -> (b c) t") - - stft = torch.stft( - wave, - n_fft=self.num_fft, - hop_length=self.hop_length, - win_length=self.window_length, - window=self.window, # type: ignore - return_complex=True, - normalized=True, - ) - - if self.use_complex: - # Returns real and imaginary - stft_a, stft_b = stft.real, stft.imag - else: - # Returns magnitude and phase matrices - magnitude, phase = torch.abs(stft), torch.angle(stft) - stft_a, stft_b = magnitude, phase - - return rearrange_many((stft_a, stft_b), "(b c) f l -> b c f l", b=b) - - def decode(self, stft_a: Tensor, stft_b: Tensor) -> Tensor: - b, l = stft_a.shape[0], stft_a.shape[-1] # noqa - length = closest_power_2(l * self.hop_length) - - stft_a, stft_b = rearrange_many((stft_a, stft_b), "b c f l -> (b c) f l") - - if self.use_complex: - real, imag = stft_a, stft_b - else: - magnitude, phase = stft_a, stft_b - real, imag = magnitude * torch.cos(phase), magnitude * torch.sin(phase) - - stft = torch.stack([real, imag], dim=-1) - - wave = torch.istft( - stft, - n_fft=self.num_fft, - hop_length=self.hop_length, - win_length=self.window_length, - window=self.window, # type: ignore - length=default(self.length, length), - normalized=True, - ) - - return rearrange(wave, "(b c) t -> b c t", b=b) - - def encode1d( - self, wave: Tensor, stacked: bool = True - ) -> Union[Tensor, Tuple[Tensor, Tensor]]: - stft_a, stft_b = self.encode(wave) - stft_a, stft_b = rearrange_many((stft_a, stft_b), "b c f l -> b (c f) l") - return torch.cat((stft_a, stft_b), dim=1) if stacked else (stft_a, stft_b) - - def decode1d(self, stft_pair: Tensor) -> Tensor: - f = self.num_fft // 2 + 1 - stft_a, stft_b = stft_pair.chunk(chunks=2, dim=1) - stft_a, stft_b = rearrange_many((stft_a, stft_b), "b (c f) l -> b c f l", f=f) - return self.decode(stft_a, stft_b) diff --git a/think_sound/models/diffusion_prior.py b/think_sound/models/diffusion_prior.py deleted file mode 100644 index 0cb15258d7656fb85ee763910dc9b500331de603..0000000000000000000000000000000000000000 --- a/think_sound/models/diffusion_prior.py +++ /dev/null @@ -1,82 +0,0 @@ -from enum import Enum -import typing as tp - -from .diffusion import ConditionedDiffusionModelWrapper -from ..inference.generation import generate_diffusion_cond -from ..inference.utils import prepare_audio - -import torch -from torch.nn import functional as F -from torchaudio import transforms as T - -# Define prior types enum -class PriorType(Enum): - MonoToStereo = 1 - -class DiffusionPrior(ConditionedDiffusionModelWrapper): - def __init__(self, *args, prior_type: PriorType=None, **kwargs): - super().__init__(*args, **kwargs) - self.prior_type = prior_type - -class MonoToStereoDiffusionPrior(DiffusionPrior): - def __init__(self, *args, **kwargs): - super().__init__(*args, prior_type=PriorType.MonoToStereo, **kwargs) - - def stereoize( - self, - audio: torch.Tensor, # (batch, channels, time) - video: torch.Tensor, - in_sr: int, - steps: int, - sampler_kwargs: dict = {}, - ): - """ - Generate stereo audio from mono audio using a pre-trained diffusion prior - - Args: - audio: The mono audio to convert to stereo - in_sr: The sample rate of the input audio - steps: The number of diffusion steps to run - sampler_kwargs: Keyword arguments to pass to the diffusion sampler - """ - - device = audio.device - - sample_rate = self.sample_rate - - # Resample input audio if necessary - if in_sr != sample_rate: - resample_tf = T.Resample(in_sr, sample_rate).to(audio.device) - audio = resample_tf(audio) - - audio_length = audio.shape[-1] - - # # Pad input audio to be compatible with the model - # min_length = self.min_input_length - # padded_input_length = audio_length + (min_length - (audio_length % min_length)) % min_length - - # # Pad input audio to be compatible with the model - # if padded_input_length > audio_length: - # audio = F.pad(audio, (0, padded_input_length - audio_length)) - - # Make audio mono, duplicate to stereo - dual_mono = audio.mean(1, keepdim=True).repeat(1, 2, 1) - - if self.pretransform is not None: - dual_mono = self.pretransform.encode(dual_mono) - - conditioning = self.conditioner([{'video':video}], device) - # Return fake stereo audio - conditioning["source"] = [dual_mono] - stereo_audio = generate_diffusion_cond( - self, - conditioning_tensors=conditioning, - steps=steps, - sample_size=audio_length, - sample_rate=sample_rate, - device=device, - cfg_scale=1, - **sampler_kwargs, - ) - - return stereo_audio \ No newline at end of file diff --git a/think_sound/models/discriminators.py b/think_sound/models/discriminators.py deleted file mode 100644 index b593168df965bb1f57881ea79edbc2f66478c6c2..0000000000000000000000000000000000000000 --- a/think_sound/models/discriminators.py +++ /dev/null @@ -1,546 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F -import numpy as np -from functools import reduce -import typing as tp -from einops import rearrange -from audiotools import AudioSignal, STFTParams -from dac.model.discriminator import WNConv1d, WNConv2d - -def get_hinge_losses(score_real, score_fake): - gen_loss = -score_fake.mean() - dis_loss = torch.relu(1 - score_real).mean() + torch.relu(1 + score_fake).mean() - return dis_loss, gen_loss - -class EncodecDiscriminator(nn.Module): - - def __init__(self, *args, **kwargs): - super().__init__() - - from encodec.msstftd import MultiScaleSTFTDiscriminator - - self.discriminators = MultiScaleSTFTDiscriminator(*args, **kwargs) - - def forward(self, x): - logits, features = self.discriminators(x) - return logits, features - - def loss(self, x, y): - feature_matching_distance = 0. - logits_true, feature_true = self.forward(x) - logits_fake, feature_fake = self.forward(y) - - dis_loss = torch.tensor(0.) - adv_loss = torch.tensor(0.) - - for i, (scale_true, scale_fake) in enumerate(zip(feature_true, feature_fake)): - - feature_matching_distance = feature_matching_distance + sum( - map( - lambda x, y: abs(x - y).mean(), - scale_true, - scale_fake, - )) / len(scale_true) - - _dis, _adv = get_hinge_losses( - logits_true[i], - logits_fake[i], - ) - - dis_loss = dis_loss + _dis - adv_loss = adv_loss + _adv - - return dis_loss, adv_loss, feature_matching_distance - -# Discriminators from oobleck - -IndividualDiscriminatorOut = tp.Tuple[torch.Tensor, tp.Sequence[torch.Tensor]] - -TensorDict = tp.Dict[str, torch.Tensor] - -class SharedDiscriminatorConvNet(nn.Module): - - def __init__( - self, - in_size: int, - convolution: tp.Union[nn.Conv1d, nn.Conv2d], - out_size: int = 1, - capacity: int = 32, - n_layers: int = 4, - kernel_size: int = 15, - stride: int = 4, - activation: tp.Callable[[], nn.Module] = lambda: nn.SiLU(), - normalization: tp.Callable[[nn.Module], nn.Module] = torch.nn.utils.weight_norm, - ) -> None: - super().__init__() - channels = [in_size] - channels += list(capacity * 2**np.arange(n_layers)) - - if isinstance(stride, int): - stride = n_layers * [stride] - - net = [] - for i in range(n_layers): - if isinstance(kernel_size, int): - pad = kernel_size // 2 - s = stride[i] - else: - pad = kernel_size[0] // 2 - s = (stride[i], 1) - - net.append( - normalization( - convolution( - channels[i], - channels[i + 1], - kernel_size, - stride=s, - padding=pad, - ))) - net.append(activation()) - - net.append(convolution(channels[-1], out_size, 1)) - - self.net = nn.ModuleList(net) - - def forward(self, x) -> IndividualDiscriminatorOut: - features = [] - for layer in self.net: - x = layer(x) - if isinstance(layer, nn.modules.conv._ConvNd): - features.append(x) - score = x.reshape(x.shape[0], -1).mean(-1) - return score, features - - -class MultiScaleDiscriminator(nn.Module): - - def __init__(self, - in_channels: int, - n_scales: int, - **conv_kwargs) -> None: - super().__init__() - layers = [] - for _ in range(n_scales): - layers.append(SharedDiscriminatorConvNet(in_channels, nn.Conv1d, **conv_kwargs)) - self.layers = nn.ModuleList(layers) - - def forward(self, x: torch.Tensor) -> IndividualDiscriminatorOut: - score = 0 - features = [] - for layer in self.layers: - s, f = layer(x) - score = score + s - features.extend(f) - x = nn.functional.avg_pool1d(x, 2) - return score, features - -class MultiPeriodDiscriminator(nn.Module): - - def __init__(self, - in_channels: int, - periods: tp.Sequence[int], - **conv_kwargs) -> None: - super().__init__() - layers = [] - self.periods = periods - - for _ in periods: - layers.append(SharedDiscriminatorConvNet(in_channels, nn.Conv2d, **conv_kwargs)) - - self.layers = nn.ModuleList(layers) - - def forward(self, x: torch.Tensor) -> IndividualDiscriminatorOut: - score = 0 - features = [] - for layer, n in zip(self.layers, self.periods): - s, f = layer(self.fold(x, n)) - score = score + s - features.extend(f) - return score, features - - def fold(self, x: torch.Tensor, n: int) -> torch.Tensor: - pad = (n - (x.shape[-1] % n)) % n - x = nn.functional.pad(x, (0, pad)) - return x.reshape(*x.shape[:2], -1, n) - - -class MultiDiscriminator(nn.Module): - """ - Individual discriminators should take a single tensor as input (NxB C T) and - return a tuple composed of a score tensor (NxB) and a Sequence of Features - Sequence[NxB C' T']. - """ - - def __init__(self, discriminator_list: tp.Sequence[nn.Module], - keys: tp.Sequence[str]) -> None: - super().__init__() - self.discriminators = nn.ModuleList(discriminator_list) - self.keys = keys - - def unpack_tensor_to_dict(self, features: torch.Tensor) -> TensorDict: - features = features.chunk(len(self.keys), 0) - return {k: features[i] for i, k in enumerate(self.keys)} - - @staticmethod - def concat_dicts(dict_a, dict_b): - out_dict = {} - keys = set(list(dict_a.keys()) + list(dict_b.keys())) - for k in keys: - out_dict[k] = [] - if k in dict_a: - if isinstance(dict_a[k], list): - out_dict[k].extend(dict_a[k]) - else: - out_dict[k].append(dict_a[k]) - if k in dict_b: - if isinstance(dict_b[k], list): - out_dict[k].extend(dict_b[k]) - else: - out_dict[k].append(dict_b[k]) - return out_dict - - @staticmethod - def sum_dicts(dict_a, dict_b): - out_dict = {} - keys = set(list(dict_a.keys()) + list(dict_b.keys())) - for k in keys: - out_dict[k] = 0. - if k in dict_a: - out_dict[k] = out_dict[k] + dict_a[k] - if k in dict_b: - out_dict[k] = out_dict[k] + dict_b[k] - return out_dict - - def forward(self, inputs: TensorDict) -> TensorDict: - discriminator_input = torch.cat([inputs[k] for k in self.keys], 0) - all_scores = [] - all_features = [] - - for discriminator in self.discriminators: - score, features = discriminator(discriminator_input) - scores = self.unpack_tensor_to_dict(score) - scores = {f"score_{k}": scores[k] for k in scores.keys()} - all_scores.append(scores) - - features = map(self.unpack_tensor_to_dict, features) - features = reduce(self.concat_dicts, features) - features = {f"features_{k}": features[k] for k in features.keys()} - all_features.append(features) - - all_scores = reduce(self.sum_dicts, all_scores) - all_features = reduce(self.concat_dicts, all_features) - - inputs.update(all_scores) - inputs.update(all_features) - - return inputs - -class OobleckDiscriminator(nn.Module): - - def __init__( - self, - in_channels=1, - ): - super().__init__() - - multi_scale_discriminator = MultiScaleDiscriminator( - in_channels=in_channels, - n_scales=3, - ) - - multi_period_discriminator = MultiPeriodDiscriminator( - in_channels=in_channels, - periods=[2, 3, 5, 7, 11] - ) - - # multi_resolution_discriminator = MultiScaleSTFTDiscriminator( - # filters=32, - # in_channels = in_channels, - # out_channels = 1, - # n_ffts = [2048, 1024, 512, 256, 128], - # hop_lengths = [512, 256, 128, 64, 32], - # win_lengths = [2048, 1024, 512, 256, 128] - # ) - - self.multi_discriminator = MultiDiscriminator( - [multi_scale_discriminator, multi_period_discriminator], #, multi_resolution_discriminator], - ["reals", "fakes"] - ) - - def loss(self, reals, fakes): - inputs = { - "reals": reals, - "fakes": fakes, - } - - inputs = self.multi_discriminator(inputs) - - scores_real = inputs["score_reals"] - scores_fake = inputs["score_fakes"] - - features_real = inputs["features_reals"] - features_fake = inputs["features_fakes"] - - dis_loss, gen_loss = get_hinge_losses(scores_real, scores_fake) - - feature_matching_distance = torch.tensor(0.) - - for _, (scale_real, scale_fake) in enumerate(zip(features_real, features_fake)): - - feature_matching_distance = feature_matching_distance + sum( - map( - lambda real, fake: abs(real - fake).mean(), - scale_real, - scale_fake, - )) / len(scale_real) - - return dis_loss, gen_loss, feature_matching_distance - - -## Discriminators from Descript Audio Codec repo -## Copied and modified under MIT license, see LICENSES/LICENSE_DESCRIPT.txt -class MPD(nn.Module): - def __init__(self, period, channels=1): - super().__init__() - - self.period = period - self.convs = nn.ModuleList( - [ - WNConv2d(channels, 32, (5, 1), (3, 1), padding=(2, 0)), - WNConv2d(32, 128, (5, 1), (3, 1), padding=(2, 0)), - WNConv2d(128, 512, (5, 1), (3, 1), padding=(2, 0)), - WNConv2d(512, 1024, (5, 1), (3, 1), padding=(2, 0)), - WNConv2d(1024, 1024, (5, 1), 1, padding=(2, 0)), - ] - ) - self.conv_post = WNConv2d( - 1024, 1, kernel_size=(3, 1), padding=(1, 0), act=False - ) - - def pad_to_period(self, x): - t = x.shape[-1] - x = F.pad(x, (0, self.period - t % self.period), mode="reflect") - return x - - def forward(self, x): - fmap = [] - - x = self.pad_to_period(x) - x = rearrange(x, "b c (l p) -> b c l p", p=self.period) - - for layer in self.convs: - x = layer(x) - fmap.append(x) - - x = self.conv_post(x) - fmap.append(x) - - return fmap - - -class MSD(nn.Module): - def __init__(self, rate: int = 1, sample_rate: int = 44100, channels=1): - super().__init__() - - self.convs = nn.ModuleList( - [ - WNConv1d(channels, 16, 15, 1, padding=7), - WNConv1d(16, 64, 41, 4, groups=4, padding=20), - WNConv1d(64, 256, 41, 4, groups=16, padding=20), - WNConv1d(256, 1024, 41, 4, groups=64, padding=20), - WNConv1d(1024, 1024, 41, 4, groups=256, padding=20), - WNConv1d(1024, 1024, 5, 1, padding=2), - ] - ) - self.conv_post = WNConv1d(1024, 1, 3, 1, padding=1, act=False) - self.sample_rate = sample_rate - self.rate = rate - - def forward(self, x): - x = AudioSignal(x, self.sample_rate) - x.resample(self.sample_rate // self.rate) - x = x.audio_data - - fmap = [] - - for l in self.convs: - x = l(x) - fmap.append(x) - x = self.conv_post(x) - fmap.append(x) - - return fmap - - -BANDS = [(0.0, 0.1), (0.1, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)] - - -class MRD(nn.Module): - def __init__( - self, - window_length: int, - hop_factor: float = 0.25, - sample_rate: int = 44100, - bands: list = BANDS, - channels: int = 1 - ): - """Complex multi-band spectrogram discriminator. - Parameters - ---------- - window_length : int - Window length of STFT. - hop_factor : float, optional - Hop factor of the STFT, defaults to ``0.25 * window_length``. - sample_rate : int, optional - Sampling rate of audio in Hz, by default 44100 - bands : list, optional - Bands to run discriminator over. - """ - super().__init__() - - self.window_length = window_length - self.hop_factor = hop_factor - self.sample_rate = sample_rate - self.stft_params = STFTParams( - window_length=window_length, - hop_length=int(window_length * hop_factor), - match_stride=True, - ) - - self.channels = channels - - n_fft = window_length // 2 + 1 - bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands] - self.bands = bands - - ch = 32 - convs = lambda: nn.ModuleList( - [ - WNConv2d(2, ch, (3, 9), (1, 1), padding=(1, 4)), - WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)), - WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)), - WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)), - WNConv2d(ch, ch, (3, 3), (1, 1), padding=(1, 1)), - ] - ) - self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))]) - self.conv_post = WNConv2d(ch, 1, (3, 3), (1, 1), padding=(1, 1), act=False) - - def spectrogram(self, x): - x = AudioSignal(x, self.sample_rate, stft_params=self.stft_params) - x = torch.view_as_real(x.stft()) - x = rearrange(x, "b ch f t c -> (b ch) c t f", ch=self.channels) - # Split into bands - x_bands = [x[..., b[0] : b[1]] for b in self.bands] - return x_bands - - def forward(self, x): - x_bands = self.spectrogram(x) - fmap = [] - - x = [] - for band, stack in zip(x_bands, self.band_convs): - for layer in stack: - band = layer(band) - fmap.append(band) - x.append(band) - - x = torch.cat(x, dim=-1) - x = self.conv_post(x) - fmap.append(x) - - return fmap - - -class DACDiscriminator(nn.Module): - def __init__( - self, - channels: int = 1, - rates: list = [], - periods: list = [2, 3, 5, 7, 11], - fft_sizes: list = [2048, 1024, 512], - sample_rate: int = 44100, - bands: list = BANDS, - ): - """Discriminator that combines multiple discriminators. - - Parameters - ---------- - rates : list, optional - sampling rates (in Hz) to run MSD at, by default [] - If empty, MSD is not used. - periods : list, optional - periods (of samples) to run MPD at, by default [2, 3, 5, 7, 11] - fft_sizes : list, optional - Window sizes of the FFT to run MRD at, by default [2048, 1024, 512] - sample_rate : int, optional - Sampling rate of audio in Hz, by default 44100 - bands : list, optional - Bands to run MRD at, by default `BANDS` - """ - super().__init__() - discs = [] - discs += [MPD(p, channels=channels) for p in periods] - discs += [MSD(r, sample_rate=sample_rate, channels=channels) for r in rates] - discs += [MRD(f, sample_rate=sample_rate, bands=bands, channels=channels) for f in fft_sizes] - self.discriminators = nn.ModuleList(discs) - - def preprocess(self, y): - # Remove DC offset - y = y - y.mean(dim=-1, keepdims=True) - # Peak normalize the volume of input audio - y = 0.8 * y / (y.abs().max(dim=-1, keepdim=True)[0] + 1e-9) - return y - - def forward(self, x): - x = self.preprocess(x) - fmaps = [d(x) for d in self.discriminators] - return fmaps - -class DACGANLoss(nn.Module): - """ - Computes a discriminator loss, given a discriminator on - generated waveforms/spectrograms compared to ground truth - waveforms/spectrograms. Computes the loss for both the - discriminator and the generator in separate functions. - """ - - def __init__(self, **discriminator_kwargs): - super().__init__() - self.discriminator = DACDiscriminator(**discriminator_kwargs) - - def forward(self, fake, real): - d_fake = self.discriminator(fake) - d_real = self.discriminator(real) - return d_fake, d_real - - def discriminator_loss(self, fake, real): - d_fake, d_real = self.forward(fake.clone().detach(), real) - - loss_d = 0 - for x_fake, x_real in zip(d_fake, d_real): - loss_d += torch.mean(x_fake[-1] ** 2) - loss_d += torch.mean((1 - x_real[-1]) ** 2) - return loss_d - - def generator_loss(self, fake, real): - d_fake, d_real = self.forward(fake, real) - - loss_g = 0 - for x_fake in d_fake: - loss_g += torch.mean((1 - x_fake[-1]) ** 2) - - loss_feature = 0 - - for i in range(len(d_fake)): - for j in range(len(d_fake[i]) - 1): - loss_feature += F.l1_loss(d_fake[i][j], d_real[i][j].detach()) - return loss_g, loss_feature - - def loss(self, fake, real): - gen_loss, feature_distance = self.generator_loss(fake, real) - dis_loss = self.discriminator_loss(fake, real) - - return dis_loss, gen_loss, feature_distance \ No newline at end of file diff --git a/think_sound/models/lm.py b/think_sound/models/lm.py deleted file mode 100644 index 1897fa72ab716f69e0c6d71236e47cc50f78592e..0000000000000000000000000000000000000000 --- a/think_sound/models/lm.py +++ /dev/null @@ -1,541 +0,0 @@ -from dataclasses import dataclass -import torch -from tqdm.auto import trange -import typing as tp -from einops import rearrange -from torch import nn - -from .conditioners import MultiConditioner, create_multi_conditioner_from_conditioning_config -from .factory import create_pretransform_from_config -from .lm_backbone import AudioLMBackbone, XTransformersAudioLMBackbone, ContinuousTransformerAudioLMBackbone -from .pretransforms import Pretransform, AutoencoderPretransform, PretrainedDACPretransform, AudiocraftCompressionPretransform -from .utils import multinomial, sample_top_k, sample_top_p - -from .codebook_patterns import ( - CodebooksPatternProvider, - DelayedPatternProvider, - MusicLMPattern, - ParallelPatternProvider, - UnrolledPatternProvider -) - -# Copied and modified from https://github.com/facebookresearch/audiocraft/blob/main/audiocraft/models/lm.py under MIT license -# License can be found in LICENSES/LICENSE_META.txt - -@dataclass -class LMOutput: - # The logits are already re-aligned with the input codes - # hence no extra shift is required, e.g. when computing CE - logits: torch.Tensor # [B, K, T, card] - mask: torch.Tensor # [B, K, T] - -# Wrapper for a multi-codebook language model -# Handles patterns and quantizer heads -class AudioLanguageModel(nn.Module): - def __init__( - self, - pattern_provider: CodebooksPatternProvider, - backbone: AudioLMBackbone, - num_quantizers: int, - codebook_size: int - ): - super().__init__() - - self.pattern_provider = pattern_provider - self.backbone = backbone - self.num_quantizers = num_quantizers - self.codebook_size = codebook_size - - self.masked_token_id = codebook_size - - # Per-quantizer embedders - # Add one for the mask embed - self.embeds = nn.ModuleList([nn.Embedding(codebook_size + 1, backbone.embed_dim) for _ in range(num_quantizers)]) - - # Per-quantizer output heads - self.quantizer_heads = nn.ModuleList([ - nn.Linear(backbone.embed_dim, codebook_size) for _ in range(num_quantizers) - ]) - - def forward(self, - sequence: torch.Tensor, #[batch, seq_len, - prepend_cond=None, #[batch, seq, channels] - prepend_cond_mask=None, - cross_attn_cond=None, #[batch, seq, channels], - **kwargs - ): - - batch, num_quantizers, seq_len = sequence.shape - - assert num_quantizers == self.num_quantizers, "Number of quantizers in sequence must match number of quantizers in model" - - backbone_input = sum([self.embeds[i](sequence[:, i]) for i in range(num_quantizers)]) # [batch, seq_len, embed_dim] - - dtype = next(self.parameters()).dtype - - if cross_attn_cond is not None: - cross_attn_cond = cross_attn_cond.to(dtype) - - if prepend_cond is not None: - prepend_cond = prepend_cond.to(dtype) - - if prepend_cond_mask is not None: - prepend_cond_mask = prepend_cond_mask.to(dtype) - - backbone_input = backbone_input.to(dtype) - - output = self.backbone( - backbone_input, - cross_attn_cond=cross_attn_cond, - prepend_cond=prepend_cond, - prepend_cond_mask=prepend_cond_mask, - **kwargs - ) # [batch, seq_len, embed_dim] - - # Run output through quantizer heads - logits = torch.stack([self.quantizer_heads[i](output) for i in range(num_quantizers)], dim=1) # [batch, num_quantizers, seq_len, codebook_size] - - return logits - - def compute_logits( - self, - codes, #[batch, num_quantizers, seq_len] - **kwargs): - """ - Compute logits for a batch of codes, optionally conditioning on cross-attention and prepend conditioning - Handles translation between input sequence and pattern-shifted sequence - Only used during training - """ - - batch, _, seq_len = codes.shape - - pattern = self.pattern_provider.get_pattern(seq_len) - - # Apply the token pattern to the codes, shifting the codes as needed and masking out invalid steps - shifted_codes, _, _ = pattern.build_pattern_sequence( - codes, - self.masked_token_id, - keep_only_valid_steps=True - ) - - # Run the model to get logits for each quantizer [batch, num_quantizers, seq_len, codebook_size] - logits = self(shifted_codes, **kwargs) - - # Rearrange logits to prepare to revert pattern - logits = rearrange(logits, "b n s c -> b c n s") - - # Revert sequence logits back to original sequence length, removing masked steps - logits, _, logits_mask = pattern.revert_pattern_logits( - logits, float('nan'), keep_only_valid_steps=True - ) - - logits = rearrange(logits, "b c n t -> b n t c") - - logits_mask = logits_mask[None, :, :].expand(batch, -1, -1) # [batch, num_quantizers, seq_len] - - return LMOutput(logits=logits, mask=logits_mask) - -# Conditioning and generation wrapper for a multi-codebook language model -# Handles conditioning, CFG, generation, and encoding/decoding -class AudioLanguageModelWrapper(nn.Module): - def __init__( - self, - pretransform: Pretransform, - lm: AudioLanguageModel, - sample_rate: int, - min_input_length: int, - conditioner: MultiConditioner = None, - cross_attn_cond_ids: tp.List[str] = [], - prepend_cond_ids: tp.List[str] = [], - global_cond_ids: tp.List[str] = [] - ): - super().__init__() - - assert pretransform.is_discrete, "Pretransform must be discrete" - self.pretransform = pretransform - - self.pretransform.requires_grad_(False) - self.pretransform.eval() - - if isinstance(self.pretransform, AutoencoderPretransform): - self.num_quantizers = self.pretransform.model.bottleneck.num_quantizers - self.codebook_size = self.pretransform.model.bottleneck.codebook_size - elif isinstance(self.pretransform, PretrainedDACPretransform): - self.num_quantizers = self.pretransform.model.num_quantizers - self.codebook_size = self.pretransform.model.codebook_size - elif isinstance(self.pretransform, AudiocraftCompressionPretransform): - self.num_quantizers = self.pretransform.num_quantizers - self.codebook_size = self.pretransform.codebook_size - else: - raise NotImplementedError(f"Unrecognized pretransform type {type(self.pretransform)}") - - self.conditioner = conditioner - - self.lm = lm - - self.sample_rate = sample_rate - self.min_input_length = min_input_length - - self.cross_attn_cond_ids = cross_attn_cond_ids - self.prepend_cond_ids = prepend_cond_ids - self.global_cond_ids = global_cond_ids - - def get_conditioning_inputs(self, cond: tp.Dict[str, tp.Any], negative=False): - cross_attention_input = None - prepend_cond = None - prepend_cond_mask = None - global_cond = None - - if len(self.cross_attn_cond_ids) > 0: - # Concatenate all cross-attention inputs over the sequence dimension - # Assumes that the cross-attention inputs are of shape (batch, seq, channels) - cross_attention_input = torch.cat([cond[key][0] for key in self.cross_attn_cond_ids], dim=1) - - if len(self.prepend_cond_ids) > 0: - # Concatenate all prepend conditioning inputs over the sequence dimension - # Assumes that the prepend conditioning inputs are of shape (batch, seq, channels) - prepend_cond = torch.cat([cond[key][0] for key in self.prepend_cond_ids], dim=1) - prepend_cond_mask = torch.cat([cond[key][1] for key in self.prepend_cond_ids], dim=1) - - if len(self.global_cond_ids) > 0: - # Concatenate all global conditioning inputs over the channel dimension - # Assumes that the global conditioning inputs are of shape (batch, channels) - global_cond = torch.cat([cond[key][0] for key in self.global_cond_ids], dim=-1) - if len(global_cond.shape) == 3: - global_cond = global_cond.squeeze(1) - - if negative: - return { - "negative_cross_attn_cond": cross_attention_input, - "negative_prepend_cond": prepend_cond, - "negative_prepend_cond_mask": prepend_cond_mask, - "negative_global_cond": global_cond - } - else: - return { - "cross_attn_cond": cross_attention_input, - "prepend_cond": prepend_cond, - "prepend_cond_mask": prepend_cond_mask, - "global_cond": global_cond - } - - def compute_logits( - self, - codes, - condition_tensors=None, - cfg_dropout_prob=0.0, - **kwargs - ): - """ - Compute logits for a batch of codes, and translates from conditioning inputs to model inputs - Handles CFG dropout - """ - - if condition_tensors is None: - condition_tensors = {} - - conditioning_inputs = self.get_conditioning_inputs(condition_tensors) - - cross_attn_cond = conditioning_inputs["cross_attn_cond"] - prepend_cond = conditioning_inputs["prepend_cond"] - prepend_cond_mask = conditioning_inputs["prepend_cond_mask"] - global_cond = conditioning_inputs["global_cond"] - - if cfg_dropout_prob > 0.0: - if cross_attn_cond is not None: - null_embed = torch.zeros_like(cross_attn_cond, device=cross_attn_cond.device) - dropout_mask = torch.bernoulli(torch.full((cross_attn_cond.shape[0], 1, 1), cfg_dropout_prob, device=cross_attn_cond.device)).to(torch.bool) - cross_attn_cond = torch.where(dropout_mask, null_embed, cross_attn_cond) - - if prepend_cond is not None: - null_embed = torch.zeros_like(prepend_cond, device=prepend_cond.device) - dropout_mask = torch.bernoulli(torch.full((prepend_cond.shape[0], 1, 1), cfg_dropout_prob, device=prepend_cond.device)).to(torch.bool) - prepend_cond = torch.where(dropout_mask, null_embed, prepend_cond) - - if global_cond is not None: - null_embed = torch.zeros_like(global_cond, device=global_cond.device) - dropout_mask = torch.bernoulli(torch.full((global_cond.shape[0], 1), cfg_dropout_prob, device=global_cond.device)).to(torch.bool) - global_cond = torch.where(dropout_mask, null_embed, global_cond) - - return self.lm.compute_logits(codes, cross_attn_cond=cross_attn_cond, prepend_cond=prepend_cond, prepend_cond_mask=prepend_cond_mask, global_cond=global_cond, **kwargs) - - def _sample_next_token( - self, - sequence, #[batch, num_quantizers, seq_len] - conditioning_tensors=None, - cross_attn_use_cfg=True, - prepend_use_cfg=True, - global_use_cfg=True, - cfg_scale=1.0, - top_k=250, - top_p=0.0, - temp=1.0, - **kwargs - ): - """ - Sample the next token for a batch of codes, and translates from conditioning inputs to model inputs - Handles CFG inference - """ - - if conditioning_tensors is None: - conditioning_tensors = {} - - conditioning_inputs = self.get_conditioning_inputs(conditioning_tensors) - - cross_attn_cond = conditioning_inputs["cross_attn_cond"] - prepend_cond = conditioning_inputs["prepend_cond"] - prepend_cond_mask = conditioning_inputs["prepend_cond_mask"] - global_cond = conditioning_inputs["global_cond"] - - if cfg_scale != 1.0: - - # Batch size is doubled to account for negative samples - sequence = torch.cat([sequence, sequence], dim=0) - - if cross_attn_cond is not None and cross_attn_use_cfg: - null_embed = torch.zeros_like(cross_attn_cond, device=cross_attn_cond.device) - - cross_attn_cond = torch.cat([cross_attn_cond, null_embed], dim=0) - - if prepend_cond is not None and prepend_use_cfg: - null_embed = torch.zeros_like(prepend_cond, device=prepend_cond.device) - - prepend_cond = torch.cat([prepend_cond, null_embed], dim=0) - - if prepend_cond_mask is not None: - prepend_cond_mask = torch.cat([prepend_cond_mask, prepend_cond_mask], dim=0) - - if global_cond is not None and global_use_cfg: - null_embed = torch.zeros_like(global_cond, device=global_cond.device) - - global_cond = torch.cat([global_cond, null_embed], dim=0) - - logits = self.lm(sequence, cross_attn_cond=cross_attn_cond, prepend_cond=prepend_cond, prepend_cond_mask=prepend_cond_mask, global_cond=global_cond, **kwargs) - - if cfg_scale != 1.0: - cond_logits, uncond_logits = logits.chunk(2, dim=0) - - logits = uncond_logits + (cond_logits - uncond_logits) * cfg_scale - - logits = rearrange(logits, "b n s c -> b n c s") # [batch, num_quantizers, codebook_size, seq_len] - - # Grab the logits for the last step - logits = logits[:, :, :, -1] # [batch, num_quantizers, codebook_size] - - # Apply top-k or top-p sampling - - if temp > 0: - probs = torch.softmax(logits / temp, dim=-1) - - if top_p > 0.0: - next_token = sample_top_p(probs, p=top_p) - elif top_k > 0: - next_token = sample_top_k(probs, k=top_k) - else: - next_token = multinomial(probs, num_samples=1) - - else: - next_token = torch.argmax(logits, dim=-1, keepdim=True) # [batch, num_quantizers, 1] - - return next_token - - @torch.no_grad() - def generate( - self, - max_gen_len: int = 256, - batch_size: tp.Optional[int] = None, - init_data: tp.Optional[torch.Tensor] = None, - conditioning: tp.Optional[tp.Dict[str, tp.Any]] = None, - conditioning_tensors: tp.Optional[tp.Dict[str, tp.Any]] = None, - callback: tp.Optional[tp.Callable[[int, int], None]] = None, - use_cache: bool = True, - cfg_scale: float = 1.0, - **kwargs - ): - device = next(self.parameters()).device - - if conditioning_tensors is None and conditioning is not None: - # Convert conditioning inputs to conditioning tensors - conditioning_tensors = self.conditioner(conditioning, device) - - # Check that batch size is consistent across inputs - possible_batch_sizes = [] - - if batch_size is not None: - possible_batch_sizes.append(batch_size) - elif init_data is not None: - possible_batch_sizes.append(init_data.shape[0]) - elif conditioning_tensors is not None: - # Assume that the first conditioning tensor has the batch dimension - possible_batch_sizes.append(conditioning_tensors[list(conditioning_tensors.keys())[0]][0].shape[0]) - else: - possible_batch_sizes.append(1) - - assert [x == possible_batch_sizes[0] for x in possible_batch_sizes], "Batch size must be consistent across inputs" - - batch_size = possible_batch_sizes[0] - - if init_data is None: - # Initialize with zeros - assert batch_size > 0 - init_data = torch.zeros((batch_size, self.num_quantizers, 0), device=device, dtype=torch.long) - - batch_size, num_quantizers, seq_len = init_data.shape - - start_offset = seq_len - assert start_offset < max_gen_len, "init data longer than max gen length" - - pattern = self.lm.pattern_provider.get_pattern(max_gen_len) - - unknown_token = -1 - - # Initialize the generated codes with the init data, padded with unknown tokens - gen_codes = torch.full((batch_size, num_quantizers, max_gen_len), unknown_token, device=device, dtype=torch.long) - gen_codes[:, :, :start_offset] = init_data # [batch, num_quantizers, max_gen_len] - - gen_sequence, _, mask = pattern.build_pattern_sequence(gen_codes, self.lm.masked_token_id) # [batch, num_quantizers, gen_sequence_len] - - start_offset_sequence = pattern.get_first_step_with_timesteps(start_offset) - assert start_offset_sequence is not None - - # Generation - prev_offset = 0 - gen_sequence_len = gen_sequence.shape[-1] - - # Reset generation cache - if use_cache and self.lm.backbone.use_generation_cache: - self.lm.backbone.reset_generation_cache(max_gen_len, batch_size if cfg_scale == 1.0 else batch_size * 2) - - for offset in trange(start_offset_sequence, gen_sequence_len): - - # Get the full sequence up to the current offset - curr_sequence = gen_sequence[..., prev_offset:offset] - - next_token = self._sample_next_token( - curr_sequence, - conditioning_tensors=conditioning_tensors, - use_cache=use_cache, - cfg_scale=cfg_scale, - **kwargs - ) - - valid_mask = mask[..., offset:offset+1].expand(batch_size, -1, -1) - next_token[~valid_mask] = self.lm.masked_token_id - - # Update the generated sequence with the next token - gen_sequence[..., offset:offset+1] = torch.where( - gen_sequence[..., offset:offset+1] == unknown_token, - next_token, - gen_sequence[..., offset:offset+1] - ) - - if use_cache and self.lm.backbone.use_generation_cache: - # Only update the offset if caching is being used - prev_offset = offset - - self.lm.backbone.update_generation_cache(offset) - - if callback is not None: - # Callback to report progress - # Pass in the offset relative to the start of the sequence, and the length of the current sequence - callback(1 + offset - start_offset_sequence, gen_sequence_len - start_offset_sequence) - - assert not (gen_sequence == unknown_token).any(), "Unknown tokens in generated sequence" - - out_codes, _, out_mask = pattern.revert_pattern_sequence(gen_sequence, special_token=unknown_token) - - # sanity checks over the returned codes and corresponding masks - assert (out_codes[..., :max_gen_len] != unknown_token).all() - assert (out_mask[..., :max_gen_len] == 1).all() - - #out_codes = out_codes[..., 0:max_gen_len] - - return out_codes - - - def generate_audio( - self, - **kwargs - ): - """ - Generate audio from a batch of codes - """ - - codes = self.generate(**kwargs) - - audio = self.pretransform.decode_tokens(codes) - - return audio - - -def create_audio_lm_from_config(config): - model_config = config.get('model', None) - assert model_config is not None, 'model config must be specified in config' - - sample_rate = config.get('sample_rate', None) - assert sample_rate is not None, "Must specify sample_rate in config" - - lm_config = model_config.get('lm', None) - assert lm_config is not None, 'lm config must be specified in model config' - - codebook_pattern = lm_config.get("codebook_pattern", "delay") - - pattern_providers = { - 'parallel': ParallelPatternProvider, - 'delay': DelayedPatternProvider, - 'unroll': UnrolledPatternProvider, - 'musiclm': MusicLMPattern, - } - - pretransform_config = model_config.get("pretransform", None) - - pretransform = create_pretransform_from_config(pretransform_config, sample_rate) - - assert pretransform.is_discrete, "Pretransform must be discrete" - - min_input_length = pretransform.downsampling_ratio - - pattern_provider = pattern_providers[codebook_pattern](n_q=pretransform.num_quantizers) - - conditioning_config = model_config.get('conditioning', None) - - conditioner = None - if conditioning_config is not None: - conditioner = create_multi_conditioner_from_conditioning_config(conditioning_config) - - cross_attn_cond_ids = lm_config.get('cross_attention_cond_ids', []) - prepend_cond_ids = lm_config.get('prepend_cond_ids', []) - global_cond_ids = lm_config.get('global_cond_ids', []) - - lm_type = lm_config.get("type", None) - lm_model_config = lm_config.get("config", None) - - assert lm_type is not None, "Must specify lm type in lm config" - assert lm_model_config is not None, "Must specify lm model config in lm config" - - if lm_type == "x-transformers": - backbone = XTransformersAudioLMBackbone(**lm_model_config) - elif lm_type == "continuous_transformer": - backbone = ContinuousTransformerAudioLMBackbone(**lm_model_config) - else: - raise NotImplementedError(f"Unrecognized lm type {lm_type}") - - lm = AudioLanguageModel( - pattern_provider=pattern_provider, - backbone=backbone, - num_quantizers=pretransform.num_quantizers, - codebook_size=pretransform.codebook_size - ) - - model = AudioLanguageModelWrapper( - pretransform=pretransform, - lm=lm, - conditioner=conditioner, - sample_rate=sample_rate, - min_input_length=min_input_length, - cross_attn_cond_ids=cross_attn_cond_ids, - prepend_cond_ids=prepend_cond_ids, - global_cond_ids=global_cond_ids - ) - - return model \ No newline at end of file diff --git a/think_sound/models/lm_backbone.py b/think_sound/models/lm_backbone.py deleted file mode 100644 index c80cce60b06d9b367b114188444b0890a1990b61..0000000000000000000000000000000000000000 --- a/think_sound/models/lm_backbone.py +++ /dev/null @@ -1,159 +0,0 @@ -from torch import nn -from x_transformers import ContinuousTransformerWrapper, Decoder - -from .transformer import ContinuousTransformer - -# Interface for backbone of a language model -# Handles conditioning and cross-attention -# Does not have to deal with patterns or quantizer heads -class AudioLMBackbone(nn.Module): - def __init__(self, embed_dim: int, use_generation_cache=False, **kwargs): - super().__init__() - - self.embed_dim = embed_dim - self.use_generation_cache = use_generation_cache - - def forward( - self, - x, - cross_attn_cond=None, - prepend_cond=None, - prepend_cond_mask=None, - global_cond=None, - use_cache=False, - **kwargs - ): - raise NotImplementedError - - def reset_generation_cache( - self, - max_seq_len, - batch_size, - dtype=None - ): - pass - - def update_generation_cache( - self, - seqlen_offset - ): - pass - -class XTransformersAudioLMBackbone(AudioLMBackbone): - def __init__(self, - embed_dim: int, - cross_attn_cond_dim: int = 0, - prepend_cond_dim: int = 0, - **kwargs): - super().__init__(embed_dim=embed_dim) - - # Embeddings are done in the AudioLanguageModel, so we use the continuous-input transformer - self.model = ContinuousTransformerWrapper( - dim_in=embed_dim, - dim_out=embed_dim, - max_seq_len=0, #Not relevant without absolute positional embeds, - attn_layers=Decoder( - dim=embed_dim, - attn_flash = True, - cross_attend = cross_attn_cond_dim > 0, - zero_init_branch_output=True, - use_abs_pos_emb = False, - rotary_pos_emb=True, - ff_swish = True, - ff_glu = True, - **kwargs - ) - ) - - if prepend_cond_dim > 0: - # Prepend conditioning - self.to_prepend_embed = nn.Sequential( - nn.Linear(prepend_cond_dim, embed_dim, bias=False), - nn.SiLU(), - nn.Linear(embed_dim, embed_dim, bias=False) - ) - - if cross_attn_cond_dim > 0: - # Cross-attention conditioning - self.to_cross_attn_embed = nn.Sequential( - nn.Linear(cross_attn_cond_dim, embed_dim, bias=False), - nn.SiLU(), - nn.Linear(embed_dim, embed_dim, bias=False) - ) - - def forward(self, x, mask=None, prepend_cond=None, prepend_cond_mask=None, cross_attn_cond=None, global_cond=None, use_cache=False): - - prepend_length = 0 - if prepend_cond is not None: - # Project the prepend conditioning to the embedding dimension - prepend_cond = self.to_prepend_embed(prepend_cond) - prepend_length = prepend_cond.shape[1] - - if prepend_cond_mask is not None: - # Cast mask to bool - prepend_cond_mask = prepend_cond_mask.bool() - - if cross_attn_cond is not None: - # Project the cross-attention conditioning to the embedding dimension - cross_attn_cond = self.to_cross_attn_embed(cross_attn_cond) - - return self.model(x, mask=mask, context=cross_attn_cond, prepend_embeds=prepend_cond, prepend_mask=prepend_cond_mask)[:, prepend_length:, :] - -class ContinuousTransformerAudioLMBackbone(AudioLMBackbone): - def __init__(self, - embed_dim: int, - cross_attn_cond_dim: int = 0, - prepend_cond_dim: int = 0, - project_cross_attn_cond: bool = False, - **kwargs): - super().__init__(embed_dim=embed_dim) - - # Embeddings are done in the AudioLanguageModel, so we use the continuous-input transformer - self.model = ContinuousTransformer( - dim=embed_dim, - dim_in=embed_dim, - dim_out=embed_dim, - cross_attend = cross_attn_cond_dim > 0, - cond_token_dim = embed_dim if project_cross_attn_cond else cross_attn_cond_dim, - causal=True, - **kwargs - ) - - if prepend_cond_dim > 0: - # Prepend conditioning - self.to_prepend_embed = nn.Sequential( - nn.Linear(prepend_cond_dim, embed_dim, bias=False), - nn.SiLU(), - nn.Linear(embed_dim, embed_dim, bias=False) - ) - - if cross_attn_cond_dim > 0 and project_cross_attn_cond: - # Cross-attention conditioning - self.to_cross_attn_embed = nn.Sequential( - nn.Linear(cross_attn_cond_dim, embed_dim, bias=False), - nn.SiLU(), - nn.Linear(embed_dim, embed_dim, bias=False) - ) - else: - self.to_cross_attn_embed = nn.Identity() - - def forward(self, x, mask=None, prepend_cond=None, prepend_cond_mask=None, cross_attn_cond=None, global_cond=None, use_cache=False): - - prepend_length = 0 - if prepend_cond is not None: - # Project the prepend conditioning to the embedding dimension - prepend_cond = self.to_prepend_embed(prepend_cond) - prepend_length = prepend_cond.shape[1] - - if prepend_cond_mask is not None: - # Cast mask to bool - prepend_cond_mask = prepend_cond_mask.bool() - - if cross_attn_cond is not None: - # Cast cross_attn_cond to same dtype as self.to_cross_attn_embed - cross_attn_cond = cross_attn_cond.to(self.to_cross_attn_embed[0].weight.dtype) - - # Project the cross-attention conditioning to the embedding dimension - cross_attn_cond = self.to_cross_attn_embed(cross_attn_cond) - - return self.model(x, mask=mask, context=cross_attn_cond, prepend_embeds=prepend_cond, prepend_mask=prepend_cond_mask)[:, prepend_length:, :] \ No newline at end of file diff --git a/think_sound/models/lm_continuous.py b/think_sound/models/lm_continuous.py deleted file mode 100644 index 469bb49f32492794345cf76dafbb377778eca81e..0000000000000000000000000000000000000000 --- a/think_sound/models/lm_continuous.py +++ /dev/null @@ -1,525 +0,0 @@ -from dataclasses import dataclass -import torch -from tqdm.auto import trange -import typing as tp -from einops import rearrange -from torch import nn - -from .conditioners import MultiConditioner, create_multi_conditioner_from_conditioning_config -from .factory import create_pretransform_from_config -from .lm_backbone import AudioLMBackbone, XTransformersAudioLMBackbone, ContinuousTransformerAudioLMBackbone -from .pretransforms import Pretransform, AutoencoderPretransform, PretrainedDACPretransform, AudiocraftCompressionPretransform -from .utils import multinomial, sample_top_k, sample_top_p -from ..models.diffusion import DiffusionModelWrapper, ConditionedDiffusionModelWrapper, create_diffusion_cond_from_config - -from .codebook_patterns import ( - CodebooksPatternProvider, - DelayedPatternProvider, - MusicLMPattern, - ParallelPatternProvider, - UnrolledPatternProvider -) - -# Copied and modified from https://github.com/facebookresearch/audiocraft/blob/main/audiocraft/models/lm.py under MIT license -# License can be found in LICENSES/LICENSE_META.txt - -@dataclass -class LMContinuousOutput: - # The logits are already re-aligned with the input codes - # hence no extra shift is required, e.g. when computing CE - logits: torch.Tensor # [B, K, T, card] - mask: torch.Tensor # [B, K, T] - -# Wrapper for a multi-codebook language model -# Handles patterns and quantizer heads -class AudioLMContinuousModel(nn.Module): - def __init__( - self, - backbone: AudioLMBackbone, - ): - super().__init__() - - self.backbone = backbone - - def sample_orders(self, bsz): - # generate a batch of random generation orders - orders = [] - for _ in range(bsz): - order = np.array(list(range(self.seq_len))) - np.random.shuffle(order) - orders.append(order) - orders = torch.Tensor(np.array(orders)).cuda().long() - return orders - - def random_masking(self, x, orders): - # generate token mask - bsz, seq_len, embed_dim = x.shape - mask_rate = self.mask_ratio_generator.rvs(1)[0] - num_masked_tokens = int(np.ceil(seq_len * mask_rate)) - mask = torch.zeros(bsz, seq_len, device=x.device) - mask = torch.scatter(mask, dim=-1, index=orders[:, :num_masked_tokens], - src=torch.ones(bsz, seq_len, device=x.device)) - return mask - - def forward(self, - sequence: torch.Tensor, #[batch, seq_len, - prepend_cond=None, #[batch, seq, channels] - prepend_cond_mask=None, - cross_attn_cond=None, #[batch, seq, channels], - **kwargs - ): - - - batch, seq_len, dim = sequence.shape - - dtype = next(self.parameters()).dtype - - if cross_attn_cond is not None: - cross_attn_cond = cross_attn_cond.to(dtype) - - if prepend_cond is not None: - prepend_cond = prepend_cond.to(dtype) - - if prepend_cond_mask is not None: - prepend_cond_mask = prepend_cond_mask.to(dtype) - - x = sequence.to(dtype) - orders = self.sample_orders(bsz=batch) - mask = self.random_masking(x, orders) - - output = self.backbone( - x, - mask = mask, - cross_attn_cond=cross_attn_cond, - prepend_cond=prepend_cond, - prepend_cond_mask=prepend_cond_mask, - **kwargs - ) # [batch, seq_len, embed_dim] - - - return output - -# Conditioning and generation wrapper for a multi-codebook language model -# Handles conditioning, CFG, generation, and encoding/decoding -class AudioLanguageModelWrapper(nn.Module): - def __init__( - self, - pretransform: Pretransform, - lm: AudioLanguageModel, - diff: ConditionedDiffusionModelWrapper, - sample_rate: int, - min_input_length: int, - conditioner: MultiConditioner = None, - diffusion_objective: tp.Literal["v", "rectified_flow"] = "v", - cross_attn_cond_ids: tp.List[str] = [], - prepend_cond_ids: tp.List[str] = [], - global_cond_ids: tp.List[str] = [] - ): - super().__init__() - - assert pretransform.is_discrete, "Pretransform must be discrete" - self.pretransform = pretransform - - self.pretransform.requires_grad_(False) - self.pretransform.eval() - self.diffusion_objective = diffusion_objective - print(f'Training in the {diffusion_objective} formulation') - if isinstance(self.pretransform, AutoencoderPretransform): - self.num_quantizers = self.pretransform.model.bottleneck.num_quantizers - self.codebook_size = self.pretransform.model.bottleneck.codebook_size - elif isinstance(self.pretransform, PretrainedDACPretransform): - self.num_quantizers = self.pretransform.model.num_quantizers - self.codebook_size = self.pretransform.model.codebook_size - elif isinstance(self.pretransform, AudiocraftCompressionPretransform): - self.num_quantizers = self.pretransform.num_quantizers - self.codebook_size = self.pretransform.codebook_size - else: - raise NotImplementedError(f"Unrecognized pretransform type {type(self.pretransform)}") - - self.conditioner = conditioner - - self.lm = lm - - self.sample_rate = sample_rate - self.min_input_length = min_input_length - - self.cross_attn_cond_ids = cross_attn_cond_ids - self.prepend_cond_ids = prepend_cond_ids - self.global_cond_ids = global_cond_ids - - def get_conditioning_inputs(self, cond: tp.Dict[str, tp.Any], negative=False): - cross_attention_input = None - prepend_cond = None - prepend_cond_mask = None - global_cond = None - - if len(self.cross_attn_cond_ids) > 0: - # Concatenate all cross-attention inputs over the sequence dimension - # Assumes that the cross-attention inputs are of shape (batch, seq, channels) - cross_attention_input = torch.cat([cond[key][0] for key in self.cross_attn_cond_ids], dim=1) - - if len(self.prepend_cond_ids) > 0: - # Concatenate all prepend conditioning inputs over the sequence dimension - # Assumes that the prepend conditioning inputs are of shape (batch, seq, channels) - prepend_cond = torch.cat([cond[key][0] for key in self.prepend_cond_ids], dim=1) - prepend_cond_mask = torch.cat([cond[key][1] for key in self.prepend_cond_ids], dim=1) - - if len(self.global_cond_ids) > 0: - # Concatenate all global conditioning inputs over the channel dimension - # Assumes that the global conditioning inputs are of shape (batch, channels) - global_cond = torch.cat([cond[key][0] for key in self.global_cond_ids], dim=-1) - if len(global_cond.shape) == 3: - global_cond = global_cond.squeeze(1) - - if negative: - return { - "negative_cross_attn_cond": cross_attention_input, - "negative_prepend_cond": prepend_cond, - "negative_prepend_cond_mask": prepend_cond_mask, - "negative_global_cond": global_cond - } - else: - return { - "cross_attn_cond": cross_attention_input, - "prepend_cond": prepend_cond, - "prepend_cond_mask": prepend_cond_mask, - "global_cond": global_cond - } - - def compute_logits( - self, - audios, - condition_tensors=None, - cfg_dropout_prob=0.0, - **kwargs - ): - """ - Compute logits for a batch of codes, and translates from conditioning inputs to model inputs - Handles CFG dropout - """ - - if condition_tensors is None: - condition_tensors = {} - - conditioning_inputs = self.get_conditioning_inputs(condition_tensors) - - cross_attn_cond = conditioning_inputs["cross_attn_cond"] - prepend_cond = conditioning_inputs["prepend_cond"] - prepend_cond_mask = conditioning_inputs["prepend_cond_mask"] - global_cond = conditioning_inputs["global_cond"] - - if cfg_dropout_prob > 0.0: - if cross_attn_cond is not None: - null_embed = torch.zeros_like(cross_attn_cond, device=cross_attn_cond.device) - dropout_mask = torch.bernoulli(torch.full((cross_attn_cond.shape[0], 1, 1), cfg_dropout_prob, device=cross_attn_cond.device)).to(torch.bool) - cross_attn_cond = torch.where(dropout_mask, null_embed, cross_attn_cond) - - if prepend_cond is not None: - null_embed = torch.zeros_like(prepend_cond, device=prepend_cond.device) - dropout_mask = torch.bernoulli(torch.full((prepend_cond.shape[0], 1, 1), cfg_dropout_prob, device=prepend_cond.device)).to(torch.bool) - prepend_cond = torch.where(dropout_mask, null_embed, prepend_cond) - - if global_cond is not None: - null_embed = torch.zeros_like(global_cond, device=global_cond.device) - dropout_mask = torch.bernoulli(torch.full((global_cond.shape[0], 1), cfg_dropout_prob, device=global_cond.device)).to(torch.bool) - global_cond = torch.where(dropout_mask, null_embed, global_cond) - - return self.lm.forward(audios, cross_attn_cond=cross_attn_cond, prepend_cond=prepend_cond, prepend_cond_mask=prepend_cond_mask, global_cond=global_cond, **kwargs) - - def _sample_next_token( - self, - sequence, #[batch, num_quantizers, seq_len] - conditioning_tensors=None, - cross_attn_use_cfg=True, - prepend_use_cfg=True, - global_use_cfg=True, - cfg_scale=1.0, - top_k=250, - top_p=0.0, - temp=1.0, - **kwargs - ): - """ - Sample the next token for a batch of codes, and translates from conditioning inputs to model inputs - Handles CFG inference - """ - - if conditioning_tensors is None: - conditioning_tensors = {} - - conditioning_inputs = self.get_conditioning_inputs(conditioning_tensors) - - cross_attn_cond = conditioning_inputs["cross_attn_cond"] - prepend_cond = conditioning_inputs["prepend_cond"] - prepend_cond_mask = conditioning_inputs["prepend_cond_mask"] - global_cond = conditioning_inputs["global_cond"] - - if cfg_scale != 1.0: - - # Batch size is doubled to account for negative samples - sequence = torch.cat([sequence, sequence], dim=0) - - if cross_attn_cond is not None and cross_attn_use_cfg: - null_embed = torch.zeros_like(cross_attn_cond, device=cross_attn_cond.device) - - cross_attn_cond = torch.cat([cross_attn_cond, null_embed], dim=0) - - if prepend_cond is not None and prepend_use_cfg: - null_embed = torch.zeros_like(prepend_cond, device=prepend_cond.device) - - prepend_cond = torch.cat([prepend_cond, null_embed], dim=0) - - if prepend_cond_mask is not None: - prepend_cond_mask = torch.cat([prepend_cond_mask, prepend_cond_mask], dim=0) - - if global_cond is not None and global_use_cfg: - null_embed = torch.zeros_like(global_cond, device=global_cond.device) - - global_cond = torch.cat([global_cond, null_embed], dim=0) - - logits = self.lm(sequence, cross_attn_cond=cross_attn_cond, prepend_cond=prepend_cond, prepend_cond_mask=prepend_cond_mask, global_cond=global_cond, **kwargs) - - if cfg_scale != 1.0: - cond_logits, uncond_logits = logits.chunk(2, dim=0) - - logits = uncond_logits + (cond_logits - uncond_logits) * cfg_scale - - logits = rearrange(logits, "b n s c -> b n c s") # [batch, num_quantizers, codebook_size, seq_len] - - # Grab the logits for the last step - logits = logits[:, :, :, -1] # [batch, num_quantizers, codebook_size] - - # Apply top-k or top-p sampling - - if temp > 0: - probs = torch.softmax(logits / temp, dim=-1) - - if top_p > 0.0: - next_token = sample_top_p(probs, p=top_p) - elif top_k > 0: - next_token = sample_top_k(probs, k=top_k) - else: - next_token = multinomial(probs, num_samples=1) - - else: - next_token = torch.argmax(logits, dim=-1, keepdim=True) # [batch, num_quantizers, 1] - - return next_token - - @torch.no_grad() - def generate( - self, - max_gen_len: int = 256, - batch_size: tp.Optional[int] = None, - init_data: tp.Optional[torch.Tensor] = None, - conditioning: tp.Optional[tp.Dict[str, tp.Any]] = None, - conditioning_tensors: tp.Optional[tp.Dict[str, tp.Any]] = None, - callback: tp.Optional[tp.Callable[[int, int], None]] = None, - use_cache: bool = True, - cfg_scale: float = 1.0, - **kwargs - ): - device = next(self.parameters()).device - - if conditioning_tensors is None and conditioning is not None: - # Convert conditioning inputs to conditioning tensors - conditioning_tensors = self.conditioner(conditioning, device) - - # Check that batch size is consistent across inputs - possible_batch_sizes = [] - - if batch_size is not None: - possible_batch_sizes.append(batch_size) - elif init_data is not None: - possible_batch_sizes.append(init_data.shape[0]) - elif conditioning_tensors is not None: - # Assume that the first conditioning tensor has the batch dimension - possible_batch_sizes.append(conditioning_tensors[list(conditioning_tensors.keys())[0]][0].shape[0]) - else: - possible_batch_sizes.append(1) - - assert [x == possible_batch_sizes[0] for x in possible_batch_sizes], "Batch size must be consistent across inputs" - - batch_size = possible_batch_sizes[0] - - if init_data is None: - # Initialize with zeros - assert batch_size > 0 - init_data = torch.zeros((batch_size, self.num_quantizers, 0), device=device, dtype=torch.long) - - batch_size, num_quantizers, seq_len = init_data.shape - - start_offset = seq_len - assert start_offset < max_gen_len, "init data longer than max gen length" - - pattern = self.lm.pattern_provider.get_pattern(max_gen_len) - - unknown_token = -1 - - # Initialize the generated codes with the init data, padded with unknown tokens - gen_codes = torch.full((batch_size, num_quantizers, max_gen_len), unknown_token, device=device, dtype=torch.long) - gen_codes[:, :, :start_offset] = init_data # [batch, num_quantizers, max_gen_len] - - gen_sequence, _, mask = pattern.build_pattern_sequence(gen_codes, self.lm.masked_token_id) # [batch, num_quantizers, gen_sequence_len] - - start_offset_sequence = pattern.get_first_step_with_timesteps(start_offset) - assert start_offset_sequence is not None - - # Generation - prev_offset = 0 - gen_sequence_len = gen_sequence.shape[-1] - - # Reset generation cache - if use_cache and self.lm.backbone.use_generation_cache: - self.lm.backbone.reset_generation_cache(max_gen_len, batch_size if cfg_scale == 1.0 else batch_size * 2) - - for offset in trange(start_offset_sequence, gen_sequence_len): - - # Get the full sequence up to the current offset - curr_sequence = gen_sequence[..., prev_offset:offset] - - next_token = self._sample_next_token( - curr_sequence, - conditioning_tensors=conditioning_tensors, - use_cache=use_cache, - cfg_scale=cfg_scale, - **kwargs - ) - - valid_mask = mask[..., offset:offset+1].expand(batch_size, -1, -1) - next_token[~valid_mask] = self.lm.masked_token_id - - # Update the generated sequence with the next token - gen_sequence[..., offset:offset+1] = torch.where( - gen_sequence[..., offset:offset+1] == unknown_token, - next_token, - gen_sequence[..., offset:offset+1] - ) - - if use_cache and self.lm.backbone.use_generation_cache: - # Only update the offset if caching is being used - prev_offset = offset - - self.lm.backbone.update_generation_cache(offset) - - if callback is not None: - # Callback to report progress - # Pass in the offset relative to the start of the sequence, and the length of the current sequence - callback(1 + offset - start_offset_sequence, gen_sequence_len - start_offset_sequence) - - assert not (gen_sequence == unknown_token).any(), "Unknown tokens in generated sequence" - - out_codes, _, out_mask = pattern.revert_pattern_sequence(gen_sequence, special_token=unknown_token) - - # sanity checks over the returned codes and corresponding masks - assert (out_codes[..., :max_gen_len] != unknown_token).all() - assert (out_mask[..., :max_gen_len] == 1).all() - - #out_codes = out_codes[..., 0:max_gen_len] - - return out_codes - - - def generate_audio( - self, - **kwargs - ): - """ - Generate audio from a batch of codes - """ - - codes = self.generate(**kwargs) - - audio = self.pretransform.decode_tokens(codes) - - return audio - - -def create_audio_lm_continuous_from_config(config): - model_config = config.get('model', None) - assert model_config is not None, 'model config must be specified in config' - - sample_rate = config.get('sample_rate', None) - assert sample_rate is not None, "Must specify sample_rate in config" - - lm_config = model_config.get('lm', None) - assert lm_config is not None, 'lm config must be specified in model config' - - - - pretransform_config = model_config.get("pretransform", None) - - if pretransform is not None: - pretransform = create_pretransform_from_config(pretransform, sample_rate) - min_input_length = pretransform.downsampling_ratio - else: - min_input_length = 1 - - - conditioning_config = model_config.get('conditioning', None) - - conditioner = None - if conditioning_config is not None: - conditioner = create_multi_conditioner_from_conditioning_config(conditioning_config) - - cross_attn_cond_ids = lm_config.get('cross_attention_cond_ids', []) - prepend_cond_ids = lm_config.get('prepend_cond_ids', []) - global_cond_ids = lm_config.get('global_cond_ids', []) - - lm_type = lm_config.get("type", None) - lm_model_config = lm_config.get("config", None) - - assert lm_type is not None, "Must specify lm type in lm config" - assert lm_model_config is not None, "Must specify lm model config in lm config" - - if lm_type == "x-transformers": - backbone = XTransformersAudioLMBackbone(**lm_model_config) - elif lm_type == "continuous_transformer": - backbone = ContinuousTransformerAudioLMBackbone(**lm_model_config) - else: - raise NotImplementedError(f"Unrecognized lm type {lm_type}") - - lm = AudioLanguageModel( - pattern_provider=pattern_provider, - backbone=backbone, - num_quantizers=pretransform.num_quantizers, - codebook_size=pretransform.codebook_size - ) - - diff_config = model_config.get("diffusion", None) - diffusion_model = DiTWrapper(**diff_config) - - cross_attention_ids = diffusion_config.get('cross_attention_cond_ids', []) - add_cond_ids = diffusion_config.get('add_cond_ids', []) - global_cond_ids = diffusion_config.get('global_cond_ids', []) - input_concat_ids = diffusion_config.get('input_concat_ids', []) - prepend_cond_ids = diffusion_config.get('prepend_cond_ids', []) - - diff = ConditionedDiffusionModelWrapper( - diffusion_model, - conditioner=None, - min_input_length=min_input_length, - sample_rate=sample_rate, - cross_attn_cond_ids=cross_attention_ids, - global_cond_ids=global_cond_ids, - input_concat_ids=input_concat_ids, - prepend_cond_ids=prepend_cond_ids, - add_cond_ids=add_cond_ids, - pretransform=pretransform, - io_channels=2, - ) - - - model = AudioLanguageModelWrapper( - pretransform=pretransform, - lm=lm, - diff=diff, - conditioner=conditioner, - sample_rate=sample_rate, - min_input_length=min_input_length, - cross_attn_cond_ids=cross_attn_cond_ids, - prepend_cond_ids=prepend_cond_ids, - global_cond_ids=global_cond_ids - ) - - return model \ No newline at end of file diff --git a/think_sound/models/mmmodules/__init__.py b/think_sound/models/mmmodules/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/think_sound/models/mmmodules/__pycache__/__init__.cpython-310.pyc b/think_sound/models/mmmodules/__pycache__/__init__.cpython-310.pyc deleted file mode 100644 index d9c6dc76969b971214d33f622aef5e6c05c78e23..0000000000000000000000000000000000000000 Binary files a/think_sound/models/mmmodules/__pycache__/__init__.cpython-310.pyc and /dev/null differ diff --git a/think_sound/models/mmmodules/__pycache__/__init__.cpython-39.pyc b/think_sound/models/mmmodules/__pycache__/__init__.cpython-39.pyc deleted file mode 100644 index f4bde87364e1f4176f489cc3015cbd1cdcf7446c..0000000000000000000000000000000000000000 Binary files a/think_sound/models/mmmodules/__pycache__/__init__.cpython-39.pyc and /dev/null differ diff --git a/think_sound/models/mmmodules/ext/__init__.py b/think_sound/models/mmmodules/ext/__init__.py deleted file mode 100644 index 8b137891791fe96927ad78e64b0aad7bded08bdc..0000000000000000000000000000000000000000 --- a/think_sound/models/mmmodules/ext/__init__.py +++ /dev/null @@ -1 +0,0 @@ - diff --git a/think_sound/models/mmmodules/ext/__pycache__/__init__.cpython-310.pyc b/think_sound/models/mmmodules/ext/__pycache__/__init__.cpython-310.pyc deleted file mode 100644 index b85fdf9a28490bf9b437551f7a1634092204900e..0000000000000000000000000000000000000000 Binary files a/think_sound/models/mmmodules/ext/__pycache__/__init__.cpython-310.pyc and /dev/null differ diff --git a/think_sound/models/mmmodules/ext/__pycache__/__init__.cpython-39.pyc b/think_sound/models/mmmodules/ext/__pycache__/__init__.cpython-39.pyc deleted file mode 100644 index 42fe0fad8c356a3962d193a463180a237f127c90..0000000000000000000000000000000000000000 Binary files a/think_sound/models/mmmodules/ext/__pycache__/__init__.cpython-39.pyc and /dev/null differ diff --git a/think_sound/models/mmmodules/ext/__pycache__/rotary_embeddings.cpython-310.pyc b/think_sound/models/mmmodules/ext/__pycache__/rotary_embeddings.cpython-310.pyc deleted file mode 100644 index ae3f4750b54aa26378a8a542c4d11355c1024790..0000000000000000000000000000000000000000 Binary files a/think_sound/models/mmmodules/ext/__pycache__/rotary_embeddings.cpython-310.pyc and /dev/null differ diff --git a/think_sound/models/mmmodules/ext/__pycache__/rotary_embeddings.cpython-39.pyc b/think_sound/models/mmmodules/ext/__pycache__/rotary_embeddings.cpython-39.pyc deleted file mode 100644 index 7a19ebe09f3893f90cb4f1ddfc74af285a8a1cb1..0000000000000000000000000000000000000000 Binary files a/think_sound/models/mmmodules/ext/__pycache__/rotary_embeddings.cpython-39.pyc and /dev/null differ diff --git a/think_sound/models/mmmodules/ext/rotary_embeddings.py b/think_sound/models/mmmodules/ext/rotary_embeddings.py deleted file mode 100644 index 1ea9d56278cb68b7577ed13148227c30ed98fd02..0000000000000000000000000000000000000000 --- a/think_sound/models/mmmodules/ext/rotary_embeddings.py +++ /dev/null @@ -1,35 +0,0 @@ -from typing import Union - -import torch -from einops import rearrange -from torch import Tensor - -# Ref: https://github.com/black-forest-labs/flux/blob/main/src/flux/math.py -# Ref: https://github.com/lucidrains/rotary-embedding-torch - - -def compute_rope_rotations(length: int, - dim: int, - theta: int, - *, - freq_scaling: float = 1.0, - device: Union[torch.device, str] = 'cpu') -> Tensor: - assert dim % 2 == 0 - - with torch.amp.autocast(device_type='cuda', enabled=False): - pos = torch.arange(length, dtype=torch.float32, device=device) - freqs = 1.0 / (theta**(torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)) - freqs *= freq_scaling - - rot = torch.einsum('..., f -> ... f', pos, freqs) - rot = torch.stack([torch.cos(rot), -torch.sin(rot), torch.sin(rot), torch.cos(rot)], dim=-1) - rot = rearrange(rot, 'n d (i j) -> 1 n d i j', i=2, j=2) - return rot - - -def apply_rope(x: Tensor, rot: Tensor) -> tuple[Tensor, Tensor]: - with torch.amp.autocast(device_type='cuda', enabled=False): - _x = x.float() - _x = _x.view(*_x.shape[:-1], -1, 1, 2) - x_out = rot[..., 0] * _x[..., 0] + rot[..., 1] * _x[..., 1] - return x_out.reshape(*x.shape).to(dtype=x.dtype) diff --git a/think_sound/models/mmmodules/ext/stft_converter.py b/think_sound/models/mmmodules/ext/stft_converter.py deleted file mode 100644 index 62922067ef3b1d3b8727ec39e7d664ccb304d9fe..0000000000000000000000000000000000000000 --- a/think_sound/models/mmmodules/ext/stft_converter.py +++ /dev/null @@ -1,183 +0,0 @@ -# Reference: # https://github.com/bytedance/Make-An-Audio-2 - -import torch -import torch.nn as nn -import torchaudio -from einops import rearrange -from librosa.filters import mel as librosa_mel_fn - - -def dynamic_range_compression_torch(x, C=1, clip_val=1e-5, norm_fn=torch.log10): - return norm_fn(torch.clamp(x, min=clip_val) * C) - - -def spectral_normalize_torch(magnitudes, norm_fn): - output = dynamic_range_compression_torch(magnitudes, norm_fn=norm_fn) - return output - - -class STFTConverter(nn.Module): - - def __init__( - self, - *, - sampling_rate: float = 16_000, - n_fft: int = 1024, - num_mels: int = 128, - hop_size: int = 256, - win_size: int = 1024, - fmin: float = 0, - fmax: float = 8_000, - norm_fn=torch.log, - ): - super().__init__() - self.sampling_rate = sampling_rate - self.n_fft = n_fft - self.num_mels = num_mels - self.hop_size = hop_size - self.win_size = win_size - self.fmin = fmin - self.fmax = fmax - self.norm_fn = norm_fn - - mel = librosa_mel_fn(sr=self.sampling_rate, - n_fft=self.n_fft, - n_mels=self.num_mels, - fmin=self.fmin, - fmax=self.fmax) - mel_basis = torch.from_numpy(mel).float() - hann_window = torch.hann_window(self.win_size) - - self.register_buffer('mel_basis', mel_basis) - self.register_buffer('hann_window', hann_window) - - @property - def device(self): - return self.hann_window.device - - def forward(self, waveform: torch.Tensor) -> torch.Tensor: - # input: batch_size * length - bs = waveform.shape[0] - waveform = waveform.clamp(min=-1., max=1.) - - spec = torch.stft(waveform, - self.n_fft, - hop_length=self.hop_size, - win_length=self.win_size, - window=self.hann_window, - center=True, - pad_mode='reflect', - normalized=False, - onesided=True, - return_complex=True) - - spec = torch.view_as_real(spec) - # print('After stft', spec.shape, spec.min(), spec.max(), spec.mean()) - - power = spec.pow(2).sum(-1) - angle = torch.atan2(spec[..., 1], spec[..., 0]) - - print('power', power.shape, power.min(), power.max(), power.mean()) - print('angle', angle.shape, angle.min(), angle.max(), angle.mean()) - - # print('mel', self.mel_basis.shape, self.mel_basis.min(), self.mel_basis.max(), - # self.mel_basis.mean()) - - # spec = rearrange(spec, 'b f t c -> (b c) f t') - - # spec = self.mel_transform(spec) - - # spec = torch.matmul(self.mel_basis, spec) - - # print('After mel', spec.shape, spec.min(), spec.max(), spec.mean()) - - # spec = spectral_normalize_torch(spec, self.norm_fn) - - # print('After norm', spec.shape, spec.min(), spec.max(), spec.mean()) - - # compute magnitude - # magnitude = torch.sqrt((spec**2).sum(-1)) - # normalize by magnitude - # scaled_magnitude = torch.log10(magnitude.clamp(min=1e-5)) * 10 - # spec = spec / magnitude.unsqueeze(-1) * scaled_magnitude.unsqueeze(-1) - - # power = torch.log10(power.clamp(min=1e-5)) * 10 - power = torch.log10(power.clamp(min=1e-5)) - - print('After scaling', power.shape, power.min(), power.max(), power.mean()) - - spec = torch.stack([power, angle], dim=-1) - - # spec = rearrange(spec, '(b c) f t -> b c f t', b=bs) - spec = rearrange(spec, 'b f t c -> b c f t', b=bs) - - # spec[:, :, 400:] = 0 - - return spec - - def invert(self, spec: torch.Tensor, length: int) -> torch.Tensor: - bs = spec.shape[0] - - # spec = rearrange(spec, 'b c f t -> (b c) f t') - # print(spec.shape, self.mel_basis.shape) - # spec = torch.linalg.lstsq(self.mel_basis.unsqueeze(0), spec).solution - # spec = torch.linalg.pinv(self.mel_basis.unsqueeze(0)) @ spec - - # spec = self.invmel_transform(spec) - - spec = rearrange(spec, 'b c f t -> b f t c', b=bs).contiguous() - - # spec[..., 0] = 10**(spec[..., 0] / 10) - - power = spec[..., 0] - power = 10**power - - # print('After unscaling', spec[..., 0].shape, spec[..., 0].min(), spec[..., 0].max(), - # spec[..., 0].mean()) - - unit_vector = torch.stack([ - torch.cos(spec[..., 1]), - torch.sin(spec[..., 1]), - ], dim=-1) - - spec = torch.sqrt(power) * unit_vector - - # spec = rearrange(spec, '(b c) f t -> b f t c', b=bs).contiguous() - spec = torch.view_as_complex(spec) - - waveform = torch.istft( - spec, - self.n_fft, - length=length, - hop_length=self.hop_size, - win_length=self.win_size, - window=self.hann_window, - center=True, - normalized=False, - onesided=True, - return_complex=False, - ) - - return waveform - - -if __name__ == '__main__': - - converter = STFTConverter(sampling_rate=16000) - - signal = torchaudio.load('./output/ZZ6GRocWW38_000090.wav')[0] - # resample signal at 44100 Hz - # signal = torchaudio.transforms.Resample(16_000, 44_100)(signal) - - L = signal.shape[1] - print('Input signal', signal.shape) - spec = converter(signal) - - print('Final spec', spec.shape) - - signal_recon = converter.invert(spec, length=L) - print('Output signal', signal_recon.shape, signal_recon.min(), signal_recon.max(), - signal_recon.mean()) - - print('MSE', torch.nn.functional.mse_loss(signal, signal_recon)) - torchaudio.save('./output/ZZ6GRocWW38_000090_recon.wav', signal_recon, 16000) diff --git a/think_sound/models/mmmodules/ext/stft_converter_mel.py b/think_sound/models/mmmodules/ext/stft_converter_mel.py deleted file mode 100644 index f6b32d4cb9a23cd74f723e7d8307fd82fa1abba0..0000000000000000000000000000000000000000 --- a/think_sound/models/mmmodules/ext/stft_converter_mel.py +++ /dev/null @@ -1,234 +0,0 @@ -# Reference: # https://github.com/bytedance/Make-An-Audio-2 - -import torch -import torch.nn as nn -import torchaudio -from einops import rearrange -from librosa.filters import mel as librosa_mel_fn - - -def dynamic_range_compression_torch(x, C=1, clip_val=1e-5, norm_fn=torch.log10): - return norm_fn(torch.clamp(x, min=clip_val) * C) - - -def spectral_normalize_torch(magnitudes, norm_fn): - output = dynamic_range_compression_torch(magnitudes, norm_fn=norm_fn) - return output - - -class STFTConverter(nn.Module): - - def __init__( - self, - *, - sampling_rate: float = 16_000, - n_fft: int = 1024, - num_mels: int = 128, - hop_size: int = 256, - win_size: int = 1024, - fmin: float = 0, - fmax: float = 8_000, - norm_fn=torch.log, - ): - super().__init__() - self.sampling_rate = sampling_rate - self.n_fft = n_fft - self.num_mels = num_mels - self.hop_size = hop_size - self.win_size = win_size - self.fmin = fmin - self.fmax = fmax - self.norm_fn = norm_fn - - mel = librosa_mel_fn(sr=self.sampling_rate, - n_fft=self.n_fft, - n_mels=self.num_mels, - fmin=self.fmin, - fmax=self.fmax) - mel_basis = torch.from_numpy(mel).float() - hann_window = torch.hann_window(self.win_size) - - self.register_buffer('mel_basis', mel_basis) - self.register_buffer('hann_window', hann_window) - - @property - def device(self): - return self.hann_window.device - - def forward(self, waveform: torch.Tensor) -> torch.Tensor: - # input: batch_size * length - bs = waveform.shape[0] - waveform = waveform.clamp(min=-1., max=1.) - - spec = torch.stft(waveform, - self.n_fft, - hop_length=self.hop_size, - win_length=self.win_size, - window=self.hann_window, - center=True, - pad_mode='reflect', - normalized=False, - onesided=True, - return_complex=True) - - spec = torch.view_as_real(spec) - # print('After stft', spec.shape, spec.min(), spec.max(), spec.mean()) - - power = (spec.pow(2).sum(-1))**(0.5) - angle = torch.atan2(spec[..., 1], spec[..., 0]) - - print('power 1', power.shape, power.min(), power.max(), power.mean()) - print('angle 1', angle.shape, angle.min(), angle.max(), angle.mean(), angle[:, :2, :2]) - - # print('mel', self.mel_basis.shape, self.mel_basis.min(), self.mel_basis.max(), - # self.mel_basis.mean()) - - # spec = self.mel_transform(spec) - - # power = torch.matmul(self.mel_basis, power) - - spec = rearrange(spec, 'b f t c -> (b c) f t') - spec = self.mel_basis.unsqueeze(0) @ spec - spec = rearrange(spec, '(b c) f t -> b f t c', b=bs) - - power = (spec.pow(2).sum(-1))**(0.5) - angle = torch.atan2(spec[..., 1], spec[..., 0]) - - print('power', power.shape, power.min(), power.max(), power.mean()) - print('angle', angle.shape, angle.min(), angle.max(), angle.mean(), angle[:, :2, :2]) - - # print('After mel', spec.shape, spec.min(), spec.max(), spec.mean()) - - # spec = spectral_normalize_torch(spec, self.norm_fn) - - # print('After norm', spec.shape, spec.min(), spec.max(), spec.mean()) - - # compute magnitude - # magnitude = torch.sqrt((spec**2).sum(-1)) - # normalize by magnitude - # scaled_magnitude = torch.log10(magnitude.clamp(min=1e-5)) * 10 - # spec = spec / magnitude.unsqueeze(-1) * scaled_magnitude.unsqueeze(-1) - - # power = torch.log10(power.clamp(min=1e-5)) * 10 - power = torch.log10(power.clamp(min=1e-8)) - - print('After scaling', power.shape, power.min(), power.max(), power.mean()) - - # spec = torch.stack([power, angle], dim=-1) - - # spec = rearrange(spec, '(b c) f t -> b c f t', b=bs) - # spec = rearrange(spec, 'b f t c -> b c f t', b=bs) - - # spec[:, :, 400:] = 0 - - return power, angle - # return spec[..., 0], spec[..., 1] - - def invert(self, spec: torch.Tensor, length: int) -> torch.Tensor: - - power, angle = spec - - bs = power.shape[0] - - # spec = rearrange(spec, 'b c f t -> (b c) f t') - # print(spec.shape, self.mel_basis.shape) - # spec = torch.linalg.lstsq(self.mel_basis.unsqueeze(0), spec).solution - # spec = torch.linalg.pinv(self.mel_basis.unsqueeze(0)) @ spec - - # spec = self.invmel_transform(spec) - - # spec = rearrange(spec, 'b c f t -> b f t c', b=bs).contiguous() - - # spec[..., 0] = 10**(spec[..., 0] / 10) - - # power = spec[..., 0] - power = 10**power - - # print('After unscaling', spec[..., 0].shape, spec[..., 0].min(), spec[..., 0].max(), - # spec[..., 0].mean()) - - unit_vector = torch.stack([ - torch.cos(angle), - torch.sin(angle), - ], dim=-1) - - spec = power.unsqueeze(-1) * unit_vector - - # power = torch.linalg.lstsq(self.mel_basis.unsqueeze(0), power).solution - spec = rearrange(spec, 'b f t c -> (b c) f t') - spec = torch.linalg.pinv(self.mel_basis.unsqueeze(0)) @ spec - # spec = torch.linalg.lstsq(self.mel_basis.unsqueeze(0), spec).solution - spec = rearrange(spec, '(b c) f t -> b f t c', b=bs).contiguous() - - power = (spec.pow(2).sum(-1))**(0.5) - angle = torch.atan2(spec[..., 1], spec[..., 0]) - - print('power 2', power.shape, power.min(), power.max(), power.mean()) - print('angle 2', angle.shape, angle.min(), angle.max(), angle.mean(), angle[:, :2, :2]) - - # spec = rearrange(spec, '(b c) f t -> b f t c', b=bs).contiguous() - spec = torch.view_as_complex(spec) - - waveform = torch.istft( - spec, - self.n_fft, - length=length, - hop_length=self.hop_size, - win_length=self.win_size, - window=self.hann_window, - center=True, - normalized=False, - onesided=True, - return_complex=False, - ) - - return waveform - - -if __name__ == '__main__': - - converter = STFTConverter(sampling_rate=16000) - - signal = torchaudio.load('./output/ZZ6GRocWW38_000090.wav')[0] - # resample signal at 44100 Hz - # signal = torchaudio.transforms.Resample(16_000, 44_100)(signal) - - L = signal.shape[1] - print('Input signal', signal.shape) - spec = converter(signal) - - power, angle = spec - - # print(power.shape, angle.shape) - # print(power, power.min(), power.max(), power.mean()) - # power = power.clamp(-1, 1) - # angle = angle.clamp(-1, 1) - - import matplotlib.pyplot as plt - - # Visualize power - plt.figure() - plt.imshow(power[0].detach().numpy(), aspect='auto', origin='lower') - plt.colorbar() - plt.title('Power') - plt.xlabel('Time') - plt.ylabel('Frequency') - plt.savefig('./output/power.png') - - # Visualize angle - plt.figure() - plt.imshow(angle[0].detach().numpy(), aspect='auto', origin='lower') - plt.colorbar() - plt.title('Angle') - plt.xlabel('Time') - plt.ylabel('Frequency') - plt.savefig('./output/angle.png') - - # print('Final spec', spec.shape) - - signal_recon = converter.invert(spec, length=L) - print('Output signal', signal_recon.shape, signal_recon.min(), signal_recon.max(), - signal_recon.mean()) - - print('MSE', torch.nn.functional.mse_loss(signal, signal_recon)) - torchaudio.save('./output/ZZ6GRocWW38_000090_recon.wav', signal_recon, 16000) diff --git a/think_sound/models/mmmodules/model/__init__.py b/think_sound/models/mmmodules/model/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/think_sound/models/mmmodules/model/__pycache__/__init__.cpython-310.pyc b/think_sound/models/mmmodules/model/__pycache__/__init__.cpython-310.pyc deleted file mode 100644 index 4401799caf1e7efffce5ba20eee3e4e6c4cf18ed..0000000000000000000000000000000000000000 Binary files a/think_sound/models/mmmodules/model/__pycache__/__init__.cpython-310.pyc and /dev/null differ diff --git a/think_sound/models/mmmodules/model/__pycache__/__init__.cpython-39.pyc b/think_sound/models/mmmodules/model/__pycache__/__init__.cpython-39.pyc deleted file mode 100644 index fe6460de12d696b4d8a4f8f9f6af4713a12b8d2e..0000000000000000000000000000000000000000 Binary files a/think_sound/models/mmmodules/model/__pycache__/__init__.cpython-39.pyc and /dev/null differ diff --git a/think_sound/models/mmmodules/model/__pycache__/embeddings.cpython-310.pyc b/think_sound/models/mmmodules/model/__pycache__/embeddings.cpython-310.pyc deleted file mode 100644 index d8182934cbcf7d20c9145b17a9295f62bf0b03bf..0000000000000000000000000000000000000000 Binary files a/think_sound/models/mmmodules/model/__pycache__/embeddings.cpython-310.pyc and /dev/null differ diff --git a/think_sound/models/mmmodules/model/__pycache__/embeddings.cpython-39.pyc b/think_sound/models/mmmodules/model/__pycache__/embeddings.cpython-39.pyc deleted file mode 100644 index f9f111166db2dd4442d2227103547d9cf522f6d7..0000000000000000000000000000000000000000 Binary files a/think_sound/models/mmmodules/model/__pycache__/embeddings.cpython-39.pyc and /dev/null differ diff --git a/think_sound/models/mmmodules/model/__pycache__/low_level.cpython-310.pyc b/think_sound/models/mmmodules/model/__pycache__/low_level.cpython-310.pyc deleted file mode 100644 index ca349e5522944bf0a9fe9a1d162fe9bd3cbb3eb5..0000000000000000000000000000000000000000 Binary files a/think_sound/models/mmmodules/model/__pycache__/low_level.cpython-310.pyc and /dev/null differ diff --git a/think_sound/models/mmmodules/model/__pycache__/low_level.cpython-39.pyc b/think_sound/models/mmmodules/model/__pycache__/low_level.cpython-39.pyc deleted file mode 100644 index fb6982d89c721052d73cef49f074fbb58456d388..0000000000000000000000000000000000000000 Binary files a/think_sound/models/mmmodules/model/__pycache__/low_level.cpython-39.pyc and /dev/null differ diff --git a/think_sound/models/mmmodules/model/__pycache__/transformer_layers.cpython-310.pyc b/think_sound/models/mmmodules/model/__pycache__/transformer_layers.cpython-310.pyc deleted file mode 100644 index 89842e8912466aa6e1659f3a4c77cc975da0f0f8..0000000000000000000000000000000000000000 Binary files a/think_sound/models/mmmodules/model/__pycache__/transformer_layers.cpython-310.pyc and /dev/null differ diff --git a/think_sound/models/mmmodules/model/__pycache__/transformer_layers.cpython-39.pyc b/think_sound/models/mmmodules/model/__pycache__/transformer_layers.cpython-39.pyc deleted file mode 100644 index d1aa3d468c975f0e977391704a453303d82305f7..0000000000000000000000000000000000000000 Binary files a/think_sound/models/mmmodules/model/__pycache__/transformer_layers.cpython-39.pyc and /dev/null differ diff --git a/think_sound/models/mmmodules/model/flow_matching.py b/think_sound/models/mmmodules/model/flow_matching.py deleted file mode 100644 index e7c65dece6dec746db999092606f4384d084d119..0000000000000000000000000000000000000000 --- a/think_sound/models/mmmodules/model/flow_matching.py +++ /dev/null @@ -1,71 +0,0 @@ -import logging -from typing import Callable, Optional - -import torch -from torchdiffeq import odeint - -log = logging.getLogger() - - -# Partially from https://github.com/gle-bellier/flow-matching -class FlowMatching: - - def __init__(self, min_sigma: float = 0.0, inference_mode='euler', num_steps: int = 25): - # inference_mode: 'euler' or 'adaptive' - # num_steps: number of steps in the euler inference mode - super().__init__() - self.min_sigma = min_sigma - self.inference_mode = inference_mode - self.num_steps = num_steps - - # self.fm = ExactOptimalTransportConditionalFlowMatcher(sigma=min_sigma) - - assert self.inference_mode in ['euler', 'adaptive'] - if self.inference_mode == 'adaptive' and num_steps > 0: - log.info('The number of steps is ignored in adaptive inference mode ') - - def get_conditional_flow(self, x0: torch.Tensor, x1: torch.Tensor, - t: torch.Tensor) -> torch.Tensor: - # which is psi_t(x), eq 22 in flow matching for generative models - t = t[:, None, None].expand_as(x0) - return (1 - (1 - self.min_sigma) * t) * x0 + t * x1 - - def loss(self, predicted_v: torch.Tensor, x0: torch.Tensor, x1: torch.Tensor) -> torch.Tensor: - # return the mean error without reducing the batch dimension - reduce_dim = list(range(1, len(predicted_v.shape))) - target_v = x1 - (1 - self.min_sigma) * x0 - return (predicted_v - target_v).pow(2).mean(dim=reduce_dim) - - def get_x0_xt_c( - self, - x1: torch.Tensor, - t: torch.Tensor, - Cs: list[torch.Tensor], - generator: Optional[torch.Generator] = None - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - x0 = torch.empty_like(x1).normal_(generator=generator) - - xt = self.get_conditional_flow(x0, x1, t) - return x0, x1, xt, Cs - - def to_prior(self, fn: Callable, x1: torch.Tensor) -> torch.Tensor: - return self.run_t0_to_t1(fn, x1, 1, 0) - - def to_data(self, fn: Callable, x0: torch.Tensor) -> torch.Tensor: - return self.run_t0_to_t1(fn, x0, 0, 1) - - def run_t0_to_t1(self, fn: Callable, x0: torch.Tensor, t0: float, t1: float) -> torch.Tensor: - # fn: a function that takes (t, x) and returns the direction x0->x1 - - if self.inference_mode == 'adaptive': - return odeint(fn, x0, torch.tensor([t0, t1], device=x0.device, dtype=x0.dtype)) - elif self.inference_mode == 'euler': - x = x0 - steps = torch.linspace(t0, t1 - self.min_sigma, self.num_steps + 1) - for ti, t in enumerate(steps[:-1]): - flow = fn(t, x) - next_t = steps[ti + 1] - dt = next_t - t - x = x + dt * flow - - return x diff --git a/think_sound/models/mmmodules/model/low_level.py b/think_sound/models/mmmodules/model/low_level.py deleted file mode 100644 index c8326a8bec99f1be08b92e76fda4b59e777b39d2..0000000000000000000000000000000000000000 --- a/think_sound/models/mmmodules/model/low_level.py +++ /dev/null @@ -1,95 +0,0 @@ -import torch -from torch import nn -from torch.nn import functional as F - - -class ChannelLastConv1d(nn.Conv1d): - - def forward(self, x: torch.Tensor) -> torch.Tensor: - x = x.permute(0, 2, 1) - x = super().forward(x) - x = x.permute(0, 2, 1) - return x - - -# https://github.com/Stability-AI/sd3-ref -class MLP(nn.Module): - - def __init__( - self, - dim: int, - hidden_dim: int, - multiple_of: int = 256, - ): - """ - Initialize the FeedForward module. - - Args: - dim (int): Input dimension. - hidden_dim (int): Hidden dimension of the feedforward layer. - multiple_of (int): Value to ensure hidden dimension is a multiple of this value. - - Attributes: - w1 (ColumnParallelLinear): Linear transformation for the first layer. - w2 (RowParallelLinear): Linear transformation for the second layer. - w3 (ColumnParallelLinear): Linear transformation for the third layer. - - """ - super().__init__() - hidden_dim = int(2 * hidden_dim / 3) - hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) - - self.w1 = nn.Linear(dim, hidden_dim, bias=False) - self.w2 = nn.Linear(hidden_dim, dim, bias=False) - self.w3 = nn.Linear(dim, hidden_dim, bias=False) - - def forward(self, x): - return self.w2(F.silu(self.w1(x)) * self.w3(x)) - - -class ConvMLP(nn.Module): - - def __init__( - self, - dim: int, - hidden_dim: int, - multiple_of: int = 256, - kernel_size: int = 3, - padding: int = 1, - ): - """ - Initialize the FeedForward module. - - Args: - dim (int): Input dimension. - hidden_dim (int): Hidden dimension of the feedforward layer. - multiple_of (int): Value to ensure hidden dimension is a multiple of this value. - - Attributes: - w1 (ColumnParallelLinear): Linear transformation for the first layer. - w2 (RowParallelLinear): Linear transformation for the second layer. - w3 (ColumnParallelLinear): Linear transformation for the third layer. - - """ - super().__init__() - hidden_dim = int(2 * hidden_dim / 3) - hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) - - self.w1 = ChannelLastConv1d(dim, - hidden_dim, - bias=False, - kernel_size=kernel_size, - padding=padding) - self.w2 = ChannelLastConv1d(hidden_dim, - dim, - bias=False, - kernel_size=kernel_size, - padding=padding) - self.w3 = ChannelLastConv1d(dim, - hidden_dim, - bias=False, - kernel_size=kernel_size, - padding=padding) - - def forward(self, x): - return self.w2(F.silu(self.w1(x)) * self.w3(x)) diff --git a/think_sound/models/mmmodules/model/networks.py b/think_sound/models/mmmodules/model/networks.py deleted file mode 100644 index 8272a896b358f5db681d1462c4189d671b916d76..0000000000000000000000000000000000000000 --- a/think_sound/models/mmmodules/model/networks.py +++ /dev/null @@ -1,470 +0,0 @@ -import logging -from dataclasses import dataclass -from typing import Optional - -import torch -import torch.nn as nn -import torch.nn.functional as F - -from mmaudio.ext.rotary_embeddings import compute_rope_rotations -from mmaudio.model.embeddings import TimestepEmbedder -from mmaudio.model.low_level import MLP, ChannelLastConv1d, ConvMLP -from mmaudio.model.transformer_layers import (FinalBlock, JointBlock, MMDitSingleBlock) - -log = logging.getLogger() - - -@dataclass -class PreprocessedConditions: - clip_f: torch.Tensor - sync_f: torch.Tensor - text_f: torch.Tensor - clip_f_c: torch.Tensor - text_f_c: torch.Tensor - - -# Partially from https://github.com/facebookresearch/DiT -class MMAudio(nn.Module): - - def __init__(self, - *, - latent_dim: int, - clip_dim: int, - sync_dim: int, - text_dim: int, - hidden_dim: int, - depth: int, - fused_depth: int, - num_heads: int, - mlp_ratio: float = 4.0, - latent_seq_len: int, - clip_seq_len: int, - sync_seq_len: int, - text_seq_len: int = 77, - latent_mean: Optional[torch.Tensor] = None, - latent_std: Optional[torch.Tensor] = None, - empty_string_feat: Optional[torch.Tensor] = None, - v2: bool = False) -> None: - super().__init__() - - self.v2 = v2 - self.latent_dim = latent_dim - self._latent_seq_len = latent_seq_len - self._clip_seq_len = clip_seq_len - self._sync_seq_len = sync_seq_len - self._text_seq_len = text_seq_len - self.hidden_dim = hidden_dim - self.num_heads = num_heads - - if v2: - self.audio_input_proj = nn.Sequential( - ChannelLastConv1d(latent_dim, hidden_dim, kernel_size=7, padding=3), - nn.SiLU(), - ConvMLP(hidden_dim, hidden_dim * 4, kernel_size=7, padding=3), - ) - - self.clip_input_proj = nn.Sequential( - nn.Linear(clip_dim, hidden_dim), - nn.SiLU(), - ConvMLP(hidden_dim, hidden_dim * 4, kernel_size=3, padding=1), - ) - - self.sync_input_proj = nn.Sequential( - ChannelLastConv1d(sync_dim, hidden_dim, kernel_size=7, padding=3), - nn.SiLU(), - ConvMLP(hidden_dim, hidden_dim * 4, kernel_size=3, padding=1), - ) - - self.text_input_proj = nn.Sequential( - nn.Linear(text_dim, hidden_dim), - nn.SiLU(), - MLP(hidden_dim, hidden_dim * 4), - ) - else: - self.audio_input_proj = nn.Sequential( - ChannelLastConv1d(latent_dim, hidden_dim, kernel_size=7, padding=3), - nn.SELU(), - ConvMLP(hidden_dim, hidden_dim * 4, kernel_size=7, padding=3), - ) - - self.clip_input_proj = nn.Sequential( - nn.Linear(clip_dim, hidden_dim), - ConvMLP(hidden_dim, hidden_dim * 4, kernel_size=3, padding=1), - ) - - self.sync_input_proj = nn.Sequential( - ChannelLastConv1d(sync_dim, hidden_dim, kernel_size=7, padding=3), - nn.SELU(), - ConvMLP(hidden_dim, hidden_dim * 4, kernel_size=3, padding=1), - ) - - self.text_input_proj = nn.Sequential( - nn.Linear(text_dim, hidden_dim), - MLP(hidden_dim, hidden_dim * 4), - ) - - self.clip_cond_proj = nn.Linear(hidden_dim, hidden_dim) - self.text_cond_proj = nn.Linear(hidden_dim, hidden_dim) - self.global_cond_mlp = MLP(hidden_dim, hidden_dim * 4) - # each synchformer output segment has 8 feature frames - self.sync_pos_emb = nn.Parameter(torch.zeros((1, 1, 8, sync_dim))) - - self.final_layer = FinalBlock(hidden_dim, latent_dim) - - if v2: - self.t_embed = TimestepEmbedder(hidden_dim, - frequency_embedding_size=hidden_dim, - max_period=1) - else: - self.t_embed = TimestepEmbedder(hidden_dim, - frequency_embedding_size=256, - max_period=10000) - self.joint_blocks = nn.ModuleList([ - JointBlock(hidden_dim, - num_heads, - mlp_ratio=mlp_ratio, - pre_only=(i == depth - fused_depth - 1)) for i in range(depth - fused_depth) - ]) - - self.fused_blocks = nn.ModuleList([ - MMDitSingleBlock(hidden_dim, num_heads, mlp_ratio=mlp_ratio, kernel_size=3, padding=1) - for i in range(fused_depth) - ]) - - if latent_mean is None: - # these values are not meant to be used - # if you don't provide mean/std here, we should load them later from a checkpoint - assert latent_std is None - latent_mean = torch.ones(latent_dim).view(1, 1, -1).fill_(float('nan')) - latent_std = torch.ones(latent_dim).view(1, 1, -1).fill_(float('nan')) - else: - assert latent_std is not None - assert latent_mean.numel() == latent_dim, f'{latent_mean.numel()=} != {latent_dim=}' - if empty_string_feat is None: - empty_string_feat = torch.zeros((text_seq_len, text_dim)) - self.latent_mean = nn.Parameter(latent_mean.view(1, 1, -1), requires_grad=False) - self.latent_std = nn.Parameter(latent_std.view(1, 1, -1), requires_grad=False) - - self.empty_string_feat = nn.Parameter(empty_string_feat, requires_grad=False) - self.empty_clip_feat = nn.Parameter(torch.zeros(1, clip_dim), requires_grad=True) - self.empty_sync_feat = nn.Parameter(torch.zeros(1, sync_dim), requires_grad=True) - - self.initialize_weights() - self.initialize_rotations() - - def initialize_rotations(self): - base_freq = 1.0 - latent_rot = compute_rope_rotations(self._latent_seq_len, - self.hidden_dim // self.num_heads, - 10000, - freq_scaling=base_freq, - device=self.device) - clip_rot = compute_rope_rotations(self._clip_seq_len, - self.hidden_dim // self.num_heads, - 10000, - freq_scaling=base_freq * self._latent_seq_len / - self._clip_seq_len, - device=self.device) - - self.latent_rot = nn.Buffer(latent_rot, persistent=False) - self.clip_rot = nn.Buffer(clip_rot, persistent=False) - - def update_seq_lengths(self, latent_seq_len: int, clip_seq_len: int, sync_seq_len: int) -> None: - self._latent_seq_len = latent_seq_len - self._clip_seq_len = clip_seq_len - self._sync_seq_len = sync_seq_len - self.initialize_rotations() - - def initialize_weights(self): - - def _basic_init(module): - if isinstance(module, nn.Linear): - torch.nn.init.xavier_uniform_(module.weight) - if module.bias is not None: - nn.init.constant_(module.bias, 0) - - self.apply(_basic_init) - - # Initialize timestep embedding MLP: - nn.init.normal_(self.t_embed.mlp[0].weight, std=0.02) - nn.init.normal_(self.t_embed.mlp[2].weight, std=0.02) - - # Zero-out adaLN modulation layers in DiT blocks: - for block in self.joint_blocks: - nn.init.constant_(block.latent_block.adaLN_modulation[-1].weight, 0) - nn.init.constant_(block.latent_block.adaLN_modulation[-1].bias, 0) - nn.init.constant_(block.clip_block.adaLN_modulation[-1].weight, 0) - nn.init.constant_(block.clip_block.adaLN_modulation[-1].bias, 0) - nn.init.constant_(block.text_block.adaLN_modulation[-1].weight, 0) - nn.init.constant_(block.text_block.adaLN_modulation[-1].bias, 0) - for block in self.fused_blocks: - nn.init.constant_(block.adaLN_modulation[-1].weight, 0) - nn.init.constant_(block.adaLN_modulation[-1].bias, 0) - - # Zero-out output layers: - nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) - nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) - nn.init.constant_(self.final_layer.conv.weight, 0) - nn.init.constant_(self.final_layer.conv.bias, 0) - - # empty string feat shall be initialized by a CLIP encoder - nn.init.constant_(self.sync_pos_emb, 0) - nn.init.constant_(self.empty_clip_feat, 0) - nn.init.constant_(self.empty_sync_feat, 0) - - def normalize(self, x: torch.Tensor) -> torch.Tensor: - # return (x - self.latent_mean) / self.latent_std - return x.sub_(self.latent_mean).div_(self.latent_std) - - def unnormalize(self, x: torch.Tensor) -> torch.Tensor: - # return x * self.latent_std + self.latent_mean - return x.mul_(self.latent_std).add_(self.latent_mean) - - def preprocess_conditions(self, clip_f: torch.Tensor, sync_f: torch.Tensor, - text_f: torch.Tensor) -> PreprocessedConditions: - """ - cache computations that do not depend on the latent/time step - i.e., the features are reused over steps during inference - """ - assert clip_f.shape[1] == self._clip_seq_len, f'{clip_f.shape=} {self._clip_seq_len=}' - assert sync_f.shape[1] == self._sync_seq_len, f'{sync_f.shape=} {self._sync_seq_len=}' - assert text_f.shape[1] == self._text_seq_len, f'{text_f.shape=} {self._text_seq_len=}' - - bs = clip_f.shape[0] - - # B * num_segments (24) * 8 * 768 - num_sync_segments = self._sync_seq_len // 8 - sync_f = sync_f.view(bs, num_sync_segments, 8, -1) + self.sync_pos_emb - sync_f = sync_f.flatten(1, 2) # (B, VN, D) - - # extend vf to match x - clip_f = self.clip_input_proj(clip_f) # (B, VN, D) - sync_f = self.sync_input_proj(sync_f) # (B, VN, D) - text_f = self.text_input_proj(text_f) # (B, VN, D) - - # upsample the sync features to match the audio - sync_f = sync_f.transpose(1, 2) # (B, D, VN) - sync_f = F.interpolate(sync_f, size=self._latent_seq_len, mode='nearest-exact') - sync_f = sync_f.transpose(1, 2) # (B, N, D) - - # get conditional features from the clip side - clip_f_c = self.clip_cond_proj(clip_f.mean(dim=1)) # (B, D) - text_f_c = self.text_cond_proj(text_f.mean(dim=1)) # (B, D) - - return PreprocessedConditions(clip_f=clip_f, - sync_f=sync_f, - text_f=text_f, - clip_f_c=clip_f_c, - text_f_c=text_f_c) - - def predict_flow(self, latent: torch.Tensor, t: torch.Tensor, - conditions: PreprocessedConditions) -> torch.Tensor: - """ - for non-cacheable computations - """ - assert latent.shape[1] == self._latent_seq_len, f'{latent.shape=} {self._latent_seq_len=}' - - clip_f = conditions.clip_f - sync_f = conditions.sync_f - text_f = conditions.text_f - clip_f_c = conditions.clip_f_c - text_f_c = conditions.text_f_c - - latent = self.audio_input_proj(latent) # (B, N, D) - global_c = self.global_cond_mlp(clip_f_c + text_f_c) # (B, D) - - global_c = self.t_embed(t).unsqueeze(1) + global_c.unsqueeze(1) # (B, D) - extended_c = global_c + sync_f - - for block in self.joint_blocks: - latent, clip_f, text_f = block(latent, clip_f, text_f, global_c, extended_c, - self.latent_rot, self.clip_rot) # (B, N, D) - - for block in self.fused_blocks: - latent = block(latent, extended_c, self.latent_rot) - - # should be extended_c; this is a minor implementation error #55 - flow = self.final_layer(latent, extended_c) # (B, N, out_dim), remove t - return flow - - def forward(self, latent: torch.Tensor, clip_f: torch.Tensor, sync_f: torch.Tensor, - text_f: torch.Tensor, t: torch.Tensor) -> torch.Tensor: - """ - latent: (B, N, C) - vf: (B, T, C_V) - t: (B,) - """ - conditions = self.preprocess_conditions(clip_f, sync_f, text_f) - flow = self.predict_flow(latent, t, conditions) - return flow - - def get_empty_string_sequence(self, bs: int) -> torch.Tensor: - return self.empty_string_feat.unsqueeze(0).expand(bs, -1, -1) - - def get_empty_clip_sequence(self, bs: int) -> torch.Tensor: - return self.empty_clip_feat.unsqueeze(0).expand(bs, self._clip_seq_len, -1) - - def get_empty_sync_sequence(self, bs: int) -> torch.Tensor: - return self.empty_sync_feat.unsqueeze(0).expand(bs, self._sync_seq_len, -1) - - def get_empty_conditions( - self, - bs: int, - *, - negative_text_features: Optional[torch.Tensor] = None) -> PreprocessedConditions: - if negative_text_features is not None: - empty_text = negative_text_features - else: - empty_text = self.get_empty_string_sequence(1) - - empty_clip = self.get_empty_clip_sequence(1) - empty_sync = self.get_empty_sync_sequence(1) - conditions = self.preprocess_conditions(empty_clip, empty_sync, empty_text) - conditions.clip_f = conditions.clip_f.expand(bs, -1, -1) - conditions.sync_f = conditions.sync_f.expand(bs, -1, -1) - conditions.clip_f_c = conditions.clip_f_c.expand(bs, -1) - if negative_text_features is None: - conditions.text_f = conditions.text_f.expand(bs, -1, -1) - conditions.text_f_c = conditions.text_f_c.expand(bs, -1) - - return conditions - - def ode_wrapper(self, t: torch.Tensor, latent: torch.Tensor, conditions: PreprocessedConditions, - empty_conditions: PreprocessedConditions, cfg_strength: float) -> torch.Tensor: - t = t * torch.ones(len(latent), device=latent.device, dtype=latent.dtype) - - if cfg_strength < 1.0: - return self.predict_flow(latent, t, conditions) - else: - return (cfg_strength * self.predict_flow(latent, t, conditions) + - (1 - cfg_strength) * self.predict_flow(latent, t, empty_conditions)) - - def load_weights(self, src_dict) -> None: - if 't_embed.freqs' in src_dict: - del src_dict['t_embed.freqs'] - if 'latent_rot' in src_dict: - del src_dict['latent_rot'] - if 'clip_rot' in src_dict: - del src_dict['clip_rot'] - - self.load_state_dict(src_dict, strict=True) - - @property - def device(self) -> torch.device: - return self.latent_mean.device - - @property - def latent_seq_len(self) -> int: - return self._latent_seq_len - - @property - def clip_seq_len(self) -> int: - return self._clip_seq_len - - @property - def sync_seq_len(self) -> int: - return self._sync_seq_len - - -def small_16k(**kwargs) -> MMAudio: - num_heads = 7 - return MMAudio(latent_dim=20, - clip_dim=1024, - sync_dim=768, - text_dim=1024, - hidden_dim=64 * num_heads, - depth=12, - fused_depth=8, - num_heads=num_heads, - latent_seq_len=250, - clip_seq_len=64, - sync_seq_len=192, - **kwargs) - - -def small_44k(**kwargs) -> MMAudio: - num_heads = 7 - return MMAudio(latent_dim=40, - clip_dim=1024, - sync_dim=768, - text_dim=1024, - hidden_dim=64 * num_heads, - depth=12, - fused_depth=8, - num_heads=num_heads, - latent_seq_len=345, - clip_seq_len=64, - sync_seq_len=192, - **kwargs) - - -def medium_44k(**kwargs) -> MMAudio: - num_heads = 14 - return MMAudio(latent_dim=40, - clip_dim=1024, - sync_dim=768, - text_dim=1024, - hidden_dim=64 * num_heads, - depth=12, - fused_depth=8, - num_heads=num_heads, - latent_seq_len=345, - clip_seq_len=64, - sync_seq_len=192, - **kwargs) - - -def large_44k(**kwargs) -> MMAudio: - num_heads = 14 - return MMAudio(latent_dim=40, - clip_dim=1024, - sync_dim=768, - text_dim=1024, - hidden_dim=64 * num_heads, - depth=21, - fused_depth=14, - num_heads=num_heads, - latent_seq_len=345, - clip_seq_len=64, - sync_seq_len=192, - **kwargs) - - -def large_44k_v2(**kwargs) -> MMAudio: - num_heads = 14 - return MMAudio(latent_dim=40, - clip_dim=1024, - sync_dim=768, - text_dim=1024, - hidden_dim=64 * num_heads, - depth=21, - fused_depth=14, - num_heads=num_heads, - latent_seq_len=345, - clip_seq_len=64, - sync_seq_len=192, - v2=True, - **kwargs) - - -def get_my_mmaudio(name: str, **kwargs) -> MMAudio: - if name == 'small_16k': - return small_16k(**kwargs) - if name == 'small_44k': - return small_44k(**kwargs) - if name == 'medium_44k': - return medium_44k(**kwargs) - if name == 'large_44k': - return large_44k(**kwargs) - if name == 'large_44k_v2': - return large_44k_v2(**kwargs) - - raise ValueError(f'Unknown model name: {name}') - - -if __name__ == '__main__': - network = get_my_mmaudio('small_16k') - - # print the number of parameters in terms of millions - num_params = sum(p.numel() for p in network.parameters()) / 1e6 - print(f'Number of parameters: {num_params:.2f}M') diff --git a/think_sound/models/mmmodules/model/sequence_config.py b/think_sound/models/mmmodules/model/sequence_config.py deleted file mode 100644 index 14269014dc401b4751d172466813a935fddda6c1..0000000000000000000000000000000000000000 --- a/think_sound/models/mmmodules/model/sequence_config.py +++ /dev/null @@ -1,58 +0,0 @@ -import dataclasses -import math - - -@dataclasses.dataclass -class SequenceConfig: - # general - duration: float - - # audio - sampling_rate: int - spectrogram_frame_rate: int - latent_downsample_rate: int = 2 - - # visual - clip_frame_rate: int = 8 - sync_frame_rate: int = 25 - sync_num_frames_per_segment: int = 16 - sync_step_size: int = 8 - sync_downsample_rate: int = 2 - - @property - def num_audio_frames(self) -> int: - # we need an integer number of latents - return self.latent_seq_len * self.spectrogram_frame_rate * self.latent_downsample_rate - - @property - def latent_seq_len(self) -> int: - return int( - math.ceil(self.duration * self.sampling_rate / self.spectrogram_frame_rate / - self.latent_downsample_rate)) - - @property - def clip_seq_len(self) -> int: - return int(self.duration * self.clip_frame_rate) - - @property - def sync_seq_len(self) -> int: - num_frames = self.duration * self.sync_frame_rate - num_segments = (num_frames - self.sync_num_frames_per_segment) // self.sync_step_size + 1 - return int(num_segments * self.sync_num_frames_per_segment / self.sync_downsample_rate) - - -CONFIG_16K = SequenceConfig(duration=8.0, sampling_rate=16000, spectrogram_frame_rate=256) -CONFIG_44K = SequenceConfig(duration=8.0, sampling_rate=44100, spectrogram_frame_rate=512) - -if __name__ == '__main__': - assert CONFIG_16K.latent_seq_len == 250 - assert CONFIG_16K.clip_seq_len == 64 - assert CONFIG_16K.sync_seq_len == 192 - assert CONFIG_16K.num_audio_frames == 128000 - - assert CONFIG_44K.latent_seq_len == 345 - assert CONFIG_44K.clip_seq_len == 64 - assert CONFIG_44K.sync_seq_len == 192 - assert CONFIG_44K.num_audio_frames == 353280 - - print('Passed') diff --git a/think_sound/models/mmmodules/runner.py b/think_sound/models/mmmodules/runner.py deleted file mode 100644 index 755ee76bea7de3f31a14a5512710c39743dc9239..0000000000000000000000000000000000000000 --- a/think_sound/models/mmmodules/runner.py +++ /dev/null @@ -1,609 +0,0 @@ -""" -trainer.py - wrapper and utility functions for network training -Compute loss, back-prop, update parameters, logging, etc. -""" -import os -from pathlib import Path -from typing import Optional, Union - -import torch -import torch.distributed -import torch.optim as optim -from av_bench.evaluate import evaluate -from av_bench.extract import extract -from nitrous_ema import PostHocEMA -from omegaconf import DictConfig -from torch.nn.parallel import DistributedDataParallel as DDP - -from mmaudio.model.flow_matching import FlowMatching -from mmaudio.model.networks import get_my_mmaudio -from mmaudio.model.sequence_config import CONFIG_16K, CONFIG_44K -from mmaudio.model.utils.features_utils import FeaturesUtils -from mmaudio.model.utils.parameter_groups import get_parameter_groups -from mmaudio.model.utils.sample_utils import log_normal_sample -from mmaudio.utils.dist_utils import (info_if_rank_zero, local_rank, string_if_rank_zero) -from mmaudio.utils.log_integrator import Integrator -from mmaudio.utils.logger import TensorboardLogger -from mmaudio.utils.time_estimator import PartialTimeEstimator, TimeEstimator -from mmaudio.utils.video_joiner import VideoJoiner - - -class Runner: - - def __init__(self, - cfg: DictConfig, - log: TensorboardLogger, - run_path: Union[str, Path], - for_training: bool = True, - latent_mean: Optional[torch.Tensor] = None, - latent_std: Optional[torch.Tensor] = None): - self.exp_id = cfg.exp_id - self.use_amp = cfg.amp - self.enable_grad_scaler = cfg.enable_grad_scaler - self.for_training = for_training - self.cfg = cfg - - if cfg.model.endswith('16k'): - self.seq_cfg = CONFIG_16K - mode = '16k' - elif cfg.model.endswith('44k'): - self.seq_cfg = CONFIG_44K - mode = '44k' - else: - raise ValueError(f'Unknown model: {cfg.model}') - - self.sample_rate = self.seq_cfg.sampling_rate - self.duration_sec = self.seq_cfg.duration - - # setting up the model - empty_string_feat = torch.load('./ext_weights/empty_string.pth', weights_only=True)[0] - self.network = DDP(get_my_mmaudio(cfg.model, - latent_mean=latent_mean, - latent_std=latent_std, - empty_string_feat=empty_string_feat).cuda(), - device_ids=[local_rank], - broadcast_buffers=False) - if cfg.compile: - # NOTE: though train_fn and val_fn are very similar - # (early on they are implemented as a single function) - # keeping them separate and compiling them separately are CRUCIAL for high performance - self.train_fn = torch.compile(self.train_fn) - self.val_fn = torch.compile(self.val_fn) - - self.fm = FlowMatching(cfg.sampling.min_sigma, - inference_mode=cfg.sampling.method, - num_steps=cfg.sampling.num_steps) - - # ema profile - if for_training and cfg.ema.enable and local_rank == 0: - self.ema = PostHocEMA(self.network.module, - sigma_rels=cfg.ema.sigma_rels, - update_every=cfg.ema.update_every, - checkpoint_every_num_steps=cfg.ema.checkpoint_every, - checkpoint_folder=cfg.ema.checkpoint_folder, - step_size_correction=True).cuda() - self.ema_start = cfg.ema.start - else: - self.ema = None - - self.rng = torch.Generator(device='cuda') - self.rng.manual_seed(cfg['seed'] + local_rank) - - # setting up feature extractors and VAEs - if mode == '16k': - self.features = FeaturesUtils( - tod_vae_ckpt=cfg['vae_16k_ckpt'], - bigvgan_vocoder_ckpt=cfg['bigvgan_vocoder_ckpt'], - synchformer_ckpt=cfg['synchformer_ckpt'], - enable_conditions=True, - mode=mode, - need_vae_encoder=False, - ) - elif mode == '44k': - self.features = FeaturesUtils( - tod_vae_ckpt=cfg['vae_44k_ckpt'], - synchformer_ckpt=cfg['synchformer_ckpt'], - enable_conditions=True, - mode=mode, - need_vae_encoder=False, - ) - self.features = self.features.cuda().eval() - - if cfg.compile: - self.features.compile() - - # hyperparameters - self.log_normal_sampling_mean = cfg.sampling.mean - self.log_normal_sampling_scale = cfg.sampling.scale - self.null_condition_probability = cfg.null_condition_probability - self.cfg_strength = cfg.cfg_strength - - # setting up logging - self.log = log - self.run_path = Path(run_path) - vgg_cfg = cfg.data.VGGSound - if for_training: - self.val_video_joiner = VideoJoiner(vgg_cfg.root, self.run_path / 'val-sampled-videos', - self.sample_rate, self.duration_sec) - else: - self.test_video_joiner = VideoJoiner(vgg_cfg.root, - self.run_path / 'test-sampled-videos', - self.sample_rate, self.duration_sec) - string_if_rank_zero(self.log, 'model_size', - f'{sum([param.nelement() for param in self.network.parameters()])}') - string_if_rank_zero( - self.log, 'number_of_parameters_that_require_gradient: ', - str( - sum([ - param.nelement() - for param in filter(lambda p: p.requires_grad, self.network.parameters()) - ]))) - info_if_rank_zero(self.log, 'torch version: ' + torch.__version__) - self.train_integrator = Integrator(self.log, distributed=True) - self.val_integrator = Integrator(self.log, distributed=True) - - # setting up optimizer and loss - if for_training: - self.enter_train() - parameter_groups = get_parameter_groups(self.network, cfg, print_log=(local_rank == 0)) - self.optimizer = optim.AdamW(parameter_groups, - lr=cfg['learning_rate'], - weight_decay=cfg['weight_decay'], - betas=[0.9, 0.95], - eps=1e-6 if self.use_amp else 1e-8, - fused=True) - if self.enable_grad_scaler: - self.scaler = torch.amp.GradScaler(init_scale=2048) - self.clip_grad_norm = cfg['clip_grad_norm'] - - # linearly warmup learning rate - linear_warmup_steps = cfg['linear_warmup_steps'] - - def warmup(currrent_step: int): - return (currrent_step + 1) / (linear_warmup_steps + 1) - - warmup_scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=warmup) - - # setting up learning rate scheduler - if cfg['lr_schedule'] == 'constant': - next_scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lambda _: 1) - elif cfg['lr_schedule'] == 'poly': - total_num_iter = cfg['iterations'] - next_scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, - lr_lambda=lambda x: - (1 - (x / total_num_iter))**0.9) - elif cfg['lr_schedule'] == 'step': - next_scheduler = optim.lr_scheduler.MultiStepLR(self.optimizer, - cfg['lr_schedule_steps'], - cfg['lr_schedule_gamma']) - else: - raise NotImplementedError - - self.scheduler = optim.lr_scheduler.SequentialLR(self.optimizer, - [warmup_scheduler, next_scheduler], - [linear_warmup_steps]) - - # Logging info - self.log_text_interval = cfg['log_text_interval'] - self.log_extra_interval = cfg['log_extra_interval'] - self.save_weights_interval = cfg['save_weights_interval'] - self.save_checkpoint_interval = cfg['save_checkpoint_interval'] - self.save_copy_iterations = cfg['save_copy_iterations'] - self.num_iterations = cfg['num_iterations'] - if cfg['debug']: - self.log_text_interval = self.log_extra_interval = 1 - - # update() is called when we log metrics, within the logger - self.log.batch_timer = TimeEstimator(self.num_iterations, self.log_text_interval) - # update() is called every iteration, in this script - self.log.data_timer = PartialTimeEstimator(self.num_iterations, 1, ema_alpha=0.9) - else: - self.enter_val() - - def train_fn( - self, - clip_f: torch.Tensor, - sync_f: torch.Tensor, - text_f: torch.Tensor, - a_mean: torch.Tensor, - a_std: torch.Tensor, - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - # sample - a_randn = torch.empty_like(a_mean).normal_(generator=self.rng) - x1 = a_mean + a_std * a_randn - bs = x1.shape[0] # batch_size * seq_len * num_channels - - # normalize the latents - x1 = self.network.module.normalize(x1) - - t = log_normal_sample(x1, - generator=self.rng, - m=self.log_normal_sampling_mean, - s=self.log_normal_sampling_scale) - x0, x1, xt, (clip_f, sync_f, text_f) = self.fm.get_x0_xt_c(x1, - t, - Cs=[clip_f, sync_f, text_f], - generator=self.rng) - - # classifier-free training - samples = torch.rand(bs, device=x1.device, generator=self.rng) - null_video = (samples < self.null_condition_probability) - clip_f[null_video] = self.network.module.empty_clip_feat - sync_f[null_video] = self.network.module.empty_sync_feat - - samples = torch.rand(bs, device=x1.device, generator=self.rng) - null_text = (samples < self.null_condition_probability) - text_f[null_text] = self.network.module.empty_string_feat - - pred_v = self.network(xt, clip_f, sync_f, text_f, t) - loss = self.fm.loss(pred_v, x0, x1) - mean_loss = loss.mean() - return x1, loss, mean_loss, t - - def val_fn( - self, - clip_f: torch.Tensor, - sync_f: torch.Tensor, - text_f: torch.Tensor, - x1: torch.Tensor, - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - bs = x1.shape[0] # batch_size * seq_len * num_channels - # normalize the latents - x1 = self.network.module.normalize(x1) - t = log_normal_sample(x1, - generator=self.rng, - m=self.log_normal_sampling_mean, - s=self.log_normal_sampling_scale) - x0, x1, xt, (clip_f, sync_f, text_f) = self.fm.get_x0_xt_c(x1, - t, - Cs=[clip_f, sync_f, text_f], - generator=self.rng) - - # classifier-free training - samples = torch.rand(bs, device=x1.device, generator=self.rng) - # null mask is for when a video is provided but we decided to ignore it - null_video = (samples < self.null_condition_probability) - # complete mask is for when a video is not provided or we decided to ignore it - clip_f[null_video] = self.network.module.empty_clip_feat - sync_f[null_video] = self.network.module.empty_sync_feat - - samples = torch.rand(bs, device=x1.device, generator=self.rng) - null_text = (samples < self.null_condition_probability) - text_f[null_text] = self.network.module.empty_string_feat - - pred_v = self.network(xt, clip_f, sync_f, text_f, t) - - loss = self.fm.loss(pred_v, x0, x1) - mean_loss = loss.mean() - return loss, mean_loss, t - - def train_pass(self, data, it: int = 0): - - if not self.for_training: - raise ValueError('train_pass() should not be called when not training.') - - self.enter_train() - with torch.amp.autocast('cuda', enabled=self.use_amp, dtype=torch.bfloat16): - clip_f = data['clip_features'].cuda(non_blocking=True) - sync_f = data['sync_features'].cuda(non_blocking=True) - text_f = data['text_features'].cuda(non_blocking=True) - video_exist = data['video_exist'].cuda(non_blocking=True) - text_exist = data['text_exist'].cuda(non_blocking=True) - a_mean = data['a_mean'].cuda(non_blocking=True) - a_std = data['a_std'].cuda(non_blocking=True) - - # these masks are for non-existent data; masking for CFG training is in train_fn - clip_f[~video_exist] = self.network.module.empty_clip_feat - sync_f[~video_exist] = self.network.module.empty_sync_feat - text_f[~text_exist] = self.network.module.empty_string_feat - - self.log.data_timer.end() - if it % self.log_extra_interval == 0: - unmasked_clip_f = clip_f.clone() - unmasked_sync_f = sync_f.clone() - unmasked_text_f = text_f.clone() - x1, loss, mean_loss, t = self.train_fn(clip_f, sync_f, text_f, a_mean, a_std) - - self.train_integrator.add_dict({'loss': mean_loss}) - - if it % self.log_text_interval == 0 and it != 0: - self.train_integrator.add_scalar('lr', self.scheduler.get_last_lr()[0]) - self.train_integrator.add_binned_tensor('binned_loss', loss, t) - self.train_integrator.finalize('train', it) - self.train_integrator.reset_except_hooks() - - # Backward pass - self.optimizer.zero_grad(set_to_none=True) - if self.enable_grad_scaler: - self.scaler.scale(mean_loss).backward() - self.scaler.unscale_(self.optimizer) - grad_norm = torch.nn.utils.clip_grad_norm_(self.network.parameters(), - self.clip_grad_norm) - self.scaler.step(self.optimizer) - self.scaler.update() - else: - mean_loss.backward() - grad_norm = torch.nn.utils.clip_grad_norm_(self.network.parameters(), - self.clip_grad_norm) - self.optimizer.step() - - if self.ema is not None and it >= self.ema_start: - self.ema.update() - self.scheduler.step() - self.integrator.add_scalar('grad_norm', grad_norm) - - self.enter_val() - with torch.amp.autocast('cuda', enabled=self.use_amp, - dtype=torch.bfloat16), torch.inference_mode(): - try: - if it % self.log_extra_interval == 0: - # save GT audio - # unnormalize the latents - x1 = self.network.module.unnormalize(x1[0:1]) - mel = self.features.decode(x1) - audio = self.features.vocode(mel).cpu()[0] # 1 * num_samples - self.log.log_spectrogram('train', f'spec-gt-r{local_rank}', mel.cpu()[0], it) - self.log.log_audio('train', - f'audio-gt-r{local_rank}', - audio, - it, - sample_rate=self.sample_rate) - - # save audio from sampling - x0 = torch.empty_like(x1[0:1]).normal_(generator=self.rng) - clip_f = unmasked_clip_f[0:1] - sync_f = unmasked_sync_f[0:1] - text_f = unmasked_text_f[0:1] - conditions = self.network.module.preprocess_conditions(clip_f, sync_f, text_f) - empty_conditions = self.network.module.get_empty_conditions(x0.shape[0]) - cfg_ode_wrapper = lambda t, x: self.network.module.ode_wrapper( - t, x, conditions, empty_conditions, self.cfg_strength) - x1_hat = self.fm.to_data(cfg_ode_wrapper, x0) - x1_hat = self.network.module.unnormalize(x1_hat) - mel = self.features.decode(x1_hat) - audio = self.features.vocode(mel).cpu()[0] - self.log.log_spectrogram('train', f'spec-r{local_rank}', mel.cpu()[0], it) - self.log.log_audio('train', - f'audio-r{local_rank}', - audio, - it, - sample_rate=self.sample_rate) - except Exception as e: - self.log.warning(f'Error in extra logging: {e}') - if self.cfg.debug: - raise - - # Save network weights and checkpoint if needed - save_copy = it in self.save_copy_iterations - - if (it % self.save_weights_interval == 0 and it != 0) or save_copy: - self.save_weights(it) - - if it % self.save_checkpoint_interval == 0 and it != 0: - self.save_checkpoint(it, save_copy=save_copy) - - self.log.data_timer.start() - - @torch.inference_mode() - def validation_pass(self, data, it: int = 0): - self.enter_val() - with torch.amp.autocast('cuda', enabled=self.use_amp, dtype=torch.bfloat16): - clip_f = data['clip_features'].cuda(non_blocking=True) - sync_f = data['sync_features'].cuda(non_blocking=True) - text_f = data['text_features'].cuda(non_blocking=True) - video_exist = data['video_exist'].cuda(non_blocking=True) - text_exist = data['text_exist'].cuda(non_blocking=True) - a_mean = data['a_mean'].cuda(non_blocking=True) - a_std = data['a_std'].cuda(non_blocking=True) - - clip_f[~video_exist] = self.network.module.empty_clip_feat - sync_f[~video_exist] = self.network.module.empty_sync_feat - text_f[~text_exist] = self.network.module.empty_string_feat - a_randn = torch.empty_like(a_mean).normal_(generator=self.rng) - x1 = a_mean + a_std * a_randn - - self.log.data_timer.end() - loss, mean_loss, t = self.val_fn(clip_f.clone(), sync_f.clone(), text_f.clone(), x1) - - self.val_integrator.add_binned_tensor('binned_loss', loss, t) - self.val_integrator.add_dict({'loss': mean_loss}) - - self.log.data_timer.start() - - @torch.inference_mode() - def inference_pass(self, - data, - it: int, - data_cfg: DictConfig, - *, - save_eval: bool = True) -> Path: - self.enter_val() - with torch.amp.autocast('cuda', enabled=self.use_amp, dtype=torch.bfloat16): - clip_f = data['clip_features'].cuda(non_blocking=True) - sync_f = data['sync_features'].cuda(non_blocking=True) - text_f = data['text_features'].cuda(non_blocking=True) - video_exist = data['video_exist'].cuda(non_blocking=True) - text_exist = data['text_exist'].cuda(non_blocking=True) - a_mean = data['a_mean'].cuda(non_blocking=True) # for the shape only - - clip_f[~video_exist] = self.network.module.empty_clip_feat - sync_f[~video_exist] = self.network.module.empty_sync_feat - text_f[~text_exist] = self.network.module.empty_string_feat - - # sample - x0 = torch.empty_like(a_mean).normal_(generator=self.rng) - conditions = self.network.module.preprocess_conditions(clip_f, sync_f, text_f) - empty_conditions = self.network.module.get_empty_conditions(x0.shape[0]) - cfg_ode_wrapper = lambda t, x: self.network.module.ode_wrapper( - t, x, conditions, empty_conditions, self.cfg_strength) - x1_hat = self.fm.to_data(cfg_ode_wrapper, x0) - x1_hat = self.network.module.unnormalize(x1_hat) - mel = self.features.decode(x1_hat) - audio = self.features.vocode(mel).cpu() - for i in range(audio.shape[0]): - video_id = data['id'][i] - if (not self.for_training) and i == 0: - # save very few videos - self.test_video_joiner.join(video_id, f'{video_id}', audio[i].transpose(0, 1)) - - if data_cfg.output_subdir is not None: - # validation - if save_eval: - iter_naming = f'{it:09d}' - else: - iter_naming = 'val-cache' - audio_dir = self.log.log_audio(iter_naming, - f'{video_id}', - audio[i], - it=None, - sample_rate=self.sample_rate, - subdir=Path(data_cfg.output_subdir)) - if save_eval and i == 0: - self.val_video_joiner.join(video_id, f'{iter_naming}-{video_id}', - audio[i].transpose(0, 1)) - else: - # full test set, usually - audio_dir = self.log.log_audio(f'{data_cfg.tag}-sampled', - f'{video_id}', - audio[i], - it=None, - sample_rate=self.sample_rate) - - return Path(audio_dir) - - @torch.inference_mode() - def eval(self, audio_dir: Path, it: int, data_cfg: DictConfig) -> dict[str, float]: - with torch.amp.autocast('cuda', enabled=False): - if local_rank == 0: - extract(audio_path=audio_dir, - output_path=audio_dir / 'cache', - device='cuda', - batch_size=32, - audio_length=8) - output_metrics = evaluate(gt_audio_cache=Path(data_cfg.gt_cache), - pred_audio_cache=audio_dir / 'cache') - for k, v in output_metrics.items(): - # pad k to 10 characters - # pad v to 10 decimal places - self.log.log_scalar(f'{data_cfg.tag}/{k}', v, it) - self.log.info(f'{data_cfg.tag}/{k:<10}: {v:.10f}') - else: - output_metrics = None - - return output_metrics - - def save_weights(self, it, save_copy=False): - if local_rank != 0: - return - - os.makedirs(self.run_path, exist_ok=True) - if save_copy: - model_path = self.run_path / f'{self.exp_id}_{it}.pth' - torch.save(self.network.module.state_dict(), model_path) - self.log.info(f'Network weights saved to {model_path}.') - - # if last exists, move it to a shadow copy - model_path = self.run_path / f'{self.exp_id}_last.pth' - if model_path.exists(): - shadow_path = model_path.with_name(model_path.name.replace('last', 'shadow')) - model_path.replace(shadow_path) - self.log.info(f'Network weights shadowed to {shadow_path}.') - - torch.save(self.network.module.state_dict(), model_path) - self.log.info(f'Network weights saved to {model_path}.') - - def save_checkpoint(self, it, save_copy=False): - if local_rank != 0: - return - - checkpoint = { - 'it': it, - 'weights': self.network.module.state_dict(), - 'optimizer': self.optimizer.state_dict(), - 'scheduler': self.scheduler.state_dict(), - 'ema': self.ema.state_dict() if self.ema is not None else None, - } - - os.makedirs(self.run_path, exist_ok=True) - if save_copy: - model_path = self.run_path / f'{self.exp_id}_ckpt_{it}.pth' - torch.save(checkpoint, model_path) - self.log.info(f'Checkpoint saved to {model_path}.') - - # if ckpt_last exists, move it to a shadow copy - model_path = self.run_path / f'{self.exp_id}_ckpt_last.pth' - if model_path.exists(): - shadow_path = model_path.with_name(model_path.name.replace('last', 'shadow')) - model_path.replace(shadow_path) # moves the file - self.log.info(f'Checkpoint shadowed to {shadow_path}.') - - torch.save(checkpoint, model_path) - self.log.info(f'Checkpoint saved to {model_path}.') - - def get_latest_checkpoint_path(self): - ckpt_path = self.run_path / f'{self.exp_id}_ckpt_last.pth' - if not ckpt_path.exists(): - info_if_rank_zero(self.log, f'No checkpoint found at {ckpt_path}.') - return None - return ckpt_path - - def get_latest_weight_path(self): - weight_path = self.run_path / f'{self.exp_id}_last.pth' - if not weight_path.exists(): - self.log.info(f'No weight found at {weight_path}.') - return None - return weight_path - - def get_final_ema_weight_path(self): - weight_path = self.run_path / f'{self.exp_id}_ema_final.pth' - if not weight_path.exists(): - self.log.info(f'No weight found at {weight_path}.') - return None - return weight_path - - def load_checkpoint(self, path): - # This method loads everything and should be used to resume training - map_location = 'cuda:%d' % local_rank - checkpoint = torch.load(path, map_location={'cuda:0': map_location}, weights_only=True) - - it = checkpoint['it'] - weights = checkpoint['weights'] - optimizer = checkpoint['optimizer'] - scheduler = checkpoint['scheduler'] - if self.ema is not None: - self.ema.load_state_dict(checkpoint['ema']) - self.log.info(f'EMA states loaded from step {self.ema.step}') - - map_location = 'cuda:%d' % local_rank - self.network.module.load_state_dict(weights) - self.optimizer.load_state_dict(optimizer) - self.scheduler.load_state_dict(scheduler) - - self.log.info(f'Global iteration {it} loaded.') - self.log.info('Network weights, optimizer states, and scheduler states loaded.') - - return it - - def load_weights_in_memory(self, src_dict): - self.network.module.load_weights(src_dict) - self.log.info('Network weights loaded from memory.') - - def load_weights(self, path): - # This method loads only the network weight and should be used to load a pretrained model - map_location = 'cuda:%d' % local_rank - src_dict = torch.load(path, map_location={'cuda:0': map_location}, weights_only=True) - - self.log.info(f'Importing network weights from {path}...') - self.load_weights_in_memory(src_dict) - - def weights(self): - return self.network.module.state_dict() - - def enter_train(self): - self.integrator = self.train_integrator - self.network.train() - return self - - def enter_val(self): - self.network.eval() - return self diff --git a/think_sound/models/mmmodules/sample.py b/think_sound/models/mmmodules/sample.py deleted file mode 100644 index 72b83389d7dbb55bed02991f51731b0d1e346a2b..0000000000000000000000000000000000000000 --- a/think_sound/models/mmmodules/sample.py +++ /dev/null @@ -1,90 +0,0 @@ -import json -import logging -import os -import random - -import numpy as np -import torch -from hydra.core.hydra_config import HydraConfig -from omegaconf import DictConfig, open_dict -from tqdm import tqdm - -from mmaudio.data.data_setup import setup_test_datasets -from mmaudio.runner import Runner -from mmaudio.utils.dist_utils import info_if_rank_zero -from mmaudio.utils.logger import TensorboardLogger - -local_rank = int(os.environ['LOCAL_RANK']) -world_size = int(os.environ['WORLD_SIZE']) - - -def sample(cfg: DictConfig): - # initial setup - num_gpus = world_size - run_dir = HydraConfig.get().run.dir - - # wrap python logger with a tensorboard logger - log = TensorboardLogger(cfg.exp_id, - run_dir, - logging.getLogger(), - is_rank0=(local_rank == 0), - enable_email=cfg.enable_email and not cfg.debug) - - info_if_rank_zero(log, f'All configuration: {cfg}') - info_if_rank_zero(log, f'Number of GPUs detected: {num_gpus}') - - # cuda setup - torch.cuda.set_device(local_rank) - torch.backends.cudnn.benchmark = cfg.cudnn_benchmark - - # number of dataloader workers - info_if_rank_zero(log, f'Number of dataloader workers (per GPU): {cfg.num_workers}') - - # Set seeds to ensure the same initialization - torch.manual_seed(cfg.seed) - np.random.seed(cfg.seed) - random.seed(cfg.seed) - - # setting up configurations - info_if_rank_zero(log, f'Configuration: {cfg}') - info_if_rank_zero(log, f'Batch size (per GPU): {cfg.batch_size}') - - # construct the trainer - runner = Runner(cfg, log=log, run_path=run_dir, for_training=False).enter_val() - - # load the last weights if needed - if cfg['weights'] is not None: - info_if_rank_zero(log, f'Loading weights from the disk: {cfg["weights"]}') - runner.load_weights(cfg['weights']) - cfg['weights'] = None - else: - weights = runner.get_final_ema_weight_path() - if weights is not None: - info_if_rank_zero(log, f'Automatically finding weight: {weights}') - runner.load_weights(weights) - - # setup datasets - dataset, sampler, loader = setup_test_datasets(cfg) - data_cfg = cfg.data.ExtractedVGG_test - with open_dict(data_cfg): - if cfg.output_name is not None: - # append to the tag - data_cfg.tag = f'{data_cfg.tag}-{cfg.output_name}' - - # loop - audio_path = None - for curr_iter, data in enumerate(tqdm(loader)): - new_audio_path = runner.inference_pass(data, curr_iter, data_cfg) - if audio_path is None: - audio_path = new_audio_path - else: - assert audio_path == new_audio_path, 'Different audio path detected' - - info_if_rank_zero(log, f'Inference completed. Audio path: {audio_path}') - output_metrics = runner.eval(audio_path, curr_iter, data_cfg) - - if local_rank == 0: - # write the output metrics to run_dir - output_metrics_path = os.path.join(run_dir, f'{data_cfg.tag}-output_metrics.json') - with open(output_metrics_path, 'w') as f: - json.dump(output_metrics, f, indent=4) diff --git a/think_sound/models/pqmf.py b/think_sound/models/pqmf.py deleted file mode 100644 index 007fdb51ec797554c1cdd4d9363894d743d970bf..0000000000000000000000000000000000000000 --- a/think_sound/models/pqmf.py +++ /dev/null @@ -1,393 +0,0 @@ -import math -import numpy as np -import torch -import torch.nn as nn -from einops import rearrange -from scipy.optimize import fmin -from scipy.signal import firwin, kaiser, kaiser_beta, kaiserord - -class PQMF(nn.Module): - """ - Pseudo Quadrature Mirror Filter (PQMF) for multiband signal decomposition and reconstruction. - Uses polyphase representation which is computationally more efficient for real-time. - - Parameters: - - attenuation (int): Desired attenuation of the rejected frequency bands, usually between 80 and 120 dB. - - num_bands (int): Number of desired frequency bands. It must be a power of 2. - """ - - def __init__(self, attenuation, num_bands): - super(PQMF, self).__init__() - - # Ensure num_bands is a power of 2 - is_power_of_2 = (math.log2(num_bands) == int(math.log2(num_bands))) - assert is_power_of_2, "'num_bands' must be a power of 2." - - # Create the prototype filter - prototype_filter = design_prototype_filter(attenuation, num_bands) - filter_bank = generate_modulated_filter_bank(prototype_filter, num_bands) - padded_filter_bank = pad_to_nearest_power_of_two(filter_bank) - - # Register filters and settings - self.register_buffer("filter_bank", padded_filter_bank) - self.register_buffer("prototype", prototype_filter) - self.num_bands = num_bands - - def forward(self, signal): - """Decompose the signal into multiple frequency bands.""" - # If signal is not a pytorch tensor of Batch x Channels x Length, convert it - signal = prepare_signal_dimensions(signal) - # The signal length must be a multiple of num_bands. Pad it with zeros. - signal = pad_signal(signal, self.num_bands) - # run it - signal = polyphase_analysis(signal, self.filter_bank) - return apply_alias_cancellation(signal) - - def inverse(self, bands): - """Reconstruct the original signal from the frequency bands.""" - bands = apply_alias_cancellation(bands) - return polyphase_synthesis(bands, self.filter_bank) - - -def prepare_signal_dimensions(signal): - """ - Rearrange signal into Batch x Channels x Length. - - Parameters - ---------- - signal : torch.Tensor or numpy.ndarray - The input signal. - - Returns - ------- - torch.Tensor - Preprocessed signal tensor. - """ - # Convert numpy to torch tensor - if isinstance(signal, np.ndarray): - signal = torch.from_numpy(signal) - - # Ensure tensor - if not isinstance(signal, torch.Tensor): - raise ValueError("Input should be either a numpy array or a PyTorch tensor.") - - # Modify dimension of signal to Batch x Channels x Length - if signal.dim() == 1: - # This is just a mono signal. Unsqueeze to 1 x 1 x Length - signal = signal.unsqueeze(0).unsqueeze(0) - elif signal.dim() == 2: - # This is a multi-channel signal (e.g. stereo) - # Rearrange so that larger dimension (Length) is last - if signal.shape[0] > signal.shape[1]: - signal = signal.T - # Unsqueeze to 1 x Channels x Length - signal = signal.unsqueeze(0) - return signal - -def pad_signal(signal, num_bands): - """ - Pads the signal to make its length divisible by the given number of bands. - - Parameters - ---------- - signal : torch.Tensor - The input signal tensor, where the last dimension represents the signal length. - - num_bands : int - The number of bands by which the signal length should be divisible. - - Returns - ------- - torch.Tensor - The padded signal tensor. If the original signal length was already divisible - by num_bands, returns the original signal unchanged. - """ - remainder = signal.shape[-1] % num_bands - if remainder > 0: - padding_size = num_bands - remainder - signal = nn.functional.pad(signal, (0, padding_size)) - return signal - -def generate_modulated_filter_bank(prototype_filter, num_bands): - """ - Generate a QMF bank of cosine modulated filters based on a given prototype filter. - - Parameters - ---------- - prototype_filter : torch.Tensor - The prototype filter used as the basis for modulation. - num_bands : int - The number of desired subbands or filters. - - Returns - ------- - torch.Tensor - A bank of cosine modulated filters. - """ - - # Initialize indices for modulation. - subband_indices = torch.arange(num_bands).reshape(-1, 1) - - # Calculate the length of the prototype filter. - filter_length = prototype_filter.shape[-1] - - # Generate symmetric time indices centered around zero. - time_indices = torch.arange(-(filter_length // 2), (filter_length // 2) + 1) - - # Calculate phase offsets to ensure orthogonality between subbands. - phase_offsets = (-1)**subband_indices * np.pi / 4 - - # Compute the cosine modulation function. - modulation = torch.cos( - (2 * subband_indices + 1) * np.pi / (2 * num_bands) * time_indices + phase_offsets - ) - - # Apply modulation to the prototype filter. - modulated_filters = 2 * prototype_filter * modulation - - return modulated_filters - - -def design_kaiser_lowpass(angular_cutoff, attenuation, filter_length=None): - """ - Design a lowpass filter using the Kaiser window. - - Parameters - ---------- - angular_cutoff : float - The angular frequency cutoff of the filter. - attenuation : float - The desired stopband attenuation in decibels (dB). - filter_length : int, optional - Desired length of the filter. If not provided, it's computed based on the given specs. - - Returns - ------- - ndarray - The designed lowpass filter coefficients. - """ - - estimated_length, beta = kaiserord(attenuation, angular_cutoff / np.pi) - - # Ensure the estimated length is odd. - estimated_length = 2 * (estimated_length // 2) + 1 - - if filter_length is None: - filter_length = estimated_length - - return firwin(filter_length, angular_cutoff, window=('kaiser', beta), scale=False, nyq=np.pi) - - -def evaluate_filter_objective(angular_cutoff, attenuation, num_bands, filter_length): - """ - Evaluate the filter's objective value based on the criteria from https://ieeexplore.ieee.org/document/681427 - - Parameters - ---------- - angular_cutoff : float - Angular frequency cutoff of the filter. - attenuation : float - Desired stopband attenuation in dB. - num_bands : int - Number of bands for the multiband filter system. - filter_length : int, optional - Desired length of the filter. - - Returns - ------- - float - The computed objective (loss) value for the given filter specs. - """ - - filter_coeffs = design_kaiser_lowpass(angular_cutoff, attenuation, filter_length) - convolved_filter = np.convolve(filter_coeffs, filter_coeffs[::-1], "full") - - return np.max(np.abs(convolved_filter[convolved_filter.shape[-1] // 2::2 * num_bands][1:])) - - -def design_prototype_filter(attenuation, num_bands, filter_length=None): - """ - Design the optimal prototype filter for a multiband system given the desired specs. - - Parameters - ---------- - attenuation : float - The desired stopband attenuation in dB. - num_bands : int - Number of bands for the multiband filter system. - filter_length : int, optional - Desired length of the filter. If not provided, it's computed based on the given specs. - - Returns - ------- - ndarray - The optimal prototype filter coefficients. - """ - - optimal_angular_cutoff = fmin(lambda angular_cutoff: evaluate_filter_objective(angular_cutoff, attenuation, num_bands, filter_length), - 1 / num_bands, disp=0)[0] - - prototype_filter = design_kaiser_lowpass(optimal_angular_cutoff, attenuation, filter_length) - return torch.tensor(prototype_filter, dtype=torch.float32) - -def pad_to_nearest_power_of_two(x): - """ - Pads the input tensor 'x' on both sides such that its last dimension - becomes the nearest larger power of two. - - Parameters: - ----------- - x : torch.Tensor - The input tensor to be padded. - - Returns: - -------- - torch.Tensor - The padded tensor. - """ - current_length = x.shape[-1] - target_length = 2**math.ceil(math.log2(current_length)) - - total_padding = target_length - current_length - left_padding = total_padding // 2 - right_padding = total_padding - left_padding - - return nn.functional.pad(x, (left_padding, right_padding)) - -def apply_alias_cancellation(x): - """ - Applies alias cancellation by inverting the sign of every - second element of every second row, starting from the second - row's first element in a tensor. - - This operation helps ensure that the aliasing introduced in - each band during the decomposition will be counteracted during - the reconstruction. - - Parameters: - ----------- - x : torch.Tensor - The input tensor. - - Returns: - -------- - torch.Tensor - Tensor with specific elements' sign inverted for alias cancellation. - """ - - # Create a mask of the same shape as 'x', initialized with all ones - mask = torch.ones_like(x) - - # Update specific elements in the mask to -1 to perform inversion - mask[..., 1::2, ::2] = -1 - - # Apply the mask to the input tensor 'x' - return x * mask - -def ensure_odd_length(tensor): - """ - Pads the last dimension of a tensor to ensure its size is odd. - - Parameters: - ----------- - tensor : torch.Tensor - Input tensor whose last dimension might need padding. - - Returns: - -------- - torch.Tensor - The original tensor if its last dimension was already odd, - or the padded tensor with an odd-sized last dimension. - """ - - last_dim_size = tensor.shape[-1] - - if last_dim_size % 2 == 0: - tensor = nn.functional.pad(tensor, (0, 1)) - - return tensor - -def polyphase_analysis(signal, filter_bank): - """ - Applies the polyphase method to efficiently analyze the signal using a filter bank. - - Parameters: - ----------- - signal : torch.Tensor - Input signal tensor with shape (Batch x Channels x Length). - - filter_bank : torch.Tensor - Filter bank tensor with shape (Bands x Length). - - Returns: - -------- - torch.Tensor - Signal split into sub-bands. (Batch x Channels x Bands x Length) - """ - - num_bands = filter_bank.shape[0] - num_channels = signal.shape[1] - - # Rearrange signal for polyphase processing. - # Also combine Batch x Channel into one dimension for now. - #signal = rearrange(signal, "b c (t n) -> b (c n) t", n=num_bands) - signal = rearrange(signal, "b c (t n) -> (b c) n t", n=num_bands) - - # Rearrange the filter bank for matching signal shape - filter_bank = rearrange(filter_bank, "c (t n) -> c n t", n=num_bands) - - # Apply convolution with appropriate padding to maintain spatial dimensions - padding = filter_bank.shape[-1] // 2 - filtered_signal = nn.functional.conv1d(signal, filter_bank, padding=padding) - - # Truncate the last dimension post-convolution to adjust the output shape - filtered_signal = filtered_signal[..., :-1] - # Rearrange the first dimension back into Batch x Channels - filtered_signal = rearrange(filtered_signal, "(b c) n t -> b c n t", c=num_channels) - - return filtered_signal - -def polyphase_synthesis(signal, filter_bank): - """ - Polyphase Inverse: Apply polyphase filter bank synthesis to reconstruct a signal. - - Parameters - ---------- - signal : torch.Tensor - Decomposed signal to be reconstructed (shape: Batch x Channels x Bands x Length). - - filter_bank : torch.Tensor - Analysis filter bank (shape: Bands x Length). - - should_rearrange : bool, optional - Flag to determine if the filters should be rearranged for polyphase synthesis. Default is True. - - Returns - ------- - torch.Tensor - Reconstructed signal (shape: Batch x Channels X Length) - """ - - num_bands = filter_bank.shape[0] - num_channels = signal.shape[1] - - # Rearrange the filter bank - filter_bank = filter_bank.flip(-1) - filter_bank = rearrange(filter_bank, "c (t n) -> n c t", n=num_bands) - - # Combine Batch x Channels into one dimension for now. - signal = rearrange(signal, "b c n t -> (b c) n t") - - # Apply convolution with appropriate padding - padding_amount = filter_bank.shape[-1] // 2 + 1 - reconstructed_signal = nn.functional.conv1d(signal, filter_bank, padding=int(padding_amount)) - - # Scale the result - reconstructed_signal = reconstructed_signal[..., :-1] * num_bands - - # Reorganize the output and truncate - reconstructed_signal = reconstructed_signal.flip(1) - reconstructed_signal = rearrange(reconstructed_signal, "(b c) n t -> b c (t n)", c=num_channels, n=num_bands) - reconstructed_signal = reconstructed_signal[..., 2 * filter_bank.shape[1]:] - - return reconstructed_signal \ No newline at end of file diff --git a/think_sound/models/wavelets.py b/think_sound/models/wavelets.py deleted file mode 100644 index a359e39110c168aab960d3f79262b464a660e55e..0000000000000000000000000000000000000000 --- a/think_sound/models/wavelets.py +++ /dev/null @@ -1,82 +0,0 @@ -"""The 1D discrete wavelet transform for PyTorch.""" - -from einops import rearrange -import pywt -import torch -from torch import nn -from torch.nn import functional as F -from typing import Literal - - -def get_filter_bank(wavelet): - filt = torch.tensor(pywt.Wavelet(wavelet).filter_bank) - if wavelet.startswith("bior") and torch.all(filt[:, 0] == 0): - filt = filt[:, 1:] - return filt - -class WaveletEncode1d(nn.Module): - def __init__(self, - channels, - levels, - wavelet: Literal["bior2.2", "bior2.4", "bior2.6", "bior2.8", "bior4.4", "bior6.8"] = "bior4.4"): - super().__init__() - self.wavelet = wavelet - self.channels = channels - self.levels = levels - filt = get_filter_bank(wavelet) - assert filt.shape[-1] % 2 == 1 - kernel = filt[:2, None] - kernel = torch.flip(kernel, dims=(-1,)) - index_i = torch.repeat_interleave(torch.arange(2), channels) - index_j = torch.tile(torch.arange(channels), (2,)) - kernel_final = torch.zeros(channels * 2, channels, filt.shape[-1]) - kernel_final[index_i * channels + index_j, index_j] = kernel[index_i, 0] - self.register_buffer("kernel", kernel_final) - - def forward(self, x): - for i in range(self.levels): - low, rest = x[:, : self.channels], x[:, self.channels :] - pad = self.kernel.shape[-1] // 2 - low = F.pad(low, (pad, pad), "reflect") - low = F.conv1d(low, self.kernel, stride=2) - rest = rearrange( - rest, "n (c c2) (l l2) -> n (c l2 c2) l", l2=2, c2=self.channels - ) - x = torch.cat([low, rest], dim=1) - return x - - -class WaveletDecode1d(nn.Module): - def __init__(self, - channels, - levels, - wavelet: Literal["bior2.2", "bior2.4", "bior2.6", "bior2.8", "bior4.4", "bior6.8"] = "bior4.4"): - super().__init__() - self.wavelet = wavelet - self.channels = channels - self.levels = levels - filt = get_filter_bank(wavelet) - assert filt.shape[-1] % 2 == 1 - kernel = filt[2:, None] - index_i = torch.repeat_interleave(torch.arange(2), channels) - index_j = torch.tile(torch.arange(channels), (2,)) - kernel_final = torch.zeros(channels * 2, channels, filt.shape[-1]) - kernel_final[index_i * channels + index_j, index_j] = kernel[index_i, 0] - self.register_buffer("kernel", kernel_final) - - def forward(self, x): - for i in range(self.levels): - low, rest = x[:, : self.channels * 2], x[:, self.channels * 2 :] - pad = self.kernel.shape[-1] // 2 + 2 - low = rearrange(low, "n (l2 c) l -> n c (l l2)", l2=2) - low = F.pad(low, (pad, pad), "reflect") - low = rearrange(low, "n c (l l2) -> n (l2 c) l", l2=2) - low = F.conv_transpose1d( - low, self.kernel, stride=2, padding=self.kernel.shape[-1] // 2 - ) - low = low[..., pad - 1 : -pad] - rest = rearrange( - rest, "n (c l2 c2) l -> n (c c2) (l l2)", l2=2, c2=self.channels - ) - x = torch.cat([low, rest], dim=1) - return x \ No newline at end of file diff --git a/think_sound/training/__pycache__/__init__.cpython-310.pyc b/think_sound/training/__pycache__/__init__.cpython-310.pyc deleted file mode 100644 index 5341ed35294363d4664819b34658372d1d88e9f5..0000000000000000000000000000000000000000 Binary files a/think_sound/training/__pycache__/__init__.cpython-310.pyc and /dev/null differ diff --git a/think_sound/training/__pycache__/__init__.cpython-39.pyc b/think_sound/training/__pycache__/__init__.cpython-39.pyc deleted file mode 100644 index eac873918dd6c84cad0a8ac58daa7733992010ec..0000000000000000000000000000000000000000 Binary files a/think_sound/training/__pycache__/__init__.cpython-39.pyc and /dev/null differ diff --git a/think_sound/training/__pycache__/autoencoders.cpython-310.pyc b/think_sound/training/__pycache__/autoencoders.cpython-310.pyc deleted file mode 100644 index 2138ff51a7251c02db411132dd507928a12e5329..0000000000000000000000000000000000000000 Binary files a/think_sound/training/__pycache__/autoencoders.cpython-310.pyc and /dev/null differ diff --git a/think_sound/training/__pycache__/autoencoders.cpython-39.pyc b/think_sound/training/__pycache__/autoencoders.cpython-39.pyc deleted file mode 100644 index 2b2a6bedcd0cc3066e42776776c849564b079d1b..0000000000000000000000000000000000000000 Binary files a/think_sound/training/__pycache__/autoencoders.cpython-39.pyc and /dev/null differ diff --git a/think_sound/training/__pycache__/diffusion.cpython-310.pyc b/think_sound/training/__pycache__/diffusion.cpython-310.pyc deleted file mode 100644 index f65fd143aa6d3379f48252d9084461fef88ce66c..0000000000000000000000000000000000000000 Binary files a/think_sound/training/__pycache__/diffusion.cpython-310.pyc and /dev/null differ diff --git a/think_sound/training/__pycache__/diffusion.cpython-39.pyc b/think_sound/training/__pycache__/diffusion.cpython-39.pyc deleted file mode 100644 index 11b337395b997f0d788cd045371b466af68e9a41..0000000000000000000000000000000000000000 Binary files a/think_sound/training/__pycache__/diffusion.cpython-39.pyc and /dev/null differ diff --git a/think_sound/training/__pycache__/factory.cpython-310.pyc b/think_sound/training/__pycache__/factory.cpython-310.pyc deleted file mode 100644 index 7978655e702af1a76fe29e1775d5fbf89ca7d853..0000000000000000000000000000000000000000 Binary files a/think_sound/training/__pycache__/factory.cpython-310.pyc and /dev/null differ diff --git a/think_sound/training/__pycache__/factory.cpython-39.pyc b/think_sound/training/__pycache__/factory.cpython-39.pyc deleted file mode 100644 index 3f2a6d0238e0fc918e9f36b81647ebfcf2fdd419..0000000000000000000000000000000000000000 Binary files a/think_sound/training/__pycache__/factory.cpython-39.pyc and /dev/null differ diff --git a/think_sound/training/__pycache__/utils.cpython-310.pyc b/think_sound/training/__pycache__/utils.cpython-310.pyc deleted file mode 100644 index 439cab9e3454c17accbacf85f947a62cac1bfd90..0000000000000000000000000000000000000000 Binary files a/think_sound/training/__pycache__/utils.cpython-310.pyc and /dev/null differ diff --git a/think_sound/training/__pycache__/utils.cpython-39.pyc b/think_sound/training/__pycache__/utils.cpython-39.pyc deleted file mode 100644 index c815222a7df75bafc58254b3fcac4294baf607fc..0000000000000000000000000000000000000000 Binary files a/think_sound/training/__pycache__/utils.cpython-39.pyc and /dev/null differ diff --git a/think_sound/training/lm.py b/think_sound/training/lm.py deleted file mode 100644 index e1fa9f71c805f8d4083919d5c46422c5b7eeb4a8..0000000000000000000000000000000000000000 --- a/think_sound/training/lm.py +++ /dev/null @@ -1,267 +0,0 @@ -import pytorch_lightning as pl -import sys, gc -import random -import torch -import torchaudio -import typing as tp -import wandb - -from aeiou.viz import pca_point_cloud, audio_spectrogram_image, tokens_spectrogram_image -from ema_pytorch import EMA -from einops import rearrange -from safetensors.torch import save_file -from torch import optim -from torch.nn import functional as F -from pytorch_lightning.utilities.rank_zero import rank_zero_only - -from ..models.lm import AudioLanguageModelWrapper -from .utils import create_optimizer_from_config, create_scheduler_from_config - -class AudioLanguageModelTrainingWrapper(pl.LightningModule): - def __init__( - self, - model: AudioLanguageModelWrapper, - lr = 1e-4, - use_ema=False, - ema_copy=None, - optimizer_configs: dict = None, - pre_encoded=False - ): - super().__init__() - - self.model = model - - self.model.pretransform.requires_grad_(False) - - self.model_ema = None - if use_ema: - self.model_ema = EMA(self.model, ema_model=ema_copy, beta=0.99, update_every=10) - - assert lr is not None or optimizer_configs is not None, "Must specify either lr or optimizer_configs in training config" - - if optimizer_configs is None: - optimizer_configs = { - "lm": { - "optimizer": { - "type": "AdamW", - "config": { - "lr": lr, - "betas": (0.9, 0.95), - "weight_decay": 0.1 - } - } - } - } - else: - if lr is not None: - print(f"WARNING: learning_rate and optimizer_configs both specified in config. Ignoring learning_rate and using optimizer_configs.") - - self.optimizer_configs = optimizer_configs - - self.pre_encoded = pre_encoded - - def configure_optimizers(self): - lm_opt_config = self.optimizer_configs['lm'] - opt_lm = create_optimizer_from_config(lm_opt_config['optimizer'], self.model.parameters()) - - if "scheduler" in lm_opt_config: - sched_lm = create_scheduler_from_config(lm_opt_config['scheduler'], opt_lm) - sched_lm_config = { - "scheduler": sched_lm, - "interval": "step" - } - return [opt_lm], [sched_lm_config] - - return [opt_lm] - - # Copied and modified from https://github.com/facebookresearch/audiocraft/blob/main/audiocraft/solvers/musicgen.py under MIT license - # License can be found in LICENSES/LICENSE_META.txt - - def _compute_cross_entropy( - self, logits: torch.Tensor, targets: torch.Tensor, mask: torch.Tensor - ) -> tp.Tuple[torch.Tensor, tp.List[torch.Tensor]]: - """Compute cross entropy between multi-codebook targets and model's logits. - The cross entropy is computed per codebook to provide codebook-level cross entropy. - Valid timesteps for each of the codebook are pulled from the mask, where invalid - timesteps are set to 0. - - Args: - logits (torch.Tensor): Model's logits of shape [B, K, T, card]. - targets (torch.Tensor): Target codes, of shape [B, K, T]. - mask (torch.Tensor): Mask for valid target codes, of shape [B, K, T]. - Returns: - ce (torch.Tensor): Cross entropy averaged over the codebooks - ce_per_codebook (list of torch.Tensor): Cross entropy per codebook (detached). - """ - B, K, T = targets.shape - assert logits.shape[:-1] == targets.shape - assert mask.shape == targets.shape - ce = torch.zeros([], device=targets.device) - ce_per_codebook: tp.List[torch.Tensor] = [] - for k in range(K): - logits_k = logits[:, k, ...].contiguous().view(-1, logits.size(-1)) # [B x T, card] - targets_k = targets[:, k, ...].contiguous().view(-1) # [B x T] - mask_k = mask[:, k, ...].contiguous().view(-1) # [B x T] - ce_targets = targets_k[mask_k] - ce_logits = logits_k[mask_k] - q_ce = F.cross_entropy(ce_logits, ce_targets) - ce += q_ce - ce_per_codebook.append(q_ce.detach()) - # average cross entropy across codebooks - ce = ce / K - return ce, ce_per_codebook - - def training_step(self, batch, batch_idx): - reals, metadata = batch - - if reals.ndim == 4 and reals.shape[0] == 1: - reals = reals[0] - - if not self.pre_encoded: - codes = self.model.pretransform.tokenize(reals) - else: - codes = reals - - padding_masks = [] - for md in metadata: - if md["padding_mask"].ndim == 1: - padding_masks.append(md["padding_mask"]) - else: - padding_masks.append(md["padding_mask"][0]) - - padding_masks = torch.stack(padding_masks, dim=0).to(self.device) # Shape (batch_size, sequence_length) - - # Interpolate padding masks to the same length as the codes - padding_masks = F.interpolate(padding_masks.unsqueeze(1).float(), size=codes.shape[2], mode='nearest').bool() - - condition_tensors = None - - # If the model is conditioned, get the conditioning tensors - if self.model.conditioner is not None: - condition_tensors = self.model.conditioner(metadata, self.device) - - lm_output = self.model.compute_logits(codes, condition_tensors=condition_tensors, cfg_dropout_prob=0.1) - - logits = lm_output.logits # [b, k, t, c] - logits_mask = lm_output.mask # [b, k, t] - - logits_mask = logits_mask & padding_masks - - cross_entropy, cross_entropy_per_codebook = self._compute_cross_entropy(logits, codes, logits_mask) - - loss = cross_entropy - - log_dict = { - 'train/loss': loss.detach(), - 'train/cross_entropy': cross_entropy.detach(), - 'train/perplexity': torch.exp(cross_entropy).detach(), - 'train/lr': self.trainer.optimizers[0].param_groups[0]['lr'] - } - - for k, ce_q in enumerate(cross_entropy_per_codebook): - log_dict[f'cross_entropy_q{k + 1}'] = ce_q - log_dict[f'perplexity_q{k + 1}'] = torch.exp(ce_q) - - self.log_dict(log_dict, prog_bar=True, on_step=True) - return loss - - def on_before_zero_grad(self, *args, **kwargs): - if self.model_ema is not None: - self.model_ema.update() - - def export_model(self, path, use_safetensors=False): - - model = self.model_ema.ema_model if self.model_ema is not None else self.model - - if use_safetensors: - save_file(model.state_dict(), path) - else: - torch.save({"state_dict": model.state_dict()}, path) - - -class AudioLanguageModelDemoCallback(pl.Callback): - def __init__(self, - demo_every=2000, - num_demos=8, - sample_size=65536, - sample_rate=48000, - demo_conditioning: tp.Optional[tp.Dict[str, tp.Any]] = None, - demo_cfg_scales: tp.Optional[tp.List[int]] = [3, 5, 7], - **kwargs - ): - super().__init__() - - self.demo_every = demo_every - self.num_demos = num_demos - self.demo_samples = sample_size - self.sample_rate = sample_rate - self.last_demo_step = -1 - self.demo_conditioning = demo_conditioning - self.demo_cfg_scales = demo_cfg_scales - - @rank_zero_only - @torch.no_grad() - def on_train_batch_end(self, trainer, module: AudioLanguageModelTrainingWrapper, outputs, batch, batch_idx): - - if (trainer.global_step - 1) % self.demo_every != 0 or self.last_demo_step == trainer.global_step: - return - - module.eval() - - print(f"Generating demo") - self.last_demo_step = trainer.global_step - - demo_length_tokens = self.demo_samples // module.model.pretransform.downsampling_ratio - - #demo_reals = batch[0][:self.num_demos] - - # if demo_reals.ndim == 4 and demo_reals.shape[0] == 1: - # demo_reals = demo_reals[0] - - #demo_reals_tokens = module.model.pretransform.tokenize(demo_reals) - - ##Limit to first 50 tokens - #demo_reals_tokens = demo_reals_tokens[:, :, :50] - - try: - print("Getting conditioning") - - for cfg_scale in self.demo_cfg_scales: - - model = module.model # module.model_ema.ema_model if module.model_ema is not None else module.model - - print(f"Generating demo for cfg scale {cfg_scale}") - fakes = model.generate_audio( - batch_size=self.num_demos, - max_gen_len=demo_length_tokens, - conditioning=self.demo_conditioning, - #init_data = demo_reals_tokens, - cfg_scale=cfg_scale, - temp=1.0, - top_p=0.95 - ) - - # Put the demos together - fakes = rearrange(fakes, 'b d n -> d (b n)') - - log_dict = {} - - filename = f'demo_cfg_{cfg_scale}_{trainer.global_step:08}.wav' - fakes = fakes / fakes.abs().max() - fakes = fakes.type(torch.float32).clamp(-1, 1).mul(32767).type(torch.int16).cpu() - torchaudio.save(filename, fakes, self.sample_rate) - - log_dict[f'demo_cfg_{cfg_scale}'] = wandb.Audio(filename, - sample_rate=self.sample_rate, - caption=f'Reconstructed') - - log_dict[f'demo_melspec_left_cfg_{cfg_scale}'] = wandb.Image(audio_spectrogram_image(fakes)) - - trainer.logger.experiment.log(log_dict) - - except Exception as e: - raise e - finally: - gc.collect() - torch.cuda.empty_cache() - module.train() \ No newline at end of file diff --git a/think_sound/training/lm_continuous.py b/think_sound/training/lm_continuous.py deleted file mode 100644 index 0ecc1a92336a0623f3f9b1c455a1f8198e4cacb8..0000000000000000000000000000000000000000 --- a/think_sound/training/lm_continuous.py +++ /dev/null @@ -1,294 +0,0 @@ -import pytorch_lightning as pl -import sys, gc -import random -import torch -import torchaudio -import typing as tp -import wandb - -from aeiou.viz import pca_point_cloud, audio_spectrogram_image, tokens_spectrogram_image -from ema_pytorch import EMA -from einops import rearrange -from safetensors.torch import save_file -from torch import optim -from torch.nn import functional as F -from ..inference.sampling import get_alphas_sigmas, sample, sample_discrete_euler -from pytorch_lightning.utilities.rank_zero import rank_zero_only -from ..models.diffusion import DiffusionModelWrapper, ConditionedDiffusionModelWrapper - -from ..models.lm import AudioLMContinuousModelWrapper -from .utils import create_optimizer_from_config, create_scheduler_from_config - -class AudioLMContinuousModelTrainingWrapper(pl.LightningModule): - def __init__( - self, - model: AudioLanguageModelWrapper, - lr = 1e-4, - diffusion_objective: tp.Literal["rectified_flow", "v"] = "v", - timestep_sampler: tp.Literal["uniform", "logit_normal"] = "uniform", - use_ema=False, - ema_copy=None, - optimizer_configs: dict = None, - diffusion_batch_mul=4, - pre_encoded=False - ): - super().__init__() - - self.model = model - self.diffusion = diffusion - self.rng = torch.quasirandom.SobolEngine(1, scramble=True) - - self.model.pretransform.requires_grad_(False) - - self.timestep_sampler = timestep_sampler - - self.diffusion_objective = model.diffusion_objective - - loss_modules = [ - MSELoss("v", - "targets", - weight=1.0, - name="mse_loss" - ) - ] - - self.losses = MultiLoss(loss_modules) - - self.model_ema = None - if use_ema: - self.model_ema = EMA(self.model, ema_model=ema_copy, beta=0.99, update_every=10) - - assert lr is not None or optimizer_configs is not None, "Must specify either lr or optimizer_configs in training config" - - if optimizer_configs is None: - optimizer_configs = { - "lm": { - "optimizer": { - "type": "AdamW", - "config": { - "lr": lr, - "betas": (0.9, 0.95), - "weight_decay": 0.1 - } - } - } - } - else: - if lr is not None: - print(f"WARNING: learning_rate and optimizer_configs both specified in config. Ignoring learning_rate and using optimizer_configs.") - - - self.optimizer_configs = optimizer_configs - - self.diffusion_batch_mul = diffusion_batch_mul - - self.pre_encoded = pre_encoded - - def configure_optimizers(self): - lm_opt_config = self.optimizer_configs['lm'] - opt_lm = create_optimizer_from_config(lm_opt_config['optimizer'], self.model.parameters()) - - if "scheduler" in lm_opt_config: - sched_lm = create_scheduler_from_config(lm_opt_config['scheduler'], opt_lm) - sched_lm_config = { - "scheduler": sched_lm, - "interval": "step" - } - return [opt_lm], [sched_lm_config] - - return [opt_lm] - - # Copied and modified from https://github.com/facebookresearch/audiocraft/blob/main/audiocraft/solvers/musicgen.py under MIT license - # License can be found in LICENSES/LICENSE_META.txt - - - def training_step(self, batch, batch_idx): - reals, metadata = batch - - if reals.ndim == 4 and reals.shape[0] == 1: - reals = reals[0] - - diffusion_input = reals - - loss_info = {} - - if not self.pre_encoded: - loss_info["audio_reals"] = diffusion_input - - if self.diffusion.pretransform is not None: - if not self.pre_encoded: - with torch.set_grad_enabled(self.diffusion.pretransform.enable_grad): - diffusion_input = self.diffusion.pretransform.encode(diffusion_input) - else: - # Apply scale to pre-encoded latents if needed, as the pretransform encode function will not be run - if hasattr(self.diffusion.pretransform, "scale") and self.diffusion.pretransform.scale != 1.0: - diffusion_input = diffusion_input / self.diffusion.pretransform.scale - - loss_info["reals"] = diffusion_input - - padding_masks = [] - for md in metadata: - if md["padding_mask"].ndim == 1: - padding_masks.append(md["padding_mask"]) - else: - padding_masks.append(md["padding_mask"][0]) - - padding_masks = torch.stack(padding_masks, dim=0).to(self.device) # Shape (batch_size, sequence_length) - - condition_tensors = None - - # If the model is conditioned, get the conditioning tensors - if self.model.conditioner is not None: - with torch.cuda.amp.autocast(): - condition_tensors = self.model.conditioner(metadata, self.device) - - z = self.model.compute_logits(diffusion_input, condition_tensors=condition_tensors, cfg_dropout_prob=0.1) - bsz, seq_len, _ = z.shape - gt_inputs = diffusion_input.clone().detach() - gt_inputs = gt_inputs.reshape(bsz * seq_len, -1).repeat(self.diffusion_batch_mul, 1) - z = z.reshape(bsz*seq_len, -1).repeat(self.diffusion_batch_mul, 1) - mask = mask.reshape(bsz*seq_len).repeat(self.diffusion_batch_mul) - if self.timestep_sampler == "uniform": - # Draw uniformly distributed continuous timesteps - t = self.rng.draw(z.shape[0])[:, 0].to(self.device) - elif self.timestep_sampler == "logit_normal": - t = torch.sigmoid(torch.randn(z.shape[0], device=self.device)) - - # Calculate the noise schedule parameters for those timesteps - if self.diffusion_objective == "v": - alphas, sigmas = get_alphas_sigmas(t) - elif self.diffusion_objective == "rectified_flow": - alphas, sigmas = 1-t, t - - # Combine the ground truth data and the noise - alphas = alphas[:, None] - sigmas = sigmas[:, None] - - noise = torch.randn_like(gt_inputs) - noised_inputs = gt_inputs * alphas + noise * sigmas - if self.diffusion_objective == "v": - targets = noise * alphas - gt_inputs * sigmas - elif self.diffusion_objective == "rectified_flow": - targets = noise - gt_inputs - cond = {} - cond['z'] = z - with torch.cuda.amp.autocast(): - v = self.diffusion(noised_inputs, t, cond=cond) - - loss_info.update({ - "v": v, - "targets": targets - }) - - loss, losses = self.losses() - - log_dict = { - 'train/loss': loss.detach(), - 'train/std_data': diffusion_input.std(), - 'train/lr': self.trainer.optimizers[0].param_groups[0]['lr'] - } - - - self.log_dict(log_dict, prog_bar=True, on_step=True) - return loss - - def on_before_zero_grad(self, *args, **kwargs): - if self.model_ema is not None: - self.model_ema.update() - - def export_model(self, path, use_safetensors=False): - - model = self.model_ema.ema_model if self.model_ema is not None else self.model - - if use_safetensors: - save_file(model.state_dict(), path) - else: - torch.save({"state_dict": model.state_dict()}, path) - - -class AudioLanguageModelDemoCallback(pl.Callback):loss_info - def __init__(self, - demo_every=2000, - num_demos=8, - sample_size=65536, - sample_rate=48000, - demo_conditioning: tp.Optional[tp.Dict[str, tp.Any]] = None, - demo_cfg_scales: tp.Optional[tp.List[int]] = [3, 5, 7], - **kwargs - ): - super().__init__() - - self.demo_every = demo_every - self.num_demos = num_demos - self.demo_samples = sample_size - self.sample_rate = sample_rate - self.last_demo_step = -1 - self.demo_conditioning = demo_conditioning - self.demo_cfg_scales = demo_cfg_scales - - @rank_zero_only - @torch.no_grad() - def on_train_batch_end(self, trainer, module: AudioLanguageModelTrainingWrapper, outputs, batch, batch_idx): - - if (trainer.global_step - 1) % self.demo_every != 0 or self.last_demo_step == trainer.global_step: - return - - module.eval() - - print(f"Generating demo") - self.last_demo_step = trainer.global_step - - demo_length_tokens = self.demo_samples // module.model.pretransform.downsampling_ratio - - #demo_reals = batch[0][:self.num_demos] - - # if demo_reals.ndim == 4 and demo_reals.shape[0] == 1: - # demo_reals = demo_reals[0] - - #demo_reals_tokens = module.model.pretransform.tokenize(demo_reals) - - ##Limit to first 50 tokens - #demo_reals_tokens = demo_reals_tokens[:, :, :50] - - try: - print("Getting conditioning") - - for cfg_scale in self.demo_cfg_scales: - - model = module.model # module.model_ema.ema_model if module.model_ema is not None else module.model - - print(f"Generating demo for cfg scale {cfg_scale}") - fakes = model.generate_audio( - batch_size=self.num_demos, - max_gen_len=demo_length_tokens, - conditioning=self.demo_conditioning, - #init_data = demo_reals_tokens, - cfg_scale=cfg_scale, - temp=1.0, - top_p=0.95 - ) - - # Put the demos together - fakes = rearrange(fakes, 'b d n -> d (b n)') - - log_dict = {} - - filename = f'demo_cfg_{cfg_scale}_{trainer.global_step:08}.wav' - fakes = fakes / fakes.abs().max() - fakes = fakes.type(torch.float32).clamp(-1, 1).mul(32767).type(torch.int16).cpu() - torchaudio.save(filename, fakes, self.sample_rate) - - log_dict[f'demo_cfg_{cfg_scale}'] = wandb.Audio(filename, - sample_rate=self.sample_rate, - caption=f'Reconstructed') - - log_dict[f'demo_melspec_left_cfg_{cfg_scale}'] = wandb.Image(audio_spectrogram_image(fakes)) - - trainer.logger.experiment.log(log_dict) - - except Exception as e: - raise e - finally: - gc.collect() - torch.cuda.empty_cache() - module.train() \ No newline at end of file diff --git a/think_sound/training/losses/__pycache__/__init__.cpython-310.pyc b/think_sound/training/losses/__pycache__/__init__.cpython-310.pyc deleted file mode 100644 index f1a2eac94cf919a5e241ffc55f7749b6d32fce3a..0000000000000000000000000000000000000000 Binary files a/think_sound/training/losses/__pycache__/__init__.cpython-310.pyc and /dev/null differ diff --git a/think_sound/training/losses/__pycache__/__init__.cpython-39.pyc b/think_sound/training/losses/__pycache__/__init__.cpython-39.pyc deleted file mode 100644 index f9d511b8392c6e8906ef62a166e5948fb28f3dfc..0000000000000000000000000000000000000000 Binary files a/think_sound/training/losses/__pycache__/__init__.cpython-39.pyc and /dev/null differ diff --git a/think_sound/training/losses/__pycache__/auraloss.cpython-310.pyc b/think_sound/training/losses/__pycache__/auraloss.cpython-310.pyc deleted file mode 100644 index f736432cc8ac2062b5fdb0d6e6ed21620d06231a..0000000000000000000000000000000000000000 Binary files a/think_sound/training/losses/__pycache__/auraloss.cpython-310.pyc and /dev/null differ diff --git a/think_sound/training/losses/__pycache__/auraloss.cpython-39.pyc b/think_sound/training/losses/__pycache__/auraloss.cpython-39.pyc deleted file mode 100644 index c4ef7ab4b2d3d8c80493dd9208a8469b6000a307..0000000000000000000000000000000000000000 Binary files a/think_sound/training/losses/__pycache__/auraloss.cpython-39.pyc and /dev/null differ diff --git a/think_sound/training/losses/__pycache__/losses.cpython-310.pyc b/think_sound/training/losses/__pycache__/losses.cpython-310.pyc deleted file mode 100644 index 330d5bcce476194ab113cc90782747dbb2779346..0000000000000000000000000000000000000000 Binary files a/think_sound/training/losses/__pycache__/losses.cpython-310.pyc and /dev/null differ diff --git a/think_sound/training/losses/__pycache__/losses.cpython-39.pyc b/think_sound/training/losses/__pycache__/losses.cpython-39.pyc deleted file mode 100644 index 6af0513ae4e07e8877b2253e4f1df0153a0ba20e..0000000000000000000000000000000000000000 Binary files a/think_sound/training/losses/__pycache__/losses.cpython-39.pyc and /dev/null differ