ford442 commited on
Commit
efa2898
·
verified ·
1 Parent(s): 590092f

Update demos/musicgen_app.py

Browse files
Files changed (1) hide show
  1. demos/musicgen_app.py +113 -56
demos/musicgen_app.py CHANGED
@@ -1,5 +1,4 @@
1
  import spaces # <--- IMPORTANT: Add this import
2
-
3
  import argparse
4
  import logging
5
  import os
@@ -20,7 +19,6 @@ from audiocraft.models.encodec import InterleaveStereoCompressionModel
20
  from audiocraft.models import MusicGen, MultiBandDiffusion
21
  import multiprocessing as mp
22
 
23
-
24
  # --- Utility Functions and Classes ---
25
 
26
  class FileCleaner: # Unchanged
@@ -51,20 +49,20 @@ def make_waveform(*args, **kwargs): # Unchanged
51
  print("Make a video took", time.time() - be)
52
  return out
53
 
54
- # --- Worker Process ---
55
- #This stays the same, since the worker is designed for this purpose
56
  def model_worker(model_name: str, task_queue: mp.Queue, result_queue: mp.Queue):
57
  """
58
- Persistent worker process that loads the model and handles prediction tasks.
59
  """
60
  try:
61
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
62
  model = MusicGen.get_pretrained(model_name, device=device)
63
- mbd = MultiBandDiffusion.get_mbd_musicgen(device=device) # Load MBD here too
64
 
65
  while True:
66
  task = task_queue.get()
67
- if task is None: # Sentinel value to exit
68
  break
69
 
70
  task_id, text, melody, duration, use_diffusion, gen_params = task
@@ -104,75 +102,134 @@ def model_worker(model_name: str, task_queue: mp.Queue, result_queue: mp.Queue):
104
  assert outputs_diffusion.shape[1] == 1 # output is mono
105
  outputs_diffusion = rearrange(outputs_diffusion, '(s b) c t -> b (s c) t', s=2)
106
  outputs_diffusion = outputs_diffusion.detach().cpu()
107
- result_queue.put((task_id, (output, outputs_diffusion))) # Send BOTH results.
108
  else:
109
- result_queue.put((task_id, (output, None))) # Send back the result
110
 
111
  except Exception as e:
112
- result_queue.put((task_id, e)) # Send back the exception
113
 
114
  except Exception as e:
115
- result_queue.put((-1,e)) #Fatal error on loading.
116
 
117
- # --- Gradio Interface Functions ---
 
118
 
119
  class Predictor:
120
- #This stays the same, this is the intended design
121
  def __init__(self, model_name: str):
122
- self.task_queue = mp.Queue()
123
- self.result_queue = mp.Queue()
124
- self.process = mp.Process(target=model_worker, args=(model_name, self.task_queue, self.result_queue))
125
- self.process.start()
126
- self.current_task_id = 0
127
- self._check_initialization()
128
-
 
 
 
 
 
 
 
 
 
 
 
129
 
130
  def _check_initialization(self):
131
- """Check if the worker process initialized successfully."""
132
- # Give it some time to either load or report failure.
133
- time.sleep(2)
134
- try:
135
- task_id, result = self.result_queue.get(timeout=3) # Get result from model_worker
136
-
137
- if isinstance(result, Exception):
138
- if task_id == -1:
139
- raise RuntimeError("Model loading failed in worker process.") from result
140
- except:
141
- pass # Expected if model loads fast enough
142
 
143
  def predict(self, text, melody, duration, use_diffusion, **gen_params):
144
- """
145
- Submits a prediction task to the worker process.
146
- """
147
- self.current_task_id += 1
148
- task = (self.current_task_id, text, melody, duration, use_diffusion, gen_params)
149
- self.task_queue.put(task)
150
- return self.current_task_id
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
 
152
  def get_result(self, task_id):
