root commited on
Commit
98a0e3b
·
1 Parent(s): 167c6ec

add lowmem mode

Browse files
codeclm/models/builders.py CHANGED
@@ -29,13 +29,29 @@ def get_audio_tokenizer_model(checkpoint_path: str, cfg: omegaconf.DictConfig):
29
  return None
30
  if checkpoint_path.startswith('//pretrained/'):
31
  name = checkpoint_path.split('/', 3)[-1]
32
- return AudioTokenizer.get_pretrained(name, cfg.vae_config, cfg.vae_model, 'cpu', mode=cfg.mode)
33
  elif checkpoint_path == "":
34
  return None
35
  else:
36
  name = checkpoint_path
37
- return AudioTokenizer.get_pretrained(name, cfg.vae_config, cfg.vae_model, 'cpu', mode=cfg.mode)
38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  def get_lm_model(cfg: omegaconf.DictConfig): #-> LMModel:
40
  """Instantiate a LM."""
41
  lm_kwargs = dict_from_config(getattr(cfg, 'lm'))
 
29
  return None
30
  if checkpoint_path.startswith('//pretrained/'):
31
  name = checkpoint_path.split('/', 3)[-1]
32
+ return AudioTokenizer.get_pretrained(name, cfg.vae_config, cfg.vae_model, 'cuda', mode=cfg.mode)
33
  elif checkpoint_path == "":
34
  return None
35
  else:
36
  name = checkpoint_path
37
+ return AudioTokenizer.get_pretrained(name, cfg.vae_config, cfg.vae_model, 'cuda', mode=cfg.mode)
38
 
39
+
40
+ def get_audio_tokenizer_model_cpu(checkpoint_path: str, cfg: omegaconf.DictConfig):
41
+ from codeclm.tokenizer.audio_tokenizer import AudioTokenizer
42
+ """Instantiate a compression model."""
43
+ if checkpoint_path is None:
44
+ return None
45
+ if checkpoint_path.startswith('//pretrained/'):
46
+ name = checkpoint_path.split('/', 3)[-1]
47
+ return AudioTokenizer.get_pretrained(name, cfg.vae_config, cfg.vae_model, 'cpu', mode=cfg.mode, tango_device='cpu')
48
+ elif checkpoint_path == "":
49
+ return None
50
+ else:
51
+ name = checkpoint_path
52
+ return AudioTokenizer.get_pretrained(name, cfg.vae_config, cfg.vae_model, 'cpu', mode=cfg.mode, tango_device='cpu')
53
+
54
+
55
  def get_lm_model(cfg: omegaconf.DictConfig): #-> LMModel:
56
  """Instantiate a LM."""
57
  lm_kwargs = dict_from_config(getattr(cfg, 'lm'))
codeclm/models/codeclm.py CHANGED
@@ -271,21 +271,24 @@ class CodecLM:
271
  return gen_tokens
272
 
273
  @torch.no_grad()
274
- def generate_audio(self, gen_tokens: torch.Tensor, prompt=None, vocal_prompt=None, bgm_prompt=None, chunked=False, gen_type="all"):
275
  """Generate Audio from tokens"""
276
  assert gen_tokens.dim() == 3
277
  if self.seperate_tokenizer is not None:
278
  gen_tokens_song = gen_tokens[:, [0], :]
279
  gen_tokens_vocal = gen_tokens[:, [1], :]
280
  gen_tokens_bgm = gen_tokens[:, [2], :]
281
- if gen_type == "bgm":
282
  gen_tokens_vocal = torch.full_like(gen_tokens_vocal, 3142)
283
- vocal_prompt = None
284
- elif gen_type == "vocal":
 
285
  gen_tokens_bgm = torch.full_like(gen_tokens_bgm, 9670)
286
- bgm_prompt = None
287
- # gen_audio_song = self.audiotokenizer.decode(gen_tokens_song, prompt)
288
- gen_audio_seperate = self.seperate_tokenizer.decode([gen_tokens_vocal, gen_tokens_bgm], vocal_prompt, bgm_prompt, chunked=chunked)
 
 
289
  return gen_audio_seperate
290
  else:
291
  gen_audio = self.audiotokenizer.decode(gen_tokens, prompt)
 
271
  return gen_tokens
272
 
273
  @torch.no_grad()
274
+ def generate_audio(self, gen_tokens: torch.Tensor, prompt=None, vocal_prompt=None, bgm_prompt=None, chunked=False, chunk_size=128, gen_type='mixed'):
275
  """Generate Audio from tokens"""
276
  assert gen_tokens.dim() == 3
277
  if self.seperate_tokenizer is not None:
278
  gen_tokens_song = gen_tokens[:, [0], :]
279
  gen_tokens_vocal = gen_tokens[:, [1], :]
280
  gen_tokens_bgm = gen_tokens[:, [2], :]
281
+ if gen_type == 'bgm':
282
  gen_tokens_vocal = torch.full_like(gen_tokens_vocal, 3142)
283
+ if vocal_prompt is not None:
284
+ vocal_prompt = torch.zeros_like(vocal_prompt)
285
+ elif gen_type == 'vocal':
286
  gen_tokens_bgm = torch.full_like(gen_tokens_bgm, 9670)
287
+ if bgm_prompt is not None:
288
+ bgm_prompt = torch.zeros_like(bgm_prompt)
289
+ else:
290
+ assert gen_type == 'mixed', f"gen_type {gen_type} not supported"
291
+ gen_audio_seperate = self.seperate_tokenizer.decode([gen_tokens_vocal, gen_tokens_bgm], vocal_prompt, bgm_prompt, chunked=chunked, chunk_size=chunk_size)
292
  return gen_audio_seperate
293
  else:
294
  gen_audio = self.audiotokenizer.decode(gen_tokens, prompt)
codeclm/tokenizer/Flow1dVAE/generate_1rvq.py CHANGED
@@ -46,7 +46,6 @@ class Tango:
46
 
47
  self.model.eval()
48
  self.model.init_device_dtype(torch.device(device), torch.float32)
49
- print("scaling factor: ", self.model.normfeat.std)
50
 
51
  # self.scheduler = DDIMScheduler.from_pretrained( \
52
  # scheduler_name, subfolder="scheduler")
@@ -281,3 +280,11 @@ class Tango:
281
  else:
282
  output = torch.cat([output, cur_output], -1)
283
  return output
 
 
 
 
 
 
 
 
 
46
 
47
  self.model.eval()
48
  self.model.init_device_dtype(torch.device(device), torch.float32)
 
49
 
50
  # self.scheduler = DDIMScheduler.from_pretrained( \
51
  # scheduler_name, subfolder="scheduler")
 
280
  else:
281
  output = torch.cat([output, cur_output], -1)
282
  return output
283
+
284
+ def to(self, device=None, dtype=None, non_blocking=False):
285
+ if device is not None:
286
+ self.device = device
287
+ self.model.device = device
288
+ self.vae = self.vae.to(device, dtype, non_blocking)
289
+ self.model = self.model.to(device, dtype, non_blocking)
290
+ return self
codeclm/tokenizer/Flow1dVAE/generate_2rvq.py CHANGED
@@ -51,7 +51,6 @@ class Tango:
51
 
52
  self.model.eval()
53
  self.model.init_device_dtype(torch.device(device), torch.float32)
54
- print("scaling factor: ", self.model.normfeat.std)
55
 
56
  # self.scheduler = DDIMScheduler.from_pretrained( \
57
  # scheduler_name, subfolder="scheduler")
 
51
 
52
  self.model.eval()
53
  self.model.init_device_dtype(torch.device(device), torch.float32)
 
54
 
55
  # self.scheduler = DDIMScheduler.from_pretrained( \
56
  # scheduler_name, subfolder="scheduler")
codeclm/tokenizer/Flow1dVAE/generate_4rvq.py CHANGED
@@ -50,7 +50,6 @@ class Tango:
50
 
51
  self.model.eval()
52
  self.model.init_device_dtype(torch.device(device), torch.float32)
53
- print("scaling factor: ", self.model.normfeat.std)
54
 
55
  # self.scheduler = DDIMScheduler.from_pretrained( \
56
  # scheduler_name, subfolder="scheduler")
 
50
 
51
  self.model.eval()
52
  self.model.init_device_dtype(torch.device(device), torch.float32)
 
53
 
54
  # self.scheduler = DDIMScheduler.from_pretrained( \
55
  # scheduler_name, subfolder="scheduler")
codeclm/tokenizer/Flow1dVAE/generate_septoken.py CHANGED
@@ -102,7 +102,6 @@ class Tango:
102
 
103
  self.model.eval()
104
  self.model.init_device_dtype(torch.device(device), torch.float32)
105
- print("scaling factor: ", self.model.normfeat.std)
106
 
107
  # self.scheduler = DDIMScheduler.from_pretrained( \
108
  # scheduler_name, subfolder="scheduler")
@@ -173,7 +172,7 @@ class Tango:
173
  return codes_vocal, codes_bgm
174
 
175
  @torch.no_grad()
176
- def code2sound(self, codes, prompt_vocal=None, prompt_bgm=None, duration=40, guidance_scale=1.5, num_steps=20, disable_progress=False, chunked=False):
177
  codes_vocal,codes_bgm = codes
178
  codes_vocal = codes_vocal.to(self.device)
179
  codes_bgm = codes_bgm.to(self.device)
@@ -188,7 +187,7 @@ class Tango:
188
  first_latent_codes_length = 0
189
 
190
 
191
- if (isinstance(prompt_vocal, torch.Tensor)) and (isinstance(prompt_bgm, torch.Tensor)):
192
  # prepare prompt
193
  prompt_vocal = prompt_vocal.to(self.device)
194
  prompt_bgm = prompt_bgm.to(self.device)
@@ -273,7 +272,7 @@ class Tango:
273
  output = None
274
  for i in range(len(latent_list)):
275
  latent = latent_list[i]
276
- cur_output = self.vae.decode_audio(latent, chunked=chunked)[0].detach().cpu()
277
 
278
  if output is None:
279
  output = cur_output
@@ -301,3 +300,11 @@ class Tango:
301
  codes=[codes_vocal, codes_bgm]
302
  wave = self.code2sound(codes, prompt_vocal,prompt_bgm, guidance_scale=1.5, num_steps=steps, disable_progress=disable_progress)
303
  return wave
 
 
 
 
 
 
 
 
 
102
 
103
  self.model.eval()
104
  self.model.init_device_dtype(torch.device(device), torch.float32)
 
105
 
106
  # self.scheduler = DDIMScheduler.from_pretrained( \
107
  # scheduler_name, subfolder="scheduler")
 
172
  return codes_vocal, codes_bgm
173
 
174
  @torch.no_grad()
175
+ def code2sound(self, codes, prompt_vocal=None, prompt_bgm=None, duration=40, guidance_scale=1.5, num_steps=20, disable_progress=False, chunked=False, chunk_size=128):
176
  codes_vocal,codes_bgm = codes
177
  codes_vocal = codes_vocal.to(self.device)
178
  codes_bgm = codes_bgm.to(self.device)
 
187
  first_latent_codes_length = 0
188
 
189
 
190
+ if(isinstance(prompt_vocal, torch.Tensor) and isinstance(prompt_bgm, torch.Tensor)):
191
  # prepare prompt
192
  prompt_vocal = prompt_vocal.to(self.device)
193
  prompt_bgm = prompt_bgm.to(self.device)
 
272
  output = None
273
  for i in range(len(latent_list)):
274
  latent = latent_list[i]
275
+ cur_output = self.vae.decode_audio(latent, chunked=chunked, chunk_size=chunk_size)[0].detach().cpu()
276
 
277
  if output is None:
278
  output = cur_output
 
300
  codes=[codes_vocal, codes_bgm]
301
  wave = self.code2sound(codes, prompt_vocal,prompt_bgm, guidance_scale=1.5, num_steps=steps, disable_progress=disable_progress)
302
  return wave
303
+
304
+ def to(self, device=None, dtype=None, non_blocking=False):
305
+ if device is not None:
306
+ self.device = device
307
+ self.model.device = device
308
+ self.vae = self.vae.to(device, dtype, non_blocking)
309
+ self.model = self.model.to(device, dtype, non_blocking)
310
+ return self
codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/models/musicfm/model/musicfm_25hz.py CHANGED
@@ -78,7 +78,6 @@ class MusicFM25Hz(nn.Module):
78
  with open(stat_path, "r") as f:
79
  self.stat = json.load(f)
80
  else:
81
- print("No stats file found at `{}`, use default from msd.".format(stat_path))
82
  self.stat = {"spec_256_cnt": 14394344256, "spec_256_mean": -23.34296658431829, "spec_256_std": 26.189295587132637, "spec_512_cnt": 28677104448, "spec_512_mean": -21.31267396860235, "spec_512_std": 26.52644536245769, "spec_1024_cnt": 57242624832, "spec_1024_mean": -18.852271129208273, "spec_1024_std": 26.443154583585663, "spec_2048_cnt": 114373665600, "spec_2048_mean": -15.638743433896792, "spec_2048_std": 26.115825961611545, "spec_4096_cnt": 228635747136, "spec_4096_mean": -11.715532502794836, "spec_4096_std": 25.763972210234062, "melspec_256_cnt": 14282760192, "melspec_256_mean": -26.962600400166156, "melspec_256_std": 36.13614100912126, "melspec_512_cnt": 14282760192, "melspec_512_mean": -9.108344167718862, "melspec_512_std": 24.71910937988429, "melspec_1024_cnt": 14282760192, "melspec_1024_mean": 0.37302579246531126, "melspec_1024_std": 18.684082325919388, "melspec_2048_cnt": 14282760192, "melspec_2048_mean": 6.768444971712967, "melspec_2048_std": 18.417922652295623, "melspec_4096_cnt": 14282760192, "melspec_4096_mean": 13.617164614990036, "melspec_4096_std": 18.08552130124525, "cqt_cnt": 9373061376, "cqt_mean": 0.46341379757927165, "cqt_std": 0.9543998080910191, "mfcc_256_cnt": 1339008768, "mfcc_256_mean": -11.681755459447485, "mfcc_256_std": 29.183186444668316, "mfcc_512_cnt": 1339008768, "mfcc_512_mean": -2.540581461792183, "mfcc_512_std": 31.93752185832081, "mfcc_1024_cnt": 1339008768, "mfcc_1024_mean": 6.606636263169779, "mfcc_1024_std": 34.151644801729624, "mfcc_2048_cnt": 1339008768, "mfcc_2048_mean": 5.281600844245184, "mfcc_2048_std": 33.12784541220003, "mfcc_4096_cnt": 1339008768, "mfcc_4096_mean": 4.7616569480166095, "mfcc_4096_std": 32.61458906894133, "chromagram_256_cnt": 1339008768, "chromagram_256_mean": 55.15596556703181, "chromagram_256_std": 73.91858278719991, "chromagram_512_cnt": 1339008768, "chromagram_512_mean": 175.73092252759895, "chromagram_512_std": 248.48485148525953, "chromagram_1024_cnt": 1339008768, "chromagram_1024_mean": 589.2947481634608, "chromagram_1024_std": 913.857929063196, "chromagram_2048_cnt": 1339008768, "chromagram_2048_mean": 2062.286388327397, "chromagram_2048_std": 3458.92657915397, "chromagram_4096_cnt": 1339008768, "chromagram_4096_mean": 7673.039107997085, "chromagram_4096_std": 13009.883158267234}
