File size: 1,957 Bytes
cc28aad b73e8d0 e7899d4 b73e8d0 20cda87 b73e8d0 20cda87 b73e8d0 cc28aad 20cda87 b73e8d0 cc28aad 20cda87 b73e8d0 20cda87 cc28aad 20cda87 b73e8d0 cc28aad b73e8d0 cc28aad 20cda87 cc28aad 20cda87 cc28aad b73e8d0 cc28aad b73e8d0 20cda87 cc28aad b73e8d0 20cda87 cc28aad 20cda87 cc28aad |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 |
from transformers import AutoModelForCausalLM, AutoTokenizer,BlenderbotForConditionalGeneration
import torch
import gradio as gr
chat_tkn = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium")
mdl = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium")
#chat_tkn = AutoTokenizer.from_pretrained("facebook/blenderbot-400M-distill")
#mdl = BlenderbotForConditionalGeneration.from_pretrained("facebook/blenderbot-400M-distill")
def converse(user_input, chat_history=[]):
user_input_ids = chat_tkn(user_input + chat_tkn.eos_token, return_tensors='pt').input_ids
# keep history in the tensor
bot_input_ids = torch.cat([torch.LongTensor(chat_history), user_input_ids], dim=-1)
# get response
chat_history = mdl.generate(bot_input_ids, max_length=1000, pad_token_id=chat_tkn.eos_token_id).tolist()
print (chat_history)
response = chat_tkn.decode(chat_history[0]).split("<|endoftext|>")
print("starting to print response")
print(response)
# html for display
html = "<div class='mybot'>"
for x, mesg in enumerate(response):
if x%2!=0 :
mesg="Bot:"+mesg
clazz="bot"
else :
clazz="user"
print("value of x")
print(x)
print("message")
print (mesg)
html += "<div class='mesg {}'> {}</div>".format(clazz, mesg)
html += "</div>"
print(html)
return html, chat_history
css = """
.mychat {display:flex;flex-direction:column}
.mesg {padding:5px;margin-bottom:5px;border-radius:5px;width:75%}
.mesg.user {background-color:lightblue;color:white}
.mesg.bot {background-color:orange;color:white,align-self:self-end}
.footer {display:none !important}
"""
text=gr.inputs.Textbox(placeholder="Lets chat")
gr.Interface(fn=converse,
theme="default",
inputs=[text, "state"],
outputs=["html", "state"],
css=css).launch() |