chen666-666 commited on
Commit
81c0f3c
·
verified ·
1 Parent(s): f4befb6

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -18
app.py CHANGED
@@ -508,34 +508,38 @@ def process_text(text, model_type="bert"):
508
  return ent_text, rel_text, kg_text, f"{total_duration:.2f} 秒"
509
 
510
 
511
- from matplotlib.figure import Figure
512
- from io import BytesIO
513
 
514
  def generate_kg_image(entities, relations):
515
  """
516
- 生成知识图谱的图片
517
  """
518
- fig = Figure(figsize=(10, 8))
519
- ax = fig.add_subplot(1, 1, 1)
520
- ax.axis("off")
521
-
522
  # 创建网络图
523
  G = nx.DiGraph()
524
  for entity in entities:
525
  G.add_node(entity["text"], label=entity["type"])
526
  for relation in relations:
527
  G.add_edge(relation["head"], relation["tail"], label=relation["relation"])
528
-
529
  # 绘制网络图
530
  pos = nx.spring_layout(G)
531
- nx.draw(G, pos, with_labels=True, node_color="lightblue", ax=ax)
532
- nx.draw_networkx_edge_labels(G, pos, edge_labels={(u, v): d["label"] for u, v, d in G.edges(data=True)}, ax=ax)
533
-
534
- # 保存为图片
535
- buf = BytesIO()
536
- fig.savefig(buf, format="png")
537
- buf.seek(0)
538
- return buf
 
 
 
 
 
 
539
 
540
  def process_file(file, model_type="bert"):
541
  try:
@@ -566,9 +570,9 @@ def process_file(file, model_type="bert"):
566
  # 生成知识图谱图片
567
  entities, _ = ner(text, model_type)
568
  relations, _ = re_extract(entities, text)
569
- kg_image = generate_kg_image(entities, relations)
570
 
571
- return ent_text, rel_text, kg_text, duration, kg_image
572
 
573
  except Exception as e:
574
  logging.error(f"文件处理错误: {str(e)}")
 
508
  return ent_text, rel_text, kg_text, f"{total_duration:.2f} 秒"
509
 
510
 
511
+ import os
 
512
 
513
  def generate_kg_image(entities, relations):
514
  """
515
+ 生成知识图谱的图片并保存到文件
516
  """
517
+ import matplotlib.pyplot as plt
518
+ import networkx as nx
519
+
 
520
  # 创建网络图
521
  G = nx.DiGraph()
522
  for entity in entities:
523
  G.add_node(entity["text"], label=entity["type"])
524
  for relation in relations:
525
  G.add_edge(relation["head"], relation["tail"], label=relation["relation"])
526
+
527
  # 绘制网络图
528
  pos = nx.spring_layout(G)
529
+ plt.figure(figsize=(10, 8))
530
+ nx.draw(G, pos, with_labels=True, node_color="lightblue")
531
+ nx.draw_networkx_edge_labels(G, pos, edge_labels={(u, v): d["label"] for u, v, d in G.edges(data=True)})
532
+
533
+ # 确保保存目录存在
534
+ save_dir = "D:/image"
535
+ if not os.path.exists(save_dir):
536
+ os.makedirs(save_dir) # 自动创建文件夹
537
+
538
+ # 保存图片到指定目录
539
+ file_path = os.path.join(save_dir, f"knowledge_graph_{int(time.time())}.png")
540
+ plt.savefig(file_path)
541
+ plt.close()
542
+ return file_path
543
 
544
  def process_file(file, model_type="bert"):
545
  try:
 
570
  # 生成知识图谱图片
571
  entities, _ = ner(text, model_type)
572
  relations, _ = re_extract(entities, text)
573
+ kg_image_path = generate_kg_image(entities, relations) # 返回文件路径
574
 
575
+ return ent_text, rel_text, kg_text, duration, kg_image_path
576
 
577
  except Exception as e:
578
  logging.error(f"文件处理错误: {str(e)}")