Update audiocraft/models/lm.py
Browse files- audiocraft/models/lm.py +94 -0
audiocraft/models/lm.py
CHANGED
@@ -531,3 +531,97 @@ class LMModel(StreamingModule):
|
|
531 |
# ensure the returned codes are all valid
|
532 |
assert (out_codes >= 0).all() and (out_codes <= self.card).all()
|
533 |
return out_codes
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
531 |
# ensure the returned codes are all valid
|
532 |
assert (out_codes >= 0).all() and (out_codes <= self.card).all()
|
533 |
return out_codes
|
534 |
+
@torch.no_grad()
|
535 |
+
def generate_segment(self,
|
536 |
+
segment: int,
|
537 |
+
prompt_text: str,
|
538 |
+
max_segment_len: int,
|
539 |
+
seed: tp.Optional[int] = None,
|
540 |
+
# Pass other generation params like temp, top_k, etc.
|
541 |
+
**kwargs
|
542 |
+
) -> tp.Tuple[torch.Tensor, int]:
|
543 |
+
"""
|
544 |
+
Generates audio segment by segment, saving state to the filesystem.
|
545 |
+
This mirrors the logic from the RealViz script for robust, persistent state.
|
546 |
+
|
547 |
+
Args:
|
548 |
+
segment (int): The segment number to generate (starts at 1).
|
549 |
+
prompt_text (str): The text description for the music.
|
550 |
+
max_segment_len (int): The number of tokens to generate in this segment.
|
551 |
+
seed (int, optional): The seed for generation. If None and segment is 1,
|
552 |
+
a random seed is created.
|
553 |
+
**kwargs: Additional generation parameters (temp, top_k, cfg_coef).
|
554 |
+
|
555 |
+
Returns:
|
556 |
+
A tuple containing:
|
557 |
+
- full_codes (torch.Tensor): The generated tokens for the ENTIRE song so far.
|
558 |
+
- seed (int): The seed used for the generation process.
|
559 |
+
"""
|
560 |
+
# Ensure a consistent seed across all segments of a song
|
561 |
+
if segment == 1:
|
562 |
+
if seed is None:
|
563 |
+
seed = random.randint(0, np.iinfo(np.int32).max)
|
564 |
+
print(f"Starting new generation with Seed: {seed}")
|
565 |
+
|
566 |
+
# --- This block runs only for the very first segment ---
|
567 |
+
conditions = [ConditioningAttributes(text={'description': prompt_text})]
|
568 |
+
# Start with an empty prompt tensor
|
569 |
+
prompt_codes = torch.zeros((1, self.num_codebooks, 0), dtype=torch.long, device=self.device)
|
570 |
+
self.clear_streaming_state() # Ensure model state is fresh
|
571 |
+
else:
|
572 |
+
# --- This block runs for all subsequent segments ---
|
573 |
+
state_file = f"musicgen_state_{segment-1}_{seed}.pt"
|
574 |
+
if not os.path.exists(state_file):
|
575 |
+
raise FileNotFoundError(f"State file not found! Cannot resume from segment {segment}. Please run segment {segment-1} first.")
|
576 |
+
|
577 |
+
print(f"Resuming from state file: {state_file}")
|
578 |
+
state = torch.load(state_file, map_location=self.device)
|
579 |
+
|
580 |
+
# Restore all necessary components from the saved state
|
581 |
+
seed = state['seed']
|
582 |
+
conditions = state['conditions']
|
583 |
+
# The prompt for the next segment is the full output from the previous one
|
584 |
+
prompt_codes = state['generated_tokens']
|
585 |
+
# CRITICAL: Restore the model's internal KV cache
|
586 |
+
self.set_streaming_state(state['model_state'])
|
587 |
+
|
588 |
+
# --- This part runs for EVERY segment ---
|
589 |
+
# The 'generate' function here refers to the original, non-chunking one.
|
590 |
+
# We are using it to generate just one segment's worth of audio.
|
591 |
+
# `remove_prompts=True` is vital to avoid re-generating the input prompt.
|
592 |
+
newly_generated_codes = self.generate(
|
593 |
+
prompt=prompt_codes,
|
594 |
+
conditions=conditions,
|
595 |
+
max_gen_len=prompt_codes.shape[-1] + max_segment_len, # Generate N more tokens
|
596 |
+
remove_prompts=True,
|
597 |
+
**kwargs
|
598 |
+
)
|
599 |
+
|
600 |
+
# Combine the previous audio with the new segment
|
601 |
+
full_codes = torch.cat([prompt_codes, newly_generated_codes], dim=-1)
|
602 |
+
|
603 |
+
# --- Save the new state for the NEXT segment to use ---
|
604 |
+
print(f"Segment {segment} finished. Saving state...")
|
605 |
+
new_model_state = self.get_streaming_state()
|
606 |
+
|
607 |
+
# Move tensors to CPU before saving for portability
|
608 |
+
new_model_state.to('cpu')
|
609 |
+
|
610 |
+
new_state_to_save = {
|
611 |
+
'seed': seed,
|
612 |
+
'conditions': conditions,
|
613 |
+
'generated_tokens': full_codes.to('cpu'),
|
614 |
+
'model_state': new_model_state,
|
615 |
+
}
|
616 |
+
|
617 |
+
# Save the state dictionary to a file
|
618 |
+
new_state_file = f"musicgen_state_{segment}_{seed}.pt"
|
619 |
+
torch.save(new_state_to_save, new_state_file)
|
620 |
+
print(f"State for resuming at segment {segment + 1} saved to {new_state_file}")
|
621 |
+
|
622 |
+
return full_codes, seed
|
623 |
+
|
624 |
+
# You should also add the device property to your LMModel class if it's not there
|
625 |
+
@property
|
626 |
+
def device(self):
|
627 |
+
return next(self.parameters()).device
|