Spaces:
Sleeping
Sleeping
Update utils.py
Browse files
utils.py
CHANGED
@@ -4,7 +4,6 @@ import re
|
|
4 |
from collections import defaultdict
|
5 |
from transformers import (
|
6 |
AutoTokenizer,
|
7 |
-
AutoModelForTokenClassification,
|
8 |
AutoModelForSequenceClassification,
|
9 |
pipeline,
|
10 |
)
|
@@ -15,14 +14,14 @@ from pyvis.network import Network
|
|
15 |
# 实体识别模型(NER)
|
16 |
# -------------------------------
|
17 |
ner_tokenizer = AutoTokenizer.from_pretrained("ckiplab/bert-base-chinese-ner")
|
18 |
-
ner_model =
|
19 |
ner_pipeline = pipeline("ner", model=ner_model, tokenizer=ner_tokenizer, aggregation_strategy="simple")
|
20 |
|
21 |
|
22 |
# -------------------------------
|
23 |
-
#
|
24 |
# -------------------------------
|
25 |
-
rel_model_name = "
|
26 |
rel_tokenizer = AutoTokenizer.from_pretrained(rel_model_name)
|
27 |
rel_model = AutoModelForSequenceClassification.from_pretrained(rel_model_name)
|
28 |
rel_model.eval()
|
@@ -160,4 +159,4 @@ def analyze_chat(file):
|
|
160 |
|
161 |
graph_html = draw_graph(entities, relations)
|
162 |
|
163 |
-
return str(entities), str(relations), graph_html, "\n".join(illegal_behavior_results)
|
|
|
4 |
from collections import defaultdict
|
5 |
from transformers import (
|
6 |
AutoTokenizer,
|
|
|
7 |
AutoModelForSequenceClassification,
|
8 |
pipeline,
|
9 |
)
|
|
|
14 |
# 实体识别模型(NER)
|
15 |
# -------------------------------
|
16 |
ner_tokenizer = AutoTokenizer.from_pretrained("ckiplab/bert-base-chinese-ner")
|
17 |
+
ner_model = AutoModelForSequenceClassification.from_pretrained("ckiplab/bert-base-chinese-ner")
|
18 |
ner_pipeline = pipeline("ner", model=ner_model, tokenizer=ner_tokenizer, aggregation_strategy="simple")
|
19 |
|
20 |
|
21 |
# -------------------------------
|
22 |
+
# 人物关系分类模型(使用 RoBERTa)
|
23 |
# -------------------------------
|
24 |
+
rel_model_name = "hfl/chinese-roberta-wwm-ext" # 推荐的中文 RoBERTa 模型
|
25 |
rel_tokenizer = AutoTokenizer.from_pretrained(rel_model_name)
|
26 |
rel_model = AutoModelForSequenceClassification.from_pretrained(rel_model_name)
|
27 |
rel_model.eval()
|
|
|
159 |
|
160 |
graph_html = draw_graph(entities, relations)
|
161 |
|
162 |
+
return str(entities), str(relations), graph_html, "\n".join(illegal_behavior_results)
|