Spaces:
				
			
			
	
			
			
		Sleeping
		
	
	
	
			
			
	
	
	
	
		
		
		Sleeping
		
	Commit 
							
							·
						
						d65f85e
	
1
								Parent(s):
							
							d5e2274
								
Add Gradio app for NER + RE
Browse files
    	
        app.py
    CHANGED
    
    | @@ -1,31 +1,80 @@ | |
| 1 | 
            -
            import torch
         | 
| 2 | 
            -
            from transformers import BertTokenizer, BertModel
         | 
| 3 | 
             
            import gradio as gr
         | 
| 4 | 
            -
            import  | 
| 5 | 
            -
            import  | 
| 6 | 
            -
            import json
         | 
| 7 | 
            -
            import pandas as pd
         | 
| 8 | 
            -
            import chardet
         | 
| 9 | 
            -
            from pyvis.network import Network
         | 
| 10 | 
            -
            import networkx as nx
         | 
| 11 | 
             
            from pathlib import Path
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 12 |  | 
| 13 | 
            -
            #  | 
| 14 | 
            -
             | 
| 15 | 
            -
             | 
| 16 | 
            -
            model = BertModel.from_pretrained(model_name)
         | 
| 17 |  | 
| 18 | 
            -
            #  | 
| 19 | 
             
            knowledge_graph = {
         | 
| 20 | 
            -
                "entities":  | 
| 21 | 
             
                "relations": []
         | 
| 22 | 
             
            }
         | 
| 23 |  | 
| 24 | 
            -
            def  | 
| 25 | 
            -
                 | 
| 26 | 
            -
             | 
| 27 | 
            -
             | 
| 28 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 29 |  | 
| 30 | 
             
            def visualize_kg():
         | 
| 31 | 
             
                net = Network(height="600px", width="100%", notebook=True, directed=True)
         | 
| @@ -52,159 +101,31 @@ def visualize_kg(): | |
| 52 | 
             
                }
         | 
