chen666-666 commited on
Commit
24834ac
·
verified ·
1 Parent(s): b45e38c

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +507 -426
  2. requirements.txt +14 -12
app.py CHANGED
@@ -1,427 +1,508 @@
1
- import torch
2
- from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline, AutoModel
3
- import gradio as gr
4
- import re
5
- import os
6
- import json
7
- import chardet
8
- from sklearn.metrics import precision_score, recall_score, f1_score
9
- import time
10
-
11
- # ======================== 模型加载 ========================
12
- NER_MODEL_NAME = "uer/roberta-base-finetuned-cluener2020-chinese"
13
- bert_tokenizer = AutoTokenizer.from_pretrained(NER_MODEL_NAME)
14
- bert_ner_model = AutoModelForTokenClassification.from_pretrained(NER_MODEL_NAME)
15
- bert_ner_pipeline = pipeline(
16
- "ner",
17
- model=bert_ner_model,
18
- tokenizer=bert_tokenizer,
19
- aggregation_strategy="first"
20
- )
21
-
22
- LABEL_MAPPING = {
23
- "address": "LOC",
24
- "company": "ORG",
25
- "name": "PER",
26
- "organization": "ORG",
27
- "position": "TITLE"
28
- }
29
-
30
- chatglm_model, chatglm_tokenizer = None, None
31
- use_chatglm = False
32
- try:
33
- chatglm_model_name = "THUDM/chatglm-6b-int4"
34
- chatglm_tokenizer = AutoTokenizer.from_pretrained(chatglm_model_name, trust_remote_code=True)
35
- chatglm_model = AutoModel.from_pretrained(
36
- chatglm_model_name,
37
- trust_remote_code=True,
38
- device_map="cpu",
39
- torch_dtype=torch.float32
40
- ).eval()
41
- use_chatglm = True
42
- print("✅ 4-bit量化版ChatGLM加载成功")
43
- except Exception as e:
44
- print(f"❌ ChatGLM加载失败: {e}")
45
-
46
- # ======================== 知识图谱结构 ========================
47
- knowledge_graph = {"entities": set(), "relations": set()}
48
-
49
- def update_knowledge_graph(entities, relations):
50
- for e in entities:
51
- if isinstance(e, dict) and 'text' in e and 'type' in e:
52
- knowledge_graph["entities"].add((e['text'], e['type']))
53
- # 修改4:添加关系去重逻辑
54
- existing_relations = {frozenset({r[0], r[1], r[2]}) for r in knowledge_graph["relations"]}
55
- for r in relations:
56
- if isinstance(r, dict) and all(k in r for k in ("head", "tail", "relation")):
57
- new_rel = frozenset({r['head'], r['tail'], r['relation']})
58
- if new_rel not in existing_relations:
59
- knowledge_graph["relations"].add((r['head'], r['tail'], r['relation']))
60
-
61
- def visualize_kg_text():
62
- nodes = [f"{ent[0]} ({ent[1]})" for ent in knowledge_graph["entities"]]
63
- edges = [f"{h} --[{r}]-> {t}" for h, t, r in knowledge_graph["relations"]]
64
- return "\n".join(["📌 实体:"] + nodes + ["", "📎 关系:"] + edges)
65
-
66
- # ======================== 实体识别(NER) ========================
67
- def merge_adjacent_entities(entities):
68
- merged = []
69
- for entity in entities:
70
- if not merged:
71
- merged.append(entity)
72
- continue
73
-
74
- last = merged[-1]
75
- # 合并相邻的同类型实体
76
- if (entity["type"] == last["type"] and
77
- entity["start"] == last["end"] and
78
- entity["text"] not in last["text"]):
79
- merged[-1] = {
80
- "text": last["text"] + entity["text"],
81
- "type": last["type"],
82
- "start": last["start"],
83
- "end": entity["end"]
84
- }
85
- else:
86
- merged.append(entity)
87
- return merged
88
-
89
-
90
- def ner(text, model_type="bert"):
91
- start_time = time.time()
92
- if model_type == "chatglm" and use_chatglm:
93
- try:
94
- prompt = f"""请从以下文本中识别所有实体,严格按照JSON列表格式返回,每个实体包含text、type、start、end字段:
95
- 示例:[{{"text": "北京", "type": "LOC", "start": 0, "end": 2}}]
96
- 文本:{text}"""
97
- response = chatglm_model.chat(chatglm_tokenizer, prompt, temperature=0.1)
98
- if isinstance(response, tuple):
99
- response = response[0]
100
-
101
- # 增强 JSON 解析
102
- try:
103
- json_str = re.search(r'\[.*\]', response, re.DOTALL).group()
104
- entities = json.loads(json_str)
105
- # 验证字段
106
- valid_entities = []
107
- for ent in entities:
108
- if all(k in ent for k in ("text", "type", "start", "end")):
109
- valid_entities.append(ent)
110
- return valid_entities, time.time() - start_time
111
- except Exception as e:
112
- print(f"JSON 解析失败: {e}")
113
- return [], time.time() - start_time
114
- except Exception as e:
115
- print(f"ChatGLM 调用失败:{e}")
116
- return [], time.time() - start_time
117
-
118
- # 使用微调的 BERT 中文 NER 模型
119
- raw_results = bert_ner_pipeline(text)
120
- entities = []
121
- for r in raw_results:
122
- mapped_type = LABEL_MAPPING.get(r['entity_group'], r['entity_group'])
123
- entities.append({
124
- "text": r['word'].replace(' ', ''),
125
- "start": r['start'],
126
- "end": r['end'],
127
- "type": mapped_type
128
- })
129
-
130
- # 执行合并处理
131
- entities = merge_adjacent_entities(entities)
132
- return entities, time.time() - start_time
133
-
134
-
135
- # ======================== 关系抽取(RE) ========================
136
- # ======================== 关系抽取(RE) ========================
137
- def re_extract(entities, text):
138
- # 参数校验
139
- if not entities or not text:
140
- return []
141
-
142
- # 实体类型过滤(根据业务需求调整)
143
- valid_entity_types = {"PER", "LOC", "ORG", "TITLE"}
144
- filtered_entities = [e for e in entities if e.get("type") in valid_entity_types]
145
-
146
- # --------------------- 处理单实体场景 ---------------------
147
- if len(filtered_entities) == 1:
148
- single_relations = []
149
- ent = filtered_entities[0]
150
-
151
- # 规则1:人物职位检测
152
- if ent["type"] == "PER":
153
- position_keywords = ["CEO", "经理", "总监", "工程师", "教授"]
154
- for keyword in position_keywords:
155
- if keyword in text:
156
- single_relations.append({
157
- "head": ent["text"],
158
- "tail": keyword,
159
- "relation": "担任职位"
160
- })
161
- break
162
-
163
- # 规则2:机构地点检测
164
- if ent["type"] in ["ORG", "LOC"]:
165
- location_verbs = ["位于", "坐落于", "地处"]
166
- for verb in location_verbs:
167
- if verb in text:
168
- match = re.search(fr"{ent['text']}{verb}(.*?)[,。]", text)
169
- if match:
170
- single_relations.append({
171
- "head": ent["text"],
172
- "tail": match.group(1).strip(),
173
- "relation": "位置"
174
- })
175
- break
176
- return single_relations
177
-
178
- # --------------------- 多实体关系抽取 ---------------------
179
- relations = []
180
-
181
- # 方案1:使用ChatGLM抽取关系
182
- if use_chatglm and len(filtered_entities) >= 2:
183
- try:
184
- entity_list = [e["text"] for e in filtered_entities]
185
- prompt = f"""请分析以下文本中的实体关系,严格按照JSON列表格式返回:
186
- 文本内容:{text}
187
- 候选实体:{entity_list}
188
- 要求:
189
- 1. 只返回存在明确关系的实体对
190
- 2. 关系类型使用:属于、位于、任职于、合作、其他
191
- 3. 示例格式:[{{"head":"实体1", "tail":"实体2", "relation":"关系类型"}}]
192
- 请直接返回JSON,不要多余内容:"""
193
-
194
- response = chatglm_model.chat(chatglm_tokenizer, prompt, temperature=0.01)
195
- if isinstance(response, tuple):
196
- response = response[0]
197
-
198
- # 增强JSON解析
199
- try:
200
- json_str = re.search(r'(\[.*?\])', response, re.DOTALL)
201
- if json_str:
202
- json_str = json_str.group(1)
203
- json_str = re.sub(r'[\u201c\u201d]', '"', json_str) # 处理中文引号
204
- json_str = re.sub(r'(?<!,)\n', '', json_str) # 保留逗号后的换行
205
- relations = json.loads(json_str)
206
-
207
- # 验证关系有效性
208
- valid_relations = []
209
- valid_rel_types = {"属于", "位于", "任职于", "合作", "其他"}
210
- for rel in relations:
211
- if (isinstance(rel, dict) and
212
- rel.get("head") in entity_list and
213
- rel.get("tail") in entity_list and
214
- rel.get("relation") in valid_rel_types):
215
- valid_relations.append(rel)
216
- relations = valid_relations
217
- except Exception as e:
218
- print(f"[DEBUG] 关系解析失败: {str(e)}")
219
-
220
- except Exception as e:
221
- print(f"ChatGLM关系抽取异常: {str(e)}")
222
-
223
- # 方案2:规则兜底(当模型不可用或未抽取出关系时)
224
- if len(relations) == 0:
225
- # 规则1:A位于B
226
- location_matches = re.finditer(r'([^\s,。]+)[位于|坐落于|地处]([^\s,。]+)', text)
227
- for match in location_matches:
228
- head, tail = match.groups()
229
- relations.append({"head": head, "tail": tail, "relation": "位于"})
230
-
231
- # 规则2:A属于B
232
- belong_matches = re.finditer(r'([^\s,。]+)(属于|隶属于)([^\s,。]+)', text)
233
- for match in belong_matches:
234
- head, _, tail = match.groups()
235
- relations.append({"head": head, "tail": tail, "relation": "属于"})
236
-
237
- # 规则3:人物-机构关系
238
- person_org_pattern = r'([\u4e00-\u9fa5]{2,4})(现任|担任|就职于)([\u4e00-\u9fa5]+?公司|[\u4e00-\u9fa5]+?大学)'
239
- for match in re.finditer(person_org_pattern, text):
240
- head, _, tail = match.groups()
241
- relations.append({"head": head, "tail": tail, "relation": "任职于"})
242
-
243
- # 后处理:去重和验证
244
- seen = set()
245
- final_relations = []
246
- for rel in relations:
247
- key = (rel["head"], rel["tail"], rel["relation"])
248
- if key not in seen:
249
- # 验证实体是否存在
250
- head_exists = any(e["text"] == rel["head"] for e in filtered_entities)
251
- tail_exists = any(e["text"] == rel["tail"] for e in filtered_entities)
252
- if head_exists and tail_exists:
253
- final_relations.append(rel)
254
- seen.add(key)
255
-
256
- return final_relations
257
-
258
-
259
- # ======================== 文本分析主流程 ========================
260
- def process_text(text, model_type="bert"):
261
- entities, duration = ner(text, model_type)
262
- relations = re_extract(entities, text)
263
- update_knowledge_graph(entities, relations)
264
-
265
- ent_text = "\n".join(f"{e['text']} ({e['type']}) [{e['start']}-{e['end']}]" for e in entities)
266
- rel_text = "\n".join(f"{r['head']} --[{r['relation']}]-> {r['tail']}" for r in relations)
267
- kg_text = visualize_kg_text()
268
- return ent_text, rel_text, kg_text, f"{duration:.2f} 秒"
269
-
270
-
271
- def process_file(file, model_type="bert"):
272
- try:
273
- with open(file.name, 'rb') as f:
274
- content = f.read()
275
-
276
- if len(content) > 5 * 1024 * 1024:
277
- return "❌ 文件太大", "", "", ""
278
-
279
- # 检测编码
280
- try:
281
- encoding = chardet.detect(content)['encoding'] or 'utf-8'
282
- text = content.decode(encoding)
283
- except UnicodeDecodeError:
284
- # 尝试常见中文编码
285
- for enc in ['gb18030', 'utf-16', 'big5']:
286
- try:
287
- text = content.decode(enc)
288
- break
289
- except:
290
- continue
291
- else:
292
- return "❌ 编码解析失败", "", "", ""
293
-
294
- return process_text(text, model_type)
295
- except Exception as e:
296
- return f" 文件处理错误: {str(e)}", "", "", ""
297
-
298
-
299
- # ======================== 模型评估与自动标注 ========================
300
- def convert_telegram_json_to_eval_format(path):
301
- with open(path, encoding="utf-8") as f:
302
- data = json.load(f)
303
- if isinstance(data, dict) and "text" in data:
304
- return [{"text": data["text"], "entities": [
305
- {"text": data["text"][e["start"]:e["end"]]} for e in data.get("entities", [])
306
- ]}]
307
- elif isinstance(data, list):
308
- return data
309
- elif isinstance(data, dict) and "messages" in data:
310
- result = []
311
- for m in data.get("messages", []):
312
- if isinstance(m.get("text"), str):
313
- result.append({"text": m["text"], "entities": []})
314
- elif isinstance(m.get("text"), list):
315
- txt = ''.join([x["text"] if isinstance(x, dict) else x for x in m["text"]])
316
- result.append({"text": txt, "entities": []})
317
- return result
318
- return []
319
-
320
-
321
- def evaluate_ner_model(data, model_type):
322
- tp, fp, fn = 0, 0, 0
323
- POS_TOLERANCE = 1
324
-
325
- for item in data:
326
- text = item["text"]
327
- # 处理标注数据
328
- gold_entities = []
329
- for e in item.get("entities", []):
330
- if "text" in e and "type" in e:
331
- norm_type = LABEL_MAPPING.get(e["type"], e["type"])
332
- gold_entities.append({
333
- "text": e["text"],
334
- "type": norm_type,
335
- "start": e.get("start", -1),
336
- "end": e.get("end", -1)
337
- })
338
-
339
- # 获取预测结果
340
- pred_entities, _ = ner(text, model_type)
341
-
342
- # 初始化匹配状态
343
- matched_gold = [False] * len(gold_entities)
344
- matched_pred = [False] * len(pred_entities)
345
-
346
- # 遍历预测实体寻找匹配
347
- for p_idx, p in enumerate(pred_entities):
348
- for g_idx, g in enumerate(gold_entities):
349
- if not matched_gold[g_idx] and \
350
- p["text"] == g["text"] and \
351
- p["type"] == g["type"] and \
352
- abs(p["start"] - g["start"]) <= POS_TOLERANCE and \
353
- abs(p["end"] - g["end"]) <= POS_TOLERANCE:
354
- matched_gold[g_idx] = True
355
- matched_pred[p_idx] = True
356
- break
357
-
358
- # 统计指标
359
- tp += sum(matched_pred)
360
- fp += len(pred_entities) - sum(matched_pred)
361
- fn += len(gold_entities) - sum(matched_gold)
362
-
363
- # 处理除零情况
364
- precision = tp / (tp + fp) if (tp + fp) > 0 else 0
365
- recall = tp / (tp + fn) if (tp + fn) > 0 else 0
366
- f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
367
-
368
- return (f"Precision: {precision:.2f}\n"
369
- f"Recall: {recall:.2f}\n"
370
- f"F1: {f1:.2f}")
371
-
372
- def auto_annotate(file, model_type):
373
- data = convert_telegram_json_to_eval_format(file.name)
374
- for item in data:
375
- ents, _ = ner(item["text"], model_type)
376
- item["entities"] = ents
377
- return json.dumps(data, ensure_ascii=False, indent=2)
378
-
379
-
380
- def save_json(json_text):
381
- fname = f"auto_labeled_{int(time.time())}.json"
382
- with open(fname, "w", encoding="utf-8") as f:
383
- f.write(json_text)
384
- return fname
385
-
386
-
387
- # ======================== Gradio 界面 ========================
388
- with gr.Blocks(css="""
389
- .kg-graph {height: 500px; overflow-y: auto;}
390
- .warning {color: #ff6b6b;}
391
- """) as demo:
392
- gr.Markdown("# 🤖 聊天记录实体关系识别系统")
393
-
394
- with gr.Tab("📄 文本分析"):
395
- input_text = gr.Textbox(lines=6, label="输入文本")
396
- model_type = gr.Radio(["bert", "chatglm"], value="bert", label="选择模型")
397
- btn = gr.Button("开始分析")
398
- out1 = gr.Textbox(label="识别实体")
399
- out2 = gr.Textbox(label="识别关系")
400
- out3 = gr.Textbox(label="知识图谱")
401
- out4 = gr.Textbox(label="耗时")
402
- btn.click(fn=process_text, inputs=[input_text, model_type], outputs=[out1, out2, out3, out4])
403
-
404
- with gr.Tab("🗂 文件分析"):
405
- file_input = gr.File(file_types=[".txt", ".json"])
406
- file_btn = gr.Button("上传并分析")
407
- fout1, fout2, fout3, fout4 = gr.Textbox(), gr.Textbox(), gr.Textbox(), gr.Textbox()
408
- file_btn.click(fn=process_file, inputs=[file_input, model_type], outputs=[fout1, fout2, fout3, fout4])
409
-
410
- with gr.Tab("📊 模型评估"):
411
- eval_file = gr.File(label="上传标注 JSON")
412
- eval_model = gr.Radio(["bert", "chatglm"], value="bert")
413
- eval_btn = gr.Button("开始评估")
414
- eval_output = gr.Textbox(label="评估结果", lines=5)
415
- eval_btn.click(lambda f, m: evaluate_ner_model(convert_telegram_json_to_eval_format(f.name), m),
416
- [eval_file, eval_model], eval_output)
417
-
418
- with gr.Tab("✏️ 自动标注"):
419
- raw_file = gr.File(label="上传 Telegram 原始 JSON")
420
- auto_model = gr.Radio(["bert", "chatglm"], value="bert")
421
- auto_btn = gr.Button("自动标注")
422
- marked_texts = gr.Textbox(label="标注结果", lines=20)
423
- download_btn = gr.Button("💾 下载标注文件")
424
- auto_btn.click(fn=auto_annotate, inputs=[raw_file, auto_model], outputs=marked_texts)
425
- download_btn.click(fn=save_json, inputs=marked_texts, outputs=gr.File())
426
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
427
  demo.launch(server_name="0.0.0.0", server_port=7860)
 
