chen666-666 commited on
Commit
921f671
·
verified ·
1 Parent(s): fb7b761

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +70 -66
app.py CHANGED
@@ -508,59 +508,68 @@ def process_text(text, model_type="bert"):
508
  return ent_text, rel_text, kg_text, f"{total_duration:.2f} 秒"
509
 
510
  # ======================== 知识图谱可视化 ========================
511
- import matplotlib.pyplot as plt
512
- import networkx as nx
513
- import tempfile
514
- import os
515
- import logging
516
- from matplotlib import font_manager
517
-
518
- # 这个函数用于查找并验证中文字体路径
519
- def find_chinese_font():
520
- # 尝试查找 Noto Sans CJK 字体
521
- font_paths = [
522
- "/usr/share/fonts/truetype/noto/NotoSansCJK-Regular.ttc", # Noto CJK 字体
523
- "/usr/share/fonts/truetype/wqy/wqy-microhei.ttc" # 微软雅黑
524
- ]
525
-
526
- for font_path in font_paths:
527
- if os.path.exists(font_path):
528
- logging.info(f"Found font at {font_path}")
529
- return font_path
530
-
531
- logging.error("No Chinese font found!")
532
- return None
533
-
534
  def generate_kg_image(entities, relations):
535
- """
536
- 中文知识图谱生成函数,支持自动匹配系统中的中文字体,避免中文显示为方框。
537
- """
538
  try:
539
- # === 1. 确保使用合适的中文字体 ===
540
- chinese_font = find_chinese_font() # 调用查找字体函数
541
- if chinese_font:
542
- font_prop = font_manager.FontProperties(fname=chinese_font)
543
- plt.rcParams['font.family'] = font_prop.get_name()
544
- else:
545
- # 如果字体路径未找到,使用默认字体(DejaVu Sans)
546
- logging.warning("Using default font")
547
- plt.rcParams['font.family'] = ['DejaVu Sans']
548
-
549
- plt.rcParams['axes.unicode_minus'] = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
550
 
551
  # === 2. 创建图谱 ===
552
  G = nx.DiGraph()
553
  entity_colors = {
554
- 'PER': '#FF6B6B', # 人物-红色
555
- 'ORG': '#4ECDC4', # 组织-青色
556
- 'LOC': '#45B7D1', # 地点-蓝色
557
- 'TIME': '#96CEB4' # 时间-绿色
558
  }
559
 
 
560
  for entity in entities:
561
  G.add_node(
562
  entity["text"],
563
- label=f"{entity['text']} ({entity['type']})",
564
  color=entity_colors.get(entity['type'], '#D3D3D3')
565
  )
566
 
@@ -573,14 +582,14 @@ def generate_kg_image(entities, relations):
573
  )
574
 
575
  # === 3. 绘图配置 ===
576
- plt.figure(figsize=(12, 8), dpi=150)
577
- pos = nx.spring_layout(G, k=0.7, seed=42)
578
 
 
579
  nx.draw_networkx_nodes(
580
  G, pos,
581
  node_color=[G.nodes[n]['color'] for n in G.nodes],
582
- node_size=800,
583
- alpha=0.9
584
  )
585
  nx.draw_networkx_edges(
586
  G, pos,
@@ -590,40 +599,35 @@ def generate_kg_image(entities, relations):
590
  arrowsize=20
591
  )
592
 
593
- node_labels = {n: G.nodes[n]['label'] for n in G.nodes}
594
- nx.draw_networkx_labels(
595
- G, pos,
596
- labels=node_labels,
597
- font_size=10,
598
- font_family=font_prop.get_name() if chinese_font else 'SimHei',
599
- font_weight='bold'
600
- )
601
 
 
602
  edge_labels = nx.get_edge_attributes(G, 'label')
603
  nx.draw_networkx_edge_labels(
604
  G, pos,
605
  edge_labels=edge_labels,
606
- font_size=8,
607
- font_family=font_prop.get_name() if chinese_font else 'SimHei'
608
  )
609
 
610
  plt.axis('off')
611
- plt.tight_layout()
612
-
613
- # === 4. 保存图片 ===
614
- temp_dir = tempfile.mkdtemp() # 确保在 Docker 容器中有权限写入
615
  output_path = os.path.join(temp_dir, "kg.png")
616
-
617
- # 打印路径以方便调试
618
- logging.info(f"Saving graph image to {output_path}")
619
-
620
  plt.savefig(output_path, bbox_inches='tight', pad_inches=0.1)
621
  plt.close()
622
-
623
  return output_path
624
 
625
  except Exception as e:
626
- logging.error(f"[ERROR] 图谱生成失败: {str(e)}")
627
  return None
