Spaces:
Sleeping
Sleeping
Commit
·
f305260
1
Parent(s):
0bb16d9
add app.py and requirements.txt
Browse files
app.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
import torch
|
2 |
-
from transformers import AutoTokenizer,
|
3 |
import gradio as gr
|
4 |
import re
|
5 |
import os
|
@@ -10,8 +10,9 @@ import time
|
|
10 |
|
11 |
# ======================== 模型加载 ========================
|
12 |
bert_model_name = "bert-base-chinese"
|
13 |
-
bert_tokenizer =
|
14 |
-
|
|
|
15 |
|
16 |
chatglm_model, chatglm_tokenizer = None, None
|
17 |
use_chatglm = False
|
@@ -50,32 +51,25 @@ def ner(text, model_type="bert"):
|
|
50 |
if model_type == "chatglm" and use_chatglm:
|
51 |
try:
|
52 |
prompt = f"请从以下文本中识别所有实体,用JSON格式返回:[{text}]"
|
53 |
-
response
|
|
|
|
|
54 |
entities = json.loads(response)
|
55 |
return entities, time.time() - start_time
|
56 |
except Exception as e:
|
57 |
print(f"❌ ChatGLM 实体识别失败:{e}")
|
58 |
return [], time.time() - start_time
|
59 |
|
60 |
-
|
61 |
-
|
62 |
-
entities
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
entities.append({"text": match.group(1), "start": start, "end": end, "type": "人名"})
|
71 |
-
occupied.add((start, end))
|
72 |
-
|
73 |
-
for match in re.finditer(id_pattern, text):
|
74 |
-
start, end = match.start(1), match.end(1)
|
75 |
-
if not is_occupied(start, end):
|
76 |
-
entities.append({"text": match.group(1), "start": start, "end": end, "type": "用户ID"})
|
77 |
-
occupied.add((start, end))
|
78 |
-
|
79 |
return entities, time.time() - start_time
|
80 |
|
81 |
# ======================== 关系抽取(RE) ========================
|
@@ -86,7 +80,9 @@ def re_extract(entities, text):
|
|
86 |
entity_list = [e['text'] for e in entities]
|
87 |
prompt = f"分析以下实体之间的关系:{entity_list}\n文本上下文:{text}"
|
88 |
if use_chatglm:
|
89 |
-
response
|
|
|
|
|
90 |
return json.loads(response)
|
91 |
except Exception as e:
|
92 |
print(f"❌ ChatGLM 关系抽取失败:{e}")
|
@@ -103,32 +99,38 @@ def process_text(text, model_type="bert"):
|
|
103 |
kg_text = visualize_kg_text()
|
104 |
return ent_text, rel_text, kg_text, f"{duration:.2f} 秒"
|
105 |
|
106 |
-
def process_file(file, model_type="bert"):
|
107 |
-
content = file.read()
|
108 |
-
if len(content) > 5 * 1024 * 1024:
|
109 |
-
return "❌ 文件太大", "", "", ""
|
110 |
-
encoding = chardet.detect(content)['encoding'] or 'utf-8'
|
111 |
-
text = content.decode(encoding)
|
112 |
-
return process_text(text, model_type)
|
113 |
-
|
114 |
# ======================== 模型评估与自动标注 ========================
|
115 |
def convert_telegram_json_to_eval_format(path):
|
116 |
with open(path, encoding="utf-8") as f:
|
117 |
data = json.load(f)
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
126 |
|
127 |
def evaluate_ner_model(data, model_type):
|
128 |
y_true, y_pred = [], []
|
129 |
for item in data:
|
130 |
-
|
131 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
132 |
pred = set(e['text'] for e in pred)
|
133 |
for ent in gold.union(pred):
|
134 |
y_true.append(1 if ent in gold else 0)
|
@@ -143,7 +145,7 @@ def auto_annotate(file, model_type):
|
|
143 |
return json.dumps(data, ensure_ascii=False, indent=2)
|
144 |
|
145 |
def save_json(json_text):
|
146 |
-
fname = "
|
147 |
with open(fname, "w", encoding="utf-8") as f:
|
148 |
f.write(json_text)
|
149 |
return fname
|
|
|
1 |
import torch
|
2 |
+
from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline, AutoModel
|
3 |
import gradio as gr
|
4 |
import re
|
5 |
import os
|
|
|
10 |
|
11 |
# ======================== 模型加载 ========================
|
12 |
bert_model_name = "bert-base-chinese"
|
13 |
+
bert_tokenizer = AutoTokenizer.from_pretrained(bert_model_name)
|
14 |
+
bert_ner_model = AutoModelForTokenClassification.from_pretrained("ckiplab/bert-base-chinese-ner")
|
15 |
+
bert_ner_pipeline = pipeline("ner", model=bert_ner_model, tokenizer=bert_tokenizer, aggregation_strategy="simple")
|
16 |
|
17 |
chatglm_model, chatglm_tokenizer = None, None
|
18 |
use_chatglm = False
|
|
|
51 |
if model_type == "chatglm" and use_chatglm:
|
52 |
try:
|
53 |
prompt = f"请从以下文本中识别所有实体,用JSON格式返回:[{text}]"
|
54 |
+
response = chatglm_model.chat(chatglm_tokenizer, prompt, temperature=0.1)
|
55 |
+
if isinstance(response, tuple):
|
56 |
+
response = response[0]
|
57 |
entities = json.loads(response)
|
58 |
return entities, time.time() - start_time
|
59 |
except Exception as e:
|
60 |
print(f"❌ ChatGLM 实体识别失败:{e}")
|
61 |
return [], time.time() - start_time
|
62 |
|
63 |
+
# 使用微调的 BERT 中文 NER 模型
|
64 |
+
raw_results = bert_ner_pipeline(text)
|
65 |
+
entities = []
|
66 |
+
for r in raw_results:
|
67 |
+
entities.append({
|
68 |
+
"text": r["word"],
|
69 |
+
"start": r["start"],
|
70 |
+
"end": r["end"],
|
71 |
+
"type": r["entity_group"]
|
72 |
+
})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
return entities, time.time() - start_time
|
74 |
|
75 |
# ======================== 关系抽取(RE) ========================
|
|
|
80 |
entity_list = [e['text'] for e in entities]
|
81 |
prompt = f"分析以下实体之间的关系:{entity_list}\n文本上下文:{text}"
|
82 |
if use_chatglm:
|
83 |
+
response = chatglm_model.chat(chatglm_tokenizer, prompt, temperature=0.1)
|
84 |
+
if isinstance(response, tuple):
|
85 |
+
response = response[0]
|
86 |
return json.loads(response)
|
87 |
except Exception as e:
|
88 |
print(f"❌ ChatGLM 关系抽取失败:{e}")
|
|
|
99 |
kg_text = visualize_kg_text()
|
100 |
return ent_text, rel_text, kg_text, f"{duration:.2f} 秒"
|
101 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
102 |
# ======================== 模型评估与自动标注 ========================
|
103 |
def convert_telegram_json_to_eval_format(path):
|
104 |
with open(path, encoding="utf-8") as f:
|
105 |
data = json.load(f)
|
106 |
+
if isinstance(data, dict) and "text" in data:
|
107 |
+
return [{"text": data["text"], "entities": [
|
108 |
+
{"text": data["text"][e["start"]:e["end"]]} for e in data.get("entities", [])
|
109 |
+
]}]
|
110 |
+
elif isinstance(data, list):
|
111 |
+
return data
|
112 |
+
elif isinstance(data, dict) and "messages" in data:
|
113 |
+
result = []
|
114 |
+
for m in data.get("messages", []):
|
115 |
+
if isinstance(m.get("text"), str):
|
116 |
+
result.append({"text": m["text"], "entities": []})
|
117 |
+
elif isinstance(m.get("text"), list):
|
118 |
+
txt = ''.join([x["text"] if isinstance(x, dict) else x for x in m["text"]])
|
119 |
+
result.append({"text": txt, "entities": []})
|
120 |
+
return result
|
121 |
+
return []
|
122 |
|
123 |
def evaluate_ner_model(data, model_type):
|
124 |
y_true, y_pred = [], []
|
125 |
for item in data:
|
126 |
+
text = item["text"]
|
127 |
+
gold = set()
|
128 |
+
for e in item.get("entities", []):
|
129 |
+
if "text" in e:
|
130 |
+
gold.add(e["text"])
|
131 |
+
elif "start" in e and "end" in e:
|
132 |
+
gold.add(text[e["start"]:e["end"]])
|
133 |
+
pred, _ = ner(text, model_type)
|
134 |
pred = set(e['text'] for e in pred)
|
135 |
for ent in gold.union(pred):
|
136 |
y_true.append(1 if ent in gold else 0)
|
|
|
145 |
return json.dumps(data, ensure_ascii=False, indent=2)
|
146 |
|
147 |
def save_json(json_text):
|
148 |
+
fname = f"auto_labeled_{int(time.time())}.json"
|
149 |
with open(fname, "w", encoding="utf-8") as f:
|
150 |
f.write(json_text)
|
151 |
return fname
|