|
import gradio as gr |
|
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM |
|
|
|
|
|
emotion_analyzer = pipeline("text-classification", model="j-hartmann/emotion-english-distilroberta-base") |
|
safety_checker = pipeline("text-classification", model="meta-llama/Meta-Llama-Guard-2-8B") |
|
tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium") |
|
model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium") |
|
|
|
|
|
SAFE_RESPONSES = { |
|
"crisis": "我注意到您可能需要专业帮助,建议立即联系心理咨询师或拨打心理援助热线。", |
|
"sadness": "听起来您最近压力很大,要不要试试深呼吸或听舒缓音乐?", |
|
"anger": "情绪波动很正常,我们可以一起分析问题的根源。" |
|
} |
|
|
|
def generate_response(user_input, history): |
|
|
|
safety_result = safety_checker(user_input)[0] |
|
if safety_result["label"] == "UNSAFE": |
|
return SAFE_RESPONSES["crisis"] |
|
|
|
|
|
emotion = emotion_analyzer(user_input)[0]["label"] |
|
|
|
|
|
if emotion in ["sadness", "fear"]: |
|
return SAFE_RESPONSES.get(emotion, "我理解您的感受,可以多聊聊吗?") |
|
|
|
|
|
inputs = tokenizer.encode(user_input + tokenizer.eos_token, return_tensors="pt") |
|
reply_ids = model.generate( |
|
inputs, |
|
max_length=1000, |
|
pad_token_id=tokenizer.eos_token_id, |
|
no_repeat_ngram_size=3 |
|
) |
|
response = tokenizer.decode(reply_ids[:, inputs.shape[-1]:][0], skip_special_tokens=True) |
|
|
|
return response |
|
|
|
|
|
demo = gr.ChatInterface( |
|
fn=generate_response, |
|
examples=["最近总是失眠", "感觉没有人理解我", "考试成绩让我很焦虑"], |
|
title="青少年心理健康助手", |
|
description="请随时倾诉您的感受,我会尽力帮助您调整情绪。" |
|
) |
|
|
|
demo.launch() |
|
|