File size: 2,049 Bytes
b19af15
bad32bd
b19af15
2524827
 
9e350ba
 
 
f8beb9a
9e350ba
 
f8beb9a
2524827
 
6091881
f8beb9a
2524827
85eb8de
f8beb9a
 
2524827
 
 
f8beb9a
b9146b8
2524827
33ca492
 
f8beb9a
 
 
 
 
ec418c6
 
2524827
ec418c6
f8beb9a
 
72bd468
2524827
5e2ff12
cdb2aab
5e2ff12
f8beb9a
 
 
cdb2aab
 
f8beb9a
 
cdb2aab
 
f8beb9a
cdb2aab
f8beb9a
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
#from transformers import BlenderbotTokenizer, BlenderbotForConditionalGeneration
from transformers import AutoModelForCausalLM, AutoTokenizer,BlenderbotForConditionalGeneration
import torch


chat_tkn = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium")
mdl = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium")


#chat_tkn = AutoTokenizer.from_pretrained("facebook/blenderbot-400M-distill")
#mdl = BlenderbotForConditionalGeneration.from_pretrained("facebook/blenderbot-400M-distill")

def converse(user_input, chat_history=[]):
    
    user_input_ids = chat_tkn(user_input + chat_tkn.eos_token, return_tensors='pt').input_ids

    # create a combined tensor with chat history
    bot_input_ids = torch.cat([torch.LongTensor(chat_history), user_input_ids], dim=-1)

    # generate a response 
    chat_history = mdl.generate(bot_input_ids, max_length=1000, pad_token_id=chat_tkn.eos_token_id).tolist()
    print (chat_history)

    # convert the tokens to text, and then split the responses into lines
    response = chat_tkn.decode(chat_history[0]).split("<|endoftext|>")
    #response.remove("")
    print("starting to print response")
    print(response)
    
    # write some HTML
    html = "<div class='chatbot'>"
    for m, msg in enumerate(response):
        cls = "user" if m%2 == 0 else "bot"
        print("value of m")
        print(m)
        print("message")
        print (msg)
        html += "<div class='msg {}'> {}</div>".format(cls, msg)
    html += "</div>"
    print(html)
    return html, chat_history

import gradio as grad

css = """
.chatbox {display:flex;flex-direction:column}
.msg {padding:4px;margin-bottom:4px;border-radius:4px;width:80%}
.msg.user {background-color:blue;color:white}
.msg.bot {background-color:orange;align-self:self-end}
.footer {display:none !important}
"""
text=grad.inputs.Textbox(placeholder="Lets chat")
grad.Interface(fn=converse,
             theme="default",
             inputs=[text, "state"],
             outputs=["html", "state"],
             css=css).launch()