|
import gradio as gr |
|
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM |
|
|
|
|
|
emotion_analyzer = pipeline( |
|
"text-classification", |
|
model="j-hartmann/emotion-english-distilroberta-base" |
|
) |
|
|
|
safety_checker = pipeline( |
|
"text-classification", |
|
model="unitary/toxic-bert" |
|
) |
|
|
|
tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium") |
|
model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium") |
|
|
|
|
|
SAFE_RESPONSES = { |
|
"crisis": "我注意到您可能需要专业帮助,建议联系心理咨询师或拨打援助热线。", |
|
"sadness": "听起来您压力很大,试试深呼吸或听舒缓音乐?", |
|
"toxic": "检测到敏感内容,已启动安全保护机制。" |
|
} |
|
|
|
def generate_response(user_input, history): |
|
|
|
if safety_checker(user_input)[0]["label"] == "toxic": |
|
return SAFE_RESPONSES["toxic"] |
|
|
|
|
|
emotion = emotion_analyzer(user_input)[0]["label"] |
|
|
|
|
|
if emotion in ["sadness", "fear"]: |
|
return SAFE_RESPONSES.get(emotion, "可以多聊聊您的感受吗?") |
|
|
|
|
|
inputs = tokenizer.encode(user_input + tokenizer.eos_token, return_tensors="pt") |
|
reply_ids = model.generate( |
|
inputs, |
|
max_length=1000, |
|
pad_token_id=tokenizer.eos_token_id, |
|
no_repeat_ngram_size=3 |
|
) |
|
return tokenizer.decode(reply_ids[:, inputs.shape[-1]:][0], skip_special_tokens=True) |
|
|
|
|
|
demo = gr.ChatInterface( |
|
fn=generate_response, |
|
examples=["最近学习压力好大", "和父母吵架了很难过"], |
|
title="青少年心理健康助手" |
|
) |
|
|
|
demo.launch() |
|
|