Spaces:
Sleeping
Sleeping
Upload app.py
Browse files
app.py
CHANGED
|
@@ -7,52 +7,7 @@ import json
|
|
| 7 |
import chardet
|
| 8 |
from sklearn.metrics import precision_score, recall_score, f1_score
|
| 9 |
import time
|
| 10 |
-
from functools import lru_cache
|
| 11 |
-
from sqlalchemy import create_engine
|
| 12 |
-
from sqlalchemy.orm import sessionmaker
|
| 13 |
-
from contextlib import contextmanager
|
| 14 |
-
import logging
|
| 15 |
-
import networkx as nx
|
| 16 |
-
from pyvis.network import Network
|
| 17 |
-
import pandas as pd
|
| 18 |
-
import matplotlib.pyplot as plt
|
| 19 |
-
from gqlalchemy import Memgraph
|
| 20 |
-
from mcp_use import RelationPredictor, insert_to_memgraph, get_memgraph_conn # 引入mcp_use中的功能
|
| 21 |
-
from relation_extraction.hparams import hparams # 引入模型超参数
|
| 22 |
-
|
| 23 |
-
# ======================== 数据库模块 ========================
|
| 24 |
-
MEMGRAPH_HOST = '18.159.132.161'
|
| 25 |
-
MEMGRAPH_PORT = 7687
|
| 26 |
-
MEMGRAPH_USERNAME = '[email protected]'
|
| 27 |
-
MEMGRAPH_PASSWORD = os.environ.get("MEMGRAPH_PASSWORD", "<YOUR MEMGRAPH PASSWORD HERE>")
|
| 28 |
-
|
| 29 |
-
# 初始化 Memgraph 连接
|
| 30 |
-
memgraph = get_memgraph_conn()
|
| 31 |
-
|
| 32 |
-
# 初始化关系抽取模型
|
| 33 |
-
relation_predictor = RelationPredictor(hparams)
|
| 34 |
-
|
| 35 |
-
# ======================== 关系抽取功能整合 ========================
|
| 36 |
-
def extract_and_save_relations(text, entity1, entity2):
|
| 37 |
-
"""
|
| 38 |
-
使用 mcp_use.py 中的 RelationPredictor 提取关系,并保存到 Memgraph
|
| 39 |
-
"""
|
| 40 |
-
try:
|
| 41 |
-
# 调用关系抽取模型
|
| 42 |
-
result = relation_predictor.predict_one(text, entity1, entity2)
|
| 43 |
-
if result is None:
|
| 44 |
-
return f"❌ 未找到实体 '{entity1}' 或 '{entity2}'"
|
| 45 |
-
|
| 46 |
-
# 提取关系
|
| 47 |
-
entity1, relation, entity2 = result
|
| 48 |
-
|
| 49 |
-
# 保存到 Memgraph
|
| 50 |
-
insert_to_memgraph(memgraph, entity1, relation, entity2)
|
| 51 |
-
return f"✅ 已写入 Memgraph:({entity1})-[:{relation}]->({entity2})"
|
| 52 |
-
except Exception as e:
|
| 53 |
-
logging.error(f"关系抽取失败: {e}")
|
| 54 |
-
return f"❌ 关系抽取失败: {e}"
|
| 55 |
-
|
| 56 |
# ======================== 数据库模块 ========================
|
| 57 |
from sqlalchemy import create_engine
|
| 58 |
from sqlalchemy.orm import sessionmaker
|
|
@@ -63,46 +18,6 @@ from pyvis.network import Network
|
|
| 63 |
import pandas as pd
|
| 64 |
import matplotlib.pyplot as plt
|
| 65 |
|
| 66 |
-
|
| 67 |
-
from gqlalchemy import Memgraph
|
| 68 |
-
import os
|
| 69 |
-
MEMGRAPH_HOST = '18.159.132.161'
|
| 70 |
-
MEMGRAPH_PORT = 7687
|
| 71 |
-
MEMGRAPH_USERNAME = '[email protected]'
|
| 72 |
-
MEMGRAPH_PASSWORD = os.environ.get("MEMGRAPH_PASSWORD", "<YOUR MEMGRAPH PASSWORD HERE>")
|
| 73 |
-
|
| 74 |
-
def hello_memgraph():
|
| 75 |
-
"""测试 Memgraph 数据库连接并进行健康检查"""
|
| 76 |
-
try:
|
| 77 |
-
connection = Memgraph(
|
| 78 |
-
host=os.environ["MEMGRAPH_HOST"],
|
| 79 |
-
port=int(os.environ["MEMGRAPH_PORT"]),
|
| 80 |
-
username=os.environ["MEMGRAPH_USERNAME"],
|
| 81 |
-
password=os.environ["MEMGRAPH_PASSWORD"], # 强制从环境变量获取
|
| 82 |
-
encrypted=True,
|
| 83 |
-
ssl_verify=True,
|
| 84 |
-
ca_path="/etc/ssl/certs/memgraph.crt"
|
| 85 |
-
)
|
| 86 |
-
|
| 87 |
-
# 健康检查查询
|
| 88 |
-
health = connection.execute_and_fetch("CALL mg.get('memgraph') YIELD value;")
|
| 89 |
-
health_status = next(health)["value"]["status"]
|
| 90 |
-
|
| 91 |
-
# 创建测试节点
|
| 92 |
-
connection.execute(
|
| 93 |
-
'CREATE (n:ConnectionTest { message: "Hello Memgraph", ts: $ts })',
|
| 94 |
-
{"ts": datetime.now().isoformat()}
|
| 95 |
-
)
|
| 96 |
-
|
| 97 |
-
return f"✅ 连接正常 | 状态: {health_status}"
|
| 98 |
-
|
| 99 |
-
except Exception as e:
|
| 100 |
-
logging.error(f"连接失败: {str(e)}", exc_info=True)
|
| 101 |
-
return f"❌ 连接失败: {str(e)}"
|
| 102 |
-
finally:
|
| 103 |
-
connection.close()
|
| 104 |
-
|
| 105 |
-
|
| 106 |
# 配置日志
|
| 107 |
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
| 108 |
|
|
@@ -593,55 +508,35 @@ def process_text(text, model_type="bert"):
|
|
| 593 |
return ent_text, rel_text, kg_text, f"{total_duration:.2f} 秒"
|
| 594 |
|
| 595 |
# ======================== 知识图谱可视化 ========================
|
| 596 |
-
import matplotlib.pyplot as plt
|
| 597 |
-
import networkx as nx
|
| 598 |
-
import tempfile
|
| 599 |
-
import os
|
| 600 |
-
import logging
|
| 601 |
-
from matplotlib import font_manager
|
| 602 |
-
|
| 603 |
-
# 这个函数用于查找并验证中文字体路径
|
| 604 |
-
def find_chinese_font():
|
| 605 |
-
# 尝试查找 Noto Sans CJK 字体
|
| 606 |
-
font_paths = [
|
| 607 |
-
"/usr/share/fonts/truetype/noto/NotoSansCJK-Regular.ttc", # Noto CJK 字体
|
| 608 |
-
"/usr/share/fonts/truetype/wqy/wqy-microhei.ttc" # 微软雅黑
|
| 609 |
-
]
|
| 610 |
-
|
| 611 |
-
for font_path in font_paths:
|
| 612 |
-
if os.path.exists(font_path):
|
| 613 |
-
logging.info(f"Found font at {font_path}")
|
| 614 |
-
return font_path
|
| 615 |
-
|
| 616 |
-
logging.error("No Chinese font found!")
|
| 617 |
-
return None
|
| 618 |
-
|
| 619 |
def generate_kg_image(entities, relations):
|
| 620 |
"""
|
| 621 |
-
|
| 622 |
"""
|
| 623 |
try:
|
| 624 |
-
|
| 625 |
-
|
| 626 |
-
|
| 627 |
-
|
| 628 |
-
|
| 629 |
-
|
| 630 |
-
|
| 631 |
-
|
| 632 |
-
plt.rcParams['font.family'] = ['DejaVu Sans']
|
| 633 |
|
| 634 |
-
|
|
|
|
|
|
|
| 635 |
|
| 636 |
-
# ===
|
| 637 |
G = nx.DiGraph()
|
| 638 |
entity_colors = {
|
| 639 |
-
'PER': '#FF6B6B', #
|
| 640 |
-
'ORG': '#4ECDC4', #
|
| 641 |
-
'LOC': '#45B7D1', #
|
| 642 |
-
'TIME': '#96CEB4'
|
|
|
|
| 643 |
}
|
| 644 |
|
|
|
|
| 645 |
for entity in entities:
|
| 646 |
G.add_node(
|
| 647 |
entity["text"],
|
|
@@ -649,6 +544,7 @@ def generate_kg_image(entities, relations):
|
|
| 649 |
color=entity_colors.get(entity['type'], '#D3D3D3')
|
| 650 |
)
|
| 651 |
|
|
|
|
| 652 |
for relation in relations:
|
| 653 |
if relation["head"] in G.nodes and relation["tail"] in G.nodes:
|
| 654 |
G.add_edge(
|
|
@@ -657,15 +553,15 @@ def generate_kg_image(entities, relations):
|
|
| 657 |
label=relation["relation"]
|
| 658 |
)
|
| 659 |
|
| 660 |
-
# ===
|
| 661 |
-
plt.figure(figsize=(12, 8), dpi=150)
|
| 662 |
-
pos = nx.spring_layout(G, k=0.7, seed=42)
|
| 663 |
|
|
|
|
| 664 |
nx.draw_networkx_nodes(
|
| 665 |
G, pos,
|
| 666 |
node_color=[G.nodes[n]['color'] for n in G.nodes],
|
| 667 |
-
node_size=800
|
| 668 |
-
alpha=0.9
|
| 669 |
)
|
| 670 |
nx.draw_networkx_edges(
|
| 671 |
G, pos,
|
|
@@ -675,44 +571,35 @@ def generate_kg_image(entities, relations):
|
|
| 675 |
arrowsize=20
|
| 676 |
)
|
| 677 |
|
| 678 |
-
|
| 679 |
nx.draw_networkx_labels(
|
| 680 |
G, pos,
|
| 681 |
-
labels=
|
| 682 |
font_size=10,
|
| 683 |
-
font_family=
|
| 684 |
-
font_weight='bold'
|
| 685 |
)
|
| 686 |
-
|
| 687 |
-
edge_labels = nx.get_edge_attributes(G, 'label')
|
| 688 |
nx.draw_networkx_edge_labels(
|
| 689 |
G, pos,
|
| 690 |
-
edge_labels=
|
| 691 |
font_size=8,
|
| 692 |
-
font_family=
|
| 693 |
)
|
| 694 |
|
| 695 |
plt.axis('off')
|
| 696 |
-
|
| 697 |
-
|
| 698 |
-
|
| 699 |
-
|
| 700 |
-
|
| 701 |
-
|
| 702 |
-
# 打印路径以方便调试
|
| 703 |
-
logging.info(f"Saving graph image to {output_path}")
|
| 704 |
-
|
| 705 |
-
plt.savefig(output_path, bbox_inches='tight', pad_inches=0.1)
|
| 706 |
plt.close()
|
| 707 |
-
|
| 708 |
-
return
|
| 709 |
-
|
| 710 |
except Exception as e:
|
| 711 |
-
logging.error(f"
|
| 712 |
return None
|
| 713 |
|
| 714 |
|
| 715 |
-
# ======================== 文件处理 ========================
|
| 716 |
def process_file(file, model_type="bert"):
|
| 717 |
try:
|
| 718 |
with open(file.name, 'rb') as f:
|
|
@@ -864,11 +751,13 @@ with gr.Blocks(css="""
|
|
| 864 |
|
| 865 |
with gr.Tab("📄 文本分析"):
|
| 866 |
input_text = gr.Textbox(lines=6, label="输入文本")
|
| 867 |
-
|
| 868 |
-
|
| 869 |
-
|
| 870 |
-
|
| 871 |
-
|
|
|
|
|
|
|
| 872 |
|
| 873 |
with gr.Tab("🗂 文件分析"):
|
| 874 |
file_input = gr.File(file_types=[".txt", ".json"])
|
|
@@ -903,10 +792,4 @@ with gr.Blocks(css="""
|
|
| 903 |
import_output = gr.Textbox(label="导入日志")
|
| 904 |
import_btn.click(fn=lambda: import_dataset(dataset_path.value), outputs=import_output)
|
| 905 |
|
| 906 |
-
gr.Markdown("### 测试 Memgraph 数据库连接")
|
| 907 |
-
memgraph_btn = gr.Button("测试 Memgraph 连接")
|
| 908 |
-
memgraph_output = gr.Textbox(label="连接测试日志")
|
| 909 |
-
memgraph_btn.click(fn=hello_memgraph, outputs=memgraph_output)
|
| 910 |
-
|
| 911 |
-
|
| 912 |
demo.launch(server_name="0.0.0.0", server_port=7860)
|
|
|
|
| 7 |
import chardet
|
| 8 |
from sklearn.metrics import precision_score, recall_score, f1_score
|
| 9 |
import time
|
| 10 |
+
from functools import lru_cache # 添加这行导入
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
# ======================== 数据库模块 ========================
|
| 12 |
from sqlalchemy import create_engine
|
| 13 |
from sqlalchemy.orm import sessionmaker
|
|
|
|
| 18 |
import pandas as pd
|
| 19 |
import matplotlib.pyplot as plt
|
| 20 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
# 配置日志
|
| 22 |
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
| 23 |
|
|
|
|
| 508 |
return ent_text, rel_text, kg_text, f"{total_duration:.2f} 秒"
|
| 509 |
|
| 510 |
# ======================== 知识图谱可视化 ========================
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 511 |
def generate_kg_image(entities, relations):
|
| 512 |
"""
|
| 513 |
+
生成知识图谱的图片并保存到临时文件(Hugging Face适配版)
|
| 514 |
"""
|
| 515 |
try:
|
| 516 |
+
import tempfile
|
| 517 |
+
import matplotlib.pyplot as plt
|
| 518 |
+
import networkx as nx
|
| 519 |
+
import os
|
| 520 |
+
|
| 521 |
+
# === 1. 强制设置中文字体 ===
|
| 522 |
+
plt.rcParams['font.sans-serif'] = ['Noto Sans CJK SC'] # Hugging Face内置字体
|
| 523 |
+
plt.rcParams['axes.unicode_minus'] = False # 解决负号显示问题
|
|
|
|
| 524 |
|
| 525 |
+
# === 2. 检查输入数据 ===
|
| 526 |
+
if not entities or not relations:
|
| 527 |
+
return None
|
| 528 |
|
| 529 |
+
# === 3. 创建图谱 ===
|
| 530 |
G = nx.DiGraph()
|
| 531 |
entity_colors = {
|
| 532 |
+
'PER': '#FF6B6B', # 红色
|
| 533 |
+
'ORG': '#4ECDC4', # 青色
|
| 534 |
+
'LOC': '#45B7D1', # 蓝色
|
| 535 |
+
'TIME': '#96CEB4', # 绿色
|
| 536 |
+
'TITLE': '#D4A5A5' # 灰色
|
| 537 |
}
|
| 538 |
|
| 539 |
+
# 添加节点(实体)
|
| 540 |
for entity in entities:
|
| 541 |
G.add_node(
|
| 542 |
entity["text"],
|
|
|
|
| 544 |
color=entity_colors.get(entity['type'], '#D3D3D3')
|
| 545 |
)
|
| 546 |
|
| 547 |
+
# 添加边(关系)
|
| 548 |
for relation in relations:
|
| 549 |
if relation["head"] in G.nodes and relation["tail"] in G.nodes:
|
| 550 |
G.add_edge(
|
|
|
|
| 553 |
label=relation["relation"]
|
| 554 |
)
|
| 555 |
|
| 556 |
+
# === 4. 绘图配置 ===
|
| 557 |
+
plt.figure(figsize=(12, 8), dpi=150) # 降低DPI以节省内存
|
| 558 |
+
pos = nx.spring_layout(G, k=0.7, seed=42) # 固定随机种子保证布局稳定
|
| 559 |
|
| 560 |
+
# 绘制节点和边
|
| 561 |
nx.draw_networkx_nodes(
|
| 562 |
G, pos,
|
| 563 |
node_color=[G.nodes[n]['color'] for n in G.nodes],
|
| 564 |
+
node_size=800
|
|
|
|
| 565 |
)
|
| 566 |
nx.draw_networkx_edges(
|
| 567 |
G, pos,
|
|
|
|
| 571 |
arrowsize=20
|
| 572 |
)
|
| 573 |
|
| 574 |
+
# === 5. 绘制中文标签(关键修改点)===
|
| 575 |
nx.draw_networkx_labels(
|
| 576 |
G, pos,
|
| 577 |
+
labels={n: G.nodes[n]['label'] for n in G.nodes},
|
| 578 |
font_size=10,
|
| 579 |
+
font_family='Noto Sans CJK SC' # 显式指定字体
|
|
|
|
| 580 |
)
|
|
|
|
|
|
|
| 581 |
nx.draw_networkx_edge_labels(
|
| 582 |
G, pos,
|
| 583 |
+
edge_labels=nx.get_edge_attributes(G, 'label'),
|
| 584 |
font_size=8,
|
| 585 |
+
font_family='Noto Sans CJK SC' # 显式指定字体
|
| 586 |
)
|
| 587 |
|
| 588 |
plt.axis('off')
|
| 589 |
+
|
| 590 |
+
# === 6. 保存到临时文件 ===
|
| 591 |
+
temp_dir = tempfile.mkdtemp()
|
| 592 |
+
file_path = os.path.join(temp_dir, "kg.png")
|
| 593 |
+
plt.savefig(file_path, bbox_inches='tight')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 594 |
plt.close()
|
| 595 |
+
|
| 596 |
+
return file_path
|
| 597 |
+
|
| 598 |
except Exception as e:
|
| 599 |
+
logging.error(f"生成知识图谱图片失败: {str(e)}")
|
| 600 |
return None
|
| 601 |
|
| 602 |
|
|
|
|
| 603 |
def process_file(file, model_type="bert"):
|
| 604 |
try:
|
| 605 |
with open(file.name, 'rb') as f:
|
|
|
|
| 751 |
|
| 752 |
with gr.Tab("📄 文本分析"):
|
| 753 |
input_text = gr.Textbox(lines=6, label="输入文本")
|
| 754 |
+
model_type = gr.Radio(["bert", "chatglm"], value="bert", label="选择模型")
|
| 755 |
+
btn = gr.Button("开始分析")
|
| 756 |
+
out1 = gr.Textbox(label="识别实体")
|
| 757 |
+
out2 = gr.Textbox(label="识别关系")
|
| 758 |
+
out3 = gr.HTML(label="知识图谱") # 使用HTML组件显示文本格式的知识图谱
|
| 759 |
+
out4 = gr.Textbox(label="耗时")
|
| 760 |
+
btn.click(fn=process_text, inputs=[input_text, model_type], outputs=[out1, out2, out3, out4])
|
| 761 |
|
| 762 |
with gr.Tab("🗂 文件分析"):
|
| 763 |
file_input = gr.File(file_types=[".txt", ".json"])
|
|
|
|
| 792 |
import_output = gr.Textbox(label="导入日志")
|
| 793 |
import_btn.click(fn=lambda: import_dataset(dataset_path.value), outputs=import_output)
|
| 794 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 795 |
demo.launch(server_name="0.0.0.0", server_port=7860)
|