ghostai1 commited on
Commit
80cc045
Β·
verified Β·
1 Parent(s): 1279448

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +183 -78
app.py CHANGED
@@ -1,108 +1,213 @@
1
  #!/usr/bin/env python3
2
- # GhostAI Music Generator Hugging Face Spaces GPU-Compatible
3
-
4
- import spaces # <--- Must be imported FIRST before torch and CUDA
5
-
6
- import os
7
- import sys
8
- import gc
9
- import warnings
10
- import tempfile
11
- import torch
12
- import torchaudio
13
- import numpy as np
14
- import gradio as gr
15
  from pydub import AudioSegment
 
16
  from audiocraft.models import MusicGen
17
  from huggingface_hub import login
18
 
 
 
 
 
 
 
 
 
19
  warnings.filterwarnings("ignore")
 
20
 
21
- # Hugging Face token auth
 
 
22
  HF_TOKEN = os.getenv("HF_TOKEN")
23
  if not HF_TOKEN:
24
- sys.exit("ERROR: HF_TOKEN not set.")
25
  login(HF_TOKEN)
26
 
27
- # Device setup
 
 
28
  device = "cuda" if torch.cuda.is_available() else "cpu"
29
  print(f"Running on {device.upper()}")
30
-
31
- os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128"
32
-
33
- # Load MusicGen model explicitly on GPU
34
- musicgen = MusicGen.get_pretrained("medium")
35
- musicgen.lm.to(device)
36
- musicgen.set_generation_params(duration=10)
37
-
38
- def clean_resources():
39
  if device == "cuda":
40
  torch.cuda.empty_cache()
41
  gc.collect()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
- @spaces.GPU # <-- Correct GPU decorator for HF
44
- def generate_music(prompt, cfg, top_k, top_p, temp, total_len, chunk_len, crossfade):
45
  if not prompt.strip():
