import json import tempfile import re from collections import defaultdict from transformers import ( AutoTokenizer, AutoModelForTokenClassification, AutoModelForSequenceClassification, pipeline, ) import torch from pyvis.network import Network # ------------------------------- # 实体识别模型(NER) # ------------------------------- ner_tokenizer = AutoTokenizer.from_pretrained("ckiplab/bert-base-chinese-ner") ner_model = AutoModelForTokenClassification.from_pretrained("ckiplab/bert-base-chinese-ner") ner_pipeline = pipeline("ner", model=ner_model, tokenizer=ner_tokenizer, aggregation_strategy="simple") # ------------------------------- # 人物关系分类模型(BERT 分类器) # ------------------------------- rel_model_name = "uer/roberta-base-finetuned-baike-chinese-relation-extraction" rel_tokenizer = AutoTokenizer.from_pretrained(rel_model_name) rel_model = AutoModelForSequenceClassification.from_pretrained(rel_model_name) rel_model.eval() id2label = { 0: "夫妻", 1: "父子", 2: "朋友", 3: "师生", 4: "同事", 5: "其他" } def classify_relation_bert(e1, e2, context): prompt = f"{e1}和{e2}的关系是?{context}" inputs = rel_tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512) with torch.no_grad(): logits = rel_model(**inputs).logits pred = torch.argmax(logits, dim=1).item() probs = torch.nn.functional.softmax(logits, dim=1) confidence = probs[0, pred].item() return f"{id2label[pred]}(置信度 {confidence:.2f})" # ------------------------------- # 聊天输入解析 # ------------------------------- def parse_input_file(file): filename = file.name if filename.endswith(".json"): return json.load(file) elif filename.endswith(".txt"): content = file.read().decode("utf-8") lines = content.strip().splitlines() chat_data = [] for line in lines: match = re.match(r"(\d{4}-\d{2}-\d{2}.*?) (.*?): (.*)", line) if match: _, sender, message = match.groups() chat_data.append({"sender": sender, "message": message}) return chat_data else: raise ValueError("不支持的文件格式,请上传 JSON 或 TXT 文件") # ------------------------------- # 实体提取函数 # ------------------------------- def extract_entities(text): results = ner_pipeline(text) people = set() for r in results: if r["entity_group"] == "PER": people.add(r["word"]) return list(people) # ------------------------------- # 关系抽取函数(共现 + BERT 分类) # ------------------------------- def extract_relations(chat_data, entities): relations = defaultdict(lambda: defaultdict(lambda: {"count": 0, "contexts": []})) for entry in chat_data: msg = entry["message"] found = [e for e in entities if e in msg] for i in range(len(found)): for j in range(i + 1, len(found)): e1, e2 = found[i], found[j] relations[e1][e2]["count"] += 1 relations[e1][e2]["contexts"].append(msg) relations[e2][e1]["count"] += 1 relations[e2][e1]["contexts"].append(msg) edges = [] for e1 in relations: for e2 in relations[e1]: if e1 < e2: context_text = " ".join(relations[e1][e2]["contexts"]) label = classify_relation_bert(e1, e2, context_text) edges.append((e1, e2, relations[e1][e2]["count"], label)) return edges # ------------------------------- # 图谱绘制 # ------------------------------- def draw_graph(entities, relations): g = Network(height="600px", width="100%", notebook=False) g.barnes_hut() for ent in entities: g.add_node(ent, label=ent) for e1, e2, weight, label in relations: g.add_edge(e1, e2, value=weight, title=f"{label}(互动{weight}次)", label=label) tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".html") g.show(tmp_file.name) with open(tmp_file.name, 'r', encoding='utf-8') as f: return f.read() # ------------------------------- # 主流程函数 # ------------------------------- def analyze_chat(file): if file is None: return "请上传聊天文件", "", "" try: content = parse_input_file(file) except Exception as e: return f"读取文件失败: {e}", "", "" text = "\n".join([entry["sender"] + ": " + entry["message"] for entry in content]) entities = extract_entities(text) if not entities: return "未识别到任何人物实体", "", "" relations = extract_relations(content, entities) if not relations: return "未发现人物之间的关系", "", "" graph_html = draw_graph(entities, relations) return str(entities), str(relations), graph_html