chen666-666 commited on
Commit
c85af5a
·
1 Parent(s): 20683c1

Add Gradio app for NER + RE

Browse files
Files changed (1) hide show
  1. app.py +22 -40
app.py CHANGED
@@ -24,9 +24,11 @@ knowledge_graph = {
24
  def update_knowledge_graph(entities, relations):
25
  """更新知识图谱数据"""
26
  for e in entities:
27
- knowledge_graph["entities"].add((e['text'], e['type']))
 
28
  for r in relations:
29
- knowledge_graph["relations"].append((r['head'], r['tail'], r['relation']))
 
30
 
31
 
32
  def visualize_kg():
@@ -34,17 +36,21 @@ def visualize_kg():
34
  net = Network(height="600px", width="100%", notebook=True, directed=True)
35
  node_map = {}
36
 
37
- # 添加节点(带中文支持)
38
- for idx, (name, type_) in enumerate(knowledge_graph["entities"]):
39
- node_map[name] = idx
40
- net.add_node(idx,
41
- label=name,
42
- title=f"类型:{type_}",
43
- group=type_,
44
- font={'size': 20, 'face': 'SimHei'}) # 使用中文字体
 
 
 
 
45
 
46
  # 添加边
47
- seen_edges = set() # 防止重复边
48
  for head, tail, relation in knowledge_graph["relations"]:
49
  if head in node_map and tail in node_map:
50
  edge_key = f"{head}-{tail}-{relation}"
@@ -55,7 +61,6 @@ def visualize_kg():
55
  font={'size': 14})
56
  seen_edges.add(edge_key)
57
 
58
- # 优化布局配置
59
  net.set_options("""
60
  {
61
  "nodes": {
@@ -82,24 +87,19 @@ def visualize_kg():
82
  }
83
  """)
84
 
85
- # 生成HTML并修复资源引用
86
  html = net.generate_html()
87
  html = html.replace('//cdnjs.cloudflare.com', 'https://cdnjs.cloudflare.com')
88
  html = html.replace('//unpkg.com', 'https://unpkg.com')
89
  return f'<div class="kg-graph">{html}</div>'
90
 
91
 
92
- # ----------- 增强的NER逻辑 -----------------
93
  def ner(text):
94
- # 优化中文姓名识别(排除常见动词后缀)
95
  name_pattern = r"([赵钱孙李周吴郑王冯陈褚卫蒋沈韩杨朱秦尤许何吕施张孔曹严华金魏陶姜][\u4e00-\u9fa5]{1,2})(?![的地得啦啊呀])"
96
- # 增强ID识别(支持带下划线和数字)
97
  id_pattern = r"(?<!\S)([a-zA-Z_][a-zA-Z0-9_]{4,})(?![\u4e00-\u9fa5])"
98
 
99
  entities = []
100
  occupied = set()
101
 
102
- # 识别中文姓名
103
  for match in re.finditer(name_pattern, text):
104
  start, end = match.start(1), match.end(1)
105
  if not any(s <= start < e for (s, e) in occupied):
@@ -111,7 +111,6 @@ def ner(text):
111
  })
112
  occupied.update(range(start, end))
113
 
114
- # 识别用户ID
115
  for match in re.finditer(id_pattern, text):
116
  start, end = match.start(1), match.end(1)
117
  if not any(s <= start < e for (s, e) in occupied):
@@ -126,7 +125,6 @@ def ner(text):
126
  return sorted(entities, key=lambda x: x["start"])
127
 
128
 
129
- # ----------- 改进的关系抽取逻辑 -----------------
130
  def re_extract(entities, text):
131
  relations = []
132
  triggers = {
@@ -136,17 +134,14 @@ def re_extract(entities, text):
136
  }
137
 
138
  for i in range(len(entities)):
139
- # 检查前后两个窗口范围
140
  for j in range(max(0, i - 2), min(len(entities), i + 3)):
141
  if i == j:
142
  continue
143
 
144
- # 获取上下文内容
145
  ctx_start = entities[i]["end"]
146
  ctx_end = entities[j]["start"]
147
  context = text[ctx_start:ctx_end].strip()
148
 
149
- # 处理@提及的情况
150
  if text.startswith('@', entities[i]["start"] - 1):
151
  relations.append({
152
  "head": entities[i]["text"],
@@ -155,7 +150,6 @@ def re_extract(entities, text):
155
  })
156
  continue
157
 
158
- # 关系判断
159
  relation_type = "knows"
160
  for rel_type, keywords in triggers.items():
161
  if any(kw in context for kw in keywords):
