chen666-666 commited on
Commit
26ec260
·
1 Parent(s): 1a6560a

add app.py and requirements.txt

Browse files
Files changed (2) hide show
  1. app.py +107 -186
  2. requirements.txt +5 -4
app.py CHANGED
@@ -8,32 +8,20 @@ import pandas as pd
8
  import chardet
9
  from pyvis.network import Network
10
  import time
 
11
 
12
- # 初始化模型
13
  bert_model_name = "bert-base-chinese"
14
  bert_tokenizer = BertTokenizer.from_pretrained(bert_model_name)
15
  bert_model = BertModel.from_pretrained(bert_model_name)
16
 
17
- # 加载中文模型 ChatGLM3-6B
18
  chatglm_model_name = "THUDM/chatglm3-6b"
19
- chatglm_tokenizer = AutoTokenizer.from_pretrained(
20
- chatglm_model_name,
21
- trust_remote_code=True
22
- )
23
- chatglm_model = AutoModel.from_pretrained(
24
- chatglm_model_name,
25
- trust_remote_code=True,
26
- device_map="auto",
27
- torch_dtype=torch.float16
28
- ).eval()
29
-
30
- # 知识图谱数据存储
31
- knowledge_graph = {
32
- "entities": set(),
33
- "relations": []
34
- }
35
 
 
36
 
 
37
  def update_knowledge_graph(entities, relations):
38
  for e in entities:
39
  if isinstance(e, dict) and 'text' in e and 'type' in e:
@@ -42,232 +30,165 @@ def update_knowledge_graph(entities, relations):
42
  if isinstance(r, dict) and all(k in r for k in ("head", "tail", "relation")):
43
  knowledge_graph["relations"].append((r['head'], r['tail'], r['relation']))
44
 
45
-
46
  def visualize_kg():
47
  net = Network(height="600px", width="100%", notebook=True, directed=True)
48
  node_map = {}
49
  idx = 0
50
  for ent in knowledge_graph["entities"]:
51
- if isinstance(ent, tuple) and len(ent) == 2:
52
- name, type_ = ent
53
- node_map[name] = idx
54
- net.add_node(idx,
55
- label=name,
56
- title=f"类型:{type_}",
57
- group=type_,
58
- font={'size': 20, 'face': 'SimHei'})
59
- idx += 1
60
-
61
  seen_edges = set()
62
  for head, tail, relation in knowledge_graph["relations"]:
63
  if head in node_map and tail in node_map:
64
  edge_key = f"{head}-{tail}-{relation}"
65
  if edge_key not in seen_edges:
66
- net.add_edge(node_map[head], node_map[tail],
67
- label=relation,
68
- arrows='to',
69
- font={'size': 14})
70
  seen_edges.add(edge_key)
71
-
72
- net.set_options("""
73
- {
74
- "nodes": {
75
- "scaling": {
76
- "min": 20,
77
- "max": 40
78
- }
79
- },
80
- "physics": {
81
- "stabilization": {
82
- "enabled": true,
83
- "iterations": 200,
84
- "updateInterval": 25
85
- },
86
- "barnesHut": {
87
- "gravitationalConstant": -2000,
88
- "springLength": 150
89
- }
90
- },
91
- "interaction": {
92
- "hover": true,
93
- "tooltipDelay": 200
94
- }
95
- }
96
- """)
97
-
98
- html = net.generate_html()
99
- html = html.replace('//cdnjs.cloudflare.com', 'https://cdnjs.cloudflare.com')
100
- html = html.replace('//unpkg.com', 'https://unpkg.com')
101
  return f'<div class="kg-graph">{html}</div>'
102
 
103
-
104
  def ner(text, model_type="bert"):
105
  start_time = time.time()
106
  if model_type == "bert":
107
- # BERT 中文实体识别(原逻辑保留)
108
  name_pattern = r"([\u4e00-\u9fa5]{2,4})(?![的等地得啦啊哦])"
109
- id_pattern = r"(?<!\S)([a-zA-Z_][a-zA-Z0-9_]{4,})(?![\\u4e00-\\u9fa5])"
110
  else:
111
- # ChatGLM 增强实体识别
112
- response, _ = chatglm_model.chat(
113
- chatglm_tokenizer,
114
- f"请从以下文本中识别所有实体,用JSON格式返回:[{text}]",
115
- temperature=0.1
116
- )
117
  try:
118
  entities = json.loads(response)
119
  return entities, time.time() - start_time
120
  except:
121
- pass
122
-
123
- # 如果模型响应失败,使用备用正则
124
- name_pattern = r"([\\u4e00-\\u9fa5]{2,4})(?![的等地得啦啊哦])"
125
- id_pattern = r"(?<!\S)([a-zA-Z_][a-zA-Z0-9_]{4,})"
126
 
127
  entities = []
128
  occupied = set()
129
-
130
  def is_occupied(start, end):
131
  return any(s <= start < e or s < end <= e for s, e in occupied)
132
 
133
  for match in re.finditer(name_pattern, text):
134
  start, end = match.start(1), match.end(1)
135
  if not is_occupied(start, end):
136
- entities.append({
137
- "text": match.group(1),
138
- "start": start,
139
- "end": end,
140
- "type": "人名"
141
- })
142
  occupied.add((start, end))
143
 
144
  for match in re.finditer(id_pattern, text):
145
  start, end = match.start(1), match.end(1)
146
  if not is_occupied(start, end):
147
- entities.append({
148
- "text": match.group(1),
149
- "start": start,
150
- "end": end,
151
- "type": "用户ID"
152
- })
153
  occupied.add((start, end))
154
 
155
- processing_time = time.time() - start_time
156
- return entities, processing_time
157
-
158
 
159
  def re_extract(entities, text):
160
  relations = []
161
  if len(entities) < 2:
162
  return relations
163
-
164
- # 使用ChatGLM分析关系
165
  entity_list = [e['text'] for e in entities]
166
  prompt = f"分析以下实体之间的关系:{entity_list}\n文本上下文:{text}"
167
  response, _ = chatglm_model.chat(chatglm_tokenizer, prompt, temperature=0.1)
168
-
169
  try:
170
  relations = json.loads(response)
171
  except:
172
- # 备用简单关系生成
173
  for i in range(len(entities)):
174
  for j in range(i + 1, len(entities)):
175
- relations.append({
176
- "head": entities[i]['text'],
177
- "tail": entities[j]['text'],
178
- "relation": "相关"
179
- })
180
-
181
  return relations
182
 
183
-
184
  def process_text(text, model_type="bert"):
185
- try:
186
- entities, processing_time = ner(text, model_type=model_type)
187
- relations = re_extract(entities, text)
188
- update_knowledge_graph(entities, relations)
189
-
190
- entity_output = "\n".join(
191
- f"{e['text']} ({e['type']}) [{e['start']}-{e['end']}]"
192
- for e in entities
193
- )
194
- relation_output = "\n".join(
195
- f"{r['head']} --[{r['relation']}]-> {r['tail']}"
196
- for r in relations
197
- )
198
- kg_html = visualize_kg()
199
-
200
- return entity_output, relation_output, gr.HTML(kg_html), f"处理时间:{processing_time:.2f}秒"
201
-
202
- except Exception as e:
203
- return f"处理出错: {str(e)}", "", gr.HTML(), ""
204
-
205
 
206
  def process_file(file, model_type="bert"):
207
- try:
208
- content_bytes = file.read()
209
- if len(content_bytes) > 5 * 1024 * 1024:
210
- return "❌ 文件大小超过5MB限制", "", gr.HTML(), ""
211
-
212
- encoding = chardet.detect(content_bytes)['encoding'] or 'utf-8'
213
- full_text = content_bytes.decode(encoding)
214
- ext = os.path.splitext(file.name)[-1].lower()
215
-
216
- if ext == ".csv":
217
- df = pd.read_csv(file.name)
218
- if 'text' in df.columns:
219
- full_text = "\n".join(df['text'].astype(str))
220
- else:
221
- return "❌ CSV文件中缺少text", "", gr.HTML(), ""
222
-
223
- return process_text(full_text, model_type)
224
-
225
- except Exception as e:
226
- return f"❌ 文件处理错误: {str(e)}", "", gr.HTML(), ""
227
-
228
-
229
- # Gradio UI
230
- css = """
231
- .kg-container {
232
- border: 1px solid #e0e0e0;
233
- border-radius: 10px;
234
- padding: 20px;
235
- margin: 20px 0;
236
- background: white;
237
- box-shadow: 0 2px 8px rgba(0,0,0,0.1);
238
- }
239
- .kg-graph {
240
- width: 100%;
241
- height: 600px;
242
- }
243
- """
244
-
 
 
 
 
 
 
 
245
  with gr.Blocks(css=css) as demo:
246
- gr.Markdown("# 🚀 智能聊天记录分析系统(ChatGLM3-6B版)")
247
 