83
 
84
  # feature extractor
@@ -90,40 +89,6 @@ class MusicFM25Hz(nn.Module):
90
  self.use_rvq_target = use_rvq_target
91
 
92
  seed = 142
93
- if use_rvq_target:
94
- try:
95
- from .rvq_musicfm import ResidualVectorQuantize
96
-
97
- except:
98
- import sys, os
99
- sys.path.append(os.path.dirname(os.path.abspath(__file__)))
100
- from rvq_musicfm import ResidualVectorQuantize
101
-
102
- self.rvq = ResidualVectorQuantize(
103
- input_dim = 128*4,
104
- n_codebooks = 8,
105
- codebook_size = 1024,
106
- codebook_dim = 16,
107
- quantizer_dropout = 0.0,
108
- )
109
- import os
110
- if rvq_ckpt_path is not None and os.path.exists(rvq_ckpt_path):
111
- state_dict = torch.load(rvq_ckpt_path, map_location="cpu")
112
- self.rvq.load_state_dict(state_dict)
113
- else:
114
- print(f'Checkpoint for rvq `{rvq_ckpt_path}` not found. Using random initialization.')
115
-
116
- else:
117
- for feature in self.features:
118
- for i in range(num_codebooks):
119
- setattr(
120
- self,
121
- f"quantizer_{feature}", # _{i}
122
- RandomProjectionQuantizer(
123
- n_mels * 4, codebook_dim, codebook_size, seed=seed + i
124
- ),
125
- )
126
-
127
  # two residual convolution layers + one projection layer
128
  self.conv = Conv2dSubsampling(
129
  1, conv_dim, encoder_dim, strides=[2, 2], n_bands=n_mels
@@ -247,16 +212,8 @@ class MusicFM25Hz(nn.Module):
247
  @torch.no_grad()
248
  def tokenize(self, x):
249
  out = {}
250
- for key in x.keys():
251
- if self.use_rvq_target:
252
- self.rvq.eval()
253
- quantized_prompt_embeds, codes, _, commitment_loss, codebook_loss, rvq_usage = self.rvq(x[key].permute((0, 2, 1)))
254
- out[key] = torch.cat([codes[:, idx, :] for idx in range(int(self.codebook_size//1024))], dim=-1)
255
- else:
256
- layer = getattr(self, "quantizer_%s" % key)
257
- out[key] = layer(x[key])
258
- return out
259
-
260
  def get_targets(self, x):
261
  x = self.preprocessing(x, features=self.features) # -> {'melspec_2048': Tensor{Size([3, 128, 3000]) cuda:0 f32}}
262
  x = self.normalize(x)
 
78
  with open(stat_path, "r") as f:
79
  self.stat = json.load(f)
80
  else:
 
81
  self.stat = {"spec_256_cnt": 14394344256, "spec_256_mean": -23.34296658431829, "spec_256_std": 26.189295587132637, "spec_512_cnt": 28677104448, "spec_512_mean": -21.31267396860235, "spec_512_std": 26.52644536245769, "spec_1024_cnt": 57242624832, "spec_1024_mean": -18.852271129208273, "spec_1024_std": 26.443154583585663, "spec_2048_cnt": 114373665600, "spec_2048_mean": -15.638743433896792, "spec_2048_std": 26.115825961611545, "spec_4096_cnt": 228635747136, "spec_4096_mean": -11.715532502794836, "spec_4096_std": 25.763972210234062, "melspec_256_cnt": 14282760192, "melspec_256_mean": -26.962600400166156, "melspec_256_std": 36.13614100912126, "melspec_512_cnt": 14282760192, "melspec_512_mean": -9.108344167718862, "melspec_512_std": 24.71910937988429, "melspec_1024_cnt": 14282760192, "melspec_1024_mean": 0.37302579246531126, "melspec_1024_std": 18.684082325919388, "melspec_2048_cnt": 14282760192, "melspec_2048_mean": 6.768444971712967, "melspec_2048_std": 18.417922652295623, "melspec_4096_cnt": 14282760192, "melspec_4096_mean": 13.617164614990036, "melspec_4096_std": 18.08552130124525, "cqt_cnt": 9373061376, "cqt_mean": 0.46341379757927165, "cqt_std": 0.9543998080910191, "mfcc_256_cnt": 1339008768, "mfcc_256_mean": -11.681755459447485, "mfcc_256_std": 29.183186444668316, "mfcc_512_cnt": 1339008768, "mfcc_512_mean": -2.540581461792183, "mfcc_512_std": 31.93752185832081, "mfcc_1024_cnt": 1339008768, "mfcc_1024_mean": 6.606636263169779, "mfcc_1024_std": 34.151644801729624, "mfcc_2048_cnt": 1339008768, "mfcc_2048_mean": 5.281600844245184, "mfcc_2048_std": 33.12784541220003, "mfcc_4096_cnt": 1339008768, "mfcc_4096_mean": 4.7616569480166095, "mfcc_4096_std": 32.61458906894133, "chromagram_256_cnt": 1339008768, "chromagram_256_mean": 55.15596556703181, "chromagram_256_std": 73.91858278719991, "chromagram_512_cnt": 1339008768, "chromagram_512_mean": 175.73092252759895, "chromagram_512_std": 248.48485148525953, "chromagram_1024_cnt": 1339008768, "chromagram_1024_mean": 589.2947481634608, "chromagram_1024_std": 913.857929063196, "chromagram_2048_cnt": 1339008768, "chromagram_2048_mean": 2062.286388327397, "chromagram_2048_std": 3458.92657915397, "chromagram_4096_cnt": 1339008768, "chromagram_4096_mean": 7673.039107997085, "chromagram_4096_std": 13009.883158267234}
82
 
83
  # feature extractor
 
89
  self.use_rvq_target = use_rvq_target
90
 
91
  seed = 142
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  # two residual convolution layers + one projection layer
93
  self.conv = Conv2dSubsampling(
94
  1, conv_dim, encoder_dim, strides=[2, 2], n_bands=n_mels
 
212
  @torch.no_grad()
213
  def tokenize(self, x):
214
  out = {}
215
+ raise NotImplementedError("tokenize is not implemented")
216
+
 
 
 
 
 
 
 
 
217
  def get_targets(self, x):
218
  x = self.preprocessing(x, features=self.features) # -> {'melspec_2048': Tensor{Size([3, 128, 3000]) cuda:0 f32}}
219
  x = self.normalize(x)
codeclm/tokenizer/audio_tokenizer.py CHANGED
@@ -78,7 +78,8 @@ class AudioTokenizer(ABC, nn.Module):
78
  vae_config: str,
79
  vae_model: str,
80
  device: tp.Union[torch.device, str] = 'cpu',
81
- mode='extract'
 
82
  ) -> 'AudioTokenizer':
83
  """Instantiate a AudioTokenizer model from a given pretrained model.
84
 
@@ -91,11 +92,11 @@ class AudioTokenizer(ABC, nn.Module):
91
  if name.split('_')[0] == 'Flow1dVAESeparate':
92
  model_type = name.split('_', 1)[1]
93
  logger.info("Getting pretrained compression model from semantic model %s", model_type)
94
- model = Flow1dVAESeparate(model_type, vae_config, vae_model)
95
  elif name.split('_')[0] == 'Flow1dVAE1rvq':
96
  model_type = name.split('_', 1)[1]
97
  logger.info("Getting pretrained compression model from semantic model %s", model_type)
98
- model = Flow1dVAE1rvq(model_type, vae_config, vae_model)
99
  else:
100
  raise NotImplementedError("{} is not implemented in models/audio_tokenizer.py".format(
101
  name))
@@ -108,12 +109,13 @@ class Flow1dVAE1rvq(AudioTokenizer):
108
  model_type: str = "model_2_fixed.safetensors",
109
  vae_config: str = "",
110
  vae_model: str = "",
 
111
  ):
112
  super().__init__()
113
 
114
  from codeclm.tokenizer.Flow1dVAE.generate_1rvq import Tango
115
  model_path = model_type
116
- self.model = Tango(model_path=model_path, vae_config=vae_config, vae_model=vae_model, device='cuda')
117
  print ("Successfully loaded checkpoint from:", model_path)
118
 
119
 
@@ -176,6 +178,15 @@ class Flow1dVAE1rvq(AudioTokenizer):
176
  assert n <= self.total_codebooks
177
  self.n_quantizers = n
178
 
 
 
 
 
 
 
 
 
 
179
 
180
  class Flow1dVAESeparate(AudioTokenizer):
181
  def __init__(
@@ -183,12 +194,13 @@ class Flow1dVAESeparate(AudioTokenizer):
183
  model_type: str = "model_2.safetensors",
184
  vae_config: str = "",
185
  vae_model: str = "",
 
186
  ):
187
  super().__init__()
188
 
189
  from codeclm.tokenizer.Flow1dVAE.generate_septoken import Tango
190
  model_path = model_type
191
- self.model = Tango(model_path=model_path, vae_config=vae_config, vae_model=vae_model, device='cuda')
192
  print ("Successfully loaded checkpoint from:", model_path)
193
 
194
 
@@ -208,9 +220,9 @@ class Flow1dVAESeparate(AudioTokenizer):
208
  return codes_vocal, codes_bgm
209
 
210
  @torch.no_grad()
211
- def decode(self, codes: torch.Tensor, prompt_vocal = None, prompt_bgm = None, chunked=False):
212
  wav = self.model.code2sound(codes, prompt_vocal=prompt_vocal, prompt_bgm=prompt_bgm, guidance_scale=1.5,
213
- num_steps=50, disable_progress=False, chunked=chunked) # [B,N,T] -> [B,T]
214
  return wav[None]
215
 
216
 
@@ -251,3 +263,14 @@ class Flow1dVAESeparate(AudioTokenizer):
251
  assert n >= 1
252
  assert n <= self.total_codebooks
253
  self.n_quantizers = n
 
 
 
 
 
 
 
 
 
 
 
 
78
  vae_config: str,
79
  vae_model: str,
80
  device: tp.Union[torch.device, str] = 'cpu',
81
+ mode='extract',
82
+ tango_device:str='cuda'
83
  ) -> 'AudioTokenizer':
84
  """Instantiate a AudioTokenizer model from a given pretrained model.
85
 
 
92
  if name.split('_')[0] == 'Flow1dVAESeparate':
93
  model_type = name.split('_', 1)[1]
94
  logger.info("Getting pretrained compression model from semantic model %s", model_type)
95
+ model = Flow1dVAESeparate(model_type, vae_config, vae_model, tango_device=tango_device)
96
  elif name.split('_')[0] == 'Flow1dVAE1rvq':
97
  model_type = name.split('_', 1)[1]
98
  logger.info("Getting pretrained compression model from semantic model %s", model_type)
99
+ model = Flow1dVAE1rvq(model_type, vae_config, vae_model, tango_device=tango_device)
100
  else:
101
  raise NotImplementedError("{} is not implemented in models/audio_tokenizer.py".format(
102
  name))
 
109
  model_type: str = "model_2_fixed.safetensors",
110
  vae_config: str = "",
111
  vae_model: str = "",
112
+ tango_device: str = "cuda"
113
  ):
114
  super().__init__()
115
 
116
  from codeclm.tokenizer.Flow1dVAE.generate_1rvq import Tango
117
  model_path = model_type
118
+ self.model = Tango(model_path=model_path, vae_config=vae_config, vae_model=vae_model, device=tango_device)
119
  print ("Successfully loaded checkpoint from:", model_path)
120
 
121
 
 
178
  assert n <= self.total_codebooks
179
  self.n_quantizers = n
180
 
181
+ def to(self, device=None, dtype=None, non_blocking=False):
182
+ self = super(Flow1dVAE1rvq, self).to(device, dtype, non_blocking)
183
+ self.model = self.model.to(device, dtype, non_blocking)
184
+ return self
185
+
186
+ def cuda(self, device=None):
187
+ if device is None:
188
+ device = 'cuda:0'
189
+ return super(Flow1dVAE1rvq, self).cuda(device)
190
 
191
  class Flow1dVAESeparate(AudioTokenizer):
192
  def __init__(
 
194
  model_type: str = "model_2.safetensors",
195
  vae_config: str = "",
196
  vae_model: str = "",
197
+ tango_device: str = "cuda"
198
  ):
199
  super().__init__()
200
 
201
  from codeclm.tokenizer.Flow1dVAE.generate_septoken import Tango
202
  model_path = model_type
203
+ self.model = Tango(model_path=model_path, vae_config=vae_config, vae_model=vae_model, device=tango_device)
204
  print ("Successfully loaded checkpoint from:", model_path)
205
 
206
 
 
220
  return codes_vocal, codes_bgm
221
 
222
  @torch.no_grad()
223
+ def decode(self, codes: torch.Tensor, prompt_vocal = None, prompt_bgm = None, chunked=False, chunk_size=128):
224
  wav = self.model.code2sound(codes, prompt_vocal=prompt_vocal, prompt_bgm=prompt_bgm, guidance_scale=1.5,
225
+ num_steps=50, disable_progress=False, chunked=chunked, chunk_size=chunk_size) # [B,N,T] -> [B,T]
226
  return wav[None]
227
 
228
 
 
263
  assert n >= 1
264
  assert n <= self.total_codebooks
265
  self.n_quantizers = n
266
+
267
+ def to(self, device=None, dtype=None, non_blocking=False):
268
+ self = super(Flow1dVAESeparate, self).to(device, dtype, non_blocking)
269
+ self.model = self.model.to(device, dtype, non_blocking)
270
+ return self
271
+
272
+ def cuda(self, device=None):
273
+ if device is None:
274
+ device = 'cuda:0'
275
+ self = super(Flow1dVAESeparate, self).cuda(device)
276
+ return self
codeclm/trainer/codec_song_pl.py CHANGED
@@ -49,9 +49,7 @@ class CodecLM_PL(pl.LightningModule):
49
  # 3) Load pretrained checkpoint (if any)
50
  checkpoint = torch.load(ckpt_path, map_location='cpu')
51
  missing, unexpected = self.load_state_dict(checkpoint, strict=False)
52
- print(f'-------------Missing--------------\n{missing}')
53
- print(f'-------------Unexpected--------------\n{unexpected}')
54
- print("successfully load deepspeed pretrained model {}".format(ckpt_path))
55
  # 4) Build metrics
56
  self.val_steps = []
57
  self.train_slide_acc = []
@@ -70,7 +68,6 @@ class CodecLM_PL(pl.LightningModule):
70
  ) for _ in range(self.audiolm.code_depth)])
71
 
72
  self.epoch = 0
73
- print("++++++++++++++++ training <song> +++++++++++++++++")
74
 
75
  # TODO: move this part to loader
76
  def generate_mask_and_end_token(self, x, sequence_lengths, end_id=16384):
 
49
  # 3) Load pretrained checkpoint (if any)
50
  checkpoint = torch.load(ckpt_path, map_location='cpu')
51
  missing, unexpected = self.load_state_dict(checkpoint, strict=False)
52
+ print("successfully load pretrained model {}".format(ckpt_path))
 
 
53
  # 4) Build metrics
54
  self.val_steps = []
55
  self.train_slide_acc = []
 
68
  ) for _ in range(self.audiolm.code_depth)])
