smjain commited on
Commit
f53d19b
·
1 Parent(s): e4c7064

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -5
app.py CHANGED
@@ -7,14 +7,14 @@ chat_tkn = BlenderbotTokenizer.from_pretrained("facebook/blenderbot-400M-distill
7
  mdl = BlenderbotForConditionalGeneration.from_pretrained("facebook/blenderbot-400M-distill")
8
 
9
  def converse(user_input, chat_history=[]):
10
- #tkn_ids = chat_tkn(input(">> User:") + chat_tkn.eos_token, return_tensors='pt')
11
  user_input_ids = chat_tkn(user_input + chat_tkn.eos_token, return_tensors='pt')
12
 
13
- # append the new user input tokens to the chat history
14
- bot_input_ids = torch.cat([torch.LongTensor(history), user_input_ids], dim=-1)
15
 
16
  # generate a response
17
- history = mdl.generate(bot_input_ids, max_length=1000, pad_token_id=chat_tkn.eos_token_id).tolist()
18
 
19
  # convert the tokens to text, and then split the responses into lines
20
  response = chat_tkn.decode(chat_history[0]).split("<|endoftext|>")
@@ -27,7 +27,7 @@ def converse(user_input, chat_history=[]):
27
  html += "<div class='msg {}'> {}</div>".format(cls, msg)
28
  html += "</div>"
29
 
30
- return html, history
31
 
32
  import gradio as gr
33
 
 
7
  mdl = BlenderbotForConditionalGeneration.from_pretrained("facebook/blenderbot-400M-distill")
8
 
9
  def converse(user_input, chat_history=[]):
10
+
11
  user_input_ids = chat_tkn(user_input + chat_tkn.eos_token, return_tensors='pt')
12
 
13
+ # create a combined tensor with chat history
14
+ bot_input_ids = torch.cat([torch.LongTensor(chat_history), user_input_ids], dim=-1)
15
 
16
  # generate a response
17
+ chat_history = mdl.generate(bot_input_ids, max_length=1000, pad_token_id=chat_tkn.eos_token_id).tolist()
18
 
19
  # convert the tokens to text, and then split the responses into lines
20
  response = chat_tkn.decode(chat_history[0]).split("<|endoftext|>")
 
27
  html += "<div class='msg {}'> {}</div>".format(cls, msg)
28
  html += "</div>"
29
 
30
+ return html, chat_history
31
 
32
  import gradio as gr
33