Spaces:
Sleeping
Sleeping
Upload app.py
Browse files
app.py
CHANGED
|
@@ -7,7 +7,52 @@ import json
|
|
| 7 |
import chardet
|
| 8 |
from sklearn.metrics import precision_score, recall_score, f1_score
|
| 9 |
import time
|
| 10 |
-
from functools import lru_cache
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
# ======================== 数据库模块 ========================
|
| 12 |
from sqlalchemy import create_engine
|
| 13 |
from sqlalchemy.orm import sessionmaker
|
|
@@ -27,30 +72,36 @@ MEMGRAPH_USERNAME = '[email protected]'
|
|
| 27 |
MEMGRAPH_PASSWORD = os.environ.get("MEMGRAPH_PASSWORD", "<YOUR MEMGRAPH PASSWORD HERE>")
|
| 28 |
|
| 29 |
def hello_memgraph():
|
| 30 |
-
"""
|
| 31 |
-
测试 Memgraph 数据库连接并创建一个节点
|
| 32 |
-
"""
|
| 33 |
try:
|
| 34 |
-
# 初始化连接
|
| 35 |
connection = Memgraph(
|
| 36 |
-
host=MEMGRAPH_HOST,
|
| 37 |
-
port=MEMGRAPH_PORT,
|
| 38 |
-
username=MEMGRAPH_USERNAME,
|
| 39 |
-
password=MEMGRAPH_PASSWORD,
|
| 40 |
-
encrypted=True
|
|
|
|
|
|
|
| 41 |
)
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
)
|
| 46 |
-
|
|
|
|
|
|
|
| 47 |
except Exception as e:
|
| 48 |
-
logging.error(f"
|
| 49 |
-
return f"连接失败: {str(e)}"
|
| 50 |
finally:
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
connection.close()
|
| 54 |
|
| 55 |
# 配置日志
|
| 56 |
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
|
@@ -813,13 +864,11 @@ with gr.Blocks(css="""
|
|
| 813 |
|
| 814 |
with gr.Tab("📄 文本分析"):
|
| 815 |
input_text = gr.Textbox(lines=6, label="输入文本")
|
| 816 |
-
|
| 817 |
-
|
| 818 |
-
|
| 819 |
-
|
| 820 |
-
|
| 821 |
-
out4 = gr.Textbox(label="耗时")
|
| 822 |
-
btn.click(fn=process_text, inputs=[input_text, model_type], outputs=[out1, out2, out3, out4])
|
| 823 |
|
| 824 |
with gr.Tab("🗂 文件分析"):
|
| 825 |
file_input = gr.File(file_types=[".txt", ".json"])
|
|
|
|
| 7 |
import chardet
|
| 8 |
from sklearn.metrics import precision_score, recall_score, f1_score
|
| 9 |
import time
|
| 10 |
+
from functools import lru_cache
|
| 11 |
+
from sqlalchemy import create_engine
|
| 12 |
+
from sqlalchemy.orm import sessionmaker
|
| 13 |
+
from contextlib import contextmanager
|
| 14 |
+
import logging
|
| 15 |
+
import networkx as nx
|
| 16 |
+
from pyvis.network import Network
|
| 17 |
+
import pandas as pd
|
| 18 |
+
import matplotlib.pyplot as plt
|
| 19 |
+
from gqlalchemy import Memgraph
|
| 20 |
+
from mcp_use import RelationPredictor, insert_to_memgraph, get_memgraph_conn # 引入mcp_use中的功能
|
| 21 |
+
from relation_extraction.hparams import hparams # 引入模型超参数
|
| 22 |
+
|
| 23 |
+
# ======================== 数据库模块 ========================
|
| 24 |
+
MEMGRAPH_HOST = '18.159.132.161'
|
| 25 |
+
MEMGRAPH_PORT = 7687
|
| 26 |
+
MEMGRAPH_USERNAME = '[email protected]'
|
| 27 |
+
MEMGRAPH_PASSWORD = os.environ.get("MEMGRAPH_PASSWORD", "<YOUR MEMGRAPH PASSWORD HERE>")
|
| 28 |
+
|
| 29 |
+
# 初始化 Memgraph 连接
|
| 30 |
+
memgraph = get_memgraph_conn()
|
| 31 |
+
|
| 32 |
+
# 初始化关系抽取模型
|
| 33 |
+
relation_predictor = RelationPredictor(hparams)
|
| 34 |
+
|
| 35 |
+
# ======================== 关系抽取功能整合 ========================
|
| 36 |
+
def extract_and_save_relations(text, entity1, entity2):
|
| 37 |
+
"""
|
| 38 |
+
使用 mcp_use.py 中的 RelationPredictor 提取关系,并保存到 Memgraph
|
| 39 |
+
"""
|
| 40 |
+
try:
|
| 41 |
+
# 调用关系抽取模型
|
| 42 |
+
result = relation_predictor.predict_one(text, entity1, entity2)
|
| 43 |
+
if result is None:
|
| 44 |
+
return f"❌ 未找到实体 '{entity1}' 或 '{entity2}'"
|
| 45 |
+
|
| 46 |
+
# 提取关系
|
| 47 |
+
entity1, relation, entity2 = result
|
| 48 |
+
|
| 49 |
+
# 保存到 Memgraph
|
| 50 |
+
insert_to_memgraph(memgraph, entity1, relation, entity2)
|
| 51 |
+
return f"✅ 已写入 Memgraph:({entity1})-[:{relation}]->({entity2})"
|
| 52 |
+
except Exception as e:
|
| 53 |
+
logging.error(f"关系抽取失败: {e}")
|
| 54 |
+
return f"❌ 关系抽取失败: {e}"
|
| 55 |
+
|
| 56 |
# ======================== 数据库模块 ========================
|
| 57 |
from sqlalchemy import create_engine
|
| 58 |
from sqlalchemy.orm import sessionmaker
|
|
|
|
| 72 |
MEMGRAPH_PASSWORD = os.environ.get("MEMGRAPH_PASSWORD", "<YOUR MEMGRAPH PASSWORD HERE>")
|
| 73 |
|
| 74 |
def hello_memgraph():
|
| 75 |
+
"""测试 Memgraph 数据库连接并进行健康检查"""
|
|
|
|
|
|
|
| 76 |
try:
|
|
|
|
| 77 |
connection = Memgraph(
|
| 78 |
+
host=os.environ["MEMGRAPH_HOST"],
|
| 79 |
+
port=int(os.environ["MEMGRAPH_PORT"]),
|
| 80 |
+
username=os.environ["MEMGRAPH_USERNAME"],
|
| 81 |
+
password=os.environ["MEMGRAPH_PASSWORD"], # 强制从环境变量获取
|
| 82 |
+
encrypted=True,
|
| 83 |
+
ssl_verify=True,
|
| 84 |
+
ca_path="/etc/ssl/certs/memgraph.crt"
|
| 85 |
)
|
| 86 |
+
|
| 87 |
+
# 健康检查查询
|
| 88 |
+
health = connection.execute_and_fetch("CALL mg.get('memgraph') YIELD value;")
|
| 89 |
+
health_status = next(health)["value"]["status"]
|
| 90 |
+
|
| 91 |
+
# 创建测试节点
|
| 92 |
+
connection.execute(
|
| 93 |
+
'CREATE (n:ConnectionTest { message: "Hello Memgraph", ts: $ts })',
|
| 94 |
+
{"ts": datetime.now().isoformat()}
|
| 95 |
)
|
| 96 |
+
|
| 97 |
+
return f"✅ 连接正常 | 状态: {health_status}"
|
| 98 |
+
|
| 99 |
except Exception as e:
|
| 100 |
+
logging.error(f"连接失败: {str(e)}", exc_info=True)
|
| 101 |
+
return f"❌ 连接失败: {str(e)}"
|
| 102 |
finally:
|
| 103 |
+
connection.close()
|
| 104 |
+
|
|
|
|
| 105 |
|
| 106 |
# 配置日志
|
| 107 |
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
|
|
|
| 864 |
|
| 865 |
with gr.Tab("📄 文本分析"):
|
| 866 |
input_text = gr.Textbox(lines=6, label="输入文本")
|
| 867 |
+
entity1 = gr.Textbox(label="实体1")
|
| 868 |
+
entity2 = gr.Textbox(label="实体2")
|
| 869 |
+
btn = gr.Button("提取关系并保存到 Memgraph")
|
| 870 |
+
output = gr.Textbox(label="结果")
|
| 871 |
+
btn.click(fn=extract_and_save_relations, inputs=[input_text, entity1, entity2], outputs=output)
|
|
|
|
|
|
|
| 872 |
|
| 873 |
with gr.Tab("🗂 文件分析"):
|
| 874 |
file_input = gr.File(file_types=[".txt", ".json"])
|