admin commited on
Commit
d574298
·
1 Parent(s): 05f952e
Files changed (4) hide show
  1. app.py +57 -59
  2. model.py +9 -4
  3. requirements.txt +5 -3
  4. utils.py +62 -12
app.py CHANGED
@@ -8,25 +8,17 @@ import gradio as gr
8
  import librosa.display
9
  import matplotlib.pyplot as plt
10
  from model import EvalNet
11
- from utils import get_modelist, find_wav_files, embed_img
12
-
13
-
14
- TRANSLATE = {
15
- "vibrato": "Rou xian",
16
- "trill": "Chan yin",
17
- "tremolo": "Chan gong",
18
- "staccato": "Dun gong",
19
- "ricochet": "Pao gong",
20
- "pizzicato": "Bo xian",
21
- "percussive": "Ji gong",
22
- "legato_slide_glissando": "Lian hua yin",
23
- "harmonic": "Fan yin",
24
- "diangong": "Dian gong",
25
- "detache": "Fen gong",
26
- }
27
- CLASSES = list(TRANSLATE.keys())
28
- TEMP_DIR = "./__pycache__/tmp"
29
- SAMPLE_RATE = 44100
30
 
31
 
32
  def circular_padding(y: np.ndarray, sr: int, dur=3):
@@ -88,33 +80,38 @@ def wav2chroma(audio_path: str):
88
 
89
 
90
  def infer(wav_path: str, log_name: str, folder_path=TEMP_DIR):
91
- if os.path.exists(folder_path):
92
- shutil.rmtree(folder_path)
 
 
 
93
 
94
- if not wav_path:
95
- return None, "Please input an audio!"
96
 
97
- spec = log_name.split("_")[-3]
98
- os.makedirs(folder_path, exist_ok=True)
99
- try:
100
  model = EvalNet(log_name, len(TRANSLATE)).model
101
  eval("wav2%s" % spec)(wav_path)
 
 
 
 
 
 
 
 
 
102
 
103
  except Exception as e:
104
- return None, f"{e}"
105
-
106
- input = embed_img(f"{folder_path}/output.jpg")
107
- output: torch.Tensor = model(input)
108
- pred_id = torch.max(output.data, 1)[1]
109
- return (
110
- os.path.basename(wav_path),
111
- f"{TRANSLATE[CLASSES[pred_id]]} ({CLASSES[pred_id].capitalize()})",
112
- )
113
 
114
 
115
  if __name__ == "__main__":
116
  warnings.filterwarnings("ignore")
117
- models = get_modelist(assign_model="Swin_T_mel")
118
  examples = []
119
  example_wavs = find_wav_files()
120
  for wav in example_wavs:
@@ -124,36 +121,37 @@ if __name__ == "__main__":
124
  gr.Interface(
125
  fn=infer,
126
  inputs=[
127
- gr.Audio(label="Upload a recording", type="filepath"),
128
- gr.Dropdown(choices=models, label="Select a model", value=models[0]),
129
  ],
130
  outputs=[
131
- gr.Textbox(label="Audio filename", show_copy_button=True),
132
- gr.Textbox(label="Playing tech recognition", show_copy_button=True),
 
133
  ],
134
  examples=examples,
135
  cache_examples=False,
136
- allow_flagging="never",
137
- title="It is recommended to keep the recording length around 3s.",
138
  )
139
 
140
  gr.Markdown(
141
- """
142
- # Cite
143
- ```bibtex
144
- @article{Zhou-2025,
145
- author = {Monan Zhou and Shenyang Xu and Zhaorui Liu and Zhaowen Wang and Feng Yu and Wei Li and Baoqiang Han},
146
- title = {CCMusic: An Open and Diverse Database for Chinese Music Information Retrieval Research},
147
- journal = {Transactions of the International Society for Music Information Retrieval},
148
- volume = {8},
149
- number = {1},
150
- pages = {22--38},
151
- month = {Mar},
152
- year = {2025},
153
- url = {https://doi.org/10.5334/tismir.194},
154
- doi = {10.5334/tismir.194}
155
- }
156
- ```"""
157
  )
158
 
159
- demo.launch(ssr_mode=False)
 
8
  import librosa.display
9
  import matplotlib.pyplot as plt
10
  from model import EvalNet
11
+ from utils import (
12
+ get_modelist,
13
+ find_wav_files,
14
+ embed_img,
15
+ _L,
16
+ EN_US,
17
+ SAMPLE_RATE,
18
+ TEMP_DIR,
19
+ TRANSLATE,
20
+ CLASSES,
21
+ )
 
 
 
 
 
 
 
 
22
 
23
 
24
  def circular_padding(y: np.ndarray, sr: int, dur=3):
 
80
 
81
 
82
  def infer(wav_path: str, log_name: str, folder_path=TEMP_DIR):
83
+ status = "Success"
84
+ filename = result = None
85
+ try:
86
+ if os.path.exists(folder_path):
87
+ shutil.rmtree(folder_path)
88
 
89
+ if not wav_path:
90
+ return None, "请输入音频!"
91
 
