|
import openai
|
|
import sys
|
|
import re
|
|
sys.path.append('.')
|
|
from local_config import openai_key
|
|
from utils.format.txt_2_list import txt_2_list
|
|
|
|
|
|
openai.api_key = openai_key
|
|
|
|
def text_classification(src_txt, type_arr, history=[]):
|
|
history_txt = ''.join([f'输入|```{q}```输出|{a}\n' for q, a in history])
|
|
user = f"你是一个聪明而且有百年经验的文本分类器. 你的任务是从一段文本里面提取出相应的分类结果签。你的回答必须用统一的格式。文本用```符号分割。分类类型保存在一个数组里{type_arr}\n{history_txt}输入|```{src_txt}```输出|"
|
|
|
|
completion = openai.ChatCompletion.create(
|
|
model="gpt-3.5-turbo",
|
|
messages=[
|
|
{"role": "user", "content": f"{user}"},
|
|
]
|
|
)
|
|
|
|
|
|
content = completion.choices[0].message.content
|
|
|
|
result = []
|
|
for type in type_arr:
|
|
if type in content:
|
|
result.append(type)
|
|
|
|
content = content.replace(type, '')
|
|
return result
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|
|
type_arr_txt = "天气查询、股票查询、其他"
|
|
type_arr = txt_2_list(type_arr_txt)
|
|
txts = [
|
|
'这个商品真不错',
|
|
'用着不行',
|
|
'没用过这么好的东西',
|
|
|
|
]
|
|
history = [
|
|
['这个商品真不错', ['其他']],
|
|
]
|
|
for txt in txts:
|
|
result = text_classification(txt, type_arr, history)
|
|
print(txt, result)
|
|
|