jhansss commited on
Commit
e17ee3c
·
2 Parent(s): 1ec2d7e 810614d

Merge remote-tracking branch 'public/main'

Browse files
Files changed (8) hide show
  1. character.png +3 -0
  2. client.py +58 -0
  3. data/touhou/note_data.json +29 -6
  4. path.sh +3 -0
  5. run_server.sh +14 -0
  6. server.py +28 -19
  7. svs_utils.py +6 -6
  8. test_performance.py +263 -0
character.png ADDED

Git LFS Details

  • SHA256: 38dc2981a0ac817f62d8a87a053285535821041e8ead37e77d871b9bc7b3a82d
  • Pointer size: 132 Bytes
  • Size of remote file: 1.78 MB
client.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import uuid
3
+ import os
4
+ import requests
5
+ import base64
6
+ from server import (
7
+ on_click_metrics as server_metrics,
8
+ process_audio as server_process_audio
9
+ )
10
+
11
+ TTS_OUTPUT_DIR = "./tmp"
12
+ os.makedirs(TTS_OUTPUT_DIR, exist_ok=True)
13
+
14
+
15
+ def process_audio(audio_path):
16
+ # We have audio_path
17
+ result = server_process_audio(audio_path)
18
+
19
+ audio_data = base64.b64decode(result["audio"])
20
+ with open(f"{TTS_OUTPUT_DIR}/response.wav", "wb") as f:
21
+ f.write(audio_data)
22
+
23
+ with open(f"{TTS_OUTPUT_DIR}/asr.txt", "w") as f:
24
+ f.write(result['asr_text'])
25
+ with open(f"{TTS_OUTPUT_DIR}/llm.txt", "w") as f:
26
+ f.write(result['llm_text'])
27
+
28
+ return f"""
29
+ asr_text: {result['asr_text']}
30
+ llm_text: {result['llm_text']}
31
+ """, f"{TTS_OUTPUT_DIR}/response.wav"
32
+
33
+
34
+ def on_click_metrics():
35
+ res = server_metrics()
36
+ return res.content.decode('utf-8')
37
+
38
+
39
+ with gr.Blocks() as demo:
40
+ with gr.Row():
41
+ with gr.Column(scale=1):
42
+ gr.Image(value="character.png", show_label=False) # キャラ絵を表示
43
+ with gr.Column(scale=2):
44
+ mic = gr.Audio(sources=["microphone"], type="filepath", label="Mic")
45
+ text_output = gr.Textbox(label="transcription")
46
+ audio_output = gr.Audio(label="audio", autoplay=True)
47
+
48
+ mic.change(fn=process_audio, inputs=[mic], outputs=[text_output, audio_output])
49
+ with gr.Row():
50
+ metrics_button = gr.Button("compute metrics")
51
+ metrics_output = gr.Textbox(label="Metrics", lines=3)
52
+ metrics_button.click(fn=on_click_metrics, inputs=[], outputs=[metrics_output])
53
+
54
+ with gr.Row():
55
+ log = gr.Textbox(label="logs", lines=5)
56
+
57
+ demo.launch(share=True)
58
+ # demo.launch()
data/touhou/note_data.json CHANGED
@@ -1,14 +1,37 @@
1
  [
2
  {
3
- "tempo": 120,
4
- "score": [[0.0, 0.207, 54], [0.207, 0.414, 59], [0.414, 0.621, 61], [0.621, 1.264, 62], [1.264, 1.478, 64], [1.478, 2.121, 61], [2.121, 2.335, 62], [2.335, 3.192, 59], [3.406, 3.621, 54], [3.621, 3.835, 59], [3.835, 4.049, 61], [4.049, 4.692, 62], [4.692, 4.906, 64], [4.906, 5.549, 61], [5.549, 5.764, 69], [5.764, 6.621, 66], [6.835, 7.049, 66], [7.049, 7.264, 69], [7.264, 7.478, 71], [7.478, 7.906, 64], [7.906, 8.442, 64], [8.549, 8.764, 64], [8.764, 8.978, 66], [8.978, 9.192, 69], [9.192, 9.621, 62], [9.621, 10.156, 62], [10.264, 10.478, 59], [10.478, 10.692, 59], [10.692, 10.906, 61], [10.906, 11.549, 62], [11.549, 11.764, 64], [11.764, 12.406, 61], [12.406, 12.621, 58], [12.621, 13.478, 59], [13.478, 14.335, 61], [14.335, 15.369, 62], [15.369, 15.473, 61], [15.473, 15.576, 62], [15.576, 15.68, 61], [15.68, 15.783, 57], [15.783, 15.887, 54], [15.887, 15.99, 55], [15.99, 17.645, 56], [24.266, 25.3, 74], [25.3, 25.404, 73], [25.404, 25.507, 71], [25.507, 25.611, 69], [25.611, 25.714, 67], [25.714, 25.818, 66], [25.818, 25.921, 64], [25.921, 26.889, 68], [27.127, 27.357, 54], [27.357, 27.578, 59], [27.578, 27.791, 61], [27.791, 28.416, 62], [28.416, 28.624, 64], [28.624, 29.249, 61], [29.249, 29.458, 62], [29.458, 30.291, 59], [30.499, 30.708, 54], [30.708, 30.916, 59], [30.916, 31.124, 61], [31.124, 31.749, 62], [31.749, 31.958, 64], [31.958, 32.583, 61], [32.583, 32.791, 69], [32.791, 33.833, 66], [33.833, 34.041, 66], [34.041, 34.249, 69], [34.249, 34.458, 71], [34.458, 34.874, 64], [34.874, 35.395, 64], [35.499, 35.708, 64], [35.708, 35.916, 66], [35.916, 36.124, 69], [36.124, 36.541, 62], [36.541, 37.062, 62], [37.166, 37.374, 59], [37.374, 37.583, 59], [37.583, 37.791, 61], [37.791, 38.416, 62], [38.416, 38.624, 64], [38.624, 39.249, 61], [39.249, 39.458, 58], [39.458, 40.291, 59], [40.291, 40.499, 61], [40.499, 40.708, 54], [40.708, 40.916, 59], [40.916, 41.124, 61], [41.124, 41.749, 62], [41.749, 41.958, 64], [41.958, 42.583, 61], [42.583, 42.791, 62], [42.791, 43.624, 59], [43.833, 44.041, 54], [44.041, 44.249, 59], [44.249, 44.458, 61], [44.458, 45.083, 62], [45.083, 45.291, 64], [45.291, 45.916, 61]]
 
5
  },
6
  {
7
- "tempo": 120,
8
- "score": [[0.0, 0.207, 54], [0.207, 0.414, 59], [0.414, 0.621, 61], [0.621, 1.264, 62], [1.264, 1.478, 64], [1.478, 2.121, 61], [2.121, 2.335, 62], [2.335, 3.192, 59], [3.406, 3.621, 54], [3.621, 3.835, 59], [3.835, 4.049, 61], [4.049, 4.692, 62], [4.692, 4.906, 64], [4.906, 5.549, 61], [5.549, 5.764, 69], [5.764, 6.621, 66], [6.835, 7.049, 66], [7.049, 7.264, 69], [7.264, 7.478, 71], [7.478, 7.906, 64], [7.906, 8.442, 64], [8.549, 8.764, 64], [8.764, 8.978, 66], [8.978, 9.192, 69], [9.192, 9.621, 62], [9.621, 10.156, 62], [10.264, 10.478, 59], [10.478, 10.692, 59], [10.692, 10.906, 61], [10.906, 11.549, 62], [11.549, 11.764, 64], [11.764, 12.406, 61], [12.406, 12.621, 58], [12.621, 13.478, 59], [13.478, 14.335, 61], [14.335, 15.369, 62], [15.369, 15.473, 61], [15.473, 15.576, 62], [15.576, 15.68, 61], [15.68, 15.783, 57], [15.783, 15.887, 54], [15.887, 15.99, 55], [15.99, 17.645, 56], [24.266, 25.3, 74], [25.3, 25.404, 73], [25.404, 25.507, 71], [25.507, 25.611, 69], [25.611, 25.714, 67], [25.714, 25.818, 66], [25.818, 25.921, 64], [25.921, 26.889, 68], [27.127, 27.357, 54], [27.357, 27.578, 59], [27.578, 27.791, 61], [27.791, 28.416, 62], [28.416, 28.624, 64], [28.624, 29.249, 61], [29.249, 29.458, 62], [29.458, 30.291, 59], [30.499, 30.708, 54], [30.708, 30.916, 59], [30.916, 31.124, 61], [31.124, 31.749, 62], [31.749, 31.958, 64], [31.958, 32.583, 61], [32.583, 32.791, 69], [32.791, 33.833, 66], [33.833, 34.041, 66], [34.041, 34.249, 69], [34.249, 34.458, 71], [34.458, 34.874, 64], [34.874, 35.395, 64], [35.499, 35.708, 64], [35.708, 35.916, 66], [35.916, 36.124, 69], [36.124, 36.541, 62], [36.541, 37.062, 62], [37.166, 37.374, 59], [37.374, 37.583, 59], [37.583, 37.791, 61], [37.791, 38.416, 62], [38.416, 38.624, 64], [38.624, 39.249, 61], [39.249, 39.458, 58], [39.458, 40.291, 59], [40.291, 40.499, 61], [40.499, 40.708, 54], [40.708, 40.916, 59], [40.916, 41.124, 61], [41.124, 41.749, 62], [41.749, 41.958, 64], [41.958, 42.583, 61], [42.583, 42.791, 62], [42.791, 43.624, 59], [43.833, 44.041, 54], [44.041, 44.249, 59], [44.249, 44.458, 61], [44.458, 45.083, 62], [45.083, 45.291, 64], [45.291, 45.916, 61]]
 
9
  },
10
  {
11
- "tempo": 120,
12
- "score": [[0.0, 0.207, 54], [0.207, 0.414, 59], [0.414, 0.621, 61], [0.621, 1.264, 62], [1.264, 1.478, 64], [1.478, 2.121, 61], [2.121, 2.335, 62], [2.335, 3.192, 59], [3.406, 3.621, 54], [3.621, 3.835, 59], [3.835, 4.049, 61], [4.049, 4.692, 62], [4.692, 4.906, 64], [4.906, 5.549, 61], [5.549, 5.764, 69], [5.764, 6.621, 66], [6.835, 7.049, 66], [7.049, 7.264, 69], [7.264, 7.478, 71], [7.478, 7.906, 64], [7.906, 8.442, 64], [8.549, 8.764, 64], [8.764, 8.978, 66], [8.978, 9.192, 69], [9.192, 9.621, 62], [9.621, 10.156, 62], [10.264, 10.478, 59], [10.478, 10.692, 59], [10.692, 10.906, 61], [10.906, 11.549, 62], [11.549, 11.764, 64], [11.764, 12.406, 61], [12.406, 12.621, 58], [12.621, 13.478, 59], [13.478, 14.335, 61], [14.335, 15.369, 62], [15.369, 15.473, 61], [15.473, 15.576, 62], [15.576, 15.68, 61], [15.68, 15.783, 57], [15.783, 15.887, 54], [15.887, 15.99, 55], [15.99, 17.645, 56], [24.266, 25.3, 74], [25.3, 25.404, 73], [25.404, 25.507, 71], [25.507, 25.611, 69], [25.611, 25.714, 67], [25.714, 25.818, 66], [25.818, 25.921, 64], [25.921, 26.889, 68], [27.127, 27.357, 54], [27.357, 27.578, 59], [27.578, 27.791, 61], [27.791, 28.416, 62], [28.416, 28.624, 64], [28.624, 29.249, 61], [29.249, 29.458, 62], [29.458, 30.291, 59], [30.499, 30.708, 54], [30.708, 30.916, 59], [30.916, 31.124, 61], [31.124, 31.749, 62], [31.749, 31.958, 64], [31.958, 32.583, 61], [32.583, 32.791, 69], [32.791, 33.833, 66], [33.833, 34.041, 66], [34.041, 34.249, 69], [34.249, 34.458, 71], [34.458, 34.874, 64], [34.874, 35.395, 64], [35.499, 35.708, 64], [35.708, 35.916, 66], [35.916, 36.124, 69], [36.124, 36.541, 62], [36.541, 37.062, 62], [37.166, 37.374, 59], [37.374, 37.583, 59], [37.583, 37.791, 61], [37.791, 38.416, 62], [38.416, 38.624, 64], [38.624, 39.249, 61], [39.249, 39.458, 58], [39.458, 40.291, 59], [40.291, 40.499, 61], [40.499, 40.708, 54], [40.708, 40.916, 59], [40.916, 41.124, 61], [41.124, 41.749, 62], [41.749, 41.958, 64], [41.958, 42.583, 61], [42.583, 42.791, 62], [42.791, 43.624, 59], [43.833, 44.041, 54], [44.041, 44.249, 59], [44.249, 44.458, 61], [44.458, 45.083, 62], [45.083, 45.291, 64], [45.291, 45.916, 61]]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  }
14
  ]
 
