nugentc commited on
Commit
245a96a
·
1 Parent(s): 10ac208

add chat agent

Browse files
Files changed (1) hide show
  1. app.py +14 -12
app.py CHANGED
@@ -7,18 +7,20 @@ import torch
7
  import gradio as gr
8
 
9
 
10
- def chat(message, history):
11
  history = history or []
12
- if message.startswith("How many"):
13
- response = random.randint(1, 10)
14
- elif message.startswith("How"):
15
- response = random.choice(["Great", "Good", "Okay", "Bad"])
16
- elif message.startswith("Where"):
17
- response = random.choice(["Here", "There", "Somewhere"])
18
- else:
19
- response = "I don't know"
 
 
20
  history.append((message, response))
21
- return history, history, feedback(message)
22
 
23
 
24
  def feedback(text):
@@ -36,8 +38,8 @@ def feedback(text):
36
 
37
  iface = gr.Interface(
38
  chat,
39
- ["text", "state"],
40
- ["chatbot", "state", "text"],
41
  allow_screenshot=False,
42
  allow_flagging="never",
43
  )
 
7
  import gradio as gr
8
 
9
 
10
+ def chat(message, history, bot_input_ids):
11
  history = history or []
12
+ bot_input_ids = bot_input_ids or []
13
+ new_user_input_ids = tokenizer.encode(message+tokenizer.eos_token, return_tensors='pt')
14
+ # append the new user input tokens to the chat history
15
+ bot_input_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1) if chat_history_ids is not None else new_user_input_ids
16
+
17
+ # generated a response while limiting the total chat history to 1000 tokens,
18
+ chat_history_ids = model.generate(bot_input_ids, max_length=5000, pad_token_id=tokenizer.eos_token_id)
19
+ print("The text is ", [text])
20
+ # pretty print last ouput tokens from bot
21
+ reponse = tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)
22
  history.append((message, response))
23
+ return history, bot_input_ids, feedback(message)
24
 
25
 
26
  def feedback(text):
 
38
 
39
  iface = gr.Interface(
40
  chat,
41
+ ["text", "state", "state"],
42
+ ["chatbot", "state", "state", "text"],
43
  allow_screenshot=False,
44
  allow_flagging="never",
45
  )