chen666-666 commited on
Commit
d65f85e
·
1 Parent(s): d5e2274

Add Gradio app for NER + RE

Browse files
Files changed (1) hide show
  1. app.py +95 -174
app.py CHANGED
@@ -1,31 +1,80 @@
1
- import torch
2
- from transformers import BertTokenizer, BertModel
3
  import gradio as gr
4
- import re
5
- import os
6
- import json
7
- import pandas as pd
8
- import chardet
9
- from pyvis.network import Network
10
- import networkx as nx
11
  from pathlib import Path
 
 
 
 
 
 
 
 
12
 
13
- # 初始化模型
14
- model_name = "bert-base-chinese"
15
- tokenizer = BertTokenizer.from_pretrained(model_name)
16
- model = BertModel.from_pretrained(model_name)
17
 
18
- # 知识图谱数据存储
19
  knowledge_graph = {
20
- "entities": set(),
21
  "relations": []
22
  }
23
 
24
- def update_knowledge_graph(entities, relations):
25
- for e in entities:
26
- knowledge_graph["entities"].add((e['text'], e['type']))
27
- for r in relations:
28
- knowledge_graph["relations"].append((r['head'], r['tail'], r['relation']))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
  def visualize_kg():
31
  net = Network(height="600px", width="100%", notebook=True, directed=True)
@@ -52,159 +101,31 @@ def visualize_kg():
52
  }
