chen666-666 commited on
Commit
d18c4dc
·
verified ·
1 Parent(s): 8c073c8

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +770 -768
app.py CHANGED
@@ -1,769 +1,771 @@
1
- import torch
2
- from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline, AutoModel
3
- import gradio as gr
4
- import re
5
- import os
6
- import json
7
- import chardet
8
- from sklearn.metrics import precision_score, recall_score, f1_score
9
- import time
10
- from functools import lru_cache # 添加这行导入
11
- # ======================== 数据库模块 ========================
12
- from sqlalchemy import create_engine
13
- from sqlalchemy.orm import sessionmaker
14
- from contextlib import contextmanager
15
- import logging
16
- import networkx as nx
17
- from pyvis.network import Network
18
- import pandas as pd
19
-
20
- # 配置日志
21
- logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
22
-
23
- # 使用SQLAlchemy的连接池来管理数据库连接
24
- DATABASE_URL = "mysql+pymysql://user:password@host/dbname" # 请根据实际情况修改连接字符串
25
-
26
- # 创建引擎(连接池)
27
- engine = create_engine(DATABASE_URL, pool_size=10, max_overflow=20, echo=True)
28
-
29
- # 创建session类
30
- Session = sessionmaker(bind=engine)
31
-
32
- @contextmanager
33
- def get_db_connection():
34
- """
35
- 使用上下文管理器获取数据库连接
36
- """
37
- session = None
38
- try:
39
- session = Session() # 从连接池中获取一个连接
40
- logging.info("✅ 数据库连接已建立")
41
- yield session # 使用session进行数据库操作
42
- except Exception as e:
43
- logging.error(f"❌ 数据库操作时发生错误: {e}")
44
- if session:
45
- session.rollback() # 回滚事务
46
- finally:
47
- if session:
48
- try:
49
- session.commit() # 提交事务
50
- logging.info("✅ 数据库事务已提交")
51
- except Exception as e:
52
- logging.error(f"❌ 提交事务时发生错误: {e}")
53
- finally:
54
- session.close() # 关闭会话,释放连接
55
- logging.info("✅ 数据库连接已关闭")
56
-
57
- def save_to_db(table, data):
58
- """
59
- 将数据保存到数据库
60
- :param table: 表名
61
- :param data: 数据字典
62
- """
63
- try:
64
- valid_tables = ["entities", "relations"] # 只允许保存到这些表
65
- if table not in valid_tables:
66
- raise ValueError(f"Invalid table: {table}")
67
-
68
- with get_db_connection() as conn:
69
- if conn:
70
- # 这里的操作假设使用了ORM模型来处理插入,实际根据你数据库的表结构来调整
71
- table_model = get_table_model(table) # 假设你有一个方法来根据表名获得ORM模型
72
- new_record = table_model(**data)
73
- conn.add(new_record)
74
- conn.commit() # 提交事务
75
- except Exception as e:
76
- logging.error(f"❌ 保存数据时发生错误: {e}")
77
- return False
78
- return True
79
-
80
- def get_table_model(table_name):
81
- """
82
- 根据表名获取ORM模型(这里假设你有一个映射到数据库表的模型)
83
- :param table_name: 表名
84
- :return: 对应的ORM模型
85
- """
86
- if table_name == "entities":
87
- from models import Entity # 假设你已经定义了ORM模型
88
- return Entity
89
- elif table_name == "relations":
90
- from models import Relation # 假设你已经定义了ORM模型
91
- return Relation
92
- else:
93
- raise ValueError(f"Unknown table: {table_name}")
94
-
95
-
96
- # ======================== 模型加载 ========================
97
- NER_MODEL_NAME = "uer/roberta-base-finetuned-cluener2020-chinese"
98
-
99
- @lru_cache(maxsize=1)
100
- def get_ner_pipeline():
101
- tokenizer = AutoTokenizer.from_pretrained(NER_MODEL_NAME)
102
- model = AutoModelForTokenClassification.from_pretrained(NER_MODEL_NAME)
103
- return pipeline(
104
- "ner",
105
- model=model,
106
- tokenizer=tokenizer,
107
- aggregation_strategy="first"
108
- )
109
-
110
- @lru_cache(maxsize=1)
111
- def get_re_pipeline():
112
- return pipeline(
113
- "text2text-generation",
114
- model=NER_MODEL_NAME,
115
- tokenizer=NER_MODEL_NAME,
116
- max_length=512,
117
- device=0 if torch.cuda.is_available() else -1
118
- )
119
-
120
-
121
- # chatglm_model, chatglm_tokenizer = None, None
122
- # use_chatglm = False
123
- # try:
124
- # chatglm_model_name = "THUDM/chatglm-6b-int4"
125
- # chatglm_tokenizer = AutoTokenizer.from_pretrained(chatglm_model_name, trust_remote_code=True)
126
- # chatglm_model = AutoModel.from_pretrained(
127
- # chatglm_model_name,
128
- # trust_remote_code=True,
129
- # device_map="cpu",
130
- # torch_dtype=torch.float32
131
- # ).eval()
132
- # use_chatglm = True
133
- # print("✅ 4-bit量化版ChatGLM加载成功")
134
- # except Exception as e:
135
- # print(f"❌ ChatGLM加载失败: {e}")
136
-
137
- # ======================== 知识图谱结构 ========================
138
- knowledge_graph = {"entities": set(), "relations": set()}
139
-
140
-
141
- def update_knowledge_graph(entities, relations):
142
- # 保存实体
143
- for e in entities:
144
- if isinstance(e, dict) and 'text' in e and 'type' in e:
145
- save_to_db('entities', {
146
- 'text': e['text'],
147
- 'type': e['type'],
148
- 'start_pos': e.get('start', -1),
149
- 'end_pos': e.get('end', -1),
150
- 'source': 'user_input'
151
- })
152
-
153
- # 保存关系
154
- for r in relations:
155
- if isinstance(r, dict) and all(k in r for k in ("head", "tail", "relation")):
156
- save_to_db('relations', {
157
- 'head_entity': r['head'],
158
- 'tail_entity': r['tail'],
159
- 'relation_type': r['relation'],
160
- 'source_text': '' # 可添加原文关联
161
- })
162
-
163
-
164
- def visualize_kg_text():
165
- nodes = [f"{ent[0]} ({ent[1]})" for ent in knowledge_graph["entities"]]
166
- edges = [f"{h} --[{r}]-> {t}" for h, t, r in knowledge_graph["relations"]]
167
- return "\n".join(["📌 实体:"] + nodes + ["", "📎 关系:"] + edges)
168
-
169
- def visualize_kg_interactive(entities, relations):
170
- """
171
- 生成交互式的知识图谱可视化
172
- """
173
- # 创建一个新的网络图
174
- net = Network(height="500px", width="100%", bgcolor="#ffffff", font_color="black")
175
-
176
- # 添加节点
177
- entity_colors = {
178
- 'PER': '#FF6B6B', # 人物-红色
179
- 'ORG': '#4ECDC4', # 组织-青色
180
- 'LOC': '#45B7D1', # 地点-蓝色
181
- 'TIME': '#96CEB4', # 时间-绿色
182
- 'MISC': '#D4A5A5' # 其他-灰色
183
- }
184
-
185
- # 添加实体节点
186
- for entity in entities:
187
- node_color = entity_colors.get(entity['type'], '#D3D3D3')
188
- net.add_node(entity['text'],
189
- label=f"{entity['text']}\n({entity['type']})",
190
- color=node_color,
191
- title=f"类型: {entity['type']}")
192
-
193
- # 添加关系边
194
- for relation in relations:
195
- net.add_edge(relation['head'],
196
- relation['tail'],
197
- label=relation['relation'],
198
- arrows='to')
199
-
200
- # 设置物理布局
201
- net.set_options('''
202
- var options = {
203
- "physics": {
204
- "forceAtlas2Based": {
205
- "gravitationalConstant": -50,
206
- "centralGravity": 0.01,
207
- "springLength": 100,
208
- "springConstant": 0.08
209
- },
210
- "maxVelocity": 50,
211
- "solver": "forceAtlas2Based",
212
- "timestep": 0.35,
213
- "stabilization": {"iterations": 150}
214
- }
215
- }
216
- ''')
217
-
218
- # 生成HTML文件
219
- html_path = "knowledge_graph.html"
220
- net.save_graph(html_path)
221
- return html_path
222
-
223
- # ======================== 实体识别(NER) ========================
224
- def merge_adjacent_entities(entities):
225
- if not entities:
226
- return entities
227
-
228
- merged = [entities[0]]
229
- for entity in entities[1:]:
230
- last = merged[-1]
231
- # 合并相邻的同类型实体
232
- if (entity["type"] == last["type"] and
233
- entity["start"] == last["end"]):
234
- last["text"] += entity["text"]
235
- last["end"] = entity["end"]
236
- else:
237
- merged.append(entity)
238
-
239
- return merged
240
-
241
-
242
- def ner(text, model_type="bert"):
243
- start_time = time.time()
244
-
245
- # 如果使用的是 ChatGLM 模型,执行 ChatGLM 的NER
246
- if model_type == "chatglm" and use_chatglm:
247
- try:
248
- prompt = f"""请从以下文本中识别所有实体,严格按照JSON列表格式返回,每个实体包含text、type、start、end字段:
249
- 示例:[{{"text": "北京", "type": "LOC", "start": 0, "end": 2}}]
250
- 文本:{text}"""
251
- response = chatglm_model.chat(chatglm_tokenizer, prompt, temperature=0.1)
252
- if isinstance(response, tuple):
253
- response = response[0]
254
-
255
- try:
256
- json_str = re.search(r'\[.*\]', response, re.DOTALL).group()
257
- entities = json.loads(json_str)
258
- valid_entities = [ent for ent in entities if all(k in ent for k in ("text", "type", "start", "end"))]
259
- return valid_entities, time.time() - start_time
260
- except Exception as e:
261
- print(f"JSON解析失败: {e}")
262
- return [], time.time() - start_time
263
- except Exception as e:
264
- print(f"ChatGLM调用失败: {e}")
265
- return [], time.time() - start_time
266
-
267
- # 使用BERT NER
268
- text_chunks = [text[i:i + 510] for i in range(0, len(text), 510)] # 安全分段
269
- raw_results = []
270
-
271
- # 获取NER pipeline
272
- ner_pipeline = get_ner_pipeline() # 使用缓存的pipeline
273
-
274
- for idx, chunk in enumerate(text_chunks):
275
- chunk_results = ner_pipeline(chunk) # 使用获取的pipeline
276
- for r in chunk_results:
277
- r["start"] += idx * 510
278
- r["end"] += idx * 510
279
- raw_results.extend(chunk_results)
280
-
281
- entities = [{
282
- "text": r['word'].replace(' ', ''),
283
- "start": r['start'],
284
- "end": r['end'],
285
- "type": LABEL_MAPPING.get(r.get('entity_group') or r.get('entity'), r.get('entity_group') or r.get('entity'))
286
- } for r in raw_results]
287
-
288
- entities = merge_adjacent_entities(entities)
289
- return entities, time.time() - start_time
290
-
291
-
292
- # ------------------ 实体类型标准化 ------------------
293
- LABEL_MAPPING = {
294
- "address": "LOC",
295
- "company": "ORG",
296
- "name": "PER",
297
- "organization": "ORG",
298
- "position": "TITLE",
299
- "government": "ORG",
300
- "scene": "LOC",
301
- "book": "WORK",
302
- "movie": "WORK",
303
- "game": "WORK"
304
- }
305
-
306
- # 提取实体
307
- entities, processing_time = ner("Google in New York met Alice")
308
-
309
- # 标准化实体类型
310
- for e in entities:
311
- e["type"] = LABEL_MAPPING.get(e.get("type"), e.get("type"))
312
-
313
- # 打印标准化后的实体
314
- print(f"[DEBUG] 标准化后实体列表: {[{'text': e['text'], 'type': e['type']} for e in entities]}")
315
-
316
- # 打印处理时间
317
- print(f"处理时间: {processing_time:.2f}秒")
318
-
319
-
320
- # ======================== 关系抽取(RE) ========================
321
- @lru_cache(maxsize=1)
322
- def get_re_pipeline():
323
- tokenizer = AutoTokenizer.from_pretrained(NER_MODEL_NAME)
324
- model = AutoModelForTokenClassification.from_pretrained(NER_MODEL_NAME)
325
- return pipeline(
326
- "ner", # 使用NER pipeline
327
- model=model,
328
- tokenizer=tokenizer,
329
- aggregation_strategy="first"
330
- )
331
-
332
- def re_extract(entities, text, use_bert_model=True):
333
- if not entities or not text:
334
- return [], 0
335
-
336
- start_time = time.time()
337
- try:
338
- # 使用规则匹配关系
339
- relations = []
340
-
341
- # 定义关系关键词和对应的实体类型约束
342
- relation_rules = {
343
- "位于": {
344
- "keywords": ["位于", "在", "坐落于"],
345
- "valid_types": {
346
- "head": ["ORG", "PER", "LOC"],
347
- "tail": ["LOC"]
348
- }
349
- },
350
- "属于": {
351
- "keywords": ["属于", "是", "为"],
352
- "valid_types": {
353
- "head": ["ORG", "PER"],
354
- "tail": ["ORG", "LOC"]
355
- }
356
- },
357
- "任职于": {
358
- "keywords": ["任职于", "就职于", "工作于"],
359
- "valid_types": {
360
- "head": ["PER"],
361
- "tail": ["ORG"]
362
- }
363
- }
364
- }
365
-
366
- # 预处理实体,去除重复和部分匹配
367
- processed_entities = []
368
- for e in entities:
369
- # 检查是否与已有实体重叠
370
- is_subset = False
371
- for pe in processed_entities:
372
- if e["text"] in pe["text"] and e["text"] != pe["text"]:
373
- is_subset = True
374
- break
375
- if not is_subset:
376
- processed_entities.append(e)
377
-
378
- # 遍历文本中的每个句子
379
- sentences = re.split('[。!?.!?]', text)
380
- for sentence in sentences:
381
- if not sentence.strip():
382
- continue
383
-
384
- # 获取当前句子中的实体
385
- sentence_entities = [e for e in processed_entities if e["text"] in sentence]
386
-
387
- # 检查每个关系类型
388
- for rel_type, rule in relation_rules.items():
389
- for keyword in rule["keywords"]:
390
- if keyword in sentence:
391
- # 在句子中查找符合类型约束的实体对
392
- for i, ent1 in enumerate(sentence_entities):
393
- for j, ent2 in enumerate(sentence_entities):
394
- if i != j: # 避免自循环
395
- # 检查实体类型是否符合规则
396
- if (ent1["type"] in rule["valid_types"]["head"] and
397
- ent2["type"] in rule["valid_types"]["tail"]):
398
- # 检查实体在句子中的位置关系
399
- if sentence.find(ent1["text"]) < sentence.find(ent2["text"]):
400
- relations.append({
401
- "head": ent1["text"],
402
- "tail": ent2["text"],
403
- "relation": rel_type
404
- })
405
-
406
- # 去重
407
- unique_relations = []
408
- seen = set()
409
- for rel in relations:
410
- rel_key = (rel["head"], rel["tail"], rel["relation"])
411
- if rel_key not in seen:
412
- seen.add(rel_key)
413
- unique_relations.append(rel)
414
-
415
- return unique_relations, time.time() - start_time
416
-
417
- except Exception as e:
418
- logging.error(f"关系抽取失败: {e}")
419
- return [], time.time() - start_time
420
-
421
-
422
- # ======================== 文本分析主流程 ========================
423
- def create_knowledge_graph(entities, relations):
424
- """
425
- 创建交互式网络图形式的知识图谱
426
- """
427
- # 创建一个新的网络图
428
- net = Network(height="600px", width="100%", bgcolor="#ffffff", font_color="black", directed=True)
429
-
430
- # 设置实体类型的颜色映射
431
- entity_colors = {
432
- 'PER': '#FF6B6B', # 人物-红色
433
- 'ORG': '#4ECDC4', # 组织-青色
434
- 'LOC': '#45B7D1', # 地点-蓝色
435
- 'TIME': '#96CEB4', # 时间-绿色
436
- 'TITLE': '#D4A5A5' # 职位-粉色
437
- }
438
-
439
- # 添加实体节点
440
- added_nodes = set()
441
- for entity in entities:
442
- if entity['text'] not in added_nodes:
443
- node_color = entity_colors.get(entity['type'], '#D3D3D3')
444
- net.add_node(
445
- entity['text'],
446
- label=entity['text'],
447
- title=f"类型: {entity['type']}",
448
- color=node_color,
449
- size=20,
450
- font={'size': 16}
451
- )
452
- added_nodes.add(entity['text'])
453
-
454
- # 添加关系边
455
- for relation in relations:
456
- if relation['head'] in added_nodes and relation['tail'] in added_nodes:
457
- net.add_edge(
458
- relation['head'],
459
- relation['tail'],
460
- label=relation['relation'],
461
- title=relation['relation'],
462
- arrows={'to': {'enabled': True, 'type': 'arrow'}},
463
- color={'color': '#666666'},
464
- font={'size': 12}
465
- )
466
-
467
- # 设置物理布局参数
468
- net.set_options('''
469
- {
470
- "nodes": {
471
- "shape": "dot",
472
- "shadow": true
473
- },
474
- "edges": {
475
- "smooth": {
476
- "type": "continuous",
477
- "forceDirection": "none"
478
- },
479
- "shadow": true
480
- },
481
- "physics": {
482
- "barnesHut": {
483
- "gravitationalConstant": -2000,
484
- "centralGravity": 0.3,
485
- "springLength": 200,
486
- "springConstant": 0.04,
487
- "damping": 0.09
488
- },
489
- "minVelocity": 0.75
490
- },
491
- "interaction": {
492
- "hover": true,
493
- "navigationButtons": true,
494
- "keyboard": true
495
- }
496
- }
497
- ''')
498
-
499
- # 生成HTML文件
500
- try:
501
- # 创建临时目录(如果不存在)
502
- temp_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "temp")
503
- os.makedirs(temp_dir, exist_ok=True)
504
-
505
- # 生成唯一的文件名
506
- output_path = os.path.join(temp_dir, f"kg_{int(time.time())}.html")
507
-
508
- # 保存图谱
509
- net.save_graph(output_path)
510
-
511
- # 读取生成的HTML文件内容
512
- with open(output_path, 'r', encoding='utf-8') as f:
513
- html_content = f.read()
514
-
515
- # 删除临时文件
516
- os.remove(output_path)
517
-
518
- # 修改HTML内容以适应Gradio界面
519
- html_content = html_content.replace('height: 600px', 'height: 700px')
520
-
521
- # 添加图例
522
- legend_html = f"""
523
- <div style="margin-bottom: 10px; padding: 10px; background-color: #f8f9fa; border-radius: 5px;">
524
- <div style="font-weight: bold; margin-bottom: 5px;">图例说明:</div>
525
- <div style="display: flex; gap: 15px; flex-wrap: wrap;">
526
- <div style="display: flex; align-items: center; gap: 5px;">
527
- <div style="width: 15px; height: 15px; background: {entity_colors['PER']}; border-radius: 50%;"></div>
528
- <span>人物 (PER)</span>
529
- </div>
530
- <div style="display: flex; align-items: center; gap: 5px;">
531
- <div style="width: 15px; height: 15px; background: {entity_colors['ORG']}; border-radius: 50%;"></div>
532
- <span>组织 (ORG)</span>
533
- </div>
534
- <div style="display: flex; align-items: center; gap: 5px;">
535
- <div style="width: 15px; height: 15px; background: {entity_colors['LOC']}; border-radius: 50%;"></div>
536
- <span>地点 (LOC)</span>
537
- </div>
538
- <div style="display: flex; align-items: center; gap: 5px;">
539
- <div style="width: 15px; height: 15px; background: {entity_colors['TITLE']}; border-radius: 50%;"></div>
540
- <span>职位 (TITLE)</span>
541
- </div>
542
- </div>
543
- </div>
544
- """
545
-
546
- # 将图例添加到HTML内容中
547
- html_content = legend_html + html_content
548
-
549
- return html_content
550
-
551
- except Exception as e:
552
- logging.error(f"生成知识图谱失败: {str(e)}")
553
- return f"<div class='error'>生成知识图谱失败: {str(e)}</div>"
554
-
555
- def process_text(text, model_type="bert"):
556
- """
557
- 处理文本,进行实体识别和关系抽取
558
- """
559
- start_time = time.time()
560
-
561
- # 实体识别
562
- entities, ner_duration = ner(text, model_type)
563
- if not entities:
564
- return "", "", "", f"{time.time() - start_time:.2f} 秒"
565
-
566
- # 关系抽取
567
- relations, re_duration = re_extract(entities, text)
568
-
569
- # 生成文本格式的实体和关系描述
570
- ent_text = "📌 实体:\n" + "\n".join([f"{e['text']} ({e['type']})" for e in entities])
571
- rel_text = "\n\n📎 关系:\n" + "\n".join([f"{r['head']} --[{r['relation']}]--> {r['tail']}" for r in relations])
572
-
573
- # 生成知识图谱
574
- kg_text = create_knowledge_graph(entities, relations)
575
-
576
- total_duration = time.time() - start_time
577
- return ent_text, rel_text, kg_text, f"{total_duration:.2f} 秒"
578
-
579
-
580
- def process_file(file, model_type="bert"):
581
- try:
582
- with open(file.name, 'rb') as f:
583
- content = f.read()
584
-
585
- if len(content) > 5 * 1024 * 1024:
586
- return "❌ 文件太大", "", "", ""
587
-
588
- # 检测编码
589
- try:
590
- encoding = chardet.detect(content)['encoding'] or 'utf-8'
591
- text = content.decode(encoding)
592
- except UnicodeDecodeError:
593
- # 尝试常见中文编码
594
- for enc in ['gb18030', 'utf-16', 'big5']:
595
- try:
596
- text = content.decode(enc)
597
- break
598
- except:
599
- continue
600
- else:
601
- return "❌ 编码解析失败", "", "", ""
602
-
603
- # 直接调用process_text处理文本
604
- return process_text(text, model_type)
605
-
606
- except Exception as e:
607
- logging.error(f"文件处理错误: {str(e)}")
608
- return f"❌ 文件处理错误: {str(e)}", "", "", ""
609
-
610
-
611
-
612
- # ======================== 模型评估与自动标注 ========================
613
- def convert_telegram_json_to_eval_format(path):
614
- with open(path, encoding="utf-8") as f:
615
- data = json.load(f)
616
- if isinstance(data, dict) and "text" in data:
617
- return [{"text": data["text"], "entities": [
618
- {"text": data["text"][e["start"]:e["end"]]} for e in data.get("entities", [])
619
- ]}]
620
- elif isinstance(data, list):
621
- return data
622
- elif isinstance(data, dict) and "messages" in data:
623
- result = []
624
- for m in data.get("messages", []):
625
- if isinstance(m.get("text"), str):
626
- result.append({"text": m["text"], "entities": []})
627
- elif isinstance(m.get("text"), list):
628
- txt = ''.join([x["text"] if isinstance(x, dict) else x for x in m["text"]])
629
- result.append({"text": txt, "entities": []})
630
- return result
631
- return []
632
-
633
-
634
- def evaluate_ner_model(data, model_type):
635
- tp, fp, fn = 0, 0, 0
636
- POS_TOLERANCE = 1
637
-
638
- for item in data:
639
- text = item["text"]
640
- # 处理标注数据
641
- gold_entities = []
642
- for e in item.get("entities", []):
643
- if "text" in e and "type" in e:
644
- norm_type = LABEL_MAPPING.get(e["type"], e["type"])
645
- gold_entities.append({
646
- "text": e["text"],
647
- "type": norm_type,
648
- "start": e.get("start", -1),
649
- "end": e.get("end", -1)
650
- })
651
-
652
- # 获取预测结果
653
- pred_entities, _ = ner(text, model_type)
654
-
655
- # 初始化匹配状态
656
- matched_gold = [False] * len(gold_entities)
657
- matched_pred = [False] * len(pred_entities)
658
-
659
- # 遍历预测实体寻找匹配
660
- for p_idx, p in enumerate(pred_entities):
661
- for g_idx, g in enumerate(gold_entities):
662
- if not matched_gold[g_idx] and \
663
- p["text"] == g["text"] and \
664
- p["type"] == g["type"] and \
665
- abs(p["start"] - g["start"]) <= POS_TOLERANCE and \
666
- abs(p["end"] - g["end"]) <= POS_TOLERANCE:
667
- matched_gold[g_idx] = True
668
- matched_pred[p_idx] = True
669
- break
670
-
671
- # 统计指标
672
- tp += sum(matched_pred)
673
- fp += len(pred_entities) - sum(matched_pred)
674
- fn += len(gold_entities) - sum(matched_gold)
675
-
676
- # 处理除零情况
677
- precision = tp / (tp + fp) if (tp + fp) > 0 else 0
678
- recall = tp / (tp + fn) if (tp + fn) > 0 else 0
679
- f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
680
-
681
- return (f"Precision: {precision:.2f}\n"
682
- f"Recall: {recall:.2f}\n"
683
- f"F1: {f1:.2f}")
684
-
685
-
686
- def auto_annotate(file, model_type):
687
- data = convert_telegram_json_to_eval_format(file.name)
688
- for item in data:
689
- ents, _ = ner(item["text"], model_type)
690
- item["entities"] = ents
691
- return json.dumps(data, ensure_ascii=False, indent=2)
692
-
693
-
694
- def save_json(json_text):
695
- fname = f"auto_labeled_{int(time.time())}.json"
696
- with open(fname, "w", encoding="utf-8") as f:
697
- f.write(json_text)
698
- return fname
699
-
700
-
701
- # ======================== 数据集导入 ========================
702
- def import_dataset(path="D:/云边智算/暗语识别/filtered_results"):
703
- import os
704
- import json
705
-
706
- for filename in os.listdir(path):
707
- if filename.endswith('.json'):
708
- filepath = os.path.join(path, filename)
709
- with open(filepath, 'r', encoding='utf-8') as f:
710
- data = json.load(f)
711
- # 调用现有处理流程
712
- process_text(data['text'])
713
- print(f"已处理文件: {filename}")
714
-
715
-
716
- # ======================== Gradio 界面 ========================
717
- with gr.Blocks(css="""
718
- .kg-graph {height: 500px; overflow-y: auto;}
719
- .warning {color: #ff6b6b;}
720
- """) as demo:
721
- gr.Markdown("# 🤖 聊天记录实体关系识别系统")
722
-
723
- with gr.Tab("📄 文本分析"):
724
- input_text = gr.Textbox(lines=6, label="输入文本")
725
- model_type = gr.Radio(["bert", "chatglm"], value="bert", label="选择模型")
726
- btn = gr.Button("开始分析")
727
- out1 = gr.Textbox(label="识别实体")
728
- out2 = gr.Textbox(label="识别关系")
729
- out3 = gr.HTML(label="知识图谱")
730
- out4 = gr.Textbox(label="耗时")
731
- btn.click(fn=process_text, inputs=[input_text, model_type], outputs=[out1, out2, out3, out4])
732
-
733
- with gr.Tab("🗂 文件分析"):
734
- file_input = gr.File(file_types=[".txt", ".json"])
735
- file_btn = gr.Button("上传并分析")
736
- fout1 = gr.Textbox()
737
- fout2 = gr.Textbox()
738
- fout3 = gr.HTML()
739
- fout4 = gr.Textbox()
740
- file_btn.click(fn=process_file, inputs=[file_input, model_type], outputs=[fout1, fout2, fout3, fout4])
741
-
742
- with gr.Tab("📊 模型评估"):
743
- eval_file = gr.File(label="上传标注 JSON")
744
- eval_model = gr.Radio(["bert", "chatglm"], value="bert")
745
- eval_btn = gr.Button("开始评估")
746
- eval_output = gr.Textbox(label="评估结果", lines=5)
747
- eval_btn.click(lambda f, m: evaluate_ner_model(convert_telegram_json_to_eval_format(f.name), m),
748
- [eval_file, eval_model], eval_output)
749
-
750
- with gr.Tab("✏️ 自动标注"):
751
- raw_file = gr.File(label="上传 Telegram 原始 JSON")
752
- auto_model = gr.Radio(["bert", "chatglm"], value="bert")
753
- auto_btn = gr.Button("自动标注")
754
- marked_texts = gr.Textbox(label="标注结果", lines=20)
755
- download_btn = gr.Button("💾 下载标注文件")
756
- auto_btn.click(fn=auto_annotate, inputs=[raw_file, auto_model], outputs=marked_texts)
757
- download_btn.click(fn=save_json, inputs=marked_texts, outputs=gr.File())
758
-
759
- with gr.Tab("📂 数据管理"):
760
- gr.Markdown("### 数据集导入")
761
- dataset_path = gr.Textbox(
762
- value="D:/云边智算/暗语识别/filtered_results",
763
- label="数据集路径"
764
- )
765
- import_btn = gr.Button("导入数据集到数据库")
766
- import_output = gr.Textbox(label="导入日志")
767
- import_btn.click(fn=lambda: import_dataset(dataset_path.value), outputs=import_output)
768
-
 
 
769
  demo.launch(server_name="0.0.0.0", server_port=7860)
 
