File size: 4,928 Bytes
3d9242d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import json
import tempfile
import re
from collections import defaultdict

from transformers import (
    AutoTokenizer,
    AutoModelForTokenClassification,
    AutoModelForSequenceClassification,
    pipeline,
)
import torch
from pyvis.network import Network


# -------------------------------
# 实体识别模型(NER)
# -------------------------------
ner_tokenizer = AutoTokenizer.from_pretrained("ckiplab/bert-base-chinese-ner")
ner_model = AutoModelForTokenClassification.from_pretrained("ckiplab/bert-base-chinese-ner")
ner_pipeline = pipeline("ner", model=ner_model, tokenizer=ner_tokenizer, aggregation_strategy="simple")


# -------------------------------
# 人物关系分类模型(BERT 分类器)
# -------------------------------
rel_model_name = "uer/roberta-base-finetuned-baike-chinese-relation-extraction"
rel_tokenizer = AutoTokenizer.from_pretrained(rel_model_name)
rel_model = AutoModelForSequenceClassification.from_pretrained(rel_model_name)
rel_model.eval()

id2label = {
    0: "夫妻", 1: "父子", 2: "朋友", 3: "师生", 4: "同事", 5: "其他"
}


def classify_relation_bert(e1, e2, context):
    prompt = f"{e1}{e2}的关系是?{context}"
    inputs = rel_tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
    with torch.no_grad():
        logits = rel_model(**inputs).logits
    pred = torch.argmax(logits, dim=1).item()
    probs = torch.nn.functional.softmax(logits, dim=1)
    confidence = probs[0, pred].item()
    return f"{id2label[pred]}(置信度 {confidence:.2f})"


# -------------------------------
# 聊天输入解析
# -------------------------------
def parse_input_file(file):
    filename = file.name
    if filename.endswith(".json"):
        return json.load(file)
    elif filename.endswith(".txt"):
        content = file.read().decode("utf-8")
        lines = content.strip().splitlines()
        chat_data = []
        for line in lines:
            match = re.match(r"(\d{4}-\d{2}-\d{2}.*?) (.*?): (.*)", line)
            if match:
                _, sender, message = match.groups()
                chat_data.append({"sender": sender, "message": message})
        return chat_data
    else:
        raise ValueError("不支持的文件格式,请上传 JSON 或 TXT 文件")


# -------------------------------
# 实体提取函数
# -------------------------------
def extract_entities(text):
    results = ner_pipeline(text)
    people = set()
    for r in results:
        if r["entity_group"] == "PER":
            people.add(r["word"])
    return list(people)


# -------------------------------
# 关系抽取函数(共现 + BERT 分类)
# -------------------------------
def extract_relations(chat_data, entities):
    relations = defaultdict(lambda: defaultdict(lambda: {"count": 0, "contexts": []}))

    for entry in chat_data:
        msg = entry["message"]
        found = [e for e in entities if e in msg]
        for i in range(len(found)):
            for j in range(i + 1, len(found)):
                e1, e2 = found[i], found[j]
                relations[e1][e2]["count"] += 1
                relations[e1][e2]["contexts"].append(msg)
                relations[e2][e1]["count"] += 1
                relations[e2][e1]["contexts"].append(msg)

    edges = []
    for e1 in relations:
        for e2 in relations[e1]:
            if e1 < e2:
                context_text = " ".join(relations[e1][e2]["contexts"])
                label = classify_relation_bert(e1, e2, context_text)
                edges.append((e1, e2, relations[e1][e2]["count"], label))
    return edges


# -------------------------------
# 图谱绘制
# -------------------------------
def draw_graph(entities, relations):
    g = Network(height="600px", width="100%", notebook=False)
    g.barnes_hut()
    for ent in entities:
        g.add_node(ent, label=ent)
    for e1, e2, weight, label in relations:
        g.add_edge(e1, e2, value=weight, title=f"{label}(互动{weight}次)", label=label)
    tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".html")
    g.show(tmp_file.name)
    with open(tmp_file.name, 'r', encoding='utf-8') as f:
        return f.read()


# -------------------------------
# 主流程函数
# -------------------------------
def analyze_chat(file):
    if file is None:
        return "请上传聊天文件", "", ""
    
    try:
        content = parse_input_file(file)
    except Exception as e:
        return f"读取文件失败: {e}", "", ""

    text = "\n".join([entry["sender"] + ": " + entry["message"] for entry in content])
    entities = extract_entities(text)
    if not entities:
        return "未识别到任何人物实体", "", ""

    relations = extract_relations(content, entities)
    if not relations:
        return "未发现人物之间的关系", "", ""

    graph_html = draw_graph(entities, relations)

    return str(entities), str(relations), graph_html