quazim commited on
Commit
f94241a
Β·
1 Parent(s): 8eded56
Files changed (2) hide show
  1. app.py +209 -13
  2. requirements.txt +2 -1
app.py CHANGED
@@ -6,12 +6,20 @@ import random
6
  import os
7
  import tempfile
8
  import soundfile as sf
 
9
 
10
  os.environ['ELASTIC_LOG_LEVEL'] = 'DEBUG'
11
  from transformers import AutoProcessor, pipeline
12
  from elastic_models.transformers import MusicgenForConditionalGeneration
13
 
14
 
 
 
 
 
 
 
 
15
  def set_seed(seed: int = 42):
16
  random.seed(seed)
17
  np.random.seed(seed)
@@ -23,7 +31,6 @@ def set_seed(seed: int = 42):
23
 
24
 
25
  def cleanup_gpu():
26
- """Clean up GPU memory to avoid TensorRT conflicts."""
27
  if torch.cuda.is_available():
28
  torch.cuda.empty_cache()
29
  torch.cuda.synchronize()
@@ -31,7 +38,6 @@ def cleanup_gpu():
31
 
32
 
33
  def cleanup_temp_files():
34
- """Clean up old temporary audio files."""
35
  import glob
36
  import time
37
  temp_dir = tempfile.gettempdir()
@@ -47,6 +53,8 @@ def cleanup_temp_files():
47
 
48
  _generator = None
49
  _processor = None
 
 
50
 
51
 
52
  def load_model():
@@ -88,6 +96,44 @@ def load_model():
88
  return _generator, _processor
89
 
90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  def calculate_max_tokens(duration_seconds):
92
  token_rate = 50
93
  max_new_tokens = int(duration_seconds * token_rate)
@@ -160,9 +206,9 @@ def generate_music(text_prompt, duration=10, guidance_scale=3.0):
160
 
161
  max_val = np.max(np.abs(audio_data))
162
  if max_val > 0:
163
- audio_data = audio_data / max_val * 0.95 # Scale to 95% to avoid clipping
164
 
165
- audio_data = (audio_data * 32767).astype(np.int16) ###
166
 
167
  print(f"[GENERATION] Final audio shape: {audio_data.shape}")
168
  print(f"[GENERATION] Audio range: [{np.min(audio_data)}, {np.max(audio_data)}]")
@@ -180,6 +226,7 @@ def generate_music(text_prompt, duration=10, guidance_scale=3.0):
180
  print(f"[GENERATION] Audio saved to: {temp_path}")
181
  print(f"[GENERATION] File size: {file_size} bytes")
182
 
 
183
  print(f"[GENERATION] Returning numpy tuple: ({sample_rate}, audio_array)")
184
  return (sample_rate, audio_data)
185
  else:
@@ -192,9 +239,150 @@ def generate_music(text_prompt, duration=10, guidance_scale=3.0):
192
  return None
193
 
194
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
195
  with gr.Blocks(title="MusicGen Large - Music Generation") as demo:
196
  gr.Markdown("# 🎡 MusicGen Large Music Generator")
197
- gr.Markdown("Generate music from text descriptions using Facebook's MusicGen Large model with elastic compression.")
198
 
199
  with gr.Row():
200
  with gr.Column():
@@ -204,7 +392,7 @@ with gr.Blocks(title="MusicGen Large - Music Generation") as demo:
204
  lines=3,
205
  value="A groovy funk bassline with a tight drum beat"
206
  )
207
-
208
  with gr.Row():
209
  duration = gr.Slider(
210
  minimum=5,
@@ -222,13 +410,18 @@ with gr.Blocks(title="MusicGen Large - Music Generation") as demo:
222
  info="Higher values follow prompt more closely"
223
  )
224
 
225
- generate_btn = gr.Button("🎡 Generate Music", variant="primary", size="lg")
226
 
227
  with gr.Column():
228
- audio_output = gr.Audio(
229
- label="Generated Music",
230
- type="numpy"
231
- )
 
 
 
 
 
232
 
233
  with gr.Accordion("Tips", open=False):
234
  gr.Markdown("""
@@ -238,10 +431,13 @@ with gr.Blocks(title="MusicGen Large - Music Generation") as demo:
238
  - Duration is limited to 30 seconds for faster generation
239
  """)
