ford442 commited on
Commit
1153954
·
verified ·
1 Parent(s): 137ed0b

Update audiocraft/models/lm.py

Browse files
Files changed (1) hide show
  1. audiocraft/models/lm.py +94 -0
audiocraft/models/lm.py CHANGED
@@ -531,3 +531,97 @@ class LMModel(StreamingModule):
531
  # ensure the returned codes are all valid
532
  assert (out_codes >= 0).all() and (out_codes <= self.card).all()
533
  return out_codes
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
531
  # ensure the returned codes are all valid
532
  assert (out_codes >= 0).all() and (out_codes <= self.card).all()
533
  return out_codes
534
+ @torch.no_grad()
535
+ def generate_segment(self,
536
+ segment: int,
537
+ prompt_text: str,
538
+ max_segment_len: int,
539
+ seed: tp.Optional[int] = None,
540
+ # Pass other generation params like temp, top_k, etc.
541
+ **kwargs
542
+ ) -> tp.Tuple[torch.Tensor, int]:
543
+ """
544
+ Generates audio segment by segment, saving state to the filesystem.
545
+ This mirrors the logic from the RealViz script for robust, persistent state.
546
+
547
+ Args:
548
+ segment (int): The segment number to generate (starts at 1).
549
+ prompt_text (str): The text description for the music.
550
+ max_segment_len (int): The number of tokens to generate in this segment.
551
+ seed (int, optional): The seed for generation. If None and segment is 1,
552
+ a random seed is created.
553
+ **kwargs: Additional generation parameters (temp, top_k, cfg_coef).
554
+
555
+ Returns:
556
+ A tuple containing:
557
+ - full_codes (torch.Tensor): The generated tokens for the ENTIRE song so far.
558
+ - seed (int): The seed used for the generation process.
559
+ """
560
+ # Ensure a consistent seed across all segments of a song
561
+ if segment == 1:
562
+ if seed is None:
563
+ seed = random.randint(0, np.iinfo(np.int32).max)
564
+ print(f"Starting new generation with Seed: {seed}")
565
+
566
+ # --- This block runs only for the very first segment ---
567
+ conditions = [ConditioningAttributes(text={'description': prompt_text})]
568
+ # Start with an empty prompt tensor
569
+ prompt_codes = torch.zeros((1, self.num_codebooks, 0), dtype=torch.long, device=self.device)
570
+ self.clear_streaming_state() # Ensure model state is fresh
571
+ else:
572
+ # --- This block runs for all subsequent segments ---
573
+ state_file = f"musicgen_state_{segment-1}_{seed}.pt"
574
+ if not os.path.exists(state_file):
575
+ raise FileNotFoundError(f"State file not found! Cannot resume from segment {segment}. Please run segment {segment-1} first.")
576
+
577
+ print(f"Resuming from state file: {state_file}")
578
+ state = torch.load(state_file, map_location=self.device)
579
+
580
+ # Restore all necessary components from the saved state
581
+ seed = state['seed']
582
+ conditions = state['conditions']
583
+ # The prompt for the next segment is the full output from the previous one
584
+ prompt_codes = state['generated_tokens']
585
+ # CRITICAL: Restore the model's internal KV cache
586
+ self.set_streaming_state(state['model_state'])
587
+
588
+ # --- This part runs for EVERY segment ---
589
+ # The 'generate' function here refers to the original, non-chunking one.
590
+ # We are using it to generate just one segment's worth of audio.
591
+ # `remove_prompts=True` is vital to avoid re-generating the input prompt.
592
+ newly_generated_codes = self.generate(
593
+ prompt=prompt_codes,
594
+ conditions=conditions,
595
+ max_gen_len=prompt_codes.shape[-1] + max_segment_len, # Generate N more tokens
596
+ remove_prompts=True,
597
+ **kwargs
598
+ )
599
+
600
+ # Combine the previous audio with the new segment
601
+ full_codes = torch.cat([prompt_codes, newly_generated_codes], dim=-1)
602
+
603
+ # --- Save the new state for the NEXT segment to use ---
604
+ print(f"Segment {segment} finished. Saving state...")
605
+ new_model_state = self.get_streaming_state()
606
+
607
+ # Move tensors to CPU before saving for portability
608
+ new_model_state.to('cpu')
609
+
610
+ new_state_to_save = {
611
+ 'seed': seed,
612
+ 'conditions': conditions,
613
+ 'generated_tokens': full_codes.to('cpu'),
614
+ 'model_state': new_model_state,
615
+ }
616
+
617
+ # Save the state dictionary to a file
618
+ new_state_file = f"musicgen_state_{segment}_{seed}.pt"
619
+ torch.save(new_state_to_save, new_state_file)
620
+ print(f"State for resuming at segment {segment + 1} saved to {new_state_file}")
621
+
622
+ return full_codes, seed
623
+
624
+ # You should also add the device property to your LMModel class if it's not there
625
+ @property
626
+ def device(self):
627
+ return next(self.parameters()).device