248
  with gr.Tab("✍️ 文本分析"):
249
- input_text = gr.Textbox(label="输入内容", lines=8,
250
- placeholder="示例:张三@李四 请把需求文档_v2发送给王五")
251
- model_type = gr.Radio(["bert", "chatglm"], label="选择模型", value="bert")
252
- analyze_btn = gr.Button("开始分析", variant="primary")
253
-
254
- with gr.Row():
255
- entity_output = gr.Textbox(label="识别的实体", lines=6)
256
- relation_output = gr.Textbox(label="提取的关系", lines=6)
257
- kg_output = gr.HTML(label="知识图谱")
258
- time_output = gr.Textbox(label="处理时间")
259
 
260
  with gr.Tab("📄 文件分析"):
261
- file_input = gr.File(label="选择文件", file_types=[".txt", ".csv", ".json"])
262
- analyze_file_btn = gr.Button("开始分析文件", variant="primary")
263
- file_entity_output = gr.Textbox(label="识别的实体", lines=6)
264
- file_relation_output = gr.Textbox(label="提取的关系", lines=6)
265
- file_kg_output = gr.HTML(label="知识图谱")
266
- file_time_output = gr.Textbox(label="处理时间")
267
-
268
- analyze_btn.click(process_text, [input_text, model_type],
269
- [entity_output, relation_output, kg_output, time_output])
270
- analyze_file_btn.click(process_file, [file_input, model_type],
271
- [file_entity_output, file_relation_output, file_kg_output, file_time_output])
272
-
273
- demo.launch(server_name="0.0.0.0", server_port=7860)
 
 
 
 
 
 
 
 
 
 
8
  import chardet
9
  from pyvis.network import Network
10
  import time
11
+ from sklearn.metrics import precision_score, recall_score, f1_score
12
 
13
+ # ==== 模型初始化 ====
14
  bert_model_name = "bert-base-chinese"
15
  bert_tokenizer = BertTokenizer.from_pretrained(bert_model_name)
16
  bert_model = BertModel.from_pretrained(bert_model_name)
17
 
 
18
  chatglm_model_name = "THUDM/chatglm3-6b"
19
+ chatglm_tokenizer = AutoTokenizer.from_pretrained(chatglm_model_name, trust_remote_code=True)
20
+ chatglm_model = AutoModel.from_pretrained(chatglm_model_name, trust_remote_code=True, device_map="auto", torch_dtype=torch.float16).eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
+ knowledge_graph = {"entities": set(), "relations": []}
23
 
24
+ # ==== 核心处理函数 ====
25
  def update_knowledge_graph(entities, relations):
26
  for e in entities:
27
  if isinstance(e, dict) and 'text' in e and 'type' in e:
 
30
  if isinstance(r, dict) and all(k in r for k in ("head", "tail", "relation")):
31
  knowledge_graph["relations"].append((r['head'], r['tail'], r['relation']))
32
 
 
33
  def visualize_kg():
34
  net = Network(height="600px", width="100%", notebook=True, directed=True)
35
  node_map = {}
36
  idx = 0
37
  for ent in knowledge_graph["entities"]:
38
+ name, type_ = ent
39
+ node_map[name] = idx
40
+ net.add_node(idx, label=name, title=f"类型:{type_}", group=type_, font={'size': 20, 'face': 'SimHei'})
41
+ idx += 1
 
 
 
 
 
 
42
  seen_edges = set()
43
  for head, tail, relation in knowledge_graph["relations"]:
44
  if head in node_map and tail in node_map:
45
  edge_key = f"{head}-{tail}-{relation}"
46
  if edge_key not in seen_edges:
47
+ net.add_edge(node_map[head], node_map[tail], label=relation, arrows='to', font={'size': 14})
 
 
 
48
  seen_edges.add(edge_key)
49
+ net.set_options("""{
50
+ "nodes": {"scaling": {"min": 20, "max": 40}},
51
+ "physics": {"stabilization": {"enabled": true, "iterations": 200}, "barnesHut": {"gravitationalConstant": -2000, "springLength": 150}},
52
+ "interaction": {"hover": true, "tooltipDelay": 200}
53
+ }""")
54
+ html = net.generate_html().replace('//cdnjs.cloudflare.com', 'https://cdnjs.cloudflare.com').replace('//unpkg.com', 'https://unpkg.com')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  return f'<div class="kg-graph">{html}</div>'
56
 
 
57
  def ner(text, model_type="bert"):