1
+ import torch
2
+ from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline, AutoModel
3
+ import gradio as gr
4
+ import re
5
+ import os
6
+ import json
7
+ import chardet
8
+ from sklearn.metrics import precision_score, recall_score, f1_score
9
+ import time
10
+ from functools import lru_cache # 添加这行导入
11
+ # ======================== 数据库模块 ========================
12
+ from sqlalchemy import create_engine
13
+ from sqlalchemy.orm import sessionmaker
14
+ from contextlib import contextmanager
15
+ import logging
16
+ import networkx as nx
17
+ from pyvis.network import Network
18
+ import pandas as pd
19
+
20
+ # 配置日志
21
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
22
+
23
+ # 使用SQLAlchemy的连接池来管理数据库连接
24
+ DATABASE_URL = "mysql+pymysql://user:password@host/dbname" # 请根据实际情况修改连接字符串
25
+
26
+ # 创建引擎(连接池)
27
+ engine = create_engine(DATABASE_URL, pool_size=10, max_overflow=20, echo=True)
28
+
29
+ # 创建session类
30
+ Session = sessionmaker(bind=engine)
31
+
32
+ @contextmanager
33
+ def get_db_connection():
34
+ """
35
+ 使用上下文管理器获取数据库连接
36
+ """
37
+ session = None
38
+ try:
39
+ session = Session() # 从连接池中获取一个连接
40
+ logging.info("✅ 数据库连接已建立")
41
+ yield session # 使用session进行数据库操作
42
+ except Exception as e:
43
+ logging.error(f"❌ 数据库操作时发生错误: {e}")
44
+ if session:
45
+ session.rollback() # 回滚事务
46
+ finally:
47
+ if session:
48
+ try:
49
+ session.commit() # 提交事务
50
+ logging.info("✅ 数据库事务已提交")
51
+ except Exception as e:
52
+ logging.error(f"❌ 提交事务时发生错误: {e}")
53
+ finally:
54
+ session.close() # 关闭会话,释放连接
55
+ logging.info("✅ 数据库连接已关闭")
56
+
57
+ def save_to_db(table, data):
58
+ """
59
+ 将数据保存到数据库
60
+ :param table: 表名
61
+ :param data: 数据字典
62
+ """
63
+ try:
64
+ valid_tables = ["entities", "relations"] # 只允许保存到这些表
65
+ if table not in valid_tables:
66
+ raise ValueError(f"Invalid table: {table}")
67
+
68
+ with get_db_connection() as conn:
69
+ if conn:
70
+ # 这里的操作假设使用了ORM模型来处理插入,实际根据你数据库的表结构来调整
71
+ table_model = get_table_model(table) # 假设你有一个方法来根据表名获得ORM模型
72
+ new_record = table_model(**data)
73
+ conn.add(new_record)
74
+ conn.commit() # 提交事务
75
+ except Exception as e:
76
+ logging.error(f"❌ 保存数据时发生错误: {e}")
77
+ return False
78
+ return True
79
+
80
+ def get_table_model(table_name):
81
+ """
82
+ 根据表名获取ORM模型(这里假设你有一个映射到数据库表的模型)
83
+ :param table_name: 表名
84
+ :return: 对应的ORM模型
85
+ """
86
+ if table_name == "entities":
87
+ from models import Entity # 假设你已经定义了ORM模型
88
+ return Entity
89
+ elif table_name == "relations":
90
+ from models import Relation # 假设你已经定义了ORM模型
91
+ return Relation
92
+ else:
93
+ raise ValueError(f"Unknown table: {table_name}")
94
+
95
+
96
+ # ======================== 模型加载 ========================
97
+ NER_MODEL_NAME = "uer/roberta-base-finetuned-cluener2020-chinese"
98
+
99
+ @lru_cache(maxsize=1)
100
+ def get_ner_pipeline():
101
+ tokenizer = AutoTokenizer.from_pretrained(NER_MODEL_NAME)
102
+ model = AutoModelForTokenClassification.from_pretrained(NER_MODEL_NAME)
103
+ return pipeline(
104
+ "ner",
105
+ model=model,
106
+ tokenizer=tokenizer,
107
+ aggregation_strategy="first"
108
+ )
109
+
110
+ @lru_cache(maxsize=1)
111
+ def get_re_pipeline():
112
+ return pipeline(
113
+ "text2text-generation",
114
+ model=NER_MODEL_NAME,
115
+ tokenizer=NER_MODEL_NAME,
116
+ max_length=512,
117
+ device=0 if torch.cuda.is_available() else -1
118
+ )
119
+
120
+
121
+ # chatglm_model, chatglm_tokenizer = None, None
122
+ # use_chatglm = False
123
+ # try:
124
+ # chatglm_model_name = "THUDM/chatglm-6b-int4"
125
+ # chatglm_tokenizer = AutoTokenizer.from_pretrained(chatglm_model_name, trust_remote_code=True)
126
+ # chatglm_model = AutoModel.from_pretrained(
127
+ # chatglm_model_name,
128
+ # trust_remote_code=True,
129
+ # device_map="cpu",
130
+ # torch_dtype=torch.float32
131
+ # ).eval()
132
+ # use_chatglm = True
133
+ # print("✅ 4-bit量化版ChatGLM加载成功")
134
+ # except Exception as e:
135
+ # print(f"❌ ChatGLM加载失败: {e}")
136
+
137
+ # ======================== 知识图谱结构 ========================
138
+ knowledge_graph = {"entities": set(), "relations": set()}
139
+
140
+
141
+ def update_knowledge_graph(entities, relations):
142
+ # 保存实体
143
+ for e in entities:
144
+ if isinstance(e, dict) and 'text' in e and 'type' in e:
145
+ save_to_db('entities', {
146
+ 'text': e['text'],
147
+ 'type': e['type'],
148
+ 'start_pos': e.get('start', -1),
149
+ 'end_pos': e.get('end', -1),
150
+ 'source': 'user_input'
151
+ })
152
+
153
+ # 保存关系
154
+ for r in relations:
155
+ if isinstance(r, dict) and all(k in r for k in ("head", "tail", "relation")):
156
+ save_to_db('relations', {
157
+ 'head_entity': r['head'],
158
+ 'tail_entity': r['tail'],
159
+ 'relation_type': r['relation'],
160
+ 'source_text': '' # 可添加原文关联
161
+ })
162
+
163
+
164
+ def visualize_kg_text():
165
+ nodes = [f"{ent[0]} ({ent[1]})" for ent in knowledge_graph["entities"]]
166
+ edges = [f"{h} --[{r}]-> {t}" for h, t, r in knowledge_graph["relations"]]
167
+ return "\n".join(["📌 实体:"] + nodes + ["", "📎 关系:"] + edges)
168
+
169
+ def visualize_kg_interactive(entities, relations):
170
+ """
171
+ 生成交互式的知识图谱可视化
172
+ """
173
+ # 创建一个新的网络图
174
+ net = Network(height="500px", width="100%", bgcolor="#ffffff", font_color="black")
175
+
176
+ # 添加节点
177
+ entity_colors = {
178
+ 'PER': '#FF6B6B', # 人物-红色
179
+ 'ORG': '#4ECDC4', # 组织-青色
180
+ 'LOC': '#45B7D1', # 地点-蓝色
181
+ 'TIME': '#96CEB4', # 时间-绿色
182
+ 'MISC': '#D4A5A5' # 其他-灰色
183
+ }
184
+
185
+ # 添加实体节点
186
+ for entity in entities:
187
+ node_color = entity_colors.get(entity['type'], '#D3D3D3')
188
+ net.add_node(entity['text'],
189
+ label=f"{entity['text']}\n({entity['type']})",
190
+ color=node_color,
191
+ title=f"类型: {entity['type']}")
192
+
193
+ # 添加关系边
194
+ for relation in relations:
195
+ net.add_edge(relation['head'],
196
+ relation['tail'],
197
+ label=relation['relation'],
198
+ arrows='to')
199
+
200
+ # 设置物理布局
201
+ net.set_options('''
202
+ var options = {
203
+ "physics": {
204
+ "forceAtlas2Based": {
205
+ "gravitationalConstant": -50,
206
+ "centralGravity": 0.01,
207
+ "springLength": 100,
208
+ "springConstant": 0.08
209
+ },
210
+ "maxVelocity": 50,
211
+ "solver": "forceAtlas2Based",
212
+ "timestep": 0.35,
213
+ "stabilization": {"iterations": 150}
214
+ }
215
+ }
216
+ ''')
217
+
218
+ # 生成HTML文件
219
+ html_path = "knowledge_graph.html"
220
+ net.save_graph(html_path)
221
+ return html_path
222
+
223
+ # ======================== 实体识别(NER) ========================
224
+ def merge_adjacent_entities(entities):
225
+ if not entities:
226
+ return entities
227
+
228
+ merged = [entities[0]]
229
+ for entity in entities[1:]:
230
+ last = merged[-1]
231
+ # 合并相邻的同类型实体
232
+ if (entity["type"] == last["type"] and
233
+ entity["start"] == last["end"]):
234
+ last["text"] += entity["text"]
235
+ last["end"] = entity["end"]
236
+ else:
237
+ merged.append(entity)
238
+
239
+ return merged
240
+
241
+
242
+ def ner(text, model_type="bert"):
243
+ start_time = time.time()
244
+
245
+ # 如果使用的是 ChatGLM 模型,执行 ChatGLM 的NER
246
+ if model_type == "chatglm" and use_chatglm:
247
+ try:
248
+ prompt = f"""请从以下文本中识别所有实体,严格按照JSON列表格式返回,每个实体包含text、type、start、end字段:
249
+ 示例:[{{"text": "北京", "type": "LOC", "start": 0, "end": 2}}]
250
+ 文本:{text}"""
251
+ response = chatglm_model.chat(chatglm_tokenizer, prompt, temperature=0.1)
252
+ if isinstance(response, tuple):
253
+ response = response[0]
254
+
255
+ try:
256
+ json_str = re.search(r'\[.*\]', response, re.DOTALL).group()
257
+ entities = json.loads(json_str)
258
+ valid_entities = [ent for ent in entities if all(k in ent for k in ("text", "type", "start", "end"))]
259
+ return valid_entities, time.time() - start_time
260
+ except Exception as e:
261
+ print(f"JSON解析失败: {e}")
262
+ return [], time.time() - start_time
263
+ except Exception as e:
264
+ print(f"ChatGLM调用失败: {e}")
265
+ return [], time.time() - start_time
266
+
267
+ # 使用BERT NER
268
+ text_chunks = [text[i:i + 510] for i in range(0, len(text), 510)] # 安全分段
269
+ raw_results = []
270
+
271
+ # 获取NER pipeline
272
+ ner_pipeline = get_ner_pipeline() # 使用缓存的pipeline
273
+
274
+ for idx, chunk in enumerate(text_chunks):
275
+ chunk_results = ner_pipeline(chunk) # 使用获取的pipeline
276
+ for r in chunk_results:
277
+ r["start"] += idx * 510
278
+ r["end"] += idx * 510
279
+ raw_results.extend(chunk_results)
280
+
281
+ entities = [{
282
+ "text": r['word'].replace(' ', ''),
283
+ "start": r['start'],
284
+ "end": r['end'],
285
+ "type": LABEL_MAPPING.get(r.get('entity_group') or r.get('entity'), r.get('entity_group') or r.get('entity'))
286
+ } for r in raw_results]
287
+
288
+ entities = merge_adjacent_entities(entities)
289
+ return entities, time.time() - start_time
290
+
291
+
292
+ # ------------------ 实体类型标准化 ------------------
293
+ LABEL_MAPPING = {
294
+ "address": "LOC",
295
+ "company": "ORG",
296
+ "name": "PER",
297
+ "organization": "ORG",
298
+ "position": "TITLE",
299
+ "government": "ORG",
300
+ "scene": "LOC",
301
+ "book": "WORK",
302
+ "movie": "WORK",
303
+ "game": "WORK"
304
+ }
305
+
306
+ # 提取实体
307
+ entities, processing_time = ner("Google in New York met Alice")
308
+
309
+ # 标准化实体类型
310
+ for e in entities:
311
+ e["type"] = LABEL_MAPPING.get(e.get("type"), e.get("type"))
312
+
313
+ # 打印标准化后的实体
314
+ print(f"[DEBUG] 标准化后实体列表: {[{'text': e['text'], 'type': e['type']} for e in entities]}")
315
+
316
+ # 打印处理时间
317
+ print(f"处理时间: {processing_time:.2f}秒")
318
+
319
+
320
+ # ======================== 关系抽取(RE) ========================
321
+ @lru_cache(maxsize=1)
322
+ def get_re_pipeline():
323
+ tokenizer = AutoTokenizer.from_pretrained(NER_MODEL_NAME)
324
+ model = AutoModelForTokenClassification.from_pretrained(NER_MODEL_NAME)
325
+ return pipeline(
326
+ "ner", # 使用NER pipeline
327
+ model=model,
328
+ tokenizer=tokenizer,
329
+ aggregation_strategy="first"
330
+ )
331
+
332
+ def re_extract(entities, text, use_bert_model=True):
333
+ if not entities or not text:
334
+ return [], 0
335
+
336
+ start_time = time.time()
337
+ try:
338
+ # 使用规则匹配关系
339
+ relations = []
340
+
341
+ # 定义关系关键词和对应的实体类型约束
342
+ relation_rules = {
343
+ "位于": {
344
+ "keywords": ["位于", "在", "��落于"],
345
+ "valid_types": {
346
+ "head": ["ORG", "PER", "LOC"],
347
+ "tail": ["LOC"]
348
+ }
349
+ },
350
+ "属于": {
351
+ "keywords": ["属于", "是", "为"],
352
+ "valid_types": {
353
+ "head": ["ORG", "PER"],
354
+ "tail": ["ORG", "LOC"]
355
+ }
356
+ },
357
+ "任职于": {
358
+ "keywords": ["任职于", "就职于", "工作于"],
359
+ "valid_types": {
360
+ "head": ["PER"],
361
+ "tail": ["ORG"]
362
+ }
363
+ }
364
+ }
365
+
366
+ # 预处理实体,去除重复和部分匹配
367
+ processed_entities = []
368
+ for e in entities:
369
+ # 检查是否与已有实体重叠
370
+ is_subset = False
371
+ for pe in processed_entities:
372
+ if e["text"] in pe["text"] and e["text"] != pe["text"]:
373
+ is_subset = True
374
+ break
375
+ if not is_subset:
376
+ processed_entities.append(e)
377
+
378
+ # 遍历文本中的每个句子
379
+ sentences = re.split('[。!?.!?]', text)
380
+ for sentence in sentences:
381
+ if not sentence.strip():
382
+ continue
383
+
384
+ # 获取当前句子中的实体
385
+ sentence_entities = [e for e in processed_entities if e["text"] in sentence]
386
+
387
+ # 检查每个关系类型
388
+ for rel_type, rule in relation_rules.items():
389
+ for keyword in rule["keywords"]:
390
+ if keyword in sentence:
391
+ # 在句子中查找符合类型约束的实体对
392
+ for i, ent1 in enumerate(sentence_entities):
393
+ for j, ent2 in enumerate(sentence_entities):
394
+ if i != j: # 避免自循环
395
+ # 检查实体类型是否符合规则
396
+ if (ent1["type"] in rule["valid_types"]["head"] and
397
+ ent2["type"] in rule["valid_types"]["tail"]):
398
+ # 检查实体在句子中的位置关系
399
+ if sentence.find(ent1["text"]) < sentence.find(ent2["text"]):
400
+ relations.append({
401
+ "head": ent1["text"],
402
+ "tail": ent2["text"],
403
+ "relation": rel_type
404
+ })
405
+
406
+ # 去重
407
+ unique_relations = []
408
+ seen = set()
409
+ for rel in relations:
410
+ rel_key = (rel["head"], rel["tail"], rel["relation"])
411
+ if rel_key not in seen:
412
+ seen.add(rel_key)
413
+ unique_relations.append(rel)
414
+
415
+ return unique_relations, time.time() - start_time
416
+
417
+ except Exception as e:
418
+ logging.error(f"关系抽取失败: {e}")
419
+ return [], time.time() - start_time
420
+
421
+
422
+ # ======================== 文本分析主流程 ========================
423
+ def create_knowledge_graph(entities, relations):
424
+ """
425
+ 创建交互式网络图形式的知识图谱
426
+ """
427
+ # 创建一个新的网络图
428
+ net = Network(height="600px", width="100%", bgcolor="#ffffff", font_color="black", directed=True)
429
+
430
+ # 设置实体类型的颜色映射
431
+ entity_colors = {
432
+ 'PER': '#FF6B6B', # 人物-红色
433
+ 'ORG': '#4ECDC4', # 组织-青色
434
+ 'LOC': '#45B7D1', # 地点-蓝色
435
+ 'TIME': '#96CEB4', # 时间-绿色
436
+ 'TITLE': '#D4A5A5' # 职位-粉色
437
+ }
438
+
439
+ # 添加实体节点
440
+ added_nodes = set()
441
+ for entity in entities:
442
+ if entity['text'] not in added_nodes:
443
+ node_color = entity_colors.get(entity['type'], '#D3D3D3')
444
+ net.add_node(
445
+ entity['text'],
446
+ label=entity['text'],
447
+ title=f"类型: {entity['type']}",
448
+ color=node_color,
449
+ size=20,
450
+ font={'size': 16}
451
+ )
452
+ added_nodes.add(entity['text'])
453
+
454
+ # 添加关系边
455
+ for relation in relations:
456
+ if relation['head'] in added_nodes and relation['tail'] in added_nodes:
457
+ net.add_edge(
458
+ relation['head'],
459
+ relation['tail'],
460
+ label=relation['relation'],
461
+ title=relation['relation'],
462
+ arrows={'to': {'enabled': True, 'type': 'arrow'}},
463
+ color={'color': '#666666'},
464
+ font={'size': 12}
465
+ )
466
+
467
+ # 设置物理布局参数
468
+ net.set_options('''
469
+ {
470
+ "nodes": {
471
+ "shape": "dot",
472
+ "shadow": true
473
+ },
474
+ "edges": {
475
+ "smooth": {
476
+ "type": "continuous",
477
+ "forceDirection": "none"
478
+ },
479
+ "shadow": true
480
+ },
481
+ "physics": {
482
+ "barnesHut": {
483
+ "gravitationalConstant": -2000,
484
+ "centralGravity": 0.3,
485
+ "springLength": 200,
486
+ "springConstant": 0.04,
487
+ "damping": 0.09
488
+ },
489
+ "minVelocity": 0.75
490
+ },
491
+ "interaction": {
492
+ "hover": true,
493
+ "navigationButtons": true,
494
+ "keyboard": true
495
+ }
496
+ }
497
+ ''')
498
+
499
+ try:
500
+ # 创建临时目录(如果不存在)
501
+ temp_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "temp")
502
+ os.makedirs(temp_dir, exist_ok=True)
503
+
504
+ # 生成唯一的文件名
505
+ timestamp = int(time.time())
506
+ html_file = os.path.join(temp_dir, f"kg_{timestamp}.html")
507
+
508
+ # 保存HTML文件
509
+ net.save_graph(html_file)
510
+
511
+ # 读取HTML内容
512
+ with open(html_file, 'r', encoding='utf-8') as f:
513
+ html_content = f.read()
514
+
515
+ # 添加图例
516
+ legend_html = f"""
517
+ <div style="margin-bottom: 10px; padding: 10px; background-color: #f8f9fa; border-radius: 5px;">
518
+ <div style="font-weight: bold; margin-bottom: 5px;">图例说明:</div>
519
+ <div style="display: flex; gap: 15px; flex-wrap: wrap;">
520
+ <div style="display: flex; align-items: center; gap: 5px;">
521
+ <div style="width: 15px; height: 15px; background: {entity_colors['PER']}; border-radius: 50%;"></div>
522
+ <span>人物 (PER)</span>
523
+ </div>
524
+ <div style="display: flex; align-items: center; gap: 5px;">
525
+ <div style="width: 15px; height: 15px; background: {entity_colors['ORG']}; border-radius: 50%;"></div>
526
+ <span>组织 (ORG)</span>
527
+ </div>
528
+ <div style="display: flex; align-items: center; gap: 5px;">
529
+ <div style="width: 15px; height: 15px; background: {entity_colors['LOC']}; border-radius: 50%;"></div>
530
+ <span>地点 (LOC)</span>
531
+ </div>
532
+ <div style="display: flex; align-items: center; gap: 5px;">
533
+ <div style="width: 15px; height: 15px; background: {entity_colors['TITLE']}; border-radius: 50%;"></div>
534
+ <span>职位 (TITLE)</span>
535
+ </div>
536
+ </div>
537
+ </div>
538
+ """
539
+
540
+ # 将图例添加到HTML内容中
541
+ html_content = legend_html + html_content
542
+
543
+ # 清理旧的临时文件
544
+ for old_file in os.listdir(temp_dir):
545
+ if old_file.startswith("kg_") and old_file.endswith(".html"):
546
+ old_path = os.path.join(temp_dir, old_file)
547
+ if os.path.getmtime(old_path) < time.time() - 3600: # 删除1小时前的文件
548
+ try:
549
+ os.remove(old_path)
550
+ except:
551
+ pass
552
+
553
+ return html_content
554
+
555
+ except Exception as e:
556
+ logging.error(f"生成知识图谱失败: {str(e)}")
557
+ return f"<div class='error'>生成知识图谱失败: {str(e)}</div>"
558
+
559
+ def process_text(text, model_type="bert"):
560
+ """
561
+ 处理文本,进行实体识别和关系抽取
562
+ """
563
+ start_time = time.time()
564
+
565
+ # 实体识别
566
+ entities, ner_duration = ner(text, model_type)
567
+ if not entities:
568
+ return "", "", "", f"{time.time() - start_time:.2f} 秒"
569
+
570
+ # 关系抽取
571
+ relations, re_duration = re_extract(entities, text)
572
+
573
+ # 生成文本格式的实体和关系描述
574
+ ent_text = "📌 实体:\n" + "\n".join([f"{e['text']} ({e['type']})" for e in entities])
575
+ rel_text = "\n\n📎 关系:\n" + "\n".join([f"{r['head']} --[{r['relation']}]--> {r['tail']}" for r in relations])
576
+
577
+ # 生成知识图谱
578
+ kg_text = create_knowledge_graph(entities, relations)
579
+
580
+ total_duration = time.time() - start_time
581
+ return ent_text, rel_text, kg_text, f"{total_duration:.2f} 秒"
582
+
583
+
584
+ def process_file(file, model_type="bert"):
585
+ try:
586
+ with open(file.name, 'rb') as f:
587
+ content = f.read()
588
+
589
+ if len(content) > 5 * 1024 * 1024:
590
+ return "❌ 文件太大", "", "", ""
591
+
592
+ # 检测编码
593
+ try:
594
+ encoding = chardet.detect(content)['encoding'] or 'utf-8'
595
+ text = content.decode(encoding)
596
+ except UnicodeDecodeError:
597
+ # 尝试常见中文编码
598
+ for enc in ['gb18030', 'utf-16', 'big5']:
599
+ try:
600
+ text = content.decode(enc)
601
+ break
602
+ except:
603
+ continue
604
+ else:
605
+ return "❌ 编码解析失败", "", "", ""
606
+
607
+ # 直接调用process_text处理文本
608
+ return process_text(text, model_type)
609
+
610
+ except Exception as e:
611
+ logging.error(f"文件处理错误: {str(e)}")
612
+ return f"❌ 文件处理错误: {str(e)}", "", "", ""
613
+
614
+
615
+
616
+ # ======================== 模型评估与自动标注 ========================
617
+ def convert_telegram_json_to_eval_format(path):
618
+ with open(path, encoding="utf-8") as f:
619
+ data = json.load(f)
620
+ if isinstance(data, dict) and "text" in data:
621
+ return [{"text": data["text"], "entities": [
622
+ {"text": data["text"][e["start"]:e["end"]]} for e in data.get("entities", [])
623
+ ]}]
624
+ elif isinstance(data, list):
625
+ return data
626
+ elif isinstance(data, dict) and "messages" in data:
627
+ result = []
628
+ for m in data.get("messages", []):
629
+ if isinstance(m.get("text"), str):
630
+ result.append({"text": m["text"], "entities": []})
631
+ elif isinstance(m.get("text"), list):
632
+ txt = ''.join([x["text"] if isinstance(x, dict) else x for x in m["text"]])
633
+ result.append({"text": txt, "entities": []})
634
+ return result
635
+ return []
636
+
637
+
638
+ def evaluate_ner_model(data, model_type):
639
+ tp, fp, fn = 0, 0, 0
640
+ POS_TOLERANCE = 1
641
+
642
+ for item in data:
643
+ text = item["text"]
644
+ # 处理标注数据
645
+ gold_entities = []
646
+ for e in item.get("entities", []):
647
+ if "text" in e and "type" in e:
648
+ norm_type = LABEL_MAPPING.get(e["type"], e["type"])
649
+ gold_entities.append({
650
+ "text": e["text"],
651
+ "type": norm_type,
652
+ "start": e.get("start", -1),
653
+ "end": e.get("end", -1)
654
+ })
655
+
656
+ # 获取预测结果
657
+ pred_entities, _ = ner(text, model_type)
658
+
659
+ # 初始化匹配状态
660
+ matched_gold = [False] * len(gold_entities)
661
+ matched_pred = [False] * len(pred_entities)
662
+
663
+ # 遍历预测实体寻找匹配
664
+ for p_idx, p in enumerate(pred_entities):
665
+ for g_idx, g in enumerate(gold_entities):
666
+ if not matched_gold[g_idx] and \
667
+ p["text"] == g["text"] and \
668
+ p["type"] == g["type"] and \
669
+ abs(p["start"] - g["start"]) <= POS_TOLERANCE and \
670
+ abs(p["end"] - g["end"]) <= POS_TOLERANCE:
671
+ matched_gold[g_idx] = True
672
+ matched_pred[p_idx] = True
673
+ break
674
+
675
+ # 统计指标
676
+ tp += sum(matched_pred)
677
+ fp += len(pred_entities) - sum(matched_pred)
678
+ fn += len(gold_entities) - sum(matched_gold)
679
+
680
+ # 处理除零情况
681
+ precision = tp / (tp + fp) if (tp + fp) > 0 else 0
682
+ recall = tp / (tp + fn) if (tp + fn) > 0 else 0
683
+ f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
684
+
685
+ return (f"Precision: {precision:.2f}\n"
686
+ f"Recall: {recall:.2f}\n"
687
+ f"F1: {f1:.2f}")
688
+
689
+
690
+ def auto_annotate(file, model_type):
691
+ data = convert_telegram_json_to_eval_format(file.name)
692
+ for item in data:
693
+ ents, _ = ner(item["text"], model_type)
694
+ item["entities"] = ents
695
+ return json.dumps(data, ensure_ascii=False, indent=2)
696
+
697
+
698
+ def save_json(json_text):
699
+ fname = f"auto_labeled_{int(time.time())}.json"
700
+ with open(fname, "w", encoding="utf-8") as f:
701
+ f.write(json_text)
702
+ return fname
703
+
704
+
705
+ # ======================== 数据集导入 ========================
706
+ def import_dataset(path="D:/云边智算/暗语识别/filtered_results"):
707
+ import os
708
+ import json
709
+
710
+ for filename in os.listdir(path):
711
+ if filename.endswith('.json'):
712
+ filepath = os.path.join(path, filename)
713
+ with open(filepath, 'r', encoding='utf-8') as f:
714
+ data = json.load(f)
715
+ # 调用现有处理流程
716
+ process_text(data['text'])
717
+ print(f"已处理文件: {filename}")
718
+
719
+
720
+ # ======================== Gradio 界面 ========================
721
+ with gr.Blocks(css="""
722
+ .kg-graph {height: 700px; overflow-y: auto;}
723
+ .warning {color: #ff6b6b;}
724
+ .error {color: #ff0000; padding: 10px; background-color: #ffeeee; border-radius: 5px;}
725
+ """) as demo:
726
+ gr.Markdown("# 🤖 聊天记录实体关系识别系统")
727
+
728
+ with gr.Tab("📄 文本分析"):
729
+ input_text = gr.Textbox(lines=6, label="输入文本")
730
+ model_type = gr.Radio(["bert", "chatglm"], value="bert", label="选择模型")
731
+ btn = gr.Button("开始分析")
732
+ out1 = gr.Textbox(label="识别实体")
733
+ out2 = gr.Textbox(label="识别关系")
734
+ out3 = gr.HTML(label="知识图谱")
735
+ out4 = gr.Textbox(label="耗时")
736
+ btn.click(fn=process_text, inputs=[input_text, model_type], outputs=[out1, out2, out3, out4])
737
+
738
+ with gr.Tab("🗂 文件分析"):
739
+ file_input = gr.File(file_types=[".txt", ".json"])
740
+ file_btn = gr.Button("上传并分析")
741
+ fout1, fout2, fout3, fout4 = gr.Textbox(), gr.Textbox(), gr.Textbox(), gr.Textbox()
742
+ file_btn.click(fn=process_file, inputs=[file_input, model_type], outputs=[fout1, fout2, fout3, fout4])
743
+
744
+ with gr.Tab("📊 模型评估"):
745
+ eval_file = gr.File(label="上传标注 JSON")
746
+ eval_model = gr.Radio(["bert", "chatglm"], value="bert")
747
+ eval_btn = gr.Button("开始评估")
748
+ eval_output = gr.Textbox(label="评估结果", lines=5)
749
+ eval_btn.click(lambda f, m: evaluate_ner_model(convert_telegram_json_to_eval_format(f.name), m),
750
+ [eval_file, eval_model], eval_output)
751
+
752
+ with gr.Tab("✏️ 自动标注"):
753
+ raw_file = gr.File(label="上传 Telegram 原始 JSON")
754
+ auto_model = gr.Radio(["bert", "chatglm"], value="bert")
755
+ auto_btn = gr.Button("自动标注")
756
+ marked_texts = gr.Textbox(label="标注结果", lines=20)
757
+ download_btn = gr.Button("💾 下载标注文件")
758
+ auto_btn.click(fn=auto_annotate, inputs=[raw_file, auto_model], outputs=marked_texts)
759
+ download_btn.click(fn=save_json, inputs=marked_texts, outputs=gr.File())
760
+
761
+ with gr.Tab("📂 数据管理"):
762
+ gr.Markdown("### 数据集导入")
763
+ dataset_path = gr.Textbox(
764
+ value="D:/云边智算/暗语识别/filtered_results",
765
+ label="数据集路径"
766
+ )
767
+ import_btn = gr.Button("导入数据集到数据库")
768
+ import_output = gr.Textbox(label="导入日志")
769
+ import_btn.click(fn=lambda: import_dataset(dataset_path.value), outputs=import_output)
770
+
771
  demo.launch(server_name="0.0.0.0", server_port=7860)