chen666-666 commited on
Commit
8810e7b
·
1 Parent(s): 0207e75

add app.py and requirements.txt

Browse files
Files changed (2) hide show
  1. app.py +25 -178
  2. requirements.txt +0 -1
app.py CHANGED
@@ -1,5 +1,5 @@
1
  import torch
2
- from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline, AutoModel
3
  import gradio as gr
4
  import re
5
  import os
@@ -14,21 +14,6 @@ bert_tokenizer = AutoTokenizer.from_pretrained(bert_model_name)
14
  bert_ner_model = AutoModelForTokenClassification.from_pretrained("ckiplab/bert-base-chinese-ner")
15
  bert_ner_pipeline = pipeline("ner", model=bert_ner_model, tokenizer=bert_tokenizer, aggregation_strategy="simple")
16
 
17
- chatglm_model, chatglm_tokenizer = None, None
18
- use_chatglm = False
19
- try:
20
- if torch.cuda.is_available():
21
- chatglm_model_name = "THUDM/chatglm3-6b"
22
- chatglm_tokenizer = AutoTokenizer.from_pretrained(chatglm_model_name, trust_remote_code=True)
23
- chatglm_model = AutoModel.from_pretrained(
24
- chatglm_model_name, trust_remote_code=True, device_map="auto", torch_dtype=torch.float16
25
- ).eval()
26
- use_chatglm = True
27
- else:
28
- print("⚠️ 当前为 CPU 环境,ChatGLM3 不可用,将仅使用 BERT。")
29
- except Exception as e:
30
- print(f"❌ ChatGLM 加载失败: {e}")
31
-
32
  # ======================== 知识图谱结构 ========================
33
  knowledge_graph = {"entities": set(), "relations": set()}
34
 
@@ -40,46 +25,21 @@ def update_knowledge_graph(entities, relations):
40
 
41
  for r in relations:
42
  if isinstance(r, dict) and all(k in r for k in ("head", "tail", "relation")):
43
- # 标准化关系方向
44
  relation_tuple = (r['head'], r['tail'], r['relation'])
45
  reverse_tuple = (r['tail'], r['head'], r['relation'])
46
  if reverse_tuple not in knowledge_graph["relations"]:
47
  knowledge_graph["relations"].add(relation_tuple)
48
-
 
49
  def visualize_kg_text():
50
  nodes = [f"{ent[0]} ({ent[1]})" for ent in knowledge_graph["entities"]]
51
  edges = [f"{h} --[{r}]-> {t}" for h, t, r in knowledge_graph["relations"]]
52
  return "\n".join(["📌 实体:"] + nodes + ["", "📎 关系:"] + edges)
53
 
 
54
  # ======================== 实体识别(NER) ========================
55
- def ner(text, model_type="bert"):
56
  start_time = time.time()
57
- if model_type == "chatglm" and use_chatglm:
58
- try:
59
- prompt = f"""请从以下文本中识别所有实体,严格按照JSON列表格式返回,每个实体包含text、type、start、end字段:
60
- 示例:[{{"text": "北京", "type": "LOC", "start": 0, "end": 2}}]
61
- 文本:{text}"""
62
- response = chatglm_model.chat(chatglm_tokenizer, prompt, temperature=0.1)
63
- if isinstance(response, tuple):
64
- response = response[0]
65
-
66
- # 增强 JSON 解析
67
- try:
68
- json_str = re.search(r'\[.*\]', response, re.DOTALL).group()
69
- entities = json.loads(json_str)
70
- # 验证字段
71
- valid_entities = []
72
- for ent in entities:
73
- if all(k in ent for k in ("text", "type", "start", "end")):
74
- valid_entities.append(ent)
75
- return valid_entities, time.time() - start_time
76
- except Exception as e:
77
- print(f"JSON 解析失败: {e}")
78
- return [], time.time() - start_time
79
- except Exception as e:
80
- print(f"ChatGLM 调用失败:{e}")
81
- return [], time.time() - start_time
82
-
83
  # 使用微调的 BERT 中文 NER 模型
84
  raw_results = bert_ner_pipeline(text)
85
  entities = []
@@ -92,6 +52,7 @@ def ner(text, model_type="bert"):
92
  })
93
  return entities, time.time() - start_time
94
 
 
95
  # ======================== 关系抽取(RE) ========================
