Spaces:
Sleeping
Sleeping
File size: 13,355 Bytes
d65f85e d7c47a4 07e97de ff6d08e 07e97de 26ec260 ee61e9e d65f85e 1eeeaf5 6129c00 1e5ba7c d7c47a4 6129c00 4fab3f7 0378c00 6129c00 4fab3f7 6129c00 d7c47a4 6129c00 d7c47a4 1eeeaf5 0207e75 07e97de c85af5a 6129c00 07e97de c85af5a 6129c00 8810e7b 1eeeaf5 6129c00 d7c47a4 1e5ba7c d7c47a4 1d3964d f305260 6129c00 f305260 6129c00 f305260 20683c1 6129c00 8810e7b 1d3964d 1eeeaf5 c781ba0 6129c00 ee61e9e 0207e75 1a6560a 0207e75 d7c47a4 1eeeaf5 0207e75 ee61e9e 8810e7b 1eeeaf5 d7c47a4 26ec260 ee61e9e d7c47a4 6129c00 d7c47a4 6129c00 d7c47a4 6129c00 d7c47a4 6129c00 d7c47a4 07e97de 0207e75 1eeeaf5 6129c00 ee61e9e 07e97de ee61e9e d7c47a4 26ec260 ee61e9e 1eeeaf5 26ec260 d7c47a4 26ec260 d7c47a4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 |
import torch
from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline, AutoModel
import gradio as gr
import re
import os
import json
import chardet
from sklearn.metrics import precision_score, recall_score, f1_score
import time
# ======================== 模型加载 ========================
NER_MODEL_NAME = "uer/roberta-base-finetuned-cluener2020-chinese"
bert_tokenizer = AutoTokenizer.from_pretrained(NER_MODEL_NAME)
bert_ner_model = AutoModelForTokenClassification.from_pretrained(NER_MODEL_NAME)
bert_ner_pipeline = pipeline(
"ner",
model=bert_ner_model,
tokenizer=bert_tokenizer,
aggregation_strategy="first"
)
LABEL_MAPPING = {
"address": "LOC",
"company": "ORG",
"name": "PER",
"organization": "ORG",
"position": "TITLE"
}
chatglm_model, chatglm_tokenizer = None, None
use_chatglm = False
try:
chatglm_model_name = "THUDM/chatglm-6b-int4"
chatglm_tokenizer = AutoTokenizer.from_pretrained(chatglm_model_name, trust_remote_code=True)
chatglm_model = AutoModel.from_pretrained(
chatglm_model_name,
trust_remote_code=True,
device_map="cpu",
torch_dtype=torch.float32
).eval()
use_chatglm = True
print("✅ 4-bit量化版ChatGLM加载成功")
except Exception as e:
print(f"❌ ChatGLM加载失败: {e}")
# ======================== 知识图谱结构 ========================
knowledge_graph = {"entities": set(), "relations": set()}
def update_knowledge_graph(entities, relations):
for e in entities:
if isinstance(e, dict) and 'text' in e and 'type' in e:
knowledge_graph["entities"].add((e['text'], e['type']))
# 修改4:添加关系去重逻辑
existing_relations = {frozenset({r[0], r[1], r[2]}) for r in knowledge_graph["relations"]}
for r in relations:
if isinstance(r, dict) and all(k in r for k in ("head", "tail", "relation")):
new_rel = frozenset({r['head'], r['tail'], r['relation']})
if new_rel not in existing_relations:
knowledge_graph["relations"].add((r['head'], r['tail'], r['relation']))
def visualize_kg_text():
nodes = [f"{ent[0]} ({ent[1]})" for ent in knowledge_graph["entities"]]
edges = [f"{h} --[{r}]-> {t}" for h, t, r in knowledge_graph["relations"]]
return "\n".join(["📌 实体:"] + nodes + ["", "📎 关系:"] + edges)
# ======================== 实体识别(NER) ========================
def merge_adjacent_entities(entities):
merged = []
for entity in entities:
if not merged:
merged.append(entity)
continue
last = merged[-1]
# 合并相邻的同类型实体
if (entity["type"] == last["type"] and
entity["start"] == last["end"] and
entity["text"] not in last["text"]):
merged[-1] = {
"text": last["text"] + entity["text"],
"type": last["type"],
"start": last["start"],
"end": entity["end"]
}
else:
merged.append(entity)
return merged
def ner(text, model_type="bert"):
start_time = time.time()
if model_type == "chatglm" and use_chatglm:
try:
prompt = f"""请从以下文本中识别所有实体,严格按照JSON列表格式返回,每个实体包含text、type、start、end字段:
示例:[{{"text": "北京", "type": "LOC", "start": 0, "end": 2}}]
文本:{text}"""
response = chatglm_model.chat(chatglm_tokenizer, prompt, temperature=0.1)
if isinstance(response, tuple):
response = response[0]
# 增强 JSON 解析
try:
json_str = re.search(r'\[.*\]', response, re.DOTALL).group()
entities = json.loads(json_str)
# 验证字段
valid_entities = []
for ent in entities:
if all(k in ent for k in ("text", "type", "start", "end")):
valid_entities.append(ent)
return valid_entities, time.time() - start_time
except Exception as e:
print(f"JSON 解析失败: {e}")
return [], time.time() - start_time
except Exception as e:
print(f"ChatGLM 调用失败:{e}")
return [], time.time() - start_time
# 使用微调的 BERT 中文 NER 模型
raw_results = bert_ner_pipeline(text)
entities = []
for r in raw_results:
mapped_type = LABEL_MAPPING.get(r['entity_group'], r['entity_group'])
entities.append({
"text": r['word'].replace(' ', ''),
"start": r['start'],
"end": r['end'],
"type": mapped_type
})
# 执行合并处理
entities = merge_adjacent_entities(entities)
return entities, time.time() - start_time
# ======================== 关系抽取(RE) ========================
def re_extract(entities, text):
# 修改7:添加实体类型过滤
valid_entity_types = {"PER", "LOC", "ORG"}
filtered_entities = [e for e in entities if e["type"] in valid_entity_types]
if len(filtered_entities) < 2:
return []
relations = []
try:
entity_pairs = [(e1, e2) for i, e1 in enumerate(entities) for e2 in entities[i + 1:]]
prompt = f"""分析文本中的实体关系,返回JSON列表:
文本:{text}
实体列表:{[e['text'] for e in entities]}
要求:
1. 仅返回存在明确关系的实体对
2. 关系类型使用:属于、位于、参与、其他
3. 格式示例:[{{"head": "北京", "tail": "中国", "relation": "位于"}}]"""
if use_chatglm:
response = chatglm_model.chat(chatglm_tokenizer, prompt, temperature=0.1)
if isinstance(response, tuple):
response = response[0]
# 提取 JSON
try:
json_str = re.search(r'\[.*\]', response, re.DOTALL).group()
relations = json.loads(json_str)
# 验证关系
valid_relations = []
valid_types = {"属于", "位于", "参与", "其他"}
for rel in relations:
if all(k in rel for k in ("head", "tail", "relation")) and rel["relation"] in valid_types:
valid_relations.append(rel)
return valid_relations
except Exception as e:
print(f"关系解析失败: {e}")
except Exception as e:
print(f"关系抽取失败: {e}")
# 默认不生成任何关系
return []
# ======================== 文本分析主流程 ========================
def process_text(text, model_type="bert"):
entities, duration = ner(text, model_type)
relations = re_extract(entities, text)
update_knowledge_graph(entities, relations)
ent_text = "\n".join(f"{e['text']} ({e['type']}) [{e['start']}-{e['end']}]" for e in entities)
rel_text = "\n".join(f"{r['head']} --[{r['relation']}]-> {r['tail']}" for r in relations)
kg_text = visualize_kg_text()
return ent_text, rel_text, kg_text, f"{duration:.2f} 秒"
def process_file(file, model_type="bert"):
try:
with open(file.name, 'rb') as f:
content = f.read()
if len(content) > 5 * 1024 * 1024:
return "❌ 文件太大", "", "", ""
# 检测编码
try:
encoding = chardet.detect(content)['encoding'] or 'utf-8'
text = content.decode(encoding)
except UnicodeDecodeError:
# 尝试常见中文编码
for enc in ['gb18030', 'utf-16', 'big5']:
try:
text = content.decode(enc)
break
except:
continue
else:
return "❌ 编码解析失败", "", "", ""
return process_text(text, model_type)
except Exception as e:
return f"❌ 文件处理错误: {str(e)}", "", "", ""
# ======================== 模型评估与自动标注 ========================
def convert_telegram_json_to_eval_format(path):
with open(path, encoding="utf-8") as f:
data = json.load(f)
if isinstance(data, dict) and "text" in data:
return [{"text": data["text"], "entities": [
{"text": data["text"][e["start"]:e["end"]]} for e in data.get("entities", [])
]}]
elif isinstance(data, list):
return data
elif isinstance(data, dict) and "messages" in data:
result = []
for m in data.get("messages", []):
if isinstance(m.get("text"), str):
result.append({"text": m["text"], "entities": []})
elif isinstance(m.get("text"), list):
txt = ''.join([x["text"] if isinstance(x, dict) else x for x in m["text"]])
result.append({"text": txt, "entities": []})
return result
return []
def evaluate_ner_model(data, model_type):
y_true, y_pred = [], []
POS_TOLERANCE = 1 # 允许的位置误差
for item in data:
text = item["text"]
gold_entities = []
for e in item.get("entities", []):
if "text" in e and "type" in e:
# 标准化标签
norm_type = LABEL_MAPPING.get(e["type"], e["type"])
gold_entities.append({
"text": e["text"],
"type": norm_type,
"start": e.get("start", -1),
"end": e.get("end", -1)
})
pred_entities, _ = ner(text, model_type)
# 构建对比集合
all_entities = set()
# 处理标注数据
for g in gold_entities:
key = f"{g['text']}|{g['type']}|{g['start']}|{g['end']}"
all_entities.add(key)
# 处理预测结果
pred_set = set()
for p in pred_entities:
# 允许位置误差
matched = False
for g in gold_entities:
if (p["text"] == g["text"] and
p["type"] == g["type"] and
abs(p["start"] - g["start"]) <= POS_TOLERANCE and
abs(p["end"] - g["end"]) <= POS_TOLERANCE):
matched = True
break
pred_set.add(matched)
# 构建指标
y_true.extend([1] * len(gold_entities))
y_pred.extend([1 if m else 0 for m in pred_set])
if not y_true:
return "⚠️ 无有效标注数据"
return (f"Precision: {precision_score(y_true, y_pred, zero_division=0):.2f}\n"
f"Recall: {recall_score(y_true, y_pred, zero_division=0):.2f}\n"
f"F1: {f1_score(y_true, y_pred, zero_division=0):.2f}")
def auto_annotate(file, model_type):
data = convert_telegram_json_to_eval_format(file.name)
for item in data:
ents, _ = ner(item["text"], model_type)
item["entities"] = ents
return json.dumps(data, ensure_ascii=False, indent=2)
def save_json(json_text):
fname = f"auto_labeled_{int(time.time())}.json"
with open(fname, "w", encoding="utf-8") as f:
f.write(json_text)
return fname
# ======================== Gradio 界面 ========================
with gr.Blocks(css="""
.kg-graph {height: 500px; overflow-y: auto;}
.warning {color: #ff6b6b;}
""") as demo:
gr.Markdown("# 🤖 聊天记录实体关系识别系统")
with gr.Tab("📄 文本分析"):
input_text = gr.Textbox(lines=6, label="输入文本")
model_type = gr.Radio(["bert", "chatglm"], value="bert", label="选择模型")
btn = gr.Button("开始分析")
out1 = gr.Textbox(label="识别实体")
out2 = gr.Textbox(label="识别关系")
out3 = gr.Textbox(label="知识图谱")
out4 = gr.Textbox(label="耗时")
btn.click(fn=process_text, inputs=[input_text, model_type], outputs=[out1, out2, out3, out4])
with gr.Tab("🗂 文件分析"):
file_input = gr.File(file_types=[".txt", ".json"])
file_btn = gr.Button("上传并分析")
fout1, fout2, fout3, fout4 = gr.Textbox(), gr.Textbox(), gr.Textbox(), gr.Textbox()
file_btn.click(fn=process_file, inputs=[file_input, model_type], outputs=[fout1, fout2, fout3, fout4])
with gr.Tab("📊 模型评估"):
eval_file = gr.File(label="上传标注 JSON")
eval_model = gr.Radio(["bert", "chatglm"], value="bert")
eval_btn = gr.Button("开始评估")
eval_output = gr.Textbox(label="评估结果", lines=5)
eval_btn.click(lambda f, m: evaluate_ner_model(convert_telegram_json_to_eval_format(f.name), m),
[eval_file, eval_model], eval_output)
with gr.Tab("✏️ 自动标注"):
raw_file = gr.File(label="上传 Telegram 原始 JSON")
auto_model = gr.Radio(["bert", "chatglm"], value="bert")
auto_btn = gr.Button("自动标注")
marked_texts = gr.Textbox(label="标注结果", lines=20)
download_btn = gr.Button("💾 下载标注文件")
auto_btn.click(fn=auto_annotate, inputs=[raw_file, auto_model], outputs=marked_texts)
download_btn.click(fn=save_json, inputs=marked_texts, outputs=gr.File())
demo.launch(server_name="0.0.0.0", server_port=7860) |