chen666-666 commited on
Commit
4affd42
·
verified ·
1 Parent(s): e4ec800

Update utils.py

Browse files
Files changed (1) hide show
  1. utils.py +4 -5
utils.py CHANGED
@@ -4,7 +4,6 @@ import re
4
  from collections import defaultdict
5
  from transformers import (
6
  AutoTokenizer,
7
- AutoModelForTokenClassification,
8
  AutoModelForSequenceClassification,
9
  pipeline,
10
  )
@@ -15,14 +14,14 @@ from pyvis.network import Network
15
  # 实体识别模型(NER)
16
  # -------------------------------
17
  ner_tokenizer = AutoTokenizer.from_pretrained("ckiplab/bert-base-chinese-ner")
18
- ner_model = AutoModelForTokenClassification.from_pretrained("ckiplab/bert-base-chinese-ner")
19
  ner_pipeline = pipeline("ner", model=ner_model, tokenizer=ner_tokenizer, aggregation_strategy="simple")
20
 
21
 
22
  # -------------------------------
23
- # 人物关系分类模型(BERT 分类器)
24
  # -------------------------------
25
- rel_model_name = "uer/roberta-base-finetuned-baike-chinese-relation-extraction"
26
  rel_tokenizer = AutoTokenizer.from_pretrained(rel_model_name)
27
  rel_model = AutoModelForSequenceClassification.from_pretrained(rel_model_name)
28
  rel_model.eval()
@@ -160,4 +159,4 @@ def analyze_chat(file):
160
 
161
  graph_html = draw_graph(entities, relations)
162
 
163
- return str(entities), str(relations), graph_html, "\n".join(illegal_behavior_results)
 
4
  from collections import defaultdict
5
  from transformers import (
6
  AutoTokenizer,
 
7
  AutoModelForSequenceClassification,
8
  pipeline,
9
  )
 
14
  # 实体识别模型(NER)
15
  # -------------------------------
16
  ner_tokenizer = AutoTokenizer.from_pretrained("ckiplab/bert-base-chinese-ner")
17
+ ner_model = AutoModelForSequenceClassification.from_pretrained("ckiplab/bert-base-chinese-ner")
18
  ner_pipeline = pipeline("ner", model=ner_model, tokenizer=ner_tokenizer, aggregation_strategy="simple")
19
 
20
 
21
  # -------------------------------
22
+ # 人物关系分类模型(使用 RoBERTa)
23
  # -------------------------------
24
+ rel_model_name = "hfl/chinese-roberta-wwm-ext" # 推荐的中文 RoBERTa 模型
25
  rel_tokenizer = AutoTokenizer.from_pretrained(rel_model_name)
26
  rel_model = AutoModelForSequenceClassification.from_pretrained(rel_model_name)
27
  rel_model.eval()
 
159
 
160
  graph_html = draw_graph(entities, relations)
161
 
162
+ return str(entities), str(relations), graph_html, "\n".join(illegal_behavior_results)