Spaces:
Sleeping
Sleeping
Commit
·
c85af5a
1
Parent(s):
20683c1
Add Gradio app for NER + RE
Browse files
app.py
CHANGED
@@ -24,9 +24,11 @@ knowledge_graph = {
|
|
24 |
def update_knowledge_graph(entities, relations):
|
25 |
"""更新知识图谱数据"""
|
26 |
for e in entities:
|
27 |
-
|
|
|
28 |
for r in relations:
|
29 |
-
|
|
|
30 |
|
31 |
|
32 |
def visualize_kg():
|
@@ -34,17 +36,21 @@ def visualize_kg():
|
|
34 |
net = Network(height="600px", width="100%", notebook=True, directed=True)
|
35 |
node_map = {}
|
36 |
|
37 |
-
#
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
|
|
|
|
|
|
|
|
45 |
|
46 |
# 添加边
|
47 |
-
seen_edges = set()
|
48 |
for head, tail, relation in knowledge_graph["relations"]:
|
49 |
if head in node_map and tail in node_map:
|
50 |
edge_key = f"{head}-{tail}-{relation}"
|
@@ -55,7 +61,6 @@ def visualize_kg():
|
|
55 |
font={'size': 14})
|
56 |
seen_edges.add(edge_key)
|
57 |
|
58 |
-
# 优化布局配置
|
59 |
net.set_options("""
|
60 |
{
|
61 |
"nodes": {
|
@@ -82,24 +87,19 @@ def visualize_kg():
|
|
82 |
}
|
83 |
""")
|
84 |
|
85 |
-
# 生成HTML并修复资源引用
|
86 |
html = net.generate_html()
|
87 |
html = html.replace('//cdnjs.cloudflare.com', 'https://cdnjs.cloudflare.com')
|
88 |
html = html.replace('//unpkg.com', 'https://unpkg.com')
|
89 |
return f'<div class="kg-graph">{html}</div>'
|
90 |
|
91 |
|
92 |
-
# ----------- 增强的NER逻辑 -----------------
|
93 |
def ner(text):
|
94 |
-
# 优化中文姓名识别(排除常见动词后缀)
|
95 |
name_pattern = r"([赵钱孙李周吴郑王冯陈褚卫蒋沈韩杨朱秦尤许何吕施张孔曹严华金魏陶姜][\u4e00-\u9fa5]{1,2})(?![的地得啦啊呀])"
|
96 |
-
# 增强ID识别(支持带下划线和数字)
|
97 |
id_pattern = r"(?<!\S)([a-zA-Z_][a-zA-Z0-9_]{4,})(?![\u4e00-\u9fa5])"
|
98 |
|
99 |
entities = []
|
100 |
occupied = set()
|
101 |
|
102 |
-
# 识别中文姓名
|
103 |
for match in re.finditer(name_pattern, text):
|
104 |
start, end = match.start(1), match.end(1)
|
105 |
if not any(s <= start < e for (s, e) in occupied):
|
@@ -111,7 +111,6 @@ def ner(text):
|
|
111 |
})
|
112 |
occupied.update(range(start, end))
|
113 |
|
114 |
-
# 识别用户ID
|
115 |
for match in re.finditer(id_pattern, text):
|
116 |
start, end = match.start(1), match.end(1)
|
117 |
if not any(s <= start < e for (s, e) in occupied):
|
@@ -126,7 +125,6 @@ def ner(text):
|
|
126 |
return sorted(entities, key=lambda x: x["start"])
|
127 |
|
128 |
|
129 |
-
# ----------- 改进的关系抽取逻辑 -----------------
|
130 |
def re_extract(entities, text):
|
131 |
relations = []
|
132 |
triggers = {
|
@@ -136,17 +134,14 @@ def re_extract(entities, text):
|
|
136 |
}
|
137 |
|
138 |
for i in range(len(entities)):
|
139 |
-
# 检查前后两个窗口范围
|
140 |
for j in range(max(0, i - 2), min(len(entities), i + 3)):
|
141 |
if i == j:
|
142 |
continue
|
143 |
|
144 |
-
# 获取上下文内容
|
145 |
ctx_start = entities[i]["end"]
|
146 |
ctx_end = entities[j]["start"]
|
147 |
context = text[ctx_start:ctx_end].strip()
|
148 |
|
149 |
-
# 处理@提及的情况
|
150 |
if text.startswith('@', entities[i]["start"] - 1):
|
151 |
relations.append({
|
152 |
"head": entities[i]["text"],
|
@@ -155,7 +150,6 @@ def re_extract(entities, text):
|
|
155 |
})
|
156 |
continue
|
157 |
|
158 |
-
# 关系判断
|
159 |
relation_type = "knows"
|
160 |
for rel_type, keywords in triggers.items():
|
161 |
if any(kw in context for kw in keywords):
|
@@ -168,7 +162,6 @@ def re_extract(entities, text):
|
|
168 |
"relation": relation_type
|
169 |
})
|
170 |
|
171 |
-
# 去重
|
172 |
unique_relations = []
|
173 |
seen = set()
|
174 |
for rel in relations:
|
@@ -180,7 +173,6 @@ def re_extract(entities, text):
|
|
180 |
return unique_relations
|
181 |
|
182 |
|
183 |
-
# ----------- 文本处理流程 -----------------
|
184 |
def process_text(text):
|
185 |
try:
|
186 |
entities = ner(text)
|
@@ -197,19 +189,12 @@ def process_text(text):
|
|
197 |
)
|
198 |
kg_html = visualize_kg()
|
199 |
|
200 |
-
# 调试日志
|
201 |
-
print(f"Entities: {entities}")
|
202 |
-
print(f"Relations: {relations}")
|
203 |
-
with open("debug_kg.html", "w", encoding="utf-8") as f:
|
204 |
-
f.write(kg_html)
|
205 |
-
|
206 |
return entity_output, relation_output, gr.HTML(kg_html)
|
207 |
|
208 |
except Exception as e:
|
209 |
return f"处理出错: {str(e)}", "", gr.HTML()
|
210 |
|
211 |
|
212 |
-
# ----------- 文件处理模块 -----------------
|
213 |
def detect_encoding(file_path):
|
214 |
with open(file_path, 'rb') as f:
|
215 |
return chardet.detect(f.read(4096))['encoding'] or 'utf-8'
|
@@ -218,7 +203,6 @@ def detect_encoding(file_path):
|
|
218 |
def process_file(file):
|
219 |
ext = os.path.splitext(file.name)[-1].lower()
|
220 |
full_text = ""
|
221 |
-
warning = ""
|
222 |
|
223 |
try:
|
224 |
encoding = detect_encoding(file.name)
|
@@ -262,7 +246,7 @@ def process_file(file):
|
|
262 |
return f"❌ 文件处理错误: {str(e)}", "", gr.HTML()
|
263 |
|
264 |
|
265 |
-
#
|
266 |
css = """
|
267 |
.kg-container {
|
268 |
border: 1px solid #e0e0e0;
|
@@ -285,13 +269,12 @@ css = """
|
|
285 |
"""
|
286 |
|
287 |
with gr.Blocks(css=css) as demo:
|
288 |
-
gr.Markdown("
|
289 |
-
**功能**: 实体识别 → 关系抽取 → 动态知识图谱""")
|
290 |
|
291 |
with gr.Tab("✍️ 文本分析"):
|
292 |
gr.Markdown("### 直接输入聊天内容")
|
293 |
input_text = gr.Textbox(label="输入内容", lines=8,
|
294 |
-
placeholder="
|
295 |
analyze_btn = gr.Button("开始分析", variant="primary")
|
296 |
|
297 |
with gr.Row():
|
@@ -315,8 +298,7 @@ with gr.Blocks(css=css) as demo:
|
|
315 |
|
316 |
with gr.Tab("📁 文件分析"):
|
317 |
gr.Markdown("### 上传聊天记录文件")
|
318 |
-
file_input = gr.File(label="选择文件",
|
319 |
-
file_types=[".txt", ".json", ".jsonl", ".csv"])
|
320 |
file_btn = gr.Button("分析文件", variant="primary")
|
321 |
|
322 |
with gr.Row():
|
@@ -343,4 +325,4 @@ with gr.Blocks(css=css) as demo:
|
|
343 |
)
|
344 |
|
345 |
if __name__ == "__main__":
|
346 |
-
demo.launch(server_name="0.0.0.0", server_port=7860)
|
|
|
24 |
def update_knowledge_graph(entities, relations):
|
25 |
"""更新知识图谱数据"""
|
26 |
for e in entities:
|
27 |
+
if isinstance(e, dict) and 'text' in e and 'type' in e:
|
28 |
+
knowledge_graph["entities"].add((e['text'], e['type']))
|
29 |
for r in relations:
|
30 |
+
if isinstance(r, dict) and all(k in r for k in ("head", "tail", "relation")):
|
31 |
+
knowledge_graph["relations"].append((r['head'], r['tail'], r['relation']))
|
32 |
|
33 |
|
34 |
def visualize_kg():
|
|
|
36 |
net = Network(height="600px", width="100%", notebook=True, directed=True)
|
37 |
node_map = {}
|
38 |
|
39 |
+
# 添加节点
|
40 |
+
idx = 0
|
41 |
+
for ent in knowledge_graph["entities"]:
|
42 |
+
if isinstance(ent, tuple) and len(ent) == 2:
|
43 |
+
name, type_ = ent
|
44 |
+
node_map[name] = idx
|
45 |
+
net.add_node(idx,
|
46 |
+
label=name,
|
47 |
+
title=f"类型:{type_}",
|
48 |
+
group=type_,
|
49 |
+
font={'size': 20, 'face': 'SimHei'})
|
50 |
+
idx += 1
|
51 |
|
52 |
# 添加边
|
53 |
+
seen_edges = set()
|
54 |
for head, tail, relation in knowledge_graph["relations"]:
|
55 |
if head in node_map and tail in node_map:
|
56 |
edge_key = f"{head}-{tail}-{relation}"
|
|
|
61 |
font={'size': 14})
|
62 |
seen_edges.add(edge_key)
|
63 |
|
|
|
64 |
net.set_options("""
|
65 |
{
|
66 |
"nodes": {
|
|
|
87 |
}
|
88 |
""")
|
89 |
|
|
|
90 |
html = net.generate_html()
|
91 |
html = html.replace('//cdnjs.cloudflare.com', 'https://cdnjs.cloudflare.com')
|
92 |
html = html.replace('//unpkg.com', 'https://unpkg.com')
|
93 |
return f'<div class="kg-graph">{html}</div>'
|
94 |
|
95 |
|
|
|
96 |
def ner(text):
|
|
|
97 |
name_pattern = r"([赵钱孙李周吴郑王冯陈褚卫蒋沈韩杨朱秦尤许何吕施张孔曹严华金魏陶姜][\u4e00-\u9fa5]{1,2})(?![的地得啦啊呀])"
|
|
|
98 |
id_pattern = r"(?<!\S)([a-zA-Z_][a-zA-Z0-9_]{4,})(?![\u4e00-\u9fa5])"
|
99 |
|
100 |
entities = []
|
101 |
occupied = set()
|
102 |
|
|
|
103 |
for match in re.finditer(name_pattern, text):
|
104 |
start, end = match.start(1), match.end(1)
|
105 |
if not any(s <= start < e for (s, e) in occupied):
|
|
|
111 |
})
|
112 |
occupied.update(range(start, end))
|
113 |
|
|
|
114 |
for match in re.finditer(id_pattern, text):
|
115 |
start, end = match.start(1), match.end(1)
|
116 |
if not any(s <= start < e for (s, e) in occupied):
|
|
|
125 |
return sorted(entities, key=lambda x: x["start"])
|
126 |
|
127 |
|
|
|
128 |
def re_extract(entities, text):
|
129 |
relations = []
|
130 |
triggers = {
|
|
|
134 |
}
|
135 |
|
136 |
for i in range(len(entities)):
|
|
|
137 |
for j in range(max(0, i - 2), min(len(entities), i + 3)):
|
138 |
if i == j:
|
139 |
continue
|
140 |
|
|
|
141 |
ctx_start = entities[i]["end"]
|
142 |
ctx_end = entities[j]["start"]
|
143 |
context = text[ctx_start:ctx_end].strip()
|
144 |
|
|
|
145 |
if text.startswith('@', entities[i]["start"] - 1):
|
146 |
relations.append({
|
147 |
"head": entities[i]["text"],
|
|
|
150 |
})
|
151 |
continue
|
152 |
|
|
|
153 |
relation_type = "knows"
|
154 |
for rel_type, keywords in triggers.items():
|
155 |
if any(kw in context for kw in keywords):
|
|
|
162 |
"relation": relation_type
|
163 |
})
|
164 |
|
|
|
165 |
unique_relations = []
|
166 |
seen = set()
|
167 |
for rel in relations:
|
|
|
173 |
return unique_relations
|
174 |
|
175 |
|
|
|
176 |
def process_text(text):
|
177 |
try:
|
178 |
entities = ner(text)
|
|
|
189 |
)
|
190 |
kg_html = visualize_kg()
|
191 |
|
|
|
|
|
|
|
|
|
|
|
|
|
192 |
return entity_output, relation_output, gr.HTML(kg_html)
|
193 |
|
194 |
except Exception as e:
|
195 |
return f"处理出错: {str(e)}", "", gr.HTML()
|
196 |
|
197 |
|
|
|
198 |
def detect_encoding(file_path):
|
199 |
with open(file_path, 'rb') as f:
|
200 |
return chardet.detect(f.read(4096))['encoding'] or 'utf-8'
|
|
|
203 |
def process_file(file):
|
204 |
ext = os.path.splitext(file.name)[-1].lower()
|
205 |
full_text = ""
|
|
|
206 |
|
207 |
try:
|
208 |
encoding = detect_encoding(file.name)
|
|
|
246 |
return f"❌ 文件处理错误: {str(e)}", "", gr.HTML()
|
247 |
|
248 |
|
249 |
+
# Gradio UI
|
250 |
css = """
|
251 |
.kg-container {
|
252 |
border: 1px solid #e0e0e0;
|
|
|
269 |
"""
|
270 |
|
271 |
with gr.Blocks(css=css) as demo:
|
272 |
+
gr.Markdown("# 🚀 智能聊天记录分析系统\n**功能**: 实体识别 → 关系抽取 → 动态知识图谱")
|
|
|
273 |
|
274 |
with gr.Tab("✍️ 文本分析"):
|
275 |
gr.Markdown("### 直接输入聊天内容")
|
276 |
input_text = gr.Textbox(label="输入内容", lines=8,
|
277 |
+
placeholder="示例:张三@李四 请把需求文档_v2发送给王五")
|
278 |
analyze_btn = gr.Button("开始分析", variant="primary")
|
279 |
|
280 |
with gr.Row():
|
|
|
298 |
|
299 |
with gr.Tab("📁 文件分析"):
|
300 |
gr.Markdown("### 上传聊天记录文件")
|
301 |
+
file_input = gr.File(label="选择文件", file_types=[".txt", ".json", ".jsonl", ".csv"])
|
|
|
302 |
file_btn = gr.Button("分析文件", variant="primary")
|
303 |
|
304 |
with gr.Row():
|
|
|
325 |
)
|
326 |
|
327 |
if __name__ == "__main__":
|
328 |
+
demo.launch(server_name="0.0.0.0", server_port=7860)
|