chen666-666 commited on
Commit
950dc1a
·
verified ·
1 Parent(s): 6a568bb

Update utils.py

Browse files
Files changed (1) hide show
  1. utils.py +108 -75
utils.py CHANGED
@@ -1,6 +1,7 @@
1
  import json
2
  import tempfile
3
  import re
 
4
  from collections import defaultdict
5
  from transformers import (
6
  AutoTokenizer,
@@ -11,25 +12,38 @@ import torch
11
  from pyvis.network import Network
12
 
13
  # -------------------------------
14
- # 实体识别模型(NER)
15
  # -------------------------------
16
- ner_tokenizer = AutoTokenizer.from_pretrained("ckiplab/bert-base-chinese-ner")
17
- ner_model = AutoModelForSequenceClassification.from_pretrained("ckiplab/bert-base-chinese-ner")
18
- ner_pipeline = pipeline("ner", model=ner_model, tokenizer=ner_tokenizer, aggregation_strategy="simple")
19
 
 
 
 
 
 
 
 
 
 
 
20
 
21
  # -------------------------------
22
  # 人物关系分类模型(使用 RoBERTa)
23
  # -------------------------------
24
- rel_model_name = "hfl/chinese-roberta-wwm-ext" # 推荐的中文 RoBERTa 模型
25
- rel_tokenizer = AutoTokenizer.from_pretrained(rel_model_name)
26
- rel_model = AutoModelForSequenceClassification.from_pretrained(
27
- rel_model_name,
28
- num_labels=6, # 确保标签数量匹配
29
- id2label={0: "夫妻", 1: "父子", 2: "朋友", 3: "师生", 4: "同事", 5: "其他"},
30
- label2id={"夫妻": 0, "父子": 1, "朋友": 2, "师生": 3, "同事": 4, "其他": 5}
31
- )
32
- rel_model.eval()
 
 
 
33
 
34
  # 关系分类的标签映射
35
  relation_id2label = {
@@ -41,55 +55,52 @@ legal_id2label = {
41
  0: "无违法", 1: "赌博", 2: "毒品", 3: "色情", 4: "诈骗", 5: "暴力"
42
  }
43
 
44
-
45
- def classify_relation_bert(e1, e2, context):
46
- prompt = f"{e1}和{e2}的关系是?{context}"
47
- inputs = rel_tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
48
- with torch.no_grad():
49
- logits = rel_model(**inputs).logits
50
- pred = torch.argmax(logits, dim=1).item()
51
- probs = torch.nn.functional.softmax(logits, dim=1)
52
- confidence = probs[0, pred].item()
53
- return f"{relation_id2label[pred]}(置信度 {confidence:.2f})"
54
-
55
-
56
  # -------------------------------
57
  # 聊天输入解析
58
  # -------------------------------
59
  def parse_input_file(file):
60
- filename = file.name
61
- if filename.endswith(".json"):
62
- return json.load(file)
63
- elif filename.endswith(".txt"):
64
- content = file.read().decode("utf-8")
65
- lines = content.strip().splitlines()
66
- chat_data = []
67
- for line in lines:
68
- match = re.match(r"(\d{4}-\d{2}-\d{2}.*?) (.*?): (.*)", line)
69
- if match:
70
- _, sender, message = match.groups()
71
- chat_data.append({"sender": sender, "message": message})
72
- return chat_data
73
- else:
74
- raise ValueError("不支持的文件格式,请上传 JSON 或 TXT 文件")
75
-
 
 
 
 
76
 
77
  # -------------------------------
78
  # 实体提取函数
79
  # -------------------------------
80
  def extract_entities(text):
81
- results = ner_pipeline(text)
82
- people = set()
83
- for r in results:
84
- if r["entity_group"] == "PER":
85
- people.add(r["word"])
86
- return list(people)
87
-
 
 
 
 
88
 
89
  # -------------------------------
90
  # 关系抽取函数(共现 + BERT 分类)
91
  # -------------------------------
92
  def extract_relations(chat_data, entities):
 
93
  relations = defaultdict(lambda: defaultdict(lambda: {"count": 0, "contexts": []}))
94
 
95
  for entry in chat_data:
@@ -116,62 +127,84 @@ def extract_relations(chat_data, entities):
116
  edges.append((e1, e2, relations[e1][e2]["count"], label))
117
  return edges
118
 
119
-
120
  # -------------------------------
121
  # 法律风险分析(黄赌毒等)函数
122
  # -------------------------------
123
- def classify_illegal_behavior(chat_context):
124
- prompt = f"请分析以下聊天记录,判断是否涉及以下违法行为:赌博、毒品、色情、诈骗、暴力行为。\n聊天内容:{chat_context}\n请回答:"
125
- inputs = rel_tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
126
- with torch.no_grad():
127
- logits = rel_model(**inputs).logits
128
- pred = torch.argmax(logits, dim=1).item()
129
- probs = torch.nn.functional.softmax(logits, dim=1)
130
- confidence = probs[0, pred].item()
131
-
132
- return f"违法行为判断结果:{legal_id2label.get(pred, '未知')}(置信度 {confidence:.2f})"
 
 
 
 
133
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
 
135
  # -------------------------------
136
  # 图谱绘制
137
  # -------------------------------
138
  def draw_graph(entities, relations):
139
- g = Network(height="600px", width="100%", notebook=False)
140
- g.barnes_hut()
141
- for ent in entities:
142
- g.add_node(ent, label=ent)
143
- for e1, e2, weight, label in relations:
144
- g.add_edge(e1, e2, value=weight, title=f"{label}(互动{weight}次)", label=label)
145
- tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".html")
146
- g.show(tmp_file.name)
147
- with open(tmp_file.name, 'r', encoding='utf-8') as f:
148
- return f.read()
149
-
 
 
 
 
150
 
151
  # -------------------------------
152
  # 主流程函数
153
  # -------------------------------
154
  def analyze_chat(file):
 
155
  if file is None:
156
- return "请上传聊天文件", "", ""
157
 
158
  try:
159
  content = parse_input_file(file)
160
  except Exception as e:
161
- return f"读取文件失败: {e}", "", ""
162
 
163
  text = "\n".join([entry["sender"] + ": " + entry["message"] for entry in content])
164
  entities = extract_entities(text)
165
  if not entities:
166
- return "未识别到任何人物实体", "", ""
167
 
168
  relations = extract_relations(content, entities)
169
  if not relations:
170
- return "未发现人物之间的关系", "", ""
171
 
172
  # 法律风险分析
173
  illegal_behavior_results = [classify_illegal_behavior(msg["message"]) for msg in content]
174
 
175
  graph_html = draw_graph(entities, relations)
176
 
177
- return str(entities), str(relations), graph_html, "\n".join(illegal_behavior_results)
 
1
  import json
2
  import tempfile
3
  import re
4
+ import os
5
  from collections import defaultdict
6
  from transformers import (
7
  AutoTokenizer,
 
12
  from pyvis.network import Network
13
 
14
  # -------------------------------
15
+ # 模型配置
16
  # -------------------------------
17
+ # 使用环境变量配置模型名称,便于在Hugging Face上部署时修改
18
+ NER_MODEL_NAME = os.environ.get("NER_MODEL_NAME", "ckiplab/bert-base-chinese-ner")
19
+ REL_MODEL_NAME = os.environ.get("REL_MODEL_NAME", "hfl/chinese-roberta-wwm-ext")
20
 
21
+ # -------------------------------
22
+ # 实体识别模型(NER)
23
+ # -------------------------------
24
+ try:
25
+ ner_tokenizer = AutoTokenizer.from_pretrained(NER_MODEL_NAME)
26
+ ner_model = AutoModelForSequenceClassification.from_pretrained(NER_MODEL_NAME)
27
+ ner_pipeline = pipeline("ner", model=ner_model, tokenizer=ner_tokenizer, aggregation_strategy="simple")
28
+ except Exception as e:
29
+ print(f"NER模型加载失败: {e}")
30
+ # 可以添加备选方案或错误处理逻辑
31
 
32
  # -------------------------------
33
  # 人物关系分类模型(使用 RoBERTa)
34
  # -------------------------------
35
+ try:
36
+ rel_tokenizer = AutoTokenizer.from_pretrained(REL_MODEL_NAME)
37
+ rel_model = AutoModelForSequenceClassification.from_pretrained(
38
+ REL_MODEL_NAME,
39
+ num_labels=6, # 确保标签数量匹配
40
+ id2label={0: "夫妻", 1: "父子", 2: "朋友", 3: "师生", 4: "同事", 5: "其他"},
41
+ label2id={"夫妻": 0, "父子": 1, "朋友": 2, "师生": 3, "同事": 4, "其他": 5}
42
+ )
43
+ rel_model.eval()
44
+ except Exception as e:
45
+ print(f"关系分类模型加载失败: {e}")
46
+ # 可以添加备选方案或错误处理逻辑
47
 
48
  # 关系分类的标签映射
49
  relation_id2label = {
 
55
  0: "无违法", 1: "赌博", 2: "毒品", 3: "色情", 4: "诈骗", 5: "暴力"
56
  }
57
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  # -------------------------------
59
  # 聊天输入解析
60
  # -------------------------------
61
  def parse_input_file(file):
62
+ """解析聊天文件,支持JSON和TXT格式"""
63
+ try:
64
+ filename = file.name
65
+ if filename.endswith(".json"):
66
+ return json.load(file)
67
+ elif filename.endswith(".txt"):
68
+ content = file.read().decode("utf-8")
69
+ lines = content.strip().splitlines()
70
+ chat_data = []
71
+ for line in lines:
72
+ match = re.match(r"(\d{4}-\d{2}-\d{2}.*?) (.*?): (.*)", line)
73
+ if match:
74
+ _, sender, message = match.groups()
75
+ chat_data.append({"sender": sender, "message": message})
76
+ return chat_data
77
+ else:
78
+ raise ValueError("不支持的文件格式,请上传JSON或TXT文件")
79
+ except Exception as e:
80
+ print(f"文件解析错误: {e}")
81
+ raise
82
 
83
  # -------------------------------
84
  # 实体提取函数
85
  # -------------------------------
86
  def extract_entities(text):
87
+ """从文本中提取人物实���"""
88
+ try:
89
+ results = ner_pipeline(text)
90
+ people = set()
91
+ for r in results:
92
+ if r["entity_group"] == "PER":
93
+ people.add(r["word"])
94
+ return list(people)
95
+ except Exception as e:
96
+ print(f"实体提取错误: {e}")
97
+ return []
98
 
99
  # -------------------------------
100
  # 关系抽取函数(共现 + BERT 分类)
101
  # -------------------------------
102
  def extract_relations(chat_data, entities):
103
+ """分析人物之间的关系"""
104
  relations = defaultdict(lambda: defaultdict(lambda: {"count": 0, "contexts": []}))
105
 
106
  for entry in chat_data:
 
127
  edges.append((e1, e2, relations[e1][e2]["count"], label))
128
  return edges
129
 
 
130
  # -------------------------------
131
  # 法律风险分析(黄赌毒等)函数
132
  # -------------------------------
133
+ def classify_relation_bert(e1, e2, context):
134
+ """使用BERT模型分析人物关系"""
135
+ try:
136
+ prompt = f"{e1}和{e2}的关系是?{context}"
137
+ inputs = rel_tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
138
+ with torch.no_grad():
139
+ logits = rel_model(**inputs).logits
140
+ pred = torch.argmax(logits, dim=1).item()
141
+ probs = torch.nn.functional.softmax(logits, dim=1)
142
+ confidence = probs[0, pred].item()
143
+ return f"{relation_id2label[pred]}(置信度 {confidence:.2f})"
144
+ except Exception as e:
145
+ print(f"关系分类错误: {e}")
146
+ return "其他(置信度 0.00)"
147
 
148
+ def classify_illegal_behavior(chat_context):
149
+ """分析聊天内容中的法律风险"""
150
+ try:
151
+ prompt = f"请分析以下聊天记录,判断是否涉及以下违法行为:赌博、毒品、色情、诈骗、暴力行为。\n聊天内容:{chat_context}\n请回答:"
152
+ inputs = rel_tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
153
+ with torch.no_grad():
154
+ logits = rel_model(**inputs).logits
155
+ pred = torch.argmax(logits, dim=1).item()
156
+ probs = torch.nn.functional.softmax(logits, dim=1)
157
+ confidence = probs[0, pred].item()
158
+ return f"违法行为判断结果:{legal_id2label.get(pred, '未知')}(置信度 {confidence:.2f})"
159
+ except Exception as e:
160
+ print(f"法律风险分析错误: {e}")
161
+ return "违法行为判断结果:未知(置信度 0.00)"
162
 
163
  # -------------------------------
164
  # 图谱绘制
165
  # -------------------------------
166
  def draw_graph(entities, relations):
167
+ """生成人物关系图谱"""
168
+ try:
169
+ g = Network(height="600px", width="100%", notebook=False)
170
+ g.barnes_hut()
171
+ for ent in entities:
172
+ g.add_node(ent, label=ent)
173
+ for e1, e2, weight, label in relations:
174
+ g.add_edge(e1, e2, value=weight, title=f"{label}(互动{weight}次)", label=label)
175
+ tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".html")
176
+ g.show(tmp_file.name)
177
+ with open(tmp_file.name, 'r', encoding='utf-8') as f:
178
+ return f.read()
179
+ except Exception as e:
180
+ print(f"图谱绘制错误: {e}")
181
+ return "<h3>图谱生成失败</h3><p>请检查输入数据是否有效</p>"
182
 
183
  # -------------------------------
184
  # 主流程函数
185
  # -------------------------------
186
  def analyze_chat(file):
187
+ """分析聊天记录的主函数"""
188
  if file is None:
189
+ return "请上传聊天文件", "", "", ""
190
 
191
  try:
192
  content = parse_input_file(file)
193
  except Exception as e:
194
+ return f"读取文件失败: {e}", "", "", ""
195
 
196
  text = "\n".join([entry["sender"] + ": " + entry["message"] for entry in content])
197
  entities = extract_entities(text)
198
  if not entities:
199
+ return "未识别到任何人物实体", "", "", ""
200
 
201
  relations = extract_relations(content, entities)
202
  if not relations:
203
+ return "未发现人物之间的关系", "", "", ""
204
 
205
  # 法律风险分析
206
  illegal_behavior_results = [classify_illegal_behavior(msg["message"]) for msg in content]
207
 
208
  graph_html = draw_graph(entities, relations)
209
 
210
+ return str(entities), str(relations), graph_html, "\n".join(illegal_behavior_results)