chen666-666 commited on
Commit
f305260
·
1 Parent(s): 0bb16d9

add app.py and requirements.txt

Browse files
Files changed (1) hide show
  1. app.py +45 -43
app.py CHANGED
@@ -1,5 +1,5 @@
1
  import torch
2
- from transformers import AutoTokenizer, AutoModel, BertTokenizer, BertModel
3
  import gradio as gr
4
  import re
5
  import os
@@ -10,8 +10,9 @@ import time
10
 
11
  # ======================== 模型加载 ========================
12
  bert_model_name = "bert-base-chinese"
13
- bert_tokenizer = BertTokenizer.from_pretrained(bert_model_name)
14
- bert_model = BertModel.from_pretrained(bert_model_name)
 
15
 
16
  chatglm_model, chatglm_tokenizer = None, None
17
  use_chatglm = False
@@ -50,32 +51,25 @@ def ner(text, model_type="bert"):
50
  if model_type == "chatglm" and use_chatglm:
51
  try:
52
  prompt = f"请从以下文本中识别所有实体,用JSON格式返回:[{text}]"
53
- response, _ = chatglm_model.chat(chatglm_tokenizer, prompt, temperature=0.1)
 
 
54
  entities = json.loads(response)
55
  return entities, time.time() - start_time
56
  except Exception as e:
57
  print(f"❌ ChatGLM 实体识别失败:{e}")
58
  return [], time.time() - start_time
59
 
60
- name_pattern = r"([\u4e00-\u9fa5]{2,4})(?![的等地得啦啊哦])"
61
- id_pattern = r"(?<!\S)([a-zA-Z_][a-zA-Z0-9_]{4,})(?![\u4e00-\u9fa5])"
62
- entities, occupied = [], set()
63
-
64
- def is_occupied(start, end):
65
- return any(s <= start < e or s < end <= e for s, e in occupied)
66
-
67
- for match in re.finditer(name_pattern, text):
68
- start, end = match.start(1), match.end(1)
69
- if not is_occupied(start, end):
70
- entities.append({"text": match.group(1), "start": start, "end": end, "type": "人名"})
71
- occupied.add((start, end))
72
-
73
- for match in re.finditer(id_pattern, text):
74
- start, end = match.start(1), match.end(1)
75
- if not is_occupied(start, end):
76
- entities.append({"text": match.group(1), "start": start, "end": end, "type": "用户ID"})
77
- occupied.add((start, end))
78
-
79
  return entities, time.time() - start_time
80
 
81
  # ======================== 关系抽取(RE) ========================
@@ -86,7 +80,9 @@ def re_extract(entities, text):
86
  entity_list = [e['text'] for e in entities]
87
  prompt = f"分析以下实体之间的关系:{entity_list}\n文本上下文:{text}"
88
  if use_chatglm:
89
- response, _ = chatglm_model.chat(chatglm_tokenizer, prompt, temperature=0.1)
 
 
90
  return json.loads(response)
91
  except Exception as e:
92
  print(f"❌ ChatGLM 关系抽取失败:{e}")
@@ -103,32 +99,38 @@ def process_text(text, model_type="bert"):
103
  kg_text = visualize_kg_text()
104
  return ent_text, rel_text, kg_text, f"{duration:.2f} 秒"
105
 
106
- def process_file(file, model_type="bert"):
107
- content = file.read()
108
- if len(content) > 5 * 1024 * 1024:
109
- return "❌ 文件太大", "", "", ""
110
- encoding = chardet.detect(content)['encoding'] or 'utf-8'
111
- text = content.decode(encoding)
112
- return process_text(text, model_type)
113
-
114
  # ======================== 模型评估与自动标注 ========================
115
  def convert_telegram_json_to_eval_format(path):
116
  with open(path, encoding="utf-8") as f:
117
  data = json.load(f)
118
- result = []
119
- for m in data.get("messages", []):
120
- if isinstance(m.get("text"), str):
121
- result.append({"text": m["text"], "entities": []})
122
- elif isinstance(m.get("text"), list):
123
- txt = ''.join([x["text"] if isinstance(x, dict) else x for x in m["text"]])
124
- result.append({"text": txt, "entities": []})
125
- return result
 
 
 
 
 
 
 
 
126
 
127
  def evaluate_ner_model(data, model_type):
128
  y_true, y_pred = [], []
129
  for item in data:
130
- gold = set(e['text'] for e in item['entities'])
131
- pred, _ = ner(item['text'], model_type)
 
 
 
 
 
 
132
  pred = set(e['text'] for e in pred)
133
  for ent in gold.union(pred):
134
  y_true.append(1 if ent in gold else 0)
