Spaces:
Sleeping
Sleeping
| import torch | |
| from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline, AutoModel | |
| import gradio as gr | |
| import re | |
| import os | |
| import json | |
| import chardet | |
| from sklearn.metrics import precision_score, recall_score, f1_score | |
| import time | |
| # ======================== 模型加载 ======================== | |
| bert_model_name = "bert-base-chinese" | |
| bert_tokenizer = AutoTokenizer.from_pretrained(bert_model_name) | |
| bert_ner_model = AutoModelForTokenClassification.from_pretrained("ckiplab/bert-base-chinese-ner") | |
| bert_ner_pipeline = pipeline("ner", model=bert_ner_model, tokenizer=bert_tokenizer, aggregation_strategy="simple") | |
| chatglm_model, chatglm_tokenizer = None, None | |
| use_chatglm = False | |
| try: | |
| if torch.cuda.is_available(): | |
| chatglm_model_name = "THUDM/chatglm3-6b" | |
| chatglm_tokenizer = AutoTokenizer.from_pretrained(chatglm_model_name, trust_remote_code=True) | |
| chatglm_model = AutoModel.from_pretrained( | |
| chatglm_model_name, trust_remote_code=True, device_map="auto", torch_dtype=torch.float16 | |
| ).eval() | |
| use_chatglm = True | |
| else: | |
| print("⚠️ 当前为 CPU 环境,ChatGLM3 不可用,将仅使用 BERT。") | |
| except Exception as e: | |
| print(f"❌ ChatGLM 加载失败: {e}") | |
| # ======================== 知识图谱结构 ======================== | |
| knowledge_graph = {"entities": set(), "relations": []} | |
| def update_knowledge_graph(entities, relations): | |
| for e in entities: | |
| if isinstance(e, dict) and 'text' in e and 'type' in e: | |
| knowledge_graph["entities"].add((e['text'], e['type'])) | |
| for r in relations: | |
| if isinstance(r, dict) and all(k in r for k in ("head", "tail", "relation")): | |
| knowledge_graph["relations"].append((r['head'], r['tail'], r['relation'])) | |
| def visualize_kg_text(): | |
| nodes = [f"{ent[0]} ({ent[1]})" for ent in knowledge_graph["entities"]] | |
| edges = [f"{h} --[{r}]-> {t}" for h, t, r in knowledge_graph["relations"]] | |
| return "\n".join(["📌 实体:"] + nodes + ["", "📎 关系:"] + edges) | |
| # ======================== 实体识别(NER) ======================== | |
| def ner(text, model_type="bert"): | |
| start_time = time.time() | |
| if model_type == "chatglm" and use_chatglm: | |
| try: | |
| prompt = f"请从以下文本中识别所有实体,用JSON格式返回:[{text}]" | |
| response = chatglm_model.chat(chatglm_tokenizer, prompt, temperature=0.1) | |
| if isinstance(response, tuple): | |
| response = response[0] | |
| entities = json.loads(response) | |
| return entities, time.time() - start_time | |
| except Exception as e: | |
| print(f"❌ ChatGLM 实体识别失败:{e}") | |
| return [], time.time() - start_time | |
| # 使用微调的 BERT 中文 NER 模型 | |
| raw_results = bert_ner_pipeline(text) | |
| entities = [] | |
| for r in raw_results: | |
| entities.append({ | |
| "text": r["word"], | |
| "start": r["start"], | |
| "end": r["end"], | |
| "type": r["entity_group"] | |
| }) | |
| return entities, time.time() - start_time | |
| # ======================== 关系抽取(RE) ======================== | |
| def re_extract(entities, text): | |
| if len(entities) < 2: | |
| return [] | |
| try: | |
| entity_list = [e['text'] for e in entities] | |
| prompt = f"分析以下实体之间的关系:{entity_list}\n文本上下文:{text}" | |
| if use_chatglm: | |
| response = chatglm_model.chat(chatglm_tokenizer, prompt, temperature=0.1) | |
| if isinstance(response, tuple): | |
| response = response[0] | |
| return json.loads(response) | |
| except Exception as e: | |
| print(f"❌ ChatGLM 关系抽取失败:{e}") | |
| return [{"head": e1['text'], "tail": e2['text'], "relation": "相关"} for i, e1 in enumerate(entities) for e2 in entities[i+1:]] | |
| # ======================== 文本分析主流程 ======================== | |
| def process_text(text, model_type="bert"): | |
| entities, duration = ner(text, model_type) | |
| relations = re_extract(entities, text) | |
| update_knowledge_graph(entities, relations) | |
| ent_text = "\n".join(f"{e['text']} ({e['type']}) [{e['start']}-{e['end']}]" for e in entities) | |
| rel_text = "\n".join(f"{r['head']} --[{r['relation']}]-> {r['tail']}" for r in relations) | |
| kg_text = visualize_kg_text() | |
| return ent_text, rel_text, kg_text, f"{duration:.2f} 秒" | |
| # ======================== 模型评估与自动标注 ======================== | |
| def convert_telegram_json_to_eval_format(path): | |
| with open(path, encoding="utf-8") as f: | |
| data = json.load(f) | |
| if isinstance(data, dict) and "text" in data: | |
| return [{"text": data["text"], "entities": [ | |
| {"text": data["text"][e["start"]:e["end"]]} for e in data.get("entities", []) | |
| ]}] | |
| elif isinstance(data, list): | |
| return data | |
| elif isinstance(data, dict) and "messages" in data: | |
| result = [] | |
| for m in data.get("messages", []): | |
| if isinstance(m.get("text"), str): | |
| result.append({"text": m["text"], "entities": []}) | |
| elif isinstance(m.get("text"), list): | |
| txt = ''.join([x["text"] if isinstance(x, dict) else x for x in m["text"]]) | |
| result.append({"text": txt, "entities": []}) | |
| return result | |
| return [] | |
| def evaluate_ner_model(data, model_type): | |
| y_true, y_pred = [], [] | |
| for item in data: | |
| text = item["text"] | |
| gold = set() | |
| for e in item.get("entities", []): | |
| if "text" in e: | |
| gold.add(e["text"]) | |
| elif "start" in e and "end" in e: | |
| gold.add(text[e["start"]:e["end"]]) | |
| pred, _ = ner(text, model_type) | |
| pred = set(e['text'] for e in pred) | |
| for ent in gold.union(pred): | |
| y_true.append(1 if ent in gold else 0) | |
| y_pred.append(1 if ent in pred else 0) | |
| return f"Precision: {precision_score(y_true, y_pred):.2f}\nRecall: {recall_score(y_true, y_pred):.2f}\nF1: {f1_score(y_true, y_pred):.2f}" | |
| def auto_annotate(file, model_type): | |
| data = convert_telegram_json_to_eval_format(file.name) | |
| for item in data: | |
| ents, _ = ner(item["text"], model_type) | |
| item["entities"] = ents | |
| return json.dumps(data, ensure_ascii=False, indent=2) | |
| def save_json(json_text): | |
| fname = f"auto_labeled_{int(time.time())}.json" | |
| with open(fname, "w", encoding="utf-8") as f: | |
| f.write(json_text) | |
| return fname | |
| # ======================== Gradio 界面 ======================== | |
| with gr.Blocks(css=".kg-graph {height: 500px;}") as demo: | |
| gr.Markdown("# 🤖 聊天记录实体关系识别系统") | |
| with gr.Tab("📄 文本分析"): | |
| input_text = gr.Textbox(lines=6, label="输入文本") | |
| model_type = gr.Radio(["bert", "chatglm"], value="bert", label="选择模型") | |
| btn = gr.Button("开始分析") | |
| out1 = gr.Textbox(label="识别实体") | |
| out2 = gr.Textbox(label="识别关系") | |
| out3 = gr.Textbox(label="知识图谱") | |
| out4 = gr.Textbox(label="耗时") | |
| btn.click(fn=process_text, inputs=[input_text, model_type], outputs=[out1, out2, out3, out4]) | |
| with gr.Tab("🗂 文件分析"): | |
| file_input = gr.File(file_types=[".txt", ".json"]) | |
| file_btn = gr.Button("上传并分析") | |
| fout1, fout2, fout3, fout4 = gr.Textbox(), gr.Textbox(), gr.Textbox(), gr.Textbox() | |
| file_btn.click(fn=process_file, inputs=[file_input, model_type], outputs=[fout1, fout2, fout3, fout4]) | |
| with gr.Tab("📊 模型评估"): | |
| eval_file = gr.File(label="上传标注 JSON") | |
| eval_model = gr.Radio(["bert", "chatglm"], value="bert") | |
| eval_btn = gr.Button("开始评估") | |
| eval_output = gr.Textbox(label="评估结果", lines=5) | |
| eval_btn.click(lambda f, m: evaluate_ner_model(convert_telegram_json_to_eval_format(f.name), m), [eval_file, eval_model], eval_output) | |
| with gr.Tab("✏️ 自动标注"): | |
| raw_file = gr.File(label="上传 Telegram 原始 JSON") | |
| auto_model = gr.Radio(["bert", "chatglm"], value="bert") | |
| auto_btn = gr.Button("自动标注") | |
| marked_texts = gr.Textbox(label="标注结果", lines=20) | |
| download_btn = gr.Button("💾 下载标注文件") | |
| auto_btn.click(fn=auto_annotate, inputs=[raw_file, auto_model], outputs=marked_texts) | |
| download_btn.click(fn=save_json, inputs=marked_texts, outputs=gr.File()) | |
| demo.launch(server_name="0.0.0.0", server_port=7860) | |