| 53 | 
             
                """)
         | 
| 54 |  | 
| 55 | 
            -
                 | 
| 56 | 
            -
                 | 
| 57 | 
            -
             | 
| 58 | 
            -
             | 
| 59 | 
            -
                 | 
| 60 | 
            -
                return  | 
| 61 | 
            -
             | 
| 62 | 
            -
            #  | 
| 63 | 
            -
             | 
| 64 | 
            -
                 | 
| 65 | 
            -
                 | 
| 66 | 
            -
             | 
| 67 | 
            -
             | 
| 68 | 
            -
                 | 
| 69 | 
            -
                     | 
| 70 | 
            -
             | 
| 71 | 
            -
             | 
| 72 | 
            -
             | 
| 73 | 
            -
             | 
| 74 | 
            -
             | 
| 75 | 
            -
             | 
| 76 | 
            -
             | 
| 77 | 
            -
             | 
| 78 | 
            -
             | 
| 79 | 
            -
             | 
| 80 | 
            -
             | 
| 81 | 
            -
                            "end": match.end(),
         | 
| 82 | 
            -
                            "type": "UserID"
         | 
| 83 | 
            -
                        })
         | 
| 84 | 
            -
             | 
| 85 | 
            -
                return sorted(entities, key=lambda x: x["start"])
         | 
| 86 | 
            -
             | 
| 87 | 
            -
            def re_extract(entities, text):
         | 
| 88 | 
            -
                relations = []
         | 
| 89 | 
            -
                if len(entities) >= 2:
         | 
| 90 | 
            -
                    for i in range(len(entities) - 1):
         | 
| 91 | 
            -
                        head = entities[i]["text"]
         | 
| 92 | 
            -
                        tail = entities[i + 1]["text"]
         | 
| 93 | 
            -
                        context = text[entities[i]["end"]:entities[i + 1]["start"]]
         | 
| 94 | 
            -
                        if "推荐" in context or "找" in context:
         | 
| 95 | 
            -
                            relation = "recommend"
         | 
| 96 | 
            -
                        elif "发送" in context or "发给" in context:
         | 
| 97 | 
            -
                            relation = "send_to"
         | 
| 98 | 
            -
                        elif "提到" in context or "说" in context:
         | 
| 99 | 
            -
                            relation = "mention"
         | 
| 100 | 
            -
                        else:
         | 
| 101 | 
            -
                            relation = "knows"
         | 
| 102 | 
            -
                        relations.append({
         | 
| 103 | 
            -
                            "head": head,
         | 
| 104 | 
            -
                            "tail": tail,
         | 
| 105 | 
            -
                            "relation": relation
         | 
| 106 | 
            -
                        })
         | 
| 107 | 
            -
                return relations
         | 
| 108 | 
            -
             | 
| 109 | 
            -
            # ----------- 文本处理逻辑 -----------------
         | 
| 110 | 
            -
            def process_text(text):
         | 
| 111 | 
            -
                entities = ner(text)
         | 
| 112 | 
            -
                relations = re_extract(entities, text)
         | 
| 113 | 
            -
                update_knowledge_graph(entities, relations)
         | 
| 114 | 
            -
                kg_html = visualize_kg()
         | 
| 115 | 
            -
                entity_output = "\n".join([f"{e['text']} ({e['type']}) [{e['start']}, {e['end']}]" for e in entities])
         | 
| 116 | 
            -
                relation_output = "\n".join([f"{r['head']} --[{r['relation']}]-> {r['tail']}" for r in relations])
         | 
| 117 | 
            -
                return entity_output, relation_output, kg_html
         | 
| 118 | 
            -
             | 
| 119 | 
            -
            # ----------- 文件上传处理逻辑 -----------------
         | 
| 120 | 
            -
            def detect_encoding(file_path):
         | 
| 121 | 
            -
                with open(file_path, 'rb') as f:
         | 
| 122 | 
            -
                    raw_data = f.read(4096)
         | 
| 123 | 
            -
                result = chardet.detect(raw_data)
         | 
| 124 | 
            -
                return result['encoding'] if result['encoding'] else 'utf-8'
         | 
| 125 | 
            -
             | 
| 126 | 
            -
            def process_file(file):
         | 
| 127 | 
            -
                ext = os.path.splitext(file.name)[-1].lower()
         | 
| 128 | 
            -
                full_text = ""
         | 
| 129 | 
            -
                warning = ""
         | 
| 130 | 
            -
             | 
| 131 | 
            -
                try:
         | 
| 132 | 
            -
                    encoding = detect_encoding(file.name)
         | 
| 133 | 
            -
             | 
| 134 | 
            -
                    if ext == ".txt":
         | 
| 135 | 
            -
                        with open(file.name, "r", encoding=encoding) as f:
         | 
| 136 | 
            -
                            full_text = f.read()
         | 
| 137 | 
            -
             | 
| 138 | 
            -
                    elif ext == ".jsonl":
         | 
| 139 | 
            -
                        with open(file.name, "r", encoding=encoding) as f:
         | 
| 140 | 
            -
                            lines = f.readlines()
         | 
| 141 | 
            -
                            texts = []
         | 
| 142 | 
            -
                            skipped_lines = []
         | 
| 143 | 
            -
                            for i, line in enumerate(lines, start=1):
         | 
| 144 | 
            -
                                try:
         | 
| 145 | 
            -
                                    obj = json.loads(line)
         | 
| 146 | 
            -
                                    texts.append(obj.get("text", ""))
         | 
| 147 | 
            -
                                except Exception:
         | 
| 148 | 
            -
                                    skipped_lines.append(i)
         | 
| 149 | 
            -
                            full_text = "\n".join(texts)
         | 
| 150 | 
            -
                            if skipped_lines:
         | 
| 151 | 
            -
                                warning = f"⚠️ 跳过 {len(skipped_lines)} 行无效 JSON(如第 {skipped_lines[0]} 行)\n\n"
         | 
| 152 | 
            -
             | 
| 153 | 
            -
                    elif ext == ".json":
         | 
| 154 | 
            -
                        with open(file.name, "r", encoding=encoding) as f:
         | 
| 155 | 
            -
                            data = json.load(f)
         | 
| 156 | 
            -
                            if isinstance(data, list):
         | 
| 157 | 
            -
                                full_text = "\n".join([str(item.get("text", "")) for item in data])
         | 
| 158 | 
            -
                            elif isinstance(data, dict):
         | 
| 159 | 
            -
                                full_text = data.get("text", "")
         | 
| 160 | 
            -
                            else:
         | 
| 161 | 
            -
                                return "❌ JSON 文件格式无法解析", "", ""
         | 
| 162 | 
            -
             | 
| 163 | 
            -
                    elif ext == ".csv":
         | 
| 164 | 
            -
                        df = pd.read_csv(file.name, encoding=encoding)
         | 
| 165 | 
            -
                        if "text" in df.columns:
         | 
| 166 | 
            -
                            full_text = "\n".join(df["text"].astype(str))
         | 
| 167 | 
            -
                        else:
         | 
| 168 | 
            -
                            return "❌ CSV 中未找到 'text' 列", "", ""
         | 
| 169 | 
            -
             | 
| 170 | 
            -
                    else:
         | 
| 171 | 
            -
                        return f"❌ 不支持的文件格式:{ext}", "", ""
         | 
| 172 | 
            -
             | 
| 173 | 
            -
                except Exception as e:
         | 
| 174 | 
            -
                    return f"❌ 文件读取错误:{str(e)}", "", ""
         | 
| 175 | 
            -
             | 
| 176 | 
            -
                entity_out, relation_out, kg_html = process_text(full_text)
         | 
| 177 | 
            -
                return warning + entity_out, relation_out, kg_html
         | 
| 178 | 
            -
             | 
| 179 | 
            -
            # ----------- Gradio UI -----------------
         | 
| 180 | 
            -
            with gr.Blocks() as demo:
         | 
| 181 | 
            -
                gr.Markdown("""# 📱 微信聊天记录分析系统  
         | 
