chen666-666 commited on
Commit
6129c00
·
1 Parent(s): 0378c00

add app.py and requirements.txt

Browse files
Files changed (1) hide show
  1. app.py +113 -65
app.py CHANGED
@@ -9,46 +9,59 @@ from sklearn.metrics import precision_score, recall_score, f1_score
9
  import time
10
 
11
  # ======================== 模型加载 ========================
12
- bert_model_name = "bert-base-chinese"
13
- bert_tokenizer = AutoTokenizer.from_pretrained(bert_model_name)
14
- bert_ner_model = AutoModelForTokenClassification.from_pretrained("ckiplab/bert-base-chinese-ner")
15
- bert_ner_pipeline = pipeline("ner", model=bert_ner_model, tokenizer=bert_tokenizer, aggregation_strategy="simple")
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
  chatglm_model, chatglm_tokenizer = None, None
18
  use_chatglm = False
19
  try:
20
- chatglm_model_name = "THUDM/chatglm-6b-int4" # 4-bit量化版本
21
- chatglm_tokenizer = AutoTokenizer.from_pretrained(
22
- chatglm_model_name,
23
- trust_remote_code=True
24
- )
25
  chatglm_model = AutoModel.from_pretrained(
26
  chatglm_model_name,
27
  trust_remote_code=True,
28
  device_map="cpu",
29
- torch_dtype=torch.float32 # 必须使用float32
30
  ).eval()
31
  use_chatglm = True
32
- print("✅ 4-bit量化版ChatGLM加载成功(需6GB内存)")
33
  except Exception as e:
34
- print(f"❌ 量化模型加载失败: {e}")
35
 
36
  # ======================== 知识图谱结构 ========================
37
  knowledge_graph = {"entities": set(), "relations": set()}
38
 
39
-
40
  def update_knowledge_graph(entities, relations):
41
  for e in entities:
42
  if isinstance(e, dict) and 'text' in e and 'type' in e:
43
  knowledge_graph["entities"].add((e['text'], e['type']))
44
-
 
45
  for r in relations:
46
  if isinstance(r, dict) and all(k in r for k in ("head", "tail", "relation")):
47
- # 标准化关系方向
48
- relation_tuple = (r['head'], r['tail'], r['relation'])
49
- reverse_tuple = (r['tail'], r['head'], r['relation'])
50
- if reverse_tuple not in knowledge_graph["relations"]:
51
- knowledge_graph["relations"].add(relation_tuple)
 
 
 
52
 
53
 
54
  def visualize_kg_text():
@@ -58,50 +71,57 @@ def visualize_kg_text():
58
 
59
 
60
  # ======================== 实体识别(NER) ========================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  def ner(text, model_type="bert"):
62
  start_time = time.time()
63
  if model_type == "chatglm" and use_chatglm:
64
- try:
65
- prompt = f"""请从以下文本中识别所有实体,严格按照JSON列表格式返回,每个实体包含text、type、start、end字段:
66
- 示例:[{{"text": "北京", "type": "LOC", "start": 0, "end": 2}}]
67
- 文本:{text}"""
68
- response = chatglm_model.chat(chatglm_tokenizer, prompt, temperature=0.1)
69
- if isinstance(response, tuple):
70
- response = response[0]
71
-
72
- # 增强 JSON 解析
73
- try:
74
- json_str = re.search(r'\[.*\]', response, re.DOTALL).group()
75
- entities = json.loads(json_str)
76
- # 验证字段
77
- valid_entities = []
78
- for ent in entities:
79
- if all(k in ent for k in ("text", "type", "start", "end")):
80
- valid_entities.append(ent)
81
- return valid_entities, time.time() - start_time
82
- except Exception as e:
83
- print(f"JSON 解析失败: {e}")
84
- return [], time.time() - start_time
85
- except Exception as e:
86
- print(f"ChatGLM 调用失败:{e}")
87
- return [], time.time() - start_time
88
 
89
- # 使用微调的 BERT 中文 NER 模型
90
  raw_results = bert_ner_pipeline(text)
91
  entities = []
92
  for r in raw_results:
 
93
  entities.append({
94
- "text": r["word"],
95
- "start": r["start"],
96
- "end": r["end"],
97
- "type": r["entity_group"]
98
  })
99
- return entities, time.time() - start_time
100
 
 
 
 
101
 
102
  # ======================== 关系抽取(RE) ========================
103
  def re_extract(entities, text):
104
- if len(entities) < 2:
 
 
 
 
105
  return []
106
 
107
  relations = []
@@ -204,30 +224,55 @@ def convert_telegram_json_to_eval_format(path):
204
 
