Spaces:
Sleeping
Sleeping
File size: 4,928 Bytes
3d9242d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
import json
import tempfile
import re
from collections import defaultdict
from transformers import (
AutoTokenizer,
AutoModelForTokenClassification,
AutoModelForSequenceClassification,
pipeline,
)
import torch
from pyvis.network import Network
# -------------------------------
# 实体识别模型(NER)
# -------------------------------
ner_tokenizer = AutoTokenizer.from_pretrained("ckiplab/bert-base-chinese-ner")
ner_model = AutoModelForTokenClassification.from_pretrained("ckiplab/bert-base-chinese-ner")
ner_pipeline = pipeline("ner", model=ner_model, tokenizer=ner_tokenizer, aggregation_strategy="simple")
# -------------------------------
# 人物关系分类模型(BERT 分类器)
# -------------------------------
rel_model_name = "uer/roberta-base-finetuned-baike-chinese-relation-extraction"
rel_tokenizer = AutoTokenizer.from_pretrained(rel_model_name)
rel_model = AutoModelForSequenceClassification.from_pretrained(rel_model_name)
rel_model.eval()
id2label = {
0: "夫妻", 1: "父子", 2: "朋友", 3: "师生", 4: "同事", 5: "其他"
}
def classify_relation_bert(e1, e2, context):
prompt = f"{e1}和{e2}的关系是?{context}"
inputs = rel_tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
with torch.no_grad():
logits = rel_model(**inputs).logits
pred = torch.argmax(logits, dim=1).item()
probs = torch.nn.functional.softmax(logits, dim=1)
confidence = probs[0, pred].item()
return f"{id2label[pred]}(置信度 {confidence:.2f})"
# -------------------------------
# 聊天输入解析
# -------------------------------
def parse_input_file(file):
filename = file.name
if filename.endswith(".json"):
return json.load(file)
elif filename.endswith(".txt"):
content = file.read().decode("utf-8")
lines = content.strip().splitlines()
chat_data = []
for line in lines:
match = re.match(r"(\d{4}-\d{2}-\d{2}.*?) (.*?): (.*)", line)
if match:
_, sender, message = match.groups()
chat_data.append({"sender": sender, "message": message})
return chat_data
else:
raise ValueError("不支持的文件格式,请上传 JSON 或 TXT 文件")
# -------------------------------
# 实体提取函数
# -------------------------------
def extract_entities(text):
results = ner_pipeline(text)
people = set()
for r in results:
if r["entity_group"] == "PER":
people.add(r["word"])
return list(people)
# -------------------------------
# 关系抽取函数(共现 + BERT 分类)
# -------------------------------
def extract_relations(chat_data, entities):
relations = defaultdict(lambda: defaultdict(lambda: {"count": 0, "contexts": []}))
for entry in chat_data:
msg = entry["message"]
found = [e for e in entities if e in msg]
for i in range(len(found)):
for j in range(i + 1, len(found)):
e1, e2 = found[i], found[j]
relations[e1][e2]["count"] += 1
relations[e1][e2]["contexts"].append(msg)
relations[e2][e1]["count"] += 1
relations[e2][e1]["contexts"].append(msg)
edges = []
for e1 in relations:
for e2 in relations[e1]:
if e1 < e2:
context_text = " ".join(relations[e1][e2]["contexts"])
label = classify_relation_bert(e1, e2, context_text)
edges.append((e1, e2, relations[e1][e2]["count"], label))
return edges
# -------------------------------
# 图谱绘制
# -------------------------------
def draw_graph(entities, relations):
g = Network(height="600px", width="100%", notebook=False)
g.barnes_hut()
for ent in entities:
g.add_node(ent, label=ent)
for e1, e2, weight, label in relations:
g.add_edge(e1, e2, value=weight, title=f"{label}(互动{weight}次)", label=label)
tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".html")
g.show(tmp_file.name)
with open(tmp_file.name, 'r', encoding='utf-8') as f:
return f.read()
# -------------------------------
# 主流程函数
# -------------------------------
def analyze_chat(file):
if file is None:
return "请上传聊天文件", "", ""
try:
content = parse_input_file(file)
except Exception as e:
return f"读取文件失败: {e}", "", ""
text = "\n".join([entry["sender"] + ": " + entry["message"] for entry in content])
entities = extract_entities(text)
if not entities:
return "未识别到任何人物实体", "", ""
relations = extract_relations(content, entities)
if not relations:
return "未发现人物之间的关系", "", ""
graph_html = draw_graph(entities, relations)
return str(entities), str(relations), graph_html
|