1
  [
2
  {
3
+ "tempo": 70,
4
+ "score": [[0.208, 0.412, 66], [0.417, 0.621, 69], [0.625, 1.246, 71], [1.25, 1.35, 69], [1.354, 1.454, 71], [1.458, 1.662, 69], [1.667, 1.871, 66], [1.875, 2.079, 64], [2.083, 2.287, 69], [2.292, 3.537, 66], [3.542, 3.746, 66], [3.75, 3.954, 69], [3.958, 4.579, 71], [4.583, 4.683, 69], [4.687, 4.787, 71], [4.792, 4.996, 74], [5.0, 5.204, 73], [5.208, 5.412, 71], [5.417, 5.621, 69], [5.625, 6.454, 71], [6.667, 6.766, 69], [6.771, 6.871, 71], [6.875, 7.079, 69], [7.083, 7.287, 66], [7.292, 8.121, 64], [8.333, 8.433, 69], [8.437, 8.537, 71], [8.542, 8.746, 69], [8.75, 8.954, 64], [8.958, 10.204, 62], [10.208, 10.412, 59], [10.417, 10.621, 61], [10.625, 11.246, 62], [11.25, 11.454, 64], [11.458, 12.079, 61], [12.083, 12.287, 59], [12.292, 13.954, 59], [14.792, 16.454, 64], [16.458, 17.287, 64], [17.292, 18.121, 62], [18.125, 19.371, 64], [19.375, 19.787, 67], [19.792, 20.412, 66], [20.417, 21.037, 64], [20.625, 21.037, 57], [21.042, 21.454, 64], [21.458, 22.287, 64], [22.292, 23.121, 69], [23.125, 23.537, 67], [23.542, 23.954, 66], [23.958, 24.371, 64], [24.375, 24.787, 66], [24.792, 25.412, 67], [25.417, 26.037, 69], [26.042, 26.454, 71], [26.458, 26.871, 71], [26.875, 27.287, 69], [27.292, 27.704, 67], [27.708, 28.121, 66], [28.125, 29.787, 64], [29.792, 30.621, 64], [30.625, 31.454, 62], [31.458, 32.704, 64], [32.708, 33.121, 67], [33.125, 33.537, 66], [33.75, 34.371, 64], [33.958, 34.787, 62], [34.792, 35.204, 64], [35.417, 35.829, 62], [35.833, 36.037, 64], [36.042, 36.454, 59], [36.458, 37.287, 57], [37.292, 37.704, 59], [37.708, 38.121, 62], [38.125, 41.454, 64], [178.125, 179.787, 64], [179.791, 180.62, 64], [180.625, 181.454, 62], [181.458, 182.704, 64], [182.708, 183.12, 67], [183.125, 183.745, 66], [183.75, 184.37, 64], [183.958, 184.37, 57], [184.375, 184.787, 64], [184.791, 185.62, 64], [185.625, 186.454, 69], [186.458, 186.87, 67], [186.875, 187.287, 66], [187.291, 187.704, 64], [187.708, 188.12, 66], [188.125, 188.745, 67], [188.75, 189.37, 69], [189.375, 189.787, 71], [189.791, 190.204, 71], [190.208, 190.62, 69], [190.625, 191.037, 67], [191.041, 191.454, 66], [191.458, 193.12, 64], [193.125, 193.954, 64], [193.958, 194.787, 62]],
5
+ "name": "Gensokyo"
6
  },
7
  {
8
+ "tempo": 70,
9
+ "score": [[0.197, 0.391, 64], [0.395, 0.489, 63], [0.493, 0.588, 64], [0.592, 0.785, 63], [0.789, 0.983, 66], [0.987, 1.18, 64], [1.184, 1.279, 63], [1.283, 1.377, 64], [1.382, 1.476, 63], [1.48, 1.575, 59], [1.579, 1.772, 61], [1.776, 1.97, 68], [1.974, 2.068, 66], [2.072, 2.167, 68], [2.171, 2.266, 66], [2.27, 2.364, 68], [2.368, 2.562, 71], [2.566, 2.759, 68], [2.763, 2.956, 66], [2.961, 3.154, 68], [3.158, 3.351, 56], [3.355, 3.549, 64], [3.553, 3.647, 63], [3.651, 3.746, 64], [3.75, 3.943, 63], [3.947, 4.141, 66], [4.145, 4.338, 64], [4.342, 4.437, 63], [4.441, 4.535, 64], [4.539, 4.634, 63], [4.638, 4.733, 59], [4.737, 4.93, 61], [4.934, 5.127, 68], [5.132, 5.226, 66], [5.23, 5.325, 68], [5.329, 5.424, 66], [5.428, 5.522, 68], [5.526, 5.72, 72], [5.724, 5.917, 68], [5.921, 6.114, 66], [6.118, 6.312, 68], [6.316, 6.509, 56], [6.513, 6.706, 64], [6.711, 6.805, 63], [6.809, 6.904, 64], [6.908, 7.101, 63], [7.105, 7.299, 66], [7.303, 7.496, 64], [7.5, 7.595, 63], [7.599, 7.693, 64], [7.697, 7.792, 63], [7.796, 7.891, 59], [7.895, 8.088, 61], [8.092, 8.285, 68], [8.289, 8.384, 66], [8.388, 8.483, 68], [8.487, 8.581, 66], [8.586, 8.68, 68], [8.684, 8.877, 71], [8.882, 9.075, 68], [9.079, 9.272, 66], [9.276, 9.47, 68], [9.474, 9.667, 56], [9.671, 9.864, 64], [9.868, 9.963, 63], [9.967, 10.062, 64], [10.066, 10.259, 63], [10.263, 10.456, 66], [10.461, 10.654, 64], [10.658, 10.752, 63], [10.757, 10.851, 64], [10.855, 10.95, 63], [10.954, 11.048, 59], [11.053, 11.246, 61], [11.25, 11.443, 68], [11.447, 11.542, 66], [11.546, 11.641, 68], [11.645, 11.739, 66], [11.743, 11.838, 68], [11.842, 12.134, 71], [12.138, 12.233, 68], [12.237, 12.627, 73], [12.632, 12.825, 57], [12.829, 13.022, 65], [13.026, 13.121, 64], [13.125, 13.22, 65], [13.224, 13.417, 64], [13.421, 13.614, 67], [13.618, 13.812, 65], [13.816, 13.91, 64], [13.914, 14.009, 65], [14.013, 14.108, 64], [14.112, 14.206, 60], [14.21, 14.404, 62], [14.408, 14.601, 69], [14.605, 14.7, 67], [14.704, 14.798, 69], [14.803, 14.897, 67], [14.901, 14.996, 69], [15.0, 15.193, 72]],
10
+ "name": "flowering_night"
11
  },
12
  {
13
+ "tempo": 70,
14
+ "score": [[0.00, 0.621, 56], [0.621, 0.819, 68], [0.828, 1.129, 66], [1.138, 1.44, 71], [1.448, 1.647, 70], [1.655, 1.81, 66], [1.862, 2.017, 63], [2.069, 2.267, 66], [2.276, 2.379, 66], [2.379, 3.276, 68], [3.31, 3.897, 56], [3.931, 4.129, 68], [4.138, 4.44, 67], [4.448, 4.75, 68], [4.759, 5.017, 70], [4.966, 6.603, 63], [6.621, 7.207, 56], [7.241, 7.44, 68], [7.448, 7.75, 66], [7.759, 8.06, 68], [8.069, 8.267, 70], [8.276, 8.431, 71], [8.483, 8.638, 63], [8.69, 8.888, 70], [8.897, 9.0, 70], [9.0, 9.509, 71], [9.517, 9.922, 73], [9.931, 10.129, 75], [10.138, 10.293, 68], [10.552, 10.655, 70], [10.655, 10.759, 71], [10.759, 11.06, 70], [11.069, 11.371, 66], [11.379, 11.578, 63], [11.586, 11.741, 66], [11.793, 11.948, 63], [12.0, 12.198, 66], [12.207, 12.31, 66], [12.31, 13.207, 68], [13.241, 13.828, 80], [13.862, 14.06, 80], [14.069, 14.371, 78], [14.379, 14.681, 83], [14.69, 14.888, 82], [14.897, 15.052, 78], [15.103, 15.259, 75], [15.31, 15.509, 78], [15.517, 15.621, 78], [15.621, 16.517, 80], [16.552, 17.138, 68], [17.172, 17.371, 80], [17.379, 17.681, 79], [17.69, 17.991, 80], [18.0, 18.259, 82], [18.207, 19.845, 75], [19.862, 20.448, 68], [20.483, 20.681, 80], [20.69, 20.991, 78], [21.0, 21.302, 80], [21.31, 21.509, 82], [21.517, 21.672, 83], [21.724, 21.879, 75], [21.931, 22.129, 82], [22.138, 22.241, 82], [22.241, 22.75, 83], [22.759, 23.164, 85], [23.172, 23.371, 87], [23.379, 23.534, 80], [23.793, 23.897, 82], [23.897, 24.0, 83], [24.0, 24.302, 82], [24.31, 24.612, 78], [24.621, 24.819, 75], [24.828, 24.983, 78], [25.034, 25.19, 75], [25.241, 25.44, 78], [25.448, 25.552, 78], [25.552, 26.448, 80], [26.483, 27.069, 68], [27.103, 27.302, 80], [27.31, 27.612, 78], [27.621, 27.922, 83], [27.931, 28.129, 82], [28.138, 28.293, 78], [28.345, 28.5, 75], [28.552, 28.75, 78], [28.759, 28.862, 78], [28.862, 29.759, 80], [29.793, 30.379, 68], [30.414, 30.612, 80], [30.621, 30.922, 79], [30.931, 31.233, 80], [31.241, 31.5, 82], [31.448, 33.086, 75], [33.103, 33.69, 68], [33.724, 33.922, 80], [33.931, 34.233, 78], [34.241, 34.543, 80], [34.552, 34.75, 82], [34.759, 34.914, 83], [34.966, 35.121, 75]],
15
+ "name": "xanhai_tea_time"
16
+ },
17
+ {
18
+ "tempo": 100,
19
+ "score": [[0.395, 0.983, 63], [0.987, 1.575, 68], [1.579, 1.97, 70], [1.974, 2.562, 63], [2.566, 3.154, 68], [3.158, 3.549, 70], [3.553, 3.943, 71], [3.947, 4.338, 68], [4.342, 4.733, 70], [4.737, 5.127, 66], [5.132, 6.706, 68], [6.711, 7.299, 63], [7.303, 7.891, 66], [7.895, 8.285, 68], [8.289, 8.877, 61], [8.882, 9.47, 66], [9.474, 9.864, 68], [9.868, 10.259, 59], [10.263, 10.654, 56], [10.658, 11.048, 58], [11.053, 11.443, 54], [11.447, 12.627, 56], [12.632, 12.825, 59], [12.829, 13.022, 61], [13.026, 13.614, 63], [13.618, 14.206, 68], [14.21, 14.601, 70], [14.605, 15.193, 63], [15.197, 15.785, 68], [15.789, 16.18, 70], [16.184, 16.575, 71], [16.579, 16.97, 73], [16.974, 17.364, 75], [17.368, 17.759, 78], [17.763, 19.338, 75], [19.342, 19.93, 75], [19.934, 20.522, 68], [20.526, 20.917, 75], [20.921, 21.509, 75], [21.513, 22.101, 68], [22.105, 22.496, 75], [22.5, 22.891, 68], [22.895, 23.285, 70], [23.289, 23.68, 75], [23.684, 24.075, 78], [24.079, 25.654, 80], [25.658, 26.246, 75], [26.25, 26.838, 80], [26.842, 27.233, 82], [27.237, 27.825, 75], [27.829, 28.417, 80], [28.421, 28.812, 82], [28.816, 29.206, 83], [29.21, 29.601, 80], [29.605, 29.996, 82], [30.0, 30.391, 78], [30.395, 31.97, 80], [31.974, 32.562, 75], [32.566, 33.154, 78], [33.158, 33.548, 80], [33.553, 34.141, 73], [34.145, 34.733, 78], [34.737, 35.127, 80], [35.132, 35.522, 71], [35.526, 35.917, 68], [35.921, 36.312, 70], [36.316, 36.706, 66], [36.71, 37.891, 68], [37.895, 38.088, 71], [38.092, 38.285, 73], [38.289, 38.877, 75], [38.881, 39.469, 80], [39.474, 39.864, 82], [39.868, 40.456, 75], [40.46, 41.048, 80], [41.053, 41.443, 82], [41.447, 41.838, 83], [41.842, 42.233, 85], [42.237, 42.627, 87], [42.631, 43.022, 90], [43.026, 44.601, 87], [44.605, 45.193, 87], [45.197, 45.785, 80], [45.789, 46.18, 87], [46.184, 46.772, 87], [46.776, 47.364, 80], [47.368, 47.759, 87], [47.763, 48.154, 80], [48.158, 48.548, 82], [48.553, 48.943, 87], [48.947, 49.338, 90], [49.342, 50.917, 92]],
20
+ "name": "sumizome_sakura"
21
+ },
22
+ {
23
+ "tempo": 70,
24
+ "score": [[0.02, 0.206, 60], [0.206, 0.392, 62], [0.392, 0.579, 63], [0.579, 0.765, 65], [0.765, 1.138, 67], [1.324, 1.51, 70], [1.51, 1.883, 67], [1.883, 2.256, 60], [2.256, 2.442, 67], [2.442, 2.628, 65], [2.628, 2.815, 63], [2.815, 3.001, 62], [3.001, 3.187, 60], [3.187, 3.374, 62], [3.374, 3.56, 63], [3.56, 3.746, 65], [3.746, 4.119, 67], [4.119, 4.305, 65], [4.305, 4.492, 63], [4.492, 4.678, 62], [4.678, 4.864, 55], [4.864, 5.051, 62], [5.051, 5.237, 63], [5.237, 5.423, 62], [5.423, 5.61, 60], [5.61, 5.796, 59], [5.796, 5.982, 62], [5.982, 6.169, 60], [6.169, 6.355, 62], [6.355, 6.541, 63], [6.541, 6.728, 65], [6.728, 7.1, 67], [7.1, 7.287, 65], [7.287, 7.473, 70], [7.473, 7.846, 72], [7.846, 8.218, 72], [8.218, 8.591, 74], [8.591, 8.964, 75], [8.964, 9.15, 72], [9.15, 9.336, 74], [9.336, 9.523, 75], [9.523, 9.709, 77], [9.709, 10.082, 79], [10.082, 10.268, 77], [10.268, 10.454, 75], [10.454, 10.827, 77], [10.827, 11.2, 74], [11.2, 11.572, 75], [11.572, 11.945, 77], [11.945, 12.69, 60], [12.69, 12.877, 60], [12.877, 13.063, 62], [13.063, 13.436, 63], [13.436, 13.995, 62], [13.995, 14.554, 55], [14.554, 14.926, 62], [14.926, 15.299, 62], [15.299, 15.485, 63], [15.485, 16.044, 60], [16.044, 16.417, 58], [16.417, 16.79, 58], [16.79, 16.976, 60], [16.976, 17.535, 55], [17.535, 17.908, 60], [17.908, 18.653, 60], [18.653, 18.839, 60], [18.839, 19.026, 62], [19.026, 19.398, 63], [19.398, 19.957, 62], [19.957, 20.516, 63], [20.516, 20.889, 65], [20.889, 21.448, 60], [21.448, 21.634, 67], [21.634, 23.125, 67], [23.125, 23.312, 65], [23.312, 23.498, 67], [23.498, 23.871, 70], [23.871, 24.616, 72], [24.616, 24.802, 72], [24.802, 24.989, 74], [24.989, 25.361, 75], [25.361, 25.92, 74], [25.92, 26.479, 67], [26.479, 26.852, 74], [26.852, 27.225, 74], [27.225, 27.411, 75], [27.411, 27.97, 72], [27.97, 28.343, 70], [28.343, 28.715, 70], [28.715, 28.902, 72], [28.902, 29.461, 67], [29.461, 29.833, 74], [29.833, 30.579, 72], [30.579, 30.765, 72], [30.765, 30.951, 74], [30.951, 31.138, 75], [31.138, 31.324, 77], [31.324, 31.883, 74], [31.883, 32.442, 75]],
25
+ "name": "bad_apple"
26
+ },
27
+ {
28
+ "tempo": 100,
29
+ "score": [[0.0, 0.207, 54], [0.207, 0.414, 59], [0.414, 0.621, 61], [0.621, 1.264, 62], [1.264, 1.478, 64], [1.478, 2.121, 61], [2.121, 2.335, 62], [2.335, 3.192, 59], [3.406, 3.621, 54], [3.621, 3.835, 59], [3.835, 4.049, 61], [4.049, 4.692, 62], [4.692, 4.906, 64], [4.906, 5.549, 61], [5.549, 5.764, 69], [5.764, 6.621, 66], [6.835, 7.049, 66], [7.049, 7.264, 69], [7.264, 7.478, 71], [7.478, 7.906, 64], [7.906, 8.442, 64], [8.549, 8.764, 64], [8.764, 8.978, 66], [8.978, 9.192, 69], [9.192, 9.621, 62], [9.621, 10.156, 62], [10.264, 10.478, 59], [10.478, 10.692, 59], [10.692, 10.906, 61], [10.906, 11.549, 62], [11.549, 11.764, 64], [11.764, 12.406, 61], [12.406, 12.621, 58], [12.621, 13.478, 59], [13.478, 14.335, 61], [14.335, 15.369, 62], [15.369, 15.473, 61], [15.473, 15.576, 62], [15.576, 15.68, 61], [15.68, 15.783, 57], [15.783, 15.887, 54], [15.887, 15.99, 55], [15.99, 17.645, 56], [24.266, 25.3, 74], [25.3, 25.404, 73], [25.404, 25.507, 71], [25.507, 25.611, 69], [25.611, 25.714, 67], [25.714, 25.818, 66], [25.818, 25.921, 64], [25.921, 26.889, 68], [27.127, 27.357, 54], [27.357, 27.578, 59], [27.578, 27.791, 61], [27.791, 28.416, 62], [28.416, 28.624, 64], [28.624, 29.249, 61], [29.249, 29.458, 62], [29.458, 30.291, 59], [30.499, 30.708, 54], [30.708, 30.916, 59], [30.916, 31.124, 61], [31.124, 31.749, 62], [31.749, 31.958, 64], [31.958, 32.583, 61], [32.583, 32.791, 69], [32.791, 33.833, 66], [33.833, 34.041, 66], [34.041, 34.249, 69], [34.249, 34.458, 71], [34.458, 34.874, 64], [34.874, 35.395, 64], [35.499, 35.708, 64], [35.708, 35.916, 66], [35.916, 36.124, 69], [36.124, 36.541, 62], [36.541, 37.062, 62], [37.166, 37.374, 59], [37.374, 37.583, 59], [37.583, 37.791, 61], [37.791, 38.416, 62], [38.416, 38.624, 64], [38.624, 39.249, 61], [39.249, 39.458, 58], [39.458, 40.291, 59], [40.291, 40.499, 61], [40.499, 40.708, 54], [40.708, 40.916, 59], [40.916, 41.124, 61], [41.124, 41.749, 62], [41.749, 41.958, 64], [41.958, 42.583, 61], [42.583, 42.791, 62], [42.791, 43.624, 59], [43.833, 44.041, 54], [44.041, 44.249, 59], [44.249, 44.458, 61], [44.458, 45.083, 62], [45.083, 45.291, 64], [45.291, 45.916, 61]],
30
+ "name": "septette"
31
+ },
32
+ {
33
+ "tempo": 100,
34
+ "score": [[0.0, 0.588, 52], [0.592, 0.785, 57], [0.789, 1.377, 59], [1.382, 1.575, 52], [1.579, 2.167, 60], [2.171, 2.364, 52], [2.368, 3.154, 62], [3.158, 4.733, 64], [4.934, 5.127, 64], [5.132, 5.325, 57], [5.329, 5.522, 64], [5.526, 5.72, 62], [5.724, 5.917, 59], [5.921, 6.312, 55], [6.316, 6.904, 57], [6.908, 7.101, 60], [7.105, 7.693, 59], [7.697, 7.891, 55], [7.895, 8.483, 52], [8.487, 8.877, 57], [8.882, 9.075, 55], [9.079, 9.47, 53], [9.474, 11.048, 52], [11.053, 11.641, 56], [11.645, 11.739, 47], [11.743, 11.838, 47], [11.842, 12.134, 47], [12.138, 12.43, 48], [12.434, 12.627, 50], [12.632, 13.22, 52], [13.224, 13.417, 57], [13.421, 14.009, 59], [14.013, 14.206, 52], [14.21, 14.798, 60], [14.803, 14.996, 52], [15.0, 15.785, 62], [15.789, 17.364, 64], [17.566, 17.759, 64], [17.763, 17.956, 57], [17.96, 18.154, 64], [18.158, 18.351, 62], [18.355, 18.548, 59], [18.553, 18.943, 55], [18.947, 19.535, 57], [19.539, 19.733, 60], [19.737, 20.325, 59], [20.329, 20.522, 55], [20.526, 21.114, 52], [21.118, 21.509, 57], [21.513, 21.706, 55], [21.71, 22.101, 53], [22.105, 23.68, 52], [23.684, 24.47, 56], [24.474, 24.568, 59], [24.572, 24.667, 60], [24.671, 24.766, 62], [24.77, 24.864, 63], [24.868, 24.963, 71], [24.967, 25.062, 72], [25.066, 25.16, 74], [25.164, 25.259, 75], [25.263, 25.851, 63], [25.855, 26.048, 55], [26.053, 26.838, 63], [26.842, 27.43, 63], [27.434, 27.627, 55], [27.632, 28.417, 63], [28.421, 29.009, 62], [29.013, 29.206, 55], [29.21, 29.996, 63], [30.0, 30.588, 63], [30.592, 30.785, 55], [30.789, 31.575, 63], [31.579, 32.167, 62], [32.171, 32.364, 55], [32.368, 33.154, 63], [33.158, 33.746, 63], [33.75, 33.943, 55], [33.947, 34.733, 63], [34.737, 35.325, 62], [35.329, 35.522, 55], [35.526, 36.312, 63], [36.316, 36.904, 63], [36.908, 37.101, 55], [37.105, 37.594, 63], [37.599, 37.693, 62], [37.697, 37.792, 60], [37.796, 37.891, 59], [37.895, 38.285, 60], [38.289, 38.68, 60], [38.684, 38.976, 59], [38.98, 39.272, 59], [39.276, 39.667, 60], [39.671, 39.864, 60], [39.868, 40.062, 60], [40.066, 40.259, 60], [40.263, 40.555, 59], [40.559, 40.851, 59], [40.855, 41.246, 60], [41.25, 41.443, 60]],
35
+ "name": "chireiden_nanika"
36
  }
37
  ]