69
 
70
  self.epoch = 0
 
71
 
72
  # TODO: move this part to loader
73
  def generate_mask_and_end_token(self, x, sequence_lengths, end_id=16384):
codeclm/utils/offload_profiler.py ADDED
@@ -0,0 +1,505 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.func import functional_call
3
+ import queue
4
+ import threading
5
+ from typing import Dict, List, Any
6
+ import omegaconf
7
+ from pydantic import BaseModel, validator
8
+ from typing import Optional
9
+ from functools import wraps
10
+
11
+ def _callable_once(func):
12
+ @wraps(func)
13
+ def wrapper(self, *args, **kwargs):
14
+ method_called_flag = f"_called_once_{func.__name__}"
15
+ if getattr(self, method_called_flag, False):
16
+ raise RuntimeError(f"{func.__name__} can only be called once.")
17
+ setattr(self, method_called_flag, True)
18
+ return func(self, *args, **kwargs)
19
+ return wrapper
20
+
21
+ class OffloadCleanCacheWrapperParam(BaseModel):
22
+ module: Any
23
+ method_name: str
24
+ diff_mem_gb_thre: float
25
+
26
+ class OffloadParam(BaseModel):
27
+ offload_module: Any
28
+ cpu_mem_gb: float
29
+ pre_copy_step: Optional[int] = None
30
+ clean_cache_after_forward: Optional[bool] = None
31
+ dtype: Optional[str] = None
32
+ offload_layer_dict: Dict[str, int] = {}
33
+ ignore_layer_list: List[str] = []
34
+ clean_cache_wrapper: Optional[OffloadCleanCacheWrapperParam] = None
35
+ debug: Optional[bool] = None
36
+
37
+ @validator('dtype')
38
+ def parse_dtype(cls, value):
39
+ if value is None:
40
+ return None
41
+ dtype_map = {
42
+ 'torch.float16': torch.float16,
43
+ 'torch.float32': torch.float32,
44
+ 'torch.float64': torch.float64,
45
+ 'torch.int64': torch.int64,
46
+ }
47
+ if value not in dtype_map:
48
+ raise ValueError(f"Unsupported dtype: {value}")
49
+ return dtype_map[value]
50
+
51
+ def init_param_dict(self):
52
+ param_dict = {}
53
+ param_dict['cpu_mem_gb'] = self.cpu_mem_gb
54
+ if self.pre_copy_step is not None:
55
+ param_dict['pre_copy_step'] = self.pre_copy_step
56
+ if self.clean_cache_after_forward is not None:
57
+ param_dict['clean_cache_after_forward'] = self.clean_cache_after_forward
58
+ if self.debug is not None:
59
+ param_dict['debug'] = self.debug
60
+
61
+ return param_dict
62
+
63
+ def offload_layer_param_dict(self):
64
+ param_dict = {}
65
+ param_dict['module'] = self.offload_module
66
+ param_dict['offload_layer_dict'] = self.offload_layer_dict
67
+ param_dict['ignore_layer_list'] = self.ignore_layer_list
68
+ param_dict['dtype'] = self.dtype
69
+
70
+ return param_dict
71
+
72
+ def clean_cache_param_dict(self):
73
+ param_dict = {}
74
+ if self.clean_cache_wrapper is not None:
75
+ param_dict['module'] = self.clean_cache_wrapper.module
76
+ param_dict['method_name'] = self.clean_cache_wrapper.method_name
77
+ param_dict['diff_mem_gb_thre'] = self.clean_cache_wrapper.diff_mem_gb_thre
78
+
79
+ return param_dict
80
+
81
+ @staticmethod
82
+ def recursive_print(model, indent=0):
83
+ for field_name, field_info in model.__fields__.items():
84
+ field_value = getattr(model, field_name)
85
+ print(" " * indent + f"{field_name}:")
86
+
87
+ if issubclass(type(field_value), BaseModel):
88
+ print(" " * (indent + 2) + f"--- Nested model: {field_value.__class__.__name__}")
89
+ OffloadParam.recursive_print(field_value, indent + 4)
90
+ else:
91
+ print(" " * (indent + 2) + f"class: {field_value.__class__.__name__}")
92
+ if isinstance(field_value, torch.nn.Module):
93
+ pass
94
+ else:
95
+ print(" " * (indent + 2) + f"value: {field_value}")
96
+
97
+ def show(self):
98
+ print("-"*20 + "[OffloadParam]" + "-"*20)
99
+ OffloadParam.recursive_print(self)
100
+ print("-"*40)
101
+
102
+
103
+ class OffloadParamParse:
104
+ def __init__(self):
105
+ pass
106
+
107
+ @staticmethod
108
+ def _get_model(root_model: torch.nn.Module, model_dir: str):
109
+ assert(model_dir.startswith("self")), f"model_dir {model_dir} must startswith `self`"
110
+ model = root_model
111
+ for layer in model_dir.split('.'):
112
+ if layer == "self":
113
+ continue
114
+ assert(hasattr(model, layer)), f"model not has layer [{layer}]!"
115
+ model = getattr(model, layer)
116
+ return model
117
+
118
+ @staticmethod
119
+ def parse_config(root_model: torch.nn.Module, cfg: omegaconf.DictConfig)->OffloadParam:
120
+ assert(hasattr(cfg, "offload_module") and hasattr(cfg, "cpu_mem_gb") and hasattr(cfg, "dtype"))
121
+
122
+ offload_module = OffloadParamParse._get_model(root_model, cfg.offload_module)
123
+ cpu_mem_gb = cfg.cpu_mem_gb
124
+ dtype = cfg.dtype
125
+
126
+ pre_copy_step = cfg.pre_copy_step \
127
+ if hasattr(cfg, "pre_copy_step") else None
128
+
129
+ clean_cache_after_forward = cfg.clean_cache_after_forward \
130
+ if hasattr(cfg, "clean_cache_after_forward") else None
131
+
132
+ offload_layer_dict = {k: v for k, v in cfg.offload_layer_dict.items()} \
133
+ if hasattr(cfg, "offload_layer_dict") else {}
134
+
135
+ ignore_layer_list = cfg.ignore_layer_list \
136
+ if hasattr(cfg, "ignore_layer_list") else []
137
+
138
+ debug = cfg.debug if hasattr(cfg, "debug") else None
139
+
140
+ clean_cache_wrapper = None
141
+ if hasattr(cfg, "clean_cache_wrapper"):
142
+ clean_cache_cfg = cfg.clean_cache_wrapper
143
+ cc_module = OffloadParamParse._get_model(root_model, clean_cache_cfg.module)
144
+ cc_method_name = clean_cache_cfg.method_name
145
+ diff_mem_gb_thre = clean_cache_cfg.diff_mem_gb_thre
146
+ clean_cache_wrapper = OffloadCleanCacheWrapperParam(
147
+ module=cc_module,
148
+ method_name=cc_method_name,
149
+ diff_mem_gb_thre=diff_mem_gb_thre)
150
+
151
+ return OffloadParam(
152
+ offload_module=offload_module,
153
+ cpu_mem_gb=cpu_mem_gb,
154
+ pre_copy_step=pre_copy_step,
155
+ clean_cache_after_forward=clean_cache_after_forward,
156
+ dtype=dtype,
157
+ offload_layer_dict=offload_layer_dict,
158
+ ignore_layer_list=ignore_layer_list,
159
+ clean_cache_wrapper=clean_cache_wrapper,
160
+ debug=debug
161
+ )
162
+
163
+
164
+ class LayerParamStruct:
165
+ def __init__(self):
166
+ self.count = 0
167
+ self.device_state = None
168
+
169
+
170
+ class OffloadProfiler:
171
+ def __init__(self, device_index=0, cpu_mem_gb=-1, pre_copy_step=1, clean_cache_after_forward=False, debug=False):
172
+ self.clean_cache_after_forward = clean_cache_after_forward
173
+ self.cpu_mem_gb = cpu_mem_gb
174
+ self.cpu_mem_b_count = 0
175
+ self.device_index = device_index
176
+ self.execution_order = []
177
+ self.execution_order_idx = {}
178
+ self.pin_memory = False
179
+ test_data = torch.rand(1,1, device='cpu')
180
+ pin_data = test_data.pin_memory()
181
+ self.pin_memory = pin_data.is_pinned()
182
+ print(f"pin:{self.pin_memory}")
183
+ self.copy_stream = torch.cuda.Stream()
184
+ self.copy_queue = queue.Queue()
185
+ self.layer_param:Dict[str, LayerParamStruct] = {}
186
+ self.model_map = {}
187
+ self.stop_flag = False
188
+ self.copy_condition = threading.Condition()
189
+ self.queue_condition = threading.Condition()
190
+ self.mem_line_b = 0
191
+
192
+ self.copy_thread = threading.Thread(target=self._copy_thread_fun)
193
+ self.copy_thread.daemon = True
194
+ self.copy_thread.start()
195
+
196
+ self.cur_copy_idx = 0
197
+ self.execute_over = False
198
+ self.pre_copy_step = pre_copy_step
199
+
200
+ self.tmp_state_list = []
201
+ self.tmp_state_idx = 0
202
+ for i in range(pre_copy_step + 2):
203
+ self.tmp_state_list.append(None)
204
+
205
+ self.debug = debug
206
+
207
+ def stop(self):
208
+ self.stop_flag = True
209
+ with self.queue_condition:
210
+ self.queue_condition.notify()
211
+ self.copy_thread.join()
212
+
213
+ del self.layer_param
214
+ del self.model_map
215
+ del self.copy_stream
216
+
217
+ def _copy_thread_fun(self):
218
+ while self.stop_flag == False:
219
+ layer_name = "--"
220
+ with self.queue_condition:
221
+ while self.copy_queue.qsize() == 0 and self.stop_flag == False:
222
+ self.queue_condition.wait()
223
+ if self.stop_flag == True:
224
+ break
225
+ layer_name = self.copy_queue.get()
226
+ with torch.cuda.stream(self.copy_stream):
227
+ if layer_name in self.model_map:
228
+ model = self.model_map[layer_name]
229
+ self.tmp_state_list[self.tmp_state_idx] = {
230
+ k: v.to(torch.device(f"cuda:{self.device_index}"), non_blocking=False)
231
+ for k, v in model.state_dict().items()
232
+ }
233
+ self.copy_stream.synchronize()
234
+
235
+ device_state = self.tmp_state_list[self.tmp_state_idx]
236
+ self.tmp_state_idx = (self.tmp_state_idx + 1) % len(self.tmp_state_list)
237
+
238
+ with self.copy_condition:
239
+ if layer_name in self.layer_param:
240
+ self.layer_param[layer_name].count += 1
241
+ else:
242
+ self.layer_param[layer_name] = LayerParamStruct()
243
+ self.layer_param[layer_name].count = 1
244
+ self.layer_param[layer_name].device_state = device_state
245
+ self.copy_condition.notify()
246
+ else:
247
+ print(f"get model error! {layer_name}")
248
+ print("copy thread stop..")
249
+
250
+ def _get_new_step_copy_begin_end(self, tag_name):
251
+
252
+ pre_copy_step = self.pre_copy_step
253
+ pre_copy_step = min(pre_copy_step, len(self.execution_order) // 2)
254
+
255
+ cur_exe_idx = self.execution_order_idx[tag_name]
256
+ copy_begin = self.cur_copy_idx
257
+ copy_end = cur_exe_idx + pre_copy_step + 1
258
+ if copy_end - copy_begin > len(self.execution_order):
259
+ copy_end %= len(self.execution_order)
260
+ if copy_end - copy_begin > pre_copy_step + 1 or copy_end - copy_begin < 0:
261
+ # jump
262
+ self.cur_copy_idx = cur_exe_idx
263
+ copy_begin, copy_end = self._get_new_step_copy_begin_end(tag_name=tag_name)
264
+ return copy_begin, copy_end
265
+
266
+ def make_forward_wrapper(self, module, tag_name, ignore_layer_list=[]):
267
+ original_forward = module.forward
268
+ layer_param_size = 0
269
+ for name, param in module.named_parameters():
270
+ layer_param_size += param.data.numel() * param.data.element_size() / 1024 / 1024 #MB
271
+
272
+ taget_cpu_mem_b = self.cpu_mem_gb * 1024 * 1024 * 1024
273
+ offload = False
274
+ for name, param in module.named_parameters():
275
+ p_name = f"{tag_name}.{name}" if tag_name else name
276
+ for i_layer in ignore_layer_list:
277
+ if p_name.startswith(i_layer):
278
+ if self.debug:
279
+ print(f"ignore layer param: {p_name}")
280
+ continue
281
+
282
+ if taget_cpu_mem_b >= 0 and self.cpu_mem_b_count >= taget_cpu_mem_b:
283
+ break
284
+ cpu_data = torch.empty_strided(size=param.data.size(),
285
+ stride=param.data.stride(),
286
+ dtype=param.data.dtype,
287
+ layout=param.data.layout,
288
+ device='cpu',
289
+ pin_memory=self.pin_memory)
290
+ cpu_data.copy_(param.data)
291
+ param.data = cpu_data
292
+
293
+ param_size = param.data.numel() * param.data.element_size()
294
+ self.cpu_mem_b_count += param_size
295
+ offload = True
296
+ if self.debug:
297
+ print(f"layer: {tag_name}, type: {module.__class__.__name__}, size(MB): {layer_param_size}, offload: {offload}, sum_offload_size(MB): {self.cpu_mem_b_count/1024/1024}")
298
+
299
+ if offload:
300
+ copy_condition = self.copy_condition
301
+ queue_condition = self.queue_condition
302
+ copy_queue = self.copy_queue
303
+ layer_param = self.layer_param
304
+ def forward_wrapper(*args, **kwargs):
305
+ module.forward = original_forward
306
+
307
+ execute_over = False if tag_name not in self.execution_order_idx else True
308
+ if execute_over == False:
309
+ self.model_map[tag_name] = module
310
+ self.execution_order.append(tag_name)
311
+ self.execution_order_idx[tag_name] = len(self.execution_order) - 1
312
+ copy_queue.put(tag_name)
313
+ with queue_condition:
314
+ queue_condition.notify()
315
+ else:
316
+
317
+ copy_begin, copy_end = self._get_new_step_copy_begin_end(tag_name=tag_name)
318
+ if copy_end > copy_begin:
319
+ for idx in range(copy_begin, copy_end):
320
+ idx = idx % len(self.execution_order)
321
+ copy_tag_name = self.execution_order[idx]
322
+ copy_queue.put(copy_tag_name)
323
+ with queue_condition:
324
+ queue_condition.notify()
325
+
326
+ self.cur_copy_idx = copy_end % len(self.execution_order)
327
+
328
+ run_state = None
329
+ with self.copy_condition:
330
+ while tag_name not in self.layer_param:
331
+ copy_condition.wait()
332
+ run_state = self.layer_param[tag_name].device_state
333
+ self.layer_param[tag_name].count -= 1
334
+
335
+ module.eval()
336
+ with torch.no_grad():
337
+ output = functional_call(module, run_state, args=args, kwargs=kwargs)
338
+ with self.copy_condition:
339
+ if self.layer_param[tag_name].count == 0:
340
+ del self.layer_param[tag_name]
341
+ diff_mem_b_thre = 1 * (1024 ** 3)
342
+ if self.clean_cache_after_forward:
343
+ reserved = torch.cuda.memory_reserved()
344
+ if reserved > self.mem_line_b:
345
+ torch.cuda.empty_cache()
346
+ cur_reserved = torch.cuda.memory_reserved()
347
+ diff_mem = reserved - cur_reserved
348
+ if diff_mem > diff_mem_b_thre:
349
+ self.mem_line_b = cur_reserved + (reserved - cur_reserved) / 2 + 10
350
+ else:
351
+ self.mem_line_b = reserved + 10
352
+ if self.debug:
353
+ print(f"child mem line update, clean cache:{reserved/1024/1024}, cur mem: {cur_reserved/1024/1024} new limit: {self.mem_line_b / 1024 / 1024}, child name: {tag_name}")
354
+
355
+ module.forward = forward_wrapper
356
+ return output
357
+ module.forward = forward_wrapper
358
+
359
+ torch.cuda.empty_cache()
360
+ return module
361
+
362
+ def reset_empty_cache_mem_line(self):
363
+ self.mem_line_b = 0
364
+ torch.cuda.empty_cache()
365
+
366
+ def clean_cache_wrapper(self, module, method_name='', diff_mem_gb_thre=1):
367
+ if not hasattr(module, method_name) or not callable(getattr(module, method_name)):
368
+ print(f"no this method {method_name}")
369
+ return module
370
+
371
+ original_fun = getattr(module, method_name)
372
+ diff_mem_b_thre = diff_mem_gb_thre * (1024 ** 3)
373
+ self.reset_empty_cache_mem_line()
374
+
375
+ def clean_wrapper(*args, **kwargs):
376
+ setattr(module, method_name, original_fun)
377
+ output = original_fun(*args, **kwargs)
378
+ reserved = torch.cuda.memory_reserved()
379
+ if reserved > self.mem_line_b:
380
+ torch.cuda.empty_cache()
381
+ cur_reserved = torch.cuda.memory_reserved()
382
+ diff_mem = reserved - cur_reserved
383
+ if diff_mem > diff_mem_b_thre:
384
+ self.mem_line_b = cur_reserved + (reserved - cur_reserved) / 2 + 10
385
+ else:
386
+ self.mem_line_b = reserved + 10
387
+
388
+ if self.debug:
389
+ print(f"mem line update, clean cache:{reserved/1024/1024}, cur mem: {cur_reserved/1024/1024} new limit: {self.mem_line_b / 1024 / 1024}")
390
+ setattr(module, method_name, clean_wrapper)
391
+ return output
392
+
393
+ setattr(module, method_name, clean_wrapper)
394
+ return module
395
+
396
+ @_callable_once
397
+ def offload_layer(self, module, offload_layer_dict={}, ignore_layer_list=[], dtype:torch.dtype = None):
398
+ return self._offload_layer(
399
+ module=module,
400
+ tag="",
401
+ offload_layer_dict=offload_layer_dict,
402
+ ignore_layer_list=ignore_layer_list,
403
+ dtype=dtype
404
+ )
405
+
406
+ def _offload_layer(self, module, tag="", offload_layer_dict={}, ignore_layer_list=[], dtype:torch.dtype = None):
407
+ """
408
+ Offload specific layers of a PyTorch model to a specified depth.
409
+ A model can only be offloaded once.
410
+
411
+ Args:
412
+ module (torch.nn.Module):
413
+ The PyTorch model containing the layers to offload. This is the model that will be modified in place.
414
+
415
+ tag (str, optional):
416
+ A string identifier for the model.
417
+ Default is an empty string.
418
+
419
+ offload_layer_dict (dict, optional):
420
+ A dictionary where keys are layer names and values represent the depth at which the offloading should occur.
421
+ For example,
422
+ ```offload_layer_dict = {'cfm_wrapper': 5, 'hubert': 4}``` means that the `cfm_wrapper` layer should
423
+ be offloaded at depth 5, and the `hubert` layer should be offloaded at depth 4.
424
+ Default is an empty dictionary.
425
+
426
+ ignore_layer_list (list, optional):
427
+ A list of layer names or parameter identifiers to be ignored during the offloading process.
428
+ Layers in this list will not be offloaded, even if they are present in the `offload_layer_dict`.
429
+ For example,
430
+ ```ignore_layer_list = ['cfm_wrapper.estimator.h', 'cfm_wrapper.estimator.adaln_single']```
431
+ means that layers starting with `cfm_wrapper.estimator.h` or 'cfm_wrapper.estimator.adaln_single' will not be offload.
432
+ Default is an empty list.
433
+
434
+ dtype (torch.dtype, optional):
435
+ The data type (e.g., `torch.float16`, `torch.float32`) to which the offloaded layers should be converted.
436
+ If `None`, the data type of the layers will remain unchanged. Default is `None`.
437
+
438
+ Returns:
439
+ None
440
+ """
441
+ for p in module._parameters.values():
442
+ if p is not None:
443
+ p.data = p.data.to(torch.device(f"cuda:{self.device_index}"))
444
+ if dtype is not None:
445
+ p.data = p.data.to(dtype)
446
+ for b in module._buffers.values():
447
+ if b is not None:
448
+ b.data = b.data.to(torch.device(f"cuda:{self.device_index}"))
449
+ if dtype is not None:
450
+ b.data = b.data.to(dtype)
451
+ for attr_name, attr in module.__dict__.items():
452
+ if isinstance(attr, torch.Tensor) and not attr_name.startswith('_'):
453
+ attr.data = attr.data.to(torch.device(f"cuda:{self.device_index}"))
454
+ if dtype is not None:
455
+ attr.data = attr.data.to(dtype)
456
+
457
+ for name, child in module.named_children():
458
+ current_tag = f"{tag}.{name}" if tag else name
459
+ child = child.to(torch.device(f"cuda:{self.device_index}"))
460
+ if dtype is not None:
461
+ child = child.to(dtype)
462
+
463
+ torch.cuda.empty_cache()
464
+ setattr(module, name, child)
465
+ pre_name = current_tag.split('.')[0]
466
+ if pre_name not in offload_layer_dict:
467
+ param_size = 0
468
+ for p in child.parameters():
469
+ param_size += p.data.numel() * p.data.element_size()
470
+ param_size = param_size / 1024 / 1024
471
+ if self.debug:
472
+ print(f"not offload layer {current_tag}, size: {param_size}MB")
473
+ continue
474
+
475
+ has_children = any(child.named_children())
476
+ layer_count = current_tag.count('.') + 1
477
+
478
+ layer_deep = offload_layer_dict[pre_name]
479
+ if layer_count >= layer_deep:
480
+ has_children = False
481
+
482
+ if has_children:
483
+ self._offload_layer(module=child,
484
+ tag=current_tag,
485
+ offload_layer_dict=offload_layer_dict,
486
+ ignore_layer_list=ignore_layer_list,
487
+ dtype=dtype)
488
+ continue
489
+
490
+ ignore = False
491
+ for i_layer in ignore_layer_list:
492
+ if current_tag.startswith(i_layer):
493
+ ignore = True
494
+ if self.debug:
495
+ print(f"ignore layer offload: {current_tag}")
496
+ break
497
+
498
+ if hasattr(child, "forward") and not ignore:
499
+ child = self.make_forward_wrapper(
500
+ child, current_tag, ignore_layer_list=ignore_layer_list
501
+ )
502
+ return module
503
+
504
+ def get_execution_order(self):
505
+ return self.execution_order
download.py CHANGED
@@ -7,7 +7,7 @@ def download_model(local_dir):
7
  downloaded_path = snapshot_download(
8
  repo_id=repo_id,
9
  local_dir=local_dir,
10
- revision="0c80d30",
11
  token=os.environ.get("HF_TOKEN"),
12
  ignore_patterns=['.git*']
13
  )
 
7
  downloaded_path = snapshot_download(
8
  repo_id=repo_id,
9
  local_dir=local_dir,
10
+ revision="647f0a5",
11
  token=os.environ.get("HF_TOKEN"),
12
  ignore_patterns=['.git*']
13
  )
generate.py CHANGED
@@ -1,5 +1,7 @@
 
1
  import sys
2
  import os
 
3
 
4
  import time
5
  import json
@@ -7,11 +9,13 @@ import torch
7
  import torchaudio
8
  import numpy as np
9
  from omegaconf import OmegaConf
10
-
 
11
  from codeclm.trainer.codec_song_pl import CodecLM_PL
12
  from codeclm.models import CodecLM
13
  from third_party.demucs.models.pretrained import get_model_from_yaml
14
 
 
15
  auto_prompt_type = ['Pop', 'R&B', 'Dance', 'Jazz', 'Folk', 'Rock', 'Chinese Style', 'Chinese Tradition', 'Metal', 'Reggae', 'Chinese Opera', 'Auto']
16
 
17
  class Separator:
@@ -34,8 +38,6 @@ class Separator:
34
  a = torchaudio.functional.resample(a, fs, 48000)
35
  if a.shape[-1] >= 48000*10:
36
  a = a[..., :48000*10]
37
- else:
38
- a = torch.cat([a, a], -1)
39
  return a[:, 0:48000*10]
40
 
41
  def run(self, audio_path, output_dir='tmp', ext=".flac"):
@@ -59,38 +61,146 @@ class Separator:
59
  return full_audio, vocal_audio, bgm_audio
60
 
61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
- if __name__ == "__main__":
64
- torch.backends.cudnn.enabled = False
65
- OmegaConf.register_new_resolver("eval", lambda x: eval(x))
66
- OmegaConf.register_new_resolver("concat", lambda *x: [xxx for xx in x for xxx in xx])
67
- OmegaConf.register_new_resolver("get_fname", lambda: os.path.splitext(os.path.basename(sys.argv[1]))[0])
68
- OmegaConf.register_new_resolver("load_yaml", lambda x: list(OmegaConf.load(x)))
69
- np.random.seed(int(time.time()))
70
- ckpt_path = sys.argv[1]
71
- input_jsonl = sys.argv[2]
72
- save_dir = sys.argv[3]
73
- gen_type = sys.argv[4] if len(sys.argv) > 4 else "all"
74
  cfg_path = os.path.join(ckpt_path, 'config.yaml')
75
  ckpt_path = os.path.join(ckpt_path, 'model.pt')
76
  cfg = OmegaConf.load(cfg_path)
 
 
77
  cfg.mode = 'inference'
78
  max_duration = cfg.max_dur
 
79
 
80
- # Define model or load pretrained model
81
- model_light = CodecLM_PL(cfg, ckpt_path)
82
 
83
- model_light = model_light.eval().cuda()
84
- model_light.audiolm.cfg = cfg
85
- model = CodecLM(name = "tmp",
86
- lm = model_light.audiolm,
87
- audiotokenizer = model_light.audio_tokenizer,
88
- max_duration = max_duration,
89
- seperate_tokenizer = model_light.seperate_tokenizer,
90
- )
91
  separator = Separator()
92
  auto_prompt = torch.load('ckpt/prompt.pt')
 
 
93
  merge_prompt = [item for sublist in auto_prompt.values() for item in sublist]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  cfg_coef = 1.5 #25
95
  temp = 0.9
96
  top_k = 50
@@ -104,21 +214,135 @@ if __name__ == "__main__":
104
  os.makedirs(save_dir + "/audios", exist_ok=True)
105
  os.makedirs(save_dir + "/jsonl", exist_ok=True)
106
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
  with open(input_jsonl, "r") as fp:
108
  lines = fp.readlines()
109
-
 
 
 
 
 
 
 
 
 
 
110
  new_items = []
111
  for line in lines:
112
  item = json.loads(line)
113
  target_wav_name = f"{save_dir}/audios/{item['idx']}.flac"
114
- lyric = item["gt_lyric"]
115
- descriptions = item["descriptions"] if "descriptions" in item else None
116
  # get prompt audio
117
  if "prompt_audio_path" in item:
118
  assert os.path.exists(item['prompt_audio_path']), f"prompt_audio_path {item['prompt_audio_path']} not found"
119
  assert 'auto_prompt_audio_type' not in item, f"auto_prompt_audio_type and prompt_audio_path cannot be used together"
120
- pmt_wav, vocal_wav, bgm_wav = separator.run(item['prompt_audio_path'])
121
- melody_is_wav = True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
  elif "auto_prompt_audio_type" in item:
123
  assert item["auto_prompt_audio_type"] in auto_prompt_type, f"auto_prompt_audio_type {item['auto_prompt_audio_type']} not found"
124
  if item["auto_prompt_audio_type"] == "Auto":
@@ -134,6 +358,86 @@ if __name__ == "__main__":
134
  vocal_wav = None
135
  bgm_wav = None
136
  melody_is_wav = True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
 
138
  generate_inp = {
139
  'lyrics': [lyric.replace(" ", " ")],
@@ -143,25 +447,119 @@ if __name__ == "__main__":
143
  'bgm_wavs': bgm_wav,
144
  'melody_is_wav': melody_is_wav,
145
  }
146
- start_time = time.time()
147
  with torch.autocast(device_type="cuda", dtype=torch.float16):
148
- tokens = model.generate(**generate_inp, return_tokens=True)
149
- mid_time = time.time()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
 
 
151
  with torch.no_grad():
152
- if melody_is_wav:
153
- wav_seperate = model.generate_audio(tokens, pmt_wav, vocal_wav, bgm_wav, gen_type=gen_type)
 
 
 
 
 
 
 
 
 
 
154
  else:
155
- wav_seperate = model.generate_audio(tokens, gen_type=gen_type)
156
- end_time = time.time()
157
- torchaudio.save(target_wav_name, wav_seperate[0].cpu().float(), cfg.sample_rate)
158
- print(f"process{item['idx']} {gen_type}, lm cost {mid_time - start_time}s, diffusion cost {end_time - mid_time}")
159
-
160
- item["idx"] = f"{item['idx']}"
161
- item["wav_path"] = target_wav_name
162
- new_items.append(item)
 
 
 
 
 
 
 
 
 
 
 
163
 
 
 
 
164
  src_jsonl_name = os.path.split(input_jsonl)[-1]
165
  with open(f"{save_dir}/jsonl/{src_jsonl_name}.jsonl", "w", encoding='utf-8') as fw:
166
  for item in new_items:
167
  fw.writelines(json.dumps(item, ensure_ascii=False)+"\n")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from hmac import new
2
  import sys
3
  import os
4
+ import argparse
5
 
6
  import time
7
  import json
 
9
  import torchaudio
10
  import numpy as np
11
  from omegaconf import OmegaConf
12
+ from codeclm.models import builders
13
+ import gc
14
  from codeclm.trainer.codec_song_pl import CodecLM_PL
15
  from codeclm.models import CodecLM
16
  from third_party.demucs.models.pretrained import get_model_from_yaml
17
 
18
+
19
  auto_prompt_type = ['Pop', 'R&B', 'Dance', 'Jazz', 'Folk', 'Rock', 'Chinese Style', 'Chinese Tradition', 'Metal', 'Reggae', 'Chinese Opera', 'Auto']
20
 
21
  class Separator:
 
38
  a = torchaudio.functional.resample(a, fs, 48000)
39
  if a.shape[-1] >= 48000*10:
40
  a = a[..., :48000*10]
 
 
41
  return a[:, 0:48000*10]
42
 
43
  def run(self, audio_path, output_dir='tmp', ext=".flac"):
 
61
  return full_audio, vocal_audio, bgm_audio
62
 
63
 
64
+ def parse_args():
65
+ parser = argparse.ArgumentParser(description='Song Generation Script')
66
+
67
+ # 必需参数
68
+ parser.add_argument('--ckpt_path', type=str, required=True,
69
+ help='Path to the checkpoint directory containing config.yaml and model.pt')
70
+ parser.add_argument('--input_jsonl', type=str, required=True,
71
+ help='Path to input JSONL file containing generation tasks')
72
+ parser.add_argument('--save_dir', type=str, required=True,
73
+ help='Directory to save generated audio files and results')
74
+ # 可选参数
75
+ parser.add_argument('--generate_type', type=str, default='mixed',
76
+ help='Type of generation: "vocal" or "bgm" or "separate" or "mixed" (default: "mixed")')
77
+ parser.add_argument('--use_flash_attn', action='store_true',
78
+ help='Whether to use flash attention (default: False)')
79
+ parser.add_argument('--low_mem', action='store_true',
80
+ help='Whether to use low memory mode (default: False)')
81
+ return parser.parse_args()
82
 
83
+ def generate(args):
84
+ ckpt_path = args.ckpt_path
85
+ input_jsonl = args.input_jsonl
86
+ save_dir = args.save_dir
 
 
 
 
 
 
 
87
  cfg_path = os.path.join(ckpt_path, 'config.yaml')
88
  ckpt_path = os.path.join(ckpt_path, 'model.pt')
89
  cfg = OmegaConf.load(cfg_path)
90
+ cfg.lm.use_flash_attn_2 = args.use_flash_attn
91
+ print(f"use_flash_attn: {args.use_flash_attn}")
92
  cfg.mode = 'inference'
93
  max_duration = cfg.max_dur
94
+ gen_type = args.generate_type
95
 
 
 
96
 
 
 
 
 
 
 
 
 
97
  separator = Separator()
98
  auto_prompt = torch.load('ckpt/prompt.pt')
99
+ audio_tokenizer = builders.get_audio_tokenizer_model(cfg.audio_tokenizer_checkpoint, cfg)
100
+ audio_tokenizer = audio_tokenizer.eval().cuda()
101
  merge_prompt = [item for sublist in auto_prompt.values() for item in sublist]
102
+ with open(input_jsonl, "r") as fp:
103
+ lines = fp.readlines()
104
+
105
+
106
+ new_items = []
107
+ for line in lines:
108
+ item = json.loads(line)
109
+ target_wav_name = f"{save_dir}/audios/{item['idx']}.flac"
110
+ # get prompt audio
111
+ if "prompt_audio_path" in item:
112
+ assert os.path.exists(item['prompt_audio_path']), f"prompt_audio_path {item['prompt_audio_path']} not found"
113
+ assert 'auto_prompt_audio_type' not in item, f"auto_prompt_audio_type and prompt_audio_path cannot be used together"
114
+ with torch.no_grad():
115
+ pmt_wav, vocal_wav, bgm_wav = separator.run(item['prompt_audio_path'])
116
+ item['raw_pmt_wav'] = pmt_wav
117
+ item['raw_vocal_wav'] = vocal_wav
118
+ item['raw_bgm_wav'] = bgm_wav
119
+ if pmt_wav.dim() == 2:
120
+ pmt_wav = pmt_wav[None]
121
+ if pmt_wav.dim() != 3:
122
+ raise ValueError("Melody wavs should have a shape [B, C, T].")
123
+ pmt_wav = list(pmt_wav)
124
+ if vocal_wav.dim() == 2:
125
+ vocal_wav = vocal_wav[None]
126
+ if vocal_wav.dim() != 3:
127
+ raise ValueError("Vocal wavs should have a shape [B, C, T].")
128
+ vocal_wav = list(vocal_wav)
129
+ if bgm_wav.dim() == 2:
130
+ bgm_wav = bgm_wav[None]
131
+ if bgm_wav.dim() != 3:
132
+ raise ValueError("BGM wavs should have a shape [B, C, T].")
133
+ bgm_wav = list(bgm_wav)
134
+ if type(pmt_wav) == list:
135
+ pmt_wav = torch.stack(pmt_wav, dim=0)
136
+ if type(vocal_wav) == list:
137
+ vocal_wav = torch.stack(vocal_wav, dim=0)
138
+ if type(bgm_wav) == list:
139
+ bgm_wav = torch.stack(bgm_wav, dim=0)
140
+ pmt_wav = pmt_wav
141
+ vocal_wav = vocal_wav
142
+ bgm_wav = bgm_wav
143
+ with torch.no_grad():
144
+ pmt_wav, _ = audio_tokenizer.encode(pmt_wav.cuda())
145
+ melody_is_wav = False
146
+ elif "auto_prompt_audio_type" in item:
147
+ assert item["auto_prompt_audio_type"] in auto_prompt_type, f"auto_prompt_audio_type {item['auto_prompt_audio_type']} not found"
148
+ if item["auto_prompt_audio_type"] == "Auto":
149
+ prompt_token = merge_prompt[np.random.randint(0, len(merge_prompt))]
150
+ else:
151
+ prompt_token = auto_prompt[item["auto_prompt_audio_type"]][np.random.randint(0, len(auto_prompt[item["auto_prompt_audio_type"]]))]
152
+ pmt_wav = prompt_token[:,[0],:]
153
+ vocal_wav = prompt_token[:,[1],:]
154
+ bgm_wav = prompt_token[:,[2],:]
155
+ melody_is_wav = False
156
+ else:
157
+ pmt_wav = None
158
+ vocal_wav = None
159
+ bgm_wav = None
160
+ melody_is_wav = True
161
+ item['pmt_wav'] = pmt_wav
162
+ item['vocal_wav'] = vocal_wav
163
+ item['bgm_wav'] = bgm_wav
164
+ item['melody_is_wav'] = melody_is_wav
165
+ item["idx"] = f"{item['idx']}"
166
+ item["wav_path"] = target_wav_name
167
+ new_items.append(item)
168
+
169
+ del audio_tokenizer
170
+ del separator
171
+
172
+ torch.cuda.empty_cache()
173
+
174
+ if "audio_tokenizer_checkpoint_sep" in cfg.keys():
175
+ seperate_tokenizer = builders.get_audio_tokenizer_model(cfg.audio_tokenizer_checkpoint_sep, cfg)
176
+ else:
177
+ seperate_tokenizer = None
178
+
179
+ if seperate_tokenizer is not None:
180
+ seperate_tokenizer = seperate_tokenizer.eval().cuda()
181
+
182
+ for item in new_items:
183
+ if "prompt_audio_path" in item:
184
+ with torch.no_grad():
185
+ vocal_wav, bgm_wav = seperate_tokenizer.encode(item['vocal_wav'].cuda(), item['bgm_wav'].cuda())
186
+ item['vocal_wav'] = vocal_wav
187
+ item['bgm_wav'] = bgm_wav
188
+
189
+ torch.cuda.empty_cache()
190
+ audiolm = builders.get_lm_model(cfg)
191
+ checkpoint = torch.load(ckpt_path, map_location='cpu')
192
+ audiolm_state_dict = {k.replace('audiolm.', ''): v for k, v in checkpoint.items() if k.startswith('audiolm')}
193
+ audiolm.load_state_dict(audiolm_state_dict, strict=False)
194
+ audiolm = audiolm.eval()
195
+ audiolm = audiolm.cuda().to(torch.float16)
196
+
197
+ model = CodecLM(name = "tmp",
198
+ lm = audiolm,
199
+ audiotokenizer = None,
200
+ max_duration = max_duration,
201
+ seperate_tokenizer = seperate_tokenizer,
202
+ )
203
+
204
  cfg_coef = 1.5 #25
205
  temp = 0.9
206
  top_k = 50
 
214
  os.makedirs(save_dir + "/audios", exist_ok=True)
215
  os.makedirs(save_dir + "/jsonl", exist_ok=True)
216
 
217
+ for item in new_items:
218
+ lyric = item["gt_lyric"]
219
+ descriptions = item["descriptions"] if "descriptions" in item else None
220
+ pmt_wav = item['pmt_wav']
221
+ vocal_wav = item['vocal_wav']
222
+ bgm_wav = item['bgm_wav']
223
+ melody_is_wav = item['melody_is_wav']
224
+ target_wav_name = f"{save_dir}/audios/{item['idx']}.flac"
225
+
226
+
227
+ generate_inp = {
228
+ 'lyrics': [lyric.replace(" ", " ")],
229
+ 'descriptions': [descriptions],
230
+ 'melody_wavs': pmt_wav,
231
+ 'vocal_wavs': vocal_wav,
232
+ 'bgm_wavs': bgm_wav,
233
+ 'melody_is_wav': melody_is_wav,
234
+ }
235
+ start_time = time.time()
236
+ with torch.autocast(device_type="cuda", dtype=torch.float16):
237
+ with torch.no_grad():
238
+ tokens = model.generate(**generate_inp, return_tokens=True)
239
+ mid_time = time.time()
240
+
241
+ with torch.no_grad():
242
+ if 'raw_pmt_wav' in item:
243
+ if gen_type == 'separate':
244
+ wav_seperate = model.generate_audio(tokens, item['raw_pmt_wav'], item['raw_vocal_wav'], item['raw_bgm_wav'], chunked=True, gen_type='mixed')
245
+ wav_vocal = model.generate_audio(tokens, item['raw_pmt_wav'], item['raw_vocal_wav'], item['raw_bgm_wav'], chunked=True, gen_type='vocal')
246
+ wav_bgm = model.generate_audio(tokens, item['raw_pmt_wav'], item['raw_vocal_wav'], item['raw_bgm_wav'], chunked=True, gen_type='bgm')
247
+ elif gen_type == 'mixed':
248
+ wav_seperate = model.generate_audio(tokens, item['raw_pmt_wav'], item['raw_vocal_wav'], item['raw_bgm_wav'],chunked=True, gen_type=gen_type)
249
+ else:
250
+ wav_seperate = model.generate_audio(tokens,chunked=True, gen_type=gen_type)
251
+ del item['raw_pmt_wav']
252
+ del item['raw_vocal_wav']
253
+ del item['raw_bgm_wav']
254
+ else:
255
+ if gen_type == 'separate':
256
+ wav_vocal = model.generate_audio(tokens, chunked=True, gen_type='vocal')
257
+ wav_bgm = model.generate_audio(tokens, chunked=True, gen_type='bgm')
258
+ wav_seperate = model.generate_audio(tokens, chunked=True, gen_type='mixed')
259
+ else:
260
+ wav_seperate = model.generate_audio(tokens, chunked=True, gen_type=gen_type)
261
+ del item['pmt_wav']
262
+ del item['vocal_wav']
263
+ del item['bgm_wav']
264
+ del item['melody_is_wav']
265
+ end_time = time.time()
266
+ if gen_type == 'separate':
267
+ torchaudio.save(target_wav_name.replace('.flac', '_vocal.flac'), wav_vocal[0].cpu().float(), cfg.sample_rate)
268
+ torchaudio.save(target_wav_name.replace('.flac', '_bgm.flac'), wav_bgm[0].cpu().float(), cfg.sample_rate)
269
+ torchaudio.save(target_wav_name, wav_seperate[0].cpu().float(), cfg.sample_rate)
270
+ else:
271
+ torchaudio.save(target_wav_name, wav_seperate[0].cpu().float(), cfg.sample_rate)
272
+
273
+ print(f"process{item['idx']}, lm cost {mid_time - start_time}s, diffusion cost {end_time - mid_time}")
274
+ item["idx"] = f"{item['idx']}"
275
+ item["wav_path"] = target_wav_name
276
+
277
+ src_jsonl_name = os.path.split(input_jsonl)[-1]
278
+ with open(f"{save_dir}/jsonl/{src_jsonl_name}.jsonl", "w", encoding='utf-8') as fw:
279
+ for item in new_items:
280
+ fw.writelines(json.dumps(item, ensure_ascii=False)+"\n")
281
+
282
+ def generate_lowmem(args):
283
+ ckpt_path = args.ckpt_path
284
+ input_jsonl = args.input_jsonl
285
+ save_dir = args.save_dir
286
+ cfg_path = os.path.join(ckpt_path, 'config.yaml')
287
+ ckpt_path = os.path.join(ckpt_path, 'model.pt')
288
+ cfg = OmegaConf.load(cfg_path)
289
+ cfg.lm.use_flash_attn_2 = args.use_flash_attn
290
+ print(f"use_flash_attn: {args.use_flash_attn}")
291
+ cfg.mode = 'inference'
292
+ max_duration = cfg.max_dur
293
+ gen_type = args.generate_type
294
+ chunk_size = 128
295
+ use_audio_tokenizer = False
296
  with open(input_jsonl, "r") as fp:
297
  lines = fp.readlines()
298
+ for line in lines:
299
+ item = json.loads(line)
300
+ if "prompt_audio_path" in item:
301
+ use_audio_tokenizer = True
302
+ break
303
+ if use_audio_tokenizer:
304
+ separator = Separator()
305
+ audio_tokenizer = builders.get_audio_tokenizer_model(cfg.audio_tokenizer_checkpoint, cfg)
306
+ audio_tokenizer = audio_tokenizer.eval().cuda()
307
+ auto_prompt = torch.load('ckpt/prompt.pt')
308
+ merge_prompt = [item for sublist in auto_prompt.values() for item in sublist]
309
  new_items = []
310
  for line in lines:
311
  item = json.loads(line)
312
  target_wav_name = f"{save_dir}/audios/{item['idx']}.flac"
 
 
313
  # get prompt audio
314
  if "prompt_audio_path" in item:
315
  assert os.path.exists(item['prompt_audio_path']), f"prompt_audio_path {item['prompt_audio_path']} not found"
316
  assert 'auto_prompt_audio_type' not in item, f"auto_prompt_audio_type and prompt_audio_path cannot be used together"
317
+ with torch.no_grad():
318
+ pmt_wav, vocal_wav, bgm_wav = separator.run(item['prompt_audio_path'])
319
+ item['raw_pmt_wav'] = pmt_wav
320
+ item['raw_vocal_wav'] = vocal_wav
321
+ item['raw_bgm_wav'] = bgm_wav
322
+ if pmt_wav.dim() == 2:
323
+ pmt_wav = pmt_wav[None]
324
+ if pmt_wav.dim() != 3:
325
+ raise ValueError("Melody wavs should have a shape [B, C, T].")
326
+ pmt_wav = list(pmt_wav)
327
+ if vocal_wav.dim() == 2:
328
+ vocal_wav = vocal_wav[None]
329
+ if vocal_wav.dim() != 3:
330
+ raise ValueError("Vocal wavs should have a shape [B, C, T].")
331
+ vocal_wav = list(vocal_wav)
332
+ if bgm_wav.dim() == 2:
333
+ bgm_wav = bgm_wav[None]
334
+ if bgm_wav.dim() != 3:
335
+ raise ValueError("BGM wavs should have a shape [B, C, T].")
336
+ bgm_wav = list(bgm_wav)
337
+ if type(pmt_wav) == list:
338
+ pmt_wav = torch.stack(pmt_wav, dim=0)
339
+ if type(vocal_wav) == list:
340
+ vocal_wav = torch.stack(vocal_wav, dim=0)
341
+ if type(bgm_wav) == list:
342
+ bgm_wav = torch.stack(bgm_wav, dim=0)
343
+ with torch.no_grad():
344
+ pmt_wav, _ = audio_tokenizer.encode(pmt_wav.cuda())
345
+ melody_is_wav = False
346
  elif "auto_prompt_audio_type" in item:
347
  assert item["auto_prompt_audio_type"] in auto_prompt_type, f"auto_prompt_audio_type {item['auto_prompt_audio_type']} not found"
348
  if item["auto_prompt_audio_type"] == "Auto":
 
358
  vocal_wav = None
359
  bgm_wav = None
360
  melody_is_wav = True
361
+ item['pmt_wav'] = pmt_wav
362
+ item['vocal_wav'] = vocal_wav
363
+ item['bgm_wav'] = bgm_wav
364
+ item['melody_is_wav'] = melody_is_wav
365
+ item["idx"] = f"{item['idx']}"
366
+ item["wav_path"] = target_wav_name
367
+ new_items.append(item)
368
+
369
+ if use_audio_tokenizer:
370
+ del audio_tokenizer
371
+ del separator
372
+
373
+ torch.cuda.empty_cache()
374
+
375
+ if "audio_tokenizer_checkpoint_sep" in cfg.keys() and use_audio_tokenizer:
376
+ seperate_tokenizer = builders.get_audio_tokenizer_model(cfg.audio_tokenizer_checkpoint_sep, cfg)
377
+ else:
378
+ seperate_tokenizer = None
379
+
380
+ if seperate_tokenizer is not None:
381
+ seperate_tokenizer = seperate_tokenizer.eval().cuda()
382
+
383
+ for item in new_items:
384
+ if "prompt_audio_path" in item:
385
+ with torch.no_grad():
386
+ vocal_wav, bgm_wav = seperate_tokenizer.encode(item['vocal_wav'].cuda(), item['bgm_wav'].cuda())
387
+ item['vocal_wav'] = vocal_wav
388
+ item['bgm_wav'] = bgm_wav
389
+
390
+ if use_audio_tokenizer:
391
+ del seperate_tokenizer
392
+
393
+ torch.cuda.empty_cache()
394
+
395
+ # Define model or load pretrained model
396
+ audiolm = builders.get_lm_model(cfg)
397
+ checkpoint = torch.load(ckpt_path, map_location='cpu')
398
+ audiolm_state_dict = {k.replace('audiolm.', ''): v for k, v in checkpoint.items() if k.startswith('audiolm')}
399
+ audiolm.load_state_dict(audiolm_state_dict, strict=False)
400
+ audiolm = audiolm.eval()
401
+
402
+ offload_audiolm = True if 'offload' in cfg.keys() and 'audiolm' in cfg.offload else False
403
+ if offload_audiolm:
404
+ audiolm_offload_param = OffloadParamParse.parse_config(audiolm, cfg.offload.audiolm)
405
+ audiolm_offload_param.show()
406
+ offload_profiler = OffloadProfiler(device_index=0, **(audiolm_offload_param.init_param_dict()))
407
+ offload_profiler.offload_layer(**(audiolm_offload_param.offload_layer_param_dict()))
408
+ offload_profiler.clean_cache_wrapper(**(audiolm_offload_param.clean_cache_param_dict()))
409
+ else:
410
+ audiolm = audiolm.cuda().to(torch.float16)
411
+
412
+ model = CodecLM(name = "tmp",
413
+ lm = audiolm,
414
+ audiotokenizer = None,
415
+ max_duration = max_duration,
416
+ seperate_tokenizer = None,
417
+ )
418
+
419
+ cfg_coef = 1.5 #25
420
+ temp = 0.9
421
+ top_k = 50
422
+ top_p = 0.0
423
+ record_tokens = True
424
+ record_window = 50
425
+
426
+
427
+ model.set_generation_params(duration=max_duration, extend_stride=5, temperature=temp, cfg_coef=cfg_coef,
428
+ top_k=top_k, top_p=top_p, record_tokens=record_tokens, record_window=record_window)
429
+ os.makedirs(save_dir, exist_ok=True)
430
+ os.makedirs(save_dir + "/audios", exist_ok=True)
431
+ os.makedirs(save_dir + "/jsonl", exist_ok=True)
432
+
433
+
434
+ for item in new_items:
435
+ lyric = item["gt_lyric"]
436
+ descriptions = item["descriptions"] if "descriptions" in item else None
437
+ pmt_wav = item['pmt_wav']
438
+ vocal_wav = item['vocal_wav']
439
+ bgm_wav = item['bgm_wav']
440
+ melody_is_wav = item['melody_is_wav']
441
 
442
  generate_inp = {
443
  'lyrics': [lyric.replace(" ", " ")],
 
447
  'bgm_wavs': bgm_wav,
448
  'melody_is_wav': melody_is_wav,
449
  }
 
450
  with torch.autocast(device_type="cuda", dtype=torch.float16):
451
+ with torch.no_grad():
452
+ tokens = model.generate(**generate_inp, return_tokens=True)
453
+ if offload_audiolm:
454
+ offload_profiler.reset_empty_cache_mem_line()
455
+ item['tokens'] = tokens
456
+ if offload_audiolm:
457
+ offload_profiler.stop()
458
+ del offload_profiler
459
+ del audiolm_offload_param
460
+ del model
461
+ audiolm = audiolm.cpu()
462
+ del audiolm
463
+ del checkpoint
464
+ gc.collect()
465
+ torch.cuda.empty_cache()
466
+
467
+ seperate_tokenizer = builders.get_audio_tokenizer_model_cpu(cfg.audio_tokenizer_checkpoint_sep, cfg)
468
+ device = "cuda:0"
469
+ seperate_tokenizer.model.device = device
470
+ seperate_tokenizer.model.vae = seperate_tokenizer.model.vae.to(device)
471
+ seperate_tokenizer.model.model.device = torch.device(device)
472
+ seperate_tokenizer = seperate_tokenizer.eval()
473
+
474
+ offload_wav_tokenizer_diffusion = True if 'offload' in cfg.keys() and 'wav_tokenizer_diffusion' in cfg.offload else False
475
+ if offload_wav_tokenizer_diffusion:
476
+ sep_offload_param = OffloadParamParse.parse_config(seperate_tokenizer, cfg.offload.wav_tokenizer_diffusion)
477
+ sep_offload_param.show()
478
+ sep_offload_profiler = OffloadProfiler(device_index=0, **(sep_offload_param.init_param_dict()))
479
+ sep_offload_profiler.offload_layer(**(sep_offload_param.offload_layer_param_dict()))
480
+ sep_offload_profiler.clean_cache_wrapper(**(sep_offload_param.clean_cache_param_dict()))
481
+ else:
482
+ seperate_tokenizer.model.model = seperate_tokenizer.model.model.to(device)
483
+
484
+ model = CodecLM(name = "tmp",
485
+ lm = None,
486
+ audiotokenizer = None,
487
+ max_duration = max_duration,
488
+ seperate_tokenizer = seperate_tokenizer,
489
+ )
490
 
491
+ for item in new_items:
492
  with torch.no_grad():
493
+ if 'raw_pmt_wav' in item:
494
+ if gen_type == 'separate':
495
+ wav_seperate = model.generate_audio(item['tokens'], item['raw_pmt_wav'], item['raw_vocal_wav'], item['raw_bgm_wav'],chunked=True, gen_type='mixed')
496
+ wav_vocal = model.generate_audio(item['tokens'],chunked=True, gen_type='vocal')
497
+ wav_bgm = model.generate_audio(item['tokens'], chunked=True, gen_type='bgm')
498
+ elif gen_type == 'mixed':
499
+ wav_seperate = model.generate_audio(item['tokens'], item['raw_pmt_wav'], item['raw_vocal_wav'], item['raw_bgm_wav'],chunked=True, gen_type=gen_type)
500
+ else:
501
+ wav_seperate = model.generate_audio(item['tokens'], chunked=True, gen_type=gen_type)
502
+ del item['raw_pmt_wav']
503
+ del item['raw_vocal_wav']
504
+ del item['raw_bgm_wav']
505
  else:
506
+ if gen_type == 'separate':
507
+ wav_vocal = model.generate_audio(item['tokens'], chunked=True, gen_type='vocal')
508
+ wav_bgm = model.generate_audio(item['tokens'], chunked=True, gen_type='bgm')
509
+ wav_seperate = model.generate_audio(item['tokens'], chunked=True, gen_type='mixed')
510
+ else:
511
+ wav_seperate = model.generate_audio(item['tokens'], chunked=True, gen_type=gen_type)
512
+ if gen_type == 'separate':
513
+ torchaudio.save(item['wav_path'].replace('.flac', '_vocal.flac'), wav_vocal[0].cpu().float(), cfg.sample_rate)
514
+ torchaudio.save(item['wav_path'].replace('.flac', '_bgm.flac'), wav_bgm[0].cpu().float(), cfg.sample_rate)
515
+ torchaudio.save(item['wav_path'], wav_seperate[0].cpu().float(), cfg.sample_rate)
516
+ else:
517
+ torchaudio.save(item['wav_path'], wav_seperate[0].cpu().float(), cfg.sample_rate)
518
+ del item['tokens']
519
+ del item['pmt_wav']
520
+ del item['vocal_wav']
521
+ del item['bgm_wav']
522
+ del item['melody_is_wav']
523
+ if offload_wav_tokenizer_diffusion:
524
+ sep_offload_profiler.reset_empty_cache_mem_line()
525
 
526
+ if offload_wav_tokenizer_diffusion:
527
+ sep_offload_profiler.stop()
528
+ torch.cuda.empty_cache()
529
  src_jsonl_name = os.path.split(input_jsonl)[-1]
530
  with open(f"{save_dir}/jsonl/{src_jsonl_name}.jsonl", "w", encoding='utf-8') as fw:
531
  for item in new_items:
532
  fw.writelines(json.dumps(item, ensure_ascii=False)+"\n")
533
+
534
+
535
+ if __name__ == "__main__":
536
+ torch.backends.cudnn.enabled = False
537
+ OmegaConf.register_new_resolver("eval", lambda x: eval(x))
538
+ OmegaConf.register_new_resolver("concat", lambda *x: [xxx for xx in x for xxx in xx])
539
+ OmegaConf.register_new_resolver("get_fname", lambda: os.path.splitext(os.path.basename(sys.argv[1]))[0])
540
+ OmegaConf.register_new_resolver("load_yaml", lambda x: list(OmegaConf.load(x)))
541
+ np.random.seed(int(time.time()))
542
+ # 解析命令行参数
543
+ args = parse_args()
544
+ if torch.cuda.is_available():
545
+ device = torch.cuda.current_device()
546
+ reserved = torch.cuda.memory_reserved(device)
547
+ total = torch.cuda.get_device_properties(device).total_memory
548
+ res_mem = (total - reserved) / 1024 / 1024 / 1024
549
+ print(f"reserved memory: {res_mem}GB")
550
+
551
+ model_name = args.ckpt_path.split("/")[-1]
552
+ assert model_name in ['songgeneration_base'], f'{model_name} is not supported, currently only songgeneration_base is supported'
553
+ if model_name == 'songgeneration_base':
554
+ if res_mem > 24 and not args.low_mem:
555
+ print("use generate")
556
+ generate(args)
557
+ else:
558
+ from codeclm.utils.offload_profiler import OffloadProfiler, OffloadParamParse
559
+ print("use generate_lowmem")
560
+ generate_lowmem(args)
561
+
562
+ else:
563
+ print("CUDA is not available")
564
+ exit()
565
+
generate.sh CHANGED
@@ -7,5 +7,66 @@ export PYTHONPATH="$(pwd)/codeclm/tokenizer/":"$(pwd)":"$(pwd)/codeclm/tokenizer
7
  CKPT_PATH=$1
8
  JSONL=$2
9
  SAVE_DIR=$3
10
- GEN_TYEP=$4
11
- python3 generate.py $CKPT_PATH $JSONL $SAVE_DIR $GEN_TYEP
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  CKPT_PATH=$1
8
  JSONL=$2
9
  SAVE_DIR=$3
10
+ USE_FLASH_ATTN="True"
11
+ LOW_MEM="False"
12
+ GENERATE_TYPE="mixed"
13
+ for arg in "$@"; do
14
+ if [[ $arg == "--not_use_flash_attn" ]]; then
15
+ USE_FLASH_ATTN="False"
16
+ fi
17
+ done
18
+ for arg in "$@"; do
19
+ if [[ $arg == "--low_mem" ]]; then
20
+ LOW_MEM="True"
21
+ fi
22
+ done
23
+ for arg in "$@"; do
24
+ if [[ $arg == "--separate" ]]; then
25
+ GENERATE_TYPE="separate"
26
+ fi
27
+ done
28
+ for arg in "$@"; do
29
+ if [[ $arg == "--bgm" ]]; then
30
+ GENERATE_TYPE="bgm"
31
+ fi
32
+ done
33
+ for arg in "$@"; do
34
+ if [[ $arg == "--vocal" ]]; then
35
+ GENERATE_TYPE="vocal"
36
+ fi
37
+ done
38
+
39
+
40
+ if [ "$USE_FLASH_ATTN" == "True" ] && [ "$LOW_MEM" == "True" ]; then
41
+ echo "Use Flash Attention + Low Memory Mode"
42
+ python3 generate.py \
43
+ --ckpt_path $CKPT_PATH \
44
+ --input_jsonl $JSONL \
45
+ --save_dir $SAVE_DIR \
46
+ --generate_type $GENERATE_TYPE \
47
+ --use_flash_attn \
48
+ --low_mem
49
+ elif [ "$USE_FLASH_ATTN" == "True" ] && [ "$LOW_MEM" == "False" ]; then
50
+ echo "Use Flash Attention + Auto Memory Mode"
51
+ python3 generate.py \
52
+ --ckpt_path $CKPT_PATH \
53
+ --input_jsonl $JSONL \
54
+ --save_dir $SAVE_DIR \
55
+ --generate_type $GENERATE_TYPE \
56
+ --use_flash_attn
57
+ elif [ "$USE_FLASH_ATTN" == "False" ] && [ "$LOW_MEM" == "False" ]; then
58
+ echo "Not Use Flash Attention + Auto Memory Mode"
59
+ python3 generate.py \
60
+ --ckpt_path $CKPT_PATH \
61
+ --input_jsonl $JSONL \
62
+ --generate_type $GENERATE_TYPE \
63
+ --save_dir $SAVE_DIR
64
+ elif [ "$USE_FLASH_ATTN" == "False" ] && [ "$LOW_MEM" == "True" ]; then
65
+ echo "Not Use Flash Attention + Low Memory Mode"
66
+ python3 generate.py \
67
+ --ckpt_path $CKPT_PATH \
68
+ --input_jsonl $JSONL \
69
+ --save_dir $SAVE_DIR \
70
+ --generate_type $GENERATE_TYPE \
71
+ --low_mem
72
+ fi
generate_lowmem.py DELETED
@@ -1,241 +0,0 @@
1
- import sys
2
- import os
3
-
4
- import time
5
- import json
6
- import torch
7
- import torchaudio
8
- import numpy as np
9
- from omegaconf import OmegaConf
10
- from codeclm.models import builders
11
-
12
- from codeclm.trainer.codec_song_pl import CodecLM_PL
13
- from codeclm.models import CodecLM
14
- from third_party.demucs.models.pretrained import get_model_from_yaml
15
-
16
- auto_prompt_type = ['Pop', 'R&B', 'Dance', 'Jazz', 'Folk', 'Rock', 'Chinese Style', 'Chinese Tradition', 'Metal', 'Reggae', 'Chinese Opera', 'Auto']
17
-
18
- class Separator:
19
- def __init__(self, dm_model_path='third_party/demucs/ckpt/htdemucs.pth', dm_config_path='third_party/demucs/ckpt/htdemucs.yaml', gpu_id=0) -> None:
20
- if torch.cuda.is_available() and gpu_id < torch.cuda.device_count():
21
- self.device = torch.device(f"cuda:{gpu_id}")
22
- else:
23
- self.device = torch.device("cpu")
24
- self.demucs_model = self.init_demucs_model(dm_model_path, dm_config_path)
25
-
26
- def init_demucs_model(self, model_path, config_path):
27
- model = get_model_from_yaml(config_path, model_path)
28
- model.to(self.device)
29
- model.eval()
30
- return model
31
-
32
- def load_audio(self, f):
33
- a, fs = torchaudio.load(f)
34
- if (fs != 48000):
35
- a = torchaudio.functional.resample(a, fs, 48000)
36
- if a.shape[-1] >= 48000*10:
37
- a = a[..., :48000*10]
38
- else:
39
- a = torch.cat([a, a], -1)
40
- return a[:, 0:48000*10]
41
-
42
- def run(self, audio_path, output_dir='tmp', ext=".flac"):
43
- os.makedirs(output_dir, exist_ok=True)
44
- name, _ = os.path.splitext(os.path.split(audio_path)[-1])
45
- output_paths = []
46
-
47
- for stem in self.demucs_model.sources:
48
- output_path = os.path.join(output_dir, f"{name}_{stem}{ext}")
49
- if os.path.exists(output_path):
50
- output_paths.append(output_path)
51
- if len(output_paths) == 1: # 4
52
- vocal_path = output_paths[0]
53
- else:
54
- drums_path, bass_path, other_path, vocal_path = self.demucs_model.separate(audio_path, output_dir, device=self.device)
55
- for path in [drums_path, bass_path, other_path]:
56
- os.remove(path)
57
- full_audio = self.load_audio(audio_path)
58
- vocal_audio = self.load_audio(vocal_path)
59
- bgm_audio = full_audio - vocal_audio
60
- return full_audio, vocal_audio, bgm_audio
61
-
62
-
63
-
64
- if __name__ == "__main__":
65
- torch.backends.cudnn.enabled = False
66
- OmegaConf.register_new_resolver("eval", lambda x: eval(x))
67
- OmegaConf.register_new_resolver("concat", lambda *x: [xxx for xx in x for xxx in xx])
68
- OmegaConf.register_new_resolver("get_fname", lambda: os.path.splitext(os.path.basename(sys.argv[1]))[0])
69
- OmegaConf.register_new_resolver("load_yaml", lambda x: list(OmegaConf.load(x)))
70
- np.random.seed(int(time.time()))
71
- ckpt_path = sys.argv[1]
72
- input_jsonl = sys.argv[2]
73
- save_dir = sys.argv[3]
74
- gen_type = sys.argv[4] if len(sys.argv) > 4 else "all"
75
- cfg_path = os.path.join(ckpt_path, 'config.yaml')
76
- ckpt_path = os.path.join(ckpt_path, 'model.pt')
77
- cfg = OmegaConf.load(cfg_path)
78
- cfg.mode = 'inference'
79
- max_duration = cfg.max_dur
80
-
81
- separator = Separator()
82
- auto_prompt = torch.load('ckpt/prompt.pt')
83
- audio_tokenizer = builders.get_audio_tokenizer_model(cfg.audio_tokenizer_checkpoint, cfg)
84
- if "audio_tokenizer_checkpoint_sep" in cfg.keys():
85
- seperate_tokenizer = builders.get_audio_tokenizer_model(cfg.audio_tokenizer_checkpoint_sep, cfg)
86
- else:
87
- seperate_tokenizer = None
88
- audio_tokenizer = audio_tokenizer.eval().cuda()
89
- if seperate_tokenizer is not None:
90
- seperate_tokenizer = seperate_tokenizer.eval().cuda()
91
-
92
- merge_prompt = [item for sublist in auto_prompt.values() for item in sublist]
93
- with open(input_jsonl, "r") as fp:
94
- lines = fp.readlines()
95
- new_items = []
96
- for line in lines:
97
- item = json.loads(line)
98
- target_wav_name = f"{save_dir}/audios/{item['idx']}.flac"
99
- # get prompt audio
100
- if "prompt_audio_path" in item:
101
- assert os.path.exists(item['prompt_audio_path']), f"prompt_audio_path {item['prompt_audio_path']} not found"
102
- assert 'auto_prompt_audio_type' not in item, f"auto_prompt_audio_type and prompt_audio_path cannot be used together"
103
- pmt_wav, vocal_wav, bgm_wav = separator.run(item['prompt_audio_path'])
104
- item['raw_pmt_wav'] = pmt_wav
105
- item['raw_vocal_wav'] = vocal_wav
106
- item['raw_bgm_wav'] = bgm_wav
107
- if pmt_wav.dim() == 2:
108
- pmt_wav = pmt_wav[None]
109
- if pmt_wav.dim() != 3:
110
- raise ValueError("Melody wavs should have a shape [B, C, T].")
111
- pmt_wav = list(pmt_wav)
112
- if vocal_wav.dim() == 2:
113
- vocal_wav = vocal_wav[None]
114
- if vocal_wav.dim() != 3:
115
- raise ValueError("Vocal wavs should have a shape [B, C, T].")
116
- vocal_wav = list(vocal_wav)
117
- if bgm_wav.dim() == 2:
118
- bgm_wav = bgm_wav[None]
119
- if bgm_wav.dim() != 3:
120
- raise ValueError("BGM wavs should have a shape [B, C, T].")
121
- bgm_wav = list(bgm_wav)
122
- if type(pmt_wav) == list:
123
- pmt_wav = torch.stack(pmt_wav, dim=0)
124
- if type(vocal_wav) == list:
125
- vocal_wav = torch.stack(vocal_wav, dim=0)
126
- if type(bgm_wav) == list:
127
- bgm_wav = torch.stack(bgm_wav, dim=0)
128
- pmt_wav = pmt_wav.cuda()
129
- vocal_wav = vocal_wav.cuda()
130
- bgm_wav = bgm_wav.cuda()
131
- pmt_wav, _ = audio_tokenizer.encode(pmt_wav)
132
- vocal_wav, bgm_wav = seperate_tokenizer.encode(vocal_wav, bgm_wav)
133
- melody_is_wav = False
134
- elif "auto_prompt_audio_type" in item:
135
- assert item["auto_prompt_audio_type"] in auto_prompt_type, f"auto_prompt_audio_type {item['auto_prompt_audio_type']} not found"
136
- if item["auto_prompt_audio_type"] == "Auto":
137
- prompt_token = merge_prompt[np.random.randint(0, len(merge_prompt))]
138
- else:
139
- prompt_token = auto_prompt[item["auto_prompt_audio_type"]][np.random.randint(0, len(auto_prompt[item["auto_prompt_audio_type"]]))]
140
- pmt_wav = prompt_token[:,[0],:]
141
- vocal_wav = prompt_token[:,[1],:]
142
- bgm_wav = prompt_token[:,[2],:]
143
- melody_is_wav = False
144
- else:
145
- pmt_wav = None
146
- vocal_wav = None
147
- bgm_wav = None
148
- melody_is_wav = True
149
- item['pmt_wav'] = pmt_wav
150
- item['vocal_wav'] = vocal_wav
151
- item['bgm_wav'] = bgm_wav
152
- item['melody_is_wav'] = melody_is_wav
153
- item["idx"] = f"{item['idx']}"
154
- item["wav_path"] = target_wav_name
155
- new_items.append(item)
156
-
157
- del audio_tokenizer
158
- del seperate_tokenizer
159
- del separator
160
-
161
- # Define model or load pretrained model
162
- model_light = CodecLM_PL(cfg, ckpt_path)
163
- model_light = model_light.eval()
164
- model_light.audiolm.cfg = cfg
165
- model = CodecLM(name = "tmp",
166
- lm = model_light.audiolm,
167
- audiotokenizer = None,
168
- max_duration = max_duration,
169
- seperate_tokenizer = None,
170
- )
171
- del model_light
172
- model.lm = model.lm.cuda().to(torch.float16)
173
-
174
- cfg_coef = 1.5 #25
175
- temp = 0.9
176
- top_k = 50
177
- top_p = 0.0
178
- record_tokens = True
179
- record_window = 50
180
-
181
- model.set_generation_params(duration=max_duration, extend_stride=5, temperature=temp, cfg_coef=cfg_coef,
182
- top_k=top_k, top_p=top_p, record_tokens=record_tokens, record_window=record_window)
183
- os.makedirs(save_dir, exist_ok=True)
184
- os.makedirs(save_dir + "/audios", exist_ok=True)
185
- os.makedirs(save_dir + "/jsonl", exist_ok=True)
186
-
187
-
188
- for item in new_items:
189
- lyric = item["gt_lyric"]
190
- descriptions = item["descriptions"] if "descriptions" in item else None
191
- pmt_wav = item['pmt_wav']
192
- vocal_wav = item['vocal_wav']
193
- bgm_wav = item['bgm_wav']
194
- melody_is_wav = item['melody_is_wav']
195
-
196
- generate_inp = {
197
- 'lyrics': [lyric.replace(" ", " ")],
198
- 'descriptions': [descriptions],
199
- 'melody_wavs': pmt_wav,
200
- 'vocal_wavs': vocal_wav,
201
- 'bgm_wavs': bgm_wav,
202
- 'melody_is_wav': melody_is_wav,
203
- }
204
- with torch.autocast(device_type="cuda", dtype=torch.float16):
205
- tokens = model.generate(**generate_inp, return_tokens=True)
206
- item['tokens'] = tokens
207
-
208
- del model
209
- torch.cuda.empty_cache()
210
-
211
-
212
- seperate_tokenizer = builders.get_audio_tokenizer_model(cfg.audio_tokenizer_checkpoint_sep, cfg)
213
- seperate_tokenizer = seperate_tokenizer.eval().cuda()
214
-
215
- model = CodecLM(name = "tmp",
216
- lm = None,
217
- audiotokenizer = None,
218
- max_duration = max_duration,
219
- seperate_tokenizer = seperate_tokenizer,
220
- )
221
- for item in new_items:
222
- with torch.no_grad():
223
- if 'raw_pmt_wav' in item:
224
- wav_seperate = model.generate_audio(item['tokens'], item['raw_pmt_wav'], item['raw_vocal_wav'], item['raw_bgm_wav'], chunked=True, gen_type=gen_type)
225
- del item['raw_pmt_wav']
226
- del item['raw_vocal_wav']
227
- del item['raw_bgm_wav']
228
- else:
229
- wav_seperate = model.generate_audio(item['tokens'], chunked=True, gen_type=gen_type)
230
- torchaudio.save(item['wav_path'], wav_seperate[0].cpu().float(), cfg.sample_rate)
231
- del item['tokens']
232
- del item['pmt_wav']
233
- del item['vocal_wav']
234
- del item['bgm_wav']
235
- del item['melody_is_wav']
236
-
237
- torch.cuda.empty_cache()
238
- src_jsonl_name = os.path.split(input_jsonl)[-1]
239
- with open(f"{save_dir}/jsonl/{src_jsonl_name}.jsonl", "w", encoding='utf-8') as fw:
240
- for item in new_items:
241
- fw.writelines(json.dumps(item, ensure_ascii=False)+"\n")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
generate_lowmem.sh DELETED
@@ -1,11 +0,0 @@
1
- export USER=root
2
- export PYTHONDONTWRITEBYTECODE=1
3
- export TRANSFORMERS_CACHE="$(pwd)/third_party/hub"
4
- export NCCL_HOME=/usr/local/tccl
5
- export PYTHONPATH="$(pwd)/codeclm/tokenizer/":"$(pwd)":"$(pwd)/codeclm/tokenizer/Flow1dVAE/":"$(pwd)/codeclm/tokenizer/":$PYTHONPATH
6
-
7
- CKPT_PATH=$1
8
- JSONL=$2
9
- SAVE_DIR=$3
10
- GEN_TYEP=$4
11
- python3 generate_lowmem.py $CKPT_PATH $JSONL $SAVE_DIR $GEN_TYEP
 
 
 
 
 
 
 
 
 
 
 
 
tools/gradio/app.py CHANGED
@@ -49,7 +49,7 @@ with open(op.join(APP_DIR, 'conf/vocab.yaml'), 'r', encoding='utf-8') as file:
49
  STRUCTS = yaml.safe_load(file)
50
 
51
 
52
- def generate_song(lyric, description=None, prompt_audio=None, genre=None, cfg_coef=None, temperature=None, top_k=None, gen_type="all", progress=gr.Progress(track_tqdm=True)):
53
  global MODEL
54
  global STRUCTS
55
  params = {'cfg_coef':cfg_coef, 'temperature':temperature, 'top_k':top_k}
@@ -240,4 +240,3 @@ lyrics
240
  # 启动应用
241
  if __name__ == "__main__":
242
  demo.launch(server_name="0.0.0.0", server_port=8081)
243
-
 
49
  STRUCTS = yaml.safe_load(file)
50
 
51
 
52
+ def generate_song(lyric, description=None, prompt_audio=None, genre=None, cfg_coef=None, temperature=None, top_k=None, gen_type="mixed", progress=gr.Progress(track_tqdm=True)):
53
  global MODEL
54
  global STRUCTS
55
  params = {'cfg_coef':cfg_coef, 'temperature':temperature, 'top_k':top_k}
 
240
  # 启动应用
241
  if __name__ == "__main__":
242
  demo.launch(server_name="0.0.0.0", server_port=8081)
 
tools/gradio/levo_inference.py CHANGED
@@ -62,7 +62,7 @@ class LeVoInference(torch.nn.Module):
62
 
63
  self.model.set_generation_params(**self.default_params)
64
 
65
- def forward(self, lyric: str, description: str = None, prompt_audio_path: os.PathLike = None, genre: str = None, auto_prompt_path: os.PathLike = None, gen_type: str = "all", params = dict()):
66
  params = {**self.default_params, **params}
67
  self.model.set_generation_params(**params)
68
 
 
62
 
63
  self.model.set_generation_params(**self.default_params)
64
 
65
+ def forward(self, lyric: str, description: str = None, prompt_audio_path: os.PathLike = None, genre: str = None, auto_prompt_path: os.PathLike = None, gen_type: str = "mixed", params = dict()):
66
  params = {**self.default_params, **params}
67
  self.model.set_generation_params(**params)
68
 
tools/gradio/levo_inference_lowmem.py CHANGED
@@ -1,4 +1,5 @@
1
  import os
 
2
  import sys
3
 
4
  import torch
@@ -12,6 +13,7 @@ from codeclm.models import CodecLM
12
  from codeclm.models import builders
13
 
14
  from separator import Separator
 
15
 
16
 
17
  class LeVoInference(torch.nn.Module):
@@ -40,24 +42,28 @@ class LeVoInference(torch.nn.Module):
40
  )
