Spaces:
				
			
			
	
			
			
		Sleeping
		
	
	
	
			
			
	
	
	
	
		
		
		Sleeping
		
	Upload 3 files
Browse files
    	
        app.py
    CHANGED
    
    | 
         @@ -508,68 +508,59 @@ def process_text(text, model_type="bert"): 
     | 
|
| 508 | 
         
             
                return ent_text, rel_text, kg_text, f"{total_duration:.2f} 秒"
         
     | 
| 509 | 
         | 
| 510 | 
         
             
            # ======================== 知识图谱可视化 ========================
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 511 | 
         
             
            def generate_kg_image(entities, relations):
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 512 | 
         
             
                try:
         
     | 
| 513 | 
         
            -
                     
     | 
| 514 | 
         
            -
                     
     | 
| 515 | 
         
            -
                     
     | 
| 516 | 
         
            -
             
     | 
| 517 | 
         
            -
             
     | 
| 518 | 
         
            -
                     
     | 
| 519 | 
         
            -
             
     | 
| 520 | 
         
            -
             
     | 
| 521 | 
         
            -
             
     | 
| 522 | 
         
            -
             
     | 
| 523 | 
         
            -
             
     | 
| 524 | 
         
            -
                        matplotlib.font_manager._rebuild()
         
     | 
| 525 | 
         
            -
                        
         
     | 
| 526 | 
         
            -
                        # 查找所有可用字体
         
     | 
| 527 | 
         
            -
                        font_paths = []
         
     | 
| 528 | 
         
            -
                        for font in font_manager.fontManager.ttflist:
         
     | 
| 529 | 
         
            -
                            if 'Noto Sans CJK' in font.name or 'SimHei' in font.name:
         
     | 
| 530 | 
         
            -
                                font_paths.append(font.fname)
         
     | 
| 531 | 
         
            -
                        
         
     | 
| 532 | 
         
            -
                        # 优先使用Noto Sans CJK SC
         
     | 
| 533 | 
         
            -
                        selected_font = None
         
     | 
| 534 | 
         
            -
                        for font_path in font_paths:
         
     | 
| 535 | 
         
            -
                            if 'NotoSansCJKsc' in font_path:
         
     | 
| 536 | 
         
            -
                                selected_font = font_path
         
     | 
| 537 | 
         
            -
                                break
         
     | 
| 538 | 
         
            -
                        
         
     | 
| 539 | 
         
            -
                        if selected_font:
         
     | 
| 540 | 
         
            -
                            font_prop = font_manager.FontProperties(fname=selected_font)
         
     | 
| 541 | 
         
            -
                            plt.rcParams['font.family'] = font_prop.get_name()
         
     | 
| 542 | 
         
            -
                            print(f"使用字体: {selected_font}")
         
     | 
| 543 | 
         
            -
                        else:
         
     | 
| 544 | 
         
            -
                            # 方法2:使用绝对路径字体(上传到项目)
         
     | 
| 545 | 
         
            -
                            font_path = os.path.join(os.path.dirname(__file__), 'fonts', 'SimHei.ttf')
         
     | 
| 546 | 
         
            -
                            if os.path.exists(font_path):
         
     | 
| 547 | 
         
            -
                                font_prop = font_manager.FontProperties(fname=font_path)
         
     | 
| 548 | 
         
            -
                                plt.rcParams['font.family'] = font_prop.get_name()
         
     | 
| 549 | 
         
            -
                                print(f"使用本地字体: {font_path}")
         
     | 
| 550 | 
         
            -
                            else:
         
     | 
| 551 | 
         
            -
                                # 方法3:最后手段 - 使用系统默认字体
         
     | 
| 552 | 
         
            -
                                plt.rcParams['font.family'] = ['DejaVu Sans']
         
     | 
| 553 | 
         
            -
                                print("警告:使用默认字体,中文可能显示为方框")
         
     | 
| 554 | 
         
            -
                        
         
     | 
| 555 | 
         
            -
                        plt.rcParams['axes.unicode_minus'] = False
         
     | 
| 556 | 
         
            -
                    except Exception as e:
         
     | 
| 557 | 
         
            -
                        print(f"字体配置错误: {e}")
         
     | 
| 558 | 
         | 
| 559 | 
         
             
                    # === 2. 创建图谱 ===
         
     | 
| 560 | 
         
             
                    G = nx.DiGraph()
         
     | 
