ahk-d commited on
Commit
b0f2644
Β·
verified Β·
1 Parent(s): 57c442c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +142 -65
app.py CHANGED
@@ -2,76 +2,153 @@ import gradio as gr
2
  import torchaudio
3
  import torch
4
  import numpy as np
 
5
  from huggingface_hub import hf_hub_download
6
 
7
- # βœ… Map of model names to files on Hugging Face
 
 
 
8
  RAVE_MODELS = {
9
- "Guitar": "guitar_iil_b2048_r48000_z16.ts",
10
- "Soprano Sax": "sax_soprano_franziskaschroeder_b2048_r48000_z20.ts",
11
- "Organ (Archive)": "organ_archive_b2048_r48000_z16.ts",
12
- "Organ (Bach)": "organ_bach_b2048_r48000_z16.ts",
13
- "Voice Multivoice": "voice-multi-b2048-r48000-z11.ts",
14
- "Birds Dawn Chorus": "birds_dawnchorus_b2048_r48000_z8.ts",
15
- "Magnets": "magnets_b2048_r48000_z8.ts",
16
- "Whale Songs": "humpbacks_pondbrain_b2048_r48000_z20.ts"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  }
18
 
19
  MODEL_CACHE = {}
20
-
21
- def load_rave_model(model_name):
22
- """Load TorchScript RAVE model from Hugging Face Hub."""
23
- if model_name in MODEL_CACHE:
24
- return MODEL_CACHE[model_name]
25
-
26
- model_file = hf_hub_download(
27
- repo_id="Intelligent-Instruments-Lab/rave-models",
28
- filename=RAVE_MODELS[model_name]
29
- )
30
-
31
- model = torch.jit.load(model_file, map_location="cpu")
32
- model.eval()
33
- MODEL_CACHE[model_name] = model
34
- return model
35
-
36
- def apply_rave(audio, model_name):
37
- """Apply selected RAVE model to uploaded audio."""
38
- model = load_rave_model(model_name)
39
-
40
- # βœ… Unpack properly
41
- waveform, sr = audio # waveform: np.array [samples, channels]
42
-
43
- # βœ… Convert stereo -> mono if needed
44
- if waveform.ndim > 1:
45
- waveform = np.mean(waveform, axis=1)
46
-
47
- # βœ… Convert numpy to torch tensor
48
- audio_tensor = torch.tensor(waveform).unsqueeze(0) # shape: [1, samples]
49
-
50
- # βœ… Resample if needed
51
- if int(sr) != 48000:
52
- audio_tensor = torchaudio.functional.resample(audio_tensor, int(sr), 48000)
53
- sr = 48000
54
-
55
- with torch.no_grad():
56
- z = model.encode(audio_tensor)
57
- processed_audio = model.decode(z)
58
-
59
- return (processed_audio.squeeze().cpu().numpy(), sr)
60
-
61
-
62
- # πŸŽ› Gradio UI
63
- with gr.Blocks() as demo:
64
- gr.Markdown("## πŸŽ› RAVE Style Transfer on Stems")
65
- gr.Markdown("Upload audio, pick a RAVE model, and get a transformed version.")
66
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  with gr.Row():
68
- audio_input = gr.Audio(type="numpy", label="Upload Audio", sources=["upload", "microphone"])
69
- model_selector = gr.Dropdown(list(RAVE_MODELS.keys()), label="Select Style", value="Guitar")
70
-
71
- with gr.Row():
72
- output_audio = gr.Audio(type="numpy", label="Transformed Audio")
73
-
74
- process_btn = gr.Button("Apply Style Transfer")
75
- process_btn.click(fn=apply_rave, inputs=[audio_input, model_selector], outputs=output_audio)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
- demo.launch()
 
 
 
2
  import torchaudio
3
  import torch
4
  import numpy as np
5
+ import os
6
  from huggingface_hub import hf_hub_download
7
 
