quazim commited on
Commit
63eee27
Β·
1 Parent(s): e8054e6
Files changed (1) hide show
  1. app.py +258 -56
app.py CHANGED
@@ -6,24 +6,38 @@ import random
6
  import os
7
  import tempfile
8
  import soundfile as sf
 
9
 
10
  os.environ['ELASTIC_LOG_LEVEL'] = 'DEBUG'
11
  from transformers import AutoProcessor, pipeline
12
  from elastic_models.transformers import MusicgenForConditionalGeneration
13
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
- def set_seed(seed: int = 42):
16
- random.seed(seed)
17
- np.random.seed(seed)
18
- torch.manual_seed(seed)
19
- torch.cuda.manual_seed(seed)
20
- torch.cuda.manual_seed_all(seed)
21
- torch.backends.cudnn.deterministic = True
22
- torch.backends.cudnn.benchmark = False
 
 
 
23
 
24
 
25
  def cleanup_gpu():
26
- """Clean up GPU memory to avoid TensorRT conflicts."""
27
  if torch.cuda.is_available():
28
  torch.cuda.empty_cache()
29
  torch.cuda.synchronize()
@@ -31,7 +45,6 @@ def cleanup_gpu():
31
 
32
 
33
  def cleanup_temp_files():
34
- """Clean up old temporary audio files."""
35
  import glob
36
  import time
37
  temp_dir = tempfile.gettempdir()
@@ -47,6 +60,8 @@ def cleanup_temp_files():
47
 
48
  _generator = None
49
  _processor = None
 
 
50
 
51
 
52
  def load_model():
@@ -88,6 +103,43 @@ def load_model():
88
  return _generator, _processor
89
 
90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  def calculate_max_tokens(duration_seconds):
92
  token_rate = 50
93
  max_new_tokens = int(duration_seconds * token_rate)
@@ -107,7 +159,7 @@ def generate_music(text_prompt, duration=10, guidance_scale=3.0):
107
  cleanup_gpu()
108
 
109
  import time
110
- set_seed(42)
111
  print(f"[GENERATION] Using seed: {42}")
112
 
113
  max_new_tokens = calculate_max_tokens(duration)
@@ -160,9 +212,9 @@ def generate_music(text_prompt, duration=10, guidance_scale=3.0):
160
 
161
  max_val = np.max(np.abs(audio_data))
162
  if max_val > 0:
163
- audio_data = audio_data / max_val * 0.95 # Scale to 95% to avoid clipping
164
 
165
- audio_data = (audio_data * 32767).astype(np.int16) ###
166
 
167
  print(f"[GENERATION] Final audio shape: {audio_data.shape}")
168
  print(f"[GENERATION] Audio range: [{np.min(audio_data)}, {np.max(audio_data)}]")
@@ -180,6 +232,7 @@ def generate_music(text_prompt, duration=10, guidance_scale=3.0):
180
  print(f"[GENERATION] Audio saved to: {temp_path}")
181
  print(f"[GENERATION] File size: {file_size} bytes")
182
 
 
183
  print(f"[GENERATION] Returning numpy tuple: ({sample_rate}, audio_array)")
184
  return (sample_rate, audio_data)
185
  else:
@@ -192,56 +245,205 @@ def generate_music(text_prompt, duration=10, guidance_scale=3.0):
192
  return None
193
 
194
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
195
  with gr.Blocks(title="MusicGen Large - Music Generation") as demo:
196
  gr.Markdown("# 🎡 MusicGen Large Music Generator")
197
- gr.Markdown("Generate music from text descriptions using Facebook's MusicGen Large model with elastic compression.")
198
-
199
- with gr.Row():
200
- with gr.Column():
201
- text_input = gr.Textbox(
202
- label="Music Description",
203
- placeholder="Enter a description of the music you want to generate",
204
- lines=3,
205
- value="A groovy funk bassline with a tight drum beat"
 
 
 
 
 
 
 
 
 
 
 
206
  )
207
-
208
- with gr.Row():
209
- duration = gr.Slider(
210
- minimum=5,
211
- maximum=30,
212
- value=10,
213
- step=1,
214
- label="Duration (seconds)"
215
- )
216
- guidance_scale = gr.Slider(
217
- minimum=1.0,
218
- maximum=10.0,
219
- value=3.0,
220
- step=0.5,
221
- label="Guidance Scale",
222
- info="Higher values follow prompt more closely"
223
- )
224
-
225
- generate_btn = gr.Button("🎡 Generate Music", variant="primary", size="lg")
226
-
227
- with gr.Column():
228
- audio_output = gr.Audio(
229
- label="Generated Music",
230
- type="numpy"
231
  )