53
  """)
54
 
55
- temp_file = "kg.html"
56
- net.save_graph(temp_file)
57
-
58
- # ✅ 返回 HTML 内容而不是文件路径
59
- html_content = Path(temp_file).read_text(encoding="utf-8")
60
- return html_content
61
-
62
- # ----------- NER 和 RE 抽取逻辑 -----------------
63
- def ner(text):
64
- pattern_name = r"[赵钱孙李周吴郑王冯陈褚卫蒋沈韩杨朱秦尤许何吕施张孔曹严华金魏陶姜][\u4e00-\u9fa5]{1,2}"
65
- pattern_id = r"\b[a-zA-Z_][a-zA-Z0-9_]{4,}\b"
66
- entities = []
67
-
68
- for match in re.finditer(pattern_name, text):
69
- entities.append({
70
- "text": match.group(),
71
- "start": match.start(),
72
- "end": match.end(),
73
- "type": "PersonName"
74
- })
75
-
76
- for match in re.finditer(pattern_id, text):
77
- if not any(e["start"] == match.start() for e in entities):
78
- entities.append({
79
- "text": match.group(),
80
- "start": match.start(),
81
- "end": match.end(),
82
- "type": "UserID"
83
- })
84
-
85
- return sorted(entities, key=lambda x: x["start"])
86
-
87
- def re_extract(entities, text):
88
- relations = []
89
- if len(entities) >= 2:
90
- for i in range(len(entities) - 1):
91
- head = entities[i]["text"]
92
- tail = entities[i + 1]["text"]
93
- context = text[entities[i]["end"]:entities[i + 1]["start"]]
94
- if "推荐" in context or "找" in context:
95
- relation = "recommend"
96
- elif "发送" in context or "发给" in context:
97
- relation = "send_to"
98
- elif "提到" in context or "说" in context:
99
- relation = "mention"
100
- else:
101
- relation = "knows"
102
- relations.append({
103
- "head": head,
104
- "tail": tail,
105
- "relation": relation
106
- })
107
- return relations
108
-
109
- # ----------- 文本处理逻辑 -----------------
110
- def process_text(text):
111
- entities = ner(text)
112
- relations = re_extract(entities, text)
113
- update_knowledge_graph(entities, relations)
114
- kg_html = visualize_kg()
115
- entity_output = "\n".join([f"{e['text']} ({e['type']}) [{e['start']}, {e['end']}]" for e in entities])
116
- relation_output = "\n".join([f"{r['head']} --[{r['relation']}]-> {r['tail']}" for r in relations])
117
- return entity_output, relation_output, kg_html
118
-
119
- # ----------- 文件上传处理逻辑 -----------------
120
- def detect_encoding(file_path):
121
- with open(file_path, 'rb') as f:
122
- raw_data = f.read(4096)
123
- result = chardet.detect(raw_data)
124
- return result['encoding'] if result['encoding'] else 'utf-8'
125
-
126
- def process_file(file):
127
- ext = os.path.splitext(file.name)[-1].lower()
128
- full_text = ""
129
- warning = ""
130
-
131
- try:
132
- encoding = detect_encoding(file.name)
133
-
134
- if ext == ".txt":
135
- with open(file.name, "r", encoding=encoding) as f:
136
- full_text = f.read()
137
-
138
- elif ext == ".jsonl":
139
- with open(file.name, "r", encoding=encoding) as f:
140
- lines = f.readlines()
141
- texts = []
142
- skipped_lines = []
143
- for i, line in enumerate(lines, start=1):
144
- try:
145
- obj = json.loads(line)
146
- texts.append(obj.get("text", ""))
147
- except Exception:
148
- skipped_lines.append(i)
149
- full_text = "\n".join(texts)
150
- if skipped_lines:
151
- warning = f"⚠️ 跳过 {len(skipped_lines)} 行无效 JSON(如第 {skipped_lines[0]} 行)\n\n"
152
-
153
- elif ext == ".json":
154
- with open(file.name, "r", encoding=encoding) as f:
155
- data = json.load(f)
156
- if isinstance(data, list):
157
- full_text = "\n".join([str(item.get("text", "")) for item in data])
158
- elif isinstance(data, dict):
159
- full_text = data.get("text", "")
160
- else:
161
- return "❌ JSON 文件格式无法解析", "", ""
162
-
163
- elif ext == ".csv":
164
- df = pd.read_csv(file.name, encoding=encoding)
165
- if "text" in df.columns:
166
- full_text = "\n".join(df["text"].astype(str))
167
- else:
168
- return "❌ CSV 中未找到 'text' 列", "", ""
169
-
170
- else:
171
- return f"❌ 不支持的文件格式:{ext}", "", ""
172
-
173
- except Exception as e:
174
- return f"❌ 文件读取错误:{str(e)}", "", ""
175
-
176
- entity_out, relation_out, kg_html = process_text(full_text)
177
- return warning + entity_out, relation_out, kg_html
178
-
179
- # ----------- Gradio UI -----------------
180
- with gr.Blocks() as demo:
181
- gr.Markdown("""# 📱 微信聊天记录分析系统
182
- 功能包括:实体识别(NER)、关系抽取(RE)和知识图谱可视化""")
183
-
184
- with gr.Tab("✍️ 直接输入文本"):
185
- with gr.Row():
186
- input_text = gr.Textbox(label="输入聊天内容", lines=8, placeholder="请输入中文微信聊天记录")
187
- analyze_btn = gr.Button("分析文本")
188
- with gr.Row():
189
- entity_output1 = gr.Textbox(label="识别出的实体")
190
- relation_output1 = gr.Textbox(label="抽取的关系")
191
- kg_html1 = gr.HTML(label="知识图谱可视化")
192
- analyze_btn.click(fn=process_text, inputs=[input_text], outputs=[entity_output1, relation_output1, kg_html1])
193
-
194
- with gr.Tab("📁 上传文件"):
195
- file_input = gr.File(label="上传聊天记录文件", file_types=[".txt", ".jsonl", ".json", ".csv"])
196
- analyze_file_btn = gr.Button("分析文件")
197
- with gr.Row():
198
- entity_output2 = gr.Textbox(label="识别出的实体")
199
- relation_output2 = gr.Textbox(label="抽取的关系")
200
- kg_html2 = gr.HTML(label="知识图谱可视化")
201
- analyze_file_btn.click(fn=process_file, inputs=[file_input], outputs=[entity_output2, relation_output2, kg_html2])
202
-
203
- with gr.Tab("🗺️ 完整知识图谱"):
204
- gr.Markdown("## 当前累计构建的知识图谱")
205
- refresh_btn = gr.Button("刷新图谱")
206
- full_kg = gr.HTML()
207
- refresh_btn.click(fn=lambda: visualize_kg(), outputs=full_kg)
208
-
209
  if __name__ == "__main__":
210
  demo.launch()
 
 
 
1
  import gradio as gr
2
+ from transformers import BertTokenizerFast, BertForTokenClassification, BertForSequenceClassification
3
+ import torch
 
 
 
 
 
4
  from pathlib import Path
5
+ from pyvis.network import Network
6
+
7
+ # 加载模型和分词器
8
+ ner_tokenizer = BertTokenizerFast.from_pretrained("bert-base-chinese")
9
+ ner_model = BertForTokenClassification.from_pretrained("bert-base-chinese", num_labels=10)
10
+
11
+ re_model = BertForSequenceClassification.from_pretrained("bert-base-chinese", num_labels=5)
12
+ re_tokenizer = BertTokenizerFast.from_pretrained("bert-base-chinese")
13
 
14
+ # 定义标签和关系类型
15
+ label_list = ["O", "B-PER", "I-PER", "B-ORG", "I-ORG", "B-LOC", "I-LOC", "B-MISC", "I-MISC", "PAD"]
16
+ relation_list = ["no_relation", "per-org", "per-loc", "org-loc", "org-misc"]
 
17
 
18
+ # 用于存储知识图谱
19
  knowledge_graph = {
20
+ "entities": [],
21
  "relations": []
22
  }
23
 
24
+ def ner_predict(text):
25
+ inputs = ner_tokenizer(text, return_tensors="pt", truncation=True)
26
+ with torch.no_grad():
27
+ outputs = ner_model(**inputs).logits
28
+ predictions = torch.argmax(outputs, dim=2)
29
+ tokens = ner_tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
30
+ predicted_labels = [label_list[label_id] for label_id in predictions[0].numpy()]
31
+ entities = []
32
+ current_entity = ""
33
+ current_label = ""
34
+ start = None
35
+ for idx, (token, label) in enumerate(zip(tokens, predicted_labels)):
36
+ if label.startswith("B-"):
37
+ if current_entity:
38
+ entities.append((current_entity, current_label, start, idx))
39
+ current_entity = token.replace("##", "")
40
+ current_label = label[2:]
41
+ start = idx
42
+ elif label.startswith("I-") and current_label == label[2:]:
43
+ current_entity += token.replace("##", "")
44
+ else:
45
+ if current_entity:
46
+ entities.append((current_entity, current_label, start, idx))
47
+ current_entity = ""
48
+ current_label = ""
49
+ if current_entity:
50
+ entities.append((current_entity, current_label, start, len(tokens)))
51
+ return entities
52
+
53
+ def re_predict(text, entities):
54
+ relations = []
55
+ for i in range(len(entities)):
56
+ for j in range(len(entities)):
57
+ if i == j:
58
+ continue
59
+ head, tail = entities[i][0], entities[j][0]
60
+ input_text = f"{head} 和 {tail} 有什么关系?{text}"
61
+ inputs = re_tokenizer(input_text, return_tensors="pt", truncation=True, padding=True)
62
+ with torch.no_grad():
63
+ outputs = re_model(**inputs).logits
64
+ prediction = torch.argmax(outputs, dim=1).item()
65
+ if relation_list[prediction] != "no_relation":
66
+ relations.append((head, tail, relation_list[prediction]))
67
+ return relations
68
+
69
+ def analyze_text(text):
70
+ entities = ner_predict(text)
71
+ relations = re_predict(text, entities)
72
+ entity_list = [f"{ent[0]} ({ent[1]}) [{ent[2]}, {ent[3]}]" for ent in entities]
73
+ relation_list_text = [f"{rel[0]} --[{rel[2]}]-> {rel[1]}" for rel in relations]
74
+ # 更新全局知识图谱
75
+ knowledge_graph["entities"] = [(ent[0], ent[1]) for ent in entities]
76
+ knowledge_graph["relations"] = relations
77
+ return "\n".join(entity_list), "\n".join(relation_list_text)
78
 
79
  def visualize_kg():
80
  net = Network(height="600px", width="100%", notebook=True, directed=True)
 
101
  }
102
  """)