| 561 | 
         
             
                    entity_colors = {
         
     | 
| 562 | 
         
            -
                        'PER': '#FF6B6B',  #  
     | 
| 563 | 
         
            -
                        'ORG': '#4ECDC4',  #  
     | 
| 564 | 
         
            -
                        'LOC': '#45B7D1',  #  
     | 
| 565 | 
         
            -
                        'TIME': '#96CEB4' 
     | 
| 566 | 
         
             
                    }
         
     | 
| 567 | 
         | 
| 568 | 
         
            -
                    # 简化标签(只显示中文名称)
         
     | 
| 569 | 
         
             
                    for entity in entities:
         
     | 
| 570 | 
         
             
                        G.add_node(
         
     | 
| 571 | 
         
             
                            entity["text"],
         
     | 
| 572 | 
         
            -
                            label=entity[ 
     | 
| 573 | 
         
             
                            color=entity_colors.get(entity['type'], '#D3D3D3')
         
     | 
| 574 | 
         
             
                        )
         
     | 
| 575 | 
         | 
| 
         @@ -582,14 +573,14 @@ def generate_kg_image(entities, relations): 
     | 
|
| 582 | 
         
             
                            )
         
     | 
| 583 | 
         | 
| 584 | 
         
             
                    # === 3. 绘图配置 ===
         
     | 
