Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -207,9 +207,12 @@ def generate_music(prompt: str, audio_length: int):
|
|
207 |
model_key = "facebook/musicgen-large"
|
208 |
musicgen_model, musicgen_processor = get_musicgen_model(model_key)
|
209 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
210 |
-
|
|
|
|
|
211 |
with torch.inference_mode():
|
212 |
outputs = musicgen_model.generate(**inputs, max_new_tokens=audio_length)
|
|
|
213 |
audio_data = outputs[0, 0].cpu().numpy()
|
214 |
normalized_audio = (audio_data / max(abs(audio_data)) * 32767).astype("int16")
|
215 |
output_path = os.path.join(tempfile.gettempdir(), "musicgen_large_generated_music.wav")
|
|
|
207 |
model_key = "facebook/musicgen-large"
|
208 |
musicgen_model, musicgen_processor = get_musicgen_model(model_key)
|
209 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
210 |
+
# Process the input and move each tensor to the proper device
|
211 |
+
inputs = musicgen_processor(text=[prompt], padding=True, return_tensors="pt")
|
212 |
+
inputs = {k: v.to(device) for k, v in inputs.items()}
|
213 |
with torch.inference_mode():
|
214 |
outputs = musicgen_model.generate(**inputs, max_new_tokens=audio_length)
|
215 |
+
# Post-process the output to create a WAV file
|
216 |
audio_data = outputs[0, 0].cpu().numpy()
|
217 |
normalized_audio = (audio_data / max(abs(audio_data)) * 32767).astype("int16")
|
218 |
output_path = os.path.join(tempfile.gettempdir(), "musicgen_large_generated_music.wav")
|