@@ -168,7 +162,6 @@ def re_extract(entities, text):
168
  "relation": relation_type
169
  })
170
 
171
- # 去重
172
  unique_relations = []
173
  seen = set()
174
  for rel in relations:
@@ -180,7 +173,6 @@ def re_extract(entities, text):
180
  return unique_relations
181
 
182
 
183
- # ----------- 文本处理流程 -----------------
184
  def process_text(text):
185
  try:
186
  entities = ner(text)
@@ -197,19 +189,12 @@ def process_text(text):
197
  )
198
  kg_html = visualize_kg()
199
 
200
- # 调试日志
201
- print(f"Entities: {entities}")
202
- print(f"Relations: {relations}")
203
- with open("debug_kg.html", "w", encoding="utf-8") as f:
204
- f.write(kg_html)
205
-
206
  return entity_output, relation_output, gr.HTML(kg_html)
207
 
208
  except Exception as e:
209
  return f"处理出错: {str(e)}", "", gr.HTML()
210
 
211
 
212
- # ----------- 文件处理模块 -----------------
213
  def detect_encoding(file_path):
214
  with open(file_path, 'rb') as f:
215
  return chardet.detect(f.read(4096))['encoding'] or 'utf-8'
@@ -218,7 +203,6 @@ def detect_encoding(file_path):
218
  def process_file(file):
219
  ext = os.path.splitext(file.name)[-1].lower()
220
  full_text = ""
221
- warning = ""
222
 
223
  try:
224
  encoding = detect_encoding(file.name)
@@ -262,7 +246,7 @@ def process_file(file):
262
  return f"❌ 文件处理错误: {str(e)}", "", gr.HTML()
263
 
264
 
265
- # ----------- Gradio界面 -----------------
266
  css = """
267
  .kg-container {
268
  border: 1px solid #e0e0e0;
@@ -285,13 +269,12 @@ css = """
285
  """
286
 
287
  with gr.Blocks(css=css) as demo:
288
- gr.Markdown("""# 🚀 智能聊天记录分析系统
289
- **功能**: 实体识别 → 关系抽取 → 动态知识图谱""")
290
 
291
  with gr.Tab("✍️ 文本分析"):
292
  gr.Markdown("### 直接输入聊天内容")
293
  input_text = gr.Textbox(label="输入内容", lines=8,
294
- placeholder="示例:\n张三@李四 请把需求文档_v2发送给王五\n李四回复:已发送至[email protected]")
295
  analyze_btn = gr.Button("开始分析", variant="primary")
296
 
297
  with gr.Row():
@@ -315,8 +298,7 @@ with gr.Blocks(css=css) as demo:
315
 
316
  with gr.Tab("📁 文件分析"):
317
  gr.Markdown("### 上传聊天记录文件")
318
- file_input = gr.File(label="选择文件",
319
- file_types=[".txt", ".json", ".jsonl", ".csv"])
320
  file_btn = gr.Button("分析文件", variant="primary")
321
 
322
  with gr.Row():
@@ -343,4 +325,4 @@ with gr.Blocks(css=css) as demo:
343
  )
344
 
345
  if __name__ == "__main__":
346
- demo.launch(server_name="0.0.0.0", server_port=7860)
 
24
  def update_knowledge_graph(entities, relations):
25
  """更新知识图谱数据"""
26
  for e in entities:
27
+ if isinstance(e, dict) and 'text' in e and 'type' in e:
28
+ knowledge_graph["entities"].add((e['text'], e['type']))
29
  for r in relations:
30
+ if isinstance(r, dict) and all(k in r for k in ("head", "tail", "relation")):
31
+ knowledge_graph["relations"].append((r['head'], r['tail'], r['relation']))
32
 
33
 
34
  def visualize_kg():
 
36
  net = Network(height="600px", width="100%", notebook=True, directed=True)
37
  node_map = {}
38
 
39
+ # 添加节点
40
+ idx = 0
41
+ for ent in knowledge_graph["entities"]:
42
+ if isinstance(ent, tuple) and len(ent) == 2:
43
+ name, type_ = ent
44
+ node_map[name] = idx
45
+ net.add_node(idx,
46
+ label=name,
47
+ title=f"类型:{type_}",
48
+ group=type_,
49
+ font={'size': 20, 'face': 'SimHei'})
50
+ idx += 1
51
 
52
  # 添加边
53
+ seen_edges = set()
54
  for head, tail, relation in knowledge_graph["relations"]:
55
  if head in node_map and tail in node_map:
56
  edge_key = f"{head}-{tail}-{relation}"
 
61
  font={'size': 14})
62
  seen_edges.add(edge_key)
63
 
 
64
  net.set_options("""
