Update app.py
Browse files
app.py
CHANGED
@@ -1,31 +1,38 @@
|
|
1 |
import gradio as gr
|
2 |
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
|
3 |
|
4 |
-
#
|
5 |
-
emotion_analyzer = pipeline(
|
6 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium")
|
8 |
model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium")
|
9 |
|
10 |
-
#
|
11 |
SAFE_RESPONSES = {
|
12 |
-
"crisis": "
|
13 |
-
"sadness": "
|
14 |
-
"
|
15 |
}
|
16 |
|
17 |
def generate_response(user_input, history):
|
18 |
# 安全检查
|
19 |
-
|
20 |
-
|
21 |
-
return SAFE_RESPONSES["crisis"]
|
22 |
|
23 |
# 情绪分析
|
24 |
emotion = emotion_analyzer(user_input)[0]["label"]
|
25 |
|
26 |
-
#
|
27 |
if emotion in ["sadness", "fear"]:
|
28 |
-
return SAFE_RESPONSES.get(emotion, "
|
29 |
|
30 |
# 生成对话
|
31 |
inputs = tokenizer.encode(user_input + tokenizer.eos_token, return_tensors="pt")
|
@@ -33,18 +40,15 @@ def generate_response(user_input, history):
|
|
33 |
inputs,
|
34 |
max_length=1000,
|
35 |
pad_token_id=tokenizer.eos_token_id,
|
36 |
-
no_repeat_ngram_size=3
|
37 |
)
|
38 |
-
|
39 |
-
|
40 |
-
return response
|
41 |
|
42 |
-
#
|
43 |
demo = gr.ChatInterface(
|
44 |
fn=generate_response,
|
45 |
-
examples=["
|
46 |
-
title="青少年心理健康助手"
|
47 |
-
description="请随时倾诉您的感受,我会尽力帮助您调整情绪。"
|
48 |
)
|
49 |
|
50 |
demo.launch()
|
|
|
1 |
import gradio as gr
|
2 |
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
|
3 |
|
4 |
+
# 使用公开模型初始化
|
5 |
+
emotion_analyzer = pipeline(
|
6 |
+
"text-classification",
|
7 |
+
model="j-hartmann/emotion-english-distilroberta-base"
|
8 |
+
)
|
9 |
+
|
10 |
+
safety_checker = pipeline(
|
11 |
+
"text-classification",
|
12 |
+
model="unitary/toxic-bert" # 替换为开源模型
|
13 |
+
)
|
14 |
+
|
15 |
tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium")
|
16 |
model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium")
|
17 |
|
18 |
+
# 安全回复模板
|
19 |
SAFE_RESPONSES = {
|
20 |
+
"crisis": "我注意到您可能需要专业帮助,建议联系心理咨询师或拨打援助热线。",
|
21 |
+
"sadness": "听起来您压力很大,试试深呼吸或听舒缓音乐?",
|
22 |
+
"toxic": "检测到敏感内容,已启动安全保护机制。"
|
23 |
}
|
24 |
|
25 |
def generate_response(user_input, history):
|
26 |
# 安全检查
|
27 |
+
if safety_checker(user_input)[0]["label"] == "toxic":
|
28 |
+
return SAFE_RESPONSES["toxic"]
|
|
|
29 |
|
30 |
# 情绪分析
|
31 |
emotion = emotion_analyzer(user_input)[0]["label"]
|
32 |
|
33 |
+
# 根据情绪生成回复
|
34 |
if emotion in ["sadness", "fear"]:
|
35 |
+
return SAFE_RESPONSES.get(emotion, "可以多聊聊您的感受吗?")
|
36 |
|
37 |
# 生成对话
|
38 |
inputs = tokenizer.encode(user_input + tokenizer.eos_token, return_tensors="pt")
|
|
|
40 |
inputs,
|
41 |
max_length=1000,
|
42 |
pad_token_id=tokenizer.eos_token_id,
|
43 |
+
no_repeat_ngram_size=3
|
44 |
)
|
45 |
+
return tokenizer.decode(reply_ids[:, inputs.shape[-1]:][0], skip_special_tokens=True)
|
|
|
|
|
46 |
|
47 |
+
# 创建界面
|
48 |
demo = gr.ChatInterface(
|
49 |
fn=generate_response,
|
50 |
+
examples=["最近学习压力好大", "和父母吵架了很难过"],
|
51 |
+
title="青少年心理健康助手"
|
|
|
52 |
)
|
53 |
|
54 |
demo.launch()
|