41
 
42
 
43
- def forward(self, lyric: str, description: str = None, prompt_audio_path: os.PathLike = None, genre: str = None, auto_prompt_path: os.PathLike = None, gen_type: str = "all", params = dict()):
44
  if prompt_audio_path is not None and os.path.exists(prompt_audio_path):
45
  separator = Separator()
46
  audio_tokenizer = builders.get_audio_tokenizer_model(self.cfg.audio_tokenizer_checkpoint, self.cfg)
47
  audio_tokenizer = audio_tokenizer.eval().cuda()
48
- seperate_tokenizer = builders.get_audio_tokenizer_model(self.cfg.audio_tokenizer_checkpoint_sep, self.cfg)
49
- seperate_tokenizer = seperate_tokenizer.eval().cuda()
50
  pmt_wav, vocal_wav, bgm_wav = separator.run(prompt_audio_path)
51
  pmt_wav = pmt_wav.cuda()
52
  vocal_wav = vocal_wav.cuda()
53
  bgm_wav = bgm_wav.cuda()
54
- pmt_wav, _ = audio_tokenizer.encode(pmt_wav)
55
- vocal_wav, bgm_wav = seperate_tokenizer.encode(vocal_wav, bgm_wav)
56
- melody_is_wav = False
57
- melody_is_wav = False
58
  del audio_tokenizer