232
-
233
- with gr.Accordion("Tips", open=False):
234
- gr.Markdown("""
235
- - Be specific in your descriptions (e.g., "slow blues guitar with harmonica")
236
- - Higher guidance scale = follows prompt more closely
237
- - Lower guidance scale = more creative/varied results
238
- - Duration is limited to 30 seconds for faster generation
239
- """)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
240
 
241
  generate_btn.click(
242
- fn=generate_music,
243
  inputs=[text_input, duration, guidance_scale],
244
- outputs=audio_output,
245
  show_progress=True
246
  )
247
 
 
6
  import os
7
  import tempfile
8
  import soundfile as sf
9
+ import time
10
 
11
  os.environ['ELASTIC_LOG_LEVEL'] = 'DEBUG'
12
  from transformers import AutoProcessor, pipeline
13
  from elastic_models.transformers import MusicgenForConditionalGeneration
14
 
15
+ MODEL_CONFIG = {
16
+ 'cost_per_hour': 1.8, # $1.8 per hour on L40S
17
+ 'cost_savings_1000h': {
18
+ 'savings_dollars': 8.4, # $8.4 saved per 1000 hours
19
+ 'savings_percent': 74.9, # 74.9% savings
20
+ 'compressed_cost': 2.8, # $2.8 for compressed
21
+ 'original_cost': 11.3, # $11.3 for original
22
+ },
23
+ 'batch_mode': True,
24
+ 'batch_size': 2 # Number of variants to generate (2, 4, 6, etc.)
25
+ }
26
 
27
+ original_time_cache = {"original_time": 22.57}
28
+
29
+
30
+ # def set_seed(seed: int = 42):
31
+ # random.seed(seed)
32
+ # np.random.seed(seed)
33
+ # torch.manual_seed(seed)
34
+ # torch.cuda.manual_seed(seed)
35
+ # torch.cuda.manual_seed_all(seed)
36
+ # torch.backends.cudnn.deterministic = True
37
+ # torch.backends.cudnn.benchmark = False
38
 
39
 
40
  def cleanup_gpu():
 
41
  if torch.cuda.is_available():
42
  torch.cuda.empty_cache()
43
  torch.cuda.synchronize()
 
45
 
46
 
47
  def cleanup_temp_files():
 
48
  import glob
49
  import time
50
  temp_dir = tempfile.gettempdir()
 
60
 
61
  _generator = None
62
  _processor = None
63
+ _original_generator = None
64
+ _original_processor = None
65
 
66
 
67
  def load_model():
 
103
  return _generator, _processor
104
 
105
 
106
+ def load_original_model():
107
+ global _original_generator, _original_processor
108
+
109
+ if _original_generator is None:
110
+ print("[ORIGINAL MODEL] Starting original model initialization...")
111
+ cleanup_gpu()
112
+
113
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
114
+ print(f"[ORIGINAL MODEL] Using device: {device}")
115
+
116
+ print("[ORIGINAL MODEL] Loading processor...")
117
+ _original_processor = AutoProcessor.from_pretrained(
118
+ "facebook/musicgen-large"
119
+ )
120
+ from transformers import MusicgenForConditionalGeneration as HFMusicgenForConditionalGeneration
121
+
122
+ print("[ORIGINAL MODEL] Loading original model...")
123
+ model = HFMusicgenForConditionalGeneration.from_pretrained(
124
+ "facebook/musicgen-large",
125
+ torch_dtype=torch.float16,
126
+ ).to(device)
127
+
128
+ model.eval()
129
+
130
+ print("[ORIGINAL MODEL] Creating pipeline...")
131
+ _original_generator = pipeline(
132
+ task="text-to-audio",
133
+ model=model,
134
+ tokenizer=_original_processor.tokenizer,
135
+ device=device,
136
+ )
137
+
138
+ print("[ORIGINAL MODEL] Original model initialization completed successfully")
139
+
140
+ return _original_generator, _original_processor
141
+
142
+
143
  def calculate_max_tokens(duration_seconds):
144
  token_rate = 50
145
  max_new_tokens = int(duration_seconds * token_rate)
 
159
  cleanup_gpu()
160
 
161
  import time
162
+ # set_seed(42)
163
  print(f"[GENERATION] Using seed: {42}")
164
 
165
  max_new_tokens = calculate_max_tokens(duration)
 
212
 
213
  max_val = np.max(np.abs(audio_data))
214
  if max_val > 0:
215
+ audio_data = audio_data / max_val * 0.95
216
 
217
+ audio_data = (audio_data * 32767).astype(np.int16)
218
 
219
  print(f"[GENERATION] Final audio shape: {audio_data.shape}")
220
  print(f"[GENERATION] Audio range: [{np.min(audio_data)}, {np.max(audio_data)}]")
 
232
  print(f"[GENERATION] Audio saved to: {temp_path}")