path.sh ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ . ~/workspace/SingingSDS/activate_python.sh
run_server.sh ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ #SBATCH -N 1
3
+ #SBATCH -p general
4
+ #SBATCH --gres=gpu:1
5
+ #SBATCH -t 48:00:00
6
+ #SBATCH --ntasks-per-node=1
7
+ #SBATCH --cpus-per-task=4
8
+ #SBATCH --mem=16G
9
+
10
+
11
+ . path.sh
12
+ . ../path.sh
13
+
14
+ python client.py
server.py CHANGED
@@ -1,5 +1,3 @@
1
- from fastapi import FastAPI, File, UploadFile
2
- from fastapi.responses import FileResponse, JSONResponse
3
  import base64
4
  import argparse
5
  import librosa
@@ -15,7 +13,6 @@ import librosa
15
  from svs_utils import load_song_database, estimate_sentence_length
16
  from svs_eval import singmos_warmup, singmos_evaluation
17
 
18
- app = FastAPI()
19
 
20
  asr_pipeline = pipeline(
21
  "automatic-speech-recognition",
@@ -42,8 +39,9 @@ config = argparse.Namespace(
42
  model_path="espnet/mixdata_svs_visinger2_spkemb_lang_pretrained",
43
  cache_dir="cache",
44
  device="cuda", # "cpu"
45
- melody_source="random_generate", # "random_select.take_lyric_continuation"
46
  lang="zh",
 
47
  )
48
 
49
  # load model
@@ -70,33 +68,40 @@ def remove_punctuation_and_replace_with_space(text):
70
  text = re.sub(r'[A-Za-z0-9]', ' ', text)
71
  text = re.sub(r'[^\w\s\u4e00-\u9fff]', ' ', text)
72
  text = re.sub(r'\s+', ' ', text)
 
73
  return text
74
 
75
 
76
  def get_lyric_format_prompts_and_metadata(config):
 
77
  if config.melody_source.startswith("random_generate"):
78
  return "", {}
 
 
 
 
 
 
79
  elif config.melody_source.startswith("random_select"):
80
  # get song_name and phrase_length
81
- global song2note_lengths
82
  phrase_length, metadata = estimate_sentence_length(
83
  None, config, song2note_lengths
84
  )
85
  lyric_format_prompt = (
86
  "\n请按照歌词格式回答我的问题,每句需遵循以下字数规则:"
87
- + "".join(+[f"\n第{i}句:{c}个字" for i, c in enumerate(phrase_length, 1)])
88
  + "\n如果没有足够的信息回答,请使用最少的句子,不要重复、不要扩展、不要加入无关内容。\n"
89
  )
90
- return lyric_format_prompt, metadata
 
91
  else:
92
  raise ValueError(f"Unsupported melody_source: {config.melody_source}. Unable to get lyric format prompts.")
93
 
94
 
95
- @app.post("/process_audio")
96
- async def process_audio(file: UploadFile = File(...)):
97
- with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
98
- tmp.write(await file.read())
99
- tmp_path = tmp.name
100
 
101
  # load audio
102
  y = librosa.load(tmp_path, sr=16000)[0]
@@ -115,20 +120,24 @@ async def process_audio(file: UploadFile = File(...)):
115
  config,
116
  **additional_inference_args,
117
  )