65
  {
66
  "nodes": {
 
87
  }
88
  """)
89
 
 
90
  html = net.generate_html()
91
  html = html.replace('//cdnjs.cloudflare.com', 'https://cdnjs.cloudflare.com')
92
  html = html.replace('//unpkg.com', 'https://unpkg.com')
93
  return f'<div class="kg-graph">{html}</div>'
94
 
95
 
 
96
  def ner(text):
 
97
  name_pattern = r"([赵钱孙李周吴郑王冯陈褚卫蒋沈韩杨朱秦尤许何吕施张孔曹严华金魏陶姜][\u4e00-\u9fa5]{1,2})(?![的地得啦啊呀])"
 
98
  id_pattern = r"(?<!\S)([a-zA-Z_][a-zA-Z0-9_]{4,})(?![\u4e00-\u9fa5])"
99
 
100
  entities = []
101
  occupied = set()
102
 
 
103
  for match in re.finditer(name_pattern, text):
104
  start, end = match.start(1), match.end(1)
105
  if not any(s <= start < e for (s, e) in occupied):
 
111
  })
112
  occupied.update(range(start, end))
113
 
 
114
  for match in re.finditer(id_pattern, text):
115
  start, end = match.start(1), match.end(1)
116
  if not any(s <= start < e for (s, e) in occupied):
 
125
  return sorted(entities, key=lambda x: x["start"])
126
 
127
 
 
128
  def re_extract(entities, text):
129
  relations = []
130
  triggers = {
 
134
  }
135
 
136
  for i in range(len(entities)):
 
137
  for j in range(max(0, i - 2), min(len(entities), i + 3)):
138
  if i == j:
139
  continue
140
 
 
141
  ctx_start = entities[i]["end"]
142
  ctx_end = entities[j]["start"]
143
  context = text[ctx_start:ctx_end].strip()
144
 
 
145
  if text.startswith('@', entities[i]["start"] - 1):
146
  relations.append({
147
  "head": entities[i]["text"],
 
150
  })
151
  continue
152
 
 
153
  relation_type = "knows"
154
  for rel_type, keywords in triggers.items():
155
  if any(kw in context for kw in keywords):
 
162
  "relation": relation_type
163
  })
164
 
 
165
  unique_relations = []
166
  seen = set()
167
  for rel in relations:
 
173
  return unique_relations
174
 
175
 
 
176
  def process_text(text):
177
  try:
178
  entities = ner(text)
 
189
  )
190
  kg_html = visualize_kg()
191
 
 
 
 
 
 
 
192
  return entity_output, relation_output, gr.HTML(kg_html)
193
 
194
  except Exception as e:
195
  return f"处理出错: {str(e)}", "", gr.HTML()
196
 
197
 
 
198
  def detect_encoding(file_path):
199
  with open(file_path, 'rb') as f:
200
  return chardet.detect(f.read(4096))['encoding'] or 'utf-8'
 
203
  def process_file(file):
204
  ext = os.path.splitext(file.name)[-1].lower()
205
  full_text = ""
 
206
 
207
  try:
208
  encoding = detect_encoding(file.name)
 
246
  return f"❌ 文件处理错误: {str(e)}", "", gr.HTML()
247
 
248
 
249
+ # Gradio UI
250
  css = """
251
  .kg-container {
252
  border: 1px solid #e0e0e0;
 
269
  """
270
 
271
  with gr.Blocks(css=css) as demo:
272
+ gr.Markdown("# 🚀 智能聊天记录分析系统\n**功能**: 实体识别 → 关系抽取 → 动态知识图谱")
 
273
 
274
  with gr.Tab("✍️ 文本分析"):
275
  gr.Markdown("### 直接输入聊天内容")
276
  input_text = gr.Textbox(label="输入内容", lines=8,
277
+ placeholder="示例:张三@李四 请把需求文档_v2发送给王五")
278
  analyze_btn = gr.Button("开始分析", variant="primary")
279
 
280
  with gr.Row():
 
298
 
299
  with gr.Tab("📁 文件分析"):
300
  gr.Markdown("### 上传聊天记录文件")
301
+ file_input = gr.File(label="选择文件", file_types=[".txt", ".json", ".jsonl", ".csv"])
 
302
  file_btn = gr.Button("分析文件", variant="primary")
303
 
304
  with gr.Row():
 
325
  )
326
 
327
  if __name__ == "__main__":
328
+ demo.launch(server_name="0.0.0.0", server_port=7860)