wechat-ner-re / app.py
chen666-666's picture
add app.py and requirements.txt
1d3964d
raw
history blame
13.4 kB
import torch
from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline, AutoModel
import gradio as gr
import re
import os
import json
import chardet
from sklearn.metrics import precision_score, recall_score, f1_score
import time
# ======================== 模型加载 ========================
NER_MODEL_NAME = "uer/roberta-base-finetuned-cluener2020-chinese"
bert_tokenizer = AutoTokenizer.from_pretrained(NER_MODEL_NAME)
bert_ner_model = AutoModelForTokenClassification.from_pretrained(NER_MODEL_NAME)
bert_ner_pipeline = pipeline(
"ner",
model=bert_ner_model,
tokenizer=bert_tokenizer,
aggregation_strategy="first"
)
LABEL_MAPPING = {
"address": "LOC",
"company": "ORG",
"name": "PER",
"organization": "ORG",
"position": "TITLE"
}
chatglm_model, chatglm_tokenizer = None, None
use_chatglm = False
try:
chatglm_model_name = "THUDM/chatglm-6b-int4"
chatglm_tokenizer = AutoTokenizer.from_pretrained(chatglm_model_name, trust_remote_code=True)
chatglm_model = AutoModel.from_pretrained(
chatglm_model_name,
trust_remote_code=True,
device_map="cpu",
torch_dtype=torch.float32
).eval()
use_chatglm = True
print("✅ 4-bit量化版ChatGLM加载成功")
except Exception as e:
print(f"❌ ChatGLM加载失败: {e}")
# ======================== 知识图谱结构 ========================
knowledge_graph = {"entities": set(), "relations": set()}
def update_knowledge_graph(entities, relations):
for e in entities:
if isinstance(e, dict) and 'text' in e and 'type' in e:
knowledge_graph["entities"].add((e['text'], e['type']))
# 修改4:添加关系去重逻辑
existing_relations = {frozenset({r[0], r[1], r[2]}) for r in knowledge_graph["relations"]}
for r in relations:
if isinstance(r, dict) and all(k in r for k in ("head", "tail", "relation")):
new_rel = frozenset({r['head'], r['tail'], r['relation']})
if new_rel not in existing_relations:
knowledge_graph["relations"].add((r['head'], r['tail'], r['relation']))
def visualize_kg_text():
nodes = [f"{ent[0]} ({ent[1]})" for ent in knowledge_graph["entities"]]
edges = [f"{h} --[{r}]-> {t}" for h, t, r in knowledge_graph["relations"]]
return "\n".join(["📌 实体:"] + nodes + ["", "📎 关系:"] + edges)
# ======================== 实体识别(NER) ========================
def merge_adjacent_entities(entities):
merged = []
for entity in entities:
if not merged:
merged.append(entity)
continue
last = merged[-1]
# 合并相邻的同类型实体
if (entity["type"] == last["type"] and
entity["start"] == last["end"] and
entity["text"] not in last["text"]):
merged[-1] = {
"text": last["text"] + entity["text"],
"type": last["type"],
"start": last["start"],
"end": entity["end"]
}
else:
merged.append(entity)
return merged
def ner(text, model_type="bert"):
start_time = time.time()
if model_type == "chatglm" and use_chatglm:
try:
prompt = f"""请从以下文本中识别所有实体,严格按照JSON列表格式返回,每个实体包含text、type、start、end字段:
示例:[{{"text": "北京", "type": "LOC", "start": 0, "end": 2}}]
文本:{text}"""
response = chatglm_model.chat(chatglm_tokenizer, prompt, temperature=0.1)
if isinstance(response, tuple):
response = response[0]
# 增强 JSON 解析
try:
json_str = re.search(r'\[.*\]', response, re.DOTALL).group()
entities = json.loads(json_str)
# 验证字段
valid_entities = []
for ent in entities:
if all(k in ent for k in ("text", "type", "start", "end")):
valid_entities.append(ent)
return valid_entities, time.time() - start_time
except Exception as e:
print(f"JSON 解析失败: {e}")
return [], time.time() - start_time
except Exception as e:
print(f"ChatGLM 调用失败:{e}")
return [], time.time() - start_time
# 使用微调的 BERT 中文 NER 模型
raw_results = bert_ner_pipeline(text)
entities = []
for r in raw_results:
mapped_type = LABEL_MAPPING.get(r['entity_group'], r['entity_group'])
entities.append({
"text": r['word'].replace(' ', ''),
"start": r['start'],
"end": r['end'],
"type": mapped_type
})
# 执行合并处理
entities = merge_adjacent_entities(entities)
return entities, time.time() - start_time
# ======================== 关系抽取(RE) ========================
def re_extract(entities, text):
# 修改7:添加实体类型过滤
valid_entity_types = {"PER", "LOC", "ORG"}
filtered_entities = [e for e in entities if e["type"] in valid_entity_types]
if len(filtered_entities) < 2:
return []
relations = []
try:
entity_pairs = [(e1, e2) for i, e1 in enumerate(entities) for e2 in entities[i + 1:]]
prompt = f"""分析文本中的实体关系,返回JSON列表:
文本:{text}
实体列表:{[e['text'] for e in entities]}
要求:
1. 仅返回存在明确关系的实体对
2. 关系类型使用:属于、位于、参与、其他
3. 格式示例:[{{"head": "北京", "tail": "中国", "relation": "位于"}}]"""
if use_chatglm:
response = chatglm_model.chat(chatglm_tokenizer, prompt, temperature=0.1)
if isinstance(response, tuple):
response = response[0]
# 提取 JSON
try:
json_str = re.search(r'\[.*\]', response, re.DOTALL).group()
relations = json.loads(json_str)
# 验证关系
valid_relations = []
valid_types = {"属于", "位于", "参与", "其他"}
for rel in relations:
if all(k in rel for k in ("head", "tail", "relation")) and rel["relation"] in valid_types:
valid_relations.append(rel)
return valid_relations
except Exception as e:
print(f"关系解析失败: {e}")
except Exception as e:
print(f"关系抽取失败: {e}")
# 默认不生成任何关系
return []
# ======================== 文本分析主流程 ========================
def process_text(text, model_type="bert"):
entities, duration = ner(text, model_type)
relations = re_extract(entities, text)
update_knowledge_graph(entities, relations)
ent_text = "\n".join(f"{e['text']} ({e['type']}) [{e['start']}-{e['end']}]" for e in entities)
rel_text = "\n".join(f"{r['head']} --[{r['relation']}]-> {r['tail']}" for r in relations)
kg_text = visualize_kg_text()
return ent_text, rel_text, kg_text, f"{duration:.2f} 秒"
def process_file(file, model_type="bert"):
try:
with open(file.name, 'rb') as f:
content = f.read()
if len(content) > 5 * 1024 * 1024:
return "❌ 文件太大", "", "", ""
# 检测编码
try:
encoding = chardet.detect(content)['encoding'] or 'utf-8'
text = content.decode(encoding)
except UnicodeDecodeError:
# 尝试常见中文编码
for enc in ['gb18030', 'utf-16', 'big5']:
try:
text = content.decode(enc)
break
except:
continue
else:
return "❌ 编码解析失败", "", "", ""
return process_text(text, model_type)
except Exception as e:
return f"❌ 文件处理错误: {str(e)}", "", "", ""
# ======================== 模型评估与自动标注 ========================
def convert_telegram_json_to_eval_format(path):
with open(path, encoding="utf-8") as f:
data = json.load(f)
if isinstance(data, dict) and "text" in data:
return [{"text": data["text"], "entities": [
{"text": data["text"][e["start"]:e["end"]]} for e in data.get("entities", [])
]}]
elif isinstance(data, list):
return data
elif isinstance(data, dict) and "messages" in data:
result = []
for m in data.get("messages", []):
if isinstance(m.get("text"), str):
result.append({"text": m["text"], "entities": []})
elif isinstance(m.get("text"), list):
txt = ''.join([x["text"] if isinstance(x, dict) else x for x in m["text"]])
result.append({"text": txt, "entities": []})
return result
return []
def evaluate_ner_model(data, model_type):
y_true, y_pred = [], []
POS_TOLERANCE = 1 # 允许的位置误差
for item in data:
text = item["text"]
gold_entities = []
for e in item.get("entities", []):
if "text" in e and "type" in e:
# 标准化标签
norm_type = LABEL_MAPPING.get(e["type"], e["type"])
gold_entities.append({
"text": e["text"],
"type": norm_type,
"start": e.get("start", -1),
"end": e.get("end", -1)
})
pred_entities, _ = ner(text, model_type)
# 构建对比集合
all_entities = set()
# 处理标注数据
for g in gold_entities:
key = f"{g['text']}|{g['type']}|{g['start']}|{g['end']}"
all_entities.add(key)
# 处理预测结果
pred_set = set()
for p in pred_entities:
# 允许位置误差
matched = False
for g in gold_entities:
if (p["text"] == g["text"] and
p["type"] == g["type"] and
abs(p["start"] - g["start"]) <= POS_TOLERANCE and
abs(p["end"] - g["end"]) <= POS_TOLERANCE):
matched = True
break
pred_set.add(matched)
# 构建指标
y_true.extend([1] * len(gold_entities))
y_pred.extend([1 if m else 0 for m in pred_set])
if not y_true:
return "⚠️ 无有效标注数据"
return (f"Precision: {precision_score(y_true, y_pred, zero_division=0):.2f}\n"
f"Recall: {recall_score(y_true, y_pred, zero_division=0):.2f}\n"
f"F1: {f1_score(y_true, y_pred, zero_division=0):.2f}")
def auto_annotate(file, model_type):
data = convert_telegram_json_to_eval_format(file.name)
for item in data:
ents, _ = ner(item["text"], model_type)
item["entities"] = ents
return json.dumps(data, ensure_ascii=False, indent=2)
def save_json(json_text):
fname = f"auto_labeled_{int(time.time())}.json"
with open(fname, "w", encoding="utf-8") as f:
f.write(json_text)
return fname
# ======================== Gradio 界面 ========================
with gr.Blocks(css="""
.kg-graph {height: 500px; overflow-y: auto;}
.warning {color: #ff6b6b;}
""") as demo:
gr.Markdown("# 🤖 聊天记录实体关系识别系统")
with gr.Tab("📄 文本分析"):
input_text = gr.Textbox(lines=6, label="输入文本")
model_type = gr.Radio(["bert", "chatglm"], value="bert", label="选择模型")
btn = gr.Button("开始分析")
out1 = gr.Textbox(label="识别实体")
out2 = gr.Textbox(label="识别关系")
out3 = gr.Textbox(label="知识图谱")
out4 = gr.Textbox(label="耗时")
btn.click(fn=process_text, inputs=[input_text, model_type], outputs=[out1, out2, out3, out4])
with gr.Tab("🗂 文件分析"):
file_input = gr.File(file_types=[".txt", ".json"])
file_btn = gr.Button("上传并分析")
fout1, fout2, fout3, fout4 = gr.Textbox(), gr.Textbox(), gr.Textbox(), gr.Textbox()
file_btn.click(fn=process_file, inputs=[file_input, model_type], outputs=[fout1, fout2, fout3, fout4])
with gr.Tab("📊 模型评估"):
eval_file = gr.File(label="上传标注 JSON")
eval_model = gr.Radio(["bert", "chatglm"], value="bert")
eval_btn = gr.Button("开始评估")
eval_output = gr.Textbox(label="评估结果", lines=5)
eval_btn.click(lambda f, m: evaluate_ner_model(convert_telegram_json_to_eval_format(f.name), m),
[eval_file, eval_model], eval_output)
with gr.Tab("✏️ 自动标注"):
raw_file = gr.File(label="上传 Telegram 原始 JSON")
auto_model = gr.Radio(["bert", "chatglm"], value="bert")
auto_btn = gr.Button("自动标注")
marked_texts = gr.Textbox(label="标注结果", lines=20)
download_btn = gr.Button("💾 下载标注文件")
auto_btn.click(fn=auto_annotate, inputs=[raw_file, auto_model], outputs=marked_texts)
download_btn.click(fn=save_json, inputs=marked_texts, outputs=gr.File())
demo.launch(server_name="0.0.0.0", server_port=7860)