chen666-666 commited on
Commit
1a6560a
·
1 Parent(s): e256c0a

add app.py and requirements.txt

Browse files
Files changed (2) hide show
  1. app.py +68 -66
  2. requirements.txt +4 -3
app.py CHANGED
@@ -1,5 +1,5 @@
1
  import torch
2
- from transformers import BertTokenizer, BertModel, LlamaTokenizer, LlamaForCausalLM
3
  import gradio as gr
4
  import re
5
  import os
@@ -9,25 +9,23 @@ import chardet
9
  from pyvis.network import Network
10
  import time
11
 
12
- # 初始化模型(适配 Hugging Face Secrets)
13
  bert_model_name = "bert-base-chinese"
14
  bert_tokenizer = BertTokenizer.from_pretrained(bert_model_name)
15
  bert_model = BertModel.from_pretrained(bert_model_name)
16
 
17
- llama_model_name = os.getenv("LLAMA_MODEL_NAME", "meta-llama/Llama-2-7b-chat-hf")
18
- access_token = os.getenv("HUGGINGFACE_HUB_TOKEN")
19
-
20
- llama_tokenizer = LlamaTokenizer.from_pretrained(
21
- llama_model_name,
22
- use_auth_token=access_token # 使用旧参数名确保兼容性
23
- )
24
-
25
- llama_model = LlamaForCausalLM.from_pretrained(
26
- llama_model_name,
27
- use_auth_token=access_token,
28
- torch_dtype=torch.float16, # 添加量化配置
29
- device_map="auto"
30
  )
 
 
 
 
 
 
31
 
32
  # 知识图谱数据存储
33
  knowledge_graph = {
@@ -106,14 +104,25 @@ def visualize_kg():
106
  def ner(text, model_type="bert"):
107
  start_time = time.time()
108
  if model_type == "bert":
109
- tokenizer = bert_tokenizer
110
- model = bert_model
111
- elif model_type == "llama":
112
- tokenizer = llama_tokenizer
113
- model = llama_model
 
 
 
 
 
 
 
 
 
 
114
 
115
- name_pattern = r"([\u4e00-\u9fa5]{2,4})(?![\u7684\u5730\u5f97\u5566\u554a\u5440])"
116
- id_pattern = r"(?<!\S)([a-zA-Z_][a-zA-Z0-9_]{4,})(?![\u4e00-\u9fa5])"
 
117
 
118
  entities = []
119
  occupied = set()
@@ -128,7 +137,7 @@ def ner(text, model_type="bert"):
128
  "text": match.group(1),
129
  "start": start,
130
  "end": end,
131
- "type": "PersonName"
132
  })
133
  occupied.add((start, end))
134
 
@@ -139,7 +148,7 @@ def ner(text, model_type="bert"):
139
  "text": match.group(1),
140
  "start": start,
141
  "end": end,
142
- "type": "UserID"
143
  })
144
  occupied.add((start, end))
145
 
@@ -149,15 +158,26 @@ def ner(text, model_type="bert"):
149
 
150
  def re_extract(entities, text):
151
  relations = []
152
- for i, entity1 in enumerate(entities):
153
- for j, entity2 in enumerate(entities):
154
- if i != j:
155
- relation = {
156
- "head": entity1['text'],
157
- "tail": entity2['text'],
158
- "relation": "联系"
159
- }
160
- relations.append(relation)
 
 
 
 
 
 
 
 
 
 
 
161
  return relations
162
 
163
 
@@ -177,7 +197,7 @@ def process_text(text, model_type="bert"):
177
  )
178
  kg_html = visualize_kg()
179
 
180
- return entity_output, relation_output, gr.HTML(kg_html), f"\u5904\u7406\u65f6\u95f4:{processing_time:.2f}\u79d2"
181
 
182
  except Exception as e:
183
  return f"处理出错: {str(e)}", "", gr.HTML(), ""
@@ -185,10 +205,7 @@ def process_text(text, model_type="bert"):
185
 
186
  def process_file(file, model_type="bert"):
187
  try:
188
- # 读取文件内容(适配 Hugging Face 文件系统)
189
  content_bytes = file.read()
190
-
191
- # 文件大小限制(5MB)
192
  if len(content_bytes) > 5 * 1024 * 1024:
193
  return "❌ 文件大小超过5MB限制", "", gr.HTML(), ""
194
 
@@ -226,46 +243,31 @@ css = """
226
  """
227
 
228
  with gr.Blocks(css=css) as demo:
229
- gr.Markdown("# 🚀 智能聊天记录分析系统\n**功能**: 实体识别 → 关系抽取 → 动态知识图谱")
230
 
231
  with gr.Tab("✍️ 文本分析"):
232
- gr.Markdown("### 直接输入聊天内容")
233
  input_text = gr.Textbox(label="输入内容", lines=8,
234
  placeholder="示例:张三@李四 请把需求文档_v2发送给王五")
235
- model_type = gr.Radio(["bert", "llama"], label="选择模型", value="bert")
236
  analyze_btn = gr.Button("开始分析", variant="primary")
237
 
238
  with gr.Row():
239
- entity_output = gr.Textbox(label="识别的实体", lines=6, interactive=False)
240
- relation_output = gr.Textbox(label="提取的关系", lines=6, interactive=False)
241
  kg_output = gr.HTML(label="知识图谱")
242
- time_output = gr.Textbox(label="处理时间", interactive=False)
243
-
244
- analyze_btn.click(
245
- process_text,
246
- inputs=[input_text, model_type],
247
- outputs=[entity_output, relation_output, kg_output, time_output],
248
- show_progress="full"
249
- )
250
 
251
  with gr.Tab("📄 文件分析"):
252
- gr.Markdown("### 上传文件进行分析(支持 .txt, .jsonl, .json, .csv 格式)")
253
- file_input = gr.File(label="选择文件", type="file")
254
  analyze_file_btn = gr.Button("开始分析文件", variant="primary")
255
- file_entity_output = gr.Textbox(label="识别的实体", lines=6, interactive=False)
256
- file_relation_output = gr.Textbox(label="提取的关系", lines=6, interactive=False)
257
  file_kg_output = gr.HTML(label="知识图谱")
258
- file_time_output = gr.Textbox(label="处理时间", interactive=False)
259
 
260
- analyze_file_btn.click(
261
- process_file,
262
- inputs=[file_input, model_type],
263
- outputs=[file_entity_output, file_relation_output, file_kg_output, file_time_output],
264
- show_progress="full"
265
- )
266
 
267
- demo.launch(
268
- server_name="0.0.0.0",
269
- server_port=7860,
270
- debug=False
271
- )
 
1
  import torch
2
+ from transformers import AutoTokenizer, AutoModel, BertTokenizer, BertModel
3
  import gradio as gr
4
  import re
5
  import os
 
9
  from pyvis.network import Network
10
  import time
11
 
12
+ # 初始化模型
13
  bert_model_name = "bert-base-chinese"
14
  bert_tokenizer = BertTokenizer.from_pretrained(bert_model_name)
15
  bert_model = BertModel.from_pretrained(bert_model_name)
16
 
17
+ # 加载中文模型 ChatGLM3-6B
18
+ chatglm_model_name = "THUDM/chatglm3-6b"
19
+ chatglm_tokenizer = AutoTokenizer.from_pretrained(
20
+ chatglm_model_name,
21
+ trust_remote_code=True
 
 
 
 
 
 
 
 
22
  )
23
+ chatglm_model = AutoModel.from_pretrained(
24
+ chatglm_model_name,
25
+ trust_remote_code=True,
26
+ device_map="auto",
27
+ torch_dtype=torch.float16
28
+ ).eval()
29
 
30
  # 知识图谱数据存储
31
  knowledge_graph = {
 
104
  def ner(text, model_type="bert"):
105
  start_time = time.time()
106
  if model_type == "bert":
107
+ # BERT 中文实体识别(原逻辑保留)
108
+ name_pattern = r"([\u4e00-\u9fa5]{2,4})(?![的等地得啦啊哦])"
109
+ id_pattern = r"(?<!\S)([a-zA-Z_][a-zA-Z0-9_]{4,})(?![\\u4e00-\\u9fa5])"
110
+ else:
111
+ # ChatGLM 增强实体识别
112
+ response, _ = chatglm_model.chat(
113
+ chatglm_tokenizer,
114
+ f"请从以下文本中识别所有实体,用JSON格式返回:[{text}]",
115
+ temperature=0.1
116
+ )
117
+ try:
118
+ entities = json.loads(response)
119
+ return entities, time.time() - start_time
120
+ except:
121
+ pass
122
 
123
+ # 如果模型响应失败,使用备用正则
124
+ name_pattern = r"([\\u4e00-\\u9fa5]{2,4})(?![的等地得啦啊哦])"
125
+ id_pattern = r"(?<!\S)([a-zA-Z_][a-zA-Z0-9_]{4,})"
126
 
127
  entities = []
128
  occupied = set()
 
137
  "text": match.group(1),
138
  "start": start,
139
  "end": end,
140
+ "type": "人名"
141
  })