96
  def re_extract(entities, text):
97
  if len(entities) < 2:
@@ -108,164 +69,50 @@ def re_extract(entities, text):
108
  2. 关系类型使用:属于、位于、参与、其他
109
  3. 格式示例:[{{"head": "北京", "tail": "中国", "relation": "位于"}}]"""
110
 
111
- if use_chatglm:
112
- response = chatglm_model.chat(chatglm_tokenizer, prompt, temperature=0.1)
113
- if isinstance(response, tuple):
114
- response = response[0]
115
-
116
- # 提取 JSON
117
- try:
118
- json_str = re.search(r'\[.*\]', response, re.DOTALL).group()
119
- relations = json.loads(json_str)
120
- # 验证关系
121
- valid_relations = []
122
- valid_types = {"属于", "位于", "参与", "其他"}
123
- for rel in relations:
124
- if all(k in rel for k in ("head", "tail", "relation")) and rel["relation"] in valid_types:
125
- valid_relations.append(rel)
126
- return valid_relations
127
- except Exception as e:
128
- print(f"关系解析失败: {e}")
129
  except Exception as e:
130
  print(f"关系抽取失败: {e}")
131
 
132
  # 默认不生成任何关系
133
  return []
134
 
 
135
  # ======================== 文本分析主流程 ========================
136
- def process_text(text, model_type="bert"):
137
- entities, duration = ner(text, model_type)
138
  relations = re_extract(entities, text)
139
  update_knowledge_graph(entities, relations)
140
 
141
  ent_text = "\n".join(f"{e['text']} ({e['type']}) [{e['start']}-{e['end']}]" for e in entities)
142
  rel_text = "\n".join(f"{r['head']} --[{r['relation']}]-> {r['tail']}" for r in relations)
143
  kg_text = visualize_kg_text()
144
- return ent_text, rel_text, kg_text, f"{duration:.2f} 秒"
145
 
146
 
147
- def process_file(file, model_type="bert"):
148
- try:
149
- with open(file.name, 'rb') as f:
150
- content = f.read()
151
-
152
- if len(content) > 5 * 1024 * 1024:
153
- return "❌ 文件太大", "", "", ""
154
-
155
- # 检测编码
156
- try:
157
- encoding = chardet.detect(content)['encoding'] or 'utf-8'
158
- text = content.decode(encoding)
159
- except UnicodeDecodeError:
160
- # 尝试常见中文编码
161
- for enc in ['gb18030', 'utf-16', 'big5']:
162
- try:
163
- text = content.decode(enc)
164
- break
165
- except:
166
- continue
167
- else:
168
- return "❌ 编码解析失败", "", "", ""
169
-
170
- return process_text(text, model_type)
171
- except Exception as e:
172
- return f"❌ 文件处理错误: {str(e)}", "", "", ""
173
-
174
- # ======================== 模型评估与自动标注 ========================
175
- def convert_telegram_json_to_eval_format(path):
176
- with open(path, encoding="utf-8") as f:
177
- data = json.load(f)
178
- if isinstance(data, dict) and "text" in data:
179
- return [{"text": data["text"], "entities": [
180
- {"text": data["text"][e["start"]:e["end"]]} for e in data.get("entities", [])
181
- ]}]
182
- elif isinstance(data, list):
183
- return data
184
- elif isinstance(data, dict) and "messages" in data:
185
- result = []
186
- for m in data.get("messages", []):
187
- if isinstance(m.get("text"), str):
188
- result.append({"text": m["text"], "entities": []})
189
- elif isinstance(m.get("text"), list):
190
- txt = ''.join([x["text"] if isinstance(x, dict) else x for x in m["text"]])
191
- result.append({"text": txt, "entities": []})
192
- return result
193
- return []
194
-
195
-
196
- def evaluate_ner_model(data, model_type):
197
- y_true, y_pred = [], []
198
- for item in data:
199
- text = item["text"]
200
- gold_entities = []
201
- for e in item.get("entities", []):
202
- if "text" in e and "type" in e:
203
- # 使用哈希避免重复
204
- gold_entities.append(f"{e['text']}|{e['type']}|{e.get('start', -1)}|{e.get('end', -1)}")
205
-
206
- pred_entities = []
207
- pred, _ = ner(text, model_type)
208
- for e in pred:
209
- pred_entities.append(f"{e['text']}|{e['type']}|{e['start']}|{e['end']}")
210
-
211
- # 创建所有可能的实体集合
212
- all_entities = set(gold_entities + pred_entities)
213
- for ent in all_entities:
214
- y_true.append(1 if ent in gold_entities else 0)
215
- y_pred.append(1 if ent in pred_entities else 0)
216
-
217
- if not y_true:
218
- return "⚠️ 无有效标注数据"
219
-
220
- return f"Precision: {precision_score(y_true, y_pred):.2f}\nRecall: {recall_score(y_true, y_pred):.2f}\nF1: {f1_score(y_true, y_pred):.2f}"
221
-
222
- def auto_annotate(file, model_type):
223
- data = convert_telegram_json_to_eval_format(file.name)
224
- for item in data:
225
- ents, _ = ner(item["text"], model_type)
226
- item["entities"] = ents
227
- return json.dumps(data, ensure_ascii=False, indent=2)
228
-
229
- def save_json(json_text):
230
- fname = f"auto_labeled_{int(time.time())}.json"
231
- with open(fname, "w", encoding="utf-8") as f:
232
- f.write(json_text)
233
- return fname
234
-
235
  # ======================== Gradio 界面 ========================
236
  with gr.Blocks(css=".kg-graph {height: 500px;}") as demo:
237
  gr.Markdown("# 🤖 聊天记录实体关系识别系统")
238
 
239
  with gr.Tab("📄 文本分析"):
240
  input_text = gr.Textbox(lines=6, label="输入文本")
241
- model_type = gr.Radio(["bert", "chatglm"], value="bert", label="选择模型")
242
  btn = gr.Button("开始分析")
243
  out1 = gr.Textbox(label="识别实体")
244
  out2 = gr.Textbox(label="识别关系")
245
  out3 = gr.Textbox(label="知识图谱")
246
  out4 = gr.Textbox(label="耗时")
247
- btn.click(fn=process_text, inputs=[input_text, model_type], outputs=[out1, out2, out3, out4])
248
-
249
- with gr.Tab("🗂 文件分析"):
250
- file_input = gr.File(file_types=[".txt", ".json"])
251
- file_btn = gr.Button("上传并分析")
252
- fout1, fout2, fout3, fout4 = gr.Textbox(), gr.Textbox(), gr.Textbox(), gr.Textbox()
253
- file_btn.click(fn=process_file, inputs=[file_input, model_type], outputs=[fout1, fout2, fout3, fout4])
254
-
255
- with gr.Tab("📊 模型评估"):
256
- eval_file = gr.File(label="上传标注 JSON")
257
- eval_model = gr.Radio(["bert", "chatglm"], value="bert")
258
- eval_btn = gr.Button("开始评估")
259
- eval_output = gr.Textbox(label="评估结果", lines=5)
260
- eval_btn.click(lambda f, m: evaluate_ner_model(convert_telegram_json_to_eval_format(f.name), m), [eval_file, eval_model], eval_output)
261
-
262
- with gr.Tab("✏️ 自动标注"):
263
- raw_file = gr.File(label="上传 Telegram 原始 JSON")
264
- auto_model = gr.Radio(["bert", "chatglm"], value="bert")
265
- auto_btn = gr.Button("自动标注")
266
- marked_texts = gr.Textbox(label="标注结果", lines=20)
267
- download_btn = gr.Button("💾 下载标注文件")
268
- auto_btn.click(fn=auto_annotate, inputs=[raw_file, auto_model], outputs=marked_texts)
269
- download_btn.click(fn=save_json, inputs=marked_texts, outputs=gr.File())
270
 
271
  demo.launch(server_name="0.0.0.0", server_port=7860)
 
1
  import torch
2
+ from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline
3
  import gradio as gr
4
  import re
5
  import os
 
14
  bert_ner_model = AutoModelForTokenClassification.from_pretrained("ckiplab/bert-base-chinese-ner")
15
  bert_ner_pipeline = pipeline("ner", model=bert_ner_model, tokenizer=bert_tokenizer, aggregation_strategy="simple")
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  # ======================== 知识图谱结构 ========================
18
  knowledge_graph = {"entities": set(), "relations": set()}
19
 
 
25
 
26
  for r in relations:
27
  if isinstance(r, dict) and all(k in r for k in ("head", "tail", "relation")):
 
28
  relation_tuple = (r['head'], r['tail'], r['relation'])
29
  reverse_tuple = (r['tail'], r['head'], r['relation'])
30
  if reverse_tuple not in knowledge_graph["relations"]:
31
  knowledge_graph["relations"].add(relation_tuple)
32
+
33
+
34
  def visualize_kg_text():
35
  nodes = [f"{ent[0]} ({ent[1]})" for ent in knowledge_graph["entities"]]
36
  edges = [f"{h} --[{r}]-> {t}" for h, t, r in knowledge_graph["relations"]]
37
  return "\n".join(["📌 实体:"] + nodes + ["", "📎 关系:"] + edges)
38
 
39
+
40
  # ======================== 实体识别(NER) ========================
41
+ def ner(text):
42
  start_time = time.time()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  # 使用微调的 BERT 中文 NER 模型
44
  raw_results = bert_ner_pipeline(text)
45
  entities = []
 
52
  })
53
  return entities, time.time() - start_time
54
 
55
+
56
  # ======================== 关系抽取(RE) ========================
57
  def re_extract(entities, text):
58
  if len(entities) < 2:
 
69
  2. 关系类型使用:属于、位于、参与、其他
70
  3. 格式示例:[{{"head": "北京", "tail": "中国", "relation": "位于"}}]"""
