Spaces:
Sleeping
Sleeping
import torch | |
from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline, AutoModel | |
import gradio as gr | |
import re | |
import os | |
import json | |
import chardet | |
from sklearn.metrics import precision_score, recall_score, f1_score | |
import time | |
# ======================== 模型加载 ======================== | |
NER_MODEL_NAME = "uer/roberta-base-finetuned-cluener2020-chinese" | |
bert_tokenizer = AutoTokenizer.from_pretrained(NER_MODEL_NAME) | |
bert_ner_model = AutoModelForTokenClassification.from_pretrained(NER_MODEL_NAME) | |
bert_ner_pipeline = pipeline( | |
"ner", | |
model=bert_ner_model, | |
tokenizer=bert_tokenizer, | |
aggregation_strategy="first" | |
) | |
LABEL_MAPPING = { | |
"address": "LOC", | |
"company": "ORG", | |
"name": "PER", | |
"organization": "ORG", | |
"position": "TITLE" | |
} | |
chatglm_model, chatglm_tokenizer = None, None | |
use_chatglm = False | |
try: | |
chatglm_model_name = "THUDM/chatglm-6b-int4" | |
chatglm_tokenizer = AutoTokenizer.from_pretrained(chatglm_model_name, trust_remote_code=True) | |
chatglm_model = AutoModel.from_pretrained( | |
chatglm_model_name, | |
trust_remote_code=True, | |
device_map="cpu", | |
torch_dtype=torch.float32 | |
).eval() | |
use_chatglm = True | |
print("✅ 4-bit量化版ChatGLM加载成功") | |
except Exception as e: | |
print(f"❌ ChatGLM加载失败: {e}") | |
# ======================== 知识图谱结构 ======================== | |
knowledge_graph = {"entities": set(), "relations": set()} | |
def update_knowledge_graph(entities, relations): | |
for e in entities: | |
if isinstance(e, dict) and 'text' in e and 'type' in e: | |
knowledge_graph["entities"].add((e['text'], e['type'])) | |
# 修改4:添加关系去重逻辑 | |
existing_relations = {frozenset({r[0], r[1], r[2]}) for r in knowledge_graph["relations"]} | |
for r in relations: | |
if isinstance(r, dict) and all(k in r for k in ("head", "tail", "relation")): | |
new_rel = frozenset({r['head'], r['tail'], r['relation']}) | |
if new_rel not in existing_relations: | |
knowledge_graph["relations"].add((r['head'], r['tail'], r['relation'])) | |
def visualize_kg_text(): | |
nodes = [f"{ent[0]} ({ent[1]})" for ent in knowledge_graph["entities"]] | |
edges = [f"{h} --[{r}]-> {t}" for h, t, r in knowledge_graph["relations"]] | |
return "\n".join(["📌 实体:"] + nodes + ["", "📎 关系:"] + edges) | |
# ======================== 实体识别(NER) ======================== | |
def merge_adjacent_entities(entities): | |
merged = [] | |
for entity in entities: | |
if not merged: | |
merged.append(entity) | |
continue | |
last = merged[-1] | |
# 合并相邻的同类型实体 | |
if (entity["type"] == last["type"] and | |
entity["start"] == last["end"] and | |
entity["text"] not in last["text"]): | |
merged[-1] = { | |
"text": last["text"] + entity["text"], | |
"type": last["type"], | |
"start": last["start"], | |
"end": entity["end"] | |
} | |
else: | |
merged.append(entity) | |
return merged | |
def ner(text, model_type="bert"): | |
start_time = time.time() | |
if model_type == "chatglm" and use_chatglm: | |
try: | |
prompt = f"""请从以下文本中识别所有实体,严格按照JSON列表格式返回,每个实体包含text、type、start、end字段: | |
示例:[{{"text": "北京", "type": "LOC", "start": 0, "end": 2}}] | |
文本:{text}""" | |
response = chatglm_model.chat(chatglm_tokenizer, prompt, temperature=0.1) | |
if isinstance(response, tuple): | |
response = response[0] | |
# 增强 JSON 解析 | |
try: | |
json_str = re.search(r'\[.*\]', response, re.DOTALL).group() | |
entities = json.loads(json_str) | |
# 验证字段 | |
valid_entities = [] | |
for ent in entities: | |
if all(k in ent for k in ("text", "type", "start", "end")): | |
valid_entities.append(ent) | |
return valid_entities, time.time() - start_time | |
except Exception as e: | |
print(f"JSON 解析失败: {e}") | |
return [], time.time() - start_time | |
except Exception as e: | |
print(f"ChatGLM 调用失败:{e}") | |
return [], time.time() - start_time | |
# 使用微调的 BERT 中文 NER 模型 | |
raw_results = bert_ner_pipeline(text) | |
entities = [] | |
for r in raw_results: | |
mapped_type = LABEL_MAPPING.get(r['entity_group'], r['entity_group']) | |
entities.append({ | |
"text": r['word'].replace(' ', ''), | |
"start": r['start'], | |
"end": r['end'], | |
"type": mapped_type | |
}) | |
# 执行合并处理 | |
entities = merge_adjacent_entities(entities) | |
return entities, time.time() - start_time | |
# ======================== 关系抽取(RE) ======================== | |
def re_extract(entities, text): | |
# 修改7:添加实体类型过滤 | |
valid_entity_types = {"PER", "LOC", "ORG"} | |
filtered_entities = [e for e in entities if e["type"] in valid_entity_types] | |
if len(filtered_entities) < 2: | |
return [] | |
relations = [] | |
try: | |
entity_pairs = [(e1, e2) for i, e1 in enumerate(entities) for e2 in entities[i + 1:]] | |
prompt = f"""分析文本中的实体关系,返回JSON列表: | |
文本:{text} | |
实体列表:{[e['text'] for e in entities]} | |
要求: | |
1. 仅返回存在明确关系的实体对 | |
2. 关系类型使用:属于、位于、参与、其他 | |
3. 格式示例:[{{"head": "北京", "tail": "中国", "relation": "位于"}}]""" | |
if use_chatglm: | |
response = chatglm_model.chat(chatglm_tokenizer, prompt, temperature=0.1) | |
if isinstance(response, tuple): | |
response = response[0] | |
# 提取 JSON | |
try: | |
json_str = re.search(r'\[.*\]', response, re.DOTALL).group() | |
relations = json.loads(json_str) | |
# 验证关系 | |
valid_relations = [] | |
valid_types = {"属于", "位于", "参与", "其他"} | |
for rel in relations: | |
if all(k in rel for k in ("head", "tail", "relation")) and rel["relation"] in valid_types: | |
valid_relations.append(rel) | |
return valid_relations | |
except Exception as e: | |
print(f"关系解析失败: {e}") | |
except Exception as e: | |
print(f"关系抽取失败: {e}") | |
# 默认不生成任何关系 | |
return [] | |
# ======================== 文本分析主流程 ======================== | |
def process_text(text, model_type="bert"): | |
entities, duration = ner(text, model_type) | |
relations = re_extract(entities, text) | |
update_knowledge_graph(entities, relations) | |
ent_text = "\n".join(f"{e['text']} ({e['type']}) [{e['start']}-{e['end']}]" for e in entities) | |
rel_text = "\n".join(f"{r['head']} --[{r['relation']}]-> {r['tail']}" for r in relations) | |
kg_text = visualize_kg_text() | |
return ent_text, rel_text, kg_text, f"{duration:.2f} 秒" | |
def process_file(file, model_type="bert"): | |
try: | |
with open(file.name, 'rb') as f: | |
content = f.read() | |
if len(content) > 5 * 1024 * 1024: | |
return "❌ 文件太大", "", "", "" | |
# 检测编码 | |
try: | |
encoding = chardet.detect(content)['encoding'] or 'utf-8' | |
text = content.decode(encoding) | |
except UnicodeDecodeError: | |
# 尝试常见中文编码 | |
for enc in ['gb18030', 'utf-16', 'big5']: | |
try: | |
text = content.decode(enc) | |
break | |
except: | |
continue | |
else: | |
return "❌ 编码解析失败", "", "", "" | |
return process_text(text, model_type) | |
except Exception as e: | |
return f"❌ 文件处理错误: {str(e)}", "", "", "" | |
# ======================== 模型评估与自动标注 ======================== | |
def convert_telegram_json_to_eval_format(path): | |
with open(path, encoding="utf-8") as f: | |
data = json.load(f) | |
if isinstance(data, dict) and "text" in data: | |
return [{"text": data["text"], "entities": [ | |
{"text": data["text"][e["start"]:e["end"]]} for e in data.get("entities", []) | |
]}] | |
elif isinstance(data, list): | |
return data | |
elif isinstance(data, dict) and "messages" in data: | |
result = [] | |
for m in data.get("messages", []): | |
if isinstance(m.get("text"), str): | |
result.append({"text": m["text"], "entities": []}) | |
elif isinstance(m.get("text"), list): | |
txt = ''.join([x["text"] if isinstance(x, dict) else x for x in m["text"]]) | |
result.append({"text": txt, "entities": []}) | |
return result | |
return [] | |
def evaluate_ner_model(data, model_type): | |
y_true, y_pred = [], [] | |
POS_TOLERANCE = 1 # 允许的位置误差 | |
for item in data: | |
text = item["text"] | |
gold_entities = [] | |
for e in item.get("entities", []): | |
if "text" in e and "type" in e: | |
# 标准化标签 | |
norm_type = LABEL_MAPPING.get(e["type"], e["type"]) | |
gold_entities.append({ | |
"text": e["text"], | |
"type": norm_type, | |
"start": e.get("start", -1), | |
"end": e.get("end", -1) | |
}) | |
pred_entities, _ = ner(text, model_type) | |
# 构建对比集合 | |
all_entities = set() | |
# 处理标注数据 | |
for g in gold_entities: | |
key = f"{g['text']}|{g['type']}|{g['start']}|{g['end']}" | |
all_entities.add(key) | |
# 处理预测结果 | |
pred_set = set() | |
for p in pred_entities: | |
# 允许位置误差 | |
matched = False | |
for g in gold_entities: | |
if (p["text"] == g["text"] and | |
p["type"] == g["type"] and | |
abs(p["start"] - g["start"]) <= POS_TOLERANCE and | |
abs(p["end"] - g["end"]) <= POS_TOLERANCE): | |
matched = True | |
break | |
pred_set.add(matched) | |
# 构建指标 | |
y_true.extend([1] * len(gold_entities)) | |
y_pred.extend([1 if m else 0 for m in pred_set]) | |
if not y_true: | |
return "⚠️ 无有效标注数据" | |
return (f"Precision: {precision_score(y_true, y_pred, zero_division=0):.2f}\n" | |
f"Recall: {recall_score(y_true, y_pred, zero_division=0):.2f}\n" | |
f"F1: {f1_score(y_true, y_pred, zero_division=0):.2f}") | |
def auto_annotate(file, model_type): | |
data = convert_telegram_json_to_eval_format(file.name) | |
for item in data: | |
ents, _ = ner(item["text"], model_type) | |
item["entities"] = ents | |
return json.dumps(data, ensure_ascii=False, indent=2) | |
def save_json(json_text): | |
fname = f"auto_labeled_{int(time.time())}.json" | |
with open(fname, "w", encoding="utf-8") as f: | |
f.write(json_text) | |
return fname | |
# ======================== Gradio 界面 ======================== | |
with gr.Blocks(css=""" | |
.kg-graph {height: 500px; overflow-y: auto;} | |
.warning {color: #ff6b6b;} | |
""") as demo: | |
gr.Markdown("# 🤖 聊天记录实体关系识别系统") | |
with gr.Tab("📄 文本分析"): | |
input_text = gr.Textbox(lines=6, label="输入文本") | |
model_type = gr.Radio(["bert", "chatglm"], value="bert", label="选择模型") | |
btn = gr.Button("开始分析") | |
out1 = gr.Textbox(label="识别实体") | |
out2 = gr.Textbox(label="识别关系") | |
out3 = gr.Textbox(label="知识图谱") | |
out4 = gr.Textbox(label="耗时") | |
btn.click(fn=process_text, inputs=[input_text, model_type], outputs=[out1, out2, out3, out4]) | |
with gr.Tab("🗂 文件分析"): | |
file_input = gr.File(file_types=[".txt", ".json"]) | |
file_btn = gr.Button("上传并分析") | |
fout1, fout2, fout3, fout4 = gr.Textbox(), gr.Textbox(), gr.Textbox(), gr.Textbox() | |
file_btn.click(fn=process_file, inputs=[file_input, model_type], outputs=[fout1, fout2, fout3, fout4]) | |
with gr.Tab("📊 模型评估"): | |
eval_file = gr.File(label="上传标注 JSON") | |
eval_model = gr.Radio(["bert", "chatglm"], value="bert") | |
eval_btn = gr.Button("开始评估") | |
eval_output = gr.Textbox(label="评估结果", lines=5) | |
eval_btn.click(lambda f, m: evaluate_ner_model(convert_telegram_json_to_eval_format(f.name), m), | |
[eval_file, eval_model], eval_output) | |
with gr.Tab("✏️ 自动标注"): | |
raw_file = gr.File(label="上传 Telegram 原始 JSON") | |
auto_model = gr.Radio(["bert", "chatglm"], value="bert") | |
auto_btn = gr.Button("自动标注") | |
marked_texts = gr.Textbox(label="标注结果", lines=20) | |
download_btn = gr.Button("💾 下载标注文件") | |
auto_btn.click(fn=auto_annotate, inputs=[raw_file, auto_model], outputs=marked_texts) | |
download_btn.click(fn=save_json, inputs=marked_texts, outputs=gr.File()) | |
demo.launch(server_name="0.0.0.0", server_port=7860) |