chen666-666 commited on
Commit
1eeeaf5
·
1 Parent(s): ee61e9e

add app.py and requirements.txt

Browse files
Files changed (1) hide show
  1. app.py +19 -24
app.py CHANGED
@@ -4,29 +4,23 @@ import gradio as gr
4
  import re
5
  import os
6
  import json
7
- import pandas as pd
8
  import chardet
9
  from sklearn.metrics import precision_score, recall_score, f1_score
10
  import time
11
 
12
- # ==== 模型初始化 ====
13
  bert_model_name = "bert-base-chinese"
14
  bert_tokenizer = BertTokenizer.from_pretrained(bert_model_name)
15
  bert_model = BertModel.from_pretrained(bert_model_name)
16
 
17
- # chatglm3 模型检测与安全加载
18
  chatglm_model, chatglm_tokenizer = None, None
19
  use_chatglm = False
20
-
21
  try:
22
  if torch.cuda.is_available():
23
  chatglm_model_name = "THUDM/chatglm3-6b"
24
  chatglm_tokenizer = AutoTokenizer.from_pretrained(chatglm_model_name, trust_remote_code=True)
25
  chatglm_model = AutoModel.from_pretrained(
26
- chatglm_model_name,
27
- trust_remote_code=True,
28
- device_map="auto",
29
- torch_dtype=torch.float16
30
  ).eval()
31
  use_chatglm = True
32
  else:
@@ -34,7 +28,7 @@ try:
34
  except Exception as e:
35
  print(f"❌ ChatGLM 加载失败: {e}")
36
 
37
- # ==== 知识图谱数据结构 ====
38
  knowledge_graph = {"entities": set(), "relations": []}
39
 
40
  def update_knowledge_graph(entities, relations):
@@ -50,16 +44,17 @@ def visualize_kg_text():
50
  edges = [f"{h} --[{r}]-> {t}" for h, t, r in knowledge_graph["relations"]]
51
  return "📌 实体:\n" + "\n".join(nodes) + "\n\n📎 关系:\n" + "\n".join(edges)
52
 
53
- # ==== 实体识别函数 ====
54
  def ner(text, model_type="bert"):
55
  start_time = time.time()
56
  if model_type == "chatglm" and use_chatglm:
57
  try:
58
- response, _ = chatglm_model.chat(chatglm_tokenizer, f"请从以下文本中识别所有实体,用JSON格式返回:[{text}]", temperature=0.1)
 
59
  entities = json.loads(response)
60
  return entities, time.time() - start_time
61
  except Exception as e:
62
- print(f"❌ ChatGLM 解析失败:{e}")
63
  return [], time.time() - start_time
64
 
65
  name_pattern = r"([\u4e00-\u9fa5]{2,4})(?![的等地得啦啊哦])"
@@ -83,7 +78,7 @@ def ner(text, model_type="bert"):
83
 
84
  return entities, time.time() - start_time
85
 
86
- # ==== 关系抽取 ====
87
  def re_extract(entities, text):
88
  if len(entities) < 2:
89
  return []
@@ -93,11 +88,11 @@ def re_extract(entities, text):
93
  if use_chatglm:
94
  response, _ = chatglm_model.chat(chatglm_tokenizer, prompt, temperature=0.1)
95
  return json.loads(response)
96
- except:
97
- pass
98
  return [{"head": e1['text'], "tail": e2['text'], "relation": "相关"} for i, e1 in enumerate(entities) for e2 in entities[i+1:]]
99
 
100
- # ==== 文本处理主流程 ====
101
  def process_text(text, model_type="bert"):
102
  entities, duration = ner(text, model_type)
103
  relations = re_extract(entities, text)
@@ -106,7 +101,7 @@ def process_text(text, model_type="bert"):
106
  ent_text = "\n".join(f"{e['text']} ({e['type']}) [{e['start']}-{e['end']}]" for e in entities)
107
  rel_text = "\n".join(f"{r['head']} --[{r['relation']}]-> {r['tail']}" for r in relations)
