quazim's picture
added elastic model
43f4544
raw
history blame
6.01 kB
import gradio as gr
import torch
from transformers import AutoProcessor
from elastic_models.transformers import MusicgenForConditionalGeneration
import scipy.io.wavfile
import numpy as np
import subprocess
import sys
import os
def setup_flash_attention():
"""One-time setup for flash-attention with special flags"""
# Check if flash-attn is already installed
try:
import flash_attn
print("flash-attn already installed")
return
except ImportError:
pass
# Check if we've already tried to install it in this session
if os.path.exists("/tmp/flash_attn_installed"):
return
try:
print("Installing flash-attn with --no-build-isolation...")
subprocess.run([
sys.executable, "-m", "pip", "install",
"flash-attn==2.7.3", "--no-build-isolation"
], check=True)
# Uninstall apex if it exists
subprocess.run([
sys.executable, "-m", "pip", "uninstall", "apex", "-y"
], check=False) # Don't fail if apex isn't installed
# Mark as installed
with open("/tmp/flash_attn_installed", "w") as f:
f.write("installed")
print("flash-attn installation completed")
except subprocess.CalledProcessError as e:
print(f"Warning: Failed to install flash-attn: {e}")
# Continue anyway - the model might work without it
# Run setup once when the module is imported
# setup_flash_attention()
# Load model and processor
# @gr.cache()
def load_model():
"""Load the musicgen model and processor"""
processor = AutoProcessor.from_pretrained("facebook/musicgen-large")
model = MusicgenForConditionalGeneration.from_pretrained(
"facebook/musicgen-large",
torch_dtype=torch.float16,
device=self.device,
mode="S",
__paged=True,
)
return processor, model
def generate_music(text_prompt, duration=10, temperature=1.0, top_k=250, top_p=0.0):
"""Generate music based on text prompt"""
try:
processor, model = load_model()
# Process the text prompt
inputs = processor(
text=[text_prompt],
padding=True,
return_tensors="pt",
)
# Generate audio
with torch.no_grad():
audio_values = model.generate(
**inputs,
max_new_tokens=duration * 50, # Approximate tokens per second
do_sample=True,
temperature=temperature,
top_k=top_k,
top_p=top_p,
cache_implementation="paged"
)
# Convert to numpy array and prepare for output
audio_data = audio_values[0, 0].cpu().numpy()
sample_rate = model.config.sample_rate
# Normalize audio
audio_data = audio_data / np.max(np.abs(audio_data))
return sample_rate, audio_data
except Exception as e:
return None, f"Error generating music: {str(e)}"
# Create Gradio interface
with gr.Blocks(title="MusicGen Large - Music Generation") as demo:
gr.Markdown("# 🎡 MusicGen Large Music Generator")
gr.Markdown("Generate music from text descriptions using Facebook's MusicGen Large model.")
with gr.Row():
with gr.Column():
text_input = gr.Textbox(
label="Music Description",
placeholder="Enter a description of the music you want to generate (e.g., 'upbeat jazz with piano and drums')",
lines=3
)
with gr.Row():
duration = gr.Slider(
minimum=5,
maximum=30,
value=10,
step=1,
label="Duration (seconds)"
)
temperature = gr.Slider(
minimum=0.1,
maximum=2.0,
value=1.0,
step=0.1,
label="Temperature (creativity)"
)
with gr.Row():
top_k = gr.Slider(
minimum=1,
maximum=500,
value=250,
step=1,
label="Top-k"
)
top_p = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.0,
step=0.1,
label="Top-p"
)
generate_btn = gr.Button("🎡 Generate Music", variant="primary")
with gr.Column():
audio_output = gr.Audio(
label="Generated Music",
type="numpy"
)
gr.Markdown("### Tips:")
gr.Markdown("""
- Be specific in your descriptions (e.g., "slow blues guitar with harmonica")
- Higher temperature = more creative/random results
- Lower temperature = more predictable results
- Duration is limited to 30 seconds for faster generation
""")
# Example prompts
gr.Examples(
examples=[
["upbeat jazz with piano and drums"],
["relaxing acoustic guitar melody"],
["electronic dance music with heavy bass"],
["classical violin concerto"],
["reggae with steel drums and bass"],
["rock ballad with electric guitar solo"],
["test example"],
],
inputs=text_input,
label="Example Prompts"
)
# Connect the generate button to the function
generate_btn.click(
fn=generate_music,
inputs=[text_input, duration, temperature, top_k, top_p],
outputs=audio_output
)
if __name__ == "__main__":
demo.launch()