smjain commited on
Commit
85eb8de
·
1 Parent(s): bea15a3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -2
app.py CHANGED
@@ -8,10 +8,10 @@ mdl = BlenderbotForConditionalGeneration.from_pretrained("facebook/blenderbot-40
8
 
9
  def converse(user_input, chat_history=[]):
10
 
11
- user_input_ids = chat_tkn(user_input + chat_tkn.eos_token, return_tensors='pt')
12
 
13
  # create a combined tensor with chat history
14
- bot_input_ids = torch.cat([torch.LongTensor(chat_history), **user_input_ids], dim=-1)
15
 
16
  # generate a response
17
  chat_history = mdl.generate(bot_input_ids, max_length=1000, pad_token_id=chat_tkn.eos_token_id).tolist()
 
8
 
9
  def converse(user_input, chat_history=[]):
10
 
11
+ user_input_ids = chat_tkn(user_input + chat_tkn.eos_token, return_tensors='pt').input_ids
12
 
13
  # create a combined tensor with chat history
14
+ bot_input_ids = torch.cat([torch.LongTensor(chat_history), user_input_ids], dim=-1)
15
 
16
  # generate a response
17
  chat_history = mdl.generate(bot_input_ids, max_length=1000, pad_token_id=chat_tkn.eos_token_id).tolist()