smjain commited on
Commit
2524827
·
1 Parent(s): a4bf8d6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -16
app.py CHANGED
@@ -1,38 +1,40 @@
1
- from transformers import AutoModelForCausalLM, AutoTokenizer
 
 
2
  import torch
3
 
4
- tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium")
5
- model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium")
6
 
7
- def predict(input, history=[]):
8
- # tokenize the new input sentence
9
- new_user_input_ids = tokenizer.encode(input + tokenizer.eos_token, return_tensors='pt')
10
 
11
- # append the new user input tokens to the chat history
12
- bot_input_ids = torch.cat([torch.LongTensor(history), new_user_input_ids], dim=-1)
13
 
14
  # generate a response
15
- history = model.generate(bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id).tolist()
16
- print(history)
 
17
  # convert the tokens to text, and then split the responses into lines
18
- response = tokenizer.decode(history[0]).split("<|endoftext|>")
 
19
  print("starting to print response")
20
  print(response)
21
- response.remove("")
22
-
23
 
24
  # write some HTML
25
  html = "<div class='chatbot'>"
26
  for m, msg in enumerate(response):
27
-
28
  cls = "user" if m%2 == 0 else "bot"
29
  print("value of m")
30
  print(m)
 
31
  print (msg)
32
  html += "<div class='msg {}'> {}</div>".format(cls, msg)
33
  html += "</div>"
34
  print(html)
35
- return html, history
36
 
37
  import gradio as gr
38
 
@@ -44,7 +46,7 @@ css = """
44
  .footer {display:none !important}
45
  """
46
 
47
- gr.Interface(fn=predict,
48
  theme="default",
49
  inputs=[gr.inputs.Textbox(placeholder="How are you?"), "state"],
50
  outputs=["html", "state"],
 
1
+ from transformers import BlenderbotTokenizer, BlenderbotForConditionalGeneration
2
+
3
+
4
  import torch
5
 
6
+ chat_tkn = BlenderbotTokenizer.from_pretrained("facebook/blenderbot-400M-distill")
7
+ mdl = BlenderbotForConditionalGeneration.from_pretrained("facebook/blenderbot-400M-distill")
8
 
9
+ def converse(user_input, chat_history=[]):
10
+
11
+ user_input_ids = chat_tkn.encode(user_input + chat_tkn.eos_token, return_tensors='pt')
12
 
13
+ # create a combined tensor with chat history
14
+ bot_input_ids = torch.cat([torch.LongTensor(chat_history), user_input_ids], dim=-1)
15
 
16
  # generate a response
17
+ chat_history = mdl.generate(bot_input_ids, max_length=1000, pad_token_id=chat_tkn.eos_token_id).tolist()
18
+ print (chat_history)
19
+
20
  # convert the tokens to text, and then split the responses into lines
21
+ response = chat_tkn.decode(chat_history[0]).split("<|endoftext|>",skip_special_tokens=True)
22
+ #response.remove("")
23
  print("starting to print response")
24
  print(response)
 
 
25
 
26
  # write some HTML
27
  html = "<div class='chatbot'>"
28
  for m, msg in enumerate(response):
 
29
  cls = "user" if m%2 == 0 else "bot"
30
  print("value of m")
31
  print(m)
32
+ print("message")
33
  print (msg)
34
  html += "<div class='msg {}'> {}</div>".format(cls, msg)
35
  html += "</div>"
36
  print(html)
37
+ return html, chat_history
38
 
39
  import gradio as gr
40
 
 
46
  .footer {display:none !important}
47
  """
48
 
49
+ gr.Interface(fn=converse,
50
  theme="default",
51
  inputs=[gr.inputs.Textbox(placeholder="How are you?"), "state"],
52
  outputs=["html", "state"],