| 182 | 
            -
                功能包括:实体识别(NER)、关系抽取(RE)和知识图谱可视化""")
         | 
| 183 | 
            -
             | 
| 184 | 
            -
                with gr.Tab("✍️ 直接输入文本"):
         | 
| 185 | 
            -
                    with gr.Row():
         | 
| 186 | 
            -
                        input_text = gr.Textbox(label="输入聊天内容", lines=8, placeholder="请输入中文微信聊天记录")
         | 
| 187 | 
            -
                    analyze_btn = gr.Button("分析文本")
         | 
| 188 | 
            -
                    with gr.Row():
         | 
| 189 | 
            -
                        entity_output1 = gr.Textbox(label="识别出的实体")
         | 
| 190 | 
            -
                        relation_output1 = gr.Textbox(label="抽取的关系")
         | 
| 191 | 
            -
                    kg_html1 = gr.HTML(label="知识图谱可视化")
         | 
| 192 | 
            -
                    analyze_btn.click(fn=process_text, inputs=[input_text], outputs=[entity_output1, relation_output1, kg_html1])
         | 
| 193 | 
            -
             | 
| 194 | 
            -
                with gr.Tab("📁 上传文件"):
         | 
| 195 | 
            -
                    file_input = gr.File(label="上传聊天记录文件", file_types=[".txt", ".jsonl", ".json", ".csv"])
         | 
| 196 | 
            -
                    analyze_file_btn = gr.Button("分析文件")
         | 
| 197 | 
            -
                    with gr.Row():
         | 
| 198 | 
            -
                        entity_output2 = gr.Textbox(label="识别出的实体")
         | 
| 199 | 
            -
                        relation_output2 = gr.Textbox(label="抽取的关系")
         | 
| 200 | 
            -
                    kg_html2 = gr.HTML(label="知识图谱可视化")
         | 
| 201 | 
            -
                    analyze_file_btn.click(fn=process_file, inputs=[file_input], outputs=[entity_output2, relation_output2, kg_html2])
         | 
| 202 | 
            -
             | 
| 203 | 
            -
                with gr.Tab("🗺️ 完整知识图谱"):
         | 
| 204 | 
            -
                    gr.Markdown("## 当前累计构建的知识图谱")
         | 
| 205 | 
            -
                    refresh_btn = gr.Button("刷新图谱")
         | 
| 206 | 
            -
                    full_kg = gr.HTML()
         | 
| 207 | 
            -
                    refresh_btn.click(fn=lambda: visualize_kg(), outputs=full_kg)
         | 
| 208 | 
            -
             | 
| 209 | 
             
            if __name__ == "__main__":
         | 
| 210 | 
             
                demo.launch()
         | 
|  | |
|  | |
|  | |
| 1 | 
             
            import gradio as gr
         | 
| 2 | 
            +
            from transformers import BertTokenizerFast, BertForTokenClassification, BertForSequenceClassification
         | 
| 3 | 
            +
            import torch
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 4 | 
             
            from pathlib import Path
         | 
| 5 | 
            +
            from pyvis.network import Network
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            # 加载模型和分词器
         | 
| 8 | 
            +
            ner_tokenizer = BertTokenizerFast.from_pretrained("bert-base-chinese")
         | 
| 9 | 
            +
            ner_model = BertForTokenClassification.from_pretrained("bert-base-chinese", num_labels=10)
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            re_model = BertForSequenceClassification.from_pretrained("bert-base-chinese", num_labels=5)
         | 
| 12 | 
            +
            re_tokenizer = BertTokenizerFast.from_pretrained("bert-base-chinese")
         | 
| 13 |  | 
| 14 | 
            +
            # 定义标签和关系类型
         | 
| 15 | 
            +
            label_list = ["O", "B-PER", "I-PER", "B-ORG", "I-ORG", "B-LOC", "I-LOC", "B-MISC", "I-MISC", "PAD"]
         | 
| 16 | 
            +
            relation_list = ["no_relation", "per-org", "per-loc", "org-loc", "org-misc"]
         | 
|  | |
| 17 |  | 
| 18 | 
            +
            # 用于存储知识图谱
         | 
| 19 | 
             
            knowledge_graph = {
         | 
| 20 | 
            +
                "entities": [],
         | 
| 21 | 
             
                "relations": []
         | 
| 22 | 
             
            }
         | 
| 23 |  | 
| 24 | 
            +
            def ner_predict(text):
         | 
| 25 | 
            +
                inputs = ner_tokenizer(text, return_tensors="pt", truncation=True)
         | 
| 26 | 
            +
                with torch.no_grad():
         | 
| 27 | 
            +
                    outputs = ner_model(**inputs).logits
         | 
| 28 | 
            +
                predictions = torch.argmax(outputs, dim=2)
         | 
| 29 | 
            +
                tokens = ner_tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
         | 
| 30 | 
            +
                predicted_labels = [label_list[label_id] for label_id in predictions[0].numpy()]
         | 
| 31 | 
            +
                entities = []
         | 
| 32 | 
            +
                current_entity = ""
         | 
| 33 | 
            +
                current_label = ""
         | 
| 34 | 
            +
                start = None
         | 
| 35 | 
            +
                for idx, (token, label) in enumerate(zip(tokens, predicted_labels)):
         | 
| 36 | 
            +
                    if label.startswith("B-"):
         | 
| 37 | 
            +
                        if current_entity:
         | 
| 38 | 
            +
                            entities.append((current_entity, current_label, start, idx))
         | 
| 39 | 
            +
                        current_entity = token.replace("##", "")
         | 
| 40 | 
            +
                        current_label = label[2:]
         | 
| 41 | 
            +
                        start = idx
         | 
| 42 | 
            +
                    elif label.startswith("I-") and current_label == label[2:]:
         | 
| 43 | 
            +
                        current_entity += token.replace("##", "")
         | 
| 44 | 
            +
                    else:
         | 
| 45 | 
            +
                        if current_entity:
         | 
| 46 | 
            +
                            entities.append((current_entity, current_label, start, idx))
         | 
| 47 | 
            +
                            current_entity = ""
         | 
| 48 | 
            +
                            current_label = ""
         | 
| 49 | 
            +
                if current_entity:
         | 
| 50 | 
            +
                    entities.append((current_entity, current_label, start, len(tokens)))
         | 
| 51 | 
            +
                return entities
         | 
| 52 | 
            +
             | 
| 53 | 
            +
            def re_predict(text, entities):
         | 
| 54 | 
            +
                relations = []
         | 
| 55 | 
            +
                for i in range(len(entities)):
         | 
| 56 | 
            +
                    for j in range(len(entities)):
         | 
| 57 | 
            +
                        if i == j:
         | 
| 58 | 
            +
                            continue
         | 
| 59 | 
            +
                        head, tail = entities[i][0], entities[j][0]
         | 
| 60 | 
            +
                        input_text = f"{head} 和 {tail} 有什么关系?{text}"
         | 
| 61 | 
            +
                        inputs = re_tokenizer(input_text, return_tensors="pt", truncation=True, padding=True)
         | 
| 62 | 
            +
                        with torch.no_grad():
         | 
| 63 | 
            +
                            outputs = re_model(**inputs).logits
         | 
| 64 | 
            +
                        prediction = torch.argmax(outputs, dim=1).item()
         | 
| 65 | 
            +
                        if relation_list[prediction] != "no_relation":
         | 
| 66 | 
            +
                            relations.append((head, tail, relation_list[prediction]))
         | 
| 67 | 
            +
                return relations
         | 
| 68 | 
            +
             | 
| 69 | 
            +
            def analyze_text(text):
         | 
| 70 | 
            +
                entities = ner_predict(text)
         | 
| 71 | 
            +
                relations = re_predict(text, entities)
         | 
| 72 | 
            +
                entity_list = [f"{ent[0]} ({ent[1]}) [{ent[2]}, {ent[3]}]" for ent in entities]
         | 
| 73 | 
            +
                relation_list_text = [f"{rel[0]} --[{rel[2]}]-> {rel[1]}" for rel in relations]
         | 
| 74 | 
            +
                # 更新全局知识图谱
         | 
| 75 | 
            +
                knowledge_graph["entities"] = [(ent[0], ent[1]) for ent in entities]
         | 
| 76 | 
            +
                knowledge_graph["relations"] = relations
         | 
| 77 | 
            +
                return "\n".join(entity_list), "\n".join(relation_list_text)
         | 
| 78 |  | 
| 79 | 
             
            def visualize_kg():
         | 
| 80 | 
             
                net = Network(height="600px", width="100%", notebook=True, directed=True)
         | 
|  | |
| 101 | 
             
                }
         | 
| 102 | 
             
                """)
         | 