108
  kg_text = visualize_kg_text()
109
- return ent_text, rel_text, kg_text, f"{duration:.2f}秒"
110
 
111
  def process_file(file, model_type="bert"):
112
  content = file.read()
@@ -116,7 +111,7 @@ def process_file(file, model_type="bert"):
116
  text = content.decode(encoding)
117
  return process_text(text, model_type)
118
 
119
- # ==== 模型评估与初标注 ====
120
  def convert_telegram_json_to_eval_format(path):
121
  data = json.load(open(path, encoding="utf-8"))
122
  result = []
@@ -152,7 +147,7 @@ def save_json(json_text):
152
  f.write(json_text)
153
  return fname
154
 
155
- # ==== Gradio 界面 ====
156
  with gr.Blocks(css=".kg-graph {height: 500px;}") as demo:
157
  gr.Markdown("# 🤖 聊天记录实体关系识别系统")
158
 
@@ -162,7 +157,7 @@ with gr.Blocks(css=".kg-graph {height: 500px;}") as demo:
162
  btn = gr.Button("开始分析")
163
  out1 = gr.Textbox(label="识别实体")
164
  out2 = gr.Textbox(label="识别关系")
165
- out3 = gr.Textbox(label="知识图谱文本")
166
  out4 = gr.Textbox(label="耗时")
167
  btn.click(fn=process_text, inputs=[input_text, model_type], outputs=[out1, out2, out3, out4])
168
 
@@ -179,12 +174,12 @@ with gr.Blocks(css=".kg-graph {height: 500px;}") as demo:
179
  eval_output = gr.Textbox(label="评估结果", lines=5)
180
  eval_btn.click(lambda f, m: evaluate_ner_model(convert_telegram_json_to_eval_format(f.name), m), [eval_file, eval_model], eval_output)
181
 
182
- with gr.Tab("✏️ 实体自动标注"):
183
  raw_file = gr.File(label="上传 Telegram 原始 JSON")
184
  auto_model = gr.Radio(["bert", "chatglm"], value="bert")
185
- auto_btn = gr.Button("自动初标")
186
- marked_texts = gr.Textbox(label="自动标注结果", lines=20)
187
- download_btn = gr.Button("💾 下载JSON")
188
  auto_btn.click(fn=auto_annotate, inputs=[raw_file, auto_model], outputs=marked_texts)
189
  download_btn.click(fn=save_json, inputs=marked_texts, outputs=gr.File())
190
 
 
4
  import re
5
  import os
6
  import json
 
7
  import chardet
8
  from sklearn.metrics import precision_score, recall_score, f1_score
9
  import time
10
 
11
+ # ======================== 模型加载 ========================
12
  bert_model_name = "bert-base-chinese"
13
  bert_tokenizer = BertTokenizer.from_pretrained(bert_model_name)
14
  bert_model = BertModel.from_pretrained(bert_model_name)
15
 
 
16
  chatglm_model, chatglm_tokenizer = None, None
17
  use_chatglm = False
 
18
  try:
19
  if torch.cuda.is_available():
20
  chatglm_model_name = "THUDM/chatglm3-6b"
21
  chatglm_tokenizer = AutoTokenizer.from_pretrained(chatglm_model_name, trust_remote_code=True)
22
  chatglm_model = AutoModel.from_pretrained(
23
+ chatglm_model_name, trust_remote_code=True, device_map="auto", torch_dtype=torch.float16
 
 
 
24
  ).eval()
25
  use_chatglm = True
26
  else:
 
28
  except Exception as e:
29
  print(f"❌ ChatGLM 加载失败: {e}")
30
 
31
+ # ======================== 知识图谱结构 ========================
32
  knowledge_graph = {"entities": set(), "relations": []}
33
 
34
  def update_knowledge_graph(entities, relations):
 
44
  edges = [f"{h} --[{r}]-> {t}" for h, t, r in knowledge_graph["relations"]]
45
  return "📌 实体:\n" + "\n".join(nodes) + "\n\n📎 关系:\n" + "\n".join(edges)