8
+ # HF Spaces doesn't need this, but keeps local compatibility
9
+ # os.environ["GRADIO_TEMP_DIR"] = "/tmp/gradio_cache"
10
+
11
+ # βœ… Updated list: only confirmed existing models
12
  RAVE_MODELS = {
13
+ # Models from Intelligent-Instruments-Lab/rave-models
14
+ "Electric Guitar (IIL)": ("Intelligent-Instruments-Lab/rave-models", "guitar_iil_b2048_r48000_z16.ts"),
15
+ "Soprano Sax (IIL)": ("Intelligent-Instruments-Lab/rave-models", "sax_soprano_franziskaschroeder_b2048_r48000_z20.ts"),
16
+ "Organ (Archive IIL)": ("Intelligent-Instruments-Lab/rave-models", "organ_archive_b2048_r48000_z16.ts"),
17
+ "Organ (Bach IIL)": ("Intelligent-Instruments-Lab/rave-models", "organ_bach_b2048_r48000_z16.ts"),
18
+ "Magnetic Resonator Piano (IIL)": ("Intelligent-Instruments-Lab/rave-models", "mrp_strengjavera_b2048_r44100_z16.ts"),
19
+ "Multi-Voice (IIL)": ("Intelligent-Instruments-Lab/rave-models", "voice-multi-b2048-r48000-z11.ts"),
20
+ "Birds (Dawn Chorus IIL)": ("Intelligent-Instruments-Lab/rave-models", "birds_dawnchorus_b2048_r48000_z8.ts"),
21
+ "Water (Pond Brain IIL)": ("Intelligent-Instruments-Lab/rave-models", "water_pondbrain_b2048_r48000_z16.ts"),
22
+ "Marine Mammals (IIL)": ("Intelligent-Instruments-Lab/rave-models", "marinemammals_pondbrain_b2048_r48000_z20.ts"),
23
+
24
+ # Models from shuoyang-zheng/jaspers-rave-models
25
+ "Guitar Picking (Jasper Causal)": ("shuoyang-zheng/jaspers-rave-models", "guitar_picking_dm_b2048_r44100_z8_causal.ts"),
26
+ "Singing Voice (Jasper Non-Causal)": ("shuoyang-zheng/jaspers-rave-models", "gtsinger_b2048_r44100_z16_noncausal.ts"),
27
+ "Drums (Jasper AAM)": ("shuoyang-zheng/jaspers-rave-models", "aam_drum_b2048_r44100_z16_noncausal.ts"),
28
+ "Bass (Jasper AAM)": ("shuoyang-zheng/jaspers-rave-models", "aam_bass_b2048_r44100_z16_noncausal.ts"),
29
+ "Strings (Jasper AAM)": ("shuoyang-zheng/jaspers-rave-models", "aam_string_b2048_r44100_z16_noncausal.ts"),
30
+ "Speech (Jasper Causal)": ("shuoyang-zheng/jaspers-rave-models", "librispeech100_b2048_r44100_z8_causal.ts"),
31
+ "Brass/Sax (Jasper AAM)": ("shuoyang-zheng/jaspers-rave-models", "aam_brass_sax_b2048_r44100_z8_noncausal.ts"),
32
+
33
+ # Model from lancelotblanchard/rave_percussion
34
+ "Percussion (Lancelot)": ("lancelotblanchard/rave_percussion", "percussion.ts"),
35
  }
36
 
37
  MODEL_CACHE = {}