46
- return None, "⚠️ Enter a valid prompt."
47
-
48
- sample_rate = musicgen.sample_rate
49
- segments = []
50
- chunks = max(1, total_len // chunk_len)
51
-
52
- for _ in range(chunks):
53
- with torch.no_grad():
54
- audio = musicgen.generate(
55
- [prompt],
56
- temperature=temp,
57
- cfg_coef=cfg,
58
- top_k=top_k,
59
- top_p=top_p,
60
- duration=chunk_len,
61
- progress=False
62
- )[0].cpu().float()
63
-
64
- if audio.dim() == 1:
65
- audio = audio.unsqueeze(0).repeat(2, 1)
66
- elif audio.shape[0] == 1:
67
- audio = audio.repeat(2, 1)
 
 
 
 
 
 
68
 
69
  with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
70
  torchaudio.save(tmp.name, audio, sample_rate)
71
- segment = AudioSegment.from_wav(tmp.name)
72
  os.unlink(tmp.name)
73
- segments.append(segment)
74
-
75
- clean_resources()
76
 
77
- final = segments[0]
78
  for seg in segments[1:]:
79
- final = final.append(seg, crossfade=crossfade)
80
-
81
- final = final[:total_len * 1000].fade_in(1000).fade_out(1000).normalize(headroom=-9.0)
82
 
83
  out_path = "output_cleaned.mp3"
84
- final.export(out_path, format="mp3", bitrate="128k", tags={"title": "GhostAI Track", "artist": "GhostAI"})
85
-
 
86
  return out_path, "βœ… Done!"
87
 
88
- demo = gr.Interface(
89
- fn=generate_music,
90
- inputs=[
91
- gr.Textbox(label="Instrumental Prompt"),
92
- gr.Slider(1.0, 10.0, value=3.0, step=0.1, label="CFG Scale"),
93
- gr.Slider(10, 500, value=250, step=10, label="Top-K"),
94
- gr.Slider(0.0, 1.0, value=0.9, step=0.05, label="Top-P"),
95
- gr.Slider(0.1, 2.0, value=1.0, step=0.1, label="Temperature"),
96
- gr.Radio([30, 60, 90, 120], value=30, label="Length (seconds)"),
97
- gr.Slider(5, 15, value=10, step=1, label="Chunk Length (seconds)"),
98
- gr.Slider(100, 2000, value=1000, step=100, label="Crossfade (ms)")
99
- ],
100
- outputs=[
101
- gr.Audio(label="Generated Music", type="filepath"),
102
- gr.Textbox(label="Status")
103
- ],
104
- title="πŸ‘» GhostAI Music Generator",
105
- description="Generate instrumental music using MusicGen Medium model."
106
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
 
108
  demo.launch(share=False)
 
1
  #!/usr/bin/env python3
2
+ # GhostAI Music Generator – HF-download version
3
+ import os, sys, gc, time, warnings, random, tempfile
4
+ import torch, torchaudio, numpy as np, gradio as gr, psutil
 
 
 
 
 
 
 
 
 
 
5
  from pydub import AudioSegment
6
+ from torch.cuda.amp import autocast
7
  from audiocraft.models import MusicGen
8
  from huggingface_hub import login
9
 
10
+ # ------------------------------------------------------------------ #
11
+ # ❓ Torch <2.3 shim (adds get_default_device) #
12
+ # ------------------------------------------------------------------ #
13
+ if not hasattr(torch, "get_default_device"):
14
+ torch.get_default_device = lambda: torch.device(
15
+ "cuda" if torch.cuda.is_available() else "cpu"
16
+ )
17
+
18
  warnings.filterwarnings("ignore")
19
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128"
20
 
21
+ # ------------------------------------------------------------------ #
22
+ # πŸ”‘ Login to HF (model download) #
23
+ # ------------------------------------------------------------------ #
24
  HF_TOKEN = os.getenv("HF_TOKEN")
25
  if not HF_TOKEN:
26
+ sys.exit("ERROR: environment variable HF_TOKEN not set in the Space settings.")
27
  login(HF_TOKEN)
28
 
29
+ # ------------------------------------------------------------------ #
30
+ # πŸ–₯ Device setup #
31
+ # ------------------------------------------------------------------ #
32
  device = "cuda" if torch.cuda.is_available() else "cpu"
33
  print(f"Running on {device.upper()}")
34
+ if device == "cuda":
35
+ print(f"GPU : {torch.cuda.get_device_name(0)}")
36
+ def clean():
 
 
 
 
 
 
37
  if device == "cuda":
38
  torch.cuda.empty_cache()
39
  gc.collect()
40
+ clean()
41
+
42
+ # ------------------------------------------------------------------ #
43
+ # πŸ“₯ Load MusicGen from HF Hub #
44
+ # ------------------------------------------------------------------ #
45
+ print("Downloading & loading facebook/musicgen-medium …")
46
+ musicgen_model = MusicGen.get_pretrained("facebook/musicgen-medium", device=device)
47
+ musicgen_model.set_generation_params(duration=10, two_step_cfg=False)
48
+
49
+ sample_rate = musicgen_model.sample_rate
50
+
51
+ # ------------------------------------------------------------------ #
52
+ # πŸ“Š Helpers #
53
+ # ------------------------------------------------------------------ #
54
+ def vram_ok(req=3.5):
55
+ if device != "cuda":
56
+ return True
57
+ total = torch.cuda.get_device_properties(0).total_memory / 1024**3
58
+ free = total - torch.cuda.memory_allocated() / 1024**3
59
+ if free < req:
60
+ print(f"⚠️ Only {free:.2f} GB free (< {req} GB).")
61
+ return free >= req
62
+
63
+ def log(stage=""):
64
+ if stage: print(f"── {stage} ──")
65
+ if device == "cuda":
66
+ a = torch.cuda.memory_allocated() / 1024**3
67
+ r = torch.cuda.memory_reserved() / 1024**3
68
+ print(f"GPU mem alloc {a:.2f} GB reserved {r:.2f} GB")
69
+ print(f"CPU mem {psutil.virtual_memory().percent}% used")
70
+
71
+ # ------------------------------------------------------------------ #
72
+ # πŸŽ› Prompt builders (unchanged) #
73
+ # ------------------------------------------------------------------ #
74
+ def _p(base,bpm,dr,syn,st,bass,gtr,dflt_bass,dflt_gtr,vibe):
75
+ step = f" with {st}" if st!="none" else vibe.format(bpm=bpm)
76
+ dr = f", {dr} drums" if dr!="none" else ""
77
+ syn = f", {syn} accents" if syn!="none" else ""
78
+ bass = f", {bass}" if bass!="none" else dflt_bass
79
+ gtr = f", {gtr} guitar riffs" if gtr!="none" else dflt_gtr
80
+ return f"{base}{bass}{gtr}{dr}{syn}{step} at {bpm} BPM."
81
+
82
+ def set_red_hot_chili_peppers_prompt(bpm,dr,syn,st,bass,gtr):
83
+ return _p("Instrumental funk rock",bpm,dr,syn,st,bass,gtr,
84
+ ", groovy basslines",", syncopated guitar riffs",
85
+ "{bpm} BPM funky flow" if bpm>120 else "groovy rhythmic flow")
86
+ # … keep your other 17 prompt functions exactly as before …
87
+
88
+ # ------------------------------------------------------------------ #
89
+ # 🎚 Audio post-processing #
90
+ # ------------------------------------------------------------------ #
91
+ def apply_eq(seg): return seg.low_pass_filter(8000).high_pass_filter(80)
92
+ def apply_fade(seg): return seg.fade_in(1000).fade_out(1000)
93
+
94
+ # ------------------------------------------------------------------ #
95
+ # πŸš€ Generator #
96
+ # ------------------------------------------------------------------ #
97
+ def generate_music(prompt,cfg,k,p,temp,
98
+ total_len,chunk_len,xfade,
99
+ bpm,dr,syn,step,bass,gtr):
100
 
 
 
101
  if not prompt.strip():
102
+ return None, "❌ Prompt is empty."
103
+ if not vram_ok():
104
+ return None, "❌ Not enough VRAM."
105
+
106
+ total_len = int(total_len)
107
+ chunk_len = max(5, min(int(chunk_len), 15))
108
+ n_chunks = max(1, total_len // chunk_len)
109
+ chunk_len = total_len / n_chunks
110
+ overlap = min(1.0, xfade / 1000.0)
111
+ render_len = chunk_len + overlap
112
+ segments = []
113
+
114
+ torch.manual_seed(42)
115
+ np.random.seed(42)
116
+
117
+ t0 = time.time()
118
+ for i in range(n_chunks):
119
+ log(f"before chunk {i+1}")
120
+ musicgen_model.set_generation_params(
121
+ duration=render_len,use_sampling=True,
122
+ top_k=k,top_p=p,temperature=temp,cfg_coef=cfg)
123
+ with torch.no_grad(), autocast():
124
+ audio = musicgen_model.generate([prompt], progress=False)[0]
125
+
126
+ audio = audio.cpu().float()
127
+ if audio.dim()==1: audio = audio.repeat(2,1)
128
+ elif audio.shape[0]==1: audio = audio.repeat(2,1)
129
+ elif audio.shape[0]!=2: audio = audio[:1].repeat(2,1)
130
 
131
  with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
132
  torchaudio.save(tmp.name, audio, sample_rate)
133
+ seg = AudioSegment.from_wav(tmp.name)
134
  os.unlink(tmp.name)
135
+ segments.append(seg)
136
+ clean(); log(f"after chunk {i+1}")
 
137
 
138
+ track = segments[0]
139
  for seg in segments[1:]:
140
+ track = track.append(seg, crossfade=xfade)
141
+ track = track[: total_len*1000]
142
+ track = apply_fade(apply_eq(track).normalize(headroom=-9.0))
143
 
144
  out_path = "output_cleaned.mp3"
145
+ track.export(out_path, format="mp3", bitrate="128k",
146
+ tags={"title":"GhostAI Track","artist":"GhostAI"})
147
+ log("final"); print(f"⏱ {time.time()-t0:.1f}s total")
148
  return out_path, "βœ… Done!"
149
 
150
+ def clear_inputs():
151
+ return ("",3.0,250,0.9,1.0,30,10,1000,
152
+ 120,"none","none","none","none","none")
153
+
154
+ # ------------------------------------------------------------------ #
155
+ # 🎨 Custom CSS (unchanged) #
156
+ # ------------------------------------------------------------------ #
157
+ css = """
158
+ body{background:linear-gradient(135deg,#0A0A0A 0%,#1C2526 100%);color:#E0E0E0;font-family:'Orbitron',sans-serif}
159
+ .header{padding:10px;text-align:center;background:rgba(0,0,0,.9);border-bottom:1px solid #00FF9F}
160
+ """
161
+
162
+ # ------------------------------------------------------------------ #
163
+ # πŸ–Ό Gradio Blocks UI #
164
+ # ------------------------------------------------------------------ #
165
+ with gr.Blocks(css=css) as demo:
166
+ gr.HTML('<div class="header"><h1>πŸ‘» GhostAI Music Generator</h1></div>')
167
+ prompt_box = gr.Textbox(label="Instrumental Prompt", lines=4)
168
+
169
+ # genre buttons (showing two; add the rest as needed)
170
+ with gr.Row():
171
+ gr.Button("RHCP 🌢️").click(
172
+ set_red_hot_chili_peppers_prompt,
173
+ inputs=[gr.State(120),"none","none","none","none","none"],
174
+ outputs=prompt_box
175
+ )
176
+ gr.Button("Nirvana 🎸").click(
177
+ set_nirvana_grunge_prompt,
178
+ inputs=[gr.State(120),"none","none","none","none","none"],
179
+ outputs=prompt_box
180
+ )
181
+
182
+ # parameter sliders
183
+ cfg_scale = gr.Slider(1,10,3,label="CFG Scale")
184
+ top_k = gr.Slider(10,500,250,step=10,label="Top-K")
185
+ top_p = gr.Slider(0,1,0.9,step=0.05,label="Top-P")
186
+ temp = gr.Slider(0.1,2,1,step=0.1,label="Temperature")
187
+ total_len = gr.Radio([30,60,90,120],value=30,label="Length (s)")
188
+ chunk_len = gr.Slider(5,15,10,step=1,label="Chunk (s)")
189
+ crossfade = gr.Slider(100,2000,1000,step=100,label="Cross-fade (ms)")
190
+
191
+ bpm = gr.Slider(60,180,120,label="BPM")
192
+ drum = gr.Dropdown(["none","standard rock","funk groove","techno kick","jazz swing"],value="none",label="Drum")
193
+ synth = gr.Dropdown(["none","analog synth","digital pad","arpeggiated synth"],value="none",label="Synth")
194
+ step = gr.Dropdown(["none","syncopated steps","steady steps","complex steps"],value="none",label="Steps")
195
+ bass = gr.Dropdown(["none","slap bass","deep bass","melodic bass"],value="none",label="Bass")
196
+ gtr = gr.Dropdown(["none","distorted","clean","jangle"],value="none",label="Guitar")
197
+
198
+ gen = gr.Button("Generate 🎼")
199
+ clr = gr.Button("Clear 🧹")
200
+ audio_out = gr.Audio(type="filepath")
201
+ status = gr.Textbox(interactive=False)
202
+
203
+ gen.click(generate_music,
204
+ inputs=[prompt_box,cfg_scale,top_k,top_p,temp,
205
+ total_len,chunk_len,crossfade,
206
+ bpm,drum,synth,step,bass,gtr],
207
+ outputs=[audio_out,status])
208
+ clr.click(clear_inputs, None,
209
+ [prompt_box,cfg_scale,top_k,top_p,temp,
210
+ total_len,chunk_len,crossfade,
211
+ bpm,drum,synth,step,bass,gtr])
212
 
213
  demo.launch(share=False)