71
 
72
+ # 仅使用 BERT
73
+ response = bert_ner_pipeline(prompt)
74
+ try:
75
+ json_str = re.search(r'\[.*\]', response, re.DOTALL).group()
76
+ relations = json.loads(json_str)
77
+ valid_relations = []
78
+ valid_types = {"属于", "位于", "参与", "其他"}
79
+ for rel in relations:
80
+ if all(k in rel for k in ("head", "tail", "relation")) and rel["relation"] in valid_types:
81
+ valid_relations.append(rel)
82
+ return valid_relations
83
+ except Exception as e:
84
+ print(f"关系解析失败: {e}")
 
 
 
 
 
85
  except Exception as e:
86
  print(f"关系抽取失败: {e}")
87
 
88
  # 默认不生成任何关系
89
  return []
90
 
91
+
92
  # ======================== 文本分析主流程 ========================
93
+ def process_text(text, state=None):
94
+ entities, duration = ner(text)
95
  relations = re_extract(entities, text)
96
  update_knowledge_graph(entities, relations)
97
 
98
  ent_text = "\n".join(f"{e['text']} ({e['type']}) [{e['start']}-{e['end']}]" for e in entities)
99
  rel_text = "\n".join(f"{r['head']} --[{r['relation']}]-> {r['tail']}" for r in relations)
