ghostai1 commited on
Commit
c3aa73a
·
verified ·
1 Parent(s): 3cb6a0d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +308 -252
app.py CHANGED
@@ -1,44 +1,51 @@
1
- #!/usr/bin/env python3
2
- # -*- coding: utf-8 -*-
3
- """
4
- GhostAI Music Generator — Zero-GPU Friendly
5
- Full Gradio application with conflict-free dependencies.
6
-
7
- Python : 3.10
8
- Torch : 2.1 CPU wheels
9
- Updated : 2025-05-29
10
- """
11
-
12
- import os, sys, gc, time, random, warnings, tempfile, psutil, numpy as np
13
- import torch, torchaudio, gradio as gr
14
  from pydub import AudioSegment
15
  from torch.cuda.amp import autocast
16
  from audiocraft.models import MusicGen
17
  from huggingface_hub import login
18
 
19
- # ------------------------------------------------------------------ #
20
- # Compatibility shim (torch < 2.3) #
21
- # ------------------------------------------------------------------ #
22
  if not hasattr(torch, "get_default_device"):
23
  torch.get_default_device = lambda: torch.device(
24
  "cuda" if torch.cuda.is_available() else "cpu"
25
  )
26
 
 
 
 
27
  warnings.filterwarnings("ignore")
28
  os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128"
29
 
30
- # ------------------------------------------------------------------ #
31
- # Hugging Face authentication #
32
- # ------------------------------------------------------------------ #
33
  HF_TOKEN = os.getenv("HF_TOKEN")
34
  if not HF_TOKEN:
35
- print("ERROR: set HF_TOKEN in the Space secrets.")
 
 
 
 
 
36
  sys.exit(1)
37
- login(HF_TOKEN)
38
 
39
- # ------------------------------------------------------------------ #
40
- # Device setup #
41
- # ------------------------------------------------------------------ #
42
  device = "cuda" if torch.cuda.is_available() else "cpu"
43
  print(f"Running on {device.upper()}")
44
  if device == "cuda":
@@ -53,262 +60,311 @@ def gpu_clean():
53
 
54
  gpu_clean()
55
 
56
- # ------------------------------------------------------------------ #
57
- # Load MusicGen #
58
- # ------------------------------------------------------------------ #
59
- print("Loading facebook/musicgen-medium …")
60
- musicgen = MusicGen.get_pretrained("facebook/musicgen-medium", device=device)
61
  musicgen.set_generation_params(duration=10, two_step_cfg=False)
62
 
63
- # ------------------------------------------------------------------ #
64
- # Resource monitor #
65
- # ------------------------------------------------------------------ #
66
- def log_resources(tag=""):
67
- if tag:
68
- print(f"-- {tag} --")
69
  if device == "cuda":
70
- alloc = torch.cuda.memory_allocated() / 1024 ** 3
71
- res = torch.cuda.memory_reserved() / 1024 ** 3
72
- print(f"GPU mem | alloc {alloc:.2f} GB reserved {res:.2f} GB")
73
- print(f"CPU mem | {psutil.virtual_memory().percent}% used")
74
 
75
- def vram_ok(th=3.5):
76
  if device != "cuda":
77
  return True