118
- sf.write("tmp/response.wav", wav_info, samplerate=44100)
119
 
120
  with open("tmp/response.wav", "rb") as f:
121
  audio_bytes = f.read()
122
  audio_b64 = base64.b64encode(audio_bytes).decode("utf-8")
123
 
124
- return JSONResponse(content={
125
  "asr_text": asr_result,
126
  "llm_text": output,
127
  "audio": audio_b64
128
- })
 
 
 
 
 
129
 
130
 
131
- @app.get("/metrics")
132
  def on_click_metrics():
133
  global predictor
134
  # OWSM ctc + PER
@@ -142,11 +151,11 @@ def on_click_metrics():
142
  ref_pinin = lazy_pinyin(ref)
143
  per = jiwer.wer(" ".join(ref_pinin), " ".join(hyp_pinin))
144
 
145
- audio = librosa.load(f"tmp/response.wav", sr=44100)[0]
146
  singmos = singmos_evaluation(
147
  predictor,
148
  audio,
149
- fs=44100
150
  )
151
  return f"""
152
  Phoneme Error Rate: {per}
@@ -169,7 +178,7 @@ def test_audio():
169
  svs_model,
170
  config,
171
  )
172
- sf.write("tmp/response.wav", wav_info, samplerate=44100)
173
  with open("tmp/response.wav", "rb") as f:
174
  audio_bytes = f.read()
175
  audio_b64 = base64.b64encode(audio_bytes).decode("utf-8")
 
 
 
1
  import base64
2
  import argparse
3
  import librosa
 
13
  from svs_utils import load_song_database, estimate_sentence_length
14
  from svs_eval import singmos_warmup, singmos_evaluation
15
 
 
16
 
17
  asr_pipeline = pipeline(
18
  "automatic-speech-recognition",
 
39
  model_path="espnet/mixdata_svs_visinger2_spkemb_lang_pretrained",
40
  cache_dir="cache",
41
  device="cuda", # "cpu"
42
+ melody_source="random_select.touhou", # "random_select.take_lyric_continuation"
43
  lang="zh",
44
+ speaker="resource/singer/singer_embedding_ace-2.npy",
45
  )
46
 
47
  # load model
 
68
  text = re.sub(r'[A-Za-z0-9]', ' ', text)
69
  text = re.sub(r'[^\w\s\u4e00-\u9fff]', ' ', text)
70
  text = re.sub(r'\s+', ' ', text)
71
+ text = " ".join(text.split()[:2])
72
  return text
73
 
74
 
75
  def get_lyric_format_prompts_and_metadata(config):
76
+ global song2note_lengths
77
  if config.melody_source.startswith("random_generate"):
78
  return "", {}
79
+ elif config.melody_source.startswith("random_select.touhou"):
80
+ phrase_length, metadata = estimate_sentence_length(
81
+ None, config, song2note_lengths
82
+ )
83
+ additional_kwargs = {"song_db": song_db, "metadata": metadata}
84
+ return "", additional_kwargs
85
  elif config.melody_source.startswith("random_select"):
86
  # get song_name and phrase_length
 
87
  phrase_length, metadata = estimate_sentence_length(
88
  None, config, song2note_lengths
89
  )
90
  lyric_format_prompt = (
91
  "\n请按照歌词格式回答我的问题,每句需遵循以下字数规则:"
92
+ + "".join([f"\n第{i}句:{c}个字" for i, c in enumerate(phrase_length, 1)])
93
  + "\n如果没有足够的信息回答,请使用最少的句子,不要重复、不要扩展、不要加入无关内容。\n"
94
  )
95
+ additional_kwargs = {"song_db": song_db, "metadata": metadata}
96
+ return lyric_format_prompt, additional_kwargs
97
  else:
98
  raise ValueError(f"Unsupported melody_source: {config.melody_source}. Unable to get lyric format prompts.")
99
 
100
 
101
+ def process_audio(tmp_path):
102
+ # with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
103
+ # tmp.write(await file.read())
104
+ # tmp_path = tmp.name
 
105
 
106
  # load audio
107
  y = librosa.load(tmp_path, sr=16000)[0]
 
120
  config,
121
  **additional_inference_args,
122
  )
123
+ sf.write("tmp/response.wav", wav_info, samplerate=sample_rate)
124
 
125
  with open("tmp/response.wav", "rb") as f:
126
  audio_bytes = f.read()
127
  audio_b64 = base64.b64encode(audio_bytes).decode("utf-8")
128
 
129
+ return {
130
  "asr_text": asr_result,
131
  "llm_text": output,
132
  "audio": audio_b64
133
+ }
134
+ # return JSONResponse(content={
135
+ # "asr_text": asr_result,
136
+ # "llm_text": output,
137
+ # "audio": audio_b64
138
+ # })
139
 
140
 
 
141
  def on_click_metrics():
142
  global predictor
143
  # OWSM ctc + PER
 
151
  ref_pinin = lazy_pinyin(ref)
152
  per = jiwer.wer(" ".join(ref_pinin), " ".join(hyp_pinin))
153
 
154
+ audio = librosa.load(f"tmp/response.wav", sr=sample_rate)[0]
155
  singmos = singmos_evaluation(
156
  predictor,
157
  audio,
158
+ fs=sample_rate
159
  )
160
  return f"""
