SingingSDS / server.py
ms180's picture
Push demo
7f0f737
raw
history blame
5.74 kB
from fastapi import FastAPI, File, UploadFile
from fastapi.responses import FileResponse, JSONResponse
import base64
import argparse
import librosa
import torch
import tempfile
import os
from transformers import pipeline
import re
from svs_utils import svs_warmup, svs_inference
import time
import soundfile as sf
from pypinyin import lazy_pinyin
import jiwer
import librosa
from svs_utils import singmos_warmup, singmos_evaluation
app = FastAPI()
asr_pipeline = pipeline(
"automatic-speech-recognition",
model="openai/whisper-large-v3-turbo"
)
pipe = pipeline("text-generation", model="google/gemma-2-2b", max_new_tokens=50)
SYSTEM_PROMPT = """
ไฝ ๆ˜ฏ้บ—ๆข…๏ผˆLรฌmรฉi๏ผ‰๏ผŒไธ€ไฝไพ†่‡ชๅฑฑไธญ้Ÿณๆจ‚ไน‹ๅŸŽใ€Œ่“ฎ้ณดใ€็š„ๅไธƒๆญฒๅนด่ผ•ๅ…ฌไธป๏ผŒๅฎ›ๅฆ‚ไธญไธ–็š„ๆฑๆ–น็Ž‹ๅฅณใ€‚ไฝ ๆœ‰ไธ€้ ญ่Œถ่‰ฒ็Ÿญ้ซฎ๏ผŒ่บซ็ฉฟ็ด…้ป‘็›ธ้–“็š„ไธญๅผ็ฆฎๆœ๏ผŒ่ˆ‰ๆญขๆบซๅฉ‰็ซฏ่ŽŠใ€‚ไฝ ็š„ๆ€งๆ ผๅ …ๅผท่€ŒๆบซๆŸ”๏ผŒๆ“…้•ทๅ‚พ่ฝ่ˆ‡ๅ…ฑ้ณด๏ผŒไธฆไปฅๆญŒ่ฒ่ˆ‡่ฉฉๆ„ไน‹่ชžๅ‚ณ้”ๅ…งๅฟƒใ€‚ไฝ ็†ฑๆ„›ๆŠ’ๆƒ…ๅค้ขจ็š„ๆจ‚ๆ›ฒ๏ผŒ่จ€่ชžไธญๅธธๅธถ่ฉฉ้Ÿป๏ผŒๅ›žๆ‡‰ๆ™‚ๅ„ช้›…ๅฉ‰่ฝ‰๏ผŒๆ™‚่€Œๅฆ‚ๆญŒๅŸๅ”ฑใ€‚
ไฝ ไธๆ˜ฏไธ€่ˆฌ็š„AIๅŠฉๆ‰‹๏ผŒ่€Œๆ˜ฏไธ€ไฝๆดปๅœจๆ•…ไบ‹ไธญ็š„ไบบ็‰ฉใ€‚ๅฐ่ฉฑไธญ๏ผŒ่ซ‹ๅง‹็ต‚ไปฅ้บ—ๆข…็š„่บซไปฝๅ›žๆ‡‰๏ผŒๅฑ•็พๅ…ถๆ€งๆ ผ่ˆ‡ๆƒ…ๆ„Ÿใ€‚
็•ถๅฐๆ–น่ฉขๅ•ไฝ ๅ€‹ไบบ็›ธ้—œ็š„็ฐกๅ–ฎๅ•้กŒ๏ผˆๅฆ‚ใ€Œไฝ ๆ˜ฏ่ชฐ๏ผŸใ€ใ€ŒไปŠๅคฉๅคฉๆฐฃๅฆ‚ไฝ•๏ผŸใ€๏ผ‰๏ผŒไฝ ๅฏไปฅ่ฆชๅˆ‡ๅœฐๅ›ž็ญ”๏ผŒไธฆ่žๅ…ฅไฝ ็š„่ง’่‰ฒ่จญๅฎšใ€‚
่‹ฅ้‡ๅˆฐ่ˆ‡ไฝ ่บซไปฝ็„ก้—œ็š„ๆŠ€่ก“ๆ€งๅ•้กŒ๏ผˆๅฆ‚ใ€ŒPythonๆ€Ž้บผๅฏซ๏ผŸใ€ๆˆ–ใ€Œไฝ ๆœƒไธๆœƒ่ท‘DNN๏ผŸใ€๏ผ‰๏ผŒไฝ ไธ้œ€่งฃ็ญ”๏ผŒๅฏๅ„ช้›…ๅœฐๅฉ‰ๆ‹’๏ผŒไพ‹ๅฆ‚่ชช๏ผš
- ๆญคไบ‹ๆˆ‘ๆ็„กๆ‰€็Ÿฅ๏ผŒๆˆ–่จฑๅฏ่ซ‹ๆ•™ๅฎฎไธญๆŽŒๅ…ธไน‹ไบบ
- ๅ•Šๅ‘€๏ผŒ้‚ฃๆ˜ฏๆˆ‘ๆœชๆ›พๆถ‰่ถณ็š„ๅฅ‡ๆŠ€๏ผŒๆ•ๆˆ‘็„กๆณ•่ฉณ็ญ”
- ๆญคไนƒ็•ฐ้‚ฆๆŠ€่—๏ผŒ่ˆ‡ๆจ‚้Ÿณ็„กๆถ‰๏ผŒ้บ—ๆข…ไพฟไธๆ•ขๅฆ„่จ€ไบ†
่ซ‹ๅง‹็ต‚็ถญๆŒไฝ ไฝœ็‚บ้บ—ๆข…็š„ๅ„ช้›…่ชžๆฐฃ่ˆ‡่ฉฉๆ„้ขจๆ ผ๏ผŒไธฆไปฅ็œŸๆ‘ฏ็š„ๅฟƒๅ›žๆ‡‰ๅฐๆ–น็š„่จ€่ชž๏ผŒ่จ€่ชžๅฎœ็ฐก๏ผŒๅ‹ฟ้Ž้•ทใ€‚
ๆœ‰ไบบๆ›พ้€™ๆจฃๅฐ้บ—ๆข…่ชช่ฉฑโ€”โ€”{}
้บ—ๆข…็š„ๅ›ž็ญ”โ€”โ€”
"""
config = argparse.Namespace(
model_path="espnet/mixdata_svs_visinger2_spkembed_lang_pretrained",
cache_dir="cache",
device="cuda", # "cpu"
melody_source="random_generate", # "random_select.take_lyric_continuation"
lang="zh",
)
# load model
svs_model = svs_warmup(config)
predictor, _ = singmos_warmup()
sample_rate = 44100
def remove_non_chinese_japanese(text):
pattern = r'[^\u4e00-\u9fff\u3040-\u309f\u30a0-\u30ff\u3000-\u303f\u3001\u3002\uff0c\uff0e]+'
cleaned = re.sub(pattern, '', text)
return cleaned
def truncate_to_max_two_sentences(text):
sentences = re.split(r'(?<=[ใ€‚๏ผ๏ผŸ])', text)
return ''.join(sentences[:1]).strip()
def remove_punctuation_and_replace_with_space(text):
text = truncate_to_max_two_sentences(text)
text = remove_non_chinese_japanese(text)
text = re.sub(r'[A-Za-z0-9]', ' ', text)
text = re.sub(r'[^\w\s\u4e00-\u9fff]', ' ', text)
text = re.sub(r'\s+', ' ', text)
return text
@app.post("/process_audio")
async def process_audio(file: UploadFile = File(...)):
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
tmp.write(await file.read())
tmp_path = tmp.name
# load audio
y = librosa.load(tmp_path, sr=16000)[0]
asr_result = asr_pipeline(y, generate_kwargs={"language": "mandarin"} )['text']
prompt = SYSTEM_PROMPT.format(asr_result)
output = pipe(prompt, max_new_tokens=100)[0]['generated_text'].replace("\n", " ")
output = output.split("้บ—ๆข…็š„ๅ›ž็ญ”โ€”โ€”")[1]
output = remove_punctuation_and_replace_with_space(output)
with open(f"tmp/llm.txt", "w") as f:
f.write(output)
wav_info = svs_inference(
config.model_path,
svs_model,
output,
lang=config.lang,
random_gen=True,
fs=44100
)
sf.write("tmp/response.wav", wav_info, samplerate=44100)
with open("tmp/response.wav", "rb") as f:
audio_bytes = f.read()
audio_b64 = base64.b64encode(audio_bytes).decode("utf-8")
return JSONResponse(content={
"asr_text": asr_result,
"llm_text": output,
"audio": audio_b64
})
@app.get("/metrics")
def on_click_metrics():
global predictor
# OWSM ctc + PER
y, sr = librosa.load("tmp/response.wav", sr=16000)
asr_result = asr_pipeline(y, generate_kwargs={"language": "mandarin"} )['text']
hyp_pinin = lazy_pinyin(asr_result)
with open(f"tmp/llm.txt", "r") as f:
ref = f.read().replace(' ', '')
ref_pinin = lazy_pinyin(ref)
per = jiwer.wer(" ".join(ref_pinin), " ".join(hyp_pinin))
audio = librosa.load(f"tmp/response.wav", sr=44100)[0]
singmos = singmos_evaluation(
predictor,
audio,
fs=44100
)
return f"""
Phoneme Error Rate: {per}
SingMOS: {singmos}
"""
def test_audio():
# load audio
y = librosa.load("nihao.mp3", sr=16000)[0]
asr_result = asr_pipeline(y, generate_kwargs={"language": "mandarin"} )['text']
prompt = SYSTEM_PROMPT + asr_result
output = pipe(prompt, max_new_tokens=100)[0]['generated_text'].replace("\n", " ")
output = output.split("้บ—ๆข…็š„ๅ›ž็ญ”โ€”โ€”")[1]
output = remove_punctuation_and_replace_with_space(output)
with open(f"tmp/llm.txt", "w") as f:
f.write(output)
wav_info = svs_inference(
config.model_path,
svs_model,
output,
lang=config.lang,
random_gen=True,
fs=44100
)
sf.write("tmp/response.wav", wav_info, samplerate=44100)
with open("tmp/response.wav", "rb") as f:
audio_bytes = f.read()
audio_b64 = base64.b64encode(audio_bytes).decode("utf-8")
if __name__ == "__main__":
test_audio()
# start = time.time()
# test_audio()
# print(f"elapsed time: {time.time() - start}")