38
+ print("πŸŽ› RAVE Style Transfer - Starting up...")
39
+
40
+ def load_rave_model(model_key):
41
+ if model_key in MODEL_CACHE:
42
+ return MODEL_CACHE[model_key]
43
+ print(f"πŸ“₯ Loading model: {model_key}...")
44
+ try:
45
+ repo_id, model_file_name = RAVE_MODELS[model_key]
46
+ model_file = hf_hub_download(repo_id=repo_id, filename=model_file_name)
47
+ model = torch.jit.load(model_file, map_location="cpu")
48
+ model.eval()
49
+ MODEL_CACHE[model_key] = model
50
+ print(f"βœ… Loaded: {model_key}")
51
+ return model
52
+ except Exception as e:
53
+ print(f"❌ Error loading {model_key}: {str(e)}")
54
+ raise
55
+
56
+ def apply_rave(audio_path, model_name):
57
+ """
58
+ Apply RAVE style transfer to audio.
59
+ Returns tuple (sample_rate, numpy_array) for Gradio.
60
+ """
61
+ if not audio_path:
62
+ return None, "❌ Please upload an audio file."
63
+
64
+ try:
65
+ print(f"🎡 Processing audio: {os.path.basename(audio_path)} with {model_name}")
66
+
67
+ # Load and preprocess audio
68
+ waveform, sr = torchaudio.load(audio_path)
69
+ print(f"πŸ“Š Original: {waveform.shape}, {sr}Hz")
70
+
71
+ # Convert to mono if stereo
72
+ if waveform.shape[0] > 1:
73
+ print("πŸ”„ Converting stereo to mono")
74
+ waveform = torch.mean(waveform, dim=0, keepdim=True)
75
+
76
+ # Resample to 48kHz if needed
77
+ if sr != 48000:
78
+ print(f"πŸ”„ Resampling from {sr}Hz to 48000Hz")
79
+ waveform = torchaudio.functional.resample(waveform, sr, 48000)
80
+ sr = 48000
81
+
82
+ # Add batch dimension
83
+ waveform = waveform.unsqueeze(0)
84
+
85
+ # Load model and process
86
+ model = load_rave_model(model_name)
87
+ print("πŸ€– Applying RAVE transformation...")
88
+
89
+ with torch.no_grad():
90
+ z = model.encode(waveform)
91
+ processed = model.decode(z)
92
+
93
+ # Prepare output
94
+ processed = processed.squeeze(0)
95
+ arr = processed.squeeze().cpu().numpy()
96
+
97
+ print("βœ… Transformation complete!")
98
+ return (sr, arr), "βœ… Style transfer successful!"
99
+
100
+ except Exception as e:
101
+ error_msg = f"❌ Error: {str(e)}"
102
+ print(error_msg)
103
+ return None, error_msg
104
+
105
+ # --- Gradio UI ---
106
+ print("πŸš€ Creating Gradio interface...")
107
+
108
+ with gr.Blocks(theme=gr.themes.Soft(), title="RAVE Style Transfer") as demo:
109
+ gr.Markdown("# πŸŽ› RAVE Style Transfer Stem Remixer")
110
+ gr.Markdown("Transform your audio using AI-powered style transfer. Upload audio and choose an instrument style!")
111
+
112
  with gr.Row():
113
+ with gr.Column():
114
+ audio_input = gr.Audio(
115
+ type="filepath",
116
+ label="🎡 Upload Your Audio",
117
+ sources=["upload", "microphone"]
118
+ )
119
+ model_selector = gr.Dropdown(
120
+ choices=list(RAVE_MODELS.keys()),
121
+ label="🎸 Select Instrument Style",
122
+ value="Electric Guitar (IIL)",
123
+ interactive=True
124
+ )
125
+ process_btn = gr.Button("πŸ”„ Apply RAVE Transform", variant="primary", size="lg")
126
+
127
+ with gr.Column():
128
+ output_audio = gr.Audio(
129
+ type="numpy",
130
+ label="🎧 Transformed Audio"
131
+ )
132
+ status_output = gr.Textbox(
133
+ label="πŸ“Š Status",
134
+ interactive=False,
135
+ value="Ready to transform audio..."
136
+ )
137
+
138
+ process_btn.click(
139
+ fn=apply_rave,
140
+ inputs=[audio_input, model_selector],
141
+ outputs=[output_audio, status_output]
142
+ )
143
+
144
+ gr.Markdown("---")
145
+ gr.Markdown(
146
+ "<p style='text-align: center; font-size: small;'>"
147
+ "Powered by RAVE (Realtime Audio Variational autoEncoder) | "
148
+ "Models from Intelligent Instruments Lab & Community"
149
+ "</p>"
150
+ )
151
 
152
+ print("🌐 Launching demo...")
153
+ if __name__ == "__main__":
154
+ demo.launch()