Update app.py
Browse files
app.py
CHANGED
@@ -8,10 +8,10 @@ mdl = BlenderbotForConditionalGeneration.from_pretrained("facebook/blenderbot-40
|
|
8 |
|
9 |
def converse(user_input, chat_history=[]):
|
10 |
|
11 |
-
user_input_ids = chat_tkn(user_input + chat_tkn.eos_token, return_tensors='pt')
|
12 |
|
13 |
# create a combined tensor with chat history
|
14 |
-
bot_input_ids = torch.cat([torch.LongTensor(chat_history),
|
15 |
|
16 |
# generate a response
|
17 |
chat_history = mdl.generate(bot_input_ids, max_length=1000, pad_token_id=chat_tkn.eos_token_id).tolist()
|
|
|
8 |
|
9 |
def converse(user_input, chat_history=[]):
|
10 |
|
11 |
+
user_input_ids = chat_tkn(user_input + chat_tkn.eos_token, return_tensors='pt').input_ids
|
12 |
|
13 |
# create a combined tensor with chat history
|
14 |
+
bot_input_ids = torch.cat([torch.LongTensor(chat_history), user_input_ids], dim=-1)
|
15 |
|
16 |
# generate a response
|
17 |
chat_history = mdl.generate(bot_input_ids, max_length=1000, pad_token_id=chat_tkn.eos_token_id).tolist()
|