quazim commited on
Commit
3d157c8
·
1 Parent(s): 341afaa
Files changed (1) hide show
  1. app.py +53 -25
app.py CHANGED
@@ -4,10 +4,14 @@ import gc
4
  import numpy as np
5
  import random
6
  import os
 
 
 
7
  os.environ['ELASTIC_LOG_LEVEL'] = 'DEBUG'
8
  from transformers import AutoProcessor, pipeline
9
  from elastic_models.transformers import MusicgenForConditionalGeneration
10
 
 
11
  def set_seed(seed: int = 42):
12
  random.seed(seed)
13
  np.random.seed(seed)
@@ -17,6 +21,7 @@ def set_seed(seed: int = 42):
17
  torch.backends.cudnn.deterministic = True
18
  torch.backends.cudnn.benchmark = False
19
 
 
20
  def cleanup_gpu():
21
  """Clean up GPU memory to avoid TensorRT conflicts."""
22
  if torch.cuda.is_available():
@@ -24,17 +29,33 @@ def cleanup_gpu():
24
  torch.cuda.synchronize()
25
  gc.collect()
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  _generator = None
28
  _processor = None
29
 
 
30
  def load_model():
31
- """Load the musicgen model and processor using pipeline approach"""
32
  global _generator, _processor
33
-
34
  if _generator is None:
35
  print("[MODEL] Starting model initialization...")
36
  cleanup_gpu()
37
-
38
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
39
  print(f"[MODEL] Using device: {device}")
40
 
@@ -42,7 +63,7 @@ def load_model():
42
  _processor = AutoProcessor.from_pretrained(
43
  "facebook/musicgen-large"
44
  )
45
-
46
  print("[MODEL] Loading model...")
47
  model = MusicgenForConditionalGeneration.from_pretrained(
48
  "facebook/musicgen-large",
@@ -51,9 +72,9 @@ def load_model():
51
  mode="S",
52
  __paged=True,
53
  )
54
-
55
  model.eval()
56
-
57
  print("[MODEL] Creating pipeline...")
58
  _generator = pipeline(
59
  task="text-to-audio",
@@ -61,34 +82,36 @@ def load_model():
61
  tokenizer=_processor.tokenizer,
62
  device=device,
63
  )
64
-
65
  print("[MODEL] Model initialization completed successfully")
66
 
67
  return _generator, _processor
68
 
 
69
  def calculate_max_tokens(duration_seconds):
70
  token_rate = 50
71
  max_new_tokens = int(duration_seconds * token_rate)
72
  print(f"[MODEL] Duration: {duration_seconds}s -> Tokens: {max_new_tokens} (rate: {token_rate})")
73
  return max_new_tokens
74
 
 
75
  def generate_music(text_prompt, duration=10, guidance_scale=3.0):
76
  try:
77
  generator, processor = load_model()
78
-
79
  print(f"[GENERATION] Starting generation...")
80
  print(f"[GENERATION] Prompt: '{text_prompt}'")
81
  print(f"[GENERATION] Duration: {duration}s")
82
  print(f"[GENERATION] Guidance scale: {guidance_scale}")
83
-
84
  cleanup_gpu()
85
-
86
  import time
87
  set_seed(42)
88
  print(f"[GENERATION] Using seed: {42}")
89
-
90
  max_new_tokens = calculate_max_tokens(duration)
91
-
92
  generation_params = {
93
  'do_sample': True,
94
  'guidance_scale': guidance_scale,
@@ -96,39 +119,43 @@ def generate_music(text_prompt, duration=10, guidance_scale=3.0):
96
  'min_new_tokens': max_new_tokens,
97
  'cache_implementation': 'paged',
98
  }
99
-
100
  prompts = [text_prompt]
101
  outputs = generator(
102
  prompts,
103
  batch_size=1,
104
  generate_kwargs=generation_params
105
  )
106
-
107
  print(f"[GENERATION] Generation completed successfully")
108
-
109
  output = outputs[0]
110
  audio_data = output['audio']
111
  sample_rate = output['sampling_rate']
112
-
113
  print(f"[GENERATION] Audio shape: {audio_data.shape}")