142
  occupied.add((start, end))
143
 
 
148
  "text": match.group(1),
149
  "start": start,
150
  "end": end,
151
+ "type": "用户ID"
152
  })
153
  occupied.add((start, end))
154
 
 
158
 
159
  def re_extract(entities, text):
160
  relations = []
161
+ if len(entities) < 2:
162
+ return relations
163
+
164
+ # 使用ChatGLM分析关系
165
+ entity_list = [e['text'] for e in entities]
166
+ prompt = f"分析以下实体之间的关系:{entity_list}\n文本上下文:{text}"
167
+ response, _ = chatglm_model.chat(chatglm_tokenizer, prompt, temperature=0.1)
168
+
169
+ try:
170
+ relations = json.loads(response)
171
+ except:
172
+ # 备用简单关系生成
173
+ for i in range(len(entities)):
174
+ for j in range(i + 1, len(entities)):
175
+ relations.append({
176
+ "head": entities[i]['text'],
177
+ "tail": entities[j]['text'],
178
+ "relation": "相关"
179
+ })
180
+
181
  return relations
182
 
183
 
 
197
  )
198
  kg_html = visualize_kg()
199
 
200
+ return entity_output, relation_output, gr.HTML(kg_html), f"处理时间:{processing_time:.2f}"
201
 
202
  except Exception as e:
203
  return f"处理出错: {str(e)}", "", gr.HTML(), ""
 
205
 
206
  def process_file(file, model_type="bert"):
207
  try:
 
208
  content_bytes = file.read()
 
 
209
  if len(content_bytes) > 5 * 1024 * 1024:
210
  return "❌ 文件大小超过5MB限制", "", gr.HTML(), ""
211
 
 
243
  """
244
 
245
  with gr.Blocks(css=css) as demo:
246
+ gr.Markdown("# 🚀 智能聊天记录分析系统(ChatGLM3-6B版)")
247
 
248
  with gr.Tab("✍️ 文本分析"):
 
249
  input_text = gr.Textbox(label="输入内容", lines=8,
250
  placeholder="示例:张三@李四 请把需求文档_v2发送给王五")
251
+ model_type = gr.Radio(["bert", "chatglm"], label="选择模型", value="bert")
252
  analyze_btn = gr.Button("开始分析", variant="primary")
253
 
254
  with gr.Row():
255
+ entity_output = gr.Textbox(label="识别的实体", lines=6)
256
+ relation_output = gr.Textbox(label="提取的关系", lines=6)
257
  kg_output = gr.HTML(label="知识图谱")
258
+ time_output = gr.Textbox(label="处理时间")
 
 
 
 
 
 
 
259
 
260
  with gr.Tab("📄 文件分析"):
261
+ file_input = gr.File(label="选择文件", file_types=[".txt", ".csv", ".json"])
 
262
  analyze_file_btn = gr.Button("开始分析文件", variant="primary")
263
+ file_entity_output = gr.Textbox(label="识别的实体", lines=6)
264
+ file_relation_output = gr.Textbox(label="提取的关系", lines=6)
265
  file_kg_output = gr.HTML(label="知识图谱")
266
+ file_time_output = gr.Textbox(label="处理时间")
267
 
268
+ analyze_btn.click(process_text, [input_text, model_type],
269
+ [entity_output, relation_output, kg_output, time_output])
270
+ analyze_file_btn.click(process_file, [file_input, model_type],
271
+ [file_entity_output, file_relation_output, file_kg_output, file_time_output])
 
 
272
 
273
+ demo.launch(server_name="0.0.0.0", server_port=7860)
 
 
 
 
requirements.txt CHANGED
@@ -1,10 +1,11 @@
1
  gradio==3.50.2
2
  transformers==4.39.3
3
  torch>=2.1.0
4
- accelerate>=0.27.0
5
- sentencepiece>=0.2.0
6
  pandas>=2.0.0
7
  chardet>=5.0.0
8
  networkx>=3.0
9
  pyvis>=0.3.2
10
- python-dotenv>=1.0.0
 
 
 
 
1
  gradio==3.50.2
2
  transformers==4.39.3
3
  torch>=2.1.0
 
 
4
  pandas>=2.0.0
5
  chardet>=5.0.0
6
  networkx>=3.0
7
  pyvis>=0.3.2
8
+ python-dotenv>=1.0.0
9
+ sentencepiece>=0.2.0
10
+ cpm-kernels>=1.0.11
11
+ accelerate>=0.27.0