233
  print(f"[GENERATION] File size: {file_size} bytes")
234
 
235
+ # Try returning numpy format instead
236
  print(f"[GENERATION] Returning numpy tuple: ({sample_rate}, audio_array)")
237
  return (sample_rate, audio_data)
238
  else:
 
245
  return None
246
 
247
 
248
+ def calculate_generation_cost(generation_time_seconds, mode='S'):
249
+ hours = generation_time_seconds / 3600
250
+ cost_per_hour = MODEL_CONFIG['cost_per_hour']
251
+ return hours * cost_per_hour
252
+
253
+
254
+ def calculate_cost_savings(compressed_time, original_time):
255
+ compressed_cost = calculate_generation_cost(compressed_time, 'S')
256
+ original_cost = calculate_generation_cost(original_time, 'original')
257
+ savings = original_cost - compressed_cost
258
+ savings_percent = (savings / original_cost * 100) if original_cost > 0 else 0
259
+ return {
260
+ 'compressed_cost': compressed_cost,
261
+ 'original_cost': original_cost,
262
+ 'savings': savings,
263
+ 'savings_percent': savings_percent
264
+ }
265
+
266
+
267
+ def get_fixed_savings_message():
268
+ config = MODEL_CONFIG['cost_savings_1000h']
269
+ return f"πŸ’° **Cost Savings for generation batch size 4 on L40S (1000h)**: ${config['savings_dollars']:.1f}" \
270
+ f" ({config['savings_percent']:.1f}%) - Compressed: ${config['compressed_cost']:.1f} " \
271
+ f"vs Original: ${config['original_cost']:.1f}"
272
+
273
+
274
+ def get_cache_key(prompt, duration, guidance_scale):
275
+ return f"{hash(prompt)}_{duration}_{guidance_scale}"
276
+
277
+
278
+ def generate_music_batch(text_prompt, duration=10, guidance_scale=3.0, model_mode="compressed"):
279
+ try:
280
+ generator, processor = load_model()
281
+ model_name = "Compressed (S)"
282
+
283
+ print(f"[GENERATION] Starting generation using {model_name} model...")
284
+ print(f"[GENERATION] Prompt: '{text_prompt}'")
285
+ print(f"[GENERATION] Duration: {duration}s")
286
+ print(f"[GENERATION] Guidance scale: {guidance_scale}")
287
+ print(f"[GENERATION] Batch mode: {MODEL_CONFIG['batch_mode']}")
288
+ print(f"[GENERATION] Batch size: {MODEL_CONFIG['batch_size']}")
289
+
290
+ cleanup_gpu()
291
+ # set_seed(42)
292
+ print(f"[GENERATION] Using seed: {42}")
293
+
294
+ max_new_tokens = calculate_max_tokens(duration)
295
+
296
+ generation_params = {
297
+ 'do_sample': True,
298
+ 'guidance_scale': guidance_scale,
299
+ 'max_new_tokens': max_new_tokens,
300
+ 'min_new_tokens': max_new_tokens,
301
+ 'cache_implementation': 'paged',
302
+ }
303
+
304
+ batch_size = MODEL_CONFIG['batch_size'] if MODEL_CONFIG['batch_mode'] else 1
305
+ prompts = [text_prompt] * batch_size
306
+
307
+ start_time = time.time()
308
+ outputs = generator(
309
+ prompts,
310
+ batch_size=batch_size,
311
+ generate_kwargs=generation_params
312
+ )
313
+ generation_time = time.time() - start_time
314
+
315
+ print(f"[GENERATION] Generation completed in {generation_time:.2f}s")
316
+
317
+ audio_variants = []
318
+ sample_rate = outputs[0]['sampling_rate']
319
+
320
+ for i, output in enumerate(outputs):
321
+ audio_data = output['audio']
322
+
323
+ print(f"[GENERATION] Processing variant {i + 1} audio shape: {audio_data.shape}")
324
+
325
+ if hasattr(audio_data, 'cpu'):
326
+ audio_data = audio_data.cpu().numpy()
327
+
328
+ if len(audio_data.shape) == 3:
329
+ audio_data = audio_data[0]
330
+
331
+ if len(audio_data.shape) == 2:
332
+ if audio_data.shape[0] < audio_data.shape[1]:
333
+ audio_data = audio_data.T
334
+ if audio_data.shape[1] > 1:
335
+ audio_data = audio_data[:, 0]
336
+ else:
337
+ audio_data = audio_data.flatten()
338
+
339
+ audio_data = audio_data.flatten()
340
+
341
+ max_val = np.max(np.abs(audio_data))
342
+ if max_val > 0:
343
+ audio_data = audio_data / max_val * 0.95
344
+
345
+ audio_data = (audio_data * 32767).astype(np.int16)
346
+ audio_variants.append((sample_rate, audio_data))
347
+
348
+ print(f"[GENERATION] Variant {i + 1} final shape: {audio_data.shape}")
349
+
350
+ while len(audio_variants) < 6:
351
+ audio_variants.append(None)
352
+
353
+ savings_message = get_fixed_savings_message()
354
+
355
+ variants_text = "audio"
356
+ generation_info = f"βœ… Generated {variants_text} in {generation_time:.2f}s\n{savings_message}"
357
+
358
+ return audio_variants[0], audio_variants[1], audio_variants[2], audio_variants[3], audio_variants[4], audio_variants[5], generation_info
359
+
360
+ except Exception as e:
361
+ print(f"[ERROR] Batch generation failed: {str(e)}")
362
+ cleanup_gpu()
363
+ error_msg = f"❌ Generation failed: {str(e)}"
364
+ return None, None, None, None, None, None, error_msg
365
+
366
+
367
  with gr.Blocks(title="MusicGen Large - Music Generation") as demo:
