ahk-d commited on
Commit
7383a83
Β·
verified Β·
1 Parent(s): 76f1409

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -5
app.py CHANGED
@@ -19,8 +19,28 @@ RAVE_MODELS = {
19
 
20
  MODEL_CACHE = {}
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  def load_rave_model(model_name):
23
- """Load a RAVE model from Hugging Face or cache."""
24
  if model_name in MODEL_CACHE:
25
  return MODEL_CACHE[model_name]
26
 
@@ -29,7 +49,7 @@ def load_rave_model(model_name):
29
  filename=RAVE_MODELS[model_name]
30
  )
31
 
32
- model = RAVE.load(model_file) # RAVE.load assumes wrapper for loading .ts file
33
  model.eval()
34
  MODEL_CACHE[model_name] = model
35
  return model
@@ -42,17 +62,19 @@ def apply_rave(audio, model_name):
42
  audio_tensor = torch.tensor(audio[0]).unsqueeze(0) # [1, samples]
43
  sr = audio[1]
44
 
 
45
  if sr != 48000:
46
  audio_tensor = torchaudio.functional.resample(audio_tensor, sr, 48000)
47
  sr = 48000
48
 
49
- # Pass through model (encode -> decode)
50
  with torch.no_grad():
 
 
51
  z = model.encode(audio_tensor)
52
  processed_audio = model.decode(z)
53
 
54
- processed_audio = processed_audio.squeeze().cpu().numpy()
55
- return (processed_audio, sr)
56
 
57
  # πŸŽ› Gradio Interface
58
  with gr.Blocks() as demo:
 
19
 
20
  MODEL_CACHE = {}
21
 
22
+ import gradio as gr
23
+ import torchaudio
24
+ import torch
25
+ import numpy as np
26
+ from huggingface_hub import hf_hub_download
27
+
28
+ # βœ… Available RAVE models
29
+ RAVE_MODELS = {
30
+ "Guitar": "guitar_iil_b2048_r48000_z16.ts",
31
+ "Soprano Sax": "sax_soprano_franziskaschroeder_b2048_r48000_z20.ts",
32
+ "Organ (Archive)": "organ_archive_b2048_r48000_z16.ts",
33
+ "Organ (Bach)": "organ_bach_b2048_r48000_z16.ts",
34
+ "Voice Multivoice": "voice-multi-b2048-r48000-z11.ts",
35
+ "Birds Dawn Chorus": "birds_dawnchorus_b2048_r48000_z8.ts",
36
+ "Magnets": "magnets_b2048_r48000_z8.ts",
37
+ "Whale Songs": "humpbacks_pondbrain_b2048_r48000_z20.ts"
38
+ }
39
+
40
+ MODEL_CACHE = {}
41
+
42
  def load_rave_model(model_name):
43
+ """Load a TorchScript RAVE model directly from Hugging Face."""
44
  if model_name in MODEL_CACHE:
45
  return MODEL_CACHE[model_name]
46
 
 
49
  filename=RAVE_MODELS[model_name]
50
  )
51
 
52
+ model = torch.jit.load(model_file, map_location="cpu")
53
  model.eval()
54
  MODEL_CACHE[model_name] = model
55
  return model
 
62
  audio_tensor = torch.tensor(audio[0]).unsqueeze(0) # [1, samples]
63
  sr = audio[1]
64
 
65
+ # βœ… resample if needed
66
  if sr != 48000:
67
  audio_tensor = torchaudio.functional.resample(audio_tensor, sr, 48000)
68
  sr = 48000
69
 
 
70
  with torch.no_grad():
71
+ # βœ… pass audio through RAVE TorchScript (encode/decode)
72
+ # TorchScript models are usually structured like: model.encode(x) / model.decode(z)
73
  z = model.encode(audio_tensor)
74
  processed_audio = model.decode(z)
75
 
76
+ return (processed_audio.squeeze().cpu().numpy(), sr)
77
+
78
 
79
  # πŸŽ› Gradio Interface
80
  with gr.Blocks() as demo: