Spaces:
Sleeping
Sleeping
Commit
·
1a6560a
1
Parent(s):
e256c0a
add app.py and requirements.txt
Browse files- app.py +68 -66
- requirements.txt +4 -3
app.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
import torch
|
2 |
-
from transformers import
|
3 |
import gradio as gr
|
4 |
import re
|
5 |
import os
|
@@ -9,25 +9,23 @@ import chardet
|
|
9 |
from pyvis.network import Network
|
10 |
import time
|
11 |
|
12 |
-
#
|
13 |
bert_model_name = "bert-base-chinese"
|
14 |
bert_tokenizer = BertTokenizer.from_pretrained(bert_model_name)
|
15 |
bert_model = BertModel.from_pretrained(bert_model_name)
|
16 |
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
use_auth_token=access_token # 使用旧参数名确保兼容性
|
23 |
-
)
|
24 |
-
|
25 |
-
llama_model = LlamaForCausalLM.from_pretrained(
|
26 |
-
llama_model_name,
|
27 |
-
use_auth_token=access_token,
|
28 |
-
torch_dtype=torch.float16, # 添加量化配置
|
29 |
-
device_map="auto"
|
30 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
|
32 |
# 知识图谱数据存储
|
33 |
knowledge_graph = {
|
@@ -106,14 +104,25 @@ def visualize_kg():
|
|
106 |
def ner(text, model_type="bert"):
|
107 |
start_time = time.time()
|
108 |
if model_type == "bert":
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
114 |
|
115 |
-
|
116 |
-
|
|
|
117 |
|
118 |
entities = []
|
119 |
occupied = set()
|
@@ -128,7 +137,7 @@ def ner(text, model_type="bert"):
|
|
128 |
"text": match.group(1),
|
129 |
"start": start,
|
130 |
"end": end,
|
131 |
-
"type": "
|
132 |
})
|
133 |
occupied.add((start, end))
|
134 |
|
@@ -139,7 +148,7 @@ def ner(text, model_type="bert"):
|
|
139 |
"text": match.group(1),
|
140 |
"start": start,
|
141 |
"end": end,
|
142 |
-
"type": "
|
143 |
})
|
144 |
occupied.add((start, end))
|
145 |
|
@@ -149,15 +158,26 @@ def ner(text, model_type="bert"):
|
|
149 |
|
150 |
def re_extract(entities, text):
|
151 |
relations = []
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
161 |
return relations
|
162 |
|
163 |
|
@@ -177,7 +197,7 @@ def process_text(text, model_type="bert"):
|
|
177 |
)
|
178 |
kg_html = visualize_kg()
|
179 |
|
180 |
-
return entity_output, relation_output, gr.HTML(kg_html), f"
|
181 |
|
182 |
except Exception as e:
|
183 |
return f"处理出错: {str(e)}", "", gr.HTML(), ""
|
@@ -185,10 +205,7 @@ def process_text(text, model_type="bert"):
|
|
185 |
|
186 |
def process_file(file, model_type="bert"):
|
187 |
try:
|
188 |
-
# 读取文件内容(适配 Hugging Face 文件系统)
|
189 |
content_bytes = file.read()
|
190 |
-
|
191 |
-
# 文件大小限制(5MB)
|
192 |
if len(content_bytes) > 5 * 1024 * 1024:
|
193 |
return "❌ 文件大小超过5MB限制", "", gr.HTML(), ""
|
194 |
|
@@ -226,46 +243,31 @@ css = """
|
|
226 |
"""
|
227 |
|
228 |
with gr.Blocks(css=css) as demo:
|
229 |
-
gr.Markdown("# 🚀
|
230 |
|
231 |
with gr.Tab("✍️ 文本分析"):
|
232 |
-
gr.Markdown("### 直接输入聊天内容")
|
233 |
input_text = gr.Textbox(label="输入内容", lines=8,
|
234 |
placeholder="示例:张三@李四 请把需求文档_v2发送给王五")
|
235 |
-
model_type = gr.Radio(["bert", "
|
236 |
analyze_btn = gr.Button("开始分析", variant="primary")
|
237 |
|
238 |
with gr.Row():
|
239 |
-
entity_output = gr.Textbox(label="识别的实体", lines=6
|
240 |
-
relation_output = gr.Textbox(label="提取的关系", lines=6
|
241 |
kg_output = gr.HTML(label="知识图谱")
|
242 |
-
time_output = gr.Textbox(label="处理时间"
|
243 |
-
|
244 |
-
analyze_btn.click(
|
245 |
-
process_text,
|
246 |
-
inputs=[input_text, model_type],
|
247 |
-
outputs=[entity_output, relation_output, kg_output, time_output],
|
248 |
-
show_progress="full"
|
249 |
-
)
|
250 |
|
251 |
with gr.Tab("📄 文件分析"):
|
252 |
-
gr.
|
253 |
-
file_input = gr.File(label="选择文件", type="file")
|
254 |
analyze_file_btn = gr.Button("开始分析文件", variant="primary")
|
255 |
-
file_entity_output = gr.Textbox(label="识别的实体", lines=6
|
256 |
-
file_relation_output = gr.Textbox(label="提取的关系", lines=6
|
257 |
file_kg_output = gr.HTML(label="知识图谱")
|
258 |
-
file_time_output = gr.Textbox(label="处理时间"
|
259 |
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
show_progress="full"
|
265 |
-
)
|
266 |
|
267 |
-
demo.launch(
|
268 |
-
server_name="0.0.0.0",
|
269 |
-
server_port=7860,
|
270 |
-
debug=False
|
271 |
-
)
|
|
|
1 |
import torch
|
2 |
+
from transformers import AutoTokenizer, AutoModel, BertTokenizer, BertModel
|
3 |
import gradio as gr
|
4 |
import re
|
5 |
import os
|
|
|
9 |
from pyvis.network import Network
|
10 |
import time
|
11 |
|
12 |
+
# 初始化模型
|
13 |
bert_model_name = "bert-base-chinese"
|
14 |
bert_tokenizer = BertTokenizer.from_pretrained(bert_model_name)
|
15 |
bert_model = BertModel.from_pretrained(bert_model_name)
|
16 |
|
17 |
+
# 加载中文模型 ChatGLM3-6B
|
18 |
+
chatglm_model_name = "THUDM/chatglm3-6b"
|
19 |
+
chatglm_tokenizer = AutoTokenizer.from_pretrained(
|
20 |
+
chatglm_model_name,
|
21 |
+
trust_remote_code=True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
)
|
23 |
+
chatglm_model = AutoModel.from_pretrained(
|
24 |
+
chatglm_model_name,
|
25 |
+
trust_remote_code=True,
|
26 |
+
device_map="auto",
|
27 |
+
torch_dtype=torch.float16
|
28 |
+
).eval()
|
29 |
|
30 |
# 知识图谱数据存储
|
31 |
knowledge_graph = {
|
|
|
104 |
def ner(text, model_type="bert"):
|
105 |
start_time = time.time()
|
106 |
if model_type == "bert":
|
107 |
+
# BERT 中文实体识别(原逻辑保留)
|
108 |
+
name_pattern = r"([\u4e00-\u9fa5]{2,4})(?![的等地得啦啊哦])"
|
109 |
+
id_pattern = r"(?<!\S)([a-zA-Z_][a-zA-Z0-9_]{4,})(?![\\u4e00-\\u9fa5])"
|
110 |
+
else:
|
111 |
+
# ChatGLM 增强实体识别
|
112 |
+
response, _ = chatglm_model.chat(
|
113 |
+
chatglm_tokenizer,
|
114 |
+
f"请从以下文本中识别所有实体,用JSON格式返回:[{text}]",
|
115 |
+
temperature=0.1
|
116 |
+
)
|
117 |
+
try:
|
118 |
+
entities = json.loads(response)
|
119 |
+
return entities, time.time() - start_time
|
120 |
+
except:
|
121 |
+
pass
|
122 |
|
123 |
+
# 如果模型响应失败,使用备用正则
|
124 |
+
name_pattern = r"([\\u4e00-\\u9fa5]{2,4})(?![的等地得啦啊哦])"
|
125 |
+
id_pattern = r"(?<!\S)([a-zA-Z_][a-zA-Z0-9_]{4,})"
|
126 |
|
127 |
entities = []
|
128 |
occupied = set()
|
|
|
137 |
"text": match.group(1),
|
138 |
"start": start,
|
139 |
"end": end,
|
140 |
+
"type": "人名"
|
141 |
})
|
142 |
occupied.add((start, end))
|
143 |
|
|
|
148 |
"text": match.group(1),
|
149 |
"start": start,
|
150 |
"end": end,
|
151 |
+
"type": "用户ID"
|
152 |
})
|
153 |
occupied.add((start, end))
|
154 |
|
|
|
158 |
|
159 |
def re_extract(entities, text):
|
160 |
relations = []
|
161 |
+
if len(entities) < 2:
|
162 |
+
return relations
|
163 |
+
|
164 |
+
# 使用ChatGLM分析关系
|
165 |
+
entity_list = [e['text'] for e in entities]
|
166 |
+
prompt = f"分析以下实体之间的关系:{entity_list}\n文本上下文:{text}"
|
167 |
+
response, _ = chatglm_model.chat(chatglm_tokenizer, prompt, temperature=0.1)
|
168 |
+
|
169 |
+
try:
|
170 |
+
relations = json.loads(response)
|
171 |
+
except:
|
172 |
+
# 备用简单关系生成
|
173 |
+
for i in range(len(entities)):
|
174 |
+
for j in range(i + 1, len(entities)):
|
175 |
+
relations.append({
|
176 |
+
"head": entities[i]['text'],
|
177 |
+
"tail": entities[j]['text'],
|
178 |
+
"relation": "相关"
|
179 |
+
})
|
180 |
+
|
181 |
return relations
|
182 |
|
183 |
|
|
|
197 |
)
|
198 |
kg_html = visualize_kg()
|
199 |
|
200 |
+
return entity_output, relation_output, gr.HTML(kg_html), f"处理时间:{processing_time:.2f}秒"
|
201 |
|
202 |
except Exception as e:
|
203 |
return f"处理出错: {str(e)}", "", gr.HTML(), ""
|
|
|
205 |
|
206 |
def process_file(file, model_type="bert"):
|
207 |
try:
|
|
|
208 |
content_bytes = file.read()
|
|
|
|
|
209 |
if len(content_bytes) > 5 * 1024 * 1024:
|
210 |
return "❌ 文件大小超过5MB限制", "", gr.HTML(), ""
|
211 |
|
|
|
243 |
"""
|
244 |
|
245 |
with gr.Blocks(css=css) as demo:
|
246 |
+
gr.Markdown("# 🚀 智能聊天记录分析系统(ChatGLM3-6B版)")
|
247 |
|
248 |
with gr.Tab("✍️ 文本分析"):
|
|
|
249 |
input_text = gr.Textbox(label="输入内容", lines=8,
|
250 |
placeholder="示例:张三@李四 请把需求文档_v2发送给王五")
|
251 |
+
model_type = gr.Radio(["bert", "chatglm"], label="选择模型", value="bert")
|
252 |
analyze_btn = gr.Button("开始分析", variant="primary")
|
253 |
|
254 |
with gr.Row():
|
255 |
+
entity_output = gr.Textbox(label="识别的实体", lines=6)
|
256 |
+
relation_output = gr.Textbox(label="提取的关系", lines=6)
|
257 |
kg_output = gr.HTML(label="知识图谱")
|
258 |
+
time_output = gr.Textbox(label="处理时间")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
259 |
|
260 |
with gr.Tab("📄 文件分析"):
|
261 |
+
file_input = gr.File(label="选择文件", file_types=[".txt", ".csv", ".json"])
|
|
|
262 |
analyze_file_btn = gr.Button("开始分析文件", variant="primary")
|
263 |
+
file_entity_output = gr.Textbox(label="识别的实体", lines=6)
|
264 |
+
file_relation_output = gr.Textbox(label="提取的关系", lines=6)
|
265 |
file_kg_output = gr.HTML(label="知识图谱")
|
266 |
+
file_time_output = gr.Textbox(label="处理时间")
|
267 |
|
268 |
+
analyze_btn.click(process_text, [input_text, model_type],
|
269 |
+
[entity_output, relation_output, kg_output, time_output])
|
270 |
+
analyze_file_btn.click(process_file, [file_input, model_type],
|
271 |
+
[file_entity_output, file_relation_output, file_kg_output, file_time_output])
|
|
|
|
|
272 |
|
273 |
+
demo.launch(server_name="0.0.0.0", server_port=7860)
|
|
|
|
|
|
|
|
requirements.txt
CHANGED
@@ -1,10 +1,11 @@
|
|
1 |
gradio==3.50.2
|
2 |
transformers==4.39.3
|
3 |
torch>=2.1.0
|
4 |
-
accelerate>=0.27.0
|
5 |
-
sentencepiece>=0.2.0
|
6 |
pandas>=2.0.0
|
7 |
chardet>=5.0.0
|
8 |
networkx>=3.0
|
9 |
pyvis>=0.3.2
|
10 |
-
python-dotenv>=1.0.0
|
|
|
|
|
|
|
|
1 |
gradio==3.50.2
|
2 |
transformers==4.39.3
|
3 |
torch>=2.1.0
|
|
|
|
|
4 |
pandas>=2.0.0
|
5 |
chardet>=5.0.0
|
6 |
networkx>=3.0
|
7 |
pyvis>=0.3.2
|
8 |
+
python-dotenv>=1.0.0
|
9 |
+
sentencepiece>=0.2.0
|
10 |
+
cpm-kernels>=1.0.11
|
11 |
+
accelerate>=0.27.0
|