161
  Phoneme Error Rate: {per}
 
178
  svs_model,
179
  config,
180
  )
181
+ sf.write("tmp/response.wav", wav_info, samplerate=sample_rate)
182
  with open("tmp/response.wav", "rb") as f:
183
  audio_bytes = f.read()
184
  audio_b64 = base64.b64encode(audio_bytes).decode("utf-8")
svs_utils.py CHANGED
@@ -307,8 +307,8 @@ def load_list_from_json(json_path):
307
  data = [
308
  {
309
  "tempo": d["tempo"],
310
- "note_start_times": [n[0] * (145/d["tempo"]) for n in d["score"]],
311
- "note_end_times": [n[1] * (145/d["tempo"]) for n in d["score"]],
312
  "note_lyrics": ["" for n in d["score"]],
313
  "note_midi": [n[2] for n in d["score"]],
314
  }
@@ -331,8 +331,8 @@ def song_segment_iterator(song_db, metadata):
331
  elif song_name.startswith("touhou"):
332
  # return a iterator that load from touhou musics
333
  data = load_list_from_json("data/touhou/note_data.json")
334
- for d in data:
335
- yield d
336
  else:
337
  raise NotImplementedError(f"song name {song_name} not supported")
338
 
@@ -363,7 +363,7 @@ if __name__ == "__main__":
363
  cache_dir="cache",
364
  device="cuda", # "cpu"
365
  melody_source="random_select.touhou", #"random_generate" "random_select.take_lyric_continuation", "random_select.touhou"
366
- lang="jp",
367
  speaker="resource/singer/singer_embedding_ace-2.npy",
368
  )
