Spaces:
Sleeping
Sleeping
import json | |
import tempfile | |
import re | |
from collections import defaultdict | |
from transformers import ( | |
AutoTokenizer, | |
AutoModelForTokenClassification, | |
AutoModelForSequenceClassification, | |
pipeline, | |
) | |
import torch | |
from pyvis.network import Network | |
# ------------------------------- | |
# 实体识别模型(NER) | |
# ------------------------------- | |
ner_tokenizer = AutoTokenizer.from_pretrained("ckiplab/bert-base-chinese-ner") | |
ner_model = AutoModelForTokenClassification.from_pretrained("ckiplab/bert-base-chinese-ner") | |
ner_pipeline = pipeline("ner", model=ner_model, tokenizer=ner_tokenizer, aggregation_strategy="simple") | |
# ------------------------------- | |
# 人物关系分类模型(BERT 分类器) | |
# ------------------------------- | |
rel_model_name = "uer/roberta-base-finetuned-baike-chinese-relation-extraction" | |
rel_tokenizer = AutoTokenizer.from_pretrained(rel_model_name) | |
rel_model = AutoModelForSequenceClassification.from_pretrained(rel_model_name) | |
rel_model.eval() | |
id2label = { | |
0: "夫妻", 1: "父子", 2: "朋友", 3: "师生", 4: "同事", 5: "其他" | |
} | |
def classify_relation_bert(e1, e2, context): | |
prompt = f"{e1}和{e2}的关系是?{context}" | |
inputs = rel_tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512) | |
with torch.no_grad(): | |
logits = rel_model(**inputs).logits | |
pred = torch.argmax(logits, dim=1).item() | |
probs = torch.nn.functional.softmax(logits, dim=1) | |
confidence = probs[0, pred].item() | |
return f"{id2label[pred]}(置信度 {confidence:.2f})" | |
# ------------------------------- | |
# 聊天输入解析 | |
# ------------------------------- | |
def parse_input_file(file): | |
filename = file.name | |
if filename.endswith(".json"): | |
return json.load(file) | |
elif filename.endswith(".txt"): | |
content = file.read().decode("utf-8") | |
lines = content.strip().splitlines() | |
chat_data = [] | |
for line in lines: | |
match = re.match(r"(\d{4}-\d{2}-\d{2}.*?) (.*?): (.*)", line) | |
if match: | |
_, sender, message = match.groups() | |
chat_data.append({"sender": sender, "message": message}) | |
return chat_data | |
else: | |
raise ValueError("不支持的文件格式,请上传 JSON 或 TXT 文件") | |
# ------------------------------- | |
# 实体提取函数 | |
# ------------------------------- | |
def extract_entities(text): | |
results = ner_pipeline(text) | |
people = set() | |
for r in results: | |
if r["entity_group"] == "PER": | |
people.add(r["word"]) | |
return list(people) | |
# ------------------------------- | |
# 关系抽取函数(共现 + BERT 分类) | |
# ------------------------------- | |
def extract_relations(chat_data, entities): | |
relations = defaultdict(lambda: defaultdict(lambda: {"count": 0, "contexts": []})) | |
for entry in chat_data: | |
msg = entry["message"] | |
found = [e for e in entities if e in msg] | |
for i in range(len(found)): | |
for j in range(i + 1, len(found)): | |
e1, e2 = found[i], found[j] | |
relations[e1][e2]["count"] += 1 | |
relations[e1][e2]["contexts"].append(msg) | |
relations[e2][e1]["count"] += 1 | |
relations[e2][e1]["contexts"].append(msg) | |
edges = [] | |
for e1 in relations: | |
for e2 in relations[e1]: | |
if e1 < e2: | |
context_text = " ".join(relations[e1][e2]["contexts"]) | |
label = classify_relation_bert(e1, e2, context_text) | |
edges.append((e1, e2, relations[e1][e2]["count"], label)) | |
return edges | |
# ------------------------------- | |
# 图谱绘制 | |
# ------------------------------- | |
def draw_graph(entities, relations): | |
g = Network(height="600px", width="100%", notebook=False) | |
g.barnes_hut() | |
for ent in entities: | |
g.add_node(ent, label=ent) | |
for e1, e2, weight, label in relations: | |
g.add_edge(e1, e2, value=weight, title=f"{label}(互动{weight}次)", label=label) | |
tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".html") | |
g.show(tmp_file.name) | |
with open(tmp_file.name, 'r', encoding='utf-8') as f: | |
return f.read() | |
# ------------------------------- | |
# 主流程函数 | |
# ------------------------------- | |
def analyze_chat(file): | |
if file is None: | |
return "请上传聊天文件", "", "" | |
try: | |
content = parse_input_file(file) | |
except Exception as e: | |
return f"读取文件失败: {e}", "", "" | |
text = "\n".join([entry["sender"] + ": " + entry["message"] for entry in content]) | |
entities = extract_entities(text) | |
if not entities: | |
return "未识别到任何人物实体", "", "" | |
relations = extract_relations(content, entities) | |
if not relations: | |
return "未发现人物之间的关系", "", "" | |
graph_html = draw_graph(entities, relations) | |
return str(entities), str(relations), graph_html | |