103
 
104
+ # 保存 HTML 到 Hugging Face Spaces 可访问路径
105
+ file_path = "/home/user/kg.html"
106
+ net.save_graph(file_path)
107
+
108
+ # 返回 iframe HTML
109
+ return f'<iframe src="/file=kg.html" width="100%" height="600px" frameborder="0"></iframe>'
110
+
111
+ # 搭建 Gradio 界面
112
+ with gr.Blocks(title="Wechat Ner Re") as demo:
113
+ gr.Markdown("## 微信聊天记录结构化系统(NER + RE + 知识图谱)")
114
+ with gr.Row():
115
+ input_text = gr.Textbox(lines=5, label="请输入文本")
116
+ analyze_button = gr.Button("分析文本")
117
+ with gr.Row():
118
+ ner_output = gr.Textbox(label="识别出的实体")
119
+ re_output = gr.Textbox(label="抽取的关系")
120
+ analyze_button.click(analyze_text, inputs=input_text, outputs=[ner_output, re_output])
121
+
122
+ # 显示知识图谱
123
+ gr.Markdown("## 知识图谱可视化")
124
+ with gr.Row():
125
+ kg_button = gr.Button("生成知识图谱")
126
+ kg_html1 = gr.HTML(label="知识图谱可视化", show_label=True)
127
+ kg_button.click(fn=visualize_kg, outputs=kg_html1)
128
+
129
+ # 启动应用
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
  if __name__ == "__main__":
131
  demo.launch()