chen666-666's picture
Update utils.py
950dc1a verified
raw
history blame
8.11 kB
import json
import tempfile
import re
import os
from collections import defaultdict
from transformers import (
AutoTokenizer,
AutoModelForSequenceClassification,
pipeline,
)
import torch
from pyvis.network import Network
# -------------------------------
# 模型配置
# -------------------------------
# 使用环境变量配置模型名称,便于在Hugging Face上部署时修改
NER_MODEL_NAME = os.environ.get("NER_MODEL_NAME", "ckiplab/bert-base-chinese-ner")
REL_MODEL_NAME = os.environ.get("REL_MODEL_NAME", "hfl/chinese-roberta-wwm-ext")
# -------------------------------
# 实体识别模型(NER)
# -------------------------------
try:
ner_tokenizer = AutoTokenizer.from_pretrained(NER_MODEL_NAME)
ner_model = AutoModelForSequenceClassification.from_pretrained(NER_MODEL_NAME)
ner_pipeline = pipeline("ner", model=ner_model, tokenizer=ner_tokenizer, aggregation_strategy="simple")
except Exception as e:
print(f"NER模型加载失败: {e}")
# 可以添加备选方案或错误处理逻辑
# -------------------------------
# 人物关系分类模型(使用 RoBERTa)
# -------------------------------
try:
rel_tokenizer = AutoTokenizer.from_pretrained(REL_MODEL_NAME)
rel_model = AutoModelForSequenceClassification.from_pretrained(
REL_MODEL_NAME,
num_labels=6, # 确保标签数量匹配
id2label={0: "夫妻", 1: "父子", 2: "朋友", 3: "师生", 4: "同事", 5: "其他"},
label2id={"夫妻": 0, "父子": 1, "朋友": 2, "师生": 3, "同事": 4, "其他": 5}
)
rel_model.eval()
except Exception as e:
print(f"关系分类模型加载失败: {e}")
# 可以添加备选方案或错误处理逻辑
# 关系分类的标签映射
relation_id2label = {
0: "夫妻", 1: "父子", 2: "朋友", 3: "师生", 4: "同事", 5: "其他"
}
# 法律风险分析的标签映射
legal_id2label = {
0: "无违法", 1: "赌博", 2: "毒品", 3: "色情", 4: "诈骗", 5: "暴力"
}
# -------------------------------
# 聊天输入解析
# -------------------------------
def parse_input_file(file):
"""解析聊天文件,支持JSON和TXT格式"""
try:
filename = file.name
if filename.endswith(".json"):
return json.load(file)
elif filename.endswith(".txt"):
content = file.read().decode("utf-8")
lines = content.strip().splitlines()
chat_data = []
for line in lines:
match = re.match(r"(\d{4}-\d{2}-\d{2}.*?) (.*?): (.*)", line)
if match:
_, sender, message = match.groups()
chat_data.append({"sender": sender, "message": message})
return chat_data
else:
raise ValueError("不支持的文件格式,请上传JSON或TXT文件")
except Exception as e:
print(f"文件解析错误: {e}")
raise
# -------------------------------
# 实体提取函数
# -------------------------------
def extract_entities(text):
"""从文本中提取人物实体"""
try:
results = ner_pipeline(text)
people = set()
for r in results:
if r["entity_group"] == "PER":
people.add(r["word"])
return list(people)
except Exception as e:
print(f"实体提取错误: {e}")
return []
# -------------------------------
# 关系抽取函数(共现 + BERT 分类)
# -------------------------------
def extract_relations(chat_data, entities):
"""分析人物之间的关系"""
relations = defaultdict(lambda: defaultdict(lambda: {"count": 0, "contexts": []}))
for entry in chat_data:
msg = entry["message"]
found = [e for e in entities if e in msg]
for i in range(len(found)):
for j in range(i + 1, len(found)):
e1, e2 = found[i], found[j]
relations[e1][e2]["count"] += 1
relations[e1][e2]["contexts"].append(msg)
relations[e2][e1]["count"] += 1
relations[e2][e1]["contexts"].append(msg)
edges = []
for e1 in relations:
for e2 in relations[e1]:
if e1 < e2:
context_text = " ".join(relations[e1][e2]["contexts"])
# 截断过长的文本
max_context_length = 500 # 根据需要调整
if len(context_text) > max_context_length:
context_text = context_text[:max_context_length] + "..."
label = classify_relation_bert(e1, e2, context_text)
edges.append((e1, e2, relations[e1][e2]["count"], label))
return edges
# -------------------------------
# 法律风险分析(黄赌毒等)函数
# -------------------------------
def classify_relation_bert(e1, e2, context):
"""使用BERT模型分析人物关系"""
try:
prompt = f"{e1}{e2}的关系是?{context}"
inputs = rel_tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
with torch.no_grad():
logits = rel_model(**inputs).logits
pred = torch.argmax(logits, dim=1).item()
probs = torch.nn.functional.softmax(logits, dim=1)
confidence = probs[0, pred].item()
return f"{relation_id2label[pred]}(置信度 {confidence:.2f})"
except Exception as e:
print(f"关系分类错误: {e}")
return "其他(置信度 0.00)"
def classify_illegal_behavior(chat_context):
"""分析聊天内容中的法律风险"""
try:
prompt = f"请分析以下聊天记录,判断是否涉及以下违法行为:赌博、毒品、色情、诈骗、暴力行为。\n聊天内容:{chat_context}\n请回答:"
inputs = rel_tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
with torch.no_grad():
logits = rel_model(**inputs).logits
pred = torch.argmax(logits, dim=1).item()
probs = torch.nn.functional.softmax(logits, dim=1)
confidence = probs[0, pred].item()
return f"违法行为判断结果:{legal_id2label.get(pred, '未知')}(置信度 {confidence:.2f})"
except Exception as e:
print(f"法律风险分析错误: {e}")
return "违法行为判断结果:未知(置信度 0.00)"
# -------------------------------
# 图谱绘制
# -------------------------------
def draw_graph(entities, relations):
"""生成人物关系图谱"""
try:
g = Network(height="600px", width="100%", notebook=False)
g.barnes_hut()
for ent in entities:
g.add_node(ent, label=ent)
for e1, e2, weight, label in relations:
g.add_edge(e1, e2, value=weight, title=f"{label}(互动{weight}次)", label=label)
tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".html")
g.show(tmp_file.name)
with open(tmp_file.name, 'r', encoding='utf-8') as f:
return f.read()
except Exception as e:
print(f"图谱绘制错误: {e}")
return "<h3>图谱生成失败</h3><p>请检查输入数据是否有效</p>"
# -------------------------------
# 主流程函数
# -------------------------------
def analyze_chat(file):
"""分析聊天记录的主函数"""
if file is None:
return "请上传聊天文件", "", "", ""
try:
content = parse_input_file(file)
except Exception as e:
return f"读取文件失败: {e}", "", "", ""
text = "\n".join([entry["sender"] + ": " + entry["message"] for entry in content])
entities = extract_entities(text)
if not entities:
return "未识别到任何人物实体", "", "", ""
relations = extract_relations(content, entities)
if not relations:
return "未发现人物之间的关系", "", "", ""
# 法律风险分析
illegal_behavior_results = [classify_illegal_behavior(msg["message"]) for msg in content]
graph_html = draw_graph(entities, relations)
return str(entities), str(relations), graph_html, "\n".join(illegal_behavior_results)