@@ -143,7 +145,7 @@ def auto_annotate(file, model_type):
143
  return json.dumps(data, ensure_ascii=False, indent=2)
144
 
145
  def save_json(json_text):
146
- fname = "auto_labeled.json"
147
  with open(fname, "w", encoding="utf-8") as f:
148
  f.write(json_text)
149
  return fname
 
1
  import torch
2
+ from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline, AutoModel
3
  import gradio as gr
4
  import re
5
  import os
 
10
 
11
  # ======================== 模型加载 ========================
12
  bert_model_name = "bert-base-chinese"
13
+ bert_tokenizer = AutoTokenizer.from_pretrained(bert_model_name)
14
+ bert_ner_model = AutoModelForTokenClassification.from_pretrained("ckiplab/bert-base-chinese-ner")
15
+ bert_ner_pipeline = pipeline("ner", model=bert_ner_model, tokenizer=bert_tokenizer, aggregation_strategy="simple")
16
 
17
  chatglm_model, chatglm_tokenizer = None, None
18
  use_chatglm = False
 
51
  if model_type == "chatglm" and use_chatglm:
52
  try:
53
  prompt = f"请从以下文本中识别所有实体,用JSON格式返回:[{text}]"
54
+ response = chatglm_model.chat(chatglm_tokenizer, prompt, temperature=0.1)
55
+ if isinstance(response, tuple):
56
+ response = response[0]
57
  entities = json.loads(response)
58
  return entities, time.time() - start_time
59
  except Exception as e:
60
  print(f"❌ ChatGLM 实体识别失败:{e}")
61
  return [], time.time() - start_time
62
 
63
+ # 使用微调的 BERT 中文 NER 模型
64
+ raw_results = bert_ner_pipeline(text)
65
+ entities = []
66
+ for r in raw_results:
67
+ entities.append({
68
+ "text": r["word"],
69
+ "start": r["start"],
70
+ "end": r["end"],
71
+ "type": r["entity_group"]
72
+ })
 
 
 
 
 
 
 
 
 
73
  return entities, time.time() - start_time
74
 
75
  # ======================== 关系抽取(RE) ========================
 
80
  entity_list = [e['text'] for e in entities]
81
  prompt = f"分析以下实体之间的关系:{entity_list}\n文本上下文:{text}"
82
  if use_chatglm:
83
+ response = chatglm_model.chat(chatglm_tokenizer, prompt, temperature=0.1)
84
+ if isinstance(response, tuple):
85
+ response = response[0]
86
  return json.loads(response)
87
  except Exception as e:
88
  print(f"❌ ChatGLM 关系抽取失败:{e}")
 
99
  kg_text = visualize_kg_text()
100
  return ent_text, rel_text, kg_text, f"{duration:.2f} 秒"
101
 
 
 
 
 
 
 
 
 
102
  # ======================== 模型评估与自动标注 ========================
103
  def convert_telegram_json_to_eval_format(path):
104
  with open(path, encoding="utf-8") as f:
105
  data = json.load(f)
106
+ if isinstance(data, dict) and "text" in data:
107
+ return [{"text": data["text"], "entities": [
108
+ {"text": data["text"][e["start"]:e["end"]]} for e in data.get("entities", [])
109
+ ]}]
110
+ elif isinstance(data, list):
111
+ return data
112
+ elif isinstance(data, dict) and "messages" in data:
113
+ result = []
114
+ for m in data.get("messages", []):
115
+ if isinstance(m.get("text"), str):
116
+ result.append({"text": m["text"], "entities": []})
117
+ elif isinstance(m.get("text"), list):
118
+ txt = ''.join([x["text"] if isinstance(x, dict) else x for x in m["text"]])
119
+ result.append({"text": txt, "entities": []})
120
+ return result
121
+ return []
122
 
123
  def evaluate_ner_model(data, model_type):
124
  y_true, y_pred = [], []
125
  for item in data:
126
+ text = item["text"]
127
+ gold = set()
128
+ for e in item.get("entities", []):
129
+ if "text" in e:
130
+ gold.add(e["text"])
131
+ elif "start" in e and "end" in e:
132
+ gold.add(text[e["start"]:e["end"]])
133
+ pred, _ = ner(text, model_type)
134
  pred = set(e['text'] for e in pred)
135
  for ent in gold.union(pred):
136
  y_true.append(1 if ent in gold else 0)
 
145
  return json.dumps(data, ensure_ascii=False, indent=2)
146
 
147
  def save_json(json_text):
148
+ fname = f"auto_labeled_{int(time.time())}.json"
149
  with open(fname, "w", encoding="utf-8") as f:
150
  f.write(json_text)
151
  return fname