wechat-ner-re / app.py
chen666-666's picture
add app.py and requirements.txt
f305260
raw
history blame
8.5 kB
import torch
from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline, AutoModel
import gradio as gr
import re
import os
import json
import chardet
from sklearn.metrics import precision_score, recall_score, f1_score
import time
# ======================== 模型加载 ========================
bert_model_name = "bert-base-chinese"
bert_tokenizer = AutoTokenizer.from_pretrained(bert_model_name)
bert_ner_model = AutoModelForTokenClassification.from_pretrained("ckiplab/bert-base-chinese-ner")
bert_ner_pipeline = pipeline("ner", model=bert_ner_model, tokenizer=bert_tokenizer, aggregation_strategy="simple")
chatglm_model, chatglm_tokenizer = None, None
use_chatglm = False
try:
if torch.cuda.is_available():
chatglm_model_name = "THUDM/chatglm3-6b"
chatglm_tokenizer = AutoTokenizer.from_pretrained(chatglm_model_name, trust_remote_code=True)
chatglm_model = AutoModel.from_pretrained(
chatglm_model_name, trust_remote_code=True, device_map="auto", torch_dtype=torch.float16
).eval()
use_chatglm = True
else:
print("⚠️ 当前为 CPU 环境,ChatGLM3 不可用,将仅使用 BERT。")
except Exception as e:
print(f"❌ ChatGLM 加载失败: {e}")
# ======================== 知识图谱结构 ========================
knowledge_graph = {"entities": set(), "relations": []}
def update_knowledge_graph(entities, relations):
for e in entities:
if isinstance(e, dict) and 'text' in e and 'type' in e:
knowledge_graph["entities"].add((e['text'], e['type']))
for r in relations:
if isinstance(r, dict) and all(k in r for k in ("head", "tail", "relation")):
knowledge_graph["relations"].append((r['head'], r['tail'], r['relation']))
def visualize_kg_text():
nodes = [f"{ent[0]} ({ent[1]})" for ent in knowledge_graph["entities"]]
edges = [f"{h} --[{r}]-> {t}" for h, t, r in knowledge_graph["relations"]]
return "\n".join(["📌 实体:"] + nodes + ["", "📎 关系:"] + edges)
# ======================== 实体识别(NER) ========================
def ner(text, model_type="bert"):
start_time = time.time()
if model_type == "chatglm" and use_chatglm:
try:
prompt = f"请从以下文本中识别所有实体,用JSON格式返回:[{text}]"
response = chatglm_model.chat(chatglm_tokenizer, prompt, temperature=0.1)
if isinstance(response, tuple):
response = response[0]
entities = json.loads(response)
return entities, time.time() - start_time
except Exception as e:
print(f"❌ ChatGLM 实体识别失败:{e}")
return [], time.time() - start_time
# 使用微调的 BERT 中文 NER 模型
raw_results = bert_ner_pipeline(text)
entities = []
for r in raw_results:
entities.append({
"text": r["word"],
"start": r["start"],
"end": r["end"],
"type": r["entity_group"]
})
return entities, time.time() - start_time
# ======================== 关系抽取(RE) ========================
def re_extract(entities, text):
if len(entities) < 2:
return []
try:
entity_list = [e['text'] for e in entities]
prompt = f"分析以下实体之间的关系:{entity_list}\n文本上下文:{text}"
if use_chatglm:
response = chatglm_model.chat(chatglm_tokenizer, prompt, temperature=0.1)
if isinstance(response, tuple):
response = response[0]
return json.loads(response)
except Exception as e:
print(f"❌ ChatGLM 关系抽取失败:{e}")
return [{"head": e1['text'], "tail": e2['text'], "relation": "相关"} for i, e1 in enumerate(entities) for e2 in entities[i+1:]]
# ======================== 文本分析主流程 ========================
def process_text(text, model_type="bert"):
entities, duration = ner(text, model_type)
relations = re_extract(entities, text)
update_knowledge_graph(entities, relations)
ent_text = "\n".join(f"{e['text']} ({e['type']}) [{e['start']}-{e['end']}]" for e in entities)
rel_text = "\n".join(f"{r['head']} --[{r['relation']}]-> {r['tail']}" for r in relations)
kg_text = visualize_kg_text()
return ent_text, rel_text, kg_text, f"{duration:.2f} 秒"
# ======================== 模型评估与自动标注 ========================
def convert_telegram_json_to_eval_format(path):
with open(path, encoding="utf-8") as f:
data = json.load(f)
if isinstance(data, dict) and "text" in data:
return [{"text": data["text"], "entities": [
{"text": data["text"][e["start"]:e["end"]]} for e in data.get("entities", [])
]}]
elif isinstance(data, list):
return data
elif isinstance(data, dict) and "messages" in data:
result = []
for m in data.get("messages", []):
if isinstance(m.get("text"), str):
result.append({"text": m["text"], "entities": []})
elif isinstance(m.get("text"), list):
txt = ''.join([x["text"] if isinstance(x, dict) else x for x in m["text"]])
result.append({"text": txt, "entities": []})
return result
return []
def evaluate_ner_model(data, model_type):
y_true, y_pred = [], []
for item in data:
text = item["text"]
gold = set()
for e in item.get("entities", []):
if "text" in e:
gold.add(e["text"])
elif "start" in e and "end" in e:
gold.add(text[e["start"]:e["end"]])
pred, _ = ner(text, model_type)
pred = set(e['text'] for e in pred)
for ent in gold.union(pred):
y_true.append(1 if ent in gold else 0)
y_pred.append(1 if ent in pred else 0)
return f"Precision: {precision_score(y_true, y_pred):.2f}\nRecall: {recall_score(y_true, y_pred):.2f}\nF1: {f1_score(y_true, y_pred):.2f}"
def auto_annotate(file, model_type):
data = convert_telegram_json_to_eval_format(file.name)
for item in data:
ents, _ = ner(item["text"], model_type)
item["entities"] = ents
return json.dumps(data, ensure_ascii=False, indent=2)
def save_json(json_text):
fname = f"auto_labeled_{int(time.time())}.json"
with open(fname, "w", encoding="utf-8") as f:
f.write(json_text)
return fname
# ======================== Gradio 界面 ========================
with gr.Blocks(css=".kg-graph {height: 500px;}") as demo:
gr.Markdown("# 🤖 聊天记录实体关系识别系统")
with gr.Tab("📄 文本分析"):
input_text = gr.Textbox(lines=6, label="输入文本")
model_type = gr.Radio(["bert", "chatglm"], value="bert", label="选择模型")
btn = gr.Button("开始分析")
out1 = gr.Textbox(label="识别实体")
out2 = gr.Textbox(label="识别关系")
out3 = gr.Textbox(label="知识图谱")
out4 = gr.Textbox(label="耗时")
btn.click(fn=process_text, inputs=[input_text, model_type], outputs=[out1, out2, out3, out4])
with gr.Tab("🗂 文件分析"):
file_input = gr.File(file_types=[".txt", ".json"])
file_btn = gr.Button("上传并分析")
fout1, fout2, fout3, fout4 = gr.Textbox(), gr.Textbox(), gr.Textbox(), gr.Textbox()
file_btn.click(fn=process_file, inputs=[file_input, model_type], outputs=[fout1, fout2, fout3, fout4])
with gr.Tab("📊 模型评估"):
eval_file = gr.File(label="上传标注 JSON")
eval_model = gr.Radio(["bert", "chatglm"], value="bert")
eval_btn = gr.Button("开始评估")
eval_output = gr.Textbox(label="评估结果", lines=5)
eval_btn.click(lambda f, m: evaluate_ner_model(convert_telegram_json_to_eval_format(f.name), m), [eval_file, eval_model], eval_output)
with gr.Tab("✏️ 自动标注"):
raw_file = gr.File(label="上传 Telegram 原始 JSON")
auto_model = gr.Radio(["bert", "chatglm"], value="bert")
auto_btn = gr.Button("自动标注")
marked_texts = gr.Textbox(label="标注结果", lines=20)
download_btn = gr.Button("💾 下载标注文件")
auto_btn.click(fn=auto_annotate, inputs=[raw_file, auto_model], outputs=marked_texts)
download_btn.click(fn=save_json, inputs=marked_texts, outputs=gr.File())
demo.launch(server_name="0.0.0.0", server_port=7860)