59
- del seperate_tokenizer
60
  del separator
 
 
 
 
 
 
 
 
 
61
  elif genre is not None and auto_prompt_path is not None:
62
  auto_prompt = torch.load(auto_prompt_path)
63
  merge_prompt = [item for sublist in auto_prompt.values() for item in sublist]
@@ -75,17 +81,28 @@ class LeVoInference(torch.nn.Module):
75
  bgm_wav = None
76
  melody_is_wav = True
77
 
78
- model_light = CodecLM_PL(self.cfg, self.pt_path)
79
- model_light = model_light.eval()
80
- model_light.audiolm.cfg = self.cfg
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  model = CodecLM(name = "tmp",
82
- lm = model_light.audiolm,
83
  audiotokenizer = None,
84
  max_duration = self.max_duration,
85
  seperate_tokenizer = None,
86
  )
87
- del model_light
88
- model.lm = model.lm.cuda().to(torch.float16)
89
  params = {**self.default_params, **params}
90
  model.set_generation_params(**params)
91
 
@@ -99,28 +116,53 @@ class LeVoInference(torch.nn.Module):
99
  }
100
 
101
  with torch.autocast(device_type="cuda", dtype=torch.float16):
102
- tokens = model.generate(**generate_inp, return_tokens=True)
103
-
 
 
 
 
 
