Wataru commited on
Commit
6f5f35c
·
1 Parent(s): 7cfbe46

Add app.py and dependencies

Browse files
Files changed (2) hide show
  1. app.py +95 -0
  2. requirements.txt +2 -0
app.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import sys
3
+ from pathlib import Path
4
+
5
+ import gradio as gr
6
+ import torch
7
+ import torchaudio
8
+
9
+ # Add src to path to import sfi_utmos
10
+ project_root = Path(__file__).resolve().parent
11
+ sys.path.insert(0, str(project_root / "src"))
12
+
13
+ from sfi_utmos.model.ssl_mos import SSLMOSLightningModule
14
+
15
+ # Global variable for the model
16
+ model: SSLMOSLightningModule | None = None
17
+ device = "cuda" if torch.cuda.is_available() else "cpu"
18
+
19
+
20
+ def load_model(checkpoint_path: str):
21
+ """Loads the model from the given checkpoint path."""
22
+ global model
23
+ model = SSLMOSLightningModule.load_from_checkpoint(
24
+ checkpoint_path, map_location=device
25
+ )
26
+ model.eval()
27
+ print(f"Model loaded from {checkpoint_path}")
28
+
29
+
30
+ def predict_mos(audio_path: str):
31
+ """Predicts the MOS score for the given audio file."""
32
+ if model is None:
33
+ return "Error: Model not loaded. Please provide a valid checkpoint path."
34
+ ratings = []
35
+ for listner in range(1, 11):
36
+ wav, sr = torchaudio.load(audio_path)
37
+ if sr not in model.sr2id.keys():
38
+ return f"Error: Sample rate {sr} not supported by the model. Supported rates: {list(model.sr2id.keys())}"
39
+ waves = [wav.view(-1).to(model.device)]
40
+ srs = torch.tensor(sr).view(1, -1).to(model.device)
41
+ if model.condition_sr:
42
+ srs = torch.stack(
43
+ [torch.tensor(model.sr2id[sr.detach().cpu().item()]) for sr in srs]
44
+ ).to(model.device)
45
+ listner_tensor = torch.tensor(listner).view(-1).to(model.device)
46
+ if hasattr(model, "is_sfi") and model.is_sfi:
47
+ model.ssl_model.set_sample_rate(srs[0].item())
48
+ waves = torch.nn.utils.rnn.pad_sequence(
49
+ [w.view(-1) for w in waves], batch_first=True
50
+ ).to(device)
51
+ else:
52
+ waves = [torchaudio.functional.resample(w, sr, 16_000) for w in waves]
53
+ output = model.forward(
54
+ waves,
55
+ listner_tensor,
56
+ srs,
57
+ )
58
+ ratings.append(output.cpu().item())
59
+ mos_score = sum(ratings) / len(ratings)
60
+
61
+ return f"{mos_score:.3f}"
62
+
63
+
64
+ def main():
65
+ parser = argparse.ArgumentParser(description="Run MOS prediction demo with Gradio.")
66
+ parser.add_argument(
67
+ "--checkpoint_path",
68
+ type=str,
69
+ default="model.ckpt",
70
+ help="Path to the model checkpoint (.ckpt file).",
71
+ )
72
+ args = parser.parse_args()
73
+
74
+ load_model(args.checkpoint_path)
75
+
76
+ if model is None:
77
+ print("Failed to load model. Exiting.")
78
+ sys.exit(1)
79
+
80
+ # Gradio interface
81
+ iface = gr.Interface(
82
+ fn=predict_mos,
83
+ inputs=gr.Audio(type="filepath", label="Upload Audio File"),
84
+ outputs="text",
85
+ title="SFI-UTMOS: MOS Prediction Demo",
86
+ description=(
87
+ "Upload an audio file (WAV, MP3, etc.) to get its predicted Mean Opinion Score (MOS). "
88
+ ),
89
+ )
90
+ iface.launch()
91
+
92
+
93
+ if __name__ == "__main__":
94
+ main()
95
+
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ -e git+https://github.com/sarulab-speech/msr-utmos.git
2
+