Update demos/musicgen_colab.py
Browse files- demos/musicgen_colab.py +8 -22
demos/musicgen_colab.py
CHANGED
@@ -200,11 +200,11 @@ class Predictor:
|
|
200 |
assert outputs_diffusion.shape[1] == 1 # output is mono
|
201 |
outputs_diffusion = rearrange(outputs_diffusion, '(s b) c t -> b (s c) t', s=2)
|
202 |
outputs_diffusion = outputs_diffusion.detach().cpu()
|
203 |
-
return output, outputs_diffusion #Return the task id.
|
204 |
else:
|
205 |
-
return output, None
|
206 |
except Exception as e:
|
207 |
-
return task_id, e
|
208 |
else:
|
209 |
# Use the multiprocessing queue (multi-process mode)
|
210 |
self.current_task_id += 1
|
@@ -238,9 +238,7 @@ _default_model_name = "facebook/musicgen-melody"
|
|
238 |
def predict_full(model, model_path, depth, use_mbd, text, melody, duration, topk, topp, temperature, cfg_coef):
|
239 |
# Initialize Predictor *INSIDE* the function
|
240 |
predictor = Predictor(model, depth)
|
241 |
-
|
242 |
-
# Call predict() - this will return either (wav, diffusion_wav) or a task_id
|
243 |
-
prediction_result = predictor.predict(
|
244 |
text=text,
|
245 |
melody=melody,
|
246 |
duration=duration,
|
@@ -250,16 +248,7 @@ def predict_full(model, model_path, depth, use_mbd, text, melody, duration, topk
|
|
250 |
temperature=temperature,
|
251 |
cfg_coef=cfg_coef,
|
252 |
)
|
253 |
-
|
254 |
-
# Handle daemon and non-daemon cases
|
255 |
-
if predictor.is_daemon:
|
256 |
-
wav, diffusion_wav = prediction_result # Direct unpacking (daemon mode)
|
257 |
-
else:
|
258 |
-
# Get the result using the task_id (multi-process mode)
|
259 |
-
task_id = prediction_result
|
260 |
-
wav, diffusion_wav = predictor.get_result(task_id)
|
261 |
-
|
262 |
-
# Save and return audio files (rest of the function remains the same)
|
263 |
wav_paths = []
|
264 |
video_paths = []
|
265 |
# Save standard output
|
@@ -274,7 +263,6 @@ def predict_full(model, model_path, depth, use_mbd, text, melody, duration, topk
|
|
274 |
video_paths.append(video_path)
|
275 |
file_cleaner.add(file.name)
|
276 |
file_cleaner.add(video_path)
|
277 |
-
|
278 |
# Save MBD output if used
|
279 |
if diffusion_wav is not None:
|
280 |
with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file:
|
@@ -288,15 +276,13 @@ def predict_full(model, model_path, depth, use_mbd, text, melody, duration, topk
|
|
288 |
video_paths.append(video_path)
|
289 |
file_cleaner.add(file.name)
|
290 |
file_cleaner.add(video_path)
|
291 |
-
|
292 |
# Shutdown predictor to prevent hanging processes!
|
293 |
-
if not predictor.is_daemon:
|
294 |
predictor.shutdown()
|
295 |
-
|
296 |
if use_mbd:
|
297 |
-
|
298 |
return video_paths[0], wav_paths[0], None, None
|
299 |
-
|
300 |
def toggle_audio_src(choice):
|
301 |
if choice == "mic":
|
302 |
return gr.update(sources="microphone", value=None, label="Microphone")
|
|
|
200 |
assert outputs_diffusion.shape[1] == 1 # output is mono
|
201 |
outputs_diffusion = rearrange(outputs_diffusion, '(s b) c t -> b (s c) t', s=2)
|
202 |
outputs_diffusion = outputs_diffusion.detach().cpu()
|
203 |
+
return task_id, (output, outputs_diffusion) #Return the task id.
|
204 |
else:
|
205 |
+
return task_id, (output, None)
|
206 |
except Exception as e:
|
207 |
+
return task_id, e
|
208 |
else:
|
209 |
# Use the multiprocessing queue (multi-process mode)
|
210 |
self.current_task_id += 1
|
|
|
238 |
def predict_full(model, model_path, depth, use_mbd, text, melody, duration, topk, topp, temperature, cfg_coef):
|
239 |
# Initialize Predictor *INSIDE* the function
|
240 |
predictor = Predictor(model, depth)
|
241 |
+
task_id, (wav, diffusion_wav) = predictor.predict( # Unpack directly!
|
|
|
|
|
242 |
text=text,
|
243 |
melody=melody,
|
244 |
duration=duration,
|
|
|
248 |
temperature=temperature,
|
249 |
cfg_coef=cfg_coef,
|
250 |
)
|
251 |
+
# Save and return audio files
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
252 |
wav_paths = []
|
253 |
video_paths = []
|
254 |
# Save standard output
|
|
|
263 |
video_paths.append(video_path)
|
264 |
file_cleaner.add(file.name)
|
265 |
file_cleaner.add(video_path)
|
|
|
266 |
# Save MBD output if used
|
267 |
if diffusion_wav is not None:
|
268 |
with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file:
|
|
|
276 |
video_paths.append(video_path)
|
277 |
file_cleaner.add(file.name)
|
278 |
file_cleaner.add(video_path)
|
|
|
279 |
# Shutdown predictor to prevent hanging processes!
|
280 |
+
if not predictor.is_daemon: # Important!
|
281 |
predictor.shutdown()
|
|
|
282 |
if use_mbd:
|
283 |
+
return video_paths[0], wav_paths[0], video_paths[1], wav_paths[1]
|
284 |
return video_paths[0], wav_paths[0], None, None
|
285 |
+
|
286 |
def toggle_audio_src(choice):
|
287 |
if choice == "mic":
|
288 |
return gr.update(sources="microphone", value=None, label="Microphone")
|