smjain commited on
Commit
a4bf8d6
·
1 Parent(s): ec418c6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -18
app.py CHANGED
@@ -1,40 +1,38 @@
1
- from transformers import BlenderbotTokenizer, BlenderbotForConditionalGeneration
2
-
3
-
4
  import torch
5
 
6
- chat_tkn = BlenderbotTokenizer.from_pretrained("facebook/blenderbot-400M-distill")
7
- mdl = BlenderbotForConditionalGeneration.from_pretrained("facebook/blenderbot-400M-distill")
8
 
9
- def converse(user_input, chat_history=[]):
10
-
11
- user_input_ids = chat_tkn.encode(user_input + chat_tkn.eos_token, return_tensors='pt')
12
 
13
- # create a combined tensor with chat history
14
- bot_input_ids = torch.cat([torch.LongTensor(chat_history), user_input_ids], dim=-1)
15
 
16
  # generate a response
17
- chat_history = mdl.generate(bot_input_ids, max_length=1000, pad_token_id=chat_tkn.eos_token_id).tolist()
18
- print (chat_history)
19
-
20
  # convert the tokens to text, and then split the responses into lines
21
- response = chat_tkn.decode(chat_history[0]).split("<|endoftext|>")
22
- #response.remove("")
23
  print("starting to print response")
24
  print(response)
 
 
25
 
26
  # write some HTML
27
  html = "<div class='chatbot'>"
28
  for m, msg in enumerate(response):
 
29
  cls = "user" if m%2 == 0 else "bot"
30
  print("value of m")
31
  print(m)
32
- print("message")
33
  print (msg)
34
  html += "<div class='msg {}'> {}</div>".format(cls, msg)
35
  html += "</div>"
36
  print(html)
37
- return html, chat_history
38
 
39
  import gradio as gr
40
 
@@ -46,7 +44,7 @@ css = """
46
  .footer {display:none !important}
47
  """
48
 
49
- gr.Interface(fn=converse,
50
  theme="default",
51
  inputs=[gr.inputs.Textbox(placeholder="How are you?"), "state"],
52
  outputs=["html", "state"],
 
1
+ from transformers import AutoModelForCausalLM, AutoTokenizer
 
 
2
  import torch
3
 
4
+ tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium")
5
+ model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium")
6
 
7
+ def predict(input, history=[]):
8
+ # tokenize the new input sentence
9
+ new_user_input_ids = tokenizer.encode(input + tokenizer.eos_token, return_tensors='pt')
10
 
11
+ # append the new user input tokens to the chat history
12
+ bot_input_ids = torch.cat([torch.LongTensor(history), new_user_input_ids], dim=-1)
13
 
14
  # generate a response
15
+ history = model.generate(bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id).tolist()
16
+ print(history)
 
17
  # convert the tokens to text, and then split the responses into lines
18
+ response = tokenizer.decode(history[0]).split("<|endoftext|>")
 
19
  print("starting to print response")
20
  print(response)
21
+ response.remove("")
22
+
23
 
24
  # write some HTML
25
  html = "<div class='chatbot'>"
26
  for m, msg in enumerate(response):
27
+
28
  cls = "user" if m%2 == 0 else "bot"
29
  print("value of m")
30
  print(m)
 
31
  print (msg)
32
  html += "<div class='msg {}'> {}</div>".format(cls, msg)
33
  html += "</div>"
34
  print(html)
35
+ return html, history
36
 
37
  import gradio as gr
38
 
 
44
  .footer {display:none !important}
45
  """
46
 
47
+ gr.Interface(fn=predict,
48
  theme="default",
49
  inputs=[gr.inputs.Textbox(placeholder="How are you?"), "state"],
50
  outputs=["html", "state"],