100
  kg_text = visualize_kg_text()
101
+ return ent_text, rel_text, kg_text, f"{duration:.2f} 秒", state
102
 
103
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
  # ======================== Gradio 界面 ========================
105
  with gr.Blocks(css=".kg-graph {height: 500px;}") as demo:
106
  gr.Markdown("# 🤖 聊天记录实体关系识别系统")
107
 
108
  with gr.Tab("📄 文本分析"):
109
  input_text = gr.Textbox(lines=6, label="输入文本")
 
110
  btn = gr.Button("开始分析")
111
  out1 = gr.Textbox(label="识别实体")
112
  out2 = gr.Textbox(label="识别关系")
113
  out3 = gr.Textbox(label="知识图谱")
114
  out4 = gr.Textbox(label="耗时")
115
+ state = gr.State()
116
+ btn.click(fn=process_text, inputs=[input_text, state], outputs=[out1, out2, out3, out4, state])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
 
118
  demo.launch(server_name="0.0.0.0", server_port=7860)
requirements.txt CHANGED
@@ -4,7 +4,6 @@ torch>=2.1.0
4
  networkx>=3.0
5
  python-dotenv>=1.0.0
6
  sentencepiece>=0.2.0
7
- cpm-kernels>=1.0.11
8
  accelerate>=0.27.0
9
  scikit-learn>=1.3.0
10
  chardet>=5.2.0
 
4
  networkx>=3.0
5
  python-dotenv>=1.0.0
6
  sentencepiece>=0.2.0
 
7
  accelerate>=0.27.0
8
  scikit-learn>=1.3.0
9
  chardet>=5.2.0