| 585 | 
         
            -
                    plt.figure(figsize=(12, 8), dpi= 
     | 
| 586 | 
         
            -
                    pos = nx.spring_layout(G, k=0. 
     | 
| 587 | 
         | 
| 588 | 
         
            -
                    # 绘制元素
         
     | 
| 589 | 
         
             
                    nx.draw_networkx_nodes(
         
     | 
| 590 | 
         
             
                        G, pos,
         
     | 
| 591 | 
         
             
                        node_color=[G.nodes[n]['color'] for n in G.nodes],
         
     | 
| 592 | 
         
            -
                        node_size=800
         
     | 
| 
         | 
|
| 593 | 
         
             
                    )
         
     | 
| 594 | 
         
             
                    nx.draw_networkx_edges(
         
     | 
| 595 | 
         
             
                        G, pos,
         
     | 
| 
         @@ -599,35 +590,40 @@ def generate_kg_image(entities, relations): 
     | 
|
| 599 | 
         
             
                        arrowsize=20
         
     | 
| 600 | 
         
             
                    )
         
     | 
| 601 | 
         | 
| 602 | 
         
            -
                     
     | 
| 603 | 
         
            -
                     
     | 
| 604 | 
         
            -
                         
     | 
| 605 | 
         
            -
             
     | 
| 606 | 
         
            -
             
     | 
| 607 | 
         
            -
             
     | 
| 608 | 
         
            -
             
     | 
| 609 | 
         
            -
             
     | 
| 610 | 
         | 
| 611 | 
         
            -
                    # 边标签(英文关系)
         
     | 
| 612 | 
         
             
                    edge_labels = nx.get_edge_attributes(G, 'label')
         
     | 
| 613 | 
         
             
                    nx.draw_networkx_edge_labels(
         
     | 
| 614 | 
         
             
                        G, pos,
         
     | 
| 615 | 
         
             
                        edge_labels=edge_labels,
         
     | 
| 616 | 
         
            -
                        font_size=8
         
     | 
| 
         | 
|
| 617 | 
         
             
                    )
         
     | 
| 618 | 
         | 
| 619 | 
         
             
                    plt.axis('off')
         
     | 
| 620 | 
         
            -
                    
         
     | 
| 621 | 
         
            -
             
     | 
| 622 | 
         
            -
                     
     | 
| 
         | 
|
| 623 | 
         
             
                    output_path = os.path.join(temp_dir, "kg.png")
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 624 | 
         
             
                    plt.savefig(output_path, bbox_inches='tight', pad_inches=0.1)
         
     | 
| 625 | 
         
             
                    plt.close()
         
     | 
| 626 | 
         
            -
             
     | 
| 627 | 
         
             
                    return output_path
         
     | 
| 628 | 
         | 
| 629 | 
         
             
                except Exception as e:
         
     | 
| 630 | 
         
            -
                     
     | 
| 631 | 
         
             
                    return None
         
     | 
| 632 | 
         | 
| 633 | 
         | 
| 
         | 
|
| 508 | 
         
             
                return ent_text, rel_text, kg_text, f"{total_duration:.2f} 秒"
         
     | 
| 509 | 
         | 
| 510 | 
         
             
            # ======================== 知识图谱可视化 ========================
         
     | 
| 511 | 
         
            +
            import matplotlib.pyplot as plt
         
     | 
| 512 | 
         
            +
            import networkx as nx
         
     | 
| 513 | 
         
            +
            import tempfile
         
     | 
| 514 | 
         
            +
            import os
         
     | 
| 515 | 
         
            +
            import logging
         
     | 
| 516 | 
         
            +
            from matplotlib import font_manager
         
     | 
| 517 | 
         
            +
             
     | 
| 518 | 
         
            +
            # 这个函数用于查找并验证中文字体路径
         
     | 
| 519 | 
         
            +
            def find_chinese_font():
         
     | 
| 520 | 
         
            +
                # 尝试查找 Noto Sans CJK 字体
         
     | 
| 521 | 
         
            +
                font_paths = [
         
     | 
| 522 | 
         
            +
                    "/usr/share/fonts/truetype/noto/NotoSansCJK-Regular.ttc",  # Noto CJK 字体
         
     | 
| 523 | 
         
            +
                    "/usr/share/fonts/truetype/wqy/wqy-microhei.ttc"         # 微软雅黑
         
     | 
| 524 | 
         
            +
                ]
         
     | 
| 525 | 
         
            +
                
         
     | 
| 526 | 
         
            +
                for font_path in font_paths:
         
     | 
| 527 | 
         
            +
                    if os.path.exists(font_path):
         
     | 
| 528 | 
         
            +
                        logging.info(f"Found font at {font_path}")
         
     | 
| 529 | 
         
            +
                        return font_path
         
     | 
| 530 | 
         
            +
                
         
     | 
| 531 | 
         
            +
                logging.error("No Chinese font found!")
         
     | 
| 532 | 
         
            +
                return None
         
     | 
| 533 | 
         
            +
             
     | 
| 534 | 
         
             
            def generate_kg_image(entities, relations):
         
     | 
| 535 | 
         
            +
                """
         
     | 
| 536 | 
         
            +
                中文知识图谱生成函数,支持自动匹配系统中的中文字体,避免中文显示为方框。
         
     | 
| 537 | 
         
            +
                """
         
     | 
| 538 | 
         
             
                try:
         
     | 
| 539 | 
         
            +
                    # === 1. 确保使用合适的中文字体 ===
         
     | 
| 540 | 
         
            +
                    chinese_font = find_chinese_font()  # 调用查找字体函数
         
     | 
| 541 | 
         
            +
                    if chinese_font:
         
     | 
| 542 | 
         
            +
                        font_prop = font_manager.FontProperties(fname=chinese_font)
         
     | 
| 543 | 
         
            +
                        plt.rcParams['font.family'] = font_prop.get_name()
         
     | 
| 544 | 
         
            +
                    else:
         
     | 
| 545 | 
         
            +
                        # 如果字体路径未找到,使用默认字体(DejaVu Sans)
         
     | 
| 546 | 
         
            +
                        logging.warning("Using default font")
         
     | 
| 547 | 
         
            +
                        plt.rcParams['font.family'] = ['DejaVu Sans']
         
     | 
| 548 | 
         
            +
             
     | 
| 549 | 
         
            +
                    plt.rcParams['axes.unicode_minus'] = False
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 550 | 
         | 
| 551 | 
         
             
                    # === 2. 创建图谱 ===
         
     | 
| 552 | 
         
             
                    G = nx.DiGraph()
         
     | 
| 553 | 
         
             
                    entity_colors = {
         
     | 
| 554 | 
         
            +
                        'PER': '#FF6B6B',  # 人物-红色
         
     | 
| 555 | 
         
            +
                        'ORG': '#4ECDC4',  # 组织-青色
         
     | 
| 556 | 
         
            +
                        'LOC': '#45B7D1',  # 地点-蓝色
         
     | 
| 557 | 
         
            +
                        'TIME': '#96CEB4'  # 时间-绿色
         
     | 
| 558 | 
         
             
                    }
         
     | 
| 559 | 
         | 
| 
         | 
|
| 560 | 
         
             
                    for entity in entities:
         
     | 
| 561 | 
         
             
                        G.add_node(
         
     | 
| 562 | 
         
             
                            entity["text"],
         
     | 
| 563 | 
         
            +
                            label=f"{entity['text']} ({entity['type']})",
         
     | 
| 564 | 
         
             
                            color=entity_colors.get(entity['type'], '#D3D3D3')
         
     | 
| 565 | 
         
             
                        )
         
     | 
| 566 | 
         | 
| 
         | 
|
| 573 | 
         
             
                            )
         
     | 
| 574 | 
         | 
| 575 | 
         
             
                    # === 3. 绘图配置 ===
         
     | 
| 576 | 
         
            +
                    plt.figure(figsize=(12, 8), dpi=150)
         
     | 
| 577 | 
         
            +
                    pos = nx.spring_layout(G, k=0.7, seed=42)
         
     | 
| 578 | 
         | 
| 
         | 
|
| 579 | 
         
             
                    nx.draw_networkx_nodes(
         
     | 
| 580 | 
         
             
                        G, pos,
         
     | 
| 581 | 
         
             
                        node_color=[G.nodes[n]['color'] for n in G.nodes],
         
     | 
| 582 | 
         
            +
                        node_size=800,
         
     | 
| 583 | 
         
            +
                        alpha=0.9
         
     | 
| 584 | 
         
             
                    )
         
     | 
| 585 | 
         
             
                    nx.draw_networkx_edges(
         
     | 
| 586 | 
         
             
                        G, pos,
         
     | 
| 
         | 
|
| 590 | 
         
             
                        arrowsize=20
         
     | 
| 591 | 
         
             
                    )
         
     | 
| 592 | 
         | 
| 593 | 
         
            +
                    node_labels = {n: G.nodes[n]['label'] for n in G.nodes}
         
     | 
| 594 | 
         
            +
                    nx.draw_networkx_labels(
         
     | 
| 595 | 
         
            +
                        G, pos,
         
     | 
| 596 | 
         
            +
                        labels=node_labels,
         
     | 
| 597 | 
         
            +
                        font_size=10,
         
     | 
| 598 | 
         
            +
                        font_family=font_prop.get_name() if chinese_font else 'SimHei',
         
     | 
| 599 | 
         
            +
                        font_weight='bold'
         
     | 
| 600 | 
         
            +
                    )
         
     | 
| 601 | 
         | 
| 
         | 
|
| 602 | 
         
             
                    edge_labels = nx.get_edge_attributes(G, 'label')
         
     | 
| 603 | 
         
             
                    nx.draw_networkx_edge_labels(
         
     | 
| 604 | 
         
             
                        G, pos,
         
     | 
| 605 | 
         
             
                        edge_labels=edge_labels,
         
     | 
| 606 | 
         
            +
                        font_size=8,
         
     | 
| 607 | 
         
            +
                        font_family=font_prop.get_name() if chinese_font else 'SimHei'
         
     | 
| 608 | 
         
             
                    )
         
     | 
| 609 | 
         | 
| 610 | 
         
             
                    plt.axis('off')
         
     | 
| 611 | 
         
            +
                    plt.tight_layout()
         
     | 
| 612 | 
         
            +
             
     | 
| 613 | 
         
            +
                    # === 4. 保存图片 ===
         
     | 
| 614 | 
         
            +
                    temp_dir = tempfile.mkdtemp()  # 确保在 Docker 容器中有权限写入
         
     | 
| 615 | 
         
             
                    output_path = os.path.join(temp_dir, "kg.png")
         
     | 
| 616 | 
         
            +
             
     | 
| 617 | 
         
            +
                    # 打印路径以方便调试
         
     | 
| 618 | 
         
            +
                    logging.info(f"Saving graph image to {output_path}")
         
     | 
| 619 | 
         
            +
             
     | 
| 620 | 
         
             
                    plt.savefig(output_path, bbox_inches='tight', pad_inches=0.1)
         
     | 
| 621 | 
         
             
                    plt.close()
         
     | 
| 622 | 
         
            +
             
     | 
| 623 | 
         
             
                    return output_path
         
     | 
| 624 | 
         | 
| 625 | 
         
             
                except Exception as e:
         
     | 
| 626 | 
         
            +
                    logging.error(f"[ERROR] 图谱生成失败: {str(e)}")
         
     | 
| 627 | 
         
             
                    return None
         
     | 
| 628 | 
         | 
| 629 | 
         | 
    	
        apt.txt
    ADDED
    
    | 
         @@ -0,0 +1 @@ 
     | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            fonts-noto-cjk
         
     |