153
- """
154
- Retrieves the result of a prediction task. Blocks until the result is available.
155
- """
156
- while True: # Loop to get the correct task
157
- result_task_id, result = self.result_queue.get()
158
- if result_task_id == task_id:
159
- if isinstance(result, Exception):
160
- raise result # Re-raise the exception in the main process
161
- return result # (wav, diffusion_wav) or (wav, None)
 
 
 
 
 
162
 
163
  def shutdown(self):
164
- """
165
- Shuts down the worker process.
166
- """
167
- if self.process.is_alive():
168
- self.task_queue.put(None) # Send sentinel value to stop the worker
169
- self.process.join() # Wait for the process to terminate
170
 
171
- # NO GLOBAL PREDICTOR ANYMORE
172
 
173
- _default_model_name = 'facebook/musicgen-melody'
174
 
175
- @spaces.GPU(duration=60) # <--- IMPORTANT: Add this decorator
176
  def predict_full(model, model_path, use_mbd, text, melody, duration, topk, topp, temperature, cfg_coef):
177
  # Initialize Predictor *INSIDE* the function
178
  predictor = Predictor(model)
 
1
  import spaces # <--- IMPORTANT: Add this import
 
2
  import argparse
3
  import logging
4
  import os
 
19
  from audiocraft.models import MusicGen, MultiBandDiffusion
20
  import multiprocessing as mp
21
 
 
22
  # --- Utility Functions and Classes ---
23
 
24
  class FileCleaner: # Unchanged
 
49
  print("Make a video took", time.time() - be)
50
  return out
51
 
52
+ # --- Worker Process --- (Modified for conditional use)
53
+
54
  def model_worker(model_name: str, task_queue: mp.Queue, result_queue: mp.Queue):
55
  """
56
+ Persistent worker process (used when NOT running as a daemon).
57
  """
58
  try:
59
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
60
  model = MusicGen.get_pretrained(model_name, device=device)
61
+ mbd = MultiBandDiffusion.get_mbd_musicgen(device=device)
62
 
63
  while True:
64
  task = task_queue.get()
65
+ if task is None:
66
  break
67
 
68
  task_id, text, melody, duration, use_diffusion, gen_params = task
 
102
  assert outputs_diffusion.shape[1] == 1 # output is mono
103
  outputs_diffusion = rearrange(outputs_diffusion, '(s b) c t -> b (s c) t', s=2)
104
  outputs_diffusion = outputs_diffusion.detach().cpu()
105
+ result_queue.put((task_id, (output, outputs_diffusion)))
106
  else:
107
+ result_queue.put((task_id, (output, None)))
108
 
109
  except Exception as e:
110
+ result_queue.put((task_id, e))
111
 
112
  except Exception as e:
113
+ result_queue.put((-1, e))
114
 
115
+
116
+ # --- Predictor Class (Modified for conditional process creation) ---
117
 
118
  class Predictor:
 
119
  def __init__(self, model_name: str):
120
+ self.model_name = model_name
121
+ self.is_daemon = mp.current_process().daemon
122
+ if self.is_daemon:
123
+ # Running in a daemonic process (e.g., on Spaces)
124
+ self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
125
+ self.model = MusicGen.get_pretrained(self.model_name, device=self.device)
126
+ self.mbd = MultiBandDiffusion.get_mbd_musicgen(device=self.device) # Load MBD here too
127
+ self.current_task_id = 0 # Initialize task ID
128
+ else:
129
+ # Running in a non-daemonic process (e.g., locally)
130
+ self.task_queue = mp.Queue()
131
+ self.result_queue = mp.Queue()
132
+ self.process = mp.Process(
133
+ target=model_worker, args=(self.model_name, self.task_queue, self.result_queue)
134
+ )
135
+ self.process.start()
136
+ self.current_task_id = 0
137
+ self._check_initialization()
138
 
139
  def _check_initialization(self):
140
+ """Check if the worker process initialized successfully (only in non-daemon mode)."""
141
+ if not self.is_daemon:
142
+ time.sleep(2)
143
+ try:
144
+ task_id, result = self.result_queue.get(timeout=3)
145
+ if isinstance(result, Exception):
146
+ if task_id == -1:
147
+ raise RuntimeError("Model loading failed in worker process.") from result
148
+ except:
149
+ pass
 