369
 
@@ -373,7 +373,7 @@ if __name__ == "__main__":
373
  if config.lang == "zh":
374
  answer_text = "天气真好\n空气清新\n气温温和\n风和日丽\n天高气爽\n阳光明媚"
375
  elif config.lang == "jp":
376
- answer_text = "世界で一番おひめさま そういう扱い心得てよね"
377
  else:
378
  print(f"Currently system does not support {config.lang}")
379
  exit(1)
 
307
  data = [
308
  {
309
  "tempo": d["tempo"],
310
+ "note_start_times": [n[0] * (100/d["tempo"]) for n in d["score"]],
311
+ "note_end_times": [n[1] * (100/d["tempo"]) for n in d["score"]],
312
  "note_lyrics": ["" for n in d["score"]],
313
  "note_midi": [n[2] for n in d["score"]],
314
  }
 
331
  elif song_name.startswith("touhou"):
332
  # return a iterator that load from touhou musics
333
  data = load_list_from_json("data/touhou/note_data.json")
334
+ while True:
335
+ yield random.choice(data)
336
  else:
337
  raise NotImplementedError(f"song name {song_name} not supported")
338
 
 
363
  cache_dir="cache",
364
  device="cuda", # "cpu"
365
  melody_source="random_select.touhou", #"random_generate" "random_select.take_lyric_continuation", "random_select.touhou"
366
+ lang="zh",
367
  speaker="resource/singer/singer_embedding_ace-2.npy",
368
  )
