ghostai1 commited on
Commit
e05183d
·
verified ·
1 Parent(s): 983e1af

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -269
app.py CHANGED
@@ -1,301 +1,75 @@
1
- #!/usr/bin/env python3
2
- # -*- coding: utf-8 -*-
3
- """
4
- GhostAI Music Generator — Zero-GPU & GPU friendly
5
- Python : 3.10
6
- Torch : 2.1 CPU wheels (or CUDA 11.8/12.1)
7
- Gradio : 5.31.0
8
- Last updated: 2025-05-29
9
- """
10
 
11
- import os
12
- import sys
13
- import gc
14
- import time
15
- import random
16
- import warnings
17
- import tempfile
18
-
19
- import psutil
20
- import numpy as np
21
- import torch
22
- import torchaudio
23
- import gradio as gr
24
  from pydub import AudioSegment
25
- from torch.cuda.amp import autocast
26
  from audiocraft.models import MusicGen
27
  from huggingface_hub import login
28
 
29
- # ----------------------------------------------------------------------
30
- # Compatibility shim (torch < 2.3)
31
- # ----------------------------------------------------------------------
32
- if not hasattr(torch, "get_default_device"):
33
- torch.get_default_device = lambda: torch.device(
34
- "cuda" if torch.cuda.is_available() else "cpu"
35
- )
36
-
37
- # ----------------------------------------------------------------------
38
- # Silence warnings & CUDA fragmentation tuning
39
- # ----------------------------------------------------------------------
40
  warnings.filterwarnings("ignore")
41
  os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128"
42
 
43
- # ----------------------------------------------------------------------
44
- # Hugging Face authentication
45
- # ----------------------------------------------------------------------
46
  HF_TOKEN = os.getenv("HF_TOKEN")
47
  if not HF_TOKEN:
48
- print("ERROR: environment variable HF_TOKEN not set.")
49
- sys.exit(1)
50
- try:
51
- login(HF_TOKEN)
52
- except Exception as e:
53
- print(f"ERROR: Hugging Face login failed: {e}")
54
- sys.exit(1)
55
 
56
- # ----------------------------------------------------------------------
57
- # Device setup & cleanup
58
- # ----------------------------------------------------------------------
59
  device = "cuda" if torch.cuda.is_available() else "cpu"
60
  print(f"Running on {device.upper()}")
61
- if device == "cuda":
62
- print(f"GPU: {torch.cuda.get_device_name(0)}")
63
-
64
- def gpu_clean():
65
- if device == "cuda":
66
- torch.cuda.empty_cache()
67
- gc.collect()
68
- torch.cuda.ipc_collect()
69
- torch.cuda.synchronize()
70
 
71
- gpu_clean()
72
-
73
- # ----------------------------------------------------------------------
74
- # Load MusicGen model (fixed checkpoint name)
75
- # ----------------------------------------------------------------------
76
- print("Loading MusicGen ‘medium’ checkpoint …")
77
  musicgen = MusicGen.get_pretrained("medium", device=device)
78
  musicgen.set_generation_params(duration=10, two_step_cfg=False)
79
 
