chen666-666 commited on
Commit
07e97de
·
1 Parent(s): ff6d08e

Add Gradio app for NER + RE

Browse files
Files changed (2) hide show
  1. app.py +244 -107
  2. requirements.txt +6 -6
app.py CHANGED
@@ -1,133 +1,270 @@
1
- import gradio as gr
2
- from transformers import BertTokenizerFast, BertForTokenClassification, BertForSequenceClassification
3
  import torch
4
- from pathlib import Path
5
- from pyvis.network import Network
 
6
  import os
 
 
 
 
 
7
 
8
- # 加载模型和分词器
9
- ner_tokenizer = BertTokenizerFast.from_pretrained("bert-base-chinese")
10
- ner_model = BertForTokenClassification.from_pretrained("bert-base-chinese", num_labels=10)
11
-
12
- re_model = BertForSequenceClassification.from_pretrained("bert-base-chinese", num_labels=5)
13
- re_tokenizer = BertTokenizerFast.from_pretrained("bert-base-chinese")
14
-
15
- # 定义标签和关系类型
16
- label_list = ["O", "B-PER", "I-PER", "B-ORG", "I-ORG", "B-LOC", "I-LOC", "B-MISC", "I-MISC", "PAD"]
17
- relation_list = ["no_relation", "per-org", "per-loc", "org-loc", "org-misc"]
18
 
19
- # 用于存储知识图谱
20
  knowledge_graph = {
21
- "entities": [],
22
  "relations": []
23
  }
24
 
25
- def ner_predict(text):
26
- inputs = ner_tokenizer(text, return_tensors="pt", truncation=True)
27
- with torch.no_grad():
28
- outputs = ner_model(**inputs).logits
29
- predictions = torch.argmax(outputs, dim=2)
30
- tokens = ner_tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
31
- predicted_labels = [label_list[label_id] for label_id in predictions[0].numpy()]
32
-
33
- entities = []
34
- current_entity = ""
35
- current_label = ""
36
- start = None
37
- special_tokens = {"[CLS]", "[SEP]", "[PAD]"}
38
-
39
- for idx, (token, label) in enumerate(zip(tokens, predicted_labels)):
40
- if token in special_tokens:
41
- continue
42
- if label.startswith("B-"):
43
- if current_entity:
44
- entities.append((current_entity, current_label, start, idx))
45
- current_entity = token.replace("##", "")
46
- current_label = label[2:]
47
- start = idx
48
- elif label.startswith("I-") and current_label == label[2:]:
49
- current_entity += token.replace("##", "")
50
- else:
51
- if current_entity:
52
- entities.append((current_entity, current_label, start, idx))
53
- current_entity = ""
54
- current_label = ""
55
- if current_entity:
56
- entities.append((current_entity, current_label, start, len(tokens)))
57
- return entities
58
-
59
- def re_predict(text, entities):
60
- relations = []
61
- for i in range(len(entities)):
62
- for j in range(len(entities)):
63
- if i == j:
64
- continue
65
- head, tail = entities[i][0], entities[j][0]
66
- input_text = f"{head} 和 {tail} 有什么关系?{text}"
67
- inputs = re_tokenizer(input_text, return_tensors="pt", truncation=True, padding=True)
68
- with torch.no_grad():
69
- outputs = re_model(**inputs).logits
70
- prediction = torch.argmax(outputs, dim=1).item()
71
- if relation_list[prediction] != "no_relation":
72
- relations.append((head, tail, relation_list[prediction]))
73
- return relations
74
 
75
- def analyze_text(text):
76
- entities = ner_predict(text)
77
- relations = re_predict(text, entities)
78
- entity_list = [f"{ent[0]} ({ent[1]}) [{ent[2]}, {ent[3]}]" for ent in entities]
79
- relation_list_text = [f"{rel[0]} --[{rel[2]}]-> {rel[1]}" for rel in relations]
 
80
 
81
- # 更新知识图谱
82
- knowledge_graph["entities"] = [(ent[0], ent[1]) for ent in entities]
83
- knowledge_graph["relations"] = relations
84
-
85
- return "\n".join(entity_list), "\n".join(relation_list_text)
86
 
87
  def visualize_kg():
88
- if not knowledge_graph["entities"]:
89
- return "<p style='color:red;'>知识图谱为空,请先进行文本分析。</p>"
90
-
91
- net = Network(height="600px", width="100%", notebook=False, directed=True)
92
  node_map = {}