104
  del model
 
 
 
 
105
  torch.cuda.empty_cache()
106
 
107
- seperate_tokenizer = builders.get_audio_tokenizer_model(self.cfg.audio_tokenizer_checkpoint_sep, self.cfg)
108
- seperate_tokenizer = seperate_tokenizer.eval().cuda()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  model = CodecLM(name = "tmp",
110
  lm = None,
111
  audiotokenizer = None,
112
  max_duration = self.max_duration,
113
  seperate_tokenizer = seperate_tokenizer,
114
  )
115
-
116
  with torch.no_grad():
117
  if melody_is_wav:
118
- wav_seperate = model.generate_audio(tokens, pmt_wav, vocal_wav, bgm_wav, gen_type=gen_type)
119
  else:
120
- wav_seperate = model.generate_audio(tokens, gen_type=gen_type)
121
 
122
- del seperate_tokenizer
123
- del model
 
124
  torch.cuda.empty_cache()
125
 
126
  return wav_seperate[0]
 
1
  import os
2
+ import gc
3
  import sys
4
 
5
  import torch
 
13
  from codeclm.models import builders
14
 
15
  from separator import Separator
16
+ from codeclm.utils.offload_profiler import OffloadProfiler, OffloadParamParse
17
 
18
 
19
  class LeVoInference(torch.nn.Module):
 
