chen666-666's picture
Upload 3 files
3d9242d verified
raw
history blame
4.93 kB
import json
import tempfile
import re
from collections import defaultdict
from transformers import (
AutoTokenizer,
AutoModelForTokenClassification,
AutoModelForSequenceClassification,
pipeline,
)
import torch
from pyvis.network import Network
# -------------------------------
# 实体识别模型(NER)
# -------------------------------
ner_tokenizer = AutoTokenizer.from_pretrained("ckiplab/bert-base-chinese-ner")
ner_model = AutoModelForTokenClassification.from_pretrained("ckiplab/bert-base-chinese-ner")
ner_pipeline = pipeline("ner", model=ner_model, tokenizer=ner_tokenizer, aggregation_strategy="simple")
# -------------------------------
# 人物关系分类模型(BERT 分类器)
# -------------------------------
rel_model_name = "uer/roberta-base-finetuned-baike-chinese-relation-extraction"
rel_tokenizer = AutoTokenizer.from_pretrained(rel_model_name)
rel_model = AutoModelForSequenceClassification.from_pretrained(rel_model_name)
rel_model.eval()
id2label = {
0: "夫妻", 1: "父子", 2: "朋友", 3: "师生", 4: "同事", 5: "其他"
}
def classify_relation_bert(e1, e2, context):
prompt = f"{e1}{e2}的关系是?{context}"
inputs = rel_tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
with torch.no_grad():
logits = rel_model(**inputs).logits
pred = torch.argmax(logits, dim=1).item()
probs = torch.nn.functional.softmax(logits, dim=1)
confidence = probs[0, pred].item()
return f"{id2label[pred]}(置信度 {confidence:.2f})"
# -------------------------------
# 聊天输入解析
# -------------------------------
def parse_input_file(file):
filename = file.name
if filename.endswith(".json"):
return json.load(file)
elif filename.endswith(".txt"):
content = file.read().decode("utf-8")
lines = content.strip().splitlines()
chat_data = []
for line in lines:
match = re.match(r"(\d{4}-\d{2}-\d{2}.*?) (.*?): (.*)", line)
if match:
_, sender, message = match.groups()
chat_data.append({"sender": sender, "message": message})
return chat_data
else:
raise ValueError("不支持的文件格式,请上传 JSON 或 TXT 文件")
# -------------------------------
# 实体提取函数
# -------------------------------
def extract_entities(text):
results = ner_pipeline(text)
people = set()
for r in results:
if r["entity_group"] == "PER":
people.add(r["word"])
return list(people)
# -------------------------------
# 关系抽取函数(共现 + BERT 分类)
# -------------------------------
def extract_relations(chat_data, entities):
relations = defaultdict(lambda: defaultdict(lambda: {"count": 0, "contexts": []}))
for entry in chat_data:
msg = entry["message"]
found = [e for e in entities if e in msg]
for i in range(len(found)):
for j in range(i + 1, len(found)):
e1, e2 = found[i], found[j]
relations[e1][e2]["count"] += 1
relations[e1][e2]["contexts"].append(msg)
relations[e2][e1]["count"] += 1
relations[e2][e1]["contexts"].append(msg)
edges = []
for e1 in relations:
for e2 in relations[e1]:
if e1 < e2:
context_text = " ".join(relations[e1][e2]["contexts"])
label = classify_relation_bert(e1, e2, context_text)
edges.append((e1, e2, relations[e1][e2]["count"], label))
return edges
# -------------------------------
# 图谱绘制
# -------------------------------
def draw_graph(entities, relations):
g = Network(height="600px", width="100%", notebook=False)
g.barnes_hut()
for ent in entities:
g.add_node(ent, label=ent)
for e1, e2, weight, label in relations:
g.add_edge(e1, e2, value=weight, title=f"{label}(互动{weight}次)", label=label)
tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".html")
g.show(tmp_file.name)
with open(tmp_file.name, 'r', encoding='utf-8') as f:
return f.read()
# -------------------------------
# 主流程函数
# -------------------------------
def analyze_chat(file):
if file is None:
return "请上传聊天文件", "", ""
try:
content = parse_input_file(file)
except Exception as e:
return f"读取文件失败: {e}", "", ""
text = "\n".join([entry["sender"] + ": " + entry["message"] for entry in content])
entities = extract_entities(text)
if not entities:
return "未识别到任何人物实体", "", ""
relations = extract_relations(content, entities)
if not relations:
return "未发现人物之间的关系", "", ""
graph_html = draw_graph(entities, relations)
return str(entities), str(relations), graph_html