File size: 10,846 Bytes
5aae4e6
 
 
 
 
 
 
 
 
 
c3aa73a
 
 
 
 
 
 
 
 
 
 
 
 
87e4401
 
2c690c6
857db5b
87e4401
c3aa73a
 
 
2c690c6
779076d
 
 
2c690c6
c3aa73a
5aae4e6
c3aa73a
87e4401
 
 
c3aa73a
 
 
779076d
 
c3aa73a
 
 
 
 
 
857db5b
 
c3aa73a
 
 
87e4401
779076d
789a7fa
779076d
87e4401
779076d
789a7fa
779076d
 
 
 
 
 
 
c3aa73a
 
 
 
 
779076d
 
c3aa73a
 
 
 
 
 
779076d
c3aa73a
 
 
 
87e4401
c3aa73a
789a7fa
 
c3aa73a
 
 
 
 
 
 
 
 
 
 
5aae4e6
 
 
 
c3aa73a
 
 
 
 
 
 
 
 
5aae4e6
c3aa73a
 
 
 
2c690c6
 
 
 
 
 
c3aa73a
5aae4e6
c3aa73a
 
 
 
 
 
2c690c6
c3aa73a
779076d
c3aa73a
 
 
 
 
 
 
 
779076d
c3aa73a
2c690c6
 
 
 
5aae4e6
c3aa73a
 