92
+ spec = log_name.split("_")[-3]
93
+ os.makedirs(folder_path, exist_ok=True)
 
94
  model = EvalNet(log_name, len(TRANSLATE)).model
95
  eval("wav2%s" % spec)(wav_path)
96
+ input = embed_img(f"{folder_path}/output.jpg")
97
+ output: torch.Tensor = model(input)
98
+ pred_id = torch.max(output.data, 1)[1]
99
+ filename = os.path.basename(wav_path)
100
+ result = (
101
+ CLASSES[pred_id].capitalize()
102
+ if EN_US
103
+ else f"{TRANSLATE[CLASSES[pred_id]]} ({CLASSES[pred_id].capitalize()})"
104
+ )
105
 
106
  except Exception as e:
107
+ status = f"{e}"
108
+
109
+ return status, filename, result
 
 
 
 
 
 
110
 
111
 
112
  if __name__ == "__main__":
113
  warnings.filterwarnings("ignore")
114
+ models = get_modelist(assign_model="swin_t_mel")
115
  examples = []
116
  example_wavs = find_wav_files()
117
  for wav in example_wavs:
 
121
  gr.Interface(
122
  fn=infer,
123
  inputs=[
124
+ gr.Audio(label=_L("上传录音"), type="filepath"),
125
+ gr.Dropdown(choices=models, label=_L("选择模型"), value=models[0]),
126
  ],
127
  outputs=[
128
+ gr.Textbox(label=_L("状态栏"), show_copy_button=True),
129
+ gr.Textbox(label=_L("音频文件名"), show_copy_button=True),
130
+ gr.Textbox(label=_L("演奏技法识别"), show_copy_button=True),
131
  ],
132
  examples=examples,
133
  cache_examples=False,
134
+ flagging_mode="never",
135
+ title=_L("建议录音时长保持在 3s 左右"),
136
  )
137
 
138
  gr.Markdown(
139
+ f"# {_L('引用')}"
140
+ + """
141
+ ```bibtex
142
+ @article{Zhou-2025,
143
+ author = {Monan Zhou and Shenyang Xu and Zhaorui Liu and Zhaowen Wang and Feng Yu and Wei Li and Baoqiang Han},
144
+ title = {CCMusic: An Open and Diverse Database for Chinese Music Information Retrieval Research},
145
+ journal = {Transactions of the International Society for Music Information Retrieval},
146
+ volume = {8},
147
+ number = {1},
148
+ pages = {22--38},
149
+ month = {Mar},
150
+ year = {2025},
151
+ url = {https://doi.org/10.5334/tismir.194},
152
+ doi = {10.5334/tismir.194}
153
+ }
154
+ ```"""
155
  )
156
 
157
+ demo.launch()
model.py CHANGED
@@ -1,8 +1,9 @@
1
  import torch
2
  import torch.nn as nn
3
  import torchvision.models as models
 
4
  from datasets import load_dataset
5
- from utils import MODEL_DIR
6
 
7
 
8
  class EvalNet:
@@ -17,7 +18,7 @@ class EvalNet:
17
  self.m_type, self.input_size = self._model_info(m_ver)
18
 
19
  if not hasattr(models, m_ver):
20
- raise Exception("Unsupported model.")
21
 
22
  self.model = eval("models.%s()" % m_ver)
23
  linear_output = self._set_outsize()
@@ -34,11 +35,15 @@ class EvalNet:
34
  if ver == bb["ver"]:
35
  return bb
36
 
37
- print("Backbone name not found, using default option - alexnet.")
38
  return backbone_list[0]
39
 
40
  def _model_info(self, m_ver: str):
41
- backbone_list = load_dataset("monetjoe/cv_backbones", split="train")
 
 
 
 
42
  backbone = self._get_backbone(m_ver, backbone_list)
43
  m_type = str(backbone["type"])
44
  input_size = int(backbone["input_size"])
 
1
  import torch
2
  import torch.nn as nn
3
  import torchvision.models as models
4
+ from modelscope.msdatasets import MsDataset
5
  from datasets import load_dataset
6
+ from utils import MODEL_DIR, EN_US
7
 
8
 
9
  class EvalNet:
 
18
  self.m_type, self.input_size = self._model_info(m_ver)
19
 
20
  if not hasattr(models, m_ver):
21
+ raise ValueError("不支持的模型")
22
 
23
  self.model = eval("models.%s()" % m_ver)
24
  linear_output = self._set_outsize()
 
35
  if ver == bb["ver"]:
36
  return bb
37
 
38
+ print("未找到骨干网络名称,使用默认选项 - alexnet")
39
  return backbone_list[0]
40
 
41
  def _model_info(self, m_ver: str):
42
+ backbone_list = (
43
+ load_dataset("monetjoe/cv_backbones", split="train")
44
+ if EN_US
45
+ else MsDataset.load("monetjoe/cv_backbones", split="v1")
46
+ )
47
  backbone = self._get_backbone(m_ver, backbone_list)
48
  m_type = str(backbone["type"])
49
  input_size = int(backbone["input_size"])