46
 
47
+ # ======================== 实体识别(NER) ========================
48
  def ner(text, model_type="bert"):
49
  start_time = time.time()
50
  if model_type == "chatglm" and use_chatglm:
51
  try:
52
+ prompt = f"请从以下文本中识别所有实体,用JSON格式返回:[{text}]"
53
+ response, _ = chatglm_model.chat(chatglm_tokenizer, prompt, temperature=0.1)
54
  entities = json.loads(response)
55
  return entities, time.time() - start_time
56
  except Exception as e:
57
+ print(f"❌ ChatGLM 实体识别失败:{e}")
58
  return [], time.time() - start_time
59
 
60
  name_pattern = r"([\u4e00-\u9fa5]{2,4})(?![的等地得啦啊哦])"
 
78
 
79
  return entities, time.time() - start_time
80
 
81
+ # ======================== 关系抽取(RE) ========================
82
  def re_extract(entities, text):
83
  if len(entities) < 2:
84
  return []
 
88
  if use_chatglm:
89
  response, _ = chatglm_model.chat(chatglm_tokenizer, prompt, temperature=0.1)
90
  return json.loads(response)
91
+ except Exception as e:
92
+ print(f"❌ ChatGLM 关系抽取失败:{e}")
93
  return [{"head": e1['text'], "tail": e2['text'], "relation": "相关"} for i, e1 in enumerate(entities) for e2 in entities[i+1:]]
94
 
95
+ # ======================== 文本分析主流程 ========================
96
  def process_text(text, model_type="bert"):
97
  entities, duration = ner(text, model_type)
98
  relations = re_extract(entities, text)
 
101
  ent_text = "\n".join(f"{e['text']} ({e['type']}) [{e['start']}-{e['end']}]" for e in entities)
102
  rel_text = "\n".join(f"{r['head']} --[{r['relation']}]-> {r['tail']}" for r in relations)
103
  kg_text = visualize_kg_text()
104
+ return ent_text, rel_text, kg_text, f"{duration:.2f} 秒"
105
 
106
  def process_file(file, model_type="bert"):
107
  content = file.read()
 
111
  text = content.decode(encoding)
112
  return process_text(text, model_type)
113
 
114
+ # ======================== 模型评估与自动标注 ========================
115
  def convert_telegram_json_to_eval_format(path):
116
  data = json.load(open(path, encoding="utf-8"))
117
  result = []
 
147
  f.write(json_text)
148
  return fname
149
 
150
+ # ======================== Gradio 界面 ========================
151
  with gr.Blocks(css=".kg-graph {height: 500px;}") as demo:
152
  gr.Markdown("# 🤖 聊天记录实体关系识别系统")
153
 
 
157
  btn = gr.Button("开始分析")
158
  out1 = gr.Textbox(label="识别实体")
159
  out2 = gr.Textbox(label="识别关系")
160
+ out3 = gr.Textbox(label="知识图谱")
161
  out4 = gr.Textbox(label="耗时")
162
  btn.click(fn=process_text, inputs=[input_text, model_type], outputs=[out1, out2, out3, out4])
163
 
 
174
  eval_output = gr.Textbox(label="评估结果", lines=5)
175
  eval_btn.click(lambda f, m: evaluate_ner_model(convert_telegram_json_to_eval_format(f.name), m), [eval_file, eval_model], eval_output)
176
 
177
+ with gr.Tab("✏️ 自动标注"):
178
  raw_file = gr.File(label="上传 Telegram 原始 JSON")
179
  auto_model = gr.Radio(["bert", "chatglm"], value="bert")
180
+ auto_btn = gr.Button("自动标注")
181
+ marked_texts = gr.Textbox(label="标注结果", lines=20)
182
+ download_btn = gr.Button("💾 下载标注文件")
183
  auto_btn.click(fn=auto_annotate, inputs=[raw_file, auto_model], outputs=marked_texts)
184
  download_btn.click(fn=save_json, inputs=marked_texts, outputs=gr.File())
185