Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -33,6 +33,7 @@ from transformers import MegatronBertForMaskedLM
|
|
| 33 |
import argparse
|
| 34 |
import copy
|
| 35 |
import streamlit as st
|
|
|
|
| 36 |
# os.environ["CUDA_VISIBLE_DEVICES"] = '6'
|
| 37 |
|
| 38 |
|
|
@@ -612,12 +613,12 @@ def comp_acc(pred_data, test_data):
|
|
| 612 |
|
| 613 |
|
| 614 |
@st.experimental_memo()
|
| 615 |
-
def load_model():
|
| 616 |
total_parser = argparse.ArgumentParser("TASK NAME")
|
| 617 |
total_parser = UniMCPipelines.pipelines_args(total_parser)
|
| 618 |
args = total_parser.parse_args()
|
| 619 |
|
| 620 |
-
args.pretrained_model_path =
|
| 621 |
args.max_length = 512
|
| 622 |
args.batchsize = 8
|
| 623 |
args.default_root_dir = './'
|
|
@@ -628,14 +629,52 @@ def load_model():
|
|
| 628 |
|
| 629 |
def main():
|
| 630 |
|
| 631 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 632 |
|
| 633 |
st.subheader("UniMC Zero-shot 体验")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 634 |
st.info("请输入以下信息...")
|
|
|
|
| 635 |
|
| 636 |
-
sentences = st.text_area("请输入句子:",
|
| 637 |
-
question = st.text_input("请输入问题(不输入问题也可以):", "
|
| 638 |
-
choice = st.text_input("输入标签(以中文;分割):",
|
| 639 |
choice = choice.split(';')
|
| 640 |
|
| 641 |
data = [{"texta": sentences,
|
|
@@ -646,8 +685,9 @@ def main():
|
|
| 646 |
"id": 0}]
|
| 647 |
|
| 648 |
if st.button("点击一下,开始预测!"):
|
|
|
|
| 649 |
result = model.predict(data, cuda=False)
|
| 650 |
-
st.success("
|
| 651 |
st.json(result[0])
|
| 652 |
else:
|
| 653 |
st.info(
|
|
|
|
| 33 |
import argparse
|
| 34 |
import copy
|
| 35 |
import streamlit as st
|
| 36 |
+
import time
|
| 37 |
# os.environ["CUDA_VISIBLE_DEVICES"] = '6'
|
| 38 |
|
| 39 |
|
|
|
|
| 613 |
|
| 614 |
|
| 615 |
@st.experimental_memo()
|
| 616 |
+
def load_model(model_parh):
|
| 617 |
total_parser = argparse.ArgumentParser("TASK NAME")
|
| 618 |
total_parser = UniMCPipelines.pipelines_args(total_parser)
|
| 619 |
args = total_parser.parse_args()
|
| 620 |
|
| 621 |
+
args.pretrained_model_path = model_path
|
| 622 |
args.max_length = 512
|
| 623 |
args.batchsize = 8
|
| 624 |
args.default_root_dir = './'
|
|
|
|
| 629 |
|
| 630 |
def main():
|
| 631 |
|
| 632 |
+
text_dict={
|
| 633 |
+
'文本分类':"微软披露拓扑量子计算机计划!",
|
| 634 |
+
'情感分析':"刚买iphone13 pro 还不到一个月,天天死机最差的一次购物体验",
|
| 635 |
+
'语义匹配':"今天心情不好,我很不开心",
|
| 636 |
+
'自然语言推理':"小明正在上高中[unused1]小明是一个初中生",
|
| 637 |
+
'多项式阅读理解':"这个男的是什么意思?[unused1][SEP]女:您看这件衣服挺不错的,质量好,价钱也不贵。\n男:再看看吧。",
|
| 638 |
+
}
|
| 639 |
+
|
| 640 |
+
question_dict={
|
| 641 |
+
'文本分类':"故事;文化;娱乐;体育;财经;房产;汽车;教育;科技",
|
| 642 |
+
'情感分析':"好评;差评",
|
| 643 |
+
'语义匹配':"可以理解为;不能理解为",
|
| 644 |
+
'自然语言推理':"可以推断出;不能推断出;很难推断出",
|
| 645 |
+
'多项式阅读理解':"不想要这件;衣服挺好的;衣服质量不好",
|
| 646 |
+
}
|
| 647 |
+
|
| 648 |
+
choice_dict={
|
| 649 |
+
'文本分类':"故事;文化;娱乐;体育;财经;房产;汽车;教育;科技",
|
| 650 |
+
'情感分析':"好评;差评",
|
| 651 |
+
'语义匹配':"可以理解为;不能理解为",
|
| 652 |
+
'自然语言推理':"可以推断出;不能推断出;很难推断出",
|
| 653 |
+
'多项式阅读理解':"不想要这件;衣服挺好的;衣服质量不好",
|
| 654 |
+
}
|
| 655 |
+
|
| 656 |
+
|
| 657 |
|
| 658 |
st.subheader("UniMC Zero-shot 体验")
|
| 659 |
+
|
| 660 |
+
st.sidebar.header("参数配置")
|
| 661 |
+
sbform = st.sidebar.form("固定参数设置")
|
| 662 |
+
language = sbform.selectbox('选择语言', ['中文', 'English'])
|
| 663 |
+
sbform.form_submit_button("配置")
|
| 664 |
+
|
| 665 |
+
if language == '中文':
|
| 666 |
+
model = load_model('IDEA-CCNL/Erlangshen-UniMC-RoBERTa-110M-Chinese')
|
| 667 |
+
else:
|
| 668 |
+
model = load_model('IDEA-CCNL/Erlangshen-UniMC-RoBERTa-110M-Chinese')
|
| 669 |
+
|
| 670 |
+
model_type = st.selectbox('选择任务类型',['文本分类','情感分析','语义匹配','自然语言推理','多项式阅读理解'])
|
| 671 |
+
|
| 672 |
st.info("请输入以下信息...")
|
| 673 |
+
|
| 674 |
|
| 675 |
+
sentences = st.text_area("请输入句子:", text_dict[model_type])
|
| 676 |
+
question = st.text_input("请输入问题(不输入问题也可以):", "")
|
| 677 |
+
choice = st.text_input("输入标签(以中文;分割):", choice_dict[model_type])
|
| 678 |
choice = choice.split(';')
|
| 679 |
|
| 680 |
data = [{"texta": sentences,
|
|
|
|
| 685 |
"id": 0}]
|
| 686 |
|
| 687 |
if st.button("点击一下,开始预测!"):
|
| 688 |
+
start=time.time()
|
| 689 |
result = model.predict(data, cuda=False)
|
| 690 |
+
st.success(f"Prediction is successful, consumes {str(time.time()-start)} seconds")
|
| 691 |
st.json(result[0])
|
| 692 |
else:
|
| 693 |
st.info(
|