SanderGi commited on
Commit
7796889
·
1 Parent(s): f718747

powsm support; multiple output formats

Browse files
Files changed (8) hide show
  1. DEVELOPMENT.md +1 -0
  2. app/app.py +37 -4
  3. app/codes.py +394 -0
  4. app/hf.py +16 -4
  5. app/inference.py +50 -2
  6. app/tasks.py +16 -4
  7. requirements.txt +4 -2
  8. requirements_lock.txt +52 -2
DEVELOPMENT.md CHANGED
@@ -85,6 +85,7 @@ IPA-Transcription-EN/
85
  ├── app/ # All application code lives here
86
  │ ├── data/ # Phoneme transcription test set
87
  │ ├── app.py # Main Gradio UI
 
88
  │ ├── hf.py # Interface with the Huggingface API
89
  │ ├── inference.py # Model inference
90
  │ └── metrics.py # Evaluation metrics
 
85
  ├── app/ # All application code lives here
86
  │ ├── data/ # Phoneme transcription test set
87
  │ ├── app.py # Main Gradio UI
88
+ │ ├── codes.py # Phonetic Alphabet conversions
89
  │ ├── hf.py # Interface with the Huggingface API
90
  │ ├── inference.py # Model inference
91
  │ └── metrics.py # Evaluation metrics
app/app.py CHANGED
@@ -7,6 +7,25 @@ import pandas as pd
7
  from tasks import start_eval_task, get_status
8
  from hf import get_or_create_leaderboard
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  def get_latest_leaderboard_html(datasets: list[str], sort_option: str) -> str:
12
  try:
@@ -28,6 +47,7 @@ def get_latest_leaderboard_html(datasets: list[str], sort_option: str) -> str:
28
  lambda r: f'<a href="https://huggingface.co/{r["repo_id"]}" target="_blank">{r["display_name"]}</a>',
29
  axis=1,
30
  ),
 
31
  "Average PER ⬇️": df["average_per"].apply(lambda x: f"{100 * x:.2f}%"),
32
  }