114
  print(f"[GENERATION] Sample rate: {sample_rate}")
115
-
116
  if len(audio_data.shape) > 1:
117
- # If stereo or multi-channel, take first channel
118
  audio_data = audio_data[0] if audio_data.shape[0] < audio_data.shape[1] else audio_data[:, 0]
119
-
120
  audio_data = audio_data.flatten()
121
-
122
  max_val = np.max(np.abs(audio_data))
123
  if max_val > 0:
124
  audio_data = audio_data / max_val * 0.95 # Scale to 95% to avoid clipping
125
-
126
  audio_data = audio_data.astype(np.float32)
127
-
128
  print(f"[GENERATION] Final audio shape: {audio_data.shape}")
129
  print(f"[GENERATION] Audio range: [{np.min(audio_data):.3f}, {np.max(audio_data):.3f}]")
130
 
131
- return sample_rate, audio_data
 
 
 
 
 
132
 
133
  except Exception as e:
134
  print(f"[ERROR] Generation failed: {str(e)}")
@@ -139,7 +166,7 @@ def generate_music(text_prompt, duration=10, guidance_scale=3.0):
139
  with gr.Blocks(title="MusicGen Large - Music Generation") as demo:
140
  gr.Markdown("# 🎵 MusicGen Large Music Generator")
141
  gr.Markdown("Generate music from text descriptions using Facebook's MusicGen Large model with elastic compression.")
142
-
143
  with gr.Row():
144
  with gr.Column():
145
  text_input = gr.Textbox(
@@ -175,7 +202,7 @@ with gr.Blocks(title="MusicGen Large - Music Generation") as demo:
175
  format="wav",
176
  interactive=False
177
  )
178
-
179
  with gr.Accordion("Tips", open=False):
180
  gr.Markdown("""
181
  - Be specific in your descriptions (e.g., "slow blues guitar with harmonica")
@@ -219,4 +246,5 @@ with gr.Blocks(title="MusicGen Large - Music Generation") as demo:
219
  """)
220
 
221
  if __name__ == "__main__":
 
222
  demo.launch()
 
4
  import numpy as np
5
  import random
6
  import os
7
+ import tempfile
8
+ import soundfile as sf
9
+
10
  os.environ['ELASTIC_LOG_LEVEL'] = 'DEBUG'
11
  from transformers import AutoProcessor, pipeline
12
  from elastic_models.transformers import MusicgenForConditionalGeneration
13
 
14
+
15
  def set_seed(seed: int = 42):
16
  random.seed(seed)
17
  np.random.seed(seed)
 
21
  torch.backends.cudnn.deterministic = True
22
  torch.backends.cudnn.benchmark = False
23
 
24
+
25
  def cleanup_gpu():
26
  """Clean up GPU memory to avoid TensorRT conflicts."""
27
  if torch.cuda.is_available():
 
29
  torch.cuda.synchronize()
30
  gc.collect()
31
 
32
+
33
+ def cleanup_temp_files():
34
+ """Clean up old temporary audio files."""
35
+ import glob
36
+ import time
37
+ temp_dir = tempfile.gettempdir()
38
+ cutoff_time = time.time() - 3600
39
+ for temp_file in glob.glob(os.path.join(temp_dir, "tmp*.wav")):
40
+ try:
41
+ if os.path.getctime(temp_file) < cutoff_time:
42
+ os.remove(temp_file)
43
+ print(f"[CLEANUP] Removed old temp file: {temp_file}")
44
+ except OSError:
45
+ pass
46
+
47
+
48
  _generator = None
49
  _processor = None
50
 
51
+
52
  def load_model():
 
53
  global _generator, _processor
54
+
55
  if _generator is None:
56
  print("[MODEL] Starting model initialization...")
57
  cleanup_gpu()
58
+
59
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
60
  print(f"[MODEL] Using device: {device}")
61
 
 
63
  _processor = AutoProcessor.from_pretrained(
64
  "facebook/musicgen-large"
65
  )
66
+
67
  print("[MODEL] Loading model...")