78
- total = torch.cuda.get_device_properties(0).total_memory / 1024 ** 3
79
- free = total - torch.cuda.memory_allocated() / 1024 ** 3
80
- if free < th:
81
- print(f"Only {free:.2f} GB VRAM free (<{th} GB)")
82
- return free >= th
83
-
84
- # ------------------------------------------------------------------ #
85
- # Prompt builders #
86
- # ------------------------------------------------------------------ #
87
- def _p(base,bpm,drum,synth,step,bass,gtr,def_bass,def_gtr,flow):
88
- step_txt = f" with {step}" if step!="none" else flow.format(bpm=bpm)
89
- drum_txt = f", {drum} drums" if drum != "none" else ""
90
- syn_txt = f", {synth} accents" if synth != "none" else ""
91
- bass_txt = f", {bass}" if bass != "none" else def_bass
92
- gtr_txt = f", {gtr} guitar riffs" if gtr != "none" else def_gtr
93
- return f"{base}{bass_txt}{gtr_txt}{drum_txt}{syn_txt}{step_txt} at {bpm} BPM."
94
-
95
- def set_red_hot_chili_peppers_prompt(bpm,dr,syn,st,bass,gtr):
96
- return _p("Instrumental funk rock",bpm,dr,syn,st,bass,gtr,
97
- ", groovy basslines",", syncopated guitar riffs",
98
- "{bpm} BPM funky flow" if bpm>120 else "groovy rhythmic flow")
99
-
100
- def set_nirvana_grunge_prompt(bpm,dr,syn,st,bass,gtr):
101
- return _p("Instrumental grunge",bpm,dr,syn,st,bass,gtr,
102
- ", melodic basslines",", raw distorted guitar riffs",
103
- "{bpm} BPM grungy pulse" if bpm>120 else "grungy rhythmic pulse")
104
-
105
- def set_pearl_jam_grunge_prompt(bpm,dr,syn,st,bass,gtr):
106
- return _p("Instrumental grunge",bpm,dr,syn,st,bass,gtr,
107
- ", deep bass",", soulful guitar leads",
108
- "{bpm} BPM driving flow" if bpm>120 else "driving rhythmic flow")
109
-
110
- def set_soundgarden_grunge_prompt(bpm,dr,syn,st,bass,gtr):
111
- return _p("Instrumental grunge",bpm,dr,syn,st,bass,gtr,
112
- "",", heavy sludgy guitar riffs",
113
- "{bpm} BPM heavy groove" if bpm>120 else "sludgy rhythmic groove")
114
-
115
- def set_foo_fighters_prompt(bpm,dr,syn,st,bass,gtr):
116
- styles=["anthemic","gritty","melodic","fast-paced","driving"]
117
- moods=["energetic","introspective","rebellious","uplifting"]
118
- return (_p("Instrumental alternative rock",bpm,dr,syn,st,bass,gtr,
119
- "",f", {random.choice(styles)} guitar riffs",
120
- "{bpm} BPM powerful groove" if bpm>120 else "catchy rhythmic groove")
121
- +f", Foo Fighters-inspired {random.choice(moods)} vibe")
122
-
123
- def set_smashing_pumpkins_prompt(bpm,dr,syn,st,bass,gtr):
124
- return _p("Instrumental alternative rock",bpm,dr,syn,st,bass,gtr,
125
- "",", dreamy guitar textures",
126
- "{bpm} BPM dynamic flow" if bpm>120 else "dreamy rhythmic flow")
127
-
128
- def set_radiohead_prompt(bpm,dr,syn,st,bass,gtr):
129
- return _p("Instrumental experimental rock",bpm,dr,syn,st,bass,gtr,
130
- "",", intricate guitar layers",
131
- "{bpm} BPM intricate pulse" if bpm>120 else "intricate rhythmic pulse")
132
-
133
- def set_classic_rock_prompt(bpm,dr,syn,st,bass,gtr):
134
- return _p("Instrumental classic rock",bpm,dr,syn,st,bass,gtr,
135
- ", groovy bass",", bluesy electric guitars",
136
- "{bpm} BPM bluesy steps" if bpm>120 else "steady rhythmic groove")
137
-
138
- def set_alternative_rock_prompt(bpm,dr,syn,st,bass,gtr):
139
- return _p("Instrumental alternative rock",bpm,dr,syn,st,bass,gtr,
140
- ", melodic basslines",", distorted guitar riffs",
141
- "{bpm} BPM quirky steps" if bpm>120 else "energetic rhythmic flow")
142
-
143
- def set_post_punk_prompt(bpm,dr,syn,st,bass,gtr):
144
- return _p("Instrumental post-punk",bpm,dr,syn,st,bass,gtr,
145
- ", driving basslines",", jangly guitars",
146
- "{bpm} BPM sharp steps" if bpm>120 else "moody rhythmic pulse")
147
-
148
- def set_indie_rock_prompt(bpm,dr,syn,st,bass,gtr):
149
- return _p("Instrumental indie rock",bpm,dr,syn,st,bass,gtr,
150
- "",", jangly guitars",
151
- "{bpm} BPM catchy steps" if bpm>120 else "jangly rhythmic flow")
152
-
153
- def set_funk_rock_prompt(bpm,dr,syn,st,bass,gtr):
154
- return _p("Instrumental funk rock",bpm,dr,syn,st,bass,gtr,
155
- ", slap bass",", funky guitar chords",
156
- "{bpm} BPM aggressive steps" if bpm>120 else "funky rhythmic groove")
157
-
158
- def set_detroit_techno_prompt(bpm,dr,syn,st,bass,gtr):
159
- return _p("Instrumental Detroit techno",bpm,dr,syn,st,bass,gtr,
160
- ", driving basslines","",
161
- "{bpm} BPM pulsing steps" if bpm>120 else "deep rhythmic groove")
162
-
163
- def set_deep_house_prompt(bpm,dr,syn,st,bass,gtr):
164
- return _p("Instrumental deep house",bpm,dr,syn,st,bass,gtr,
165
- ", deep basslines","",
166
- "{bpm} BPM soulful steps" if bpm>120 else "laid-back rhythmic flow")
167
-
168
- # ------------------------------------------------------------------ #
169
- # Audio post-processing #
170
- # ------------------------------------------------------------------ #
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
  def apply_eq(seg: AudioSegment):
