ford442 commited on
Commit
d27739c
·
verified ·
1 Parent(s): 6b50856

Update demos/musicgen_colab.py

Browse files
Files changed (1) hide show
  1. demos/musicgen_colab.py +8 -22
demos/musicgen_colab.py CHANGED
@@ -200,11 +200,11 @@ class Predictor:
200
  assert outputs_diffusion.shape[1] == 1 # output is mono
201
  outputs_diffusion = rearrange(outputs_diffusion, '(s b) c t -> b (s c) t', s=2)
202
  outputs_diffusion = outputs_diffusion.detach().cpu()
203
- return output, outputs_diffusion #Return the task id. <-- Corrected return
204
  else:
205
- return output, None # <-- Corrected return
206
  except Exception as e:
207
- return task_id, e #Should never be used, kept for consistency.
208
  else:
209
  # Use the multiprocessing queue (multi-process mode)
210
  self.current_task_id += 1
@@ -238,9 +238,7 @@ _default_model_name = "facebook/musicgen-melody"
238
  def predict_full(model, model_path, depth, use_mbd, text, melody, duration, topk, topp, temperature, cfg_coef):
239
  # Initialize Predictor *INSIDE* the function
240
  predictor = Predictor(model, depth)
241
-
242
- # Call predict() - this will return either (wav, diffusion_wav) or a task_id
243
- prediction_result = predictor.predict(
244
  text=text,
245
  melody=melody,
246
  duration=duration,
@@ -250,16 +248,7 @@ def predict_full(model, model_path, depth, use_mbd, text, melody, duration, topk
250
  temperature=temperature,
251
  cfg_coef=cfg_coef,
252
  )
253
-
254
- # Handle daemon and non-daemon cases
255
- if predictor.is_daemon:
256
- wav, diffusion_wav = prediction_result # Direct unpacking (daemon mode)
257
- else:
258
- # Get the result using the task_id (multi-process mode)
259
- task_id = prediction_result
260
- wav, diffusion_wav = predictor.get_result(task_id)
261
-
262
- # Save and return audio files (rest of the function remains the same)
263
  wav_paths = []
264
  video_paths = []
265
  # Save standard output
@@ -274,7 +263,6 @@ def predict_full(model, model_path, depth, use_mbd, text, melody, duration, topk
274
  video_paths.append(video_path)
275
  file_cleaner.add(file.name)
276
  file_cleaner.add(video_path)
277
-
278
  # Save MBD output if used
279
  if diffusion_wav is not None:
280
  with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file:
@@ -288,15 +276,13 @@ def predict_full(model, model_path, depth, use_mbd, text, melody, duration, topk
288
  video_paths.append(video_path)
289
  file_cleaner.add(file.name)
290
  file_cleaner.add(video_path)
291
-
292
  # Shutdown predictor to prevent hanging processes!
293
- if not predictor.is_daemon: # Important!
294
  predictor.shutdown()
295
-
296
  if use_mbd:
297
- return video_paths[0], wav_paths[0], video_paths[1], wav_paths[1]
298
  return video_paths[0], wav_paths[0], None, None
299
-
300
  def toggle_audio_src(choice):
301
  if choice == "mic":
302
  return gr.update(sources="microphone", value=None, label="Microphone")
 
200
  assert outputs_diffusion.shape[1] == 1 # output is mono
201
  outputs_diffusion = rearrange(outputs_diffusion, '(s b) c t -> b (s c) t', s=2)
202
  outputs_diffusion = outputs_diffusion.detach().cpu()
203
+ return task_id, (output, outputs_diffusion) #Return the task id.
204
  else:
205
+ return task_id, (output, None)
206
  except Exception as e:
207
+ return task_id, e
208
  else:
209
  # Use the multiprocessing queue (multi-process mode)
210
  self.current_task_id += 1
 
238
  def predict_full(model, model_path, depth, use_mbd, text, melody, duration, topk, topp, temperature, cfg_coef):
239
  # Initialize Predictor *INSIDE* the function
240
  predictor = Predictor(model, depth)
241
+ task_id, (wav, diffusion_wav) = predictor.predict( # Unpack directly!
 
 
242
  text=text,
243
  melody=melody,
244
  duration=duration,
 
248
  temperature=temperature,
249
  cfg_coef=cfg_coef,
250
  )
251
+ # Save and return audio files
 
 
 
 
 
 
 
 
 
252
  wav_paths = []
253
  video_paths = []
254
  # Save standard output
 
263
  video_paths.append(video_path)
264
  file_cleaner.add(file.name)
265
  file_cleaner.add(video_path)
 
266
  # Save MBD output if used
267
  if diffusion_wav is not None:
268
  with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file:
 
276
  video_paths.append(video_path)
277
  file_cleaner.add(file.name)
278
  file_cleaner.add(video_path)
 
279
  # Shutdown predictor to prevent hanging processes!
280
+ if not predictor.is_daemon: # Important!
281
  predictor.shutdown()
 
282
  if use_mbd:
283
+ return video_paths[0], wav_paths[0], video_paths[1], wav_paths[1]
284
  return video_paths[0], wav_paths[0], None, None
285
+
286
  def toggle_audio_src(choice):
287
  if choice == "mic":
288
  return gr.update(sources="microphone", value=None, label="Microphone")