import json import tempfile import re import os from collections import defaultdict from transformers import ( AutoTokenizer, AutoModelForSequenceClassification, pipeline, ) import torch from pyvis.network import Network # ------------------------------- # 模型配置 # ------------------------------- # 使用环境变量配置模型名称,便于在Hugging Face上部署时修改 NER_MODEL_NAME = os.environ.get("NER_MODEL_NAME", "ckiplab/bert-base-chinese-ner") REL_MODEL_NAME = os.environ.get("REL_MODEL_NAME", "hfl/chinese-roberta-wwm-ext") # ------------------------------- # 实体识别模型(NER) # ------------------------------- try: ner_tokenizer = AutoTokenizer.from_pretrained(NER_MODEL_NAME) ner_model = AutoModelForSequenceClassification.from_pretrained(NER_MODEL_NAME) ner_pipeline = pipeline("ner", model=ner_model, tokenizer=ner_tokenizer, aggregation_strategy="simple") except Exception as e: print(f"NER模型加载失败: {e}") # 可以添加备选方案或错误处理逻辑 # ------------------------------- # 人物关系分类模型(使用 RoBERTa) # ------------------------------- try: rel_tokenizer = AutoTokenizer.from_pretrained(REL_MODEL_NAME) rel_model = AutoModelForSequenceClassification.from_pretrained( REL_MODEL_NAME, num_labels=6, # 确保标签数量匹配 id2label={0: "夫妻", 1: "父子", 2: "朋友", 3: "师生", 4: "同事", 5: "其他"}, label2id={"夫妻": 0, "父子": 1, "朋友": 2, "师生": 3, "同事": 4, "其他": 5} ) rel_model.eval() except Exception as e: print(f"关系分类模型加载失败: {e}") # 可以添加备选方案或错误处理逻辑 # 关系分类的标签映射 relation_id2label = { 0: "夫妻", 1: "父子", 2: "朋友", 3: "师生", 4: "同事", 5: "其他" } # 法律风险分析的标签映射 legal_id2label = { 0: "无违法", 1: "赌博", 2: "毒品", 3: "色情", 4: "诈骗", 5: "暴力" } # ------------------------------- # 聊天输入解析 # ------------------------------- def parse_input_file(file): """解析聊天文件,支持JSON和TXT格式""" try: filename = file.name if filename.endswith(".json"): return json.load(file) elif filename.endswith(".txt"): content = file.read().decode("utf-8") lines = content.strip().splitlines() chat_data = [] for line in lines: match = re.match(r"(\d{4}-\d{2}-\d{2}.*?) (.*?): (.*)", line) if match: _, sender, message = match.groups() chat_data.append({"sender": sender, "message": message}) return chat_data else: raise ValueError("不支持的文件格式,请上传JSON或TXT文件") except Exception as e: print(f"文件解析错误: {e}") raise # ------------------------------- # 实体提取函数 # ------------------------------- def extract_entities(text): """从文本中提取人物实体""" try: results = ner_pipeline(text) people = set() for r in results: if r["entity_group"] == "PER": people.add(r["word"]) return list(people) except Exception as e: print(f"实体提取错误: {e}") return [] # ------------------------------- # 关系抽取函数(共现 + BERT 分类) # ------------------------------- def extract_relations(chat_data, entities): """分析人物之间的关系""" relations = defaultdict(lambda: defaultdict(lambda: {"count": 0, "contexts": []})) for entry in chat_data: msg = entry["message"] found = [e for e in entities if e in msg] for i in range(len(found)): for j in range(i + 1, len(found)): e1, e2 = found[i], found[j] relations[e1][e2]["count"] += 1 relations[e1][e2]["contexts"].append(msg) relations[e2][e1]["count"] += 1 relations[e2][e1]["contexts"].append(msg) edges = [] for e1 in relations: for e2 in relations[e1]: if e1 < e2: context_text = " ".join(relations[e1][e2]["contexts"]) # 截断过长的文本 max_context_length = 500 # 根据需要调整 if len(context_text) > max_context_length: context_text = context_text[:max_context_length] + "..." label = classify_relation_bert(e1, e2, context_text) edges.append((e1, e2, relations[e1][e2]["count"], label)) return edges # ------------------------------- # 法律风险分析(黄赌毒等)函数 # ------------------------------- def classify_relation_bert(e1, e2, context): """使用BERT模型分析人物关系""" try: prompt = f"{e1}和{e2}的关系是?{context}" inputs = rel_tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512) with torch.no_grad(): logits = rel_model(**inputs).logits pred = torch.argmax(logits, dim=1).item() probs = torch.nn.functional.softmax(logits, dim=1) confidence = probs[0, pred].item() return f"{relation_id2label[pred]}(置信度 {confidence:.2f})" except Exception as e: print(f"关系分类错误: {e}") return "其他(置信度 0.00)" def classify_illegal_behavior(chat_context): """分析聊天内容中的法律风险""" try: prompt = f"请分析以下聊天记录,判断是否涉及以下违法行为:赌博、毒品、色情、诈骗、暴力行为。\n聊天内容:{chat_context}\n请回答:" inputs = rel_tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512) with torch.no_grad(): logits = rel_model(**inputs).logits pred = torch.argmax(logits, dim=1).item() probs = torch.nn.functional.softmax(logits, dim=1) confidence = probs[0, pred].item() return f"违法行为判断结果:{legal_id2label.get(pred, '未知')}(置信度 {confidence:.2f})" except Exception as e: print(f"法律风险分析错误: {e}") return "违法行为判断结果:未知(置信度 0.00)" # ------------------------------- # 图谱绘制 # ------------------------------- def draw_graph(entities, relations): """生成人物关系图谱""" try: g = Network(height="600px", width="100%", notebook=False) g.barnes_hut() for ent in entities: g.add_node(ent, label=ent) for e1, e2, weight, label in relations: g.add_edge(e1, e2, value=weight, title=f"{label}(互动{weight}次)", label=label) tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".html") g.show(tmp_file.name) with open(tmp_file.name, 'r', encoding='utf-8') as f: return f.read() except Exception as e: print(f"图谱绘制错误: {e}") return "
请检查输入数据是否有效
" # ------------------------------- # 主流程函数 # ------------------------------- def analyze_chat(file): """分析聊天记录的主函数""" if file is None: return "请上传聊天文件", "", "", "" try: content = parse_input_file(file) except Exception as e: return f"读取文件失败: {e}", "", "", "" text = "\n".join([entry["sender"] + ": " + entry["message"] for entry in content]) entities = extract_entities(text) if not entities: return "未识别到任何人物实体", "", "", "" relations = extract_relations(content, entities) if not relations: return "未发现人物之间的关系", "", "", "" # 法律风险分析 illegal_behavior_results = [classify_illegal_behavior(msg["message"]) for msg in content] graph_html = draw_graph(entities, relations) return str(entities), str(relations), graph_html, "\n".join(illegal_behavior_results)