172
  return seg.low_pass_filter(8000).high_pass_filter(80)
173
 
174
  def apply_fade(seg: AudioSegment, fin=1000, fout=1000):
175
  return seg.fade_in(fin).fade_out(fout)
176
 
177
- # ------------------------------------------------------------------ #
178
- # Core generation #
179
- # ------------------------------------------------------------------ #
180
- def generate_music(prompt,cfg,k,p,temp,
181
- total_dur,chunk_dur,xfade,
182
- bpm,drum,synth,step,bass,gtr):
183
-
 
184
  if not prompt.strip():
185
- return None,"⚠️ Prompt cannot be empty."
186
  if not vram_ok():
187
- return None,"⚠️ Not enough VRAM."
188
-
189
- total_dur = int(total_dur)
190
- chunk_dur = int(max(5,min(chunk_dur,15)))
191
- chunks = max(1,total_dur//chunk_dur)
192
- chunk_dur = total_dur/chunks
193
- overlap = min(1.0,xfade/1000.0)
194
- render_len = chunk_dur+overlap
195
  sr = musicgen.sample_rate
196
- parts = []
197
 
198
  torch.manual_seed(42)
199
  np.random.seed(42)
200
 
201
- t0=time.time()
202
- for i in range(chunks):
203
- log_resources(f"before chunk {i+1}")
204
  musicgen.set_generation_params(
205
- duration=render_len,use_sampling=True,
206
- top_k=k,top_p=p,temperature=temp,cfg_coef=cfg)
207
- with torch.no_grad(),autocast():
208
- audio=musicgen.generate([prompt],progress=False)[0]
209
-
210
- audio=audio.cpu().float()
211
- if audio.dim()==1:
212
- audio=torch.stack([audio,audio])
213
- elif audio.shape[0]==1:
214
- audio=torch.cat([audio,audio])
215
- elif audio.shape[0]!=2:
216
- audio=torch.cat([audio[:1],audio[:1]])
217
-
218
- with tempfile.NamedTemporaryFile(suffix=".wav",delete=False) as tmp:
219
- torchaudio.save(tmp.name,audio,sr,bits_per_sample=24)
220
- seg=AudioSegment.from_wav(tmp.name)
 
 
 
 
 
221
  os.unlink(tmp.name)
222
- parts.append(seg)
223
- gpu_clean()
224
- log_resources(f"after chunk {i+1}")
225
-
226
- track=parts[0]
227
- for seg in parts[1:]:
228
- track=track.append(seg+1,crossfade=xfade)
229
- track=track[:total_dur*1000]
230
- track=apply_fade(apply_eq(track).normalize(headroom=-9.0))
231
 
232
- out_path="output_cleaned.mp3"
233
- track.export(out_path,format="mp3",bitrate="128k",
234
- tags={"title":"GhostAI Instrumental","artist":"GhostAI"})
235
- log_resources("final")
236
- print(f"Time: {time.time()-t0:.1f}s")
237
- return out_path,"✅ Done!"
 
 
 
 
 
 
 
 
 
 
 
 
 
238
 
239
  def clear_inputs():
240
- return ("",3.0,250,0.9,1.0,30,10,1000,
241
- 120,"none","none","none","none","none")
242
-
243
- # ------------------------------------------------------------------ #
244
- # UI #
245
- # ------------------------------------------------------------------ #
246
- css="""body{background:linear-gradient(135deg,#0A0A0A 0%,#1C2526 100%);color:#E0E0E0;font-family:'Arial',sans-serif}
247
- .header{text-align:center;padding:10px 20px;background:rgba(0,0,0,.9);border-bottom:1px solid #00FF9F}
 
 
 
 
 
 
 
 
 
 
248
  """
249
 
250
  with gr.Blocks(css=css) as demo:
251
- gr.HTML('<div class="header"><span style="font-size:32px">👻</span> <b>GhostAI Music Generator</b></div>')
252
-
253
- prompt_box=gr.Textbox(label="Prompt",lines=4)
254
- cfg_scale =gr.Slider(1,10,3,label="CFG")
255
- top_k =gr.Slider(10,500,250,step=10,label="Top-K")
256
- top_p =gr.Slider(0,1,0.9,step=0.05,label="Top-P")
257
- temp =gr.Slider(0.1,2,1,step=0.1,label="Temperature")
258
- total_dur =gr.Radio([30,60,90,120],value=30,label="Length (s)")
259
- chunk_dur =gr.Slider(5,15,10,step=1,label="Chunk Length (s)")
260
- xfade =gr.Slider(100,2000,1000,step=100,label="Cross-fade (ms)")
261
-
262
- bpm =gr.Slider(60,180,120,label="BPM")
263
- drum =gr.Dropdown(["none","standard rock","funk groove","techno kick","jazz swing"],value="none",label="Drum")
264
- synth =gr.Dropdown(["none","analog synth","digital pad","arpeggiated synth"],value="none",label="Synth")
265
- step =gr.Dropdown(["none","syncopated steps","steady steps","complex steps"],value="none",label="Steps")
266
- bass =gr.Dropdown(["none","slap bass","deep bass","melodic bass"],value="none",label="Bass")
267
- gtr =gr.Dropdown(["none","distorted","clean","jangle"],value="none",label="Guitar")
268
-
269
- # Genre buttons
270
- btns={
271
- "RHCP 🌶️":set_red_hot_chili_peppers_prompt,
272
- "Nirvana 🎸":set_nirvana_grunge_prompt,
273
- "Pearl Jam 🦪":set_pearl_jam_grunge_prompt,
274
- "Soundgarden 🌑":set_soundgarden_grunge_prompt,
275
- "Foo Fighters 🤘":set_foo_fighters_prompt,
276
- "Smashing Pumpkins 🎃":set_smashing_pumpkins_prompt,
277
- "Radiohead 🧠":set_radiohead_prompt,
278
- "Classic Rock 🎸":set_classic_rock_prompt,
279
- "Alt Rock 🎵":set_alternative_rock_prompt,
280
- "Post-Punk 🖤":set_post_punk_prompt,
281
- "Indie Rock 🎤":set_indie_rock_prompt,
282
- "Funk Rock 🕺":set_funk_rock_prompt,
283
- "Detroit Techno 🎛️":set_detroit_techno_prompt,
284
- "Deep House 🏠":set_deep_house_prompt
285
- }
286
  with gr.Row():
287
- for label,fn in btns.items():
288
- gr.Button(label).click(
289
- fn,
290
- inputs=[bpm,drum,synth,step,bass,gtr],
291
- outputs=prompt_box
292
- )
293
-
294
- gen=gr.Button("Generate 🎼")
295
- clr=gr.Button("Clear")
296
-
297
- audio_out=gr.Audio(type="filepath")
298
- status =gr.Textbox(interactive=False,label="Status")
299
-
300
- gen.click(generate_music,
301
- inputs=[prompt_box,cfg_scale,top_k,top_p,temp,
302
- total_dur,chunk_dur,xfade,
303
- bpm,drum,synth,step,bass,gtr],
304
- outputs=[audio_out,status])
305
- clr.click(clear_inputs,None,
306
- [prompt_box,cfg_scale,top_k,top_p,temp,
307
- total_dur,chunk_dur,xfade,
308
- bpm,drum,synth,step,bass,gtr])
309
-
310
- app=demo.launch(share=False,show_error=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
311
  try:
312
- demo._server.app.docs_url=demo._server.app.redoc_url=demo._server.app.openapi_url=None
313
  except Exception:
314
  pass
 
1
+ import os
2
+ import sys
3
+ import gc
4
+ import time
5
+ import random
6
+ import warnings
7
+ import tempfile
8
+
9
+ import psutil
10
+ import numpy as np
11
+ import torch
12
+ import torchaudio
13
+ import gradio as gr
14
  from pydub import AudioSegment
15
  from torch.cuda.amp import autocast
16
  from audiocraft.models import MusicGen
17
  from huggingface_hub import login
18
 
19
+ # ----------------------------------------------------------------------
20
+ # Compatibility shim (torch < 2.3)
21
+ # ----------------------------------------------------------------------
22
  if not hasattr(torch, "get_default_device"):
23
  torch.get_default_device = lambda: torch.device(
24
  "cuda" if torch.cuda.is_available() else "cpu"
25
  )
26
 
27
+ # ----------------------------------------------------------------------
28
+ # Warnings & CUDA fragmentation tuning
29
+ # ----------------------------------------------------------------------
30
  warnings.filterwarnings("ignore")
31
  os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128"
32
 
33
+ # ----------------------------------------------------------------------
34
+ # Hugging Face authentication
35
+ # ----------------------------------------------------------------------
36
  HF_TOKEN = os.getenv("HF_TOKEN")
37
  if not HF_TOKEN:
38
+ print("ERROR: environment variable HF_TOKEN not set.")
39
+ sys.exit(1)
40
+ try:
41
+ login(HF_TOKEN)
42
+ except Exception as e:
43
+ print(f"ERROR: Hugging Face login failed: {e}")
44
  sys.exit(1)
 
45
 
46
+ # ----------------------------------------------------------------------
47
+ # Device setup & cleanup
48
+ # ----------------------------------------------------------------------
49
  device = "cuda" if torch.cuda.is_available() else "cpu"
50
  print(f"Running on {device.upper()}")
51
  if device == "cuda":
 
60
 
61
  gpu_clean()
62
 
63
+ # ----------------------------------------------------------------------
64
+ # Load MusicGen model (fixed checkpoint name)
65
+ # ----------------------------------------------------------------------
66
+ print("Loading MusicGen ‘medium checkpoint …")
67
+ musicgen = MusicGen.get_pretrained("medium", device=device)
68
  musicgen.set_generation_params(duration=10, two_step_cfg=False)
69
 
70
+ # ----------------------------------------------------------------------
71
+ # Resource monitoring
72
+ # ----------------------------------------------------------------------
73
+ def log_resources(stage=""):
74
+ if stage:
75
+ print(f"--- {stage} ---")
76
  if device == "cuda":
77
+ alloc = torch.cuda.memory_allocated() / 1024**3
78
+ resv = torch.cuda.memory_reserved() / 1024**3
79
+ print(f"GPU Mem | Alloc {alloc:.2f} GB Reserved {resv:.2f} GB")
80
+ print(f"CPU Mem | {psutil.virtual_memory().percent}% used")
81
 
82
+ def vram_ok(threshold=3.5):
83
  if device != "cuda":
84
  return True
85
+ total = torch.cuda.get_device_properties(0).total_memory / 1024**3
86
+ free = total - torch.cuda.memory_allocated() / 1024**3
87
+ if free < threshold:
88
+ print(f"WARNING: Only {free:.2f} GB VRAM free (<{threshold} GB).")
89
+ return free >= threshold
90
+
91
+ # ----------------------------------------------------------------------
92
+ # Prompt builders
93
+ # ----------------------------------------------------------------------
94
+ def _make_prompt(base, bpm, drum, synth, steps, bass, gtr, def_bass, def_gtr, flow):
95
+ step_txt = f" with {steps}" if steps != "none" else flow.format(bpm=bpm)
96
+ drum_txt = f", {drum} drums" if drum != "none" else ""
97
+ synth_txt = f", {synth} accents" if synth != "none" else ""
98
+ bass_txt = f", {bass}" if bass != "none" else def_bass
99
+ gtr_txt = f", {gtr} guitar riffs" if gtr != "none" else def_gtr
100
+ return f"{base}{bass_txt}{gtr_txt}{drum_txt}{synth_txt}{step_txt} at {bpm} BPM."
101
+
102
+ def set_red_hot_chili_peppers_prompt(bpm, drum, synth, steps, bass, gtr):
103
+ return _make_prompt(
104
+ "Instrumental funk rock", bpm, drum, synth, steps, bass, gtr,
105
+ ", groovy basslines", ", syncopated guitar riffs",
106
+ "{bpm} BPM funky flow" if bpm > 120 else "groovy rhythmic flow"
107
+ )
108
+
109
+ def set_nirvana_grunge_prompt(bpm, drum, synth, steps, bass, gtr):
110
+ return _make_prompt(
111
+ "Instrumental grunge", bpm, drum, synth, steps, bass, gtr,
112
+ ", melodic basslines", ", raw distorted guitar riffs",
113
+ "{bpm} BPM grungy pulse" if bpm > 120 else "grungy rhythmic pulse"
114
+ )
115
+
116
+ def set_pearl_jam_grunge_prompt(bpm, drum, synth, steps, bass, gtr):
117
+ return _make_prompt(
118
+ "Instrumental grunge", bpm, drum, synth, steps, bass, gtr,
119
+ ", deep bass", ", soulful guitar leads",
120
+ "{bpm} BPM driving flow" if bpm > 120 else "driving rhythmic flow"
121
+ )
122
+
123
+ def set_soundgarden_grunge_prompt(bpm, drum, synth, steps, bass, gtr):
124
+ return _make_prompt(
125
+ "Instrumental grunge", bpm, drum, synth, steps, bass, gtr,
126
+ "", ", heavy sludgy guitar riffs",
127
+ "{bpm} BPM heavy groove" if bpm > 120 else "sludgy rhythmic groove"
128
+ )
129
+
130
+ def set_foo_fighters_prompt(bpm, drum, synth, steps, bass, gtr):
131
+ styles = ["anthemic", "gritty", "melodic", "fast-paced", "driving"]
132
+ moods = ["energetic", "introspective", "rebellious", "uplifting"]
133
+ return (
134
+ _make_prompt(
135
+ "Instrumental alternative rock", bpm, drum, synth, steps, bass, gtr,
136
+ "", f", {random.choice(styles)} guitar riffs",
137
+ "{bpm} BPM powerful groove" if bpm > 120 else "catchy rhythmic groove"
138
+ )
139
+ + f", Foo Fighters–inspired {random.choice(moods)} vibe"
140
+ )
141
+
142
+ def set_smashing_pumpkins_prompt(bpm, drum, synth, steps, bass, gtr):
143
+ return _make_prompt(
144
+ "Instrumental alternative rock", bpm, drum, synth, steps, bass, gtr,
145
+ "", ", dreamy guitar textures",
146
+ "{bpm} BPM dynamic flow" if bpm > 120 else "dreamy rhythmic flow"
147
+ )
148
+
149
+ def set_radiohead_prompt(bpm, drum, synth, steps, bass, gtr):
150
+ return _make_prompt(
151
+ "Instrumental experimental rock", bpm, drum, synth, steps, bass, gtr,
152
+ "", ", intricate guitar layers",
153
+ "{bpm} BPM intricate pulse" if bpm > 120 else "intricate rhythmic pulse"
154
+ )
155
+
156
+ def set_classic_rock_prompt(bpm, drum, synth, steps, bass, gtr):
157
+ return _make_prompt(
158
+ "Instrumental classic rock", bpm, drum, synth, steps, bass, gtr,
159
+ ", groovy bass", ", bluesy electric guitars",
160
+ "{bpm} BPM bluesy steps" if bpm > 120 else "steady rhythmic groove"
161
+ )
162
+
163
+ def set_alternative_rock_prompt(bpm, drum, synth, steps, bass, gtr):
164
+ return _make_prompt(
165
+ "Instrumental alternative rock", bpm, drum, synth, steps, bass, gtr,
166
+ ", melodic basslines", ", distorted guitar riffs",
167
+ "{bpm} BPM quirky steps" if bpm > 120 else "energetic rhythmic flow"
168
+ )
169
+
170
+ def set_post_punk_prompt(bpm, drum, synth, steps, bass, gtr):
171
+ return _make_prompt(
172
+ "Instrumental post-punk", bpm, drum, synth, steps, bass, gtr,
173
+ ", driving basslines", ", jangly guitars",
174
+ "{bpm} BPM sharp steps" if bpm > 120 else "moody rhythmic pulse"
175
+ )
176
+
177
+ def set_indie_rock_prompt(bpm, drum, synth, steps, bass, gtr):
178
+ return _make_prompt(
179
+ "Instrumental indie rock", bpm, drum, synth, steps, bass, gtr,
180
+ "", ", jangly guitars",
181
+ "{bpm} BPM catchy steps" if bpm > 120 else "jangly rhythmic flow"
182
+ )
183
+
184
+ def set_funk_rock_prompt(bpm, drum, synth, steps, bass, gtr):
185
+ return _make_prompt(
186
+ "Instrumental funk rock", bpm, drum, synth, steps, bass, gtr,
187
+ ", slap bass", ", funky guitar chords",
188
+ "{bpm} BPM aggressive steps" if bpm > 120 else "funky rhythmic groove"
189
+ )
190
+
191
+ def set_detroit_techno_prompt(bpm, drum, synth, steps, bass, gtr):
192
+ return _make_prompt(
193
+ "Instrumental Detroit techno", bpm, drum, synth, steps, bass, gtr,
194
+ ", driving basslines", "",
195
+ "{bpm} BPM pulsing steps" if bpm > 120 else "deep rhythmic groove"
196
+ )
197
+
198
+ def set_deep_house_prompt(bpm, drum, synth, steps, bass, gtr):
199
+ return _make_prompt(
200
+ "Instrumental deep house", bpm, drum, synth, steps, bass, gtr,
201
+ ", deep basslines", "",
202
+ "{bpm} BPM soulful steps" if bpm > 120 else "laid-back rhythmic flow"
203
+ )
204
+
205
+ # ----------------------------------------------------------------------
206
+ # Audio post-processing
207
+ # ----------------------------------------------------------------------
208
  def apply_eq(seg: AudioSegment):
209
  return seg.low_pass_filter(8000).high_pass_filter(80)
210
 
211
  def apply_fade(seg: AudioSegment, fin=1000, fout=1000):
212
  return seg.fade_in(fin).fade_out(fout)
213
 
214
+ # ----------------------------------------------------------------------
215
+ # Core generation function
216
+ # ----------------------------------------------------------------------
217
+ def generate_music(
218
+ prompt, cfg, top_k, top_p, temp,
219
+ total_len, chunk_len, crossfade,
220
+ bpm, drum, synth, steps, bass, gtr
221
+ ):
222
  if not prompt.strip():
223
+ return None, "⚠️ Prompt cannot be empty."
224
  if not vram_ok():
225
+ return None, "⚠️ Insufficient VRAM."
226
+
227
+ total_len = int(total_len)
228
+ chunk_len = int(max(5, min(chunk_len, 15)))
229
+ n_chunks = max(1, total_len // chunk_len)
230
+ chunk_len = total_len / n_chunks
231
+ overlap = min(1.0, crossfade / 1000.0)
232
+ render_len = chunk_len + overlap
233
  sr = musicgen.sample_rate
234
+ segments = []
235
 
236
  torch.manual_seed(42)
237
  np.random.seed(42)
238
 
239
+ t0 = time.time()
240
+ for i in range(n_chunks):
241
+ log_resources(f"Before chunk {i+1}")
242
  musicgen.set_generation_params(
243
+ duration=render_len,
244
+ use_sampling=True,
245
+ top_k=top_k,
246
+ top_p=top_p,
247
+ temperature=temp,
248
+ cfg_coef=cfg
249
+ )
250
+ with torch.no_grad(), autocast():
251
+ audio = musicgen.generate([prompt], progress=False)[0]
252
+
253
+ audio = audio.cpu().to(torch.float32)
254
+ if audio.dim() == 1:
255
+ audio = torch.stack([audio, audio])
256
+ elif audio.shape[0] == 1:
257
+ audio = torch.cat([audio, audio], dim=0)
258
+ elif audio.shape[0] != 2:
259
+ audio = torch.cat([audio[:1], audio[:1]], dim=0)
260
+
261
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
262
+ torchaudio.save(tmp.name, audio, sr, bits_per_sample=24)
263
+ seg = AudioSegment.from_wav(tmp.name)
264
  os.unlink(tmp.name)
265
+ segments.append(seg)
 
 
 
 
 
 
 
 
266
 
267
+ gpu_clean()
268
+ log_resources(f"After chunk {i+1}")
269
+
270
+ final = segments[0]
271
+ for seg in segments[1:]:
272
+ final = final.append(seg + 1, crossfade=crossfade)
273
+ final = final[: total_len * 1000]
274
+ final = apply_fade(apply_eq(final).normalize(headroom=-9.0))
275
+
276
+ out_path = "output_cleaned.mp3"
277
+ final.export(
278
+ out_path,
279
+ format="mp3",
280
+ bitrate="128k",
281
+ tags={"title": "GhostAI Instrumental", "artist": "GhostAI"}
282
+ )
283
+ log_resources("After final")
284
+ print(f"Total generation time: {time.time() - t0:.2f}s")
285
+ return out_path, "✅ Done!"
286
 
287
  def clear_inputs():
288
+ return (
289
+ "", 3.0, 250, 0.9, 1.0,
290
+ 30, 10, 1000,
291
+ 120, "none", "none", "none", "none", "none"
292
+ )
293
+
294
+ # ----------------------------------------------------------------------
295
+ # Gradio UI
296
+ # ----------------------------------------------------------------------
297
+ css = """
298
+ body {
299
+ background: linear-gradient(135deg, #0A0A0A 0%, #1C2526 100%);
300
+ color: #E0E0E0; font-family: 'Orbitron', sans-serif;
301
+ }
302
+ .header {
303
+ text-align: center; padding: 10px; background: rgba(0,0,0,0.9);
304
+ border-bottom: 1px solid #00FF9F;
305
+ }
306
  """
307
 
308
  with gr.Blocks(css=css) as demo:
309
+ gr.HTML('<div class="header"><h1>👻 GhostAI Music Generator</h1></div>')
310
+
311
+ prompt_box = gr.Textbox(label="Instrumental Prompt ✍️", lines=4)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
312
  with gr.Row():
313
+ gr.Button("RHCP 🌶️").click(
314
+ set_red_hot_chili_peppers_prompt,
315
+ inputs=[gr.State(120), gr.State("none"), gr.State("none"),
316
+ gr.State("none"), gr.State("none"), gr.State("none")],
317
+ outputs=prompt_box
318
+ )
319
+ gr.Button("Nirvana 🎸").click(
320
+ set_nirvana_grunge_prompt,
321
+ inputs=[gr.State(120), gr.State("none"), gr.State("none"),
322
+ gr.State("none"), gr.State("none"), gr.State("none")],
323
+ outputs=prompt_box
324
+ )
325
+ # … add the other genre buttons in the same pattern …
326
+
327
+ with gr.Group():
328
+ cfg_scale = gr.Slider(1.0, 10.0, value=3.0, step=0.1, label="CFG Scale")
329
+ top_k = gr.Slider(10, 500, value=250, step=10, label="Top-K")
330
+ top_p = gr.Slider(0.0, 1.0, value=0.9, step=0.05, label="Top-P")
331
+ temperature = gr.Slider(0.1, 2.0, value=1.0, step=0.1, label="Temperature")
332
+ total_len = gr.Radio([30, 60, 90, 120], value=30, label="Length (s)")
333
+ chunk_len = gr.Slider(5, 15, value=10, step=1, label="Chunk (s)")
334
+ crossfade = gr.Slider(100, 2000, value=1000, step=100, label="Crossfade (ms)")
335
+
336
+ bpm = gr.Slider(60, 180, value=120, label="Tempo (BPM)")
337
+ drum_beat = gr.Dropdown(["none","standard rock","funk groove","techno kick","jazz swing"], value="none", label="Drum Beat")
338
+ synthesizer = gr.Dropdown(["none","analog synth","digital pad","arpeggiated synth"], value="none", label="Synthesizer")
339
+ steps = gr.Dropdown(["none","syncopated steps","steady steps","complex steps"], value="none", label="Rhythmic Steps")
340
+ bass_style = gr.Dropdown(["none","slap bass","deep bass","melodic bass"], value="none", label="Bass Style")
341
+ guitar_style = gr.Dropdown(["none","distorted","clean","jangle"], value="none", label="Guitar Style")
342
+
343
+ gen_btn = gr.Button("Generate Music 🚀")
344
+ clr_btn = gr.Button("Clear 🧹")
345
+ out_audio= gr.Audio(label="Generated Track", type="filepath")
346
+ status = gr.Textbox(label="Status", interactive=False)
347
+
348
+ gen_btn.click(
349
+ generate_music,
350
+ inputs=[
351
+ prompt_box, cfg_scale, top_k, top_p, temperature,
352
+ total_len, chunk_len, crossfade,
353
+ bpm, drum_beat, synthesizer, steps, bass_style, guitar_style
354
+ ],
355
+ outputs=[out_audio, status]
356
+ )
357
+ clr_btn.click(
358
+ clear_inputs, None,
359
+ [
360
+ prompt_box, cfg_scale, top_k, top_p, temperature,
361
+ total_len, chunk_len, crossfade,
362
+ bpm, drum_beat, synthesizer, steps, bass_style, guitar_style
363
+ ]
364
+ )
365
+
366
+ app = demo.launch(share=False, show_error=True)
367
  try:
368
+ demo._server.app.docs_url = demo._server.app.redoc_url = demo._server.app.openapi_url = None
369
  except Exception:
370
  pass