RAVE-n / app.py
ahk-d's picture
Update app.py
b0f2644 verified
import gradio as gr
import torchaudio
import torch
import numpy as np
import os
from huggingface_hub import hf_hub_download
# HF Spaces doesn't need this, but keeps local compatibility
# os.environ["GRADIO_TEMP_DIR"] = "/tmp/gradio_cache"
# βœ… Updated list: only confirmed existing models
RAVE_MODELS = {
# Models from Intelligent-Instruments-Lab/rave-models
"Electric Guitar (IIL)": ("Intelligent-Instruments-Lab/rave-models", "guitar_iil_b2048_r48000_z16.ts"),
"Soprano Sax (IIL)": ("Intelligent-Instruments-Lab/rave-models", "sax_soprano_franziskaschroeder_b2048_r48000_z20.ts"),
"Organ (Archive IIL)": ("Intelligent-Instruments-Lab/rave-models", "organ_archive_b2048_r48000_z16.ts"),
"Organ (Bach IIL)": ("Intelligent-Instruments-Lab/rave-models", "organ_bach_b2048_r48000_z16.ts"),
"Magnetic Resonator Piano (IIL)": ("Intelligent-Instruments-Lab/rave-models", "mrp_strengjavera_b2048_r44100_z16.ts"),
"Multi-Voice (IIL)": ("Intelligent-Instruments-Lab/rave-models", "voice-multi-b2048-r48000-z11.ts"),
"Birds (Dawn Chorus IIL)": ("Intelligent-Instruments-Lab/rave-models", "birds_dawnchorus_b2048_r48000_z8.ts"),
"Water (Pond Brain IIL)": ("Intelligent-Instruments-Lab/rave-models", "water_pondbrain_b2048_r48000_z16.ts"),
"Marine Mammals (IIL)": ("Intelligent-Instruments-Lab/rave-models", "marinemammals_pondbrain_b2048_r48000_z20.ts"),
# Models from shuoyang-zheng/jaspers-rave-models
"Guitar Picking (Jasper Causal)": ("shuoyang-zheng/jaspers-rave-models", "guitar_picking_dm_b2048_r44100_z8_causal.ts"),
"Singing Voice (Jasper Non-Causal)": ("shuoyang-zheng/jaspers-rave-models", "gtsinger_b2048_r44100_z16_noncausal.ts"),
"Drums (Jasper AAM)": ("shuoyang-zheng/jaspers-rave-models", "aam_drum_b2048_r44100_z16_noncausal.ts"),
"Bass (Jasper AAM)": ("shuoyang-zheng/jaspers-rave-models", "aam_bass_b2048_r44100_z16_noncausal.ts"),
"Strings (Jasper AAM)": ("shuoyang-zheng/jaspers-rave-models", "aam_string_b2048_r44100_z16_noncausal.ts"),
"Speech (Jasper Causal)": ("shuoyang-zheng/jaspers-rave-models", "librispeech100_b2048_r44100_z8_causal.ts"),
"Brass/Sax (Jasper AAM)": ("shuoyang-zheng/jaspers-rave-models", "aam_brass_sax_b2048_r44100_z8_noncausal.ts"),
# Model from lancelotblanchard/rave_percussion
"Percussion (Lancelot)": ("lancelotblanchard/rave_percussion", "percussion.ts"),
}
MODEL_CACHE = {}
print("πŸŽ› RAVE Style Transfer - Starting up...")
def load_rave_model(model_key):
if model_key in MODEL_CACHE:
return MODEL_CACHE[model_key]
print(f"πŸ“₯ Loading model: {model_key}...")
try:
repo_id, model_file_name = RAVE_MODELS[model_key]
model_file = hf_hub_download(repo_id=repo_id, filename=model_file_name)
model = torch.jit.load(model_file, map_location="cpu")
model.eval()
MODEL_CACHE[model_key] = model
print(f"βœ… Loaded: {model_key}")
return model
except Exception as e:
print(f"❌ Error loading {model_key}: {str(e)}")
raise
def apply_rave(audio_path, model_name):
"""
Apply RAVE style transfer to audio.
Returns tuple (sample_rate, numpy_array) for Gradio.
"""
if not audio_path:
return None, "❌ Please upload an audio file."
try:
print(f"🎡 Processing audio: {os.path.basename(audio_path)} with {model_name}")
# Load and preprocess audio
waveform, sr = torchaudio.load(audio_path)
print(f"πŸ“Š Original: {waveform.shape}, {sr}Hz")
# Convert to mono if stereo
if waveform.shape[0] > 1:
print("πŸ”„ Converting stereo to mono")
waveform = torch.mean(waveform, dim=0, keepdim=True)
# Resample to 48kHz if needed
if sr != 48000:
print(f"πŸ”„ Resampling from {sr}Hz to 48000Hz")
waveform = torchaudio.functional.resample(waveform, sr, 48000)
sr = 48000
# Add batch dimension
waveform = waveform.unsqueeze(0)
# Load model and process
model = load_rave_model(model_name)
print("πŸ€– Applying RAVE transformation...")
with torch.no_grad():
z = model.encode(waveform)
processed = model.decode(z)
# Prepare output
processed = processed.squeeze(0)
arr = processed.squeeze().cpu().numpy()
print("βœ… Transformation complete!")
return (sr, arr), "βœ… Style transfer successful!"
except Exception as e:
error_msg = f"❌ Error: {str(e)}"
print(error_msg)
return None, error_msg
# --- Gradio UI ---
print("πŸš€ Creating Gradio interface...")
with gr.Blocks(theme=gr.themes.Soft(), title="RAVE Style Transfer") as demo:
gr.Markdown("# πŸŽ› RAVE Style Transfer Stem Remixer")
gr.Markdown("Transform your audio using AI-powered style transfer. Upload audio and choose an instrument style!")
with gr.Row():
with gr.Column():
audio_input = gr.Audio(
type="filepath",
label="🎡 Upload Your Audio",
sources=["upload", "microphone"]
)
model_selector = gr.Dropdown(
choices=list(RAVE_MODELS.keys()),
label="🎸 Select Instrument Style",
value="Electric Guitar (IIL)",
interactive=True
)
process_btn = gr.Button("πŸ”„ Apply RAVE Transform", variant="primary", size="lg")
with gr.Column():
output_audio = gr.Audio(
type="numpy",
label="🎧 Transformed Audio"
)
status_output = gr.Textbox(
label="πŸ“Š Status",
interactive=False,
value="Ready to transform audio..."
)
process_btn.click(
fn=apply_rave,
inputs=[audio_input, model_selector],
outputs=[output_audio, status_output]
)
gr.Markdown("---")
gr.Markdown(
"<p style='text-align: center; font-size: small;'>"
"Powered by RAVE (Realtime Audio Variational autoEncoder) | "
"Models from Intelligent Instruments Lab & Community"
"</p>"
)
print("🌐 Launching demo...")
if __name__ == "__main__":
demo.launch()