Spaces:
Sleeping
Sleeping
Commit
·
07e97de
1
Parent(s):
ff6d08e
Add Gradio app for NER + RE
Browse files- app.py +244 -107
- requirements.txt +6 -6
app.py
CHANGED
@@ -1,133 +1,270 @@
|
|
1 |
-
import gradio as gr
|
2 |
-
from transformers import BertTokenizerFast, BertForTokenClassification, BertForSequenceClassification
|
3 |
import torch
|
4 |
-
from
|
5 |
-
|
|
|
6 |
import os
|
|
|
|
|
|
|
|
|
|
|
7 |
|
8 |
-
#
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
re_model = BertForSequenceClassification.from_pretrained("bert-base-chinese", num_labels=5)
|
13 |
-
re_tokenizer = BertTokenizerFast.from_pretrained("bert-base-chinese")
|
14 |
-
|
15 |
-
# 定义标签和关系类型
|
16 |
-
label_list = ["O", "B-PER", "I-PER", "B-ORG", "I-ORG", "B-LOC", "I-LOC", "B-MISC", "I-MISC", "PAD"]
|
17 |
-
relation_list = ["no_relation", "per-org", "per-loc", "org-loc", "org-misc"]
|
18 |
|
19 |
-
#
|
20 |
knowledge_graph = {
|
21 |
-
"entities":
|
22 |
"relations": []
|
23 |
}
|
24 |
|
25 |
-
def ner_predict(text):
|
26 |
-
inputs = ner_tokenizer(text, return_tensors="pt", truncation=True)
|
27 |
-
with torch.no_grad():
|
28 |
-
outputs = ner_model(**inputs).logits
|
29 |
-
predictions = torch.argmax(outputs, dim=2)
|
30 |
-
tokens = ner_tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
|
31 |
-
predicted_labels = [label_list[label_id] for label_id in predictions[0].numpy()]
|
32 |
-
|
33 |
-
entities = []
|
34 |
-
current_entity = ""
|
35 |
-
current_label = ""
|
36 |
-
start = None
|
37 |
-
special_tokens = {"[CLS]", "[SEP]", "[PAD]"}
|
38 |
-
|
39 |
-
for idx, (token, label) in enumerate(zip(tokens, predicted_labels)):
|
40 |
-
if token in special_tokens:
|
41 |
-
continue
|
42 |
-
if label.startswith("B-"):
|
43 |
-
if current_entity:
|
44 |
-
entities.append((current_entity, current_label, start, idx))
|
45 |
-
current_entity = token.replace("##", "")
|
46 |
-
current_label = label[2:]
|
47 |
-
start = idx
|
48 |
-
elif label.startswith("I-") and current_label == label[2:]:
|
49 |
-
current_entity += token.replace("##", "")
|
50 |
-
else:
|
51 |
-
if current_entity:
|
52 |
-
entities.append((current_entity, current_label, start, idx))
|
53 |
-
current_entity = ""
|
54 |
-
current_label = ""
|
55 |
-
if current_entity:
|
56 |
-
entities.append((current_entity, current_label, start, len(tokens)))
|
57 |
-
return entities
|
58 |
-
|
59 |
-
def re_predict(text, entities):
|
60 |
-
relations = []
|
61 |
-
for i in range(len(entities)):
|
62 |
-
for j in range(len(entities)):
|
63 |
-
if i == j:
|
64 |
-
continue
|
65 |
-
head, tail = entities[i][0], entities[j][0]
|
66 |
-
input_text = f"{head} 和 {tail} 有什么关系?{text}"
|
67 |
-
inputs = re_tokenizer(input_text, return_tensors="pt", truncation=True, padding=True)
|
68 |
-
with torch.no_grad():
|
69 |
-
outputs = re_model(**inputs).logits
|
70 |
-
prediction = torch.argmax(outputs, dim=1).item()
|
71 |
-
if relation_list[prediction] != "no_relation":
|
72 |
-
relations.append((head, tail, relation_list[prediction]))
|
73 |
-
return relations
|
74 |
|
75 |
-
def
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
|
|
80 |
|
81 |
-
# 更新知识图谱
|
82 |
-
knowledge_graph["entities"] = [(ent[0], ent[1]) for ent in entities]
|
83 |
-
knowledge_graph["relations"] = relations
|
84 |
-
|
85 |
-
return "\n".join(entity_list), "\n".join(relation_list_text)
|
86 |
|
87 |
def visualize_kg():
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
net = Network(height="600px", width="100%", notebook=False, directed=True)
|
92 |
node_map = {}
|
93 |
|
|
|
94 |
for idx, (name, type_) in enumerate(knowledge_graph["entities"]):
|
95 |
node_map[name] = idx
|
96 |
-
net.add_node(idx,
|
|
|
|
|
|
|
|
|
97 |
|
|
|
98 |
for head, tail, relation in knowledge_graph["relations"]:
|
99 |
if head in node_map and tail in node_map:
|
100 |
-
net.add_edge(node_map[head], node_map[tail],
|
|
|
|
|
|
|
101 |
|
|
|
102 |
net.set_options("""
|
103 |
{
|
104 |
-
"
|
105 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
106 |
}
|
107 |
""")
|
108 |
|
109 |
-
|
110 |
-
net.
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
132 |
if __name__ == "__main__":
|
133 |
-
demo.launch()
|
|
|
|
|
|
|
1 |
import torch
|
2 |
+
from transformers import BertTokenizer, BertModel
|
3 |
+
import gradio as gr
|
4 |
+
import re
|
5 |
import os
|
6 |
+
import json
|
7 |
+
import pandas as pd
|
8 |
+
import chardet
|
9 |
+
from pyvis.network import Network
|
10 |
+
import networkx as nx
|
11 |
|
12 |
+
# 初始化模型
|
13 |
+
model_name = "bert-base-chinese"
|
14 |
+
tokenizer = BertTokenizer.from_pretrained(model_name)
|
15 |
+
model = BertModel.from_pretrained(model_name)
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
|
17 |
+
# 知识图谱数据存储
|
18 |
knowledge_graph = {
|
19 |
+
"entities": set(),
|
20 |
"relations": []
|
21 |
}
|
22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
|
24 |
+
def update_knowledge_graph(entities, relations):
|
25 |
+
"""更新知识图谱数据"""
|
26 |
+
for e in entities:
|
27 |
+
knowledge_graph["entities"].add((e['text'], e['type']))
|
28 |
+
for r in relations:
|
29 |
+
knowledge_graph["relations"].append((r['head'], r['tail'], r['relation']))
|
30 |
|
|
|
|
|
|
|
|
|
|
|
31 |
|
32 |
def visualize_kg():
|
33 |
+
"""生成交互式知识图谱可视化(返回HTML内容)"""
|
34 |
+
net = Network(height="600px", width="100%", notebook=True, directed=True)
|
|
|
|
|
35 |
node_map = {}
|
36 |
|
37 |
+
# 添加节点
|
38 |
for idx, (name, type_) in enumerate(knowledge_graph["entities"]):
|
39 |
node_map[name] = idx
|
40 |
+
net.add_node(idx,
|
41 |
+
label=name,
|
42 |
+
title=f"类型:{type_}",
|
43 |
+
group=type_,
|
44 |
+
font={"size": 20})
|
45 |
|
46 |
+
# 添加边
|
47 |
for head, tail, relation in knowledge_graph["relations"]:
|
48 |
if head in node_map and tail in node_map:
|
49 |
+
net.add_edge(node_map[head], node_map[tail],
|
50 |
+
label=relation,
|
51 |
+
arrows='to',
|
52 |
+
font={"size": 16})
|
53 |
|
54 |
+
# 配置可视化参数
|
55 |
net.set_options("""
|
56 |
{
|
57 |
+
"nodes": {
|
58 |
+
"scaling": {
|
59 |
+
"min": 20,
|
60 |
+
"max": 40
|
61 |
+
}
|
62 |
+
},
|
63 |
+
"physics": {
|
64 |
+
"stabilization": {
|
65 |
+
"iterations": 200
|
66 |
+
},
|
67 |
+
"barnesHut": {
|
68 |
+
"springLength": 200
|
69 |
+
}
|
70 |
+
},
|
71 |
+
"interaction": {
|
72 |
+
"hover": true,
|
73 |
+
"tooltipDelay": 200
|
74 |
+
}
|
75 |
}
|
76 |
""")
|
77 |
|
78 |
+
# 生成HTML内容并修复CDN引用
|
79 |
+
html = net.generate_html()
|
80 |
+
html = html.replace('//cdnjs.cloudflare.com', 'https://cdnjs.cloudflare.com')
|
81 |
+
html = html.replace('//unpkg.com', 'https://unpkg.com')
|
82 |
+
return html
|
83 |
+
|
84 |
+
|
85 |
+
# ----------- NER 和 RE 抽取逻辑 -----------------
|
86 |
+
def ner(text):
|
87 |
+
pattern_name = r"[赵钱孙李周吴郑王冯陈褚卫蒋沈韩杨朱秦尤许何吕施张孔曹严华金魏陶姜][\u4e00-\u9fa5]{1,2}"
|
88 |
+
pattern_id = r"\b[a-zA-Z_][a-zA-Z0-9_]{4,}\b"
|
89 |
+
entities = []
|
90 |
+
|
91 |
+
# 中文姓名识别
|
92 |
+
for match in re.finditer(pattern_name, text):
|
93 |
+
entities.append({
|
94 |
+
"text": match.group(),
|
95 |
+
"start": match.start(),
|
96 |
+
"end": match.end(),
|
97 |
+
"type": "PersonName"
|
98 |
+
})
|
99 |
+
|
100 |
+
# 用户ID识别
|
101 |
+
for match in re.finditer(pattern_id, text):
|
102 |
+
if not any(e["start"] == match.start() for e in entities):
|
103 |
+
entities.append({
|
104 |
+
"text": match.group(),
|
105 |
+
"start": match.start(),
|
106 |
+
"end": match.end(),
|
107 |
+
"type": "UserID"
|
108 |
+
})
|
109 |
+
|
110 |
+
return sorted(entities, key=lambda x: x["start"])
|
111 |
+
|
112 |
+
|
113 |
+
def re_extract(entities, text):
|
114 |
+
relations = []
|
115 |
+
if len(entities) >= 2:
|
116 |
+
for i in range(len(entities) - 1):
|
117 |
+
head = entities[i]["text"]
|
118 |
+
tail = entities[i + 1]["text"]
|
119 |
+
context = text[entities[i]["end"]:entities[i + 1]["start"]]
|
120 |
+
|
121 |
+
# 关系判断逻辑
|
122 |
+
if "推荐" in context or "找" in context:
|
123 |
+
relation = "recommend"
|
124 |
+
elif "发送" in context or "发给" in context:
|
125 |
+
relation = "send_to"
|
126 |
+
elif "提到" in context or "说" in context:
|
127 |
+
relation = "mention"
|
128 |
+
else:
|
129 |
+
relation = "knows"
|
130 |
+
|
131 |
+
relations.append({
|
132 |
+
"head": head,
|
133 |
+
"tail": tail,
|
134 |
+
"relation": relation
|
135 |
+
})
|
136 |
+
return relations
|
137 |
+
|
138 |
+
|
139 |
+
# ----------- 文本处理逻辑 -----------------
|
140 |
+
def process_text(text):
|
141 |
+
# 实体识别
|
142 |
+
entities = ner(text)
|
143 |
+
|
144 |
+
# 关系抽取
|
145 |
+
relations = re_extract(entities, text)
|
146 |
+
|
147 |
+
# 更新知识图谱
|
148 |
+
update_knowledge_graph(entities, relations)
|
149 |
+
|
150 |
+
# 生成输出
|
151 |
+
entity_output = "\n".join([f"{e['text']} ({e['type']}) [{e['start']}, {e['end']}]" for e in entities])
|
152 |
+
relation_output = "\n".join([f"{r['head']} --[{r['relation']}]-> {r['tail']}" for r in relations])
|
153 |
+
kg_html = visualize_kg()
|
154 |
+
|
155 |
+
return entity_output, relation_output, gr.HTML(kg_html)
|
156 |
+
|
157 |
+
|
158 |
+
# ----------- 文件处理逻辑 -----------------
|
159 |
+
def detect_encoding(file_path):
|
160 |
+
with open(file_path, 'rb') as f:
|
161 |
+
raw_data = f.read(4096)
|
162 |
+
result = chardet.detect(raw_data)
|
163 |
+
return result['encoding'] if result['encoding'] else 'utf-8'
|
164 |
+
|
165 |
+
|
166 |
+
def process_file(file):
|
167 |
+
ext = os.path.splitext(file.name)[-1].lower()
|
168 |
+
full_text = ""
|
169 |
+
warning = ""
|
170 |
+
|
171 |
+
try:
|
172 |
+
encoding = detect_encoding(file.name)
|
173 |
+
|
174 |
+
# 处理不同文件格式
|
175 |
+
if ext == ".txt":
|
176 |
+
with open(file.name, "r", encoding=encoding) as f:
|
177 |
+
full_text = f.read()
|
178 |
+
|
179 |
+
elif ext == ".jsonl":
|
180 |
+
with open(file.name, "r", encoding=encoding) as f:
|
181 |
+
lines = f.readlines()
|
182 |
+
texts = []
|
183 |
+
skipped_lines = []
|
184 |
+
for i, line in enumerate(lines, start=1):
|
185 |
+
try:
|
186 |
+
obj = json.loads(line)
|
187 |
+
texts.append(obj.get("text", ""))
|
188 |
+
except Exception:
|
189 |
+
skipped_lines.append(i)
|
190 |
+
full_text = "\n".join(texts)
|
191 |
+
if skipped_lines:
|
192 |
+
warning = f"⚠️ 跳过 {len(skipped_lines)} 行无效 JSON(如第 {skipped_lines[0]} 行)\n\n"
|
193 |
+
|
194 |
+
elif ext == ".json":
|
195 |
+
with open(file.name, "r", encoding=encoding) as f:
|
196 |
+
data = json.load(f)
|
197 |
+
if isinstance(data, list):
|
198 |
+
full_text = "\n".join([str(item.get("text", "")) for item in data])
|
199 |
+
elif isinstance(data, dict):
|
200 |
+
full_text = data.get("text", "")
|
201 |
+
else:
|
202 |
+
return "❌ JSON 文件格式无法解析", "", gr.HTML()
|
203 |
+
|
204 |
+
elif ext == ".csv":
|
205 |
+
df = pd.read_csv(file.name, encoding=encoding)
|
206 |
+
if "text" in df.columns:
|
207 |
+
full_text = "\n".join(df["text"].astype(str))
|
208 |
+
else:
|
209 |
+
return "❌ CSV 中未找到 'text' 列", "", gr.HTML()
|
210 |
+
|
211 |
+
else:
|
212 |
+
return f"❌ 不支持的文件格式:{ext}", "", gr.HTML()
|
213 |
+
|
214 |
+
except Exception as e:
|
215 |
+
return f"❌ 文件读取错误:{str(e)}", "", gr.HTML()
|
216 |
+
|
217 |
+
# 处理文本并生成结果
|
218 |
+
entity_out, relation_out, kg_html = process_text(full_text)
|
219 |
+
return warning + entity_out, relation_out, kg_html
|
220 |
+
|
221 |
+
|
222 |
+
# ----------- Gradio 界面 -----------------
|
223 |
+
with gr.Blocks(
|
224 |
+
css=".kg-container {border: 1px solid #e0e0e0; border-radius: 10px; padding: 20px; margin-top: 20px;}") as demo:
|
225 |
+
gr.Markdown("""# 📱 微信聊天记录智能分析系统
|
226 |
+
**功能**:实体识别(NER) → 关系抽取(RE) → 动态知识图谱""")
|
227 |
+
|
228 |
+
with gr.Tab("✍️ 直接输入文本"):
|
229 |
+
gr.Markdown("## 直接输入聊天内容进行分析")
|
230 |
+
input_text = gr.Textbox(label="输入内容", lines=8,
|
231 |
+
placeholder="示例:\n张三:推荐李四加入项目组\n王五:把需求文档发送给赵六")
|
232 |
+
analyze_btn = gr.Button("开始分析", variant="primary")
|
233 |
+
|
234 |
+
with gr.Row():
|
235 |
+
entity_output1 = gr.Textbox(label="识别出的实体", interactive=False)
|
236 |
+
relation_output1 = gr.Textbox(label="抽取的关系", interactive=False)
|
237 |
+
kg_html1 = gr.HTML(label="知识图谱展示", elem_classes="kg-container")
|
238 |
+
|
239 |
+
analyze_btn.click(
|
240 |
+
fn=process_text,
|
241 |
+
inputs=[input_text],
|
242 |
+
outputs=[entity_output1, relation_output1, kg_html1]
|
243 |
+
)
|
244 |
+
|
245 |
+
with gr.Tab("📁 上传文件"):
|
246 |
+
gr.Markdown("## 上传聊天记录文件(支持多种格式)")
|
247 |
+
file_input = gr.File(label="选择文件", file_types=[".txt", ".jsonl", ".json", ".csv"])
|
248 |
+
analyze_file_btn = gr.Button("分析文件", variant="primary")
|
249 |
+
|
250 |
+
with gr.Row():
|
251 |
+
entity_output2 = gr.Textbox(label="识别出的实体", interactive=False)
|
252 |
+
relation_output2 = gr.Textbox(label="抽取的关系", interactive=False)
|
253 |
+
kg_html2 = gr.HTML(label="知识图谱展示", elem_classes="kg-container")
|
254 |
+
|
255 |
+
analyze_file_btn.click(
|
256 |
+
fn=process_file,
|
257 |
+
inputs=[file_input],
|
258 |
+
outputs=[entity_output2, relation_output2, kg_html2]
|
259 |
+
)
|
260 |
+
|
261 |
+
with gr.Tab("🗺️ 完整知识图谱"):
|
262 |
+
gr.Markdown("## 动态更新的完整知识图谱")
|
263 |
+
with gr.Row():
|
264 |
+
gr.Markdown("点击按钮刷新查看累计分析结果")
|
265 |
+
refresh_btn = gr.Button("立即刷新", variant="secondary")
|
266 |
+
full_kg = gr.HTML(elem_classes="kg-container")
|
267 |
+
refresh_btn.click(fn=lambda: visualize_kg(), outputs=full_kg)
|
268 |
+
|
269 |
if __name__ == "__main__":
|
270 |
+
demo.launch()
|
requirements.txt
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
gradio==
|
4 |
-
pandas==2.
|
5 |
-
chardet==5.
|
6 |
-
networkx==3.2.1
|
7 |
pyvis==0.3.2
|
|
|
|
1 |
+
transformers==4.30.2
|
2 |
+
torch==2.0.1
|
3 |
+
gradio==3.39.0
|
4 |
+
pandas==2.0.3
|
5 |
+
chardet==5.1.0
|
|
|
6 |
pyvis==0.3.2
|
7 |
+
networkx==3.1
|