JenniferHJF commited on
Commit
329843e
·
verified ·
1 Parent(s): 80fca3b

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +41 -22
agent.py CHANGED
@@ -1,27 +1,46 @@
1
- from transformers import pipeline
 
2
 
3
- # Step 1: 初始化翻译模型(Qwen 微调模型)
4
- translator = pipeline("text-generation", model="JenniferHJF/qwen1.5-emoji-finetuned", max_new_tokens=64)
 
 
 
 
 
 
 
 
5
 
6
- # Step 2: 初始化多个分类模型
7
- available_models = {
8
- "Hate Speech RoBERTa": "facebook/roberta-hate-speech-dynabench",
9
- "Twitter Offensive": "cardiffnlp/twitter-roberta-base-offensive",
10
- "Chinese Sentiment": "uer/roberta-base-finetuned-chinanews-chinese"
11
- }
12
 
13
- classifier_pipes = {
14
- name: pipeline("text-classification", model=repo, truncation=True)
15
- for name, repo in available_models.items()
16
- }
 
 
 
 
 
 
 
17
 
18
- # Step 3: 主处理函数
19
- def classify_emoji_text(text, selected_model):
20
- # 翻译表情
21
- translated_output = translator(f"请将以下句子中的 emoji 和谐音表达翻译为中文:{text}", return_full_text=False)
22
- translated = translated_output[0]["generated_text"].strip()
 
 
 
 
 
23
 
24
- # 分类模型处理
25
- classifier = classifier_pipes[selected_model]
26
- result = classifier(translated)[0]
27
- return translated, result["label"], result["score"]
 
 
 
1
+ from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
2
+ import torch
3
 
4
+ # Step 1: 加载 emoji 翻译模型(你微调后的模型)
5
+ emoji_model_id = "JenniferHJF/qwen1.5-emoji-finetuned"
6
+ emoji_tokenizer = AutoTokenizer.from_pretrained(emoji_model_id, trust_remote_code=True)
7
+ emoji_model = AutoModelForCausalLM.from_pretrained(
8
+ emoji_model_id,
9
+ device_map="auto",
10
+ torch_dtype=torch.float16,
11
+ trust_remote_code=True
12
+ )
13
+ emoji_model.eval()
14
 
15
+ # Step 2: 加载冒犯文本分类器(你可更换为更强大的模型)
16
+ classifier = pipeline("text-classification", model="unitary/toxic-bert", device=0 if torch.cuda.is_available() else -1)
 
 
 
 
17
 
18
+ def classify_emoji_text(text: str):
19
+ """
20
+ 输入文本 -> 翻译 emoji -> 分类是否冒犯
21
+ """
22
+ # ✅ 构造翻译 prompt
23
+ prompt = f"""请判断下面的文本是否具有冒犯性。
24
+ 这里的“冒犯性”主要指包含人身攻击、侮辱、歧视、仇恨言论或极端粗俗的内容。
25
+ 如果文本具有冒犯性,请仅回复冒犯;如果不具有冒犯性,请仅回复不冒犯。
26
+ 文本如下:
27
+ {text}
28
+ """
29
 
30
+ # 生成翻译结果
31
+ input_ids = emoji_tokenizer(prompt, return_tensors="pt").to(emoji_model.device)
32
+ with torch.no_grad():
33
+ output_ids = emoji_model.generate(
34
+ **input_ids,
35
+ max_new_tokens=50,
36
+ do_sample=False
37
+ )
38
+ decoded = emoji_tokenizer.decode(output_ids[0], skip_special_tokens=True)
39
+ translated_text = decoded.strip().split("文本如下:")[-1].strip()
40
 
41
+ # ✅ 送入第二阶段冒犯性识别
42
+ result = classifier(translated_text)[0]
43
+ label = result["label"]
44
+ score = result["score"]
45
+
46
+ return translated_text, label, score