150
 
151
  def predict(self, text, melody, duration, use_diffusion, **gen_params):
152
+ """Submits a prediction task."""
153
+ if self.is_daemon:
154
+ # Directly perform the prediction (single-process mode)
155
+ self.current_task_id +=1
156
+ task_id = self.current_task_id
157
+ try:
158
+ self.model.set_generation_params(duration=duration, **gen_params)
159
+ target_sr = self.model.sample_rate
160
+ target_ac = 1
161
+ processed_melody = None
162
+ if melody:
163
+ sr, melody_data = melody
164
+ melody_tensor = torch.from_numpy(melody_data).to(self.device).float().t()
165
+ if melody_tensor.ndim == 1:
166
+ melody_tensor = melody_tensor.unsqueeze(0)
167
+ melody_tensor = melody_tensor[..., :int(sr * duration)]
168
+ processed_melody = convert_audio(melody_tensor, sr, target_sr, target_ac)
169
+
170
+ if processed_melody is not None:
171
+ output, tokens = self.model.generate_with_chroma(
172
+ descriptions=[text],
173
+ melody_wavs=[processed_melody],
174
+ melody_sample_rate=target_sr,
175
+ progress=True,
176
+ return_tokens=True
177
+ )
178
+ else:
179
+ output, tokens = self.model.generate([text], progress=True, return_tokens=True)
180
+
181
+ output = output.detach().cpu()
182
+
183
+ if use_diffusion:
184
+ if isinstance(self.model.compression_model, InterleaveStereoCompressionModel):
185
+ left, right = self.model.compression_model.get_left_right_codes(tokens)
186
+ tokens = torch.cat([left, right])
187
+ outputs_diffusion = self.mbd.tokens_to_wav(tokens)
188
+ if isinstance(self.model.compression_model, InterleaveStereoCompressionModel):
189
+ assert outputs_diffusion.shape[1] == 1
190
+ outputs_diffusion = rearrange(outputs_diffusion, '(s b) c t -> b (s c) t', s=2)
191
+ outputs_diffusion = outputs_diffusion.detach().cpu()
192
+ return task_id, (output, outputs_diffusion) #Return the task id.
193
+ else:
194
+ return task_id, (output, None)
195
+
196
+
197
+ except Exception as e:
198
+ return task_id, e
199
+
200
+ else:
201
+ # Use the multiprocessing queue (multi-process mode)
202
+ self.current_task_id += 1
203
+ task = (self.current_task_id, text, melody, duration, use_diffusion, gen_params)
204
+ self.task_queue.put(task)
205
+ return self.current_task_id
206
 
207
  def get_result(self, task_id):
208
+ """Retrieves the result of a prediction task."""
209
+ if self.is_daemon:
210
+ # Results are returned directly by 'predict' in daemon mode
211
+ result_id, result = task_id, task_id #predictor return (task_id, results)
212
+ else:
213
+ # Get result from the queue (multi-process mode)
214
+ while True:
215
+ result_task_id, result = self.result_queue.get()
216
+ if result_task_id == task_id:
217
+ break # Found the correct result
218
+
219
+ if isinstance(result, Exception):
220
+ raise result
221
+ return result
222
 
223
  def shutdown(self):
224
+ """Shuts down the worker process (if running)."""
225
+ if not self.is_daemon and self.process.is_alive():
226
+ self.task_queue.put(None)
227
+ self.process.join()
 
 
228
 
 
229
 
230
+ _default_model_name = "facebook/musicgen-melody"
231
 
232
+ @spaces.GPU(duration=60) # Use the decorator for Spaces
233
  def predict_full(model, model_path, use_mbd, text, melody, duration, topk, topp, temperature, cfg_coef):
234
  # Initialize Predictor *INSIDE* the function
235
  predictor = Predictor(model)