33
  | {
@@ -53,14 +73,16 @@ def get_latest_leaderboard_html(datasets: list[str], sort_option: str) -> str:
53
  return f"Error updating leaderboard: {type(e).__name__} - {e}"
54
 
55
 
56
- def submit_evaluation(model_id: str, display_name: str, url: str) -> str:
 
 
57
  model_id = model_id.strip()
58
  display_name = display_name.strip()
59
  if not model_id or not display_name:
60
  return "⚠️ Please provide both model name and submission name."
61
 
62
  try:
63
- task_id = start_eval_task(display_name, model_id, url)
64
  return f"✅ Evaluation submitted successfully! Task ID: {task_id}"
65
  except Exception as e:
66
  return f"❌ Error: {str(e)}"
@@ -100,7 +122,7 @@ with gr.Blocks(
100
  - **PER (Phoneme Error Rate)**: The Levenshtein distance calculated between phoneme sequences of the predicted and actual transcriptions.
101
  - **FER (Feature Error Rate)**: The edit distance between the predicted and actual phoneme sequences, weighted by the phonetic features from [panphon](https://github.com/dmort27/panphon).
102
 
103
- Models are evaluated on a variety of English speech: native, non-native, and impaired. Read more about evaluations on [our blog](https://www.koellabs.com/blog/phonemic-transcription-metrics)
104
 
105
  ## Compute
106
  This leaderboard uses the free basic plan (16GB RAM, 2vCPUs) to allow for reproducability. The evaluation may take several hours to complete. Please be patient and do not submit the same model multiple times.
@@ -163,12 +185,23 @@ with gr.Blocks(
163
  label="Github/Kaggle/HF URL (optional)",
164
  placeholder="https://github.com/username/repo",
165
  )
 
 
 
 
 
 
 
 
 
 
 
166
  submit_btn = gr.Button("Submit")
167
  result = gr.Textbox(label="Submission Status")
168
 
169
  submit_btn.click(
170
  fn=submit_evaluation,
171
- inputs=[model_id, display_name, url],
172
  outputs=result,
173
  )
174
 
 
7
  from tasks import start_eval_task, get_status
8
  from hf import get_or_create_leaderboard
9
 
10
+ from codes import CODES
11
+ from inference import MODEL_TYPES
12
+
13
+ from math import log
14
+
15
+ unit_list = list(zip(["B", "KB", "MB", "GB", "TB", "PB"], [0, 0, 1, 2, 2, 2]))
16
+
17
+
18
+ def sizeof_fmt(num):
19
+ """Human friendly file size"""
20
+ if isinstance(num, int):
21
+ exponent = min(int(log(num, 1024)), len(unit_list) - 1)
22
+ quotient = float(num) / 1024**exponent
23
+ unit, num_decimals = unit_list[exponent]
24
+ format_string = "{:.%sf} {}" % (num_decimals)
25
+ return format_string.format(quotient, unit)
26
+ else:
27
+ return "unknown"
28
+
29
 
30
  def get_latest_leaderboard_html(datasets: list[str], sort_option: str) -> str:
31
  try:
 
47
  lambda r: f'<a href="https://huggingface.co/{r["repo_id"]}" target="_blank">{r["display_name"]}</a>',
48
  axis=1,
49
  ),
50
+ "Size": df["model_bytes"].apply(sizeof_fmt),
51
  "Average PER ⬇️": df["average_per"].apply(lambda x: f"{100 * x:.2f}%"),
52
  }
53
  | {
 
73
  return f"Error updating leaderboard: {type(e).__name__} - {e}"
74
 
75
 
76
+ def submit_evaluation(
77
+ model_id: str, display_name: str, url: str, model_type: str, phone_code: str
78
+ ) -> str:
79
  model_id = model_id.strip()
80
  display_name = display_name.strip()
81
  if not model_id or not display_name:
82
  return "⚠️ Please provide both model name and submission name."
83
 
84
  try:
85
+ task_id = start_eval_task(display_name, model_id, url, model_type, phone_code)
86
  return f"✅ Evaluation submitted successfully! Task ID: {task_id}"
87
  except Exception as e:
88
  return f"❌ Error: {str(e)}"
 
122
  - **PER (Phoneme Error Rate)**: The Levenshtein distance calculated between phoneme sequences of the predicted and actual transcriptions.
123
  - **FER (Feature Error Rate)**: The edit distance between the predicted and actual phoneme sequences, weighted by the phonetic features from [panphon](https://github.com/dmort27/panphon).
124
 
125
+ Models are evaluated on a variety of English speech: native, non-native, and impaired. Read more about [evaluations](https://www.koellabs.com/blog/phonemic-transcription-metrics) or [how to build your own leaderboards](https://www.koellabs.com/blog/building-open-source-leaderboards) on our blog.
126
 
127
  ## Compute
128
  This leaderboard uses the free basic plan (16GB RAM, 2vCPUs) to allow for reproducability. The evaluation may take several hours to complete. Please be patient and do not submit the same model multiple times.
 
185
  label="Github/Kaggle/HF URL (optional)",
186
  placeholder="https://github.com/username/repo",
187
  )
188
+ model_type = gr.Dropdown(
189
+ choices=["Transformers CTC"]
190
+ + [c for c in sorted(MODEL_TYPES) if c != "Transformers CTC"],
191
+ label="Model Type",
192
+ interactive=True,
193
+ )
194
+ output_code = gr.Dropdown(
195
+ choices=["ipa"] + [c for c in sorted(CODES) if c != "ipa"],
196
+ label="Model Output Phonetic Code",
197
+ interactive=True,
198
+ )
199
  submit_btn = gr.Button("Submit")
200
  result = gr.Textbox(label="Submission Status")
201
 
202
  submit_btn.click(
203
  fn=submit_evaluation,
204
+ inputs=[model_id, display_name, url, model_type, output_code],
205
  outputs=result,
206
  )
207
 
app/codes.py ADDED
@@ -0,0 +1,394 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ # Conversion between different phonetic codes
4
+ # Modified from https://github.com/jhasegaw/phonecodes/blob/master/src/phonecodes.py
5
+
6
+ # Canonical version of this file lives in https://github.com/KoelLabs/ML
7
+
8
+ import sys
9
+
10
+ # CODES = set(("ipa", "timit", "arpabet", "xsampa", "buckeye", "epadb", "isle", "disc", "callhome"))
11
+ CODES = set(("ipa", "timit", "arpabet", "xsampa", "buckeye", "epadb", "isle"))
12
+
13
+
14
+ def convert(phoneme_string, from_code, to_code):
15
+ assert from_code in CODES, f"from_code must be one of {CODES}"
16
+ assert to_code in CODES, f"to_code must be one of {CODES}"
17
+
18
+ if from_code == "ipa":
19
+ return globals()[f"ipa2{to_code}"](phoneme_string)
20
+ elif to_code == "ipa":
21
+ return globals()[f"{from_code}2ipa"](phoneme_string)
22
+ else:
23
+ return globals()[f"ipa2{to_code}"](
24
+ globals()[f"{from_code}2ipa"](phoneme_string)
25
+ )
26
+
27
+
28
+ def string2symbols(string, symbols):
29
+ """Converts a string of symbols into a list of symbols, minimizing the number of untranslatable symbols, then minimizing the number of translated symbols."""
30
+ N = len(string)
31
+ symcost = 1 # path cost per translated symbol
32
+ oovcost = len(string) # path cost per untranslatable symbol
33
+ maxsym = max(len(k) for k in symbols) # max input symbol length
34
+ # (pathcost to s[(n-m):n], n-m, translation[s[(n-m):m]], True/False)
35
+ lattice = [(0, 0, "", True)]
36
+ for n in range(1, N + 1):
37
+ # Initialize on the assumption that s[n-1] is untranslatable
38
+ lattice.append((oovcost + lattice[n - 1][0], n - 1, string[(n - 1) : n], False))
39
+ # Search for translatable sequences s[(n-m):n], and keep the best
40
+ for m in range(1, min(n + 1, maxsym + 1)):
41
+ if (
42
+ string[(n - m) : n] in symbols
43
+ and symcost + lattice[n - m][0] < lattice[n][0]
44
+ ):
45
+ lattice[n] = (
46
+ symcost + lattice[n - m][0],
47
+ n - m,
48
+ string[(n - m) : n],
49
+ True,
50
+ )
51
+ # Back-trace
52
+ tl = []
53
+ translated = []
54
+ n = N
55
+ while n > 0:
56
+ tl.append(lattice[n][2])
57
+ translated.append(lattice[n][3])
58
+ n = lattice[n][1]
59
+ return (tl[::-1], translated[::-1])
60
+
61
+
62
+ #####################################################################
63
+ # Handle tones/stress markers
64
+ # fmt: off
65
+ TONE2IPA = {
66
+ 'arz' : { '0':'', '1':'ˈ', '2':'ˌ', '3': '', '4': '', '5': '', '6': '' },
67
+ 'eng' : { '0':'', '1':'ˈ', '2':'ˌ', '3': '', '4': '', '5': '', '6': '' },
68
+ 'yue' : { '0':'', '1':'˥', '2':'˧˥', '3':'˧', '4':'˨˩', '5':'˩˧', '6':'˨' },
69
+ 'lao' : { '0':'', '1':'˧', '2':'˥˧', '3':'˧˩', '4':'˥', '5':'˩˧', '6':'˩' },
70
+ 'cmn' : { '0':'', '1':'˥', '2':'˧˥', '3':'˨˩˦', '4':'˥˩', '5': '', '6': '' },
71
+ 'spa' : { '0':'', '1':'ˈ', '2':'ˌ', '3': '', '4': '', '5': '', '6': '' },
72
+ 'vie' : { '0':'', '1':'˧', '2':'˨˩h', '3':'˧˥', '4':'˨˩˨', '5':'˧ʔ˥', '6':'˧˨ʔ' },
73
+ }
74
+ IPA2TONE = {key: {v: k for k, v in val.items()} for key, val in TONE2IPA.items()}
75
+ # fmt: on
76
+
77
+
78
+ def update_dict_with_tones(code2ipa: dict, ipa2code: dict, lang):
79
+ code2ipa.update(TONE2IPA[lang])
80
+ ipa2code.update(IPA2TONE[lang])
81
+
82
+
83
+ #####################################################################
84
+ # X-SAMPA
85
+ # XSAMPA2IPA = {"_": "͡", "a": "a", "b": "b", "b_<": "ɓ", "c": "c", "d": "d", "d`": "ɖ", "d_<": "ɗ", "e": "e", "f": "f", "g": "ɡ", "g_<": "ɠ", "h": "h", "h\\": "ɦ", "i": "i", "j": "j", "j\\": "ʝ", "k": "k", "l": "l", "l`": "ɭ", "l\\": "ɺ", "m": "m", "n": "n", "n`": "ɳ", "o": "o", "p": "p", "p\\": "ɸ", "q": "q", "r": "r", "r`": "ɽ", "r\\": "ɹ", "r\\`": "ɻ", "s": "s", "s`": "ʂ", "s\\": "ɕ", "t": "t", "t`": "ʈ", "u": "u", "v": "v", "v\\": "ʋ", "P": "ʋ", "w": "w", "x": "x", "x\\": "ɧ", "y": "y", "z": "z", "z`": "ʐ", "z\\": "ʑ", "A": "ɑ", "B": "β", "B\\": "ʙ", "C": "ç", "D": "ð", "E": "ɛ", "F": "ɱ", "G": "ɣ", "G\\": "ɢ", "G\\_<": "ʛ", "H": "ɥ", "H\\": "ʜ", "I": "ɪ", "I\\": "ᵻ", "J": "ɲ", "J\\": "ɟ", "J\\_<": "ʄ", "K": "ɬ", "K\\": "ɮ", "L": "ʎ", "L\\": "ʟ", "M": "ɯ", "M\\": "ɰ", "N": "ŋ", "N\\": "ɴ", "O": "ɔ", "O\\": "ʘ", "Q": "ɒ", "R": "ʁ", "R\\": "ʀ", "S": "ʃ", "T": "θ", "U": "ʊ", "U\\": "ᵿ", "V": "ʌ", "W": "ʍ", "X": "χ", "X\\": "ħ", "Y": "ʏ", "Z": "ʒ", ".": ".", '"': "ˈ", "%": "ˌ", "'": "ʲ", "_j": "ʲ", ":": "ː", ":\\": "ˑ", "@": "ə", "@\\": "ɘ", "@`": "ɚ", "{": "æ", "}": "ʉ", "1": "ɨ", "2": "ø", "3": "ɜ", "3\\": "ɞ", "4": "ɾ", "5": "ɫ", "6": "ɐ", "7": "ɤ", "8": "ɵ", "9": "œ", "&": "ɶ", "?": "ʔ", "?\\": "ʕ", "/": "/", "<": "⟨", "<\\": "ʢ", ">": "⟩", ">\\": "ʡ", "^": "ꜛ", "!": "ꜜ", "!\\": "ǃ", "|": "|", "|\\": "ǀ", "||": "‖", "|\\|\\": "ǁ", "=\\": "ǂ", "-\\": "‿", '_"': '̈', "_+": " ", "_-": " ", "_/": " ", "_0": " ", "=": " ", "_=": " ", "_>": "ʼ", "_?\\": "ˤ", "_^": " ", "_}": " ", "`": "˞", "~": " ", "_~": " ", "_A": " ", "_a": " ̺", "_B": " ̏", "_B_L": " ᷅", "_c": " ", "_d": " ̪", "_e": " ̴", "<F>": "↘", "_F": " ", "_\\": " ", "_G": "ˠ", "_H": " ", "_H_T": " ᷄", "_h": "ʰ", "_k": " ̰", "_L": " ̀", "_l": "ˡ", "_M": " ̄", "_m": " ", "_N": " ̼", "_n": "ⁿ", "_O": " ", "_o": " ", "_q": " ", "<R>": "↗", "_R": " ", "_R_F": " ᷈", "_r": " ", "_T": " ", "_t": " ", "_v": " ", "_w": "ʷ", "_X": " ", "_x": " "} # fmt: skip
86
+ XSAMPA2IPA = {"_": "͡", "a": "a", "b": "b", "b_<": "ɓ", "c": "c", "d": "d", "d`": "ɖ", "d_<": "ɗ", "e": "e", "f": "f", "g": "ɡ", "g_<": "ɠ", "h": "h", "h\\": "ɦ", "i": "i", "j": "j", "j\\": "ʝ", "k": "k", "l": "l", "l`": "ɭ", "l\\": "ɺ", "m": "m", "n": "n", "n`": "ɳ", "o": "o", "p": "p", "p\\": "ɸ", "q": "q", "r": "r", "r`": "ɽ", "r\\": "ɹ", "r\\`": "ɻ", "s": "s", "s`": "ʂ", "s\\": "ɕ", "t": "t", "t`": "ʈ", "u": "u", "v": "v", "v\\": "ʋ", "P": "ʋ", "w": "w", "x": "x", "x\\": "ɧ", "y": "y", "z": "z", "z`": "ʐ", "z\\": "ʑ", "A": "ɑ", "B": "β", "B\\": "ʙ", "C": "ç", "D": "ð", "E": "ɛ", "F": "ɱ", "G": "ɣ", "G\\": "ɢ", "G\\_<": "ʛ", "H": "ɥ", "H\\": "ʜ", "I": "ɪ", "I\\": "ᵻ", "J": "ɲ", "J\\": "ɟ", "J\\<": "ʄ", "K": "ɬ", "K\\": "ɮ", "L": "ʎ", "L\\": "ʟ", "M": "ɯ", "M\\": "ɰ", "N": "ŋ", "N\\": "ɴ", "O": "ɔ", "O\\": "ʘ", "Q": "ɒ", "R": "ʁ", "R\\": "ʀ", "S": "ʃ", "T": "θ", "U": "ʊ", "U\\": "ᵿ", "V": "ʌ", "W": "ʍ", "X": "χ", "X\\": "ħ", "Y": "ʏ", "Z": "ʒ", ".": ".", '"': "ˈ", "%": "ˌ", "'": "ʲ", "_j": "ʲ", ":": "ː", ":\\": "ˑ", "@": "ə", "@\\": "ɘ", "@`": "ɚ", "{": "æ", "}": "ʉ", "1": "ɨ", "2": "ø", "3": "ɜ", "3\\": "ɞ", "4": "ɾ", "5": "ɫ", "6": "ɐ", "7": "ɤ", "8": "ɵ", "9": "œ", "&": "ɶ", "?": "ʔ", "?\\": "ʕ", "/": "/", "<": "⟨", "<\\": "ʢ", ">": "⟩", ">\\": "ʡ", "^": "ꜛ", "!": "ꜜ", "!\\": "ǃ", "|": "|", "|\\": "ǀ", "||": "‖", "|\\|\\": "ǁ", "=\\": "ǂ", "-\\": "‿", '_"': '̈', "_+": " ", "_-": " ", "_/": " ", "_0": " ", "=": " ", "_=": " ", "_>": "ʼ", "_?\\": "ˤ", "_^": " ", "_}": " ", "`": "˞", "~": " ", "_~": " ", "_A": " ", "_a": " ̺", "_B": " ̏", "_B_L": " ᷅", "_c": " ", "_d": " ̪", "_e": " ̴", "<f>": "↘", "_F": " ", "_\\": " ", "_G": "ˠ", "_H": " ", "_H_T": " ᷄", "_h": "ʰ", "_k": " ̰", "_L": " ̀", "_l": "ˡ", "_M": " ̄", "_m": " ", "_N": " ̼", "_n": "ⁿ", "_O": " ", "_o": " ", "_q": " ", "<r>": "↗", "_R": " ", "_R_F": " ᷈", "_r": " ", "_T": " ", "_t": " ", "_v": " ", "_w": "ʷ", "_X": " ", "_x": " "} # fmt: skip
87
+ # Not supported yet:
88
+ # _<
89
+ # -
90
+ # *
91
+ # rhotization for consonants
92
+ IPA2XSAMPA = {v: k for k, v in XSAMPA2IPA.items()}
93
+
94
+
95
+ def ipa2xsampa(ipa_string, lang="eng"):
96
+ ipa_symbols = string2symbols(ipa_string, IPA2XSAMPA.keys())[0]
97
+ xsampa_symbols = [IPA2XSAMPA[x] for x in ipa_symbols]
98
+ return " ".join(xsampa_symbols)
99
+
100
+
101
+ def xsampa2ipa(xsampa_string, lang="eng"):
102
+ if " " in xsampa_string:
103
+ xsampa_symbols = xsampa_string.split()
104
+ else:
105
+ xsampa_symbols = string2symbols(xsampa_string, XSAMPA2IPA.keys())[0]
106
+ return "".join([XSAMPA2IPA[x] for x in xsampa_symbols])
107
+
108
+
109
+ #####################################################################
110
+ # DISC, the system used by CELEX
111
+ def ipa2disc(ipa_string, lang="eng"):
112
+ raise NotImplementedError
113
+
114
+
115
+ def disc2ipa(disc_string, lang="eng"):
116
+ raise NotImplementedError
117
+
118
+
119
+ #####################################################################
120
+ # Kirshenbaum
121
+ def ipa2kirshenbaum(ipa_string, lang="eng"):
122
+ raise NotImplementedError
123
+
124
+
125
+ def kirshenbaum2ipa(kirshenbaum_string, lang="eng"):
126
+ raise NotImplementedError
127
+
128
+
129
+ #######################################################################
130
+ # Callhome phone codes
131
+ def ipa2callhome(ipa_string, lang="eng"):
132
+ raise NotImplementedError
133
+
134
+
135
+ def callhome2ipa(callhome_string, lang="eng"):
136
+ raise NotImplementedError
137
+
138
+
139
+ #########################################################################
140
+ # Buckeye
141
+ BUCKEYE2IPA = {'aa':'ɑ', 'ae':'æ', 'ay':'aɪ', 'aw':'aʊ', 'ao':'ɔ', 'oy':'ɔɪ', 'ow':'oʊ', 'eh':'ɛ', 'ey':'eɪ', 'er':'ɝ', 'ah':'ʌ', 'uw':'u', 'uh':'ʊ', 'ih':'ɪ', 'iy':'i', 'm':'m', 'n':'n', 'en':'n̩', 'ng':'ŋ', 'l':'l', 'el':'l̩', 't':'t', 'd':'d', 'ch':'tʃ', 'jh':'dʒ', 'th':'θ', 'dh':'ð', 'sh':'ʃ', 'zh':'ʒ', 's':'s', 'z':'z', 'k':'k', 'g':'ɡ', 'p':'p', 'b':'b', 'f':'f', 'v':'v', 'w':'w', 'hh':'h', 'y':'j', 'r':'ɹ', 'dx':'ɾ', 'nx':'ɾ̃', 'tq':'ʔ', 'er':'ɚ', 'em':'m̩', 'ihn': 'ĩ', 'ehn': 'ɛ̃', 'own': 'oʊ̃', 'ayn': 'aɪ̃', 'aen': 'æ̃', 'aan': 'ɑ̃', 'ahn': 'ə̃', 'eng': 'ŋ̍', 'iyn': 'ĩ', 'uhn': 'ʊ̃'} # fmt: skip
142
+ IPA2BUCKEYE = {v: k for k, v in BUCKEYE2IPA.items()}
143
+ # 'Vn':'◌̃'
144
+
145
+
146
+ def ipa2buckeye(ipa_string, lang="eng"):
147
+ update_dict_with_tones(BUCKEYE2IPA, IPA2BUCKEYE, lang)
148
+ ipa_symbols = string2symbols(ipa_string, IPA2BUCKEYE.keys())[0]
149
+ buckeye_symbols = [IPA2BUCKEYE[x] for x in ipa_symbols]
150
+ return " ".join(buckeye_symbols)
151
+
152
+
153
+ def buckeye2ipa(buckeye_string, lang="eng"):
154
+ update_dict_with_tones(BUCKEYE2IPA, IPA2BUCKEYE, lang)
155
+ if " " in buckeye_string:
156
+ buckeye_symbols = buckeye_string.split()
157
+ else:
158
+ buckeye_symbols = string2symbols(buckeye_string, BUCKEYE2IPA.keys())[0]
159
+ return "".join([BUCKEYE2IPA[x] for x in buckeye_symbols])
160
+
161
+
162
+ #########################################################################
163
+ # ARPABET
164
+ ARPABET2IPA = {'AA':'ɑ','AE':'æ','AH':'ʌ','AO':'ɔ','IX':'ɨ','AW':'aʊ','AX':'ə','AXR':'ɚ','AY':'aɪ','EH':'ɛ','ER':'ɝ','EY':'eɪ','IH':'ɪ','IY':'i','OW':'oʊ','OY':'ɔɪ','UH':'ʊ','UW':'u','UX':'ʉ','B':'b','CH':'tʃ','D':'d','DH':'ð','EL':'l̩','EM':'m̩','EN':'n̩','F':'f','G':'ɡ','HH':'h','H':'h','JH':'dʒ','K':'k','L':'l','M':'m','N':'n','NG':'ŋ','NX':'ɾ̃','P':'p','Q':'ʔ','R':'ɹ','S':'s','SH':'ʃ','T':'t','TH':'θ','V':'v','W':'w','WH':'ʍ','Y':'j','Z':'z','ZH':'ʒ','DX':'ɾ'} # fmt: skip
165
+ IPA2ARPABET = {v: k for k, v in ARPABET2IPA.items()}
166
+
167
+
168
+ def ipa2arpabet(ipa_string, lang="eng"):
169
+ update_dict_with_tones(ARPABET2IPA, IPA2ARPABET, lang)
170
+ ipa_symbols = string2symbols(ipa_string, IPA2ARPABET.keys())[0]
171
+ arpabet_symbols = [IPA2ARPABET[x] for x in ipa_symbols]
172
+ return " ".join(arpabet_symbols)
173
+
174
+
175
+ def arpabet2ipa(arpabet_string, lang="eng"):
176
+ update_dict_with_tones(ARPABET2IPA, IPA2ARPABET, lang)
177
+ if " " in arpabet_string:
178
+ arpabet_symbols = arpabet_string.split()
179
+ else:
180
+ arpabet_symbols = string2symbols(arpabet_string, ARPABET2IPA.keys())[0]
181
+ return "".join([ARPABET2IPA[x] for x in arpabet_symbols])
182
+
183
+
184
+ #########################################################################
185
+ # EpaDB
186
+ # We simplify 'A' to 'a' instead of 'ä'
187
+ EPADB2IPA = dict(ARPABET2IPA, **{"PH": "pʰ", "TH": "θʰ", "SH": "sʰ", "KH": "kʰ", "DH": "ð", 'BH': 'β', 'GH': 'ɣ', 'RR': 'r', 'DX': 'ɾ', 'X': 'x', 'A': 'a', 'E': 'e', 'O': 'o', 'U': ARPABET2IPA['UW'], 'I': ARPABET2IPA['IY'], 'LL': 'ʟ'}) # fmt: skip
188
+ IPA2EPADB = {v: k for k, v in EPADB2IPA.items()}
189
+
190
+
191
+ def ipa2epadb(ipa_string, lang="eng"):
192
+ update_dict_with_tones(EPADB2IPA, IPA2EPADB, lang)
193
+ ipa_symbols = string2symbols(ipa_string, IPA2EPADB.keys())[0]
194
+ epadb_symbols = [IPA2EPADB[x] for x in ipa_symbols]
195
+ return " ".join(epadb_symbols)
196
+
197
+
198
+ def epadb2ipa(epadb_string, lang="eng"):
199
+ update_dict_with_tones(EPADB2IPA, IPA2EPADB, lang)
200
+ if " " in epadb_string:
201
+ epadb_symbols = epadb_string.split()
202
+ else:
203
+ epadb_symbols = string2symbols(epadb_string, EPADB2IPA.keys())[0]
204
+ return "".join([EPADB2IPA[x] for x in epadb_symbols])
205
+
206
+
207
+ #########################################################################
208
+ # TIMIT
209
+ CLOSURE_INTERVALS = {
210
+ "BCL": ["B"],
211
+ "DCL": ["D", "JH"],
212
+ "GCL": ["G"],
213
+ "PCL": ["P"],
214
+ "TCL": ["T", "CH"],
215
+ "KCL": ["K"],
216
+ }
217
+ TIMIT2IPA = {'AA': 'ɑ', 'AE': 'æ', 'AH': 'ʌ', 'AO': 'ɔ', 'AW': 'aʊ', 'AX': 'ə', 'AXR': 'ɚ', 'AX-H': 'ə̥', 'AY': 'aɪ', 'EH': 'ɛ', 'ER': 'ɝ', 'EY': 'eɪ', 'IH': 'ɪ', 'IY': 'i', 'OW': 'oʊ', 'OY': 'ɔɪ', 'UH': 'ʊ', 'UW': 'u', 'B': 'b', 'CH': 'tʃ', 'D': 'd', 'DH': 'ð', 'EL': 'l̩', 'EM': 'm̩', 'EN': 'n̩', 'F': 'f', 'G': 'ɡ', 'HH': 'h', 'JH': 'dʒ', 'K': 'k', 'L': 'l', 'M': 'm', 'N': 'n', 'NG': 'ŋ', 'P': 'p', 'Q': 'ʔ', 'R': 'ɹ', 'S': 's', 'SH': 'ʃ', 'T': 't', 'TH': 'θ', 'V': 'v', 'W': 'w', 'WH': 'ʍ', 'Y': 'j', 'Z': 'z', 'ZH': 'ʒ', 'DX': 'ɾ', 'ENG': 'ŋ̍', 'EPI': '', 'HV': 'ɦ', 'H#': '', 'IX': 'ɨ', 'NX': 'ɾ̃', 'PAU': '', 'UX': 'ʉ'} # fmt: skip
218
+ IPA2TIMIT = {v: k for k, v in TIMIT2IPA.items()}
219
+ INVERSE_CLOSURE_INTERVALS = {v: k for k, val in CLOSURE_INTERVALS.items() for v in val}
220
+
221
+
222
+ def parse_timit(lines):
223
+ # parses the format of a TIMIT .PHN file, handling edge cases where the closure interval and stops are not always paired
224
+ timestamped_phonemes = []
225
+ closure_interval_start = None
226
+ for line in lines:
227
+ if line == "":
228
+ continue
229
+ start, end, phoneme = line.split()
230
+ phoneme = phoneme.upper()
231
+
232
+ if closure_interval_start:
233
+ cl_start, cl_end, cl_phoneme = closure_interval_start
234
+ if phoneme not in CLOSURE_INTERVALS[cl_phoneme]:
235
+ ipa_phoneme = TIMIT2IPA[CLOSURE_INTERVALS[cl_phoneme][0]]
236
+ timestamped_phonemes.append((ipa_phoneme, int(cl_start), int(cl_end)))
237
+ else:
238
+ assert phoneme not in CLOSURE_INTERVALS
239
+ start = cl_start
240
+
241
+ if phoneme in CLOSURE_INTERVALS:
242
+ closure_interval_start = (start, end, phoneme)
243
+ continue
244
+
245
+ ipa_phoneme = TIMIT2IPA[phoneme]
246
+ timestamped_phonemes.append((ipa_phoneme, int(start), int(end)))
247
+
248
+ closure_interval_start = None
249
+
250
+ if closure_interval_start:
251
+ cl_start, cl_end, cl_phoneme = closure_interval_start
252
+ ipa_phoneme = TIMIT2IPA[CLOSURE_INTERVALS[cl_phoneme][0]]
253
+ timestamped_phonemes.append((ipa_phoneme, int(cl_start), int(cl_end)))
254
+
255
+ return timestamped_phonemes
256
+
257
+
258
+ def ipa2timit(ipa_string, lang="eng"):
259
+ update_dict_with_tones(TIMIT2IPA, IPA2TIMIT, lang)
260
+ ipa_symbols = string2symbols(ipa_string, IPA2TIMIT.keys())[0]
261
+ timit_symbols = [IPA2TIMIT[x] for x in ipa_symbols]
262
+ # insert closure intervals before each stop
263
+ timit_symbols_with_closures = []
264
+ for timit_symbol in timit_symbols:
265
+ if timit_symbol in INVERSE_CLOSURE_INTERVALS:
266
+ timit_symbols_with_closures.append(INVERSE_CLOSURE_INTERVALS[timit_symbol])
267
+ timit_symbols_with_closures.append(timit_symbol)
268
+ return " ".join(timit_symbols_with_closures)
269
+
270
+
271
+ def timit2ipa(timit_string, lang="eng"):
272
+ update_dict_with_tones(TIMIT2IPA, IPA2TIMIT, lang)
273
+ if " " in timit_string:
274
+ timit_symbols = timit_string.split()
275
+ else:
276
+ timit_symbols = string2symbols(
277
+ timit_string, TIMIT2IPA.keys() | CLOSURE_INTERVALS.keys()
278
+ )[0]
279
+ timestamped_phonemes = parse_timit((f"0 0 {x}" for x in timit_symbols))
280
+ return "".join([x[0] for x in timestamped_phonemes])
281
+
282
+
283
+ #########################################################################
284
+ # Isle (adaptation of deprecated Entropic GrapHvite UK Phone Set), see http://www.lrec-conf.org/proceedings/lrec2000/pdf/313.pdf
285
+ #
286
+ # Closely matches ARBABet but:
287
+ # - Some simplifications:
288
+ # - Only use HH to denote h, not also H
289
+ # - Drop IX (ɨ), UX (ʉ), EL (l̩), EM (m̩), EN (n̩), NX (ɾ̃), Q (ʔ), WH (ʍ), DX (ɾ)
290
+ # - Some adaptations to UK dialect:
291
+ # - Distinquish ɑ vs ɒ by adding OH for ɒ and restricting AA to ɑ
292
+ # - Map OW to əʊ instead of oʊ
293
+ # - ER maps to ɜ instead of ɝ because British English is non-rhotic (the r sound is dropped at the end of syllables)
294
+ # - Some adaptations to Italian/German dialects:
295
+ # - ER (ɜ) explicitly followed by R (ɹ) maps to ɝ because most Italian/German dialects are rhotic
296
+ # - We also keep AXR from ARPABet even though it is not in the UK Phone set, so we now map AX (ə) explicitly followed by R (ɹ) to AXR (ɚ) for the same reason
297
+ #
298
+ # symbol : example - UK G2P / US G2P | UK / US / ARPABet | comments
299
+ # Aa : balm - bɑːm / bɑm | ɑ / ɑ / ɑ |
300
+ # Aa : barn - bɑːn / bɑrn | ɑ / ɑ / ɑ |
301
+ # Ae : bat - bæt / bæt | æ / æ / æ |
302
+ # Ah : bat - bæt / bæt | æ / æ / ʌ |
303
+ # Ao : bought - bɔːt / bɑt | ɔ / ɑ / ɔ |
304
+ # Aw : bout - baʊt / baʊt | aʊ / aʊ / aʊ |
305
+ # Ax : about - əˈbaʊt / əˈbaʊt | ə / ə / ə |
306
+ # Ay : bite - baɪt / baɪt | aɪ / aɪ / aɪ |
307
+ # Eh : bet - bɛt / bɛt | ɛ / ɛ / ɛ |
308
+ # Er : bird - bɜːd / bɜrd | ɜ / ɝ / ɝ | different, ER represents non-r-colored ɜ in UK English because it is non-rhotic unlike American English which is what ARPABet is based on
309
+ # Ey : bait - beɪt / beɪt | eɪ / eɪ / eɪ |
310
+ # Ih : bit - bɪt / bɪt | ɪ / ɪ / ɪ |
311
+ # Iy : beet - biːt / bit | i / i / i |
312
+ # Ow : boat - bəʊt / boʊt | əʊ / oʊ / oʊ | different, map OW to əʊ
313
+ # Oy : boy - bɔɪ / bɔɪ | ɔɪ / ɔɪ / ɔɪ |
314
+ # Oh : box - bɒks / bɑks | ɒ / ɑ / - | added OH to disambiguate ɒ
315
+ # Uh : book - bʊk / bʊk | ʊ / ʊ / ʊ |
316
+ # Uw : boot - buːt / but | u / u / u |
317
+ # B : bet - bɛt / bɛt | b / b / b |
318
+ # Ch : cheap - ʧiːp / ʧip | ʧ / ʧ / tʃ |
319
+ # D : debt - dɛt / dɛt | d / d / d |
320
+ # Dh : that - ðæt / ðæt | ð / ð / ð |
321
+ # F : fan - fæn / fæn | f / f / f |
322
+ # G : get - ɡɛt / ɡɛt | ɡ / ɡ / ɡ |
323
+ # Hh : hat - hæt / hæt | h / h / h | match, but drop alternative H
324
+ # Jh : jeep - ʤiːp / ʤip | ʤ / ʤ / dʒ |
325
+ # K : cat - kæt / kæt | k / k / k |
326
+ # L : led - lɛd / lɛd | l / l / l |
327
+ # M : met - mɛt / mɛt | m / m / m |
328
+ # N : net - nɛt / nɛt | n / n / n |
329
+ # Ng : thing - θɪŋ / θɪŋ | ŋ / ŋ / ŋ |
330
+ # P : pet - pɛt / pɛt | p / p / p |
331
+ # R : red - rɛd / ˈɹɛd | r / ɹ / ɹ | different, but due to broad vs narrow and other annotation conventions; the sounds are actually different too but not sure how to model this
332
+ # S : sue - sjuː / su | s / s / s |
333
+ # Sh : shoe - ʃuː / ʃu | ʃ / ʃ / ʃ |
334
+ # T : tat - tæt / tæt | t / t / t |
335
+ # Th : thin - θɪn / θɪn | θ / θ / θ |
336
+ # V : van - væn / væn | v / v / v |
337
+ # W : wed - wɛd / wɛd | w / w / w |
338
+ # Y : yet - jɛt / jɛt | j / j / j |
339
+ # Z : zoo - zuː / zu | z / z / z |
340
+ # Zh : measure - ˈmɛʒə / ˈmɛʒər | ʒ / ʒ / ʒ |
341
+
342
+ ISLE2IPA = {'AA':'ɑ','AE':'æ','AH':'ʌ','AO':'ɔ','AW':'aʊ','AX':'ə','AXR':'ɚ','AY':'aɪ','EH':'ɛ','ER':'ɜ','ERR':'ɝ','EY':'eɪ','IH':'ɪ','IY':'i','OW':'əʊ','OY':'ɔɪ','OH':'ɒ','UH':'ʊ','UW':'u','B':'b','CH':'tʃ','D':'d','DH':'ð','F':'f','G':'ɡ','HH':'h','JH':'dʒ','K':'k','L':'l','M':'m','N':'n','NG':'ŋ','P':'p','R':'ɹ','S':'s','SH':'ʃ','T':'t','TH':'θ','V':'v','W':'w','Y':'j','Z':'z','ZH':'ʒ'} # fmt: skip
343
+ IPA2ISLE = {v: k for k, v in ISLE2IPA.items()}
344
+
345
+
346
+ def ipa2isle(ipa_string, lang="eng"):
347
+ update_dict_with_tones(ISLE2IPA, IPA2ISLE, lang)
348
+ ipa_symbols = string2symbols(ipa_string, IPA2ISLE.keys())[0]
349
+ isle_symbols = [IPA2ISLE[x] for x in ipa_symbols]
350
+ return " ".join(isle_symbols)
351
+
352
+
353
+ def isle2ipa(isle_string, lang="eng"):
354
+ update_dict_with_tones(ISLE2IPA, IPA2ISLE, lang)
355
+ if " " in isle_string:
356
+ isle_symbols = isle_string.split()
357
+ else:
358
+ isle_symbols = string2symbols(isle_string, ISLE2IPA.keys())[0]
359
+ return "".join([ISLE2IPA[x] for x in isle_symbols])
360
+
361
+
362
+ #########################################################################
363
+ # CLI
364
+ def usage():
365
+ print("Usage: python ./scripts/core/codes.py <src> <tgt> <phoneme_string>")
366
+ print("Supported codes:", CODES)
367
+
368
+
369
+ ALL_ANNOTATED_IPA_SYMBOLS = set()
370
+ for code in CODES:
371
+ if code == "ipa":
372
+ continue
373
+ ALL_ANNOTATED_IPA_SYMBOLS |= set(globals()[f"{code.upper()}2IPA"].values())
374
+ ALL_ANNOTATED_IPA_SYMBOLS.discard("")
375
+ ALL_ANNOTATED_IPA_SYMBOLS.discard(" ")
376
+ ALL_ANNOTATED_IPA_SYMBOLS.discard("ʰ")
377
+ ALL_ANNOTATED_IPA_SYMBOLS |= set(f"{s}ʰ" for s in ALL_ANNOTATED_IPA_SYMBOLS)
378
+
379
+
380
+ def main(args):
381
+ if len(args) != 3:
382
+ usage()
383
+ return
384
+
385
+ src, tgt, phoneme_string = args
386
+ print(phoneme_string, "=>", convert(phoneme_string, src, tgt))
387
+
388
+
389
+ if __name__ == "__main__":
390
+ try:
391
+ main(sys.argv[1:])
392
+ except Exception as e:
393
+ print(f"Line {e.__traceback__.tb_lineno}:", e) # type: ignore
394
+ usage()
app/hf.py CHANGED
@@ -26,6 +26,7 @@ LEADERBOARD_FEATURES = Features(
26
  "fer_PSST": Value("float32"),
27
  "fer_SpeechOcean": Value("float32"),
28
  "fer_ISLE": Value("float32"),
 
29
  }
30
  )
31
  LEADERBOARD_DEFAULTS = {
@@ -35,17 +36,26 @@ LEADERBOARD_DEFAULTS = {
35
  "fer_PSST": None,
36
  "fer_SpeechOcean": None,
37
  "fer_ISLE": None,
 
38
  }
39
 
40
 
 
 
 
 
 
 
 
 
41
  def get_repo_info(
42
  repo_id, type: Literal["model", "dataset", "space"] = "model"
43
- ) -> tuple[str, datetime]:
44
  try:
45
- repo_info = api.repo_info(repo_id=repo_id, repo_type=type)
46
- return repo_info.sha, repo_info.last_modified # type: ignore
47
  except RepositoryNotFoundError:
48
- return "", datetime(year=1970, month=1, day=1)
49
 
50
 
51
  def get_or_create_leaderboard() -> Dataset:
@@ -81,6 +91,7 @@ def add_leaderboard_entry(
81
  average_per: float,
82
  average_fer: float,
83
  url: str,
 
84
  per_dataset_fers: dict = {},
85
  ):
86
  existing_dataset = get_or_create_leaderboard()
@@ -99,6 +110,7 @@ def add_leaderboard_entry(
99
  fer_PSST=[per_dataset_fers.get("PSST")],
100
  fer_SpeechOcean=[per_dataset_fers.get("SpeechOcean")],
101
  fer_ISLE=[per_dataset_fers.get("ISLE")],
 
102
  ),
103
  features=LEADERBOARD_FEATURES,
104
  )
 
26
  "fer_PSST": Value("float32"),
27
  "fer_SpeechOcean": Value("float32"),
28
  "fer_ISLE": Value("float32"),
29
+ "model_bytes": Value("int64"),
30
  }
31
  )
32
  LEADERBOARD_DEFAULTS = {
 
36
  "fer_PSST": None,
37
  "fer_SpeechOcean": None,
38
  "fer_ISLE": None,
39
+ "model_bytes": None,
40
  }
41
 
42
 
43
+ def get_size(repo_info):
44
+ total_size_bytes = 0
45
+ for sibling in repo_info.siblings:
46
+ size_in_bytes = sibling.size or 0
47
+ total_size_bytes += size_in_bytes
48
+ return total_size_bytes
49
+
50
+
51
  def get_repo_info(
52
  repo_id, type: Literal["model", "dataset", "space"] = "model"
53
+ ) -> tuple[str, datetime, int | None]:
54
  try:
55
+ repo_info = api.repo_info(repo_id=repo_id, repo_type=type, files_metadata=True)
56
+ return repo_info.sha, repo_info.last_modified, get_size(repo_info) or None # type: ignore
57
  except RepositoryNotFoundError:
58
+ return "", datetime(year=1970, month=1, day=1), None
59
 
60
 
61
  def get_or_create_leaderboard() -> Dataset:
 
91
  average_per: float,
92
  average_fer: float,
93
  url: str,
94
+ model_bytes: int | None,
95
  per_dataset_fers: dict = {},
96
  ):