240
 
 
 
 
241
  generate_btn.click(
242
- fn=generate_music,
243
  inputs=[text_input, duration, guidance_scale],
244
- outputs=audio_output,
245
  show_progress=True
246
  )
247
 
 
6
  import os
7
  import tempfile
8
  import soundfile as sf
9
+ import time
10
 
11
  os.environ['ELASTIC_LOG_LEVEL'] = 'DEBUG'
12
  from transformers import AutoProcessor, pipeline
13
  from elastic_models.transformers import MusicgenForConditionalGeneration
14
 
15
 
16
+ MODEL_CONFIG = {
17
+ 'cost_per_hour': 1.8, # $1.8 per hour
18
+ }
19
+
20
+ original_time_cache = {}
21
+
22
+
23
  def set_seed(seed: int = 42):
24
  random.seed(seed)
25
  np.random.seed(seed)
 
31
 
32
 
33
  def cleanup_gpu():
 
34
  if torch.cuda.is_available():
35
  torch.cuda.empty_cache()
36
  torch.cuda.synchronize()
 
38
 
39
 
40
  def cleanup_temp_files():
 
41
  import glob
42
  import time
43
  temp_dir = tempfile.gettempdir()
 
53
 
54
  _generator = None
55
  _processor = None
56
+ _original_generator = None
57
+ _original_processor = None
58
 
59
 
60
  def load_model():
 
96
  return _generator, _processor
97
 
98
 
99
+ def load_original_model():
100
+ global _original_generator, _original_processor
101
+
102
+ if _original_generator is None:
103
+ print("[ORIGINAL MODEL] Starting original model initialization...")
104
+ cleanup_gpu()
105
+
106
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
107
+ print(f"[ORIGINAL MODEL] Using device: {device}")
108
+
109
+ print("[ORIGINAL MODEL] Loading processor...")
110
+ _original_processor = AutoProcessor.from_pretrained(
111
+ "facebook/musicgen-large"
112
+ )
113
+ from transformers import MusicgenForConditionalGeneration as HFMusicgenForConditionalGeneration
114
+
115
+ print("[ORIGINAL MODEL] Loading original model...")
116
+ model = HFMusicgenForConditionalGeneration.from_pretrained(
117
+ "facebook/musicgen-large",
118
+ torch_dtype=torch.float16,
119
+ device=device,
120
+ )
121
+
122
+ model.eval()
123
+
124
+ print("[ORIGINAL MODEL] Creating pipeline...")
125
+ _original_generator = pipeline(
126
+ task="text-to-audio",
127
+ model=model,
128
+ tokenizer=_original_processor.tokenizer,
129
+ device=device,
130
+ )
131
+
132
+ print("[ORIGINAL MODEL] Original model initialization completed successfully")
133
+
134
+ return _original_generator, _original_processor
135
+
136
+
137
  def calculate_max_tokens(duration_seconds):
138
  token_rate = 50
139
  max_new_tokens = int(duration_seconds * token_rate)
 
206
 
207
  max_val = np.max(np.abs(audio_data))
208
  if max_val > 0:
209
+ audio_data = audio_data / max_val * 0.95
210
 
211
+ audio_data = (audio_data * 32767).astype(np.int16)
212
 
213
  print(f"[GENERATION] Final audio shape: {audio_data.shape}")
214
  print(f"[GENERATION] Audio range: [{np.min(audio_data)}, {np.max(audio_data)}]")
 
226
  print(f"[GENERATION] Audio saved to: {temp_path}")
227
  print(f"[GENERATION] File size: {file_size} bytes")
228
 
229
+ # Try returning numpy format instead
230
  print(f"[GENERATION] Returning numpy tuple: ({sample_rate}, audio_array)")
231
  return (sample_rate, audio_data)
232
  else:
 
239
  return None
240
 
241
 
