File size: 13,355 Bytes
d65f85e
d7c47a4
07e97de
 
ff6d08e
07e97de
 
26ec260
ee61e9e
d65f85e
1eeeaf5
6129c00
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1e5ba7c
d7c47a4
 
 
6129c00
 
4fab3f7
 
 
0378c00
6129c00
4fab3f7
 
6129c00
d7c47a4
6129c00
d7c47a4
1eeeaf5
0207e75
 
07e97de
 
c85af5a
 
6129c00
 
07e97de
c85af5a
6129c00
 
 
 
 
 
 
 
8810e7b
1eeeaf5
6129c00
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d7c47a4
1e5ba7c
d7c47a4
1d3964d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f305260
 
 
6129c00
f305260
6129c00
 
 
 
f305260
20683c1
6129c00
 
 
8810e7b
1d3964d
1eeeaf5
c781ba0
6129c00
 
 
 
 
ee61e9e
0207e75
 
1a6560a
0207e75
 
 
 
 
 
 
 
 
d7c47a4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1eeeaf5
0207e75
 
 
 
ee61e9e
8810e7b
1eeeaf5
d7c47a4
 
26ec260
 
ee61e9e
 
 
 
d7c47a4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6129c00
 
d7c47a4
 
 
 
 
6129c00
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d7c47a4
6129c00
 
 
d7c47a4
 
 
 
6129c00
 
 
d7c47a4
 
 
 
 
 
 
 
 
 
 
 
 
 
07e97de
0207e75
1eeeaf5
6129c00
 
 
 
ee61e9e
07e97de
ee61e9e
 
d7c47a4
26ec260
ee61e9e
 
1eeeaf5
26ec260
d7c47a4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26ec260
d7c47a4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
import torch
from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline, AutoModel
import gradio as gr
import re
import os
import json
import chardet
from sklearn.metrics import precision_score, recall_score, f1_score
import time

# ======================== 模型加载 ========================
NER_MODEL_NAME = "uer/roberta-base-finetuned-cluener2020-chinese"
bert_tokenizer = AutoTokenizer.from_pretrained(NER_MODEL_NAME)
bert_ner_model = AutoModelForTokenClassification.from_pretrained(NER_MODEL_NAME)
bert_ner_pipeline = pipeline(
    "ner",
    model=bert_ner_model,
    tokenizer=bert_tokenizer,
    aggregation_strategy="first"
)

LABEL_MAPPING = {
    "address": "LOC",
    "company": "ORG",
    "name": "PER",
    "organization": "ORG",
    "position": "TITLE"
}

chatglm_model, chatglm_tokenizer = None, None
use_chatglm = False
try:
    chatglm_model_name = "THUDM/chatglm-6b-int4"
    chatglm_tokenizer = AutoTokenizer.from_pretrained(chatglm_model_name, trust_remote_code=True)
    chatglm_model = AutoModel.from_pretrained(
        chatglm_model_name,
        trust_remote_code=True,
        device_map="cpu",
        torch_dtype=torch.float32
    ).eval()
    use_chatglm = True
    print("✅ 4-bit量化版ChatGLM加载成功")
except Exception as e:
    print(f"❌ ChatGLM加载失败: {e}")

# ======================== 知识图谱结构 ========================
knowledge_graph = {"entities": set(), "relations": set()}

def update_knowledge_graph(entities, relations):
    for e in entities:
        if isinstance(e, dict) and 'text' in e and 'type' in e:
            knowledge_graph["entities"].add((e['text'], e['type']))
    # 修改4:添加关系去重逻辑
    existing_relations = {frozenset({r[0], r[1], r[2]}) for r in knowledge_graph["relations"]}
    for r in relations:
        if isinstance(r, dict) and all(k in r for k in ("head", "tail", "relation")):
            new_rel = frozenset({r['head'], r['tail'], r['relation']})
            if new_rel not in existing_relations:
                knowledge_graph["relations"].add((r['head'], r['tail'], r['relation']))

def visualize_kg_text():
    nodes = [f"{ent[0]} ({ent[1]})" for ent in knowledge_graph["entities"]]
    edges = [f"{h} --[{r}]-> {t}" for h, t, r in knowledge_graph["relations"]]
    return "\n".join(["📌 实体:"] + nodes + ["", "📎 关系:"] + edges)

# ======================== 实体识别(NER) ========================
def merge_adjacent_entities(entities):
    merged = []
    for entity in entities:
        if not merged:
            merged.append(entity)
            continue

        last = merged[-1]
        # 合并相邻的同类型实体
        if (entity["type"] == last["type"] and
                entity["start"] == last["end"] and
                entity["text"] not in last["text"]):
            merged[-1] = {
                "text": last["text"] + entity["text"],
                "type": last["type"],
                "start": last["start"],
                "end": entity["end"]
            }
        else:
            merged.append(entity)
    return merged


