Update app.py
Browse files
app.py
CHANGED
@@ -7,14 +7,14 @@ chat_tkn = BlenderbotTokenizer.from_pretrained("facebook/blenderbot-400M-distill
|
|
7 |
mdl = BlenderbotForConditionalGeneration.from_pretrained("facebook/blenderbot-400M-distill")
|
8 |
|
9 |
def converse(user_input, chat_history=[]):
|
10 |
-
|
11 |
user_input_ids = chat_tkn(user_input + chat_tkn.eos_token, return_tensors='pt')
|
12 |
|
13 |
-
#
|
14 |
-
bot_input_ids = torch.cat([torch.LongTensor(
|
15 |
|
16 |
# generate a response
|
17 |
-
|
18 |
|
19 |
# convert the tokens to text, and then split the responses into lines
|
20 |
response = chat_tkn.decode(chat_history[0]).split("<|endoftext|>")
|
@@ -27,7 +27,7 @@ def converse(user_input, chat_history=[]):
|
|
27 |
html += "<div class='msg {}'> {}</div>".format(cls, msg)
|
28 |
html += "</div>"
|
29 |
|
30 |
-
return html,
|
31 |
|
32 |
import gradio as gr
|
33 |
|
|
|
7 |
mdl = BlenderbotForConditionalGeneration.from_pretrained("facebook/blenderbot-400M-distill")
|
8 |
|
9 |
def converse(user_input, chat_history=[]):
|
10 |
+
|
11 |
user_input_ids = chat_tkn(user_input + chat_tkn.eos_token, return_tensors='pt')
|
12 |
|
13 |
+
# create a combined tensor with chat history
|
14 |
+
bot_input_ids = torch.cat([torch.LongTensor(chat_history), user_input_ids], dim=-1)
|
15 |
|
16 |
# generate a response
|
17 |
+
chat_history = mdl.generate(bot_input_ids, max_length=1000, pad_token_id=chat_tkn.eos_token_id).tolist()
|
18 |
|
19 |
# convert the tokens to text, and then split the responses into lines
|
20 |
response = chat_tkn.decode(chat_history[0]).split("<|endoftext|>")
|
|
|
27 |
html += "<div class='msg {}'> {}</div>".format(cls, msg)
|
28 |
html += "</div>"
|
29 |
|
30 |
+
return html, chat_history
|
31 |
|
32 |
import gradio as gr
|
33 |
|