QiLi520 commited on
Commit
c9ee3a0
·
verified ·
1 Parent(s): e3d81cc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -20
app.py CHANGED
@@ -1,31 +1,38 @@
1
  import gradio as gr
2
  from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
3
 
4
- # 初始化模型
5
- emotion_analyzer = pipeline("text-classification", model="j-hartmann/emotion-english-distilroberta-base")
6
- safety_checker = pipeline("text-classification", model="meta-llama/Meta-Llama-Guard-2-8B")
 
 
 
 
 
 
 
 
7
  tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium")
8
  model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium")
9
 
10
- # 预定义安全回复模板
11
  SAFE_RESPONSES = {
12
- "crisis": "我注意到您可能需要专业帮助,建议立即联系心理咨询师或拨打心理援助热线。",
13
- "sadness": "听起来您最近压力很大,要不要试试深呼吸或听舒缓音乐?",
14
- "anger": "情绪波动很正常,我们可以一起分析问题的根源。"
15
  }
16
 
17
  def generate_response(user_input, history):
18
  # 安全检查
19
- safety_result = safety_checker(user_input)[0]
20
- if safety_result["label"] == "UNSAFE":
21
- return SAFE_RESPONSES["crisis"]
22
 
23
  # 情绪分析
24
  emotion = emotion_analyzer(user_input)[0]["label"]
25
 
26
- # 根据情绪选择生成策略
27
  if emotion in ["sadness", "fear"]:
28
- return SAFE_RESPONSES.get(emotion, "我理解您的感受,可以多聊聊吗?")
29
 
30
  # 生成对话
31
  inputs = tokenizer.encode(user_input + tokenizer.eos_token, return_tensors="pt")
@@ -33,18 +40,15 @@ def generate_response(user_input, history):
33
  inputs,
34
  max_length=1000,
35
  pad_token_id=tokenizer.eos_token_id,
36
- no_repeat_ngram_size=3 # 避免重复
37
  )
38
- response = tokenizer.decode(reply_ids[:, inputs.shape[-1]:][0], skip_special_tokens=True)
39
-
40
- return response
41
 
42
- # 创建Gradio界面
43
  demo = gr.ChatInterface(
44
  fn=generate_response,
45
- examples=["最近总是失眠", "感觉没有人理解我", "考试成绩让我很焦虑"],
46
- title="青少年心理健康助手",
47
- description="请随时倾诉您的感受,我会尽力帮助您调整情绪。"
48
  )
49
 
50
  demo.launch()
 
1
  import gradio as gr
2
  from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
3
 
4
+ # 使用公开模型初始化
5
+ emotion_analyzer = pipeline(
6
+ "text-classification",
7
+ model="j-hartmann/emotion-english-distilroberta-base"
8
+ )
9
+
10
+ safety_checker = pipeline(
11
+ "text-classification",
12
+ model="unitary/toxic-bert" # 替换为开源模型
13
+ )
14
+
15
  tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium")
16
  model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium")
17
 
18
+ # 安全回复模板
19
  SAFE_RESPONSES = {
20
+ "crisis": "我注意到您可能需要专业帮助,建议联系心理咨询师或拨打援助热线。",
21
+ "sadness": "听起来您压力很大,试试深呼吸或听舒缓音乐?",
22
+ "toxic": "检测到敏感内容,已启动安全保护机制。"
23
  }
24
 
25
  def generate_response(user_input, history):
26
  # 安全检查
27
+ if safety_checker(user_input)[0]["label"] == "toxic":
28
+ return SAFE_RESPONSES["toxic"]
 
29
 
30
  # 情绪分析
31
  emotion = emotion_analyzer(user_input)[0]["label"]
32
 
33
+ # 根据情绪生成回复
34
  if emotion in ["sadness", "fear"]:
35
+ return SAFE_RESPONSES.get(emotion, "可以多聊聊您的感受吗?")
36
 
37
  # 生成对话
38
  inputs = tokenizer.encode(user_input + tokenizer.eos_token, return_tensors="pt")
 
40
  inputs,
41
  max_length=1000,
42
  pad_token_id=tokenizer.eos_token_id,
43
+ no_repeat_ngram_size=3
44
  )
45
+ return tokenizer.decode(reply_ids[:, inputs.shape[-1]:][0], skip_special_tokens=True)
 
 
46
 
47
+ # 创建界面
48
  demo = gr.ChatInterface(
49
  fn=generate_response,
50
+ examples=["最近学习压力好大", "和父母吵架了很难过"],
51
+ title="青少年心理健康助手"
 
52
  )
53
 
54
  demo.launch()