42
  )
43
 
44
 
45
+ def forward(self, lyric: str, description: str = None, prompt_audio_path: os.PathLike = None, genre: str = None, auto_prompt_path: os.PathLike = None, gen_type: str = "mixed", params = dict()):
46
  if prompt_audio_path is not None and os.path.exists(prompt_audio_path):
47
  separator = Separator()
48
  audio_tokenizer = builders.get_audio_tokenizer_model(self.cfg.audio_tokenizer_checkpoint, self.cfg)
49
  audio_tokenizer = audio_tokenizer.eval().cuda()
 
 
50
  pmt_wav, vocal_wav, bgm_wav = separator.run(prompt_audio_path)
51
  pmt_wav = pmt_wav.cuda()
52
  vocal_wav = vocal_wav.cuda()
53
  bgm_wav = bgm_wav.cuda()
54
+ with torch.no_grad():
55
+ pmt_wav, _ = audio_tokenizer.encode(pmt_wav)
 
 
56
  del audio_tokenizer
 
57
  del separator
58
+ torch.cuda.empty_cache()
59
+
60
+ seperate_tokenizer = builders.get_audio_tokenizer_model(self.cfg.audio_tokenizer_checkpoint_sep, self.cfg)
61
+ seperate_tokenizer = seperate_tokenizer.eval().cuda()
62
+ with torch.no_grad():
63
+ vocal_wav, bgm_wav = seperate_tokenizer.encode(vocal_wav, bgm_wav)
64
+ del seperate_tokenizer
65
+ melody_is_wav = False
66
+ torch.cuda.empty_cache()
67
  elif genre is not None and auto_prompt_path is not None:
68
  auto_prompt = torch.load(auto_prompt_path)
69
  merge_prompt = [item for sublist in auto_prompt.values() for item in sublist]
 
81
  bgm_wav = None
82
  melody_is_wav = True
83
 
84
+ audiolm = builders.get_lm_model(self.cfg)
85
+ checkpoint = torch.load(self.pt_path, map_location='cpu')
86
+ audiolm_state_dict = {k.replace('audiolm.', ''): v for k, v in checkpoint.items() if k.startswith('audiolm')}
87
+ audiolm.load_state_dict(audiolm_state_dict, strict=False)
88
+ audiolm = audiolm.eval()
89
+
90
+ offload_audiolm = True if 'offload' in self.cfg.keys() and 'audiolm' in self.cfg.offload else False
91
+ if offload_audiolm:
92
+ audiolm_offload_param = OffloadParamParse.parse_config(audiolm, self.cfg.offload.audiolm)
93
+ audiolm_offload_param.show()
94
+ offload_profiler = OffloadProfiler(device_index=0, **(audiolm_offload_param.init_param_dict()))
95
+ offload_profiler.offload_layer(**(audiolm_offload_param.offload_layer_param_dict()))
96
+ offload_profiler.clean_cache_wrapper(**(audiolm_offload_param.clean_cache_param_dict()))
97
+ else:
98
+ audiolm = audiolm.cuda().to(torch.float16)
99
+
100
  model = CodecLM(name = "tmp",
101
+ lm = audiolm,
102
  audiotokenizer = None,
103
  max_duration = self.max_duration,
104
  seperate_tokenizer = None,
105
  )
 
 
106
  params = {**self.default_params, **params}
107
  model.set_generation_params(**params)
108
 
 
116
  }
117
 
118
  with torch.autocast(device_type="cuda", dtype=torch.float16):
119
+ with torch.no_grad():
120
+ tokens = model.generate(**generate_inp, return_tokens=True)
121
+ if offload_audiolm:
122
+ offload_profiler.reset_empty_cache_mem_line()
123
+ offload_profiler.stop()
124
+ del offload_profiler
125
+ del audiolm_offload_param
126
  del model
127
+ audiolm = audiolm.cpu()
128
+ del audiolm
129
+ del checkpoint
130
+ gc.collect()
131
  torch.cuda.empty_cache()
132
 
133
+ seperate_tokenizer = builders.get_audio_tokenizer_model_cpu(self.cfg.audio_tokenizer_checkpoint_sep, self.cfg)
134
+ device = "cuda:0"
135
+ seperate_tokenizer.model.device = device
136
+ seperate_tokenizer.model.vae = seperate_tokenizer.model.vae.to(device)
137
+ seperate_tokenizer.model.model.device = torch.device(device)
138
+ seperate_tokenizer = seperate_tokenizer.eval()
139
+
140
+ offload_wav_tokenizer_diffusion = True if 'offload' in self.cfg.keys() and 'wav_tokenizer_diffusion' in self.cfg.offload else False
141
+ if offload_wav_tokenizer_diffusion:
142
+ sep_offload_param = OffloadParamParse.parse_config(seperate_tokenizer, self.cfg.offload.wav_tokenizer_diffusion)
143
+ sep_offload_param.show()
144
+ sep_offload_profiler = OffloadProfiler(device_index=0, **(sep_offload_param.init_param_dict()))
145
+ sep_offload_profiler.offload_layer(**(sep_offload_param.offload_layer_param_dict()))
146
+ sep_offload_profiler.clean_cache_wrapper(**(sep_offload_param.clean_cache_param_dict()))
147
+ else:
148
+ seperate_tokenizer.model.model = seperate_tokenizer.model.model.to(device)
149
+
150
  model = CodecLM(name = "tmp",
151
  lm = None,
152
  audiotokenizer = None,
153
  max_duration = self.max_duration,
154
  seperate_tokenizer = seperate_tokenizer,
155
  )
156
+
157
  with torch.no_grad():
158
  if melody_is_wav:
159
+ wav_seperate = model.generate_audio(tokens, pmt_wav, vocal_wav, bgm_wav, gen_type=gen_type, chunked=True)
160
  else:
161
+ wav_seperate = model.generate_audio(tokens, gen_type=gen_type, chunked=True)
162
 
163
+ if offload_wav_tokenizer_diffusion:
164
+ sep_offload_profiler.reset_empty_cache_mem_line()
165
+ sep_offload_profiler.stop()
166
  torch.cuda.empty_cache()
167
 
168
  return wav_seperate[0]