requirements.txt CHANGED
@@ -1,5 +1,7 @@
1
- torch
2
- pillow
 
 
3
  librosa
4
  matplotlib
5
- torchvision
 
1
+ torch==2.6.0+cu118
2
+ -f https://download.pytorch.org/whl/torch
3
+ torchvision==0.21.0+cu118
4
+ -f https://download.pytorch.org/whl/torchvision
5
  librosa
6
  matplotlib
7
+ modelscope[framework]==1.21.0
utils.py CHANGED
@@ -1,15 +1,68 @@
1
  import os
2
  import torch
3
  import torchvision.transforms as transforms
4
- from huggingface_hub import snapshot_download
 
5
  from PIL import Image
6
 
7
- MODEL_DIR = snapshot_download(
8
- "ccmusic-database/erhu_playing_tech",
9
- cache_dir="./__pycache__",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  )
11
 
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  def toCUDA(x):
14
  if hasattr(x, "cuda"):
15
  if torch.cuda.is_available():
@@ -30,19 +83,16 @@ def find_wav_files(folder_path=f"{MODEL_DIR}/examples"):
30
 
31
 
32
  def get_modelist(model_dir=MODEL_DIR, assign_model=""):
33
- try:
34
- entries = os.listdir(model_dir)
35
- except OSError as e:
36
- print(f"Cannot access {model_dir}: {e}")
37
- return
38
-
39
  output = []
40
- for entry in entries:
 
41
  full_path = os.path.join(model_dir, entry)
 
42
  if entry == ".git" or entry == "examples":
43
- print(f"Skip .git / examples dir: {full_path}")
44
  continue
45
 
 
46
  if os.path.isdir(full_path):
47
  model = os.path.basename(full_path)
48
  if assign_model and assign_model.lower() in model:
 
1
  import os
2
  import torch
3
  import torchvision.transforms as transforms
4
+ import huggingface_hub
5
+ import modelscope
6
  from PIL import Image
7
 
8
+ EN_US = os.getenv("LANG") != "zh_CN.UTF-8"
9
+
10
+ ZH2EN = {
11
+ "上传录音": "Upload a recording",
12
+ "选择模型": "Select a model",
13
+ "状态栏": "Status",
14
+ "音频文件名": "Audio filename",
15
+ "演奏技法识别": "Playing tech recognition",
16
+ "建议录音时长保持在 3s 左右": "It is recommended to keep the recording length around 3s.",
17
+ "引用": "Cite",
18
+ "揉弦": "Rou xian",
19
+ "颤音": "Chan yin",
20
+ "颤弓": "Chan gong",
21
+ "顿弓": "Dun gong",
22
+ "抛弓": "Pao gong",
23
+ "拨弦": "Bo xian",
24
+ "击弓": "Ji gong",
25
+ "连滑音": "Lian hua yin",
26
+ "泛音": "Fan yin",
27
+ "垫弓": "Dian gong",
28
+ "分弓": "Fen gong",
29
+ }
30
+
31
+ MODEL_DIR = (
32
+ huggingface_hub.snapshot_download(
33
+ "ccmusic-database/erhu_playing_tech",
34
+ cache_dir="./__pycache__",
35
+ )
36
+ if EN_US
37
+ else modelscope.snapshot_download(
38
+ "ccmusic-database/erhu_playing_tech",
39
+ cache_dir="./__pycache__",
40
+ )
41
  )
42
 
43
 
44
+ def _L(zh_txt: str):
45
+ return ZH2EN[zh_txt] if EN_US else zh_txt
46
+
47
+
48
+ TRANSLATE = {
49
+ "vibrato": _L("揉弦"),
50
+ "trill": _L("颤音"),
51
+ "tremolo": _L("颤弓"),
52
+ "staccato": _L("顿弓"),
53
+ "ricochet": _L("抛弓"),
54
+ "pizzicato": _L("拨弦"),
55
+ "percussive": _L("击弓"),
56
+ "legato_slide_glissando": _L("连滑音"),
57
+ "harmonic": _L("泛音"),
58
+ "diangong": _L("垫弓"),
59
+ "detache": _L("分弓"),
60
+ }
61
+ CLASSES = list(TRANSLATE.keys())
62
+ TEMP_DIR = "./__pycache__/tmp"
63
+ SAMPLE_RATE = 44100
64
+
65
+
66
  def toCUDA(x):
67
  if hasattr(x, "cuda"):
68
  if torch.cuda.is_available():
 
83
 
84
 
85
  def get_modelist(model_dir=MODEL_DIR, assign_model=""):
 
 
 
 
 
 
86
  output = []
87
+ for entry in os.listdir(model_dir):
88
+ # 获取完整路径
89
  full_path = os.path.join(model_dir, entry)
90
+ # 跳过'.git'文件夹
91
  if entry == ".git" or entry == "examples":
92
+ print(f"跳过 .git examples 文件夹: {full_path}")
93
  continue
94
 
95
+ # 检查条目是文件还是目录
96
  if os.path.isdir(full_path):
97
  model = os.path.basename(full_path)
98
  if assign_model and assign_model.lower() in model: