Spaces:
Sleeping
Sleeping
Commit
·
6129c00
1
Parent(s):
0378c00
add app.py and requirements.txt
Browse files
app.py
CHANGED
@@ -9,46 +9,59 @@ from sklearn.metrics import precision_score, recall_score, f1_score
|
|
9 |
import time
|
10 |
|
11 |
# ======================== 模型加载 ========================
|
12 |
-
|
13 |
-
bert_tokenizer = AutoTokenizer.from_pretrained(
|
14 |
-
bert_ner_model = AutoModelForTokenClassification.from_pretrained(
|
15 |
-
bert_ner_pipeline = pipeline(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
|
17 |
chatglm_model, chatglm_tokenizer = None, None
|
18 |
use_chatglm = False
|
19 |
try:
|
20 |
-
chatglm_model_name = "THUDM/chatglm-6b-int4"
|
21 |
-
chatglm_tokenizer = AutoTokenizer.from_pretrained(
|
22 |
-
chatglm_model_name,
|
23 |
-
trust_remote_code=True
|
24 |
-
)
|
25 |
chatglm_model = AutoModel.from_pretrained(
|
26 |
chatglm_model_name,
|
27 |
trust_remote_code=True,
|
28 |
device_map="cpu",
|
29 |
-
torch_dtype=torch.float32
|
30 |
).eval()
|
31 |
use_chatglm = True
|
32 |
-
print("✅ 4-bit量化版ChatGLM
|
33 |
except Exception as e:
|
34 |
-
print(f"❌
|
35 |
|
36 |
# ======================== 知识图谱结构 ========================
|
37 |
knowledge_graph = {"entities": set(), "relations": set()}
|
38 |
|
39 |
-
|
40 |
def update_knowledge_graph(entities, relations):
|
41 |
for e in entities:
|
42 |
if isinstance(e, dict) and 'text' in e and 'type' in e:
|
43 |
knowledge_graph["entities"].add((e['text'], e['type']))
|
44 |
-
|
|
|
45 |
for r in relations:
|
46 |
if isinstance(r, dict) and all(k in r for k in ("head", "tail", "relation")):
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
|
|
|
|
|
|
52 |
|
53 |
|
54 |
def visualize_kg_text():
|
@@ -58,50 +71,57 @@ def visualize_kg_text():
|
|
58 |
|
59 |
|
60 |
# ======================== 实体识别(NER) ========================
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
def ner(text, model_type="bert"):
|
62 |
start_time = time.time()
|
63 |
if model_type == "chatglm" and use_chatglm:
|
64 |
-
|
65 |
-
prompt = f"""请从以下文本中识别所有实体,严格按照JSON列表格式返回,每个实体包含text、type、start、end字段:
|
66 |
-
示例:[{{"text": "北京", "type": "LOC", "start": 0, "end": 2}}]
|
67 |
-
文本:{text}"""
|
68 |
-
response = chatglm_model.chat(chatglm_tokenizer, prompt, temperature=0.1)
|
69 |
-
if isinstance(response, tuple):
|
70 |
-
response = response[0]
|
71 |
-
|
72 |
-
# 增强 JSON 解析
|
73 |
-
try:
|
74 |
-
json_str = re.search(r'\[.*\]', response, re.DOTALL).group()
|
75 |
-
entities = json.loads(json_str)
|
76 |
-
# 验证字段
|
77 |
-
valid_entities = []
|
78 |
-
for ent in entities:
|
79 |
-
if all(k in ent for k in ("text", "type", "start", "end")):
|
80 |
-
valid_entities.append(ent)
|
81 |
-
return valid_entities, time.time() - start_time
|
82 |
-
except Exception as e:
|
83 |
-
print(f"JSON 解析失败: {e}")
|
84 |
-
return [], time.time() - start_time
|
85 |
-
except Exception as e:
|
86 |
-
print(f"ChatGLM 调用失败:{e}")
|
87 |
-
return [], time.time() - start_time
|
88 |
|
89 |
-
#
|
90 |
raw_results = bert_ner_pipeline(text)
|
91 |
entities = []
|
92 |
for r in raw_results:
|
|
|
93 |
entities.append({
|
94 |
-
"text": r[
|
95 |
-
"start": r[
|
96 |
-
"end": r[
|
97 |
-
"type":
|
98 |
})
|
99 |
-
return entities, time.time() - start_time
|
100 |
|
|
|
|
|
|
|
101 |
|
102 |
# ======================== 关系抽取(RE) ========================
|
103 |
def re_extract(entities, text):
|
104 |
-
|
|
|
|
|
|
|
|
|
105 |
return []
|
106 |
|
107 |
relations = []
|
@@ -204,30 +224,55 @@ def convert_telegram_json_to_eval_format(path):
|
|
204 |
|
205 |
def evaluate_ner_model(data, model_type):
|
206 |
y_true, y_pred = [], []
|
|
|
|
|
207 |
for item in data:
|
208 |
text = item["text"]
|
209 |
gold_entities = []
|
210 |
for e in item.get("entities", []):
|
211 |
if "text" in e and "type" in e:
|
212 |
-
#
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
219 |
|
220 |
-
#
|
221 |
-
|
222 |
-
for
|
223 |
-
y_true.append(1 if ent in gold_entities else 0)
|
224 |
-
y_pred.append(1 if ent in pred_entities else 0)
|
225 |
|
226 |
if not y_true:
|
227 |
return "⚠️ 无有效标注数据"
|
228 |
|
229 |
-
return f"Precision: {precision_score(y_true, y_pred
|
230 |
-
|
|
|
231 |
|
232 |
def auto_annotate(file, model_type):
|
233 |
data = convert_telegram_json_to_eval_format(file.name)
|
@@ -245,7 +290,10 @@ def save_json(json_text):
|
|
245 |
|
246 |
|
247 |
# ======================== Gradio 界面 ========================
|
248 |
-
with gr.Blocks(css="
|
|
|
|
|
|
|
249 |
gr.Markdown("# 🤖 聊天记录实体关系识别系统")
|
250 |
|
251 |
with gr.Tab("📄 文本分析"):
|
|
|
9 |
import time
|
10 |
|
11 |
# ======================== 模型加载 ========================
|
12 |
+
NER_MODEL_NAME = "uer/roberta-base-finetuned-cluener2020-chinese"
|
13 |
+
bert_tokenizer = AutoTokenizer.from_pretrained(NER_MODEL_NAME)
|
14 |
+
bert_ner_model = AutoModelForTokenClassification.from_pretrained(NER_MODEL_NAME)
|
15 |
+
bert_ner_pipeline = pipeline(
|
16 |
+
"ner",
|
17 |
+
model=bert_ner_model,
|
18 |
+
tokenizer=bert_tokenizer,
|
19 |
+
aggregation_strategy="first"
|
20 |
+
)
|
21 |
+
|
22 |
+
LABEL_MAPPING = {
|
23 |
+
"address": "LOC",
|
24 |
+
"company": "ORG",
|
25 |
+
"name": "PER",
|
26 |
+
"organization": "ORG",
|
27 |
+
"position": "TITLE"
|
28 |
+
}
|
29 |
|
30 |
chatglm_model, chatglm_tokenizer = None, None
|
31 |
use_chatglm = False
|
32 |
try:
|
33 |
+
chatglm_model_name = "THUDM/chatglm-6b-int4"
|
34 |
+
chatglm_tokenizer = AutoTokenizer.from_pretrained(chatglm_model_name, trust_remote_code=True)
|
|
|
|
|
|
|
35 |
chatglm_model = AutoModel.from_pretrained(
|
36 |
chatglm_model_name,
|
37 |
trust_remote_code=True,
|
38 |
device_map="cpu",
|
39 |
+
torch_dtype=torch.float32
|
40 |
).eval()
|
41 |
use_chatglm = True
|
42 |
+
print("✅ 4-bit量化版ChatGLM加载成功")
|
43 |
except Exception as e:
|
44 |
+
print(f"❌ ChatGLM加载失败: {e}")
|
45 |
|
46 |
# ======================== 知识图谱结构 ========================
|
47 |
knowledge_graph = {"entities": set(), "relations": set()}
|
48 |
|
|
|
49 |
def update_knowledge_graph(entities, relations):
|
50 |
for e in entities:
|
51 |
if isinstance(e, dict) and 'text' in e and 'type' in e:
|
52 |
knowledge_graph["entities"].add((e['text'], e['type']))
|
53 |
+
# 修改4:添加关系去重逻辑
|
54 |
+
existing_relations = {frozenset({r[0], r[1], r[2]}) for r in knowledge_graph["relations"]}
|
55 |
for r in relations:
|
56 |
if isinstance(r, dict) and all(k in r for k in ("head", "tail", "relation")):
|
57 |
+
new_rel = frozenset({r['head'], r['tail'], r['relation']})
|
58 |
+
if new_rel not in existing_relations:
|
59 |
+
knowledge_graph["relations"].add((r['head'], r['tail'], r['relation']))
|
60 |
+
|
61 |
+
def visualize_kg_text():
|
62 |
+
nodes = [f"{ent[0]} ({ent[1]})" for ent in knowledge_graph["entities"]]
|
63 |
+
edges = [f"{h} --[{r}]-> {t}" for h, t, r in knowledge_graph["relations"]]
|
64 |
+
return "\n".join(["📌 实体:"] + nodes + ["", "📎 关系:"] + edges)
|
65 |
|
66 |
|
67 |
def visualize_kg_text():
|
|
|
71 |
|
72 |
|
73 |
# ======================== 实体识别(NER) ========================
|
74 |
+
def merge_adjacent_entities(entities):
|
75 |
+
merged = []
|
76 |
+
for entity in entities:
|
77 |
+
if not merged:
|
78 |
+
merged.append(entity)
|
79 |
+
continue
|
80 |
+
|
81 |
+
last = merged[-1]
|
82 |
+
# 合并相邻的同类型实体
|
83 |
+
if (entity["type"] == last["type"] and
|
84 |
+
entity["start"] == last["end"] and
|
85 |
+
entity["text"] not in last["text"]):
|
86 |
+
merged[-1] = {
|
87 |
+
"text": last["text"] + entity["text"],
|
88 |
+
"type": last["type"],
|
89 |
+
"start": last["start"],
|
90 |
+
"end": entity["end"]
|
91 |
+
}
|
92 |
+
else:
|
93 |
+
merged.append(entity)
|
94 |
+
return merged
|
95 |
+
|
96 |
+
|
97 |
def ner(text, model_type="bert"):
|
98 |
start_time = time.time()
|
99 |
if model_type == "chatglm" and use_chatglm:
|
100 |
+
# ... [原有ChatGLM代码保持不变] ...
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
101 |
|
102 |
+
# 修改6:优化BERT模型处理流程
|
103 |
raw_results = bert_ner_pipeline(text)
|
104 |
entities = []
|
105 |
for r in raw_results:
|
106 |
+
mapped_type = LABEL_MAPPING.get(r['entity_group'], r['entity_group'])
|
107 |
entities.append({
|
108 |
+
"text": r['word'].replace(' ', ''),
|
109 |
+
"start": r['start'],
|
110 |
+
"end": r['end'],
|
111 |
+
"type": mapped_type
|
112 |
})
|
|
|
113 |
|
114 |
+
# 执行合并处理
|
115 |
+
entities = merge_adjacent_entities(entities)
|
116 |
+
return entities, time.time() - start_time
|
117 |
|
118 |
# ======================== 关系抽取(RE) ========================
|
119 |
def re_extract(entities, text):
|
120 |
+
# 修改7:添加实体类型过滤
|
121 |
+
valid_entity_types = {"PER", "LOC", "ORG"}
|
122 |
+
filtered_entities = [e for e in entities if e["type"] in valid_entity_types]
|
123 |
+
|
124 |
+
if len(filtered_entities) < 2:
|
125 |
return []
|
126 |
|
127 |
relations = []
|
|
|
224 |
|
225 |
def evaluate_ner_model(data, model_type):
|
226 |
y_true, y_pred = [], []
|
227 |
+
POS_TOLERANCE = 1 # 允许的位置误差
|
228 |
+
|
229 |
for item in data:
|
230 |
text = item["text"]
|
231 |
gold_entities = []
|
232 |
for e in item.get("entities", []):
|
233 |
if "text" in e and "type" in e:
|
234 |
+
# 标准化标签
|
235 |
+
norm_type = LABEL_MAPPING.get(e["type"], e["type"])
|
236 |
+
gold_entities.append({
|
237 |
+
"text": e["text"],
|
238 |
+
"type": norm_type,
|
239 |
+
"start": e.get("start", -1),
|
240 |
+
"end": e.get("end", -1)
|
241 |
+
})
|
242 |
+
|
243 |
+
pred_entities, _ = ner(text, model_type)
|
244 |
+
|
245 |
+
# 构建对比集合
|
246 |
+
all_entities = set()
|
247 |
+
# 处理标注数据
|
248 |
+
for g in gold_entities:
|
249 |
+
key = f"{g['text']}|{g['type']}|{g['start']}|{g['end']}"
|
250 |
+
all_entities.add(key)
|
251 |
+
|
252 |
+
# 处理预测结果
|
253 |
+
pred_set = set()
|
254 |
+
for p in pred_entities:
|
255 |
+
# 允许位置误差
|
256 |
+
matched = False
|
257 |
+
for g in gold_entities:
|
258 |
+
if (p["text"] == g["text"] and
|
259 |
+
p["type"] == g["type"] and
|
260 |
+
abs(p["start"] - g["start"]) <= POS_TOLERANCE and
|
261 |
+
abs(p["end"] - g["end"]) <= POS_TOLERANCE):
|
262 |
+
matched = True
|
263 |
+
break
|
264 |
+
pred_set.add(matched)
|
265 |
|
266 |
+
# 构建指标
|
267 |
+
y_true.extend([1] * len(gold_entities))
|
268 |
+
y_pred.extend([1 if m else 0 for m in pred_set])
|
|
|
|
|
269 |
|
270 |
if not y_true:
|
271 |
return "⚠️ 无有效标注数据"
|
272 |
|
273 |
+
return (f"Precision: {precision_score(y_true, y_pred, zero_division=0):.2f}\n"
|
274 |
+
f"Recall: {recall_score(y_true, y_pred, zero_division=0):.2f}\n"
|
275 |
+
f"F1: {f1_score(y_true, y_pred, zero_division=0):.2f}")
|
276 |
|
277 |
def auto_annotate(file, model_type):
|
278 |
data = convert_telegram_json_to_eval_format(file.name)
|
|
|
290 |
|
291 |
|
292 |
# ======================== Gradio 界面 ========================
|
293 |
+
with gr.Blocks(css="""
|
294 |
+
.kg-graph {height: 500px; overflow-y: auto;}
|
295 |
+
.warning {color: #ff6b6b;}
|
296 |
+
""") as demo:
|
297 |
gr.Markdown("# 🤖 聊天记录实体关系识别系统")
|
298 |
|
299 |
with gr.Tab("📄 文本分析"):
|