|
import gradio as gr |
|
from transformers import T5ForConditionalGeneration, T5Tokenizer |
|
import torch |
|
import os |
|
|
|
|
|
model_name_or_path = "DeepLearning101/Corrector101zhTWT5" |
|
auth_token = os.getenv("HF_HOME") |
|
|
|
|
|
try: |
|
tokenizer = T5Tokenizer.from_pretrained(model_name_or_path, use_auth_token=auth_token) |
|
model = T5ForConditionalGeneration.from_pretrained(model_name_or_path, use_auth_token=auth_token) |
|
model.eval() |
|
except Exception as e: |
|
print(f"加載模型或分詞器失敗,錯誤信息:{e}") |
|
exit(1) |
|
|
|
if torch.cuda.is_available(): |
|
model.cuda() |
|
|
|
def correct_text(text): |
|
"""將輸入的文本通過 T5 模型進行修正""" |
|
inputs = tokenizer(text, return_tensors="pt", max_length=512, truncation=True, padding=True) |
|
if torch.cuda.is_available(): |
|
inputs = {k: v.cuda() for k, v in inputs.items()} |
|
with torch.no_grad(): |
|
outputs = model.generate(**inputs) |
|
corrected_text = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
return corrected_text |
|
|
|
def main(): |
|
interface = gr.Interface( |
|
fn=correct_text, |
|
inputs=gr.Textbox(lines=5, placeholder="請輸入需要修正的中文文本..."), |
|
outputs=gr.Textbox(label="修正後的文本"), |
|
title="客服ASR文本AI糾錯系統", |
|
description="""<a href='https://www.twman.org' target='_blank'>TonTon Huang Ph.D. @ 2024/04 </a><br> |
|
輸入ASR文本,糾正同音字/詞錯誤<br> |
|
<a href='https://blog.twman.org/2021/04/ASR.html' target='_blank'>那些語音處理 (Speech Processing) 踩的坑</a><br> |
|
<a href='https://blog.twman.org/2024/02/asr-tts.html' target='_blank'>那些ASR和TTS可能會踩的坑</a><br> |
|
<a href='https://blog.twman.org/2021/04/NLP.html' target='_blank'>那些自然語言處理 (Natural Language Processing, NLP) 踩的坑</a><br> |
|
基於transformers的T5ForConditionalGeneration""", |
|
theme="default", |
|
examples=[ |
|
["你究輸入利的手機門號跟生分證就可以了。"], |
|
["這裡是客服中新,很高性為您服物,請問金天有什麼須要幫忙您得"], |
|
["因為我們這邊是按天術比例計蒜給您的,其實不會有態大的穎響。也就是您用前面的資非的廢率來做計算"], |
|
["我來看以下,他的時價是多少?起實您就可以直皆就不用到門事"], |
|
["因為你現在月富是六九九嘛,我幫擬減衣百塊,兒且也不會江速"] |
|
] |
|
) |
|
interface.launch(share=True) |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|