Spaces:
Sleeping
Sleeping
Commit
·
26ec260
1
Parent(s):
1a6560a
add app.py and requirements.txt
Browse files- app.py +107 -186
- requirements.txt +5 -4
app.py
CHANGED
@@ -8,32 +8,20 @@ import pandas as pd
|
|
8 |
import chardet
|
9 |
from pyvis.network import Network
|
10 |
import time
|
|
|
11 |
|
12 |
-
#
|
13 |
bert_model_name = "bert-base-chinese"
|
14 |
bert_tokenizer = BertTokenizer.from_pretrained(bert_model_name)
|
15 |
bert_model = BertModel.from_pretrained(bert_model_name)
|
16 |
|
17 |
-
# 加载中文模型 ChatGLM3-6B
|
18 |
chatglm_model_name = "THUDM/chatglm3-6b"
|
19 |
-
chatglm_tokenizer = AutoTokenizer.from_pretrained(
|
20 |
-
|
21 |
-
trust_remote_code=True
|
22 |
-
)
|
23 |
-
chatglm_model = AutoModel.from_pretrained(
|
24 |
-
chatglm_model_name,
|
25 |
-
trust_remote_code=True,
|
26 |
-
device_map="auto",
|
27 |
-
torch_dtype=torch.float16
|
28 |
-
).eval()
|
29 |
-
|
30 |
-
# 知识图谱数据存储
|
31 |
-
knowledge_graph = {
|
32 |
-
"entities": set(),
|
33 |
-
"relations": []
|
34 |
-
}
|
35 |
|
|
|
36 |
|
|
|
37 |
def update_knowledge_graph(entities, relations):
|
38 |
for e in entities:
|
39 |
if isinstance(e, dict) and 'text' in e and 'type' in e:
|
@@ -42,232 +30,165 @@ def update_knowledge_graph(entities, relations):
|
|
42 |
if isinstance(r, dict) and all(k in r for k in ("head", "tail", "relation")):
|
43 |
knowledge_graph["relations"].append((r['head'], r['tail'], r['relation']))
|
44 |
|
45 |
-
|
46 |
def visualize_kg():
|
47 |
net = Network(height="600px", width="100%", notebook=True, directed=True)
|
48 |
node_map = {}
|
49 |
idx = 0
|
50 |
for ent in knowledge_graph["entities"]:
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
label=name,
|
56 |
-
title=f"类型:{type_}",
|
57 |
-
group=type_,
|
58 |
-
font={'size': 20, 'face': 'SimHei'})
|
59 |
-
idx += 1
|
60 |
-
|
61 |
seen_edges = set()
|
62 |
for head, tail, relation in knowledge_graph["relations"]:
|
63 |
if head in node_map and tail in node_map:
|
64 |
edge_key = f"{head}-{tail}-{relation}"
|
65 |
if edge_key not in seen_edges:
|
66 |
-
net.add_edge(node_map[head], node_map[tail],
|
67 |
-
label=relation,
|
68 |
-
arrows='to',
|
69 |
-
font={'size': 14})
|
70 |
seen_edges.add(edge_key)
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
"
|
75 |
-
|
76 |
-
|
77 |
-
"max": 40
|
78 |
-
}
|
79 |
-
},
|
80 |
-
"physics": {
|
81 |
-
"stabilization": {
|
82 |
-
"enabled": true,
|
83 |
-
"iterations": 200,
|
84 |
-
"updateInterval": 25
|
85 |
-
},
|
86 |
-
"barnesHut": {
|
87 |
-
"gravitationalConstant": -2000,
|
88 |
-
"springLength": 150
|
89 |
-
}
|
90 |
-
},
|
91 |
-
"interaction": {
|
92 |
-
"hover": true,
|
93 |
-
"tooltipDelay": 200
|
94 |
-
}
|
95 |
-
}
|
96 |
-
""")
|
97 |
-
|
98 |
-
html = net.generate_html()
|
99 |
-
html = html.replace('//cdnjs.cloudflare.com', 'https://cdnjs.cloudflare.com')
|
100 |
-
html = html.replace('//unpkg.com', 'https://unpkg.com')
|
101 |
return f'<div class="kg-graph">{html}</div>'
|
102 |
|
103 |
-
|
104 |
def ner(text, model_type="bert"):
|
105 |
start_time = time.time()
|
106 |
if model_type == "bert":
|
107 |
-
# BERT 中文实体识别(原逻辑保留)
|
108 |
name_pattern = r"([\u4e00-\u9fa5]{2,4})(?![的等地得啦啊哦])"
|
109 |
-
id_pattern = r"(?<!\S)([a-zA-Z_][a-zA-Z0-9_]{4,})(?![
|
110 |
else:
|
111 |
-
|
112 |
-
response, _ = chatglm_model.chat(
|
113 |
-
chatglm_tokenizer,
|
114 |
-
f"请从以下文本中识别所有实体,用JSON格式返回:[{text}]",
|
115 |
-
temperature=0.1
|
116 |
-
)
|
117 |
try:
|
118 |
entities = json.loads(response)
|
119 |
return entities, time.time() - start_time
|
120 |
except:
|
121 |
-
|
122 |
-
|
123 |
-
# 如果模型响应失败,使用备用正则
|
124 |
-
name_pattern = r"([\\u4e00-\\u9fa5]{2,4})(?![的等地得啦啊哦])"
|
125 |
-
id_pattern = r"(?<!\S)([a-zA-Z_][a-zA-Z0-9_]{4,})"
|
126 |
|
127 |
entities = []
|
128 |
occupied = set()
|
129 |
-
|
130 |
def is_occupied(start, end):
|
131 |
return any(s <= start < e or s < end <= e for s, e in occupied)
|
132 |
|
133 |
for match in re.finditer(name_pattern, text):
|
134 |
start, end = match.start(1), match.end(1)
|
135 |
if not is_occupied(start, end):
|
136 |
-
entities.append({
|
137 |
-
"text": match.group(1),
|
138 |
-
"start": start,
|
139 |
-
"end": end,
|
140 |
-
"type": "人名"
|
141 |
-
})
|
142 |
occupied.add((start, end))
|
143 |
|
144 |
for match in re.finditer(id_pattern, text):
|
145 |
start, end = match.start(1), match.end(1)
|
146 |
if not is_occupied(start, end):
|
147 |
-
entities.append({
|
148 |
-
"text": match.group(1),
|
149 |
-
"start": start,
|
150 |
-
"end": end,
|
151 |
-
"type": "用户ID"
|
152 |
-
})
|
153 |
occupied.add((start, end))
|
154 |
|
155 |
-
|
156 |
-
return entities, processing_time
|
157 |
-
|
158 |
|
159 |
def re_extract(entities, text):
|
160 |
relations = []
|
161 |
if len(entities) < 2:
|
162 |
return relations
|
163 |
-
|
164 |
-
# 使用ChatGLM分析关系
|
165 |
entity_list = [e['text'] for e in entities]
|
166 |
prompt = f"分析以下实体之间的关系:{entity_list}\n文本上下文:{text}"
|
167 |
response, _ = chatglm_model.chat(chatglm_tokenizer, prompt, temperature=0.1)
|
168 |
-
|
169 |
try:
|
170 |
relations = json.loads(response)
|
171 |
except:
|
172 |
-
# 备用简单关系生成
|
173 |
for i in range(len(entities)):
|
174 |
for j in range(i + 1, len(entities)):
|
175 |
-
relations.append({
|
176 |
-
"head": entities[i]['text'],
|
177 |
-
"tail": entities[j]['text'],
|
178 |
-
"relation": "相关"
|
179 |
-
})
|
180 |
-
|
181 |
return relations
|
182 |
|
183 |
-
|
184 |
def process_text(text, model_type="bert"):
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
f"{e['text']} ({e['type']}) [{e['start']}-{e['end']}]"
|
192 |
-
for e in entities
|
193 |
-
)
|
194 |
-
relation_output = "\n".join(
|
195 |
-
f"{r['head']} --[{r['relation']}]-> {r['tail']}"
|
196 |
-
for r in relations
|
197 |
-
)
|
198 |
-
kg_html = visualize_kg()
|
199 |
-
|
200 |
-
return entity_output, relation_output, gr.HTML(kg_html), f"处理时间:{processing_time:.2f}秒"
|
201 |
-
|
202 |
-
except Exception as e:
|
203 |
-
return f"处理出错: {str(e)}", "", gr.HTML(), ""
|
204 |
-
|
205 |
|
206 |
def process_file(file, model_type="bert"):
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
245 |
with gr.Blocks(css=css) as demo:
|
246 |
-
gr.Markdown("# 🚀
|
247 |
|
248 |
with gr.Tab("✍️ 文本分析"):
|
249 |
-
input_text = gr.Textbox(label="输入内容"
|
250 |
-
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
kg_output = gr.HTML(label="知识图谱")
|
258 |
-
time_output = gr.Textbox(label="处理时间")
|
259 |
|
260 |
with gr.Tab("📄 文件分析"):
|
261 |
-
file_input = gr.File(
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
-
|
273 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
import chardet
|
9 |
from pyvis.network import Network
|
10 |
import time
|
11 |
+
from sklearn.metrics import precision_score, recall_score, f1_score
|
12 |
|
13 |
+
# ==== 模型初始化 ====
|
14 |
bert_model_name = "bert-base-chinese"
|
15 |
bert_tokenizer = BertTokenizer.from_pretrained(bert_model_name)
|
16 |
bert_model = BertModel.from_pretrained(bert_model_name)
|
17 |
|
|
|
18 |
chatglm_model_name = "THUDM/chatglm3-6b"
|
19 |
+
chatglm_tokenizer = AutoTokenizer.from_pretrained(chatglm_model_name, trust_remote_code=True)
|
20 |
+
chatglm_model = AutoModel.from_pretrained(chatglm_model_name, trust_remote_code=True, device_map="auto", torch_dtype=torch.float16).eval()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
|
22 |
+
knowledge_graph = {"entities": set(), "relations": []}
|
23 |
|
24 |
+
# ==== 核心处理函数 ====
|
25 |
def update_knowledge_graph(entities, relations):
|
26 |
for e in entities:
|
27 |
if isinstance(e, dict) and 'text' in e and 'type' in e:
|
|
|
30 |
if isinstance(r, dict) and all(k in r for k in ("head", "tail", "relation")):
|
31 |
knowledge_graph["relations"].append((r['head'], r['tail'], r['relation']))
|
32 |
|
|
|
33 |
def visualize_kg():
|
34 |
net = Network(height="600px", width="100%", notebook=True, directed=True)
|
35 |
node_map = {}
|
36 |
idx = 0
|
37 |
for ent in knowledge_graph["entities"]:
|
38 |
+
name, type_ = ent
|
39 |
+
node_map[name] = idx
|
40 |
+
net.add_node(idx, label=name, title=f"类型:{type_}", group=type_, font={'size': 20, 'face': 'SimHei'})
|
41 |
+
idx += 1
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
seen_edges = set()
|
43 |
for head, tail, relation in knowledge_graph["relations"]:
|
44 |
if head in node_map and tail in node_map:
|
45 |
edge_key = f"{head}-{tail}-{relation}"
|
46 |
if edge_key not in seen_edges:
|
47 |
+
net.add_edge(node_map[head], node_map[tail], label=relation, arrows='to', font={'size': 14})
|
|
|
|
|
|
|
48 |
seen_edges.add(edge_key)
|
49 |
+
net.set_options("""{
|
50 |
+
"nodes": {"scaling": {"min": 20, "max": 40}},
|
51 |
+
"physics": {"stabilization": {"enabled": true, "iterations": 200}, "barnesHut": {"gravitationalConstant": -2000, "springLength": 150}},
|
52 |
+
"interaction": {"hover": true, "tooltipDelay": 200}
|
53 |
+
}""")
|
54 |
+
html = net.generate_html().replace('//cdnjs.cloudflare.com', 'https://cdnjs.cloudflare.com').replace('//unpkg.com', 'https://unpkg.com')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
return f'<div class="kg-graph">{html}</div>'
|
56 |
|
|
|
57 |
def ner(text, model_type="bert"):
|
58 |
start_time = time.time()
|
59 |
if model_type == "bert":
|
|
|
60 |
name_pattern = r"([\u4e00-\u9fa5]{2,4})(?![的等地得啦啊哦])"
|
61 |
+
id_pattern = r"(?<!\S)([a-zA-Z_][a-zA-Z0-9_]{4,})(?![\u4e00-\u9fa5])"
|
62 |
else:
|
63 |
+
response, _ = chatglm_model.chat(chatglm_tokenizer, f"请从以下文本中识别所有实体,用JSON格式返回:[{text}]", temperature=0.1)
|
|
|
|
|
|
|
|
|
|
|
64 |
try:
|
65 |
entities = json.loads(response)
|
66 |
return entities, time.time() - start_time
|
67 |
except:
|
68 |
+
name_pattern = r"([\u4e00-\u9fa5]{2,4})(?![的等地得啦啊哦])"
|
69 |
+
id_pattern = r"(?<!\S)([a-zA-Z_][a-zA-Z0-9_]{4,})"
|
|
|
|
|
|
|
70 |
|
71 |
entities = []
|
72 |
occupied = set()
|
|
|
73 |
def is_occupied(start, end):
|
74 |
return any(s <= start < e or s < end <= e for s, e in occupied)
|
75 |
|
76 |
for match in re.finditer(name_pattern, text):
|
77 |
start, end = match.start(1), match.end(1)
|
78 |
if not is_occupied(start, end):
|
79 |
+
entities.append({"text": match.group(1), "start": start, "end": end, "type": "人名"})
|
|
|
|
|
|
|
|
|
|
|
80 |
occupied.add((start, end))
|
81 |
|
82 |
for match in re.finditer(id_pattern, text):
|
83 |
start, end = match.start(1), match.end(1)
|
84 |
if not is_occupied(start, end):
|
85 |
+
entities.append({"text": match.group(1), "start": start, "end": end, "type": "用户ID"})
|
|
|
|
|
|
|
|
|
|
|
86 |
occupied.add((start, end))
|
87 |
|
88 |
+
return entities, time.time() - start_time
|
|
|
|
|
89 |
|
90 |
def re_extract(entities, text):
|
91 |
relations = []
|
92 |
if len(entities) < 2:
|
93 |
return relations
|
|
|
|
|
94 |
entity_list = [e['text'] for e in entities]
|
95 |
prompt = f"分析以下实体之间的关系:{entity_list}\n文本上下文:{text}"
|
96 |
response, _ = chatglm_model.chat(chatglm_tokenizer, prompt, temperature=0.1)
|
|
|
97 |
try:
|
98 |
relations = json.loads(response)
|
99 |
except:
|
|
|
100 |
for i in range(len(entities)):
|
101 |
for j in range(i + 1, len(entities)):
|
102 |
+
relations.append({"head": entities[i]['text'], "tail": entities[j]['text'], "relation": "相关"})
|
|
|
|
|
|
|
|
|
|
|
103 |
return relations
|
104 |
|
|
|
105 |
def process_text(text, model_type="bert"):
|
106 |
+
entities, processing_time = ner(text, model_type)
|
107 |
+
relations = re_extract(entities, text)
|
108 |
+
update_knowledge_graph(entities, relations)
|
109 |
+
entity_output = "\n".join(f"{e['text']} ({e['type']}) [{e['start']}-{e['end']}]" for e in entities)
|
110 |
+
relation_output = "\n".join(f"{r['head']} --[{r['relation']}]-> {r['tail']}" for r in relations)
|
111 |
+
return entity_output, relation_output, gr.HTML(visualize_kg()), f"处理时间:{processing_time:.2f}秒"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
112 |
|
113 |
def process_file(file, model_type="bert"):
|
114 |
+
content_bytes = file.read()
|
115 |
+
if len(content_bytes) > 5 * 1024 * 1024:
|
116 |
+
return "❌ 文件太大", "", gr.HTML(), ""
|
117 |
+
encoding = chardet.detect(content_bytes)['encoding'] or 'utf-8'
|
118 |
+
text = content_bytes.decode(encoding)
|
119 |
+
return process_text(text, model_type)
|
120 |
+
|
121 |
+
# ==== 评估功能与自动标注 ====
|
122 |
+
def convert_telegram_json_to_eval_format(path):
|
123 |
+
data = json.load(open(path, encoding="utf-8"))
|
124 |
+
result = []
|
125 |
+
for m in data.get("messages", []):
|
126 |
+
if isinstance(m.get("text"), str):
|
127 |
+
result.append({"text": m["text"], "entities": []})
|
128 |
+
elif isinstance(m.get("text"), list):
|
129 |
+
txt = ''.join([x["text"] if isinstance(x, dict) else x for x in m["text"]])
|
130 |
+
result.append({"text": txt, "entities": []})
|
131 |
+
return result
|
132 |
+
|
133 |
+
def evaluate_ner_model(data, model_type):
|
134 |
+
y_true, y_pred = [], []
|
135 |
+
for item in data:
|
136 |
+
gold = set(e['text'] for e in item['entities'])
|
137 |
+
pred, _ = ner(item['text'], model_type)
|
138 |
+
pred = set(e['text'] for e in pred)
|
139 |
+
for ent in gold.union(pred):
|
140 |
+
y_true.append(1 if ent in gold else 0)
|
141 |
+
y_pred.append(1 if ent in pred else 0)
|
142 |
+
return f"📊 {model_type} 实体识别评估:\nPrecision: {precision_score(y_true,y_pred):.2f}\nRecall: {recall_score(y_true,y_pred):.2f}\nF1: {f1_score(y_true,y_pred):.2f}"
|
143 |
+
|
144 |
+
def auto_annotate(file, model_type):
|
145 |
+
data = convert_telegram_json_to_eval_format(file.name)
|
146 |
+
for item in data:
|
147 |
+
ents, _ = ner(item["text"], model_type)
|
148 |
+
item["entities"] = ents
|
149 |
+
return json.dumps(data, ensure_ascii=False, indent=2)
|
150 |
+
|
151 |
+
def save_json(json_text):
|
152 |
+
fname = "auto_labeled.json"
|
153 |
+
with open(fname, "w", encoding="utf-8") as f:
|
154 |
+
f.write(json_text)
|
155 |
+
return fname
|
156 |
+
|
157 |
+
# ==== Gradio UI ====
|
158 |
+
css = ".kg-graph { height: 600px; }"
|
159 |
with gr.Blocks(css=css) as demo:
|
160 |
+
gr.Markdown("# 🚀 智能聊天分析系统 + 标注评估工具")
|
161 |
|
162 |
with gr.Tab("✍️ 文本分析"):
|
163 |
+
input_text = gr.Textbox(lines=6, label="输入内容")
|
164 |
+
model_type = gr.Radio(["bert", "chatglm"], value="bert", label="模型")
|
165 |
+
btn = gr.Button("开始分析")
|
166 |
+
out1 = gr.Textbox(label="实体")
|
167 |
+
out2 = gr.Textbox(label="关系")
|
168 |
+
out3 = gr.HTML()
|
169 |
+
out4 = gr.Textbox(label="耗时")
|
170 |
+
btn.click(fn=process_text, inputs=[input_text, model_type], outputs=[out1, out2, out3, out4])
|
|
|
|
|
171 |
|
172 |
with gr.Tab("📄 文件分析"):
|
173 |
+
file_input = gr.File(file_types=[".txt", ".json", ".csv"])
|
174 |
+
btn2 = gr.Button("分析文件")
|
175 |
+
fout1, fout2, fout3, fout4 = gr.Textbox(), gr.Textbox(), gr.HTML(), gr.Textbox()
|
176 |
+
btn2.click(fn=process_file, inputs=[file_input, model_type], outputs=[fout1, fout2, fout3, fout4])
|
177 |
+
|
178 |
+
with gr.Tab("📊 模型评估"):
|
179 |
+
eval_file = gr.File(label="上传标注数据集")
|
180 |
+
eval_model = gr.Radio(["bert", "chatglm"], value="bert")
|
181 |
+
eval_btn = gr.Button("开始评估")
|
182 |
+
eval_output = gr.Textbox(label="评估结果", lines=5)
|
183 |
+
eval_btn.click(lambda f, m: evaluate_ner_model(convert_telegram_json_to_eval_format(f.name), m), [eval_file, eval_model], eval_output)
|
184 |
+
|
185 |
+
with gr.Tab("🖍 实体标注助手"):
|
186 |
+
raw_file = gr.File(label="上传原始 Telegram JSON")
|
187 |
+
auto_model = gr.Radio(["bert", "chatglm"], value="bert")
|
188 |
+
auto_btn = gr.Button("自动初标")
|
189 |
+
marked_texts = gr.Textbox(label="初步标注结果(可下载)", lines=20)
|
190 |
+
download_btn = gr.Button("💾 下载JSON")
|
191 |
+
auto_btn.click(fn=auto_annotate, inputs=[raw_file, auto_model], outputs=marked_texts)
|
192 |
+
download_btn.click(fn=save_json, inputs=marked_texts, outputs=gr.File())
|
193 |
+
|
194 |
+
demo.launch(server_name="0.0.0.0", server_port=7860)
|
requirements.txt
CHANGED
@@ -1,11 +1,12 @@
|
|
1 |
gradio==3.50.2
|
2 |
transformers==4.39.3
|
3 |
torch>=2.1.0
|
4 |
-
pandas>=2.0.0
|
5 |
-
chardet>=5.0.0
|
6 |
networkx>=3.0
|
7 |
-
pyvis>=0.3.2
|
8 |
python-dotenv>=1.0.0
|
9 |
sentencepiece>=0.2.0
|
10 |
cpm-kernels>=1.0.11
|
11 |
-
accelerate>=0.27.0
|
|
|
|
|
|
|
|
|
|
1 |
gradio==3.50.2
|
2 |
transformers==4.39.3
|
3 |
torch>=2.1.0
|
|
|
|
|
4 |
networkx>=3.0
|
|
|
5 |
python-dotenv>=1.0.0
|
6 |
sentencepiece>=0.2.0
|
7 |
cpm-kernels>=1.0.11
|
8 |
+
accelerate>=0.27.0
|
9 |
+
scikit-learn>=1.3.0
|
10 |
+
chardet>=5.2.0
|
11 |
+
pandas>=2.1.0
|
12 |
+
pyvis>=0.3.2
|