Bils commited on
Commit
ee1b035
·
verified ·
1 Parent(s): b87869d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -1
app.py CHANGED
@@ -207,9 +207,12 @@ def generate_music(prompt: str, audio_length: int):
207
  model_key = "facebook/musicgen-large"
208
  musicgen_model, musicgen_processor = get_musicgen_model(model_key)
209
  device = "cuda" if torch.cuda.is_available() else "cpu"
210
- inputs = musicgen_processor(text=[prompt], padding=True, return_tensors="pt").to(device)
 
 
211
  with torch.inference_mode():
212
  outputs = musicgen_model.generate(**inputs, max_new_tokens=audio_length)
 
213
  audio_data = outputs[0, 0].cpu().numpy()
214
  normalized_audio = (audio_data / max(abs(audio_data)) * 32767).astype("int16")
215
  output_path = os.path.join(tempfile.gettempdir(), "musicgen_large_generated_music.wav")
 
207
  model_key = "facebook/musicgen-large"
208
  musicgen_model, musicgen_processor = get_musicgen_model(model_key)
209
  device = "cuda" if torch.cuda.is_available() else "cpu"
210
+ # Process the input and move each tensor to the proper device
211
+ inputs = musicgen_processor(text=[prompt], padding=True, return_tensors="pt")
212
+ inputs = {k: v.to(device) for k, v in inputs.items()}
213
  with torch.inference_mode():
214
  outputs = musicgen_model.generate(**inputs, max_new_tokens=audio_length)
215
+ # Post-process the output to create a WAV file
216
  audio_data = outputs[0, 0].cpu().numpy()
217
  normalized_audio = (audio_data / max(abs(audio_data)) * 32767).astype("int16")
218
  output_path = os.path.join(tempfile.gettempdir(), "musicgen_large_generated_music.wav")