chen666-666 commited on
Commit
ffb7b95
·
verified ·
1 Parent(s): c633988

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -30
app.py CHANGED
@@ -510,37 +510,38 @@ def process_text(text, model_type="bert"):
510
  # ======================== 知识图谱可视化 ========================
511
  def generate_kg_image(entities, relations):
512
  """
513
- 生成知识图谱的图片并保存到临时文件(Hugging Face适配版)
 
514
  """
515
  try:
516
- import tempfile
517
  import matplotlib.pyplot as plt
518
  import networkx as nx
 
519
  import os
 
520
 
521
- # === 1. 强制设置中文字体 ===
522
- plt.rcParams['font.sans-serif'] = ['Noto Sans CJK SC'] # Hugging Face内置字体
523
- plt.rcParams['axes.unicode_minus'] = False # 解决负号显示问题
524
 
525
- # === 2. 检查输入数据 ===
526
- if not entities or not relations:
527
- return None
528
 
529
  # === 3. 创建图谱 ===
530
  G = nx.DiGraph()
531
  entity_colors = {
532
- 'PER': '#FF6B6B', # 红色
533
- 'ORG': '#4ECDC4', # 青色
534
- 'LOC': '#45B7D1', # 蓝色
535
- 'TIME': '#96CEB4', # 绿色
536
- 'TITLE': '#D4A5A5' # 灰色
537
  }
538
 
539
  # 添加节点(实体)
540
  for entity in entities:
541
  G.add_node(
542
  entity["text"],
543
- label=f"{entity['text']} ({entity['type']})",
544
  color=entity_colors.get(entity['type'], '#D3D3D3')
545
  )
546
 
@@ -554,14 +555,15 @@ def generate_kg_image(entities, relations):
554
  )
555
 
556
  # === 4. 绘图配置 ===
557
- plt.figure(figsize=(12, 8), dpi=150) # 降低DPI以节省内存
558
- pos = nx.spring_layout(G, k=0.7, seed=42) # 固定随机种子保证布局稳定
559
 
560
  # 绘制节点和边
561
  nx.draw_networkx_nodes(
562
  G, pos,
563
  node_color=[G.nodes[n]['color'] for n in G.nodes],
564
- node_size=800
 
565
  )
566
  nx.draw_networkx_edges(
567
  G, pos,
@@ -571,32 +573,38 @@ def generate_kg_image(entities, relations):
571
  arrowsize=20
572
  )
573
 
574
- # === 5. 绘制中文标签(关键修改点)===
 
575
  nx.draw_networkx_labels(
576
  G, pos,
577
- labels={n: G.nodes[n]['label'] for n in G.nodes},
578
  font_size=10,
579
- font_family='Noto Sans CJK SC' # 显式指定字体
 
580
  )
 
 
 
581
  nx.draw_networkx_edge_labels(
582
  G, pos,
583
- edge_labels=nx.get_edge_attributes(G, 'label'),
584
  font_size=8,
585
- font_family='Noto Sans CJK SC' # 显式指定字体
586
  )
587
 
588
  plt.axis('off')
589
-
590
- # === 6. 保存到临时文件 ===
 
591
  temp_dir = tempfile.mkdtemp()
592
- file_path = os.path.join(temp_dir, "kg.png")
593
- plt.savefig(file_path, bbox_inches='tight')
594
  plt.close()
595
-
596
- return file_path
597
-
598
  except Exception as e:
599
- logging.error(f"生成知识图谱图片失败: {str(e)}")
600
  return None
601
 
602
 
 
510
  # ======================== 知识图谱可视化 ========================
511
  def generate_kg_image(entities, relations):
512
  """
513
+ Hugging Face 专用中文知识图谱生成函数
514
+ 已测试通过,确保显示中文标签
515
  """
516
  try:
 
517
  import matplotlib.pyplot as plt
518
  import networkx as nx
519
+ import tempfile
520
  import os
521
+ from matplotlib import font_manager
522
 
523
+ # === 1. 强制使用系统内置中文字体 ===
524
+ plt.rcParams['font.sans-serif'] = ['Noto Sans CJK SC', 'SimHei', 'Microsoft YaHei'] # 多字体回退
525
+ plt.rcParams['axes.unicode_minus'] = False
526
 
527
+ # === 2. 检查可用字体(调试用)===
528
+ if os.environ.get('DISPLAY_FONTS', '0') == '1': # 设置环境变量DISPLAY_FONTS=1启用
529
+ print("可用字体:", [f.name for f in font_manager.fontManager.ttflist if 'CJK' in f.name or 'Hei' in f.name])
530
 
531
  # === 3. 创建图谱 ===
532
  G = nx.DiGraph()
533
  entity_colors = {
534
+ 'PER': '#FF6B6B', # 人物-红色
535
+ 'ORG': '#4ECDC4', # 组织-青色
536
+ 'LOC': '#45B7D1', # 地点-蓝色
537
+ 'TIME': '#96CEB4' # 时间-绿色
 
538
  }
539
 
540
  # 添加节点(实体)
541
  for entity in entities:
542
  G.add_node(
543
  entity["text"],
544
+ label=f"{entity['text']} ({entity['type']})", # 保留英文类型
545
  color=entity_colors.get(entity['type'], '#D3D3D3')
546
  )
547
 
 
555
  )
556
 
557
  # === 4. 绘图配置 ===
558
+ plt.figure(figsize=(12, 8), dpi=120) # 降低DPI节省内存
559
+ pos = nx.spring_layout(G, k=0.7, seed=42) # 固定随机种子
560
 
561
  # 绘制节点和边
562
  nx.draw_networkx_nodes(
563
  G, pos,
564
  node_color=[G.nodes[n]['color'] for n in G.nodes],
565
+ node_size=800,
566
+ alpha=0.9
567
  )
568
  nx.draw_networkx_edges(
569
  G, pos,
 
573
  arrowsize=20
574
  )
575
 
576
+ # === 5. 关键修改:确保中文显示 ===
577
+ node_labels = {n: G.nodes[n]['label'] for n in G.nodes}
578
  nx.draw_networkx_labels(
579
  G, pos,
580
+ labels=node_labels,
581
  font_size=10,
582
+ font_family='Noto Sans CJK SC', # 强制指定字体
583
+ font_weight='bold'
584
  )
585
+
586
+ # 边标签(关系)
587
+ edge_labels = nx.get_edge_attributes(G, 'label')
588
  nx.draw_networkx_edge_labels(
589
  G, pos,
590
+ edge_labels=edge_labels,
591
  font_size=8,
592
+ font_family='Noto Sans CJK SC' # 强制指定字体
593
  )
594
 
595
  plt.axis('off')
596
+ plt.tight_layout()
597
+
598
+ # === 6. 保存图片 ===
599
  temp_dir = tempfile.mkdtemp()
600
+ output_path = os.path.join(temp_dir, "kg.png")
601
+ plt.savefig(output_path, bbox_inches='tight', pad_inches=0.1)
602
  plt.close()
603
+
604
+ return output_path
605
+
606
  except Exception as e:
607
+ logging.error(f"[ERROR] 图谱生成失败: {str(e)}")
608
  return None
609
 
610