58
  start_time = time.time()
59
  if model_type == "bert":
 
60
  name_pattern = r"([\u4e00-\u9fa5]{2,4})(?![的等地得啦啊哦])"
61
+ id_pattern = r"(?<!\S)([a-zA-Z_][a-zA-Z0-9_]{4,})(?![\u4e00-\u9fa5])"
62
  else:
63
+ response, _ = chatglm_model.chat(chatglm_tokenizer, f"请从以下文本中识别所有实体,用JSON格式返回:[{text}]", temperature=0.1)
 
 
 
 
 
64
  try:
65
  entities = json.loads(response)
66
  return entities, time.time() - start_time
67
  except:
68
+ name_pattern = r"([\u4e00-\u9fa5]{2,4})(?![的等地得啦啊哦])"
69
+ id_pattern = r"(?<!\S)([a-zA-Z_][a-zA-Z0-9_]{4,})"
 
 
 
70
 
71
  entities = []
72
  occupied = set()
 
73
  def is_occupied(start, end):
74
  return any(s <= start < e or s < end <= e for s, e in occupied)
75
 
76
  for match in re.finditer(name_pattern, text):
77
  start, end = match.start(1), match.end(1)
78
  if not is_occupied(start, end):
79
+ entities.append({"text": match.group(1), "start": start, "end": end, "type": "人名"})
 
 
 
 
 
80
  occupied.add((start, end))
81
 
82
  for match in re.finditer(id_pattern, text):
83
  start, end = match.start(1), match.end(1)
84
  if not is_occupied(start, end):
85
+ entities.append({"text": match.group(1), "start": start, "end": end, "type": "用户ID"})
 
 
 
 
 
86
  occupied.add((start, end))
87
 
88
+ return entities, time.time() - start_time
 
 
89
 
90
  def re_extract(entities, text):
91
  relations = []
92
  if len(entities) < 2:
93
  return relations
 
 
94
  entity_list = [e['text'] for e in entities]
95
  prompt = f"分析以下实体之间的关系:{entity_list}\n文本上下文:{text}"
96
  response, _ = chatglm_model.chat(chatglm_tokenizer, prompt, temperature=0.1)
 
97
  try:
98
  relations = json.loads(response)
99
  except:
 
100
  for i in range(len(entities)):
101
  for j in range(i + 1, len(entities)):
102
+ relations.append({"head": entities[i]['text'], "tail": entities[j]['text'], "relation": "相关"})
 
 
 
 
 
103
  return relations
104
 
 
105
  def process_text(text, model_type="bert"):
106
+ entities, processing_time = ner(text, model_type)
107
+ relations = re_extract(entities, text)
108
+ update_knowledge_graph(entities, relations)
109
+ entity_output = "\n".join(f"{e['text']} ({e['type']}) [{e['start']}-{e['end']}]" for e in entities)
110
+ relation_output = "\n".join(f"{r['head']} --[{r['relation']}]-> {r['tail']}" for r in relations)
111
+ return entity_output, relation_output, gr.HTML(visualize_kg()), f"处理时间:{processing_time:.2f}秒"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
 
113
  def process_file(file, model_type="bert"):
114
+ content_bytes = file.read()
115
+ if len(content_bytes) > 5 * 1024 * 1024:
116
+ return "❌ 文件太大", "", gr.HTML(), ""
117
+ encoding = chardet.detect(content_bytes)['encoding'] or 'utf-8'
118
+ text = content_bytes.decode(encoding)
119
+ return process_text(text, model_type)
120
+
121
+ # ==== 评估功能与自动标注 ====
122
+ def convert_telegram_json_to_eval_format(path):
123
+ data = json.load(open(path, encoding="utf-8"))
124
+ result = []
125
+ for m in data.get("messages", []):
126
+ if isinstance(m.get("text"), str):
127
+ result.append({"text": m["text"], "entities": []})
128
+ elif isinstance(m.get("text"), list):
129
+ txt = ''.join([x["text"] if isinstance(x, dict) else x for x in m["text"]])
130
+ result.append({"text": txt, "entities": []})
131
+ return result
132
+
133
+ def evaluate_ner_model(data, model_type):
134
+ y_true, y_pred = [], []
135
+ for item in data:
136
+ gold = set(e['text'] for e in item['entities'])
137
+ pred, _ = ner(item['text'], model_type)
138
+ pred = set(e['text'] for e in pred)
139
+ for ent in gold.union(pred):
140
+ y_true.append(1 if ent in gold else 0)
141
+ y_pred.append(1 if ent in pred else 0)
142
+ return f"📊 {model_type} 实体识别评估:\nPrecision: {precision_score(y_true,y_pred):.2f}\nRecall: {recall_score(y_true,y_pred):.2f}\nF1: {f1_score(y_true,y_pred):.2f}"
143
+
144
+ def auto_annotate(file, model_type):
145
+ data = convert_telegram_json_to_eval_format(file.name)
146
+ for item in data:
147
+ ents, _ = ner(item["text"], model_type)
148
+ item["entities"] = ents
149
+ return json.dumps(data, ensure_ascii=False, indent=2)
150
+
151
+ def save_json(json_text):
152
+ fname = "auto_labeled.json"
153
+ with open(fname, "w", encoding="utf-8") as f:
154
+ f.write(json_text)
155
+ return fname
156
+
157
+ # ==== Gradio UI ====
158
+ css = ".kg-graph { height: 600px; }"
159
  with gr.Blocks(css=css) as demo:
160
+ gr.Markdown("# 🚀 智能聊天分析系统 + 标注评估工具")
161
 
162
  with gr.Tab("✍️ 文本分析"):
163
+ input_text = gr.Textbox(lines=6, label="输入内容")
164
+ model_type = gr.Radio(["bert", "chatglm"], value="bert", label="模型")
165
+ btn = gr.Button("开始分析")
166
+ out1 = gr.Textbox(label="实体")
167
+ out2 = gr.Textbox(label="关系")
168
+ out3 = gr.HTML()
169
+ out4 = gr.Textbox(label="耗时")
170
+ btn.click(fn=process_text, inputs=[input_text, model_type], outputs=[out1, out2, out3, out4])
 
 
171
 
172
  with gr.Tab("📄 文件分析"):
173
+ file_input = gr.File(file_types=[".txt", ".json", ".csv"])
174
+ btn2 = gr.Button("分析文件")
175
+ fout1, fout2, fout3, fout4 = gr.Textbox(), gr.Textbox(), gr.HTML(), gr.Textbox()
176
+ btn2.click(fn=process_file, inputs=[file_input, model_type], outputs=[fout1, fout2, fout3, fout4])
177
+
178
+ with gr.Tab("📊 模型评估"):
179
+ eval_file = gr.File(label="上传标注数据集")
180
+ eval_model = gr.Radio(["bert", "chatglm"], value="bert")
181
+ eval_btn = gr.Button("开始评估")
182
+ eval_output = gr.Textbox(label="评估结果", lines=5)
183
+ eval_btn.click(lambda f, m: evaluate_ner_model(convert_telegram_json_to_eval_format(f.name), m), [eval_file, eval_model], eval_output)
184
+
185
+ with gr.Tab("🖍 实体标注助手"):
186
+ raw_file = gr.File(label="上传原始 Telegram JSON")
187
+ auto_model = gr.Radio(["bert", "chatglm"], value="bert")
188
+ auto_btn = gr.Button("自动初标")
189
+ marked_texts = gr.Textbox(label="初步标注结果(可下载)", lines=20)
190
+ download_btn = gr.Button("💾 下载JSON")
191
+ auto_btn.click(fn=auto_annotate, inputs=[raw_file, auto_model], outputs=marked_texts)
192
+ download_btn.click(fn=save_json, inputs=marked_texts, outputs=gr.File())
193
+
194
+ demo.launch(server_name="0.0.0.0", server_port=7860)
requirements.txt CHANGED
@@ -1,11 +1,12 @@
1
  gradio==3.50.2
2
  transformers==4.39.3
3
  torch>=2.1.0
4
- pandas>=2.0.0
5
- chardet>=5.0.0
6
  networkx>=3.0
7
- pyvis>=0.3.2
8
  python-dotenv>=1.0.0
9
  sentencepiece>=0.2.0
10
  cpm-kernels>=1.0.11
11
- accelerate>=0.27.0
 
 
 
 
 
1
  gradio==3.50.2
2
  transformers==4.39.3
3
  torch>=2.1.0
 
 
4
  networkx>=3.0
 
5
  python-dotenv>=1.0.0
6
  sentencepiece>=0.2.0
7
  cpm-kernels>=1.0.11
8
+ accelerate>=0.27.0
9
+ scikit-learn>=1.3.0
10
+ chardet>=5.2.0
11
+ pandas>=2.1.0
12
+ pyvis>=0.3.2