File size: 1,957 Bytes
cc28aad
b73e8d0
e7899d4
b73e8d0
20cda87
 
b73e8d0
20cda87
 
 
 
 
b73e8d0
cc28aad
20cda87
b73e8d0
cc28aad
20cda87
 
b73e8d0
20cda87
cc28aad
20cda87
b73e8d0
cc28aad
b73e8d0
cc28aad
 
 
20cda87
 
cc28aad
20cda87
 
cc28aad
 
b73e8d0
cc28aad
 
 
 
b73e8d0
 
 
 
 
20cda87
cc28aad
 
 
 
b73e8d0
20cda87
cc28aad
 
20cda87
cc28aad
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
from transformers import AutoModelForCausalLM, AutoTokenizer,BlenderbotForConditionalGeneration
import torch
import gradio as gr

chat_tkn = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium")
mdl = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium")

#chat_tkn = AutoTokenizer.from_pretrained("facebook/blenderbot-400M-distill")
#mdl = BlenderbotForConditionalGeneration.from_pretrained("facebook/blenderbot-400M-distill")

def converse(user_input, chat_history=[]):    
    user_input_ids = chat_tkn(user_input + chat_tkn.eos_token, return_tensors='pt').input_ids

    # keep history in the tensor
    bot_input_ids = torch.cat([torch.LongTensor(chat_history), user_input_ids], dim=-1)

    # get response 
    chat_history = mdl.generate(bot_input_ids, max_length=1000, pad_token_id=chat_tkn.eos_token_id).tolist()
    print (chat_history)

    response = chat_tkn.decode(chat_history[0]).split("<|endoftext|>")
    
    print("starting to print response")
    print(response)
    
    # html for display
    html = "<div class='mybot'>"
    for x, mesg in enumerate(response):
        if x%2!=0 :
           mesg="Bot:"+mesg
           clazz="bot"
        else :
           clazz="user"
        
        
        print("value of x")
        print(x)
        print("message")
        print (mesg)
        
        html += "<div class='mesg {}'> {}</div>".format(clazz, mesg)
    html += "</div>"
    print(html)
    return html, chat_history

css = """
.mychat {display:flex;flex-direction:column}
.mesg {padding:5px;margin-bottom:5px;border-radius:5px;width:75%}
.mesg.user {background-color:lightblue;color:white}
.mesg.bot {background-color:orange;color:white,align-self:self-end}
.footer {display:none !important}
"""
text=gr.inputs.Textbox(placeholder="Lets chat")
gr.Interface(fn=converse,
             theme="default",
             inputs=[text, "state"],
             outputs=["html", "state"],
             css=css).launch()