dolphin-asr / app.py
abreza's picture
Add asset and model downloading functionality to app.py
a4e2823
raw
history blame
6.46 kB
import os
import gradio as gr
import spaces
import urllib.request
import shutil
import dolphin
from dolphin.languages import LANGUAGE_CODES, LANGUAGE_REGION_CODES
MODEL_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "models")
os.makedirs(MODEL_DIR, exist_ok=True)
language_options = [(f"{code}: {name[0]}", code)
for code, name in LANGUAGE_CODES.items()]
language_options.sort(key=lambda x: x[0])
MODELS = {
"base (140M)": "base",
"small (372M)": "small",
}
MODEL_URLS = {
"base": "https://huggingface.co/DataoceanAI/dolphin-base/resolve/main/base.pt",
"small": "https://huggingface.co/DataoceanAI/dolphin-small/resolve/main/small.pt",
}
ASSET_URLS = {
"bpe.model": "https://huggingface.co/DataoceanAI/dolphin-base/resolve/main/bpe.model",
"config.yaml": "https://huggingface.co/DataoceanAI/dolphin-base/resolve/main/config.yaml",
"feats_stats.npz": "https://huggingface.co/DataoceanAI/dolphin-base/resolve/main/feats_stats.npz",
}
language_to_regions = {}
for lang_region, names in LANGUAGE_REGION_CODES.items():
if "-" in lang_region:
lang, region = lang_region.split("-", 1)
if lang not in language_to_regions:
language_to_regions[lang] = []
language_to_regions[lang].append((f"{region}: {names[0]}", region))
def download_file(url, dest_path):
if not os.path.exists(dest_path):
print(f"Downloading {url} to {dest_path}")
with urllib.request.urlopen(url) as response, open(dest_path, 'wb') as out_file:
shutil.copyfileobj(response, out_file)
print(f"Downloaded {dest_path}")
else:
print(f"File already exists: {dest_path}")
def ensure_assets_downloaded():
assets_dir = os.path.join(os.path.dirname(
os.path.abspath(__file__)), "dolphin", "assets")
os.makedirs(assets_dir, exist_ok=True)
for filename, url in ASSET_URLS.items():
download_file(url, os.path.join(assets_dir, filename))
def ensure_model_downloaded(model_key):
if model_key not in MODEL_URLS:
raise ValueError(f"Unknown model: {model_key}")
model_path = os.path.join(MODEL_DIR, f"{model_key}.pt")
if not os.path.exists(model_path):
download_file(MODEL_URLS[model_key], model_path)
return model_path
def update_regions(language):
if language and language in language_to_regions:
regions = language_to_regions[language]
regions.sort(key=lambda x: x[0])
return regions, regions[0][1], True
return [], None, False
@spaces.GPU
def transcribe_audio(audio_file, model_name, language, region, predict_timestamps, padding_speech):
try:
ensure_assets_downloaded()
model_key = MODELS[model_name]
ensure_model_downloaded(model_key)
model = dolphin.load_model(model_key, MODEL_DIR, "cuda")
waveform = dolphin.load_audio(audio_file)
kwargs = {
"predict_time": predict_timestamps,
"padding_speech": padding_speech
}
if language:
kwargs["lang_sym"] = language
if region:
kwargs["region_sym"] = region
result = model(waveform, **kwargs)
output_text = result.text
language_detected = f"{result.language}"
region_detected = f"{result.region}"
detected_info = f"Detected language: {result.language}" + (
f", region: {result.region}" if result.region else "")
return output_text, detected_info
except Exception as e:
return f"Error: {str(e)}", "Transcription failed"
with gr.Blocks(title="Dolphin Speech Recognition") as demo:
gr.Markdown("# Dolphin ASR")
gr.Markdown("""
A multilingual, multitask ASR model supporting 40 Eastern languages and 22 Chinese dialects.
This model is from [DataoceanAI/Dolphin](https://github.com/DataoceanAI/Dolphin), for speech recognition in
Eastern languages including Chinese, Japanese, Korean, and many more.
""")
with gr.Row():
with gr.Column():
audio_input = gr.Audio(
type="filepath", label="Upload or Record Audio")
with gr.Row():
model_dropdown = gr.Dropdown(
choices=list(MODELS.keys()),
value=list(MODELS.keys())[1],
label="Model Size"
)
with gr.Row():
language_dropdown = gr.Dropdown(
choices=language_options,
value=None,
label="Language (Optional)",
info="If not selected, the model will auto-detect language"
)
region_dropdown = gr.Dropdown(
choices=[],
value=None,
label="Region (Optional)",
visible=False
)
with gr.Row():
timestamp_checkbox = gr.Checkbox(
value=True,
label="Include Timestamps"
)
padding_checkbox = gr.Checkbox(
value=True,
label="Pad Speech to 30s"
)
transcribe_button = gr.Button("Transcribe", variant="primary")
with gr.Column():
output_text = gr.Textbox(label="Transcription", lines=10)
language_info = gr.Textbox(label="Detected Language", lines=1)
language_dropdown.change(
fn=update_regions,
inputs=[language_dropdown],
outputs=[region_dropdown, region_dropdown, region_dropdown]
)
transcribe_button.click(
fn=transcribe_audio,
inputs=[
audio_input,
model_dropdown,
language_dropdown,
region_dropdown,
timestamp_checkbox,
padding_checkbox
],
outputs=[output_text, language_info]
)
gr.Markdown("""
## Usage Notes
- The model supports 40 Eastern languages and 22 Chinese dialects
- You can let the model auto-detect language or specify language and region
- Timestamps can be included in the output
- Speech can be padded to 30 seconds for better processing
## Credits
- Model: [DataoceanAI/Dolphin](https://github.com/DataoceanAI/Dolphin)
- Paper: [Dolphin: A Multilingual Model for Eastern Languages](https://arxiv.org/abs/2503.20212)
""")
demo.launch()