File size: 2,597 Bytes
36ba107
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import ctranslate2
from transformers import AutoTokenizer

import threading
import gradio as gr

from typing import Optional
from queue import Queue




class TokenIteratorStreamer:
    def __init__(self, end_token_id: int, timeout: Optional[float] = None):
        self.end_token_id = end_token_id
        self.queue = Queue()
        self.timeout = timeout

    def put(self, token: int):
        self.queue.put(token, timeout=self.timeout)

    def __iter__(self):
        return self

    def __next__(self):
        token = self.queue.get(timeout=self.timeout)
        if token == self.end_token_id:
            raise StopIteration()
        else:
            return token



def generate_prompt(history):
    prompt = ""
    for chain in history[:-1]:
        prompt += f"<human>: {chain[0]}\n<bot>: {chain[1]}\n"
    prompt += f"<human>: {history[-1][0]}\n<bot>:"
    tokens = tokenizer.convert_ids_to_tokens(tokenizer.encode(prompt))
    return tokens

def generate(streamer, history):
    def stepResultCallback(result):
        streamer.put(result.token_id)
        if result.is_last and (result.token_id != end_token_id):
            streamer.put(end_token_id)
        print(f"step={result.step}, batch_id={result.batch_id}, token={result.token}")
    
    tokens = generate_prompt(history)

    results = translator.translate_batch(
        [tokens],
        beam_size=1,
        max_decoding_length = 256,
        repetition_penalty = 1.8,
        callback = stepResultCallback
    )
    return results



translator = ctranslate2.Translator("model", intra_threads=2)
tokenizer = AutoTokenizer.from_pretrained("DKYoon/mt5-xl-lm-adapt")
end_token = "</s>"
end_token_id = tokenizer.encode(end_token)[0]


with gr.Blocks() as demo:
    chatbot = gr.Chatbot()
    msg = gr.Textbox()
    clear = gr.Button("Clear")

    def user(user_message, history):
        return "", history + [[user_message, ""]]

    def bot(history):
        bot_message_tokens = []
        streamer = TokenIteratorStreamer(end_token_id = end_token_id)
        generation_thread = threading.Thread(target=generate, args=(streamer, history))
        generation_thread.start()
        
        for token in streamer:
            bot_message_tokens.append(token)
            history[-1][1] = tokenizer.decode(bot_message_tokens)
            yield history
        generation_thread.join()

    msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
        bot, chatbot, chatbot
    )
    clear.click(lambda: None, None, chatbot, queue=False)
    
demo.queue()
if __name__ == "__main__":
    demo.launch()