628
 
629
 
 
508
  return ent_text, rel_text, kg_text, f"{total_duration:.2f} 秒"
509
 
510
  # ======================== 知识图谱可视化 ========================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
511
  def generate_kg_image(entities, relations):
 
 
 
512
  try:
513
+ import matplotlib.pyplot as plt
514
+ import networkx as nx
515
+ import tempfile
516
+ import os
517
+ from matplotlib import font_manager
518
+ import matplotlib
519
+
520
+ # === 1. 强制字体配置 ===
521
+ # 方法1:尝试使用系统字体
522
+ try:
523
+ # 重建字体缓存
524
+ matplotlib.font_manager._rebuild()
525
+
526
+ # 查找所有可用字体
527
+ font_paths = []
528
+ for font in font_manager.fontManager.ttflist:
529
+ if 'Noto Sans CJK' in font.name or 'SimHei' in font.name:
530
+ font_paths.append(font.fname)
531
+
532
+ # 优先使用Noto Sans CJK SC
533
+ selected_font = None
534
+ for font_path in font_paths:
535
+ if 'NotoSansCJKsc' in font_path:
536
+ selected_font = font_path
537
+ break
538
+
539
+ if selected_font:
540
+ font_prop = font_manager.FontProperties(fname=selected_font)
541
+ plt.rcParams['font.family'] = font_prop.get_name()
542
+ print(f"使用字体: {selected_font}")
543
+ else:
544
+ # 方法2:使用绝对路径字体(上传到项目)
545
+ font_path = os.path.join(os.path.dirname(__file__), 'fonts', 'SimHei.ttf')
546
+ if os.path.exists(font_path):
547
+ font_prop = font_manager.FontProperties(fname=font_path)
548
+ plt.rcParams['font.family'] = font_prop.get_name()
549
+ print(f"使用本地字体: {font_path}")
550
+ else:
551
+ # 方法3:最后手段 - 使用系统默认字体
552
+ plt.rcParams['font.family'] = ['DejaVu Sans']
553
+ print("警告:使用默认字体,中文可能显示为方框")
554
+
555
+ plt.rcParams['axes.unicode_minus'] = False
556
+ except Exception as e:
557
+ print(f"字体配置错误: {e}")
558
 
559
  # === 2. 创建图谱 ===
560
  G = nx.DiGraph()
561
  entity_colors = {
562
+ 'PER': '#FF6B6B', # 红色
563
+ 'ORG': '#4ECDC4', # 青色
564
+ 'LOC': '#45B7D1', # 蓝色
565
+ 'TIME': '#96CEB4' # 绿色
566
  }
567
 
568
+ # 简化标签(只显示中文名称)
569
  for entity in entities:
570
  G.add_node(
571
  entity["text"],
572
+ label=entity["text"], # 只显示中文文本
573
  color=entity_colors.get(entity['type'], '#D3D3D3')
574
  )
575
 
 
582
  )
583
 
584
  # === 3. 绘图配置 ===
585
+ plt.figure(figsize=(12, 8), dpi=120)
586
+ pos = nx.spring_layout(G, k=0.8, seed=42) # 增大k值避免重叠
587
 
588
+ # 绘制元素
589
  nx.draw_networkx_nodes(
590
  G, pos,
591
  node_color=[G.nodes[n]['color'] for n in G.nodes],
592
+ node_size=800
 
593
  )
594
  nx.draw_networkx_edges(
595
  G, pos,
 
599
  arrowsize=20
600
  )
601
 
602
+ # === 4. 手动添加文本标签 ===
603
+ for node, (x, y) in pos.items():
604
+ plt.text(x, y,
605
+ node, # 直接使用中文文本
606
+ fontsize=10,
607
+ ha='center',
608
+ va='center',
609
+ bbox=dict(facecolor='white', alpha=0.7, edgecolor='none'))
610
 
611
+ # 边标签(英文关系)
612
  edge_labels = nx.get_edge_attributes(G, 'label')
613
  nx.draw_networkx_edge_labels(
614
  G, pos,
615
  edge_labels=edge_labels,
616
+ font_size=8
 
617
  )
618
 
619
  plt.axis('off')
620
+
621
+ # === 5. 保存图片 ===
622
+ temp_dir = tempfile.mkdtemp()
 
623
  output_path = os.path.join(temp_dir, "kg.png")
 
 
 
 
624
  plt.savefig(output_path, bbox_inches='tight', pad_inches=0.1)
625
  plt.close()
626
+
627
  return output_path
628
 
629
  except Exception as e:
630
+ print(f"[ERROR] 图谱生成失败: {e}")
631
  return None
632
 
633