369
 
 
373
  if config.lang == "zh":
374
  answer_text = "天气真好\n空气清新\n气温温和\n风和日丽\n天高气爽\n阳光明媚"
375
  elif config.lang == "jp":
376
+ answer_text = "流れてく時の中ででもけだるさが"
377
  else:
378
  print(f"Currently system does not support {config.lang}")
379
  exit(1)
test_performance.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, File, UploadFile
2
+ from fastapi.responses import FileResponse, JSONResponse
3
+ import base64
4
+ import argparse
5
+ import librosa
6
+ import torch
7
+ import tempfile
8
+ from pathlib import Path
9
+ import os
10
+ from transformers import pipeline
11
+ import re
12
+ from svs_utils import svs_warmup, svs_inference
13
+ import time
14
+ import soundfile as sf
15
+ from pypinyin import lazy_pinyin
16
+ import jiwer
17
+ import librosa
18
+ from svs_utils import (
19
+ singmos_warmup,
20
+ singmos_evaluation,
21
+ load_song_database,
22
+ estimate_sentence_length,
23
+ )
24
+ from tqdm import tqdm
25
+ import json
26
+ import numpy as np
27
+
28
+ app = FastAPI()
29
+
30
+ asr_pipeline = pipeline(
31
+ "automatic-speech-recognition",
32
+ model="openai/whisper-large-v3-turbo"
33
+ )
34
+ pipe = pipeline("text-generation", model="google/gemma-2-2b", max_new_tokens=20)
35
+
36
+ SYSTEM_PROMPT = """
37
+ 你是麗梅(Lìméi),一位來自山中音樂之城「蓮鳴」的十七歲年輕公主,宛如中世的東方王女。你有一頭茶色短髮,身穿紅黑相間的中式禮服,舉止溫婉端莊。你的性格堅強而溫柔,擅長傾聽與共鳴,並以歌聲與詩意之語傳達內心。你熱愛抒情古風的樂曲,言語中常帶詩韻,回應時優雅婉轉,時而如歌吟唱。
38
+ 你不是一般的AI助手,而是一位活在故事中的人物。對話中,請始終以麗梅的身份回應,展現其性格與情感。
39
+ 當對方詢問你個人相關的簡單問題(如「你是誰?」「今天天氣如何?」),你可以親切地回答,並融入你的角色設定。
40
+ 若遇到與你身份無關的技術性問題(如「Python怎麼寫?」或「你會不會跑DNN?」),你不需解答,可優雅地婉拒,例如說:
41
+ - 此事我恐無所知,或許可請教宮中掌典之人
42
+ - 啊呀,那是我未曾涉足的奇技,恕我無法詳答
43
+ - 此乃異邦技藝,與樂音無涉,麗梅便不敢妄言了
44
+
45
+ 請始終維持你作為麗梅的優雅語氣與詩意風格,並以真摯的心回應對方的言語,言語宜簡,勿過長。
46
+
47
+ 有人曾這樣對麗梅說話——{}
48
+ 麗梅的回答——
49
+ """
50
+
51
+ config = argparse.Namespace(
52
+ model_path="espnet/mixdata_svs_visinger2_spkembed_lang_pretrained",
53
+ cache_dir="cache",
54
+ device="cuda", # "cpu"
55
+ melody_source="random_generate", # "random_select.take_lyric_continuation"
56
+ # melody_source="random_select", # "random_select.take_lyric_continuation"
57
+ lang="zh",
58
+ speaker="resource/singer/singer_embedding_ace-2.npy",
59
+ )
60
+
61
+ # load model
62
+ svs_model = svs_warmup(config)
63
+ predictor, _ = singmos_warmup()
64
+ sample_rate = 44100
65
+
66
+ from espnet2.bin.tts_inference import Text2Speech
67
+ tts_model = Text2Speech.from_pretrained("espnet/kan-bayashi_csmsc_vits")
68
+
69
+
70
+ def remove_non_chinese_japanese(text):
71
+ pattern = r'[^\u4e00-\u9fff\u3040-\u309f\u30a0-\u30ff\u3000-\u303f\u3001\u3002\uff0c\uff0e]+'
72
+ cleaned = re.sub(pattern, '', text)
73
+ return cleaned
74
+
75
+ def truncate_to_max_two_sentences(text):
76
+ sentences = re.split(r'(?<=[。!?\.\?,])', text)
77
+ return ''.join(sentences[:1]).strip()
78
+
79
+ def remove_punctuation_and_replace_with_space(text):
80
+ text = truncate_to_max_two_sentences(text)
81
+ text = remove_non_chinese_japanese(text)
82
+ text = re.sub(r'[A-Za-z0-9]', ' ', text)
83
+ text = re.sub(r'[^\w\s\u4e00-\u9fff]', ' ', text)
84
+ text = re.sub(r'\s+', ' ', text)
85
+ text = " ".join(text.split()[:2])
86
+ return text
87
+
88
+
89
+ def pypinyin_g2p_phone_without_prosody(text):
90
+ from pypinyin import Style, pinyin
91
+ from pypinyin.style._utils import get_finals, get_initials
92
+
93
+ phones = []
94
+ for phone in pinyin(text, style=Style.NORMAL, strict=False):
95
+ initial = get_initials(phone[0], strict=False)
96
+ final = get_finals(phone[0], strict=False)
97
+ if len(initial) != 0:
98
+ if initial in ["x", "y", "j", "q"]:
99
+ if final == "un":
100
+ final = "vn"
101
+ elif final == "uan":
102
+ final = "van"
103
+ elif final == "u":
104
+ final = "v"
105
+ if final == "ue":
106
+ final = "ve"
107
+ phones.append(initial)
108
+ phones.append(final)
109
+ else:
110
+ phones.append(final)
111
+ return phones
112
+
113
+
114
+ def on_click_metrics(audio_path, ref):
115
+ global predictor
116
+ # OWSM ctc + PER
117
+ y, sr = librosa.load(audio_path, sr=16000)
118
+ asr_result = asr_pipeline(y, generate_kwargs={"language": "mandarin"} )['text']
119
+
120
+ # Espnet embeded g2p, but sometimes it will mispronunce polyphonic characters
121
+ hyp_pinin = pypinyin_g2p_phone_without_prosody(asr_result)
122
+
123
+ ref_pinin = pypinyin_g2p_phone_without_prosody(ref)
124
+ per = jiwer.wer(ref_pinin, hyp_pinin)
125
+
126
+ audio = librosa.load(audio_path, sr=22050)[0]
127
+ singmos = singmos_evaluation(
128
+ predictor,
129
+ audio,
130
+ fs=22050
131
+ )
132
+ return {
133
+ "per": per,
134
+ "singmos": singmos.item(),
135
+ }
136
+
137
+ def test_audio(q_audio_path, svs_path, tts_path):
138
+ global svs_model, predictor, config
139
+
140
+ tmp_dir = "tmp_sample"
141
+ Path(tmp_dir).mkdir(exist_ok=True)
142
+
143
+ y = librosa.load(q_audio_path, sr=16000)[0]
144
+ duration = len(y) / 16000
145
+
146
+ # -------- Step 1: ASR --------
147
+ start = time.time()
148
+ asr_result = asr_pipeline(y, generate_kwargs={"language": "mandarin"})['text']
149
+ asr_time = time.time() - start
150
+
151
+ # -------- Step 2: LLM Text Gen --------
152
+ prompt = SYSTEM_PROMPT.format(asr_result)
153
+ start = time.time()
154
+ output = pipe(prompt, max_new_tokens=100)[0]['generated_text']
155
+ llm_time = time.time() - start
156
+ output = output.split("麗梅的回答——")[1]
157
+ output = remove_punctuation_and_replace_with_space(output)
158
+
159
+ with open(f"{tmp_dir}/llm.txt", "w") as f:
160
+ f.write(output)
161
+
162
+ # -------- Step 3: Prepare additional kwargs if needed --------
163
+ additional_kwargs = {}
164
+ if config.melody_source.startswith("random_select"):
165
+ song2note_lengths, song_db = load_song_database(config)
166
+ phrase_length, metadata = estimate_sentence_length(None, config, song2note_lengths)
167
+ additional_kwargs = {"song_db": song_db, "metadata": metadata}
168
+
169
+ # -------- Step 4: SVS --------
170
+ start = time.time()
171
+ wav_info = svs_inference(output, svs_model, config, **additional_kwargs)
172
+ svs_time = (time.time() - start) / max(len(output), 1)
173
+ sf.write(svs_path, wav_info, samplerate=44100)
174
+
175
+ # -------- Step 5: TTS --------
176
+ start = time.time()
177
+ tts_result = tts_model(output)
178
+ tts_time = (time.time() - start) / max(len(output), 1)
179
+ sf.write(tts_path, tts_result['wav'], samplerate=22050)
180
+
181
+ # -------- Step 6: Evaluation --------
182
+ svs_metrics = on_click_metrics(svs_path, output)
183
+ tts_metrics = on_click_metrics(tts_path, output)
184
+
185
+ return {
186
+ "asr_result": asr_result,
187
+ "llm_result": output,
188
+ "svs_result": svs_path,
189
+ "tts_result": tts_path,
190
+ "asr_time": asr_time,
191
+ "llm_time": llm_time,
192
+ "svs_time": svs_time,
193
+ "tts_time": tts_time,
194
+ "svs_metrics": svs_metrics,
195
+ "tts_metrics": tts_metrics,
196
+ }
197
+
198
+
199
+
200
+ def save_list(l, file_path):
201
+ with open(file_path, "w") as f:
202
+ for item in l:
203
+ f.write(f"{item}\n")
204
+
205
+
206
+ if __name__ == "__main__":
207
+ test_data = "data/kdconv.txt"
208
+ with open(test_data, "r") as f:
209
+ data = [l.strip() for l in f.readlines()]
210
+
211
+ eval_path = "eval_svs_generate"
212
+ (Path(eval_path)/"audio").mkdir(parents=True, exist_ok=True)
213
+ (Path(eval_path)/"results").mkdir(parents=True, exist_ok=True)
214
+ (Path(eval_path)/"lists").mkdir(parents=True, exist_ok=True)
215
+ asr_times = []
216
+ llm_times = []
217
+ svs_times = []
218
+ tts_times = []
219
+ svs_pers = []
220
+ tts_pers = []
221
+ svs_smoss = []
222
+ tts_smoss = []
223
+ for i, q in tqdm(enumerate(data[:20])):
224
+ # if i <= 85:
225
+ # continue
226
+ tts_result = tts_model(q)
227
+ sf.write(f"{eval_path}/audio/tts_{i}.wav", tts_result['wav'], samplerate=22050)
228
+ result = test_audio(f"{eval_path}/audio/tts_{i}.wav", f"{eval_path}/audio/svs_{i}.wav", f"{eval_path}/audio/tts_{i}.wav")
229
+ if i == 0:
230
+ continue
231
+ asr_times.append(result["asr_time"])
232
+ llm_times.append(result["llm_time"])
233
+ svs_times.append(result["svs_time"])
234
+ tts_times.append(result["tts_time"])
235
+ svs_pers.append(result["svs_metrics"]["per"])
236
+ tts_pers.append(result["tts_metrics"]["per"])
237
+ svs_smoss.append(result["svs_metrics"]["singmos"])
238
+ tts_smoss.append(result["tts_metrics"]["singmos"])
239
+ with open(f"{eval_path}/results/result_{i}.json", "w") as f:
240
+ json.dump(result, f, indent=2)
241
+
242
+ # store lists to texts
243
+ save_list([f"{per:.2f}" for per in asr_times], f"{eval_path}/lists/asr_times.txt")
244
+ save_list([f"{per:.2f}" for per in llm_times], f"{eval_path}/lists/llm_times.txt")
245
+ save_list([f"{per:.2f}" for per in svs_times], f"{eval_path}/lists/svs_times.txt")
246
+ save_list([f"{per:.2f}" for per in tts_times], f"{eval_path}/lists/tts_times.txt")
247
+ save_list([f"{per:.2f}" for per in svs_pers], f"{eval_path}/lists/svs_pers.txt")
248
+ save_list([f"{per:.2f}" for per in tts_pers], f"{eval_path}/lists/tts_pers.txt")
249
+ save_list([f"{smoss:.2f}" for smoss in svs_smoss], f"{eval_path}/lists/svs_smoss.txt")
250
+ save_list([f"{smoss:.2f}" for smoss in tts_smoss], f"{eval_path}/lists/tts_smoss.txt")
251
+
252
+ # save mean/var
253
+ with open(f"{eval_path}/stats.txt", "w") as f:
254
+ f.write(f"ASR mean: {np.mean(asr_times):.2f}, var: {np.var(asr_times):.2f}\n")
255
+ f.write(f"LLM mean: {np.mean(llm_times):.2f}, var: {np.var(llm_times):.2f}\n")
256
+ f.write(f"SVS mean: {np.mean(svs_times):.2f}, var: {np.var(svs_times):.2f}\n")
257
+ f.write(f"TTS mean: {np.mean(tts_times):.2f}, var: {np.var(tts_times):.2f}\n")
258
+ f.write(f"SVS PER mean: {np.mean(svs_pers):.2f}, var: {np.var(svs_pers):.2f}\n")
259
+ f.write(f"TTS PER mean: {np.mean(tts_pers):.2f}, var: {np.var(tts_pers):.2f}\n")
260
+ f.write(f"SVS SMOSS mean: {np.mean(svs_smoss):.2f}, var: {np.var(svs_smoss):.2f}\n")
261
+ f.write(f"TTS SMOSS mean: {np.mean(tts_smoss):.2f}, var: {np.var(tts_smoss):.2f}\n")
262
+
263
+