93
 
 
94
  for idx, (name, type_) in enumerate(knowledge_graph["entities"]):
95
  node_map[name] = idx
96
- net.add_node(idx, label=name, title=type_, group=type_)
 
 
 
 
97
 
 
98
  for head, tail, relation in knowledge_graph["relations"]:
99
  if head in node_map and tail in node_map:
100
- net.add_edge(node_map[head], node_map[tail], label=relation, arrows="to")
 
 
 
101
 
 
102
  net.set_options("""
103
  {
104
- "physics": { "stabilization": { "iterations": 100 }},
105
- "interaction": { "hover": true }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
  }
107
  """)
108
 
109
- file_path = Path("kg.html")
110
- net.save_graph(str(file_path))
111
-
112
- return f'<iframe src="file={file_path.name}" width="100%" height="600px" frameborder="0"></iframe>'
113
-
114
- # Gradio 界面
115
- with gr.Blocks(title="Wechat Ner Re") as demo:
116
- gr.Markdown("## 微信聊天记录结构化系统(NER + RE + 知识图谱)")
117
- with gr.Row():
118
- input_text = gr.Textbox(lines=5, label="请输入文本")
119
- analyze_button = gr.Button("分析文本")
120
- with gr.Row():
121
- ner_output = gr.Textbox(label="识别出的实体")
122
- re_output = gr.Textbox(label="抽取的关系")
123
- analyze_button.click(analyze_text, inputs=input_text, outputs=[ner_output, re_output])
124
-
125
- gr.Markdown("## 知识图谱可视化")
126
- with gr.Row():
127
- kg_button = gr.Button("生成知识图谱")
128
- kg_html1 = gr.HTML(label="知识图谱可视化", show_label=True)
129
- kg_button.click(fn=visualize_kg, outputs=kg_html1)
130
-
131
- # 启动应用
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
  if __name__ == "__main__":
133
- demo.launch()
 
 
 
1
  import torch
2
+ from transformers import BertTokenizer, BertModel
3
+ import gradio as gr
4
+ import re
5
  import os
6
+ import json
7
+ import pandas as pd
8
+ import chardet
9
+ from pyvis.network import Network
10
+ import networkx as nx
11
 
12
+ # 初始化模型
13
+ model_name = "bert-base-chinese"
14
+ tokenizer = BertTokenizer.from_pretrained(model_name)
15
+ model = BertModel.from_pretrained(model_name)
 
 
 
 
 
 
16
 
17
+ # 知识图谱数据存储
18
  knowledge_graph = {
19
+ "entities": set(),
20
  "relations": []
21
  }
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
+ def update_knowledge_graph(entities, relations):
25
+ """更新知识图谱数据"""
26
+ for e in entities:
27
+ knowledge_graph["entities"].add((e['text'], e['type']))
28
+ for r in relations:
29
+ knowledge_graph["relations"].append((r['head'], r['tail'], r['relation']))
30
 
 
 
 
 
 
31
 
32
  def visualize_kg():
33
+ """生成交互式知识图谱可视化(返回HTML内容)"""
34
+ net = Network(height="600px", width="100%", notebook=True, directed=True)
 
 
35
  node_map = {}
36
 
37
+ # 添加节点
38
  for idx, (name, type_) in enumerate(knowledge_graph["entities"]):
39
  node_map[name] = idx
40
+ net.add_node(idx,
41
+ label=name,
42
+ title=f"类型:{type_}",
43
+ group=type_,
44
+ font={"size": 20})
45
 
46
+ # 添加边
47
  for head, tail, relation in knowledge_graph["relations"]:
48
  if head in node_map and tail in node_map:
49
+ net.add_edge(node_map[head], node_map[tail],
50
+ label=relation,
51
+ arrows='to',
52
+ font={"size": 16})
53
 
54
+ # 配置可视化参数
55
  net.set_options("""
56
  {
57
+ "nodes": {
58
+ "scaling": {
59
+ "min": 20,
60
+ "max": 40
61
+ }
62
+ },
63
+ "physics": {
64
+ "stabilization": {
65
+ "iterations": 200
66
+ },
67
+ "barnesHut": {
68
+ "springLength": 200
69
+ }
70
+ },
71
+ "interaction": {
72
+ "hover": true,
73
+ "tooltipDelay": 200
74
+ }
75
  }
76
  """)
77
 