80
- # ----------------------------------------------------------------------
81
- # Resource monitoring
82
- # ----------------------------------------------------------------------
83
- def log_resources(stage=""):
84
- if stage:
85
- print(f"--- {stage} ---")
86
- if device == "cuda":
87
- alloc = torch.cuda.memory_allocated() / 1024**3
88
- resv = torch.cuda.memory_reserved() / 1024**3
89
- print(f"GPU Mem | Alloc {alloc:.2f} GB Reserved {resv:.2f} GB")
90
- print(f"CPU Mem | {psutil.virtual_memory().percent}% used")
91
-
92
- def vram_ok(threshold=3.5):
93
- if device != "cuda":
94
- return True
95
- total = torch.cuda.get_device_properties(0).total_memory / 1024**3
96
- free = total - torch.cuda.memory_allocated() / 1024**3
97
- if free < threshold:
98
- print(f"WARNING: Only {free:.2f} GB VRAM free (<{threshold} GB).")
99
- return free >= threshold
100
-
101
- # ----------------------------------------------------------------------
102
- # Prompt builders
103
- # ----------------------------------------------------------------------
104
- def _make_prompt(base, bpm, drum, synth, steps, bass, gtr, def_bass, def_gtr, flow):
105
- step_txt = f" with {steps}" if steps != "none" else flow.format(bpm=bpm)
106
- drum_txt = f", {drum} drums" if drum != "none" else ""
107
- synth_txt = f", {synth} accents" if synth != "none" else ""
108
- bass_txt = f", {bass}" if bass != "none" else def_bass
109
- gtr_txt = f", {gtr} guitar riffs"if gtr != "none" else def_gtr
110
- return f"{base}{bass_txt}{gtr_txt}{drum_txt}{synth_txt}{step_txt} at {bpm} BPM."
111
-
112
- def set_red_hot_chili_peppers_prompt(bpm, drum, synth, steps, bass, gtr):
113
- return _make_prompt(
114
- "Instrumental funk rock", bpm, drum, synth, steps, bass, gtr,
115
- ", groovy basslines", ", syncopated guitar riffs",
116
- "{bpm} BPM funky flow" if bpm > 120 else "groovy rhythmic flow"
117
- )
118
-
119
- # … include the other set_*_prompt functions exactly as before …
120
-
121
- # ----------------------------------------------------------------------
122
- # Audio post-processing
123
- # ----------------------------------------------------------------------
124
- def apply_eq(seg: AudioSegment):
125
- return seg.low_pass_filter(8000).high_pass_filter(80)
126
-
127
- def apply_fade(seg: AudioSegment, fin=1000, fout=1000):
128
- return seg.fade_in(fin).fade_out(fout)
129
-
130
- # ----------------------------------------------------------------------
131
- # Core generation
132
- # ----------------------------------------------------------------------
133
- def generate_music(
134
- prompt, cfg, top_k, top_p, temp,
135
- total_len, chunk_len, crossfade,
136
- bpm, drum, synth, steps, bass, gtr
137
- ):
138
  if not prompt.strip():
139
- return None, "⚠️ Prompt cannot be empty."
140
- if not vram_ok():
141
- return None, "⚠️ Insufficient VRAM."
142
 