242
+ def calculate_generation_cost(generation_time_seconds, mode='S'):
243
+ hours = generation_time_seconds / 3600
244
+ cost_per_hour = MODEL_CONFIG['cost_per_hour']
245
+ return hours * cost_per_hour
246
+
247
+
248
+ def calculate_cost_savings(compressed_time, original_time):
249
+ compressed_cost = calculate_generation_cost(compressed_time, 'S')
250
+ original_cost = calculate_generation_cost(original_time, 'original')
251
+ savings = original_cost - compressed_cost
252
+ savings_percent = (savings / original_cost * 100) if original_cost > 0 else 0
253
+ return {
254
+ 'compressed_cost': compressed_cost,
255
+ 'original_cost': original_cost,
256
+ 'savings': savings,
257
+ 'savings_percent': savings_percent
258
+ }
259
+
260
+
261
+ def get_cache_key(prompt, duration, guidance_scale):
262
+ return f"{hash(prompt)}_{duration}_{guidance_scale}"
263
+
264
+
265
+ def generate_music_batch(text_prompt, duration=10, guidance_scale=3.0, model_mode="compressed"):
266
+ try:
267
+ cache_key = get_cache_key(text_prompt, duration, guidance_scale)
268
+
269
+ generator, processor = load_model()
270
+ model_name = "Compressed (S)"
271
+
272
+ print(f"[GENERATION] Starting batch generation using {model_name} model...")
273
+ print(f"[GENERATION] Prompt: '{text_prompt}'")
274
+ print(f"[GENERATION] Duration: {duration}s")
275
+ print(f"[GENERATION] Guidance scale: {guidance_scale}")
276
+
277
+ cleanup_gpu()
278
+ set_seed(42)
279
+ print(f"[GENERATION] Using seed: 42")
280
+
281
+ max_new_tokens = calculate_max_tokens(duration)
282
+
283
+ generation_params = {
284
+ 'do_sample': True,
285
+ 'guidance_scale': guidance_scale,
286
+ 'max_new_tokens': max_new_tokens,
287
+ 'min_new_tokens': max_new_tokens,
288
+ 'cache_implementation': 'paged',
289
+ }
290
+
291
+ prompts = [text_prompt] * 4
292
+ start_time = time.time()
293
+ outputs = generator(
294
+ prompts,
295
+ batch_size=4,
296
+ generate_kwargs=generation_params
297
+ )
298
+ generation_time = time.time() - start_time
299
+
300
+ print(f"[GENERATION] Batch generation completed in {generation_time:.2f}s")
301
+
302
+ audio_variants = []
303
+ sample_rate = outputs[0]['sampling_rate']
304
+
305
+ for i, output in enumerate(outputs):
306
+ audio_data = output['audio']
307
+
308
+ print(f"[GENERATION] Processing variant {i+1} audio shape: {audio_data.shape}")
309
+
310
+ if hasattr(audio_data, 'cpu'):
311
+ audio_data = audio_data.cpu().numpy()
312
+
313
+ if len(audio_data.shape) == 3:
314
+ audio_data = audio_data[0]
315
+
316
+ if len(audio_data.shape) == 2:
317
+ if audio_data.shape[0] < audio_data.shape[1]:
318
+ audio_data = audio_data.T
319
+ if audio_data.shape[1] > 1:
320
+ audio_data = audio_data[:, 0]
321
+ else:
322
+ audio_data = audio_data.flatten()
323
+
324
+ audio_data = audio_data.flatten()
325
+
326
+ max_val = np.max(np.abs(audio_data))
327
+ if max_val > 0:
328
+ audio_data = audio_data / max_val * 0.95
329
+
330
+ audio_data = (audio_data * 32767).astype(np.int16)
331
+ audio_variants.append((sample_rate, audio_data))
332
+
333
+ print(f"[GENERATION] Variant {i+1} final shape: {audio_data.shape}")
334
+
335
+ comparison_message = ""
336
+
337
+ if cache_key in original_time_cache:
338
+ original_time = original_time_cache[cache_key]
339
+ cost_info = calculate_cost_savings(generation_time, original_time)
340
+
341
+ comparison_message = f"πŸ’° Cost Savings: ${cost_info['savings']:.4f} ({cost_info['savings_percent']:.1f}%) - Compressed: ${cost_info['compressed_cost']:.4f} vs Original: ${cost_info['original_cost']:.4f}"
342
+ print(f"[COST] Savings: ${cost_info['savings']:.4f} ({cost_info['savings_percent']:.1f}%)")
343
+ else:
344
+ try:
345
+ print(f"[TIMING] Measuring original model speed for comparison...")
346
+ original_generator, original_processor = load_original_model()
347
+
348
+ original_start = time.time()
349
+ original_outputs = original_generator(
350
+ prompts,
351
+ batch_size=4,
352
+ generate_kwargs=generation_params
353
+ )
354
+ original_time = time.time() - original_start
355
+
356
+ original_time_cache[cache_key] = original_time
357
+
358
+ cost_info = calculate_cost_savings(generation_time, original_time)
359
+ comparison_message = f"πŸ’° Cost Savings: ${cost_info['savings']:.4f} ({cost_info['savings_percent']:.1f}%) - Compressed: ${cost_info['compressed_cost']:.4f} vs Original: ${cost_info['original_cost']:.4f}"
360
+ print(f"[COST] First comparison - Savings: ${cost_info['savings']:.4f} ({cost_info['savings_percent']:.1f}%)")
361
+ print(f"[TIMING] Original: {original_time:.2f}s, Compressed: {generation_time:.2f}s")
362
+
363
+ del original_generator, original_processor
364
+ cleanup_gpu()
365
+ print(f"[CLEANUP] Original model cleaned up after timing measurement")
366
+
367
+ except Exception as e:
368
+ print(f"[WARNING] Could not measure original timing: {e}")
369
+ compressed_cost = calculate_generation_cost(generation_time, 'S')
370
+ comparison_message = f"πŸ’Έ Compressed Cost: ${compressed_cost:.4f} (could not compare with original)"
371
+
372
+ generation_info = f"βœ… Generated 4 variants in {generation_time:.2f}s\n{comparison_message}"
373
+
374
+ return audio_variants[0], audio_variants[1], audio_variants[2], audio_variants[3], generation_info
375
+
376
+ except Exception as e:
377
+ print(f"[ERROR] Batch generation failed: {str(e)}")
378
+ cleanup_gpu()
379
+ error_msg = f"❌ Generation failed: {str(e)}"
380
+ return None, None, None, None, error_msg
381
+
382
+
383
  with gr.Blocks(title="MusicGen Large - Music Generation") as demo:
