File size: 2,104 Bytes
b19af15
bad32bd
b19af15
2524827
 
9e350ba
 
 
f8beb9a
9e350ba
 
f8beb9a
2524827
 
6091881
f8beb9a
2524827
85eb8de
f8beb9a
 
2524827
 
 
f8beb9a
b9146b8
2524827
33ca492
 
f8beb9a
 
649ab52
 
5f468a8
649ab52
 
2524827
649ab52
5f468a8
dd85d27
f8beb9a
72bd468
2524827
5e2ff12
cdb2aab
5e2ff12
f8beb9a
5f468a8
649ab52
 
5f468a8
f8beb9a
 
cdb2aab
 
f8beb9a
cdb2aab
f8beb9a
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
#from transformers import BlenderbotTokenizer, BlenderbotForConditionalGeneration
from transformers import AutoModelForCausalLM, AutoTokenizer,BlenderbotForConditionalGeneration
import torch


chat_tkn = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium")
mdl = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium")


#chat_tkn = AutoTokenizer.from_pretrained("facebook/blenderbot-400M-distill")
#mdl = BlenderbotForConditionalGeneration.from_pretrained("facebook/blenderbot-400M-distill")

def converse(user_input, chat_history=[]):
    
    user_input_ids = chat_tkn(user_input + chat_tkn.eos_token, return_tensors='pt').input_ids

    # create a combined tensor with chat history
    bot_input_ids = torch.cat([torch.LongTensor(chat_history), user_input_ids], dim=-1)

    # generate a response 
    chat_history = mdl.generate(bot_input_ids, max_length=1000, pad_token_id=chat_tkn.eos_token_id).tolist()
    print (chat_history)

    # convert the tokens to text, and then split the responses into lines
    response = chat_tkn.decode(chat_history[0]).split("<|endoftext|>")
    #response.remove("")
    print("starting to print response")
    print(response)
    
    # write some HTML
    html = "<div class='mybot'>"
    for x, mesg in enumerate(response):
        cls = "user" if x%2 == 0 else "alicia"
        print("value of x")
        print(x)
        print("message")
        print (mesg)
        mesg="Alicia:"+mesg
        html += "<div class='mesg {}'> {}</div>".format(cls, mesg)
    html += "</div>"
    print(html)
    return html, chat_history

import gradio as grad

css = """
.mychat {display:flex;flex-direction:column}
.mesg {padding:5px;margin-bottom:5px;border-radius:5px;width:75%}
.mesg.user {background-color:lightblue;color:white}
.mesg.alicia {background-color:orange;color:white,align-self:self-end}
.footer {display:none !important}
"""
text=grad.inputs.Textbox(placeholder="Lets chat")
grad.Interface(fn=converse,
             theme="default",
             inputs=[text, "state"],
             outputs=["html", "state"],
             css=css).launch()