nugentc commited on
Commit
247e692
·
1 Parent(s): 245a96a

try to avoid tokenising inputs multiple times

Browse files
Files changed (1) hide show
  1. app.py +5 -6
app.py CHANGED
@@ -7,19 +7,18 @@ import torch
7
  import gradio as gr
8
 
9
 
10
- def chat(message, history, bot_input_ids):
11
  history = history or []
12
- bot_input_ids = bot_input_ids or []
13
  new_user_input_ids = tokenizer.encode(message+tokenizer.eos_token, return_tensors='pt')
14
  # append the new user input tokens to the chat history
 
15
  bot_input_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1) if chat_history_ids is not None else new_user_input_ids
16
-
17
  # generated a response while limiting the total chat history to 1000 tokens,
18
  chat_history_ids = model.generate(bot_input_ids, max_length=5000, pad_token_id=tokenizer.eos_token_id)
19
  print("The text is ", [text])
20
  # pretty print last ouput tokens from bot
21
  reponse = tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)
22
- history.append((message, response))
23
  return history, bot_input_ids, feedback(message)
24
 
25
 
@@ -38,8 +37,8 @@ def feedback(text):
38
 
39
  iface = gr.Interface(
40
  chat,
41
- ["text", "state", "state"],
42
- ["chatbot", "state", "state", "text"],
43
  allow_screenshot=False,
44
  allow_flagging="never",
45
  )
 
7
  import gradio as gr
8
 
9
 
10
+ def chat(message, history):
11
  history = history or []
 
12
  new_user_input_ids = tokenizer.encode(message+tokenizer.eos_token, return_tensors='pt')
13
  # append the new user input tokens to the chat history
14
+ bot_input_ids = [item[2], item[3] for item in history]
15
  bot_input_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1) if chat_history_ids is not None else new_user_input_ids
 
16
  # generated a response while limiting the total chat history to 1000 tokens,
17
  chat_history_ids = model.generate(bot_input_ids, max_length=5000, pad_token_id=tokenizer.eos_token_id)
18
  print("The text is ", [text])
19
  # pretty print last ouput tokens from bot
20
  reponse = tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)
21
+ history.append((message, response, new_user_input_ids, chat_history_ids))
22
  return history, bot_input_ids, feedback(message)
23
 
24
 
 
37
 
38
  iface = gr.Interface(
39
  chat,
40
+ ["text", "state"],
41
+ ["chatbot", "state", "text"],
42
  allow_screenshot=False,
43
  allow_flagging="never",
44
  )