205
  def evaluate_ner_model(data, model_type):
206
  y_true, y_pred = [], []
 
 
207
  for item in data:
208
  text = item["text"]
209
  gold_entities = []
210
  for e in item.get("entities", []):
211
  if "text" in e and "type" in e:
212
- # 使用哈希避免重复
213
- gold_entities.append(f"{e['text']}|{e['type']}|{e.get('start', -1)}|{e.get('end', -1)}")
214
-
215
- pred_entities = []
216
- pred, _ = ner(text, model_type)
217
- for e in pred:
218
- pred_entities.append(f"{e['text']}|{e['type']}|{e['start']}|{e['end']}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
219
 
220
- # 创建所有可能的实体集合
221
- all_entities = set(gold_entities + pred_entities)
222
- for ent in all_entities:
223
- y_true.append(1 if ent in gold_entities else 0)
224
- y_pred.append(1 if ent in pred_entities else 0)
225
 
226
  if not y_true:
227
  return "⚠️ 无有效标注数据"
228
 
229
- return f"Precision: {precision_score(y_true, y_pred):.2f}\nRecall: {recall_score(y_true, y_pred):.2f}\nF1: {f1_score(y_true, y_pred):.2f}"
230
-
 
231
 
232
  def auto_annotate(file, model_type):
233
  data = convert_telegram_json_to_eval_format(file.name)
@@ -245,7 +290,10 @@ def save_json(json_text):
245
 
246
 
247
  # ======================== Gradio 界面 ========================
248
- with gr.Blocks(css=".kg-graph {height: 500px;}") as demo:
 
 
 
249
  gr.Markdown("# 🤖 聊天记录实体关系识别系统")
250
 
251
  with gr.Tab("📄 文本分析"):
 
9
  import time
10
 
11
  # ======================== 模型加载 ========================
12
+ NER_MODEL_NAME = "uer/roberta-base-finetuned-cluener2020-chinese"
13
+ bert_tokenizer = AutoTokenizer.from_pretrained(NER_MODEL_NAME)
14
+ bert_ner_model = AutoModelForTokenClassification.from_pretrained(NER_MODEL_NAME)
15
+ bert_ner_pipeline = pipeline(
16
+ "ner",
17
+ model=bert_ner_model,
18
+ tokenizer=bert_tokenizer,
19
+ aggregation_strategy="first"
20
+ )
21
+
22
+ LABEL_MAPPING = {
23
+ "address": "LOC",
24
+ "company": "ORG",
25
+ "name": "PER",
26
+ "organization": "ORG",
27
+ "position": "TITLE"
28
+ }
29
 
30
  chatglm_model, chatglm_tokenizer = None, None
31
  use_chatglm = False
32
  try:
33
+ chatglm_model_name = "THUDM/chatglm-6b-int4"
34
+ chatglm_tokenizer = AutoTokenizer.from_pretrained(chatglm_model_name, trust_remote_code=True)
 
 
 
35
  chatglm_model = AutoModel.from_pretrained(
36
  chatglm_model_name,
37
  trust_remote_code=True,
38
  device_map="cpu",
39
+ torch_dtype=torch.float32
40
  ).eval()
41
  use_chatglm = True
42
+ print("✅ 4-bit量化版ChatGLM加载成功")
43
  except Exception as e:
44
+ print(f"❌ ChatGLM加载失败: {e}")
45
 
46
  # ======================== 知识图谱结构 ========================
47
  knowledge_graph = {"entities": set(), "relations": set()}
48
 
 
49
  def update_knowledge_graph(entities, relations):
50
  for e in entities:
51
  if isinstance(e, dict) and 'text' in e and 'type' in e:
52
  knowledge_graph["entities"].add((e['text'], e['type']))
53
+ # 修改4:添加关系去重逻辑
54
+ existing_relations = {frozenset({r[0], r[1], r[2]}) for r in knowledge_graph["relations"]}
55
  for r in relations:
56
  if isinstance(r, dict) and all(k in r for k in ("head", "tail", "relation")):
57
+ new_rel = frozenset({r['head'], r['tail'], r['relation']})
58
+ if new_rel not in existing_relations:
59
+ knowledge_graph["relations"].add((r['head'], r['tail'], r['relation']))
60
+
61
+ def visualize_kg_text():
62
+ nodes = [f"{ent[0]} ({ent[1]})" for ent in knowledge_graph["entities"]]
63
+ edges = [f"{h} --[{r}]-> {t}" for h, t, r in knowledge_graph["relations"]]
64
+ return "\n".join(["📌 实体:"] + nodes + ["", "📎 关系:"] + edges)
65
 
66
 
67
  def visualize_kg_text():
 
71
 
72
 
73
  # ======================== 实体识别(NER) ========================
74
+ def merge_adjacent_entities(entities):
75
+ merged = []
76
+ for entity in entities:
77
+ if not merged:
78
+ merged.append(entity)
79
+ continue
80
+
81
+ last = merged[-1]
82
+ # 合并相邻的同类型实体
83
+ if (entity["type"] == last["type"] and
84
+ entity["start"] == last["end"] and
85
+ entity["text"] not in last["text"]):
86
+ merged[-1] = {
87
+ "text": last["text"] + entity["text"],
88
+ "type": last["type"],
89
+ "start": last["start"],
90
+ "end": entity["end"]
91
+ }
92
+ else:
93
+ merged.append(entity)
94
+ return merged
95
+
96
+
97
  def ner(text, model_type="bert"):
98
  start_time = time.time()
99
  if model_type == "chatglm" and use_chatglm:
100
+ # ... [原有ChatGLM代码保持不变] ...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
 
102
+ # 修改6:优化BERT模型处理流程
103
  raw_results = bert_ner_pipeline(text)
104
  entities = []
105
  for r in raw_results:
106
+ mapped_type = LABEL_MAPPING.get(r['entity_group'], r['entity_group'])
107
  entities.append({
108
+ "text": r['word'].replace(' ', ''),
109
+ "start": r['start'],
110
+ "end": r['end'],
111
+ "type": mapped_type
112
  })
 
113
 
114
+ # 执行合并处理
115
+ entities = merge_adjacent_entities(entities)
116
+ return entities, time.time() - start_time
117
 
118
  # ======================== 关系抽取(RE) ========================
119
  def re_extract(entities, text):
120
+ # 修改7:添加实体类型过滤
121
+ valid_entity_types = {"PER", "LOC", "ORG"}
122
+ filtered_entities = [e for e in entities if e["type"] in valid_entity_types]
123
+
124
+ if len(filtered_entities) < 2:
125
  return []
126
 
127
  relations = []
 
224
 
225
  def evaluate_ner_model(data, model_type):
226
  y_true, y_pred = [], []
227
+ POS_TOLERANCE = 1 # 允许的位置误差
228
+
229
  for item in data:
230
  text = item["text"]
231
  gold_entities = []
232
  for e in item.get("entities", []):
233
  if "text" in e and "type" in e:
234
+ # 标准化标签
235
+ norm_type = LABEL_MAPPING.get(e["type"], e["type"])
236
+ gold_entities.append({
237
+ "text": e["text"],
238
+ "type": norm_type,
239
+ "start": e.get("start", -1),
240
+ "end": e.get("end", -1)
241
+ })
242
+
243
+ pred_entities, _ = ner(text, model_type)
244
+
245
+ # 构建对比集合
246
+ all_entities = set()
247
+ # 处理标注数据
248
+ for g in gold_entities:
249
+ key = f"{g['text']}|{g['type']}|{g['start']}|{g['end']}"
250
+ all_entities.add(key)
251
+
252
+ # 处理预测结果
253
+ pred_set = set()
254
+ for p in pred_entities:
255
+ # 允许位置误差
256
+ matched = False
257
+ for g in gold_entities:
258
+ if (p["text"] == g["text"] and
259
+ p["type"] == g["type"] and
260
+ abs(p["start"] - g["start"]) <= POS_TOLERANCE and
261
+ abs(p["end"] - g["end"]) <= POS_TOLERANCE):
262
+ matched = True
263
+ break
264
+ pred_set.add(matched)
265
 
266
+ # 构建指标
267
+ y_true.extend([1] * len(gold_entities))
268
+ y_pred.extend([1 if m else 0 for m in pred_set])
 
 
269
 
270
  if not y_true:
271
  return "⚠️ 无有效标注数据"
272
 
273
+ return (f"Precision: {precision_score(y_true, y_pred, zero_division=0):.2f}\n"
274
+ f"Recall: {recall_score(y_true, y_pred, zero_division=0):.2f}\n"
275
+ f"F1: {f1_score(y_true, y_pred, zero_division=0):.2f}")
276
 
277
  def auto_annotate(file, model_type):
278
  data = convert_telegram_json_to_eval_format(file.name)
 
290
 
291
 
292
  # ======================== Gradio 界面 ========================
293
+ with gr.Blocks(css="""
294
+ .kg-graph {height: 500px; overflow-y: auto;}
295
+ .warning {color: #ff6b6b;}
296
+ """) as demo:
297
  gr.Markdown("# 🤖 聊天记录实体关系识别系统")
298
 
299
  with gr.Tab("📄 文本分析"):