def ner(text, model_type="bert"):
    start_time = time.time()
    if model_type == "chatglm" and use_chatglm:
        try:
            prompt = f"""请从以下文本中识别所有实体,严格按照JSON列表格式返回,每个实体包含text、type、start、end字段:
示例:[{{"text": "北京", "type": "LOC", "start": 0, "end": 2}}]
文本:{text}"""
            response = chatglm_model.chat(chatglm_tokenizer, prompt, temperature=0.1)
            if isinstance(response, tuple):
                response = response[0]

            # 增强 JSON 解析
            try:
                json_str = re.search(r'\[.*\]', response, re.DOTALL).group()
                entities = json.loads(json_str)
                # 验证字段
                valid_entities = []
                for ent in entities:
                    if all(k in ent for k in ("text", "type", "start", "end")):
                        valid_entities.append(ent)
                return valid_entities, time.time() - start_time
            except Exception as e:
                print(f"JSON 解析失败: {e}")
                return [], time.time() - start_time
        except Exception as e:
            print(f"ChatGLM 调用失败:{e}")
            return [], time.time() - start_time

    # 使用微调的 BERT 中文 NER 模型
    raw_results = bert_ner_pipeline(text)
    entities = []
    for r in raw_results:
        mapped_type = LABEL_MAPPING.get(r['entity_group'], r['entity_group'])
        entities.append({
            "text": r['word'].replace(' ', ''),
            "start": r['start'],
            "end": r['end'],
            "type": mapped_type
        })

    # 执行合并处理
    entities = merge_adjacent_entities(entities)
    return entities, time.time() - start_time


# ======================== 关系抽取(RE) ========================
def re_extract(entities, text):
    # 修改7:添加实体类型过滤
    valid_entity_types = {"PER", "LOC", "ORG"}
    filtered_entities = [e for e in entities if e["type"] in valid_entity_types]

    if len(filtered_entities) < 2:
        return []

    relations = []
    try:
        entity_pairs = [(e1, e2) for i, e1 in enumerate(entities) for e2 in entities[i + 1:]]
        prompt = f"""分析文本中的实体关系,返回JSON列表:
文本:{text}
实体列表:{[e['text'] for e in entities]}
要求:
1. 仅返回存在明确关系的实体对
2. 关系类型使用:属于、位于、参与、其他
3. 格式示例:[{{"head": "北京", "tail": "中国", "relation": "位于"}}]"""

        if use_chatglm:
            response = chatglm_model.chat(chatglm_tokenizer, prompt, temperature=0.1)
            if isinstance(response, tuple):
                response = response[0]

            # 提取 JSON
            try:
                json_str = re.search(r'\[.*\]', response, re.DOTALL).group()
                relations = json.loads(json_str)
                # 验证关系
                valid_relations = []
                valid_types = {"属于", "位于", "参与", "其他"}
                for rel in relations:
                    if all(k in rel for k in ("head", "tail", "relation")) and rel["relation"] in valid_types:
                        valid_relations.append(rel)
                return valid_relations
            except Exception as e:
                print(f"关系解析失败: {e}")
    except Exception as e:
        print(f"关系抽取失败: {e}")

    # 默认不生成任何关系
    return []


# ======================== 文本分析主流程 ========================
def process_text(text, model_type="bert"):
    entities, duration = ner(text, model_type)
    relations = re_extract(entities, text)
    update_knowledge_graph(entities, relations)

    ent_text = "\n".join(f"{e['text']} ({e['type']}) [{e['start']}-{e['end']}]" for e in entities)
    rel_text = "\n".join(f"{r['head']} --[{r['relation']}]-> {r['tail']}" for r in relations)
    kg_text = visualize_kg_text()
    return ent_text, rel_text, kg_text, f"{duration:.2f} 秒"


def process_file(file, model_type="bert"):
    try:
        with open(file.name, 'rb') as f:
            content = f.read()

        if len(content) > 5 * 1024 * 1024:
            return "❌ 文件太大", "", "", ""

        # 检测编码
        try:
            encoding = chardet.detect(content)['encoding'] or 'utf-8'
            text = content.decode(encoding)
        except UnicodeDecodeError:
            # 尝试常见中文编码
            for enc in ['gb18030', 'utf-16', 'big5']:
                try:
                    text = content.decode(enc)
                    break
                except:
                    continue
            else:
                return "❌ 编码解析失败", "", "", ""

        return process_text(text, model_type)
    except Exception as e:
        return f"❌ 文件处理错误: {str(e)}", "", "", ""


# ======================== 模型评估与自动标注 ========================
def convert_telegram_json_to_eval_format(path):
    with open(path, encoding="utf-8") as f:
        data = json.load(f)
    if isinstance(data, dict) and "text" in data:
        return [{"text": data["text"], "entities": [
            {"text": data["text"][e["start"]:e["end"]]} for e in data.get("entities", [])
        ]}]
    elif isinstance(data, list):
        return data
    elif isinstance(data, dict) and "messages" in data:
        result = []
        for m in data.get("messages", []):
            if isinstance(m.get("text"), str):
                result.append({"text": m["text"], "entities": []})
            elif isinstance(m.get("text"), list):
                txt = ''.join([x["text"] if isinstance(x, dict) else x for x in m["text"]])
                result.append({"text": txt, "entities": []})
        return result
    return []


