Spaces:
Sleeping
Sleeping
Commit
·
8810e7b
1
Parent(s):
0207e75
add app.py and requirements.txt
Browse files- app.py +25 -178
- requirements.txt +0 -1
app.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
import torch
|
2 |
-
from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline
|
3 |
import gradio as gr
|
4 |
import re
|
5 |
import os
|
@@ -14,21 +14,6 @@ bert_tokenizer = AutoTokenizer.from_pretrained(bert_model_name)
|
|
14 |
bert_ner_model = AutoModelForTokenClassification.from_pretrained("ckiplab/bert-base-chinese-ner")
|
15 |
bert_ner_pipeline = pipeline("ner", model=bert_ner_model, tokenizer=bert_tokenizer, aggregation_strategy="simple")
|
16 |
|
17 |
-
chatglm_model, chatglm_tokenizer = None, None
|
18 |
-
use_chatglm = False
|
19 |
-
try:
|
20 |
-
if torch.cuda.is_available():
|
21 |
-
chatglm_model_name = "THUDM/chatglm3-6b"
|
22 |
-
chatglm_tokenizer = AutoTokenizer.from_pretrained(chatglm_model_name, trust_remote_code=True)
|
23 |
-
chatglm_model = AutoModel.from_pretrained(
|
24 |
-
chatglm_model_name, trust_remote_code=True, device_map="auto", torch_dtype=torch.float16
|
25 |
-
).eval()
|
26 |
-
use_chatglm = True
|
27 |
-
else:
|
28 |
-
print("⚠️ 当前为 CPU 环境,ChatGLM3 不可用,将仅使用 BERT。")
|
29 |
-
except Exception as e:
|
30 |
-
print(f"❌ ChatGLM 加载失败: {e}")
|
31 |
-
|
32 |
# ======================== 知识图谱结构 ========================
|
33 |
knowledge_graph = {"entities": set(), "relations": set()}
|
34 |
|
@@ -40,46 +25,21 @@ def update_knowledge_graph(entities, relations):
|
|
40 |
|
41 |
for r in relations:
|
42 |
if isinstance(r, dict) and all(k in r for k in ("head", "tail", "relation")):
|
43 |
-
# 标准化关系方向
|
44 |
relation_tuple = (r['head'], r['tail'], r['relation'])
|
45 |
reverse_tuple = (r['tail'], r['head'], r['relation'])
|
46 |
if reverse_tuple not in knowledge_graph["relations"]:
|
47 |
knowledge_graph["relations"].add(relation_tuple)
|
48 |
-
|
|
|
49 |
def visualize_kg_text():
|
50 |
nodes = [f"{ent[0]} ({ent[1]})" for ent in knowledge_graph["entities"]]
|
51 |
edges = [f"{h} --[{r}]-> {t}" for h, t, r in knowledge_graph["relations"]]
|
52 |
return "\n".join(["📌 实体:"] + nodes + ["", "📎 关系:"] + edges)
|
53 |
|
|
|
54 |
# ======================== 实体识别(NER) ========================
|
55 |
-
def ner(text
|
56 |
start_time = time.time()
|
57 |
-
if model_type == "chatglm" and use_chatglm:
|
58 |
-
try:
|
59 |
-
prompt = f"""请从以下文本中识别所有实体,严格按照JSON列表格式返回,每个实体包含text、type、start、end字段:
|
60 |
-
示例:[{{"text": "北京", "type": "LOC", "start": 0, "end": 2}}]
|
61 |
-
文本:{text}"""
|
62 |
-
response = chatglm_model.chat(chatglm_tokenizer, prompt, temperature=0.1)
|
63 |
-
if isinstance(response, tuple):
|
64 |
-
response = response[0]
|
65 |
-
|
66 |
-
# 增强 JSON 解析
|
67 |
-
try:
|
68 |
-
json_str = re.search(r'\[.*\]', response, re.DOTALL).group()
|
69 |
-
entities = json.loads(json_str)
|
70 |
-
# 验证字段
|
71 |
-
valid_entities = []
|
72 |
-
for ent in entities:
|
73 |
-
if all(k in ent for k in ("text", "type", "start", "end")):
|
74 |
-
valid_entities.append(ent)
|
75 |
-
return valid_entities, time.time() - start_time
|
76 |
-
except Exception as e:
|
77 |
-
print(f"JSON 解析失败: {e}")
|
78 |
-
return [], time.time() - start_time
|
79 |
-
except Exception as e:
|
80 |
-
print(f"ChatGLM 调用失败:{e}")
|
81 |
-
return [], time.time() - start_time
|
82 |
-
|
83 |
# 使用微调的 BERT 中文 NER 模型
|
84 |
raw_results = bert_ner_pipeline(text)
|
85 |
entities = []
|
@@ -92,6 +52,7 @@ def ner(text, model_type="bert"):
|
|
92 |
})
|
93 |
return entities, time.time() - start_time
|
94 |
|
|
|
95 |
# ======================== 关系抽取(RE) ========================
|
96 |
def re_extract(entities, text):
|
97 |
if len(entities) < 2:
|
@@ -108,164 +69,50 @@ def re_extract(entities, text):
|
|
108 |
2. 关系类型使用:属于、位于、参与、其他
|
109 |
3. 格式示例:[{{"head": "北京", "tail": "中国", "relation": "位于"}}]"""
|
110 |
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
if all(k in rel for k in ("head", "tail", "relation")) and rel["relation"] in valid_types:
|
125 |
-
valid_relations.append(rel)
|
126 |
-
return valid_relations
|
127 |
-
except Exception as e:
|
128 |
-
print(f"关系解析失败: {e}")
|
129 |
except Exception as e:
|
130 |
print(f"关系抽取失败: {e}")
|
131 |
|
132 |
# 默认不生成任何关系
|
133 |
return []
|
134 |
|
|
|
135 |
# ======================== 文本分析主流程 ========================
|
136 |
-
def process_text(text,
|
137 |
-
entities, duration = ner(text
|
138 |
relations = re_extract(entities, text)
|
139 |
update_knowledge_graph(entities, relations)
|
140 |
|
141 |
ent_text = "\n".join(f"{e['text']} ({e['type']}) [{e['start']}-{e['end']}]" for e in entities)
|
142 |
rel_text = "\n".join(f"{r['head']} --[{r['relation']}]-> {r['tail']}" for r in relations)
|
143 |
kg_text = visualize_kg_text()
|
144 |
-
return ent_text, rel_text, kg_text, f"{duration:.2f} 秒"
|
145 |
|
146 |
|
147 |
-
def process_file(file, model_type="bert"):
|
148 |
-
try:
|
149 |
-
with open(file.name, 'rb') as f:
|
150 |
-
content = f.read()
|
151 |
-
|
152 |
-
if len(content) > 5 * 1024 * 1024:
|
153 |
-
return "❌ 文件太大", "", "", ""
|
154 |
-
|
155 |
-
# 检测编码
|
156 |
-
try:
|
157 |
-
encoding = chardet.detect(content)['encoding'] or 'utf-8'
|
158 |
-
text = content.decode(encoding)
|
159 |
-
except UnicodeDecodeError:
|
160 |
-
# 尝试常见中文编码
|
161 |
-
for enc in ['gb18030', 'utf-16', 'big5']:
|
162 |
-
try:
|
163 |
-
text = content.decode(enc)
|
164 |
-
break
|
165 |
-
except:
|
166 |
-
continue
|
167 |
-
else:
|
168 |
-
return "❌ 编码解析失败", "", "", ""
|
169 |
-
|
170 |
-
return process_text(text, model_type)
|
171 |
-
except Exception as e:
|
172 |
-
return f"❌ 文件处理错误: {str(e)}", "", "", ""
|
173 |
-
|
174 |
-
# ======================== 模型评估与自动标注 ========================
|
175 |
-
def convert_telegram_json_to_eval_format(path):
|
176 |
-
with open(path, encoding="utf-8") as f:
|
177 |
-
data = json.load(f)
|
178 |
-
if isinstance(data, dict) and "text" in data:
|
179 |
-
return [{"text": data["text"], "entities": [
|
180 |
-
{"text": data["text"][e["start"]:e["end"]]} for e in data.get("entities", [])
|
181 |
-
]}]
|
182 |
-
elif isinstance(data, list):
|
183 |
-
return data
|
184 |
-
elif isinstance(data, dict) and "messages" in data:
|
185 |
-
result = []
|
186 |
-
for m in data.get("messages", []):
|
187 |
-
if isinstance(m.get("text"), str):
|
188 |
-
result.append({"text": m["text"], "entities": []})
|
189 |
-
elif isinstance(m.get("text"), list):
|
190 |
-
txt = ''.join([x["text"] if isinstance(x, dict) else x for x in m["text"]])
|
191 |
-
result.append({"text": txt, "entities": []})
|
192 |
-
return result
|
193 |
-
return []
|
194 |
-
|
195 |
-
|
196 |
-
def evaluate_ner_model(data, model_type):
|
197 |
-
y_true, y_pred = [], []
|
198 |
-
for item in data:
|
199 |
-
text = item["text"]
|
200 |
-
gold_entities = []
|
201 |
-
for e in item.get("entities", []):
|
202 |
-
if "text" in e and "type" in e:
|
203 |
-
# 使用哈希避免重复
|
204 |
-
gold_entities.append(f"{e['text']}|{e['type']}|{e.get('start', -1)}|{e.get('end', -1)}")
|
205 |
-
|
206 |
-
pred_entities = []
|
207 |
-
pred, _ = ner(text, model_type)
|
208 |
-
for e in pred:
|
209 |
-
pred_entities.append(f"{e['text']}|{e['type']}|{e['start']}|{e['end']}")
|
210 |
-
|
211 |
-
# 创建所有可能的实体集合
|
212 |
-
all_entities = set(gold_entities + pred_entities)
|
213 |
-
for ent in all_entities:
|
214 |
-
y_true.append(1 if ent in gold_entities else 0)
|
215 |
-
y_pred.append(1 if ent in pred_entities else 0)
|
216 |
-
|
217 |
-
if not y_true:
|
218 |
-
return "⚠️ 无有效标注数据"
|
219 |
-
|
220 |
-
return f"Precision: {precision_score(y_true, y_pred):.2f}\nRecall: {recall_score(y_true, y_pred):.2f}\nF1: {f1_score(y_true, y_pred):.2f}"
|
221 |
-
|
222 |
-
def auto_annotate(file, model_type):
|
223 |
-
data = convert_telegram_json_to_eval_format(file.name)
|
224 |
-
for item in data:
|
225 |
-
ents, _ = ner(item["text"], model_type)
|
226 |
-
item["entities"] = ents
|
227 |
-
return json.dumps(data, ensure_ascii=False, indent=2)
|
228 |
-
|
229 |
-
def save_json(json_text):
|
230 |
-
fname = f"auto_labeled_{int(time.time())}.json"
|
231 |
-
with open(fname, "w", encoding="utf-8") as f:
|
232 |
-
f.write(json_text)
|
233 |
-
return fname
|
234 |
-
|
235 |
# ======================== Gradio 界面 ========================
|
236 |
with gr.Blocks(css=".kg-graph {height: 500px;}") as demo:
|
237 |
gr.Markdown("# 🤖 聊天记录实体关系识别系统")
|
238 |
|
239 |
with gr.Tab("📄 文本分析"):
|
240 |
input_text = gr.Textbox(lines=6, label="输入文本")
|
241 |
-
model_type = gr.Radio(["bert", "chatglm"], value="bert", label="选择模型")
|
242 |
btn = gr.Button("开始分析")
|
243 |
out1 = gr.Textbox(label="识别实体")
|
244 |
out2 = gr.Textbox(label="识别关系")
|
245 |
out3 = gr.Textbox(label="知识图谱")
|
246 |
out4 = gr.Textbox(label="耗时")
|
247 |
-
|
248 |
-
|
249 |
-
with gr.Tab("🗂 文件分析"):
|
250 |
-
file_input = gr.File(file_types=[".txt", ".json"])
|
251 |
-
file_btn = gr.Button("上传并分析")
|
252 |
-
fout1, fout2, fout3, fout4 = gr.Textbox(), gr.Textbox(), gr.Textbox(), gr.Textbox()
|
253 |
-
file_btn.click(fn=process_file, inputs=[file_input, model_type], outputs=[fout1, fout2, fout3, fout4])
|
254 |
-
|
255 |
-
with gr.Tab("📊 模型评估"):
|
256 |
-
eval_file = gr.File(label="上传标注 JSON")
|
257 |
-
eval_model = gr.Radio(["bert", "chatglm"], value="bert")
|
258 |
-
eval_btn = gr.Button("开始评估")
|
259 |
-
eval_output = gr.Textbox(label="评估结果", lines=5)
|
260 |
-
eval_btn.click(lambda f, m: evaluate_ner_model(convert_telegram_json_to_eval_format(f.name), m), [eval_file, eval_model], eval_output)
|
261 |
-
|
262 |
-
with gr.Tab("✏️ 自动标注"):
|
263 |
-
raw_file = gr.File(label="上传 Telegram 原始 JSON")
|
264 |
-
auto_model = gr.Radio(["bert", "chatglm"], value="bert")
|
265 |
-
auto_btn = gr.Button("自动标注")
|
266 |
-
marked_texts = gr.Textbox(label="标注结果", lines=20)
|
267 |
-
download_btn = gr.Button("💾 下载标注文件")
|
268 |
-
auto_btn.click(fn=auto_annotate, inputs=[raw_file, auto_model], outputs=marked_texts)
|
269 |
-
download_btn.click(fn=save_json, inputs=marked_texts, outputs=gr.File())
|
270 |
|
271 |
demo.launch(server_name="0.0.0.0", server_port=7860)
|
|
|
1 |
import torch
|
2 |
+
from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline
|
3 |
import gradio as gr
|
4 |
import re
|
5 |
import os
|
|
|
14 |
bert_ner_model = AutoModelForTokenClassification.from_pretrained("ckiplab/bert-base-chinese-ner")
|
15 |
bert_ner_pipeline = pipeline("ner", model=bert_ner_model, tokenizer=bert_tokenizer, aggregation_strategy="simple")
|
16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
# ======================== 知识图谱结构 ========================
|
18 |
knowledge_graph = {"entities": set(), "relations": set()}
|
19 |
|
|
|
25 |
|
26 |
for r in relations:
|
27 |
if isinstance(r, dict) and all(k in r for k in ("head", "tail", "relation")):
|
|
|
28 |
relation_tuple = (r['head'], r['tail'], r['relation'])
|
29 |
reverse_tuple = (r['tail'], r['head'], r['relation'])
|
30 |
if reverse_tuple not in knowledge_graph["relations"]:
|
31 |
knowledge_graph["relations"].add(relation_tuple)
|
32 |
+
|
33 |
+
|
34 |
def visualize_kg_text():
|
35 |
nodes = [f"{ent[0]} ({ent[1]})" for ent in knowledge_graph["entities"]]
|
36 |
edges = [f"{h} --[{r}]-> {t}" for h, t, r in knowledge_graph["relations"]]
|
37 |
return "\n".join(["📌 实体:"] + nodes + ["", "📎 关系:"] + edges)
|
38 |
|
39 |
+
|
40 |
# ======================== 实体识别(NER) ========================
|
41 |
+
def ner(text):
|
42 |
start_time = time.time()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
# 使用微调的 BERT 中文 NER 模型
|
44 |
raw_results = bert_ner_pipeline(text)
|
45 |
entities = []
|
|
|
52 |
})
|
53 |
return entities, time.time() - start_time
|
54 |
|
55 |
+
|
56 |
# ======================== 关系抽取(RE) ========================
|
57 |
def re_extract(entities, text):
|
58 |
if len(entities) < 2:
|
|
|
69 |
2. 关系类型使用:属于、位于、参与、其他
|
70 |
3. 格式示例:[{{"head": "北京", "tail": "中国", "relation": "位于"}}]"""
|
71 |
|
72 |
+
# 仅使用 BERT
|
73 |
+
response = bert_ner_pipeline(prompt)
|
74 |
+
try:
|
75 |
+
json_str = re.search(r'\[.*\]', response, re.DOTALL).group()
|
76 |
+
relations = json.loads(json_str)
|
77 |
+
valid_relations = []
|
78 |
+
valid_types = {"属于", "位于", "参与", "其他"}
|
79 |
+
for rel in relations:
|
80 |
+
if all(k in rel for k in ("head", "tail", "relation")) and rel["relation"] in valid_types:
|
81 |
+
valid_relations.append(rel)
|
82 |
+
return valid_relations
|
83 |
+
except Exception as e:
|
84 |
+
print(f"关系解析失败: {e}")
|
|
|
|
|
|
|
|
|
|
|
85 |
except Exception as e:
|
86 |
print(f"关系抽取失败: {e}")
|
87 |
|
88 |
# 默认不生成任何关系
|
89 |
return []
|
90 |
|
91 |
+
|
92 |
# ======================== 文本分析主流程 ========================
|
93 |
+
def process_text(text, state=None):
|
94 |
+
entities, duration = ner(text)
|
95 |
relations = re_extract(entities, text)
|
96 |
update_knowledge_graph(entities, relations)
|
97 |
|
98 |
ent_text = "\n".join(f"{e['text']} ({e['type']}) [{e['start']}-{e['end']}]" for e in entities)
|
99 |
rel_text = "\n".join(f"{r['head']} --[{r['relation']}]-> {r['tail']}" for r in relations)
|
100 |
kg_text = visualize_kg_text()
|
101 |
+
return ent_text, rel_text, kg_text, f"{duration:.2f} 秒", state
|
102 |
|
103 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
104 |
# ======================== Gradio 界面 ========================
|
105 |
with gr.Blocks(css=".kg-graph {height: 500px;}") as demo:
|
106 |
gr.Markdown("# 🤖 聊天记录实体关系识别系统")
|
107 |
|
108 |
with gr.Tab("📄 文本分析"):
|
109 |
input_text = gr.Textbox(lines=6, label="输入文本")
|
|
|
110 |
btn = gr.Button("开始分析")
|
111 |
out1 = gr.Textbox(label="识别实体")
|
112 |
out2 = gr.Textbox(label="识别关系")
|
113 |
out3 = gr.Textbox(label="知识图谱")
|
114 |
out4 = gr.Textbox(label="耗时")
|
115 |
+
state = gr.State()
|
116 |
+
btn.click(fn=process_text, inputs=[input_text, state], outputs=[out1, out2, out3, out4, state])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
117 |
|
118 |
demo.launch(server_name="0.0.0.0", server_port=7860)
|
requirements.txt
CHANGED
@@ -4,7 +4,6 @@ torch>=2.1.0
|
|
4 |
networkx>=3.0
|
5 |
python-dotenv>=1.0.0
|
6 |
sentencepiece>=0.2.0
|
7 |
-
cpm-kernels>=1.0.11
|
8 |
accelerate>=0.27.0
|
9 |
scikit-learn>=1.3.0
|
10 |
chardet>=5.2.0
|
|
|
4 |
networkx>=3.0
|
5 |
python-dotenv>=1.0.0
|
6 |
sentencepiece>=0.2.0
|
|
|
7 |
accelerate>=0.27.0
|
8 |
scikit-learn>=1.3.0
|
9 |
chardet>=5.2.0
|