384
  gr.Markdown("# 🎡 MusicGen Large Music Generator")
385
+ gr.Markdown("Generate music from text descriptions using Facebook's MusicGen Large model accelerated by TheStage for 2.3x faster performance")
386
 
387
  with gr.Row():
388
  with gr.Column():
 
392
  lines=3,
393
  value="A groovy funk bassline with a tight drum beat"
394
  )
395
+
396
  with gr.Row():
397
  duration = gr.Slider(
398
  minimum=5,
 
410
  info="Higher values follow prompt more closely"
411
  )
412
 
413
+ generate_btn = gr.Button("🎡 Generate 4 Music Variants", variant="primary", size="lg")
414
 
415
  with gr.Column():
416
+ generation_info = gr.Markdown("Ready to generate music variants with cost comparison vs original model")
417
+
418
+ with gr.Row():
419
+ audio_output1 = gr.Audio(label="Variant 1", type="numpy")
420
+ audio_output2 = gr.Audio(label="Variant 2", type="numpy")
421
+
422
+ with gr.Row():
423
+ audio_output3 = gr.Audio(label="Variant 3", type="numpy")
424
+ audio_output4 = gr.Audio(label="Variant 4", type="numpy")
425
 
426
  with gr.Accordion("Tips", open=False):
427
  gr.Markdown("""
 
431
  - Duration is limited to 30 seconds for faster generation
432
  """)
433
 
434
+ def generate_simple(text_prompt, duration, guidance_scale):
435
+ return generate_music_batch(text_prompt, duration, guidance_scale, "compressed")
436
+
437
  generate_btn.click(
438
+ fn=generate_simple,
439
  inputs=[text_input, duration, guidance_scale],
440
+ outputs=[audio_output1, audio_output2, audio_output3, audio_output4, generation_info],
441
  show_progress=True
442
  )
443
 
requirements.txt CHANGED
@@ -4,6 +4,7 @@
4
 
5
  torch
6
  thestage
7
- elastic_models[nvidia]
8
  scipy
9
  transformers
 
 
4
 
5
  torch
6
  thestage
7
+ # elastic_models[nvidia]
8
  scipy
9
  transformers
10
+ soundfile