File size: 1,923 Bytes
b19af15 bad32bd b19af15 2524827 f8beb9a b19af15 9529237 f8beb9a 2524827 6854f2b f8beb9a 2524827 85eb8de f8beb9a 2524827 f8beb9a f545a57 2524827 33ca492 f8beb9a ec418c6 2524827 ec418c6 f8beb9a 72bd468 2524827 5e2ff12 f8beb9a 2524827 f8beb9a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 |
#from transformers import BlenderbotTokenizer, BlenderbotForConditionalGeneration
from transformers import AutoModelForCausalLM, AutoTokenizer,BlenderbotForConditionalGeneration
import torch
import torch
chat_tkn = AutoTokenizer.from_pretrained("facebook/blenderbot-400M-distill")
mdl = BlenderbotForConditionalGeneration.from_pretrained("facebook/blenderbot-400M-distill")
def converse(user_input, chat_history=[]):
user_input_ids = chat_tkn.encode(user_input + chat_tkn.eos_token, return_tensors='pt')
# create a combined tensor with chat history
bot_input_ids = torch.cat([torch.LongTensor(chat_history), user_input_ids], dim=-1)
# generate a response
chat_history = mdl.generate(bot_input_ids, max_length=1000, pad_token_id=chat_tkn.eos_token_id).tolist()
print (chat_history)
# convert the tokens to text, and then split the responses into lines
response = chat_tkn.batch_decode(chat_history[0],skip_special_tokens=True)
#response.remove("")
print("starting to print response")
print(response)
# write some HTML
html = "<div class='chatbot'>"
for m, msg in enumerate(response):
cls = "user" if m%2 == 0 else "bot"
print("value of m")
print(m)
print("message")
print (msg)
html += "<div class='msg {}'> {}</div>".format(cls, msg)
html += "</div>"
print(html)
return html, chat_history
import gradio as gr
css = """
.chatbox {display:flex;flex-direction:column}
.msg {padding:4px;margin-bottom:4px;border-radius:4px;width:80%}
.msg.user {background-color:cornflowerblue;color:white}
.msg.bot {background-color:lightgray;align-self:self-end}
.footer {display:none !important}
"""
gr.Interface(fn=converse,
theme="default",
inputs=[gr.inputs.Textbox(placeholder="How are you?"), "state"],
outputs=["html", "state"],
css=css).launch() |