ford442 commited on
Commit
8fb7e18
·
verified ·
1 Parent(s): 0b90fab

Update demos/musicgen_app.py

Browse files
Files changed (1) hide show
  1. demos/musicgen_app.py +3 -220
demos/musicgen_app.py CHANGED
@@ -1,5 +1,4 @@
1
  import argparse
2
- from concurrent.futures import ProcessPoolExecutor
3
  import logging
4
  import os
5
  from pathlib import Path
@@ -7,7 +6,7 @@ import subprocess as sp
7
  import sys
8
  import time
9
  import typing as tp
10
- import warnings
11
 
12
  from einops import rearrange
13
  import torch
@@ -21,7 +20,7 @@ import multiprocessing as mp
21
 
22
  # --- Utility Functions and Classes ---
23
 
24
- class FileCleaner: # Unchanged from previous example, included for completeness
25
  def __init__(self, file_lifetime: float = 3600):
26
  self.file_lifetime = file_lifetime
27
  self.files = []
@@ -154,220 +153,4 @@ class Predictor:
154
  result_task_id, result = self.result_queue.get()
155
  if result_task_id == task_id:
156
  if isinstance(result, Exception):
157
- raise result # Re-raise the exception in the main process
158
- return result # (wav, diffusion_wav) or (wav, None)
159
-
160
- def shutdown(self):
161
- """
162
- Shuts down the worker process.
163
- """
164
- self.task_queue.put(None) # Send sentinel value to stop the worker
165
- self.process.join() # Wait for the process to terminate
166
-
167
-
168
- # Global predictor instance
169
- _predictor = None
170
-
171
- def get_predictor(model_name:str = 'facebook/musicgen-melody'):
172
- global _predictor
173
- if _predictor is None:
174
- _predictor = Predictor(model_name)
175
- return _predictor
176
-
177
- def predict_full(model, model_path, use_mbd, text, melody, duration, topk, topp, temperature, cfg_coef):
178
-
179
- predictor = get_predictor(model)
180
- task_id = predictor.predict(
181
- text=text,
182
- melody=melody,
183
- duration=duration,
184
- use_diffusion=use_mbd,
185
- top_k=topk,
186
- top_p=topp,
187
- temperature=temperature,
188
- cfg_coef=cfg_coef,
189
- )
190
-
191
- wav, diffusion_wav = predictor.get_result(task_id)
192
-
193
- # Save and return audio files
194
- wav_paths = []
195
- video_paths = []
196
-
197
- # Save standard output
198
- with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file:
199
- audio_write(
200
- file.name, wav[0], 32000, strategy="loudness", #hardcoded sample rate
201
- loudness_headroom_db=16, loudness_compressor=True, add_suffix=False
202
- )
203
- wav_paths.append(file.name)
204
- video_paths.append(make_waveform(file.name)) # Make and clean up video
205
- file_cleaner.add(file.name)
206
-
207
- # Save MBD output if used
208
- if diffusion_wav is not None:
209
- with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file:
210
- audio_write(
211
- file.name, diffusion_wav[0], 32000, strategy="loudness", #hardcoded sample rate
212
- loudness_headroom_db=16, loudness_compressor=True, add_suffix=False
213
- )
214
- wav_paths.append(file.name)
215
- video_paths.append(make_waveform(file.name)) # Make and clean up video
216
- file_cleaner.add(file.name)
217
-
218
- if use_mbd:
219
- return video_paths[0], wav_paths[0], video_paths[1], wav_paths[1]
220
- return video_paths[0], wav_paths[0], None, None
221
-
222
-
223
- def toggle_audio_src(choice):
224
- if choice == "mic":
225
- return gr.update(sources="microphone", value=None, label="Microphone")
226
- else:
227
- return gr.update(sources="upload", value=None, label="File")
228
-
229
-
230
- def toggle_diffusion(choice):
231
- if choice == "MultiBand_Diffusion":
232
- return [gr.update(visible=True)] * 2
233
- else:
234
- return [gr.update(visible=False)] * 2
235
- # --- Gradio UI ---
236
-
237
- def ui_full(launch_kwargs):
238
- with gr.Blocks() as interface:
239
- gr.Markdown(
240
- """
241
- # MusicGen
242
- This is your private demo for [MusicGen](https://github.com/facebookresearch/audiocraft),
243
- a simple and controllable model for music generation
244
- presented at: ["Simple and Controllable Music Generation"](https://huggingface.co/papers/2306.05284)
245
- """
246
- )
247
- with gr.Row():
248
- with gr.Column():
249
- with gr.Row():
250
- text = gr.Text(label="Input Text", interactive=True)
251
- with gr.Column():
252
- radio = gr.Radio(["file", "mic"], value="file",
253
- label="Condition on a melody (optional) File or Mic")
254
- melody = gr.Audio(sources="upload", type="numpy", label="File",
255
- interactive=True, elem_id="melody-input")
256
- with gr.Row():
257
- submit = gr.Button("Submit")
258
- # _ = gr.Button("Interrupt").click(fn=interrupt, queue=False) # Interrupt is now handled implicitly
259
- with gr.Row():
260
- model = gr.Radio(["facebook/musicgen-melody", "facebook/musicgen-medium", "facebook/musicgen-small",
261
- "facebook/musicgen-large", "facebook/musicgen-melody-large",
262
- "facebook/musicgen-stereo-small", "facebook/musicgen-stereo-medium",
263
- "facebook/musicgen-stereo-melody", "facebook/musicgen-stereo-large",
264
- "facebook/musicgen-stereo-melody-large"],
265
- label="Model", value="facebook/musicgen-melody", interactive=True)
266
- model_path = gr.Text(label="Model Path (custom models)", interactive=False, visible=False) # Keep, but hide
267
- with gr.Row():
268
- decoder = gr.Radio(["Default", "MultiBand_Diffusion"],
269
- label="Decoder", value="Default", interactive=True)
270
- with gr.Row():
271
- duration = gr.Slider(minimum=1, maximum=120, value=10, label="Duration", interactive=True)
272
- with gr.Row():
273
- topk = gr.Number(label="Top-k", value=250, interactive=True)
274
- topp = gr.Number(label="Top-p", value=0, interactive=True)
275
- temperature = gr.Number(label="Temperature", value=1.0, interactive=True)
276
- cfg_coef = gr.Number(label="Classifier Free Guidance", value=3.0, interactive=True)
277
- with gr.Column():
278
- output = gr.Video(label="Generated Music")
279
- audio_output = gr.Audio(label="Generated Music (wav)", type='filepath')
280
- diffusion_output = gr.Video(label="MultiBand Diffusion Decoder", visible=False)
281
- audio_diffusion = gr.Audio(label="MultiBand Diffusion Decoder (wav)", type='filepath', visible=False)
282
-
283
- submit.click(
284
- toggle_diffusion, decoder, [diffusion_output, audio_diffusion], queue=False
285
- ).then(
286
- predict_full,
287
- inputs=[model, model_path, decoder, text, melody, duration, topk, topp, temperature, cfg_coef],
288
- outputs=[output, audio_output, diffusion_output, audio_diffusion]
289
- )
290
- radio.change(toggle_audio_src, radio, [melody], queue=False, show_progress=False)
291
-
292
- gr.Examples(
293
- fn=predict_full,
294
- examples=[
295
- [
296
- "An 80s driving pop song with heavy drums and synth pads in the background",
297
- "./assets/bach.mp3",
298
- "facebook/musicgen-melody",
299
- "Default"
300
- ],
301
- [
302
- "A cheerful country song with acoustic guitars",
303
- "./assets/bolero_ravel.mp3",
304
- "facebook/musicgen-melody",
305
- "Default"
306
- ],
307
- [
308
- "90s rock song with electric guitar and heavy drums",
309
- None,
310
- "facebook/musicgen-medium",
311
- "Default"
312
- ],
313
- [
314
- "a light and cheerly EDM track, with syncopated drums, aery pads, and strong emotions",
315
- "./assets/bach.mp3",
316
- "facebook/musicgen-melody",
317
- "Default"
318
- ],
319
- [
320
- "lofi slow bpm electro chill with organic samples",
321
- None,
322
- "facebook/musicgen-medium",
323
- "Default"
324
- ],
325
- [
326
- "Punk rock with loud drum and power guitar",
327
- None,
328
- "facebook/musicgen-medium",
329
- "MultiBand_Diffusion"
330
- ],
331
- ],
332
- inputs=[text, melody, model, decoder],
333
- outputs=[output]
334
- )
335
- gr.Markdown(
336
- """
337
- ### More details
338
-
339
- The model will generate a short music extract based on the description you provided.
340
- The model can generate up to 30 seconds of audio in one pass.
341
-
342
- The model was trained with description from a stock music catalog, descriptions that will work best
343
- should include some level of details on the instruments present, along with some intended use case
344
- (e.g. adding "perfect for a commercial" can somehow help).
345
-
346
- Using one of the `melody` model (e.g. `musicgen-melody-*`), you can optionally provide a reference audio
347
- from which a broad melody will be extracted.
348
- The model will then try to follow both the description and melody provided.
349
- For best results, the melody should be 30 seconds long (I know, the samples we provide are not...)
350
-
351
- It is now possible to extend the generation by feeding back the end of the previous chunk of audio.
352
- This can take a long time, and the model might lose consistency. The model might also
353
- decide at arbitrary positions that the song ends.
354
-
355
- **WARNING:** Choosing long durations will take a long time to generate (2min might take ~10min).
356
- An overlap of 12 seconds is kept with the previously generated chunk, and 18 "new" seconds
357
- are generated each time.
358
-
359
- We present 10 model variations:
360
- 1. facebook/musicgen-melody -- a music generation model capable of generating music condition
361
- on text and melody inputs. **Note**, you can also use text only.
362
- 2. facebook/musicgen-small -- a 300M transformer decoder conditioned on text only.
363
- 3. facebook/musicgen-medium -- a 1.5B transformer decoder conditioned on text only.
364
- 4. facebook/musicgen-large -- a 3.3B transformer decoder conditioned on text only.
365
- 5. facebook/musicgen-melody-large -- a 3.3B transformer decoder conditioned on and melody.
366
- 6. facebook/musicgen-stereo-*: same as the previous models but fine tuned to output stereo audio.
367
-
368
- We also present two way of decoding the audio tokens
369
- 1. Use the default GAN based compression model. It can suffer from artifacts especially
370
- for crashes, snares etc.
371
- 2. Use [MultiBand Diffusion](https://arxiv.org/abs/2308.02560). Should improve the audio quality,
372
- at an extra computational cost. When this is selected, we provide both the GAN based decoded
373
- audio, and the one obtained with MBD.
 
1
  import argparse
 
2
  import logging
3
  import os
4
  from pathlib import Path
 
6
  import sys
7
  import time
8
  import typing as tp
9
+ from tempfile import NamedTemporaryFile
10
 
11
  from einops import rearrange
12
  import torch
 
20
 
21
  # --- Utility Functions and Classes ---
22
 
23
+ class FileCleaner: # Unchanged
24
  def __init__(self, file_lifetime: float = 3600):
25
  self.file_lifetime = file_lifetime
26
  self.files = []
 
153
  result_task_id, result = self.result_queue.get()
154
  if result_task_id == task_id:
155
  if isinstance(result, Exception):
156
+ raise result