779076d
c3aa73a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2c690c6
c3aa73a
779076d
c3aa73a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5aae4e6
c3aa73a
779076d
87e4401
c3aa73a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87e4401
 
 
c3aa73a
 
 
779076d
c3aa73a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5aae4e6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c3aa73a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87e4401
c3aa73a
87e4401
2c690c6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
GhostAI Music Generator — Zero-GPU & GPU friendly
Python  : 3.10
Torch   : 2.1 CPU wheels (or CUDA 11.8/12.1)
Gradio  : 5.31.0
Last updated: 2025-05-29
"""

import os
import sys
import gc
import time
import random
import warnings
import tempfile

import psutil
import numpy as np
import torch
import torchaudio
import gradio as gr
from pydub import AudioSegment
from torch.cuda.amp import autocast
from audiocraft.models import MusicGen
from huggingface_hub import login

# ----------------------------------------------------------------------
# Compatibility shim (torch < 2.3)
# ----------------------------------------------------------------------
if not hasattr(torch, "get_default_device"):
    torch.get_default_device = lambda: torch.device(
        "cuda" if torch.cuda.is_available() else "cpu"
    )

# ----------------------------------------------------------------------
# Silence warnings & CUDA fragmentation tuning
# ----------------------------------------------------------------------
warnings.filterwarnings("ignore")
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128"

# ----------------------------------------------------------------------
# Hugging Face authentication
# ----------------------------------------------------------------------
HF_TOKEN = os.getenv("HF_TOKEN")
if not HF_TOKEN:
    print("ERROR: environment variable HF_TOKEN not set.")
    sys.exit(1)
try:
    login(HF_TOKEN)
except Exception as e:
    print(f"ERROR: Hugging Face login failed: {e}")
    sys.exit(1)

# ----------------------------------------------------------------------
# Device setup & cleanup
# ----------------------------------------------------------------------
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Running on {device.upper()}")
if device == "cuda":
    print(f"GPU: {torch.cuda.get_device_name(0)}")

def gpu_clean():
    if device == "cuda":
        torch.cuda.empty_cache()
        gc.collect()
        torch.cuda.ipc_collect()
        torch.cuda.synchronize()

gpu_clean()

# ----------------------------------------------------------------------
# Load MusicGen model (fixed checkpoint name)
# ----------------------------------------------------------------------
print("Loading MusicGen ‘medium’ checkpoint …")
musicgen = MusicGen.get_pretrained("medium", device=device)
musicgen.set_generation_params(duration=10, two_step_cfg=False)

# ----------------------------------------------------------------------
# Resource monitoring
# ----------------------------------------------------------------------
def log_resources(stage=""):
    if stage:
        print(f"--- {stage} ---")
    if device == "cuda":
        alloc = torch.cuda.memory_allocated() / 1024**3
        resv  = torch.cuda.memory_reserved()  / 1024**3
        print(f"GPU Mem | Alloc {alloc:.2f} GB  Reserved {resv:.2f} GB")
    print(f"CPU Mem | {psutil.virtual_memory().percent}% used")

def vram_ok(threshold=3.5):
    if device != "cuda":
        return True
    total = torch.cuda.get_device_properties(0).total_memory / 1024**3
    free  = total - torch.cuda.memory_allocated() / 1024**3
    if free < threshold:
        print(f"WARNING: Only {free:.2f} GB VRAM free (<{threshold} GB).")
    return free >= threshold

# ----------------------------------------------------------------------
# Prompt builders
# ----------------------------------------------------------------------
def _make_prompt(base, bpm, drum, synth, steps, bass, gtr, def_bass, def_gtr, flow):
    step_txt  = f" with {steps}" if steps != "none" else flow.format(bpm=bpm)
    drum_txt  = f", {drum} drums"      if drum  != "none" else ""
    synth_txt = f", {synth} accents"   if synth != "none" else ""
    bass_txt  = f", {bass}"            if bass  != "none" else def_bass
    gtr_txt   = f", {gtr} guitar riffs"if gtr   != "none" else def_gtr
    return f"{base}{bass_txt}{gtr_txt}{drum_txt}{synth_txt}{step_txt} at {bpm} BPM."

def set_red_hot_chili_peppers_prompt(bpm, drum, synth, steps, bass, gtr):
    return _make_prompt(
        "Instrumental funk rock", bpm, drum, synth, steps, bass, gtr,
        ", groovy basslines", ", syncopated guitar riffs",
        "{bpm} BPM funky flow" if bpm > 120 else "groovy rhythmic flow"
    )

# … include the other set_*_prompt functions exactly as before …

# ----------------------------------------------------------------------
# Audio post-processing
# ----------------------------------------------------------------------
def apply_eq(seg: AudioSegment):
    return seg.low_pass_filter(8000).high_pass_filter(80)

def apply_fade(seg: AudioSegment, fin=1000, fout=1000):
    return seg.fade_in(fin).fade_out(fout)

# ----------------------------------------------------------------------
# Core generation
# ----------------------------------------------------------------------
def generate_music(
    prompt, cfg, top_k, top_p, temp,
    total_len, chunk_len, crossfade,
    bpm, drum, synth, steps, bass, gtr
):
    if not prompt.strip():
        return None, "⚠️ Prompt cannot be empty."
    if not vram_ok():
        return None, "⚠️ Insufficient VRAM."

    total_len  = int(total_len)
    chunk_len  = int(max(5, min(chunk_len, 15)))
    n_chunks   = max(1, total_len // chunk_len)
    chunk_len  = total_len / n_chunks
    overlap    = min(1.0, crossfade / 1000.0)
    render_len = chunk_len + overlap
    sr         = musicgen.sample_rate
    segments   = []

    torch.manual_seed(42)
    np.random.seed(42)

    start = time.time()
    for i in range(n_chunks):
        log_resources(f"Before chunk {i+1}")
        musicgen.set_generation_params(
            duration=render_len,
            use_sampling=True,
            top_k=top_k,
            top_p=top_p,
            temperature=temp,
            cfg_coef=cfg
        )
        with torch.no_grad(), autocast():
            audio = musicgen.generate([prompt], progress=False)[0]

        audio = audio.cpu().to(torch.float32)
        if audio.dim() == 1:
            audio = torch.stack([audio, audio])
        elif audio.shape[0] == 1:
            audio = torch.cat([audio, audio], dim=0)
        elif audio.shape[0] != 2:
            audio = torch.cat([audio[:1], audio[:1]], dim=0)

        with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
            torchaudio.save(tmp.name, audio, sr, bits_per_sample=24)
            seg = AudioSegment.from_wav(tmp.name)
        os.unlink(tmp.name)
        segments.append(seg)

        gpu_clean()
        log_resources(f"After chunk {i+1}")

    final = segments[0]
    for seg in segments[1:]:
        final = final.append(seg + 1, crossfade=crossfade)
    final = final[: total_len * 1000]
    final = apply_fade(apply_eq(final).normalize(headroom=-9.0))

    out_path = "output_cleaned.mp3"
    final.export(
        out_path,
        format="mp3",
        bitrate="128k",
        tags={"title": "GhostAI Instrumental", "artist": "GhostAI"}
    )
    log_resources("After final")
    print(f"Total generation time: {time.time() - start:.2f}s")
    return out_path, "✅ Done!"

def clear_inputs():
    return (
        "", 3.0, 250, 0.9, 1.0,
        30, 10, 1000,
        120, "none", "none", "none", "none", "none"
    )

# ----------------------------------------------------------------------
# Gradio UI
# ----------------------------------------------------------------------
css = """
body {
  background: linear-gradient(135deg, #0A0A0A 0%, #1C2526 100%);
  color: #E0E0E0; font-family: 'Orbitron', sans-serif;
}
.header {
  text-align: center; padding: 10px; background: rgba(0,0,0,0.9);
  border-bottom: 1px solid #00FF9F;
}
"""

with gr.Blocks(css=css) as demo:
    gr.HTML('<div class="header"><h1>👻 GhostAI Music Generator</h1></div>')

    prompt_box = gr.Textbox(label="Instrumental Prompt ✍️", lines=4)
    with gr.Row():
        gr.Button("RHCP 🌶️").click(
            set_red_hot_chili_peppers_prompt,
            inputs=[gr.State(120), gr.State("none"), gr.State("none"),
                    gr.State("none"), gr.State("none"), gr.State("none")],
            outputs=prompt_box
        )
        gr.Button("Nirvana 🎸").click(
            set_nirvana_grunge_prompt,
            inputs=[gr.State(120), gr.State("none"), gr.State("none"),
                    gr.State("none"), gr.State("none"), gr.State("none")],
            outputs=prompt_box
        )
        # … add the other genre buttons in the same pattern …

    with gr.Group():
        cfg_scale   = gr.Slider(1.0, 10.0, value=3.0, step=0.1, label="CFG Scale")
        top_k       = gr.Slider(10, 500, value=250, step=10, label="Top-K")
        top_p       = gr.Slider(0.0, 1.0, value=0.9, step=0.05, label="Top-P")
        temperature = gr.Slider(0.1, 2.0, value=1.0, step=0.1, label="Temperature")
        total_len   = gr.Radio([30, 60, 90, 120], value=30, label="Length (s)")
        chunk_len   = gr.Slider(5, 15, value=10, step=1, label="Chunk (s)")
        crossfade   = gr.Slider(100, 2000, value=1000, step=100, label="Crossfade (ms)")

    bpm           = gr.Slider(60, 180, value=120, label="Tempo (BPM)")
    drum_beat     = gr.Dropdown(
        ["none","standard rock","funk groove","techno kick","jazz swing"],
        value="none", label="Drum Beat"
    )
    synthesizer   = gr.Dropdown(
        ["none","analog synth","digital pad","arpeggiated synth"],
        value="none", label="Synthesizer"
    )
    steps         = gr.Dropdown(
        ["none","syncopated steps","steady steps","complex steps"],
        value="none", label="Rhythmic Steps"
    )
    bass_style    = gr.Dropdown(
        ["none","slap bass","deep bass","melodic bass"],
        value="none", label="Bass Style"
    )
    guitar_style  = gr.Dropdown(
        ["none","distorted","clean","jangle"],
        value="none", label="Guitar Style"
    )

    gen_btn   = gr.Button("Generate Music 🚀")
    clr_btn   = gr.Button("Clear 🧹")
    out_audio = gr.Audio(label="Generated Track", type="filepath")
    status    = gr.Textbox(label="Status", interactive=False)

    gen_btn.click(
        generate_music,
        inputs=[
            prompt_box, cfg_scale, top_k, top_p, temperature,
            total_len, chunk_len, crossfade,
            bpm, drum_beat, synthesizer, steps, bass_style, guitar_style
        ],
        outputs=[out_audio, status]
    )
    clr_btn.click(
        clear_inputs, None,
        [
            prompt_box, cfg_scale, top_k, top_p, temperature,
            total_len, chunk_len, crossfade,
            bpm, drum_beat, synthesizer, steps, bass_style, guitar_style
        ]
    )

app = demo.launch(share=False, show_error=True)
try:
    demo._server.app.docs_url = demo._server.app.redoc_url = demo._server.app.openapi_url = None
except Exception:
    pass