Spaces:
Runtime error
Runtime error
#!/usr/bin/env python3 | |
# -*- coding: utf-8 -*- | |
""" | |
GhostAI Music Generator — Zero-GPU & GPU friendly | |
Python : 3.10 | |
Torch : 2.1 CPU wheels (or CUDA 11.8/12.1) | |
Gradio : 5.31.0 | |
Last updated: 2025-05-29 | |
""" | |
import os | |
import sys | |
import gc | |
import time | |
import random | |
import warnings | |
import tempfile | |
import psutil | |
import numpy as np | |
import torch | |
import torchaudio | |
import gradio as gr | |
from pydub import AudioSegment | |
from torch.cuda.amp import autocast | |
from audiocraft.models import MusicGen | |
from huggingface_hub import login | |
# ---------------------------------------------------------------------- | |
# Compatibility shim (torch < 2.3) | |
# ---------------------------------------------------------------------- | |
if not hasattr(torch, "get_default_device"): | |
torch.get_default_device = lambda: torch.device( | |
"cuda" if torch.cuda.is_available() else "cpu" | |
) | |
# ---------------------------------------------------------------------- | |
# Silence warnings & CUDA fragmentation tuning | |
# ---------------------------------------------------------------------- | |
warnings.filterwarnings("ignore") | |
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128" | |
# ---------------------------------------------------------------------- | |
# Hugging Face authentication | |
# ---------------------------------------------------------------------- | |
HF_TOKEN = os.getenv("HF_TOKEN") | |
if not HF_TOKEN: | |
print("ERROR: environment variable HF_TOKEN not set.") | |
sys.exit(1) | |
try: | |
login(HF_TOKEN) | |
except Exception as e: | |
print(f"ERROR: Hugging Face login failed: {e}") | |
sys.exit(1) | |
# ---------------------------------------------------------------------- | |
# Device setup & cleanup | |
# ---------------------------------------------------------------------- | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
print(f"Running on {device.upper()}") | |
if device == "cuda": | |
print(f"GPU: {torch.cuda.get_device_name(0)}") | |
def gpu_clean(): | |
if device == "cuda": | |
torch.cuda.empty_cache() | |
gc.collect() | |
torch.cuda.ipc_collect() | |
torch.cuda.synchronize() | |
gpu_clean() | |
# ---------------------------------------------------------------------- | |
# Load MusicGen model (fixed checkpoint name) | |
# ---------------------------------------------------------------------- | |
print("Loading MusicGen ‘medium’ checkpoint …") | |
musicgen = MusicGen.get_pretrained("medium", device=device) | |
musicgen.set_generation_params(duration=10, two_step_cfg=False) | |
# ---------------------------------------------------------------------- | |
# Resource monitoring | |
# ---------------------------------------------------------------------- | |
def log_resources(stage=""): | |
if stage: | |
print(f"--- {stage} ---") | |
if device == "cuda": | |
alloc = torch.cuda.memory_allocated() / 1024**3 | |
resv = torch.cuda.memory_reserved() / 1024**3 | |
print(f"GPU Mem | Alloc {alloc:.2f} GB Reserved {resv:.2f} GB") | |
print(f"CPU Mem | {psutil.virtual_memory().percent}% used") | |
def vram_ok(threshold=3.5): | |
if device != "cuda": | |
return True | |
total = torch.cuda.get_device_properties(0).total_memory / 1024**3 | |
free = total - torch.cuda.memory_allocated() / 1024**3 | |
if free < threshold: | |
print(f"WARNING: Only {free:.2f} GB VRAM free (<{threshold} GB).") | |
return free >= threshold | |
# ---------------------------------------------------------------------- | |
# Prompt builders | |
# ---------------------------------------------------------------------- | |
def _make_prompt(base, bpm, drum, synth, steps, bass, gtr, def_bass, def_gtr, flow): | |
step_txt = f" with {steps}" if steps != "none" else flow.format(bpm=bpm) | |
drum_txt = f", {drum} drums" if drum != "none" else "" | |
synth_txt = f", {synth} accents" if synth != "none" else "" | |
bass_txt = f", {bass}" if bass != "none" else def_bass | |
gtr_txt = f", {gtr} guitar riffs"if gtr != "none" else def_gtr | |
return f"{base}{bass_txt}{gtr_txt}{drum_txt}{synth_txt}{step_txt} at {bpm} BPM." | |
def set_red_hot_chili_peppers_prompt(bpm, drum, synth, steps, bass, gtr): | |
return _make_prompt( | |
"Instrumental funk rock", bpm, drum, synth, steps, bass, gtr, | |
", groovy basslines", ", syncopated guitar riffs", | |
"{bpm} BPM funky flow" if bpm > 120 else "groovy rhythmic flow" | |
) | |
# … include the other set_*_prompt functions exactly as before … | |
# ---------------------------------------------------------------------- | |
# Audio post-processing | |
# ---------------------------------------------------------------------- | |
def apply_eq(seg: AudioSegment): | |
return seg.low_pass_filter(8000).high_pass_filter(80) | |
def apply_fade(seg: AudioSegment, fin=1000, fout=1000): | |
return seg.fade_in(fin).fade_out(fout) | |
# ---------------------------------------------------------------------- | |
# Core generation | |
# ---------------------------------------------------------------------- | |
def generate_music( | |
prompt, cfg, top_k, top_p, temp, | |
total_len, chunk_len, crossfade, | |
bpm, drum, synth, steps, bass, gtr | |
): | |
if not prompt.strip(): | |
return None, "⚠️ Prompt cannot be empty." | |
if not vram_ok(): | |
return None, "⚠️ Insufficient VRAM." | |
total_len = int(total_len) | |
chunk_len = int(max(5, min(chunk_len, 15))) | |
n_chunks = max(1, total_len // chunk_len) | |
chunk_len = total_len / n_chunks | |
overlap = min(1.0, crossfade / 1000.0) | |
render_len = chunk_len + overlap | |
sr = musicgen.sample_rate | |
segments = [] | |
torch.manual_seed(42) | |
np.random.seed(42) | |
start = time.time() | |
for i in range(n_chunks): | |
log_resources(f"Before chunk {i+1}") | |
musicgen.set_generation_params( | |
duration=render_len, | |
use_sampling=True, | |
top_k=top_k, | |
top_p=top_p, | |
temperature=temp, | |
cfg_coef=cfg | |
) | |
with torch.no_grad(), autocast(): | |
audio = musicgen.generate([prompt], progress=False)[0] | |
audio = audio.cpu().to(torch.float32) | |
if audio.dim() == 1: | |
audio = torch.stack([audio, audio]) | |
elif audio.shape[0] == 1: | |
audio = torch.cat([audio, audio], dim=0) | |
elif audio.shape[0] != 2: | |
audio = torch.cat([audio[:1], audio[:1]], dim=0) | |
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp: | |
torchaudio.save(tmp.name, audio, sr, bits_per_sample=24) | |
seg = AudioSegment.from_wav(tmp.name) | |
os.unlink(tmp.name) | |
segments.append(seg) | |
gpu_clean() | |
log_resources(f"After chunk {i+1}") | |
final = segments[0] | |
for seg in segments[1:]: | |
final = final.append(seg + 1, crossfade=crossfade) | |
final = final[: total_len * 1000] | |
final = apply_fade(apply_eq(final).normalize(headroom=-9.0)) | |
out_path = "output_cleaned.mp3" | |
final.export( | |
out_path, | |
format="mp3", | |
bitrate="128k", | |
tags={"title": "GhostAI Instrumental", "artist": "GhostAI"} | |
) | |
log_resources("After final") | |
print(f"Total generation time: {time.time() - start:.2f}s") | |
return out_path, "✅ Done!" | |
def clear_inputs(): | |
return ( | |
"", 3.0, 250, 0.9, 1.0, | |
30, 10, 1000, | |
120, "none", "none", "none", "none", "none" | |
) | |
# ---------------------------------------------------------------------- | |
# Gradio UI | |
# ---------------------------------------------------------------------- | |
css = """ | |
body { | |
background: linear-gradient(135deg, #0A0A0A 0%, #1C2526 100%); | |
color: #E0E0E0; font-family: 'Orbitron', sans-serif; | |
} | |
.header { | |
text-align: center; padding: 10px; background: rgba(0,0,0,0.9); | |
border-bottom: 1px solid #00FF9F; | |
} | |
""" | |
with gr.Blocks(css=css) as demo: | |
gr.HTML('<div class="header"><h1>👻 GhostAI Music Generator</h1></div>') | |
prompt_box = gr.Textbox(label="Instrumental Prompt ✍️", lines=4) | |
with gr.Row(): | |
gr.Button("RHCP 🌶️").click( | |
set_red_hot_chili_peppers_prompt, | |
inputs=[gr.State(120), gr.State("none"), gr.State("none"), | |
gr.State("none"), gr.State("none"), gr.State("none")], | |
outputs=prompt_box | |
) | |
gr.Button("Nirvana 🎸").click( | |
set_nirvana_grunge_prompt, | |
inputs=[gr.State(120), gr.State("none"), gr.State("none"), | |
gr.State("none"), gr.State("none"), gr.State("none")], | |
outputs=prompt_box | |
) | |
# … add the other genre buttons in the same pattern … | |
with gr.Group(): | |
cfg_scale = gr.Slider(1.0, 10.0, value=3.0, step=0.1, label="CFG Scale") | |
top_k = gr.Slider(10, 500, value=250, step=10, label="Top-K") | |
top_p = gr.Slider(0.0, 1.0, value=0.9, step=0.05, label="Top-P") | |
temperature = gr.Slider(0.1, 2.0, value=1.0, step=0.1, label="Temperature") | |
total_len = gr.Radio([30, 60, 90, 120], value=30, label="Length (s)") | |
chunk_len = gr.Slider(5, 15, value=10, step=1, label="Chunk (s)") | |
crossfade = gr.Slider(100, 2000, value=1000, step=100, label="Crossfade (ms)") | |
bpm = gr.Slider(60, 180, value=120, label="Tempo (BPM)") | |
drum_beat = gr.Dropdown( | |
["none","standard rock","funk groove","techno kick","jazz swing"], | |
value="none", label="Drum Beat" | |
) | |
synthesizer = gr.Dropdown( | |
["none","analog synth","digital pad","arpeggiated synth"], | |
value="none", label="Synthesizer" | |
) | |
steps = gr.Dropdown( | |
["none","syncopated steps","steady steps","complex steps"], | |
value="none", label="Rhythmic Steps" | |
) | |
bass_style = gr.Dropdown( | |
["none","slap bass","deep bass","melodic bass"], | |
value="none", label="Bass Style" | |
) | |
guitar_style = gr.Dropdown( | |
["none","distorted","clean","jangle"], | |
value="none", label="Guitar Style" | |
) | |
gen_btn = gr.Button("Generate Music 🚀") | |
clr_btn = gr.Button("Clear 🧹") | |
out_audio = gr.Audio(label="Generated Track", type="filepath") | |
status = gr.Textbox(label="Status", interactive=False) | |
gen_btn.click( | |
generate_music, | |
inputs=[ | |
prompt_box, cfg_scale, top_k, top_p, temperature, | |
total_len, chunk_len, crossfade, | |
bpm, drum_beat, synthesizer, steps, bass_style, guitar_style | |
], | |
outputs=[out_audio, status] | |
) | |
clr_btn.click( | |
clear_inputs, None, | |
[ | |
prompt_box, cfg_scale, top_k, top_p, temperature, | |
total_len, chunk_len, crossfade, | |
bpm, drum_beat, synthesizer, steps, bass_style, guitar_style | |
] | |
) | |
app = demo.launch(share=False, show_error=True) | |
try: | |
demo._server.app.docs_url = demo._server.app.redoc_url = demo._server.app.openapi_url = None | |
except Exception: | |
pass | |