1
+ import torch
2
+ from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline, AutoModel
3
+ import gradio as gr
4
+ import re
5
+ import os
6
+ import json
7
+ import chardet
8
+ from sklearn.metrics import precision_score, recall_score, f1_score
9
+ import time
10
+ # ======================== 数据库模块 ========================
11
+ import pymysql
12
+ from configparser import ConfigParser
13
+
14
+
15
+ def get_db_connection():
16
+ config = ConfigParser()
17
+ config.read('db_config.ini')
18
+
19
+ return pymysql.connect(
20
+ host=config.get('mysql', 'host'),
21
+ user=config.get('mysql', 'user'),
22
+ password=config.get('mysql', 'password'),
23
+ database=config.get('mysql', 'database'),
24
+ port=config.getint('mysql', 'port', fallback=3306),
25
+ charset=config.get('mysql', 'charset', fallback='utf8mb4'),
26
+ cursorclass=pymysql.cursors.DictCursor
27
+ )
28
+
29
+
30
+ def save_to_db(table, data):
31
+ try:
32
+ conn = get_db_connection()
33
+ with conn.cursor() as cursor:
34
+ placeholders = ', '.join(['%s'] * len(data))
35
+ columns = ', '.join(data.keys())
36
+ sql = f"INSERT INTO {table} ({columns}) VALUES ({placeholders})"
37
+ cursor.execute(sql, list(data.values()))
38
+ conn.commit()
39
+ except Exception as e:
40
+ print(f"数据库写入失败: {e}")
41
+ finally:
42
+ conn.close()
43
+
44
+ # ======================== 模型加载 ========================
45
+ NER_MODEL_NAME = "hfl/chinese-roberta-wwm-ext-large"
46
+ bert_tokenizer = AutoTokenizer.from_pretrained(NER_MODEL_NAME)
47
+ bert_ner_model = AutoModelForTokenClassification.from_pretrained(NER_MODEL_NAME)
48
+ bert_ner_pipeline = pipeline(
49
+ "ner",
50
+ model=bert_ner_model,
51
+ tokenizer=bert_tokenizer,
52
+ aggregation_strategy="first"
53
+ )
54
+
55
+ LABEL_MAPPING = {
56
+ "address": "LOC",
57
+ "company": "ORG",
58
+ "name": "PER",
59
+ "organization": "ORG",
60
+ "position": "TITLE"
61
+ }
62
+
63
+ chatglm_model, chatglm_tokenizer = None, None
64
+ use_chatglm = False
65
+ try:
66
+ chatglm_model_name = "THUDM/chatglm-6b-int4"
67
+ chatglm_tokenizer = AutoTokenizer.from_pretrained(chatglm_model_name, trust_remote_code=True)
68
+ chatglm_model = AutoModel.from_pretrained(
69
+ chatglm_model_name,
70
+ trust_remote_code=True,
71
+ device_map="cpu",
72
+ torch_dtype=torch.float32
73
+ ).eval()
74
+ use_chatglm = True
75
+ print("✅ 4-bit量化版ChatGLM加载成功")
76
+ except Exception as e:
77
+ print(f" ChatGLM加载失败: {e}")
78
+
79
+ # ======================== 知识图谱结构 ========================
80
+ knowledge_graph = {"entities": set(), "relations": set()}
81
+
82
+
83
+ def update_knowledge_graph(entities, relations):
84
+ # 保存实体
85
+ for e in entities:
86
+ if isinstance(e, dict) and 'text' in e and 'type' in e:
87
+ save_to_db('entities', {
88
+ 'text': e['text'],
89
+ 'type': e['type'],
90
+ 'start_pos': e.get('start', -1),
91
+ 'end_pos': e.get('end', -1),
92
+ 'source': 'user_input'
93
+ })
94
+
95
+ # 保存关系
96
+ for r in relations:
97
+ if isinstance(r, dict) and all(k in r for k in ("head", "tail", "relation")):
98
+ save_to_db('relations', {
99
+ 'head_entity': r['head'],
100
+ 'tail_entity': r['tail'],
101
+ 'relation_type': r['relation'],
102
+ 'source_text': '' # 可添加原文关联
103
+ })
104
+
105
+
106
+ def visualize_kg_text():
107
+ nodes = [f"{ent[0]} ({ent[1]})" for ent in knowledge_graph["entities"]]
108
+ edges = [f"{h} --[{r}]-> {t}" for h, t, r in knowledge_graph["relations"]]
109
+ return "\n".join(["📌 实体:"] + nodes + ["", "📎 关系:"] + edges)
110
+
111
+ # ======================== 实体识别(NER) ========================
112
+ def merge_adjacent_entities(entities):
113
+ merged = []
114
+ for entity in entities:
115
+ if not merged:
116
+ merged.append(entity)
117
+ continue
118
+
119
+ last = merged[-1]
120
+ # 合并相邻的同类型实体
121
+ if (entity["type"] == last["type"] and
122
+ entity["start"] == last["end"] and
123
+ entity["text"] not in last["text"]):
124
+ merged[-1] = {
125
+ "text": last["text"] + entity["text"],
126
+ "type": last["type"],
127
+ "start": last["start"],
128
+ "end": entity["end"]
129
+ }
130
+ else:
131
+ merged.append(entity)
132
+ return merged
133
+
134
+
135
+ def ner(text, model_type="bert"):
136
+ start_time = time.time()
137
+ if model_type == "chatglm" and use_chatglm:
138
+ try:
139
+ prompt = f"""请从以下文本中识别所有实体,严格按照JSON列表格式返回,每个实体包含text、type、start、end字段:
140
+ 示例:[{{"text": "北京", "type": "LOC", "start": 0, "end": 2}}]
141
+ 文本:{text}"""
142
+ response = chatglm_model.chat(chatglm_tokenizer, prompt, temperature=0.1)
143
+ if isinstance(response, tuple):
144
+ response = response[0]
145
+
146
+ # 增强 JSON 解析
147
+ try:
148
+ json_str = re.search(r'\[.*\]', response, re.DOTALL).group()
149
+ entities = json.loads(json_str)
150
+ # 验证字段
151
+ valid_entities = []
152
+ for ent in entities:
153
+ if all(k in ent for k in ("text", "type", "start", "end")):
154
+ valid_entities.append(ent)
155
+ return valid_entities, time.time() - start_time
156
+ except Exception as e:
157
+ print(f"JSON 解析失败: {e}")
158
+ return [], time.time() - start_time
159
+ except Exception as e:
160
+ print(f"ChatGLM 调用失败:{e}")
161
+ return [], time.time() - start_time
162
+
163
+ # 使用微调的 BERT 中文 NER 模型
164
+ raw_results = []
165
+ max_len = 510 # 安全一点,留一点空余
166
+ text_chunks = [text[i:i + max_len] for i in range(0, len(text), max_len)]
167
+
168
+ for idx, chunk in enumerate(text_chunks):
169
+ chunk_results = bert_ner_pipeline(chunk)
170
+ # 修正每个 chunk 里识别的实体在整体文本中的位置
171
+ for r in chunk_results:
172
+ r["start"] += idx * max_len
173
+ r["end"] += idx * max_len
174
+ raw_results.extend(chunk_results)
175
+
176
+ entities = []
177
+ for r in raw_results:
178
+ mapped_type = LABEL_MAPPING.get(r['entity_group'], r['entity_group'])
179
+ entities.append({
180
+ "text": r['word'].replace(' ', ''),
181
+ "start": r['start'],
182
+ "end": r['end'],
183
+ "type": mapped_type
184
+ })
185
+
186
+ # 执行合并处理
187
+ entities = merge_adjacent_entities(entities)
188
+ return entities, time.time() - start_time
189
+
190
+
191
+ # ======================== 关系抽取(RE) ========================
192
+ def re_extract(entities, text):
193
+ # 参数校验
194
+ if not entities or not text:
195
+ return []
196
+
197
+ # 实体类型过滤(根据业务需求调整)
198
+ valid_entity_types = {"PER", "LOC", "ORG", "TITLE"}
199
+ filtered_entities = [e for e in entities if e.get("type") in valid_entity_types]
200
+
201
+ # --------------------- 处理单实体场景 ---------------------
202
+ if len(filtered_entities) == 1:
203
+ single_relations = []
204
+ ent = filtered_entities[0]
205
+
206
+ # 规则1:人物职位检测
207
+ if ent["type"] == "PER":
208
+ position_keywords = ["CEO", "经理", "总监", "工程师", "教授"]
209
+ for keyword in position_keywords:
210
+ if keyword in text:
211
+ single_relations.append({
212
+ "head": ent["text"],
213
+ "tail": keyword,
214
+ "relation": "担任职位"
215
+ })
216
+ break
217
+
218
+ # 规则2:机构地点检测
219
+ if ent["type"] in ["ORG", "LOC"]:
220
+ location_verbs = ["位于", "坐落于", "地处"]
221
+ for verb in location_verbs:
222
+ if verb in text:
223
+ match = re.search(fr"{ent['text']}{verb}(.*?)[,。]", text)
224
+ if match:
225
+ single_relations.append({
226
+ "head": ent["text"],
227
+ "tail": match.group(1).strip(),
228
+ "relation": "位置"
229
+ })
230
+ break
231
+ return single_relations
232
+
233
+ # --------------------- 多实体关系抽取 ---------------------
234
+ relations = []
235
+
236
+ # 方案1:使用ChatGLM抽取关系
237
+ if use_chatglm and len(filtered_entities) >= 2:
238
+ try:
239
+ entity_list = [e["text"] for e in filtered_entities]
240
+ prompt = f"""请分析以下文本中的实体关系,严格按照JSON列表格式返回:
241
+ 文本内容:{text}
242
+ 候选实体:{entity_list}
243
+ 要求:
244
+ 1. 只返回存在明确关系的实体对
245
+ 2. 关系类型使用:属于、位于、任职于、合作、其他
246
+ 3. 示例格式:[{{"head":"实体1", "tail":"实体2", "relation":"关系类型"}}]
247
+ 请直接返回JSON,不要多余内容:"""
248
+
249
+ response = chatglm_model.chat(chatglm_tokenizer, prompt, temperature=0.01)
250
+ if isinstance(response, tuple):
251
+ response = response[0]
252
+
253
+ # 增强JSON解析
254
+ try:
255
+ json_str = re.search(r'(\[.*?\])', response, re.DOTALL)
256
+ if json_str:
257
+ json_str = json_str.group(1)
258
+ json_str = re.sub(r'[\u201c\u201d]', '"', json_str) # 处理中文引号
259
+ json_str = re.sub(r'(?<!,)\n', '', json_str) # 保留逗号后的换行
260
+ relations = json.loads(json_str)
261
+
262
+ # 验证关系有效性
263
+ valid_relations = []
264
+ valid_rel_types = {"属于", "位于", "任职于", "合作", "其他"}
265
+ for rel in relations:
266
+ if (isinstance(rel, dict) and
267
+ rel.get("head") in entity_list and
268
+ rel.get("tail") in entity_list and
269
+ rel.get("relation") in valid_rel_types):
270
+ valid_relations.append(rel)
271
+ relations = valid_relations
272
+ except Exception as e:
273
+ print(f"[DEBUG] 关系解析失败: {str(e)}")
274
+
275
+ except Exception as e:
276
+ print(f"ChatGLM关系抽取异常: {str(e)}")
277
+
278
+ # 方案2:规则兜底(当模型不可用或未抽取出关系时)
279
+ if len(relations) == 0:
280
+ # 规则1:A位于B
281
+ location_matches = re.finditer(r'([^\s,。]+)[位于|坐落于|地处]([^\s,。]+)', text)
282
+ for match in location_matches:
283
+ head, tail = match.groups()
284
+ relations.append({"head": head, "tail": tail, "relation": "位于"})
285
+
286
+ # 规则2:A属于B
287
+ belong_matches = re.finditer(r'([^\s,。]+)(属于|隶属于)([^\s,。]+)', text)
288
+ for match in belong_matches:
289
+ head, _, tail = match.groups()
290
+ relations.append({"head": head, "tail": tail, "relation": "属于"})
291
+
292
+ # 规则3:人物-机构关系
293
+ person_org_pattern = r'([\u4e00-\u9fa5]{2,4})(现任|担任|就职于)([\u4e00-\u9fa5]+?公司|[\u4e00-\u9fa5]+?大学)'
294
+ for match in re.finditer(person_org_pattern, text):
295
+ head, _, tail = match.groups()
296
+ relations.append({"head": head, "tail": tail, "relation": "任职于"})
297
+
298
+ # 后处理:去重和验证
299
+ seen = set()
300
+ final_relations = []
301
+ for rel in relations:
302
+ key = (rel["head"], rel["tail"], rel["relation"])
303
+ if key not in seen:
304
+ # 验证实体是否存在
305
+ head_exists = any(e["text"] == rel["head"] for e in filtered_entities)
306
+ tail_exists = any(e["text"] == rel["tail"] for e in filtered_entities)
307
+ if head_exists and tail_exists:
308
+ final_relations.append(rel)
309
+ seen.add(key)
310
+
311
+ return final_relations
312
+
313
+
314
+ # ======================== 文本分析主流程 ========================
315
+ def process_text(text, model_type="bert"):
316
+ entities, duration = ner(text, model_type)
317
+ relations = re_extract(entities, text)
318
+ update_knowledge_graph(entities, relations)
319
+
320
+ ent_text = "\n".join(f"{e['text']} ({e['type']}) [{e['start']}-{e['end']}]" for e in entities)
321
+ rel_text = "\n".join(f"{r['head']} --[{r['relation']}]-> {r['tail']}" for r in relations)
322
+ kg_text = visualize_kg_text()
323
+ return ent_text, rel_text, kg_text, f"{duration:.2f} 秒"
324
+
325
+
326
+ def process_file(file, model_type="bert"):
327
+ try:
328
+ with open(file.name, 'rb') as f:
329
+ content = f.read()
330
+
331
+ if len(content) > 5 * 1024 * 1024:
332
+ return "❌ 文件太大", "", "", ""
333
+
334
+ # 检测编码
335
+ try:
336
+ encoding = chardet.detect(content)['encoding'] or 'utf-8'
337
+ text = content.decode(encoding)
338
+ except UnicodeDecodeError:
339
+ # 尝试常见中文编码
340
+ for enc in ['gb18030', 'utf-16', 'big5']:
341
+ try:
342
+ text = content.decode(enc)
343
+ break
344
+ except:
345
+ continue
346
+ else:
347
+ return "❌ 编码解析失败", "", "", ""
348
+
349
+ return process_text(text, model_type)
350
+ except Exception as e:
351
+ return f"❌ 文件处理错误: {str(e)}", "", "", ""
352
+
353
+
354
+ # ======================== 模型评估与自动标注 ========================
355
+ def convert_telegram_json_to_eval_format(path):
356
+ with open(path, encoding="utf-8") as f:
357
+ data = json.load(f)
358
+ if isinstance(data, dict) and "text" in data:
359
+ return [{"text": data["text"], "entities": [
360
+ {"text": data["text"][e["start"]:e["end"]]} for e in data.get("entities", [])
361
+ ]}]
362
+ elif isinstance(data, list):
363
+ return data
364
+ elif isinstance(data, dict) and "messages" in data:
365
+ result = []
366
+ for m in data.get("messages", []):
367
+ if isinstance(m.get("text"), str):
368
+ result.append({"text": m["text"], "entities": []})
369
+ elif isinstance(m.get("text"), list):
370
+ txt = ''.join([x["text"] if isinstance(x, dict) else x for x in m["text"]])
371
+ result.append({"text": txt, "entities": []})
372
+ return result
373
+ return []
374
+
375
+
376
+ def evaluate_ner_model(data, model_type):
377
+ tp, fp, fn = 0, 0, 0
378
+ POS_TOLERANCE = 1
379
+
380
+ for item in data:
381
+ text = item["text"]
382
+ # 处理标注数据
383
+ gold_entities = []
384
+ for e in item.get("entities", []):
385
+ if "text" in e and "type" in e:
386
+ norm_type = LABEL_MAPPING.get(e["type"], e["type"])
387
+ gold_entities.append({
388
+ "text": e["text"],
389
+ "type": norm_type,
390
+ "start": e.get("start", -1),
391
+ "end": e.get("end", -1)
392
+ })
393
+
394
+ # 获取预测结果
395
+ pred_entities, _ = ner(text, model_type)
396
+
397
+ # 初始化匹配状态
398
+ matched_gold = [False] * len(gold_entities)
399
+ matched_pred = [False] * len(pred_entities)
400
+
401
+ # 遍历预测实体寻找匹配
402
+ for p_idx, p in enumerate(pred_entities):
403
+ for g_idx, g in enumerate(gold_entities):
404
+ if not matched_gold[g_idx] and \
405
+ p["text"] == g["text"] and \
406
+ p["type"] == g["type"] and \
407
+ abs(p["start"] - g["start"]) <= POS_TOLERANCE and \
408
+ abs(p["end"] - g["end"]) <= POS_TOLERANCE:
409
+ matched_gold[g_idx] = True
410
+ matched_pred[p_idx] = True
411
+ break
412
+
413
+ # 统计指标
414
+ tp += sum(matched_pred)
415
+ fp += len(pred_entities) - sum(matched_pred)
416
+ fn += len(gold_entities) - sum(matched_gold)
417
+
418
+ # 处理除零情况
419
+ precision = tp / (tp + fp) if (tp + fp) > 0 else 0
420
+ recall = tp / (tp + fn) if (tp + fn) > 0 else 0
421
+ f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
422
+
423
+ return (f"Precision: {precision:.2f}\n"
424
+ f"Recall: {recall:.2f}\n"
425
+ f"F1: {f1:.2f}")
426
+
427
+
428
+ def auto_annotate(file, model_type):
429
+ data = convert_telegram_json_to_eval_format(file.name)
430
+ for item in data:
431
+ ents, _ = ner(item["text"], model_type)
432
+ item["entities"] = ents
433
+ return json.dumps(data, ensure_ascii=False, indent=2)
434
+
435
+
436
+ def save_json(json_text):
437
+ fname = f"auto_labeled_{int(time.time())}.json"
438
+ with open(fname, "w", encoding="utf-8") as f:
439
+ f.write(json_text)
440
+ return fname
441
+
442
+
443
+ # ======================== 数据集导入 ========================
444
+ def import_dataset(path="D:/云边智算/暗语识别/filtered_results"):
445
+ import os
446
+ import json
447
+
448
+ for filename in os.listdir(path):
449
+ if filename.endswith('.json'):
450
+ filepath = os.path.join(path, filename)
451
+ with open(filepath, 'r', encoding='utf-8') as f:
452
+ data = json.load(f)
453
+ # 调用现有处理流程
454
+ process_text(data['text'])
455
+ print(f"已处理文件: {filename}")
456
+
457
+
458
+ # ======================== Gradio 界面 ========================
459
+ with gr.Blocks(css="""
460
+ .kg-graph {height: 500px; overflow-y: auto;}
461
+ .warning {color: #ff6b6b;}
462
+ """) as demo:
463
+ gr.Markdown("# 🤖 聊天记录实体关系识别系统")
464
+
465
+ with gr.Tab("📄 文本分析"):
466
+ input_text = gr.Textbox(lines=6, label="输入文本")
467
+ model_type = gr.Radio(["bert", "chatglm"], value="bert", label="选择模型")
468
+ btn = gr.Button("开始分析")
469
+ out1 = gr.Textbox(label="识别��体")
470
+ out2 = gr.Textbox(label="识别关系")
471
+ out3 = gr.Textbox(label="知识图谱")
472
+ out4 = gr.Textbox(label="耗时")
473
+ btn.click(fn=process_text, inputs=[input_text, model_type], outputs=[out1, out2, out3, out4])
474
+
475
+ with gr.Tab("🗂 文件分析"):
476
+ file_input = gr.File(file_types=[".txt", ".json"])
477
+ file_btn = gr.Button("上传并分析")
478
+ fout1, fout2, fout3, fout4 = gr.Textbox(), gr.Textbox(), gr.Textbox(), gr.Textbox()
479
+ file_btn.click(fn=process_file, inputs=[file_input, model_type], outputs=[fout1, fout2, fout3, fout4])
480
+
481
+ with gr.Tab("📊 模型评估"):
482
+ eval_file = gr.File(label="上传标注 JSON")
483
+ eval_model = gr.Radio(["bert", "chatglm"], value="bert")
484
+ eval_btn = gr.Button("开始评估")
485
+ eval_output = gr.Textbox(label="评估结果", lines=5)
486
+ eval_btn.click(lambda f, m: evaluate_ner_model(convert_telegram_json_to_eval_format(f.name), m),
487
+ [eval_file, eval_model], eval_output)
488
+
489
+ with gr.Tab("✏️ 自动标注"):
490
+ raw_file = gr.File(label="上传 Telegram 原始 JSON")
491
+ auto_model = gr.Radio(["bert", "chatglm"], value="bert")
492
+ auto_btn = gr.Button("自动标注")
493
+ marked_texts = gr.Textbox(label="标注结果", lines=20)
494
+ download_btn = gr.Button("💾 下载标注文件")
495
+ auto_btn.click(fn=auto_annotate, inputs=[raw_file, auto_model], outputs=marked_texts)
496
+ download_btn.click(fn=save_json, inputs=marked_texts, outputs=gr.File())
497
+
498
+ with gr.Tab("📂 数据管理"):
499
+ gr.Markdown("### 数据集导入")
500
+ dataset_path = gr.Textbox(
501
+ value="D:/云边智算/暗语识别/filtered_results",
502
+ label="数据集路径"
503
+ )
504
+ import_btn = gr.Button("导入数据集到数据库")
505
+ import_output = gr.Textbox(label="导入日志")
506
+ import_btn.click(fn=lambda: import_dataset(dataset_path.value), outputs=import_output)
507
+
508
  demo.launch(server_name="0.0.0.0", server_port=7860)
requirements.txt CHANGED
@@ -1,12 +1,14 @@
1
- gradio==3.50.2
2
- transformers==4.39.3
3
- torch>=2.1.0
4
- networkx>=3.0
5
- python-dotenv>=1.0.0
6
- sentencepiece>=0.2.0
7
- cpm-kernels>=1.0.11
8
- accelerate>=0.27.0
9
- scikit-learn>=1.3.0
10
- chardet>=5.2.0
11
- pandas>=2.1.0
12
- pyvis>=0.3.2
 
 
 
1
+ gradio==3.50.2
2
+ transformers==4.39.3
3
+ torch>=2.1.0,<3.0.0
4
+ networkx>=3.0
5
+ python-dotenv>=1.0.0
6
+ sentencepiece>=0.2.0
7
+ cpm-kernels>=1.0.11
8
+ accelerate>=0.27.0
9
+ scikit-learn>=1.3.0
10
+ chardet>=5.2.0
11
+ pandas>=2.1.0
12
+ pyvis>=0.3.2
13
+ pymysql==1.1.0
14
+ protobuf==3.20.3 # 避免与新版transformers冲突