File size: 3,795 Bytes
8a2b97e
 
 
ce756cd
 
8a2b97e
ce756cd
8a2b97e
 
 
ce756cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8a2b97e
ce756cd
 
 
 
 
 
 
 
 
 
 
8a2b97e
ce756cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8a2b97e
 
 
ce756cd
 
 
 
8a2b97e
ce756cd
 
 
 
8a2b97e
 
 
7398b02
332c046
ce756cd
 
 
 
 
 
332c046
ce756cd
 
 
8a2b97e
ce756cd
 
 
8a2b97e
ce756cd
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import requests
import json
import gradio as gr
from concurrent.futures import ThreadPoolExecutor
from sentence_transformers import util

url = 'https://souljoy-my-api.hf.space/qa_maker'
headers = {
    'Content-Type': 'application/json',
}
thread_pool_executor = ThreadPoolExecutor(max_workers=16)
history_max_len = 500
all_max_len = 2000


def get_emb(text):
    emb_url = 'https://souljoy-my-api.hf.space/embeddings'
    data = {"content": text}
    result = requests.post(url=emb_url,
                           data=json.dumps(data),
                           headers=headers
                           )

    return result.json()['data'][0]['embedding']


def doc_emb(doc: str):
    texts = doc.split('\n')
    futures = []
    for text in texts:
        futures.append(thread_pool_executor.submit(get_emb, text))
    emb_list = []
    for f in futures:
        emb_list.append(f.result())
    print('\n'.join(texts))
    return texts, emb_list, gr.Textbox.update(visible=True), gr.Button.update(visible=True), gr.Markdown.update(
        visible=True)


def get_response(msg, bot, doc_text_list, doc_embeddings):
    future = thread_pool_executor.submit(get_emb, msg)
    now_len = len(msg)
    req_json = {'question': msg}
    his_bg = -1
    for i in range(len(bot) - 1, -1, -1):
        if now_len + len(bot[i][0]) + len(bot[i][1]) > history_max_len:
            break
        now_len += len(bot[i][0]) + len(bot[i][1])
        his_bg = i
    req_json['history'] = [] if his_bg == -1 else bot[his_bg:]
    query_embedding = future.result()
    cos_scores = util.cos_sim(query_embedding, doc_embeddings)[0]
    score_index = [[score, index] for score, index in zip(cos_scores, [i for i in range(len(cos_scores))])]
    score_index.sort(key=lambda x: x[0], reverse=True)
    print('score_index:\n', score_index)
    index_list, sub_doc_list = [], []
    for s_i in score_index:
        doc = doc_text_list[s_i[1]]
        if now_len + len(doc) > all_max_len:
            break
        index_list.append(s_i[1])
        now_len += len(doc)
    index_list.sort()
    for i in index_list:
        sub_doc_list.append(doc_text_list[i])
    req_json['doc'] = '' if len(sub_doc_list) == 0 else '\n'.join(sub_doc_list)
    data = {"content": json.dumps(req_json)}
    print('data:\n', req_json)
    result = requests.post(url='https://souljoy-my-api.hf.space/chatpdf',
                           data=json.dumps(data),
                           headers=headers
                           )
    res = result.json()['content']
    bot.append([msg, res])
    return bot[max(0, len(bot) - 3):], gr.Markdown.update(visible=False)


def up_file(files):
    for idx, file in enumerate(files):
        print(file.name)
    return gr.Button.update(visible=True)


with gr.Blocks() as demo:
    with gr.Row():
        with gr.Column():
            file = gr.File(file_types=['.pdf'], label='上传PDF')
            txt = gr.Textbox(label='PDF解析结果', visible=False)
            doc_bu = gr.Button(value='提交', visible=False)
            md = gr.Markdown("""#### 文档提交成功 🙋 """, visible=False)
            doc_text_state = gr.State([])
            doc_emb_state = gr.State([])
        with gr.Column():
            chat_bot = gr.Chatbot()
            msg_txt = gr.Textbox(label='消息框', placeholder='输入消息,点击发送', visible=False)
            chat_bu = gr.Button(value='发送', visible=False)

    doc_bu.click(doc_emb, [txt], [doc_text_state, doc_emb_state, msg_txt, chat_bu, md])
    chat_bu.click(get_response, [msg_txt, chat_bot, doc_text_state, doc_emb_state], [chat_bot, md])
    file.change(up_file, [file], [doc_bu])
if __name__ == "__main__":
    demo.queue().launch()
    # demo.queue().launch(share=False, server_name='172.22.2.54', server_port=9191)