143
- total_len = int(total_len)
144
- chunk_len = int(max(5, min(chunk_len, 15)))
145
- n_chunks = max(1, total_len // chunk_len)
146
- chunk_len = total_len / n_chunks
147
- overlap = min(1.0, crossfade / 1000.0)
148
- render_len = chunk_len + overlap
149
- sr = musicgen.sample_rate
150
- segments = []
151
 
152
- torch.manual_seed(42)
153
- np.random.seed(42)
154
-
155
- start = time.time()
156
- for i in range(n_chunks):
157
- log_resources(f"Before chunk {i+1}")
158
- musicgen.set_generation_params(
159
- duration=render_len,
160
- use_sampling=True,
161
- top_k=top_k,
162
- top_p=top_p,
163
- temperature=temp,
164
- cfg_coef=cfg
165
- )
166
- with torch.no_grad(), autocast():
167
- audio = musicgen.generate([prompt], progress=False)[0]
168
-
169
- audio = audio.cpu().to(torch.float32)
170
- if audio.dim() == 1:
171
- audio = torch.stack([audio, audio])
172
- elif audio.shape[0] == 1:
173
- audio = torch.cat([audio, audio], dim=0)
174
- elif audio.shape[0] != 2:
175
- audio = torch.cat([audio[:1], audio[:1]], dim=0)
176
 
 
177
  with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
178
- torchaudio.save(tmp.name, audio, sr, bits_per_sample=24)
179
  seg = AudioSegment.from_wav(tmp.name)
180
  os.unlink(tmp.name)
181
  segments.append(seg)
182
-
183
- gpu_clean()
184
- log_resources(f"After chunk {i+1}")
185
 
186
  final = segments[0]
187
  for seg in segments[1:]:
188
- final = final.append(seg + 1, crossfade=crossfade)
189
- final = final[: total_len * 1000]
190
- final = apply_fade(apply_eq(final).normalize(headroom=-9.0))
191
 
192
  out_path = "output_cleaned.mp3"
193
- final.export(
194
- out_path,
195
- format="mp3",
196
- bitrate="128k",
197
- tags={"title": "GhostAI Instrumental", "artist": "GhostAI"}
198
- )
199
- log_resources("After final")
200
- print(f"Total generation time: {time.time() - start:.2f}s")
201
  return out_path, "✅ Done!"
202
 
203
- def clear_inputs():
204
- return (
205
- "", 3.0, 250, 0.9, 1.0,
206
- 30, 10, 1000,
207
- 120, "none", "none", "none", "none", "none"
208
- )
209
-
210
- # ----------------------------------------------------------------------
211
- # Gradio UI
212
- # ----------------------------------------------------------------------
213
- css = """
214
- body {
215
- background: linear-gradient(135deg, #0A0A0A 0%, #1C2526 100%);
216
- color: #E0E0E0; font-family: 'Orbitron', sans-serif;
217
- }
218
- .header {
219
- text-align: center; padding: 10px; background: rgba(0,0,0,0.9);
220
- border-bottom: 1px solid #00FF9F;
221
- }
222
- """
223
-
224
- with gr.Blocks(css=css) as demo:
225
- gr.HTML('<div class="header"><h1>👻 GhostAI Music Generator</h1></div>')
226
-
227
- prompt_box = gr.Textbox(label="Instrumental Prompt ✍️", lines=4)
228
- with gr.Row():
229
- gr.Button("RHCP 🌶️").click(
230
- set_red_hot_chili_peppers_prompt,
231
- inputs=[gr.State(120), gr.State("none"), gr.State("none"),
232
- gr.State("none"), gr.State("none"), gr.State("none")],
233
- outputs=prompt_box
234
- )
235
- gr.Button("Nirvana 🎸").click(
236
- set_nirvana_grunge_prompt,
237
- inputs=[gr.State(120), gr.State("none"), gr.State("none"),
238
- gr.State("none"), gr.State("none"), gr.State("none")],
239
- outputs=prompt_box
240
- )
241
- # … add the other genre buttons in the same pattern …
242
-
243
- with gr.Group():
244
- cfg_scale = gr.Slider(1.0, 10.0, value=3.0, step=0.1, label="CFG Scale")
245
- top_k = gr.Slider(10, 500, value=250, step=10, label="Top-K")
246
- top_p = gr.Slider(0.0, 1.0, value=0.9, step=0.05, label="Top-P")
247
- temperature = gr.Slider(0.1, 2.0, value=1.0, step=0.1, label="Temperature")
248
- total_len = gr.Radio([30, 60, 90, 120], value=30, label="Length (s)")
249
- chunk_len = gr.Slider(5, 15, value=10, step=1, label="Chunk (s)")
250
- crossfade = gr.Slider(100, 2000, value=1000, step=100, label="Crossfade (ms)")
251
-
252
- bpm = gr.Slider(60, 180, value=120, label="Tempo (BPM)")
253
- drum_beat = gr.Dropdown(
254
- ["none","standard rock","funk groove","techno kick","jazz swing"],
255
- value="none", label="Drum Beat"
256
- )
257
- synthesizer = gr.Dropdown(
258
- ["none","analog synth","digital pad","arpeggiated synth"],
259
- value="none", label="Synthesizer"
260
- )
261
- steps = gr.Dropdown(
262
- ["none","syncopated steps","steady steps","complex steps"],
263
- value="none", label="Rhythmic Steps"
264
- )
265
- bass_style = gr.Dropdown(
266
- ["none","slap bass","deep bass","melodic bass"],
267
- value="none", label="Bass Style"
268
- )
269
- guitar_style = gr.Dropdown(
270
- ["none","distorted","clean","jangle"],
271
- value="none", label="Guitar Style"
272
- )
273
-
274
- gen_btn = gr.Button("Generate Music 🚀")
275
- clr_btn = gr.Button("Clear 🧹")
276
- out_audio = gr.Audio(label="Generated Track", type="filepath")
277
- status = gr.Textbox(label="Status", interactive=False)
278
-
279
- gen_btn.click(
280
- generate_music,
281
- inputs=[
282
- prompt_box, cfg_scale, top_k, top_p, temperature,
283
- total_len, chunk_len, crossfade,
284
- bpm, drum_beat, synthesizer, steps, bass_style, guitar_style
285
- ],
286
- outputs=[out_audio, status]
287
- )
288
- clr_btn.click(
289
- clear_inputs, None,
290
- [
291
- prompt_box, cfg_scale, top_k, top_p, temperature,
292
- total_len, chunk_len, crossfade,
293
- bpm, drum_beat, synthesizer, steps, bass_style, guitar_style
294
- ]
295
- )
296
-
297
- app = demo.launch(share=False, show_error=True)
298
- try:
299
- demo._server.app.docs_url = demo._server.app.redoc_url = demo._server.app.openapi_url = None
300
- except Exception:
301
- pass
 
 
 
 
 
 
 
 
 
 
1
 
2
+ import os, sys, gc, time, warnings, tempfile
3
+ import torch, torchaudio, numpy as np, gradio as gr
 
 
 
 
 
 
 
 
 
 
 
4
  from pydub import AudioSegment
 
5
  from audiocraft.models import MusicGen
6
  from huggingface_hub import login
7
 
 
 
 
 
 
 
 
 
 
 
 
8
  warnings.filterwarnings("ignore")
9
  os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128"
10
 
11
+ # Auth setup
 
 
12
  HF_TOKEN = os.getenv("HF_TOKEN")
13
  if not HF_TOKEN:
14
+ sys.exit("ERROR: HF_TOKEN not set.")
15
+ login(HF_TOKEN)
 
 
 
 
 
16
 
17
+ # Device setup
 
 
18
  device = "cuda" if torch.cuda.is_available() else "cpu"
19
  print(f"Running on {device.upper()}")
20
+ torch.cuda.empty_cache(); gc.collect()
 
 
 
 
 
 
 
 
21
 
22
+ # Load model correctly
 
 
 
 
 
23
  musicgen = MusicGen.get_pretrained("medium", device=device)
24
  musicgen.set_generation_params(duration=10, two_step_cfg=False)
25
 
26
+ def generate_music(prompt, cfg, top_k, top_p, temp, total_len, chunk_len, crossfade):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  if not prompt.strip():
28
+ return None, "⚠️ Enter a valid prompt."
 
 
29
 
30
+ sample_rate = musicgen.sample_rate
31
+ segments = []
 
 
 
 
 
 
32
 
33
+ for _ in range(total_len // chunk_len):
34
+ with torch.no_grad():
35
+ audio = musicgen.generate([prompt])[0].cpu().float()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
+ audio = audio if audio.dim() == 2 else audio.repeat(2, 1)
38
  with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
39
+ torchaudio.save(tmp.name, audio, sample_rate)
40
  seg = AudioSegment.from_wav(tmp.name)
41
  os.unlink(tmp.name)
42
  segments.append(seg)
43
+ torch.cuda.empty_cache(); gc.collect()
 
 
44
 
45
  final = segments[0]
46
  for seg in segments[1:]:
47
+ final = final.append(seg, crossfade=crossfade)
48
+ final = final[:total_len * 1000].fade_in(1000).fade_out(1000).normalize(-9.0)
 
49
 
50
  out_path = "output_cleaned.mp3"
51
+ final.export(out_path, format="mp3", bitrate="128k", tags={"title": "GhostAI Track", "artist": "GhostAI"})
 
 
 
 
 
 
 
52
  return out_path, "✅ Done!"
53
 
54
+ # Gradio Interface
55
+ demo = gr.Interface(
56
+ fn=generate_music,
57
+ inputs=[
58
+ gr.Textbox(label="Instrumental Prompt"),
59
+ gr.Slider(1.0, 10.0, value=3.0, label="CFG Scale"),
60
+ gr.Slider(10, 500, value=250, label="Top-K"),
61
+ gr.Slider(0.0, 1.0, value=0.9, label="Top-P"),
62
+ gr.Slider(0.1, 2.0, value=1.0, label="Temperature"),
63
+ gr.Radio([30, 60, 90, 120], value=30, label="Length (s)"),
64
+ gr.Slider(5, 15, value=10, label="Chunk Length (s)"),
65
+ gr.Slider(100, 2000, value=1000, label="Crossfade (ms)")
66
+ ],
67
+ outputs=[
68
+ gr.Audio(label="Generated Music"),
69
+ gr.Textbox(label="Status")
70
+ ],
71
+ title="👻 GhostAI Music Generator",
72
+ description="Generate instrumental music using MusicGen Medium model."
73
+ )
74
+
75
+ demo.launch(share=False)