def evaluate_ner_model(data, model_type):
    y_true, y_pred = [], []
    POS_TOLERANCE = 1  # 允许的位置误差

    for item in data:
        text = item["text"]
        gold_entities = []
        for e in item.get("entities", []):
            if "text" in e and "type" in e:
                # 标准化标签
                norm_type = LABEL_MAPPING.get(e["type"], e["type"])
                gold_entities.append({
                    "text": e["text"],
                    "type": norm_type,
                    "start": e.get("start", -1),
                    "end": e.get("end", -1)
                })

        pred_entities, _ = ner(text, model_type)

        # 构建对比集合
        all_entities = set()
        # 处理标注数据
        for g in gold_entities:
            key = f"{g['text']}|{g['type']}|{g['start']}|{g['end']}"
            all_entities.add(key)

        # 处理预测结果
        pred_set = set()
        for p in pred_entities:
            # 允许位置误差
            matched = False
            for g in gold_entities:
                if (p["text"] == g["text"] and
                        p["type"] == g["type"] and
                        abs(p["start"] - g["start"]) <= POS_TOLERANCE and
                        abs(p["end"] - g["end"]) <= POS_TOLERANCE):
                    matched = True
                    break
            pred_set.add(matched)

        # 构建指标
        y_true.extend([1] * len(gold_entities))
        y_pred.extend([1 if m else 0 for m in pred_set])

    if not y_true:
        return "⚠️ 无有效标注数据"

    return (f"Precision: {precision_score(y_true, y_pred, zero_division=0):.2f}\n"
            f"Recall: {recall_score(y_true, y_pred, zero_division=0):.2f}\n"
            f"F1: {f1_score(y_true, y_pred, zero_division=0):.2f}")

def auto_annotate(file, model_type):
    data = convert_telegram_json_to_eval_format(file.name)
    for item in data:
        ents, _ = ner(item["text"], model_type)
        item["entities"] = ents
    return json.dumps(data, ensure_ascii=False, indent=2)


def save_json(json_text):
    fname = f"auto_labeled_{int(time.time())}.json"
    with open(fname, "w", encoding="utf-8") as f:
        f.write(json_text)
    return fname


# ======================== Gradio 界面 ========================
with gr.Blocks(css="""
    .kg-graph {height: 500px; overflow-y: auto;}
    .warning {color: #ff6b6b;}
""") as demo:
    gr.Markdown("# 🤖 聊天记录实体关系识别系统")

    with gr.Tab("📄 文本分析"):
        input_text = gr.Textbox(lines=6, label="输入文本")
        model_type = gr.Radio(["bert", "chatglm"], value="bert", label="选择模型")
        btn = gr.Button("开始分析")
        out1 = gr.Textbox(label="识别实体")
        out2 = gr.Textbox(label="识别关系")
        out3 = gr.Textbox(label="知识图谱")
        out4 = gr.Textbox(label="耗时")
        btn.click(fn=process_text, inputs=[input_text, model_type], outputs=[out1, out2, out3, out4])

    with gr.Tab("🗂 文件分析"):
        file_input = gr.File(file_types=[".txt", ".json"])
        file_btn = gr.Button("上传并分析")
        fout1, fout2, fout3, fout4 = gr.Textbox(), gr.Textbox(), gr.Textbox(), gr.Textbox()
        file_btn.click(fn=process_file, inputs=[file_input, model_type], outputs=[fout1, fout2, fout3, fout4])

    with gr.Tab("📊 模型评估"):
        eval_file = gr.File(label="上传标注 JSON")
        eval_model = gr.Radio(["bert", "chatglm"], value="bert")
        eval_btn = gr.Button("开始评估")
        eval_output = gr.Textbox(label="评估结果", lines=5)
        eval_btn.click(lambda f, m: evaluate_ner_model(convert_telegram_json_to_eval_format(f.name), m),
                       [eval_file, eval_model], eval_output)

    with gr.Tab("✏️ 自动标注"):
        raw_file = gr.File(label="上传 Telegram 原始 JSON")
        auto_model = gr.Radio(["bert", "chatglm"], value="bert")
        auto_btn = gr.Button("自动标注")
        marked_texts = gr.Textbox(label="标注结果", lines=20)
        download_btn = gr.Button("💾 下载标注文件")
        auto_btn.click(fn=auto_annotate, inputs=[raw_file, auto_model], outputs=marked_texts)
        download_btn.click(fn=save_json, inputs=marked_texts, outputs=gr.File())

demo.launch(server_name="0.0.0.0", server_port=7860)