68
  model = MusicgenForConditionalGeneration.from_pretrained(
69
  "facebook/musicgen-large",
 
72
  mode="S",
73
  __paged=True,
74
  )
75
+
76
  model.eval()
77
+
78
  print("[MODEL] Creating pipeline...")
79
  _generator = pipeline(
80
  task="text-to-audio",
 
82
  tokenizer=_processor.tokenizer,
83
  device=device,
84
  )
85
+
86
  print("[MODEL] Model initialization completed successfully")
87
 
88
  return _generator, _processor
89
 
90
+
91
  def calculate_max_tokens(duration_seconds):
92
  token_rate = 50
93
  max_new_tokens = int(duration_seconds * token_rate)
94
  print(f"[MODEL] Duration: {duration_seconds}s -> Tokens: {max_new_tokens} (rate: {token_rate})")
95
  return max_new_tokens
96
 
97
+
98
  def generate_music(text_prompt, duration=10, guidance_scale=3.0):
99
  try:
100
  generator, processor = load_model()
101
+
102
  print(f"[GENERATION] Starting generation...")
103
  print(f"[GENERATION] Prompt: '{text_prompt}'")
104
  print(f"[GENERATION] Duration: {duration}s")
105
  print(f"[GENERATION] Guidance scale: {guidance_scale}")
106
+
107
  cleanup_gpu()
108
+
109
  import time
110
  set_seed(42)
111
  print(f"[GENERATION] Using seed: {42}")
112
+
113
  max_new_tokens = calculate_max_tokens(duration)
114
+
115
  generation_params = {
116
  'do_sample': True,
117
  'guidance_scale': guidance_scale,
 
119
  'min_new_tokens': max_new_tokens,
120
  'cache_implementation': 'paged',
121
  }
122
+
123
  prompts = [text_prompt]
124
  outputs = generator(
125
  prompts,
126
  batch_size=1,
127
  generate_kwargs=generation_params
128
  )
129
+
130
  print(f"[GENERATION] Generation completed successfully")
131
+
132
  output = outputs[0]
133
  audio_data = output['audio']
134
  sample_rate = output['sampling_rate']
135
+
136
  print(f"[GENERATION] Audio shape: {audio_data.shape}")
137
  print(f"[GENERATION] Sample rate: {sample_rate}")
138
+
139
  if len(audio_data.shape) > 1:
 
140
  audio_data = audio_data[0] if audio_data.shape[0] < audio_data.shape[1] else audio_data[:, 0]
141
+
142
  audio_data = audio_data.flatten()
143
+
144
  max_val = np.max(np.abs(audio_data))
145
  if max_val > 0:
146
  audio_data = audio_data / max_val * 0.95 # Scale to 95% to avoid clipping
147
+
148
  audio_data = audio_data.astype(np.float32)
149
+
150
  print(f"[GENERATION] Final audio shape: {audio_data.shape}")
151
  print(f"[GENERATION] Audio range: [{np.min(audio_data):.3f}, {np.max(audio_data):.3f}]")
152
 
153
+ with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as tmp_file:
154
+ sf.write(tmp_file.name, audio_data, sample_rate)
155
+ temp_path = tmp_file.name
156
+
157
+ print(f"[GENERATION] Audio saved to: {temp_path}")
158
+ return temp_path
159
 
160
  except Exception as e:
161
  print(f"[ERROR] Generation failed: {str(e)}")
 
166
  with gr.Blocks(title="MusicGen Large - Music Generation") as demo:
167
  gr.Markdown("# 🎵 MusicGen Large Music Generator")
168
  gr.Markdown("Generate music from text descriptions using Facebook's MusicGen Large model with elastic compression.")
169
+
170
  with gr.Row():
171
  with gr.Column():
172
  text_input = gr.Textbox(
 
202
  format="wav",
203
  interactive=False
204
  )
205
+
206
  with gr.Accordion("Tips", open=False):
207
  gr.Markdown("""
208
  - Be specific in your descriptions (e.g., "slow blues guitar with harmonica")
 
246
  """)
247
 
248
  if __name__ == "__main__":
249
+ cleanup_temp_files()
250
  demo.launch()