Spaces:
Sleeping
Sleeping
Upload app.py
Browse files
app.py
CHANGED
@@ -7,7 +7,52 @@ import json
|
|
7 |
import chardet
|
8 |
from sklearn.metrics import precision_score, recall_score, f1_score
|
9 |
import time
|
10 |
-
from functools import lru_cache
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
# ======================== 数据库模块 ========================
|
12 |
from sqlalchemy import create_engine
|
13 |
from sqlalchemy.orm import sessionmaker
|
@@ -27,30 +72,36 @@ MEMGRAPH_USERNAME = '[email protected]'
|
|
27 |
MEMGRAPH_PASSWORD = os.environ.get("MEMGRAPH_PASSWORD", "<YOUR MEMGRAPH PASSWORD HERE>")
|
28 |
|
29 |
def hello_memgraph():
|
30 |
-
"""
|
31 |
-
测试 Memgraph 数据库连接并创建一个节点
|
32 |
-
"""
|
33 |
try:
|
34 |
-
# 初始化连接
|
35 |
connection = Memgraph(
|
36 |
-
host=MEMGRAPH_HOST,
|
37 |
-
port=MEMGRAPH_PORT,
|
38 |
-
username=MEMGRAPH_USERNAME,
|
39 |
-
password=MEMGRAPH_PASSWORD,
|
40 |
-
encrypted=True
|
|
|
|
|
41 |
)
|
42 |
-
|
43 |
-
|
44 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
)
|
46 |
-
|
|
|
|
|
47 |
except Exception as e:
|
48 |
-
logging.error(f"
|
49 |
-
return f"连接失败: {str(e)}"
|
50 |
finally:
|
51 |
-
|
52 |
-
|
53 |
-
connection.close()
|
54 |
|
55 |
# 配置日志
|
56 |
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
@@ -813,13 +864,11 @@ with gr.Blocks(css="""
|
|
813 |
|
814 |
with gr.Tab("📄 文本分析"):
|
815 |
input_text = gr.Textbox(lines=6, label="输入文本")
|
816 |
-
|
817 |
-
|
818 |
-
|
819 |
-
|
820 |
-
|
821 |
-
out4 = gr.Textbox(label="耗时")
|
822 |
-
btn.click(fn=process_text, inputs=[input_text, model_type], outputs=[out1, out2, out3, out4])
|
823 |
|
824 |
with gr.Tab("🗂 文件分析"):
|
825 |
file_input = gr.File(file_types=[".txt", ".json"])
|
|
|
7 |
import chardet
|
8 |
from sklearn.metrics import precision_score, recall_score, f1_score
|
9 |
import time
|
10 |
+
from functools import lru_cache
|
11 |
+
from sqlalchemy import create_engine
|
12 |
+
from sqlalchemy.orm import sessionmaker
|
13 |
+
from contextlib import contextmanager
|
14 |
+
import logging
|
15 |
+
import networkx as nx
|
16 |
+
from pyvis.network import Network
|
17 |
+
import pandas as pd
|
18 |
+
import matplotlib.pyplot as plt
|
19 |
+
from gqlalchemy import Memgraph
|
20 |
+
from mcp_use import RelationPredictor, insert_to_memgraph, get_memgraph_conn # 引入mcp_use中的功能
|
21 |
+
from relation_extraction.hparams import hparams # 引入模型超参数
|
22 |
+
|
23 |
+
# ======================== 数据库模块 ========================
|
24 |
+
MEMGRAPH_HOST = '18.159.132.161'
|
25 |
+
MEMGRAPH_PORT = 7687
|
26 |
+
MEMGRAPH_USERNAME = '[email protected]'
|
27 |
+
MEMGRAPH_PASSWORD = os.environ.get("MEMGRAPH_PASSWORD", "<YOUR MEMGRAPH PASSWORD HERE>")
|
28 |
+
|
29 |
+
# 初始化 Memgraph 连接
|
30 |
+
memgraph = get_memgraph_conn()
|
31 |
+
|
32 |
+
# 初始化关系抽取模型
|
33 |
+
relation_predictor = RelationPredictor(hparams)
|
34 |
+
|
35 |
+
# ======================== 关系抽取功能整合 ========================
|
36 |
+
def extract_and_save_relations(text, entity1, entity2):
|
37 |
+
"""
|
38 |
+
使用 mcp_use.py 中的 RelationPredictor 提取关系,并保存到 Memgraph
|
39 |
+
"""
|
40 |
+
try:
|
41 |
+
# 调用关系抽取模型
|
42 |
+
result = relation_predictor.predict_one(text, entity1, entity2)
|
43 |
+
if result is None:
|
44 |
+
return f"❌ 未找到实体 '{entity1}' 或 '{entity2}'"
|
45 |
+
|
46 |
+
# 提取关系
|
47 |
+
entity1, relation, entity2 = result
|
48 |
+
|
49 |
+
# 保存到 Memgraph
|
50 |
+
insert_to_memgraph(memgraph, entity1, relation, entity2)
|
51 |
+
return f"✅ 已写入 Memgraph:({entity1})-[:{relation}]->({entity2})"
|
52 |
+
except Exception as e:
|
53 |
+
logging.error(f"关系抽取失败: {e}")
|
54 |
+
return f"❌ 关系抽取失败: {e}"
|
55 |
+
|
56 |
# ======================== 数据库模块 ========================
|
57 |
from sqlalchemy import create_engine
|
58 |
from sqlalchemy.orm import sessionmaker
|
|
|
72 |
MEMGRAPH_PASSWORD = os.environ.get("MEMGRAPH_PASSWORD", "<YOUR MEMGRAPH PASSWORD HERE>")
|
73 |
|
74 |
def hello_memgraph():
|
75 |
+
"""测试 Memgraph 数据库连接并进行健康检查"""
|
|
|
|
|
76 |
try:
|
|
|
77 |
connection = Memgraph(
|
78 |
+
host=os.environ["MEMGRAPH_HOST"],
|
79 |
+
port=int(os.environ["MEMGRAPH_PORT"]),
|
80 |
+
username=os.environ["MEMGRAPH_USERNAME"],
|
81 |
+
password=os.environ["MEMGRAPH_PASSWORD"], # 强制从环境变量获取
|
82 |
+
encrypted=True,
|
83 |
+
ssl_verify=True,
|
84 |
+
ca_path="/etc/ssl/certs/memgraph.crt"
|
85 |
)
|
86 |
+
|
87 |
+
# 健康检查查询
|
88 |
+
health = connection.execute_and_fetch("CALL mg.get('memgraph') YIELD value;")
|
89 |
+
health_status = next(health)["value"]["status"]
|
90 |
+
|
91 |
+
# 创建测试节点
|
92 |
+
connection.execute(
|
93 |
+
'CREATE (n:ConnectionTest { message: "Hello Memgraph", ts: $ts })',
|
94 |
+
{"ts": datetime.now().isoformat()}
|
95 |
)
|
96 |
+
|
97 |
+
return f"✅ 连接正常 | 状态: {health_status}"
|
98 |
+
|
99 |
except Exception as e:
|
100 |
+
logging.error(f"连接失败: {str(e)}", exc_info=True)
|
101 |
+
return f"❌ 连接失败: {str(e)}"
|
102 |
finally:
|
103 |
+
connection.close()
|
104 |
+
|
|
|
105 |
|
106 |
# 配置日志
|
107 |
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
|
|
864 |
|
865 |
with gr.Tab("📄 文本分析"):
|
866 |
input_text = gr.Textbox(lines=6, label="输入文本")
|
867 |
+
entity1 = gr.Textbox(label="实体1")
|
868 |
+
entity2 = gr.Textbox(label="实体2")
|
869 |
+
btn = gr.Button("提取关系并保存到 Memgraph")
|
870 |
+
output = gr.Textbox(label="结果")
|
871 |
+
btn.click(fn=extract_and_save_relations, inputs=[input_text, entity1, entity2], outputs=output)
|
|
|
|
|
872 |
|
873 |
with gr.Tab("🗂 文件分析"):
|
874 |
file_input = gr.File(file_types=[".txt", ".json"])
|