Spaces:
Sleeping
Sleeping
Commit
·
1eeeaf5
1
Parent(s):
ee61e9e
add app.py and requirements.txt
Browse files
app.py
CHANGED
@@ -4,29 +4,23 @@ import gradio as gr
|
|
4 |
import re
|
5 |
import os
|
6 |
import json
|
7 |
-
import pandas as pd
|
8 |
import chardet
|
9 |
from sklearn.metrics import precision_score, recall_score, f1_score
|
10 |
import time
|
11 |
|
12 |
-
#
|
13 |
bert_model_name = "bert-base-chinese"
|
14 |
bert_tokenizer = BertTokenizer.from_pretrained(bert_model_name)
|
15 |
bert_model = BertModel.from_pretrained(bert_model_name)
|
16 |
|
17 |
-
# chatglm3 模型检测与安全加载
|
18 |
chatglm_model, chatglm_tokenizer = None, None
|
19 |
use_chatglm = False
|
20 |
-
|
21 |
try:
|
22 |
if torch.cuda.is_available():
|
23 |
chatglm_model_name = "THUDM/chatglm3-6b"
|
24 |
chatglm_tokenizer = AutoTokenizer.from_pretrained(chatglm_model_name, trust_remote_code=True)
|
25 |
chatglm_model = AutoModel.from_pretrained(
|
26 |
-
chatglm_model_name,
|
27 |
-
trust_remote_code=True,
|
28 |
-
device_map="auto",
|
29 |
-
torch_dtype=torch.float16
|
30 |
).eval()
|
31 |
use_chatglm = True
|
32 |
else:
|
@@ -34,7 +28,7 @@ try:
|
|
34 |
except Exception as e:
|
35 |
print(f"❌ ChatGLM 加载失败: {e}")
|
36 |
|
37 |
-
#
|
38 |
knowledge_graph = {"entities": set(), "relations": []}
|
39 |
|
40 |
def update_knowledge_graph(entities, relations):
|
@@ -50,16 +44,17 @@ def visualize_kg_text():
|
|
50 |
edges = [f"{h} --[{r}]-> {t}" for h, t, r in knowledge_graph["relations"]]
|
51 |
return "📌 实体:\n" + "\n".join(nodes) + "\n\n📎 关系:\n" + "\n".join(edges)
|
52 |
|
53 |
-
#
|
54 |
def ner(text, model_type="bert"):
|
55 |
start_time = time.time()
|
56 |
if model_type == "chatglm" and use_chatglm:
|
57 |
try:
|
58 |
-
|
|
|
59 |
entities = json.loads(response)
|
60 |
return entities, time.time() - start_time
|
61 |
except Exception as e:
|
62 |
-
print(f"❌ ChatGLM
|
63 |
return [], time.time() - start_time
|
64 |
|
65 |
name_pattern = r"([\u4e00-\u9fa5]{2,4})(?![的等地得啦啊哦])"
|
@@ -83,7 +78,7 @@ def ner(text, model_type="bert"):
|
|
83 |
|
84 |
return entities, time.time() - start_time
|
85 |
|
86 |
-
#
|
87 |
def re_extract(entities, text):
|
88 |
if len(entities) < 2:
|
89 |
return []
|
@@ -93,11 +88,11 @@ def re_extract(entities, text):
|
|
93 |
if use_chatglm:
|
94 |
response, _ = chatglm_model.chat(chatglm_tokenizer, prompt, temperature=0.1)
|
95 |
return json.loads(response)
|
96 |
-
except:
|
97 |
-
|
98 |
return [{"head": e1['text'], "tail": e2['text'], "relation": "相关"} for i, e1 in enumerate(entities) for e2 in entities[i+1:]]
|
99 |
|
100 |
-
#
|
101 |
def process_text(text, model_type="bert"):
|
102 |
entities, duration = ner(text, model_type)
|
103 |
relations = re_extract(entities, text)
|
@@ -106,7 +101,7 @@ def process_text(text, model_type="bert"):
|
|
106 |
ent_text = "\n".join(f"{e['text']} ({e['type']}) [{e['start']}-{e['end']}]" for e in entities)
|
107 |
rel_text = "\n".join(f"{r['head']} --[{r['relation']}]-> {r['tail']}" for r in relations)
|
108 |
kg_text = visualize_kg_text()
|
109 |
-
return ent_text, rel_text, kg_text, f"{duration:.2f}秒"
|
110 |
|
111 |
def process_file(file, model_type="bert"):
|
112 |
content = file.read()
|
@@ -116,7 +111,7 @@ def process_file(file, model_type="bert"):
|
|
116 |
text = content.decode(encoding)
|
117 |
return process_text(text, model_type)
|
118 |
|
119 |
-
#
|
120 |
def convert_telegram_json_to_eval_format(path):
|
121 |
data = json.load(open(path, encoding="utf-8"))
|
122 |
result = []
|
@@ -152,7 +147,7 @@ def save_json(json_text):
|
|
152 |
f.write(json_text)
|
153 |
return fname
|
154 |
|
155 |
-
#
|
156 |
with gr.Blocks(css=".kg-graph {height: 500px;}") as demo:
|
157 |
gr.Markdown("# 🤖 聊天记录实体关系识别系统")
|
158 |
|
@@ -162,7 +157,7 @@ with gr.Blocks(css=".kg-graph {height: 500px;}") as demo:
|
|
162 |
btn = gr.Button("开始分析")
|
163 |
out1 = gr.Textbox(label="识别实体")
|
164 |
out2 = gr.Textbox(label="识别关系")
|
165 |
-
out3 = gr.Textbox(label="
|
166 |
out4 = gr.Textbox(label="耗时")
|
167 |
btn.click(fn=process_text, inputs=[input_text, model_type], outputs=[out1, out2, out3, out4])
|
168 |
|
@@ -179,12 +174,12 @@ with gr.Blocks(css=".kg-graph {height: 500px;}") as demo:
|
|
179 |
eval_output = gr.Textbox(label="评估结果", lines=5)
|
180 |
eval_btn.click(lambda f, m: evaluate_ner_model(convert_telegram_json_to_eval_format(f.name), m), [eval_file, eval_model], eval_output)
|
181 |
|
182 |
-
with gr.Tab("✏️
|
183 |
raw_file = gr.File(label="上传 Telegram 原始 JSON")
|
184 |
auto_model = gr.Radio(["bert", "chatglm"], value="bert")
|
185 |
-
auto_btn = gr.Button("
|
186 |
-
marked_texts = gr.Textbox(label="
|
187 |
-
download_btn = gr.Button("💾
|
188 |
auto_btn.click(fn=auto_annotate, inputs=[raw_file, auto_model], outputs=marked_texts)
|
189 |
download_btn.click(fn=save_json, inputs=marked_texts, outputs=gr.File())
|
190 |
|
|
|
4 |
import re
|
5 |
import os
|
6 |
import json
|
|
|
7 |
import chardet
|
8 |
from sklearn.metrics import precision_score, recall_score, f1_score
|
9 |
import time
|
10 |
|
11 |
+
# ======================== 模型加载 ========================
|
12 |
bert_model_name = "bert-base-chinese"
|
13 |
bert_tokenizer = BertTokenizer.from_pretrained(bert_model_name)
|
14 |
bert_model = BertModel.from_pretrained(bert_model_name)
|
15 |
|
|
|
16 |
chatglm_model, chatglm_tokenizer = None, None
|
17 |
use_chatglm = False
|
|
|
18 |
try:
|
19 |
if torch.cuda.is_available():
|
20 |
chatglm_model_name = "THUDM/chatglm3-6b"
|
21 |
chatglm_tokenizer = AutoTokenizer.from_pretrained(chatglm_model_name, trust_remote_code=True)
|
22 |
chatglm_model = AutoModel.from_pretrained(
|
23 |
+
chatglm_model_name, trust_remote_code=True, device_map="auto", torch_dtype=torch.float16
|
|
|
|
|
|
|
24 |
).eval()
|
25 |
use_chatglm = True
|
26 |
else:
|
|
|
28 |
except Exception as e:
|
29 |
print(f"❌ ChatGLM 加载失败: {e}")
|
30 |
|
31 |
+
# ======================== 知识图谱结构 ========================
|
32 |
knowledge_graph = {"entities": set(), "relations": []}
|
33 |
|
34 |
def update_knowledge_graph(entities, relations):
|
|
|
44 |
edges = [f"{h} --[{r}]-> {t}" for h, t, r in knowledge_graph["relations"]]
|
45 |
return "📌 实体:\n" + "\n".join(nodes) + "\n\n📎 关系:\n" + "\n".join(edges)
|
46 |
|
47 |
+
# ======================== 实体识别(NER) ========================
|
48 |
def ner(text, model_type="bert"):
|
49 |
start_time = time.time()
|
50 |
if model_type == "chatglm" and use_chatglm:
|
51 |
try:
|
52 |
+
prompt = f"请从以下文本中识别所有实体,用JSON格式返回:[{text}]"
|
53 |
+
response, _ = chatglm_model.chat(chatglm_tokenizer, prompt, temperature=0.1)
|
54 |
entities = json.loads(response)
|
55 |
return entities, time.time() - start_time
|
56 |
except Exception as e:
|
57 |
+
print(f"❌ ChatGLM 实体识别失败:{e}")
|
58 |
return [], time.time() - start_time
|
59 |
|
60 |
name_pattern = r"([\u4e00-\u9fa5]{2,4})(?![的等地得啦啊哦])"
|
|
|
78 |
|
79 |
return entities, time.time() - start_time
|
80 |
|
81 |
+
# ======================== 关系抽取(RE) ========================
|
82 |
def re_extract(entities, text):
|
83 |
if len(entities) < 2:
|
84 |
return []
|
|
|
88 |
if use_chatglm:
|
89 |
response, _ = chatglm_model.chat(chatglm_tokenizer, prompt, temperature=0.1)
|
90 |
return json.loads(response)
|
91 |
+
except Exception as e:
|
92 |
+
print(f"❌ ChatGLM 关系抽取失败:{e}")
|
93 |
return [{"head": e1['text'], "tail": e2['text'], "relation": "相关"} for i, e1 in enumerate(entities) for e2 in entities[i+1:]]
|
94 |
|
95 |
+
# ======================== 文本分析主流程 ========================
|
96 |
def process_text(text, model_type="bert"):
|
97 |
entities, duration = ner(text, model_type)
|
98 |
relations = re_extract(entities, text)
|
|
|
101 |
ent_text = "\n".join(f"{e['text']} ({e['type']}) [{e['start']}-{e['end']}]" for e in entities)
|
102 |
rel_text = "\n".join(f"{r['head']} --[{r['relation']}]-> {r['tail']}" for r in relations)
|
103 |
kg_text = visualize_kg_text()
|
104 |
+
return ent_text, rel_text, kg_text, f"{duration:.2f} 秒"
|
105 |
|
106 |
def process_file(file, model_type="bert"):
|
107 |
content = file.read()
|
|
|
111 |
text = content.decode(encoding)
|
112 |
return process_text(text, model_type)
|
113 |
|
114 |
+
# ======================== 模型评估与自动标注 ========================
|
115 |
def convert_telegram_json_to_eval_format(path):
|
116 |
data = json.load(open(path, encoding="utf-8"))
|
117 |
result = []
|
|
|
147 |
f.write(json_text)
|
148 |
return fname
|
149 |
|
150 |
+
# ======================== Gradio 界面 ========================
|
151 |
with gr.Blocks(css=".kg-graph {height: 500px;}") as demo:
|
152 |
gr.Markdown("# 🤖 聊天记录实体关系识别系统")
|
153 |
|
|
|
157 |
btn = gr.Button("开始分析")
|
158 |
out1 = gr.Textbox(label="识别实体")
|
159 |
out2 = gr.Textbox(label="识别关系")
|
160 |
+
out3 = gr.Textbox(label="知识图谱")
|
161 |
out4 = gr.Textbox(label="耗时")
|
162 |
btn.click(fn=process_text, inputs=[input_text, model_type], outputs=[out1, out2, out3, out4])
|
163 |
|
|
|
174 |
eval_output = gr.Textbox(label="评估结果", lines=5)
|
175 |
eval_btn.click(lambda f, m: evaluate_ner_model(convert_telegram_json_to_eval_format(f.name), m), [eval_file, eval_model], eval_output)
|
176 |
|
177 |
+
with gr.Tab("✏️ 自动标注"):
|
178 |
raw_file = gr.File(label="上传 Telegram 原始 JSON")
|
179 |
auto_model = gr.Radio(["bert", "chatglm"], value="bert")
|
180 |
+
auto_btn = gr.Button("自动标注")
|
181 |
+
marked_texts = gr.Textbox(label="标注结果", lines=20)
|
182 |
+
download_btn = gr.Button("💾 下载标注文件")
|
183 |
auto_btn.click(fn=auto_annotate, inputs=[raw_file, auto_model], outputs=marked_texts)
|
184 |
download_btn.click(fn=save_json, inputs=marked_texts, outputs=gr.File())
|
185 |
|