Spaces:
Sleeping
Sleeping
Upload app.py
Browse files
app.py
CHANGED
@@ -1,769 +1,771 @@
|
|
1 |
-
import torch
|
2 |
-
from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline, AutoModel
|
3 |
-
import gradio as gr
|
4 |
-
import re
|
5 |
-
import os
|
6 |
-
import json
|
7 |
-
import chardet
|
8 |
-
from sklearn.metrics import precision_score, recall_score, f1_score
|
9 |
-
import time
|
10 |
-
from functools import lru_cache # 添加这行导入
|
11 |
-
# ======================== 数据库模块 ========================
|
12 |
-
from sqlalchemy import create_engine
|
13 |
-
from sqlalchemy.orm import sessionmaker
|
14 |
-
from contextlib import contextmanager
|
15 |
-
import logging
|
16 |
-
import networkx as nx
|
17 |
-
from pyvis.network import Network
|
18 |
-
import pandas as pd
|
19 |
-
|
20 |
-
# 配置日志
|
21 |
-
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
22 |
-
|
23 |
-
# 使用SQLAlchemy的连接池来管理数据库连接
|
24 |
-
DATABASE_URL = "mysql+pymysql://user:password@host/dbname" # 请根据实际情况修改连接字符串
|
25 |
-
|
26 |
-
# 创建引擎(连接池)
|
27 |
-
engine = create_engine(DATABASE_URL, pool_size=10, max_overflow=20, echo=True)
|
28 |
-
|
29 |
-
# 创建session类
|
30 |
-
Session = sessionmaker(bind=engine)
|
31 |
-
|
32 |
-
@contextmanager
|
33 |
-
def get_db_connection():
|
34 |
-
"""
|
35 |
-
使用上下文管理器获取数据库连接
|
36 |
-
"""
|
37 |
-
session = None
|
38 |
-
try:
|
39 |
-
session = Session() # 从连接池中获取一个连接
|
40 |
-
logging.info("✅ 数据库连接已建立")
|
41 |
-
yield session # 使用session进行数据库操作
|
42 |
-
except Exception as e:
|
43 |
-
logging.error(f"❌ 数据库操作时发生错误: {e}")
|
44 |
-
if session:
|
45 |
-
session.rollback() # 回滚事务
|
46 |
-
finally:
|
47 |
-
if session:
|
48 |
-
try:
|
49 |
-
session.commit() # 提交事务
|
50 |
-
logging.info("✅ 数据库事务已提交")
|
51 |
-
except Exception as e:
|
52 |
-
logging.error(f"❌ 提交事务时发生错误: {e}")
|
53 |
-
finally:
|
54 |
-
session.close() # 关闭会话,释放连接
|
55 |
-
logging.info("✅ 数据库连接已关闭")
|
56 |
-
|
57 |
-
def save_to_db(table, data):
|
58 |
-
"""
|
59 |
-
将数据保存到数据库
|
60 |
-
:param table: 表名
|
61 |
-
:param data: 数据字典
|
62 |
-
"""
|
63 |
-
try:
|
64 |
-
valid_tables = ["entities", "relations"] # 只允许保存到这些表
|
65 |
-
if table not in valid_tables:
|
66 |
-
raise ValueError(f"Invalid table: {table}")
|
67 |
-
|
68 |
-
with get_db_connection() as conn:
|
69 |
-
if conn:
|
70 |
-
# 这里的操作假设使用了ORM模型来处理插入,实际根据你数据库的表结构来调整
|
71 |
-
table_model = get_table_model(table) # 假设你有一个方法来根据表名获得ORM模型
|
72 |
-
new_record = table_model(**data)
|
73 |
-
conn.add(new_record)
|
74 |
-
conn.commit() # 提交事务
|
75 |
-
except Exception as e:
|
76 |
-
logging.error(f"❌ 保存数据时发生错误: {e}")
|
77 |
-
return False
|
78 |
-
return True
|
79 |
-
|
80 |
-
def get_table_model(table_name):
|
81 |
-
"""
|
82 |
-
根据表名获取ORM模型(这里假设你有一个映射到数据库表的模型)
|
83 |
-
:param table_name: 表名
|
84 |
-
:return: 对应的ORM模型
|
85 |
-
"""
|
86 |
-
if table_name == "entities":
|
87 |
-
from models import Entity # 假设你已经定义了ORM模型
|
88 |
-
return Entity
|
89 |
-
elif table_name == "relations":
|
90 |
-
from models import Relation # 假设你已经定义了ORM模型
|
91 |
-
return Relation
|
92 |
-
else:
|
93 |
-
raise ValueError(f"Unknown table: {table_name}")
|
94 |
-
|
95 |
-
|
96 |
-
# ======================== 模型加载 ========================
|
97 |
-
NER_MODEL_NAME = "uer/roberta-base-finetuned-cluener2020-chinese"
|
98 |
-
|
99 |
-
@lru_cache(maxsize=1)
|
100 |
-
def get_ner_pipeline():
|
101 |
-
tokenizer = AutoTokenizer.from_pretrained(NER_MODEL_NAME)
|
102 |
-
model = AutoModelForTokenClassification.from_pretrained(NER_MODEL_NAME)
|
103 |
-
return pipeline(
|
104 |
-
"ner",
|
105 |
-
model=model,
|
106 |
-
tokenizer=tokenizer,
|
107 |
-
aggregation_strategy="first"
|
108 |
-
)
|
109 |
-
|
110 |
-
@lru_cache(maxsize=1)
|
111 |
-
def get_re_pipeline():
|
112 |
-
return pipeline(
|
113 |
-
"text2text-generation",
|
114 |
-
model=NER_MODEL_NAME,
|
115 |
-
tokenizer=NER_MODEL_NAME,
|
116 |
-
max_length=512,
|
117 |
-
device=0 if torch.cuda.is_available() else -1
|
118 |
-
)
|
119 |
-
|
120 |
-
|
121 |
-
# chatglm_model, chatglm_tokenizer = None, None
|
122 |
-
# use_chatglm = False
|
123 |
-
# try:
|
124 |
-
# chatglm_model_name = "THUDM/chatglm-6b-int4"
|
125 |
-
# chatglm_tokenizer = AutoTokenizer.from_pretrained(chatglm_model_name, trust_remote_code=True)
|
126 |
-
# chatglm_model = AutoModel.from_pretrained(
|
127 |
-
# chatglm_model_name,
|
128 |
-
# trust_remote_code=True,
|
129 |
-
# device_map="cpu",
|
130 |
-
# torch_dtype=torch.float32
|
131 |
-
# ).eval()
|
132 |
-
# use_chatglm = True
|
133 |
-
# print("✅ 4-bit量化版ChatGLM加载成功")
|
134 |
-
# except Exception as e:
|
135 |
-
# print(f"❌ ChatGLM加载失败: {e}")
|
136 |
-
|
137 |
-
# ======================== 知识图谱结构 ========================
|
138 |
-
knowledge_graph = {"entities": set(), "relations": set()}
|
139 |
-
|
140 |
-
|
141 |
-
def update_knowledge_graph(entities, relations):
|
142 |
-
# 保存实体
|
143 |
-
for e in entities:
|
144 |
-
if isinstance(e, dict) and 'text' in e and 'type' in e:
|
145 |
-
save_to_db('entities', {
|
146 |
-
'text': e['text'],
|
147 |
-
'type': e['type'],
|
148 |
-
'start_pos': e.get('start', -1),
|
149 |
-
'end_pos': e.get('end', -1),
|
150 |
-
'source': 'user_input'
|
151 |
-
})
|
152 |
-
|
153 |
-
# 保存关系
|
154 |
-
for r in relations:
|
155 |
-
if isinstance(r, dict) and all(k in r for k in ("head", "tail", "relation")):
|
156 |
-
save_to_db('relations', {
|
157 |
-
'head_entity': r['head'],
|
158 |
-
'tail_entity': r['tail'],
|
159 |
-
'relation_type': r['relation'],
|
160 |
-
'source_text': '' # 可添加原文关联
|
161 |
-
})
|
162 |
-
|
163 |
-
|
164 |
-
def visualize_kg_text():
|
165 |
-
nodes = [f"{ent[0]} ({ent[1]})" for ent in knowledge_graph["entities"]]
|
166 |
-
edges = [f"{h} --[{r}]-> {t}" for h, t, r in knowledge_graph["relations"]]
|
167 |
-
return "\n".join(["📌 实体:"] + nodes + ["", "📎 关系:"] + edges)
|
168 |
-
|
169 |
-
def visualize_kg_interactive(entities, relations):
|
170 |
-
"""
|
171 |
-
生成交互式的知识图谱可视化
|
172 |
-
"""
|
173 |
-
# 创建一个新的网络图
|
174 |
-
net = Network(height="500px", width="100%", bgcolor="#ffffff", font_color="black")
|
175 |
-
|
176 |
-
# 添加节点
|
177 |
-
entity_colors = {
|
178 |
-
'PER': '#FF6B6B', # 人物-红色
|
179 |
-
'ORG': '#4ECDC4', # 组织-青色
|
180 |
-
'LOC': '#45B7D1', # 地点-蓝色
|
181 |
-
'TIME': '#96CEB4', # 时间-绿色
|
182 |
-
'MISC': '#D4A5A5' # 其他-灰色
|
183 |
-
}
|
184 |
-
|
185 |
-
# 添加实体节点
|
186 |
-
for entity in entities:
|
187 |
-
node_color = entity_colors.get(entity['type'], '#D3D3D3')
|
188 |
-
net.add_node(entity['text'],
|
189 |
-
label=f"{entity['text']}\n({entity['type']})",
|
190 |
-
color=node_color,
|
191 |
-
title=f"类型: {entity['type']}")
|
192 |
-
|
193 |
-
# 添加关系边
|
194 |
-
for relation in relations:
|
195 |
-
net.add_edge(relation['head'],
|
196 |
-
relation['tail'],
|
197 |
-
label=relation['relation'],
|
198 |
-
arrows='to')
|
199 |
-
|
200 |
-
# 设置物理布局
|
201 |
-
net.set_options('''
|
202 |
-
var options = {
|
203 |
-
"physics": {
|
204 |
-
"forceAtlas2Based": {
|
205 |
-
"gravitationalConstant": -50,
|
206 |
-
"centralGravity": 0.01,
|
207 |
-
"springLength": 100,
|
208 |
-
"springConstant": 0.08
|
209 |
-
},
|
210 |
-
"maxVelocity": 50,
|
211 |
-
"solver": "forceAtlas2Based",
|
212 |
-
"timestep": 0.35,
|
213 |
-
"stabilization": {"iterations": 150}
|
214 |
-
}
|
215 |
-
}
|
216 |
-
''')
|
217 |
-
|
218 |
-
# 生成HTML文件
|
219 |
-
html_path = "knowledge_graph.html"
|
220 |
-
net.save_graph(html_path)
|
221 |
-
return html_path
|
222 |
-
|
223 |
-
# ======================== 实体识别(NER) ========================
|
224 |
-
def merge_adjacent_entities(entities):
|
225 |
-
if not entities:
|
226 |
-
return entities
|
227 |
-
|
228 |
-
merged = [entities[0]]
|
229 |
-
for entity in entities[1:]:
|
230 |
-
last = merged[-1]
|
231 |
-
# 合并相邻的同类型实体
|
232 |
-
if (entity["type"] == last["type"] and
|
233 |
-
entity["start"] == last["end"]):
|
234 |
-
last["text"] += entity["text"]
|
235 |
-
last["end"] = entity["end"]
|
236 |
-
else:
|
237 |
-
merged.append(entity)
|
238 |
-
|
239 |
-
return merged
|
240 |
-
|
241 |
-
|
242 |
-
def ner(text, model_type="bert"):
|
243 |
-
start_time = time.time()
|
244 |
-
|
245 |
-
# 如果使用的是 ChatGLM 模型,执行 ChatGLM 的NER
|
246 |
-
if model_type == "chatglm" and use_chatglm:
|
247 |
-
try:
|
248 |
-
prompt = f"""请从以下文本中识别所有实体,严格按照JSON列表格式返回,每个实体包含text、type、start、end字段:
|
249 |
-
示例:[{{"text": "北京", "type": "LOC", "start": 0, "end": 2}}]
|
250 |
-
文本:{text}"""
|
251 |
-
response = chatglm_model.chat(chatglm_tokenizer, prompt, temperature=0.1)
|
252 |
-
if isinstance(response, tuple):
|
253 |
-
response = response[0]
|
254 |
-
|
255 |
-
try:
|
256 |
-
json_str = re.search(r'\[.*\]', response, re.DOTALL).group()
|
257 |
-
entities = json.loads(json_str)
|
258 |
-
valid_entities = [ent for ent in entities if all(k in ent for k in ("text", "type", "start", "end"))]
|
259 |
-
return valid_entities, time.time() - start_time
|
260 |
-
except Exception as e:
|
261 |
-
print(f"JSON解析失败: {e}")
|
262 |
-
return [], time.time() - start_time
|
263 |
-
except Exception as e:
|
264 |
-
print(f"ChatGLM调用失败: {e}")
|
265 |
-
return [], time.time() - start_time
|
266 |
-
|
267 |
-
# 使用BERT NER
|
268 |
-
text_chunks = [text[i:i + 510] for i in range(0, len(text), 510)] # 安全分段
|
269 |
-
raw_results = []
|
270 |
-
|
271 |
-
# 获取NER pipeline
|
272 |
-
ner_pipeline = get_ner_pipeline() # 使用缓存的pipeline
|
273 |
-
|
274 |
-
for idx, chunk in enumerate(text_chunks):
|
275 |
-
chunk_results = ner_pipeline(chunk) # 使用获取的pipeline
|
276 |
-
for r in chunk_results:
|
277 |
-
r["start"] += idx * 510
|
278 |
-
r["end"] += idx * 510
|
279 |
-
raw_results.extend(chunk_results)
|
280 |
-
|
281 |
-
entities = [{
|
282 |
-
"text": r['word'].replace(' ', ''),
|
283 |
-
"start": r['start'],
|
284 |
-
"end": r['end'],
|
285 |
-
"type": LABEL_MAPPING.get(r.get('entity_group') or r.get('entity'), r.get('entity_group') or r.get('entity'))
|
286 |
-
} for r in raw_results]
|
287 |
-
|
288 |
-
entities = merge_adjacent_entities(entities)
|
289 |
-
return entities, time.time() - start_time
|
290 |
-
|
291 |
-
|
292 |
-
# ------------------ 实体类型标准化 ------------------
|
293 |
-
LABEL_MAPPING = {
|
294 |
-
"address": "LOC",
|
295 |
-
"company": "ORG",
|
296 |
-
"name": "PER",
|
297 |
-
"organization": "ORG",
|
298 |
-
"position": "TITLE",
|
299 |
-
"government": "ORG",
|
300 |
-
"scene": "LOC",
|
301 |
-
"book": "WORK",
|
302 |
-
"movie": "WORK",
|
303 |
-
"game": "WORK"
|
304 |
-
}
|
305 |
-
|
306 |
-
# 提取实体
|
307 |
-
entities, processing_time = ner("Google in New York met Alice")
|
308 |
-
|
309 |
-
# 标准化实体类型
|
310 |
-
for e in entities:
|
311 |
-
e["type"] = LABEL_MAPPING.get(e.get("type"), e.get("type"))
|
312 |
-
|
313 |
-
# 打印标准化后的实体
|
314 |
-
print(f"[DEBUG] 标准化后实体列表: {[{'text': e['text'], 'type': e['type']} for e in entities]}")
|
315 |
-
|
316 |
-
# 打印处理时间
|
317 |
-
print(f"处理时间: {processing_time:.2f}秒")
|
318 |
-
|
319 |
-
|
320 |
-
# ======================== 关系抽取(RE) ========================
|
321 |
-
@lru_cache(maxsize=1)
|
322 |
-
def get_re_pipeline():
|
323 |
-
tokenizer = AutoTokenizer.from_pretrained(NER_MODEL_NAME)
|
324 |
-
model = AutoModelForTokenClassification.from_pretrained(NER_MODEL_NAME)
|
325 |
-
return pipeline(
|
326 |
-
"ner", # 使用NER pipeline
|
327 |
-
model=model,
|
328 |
-
tokenizer=tokenizer,
|
329 |
-
aggregation_strategy="first"
|
330 |
-
)
|
331 |
-
|
332 |
-
def re_extract(entities, text, use_bert_model=True):
|
333 |
-
if not entities or not text:
|
334 |
-
return [], 0
|
335 |
-
|
336 |
-
start_time = time.time()
|
337 |
-
try:
|
338 |
-
# 使用规则匹配关系
|
339 |
-
relations = []
|
340 |
-
|
341 |
-
# 定义关系关键词和对应的实体类型约束
|
342 |
-
relation_rules = {
|
343 |
-
"位于": {
|
344 |
-
"keywords": ["位于", "在", "
|
345 |
-
"valid_types": {
|
346 |
-
"head": ["ORG", "PER", "LOC"],
|
347 |
-
"tail": ["LOC"]
|
348 |
-
}
|
349 |
-
},
|
350 |
-
"属于": {
|
351 |
-
"keywords": ["属于", "是", "为"],
|
352 |
-
"valid_types": {
|
353 |
-
"head": ["ORG", "PER"],
|
354 |
-
"tail": ["ORG", "LOC"]
|
355 |
-
}
|
356 |
-
},
|
357 |
-
"任职于": {
|
358 |
-
"keywords": ["任职于", "就职于", "工作于"],
|
359 |
-
"valid_types": {
|
360 |
-
"head": ["PER"],
|
361 |
-
"tail": ["ORG"]
|
362 |
-
}
|
363 |
-
}
|
364 |
-
}
|
365 |
-
|
366 |
-
# 预处理实体,去除重复和部分匹配
|
367 |
-
processed_entities = []
|
368 |
-
for e in entities:
|
369 |
-
# 检查是否与已有实体重叠
|
370 |
-
is_subset = False
|
371 |
-
for pe in processed_entities:
|
372 |
-
if e["text"] in pe["text"] and e["text"] != pe["text"]:
|
373 |
-
is_subset = True
|
374 |
-
break
|
375 |
-
if not is_subset:
|
376 |
-
processed_entities.append(e)
|
377 |
-
|
378 |
-
# 遍历文本中的每个句子
|
379 |
-
sentences = re.split('[。!?.!?]', text)
|
380 |
-
for sentence in sentences:
|
381 |
-
if not sentence.strip():
|
382 |
-
continue
|
383 |
-
|
384 |
-
# 获取当前句子中的实体
|
385 |
-
sentence_entities = [e for e in processed_entities if e["text"] in sentence]
|
386 |
-
|
387 |
-
# 检查每个关系类型
|
388 |
-
for rel_type, rule in relation_rules.items():
|
389 |
-
for keyword in rule["keywords"]:
|
390 |
-
if keyword in sentence:
|
391 |
-
# 在句子中查找符合类型约束的实体对
|
392 |
-
for i, ent1 in enumerate(sentence_entities):
|
393 |
-
for j, ent2 in enumerate(sentence_entities):
|
394 |
-
if i != j: # 避免自循环
|
395 |
-
# 检查实体类型是否符合规则
|
396 |
-
if (ent1["type"] in rule["valid_types"]["head"] and
|
397 |
-
ent2["type"] in rule["valid_types"]["tail"]):
|
398 |
-
# 检查实体在句子中的位置关系
|
399 |
-
if sentence.find(ent1["text"]) < sentence.find(ent2["text"]):
|
400 |
-
relations.append({
|
401 |
-
"head": ent1["text"],
|
402 |
-
"tail": ent2["text"],
|
403 |
-
"relation": rel_type
|
404 |
-
})
|
405 |
-
|
406 |
-
# 去重
|
407 |
-
unique_relations = []
|
408 |
-
seen = set()
|
409 |
-
for rel in relations:
|
410 |
-
rel_key = (rel["head"], rel["tail"], rel["relation"])
|
411 |
-
if rel_key not in seen:
|
412 |
-
seen.add(rel_key)
|
413 |
-
unique_relations.append(rel)
|
414 |
-
|
415 |
-
return unique_relations, time.time() - start_time
|
416 |
-
|
417 |
-
except Exception as e:
|
418 |
-
logging.error(f"关系抽取失败: {e}")
|
419 |
-
return [], time.time() - start_time
|
420 |
-
|
421 |
-
|
422 |
-
# ======================== 文本分析主流程 ========================
|
423 |
-
def create_knowledge_graph(entities, relations):
|
424 |
-
"""
|
425 |
-
创建交互式网络图形式的知识图谱
|
426 |
-
"""
|
427 |
-
# 创建一个新的网络图
|
428 |
-
net = Network(height="600px", width="100%", bgcolor="#ffffff", font_color="black", directed=True)
|
429 |
-
|
430 |
-
# 设置实体类型的颜色映射
|
431 |
-
entity_colors = {
|
432 |
-
'PER': '#FF6B6B', # 人物-红色
|
433 |
-
'ORG': '#4ECDC4', # 组织-青色
|
434 |
-
'LOC': '#45B7D1', # 地点-蓝色
|
435 |
-
'TIME': '#96CEB4', # 时间-绿色
|
436 |
-
'TITLE': '#D4A5A5' # 职位-粉色
|
437 |
-
}
|
438 |
-
|
439 |
-
# 添加实体节点
|
440 |
-
added_nodes = set()
|
441 |
-
for entity in entities:
|
442 |
-
if entity['text'] not in added_nodes:
|
443 |
-
node_color = entity_colors.get(entity['type'], '#D3D3D3')
|
444 |
-
net.add_node(
|
445 |
-
entity['text'],
|
446 |
-
label=entity['text'],
|
447 |
-
title=f"类型: {entity['type']}",
|
448 |
-
color=node_color,
|
449 |
-
size=20,
|
450 |
-
font={'size': 16}
|
451 |
-
)
|
452 |
-
added_nodes.add(entity['text'])
|
453 |
-
|
454 |
-
# 添加关系边
|
455 |
-
for relation in relations:
|
456 |
-
if relation['head'] in added_nodes and relation['tail'] in added_nodes:
|
457 |
-
net.add_edge(
|
458 |
-
relation['head'],
|
459 |
-
relation['tail'],
|
460 |
-
label=relation['relation'],
|
461 |
-
title=relation['relation'],
|
462 |
-
arrows={'to': {'enabled': True, 'type': 'arrow'}},
|
463 |
-
color={'color': '#666666'},
|
464 |
-
font={'size': 12}
|
465 |
-
)
|
466 |
-
|
467 |
-
# 设置物理布局参数
|
468 |
-
net.set_options('''
|
469 |
-
{
|
470 |
-
"nodes": {
|
471 |
-
"shape": "dot",
|
472 |
-
"shadow": true
|
473 |
-
},
|
474 |
-
"edges": {
|
475 |
-
"smooth": {
|
476 |
-
"type": "continuous",
|
477 |
-
"forceDirection": "none"
|
478 |
-
},
|
479 |
-
"shadow": true
|
480 |
-
},
|
481 |
-
"physics": {
|
482 |
-
"barnesHut": {
|
483 |
-
"gravitationalConstant": -2000,
|
484 |
-
"centralGravity": 0.3,
|
485 |
-
"springLength": 200,
|
486 |
-
"springConstant": 0.04,
|
487 |
-
"damping": 0.09
|
488 |
-
},
|
489 |
-
"minVelocity": 0.75
|
490 |
-
},
|
491 |
-
"interaction": {
|
492 |
-
"hover": true,
|
493 |
-
"navigationButtons": true,
|
494 |
-
"keyboard": true
|
495 |
-
}
|
496 |
-
}
|
497 |
-
''')
|
498 |
-
|
499 |
-
|
500 |
-
|
501 |
-
|
502 |
-
|
503 |
-
|
504 |
-
|
505 |
-
|
506 |
-
|
507 |
-
|
508 |
-
#
|
509 |
-
net.save_graph(
|
510 |
-
|
511 |
-
#
|
512 |
-
with open(
|
513 |
-
html_content = f.read()
|
514 |
-
|
515 |
-
#
|
516 |
-
|
517 |
-
|
518 |
-
|
519 |
-
|
520 |
-
|
521 |
-
|
522 |
-
|
523 |
-
|
524 |
-
|
525 |
-
|
526 |
-
|
527 |
-
|
528 |
-
|
529 |
-
|
530 |
-
|
531 |
-
|
532 |
-
|
533 |
-
|
534 |
-
|
535 |
-
|
536 |
-
|
537 |
-
|
538 |
-
|
539 |
-
|
540 |
-
|
541 |
-
|
542 |
-
|
543 |
-
|
544 |
-
|
545 |
-
|
546 |
-
|
547 |
-
|
548 |
-
|
549 |
-
|
550 |
-
|
551 |
-
|
552 |
-
|
553 |
-
return
|
554 |
-
|
555 |
-
|
556 |
-
|
557 |
-
|
558 |
-
|
559 |
-
|
560 |
-
|
561 |
-
|
562 |
-
|
563 |
-
|
564 |
-
|
565 |
-
|
566 |
-
|
567 |
-
|
568 |
-
|
569 |
-
|
570 |
-
|
571 |
-
|
572 |
-
|
573 |
-
#
|
574 |
-
|
575 |
-
|
576 |
-
|
577 |
-
|
578 |
-
|
579 |
-
|
580 |
-
|
581 |
-
|
582 |
-
|
583 |
-
|
584 |
-
|
585 |
-
|
586 |
-
|
587 |
-
|
588 |
-
|
589 |
-
|
590 |
-
|
591 |
-
|
592 |
-
|
593 |
-
|
594 |
-
|
595 |
-
|
596 |
-
|
597 |
-
|
598 |
-
|
599 |
-
|
600 |
-
|
601 |
-
|
602 |
-
|
603 |
-
|
604 |
-
|
605 |
-
|
606 |
-
|
607 |
-
|
608 |
-
return
|
609 |
-
|
610 |
-
|
611 |
-
|
612 |
-
|
613 |
-
|
614 |
-
|
615 |
-
|
616 |
-
|
617 |
-
|
618 |
-
|
619 |
-
|
620 |
-
|
621 |
-
return data
|
622 |
-
|
623 |
-
|
624 |
-
|
625 |
-
|
626 |
-
|
627 |
-
|
628 |
-
|
629 |
-
|
630 |
-
|
631 |
-
|
632 |
-
|
633 |
-
|
634 |
-
|
635 |
-
|
636 |
-
|
637 |
-
|
638 |
-
|
639 |
-
|
640 |
-
|
641 |
-
|
642 |
-
|
643 |
-
|
644 |
-
|
645 |
-
|
646 |
-
|
647 |
-
|
648 |
-
|
649 |
-
|
650 |
-
|
651 |
-
|
652 |
-
|
653 |
-
|
654 |
-
|
655 |
-
|
656 |
-
|
657 |
-
|
658 |
-
|
659 |
-
#
|
660 |
-
|
661 |
-
|
662 |
-
|
663 |
-
|
664 |
-
|
665 |
-
|
666 |
-
|
667 |
-
|
668 |
-
|
669 |
-
|
670 |
-
|
671 |
-
|
672 |
-
|
673 |
-
|
674 |
-
|
675 |
-
|
676 |
-
|
677 |
-
|
678 |
-
|
679 |
-
|
680 |
-
|
681 |
-
|
682 |
-
|
683 |
-
|
684 |
-
|
685 |
-
|
686 |
-
|
687 |
-
|
688 |
-
|
689 |
-
|
690 |
-
|
691 |
-
|
692 |
-
|
693 |
-
|
694 |
-
|
695 |
-
|
696 |
-
|
697 |
-
|
698 |
-
|
699 |
-
|
700 |
-
|
701 |
-
|
702 |
-
|
703 |
-
|
704 |
-
|
705 |
-
|
706 |
-
|
707 |
-
|
708 |
-
|
709 |
-
|
710 |
-
|
711 |
-
|
712 |
-
|
713 |
-
|
714 |
-
|
715 |
-
|
716 |
-
|
717 |
-
|
718 |
-
|
719 |
-
|
720 |
-
|
721 |
-
|
722 |
-
|
723 |
-
|
724 |
-
|
725 |
-
|
726 |
-
|
727 |
-
|
728 |
-
|
729 |
-
|
730 |
-
|
731 |
-
btn
|
732 |
-
|
733 |
-
|
734 |
-
|
735 |
-
|
736 |
-
|
737 |
-
|
738 |
-
|
739 |
-
|
740 |
-
file_btn
|
741 |
-
|
742 |
-
|
743 |
-
|
744 |
-
|
745 |
-
|
746 |
-
|
747 |
-
eval_btn
|
748 |
-
|
749 |
-
|
750 |
-
|
751 |
-
|
752 |
-
|
753 |
-
|
754 |
-
|
755 |
-
|
756 |
-
|
757 |
-
download_btn
|
758 |
-
|
759 |
-
|
760 |
-
|
761 |
-
|
762 |
-
|
763 |
-
|
764 |
-
|
765 |
-
|
766 |
-
|
767 |
-
import_btn
|
768 |
-
|
|
|
|
|
769 |
demo.launch(server_name="0.0.0.0", server_port=7860)
|
|
|
1 |
+
import torch
|
2 |
+
from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline, AutoModel
|
3 |
+
import gradio as gr
|
4 |
+
import re
|
5 |
+
import os
|
6 |
+
import json
|
7 |
+
import chardet
|
8 |
+
from sklearn.metrics import precision_score, recall_score, f1_score
|
9 |
+
import time
|
10 |
+
from functools import lru_cache # 添加这行导入
|
11 |
+
# ======================== 数据库模块 ========================
|
12 |
+
from sqlalchemy import create_engine
|
13 |
+
from sqlalchemy.orm import sessionmaker
|
14 |
+
from contextlib import contextmanager
|
15 |
+
import logging
|
16 |
+
import networkx as nx
|
17 |
+
from pyvis.network import Network
|
18 |
+
import pandas as pd
|
19 |
+
|
20 |
+
# 配置日志
|
21 |
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
22 |
+
|
23 |
+
# 使用SQLAlchemy的连接池来管理数据库连接
|
24 |
+
DATABASE_URL = "mysql+pymysql://user:password@host/dbname" # 请根据实际情况修改连接字符串
|
25 |
+
|
26 |
+
# 创建引擎(连接池)
|
27 |
+
engine = create_engine(DATABASE_URL, pool_size=10, max_overflow=20, echo=True)
|
28 |
+
|
29 |
+
# 创建session类
|
30 |
+
Session = sessionmaker(bind=engine)
|
31 |
+
|
32 |
+
@contextmanager
|
33 |
+
def get_db_connection():
|
34 |
+
"""
|
35 |
+
使用上下文管理器获取数据库连接
|
36 |
+
"""
|
37 |
+
session = None
|
38 |
+
try:
|
39 |
+
session = Session() # 从连接池中获取一个连接
|
40 |
+
logging.info("✅ 数据库连接已建立")
|
41 |
+
yield session # 使用session进行数据库操作
|
42 |
+
except Exception as e:
|
43 |
+
logging.error(f"❌ 数据库操作时发生错误: {e}")
|
44 |
+
if session:
|
45 |
+
session.rollback() # 回滚事务
|
46 |
+
finally:
|
47 |
+
if session:
|
48 |
+
try:
|
49 |
+
session.commit() # 提交事务
|
50 |
+
logging.info("✅ 数据库事务已提交")
|
51 |
+
except Exception as e:
|
52 |
+
logging.error(f"❌ 提交事务时发生错误: {e}")
|
53 |
+
finally:
|
54 |
+
session.close() # 关闭会话,释放连接
|
55 |
+
logging.info("✅ 数据库连接已关闭")
|
56 |
+
|
57 |
+
def save_to_db(table, data):
|
58 |
+
"""
|
59 |
+
将数据保存到数据库
|
60 |
+
:param table: 表名
|
61 |
+
:param data: 数据字典
|
62 |
+
"""
|
63 |
+
try:
|
64 |
+
valid_tables = ["entities", "relations"] # 只允许保存到这些表
|
65 |
+
if table not in valid_tables:
|
66 |
+
raise ValueError(f"Invalid table: {table}")
|
67 |
+
|
68 |
+
with get_db_connection() as conn:
|
69 |
+
if conn:
|
70 |
+
# 这里的操作假设使用了ORM模型来处理插入,实际根据你数据库的表结构来调整
|
71 |
+
table_model = get_table_model(table) # 假设你有一个方法来根据表名获得ORM模型
|
72 |
+
new_record = table_model(**data)
|
73 |
+
conn.add(new_record)
|
74 |
+
conn.commit() # 提交事务
|
75 |
+
except Exception as e:
|
76 |
+
logging.error(f"❌ 保存数据时发生错误: {e}")
|
77 |
+
return False
|
78 |
+
return True
|
79 |
+
|
80 |
+
def get_table_model(table_name):
|
81 |
+
"""
|
82 |
+
根据表名获取ORM模型(这里假设你有一个映射到数据库表的模型)
|
83 |
+
:param table_name: 表名
|
84 |
+
:return: 对应的ORM模型
|
85 |
+
"""
|
86 |
+
if table_name == "entities":
|
87 |
+
from models import Entity # 假设你已经定义了ORM模型
|
88 |
+
return Entity
|
89 |
+
elif table_name == "relations":
|
90 |
+
from models import Relation # 假设你已经定义了ORM模型
|
91 |
+
return Relation
|
92 |
+
else:
|
93 |
+
raise ValueError(f"Unknown table: {table_name}")
|
94 |
+
|
95 |
+
|
96 |
+
# ======================== 模型加载 ========================
|
97 |
+
NER_MODEL_NAME = "uer/roberta-base-finetuned-cluener2020-chinese"
|
98 |
+
|
99 |
+
@lru_cache(maxsize=1)
|
100 |
+
def get_ner_pipeline():
|
101 |
+
tokenizer = AutoTokenizer.from_pretrained(NER_MODEL_NAME)
|
102 |
+
model = AutoModelForTokenClassification.from_pretrained(NER_MODEL_NAME)
|
103 |
+
return pipeline(
|
104 |
+
"ner",
|
105 |
+
model=model,
|
106 |
+
tokenizer=tokenizer,
|
107 |
+
aggregation_strategy="first"
|
108 |
+
)
|
109 |
+
|
110 |
+
@lru_cache(maxsize=1)
|
111 |
+
def get_re_pipeline():
|
112 |
+
return pipeline(
|
113 |
+
"text2text-generation",
|
114 |
+
model=NER_MODEL_NAME,
|
115 |
+
tokenizer=NER_MODEL_NAME,
|
116 |
+
max_length=512,
|
117 |
+
device=0 if torch.cuda.is_available() else -1
|
118 |
+
)
|
119 |
+
|
120 |
+
|
121 |
+
# chatglm_model, chatglm_tokenizer = None, None
|
122 |
+
# use_chatglm = False
|
123 |
+
# try:
|
124 |
+
# chatglm_model_name = "THUDM/chatglm-6b-int4"
|
125 |
+
# chatglm_tokenizer = AutoTokenizer.from_pretrained(chatglm_model_name, trust_remote_code=True)
|
126 |
+
# chatglm_model = AutoModel.from_pretrained(
|
127 |
+
# chatglm_model_name,
|
128 |
+
# trust_remote_code=True,
|
129 |
+
# device_map="cpu",
|
130 |
+
# torch_dtype=torch.float32
|
131 |
+
# ).eval()
|
132 |
+
# use_chatglm = True
|
133 |
+
# print("✅ 4-bit量化版ChatGLM加载成功")
|
134 |
+
# except Exception as e:
|
135 |
+
# print(f"❌ ChatGLM加载失败: {e}")
|
136 |
+
|
137 |
+
# ======================== 知识图谱结构 ========================
|
138 |
+
knowledge_graph = {"entities": set(), "relations": set()}
|
139 |
+
|
140 |
+
|
141 |
+
def update_knowledge_graph(entities, relations):
|
142 |
+
# 保存实体
|
143 |
+
for e in entities:
|
144 |
+
if isinstance(e, dict) and 'text' in e and 'type' in e:
|
145 |
+
save_to_db('entities', {
|
146 |
+
'text': e['text'],
|
147 |
+
'type': e['type'],
|
148 |
+
'start_pos': e.get('start', -1),
|
149 |
+
'end_pos': e.get('end', -1),
|
150 |
+
'source': 'user_input'
|
151 |
+
})
|
152 |
+
|
153 |
+
# 保存关系
|
154 |
+
for r in relations:
|
155 |
+
if isinstance(r, dict) and all(k in r for k in ("head", "tail", "relation")):
|
156 |
+
save_to_db('relations', {
|
157 |
+
'head_entity': r['head'],
|
158 |
+
'tail_entity': r['tail'],
|
159 |
+
'relation_type': r['relation'],
|
160 |
+
'source_text': '' # 可添加原文关联
|
161 |
+
})
|
162 |
+
|
163 |
+
|
164 |
+
def visualize_kg_text():
|
165 |
+
nodes = [f"{ent[0]} ({ent[1]})" for ent in knowledge_graph["entities"]]
|
166 |
+
edges = [f"{h} --[{r}]-> {t}" for h, t, r in knowledge_graph["relations"]]
|
167 |
+
return "\n".join(["📌 实体:"] + nodes + ["", "📎 关系:"] + edges)
|
168 |
+
|
169 |
+
def visualize_kg_interactive(entities, relations):
|
170 |
+
"""
|
171 |
+
生成交互式的知识图谱可视化
|
172 |
+
"""
|
173 |
+
# 创建一个新的网络图
|
174 |
+
net = Network(height="500px", width="100%", bgcolor="#ffffff", font_color="black")
|
175 |
+
|
176 |
+
# 添加节点
|
177 |
+
entity_colors = {
|
178 |
+
'PER': '#FF6B6B', # 人物-红色
|
179 |
+
'ORG': '#4ECDC4', # 组织-青色
|
180 |
+
'LOC': '#45B7D1', # 地点-蓝色
|
181 |
+
'TIME': '#96CEB4', # 时间-绿色
|
182 |
+
'MISC': '#D4A5A5' # 其他-灰色
|
183 |
+
}
|
184 |
+
|
185 |
+
# 添加实体节点
|
186 |
+
for entity in entities:
|
187 |
+
node_color = entity_colors.get(entity['type'], '#D3D3D3')
|
188 |
+
net.add_node(entity['text'],
|
189 |
+
label=f"{entity['text']}\n({entity['type']})",
|
190 |
+
color=node_color,
|
191 |
+
title=f"类型: {entity['type']}")
|
192 |
+
|
193 |
+
# 添加关系边
|
194 |
+
for relation in relations:
|
195 |
+
net.add_edge(relation['head'],
|
196 |
+
relation['tail'],
|
197 |
+
label=relation['relation'],
|
198 |
+
arrows='to')
|
199 |
+
|
200 |
+
# 设置物理布局
|
201 |
+
net.set_options('''
|
202 |
+
var options = {
|
203 |
+
"physics": {
|
204 |
+
"forceAtlas2Based": {
|
205 |
+
"gravitationalConstant": -50,
|
206 |
+
"centralGravity": 0.01,
|
207 |
+
"springLength": 100,
|
208 |
+
"springConstant": 0.08
|
209 |
+
},
|
210 |
+
"maxVelocity": 50,
|
211 |
+
"solver": "forceAtlas2Based",
|
212 |
+
"timestep": 0.35,
|
213 |
+
"stabilization": {"iterations": 150}
|
214 |
+
}
|
215 |
+
}
|
216 |
+
''')
|
217 |
+
|
218 |
+
# 生成HTML文件
|
219 |
+
html_path = "knowledge_graph.html"
|
220 |
+
net.save_graph(html_path)
|
221 |
+
return html_path
|
222 |
+
|
223 |
+
# ======================== 实体识别(NER) ========================
|
224 |
+
def merge_adjacent_entities(entities):
|
225 |
+
if not entities:
|
226 |
+
return entities
|
227 |
+
|
228 |
+
merged = [entities[0]]
|
229 |
+
for entity in entities[1:]:
|
230 |
+
last = merged[-1]
|
231 |
+
# 合并相邻的同类型实体
|
232 |
+
if (entity["type"] == last["type"] and
|
233 |
+
entity["start"] == last["end"]):
|
234 |
+
last["text"] += entity["text"]
|
235 |
+
last["end"] = entity["end"]
|
236 |
+
else:
|
237 |
+
merged.append(entity)
|
238 |
+
|
239 |
+
return merged
|
240 |
+
|
241 |
+
|
242 |
+
def ner(text, model_type="bert"):
|
243 |
+
start_time = time.time()
|
244 |
+
|
245 |
+
# 如果使用的是 ChatGLM 模型,执行 ChatGLM 的NER
|
246 |
+
if model_type == "chatglm" and use_chatglm:
|
247 |
+
try:
|
248 |
+
prompt = f"""请从以下文本中识别所有实体,严格按照JSON列表格式返回,每个实体包含text、type、start、end字段:
|
249 |
+
示例:[{{"text": "北京", "type": "LOC", "start": 0, "end": 2}}]
|
250 |
+
文本:{text}"""
|
251 |
+
response = chatglm_model.chat(chatglm_tokenizer, prompt, temperature=0.1)
|
252 |
+
if isinstance(response, tuple):
|
253 |
+
response = response[0]
|
254 |
+
|
255 |
+
try:
|
256 |
+
json_str = re.search(r'\[.*\]', response, re.DOTALL).group()
|
257 |
+
entities = json.loads(json_str)
|
258 |
+
valid_entities = [ent for ent in entities if all(k in ent for k in ("text", "type", "start", "end"))]
|
259 |
+
return valid_entities, time.time() - start_time
|
260 |
+
except Exception as e:
|
261 |
+
print(f"JSON解析失败: {e}")
|
262 |
+
return [], time.time() - start_time
|
263 |
+
except Exception as e:
|
264 |
+
print(f"ChatGLM调用失败: {e}")
|
265 |
+
return [], time.time() - start_time
|
266 |
+
|
267 |
+
# 使用BERT NER
|
268 |
+
text_chunks = [text[i:i + 510] for i in range(0, len(text), 510)] # 安全分段
|
269 |
+
raw_results = []
|
270 |
+
|
271 |
+
# 获取NER pipeline
|
272 |
+
ner_pipeline = get_ner_pipeline() # 使用缓存的pipeline
|
273 |
+
|
274 |
+
for idx, chunk in enumerate(text_chunks):
|
275 |
+
chunk_results = ner_pipeline(chunk) # 使用获取的pipeline
|
276 |
+
for r in chunk_results:
|
277 |
+
r["start"] += idx * 510
|
278 |
+
r["end"] += idx * 510
|
279 |
+
raw_results.extend(chunk_results)
|
280 |
+
|
281 |
+
entities = [{
|
282 |
+
"text": r['word'].replace(' ', ''),
|
283 |
+
"start": r['start'],
|
284 |
+
"end": r['end'],
|
285 |
+
"type": LABEL_MAPPING.get(r.get('entity_group') or r.get('entity'), r.get('entity_group') or r.get('entity'))
|
286 |
+
} for r in raw_results]
|
287 |
+
|
288 |
+
entities = merge_adjacent_entities(entities)
|
289 |
+
return entities, time.time() - start_time
|
290 |
+
|
291 |
+
|
292 |
+
# ------------------ 实体类型标准化 ------------------
|
293 |
+
LABEL_MAPPING = {
|
294 |
+
"address": "LOC",
|
295 |
+
"company": "ORG",
|
296 |
+
"name": "PER",
|
297 |
+
"organization": "ORG",
|
298 |
+
"position": "TITLE",
|
299 |
+
"government": "ORG",
|
300 |
+
"scene": "LOC",
|
301 |
+
"book": "WORK",
|
302 |
+
"movie": "WORK",
|
303 |
+
"game": "WORK"
|
304 |
+
}
|
305 |
+
|
306 |
+
# 提取实体
|
307 |
+
entities, processing_time = ner("Google in New York met Alice")
|
308 |
+
|
309 |
+
# 标准化实体类型
|
310 |
+
for e in entities:
|
311 |
+
e["type"] = LABEL_MAPPING.get(e.get("type"), e.get("type"))
|
312 |
+
|
313 |
+
# 打印标准化后的实体
|
314 |
+
print(f"[DEBUG] 标准化后实体列表: {[{'text': e['text'], 'type': e['type']} for e in entities]}")
|
315 |
+
|
316 |
+
# 打印处理时间
|
317 |
+
print(f"处理时间: {processing_time:.2f}秒")
|
318 |
+
|
319 |
+
|
320 |
+
# ======================== 关系抽取(RE) ========================
|
321 |
+
@lru_cache(maxsize=1)
|
322 |
+
def get_re_pipeline():
|
323 |
+
tokenizer = AutoTokenizer.from_pretrained(NER_MODEL_NAME)
|
324 |
+
model = AutoModelForTokenClassification.from_pretrained(NER_MODEL_NAME)
|
325 |
+
return pipeline(
|
326 |
+
"ner", # 使用NER pipeline
|
327 |
+
model=model,
|
328 |
+
tokenizer=tokenizer,
|
329 |
+
aggregation_strategy="first"
|
330 |
+
)
|
331 |
+
|
332 |
+
def re_extract(entities, text, use_bert_model=True):
|
333 |
+
if not entities or not text:
|
334 |
+
return [], 0
|
335 |
+
|
336 |
+
start_time = time.time()
|
337 |
+
try:
|
338 |
+
# 使用规则匹配关系
|
339 |
+
relations = []
|
340 |
+
|
341 |
+
# 定义关系关键词和对应的实体类型约束
|
342 |
+
relation_rules = {
|
343 |
+
"位于": {
|
344 |
+
"keywords": ["位于", "在", "��落于"],
|
345 |
+
"valid_types": {
|
346 |
+
"head": ["ORG", "PER", "LOC"],
|
347 |
+
"tail": ["LOC"]
|
348 |
+
}
|
349 |
+
},
|
350 |
+
"属于": {
|
351 |
+
"keywords": ["属于", "是", "为"],
|
352 |
+
"valid_types": {
|
353 |
+
"head": ["ORG", "PER"],
|
354 |
+
"tail": ["ORG", "LOC"]
|
355 |
+
}
|
356 |
+
},
|
357 |
+
"任职于": {
|
358 |
+
"keywords": ["任职于", "就职于", "工作于"],
|
359 |
+
"valid_types": {
|
360 |
+
"head": ["PER"],
|
361 |
+
"tail": ["ORG"]
|
362 |
+
}
|
363 |
+
}
|
364 |
+
}
|
365 |
+
|
366 |
+
# 预处理实体,去除重复和部分匹配
|
367 |
+
processed_entities = []
|
368 |
+
for e in entities:
|
369 |
+
# 检查是否与已有实体重叠
|
370 |
+
is_subset = False
|
371 |
+
for pe in processed_entities:
|
372 |
+
if e["text"] in pe["text"] and e["text"] != pe["text"]:
|
373 |
+
is_subset = True
|
374 |
+
break
|
375 |
+
if not is_subset:
|
376 |
+
processed_entities.append(e)
|
377 |
+
|
378 |
+
# 遍历文本中的每个句子
|
379 |
+
sentences = re.split('[。!?.!?]', text)
|
380 |
+
for sentence in sentences:
|
381 |
+
if not sentence.strip():
|
382 |
+
continue
|
383 |
+
|
384 |
+
# 获取当前句子中的实体
|
385 |
+
sentence_entities = [e for e in processed_entities if e["text"] in sentence]
|
386 |
+
|
387 |
+
# 检查每个关系类型
|
388 |
+
for rel_type, rule in relation_rules.items():
|
389 |
+
for keyword in rule["keywords"]:
|
390 |
+
if keyword in sentence:
|
391 |
+
# 在句子中查找符合类型约束的实体对
|
392 |
+
for i, ent1 in enumerate(sentence_entities):
|
393 |
+
for j, ent2 in enumerate(sentence_entities):
|
394 |
+
if i != j: # 避免自循环
|
395 |
+
# 检查实体类型是否符合规则
|
396 |
+
if (ent1["type"] in rule["valid_types"]["head"] and
|
397 |
+
ent2["type"] in rule["valid_types"]["tail"]):
|
398 |
+
# 检查实体在句子中的位置关系
|
399 |
+
if sentence.find(ent1["text"]) < sentence.find(ent2["text"]):
|
400 |
+
relations.append({
|
401 |
+
"head": ent1["text"],
|
402 |
+
"tail": ent2["text"],
|
403 |
+
"relation": rel_type
|
404 |
+
})
|
405 |
+
|
406 |
+
# 去重
|
407 |
+
unique_relations = []
|
408 |
+
seen = set()
|
409 |
+
for rel in relations:
|
410 |
+
rel_key = (rel["head"], rel["tail"], rel["relation"])
|
411 |
+
if rel_key not in seen:
|
412 |
+
seen.add(rel_key)
|
413 |
+
unique_relations.append(rel)
|
414 |
+
|
415 |
+
return unique_relations, time.time() - start_time
|
416 |
+
|
417 |
+
except Exception as e:
|
418 |
+
logging.error(f"关系抽取失败: {e}")
|
419 |
+
return [], time.time() - start_time
|
420 |
+
|
421 |
+
|
422 |
+
# ======================== 文本分析主流程 ========================
|
423 |
+
def create_knowledge_graph(entities, relations):
|
424 |
+
"""
|
425 |
+
创建交互式网络图形式的知识图谱
|
426 |
+
"""
|
427 |
+
# 创建一个新的网络图
|
428 |
+
net = Network(height="600px", width="100%", bgcolor="#ffffff", font_color="black", directed=True)
|
429 |
+
|
430 |
+
# 设置实体类型的颜色映射
|
431 |
+
entity_colors = {
|
432 |
+
'PER': '#FF6B6B', # 人物-红色
|
433 |
+
'ORG': '#4ECDC4', # 组织-青色
|
434 |
+
'LOC': '#45B7D1', # 地点-蓝色
|
435 |
+
'TIME': '#96CEB4', # 时间-绿色
|
436 |
+
'TITLE': '#D4A5A5' # 职位-粉色
|
437 |
+
}
|
438 |
+
|
439 |
+
# 添加实体节点
|
440 |
+
added_nodes = set()
|
441 |
+
for entity in entities:
|
442 |
+
if entity['text'] not in added_nodes:
|
443 |
+
node_color = entity_colors.get(entity['type'], '#D3D3D3')
|
444 |
+
net.add_node(
|
445 |
+
entity['text'],
|
446 |
+
label=entity['text'],
|
447 |
+
title=f"类型: {entity['type']}",
|
448 |
+
color=node_color,
|
449 |
+
size=20,
|
450 |
+
font={'size': 16}
|
451 |
+
)
|
452 |
+
added_nodes.add(entity['text'])
|
453 |
+
|
454 |
+
# 添加关系边
|
455 |
+
for relation in relations:
|
456 |
+
if relation['head'] in added_nodes and relation['tail'] in added_nodes:
|
457 |
+
net.add_edge(
|
458 |
+
relation['head'],
|
459 |
+
relation['tail'],
|
460 |
+
label=relation['relation'],
|
461 |
+
title=relation['relation'],
|
462 |
+
arrows={'to': {'enabled': True, 'type': 'arrow'}},
|
463 |
+
color={'color': '#666666'},
|
464 |
+
font={'size': 12}
|
465 |
+
)
|
466 |
+
|
467 |
+
# 设置物理布局参数
|
468 |
+
net.set_options('''
|
469 |
+
{
|
470 |
+
"nodes": {
|
471 |
+
"shape": "dot",
|
472 |
+
"shadow": true
|
473 |
+
},
|
474 |
+
"edges": {
|
475 |
+
"smooth": {
|
476 |
+
"type": "continuous",
|
477 |
+
"forceDirection": "none"
|
478 |
+
},
|
479 |
+
"shadow": true
|
480 |
+
},
|
481 |
+
"physics": {
|
482 |
+
"barnesHut": {
|
483 |
+
"gravitationalConstant": -2000,
|
484 |
+
"centralGravity": 0.3,
|
485 |
+
"springLength": 200,
|
486 |
+
"springConstant": 0.04,
|
487 |
+
"damping": 0.09
|
488 |
+
},
|
489 |
+
"minVelocity": 0.75
|
490 |
+
},
|
491 |
+
"interaction": {
|
492 |
+
"hover": true,
|
493 |
+
"navigationButtons": true,
|
494 |
+
"keyboard": true
|
495 |
+
}
|
496 |
+
}
|
497 |
+
''')
|
498 |
+
|
499 |
+
try:
|
500 |
+
# 创建临时目录(如果不存在)
|
501 |
+
temp_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "temp")
|
502 |
+
os.makedirs(temp_dir, exist_ok=True)
|
503 |
+
|
504 |
+
# 生成唯一的文件名
|
505 |
+
timestamp = int(time.time())
|
506 |
+
html_file = os.path.join(temp_dir, f"kg_{timestamp}.html")
|
507 |
+
|
508 |
+
# 保存HTML文件
|
509 |
+
net.save_graph(html_file)
|
510 |
+
|
511 |
+
# 读取HTML内容
|
512 |
+
with open(html_file, 'r', encoding='utf-8') as f:
|
513 |
+
html_content = f.read()
|
514 |
+
|
515 |
+
# 添加图例
|
516 |
+
legend_html = f"""
|
517 |
+
<div style="margin-bottom: 10px; padding: 10px; background-color: #f8f9fa; border-radius: 5px;">
|
518 |
+
<div style="font-weight: bold; margin-bottom: 5px;">图例说明:</div>
|
519 |
+
<div style="display: flex; gap: 15px; flex-wrap: wrap;">
|
520 |
+
<div style="display: flex; align-items: center; gap: 5px;">
|
521 |
+
<div style="width: 15px; height: 15px; background: {entity_colors['PER']}; border-radius: 50%;"></div>
|
522 |
+
<span>人物 (PER)</span>
|
523 |
+
</div>
|
524 |
+
<div style="display: flex; align-items: center; gap: 5px;">
|
525 |
+
<div style="width: 15px; height: 15px; background: {entity_colors['ORG']}; border-radius: 50%;"></div>
|
526 |
+
<span>组织 (ORG)</span>
|
527 |
+
</div>
|
528 |
+
<div style="display: flex; align-items: center; gap: 5px;">
|
529 |
+
<div style="width: 15px; height: 15px; background: {entity_colors['LOC']}; border-radius: 50%;"></div>
|
530 |
+
<span>地点 (LOC)</span>
|
531 |
+
</div>
|
532 |
+
<div style="display: flex; align-items: center; gap: 5px;">
|
533 |
+
<div style="width: 15px; height: 15px; background: {entity_colors['TITLE']}; border-radius: 50%;"></div>
|
534 |
+
<span>职位 (TITLE)</span>
|
535 |
+
</div>
|
536 |
+
</div>
|
537 |
+
</div>
|
538 |
+
"""
|
539 |
+
|
540 |
+
# 将图例添加到HTML内容中
|
541 |
+
html_content = legend_html + html_content
|
542 |
+
|
543 |
+
# 清理旧的临时文件
|
544 |
+
for old_file in os.listdir(temp_dir):
|
545 |
+
if old_file.startswith("kg_") and old_file.endswith(".html"):
|
546 |
+
old_path = os.path.join(temp_dir, old_file)
|
547 |
+
if os.path.getmtime(old_path) < time.time() - 3600: # 删除1小时前的文件
|
548 |
+
try:
|
549 |
+
os.remove(old_path)
|
550 |
+
except:
|
551 |
+
pass
|
552 |
+
|
553 |
+
return html_content
|
554 |
+
|
555 |
+
except Exception as e:
|
556 |
+
logging.error(f"生成知识图谱失败: {str(e)}")
|
557 |
+
return f"<div class='error'>生成知识图谱失败: {str(e)}</div>"
|
558 |
+
|
559 |
+
def process_text(text, model_type="bert"):
|
560 |
+
"""
|
561 |
+
处理文本,进行实体识别和关系抽取
|
562 |
+
"""
|
563 |
+
start_time = time.time()
|
564 |
+
|
565 |
+
# 实体识别
|
566 |
+
entities, ner_duration = ner(text, model_type)
|
567 |
+
if not entities:
|
568 |
+
return "", "", "", f"{time.time() - start_time:.2f} 秒"
|
569 |
+
|
570 |
+
# 关系抽取
|
571 |
+
relations, re_duration = re_extract(entities, text)
|
572 |
+
|
573 |
+
# 生成文本格式的实体和关系描述
|
574 |
+
ent_text = "📌 实体:\n" + "\n".join([f"{e['text']} ({e['type']})" for e in entities])
|
575 |
+
rel_text = "\n\n📎 关系:\n" + "\n".join([f"{r['head']} --[{r['relation']}]--> {r['tail']}" for r in relations])
|
576 |
+
|
577 |
+
# 生成知识图谱
|
578 |
+
kg_text = create_knowledge_graph(entities, relations)
|
579 |
+
|
580 |
+
total_duration = time.time() - start_time
|
581 |
+
return ent_text, rel_text, kg_text, f"{total_duration:.2f} 秒"
|
582 |
+
|
583 |
+
|
584 |
+
def process_file(file, model_type="bert"):
|
585 |
+
try:
|
586 |
+
with open(file.name, 'rb') as f:
|
587 |
+
content = f.read()
|
588 |
+
|
589 |
+
if len(content) > 5 * 1024 * 1024:
|
590 |
+
return "❌ 文件太大", "", "", ""
|
591 |
+
|
592 |
+
# 检测编码
|
593 |
+
try:
|
594 |
+
encoding = chardet.detect(content)['encoding'] or 'utf-8'
|
595 |
+
text = content.decode(encoding)
|
596 |
+
except UnicodeDecodeError:
|
597 |
+
# 尝试常见中文编码
|
598 |
+
for enc in ['gb18030', 'utf-16', 'big5']:
|
599 |
+
try:
|
600 |
+
text = content.decode(enc)
|
601 |
+
break
|
602 |
+
except:
|
603 |
+
continue
|
604 |
+
else:
|
605 |
+
return "❌ 编码解析失败", "", "", ""
|
606 |
+
|
607 |
+
# 直接调用process_text处理文本
|
608 |
+
return process_text(text, model_type)
|
609 |
+
|
610 |
+
except Exception as e:
|
611 |
+
logging.error(f"文件处理错误: {str(e)}")
|
612 |
+
return f"❌ 文件处理错误: {str(e)}", "", "", ""
|
613 |
+
|
614 |
+
|
615 |
+
|
616 |
+
# ======================== 模型评估与自动标注 ========================
|
617 |
+
def convert_telegram_json_to_eval_format(path):
|
618 |
+
with open(path, encoding="utf-8") as f:
|
619 |
+
data = json.load(f)
|
620 |
+
if isinstance(data, dict) and "text" in data:
|
621 |
+
return [{"text": data["text"], "entities": [
|
622 |
+
{"text": data["text"][e["start"]:e["end"]]} for e in data.get("entities", [])
|
623 |
+
]}]
|
624 |
+
elif isinstance(data, list):
|
625 |
+
return data
|
626 |
+
elif isinstance(data, dict) and "messages" in data:
|
627 |
+
result = []
|
628 |
+
for m in data.get("messages", []):
|
629 |
+
if isinstance(m.get("text"), str):
|
630 |
+
result.append({"text": m["text"], "entities": []})
|
631 |
+
elif isinstance(m.get("text"), list):
|
632 |
+
txt = ''.join([x["text"] if isinstance(x, dict) else x for x in m["text"]])
|
633 |
+
result.append({"text": txt, "entities": []})
|
634 |
+
return result
|
635 |
+
return []
|
636 |
+
|
637 |
+
|
638 |
+
def evaluate_ner_model(data, model_type):
|
639 |
+
tp, fp, fn = 0, 0, 0
|
640 |
+
POS_TOLERANCE = 1
|
641 |
+
|
642 |
+
for item in data:
|
643 |
+
text = item["text"]
|
644 |
+
# 处理标注数据
|
645 |
+
gold_entities = []
|
646 |
+
for e in item.get("entities", []):
|
647 |
+
if "text" in e and "type" in e:
|
648 |
+
norm_type = LABEL_MAPPING.get(e["type"], e["type"])
|
649 |
+
gold_entities.append({
|
650 |
+
"text": e["text"],
|
651 |
+
"type": norm_type,
|
652 |
+
"start": e.get("start", -1),
|
653 |
+
"end": e.get("end", -1)
|
654 |
+
})
|
655 |
+
|
656 |
+
# 获取预测结果
|
657 |
+
pred_entities, _ = ner(text, model_type)
|
658 |
+
|
659 |
+
# 初始化匹配状态
|
660 |
+
matched_gold = [False] * len(gold_entities)
|
661 |
+
matched_pred = [False] * len(pred_entities)
|
662 |
+
|
663 |
+
# 遍历预测实体寻找匹配
|
664 |
+
for p_idx, p in enumerate(pred_entities):
|
665 |
+
for g_idx, g in enumerate(gold_entities):
|
666 |
+
if not matched_gold[g_idx] and \
|
667 |
+
p["text"] == g["text"] and \
|
668 |
+
p["type"] == g["type"] and \
|
669 |
+
abs(p["start"] - g["start"]) <= POS_TOLERANCE and \
|
670 |
+
abs(p["end"] - g["end"]) <= POS_TOLERANCE:
|
671 |
+
matched_gold[g_idx] = True
|
672 |
+
matched_pred[p_idx] = True
|
673 |
+
break
|
674 |
+
|
675 |
+
# 统计指标
|
676 |
+
tp += sum(matched_pred)
|
677 |
+
fp += len(pred_entities) - sum(matched_pred)
|
678 |
+
fn += len(gold_entities) - sum(matched_gold)
|
679 |
+
|
680 |
+
# 处理除零情况
|
681 |
+
precision = tp / (tp + fp) if (tp + fp) > 0 else 0
|
682 |
+
recall = tp / (tp + fn) if (tp + fn) > 0 else 0
|
683 |
+
f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
|
684 |
+
|
685 |
+
return (f"Precision: {precision:.2f}\n"
|
686 |
+
f"Recall: {recall:.2f}\n"
|
687 |
+
f"F1: {f1:.2f}")
|
688 |
+
|
689 |
+
|
690 |
+
def auto_annotate(file, model_type):
|
691 |
+
data = convert_telegram_json_to_eval_format(file.name)
|
692 |
+
for item in data:
|
693 |
+
ents, _ = ner(item["text"], model_type)
|
694 |
+
item["entities"] = ents
|
695 |
+
return json.dumps(data, ensure_ascii=False, indent=2)
|
696 |
+
|
697 |
+
|
698 |
+
def save_json(json_text):
|
699 |
+
fname = f"auto_labeled_{int(time.time())}.json"
|
700 |
+
with open(fname, "w", encoding="utf-8") as f:
|
701 |
+
f.write(json_text)
|
702 |
+
return fname
|
703 |
+
|
704 |
+
|
705 |
+
# ======================== 数据集导入 ========================
|
706 |
+
def import_dataset(path="D:/云边智算/暗语识别/filtered_results"):
|
707 |
+
import os
|
708 |
+
import json
|
709 |
+
|
710 |
+
for filename in os.listdir(path):
|
711 |
+
if filename.endswith('.json'):
|
712 |
+
filepath = os.path.join(path, filename)
|
713 |
+
with open(filepath, 'r', encoding='utf-8') as f:
|
714 |
+
data = json.load(f)
|
715 |
+
# 调用现有处理流程
|
716 |
+
process_text(data['text'])
|
717 |
+
print(f"已处理文件: {filename}")
|
718 |
+
|
719 |
+
|
720 |
+
# ======================== Gradio 界面 ========================
|
721 |
+
with gr.Blocks(css="""
|
722 |
+
.kg-graph {height: 700px; overflow-y: auto;}
|
723 |
+
.warning {color: #ff6b6b;}
|
724 |
+
.error {color: #ff0000; padding: 10px; background-color: #ffeeee; border-radius: 5px;}
|
725 |
+
""") as demo:
|
726 |
+
gr.Markdown("# 🤖 聊天记录实体关系识别系统")
|
727 |
+
|
728 |
+
with gr.Tab("📄 文本分析"):
|
729 |
+
input_text = gr.Textbox(lines=6, label="输入文本")
|
730 |
+
model_type = gr.Radio(["bert", "chatglm"], value="bert", label="选择模型")
|
731 |
+
btn = gr.Button("开始分析")
|
732 |
+
out1 = gr.Textbox(label="识别实体")
|
733 |
+
out2 = gr.Textbox(label="识别关系")
|
734 |
+
out3 = gr.HTML(label="知识图谱")
|
735 |
+
out4 = gr.Textbox(label="耗时")
|
736 |
+
btn.click(fn=process_text, inputs=[input_text, model_type], outputs=[out1, out2, out3, out4])
|
737 |
+
|
738 |
+
with gr.Tab("🗂 文件分析"):
|
739 |
+
file_input = gr.File(file_types=[".txt", ".json"])
|
740 |
+
file_btn = gr.Button("上传并分析")
|
741 |
+
fout1, fout2, fout3, fout4 = gr.Textbox(), gr.Textbox(), gr.Textbox(), gr.Textbox()
|
742 |
+
file_btn.click(fn=process_file, inputs=[file_input, model_type], outputs=[fout1, fout2, fout3, fout4])
|
743 |
+
|
744 |
+
with gr.Tab("📊 模型评估"):
|
745 |
+
eval_file = gr.File(label="上传标注 JSON")
|
746 |
+
eval_model = gr.Radio(["bert", "chatglm"], value="bert")
|
747 |
+
eval_btn = gr.Button("开始评估")
|
748 |
+
eval_output = gr.Textbox(label="评估结果", lines=5)
|
749 |
+
eval_btn.click(lambda f, m: evaluate_ner_model(convert_telegram_json_to_eval_format(f.name), m),
|
750 |
+
[eval_file, eval_model], eval_output)
|
751 |
+
|
752 |
+
with gr.Tab("✏️ 自动标注"):
|
753 |
+
raw_file = gr.File(label="上传 Telegram 原始 JSON")
|
754 |
+
auto_model = gr.Radio(["bert", "chatglm"], value="bert")
|
755 |
+
auto_btn = gr.Button("自动标注")
|
756 |
+
marked_texts = gr.Textbox(label="标注结果", lines=20)
|
757 |
+
download_btn = gr.Button("💾 下载标注文件")
|
758 |
+
auto_btn.click(fn=auto_annotate, inputs=[raw_file, auto_model], outputs=marked_texts)
|
759 |
+
download_btn.click(fn=save_json, inputs=marked_texts, outputs=gr.File())
|
760 |
+
|
761 |
+
with gr.Tab("📂 数据管理"):
|
762 |
+
gr.Markdown("### 数据集导入")
|
763 |
+
dataset_path = gr.Textbox(
|
764 |
+
value="D:/云边智算/暗语识别/filtered_results",
|
765 |
+
label="数据集路径"
|
766 |
+
)
|
767 |
+
import_btn = gr.Button("导入数据集到数据库")
|
768 |
+
import_output = gr.Textbox(label="导入日志")
|
769 |
+
import_btn.click(fn=lambda: import_dataset(dataset_path.value), outputs=import_output)
|
770 |
+
|
771 |
demo.launch(server_name="0.0.0.0", server_port=7860)
|