Spaces:
Build error
Build error
initial implem
Browse files
audiocraft/models/musicgen.py
CHANGED
|
@@ -36,10 +36,12 @@ class MusicGen:
|
|
| 36 |
used to map audio to invertible discrete representations.
|
| 37 |
lm (LMModel): Language model over discrete representations.
|
| 38 |
"""
|
| 39 |
-
def __init__(self, name: str, compression_model: CompressionModel, lm: LMModel
|
|
|
|
| 40 |
self.name = name
|
| 41 |
self.compression_model = compression_model
|
| 42 |
self.lm = lm
|
|
|
|
| 43 |
self.device = next(iter(lm.parameters())).device
|
| 44 |
self.generation_params: dict = {}
|
| 45 |
self.set_generation_params(duration=15) # 15 seconds by default
|
|
@@ -113,11 +115,10 @@ class MusicGen:
|
|
| 113 |
should we extend the audio each time. Larger values will mean less context is
|
| 114 |
preserved, and shorter value will require extra computations.
|
| 115 |
"""
|
| 116 |
-
|
| 117 |
-
assert extend_stride <= 25, "Keep at least 5 seconds of overlap!"
|
| 118 |
self.extend_stride = extend_stride
|
|
|
|
| 119 |
self.generation_params = {
|
| 120 |
-
'max_gen_len': int(duration * self.frame_rate),
|
| 121 |
'use_sampling': use_sampling,
|
| 122 |
'temp': temperature,
|
| 123 |
'top_k': top_k,
|
|
@@ -268,8 +269,12 @@ class MusicGen:
|
|
| 268 |
Returns:
|
| 269 |
torch.Tensor: Generated audio, of shape [B, C, T], T is defined by the generation params.
|
| 270 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 271 |
def _progress_callback(generated_tokens: int, tokens_to_generate: int):
|
| 272 |
-
print(f'{generated_tokens: 6d} / {
|
| 273 |
|
| 274 |
if prompt_tokens is not None:
|
| 275 |
assert self.generation_params['max_gen_len'] > prompt_tokens.shape[-1], \
|
|
@@ -279,9 +284,46 @@ class MusicGen:
|
|
| 279 |
if progress:
|
| 280 |
callback = _progress_callback
|
| 281 |
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 285 |
|
| 286 |
# generate audio
|
| 287 |
assert gen_tokens.dim() == 3
|
|
|
|
| 36 |
used to map audio to invertible discrete representations.
|
| 37 |
lm (LMModel): Language model over discrete representations.
|
| 38 |
"""
|
| 39 |
+
def __init__(self, name: str, compression_model: CompressionModel, lm: LMModel,
|
| 40 |
+
max_duration: float = 30):
|
| 41 |
self.name = name
|
| 42 |
self.compression_model = compression_model
|
| 43 |
self.lm = lm
|
| 44 |
+
self.max_duration = max_duration
|
| 45 |
self.device = next(iter(lm.parameters())).device
|
| 46 |
self.generation_params: dict = {}
|
| 47 |
self.set_generation_params(duration=15) # 15 seconds by default
|
|
|
|
| 115 |
should we extend the audio each time. Larger values will mean less context is
|
| 116 |
preserved, and shorter value will require extra computations.
|
| 117 |
"""
|
| 118 |
+
assert extend_stride <= self.max_duration - 5, "Keep at least 5 seconds of overlap!"
|
|
|
|
| 119 |
self.extend_stride = extend_stride
|
| 120 |
+
self.duration = duration
|
| 121 |
self.generation_params = {
|
|
|
|
| 122 |
'use_sampling': use_sampling,
|
| 123 |
'temp': temperature,
|
| 124 |
'top_k': top_k,
|
|
|
|
| 269 |
Returns:
|
| 270 |
torch.Tensor: Generated audio, of shape [B, C, T], T is defined by the generation params.
|
| 271 |
"""
|
| 272 |
+
total_gen_len = int(self.duration * self.frame_rate)
|
| 273 |
+
|
| 274 |
+
current_gen_offset = 0
|
| 275 |
+
|
| 276 |
def _progress_callback(generated_tokens: int, tokens_to_generate: int):
|
| 277 |
+
print(f'{current_gen_offset + generated_tokens: 6d} / {total_gen_len: 6d}', end='\r')
|
| 278 |
|
| 279 |
if prompt_tokens is not None:
|
| 280 |
assert self.generation_params['max_gen_len'] > prompt_tokens.shape[-1], \
|
|
|
|
| 284 |
if progress:
|
| 285 |
callback = _progress_callback
|
| 286 |
|
| 287 |
+
if self.duration <= self.max_duration:
|
| 288 |
+
# generate by sampling from LM, simple case.
|
| 289 |
+
with self.autocast:
|
| 290 |
+
gen_tokens = self.lm.generate(
|
| 291 |
+
prompt_tokens, attributes,
|
| 292 |
+
callback=callback, max_gen_len=total_gen_len, **self.generation_params)
|
| 293 |
+
|
| 294 |
+
else:
|
| 295 |
+
# now this gets a bit messier, we need to handle prompts,
|
| 296 |
+
# melody conditioning etc.
|
| 297 |
+
ref_wavs = [attr.wav['self_wav'] for attr in attributes]
|
| 298 |
+
all_tokens = []
|
| 299 |
+
if prompt_tokens is not None:
|
| 300 |
+
all_tokens.append(prompt_tokens)
|
| 301 |
+
|
| 302 |
+
for time_offset in range(0, self.duration, self.extend_stride):
|
| 303 |
+
chunk_duration = min(self.duration - time_offset, self.max_duration)
|
| 304 |
+
max_gen_len = int(chunk_duration * self.frame_rate)
|
| 305 |
+
for attr, ref_wav in zip(attributes, ref_wavs):
|
| 306 |
+
wav_length = ref_wav.length.item()
|
| 307 |
+
if wav_length == 0:
|
| 308 |
+
continue
|
| 309 |
+
# We will extend the wav periodically if it not long enough.
|
| 310 |
+
# we have to do it here before it is too late.
|
| 311 |
+
initial_position = int(time_offset * self.sample_rate)
|
| 312 |
+
wav_target_length = int(chunk_duration * self.sample_rate)
|
| 313 |
+
positions = torch.arange(initial_position,
|
| 314 |
+
initial_position + wav_target_length, device=self.device)
|
| 315 |
+
attr.wav['self_wav'] = ref_wav[:, positions % wav_length]
|
| 316 |
+
with self.autocast:
|
| 317 |
+
gen_tokens = self.lm.generate(
|
| 318 |
+
prompt_tokens, attributes,
|
| 319 |
+
callback=callback, max_gen_len=max_gen_len, **self.generation_params)
|
| 320 |
+
stride_tokens = int(self.frame_rate * self.extend_stride)
|
| 321 |
+
if prompt_tokens is None:
|
| 322 |
+
all_tokens.append(gen_tokens)
|
| 323 |
+
else:
|
| 324 |
+
all_tokens.append(gen_tokens[:, :, prompt_tokens.shape[-1]:])
|
| 325 |
+
prompt_tokens = gen_tokens[:, :, stride_tokens]
|
| 326 |
+
gen_tokens = torch.cat(all_tokens, dim=-1)
|
| 327 |
|
| 328 |
# generate audio
|
| 329 |
assert gen_tokens.dim() == 3
|