| 103 |  | 
| 104 | 
            +
                # 保存 HTML 到 Hugging Face Spaces 可访问路径
         | 
| 105 | 
            +
                file_path = "/home/user/kg.html"
         | 
| 106 | 
            +
                net.save_graph(file_path)
         | 
| 107 | 
            +
             | 
| 108 | 
            +
                # 返回 iframe HTML
         | 
| 109 | 
            +
                return f'<iframe src="/file=kg.html" width="100%" height="600px" frameborder="0"></iframe>'
         | 
| 110 | 
            +
             | 
| 111 | 
            +
            # 搭建 Gradio 界面
         | 
| 112 | 
            +
            with gr.Blocks(title="Wechat Ner Re") as demo:
         | 
| 113 | 
            +
                gr.Markdown("## 微信聊天记录结构化系统(NER + RE + 知识图谱)")
         | 
| 114 | 
            +
                with gr.Row():
         | 
| 115 | 
            +
                    input_text = gr.Textbox(lines=5, label="请输入文本")
         | 
| 116 | 
            +
                    analyze_button = gr.Button("分析文本")
         | 
| 117 | 
            +
                with gr.Row():
         | 
| 118 | 
            +
                    ner_output = gr.Textbox(label="识别出的实体")
         | 
| 119 | 
            +
                    re_output = gr.Textbox(label="抽取的关系")
         | 
| 120 | 
            +
                analyze_button.click(analyze_text, inputs=input_text, outputs=[ner_output, re_output])
         | 
| 121 | 
            +
             | 
| 122 | 
            +
                # 显示知识图谱
         | 
| 123 | 
            +
                gr.Markdown("## 知识图谱可视化")
         | 
| 124 | 
            +
                with gr.Row():
         | 
| 125 | 
            +
                    kg_button = gr.Button("生成知识图谱")
         | 
| 126 | 
            +
                kg_html1 = gr.HTML(label="知识图谱可视化", show_label=True)
         | 
| 127 | 
            +
                kg_button.click(fn=visualize_kg, outputs=kg_html1)
         | 
| 128 | 
            +
             | 
| 129 | 
            +
            # 启动应用
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 130 | 
             
            if __name__ == "__main__":
         | 
| 131 | 
             
                demo.launch()
         | 
