Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python3 | |
| """ | |
| GhostAI Music Generator โ ZeroGPU Space | |
| Streams facebook/musicgen-medium with dynamic GPU bursts. | |
| """ | |
| # 0๏ธโฃ Import spaces *first* so CUDA isnโt touched beforehand | |
| import spaces # HF ZeroGPU decorator | |
| # 1๏ธโฃ Standard libs | |
| import os, sys, gc, time, warnings, random, tempfile | |
| import numpy as np, psutil | |
| # 2๏ธโฃ Torch (CPU wheels; ZeroGPU migrates tensors when needed) | |
| import torch, torchaudio | |
| # 3๏ธโฃ Other deps | |
| import gradio as gr | |
| from pydub import AudioSegment | |
| from audiocraft.models import MusicGen | |
| from huggingface_hub import login | |
| from torch.cuda.amp import autocast | |
| # โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| # Torch <2.3 shim (transformers may call torch.get_default_device) | |
| if not hasattr(torch, "get_default_device"): | |
| torch.get_default_device = lambda: torch.device( | |
| "cuda" if torch.cuda.is_available() else "cpu" | |
| ) | |
| warnings.filterwarnings("ignore") | |
| os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128" | |
| # ๐ Authenticate so we can pull the model | |
| HF_TOKEN = os.getenv("HF_TOKEN") | |
| if not HF_TOKEN: | |
| sys.exit("ERROR: Add HF_TOKEN as a secret in your Space.") | |
| login(HF_TOKEN) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"ZeroGPU detected โ initial device is {device.upper()}") | |
| # ๐ฅ Download model from Hub | |
| print("Loading facebook/musicgen-medium โฆ (first run may take ~6 GB download)") | |
| musicgen = MusicGen.get_pretrained("facebook/musicgen-medium") | |
| musicgen.set_generation_params(duration=10, two_step_cfg=False) | |
| SAMPLE_RATE = musicgen.sample_rate | |
| # โญโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโฎ | |
| # โ Prompt helpers (kept exactly from your original script) โ | |
| # โฐโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโฏ | |
| def _p(base,bpm,dr,syn,st,bass,gtr,dflt_bass,dflt_gtr,flow): | |
| step = f" with {st}" if st!="none" else flow.format(bpm=bpm) | |
| dr = f", {dr} drums" if dr!="none" else "" | |
| syn = f", {syn} accents" if syn!="none" else "" | |
| bass = f", {bass}" if bass!="none" else dflt_bass | |
| gtr = f", {gtr} guitar riffs" if gtr!="none" else dflt_gtr | |
| return f"{base}{bass}{gtr}{dr}{syn}{step} at {bpm} BPM." | |
| def set_red_hot_chili_peppers_prompt(bpm,dr,syn,st,bass,gtr): | |
| return _p("Instrumental funk rock",bpm,dr,syn,st,bass,gtr, | |
| ", groovy basslines",", syncopated guitar riffs", | |
| "{bpm} BPM funky flow" if bpm>120 else "groovy rhythmic flow") | |
| def set_nirvana_grunge_prompt(bpm,dr,syn,st,bass,gtr): | |
| return _p("Instrumental grunge",bpm,dr,syn,st,bass,gtr, | |
| ", melodic basslines",", raw distorted guitar riffs", | |
| "{bpm} BPM grungy pulse" if bpm>120 else "grungy rhythmic pulse") | |
| # โฆ include your other genre functions unchanged โฆ | |
| # Audio FX | |
| def apply_eq(s): return s.low_pass_filter(8000).high_pass_filter(80) | |
| def apply_fade(s): return s.fade_in(1000).fade_out(1000) | |
| def log(stage=""): | |
| if stage: print(f"โโ {stage} โโ") | |
| if torch.cuda.is_available(): | |
| alloc = torch.cuda.memory_allocated()/1024**3 | |
| res = torch.cuda.memory_reserved()/1024**3 | |
| print(f"GPU mem alloc {alloc:.2f} GB reserved {res:.2f} GB") | |
| print(f"CPU mem {psutil.virtual_memory().percent}% used") | |
| # โญโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโฎ | |
| # โ Core generator โ wrapped with @spaces.GPU โ | |
| # โฐโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโฏ | |
| def generate_music(prompt,cfg,k,p,temp, | |
| total_len,chunk_len,xfade, | |
| bpm,dr,syn,step,bass,gtr): | |
| if not prompt.strip(): | |
| return None, "โ ๏ธ Prompt is empty." | |
| total_len = int(total_len) | |
| chunk_len = max(5, min(int(chunk_len), 15)) | |
| n_chunks = max(1, total_len // chunk_len) | |
| chunk_len = total_len / n_chunks | |
| overlap = min(1.0, xfade / 1000.0) | |
| render = chunk_len + overlap | |
| pieces = [] | |
| torch.manual_seed(42); np.random.seed(42) | |
| t0 = time.time() | |
| for i in range(n_chunks): | |
| log(f"before chunk {i+1}") | |
| musicgen.set_generation_params( | |
| duration=render, use_sampling=True, | |
| top_k=k, top_p=p, temperature=temp, cfg_coef=cfg | |
| ) | |
| with torch.no_grad(), autocast(): | |
| audio = musicgen.generate([prompt], progress=False)[0] | |
| audio = audio.cpu().float() | |
| if audio.dim()==1: audio = audio.repeat(2,1) | |
| elif audio.shape[0]==1: audio = audio.repeat(2,1) | |
| elif audio.shape[0]!=2: audio = audio[:1].repeat(2,1) | |
| with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp: | |
| torchaudio.save(tmp.name, audio, SAMPLE_RATE) | |
| seg = AudioSegment.from_wav(tmp.name) | |
| os.unlink(tmp.name) | |
| pieces.append(seg) | |
| torch.cuda.empty_cache(); gc.collect() | |
| log(f"after chunk {i+1}") | |
| track = pieces[0] | |
| for seg in pieces[1:]: | |
| track = track.append(seg, crossfade=xfade) | |
| track = track[: total_len*1000] | |
| track = apply_fade(apply_eq(track).normalize(headroom=-9.0)) | |
| out_file = "output_cleaned.mp3" | |
| track.export(out_file, format="mp3", bitrate="128k", | |
| tags={"title":"GhostAI Track","artist":"GhostAI"}) | |
| log("final"); print(f"Total {time.time()-t0:.1f}s") | |
| return out_file, "โ Done!" | |
| def clear_inputs(): | |
| return ("",3.0,250,0.9,1.0,30,10,1000, | |
| 120,"none","none","none","none","none") | |
| # โญโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโฎ | |
| # โ Gradio Blocks UI with your CSS & controls โ | |
| # โฐโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโฏ | |
| css = "body{background:#0A0A0A;color:#E0E0E0;font-family:'Orbitron',sans-serif}" | |
| with gr.Blocks(css=css) as demo: | |
| gr.HTML("<h1 style='text-align:center'>๐ป GhostAI Music Generator</h1>") | |
| prompt = gr.Textbox(label="Prompt", lines=4) | |
| with gr.Row(): | |
| gr.Button("RHCP ๐ถ๏ธ").click(set_red_hot_chili_peppers_prompt, | |
| inputs=[gr.State(120),"none","none","none","none","none"], | |
| outputs=prompt) | |
| gr.Button("Nirvana ๐ธ").click(set_nirvana_grunge_prompt, | |
| inputs=[gr.State(120),"none","none","none","none","none"], | |
| outputs=prompt) | |
| # add more genre buttons here โฆ | |
| cfg = gr.Slider(1,10,3,label="CFG") | |
| top_k = gr.Slider(10,500,250,step=10,label="Top-K") | |
| top_p = gr.Slider(0,1,0.9,step=0.05,label="Top-P") | |
| temp = gr.Slider(0.1,2,1,step=0.1,label="Temperature") | |
| length= gr.Radio([30,60,90,120],value=30,label="Length (s)") | |
| chunk = gr.Slider(5,15,10,step=1,label="Chunk (s)") | |
| xfade = gr.Slider(100,2000,1000,step=100,label="Cross-fade (ms)") | |
| bpm = gr.Slider(60,180,120,label="BPM") | |
| drum = gr.Dropdown(["none","standard rock","funk groove","techno kick","jazz swing"],"none","Drum") | |
| synth = gr.Dropdown(["none","analog synth","digital pad","arpeggiated synth"],"none","Synth") | |
| steps = gr.Dropdown(["none","syncopated steps","steady steps","complex steps"],"none","Steps") | |
| bass = gr.Dropdown(["none","slap bass","deep bass","melodic bass"],"none","Bass") | |
| gtr = gr.Dropdown(["none","distorted","clean","jangle"],"none","Guitar") | |
| gen = gr.Button("Generate ๐ถ") | |
| clr = gr.Button("Clear ๐งน") | |
| audio = gr.Audio(type="filepath") | |
| status= gr.Textbox(interactive=False) | |
| gen.click(generate_music, | |
| inputs=[prompt,cfg,top_k,top_p,temp,length,chunk,xfade, | |
| bpm,drum,synth,steps,bass,gtr], | |
| outputs=[audio,status]) | |
| clr.click(clear_inputs,None, | |
| [prompt,cfg,top_k,top_p,temp,length,chunk,xfade, | |
| bpm,drum,synth,steps,bass,gtr]) | |
| demo.launch(share=False) | |