368
  gr.Markdown("# 🎡 MusicGen Large Music Generator")
369
+
370
+ gr.Markdown(
371
+ f"Generate music from text descriptions using Facebook's MusicGen "
372
+ f"Large model accelerated by TheStage for 2.3x faster performance.")
373
+
374
+ with gr.Column():
375
+ text_input = gr.Textbox(
376
+ label="Music Description",
377
+ placeholder="Enter a description of the music you want to generate",
378
+ lines=3,
379
+ value="A groovy funk bassline with a tight drum beat"
380
+ )
381
+
382
+ with gr.Row():
383
+ duration = gr.Slider(
384
+ minimum=5,
385
+ maximum=30,
386
+ value=10,
387
+ step=1,
388
+ label="Duration (seconds)"
389
  )
390
+ guidance_scale = gr.Slider(
391
+ minimum=1.0,
392
+ maximum=10.0,
393
+ value=3.0,
394
+ step=0.5,
395
+ label="Guidance Scale",
396
+ info="Higher values follow prompt more closely"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
397
  )
398
+
399
+ generate_btn = gr.Button("🎡 Generate Music", variant="primary", size="lg")
400
+
401
+ generation_info = gr.Markdown("Ready to generate music with elastic acceleration")
402
+
403
+ audio_section_title = "### Generated Music"
404
+ gr.Markdown(audio_section_title)
405
+
406
+ actual_outputs = MODEL_CONFIG['batch_size'] if MODEL_CONFIG['batch_mode'] else 1
407
+
408
+ audio_outputs = []
409
+
410
+ with gr.Row():
411
+ audio_output1 = gr.Audio(label="Variant 1", type="numpy", visible=actual_outputs >= 1)
412
+ audio_output2 = gr.Audio(label="Variant 2", type="numpy", visible=actual_outputs >= 2)
413
+ audio_outputs.extend([audio_output1, audio_output2])
414
+
415
+ with gr.Row():
416
+ audio_output3 = gr.Audio(label="Variant 3", type="numpy", visible=actual_outputs >= 3)
417
+ audio_output4 = gr.Audio(label="Variant 4", type="numpy", visible=actual_outputs >= 4)
418
+ audio_outputs.extend([audio_output3, audio_output4])
419
+
420
+ with gr.Row():
421
+ audio_output5 = gr.Audio(label="Variant 5", type="numpy", visible=actual_outputs >= 5)
422
+ audio_output6 = gr.Audio(label="Variant 6", type="numpy", visible=actual_outputs >= 6)
423
+ audio_outputs.extend([audio_output5, audio_output6])
424
+
425
+ savings_banner = gr.Markdown(get_fixed_savings_message())
426
+
427
+ with gr.Accordion("πŸ’‘ Tips & Information", open=False):
428
+ gr.Markdown(f"""
429
+ **Generation Tips:**
430
+ - Be specific in your descriptions (e.g., "slow blues guitar with harmonica")
431
+ - Higher guidance scale = follows prompt more closely
432
+ - Lower guidance scale = more creative/varied results
433
+ - Duration is limited to 30 seconds for faster generation
434
+
435
+ **Performance:**
436
+ - Accelerated by TheStage elastic compression
437
+ - L40S GPU pricing: $1.8/hour
438
+ """)
439
+
440
+ def generate_simple(text_prompt, duration, guidance_scale):
441
+ return generate_music_batch(text_prompt, duration, guidance_scale, "compressed")
442
 
443
  generate_btn.click(
444
+ fn=generate_simple,
445
  inputs=[text_input, duration, guidance_scale],
446
+ outputs=[audio_output1, audio_output2, audio_output3, audio_output4, audio_output5, audio_output6, generation_info],
447
  show_progress=True
448
  )
449