Update demos/musicgen_app.py
Browse files- demos/musicgen_app.py +113 -56
demos/musicgen_app.py
CHANGED
@@ -1,5 +1,4 @@
|
|
1 |
import spaces # <--- IMPORTANT: Add this import
|
2 |
-
|
3 |
import argparse
|
4 |
import logging
|
5 |
import os
|
@@ -20,7 +19,6 @@ from audiocraft.models.encodec import InterleaveStereoCompressionModel
|
|
20 |
from audiocraft.models import MusicGen, MultiBandDiffusion
|
21 |
import multiprocessing as mp
|
22 |
|
23 |
-
|
24 |
# --- Utility Functions and Classes ---
|
25 |
|
26 |
class FileCleaner: # Unchanged
|
@@ -51,20 +49,20 @@ def make_waveform(*args, **kwargs): # Unchanged
|
|
51 |
print("Make a video took", time.time() - be)
|
52 |
return out
|
53 |
|
54 |
-
# --- Worker Process ---
|
55 |
-
|
56 |
def model_worker(model_name: str, task_queue: mp.Queue, result_queue: mp.Queue):
|
57 |
"""
|
58 |
-
Persistent worker process
|
59 |
"""
|
60 |
try:
|
61 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
62 |
model = MusicGen.get_pretrained(model_name, device=device)
|
63 |
-
mbd = MultiBandDiffusion.get_mbd_musicgen(device=device)
|
64 |
|
65 |
while True:
|
66 |
task = task_queue.get()
|
67 |
-
if task is None:
|
68 |
break
|
69 |
|
70 |
task_id, text, melody, duration, use_diffusion, gen_params = task
|
@@ -104,75 +102,134 @@ def model_worker(model_name: str, task_queue: mp.Queue, result_queue: mp.Queue):
|
|
104 |
assert outputs_diffusion.shape[1] == 1 # output is mono
|
105 |
outputs_diffusion = rearrange(outputs_diffusion, '(s b) c t -> b (s c) t', s=2)
|
106 |
outputs_diffusion = outputs_diffusion.detach().cpu()
|
107 |
-
result_queue.put((task_id, (output, outputs_diffusion)))
|
108 |
else:
|
109 |
-
result_queue.put((task_id, (output, None)))
|
110 |
|
111 |
except Exception as e:
|
112 |
-
result_queue.put((task_id, e))
|
113 |
|
114 |
except Exception as e:
|
115 |
-
result_queue.put((-1,e))
|
116 |
|
117 |
-
|
|
|
118 |
|
119 |
class Predictor:
|
120 |
-
#This stays the same, this is the intended design
|
121 |
def __init__(self, model_name: str):
|
122 |
-
self.
|
123 |
-
self.
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
129 |
|
130 |
def _check_initialization(self):
|
131 |
-
"""Check if the worker process initialized successfully."""
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
pass # Expected if model loads fast enough
|
142 |
|
143 |
def predict(self, text, melody, duration, use_diffusion, **gen_params):
|
144 |
-
"""
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
151 |
|
152 |
def get_result(self, task_id):
|
153 |
-
"""
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
|
|
|
|
|
|
|
|
|
|
162 |
|
163 |
def shutdown(self):
|
164 |
-
"""
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
self.task_queue.put(None) # Send sentinel value to stop the worker
|
169 |
-
self.process.join() # Wait for the process to terminate
|
170 |
|
171 |
-
# NO GLOBAL PREDICTOR ANYMORE
|
172 |
|
173 |
-
_default_model_name =
|
174 |
|
175 |
-
@spaces.GPU(duration=60) #
|
176 |
def predict_full(model, model_path, use_mbd, text, melody, duration, topk, topp, temperature, cfg_coef):
|
177 |
# Initialize Predictor *INSIDE* the function
|
178 |
predictor = Predictor(model)
|
|
|
1 |
import spaces # <--- IMPORTANT: Add this import
|
|
|
2 |
import argparse
|
3 |
import logging
|
4 |
import os
|
|
|
19 |
from audiocraft.models import MusicGen, MultiBandDiffusion
|
20 |
import multiprocessing as mp
|
21 |
|
|
|
22 |
# --- Utility Functions and Classes ---
|
23 |
|
24 |
class FileCleaner: # Unchanged
|
|
|
49 |
print("Make a video took", time.time() - be)
|
50 |
return out
|
51 |
|
52 |
+
# --- Worker Process --- (Modified for conditional use)
|
53 |
+
|
54 |
def model_worker(model_name: str, task_queue: mp.Queue, result_queue: mp.Queue):
|
55 |
"""
|
56 |
+
Persistent worker process (used when NOT running as a daemon).
|
57 |
"""
|
58 |
try:
|
59 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
60 |
model = MusicGen.get_pretrained(model_name, device=device)
|
61 |
+
mbd = MultiBandDiffusion.get_mbd_musicgen(device=device)
|
62 |
|
63 |
while True:
|
64 |
task = task_queue.get()
|
65 |
+
if task is None:
|
66 |
break
|
67 |
|
68 |
task_id, text, melody, duration, use_diffusion, gen_params = task
|
|
|
102 |
assert outputs_diffusion.shape[1] == 1 # output is mono
|
103 |
outputs_diffusion = rearrange(outputs_diffusion, '(s b) c t -> b (s c) t', s=2)
|
104 |
outputs_diffusion = outputs_diffusion.detach().cpu()
|
105 |
+
result_queue.put((task_id, (output, outputs_diffusion)))
|
106 |
else:
|
107 |
+
result_queue.put((task_id, (output, None)))
|
108 |
|
109 |
except Exception as e:
|
110 |
+
result_queue.put((task_id, e))
|
111 |
|
112 |
except Exception as e:
|
113 |
+
result_queue.put((-1, e))
|
114 |
|
115 |
+
|
116 |
+
# --- Predictor Class (Modified for conditional process creation) ---
|
117 |
|
118 |
class Predictor:
|
|
|
119 |
def __init__(self, model_name: str):
|
120 |
+
self.model_name = model_name
|
121 |
+
self.is_daemon = mp.current_process().daemon
|
122 |
+
if self.is_daemon:
|
123 |
+
# Running in a daemonic process (e.g., on Spaces)
|
124 |
+
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
125 |
+
self.model = MusicGen.get_pretrained(self.model_name, device=self.device)
|
126 |
+
self.mbd = MultiBandDiffusion.get_mbd_musicgen(device=self.device) # Load MBD here too
|
127 |
+
self.current_task_id = 0 # Initialize task ID
|
128 |
+
else:
|
129 |
+
# Running in a non-daemonic process (e.g., locally)
|
130 |
+
self.task_queue = mp.Queue()
|
131 |
+
self.result_queue = mp.Queue()
|
132 |
+
self.process = mp.Process(
|
133 |
+
target=model_worker, args=(self.model_name, self.task_queue, self.result_queue)
|
134 |
+
)
|
135 |
+
self.process.start()
|
136 |
+
self.current_task_id = 0
|
137 |
+
self._check_initialization()
|
138 |
|
139 |
def _check_initialization(self):
|
140 |
+
"""Check if the worker process initialized successfully (only in non-daemon mode)."""
|
141 |
+
if not self.is_daemon:
|
142 |
+
time.sleep(2)
|
143 |
+
try:
|
144 |
+
task_id, result = self.result_queue.get(timeout=3)
|
145 |
+
if isinstance(result, Exception):
|
146 |
+
if task_id == -1:
|
147 |
+
raise RuntimeError("Model loading failed in worker process.") from result
|
148 |
+
except:
|
149 |
+
pass
|
|
|
150 |
|
151 |
def predict(self, text, melody, duration, use_diffusion, **gen_params):
|
152 |
+
"""Submits a prediction task."""
|
153 |
+
if self.is_daemon:
|
154 |
+
# Directly perform the prediction (single-process mode)
|
155 |
+
self.current_task_id +=1
|
156 |
+
task_id = self.current_task_id
|
157 |
+
try:
|
158 |
+
self.model.set_generation_params(duration=duration, **gen_params)
|
159 |
+
target_sr = self.model.sample_rate
|
160 |
+
target_ac = 1
|
161 |
+
processed_melody = None
|
162 |
+
if melody:
|
163 |
+
sr, melody_data = melody
|
164 |
+
melody_tensor = torch.from_numpy(melody_data).to(self.device).float().t()
|
165 |
+
if melody_tensor.ndim == 1:
|
166 |
+
melody_tensor = melody_tensor.unsqueeze(0)
|
167 |
+
melody_tensor = melody_tensor[..., :int(sr * duration)]
|
168 |
+
processed_melody = convert_audio(melody_tensor, sr, target_sr, target_ac)
|
169 |
+
|
170 |
+
if processed_melody is not None:
|
171 |
+
output, tokens = self.model.generate_with_chroma(
|
172 |
+
descriptions=[text],
|
173 |
+
melody_wavs=[processed_melody],
|
174 |
+
melody_sample_rate=target_sr,
|
175 |
+
progress=True,
|
176 |
+
return_tokens=True
|
177 |
+
)
|
178 |
+
else:
|
179 |
+
output, tokens = self.model.generate([text], progress=True, return_tokens=True)
|
180 |
+
|
181 |
+
output = output.detach().cpu()
|
182 |
+
|
183 |
+
if use_diffusion:
|
184 |
+
if isinstance(self.model.compression_model, InterleaveStereoCompressionModel):
|
185 |
+
left, right = self.model.compression_model.get_left_right_codes(tokens)
|
186 |
+
tokens = torch.cat([left, right])
|
187 |
+
outputs_diffusion = self.mbd.tokens_to_wav(tokens)
|
188 |
+
if isinstance(self.model.compression_model, InterleaveStereoCompressionModel):
|
189 |
+
assert outputs_diffusion.shape[1] == 1
|
190 |
+
outputs_diffusion = rearrange(outputs_diffusion, '(s b) c t -> b (s c) t', s=2)
|
191 |
+
outputs_diffusion = outputs_diffusion.detach().cpu()
|
192 |
+
return task_id, (output, outputs_diffusion) #Return the task id.
|
193 |
+
else:
|
194 |
+
return task_id, (output, None)
|
195 |
+
|
196 |
+
|
197 |
+
except Exception as e:
|
198 |
+
return task_id, e
|
199 |
+
|
200 |
+
else:
|
201 |
+
# Use the multiprocessing queue (multi-process mode)
|
202 |
+
self.current_task_id += 1
|
203 |
+
task = (self.current_task_id, text, melody, duration, use_diffusion, gen_params)
|
204 |
+
self.task_queue.put(task)
|
205 |
+
return self.current_task_id
|
206 |
|
207 |
def get_result(self, task_id):
|
208 |
+
"""Retrieves the result of a prediction task."""
|
209 |
+
if self.is_daemon:
|
210 |
+
# Results are returned directly by 'predict' in daemon mode
|
211 |
+
result_id, result = task_id, task_id #predictor return (task_id, results)
|
212 |
+
else:
|
213 |
+
# Get result from the queue (multi-process mode)
|
214 |
+
while True:
|
215 |
+
result_task_id, result = self.result_queue.get()
|
216 |
+
if result_task_id == task_id:
|
217 |
+
break # Found the correct result
|
218 |
+
|
219 |
+
if isinstance(result, Exception):
|
220 |
+
raise result
|
221 |
+
return result
|
222 |
|
223 |
def shutdown(self):
|
224 |
+
"""Shuts down the worker process (if running)."""
|
225 |
+
if not self.is_daemon and self.process.is_alive():
|
226 |
+
self.task_queue.put(None)
|
227 |
+
self.process.join()
|
|
|
|
|
228 |
|
|
|
229 |
|
230 |
+
_default_model_name = "facebook/musicgen-melody"
|
231 |
|
232 |
+
@spaces.GPU(duration=60) # Use the decorator for Spaces
|
233 |
def predict_full(model, model_path, use_mbd, text, melody, duration, topk, topp, temperature, cfg_coef):
|
234 |
# Initialize Predictor *INSIDE* the function
|
235 |
predictor = Predictor(model)
|