|
from transformers import BlenderbotTokenizer, BlenderbotForConditionalGeneration |
|
import gradio as grad |
|
|
|
chat_tkn = BlenderbotTokenizer.from_pretrained("facebook/blenderbot-400M-distill") |
|
mdl = BlenderbotForConditionalGeneration.from_pretrained("facebook/blenderbot-400M-distill") |
|
|
|
def createHistory(message): |
|
history = grad.get_state() or [] |
|
print(history) |
|
response = chat(message) |
|
history.append((message, response)) |
|
grad.set_state(history) |
|
html = "<div class='chatbot'>" |
|
for user_msg, resp_msg in history: |
|
html += f"<div class='user_msg'>{user_msg}</div>" |
|
html += f"<div class='resp_msg'>{resp_msg}</div>" |
|
html += "</div>" |
|
return response |
|
|
|
def chat(input): |
|
|
|
tkn_ids = chat_tkn(input+ chat_tkn.eos_token, return_tensors='pt') |
|
|
|
|
|
chat_ids = mdl.generate(**tkn_ids) |
|
|
|
|
|
response= "Alicia: {}".format(chat_tkn.decode(chat_ids[0], skip_special_tokens=True)) |
|
|
|
return response |
|
|
|
out=grad.Textbox(lines=20, label="dialog", placeholder="start conversation") |
|
grad.Interface(createHistory, inputs="text",outputs=out).launch() |
|
|
|
|