97
  existing_dataset = get_or_create_leaderboard()
 
110
  fer_PSST=[per_dataset_fers.get("PSST")],
111
  fer_SpeechOcean=[per_dataset_fers.get("SpeechOcean")],
112
  fer_ISLE=[per_dataset_fers.get("ISLE")],
113
+ model_bytes=[model_bytes],
114
  ),
115
  features=LEADERBOARD_FEATURES,
116
  )
app/inference.py CHANGED
@@ -2,6 +2,9 @@
2
 
3
  import torch
4
  from transformers import AutoProcessor, AutoModelForCTC
 
 
 
5
 
6
  DEVICE = (
7
  "cuda"
@@ -27,13 +30,37 @@ def clear_cache():
27
  torch.mps.empty_cache()
28
 
29
 
30
- def load_model(model_id, device=DEVICE):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  processor = AutoProcessor.from_pretrained(model_id)
32
  model = AutoModelForCTC.from_pretrained(model_id).to(device)
33
  return model, processor
34
 
35
 
36
- def transcribe(audio, model, processor) -> str:
 
37
  input_values = (
38
  processor(
39
  [audio],
@@ -49,3 +76,24 @@ def transcribe(audio, model, processor) -> str:
49
 
50
  predicted_ids = torch.argmax(logits, dim=-1)
51
  return processor.decode(predicted_ids[0])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  import torch
4
  from transformers import AutoProcessor, AutoModelForCTC
5
+ from espnet2.bin.s2t_inference import Speech2Text
6
+
7
+ MODEL_TYPES = ["Transformers CTC", "POWSM"]
8
 
9
  DEVICE = (
10
  "cuda"
 
30
  torch.mps.empty_cache()
31
 
32
 
33
+ # ================================== POWSM ==================================
34
+ def load_powsm(model_id, language="<eng>", device=DEVICE):
35
+ s2t = Speech2Text.from_pretrained(
36
+ model_id,
37
+ device=device.replace("mps", "cpu"),
38
+ lang_sym=language,
39
+ task_sym="<pr>",
40
+ )
41
+ if device == "mps":
42
+ s2t.s2t_model.to(device=device, dtype=torch.float32)
43
+ s2t.beam_search.to(device=device, dtype=torch.float32)
44
+ s2t.dtype = "float32"
45
+ s2t.device = device
46
+ return s2t
47
+
48
+
49
+ def transcribe_powsm(audio, model):
50
+ pred = model(audio, text_prev="<na>")[0][0]
51
+ return pred.split("<notimestamps>")[1].strip().replace("/", "")
52
+
53
+
54
+ # ===========================================================================
55
+ # ============================= Transformers CTC ============================
56
+ def load_transformers_ctc(model_id, device=DEVICE):
57
  processor = AutoProcessor.from_pretrained(model_id)
58
  model = AutoModelForCTC.from_pretrained(model_id).to(device)
59
  return model, processor
60
 
61
 
62
+ def transcribe_transformers_ctc(audio, model) -> str:
63
+ model, processor = model
64
  input_values = (
65
  processor(
66
  [audio],
 
76
 
77
  predicted_ids = torch.argmax(logits, dim=-1)
78
  return processor.decode(predicted_ids[0])
79
+
80
+
81
+ # ===========================================================================
82
+
83
+
84
+ def load_model(model_id, type, device=DEVICE):
85
+ if type == "POWSM":
86
+ return load_powsm(model_id, device=device)
87
+ elif type == "Transformers CTC":
88
+ return load_transformers_ctc(model_id, device=device)
89
+ else:
90
+ raise ValueError("Unsupported model type: " + str(type))
91
+
92
+
93
+ def transcribe(audio, type, model) -> str:
94
+ if type == "POWSM":
95
+ return transcribe_powsm(audio, model)
96
+ elif type == "Transformers CTC":
97
+ return transcribe_transformers_ctc(audio, model)
98
+ else:
99
+ raise ValueError("Unsupported model type: " + str(type))
app/tasks.py CHANGED
@@ -10,6 +10,7 @@ from metrics import per, fer
10
  from datasets import load_from_disk
11
  from hf import get_repo_info, add_leaderboard_entry
12
  from inference import clear_cache, load_model, transcribe
 
13
 
14
  leaderboard_lock = multiprocessing.Lock()
15
 
@@ -21,6 +22,9 @@ class Task(TypedDict):
21
  repo_hash: str
22
  repo_last_modified: datetime
23
  submission_timestamp: datetime
 
 
 
24
  url: str
25
  error: str | None
26
 
@@ -42,10 +46,12 @@ def get_status(query: str) -> dict:
42
  return {"error": f"No results found for '{query}'"}
43
 
44
 
45
- def start_eval_task(display_name: str, repo_id: str, url: str) -> str:
 
 
46
  """Start evaluation task in background. Returns task ID that can be used to check status."""
47
 
48
- repo_hash, last_modified = get_repo_info(repo_id)
49
  # TODO: check if hash is different from the most recent submission if any for repo_id, otherwise don't recompute
50
  task = Task(
51
  status="submitted",
@@ -54,6 +60,9 @@ def start_eval_task(display_name: str, repo_id: str, url: str) -> str:
54
  repo_hash=repo_hash,
55
  repo_last_modified=last_modified,
56
  submission_timestamp=datetime.now(),
 
 
 
57
  url=url,
58
  error=None,
59
  )
@@ -83,9 +92,11 @@ def _eval_task(task: Task, leaderboard_lock):
83
  per_dataset_fers = {}
84
 
85
  clear_cache()
86
- model, processor = load_model(task["repo_id"])
87
  for row in test_ds:
88
- transcript = transcribe(row["audio"]["array"], model, processor) # type: ignore
 
 
89
  row_per = per(transcript, row["ipa"]) # type: ignore
90
  row_fer = fer(transcript, row["ipa"]) # type: ignore
91
  average_per += row_per
@@ -107,6 +118,7 @@ def _eval_task(task: Task, leaderboard_lock):
107
  average_per=average_per,
108
  average_fer=average_fer,
109
  url=task["url"],
 
110
  per_dataset_fers=per_dataset_fers,
111
  )
112
 
 
10
  from datasets import load_from_disk
11
  from hf import get_repo_info, add_leaderboard_entry
12
  from inference import clear_cache, load_model, transcribe
13
+ from codes import convert
14
 
15
  leaderboard_lock = multiprocessing.Lock()
16
 
 
22
  repo_hash: str
23
  repo_last_modified: datetime
24
  submission_timestamp: datetime
25
+ model_type: str
26
+ phone_code: str
27
+ model_bytes: int | None
28
  url: str
29
  error: str | None
30
 
 
46
  return {"error": f"No results found for '{query}'"}
47
 
48
 
49
+ def start_eval_task(
50
+ display_name: str, repo_id: str, url: str, model_type: str, phone_code: str
51
+ ) -> str:
52
  """Start evaluation task in background. Returns task ID that can be used to check status."""
53
 
54
+ repo_hash, last_modified, size_bytes = get_repo_info(repo_id)
55
  # TODO: check if hash is different from the most recent submission if any for repo_id, otherwise don't recompute
56
  task = Task(
57
  status="submitted",
 
60
  repo_hash=repo_hash,
61
  repo_last_modified=last_modified,
62
  submission_timestamp=datetime.now(),
63
+ model_type=model_type,
64
+ phone_code=phone_code,
65
+ model_bytes=size_bytes,
66
  url=url,
67
  error=None,
68
  )
 
92
  per_dataset_fers = {}
93
 
94
  clear_cache()
95
+ model = load_model(task["repo_id"], task["model_type"])
96
  for row in test_ds:
97
+ transcript = transcribe(row["audio"]["array"], task["model_type"], model) # type: ignore
98
+ if task["phone_code"] != "ipa":
99
+ transcript = convert(transcript, task["phone_code"], "ipa")
100
  row_per = per(transcript, row["ipa"]) # type: ignore
101
  row_fer = fer(transcript, row["ipa"]) # type: ignore
102
  average_per += row_per
 
118
  average_per=average_per,
119
  average_fer=average_fer,
120
  url=task["url"],
121
+ model_bytes=task["model_bytes"],
122
  per_dataset_fers=per_dataset_fers,
123
  )
124
 
requirements.txt CHANGED
@@ -3,14 +3,16 @@ huggingface_hub==0.34.4
3
  datasets==4.0.0
4
 
5
  # Data processing
6
- pandas==2.0.3
7
- numpy==1.25.2
8
  panphon==0.21.2
9
  torch==2.8.0
10
  torchaudio==2.8.0
11
  torchcodec==0.6.0
12
  transformers==4.56.0
13
  phonemizer==3.3.0
 
 
14
 
15
  # UI
16
  gradio==5.12.0
 
3
  datasets==4.0.0
4
 
5
  # Data processing
6
+ pandas==2.3.3
7
+ numpy==2.0.2
8
  panphon==0.21.2
9
  torch==2.8.0
10
  torchaudio==2.8.0
11
  torchcodec==0.6.0
12
  transformers==4.56.0
13
  phonemizer==3.3.0
14
+ espnet==202509
15
+ espnet-model-zoo==0.1.7
16
 
17
  # UI
18
  gradio==5.12.0
requirements_lock.txt CHANGED
@@ -3,65 +3,104 @@ aiohappyeyeballs==2.6.1
3
  aiohttp==3.12.15
4
  aiosignal==1.4.0
5
  annotated-types==0.7.0
 
6
  anyio==4.10.0
 
7
  async-timeout==5.0.1
8
  attrs==25.3.0
 
9
  babel==2.17.0
10
  certifi==2025.8.3
 
11
  charset-normalizer==3.4.3
 
12
  click==8.2.1
13
  colorama==0.4.6
 
14
  csvw==3.5.1
15
  datasets==4.0.0
 
16
  dill==0.3.8
 
17
  dlinfo==2.0.0
18
  editdistance==0.8.1
 
 
 
 
19
  exceptiongroup==1.3.0
 
20
  fastapi==0.116.1
21
  ffmpy==0.6.1
22
  filelock==3.19.1
23
  frozenlist==1.7.0
24
  fsspec==2025.3.0
 
25
  gradio==5.12.0
26
  gradio_client==1.5.4
27
  h11==0.16.0
 
28
  hf-xet==1.1.9
29
  httpcore==1.0.9
30
  httpx==0.28.1
31
  huggingface-hub==0.34.4
 
 
32
  idna==3.10
 
 
33
  isodate==0.7.2
 
 
34
  Jinja2==3.1.6
35
  joblib==1.5.2
36
  jsonschema==4.25.1
37
  jsonschema-specifications==2025.4.1
 
38
  language-tags==1.2.0
 
 
 
 
 
39
  markdown-it-py==4.0.0
40
  MarkupSafe==2.1.5
41
  mdurl==0.1.2
 
42
  mpmath==1.3.0
 
43
  multidict==6.6.4
44
  multiprocess==0.70.16
45
  munkres==1.1.4
46
  networkx==3.4.2
47
- numpy==1.25.2
 
 
 
 
48
  orjson==3.11.3
49
  packaging==25.0
50
- pandas==2.0.3
51
  panphon==0.21.2
52
  phonemizer==3.3.0
53
  pillow==11.3.0
 
 
54
  propcache==0.3.2
55
  protobuf==6.32.0
56
  pyarrow==21.0.0
 
57
  pydantic==2.11.7
58
  pydantic_core==2.33.2
59
  pydub==0.25.1
60
  Pygments==2.19.2
61
  pyparsing==3.2.3
 
62
  python-dateutil==2.9.0.post0
63
  python-multipart==0.0.20
 
64
  pytz==2025.2
 
65
  PyYAML==6.0.2
66
  rdflib==7.1.4
67
  referencing==0.36.2
@@ -73,28 +112,39 @@ rpds-py==0.27.1
73
  ruff==0.12.11
74
  safehttpx==0.1.6
75
  safetensors==0.6.2
 
 
76
  segments==2.3.0
77
  semantic-version==2.10.0
 
78
  shellingham==1.5.4
79
  six==1.17.0
80
  sniffio==1.3.1
 
 
81
  starlette==0.47.3
82
  sympy==1.14.0
 
83
  tokenizers==0.22.0
84
  tomlkit==0.13.3
85
  torch==2.8.0
 
86
  torchaudio==2.8.0
87
  torchcodec==0.6.0
 
88
  tqdm==4.67.1
89
  transformers==4.56.0
 
90
  typer==0.17.3
91
  typing-inspection==0.4.1
92
  typing_extensions==4.15.0
93
  tzdata==2025.2
94
  unicodecsv==0.14.1
 
95
  uritemplate==4.2.0
96
  urllib3==2.5.0
97
  uvicorn==0.35.0
98
  websockets==14.2
99
  xxhash==3.5.0
100
  yarl==1.20.1
 
 
3
  aiohttp==3.12.15
4
  aiosignal==1.4.0
5
  annotated-types==0.7.0
6
+ antlr4-python3-runtime==4.9.3
7
  anyio==4.10.0
8
+ asteroid-filterbanks==0.4.0
9
  async-timeout==5.0.1
10
  attrs==25.3.0
11
+ audioread==3.1.0
12
  babel==2.17.0
13
  certifi==2025.8.3
14
+ cffi==2.0.0
15
  charset-normalizer==3.4.3
16
+ ci-sdr==0.0.2
17
  click==8.2.1
18
  colorama==0.4.6
19
+ ConfigArgParse==1.7.1
20
  csvw==3.5.1
21
  datasets==4.0.0
22
+ decorator==5.2.1
23
  dill==0.3.8
24
+ Distance==0.1.3
25
  dlinfo==2.0.0
26
  editdistance==0.8.1
27
+ einops==0.8.1
28
+ espnet==202509
29
+ espnet-model-zoo==0.1.7
30
+ espnet-tts-frontend==0.0.3
31
  exceptiongroup==1.3.0
32
+ fast-bss-eval==0.1.3
33
  fastapi==0.116.1
34
  ffmpy==0.6.1
35
  filelock==3.19.1
36
  frozenlist==1.7.0
37
  fsspec==2025.3.0
38
+ g2p-en==2.1.0
39
  gradio==5.12.0
40
  gradio_client==1.5.4
41
  h11==0.16.0
42
+ h5py==3.15.1
43
  hf-xet==1.1.9
44
  httpcore==1.0.9
45
  httpx==0.28.1
46
  huggingface-hub==0.34.4
47
+ humanfriendly==10.0
48
+ hydra-core==1.3.2
49
  idna==3.10
50
+ importlib-metadata==4.13.0
51
+ inflect==7.5.0
52
  isodate==0.7.2
53
+ jaconv==0.4.0
54
+ jamo==0.4.1
55
  Jinja2==3.1.6
56
  joblib==1.5.2
57
  jsonschema==4.25.1
58
  jsonschema-specifications==2025.4.1
59
+ kaldiio==2.18.1
60
  language-tags==1.2.0
61
+ lazy_loader==0.4
62
+ librosa==0.11.0
63
+ lightning==2.5.5
64
+ lightning-utilities==0.15.2
65
+ llvmlite==0.45.1
66
  markdown-it-py==4.0.0
67
  MarkupSafe==2.1.5
68
  mdurl==0.1.2
69
+ more-itertools==10.8.0
70
  mpmath==1.3.0
71
+ msgpack==1.1.2
72
  multidict==6.6.4
73
  multiprocess==0.70.16
74
  munkres==1.1.4
75
  networkx==3.4.2
76
+ nltk==3.9.2
77
+ numba==0.62.1
78
+ numpy==2.0.2
79
+ omegaconf==2.3.0
80
+ opt_einsum==3.4.0
81
  orjson==3.11.3
82
  packaging==25.0
83
+ pandas==2.3.3
84
  panphon==0.21.2
85
  phonemizer==3.3.0
86
  pillow==11.3.0
87
+ platformdirs==4.5.0
88
+ pooch==1.8.2
89
  propcache==0.3.2
90
  protobuf==6.32.0
91
  pyarrow==21.0.0
92
+ pycparser==2.23
93
  pydantic==2.11.7
94
  pydantic_core==2.33.2
95
  pydub==0.25.1
96
  Pygments==2.19.2
97
  pyparsing==3.2.3
98
+ pypinyin==0.44.0
99
  python-dateutil==2.9.0.post0
100
  python-multipart==0.0.20
101
+ pytorch-lightning==2.5.5
102
  pytz==2025.2
103
+ pyworld==0.3.5
104
  PyYAML==6.0.2
105
  rdflib==7.1.4
106
  referencing==0.36.2
 
112
  ruff==0.12.11
113
  safehttpx==0.1.6
114
  safetensors==0.6.2
115
+ scikit-learn==1.7.2
116
+ scipy==1.15.3
117
  segments==2.3.0
118
  semantic-version==2.10.0
119
+ sentencepiece==0.2.0
120
  shellingham==1.5.4
121
  six==1.17.0
122
  sniffio==1.3.1
123
+ soundfile==0.13.1
124
+ soxr==1.0.0
125
  starlette==0.47.3
126
  sympy==1.14.0
127
+ threadpoolctl==3.6.0
128
  tokenizers==0.22.0
129
  tomlkit==0.13.3
130
  torch==2.8.0
131
+ torch-complex==0.4.4
132
  torchaudio==2.8.0
133
  torchcodec==0.6.0
134
+ torchmetrics==1.8.2
135
  tqdm==4.67.1
136
  transformers==4.56.0
137
+ typeguard==4.4.4
138
  typer==0.17.3
139
  typing-inspection==0.4.1
140
  typing_extensions==4.15.0
141
  tzdata==2025.2
142
  unicodecsv==0.14.1
143
+ Unidecode==1.4.0
144
  uritemplate==4.2.0
145
  urllib3==2.5.0
146
  uvicorn==0.35.0
147
  websockets==14.2
148
  xxhash==3.5.0
149
  yarl==1.20.1
150
+ zipp==3.23.0