78
+ # 生成HTML内容并修复CDN引用
79
+ html = net.generate_html()
80
+ html = html.replace('//cdnjs.cloudflare.com', 'https://cdnjs.cloudflare.com')
81
+ html = html.replace('//unpkg.com', 'https://unpkg.com')
82
+ return html
83
+
84
+
85
+ # ----------- NER RE 抽取逻辑 -----------------
86
+ def ner(text):
87
+ pattern_name = r"[赵钱孙李周吴郑王冯陈褚卫蒋沈韩杨朱秦尤许何吕施张孔曹严华金魏陶姜][\u4e00-\u9fa5]{1,2}"
88
+ pattern_id = r"\b[a-zA-Z_][a-zA-Z0-9_]{4,}\b"
89
+ entities = []
90
+
91
+ # 中文姓名识别
92
+ for match in re.finditer(pattern_name, text):
93
+ entities.append({
94
+ "text": match.group(),
95
+ "start": match.start(),
96
+ "end": match.end(),
97
+ "type": "PersonName"
98
+ })
99
+
100
+ # 用户ID识别
101
+ for match in re.finditer(pattern_id, text):
102
+ if not any(e["start"] == match.start() for e in entities):
103
+ entities.append({
104
+ "text": match.group(),
105
+ "start": match.start(),
106
+ "end": match.end(),
107
+ "type": "UserID"
108
+ })
109
+
110
+ return sorted(entities, key=lambda x: x["start"])
111
+
112
+
113
+ def re_extract(entities, text):
114
+ relations = []
115
+ if len(entities) >= 2:
116
+ for i in range(len(entities) - 1):
117
+ head = entities[i]["text"]
118
+ tail = entities[i + 1]["text"]
119
+ context = text[entities[i]["end"]:entities[i + 1]["start"]]
120
+
121
+ # 关系判断逻辑
122
+ if "推荐" in context or "找" in context:
123
+ relation = "recommend"
124
+ elif "发送" in context or "发给" in context:
125
+ relation = "send_to"
126
+ elif "提到" in context or "说" in context:
127
+ relation = "mention"
128
+ else:
129
+ relation = "knows"
130
+
131
+ relations.append({
132
+ "head": head,
133
+ "tail": tail,
134
+ "relation": relation
135
+ })
136
+ return relations
137
+
138
+
139
+ # ----------- 文本处理逻辑 -----------------
140
+ def process_text(text):
141
+ # 实体识别
142
+ entities = ner(text)
143
+
144
+ # 关系抽取
145
+ relations = re_extract(entities, text)
146
+
147
+ # 更新知识图谱
148
+ update_knowledge_graph(entities, relations)
149
+
150
+ # 生成输出
151
+ entity_output = "\n".join([f"{e['text']} ({e['type']}) [{e['start']}, {e['end']}]" for e in entities])
152
+ relation_output = "\n".join([f"{r['head']} --[{r['relation']}]-> {r['tail']}" for r in relations])
153
+ kg_html = visualize_kg()
154
+
155
+ return entity_output, relation_output, gr.HTML(kg_html)
156
+
157
+
158
+ # ----------- 文件处理逻辑 -----------------
159
+ def detect_encoding(file_path):
160
+ with open(file_path, 'rb') as f:
161
+ raw_data = f.read(4096)
162
+ result = chardet.detect(raw_data)
163
+ return result['encoding'] if result['encoding'] else 'utf-8'
164
+
165
+
166
+ def process_file(file):
167
+ ext = os.path.splitext(file.name)[-1].lower()
168
+ full_text = ""
169
+ warning = ""
170
+
171
+ try:
172
+ encoding = detect_encoding(file.name)
173
+
174
+ # 处理不同文件格式
175
+ if ext == ".txt":
176
+ with open(file.name, "r", encoding=encoding) as f:
177
+ full_text = f.read()
178
+
179
+ elif ext == ".jsonl":
180
+ with open(file.name, "r", encoding=encoding) as f:
181
+ lines = f.readlines()
182
+ texts = []
183
+ skipped_lines = []
184
+ for i, line in enumerate(lines, start=1):
185
+ try:
186
+ obj = json.loads(line)
187
+ texts.append(obj.get("text", ""))
188
+ except Exception:
189
+ skipped_lines.append(i)
190
+ full_text = "\n".join(texts)
191
+ if skipped_lines:
192
+ warning = f"⚠️ 跳过 {len(skipped_lines)} 行无效 JSON(如第 {skipped_lines[0]} 行)\n\n"
193
+
194
+ elif ext == ".json":
195
+ with open(file.name, "r", encoding=encoding) as f:
196
+ data = json.load(f)
197
+ if isinstance(data, list):
198
+ full_text = "\n".join([str(item.get("text", "")) for item in data])
199
+ elif isinstance(data, dict):
200
+ full_text = data.get("text", "")
201
+ else:
202
+ return "❌ JSON 文件格式无法解析", "", gr.HTML()
203
+
204
+ elif ext == ".csv":
205
+ df = pd.read_csv(file.name, encoding=encoding)
206
+ if "text" in df.columns:
207
+ full_text = "\n".join(df["text"].astype(str))
208
+ else:
209
+ return "❌ CSV 中未找到 'text' 列", "", gr.HTML()
210
+
211
+ else:
212
+ return f"❌ 不支持的文件格式:{ext}", "", gr.HTML()
213
+
214
+ except Exception as e:
215
+ return f"❌ 文件读取错误:{str(e)}", "", gr.HTML()
216
+
217
+ # 处理文本并生成结果
218
+ entity_out, relation_out, kg_html = process_text(full_text)
219
+ return warning + entity_out, relation_out, kg_html
220
+
221
+
222
+ # ----------- Gradio 界面 -----------------
223
+ with gr.Blocks(
224
+ css=".kg-container {border: 1px solid #e0e0e0; border-radius: 10px; padding: 20px; margin-top: 20px;}") as demo:
225
+ gr.Markdown("""# 📱 微信聊天记录智能分析系统
226
+ **功能**:实体识别(NER) → 关系抽取(RE) → 动态知识图谱""")
227
+
228
+ with gr.Tab("✍️ 直接输入文本"):
229
+ gr.Markdown("## 直接输入聊天内容进行分析")
230
+ input_text = gr.Textbox(label="输入内容", lines=8,
231
+ placeholder="示例:\n张三:推荐李四加入项目组\n王五:把需求文档发送给赵六")
232
+ analyze_btn = gr.Button("开始分析", variant="primary")
233
+
234
+ with gr.Row():
235
+ entity_output1 = gr.Textbox(label="识别出的实体", interactive=False)
236
+ relation_output1 = gr.Textbox(label="抽取的关系", interactive=False)
237
+ kg_html1 = gr.HTML(label="知识图谱展示", elem_classes="kg-container")
238
+
239
+ analyze_btn.click(
240
+ fn=process_text,
241
+ inputs=[input_text],
242
+ outputs=[entity_output1, relation_output1, kg_html1]
243
+ )
244
+
245
+ with gr.Tab("📁 上传文件"):
246
+ gr.Markdown("## 上传聊天记录文件(支持多种格式)")
247
+ file_input = gr.File(label="选择文件", file_types=[".txt", ".jsonl", ".json", ".csv"])
248
+ analyze_file_btn = gr.Button("分析文件", variant="primary")
249
+
250
+ with gr.Row():
251
+ entity_output2 = gr.Textbox(label="识别出的实体", interactive=False)
252
+ relation_output2 = gr.Textbox(label="抽取的关系", interactive=False)
253
+ kg_html2 = gr.HTML(label="知识图谱展示", elem_classes="kg-container")
254
+
255
+ analyze_file_btn.click(
256
+ fn=process_file,
257
+ inputs=[file_input],
258
+ outputs=[entity_output2, relation_output2, kg_html2]
259
+ )
260
+
261
+ with gr.Tab("🗺️ 完整知识图谱"):
262
+ gr.Markdown("## 动态更新的完整知识图谱")
263
+ with gr.Row():
264
+ gr.Markdown("点击按钮刷新查看累计分析结果")
265
+ refresh_btn = gr.Button("立即刷新", variant="secondary")
266
+ full_kg = gr.HTML(elem_classes="kg-container")
267
+ refresh_btn.click(fn=lambda: visualize_kg(), outputs=full_kg)
268
+
269
  if __name__ == "__main__":
270
+ demo.launch()
requirements.txt CHANGED
@@ -1,7 +1,7 @@
1
- torch==2.1.2
2
- transformers==4.36.2
3
- gradio==4.19.2
4
- pandas==2.2.1
5
- chardet==5.2.0
6
- networkx==3.2.1
7
  pyvis==0.3.2
 
 
1
+ transformers==4.30.2
2
+ torch==2.0.1
3
+ gradio==3.39.0
4
+ pandas==2.0.3
5
+ chardet==5.1.0
 
6
  pyvis==0.3.2
7
+ networkx==3.1