chen666-666 commited on
Commit
51333c9
·
verified ·
1 Parent(s): a455064

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +75 -26
app.py CHANGED
@@ -7,7 +7,52 @@ import json
7
  import chardet
8
  from sklearn.metrics import precision_score, recall_score, f1_score
9
  import time
10
- from functools import lru_cache # 添加这行导入
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  # ======================== 数据库模块 ========================
12
  from sqlalchemy import create_engine
13
  from sqlalchemy.orm import sessionmaker
@@ -27,30 +72,36 @@ MEMGRAPH_USERNAME = '[email protected]'
27
  MEMGRAPH_PASSWORD = os.environ.get("MEMGRAPH_PASSWORD", "<YOUR MEMGRAPH PASSWORD HERE>")
28
 
29
  def hello_memgraph():
30
- """
31
- 测试 Memgraph 数据库连接并创建一个节点
32
- """
33
  try:
34
- # 初始化连接
35
  connection = Memgraph(
36
- host=MEMGRAPH_HOST,
37
- port=MEMGRAPH_PORT,
38
- username=MEMGRAPH_USERNAME,
39
- password=MEMGRAPH_PASSWORD,
40
- encrypted=True
 
 
41
  )
42
- # 执行查询
43
- results = connection.execute_and_fetch(
44
- 'CREATE (n:FirstNode { message: "Hello Memgraph from Python!" }) RETURN n.message AS message'
 
 
 
 
 
 
45
  )
46
- return f"成功创建节点: {next(results)['message']}"
 
 
47
  except Exception as e:
48
- logging.error(f"❌ Memgraph 连接失败: {e}")
49
- return f"连接失败: {str(e)}"
50
  finally:
51
- # 确保连接关闭
52
- if 'connection' in locals():
53
- connection.close()
54
 
55
  # 配置日志
56
  logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
@@ -813,13 +864,11 @@ with gr.Blocks(css="""
813
 
814
  with gr.Tab("📄 文本分析"):
815
  input_text = gr.Textbox(lines=6, label="输入文本")
816
- model_type = gr.Radio(["bert", "chatglm"], value="bert", label="选择模型")
817
- btn = gr.Button("开始分析")
818
- out1 = gr.Textbox(label="识别实体")
819
- out2 = gr.Textbox(label="识别关系")
820
- out3 = gr.HTML(label="知识图谱") # 使用HTML组件显示文本格式的知识图谱
821
- out4 = gr.Textbox(label="耗时")
822
- btn.click(fn=process_text, inputs=[input_text, model_type], outputs=[out1, out2, out3, out4])
823
 
824
  with gr.Tab("🗂 文件分析"):
825
  file_input = gr.File(file_types=[".txt", ".json"])
 
7
  import chardet
8
  from sklearn.metrics import precision_score, recall_score, f1_score
9
  import time
10
+ from functools import lru_cache
11
+ from sqlalchemy import create_engine
12
+ from sqlalchemy.orm import sessionmaker
13
+ from contextlib import contextmanager
14
+ import logging
15
+ import networkx as nx
16
+ from pyvis.network import Network
17
+ import pandas as pd
18
+ import matplotlib.pyplot as plt
19
+ from gqlalchemy import Memgraph
20
+ from mcp_use import RelationPredictor, insert_to_memgraph, get_memgraph_conn # 引入mcp_use中的功能
21
+ from relation_extraction.hparams import hparams # 引入模型超参数
22
+
23
+ # ======================== 数据库模块 ========================
24
+ MEMGRAPH_HOST = '18.159.132.161'
25
+ MEMGRAPH_PORT = 7687
26
+ MEMGRAPH_USERNAME = '[email protected]'
27
+ MEMGRAPH_PASSWORD = os.environ.get("MEMGRAPH_PASSWORD", "<YOUR MEMGRAPH PASSWORD HERE>")
28
+
29
+ # 初始化 Memgraph 连接
30
+ memgraph = get_memgraph_conn()
31
+
32
+ # 初始化关系抽取模型
33
+ relation_predictor = RelationPredictor(hparams)
34
+
35
+ # ======================== 关系抽取功能整合 ========================
36
+ def extract_and_save_relations(text, entity1, entity2):
37
+ """
38
+ 使用 mcp_use.py 中的 RelationPredictor 提取关系,并保存到 Memgraph
39
+ """
40
+ try:
41
+ # 调用关系抽取模型
42
+ result = relation_predictor.predict_one(text, entity1, entity2)
43
+ if result is None:
44
+ return f"❌ 未找到实体 '{entity1}' 或 '{entity2}'"
45
+
46
+ # 提取关系
47
+ entity1, relation, entity2 = result
48
+
49
+ # 保存到 Memgraph
50
+ insert_to_memgraph(memgraph, entity1, relation, entity2)
51
+ return f"✅ 已写入 Memgraph:({entity1})-[:{relation}]->({entity2})"
52
+ except Exception as e:
53
+ logging.error(f"关系抽取失败: {e}")
54
+ return f"❌ 关系抽取失败: {e}"
55
+
56
  # ======================== 数据库模块 ========================
57
  from sqlalchemy import create_engine
58
  from sqlalchemy.orm import sessionmaker
 
72
  MEMGRAPH_PASSWORD = os.environ.get("MEMGRAPH_PASSWORD", "<YOUR MEMGRAPH PASSWORD HERE>")
73
 
74
  def hello_memgraph():
75
+ """测试 Memgraph 数据库连接并进行健康检查"""
 
 
76
  try:
 
77
  connection = Memgraph(
78
+ host=os.environ["MEMGRAPH_HOST"],
79
+ port=int(os.environ["MEMGRAPH_PORT"]),
80
+ username=os.environ["MEMGRAPH_USERNAME"],
81
+ password=os.environ["MEMGRAPH_PASSWORD"], # 强制从环境变量获取
82
+ encrypted=True,
83
+ ssl_verify=True,
84
+ ca_path="/etc/ssl/certs/memgraph.crt"
85
  )
86
+
87
+ # 健康检查查询
88
+ health = connection.execute_and_fetch("CALL mg.get('memgraph') YIELD value;")
89
+ health_status = next(health)["value"]["status"]
90
+
91
+ # 创建测试节点
92
+ connection.execute(
93
+ 'CREATE (n:ConnectionTest { message: "Hello Memgraph", ts: $ts })',
94
+ {"ts": datetime.now().isoformat()}
95
  )
96
+
97
+ return f"✅ 连接正常 | 状态: {health_status}"
98
+
99
  except Exception as e:
100
+ logging.error(f"连接失败: {str(e)}", exc_info=True)
101
+ return f"连接失败: {str(e)}"
102
  finally:
103
+ connection.close()
104
+
 
105
 
106
  # 配置日志
107
  logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
 
864
 
865
  with gr.Tab("📄 文本分析"):
866
  input_text = gr.Textbox(lines=6, label="输入文本")
867
+ entity1 = gr.Textbox(label="实体1")
868
+ entity2 = gr.Textbox(label="实体2")
869
+ btn = gr.Button("提取关系并保存到 Memgraph")
870
+ output = gr.Textbox(label="结果")
871
+ btn.click(fn=extract_and_save_relations, inputs=[input_text, entity1, entity2], outputs=output)
 
 
872
 
873
  with gr.Tab("🗂 文件分析"):
874
  file_input = gr.File(file_types=[".txt", ".json"])