nugentc commited on
Commit
5155493
·
1 Parent(s): 16098ad

simplify things

Browse files
Files changed (1) hide show
  1. app.py +6 -17
app.py CHANGED
@@ -10,25 +10,14 @@ import gradio as gr
10
  def chat(message, history):
11
  history = history if history is not None else []
12
  new_user_input_ids = tokenizer.encode(message+tokenizer.eos_token, return_tensors='pt')
13
- # append the new user input tokens to the chat history
14
- f1 = lambda x: x[2]
15
- f2 = lambda x: x[3]
16
- if len(history) > 0:
17
- bot_input_ids = torch.IntTensor([f(item) for item in history for f in (f1, f2)])
18
- else:
19
- bot_input_ids = torch.IntTensor([[]])
20
- print("bot_input_ids", bot_input_ids)
21
- print("new_user_input_ids", new_user_input_ids)
22
- print("sizr bot_input_ids", bot_input_ids.size())
23
- print("size new_user_input_id", new_user_input_ids.size())
24
- bot_input_ids = torch.cat([bot_input_ids, new_user_input_ids], dim=0) if bot_input_ids is not None else new_user_input_ids
25
- # generated a response while limiting the total chat history to 1000 tokens,
26
- chat_history_ids = model.generate(bot_input_ids, max_length=5000, pad_token_id=tokenizer.eos_token_id)
27
  # pretty print last ouput tokens from bot
28
- response = tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)
29
  print("The response is ", [response])
30
- history.append((message, response, new_user_input_ids, chat_history_ids))
31
- return history, bot_input_ids, feedback(message)
32
 
33
 
34
  def feedback(text):
 
10
  def chat(message, history):
11
  history = history if history is not None else []
12
  new_user_input_ids = tokenizer.encode(message+tokenizer.eos_token, return_tensors='pt')
13
+ bot_input_ids = torch.cat([torch.LongTensor(history), new_user_input_ids], dim=-1)
14
+ history = model.generate(bot_input_ids, max_length=500, pad_token_id=tokenizer.eos_token_id).tolist()
15
+ response = tokenizer.decode(history[0]).replace("<|endoftext|>", "\n")
 
 
 
 
 
 
 
 
 
 
 
16
  # pretty print last ouput tokens from bot
17
+ # response = tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)
18
  print("The response is ", [response])
19
+ # history.append((message, response, new_user_input_ids, chat_history_ids))
20
+ return response, feedback(message)
21
 
22
 
23
  def feedback(text):