nit
Browse files
audiocraft/models/musicgen.py
CHANGED
|
@@ -270,8 +270,7 @@ class MusicGen:
|
|
| 270 |
torch.Tensor: Generated audio, of shape [B, C, T], T is defined by the generation params.
|
| 271 |
"""
|
| 272 |
total_gen_len = int(self.duration * self.frame_rate)
|
| 273 |
-
|
| 274 |
-
current_gen_offset = 0
|
| 275 |
|
| 276 |
def _progress_callback(generated_tokens: int, tokens_to_generate: int):
|
| 277 |
print(f'{current_gen_offset + generated_tokens: 6d} / {total_gen_len: 6d}', end='\r')
|
|
@@ -299,7 +298,7 @@ class MusicGen:
|
|
| 299 |
if prompt_tokens is not None:
|
| 300 |
all_tokens.append(prompt_tokens)
|
| 301 |
|
| 302 |
-
time_offset = 0
|
| 303 |
while time_offset < self.duration:
|
| 304 |
chunk_duration = min(self.duration - time_offset, self.max_duration)
|
| 305 |
max_gen_len = int(chunk_duration * self.frame_rate)
|
|
@@ -308,12 +307,15 @@ class MusicGen:
|
|
| 308 |
if wav_length == 0:
|
| 309 |
continue
|
| 310 |
# We will extend the wav periodically if it not long enough.
|
| 311 |
-
# we have to do it here
|
|
|
|
| 312 |
initial_position = int(time_offset * self.sample_rate)
|
| 313 |
-
wav_target_length = int(
|
| 314 |
positions = torch.arange(initial_position,
|
| 315 |
initial_position + wav_target_length, device=self.device)
|
| 316 |
-
attr.wav['self_wav'] =
|
|
|
|
|
|
|
| 317 |
with self.autocast:
|
| 318 |
gen_tokens = self.lm.generate(
|
| 319 |
prompt_tokens, attributes,
|
|
|
|
| 270 |
torch.Tensor: Generated audio, of shape [B, C, T], T is defined by the generation params.
|
| 271 |
"""
|
| 272 |
total_gen_len = int(self.duration * self.frame_rate)
|
| 273 |
+
current_gen_offset: int = 0
|
|
|
|
| 274 |
|
| 275 |
def _progress_callback(generated_tokens: int, tokens_to_generate: int):
|
| 276 |
print(f'{current_gen_offset + generated_tokens: 6d} / {total_gen_len: 6d}', end='\r')
|
|
|
|
| 298 |
if prompt_tokens is not None:
|
| 299 |
all_tokens.append(prompt_tokens)
|
| 300 |
|
| 301 |
+
time_offset = 0.
|
| 302 |
while time_offset < self.duration:
|
| 303 |
chunk_duration = min(self.duration - time_offset, self.max_duration)
|
| 304 |
max_gen_len = int(chunk_duration * self.frame_rate)
|
|
|
|
| 307 |
if wav_length == 0:
|
| 308 |
continue
|
| 309 |
# We will extend the wav periodically if it not long enough.
|
| 310 |
+
# we have to do it here rather than in conditioners.py as otherwise
|
| 311 |
+
# we wouldn't have the full wav.
|
| 312 |
initial_position = int(time_offset * self.sample_rate)
|
| 313 |
+
wav_target_length = int(self.max_duration * self.sample_rate)
|
| 314 |
positions = torch.arange(initial_position,
|
| 315 |
initial_position + wav_target_length, device=self.device)
|
| 316 |
+
attr.wav['self_wav'] = WavCondition(
|
| 317 |
+
ref_wav[0][:, positions % wav_length],
|
| 318 |
+
torch.full_like(ref_wav[1], wav_target_length))
|
| 319 |
with self.autocast:
|
| 320 |
gen_tokens = self.lm.generate(
|
| 321 |
prompt_tokens, attributes,
|