File size: 2,773 Bytes
71f2227
3e149d5
 
71f2227
 
0a042d8
cdfe192
703a8b7
0a042d8
61a2cb3
 
 
 
 
 
 
 
 
 
 
 
 
 
fbd4cfd
 
0d07fde
fbd4cfd
 
870b1fe
482fbde
fd58ec5
 
2dfc510
ed80cd3
3e149d5
 
 
55f3dd8
c669d92
 
482fbde
c669d92
482fbde
c669d92
 
482fbde
08d35ca
703a8b7
db0e2dc
cdfe192
247e692
 
db0e2dc
cdfe192
 
db0e2dc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModelForSeq2SeqLM, T5ForConditionalGeneration, T5Tokenizer
tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-large")
model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-large")
grammar_tokenizer = T5Tokenizer.from_pretrained('deep-learning-analytics/GrammarCorrector')
grammar_model = T5ForConditionalGeneration.from_pretrained('deep-learning-analytics/GrammarCorrector')
import torch
import gradio as gr


# def chat(message, history):
#     history = history if history is not None else []
#     new_user_input_ids = tokenizer.encode(message+tokenizer.eos_token, return_tensors='pt')
#     bot_input_ids = torch.cat([torch.LongTensor(history), new_user_input_ids], dim=-1)
#     history = model.generate(bot_input_ids, max_length=500, pad_token_id=tokenizer.eos_token_id).tolist()
#     # response = tokenizer.decode(history[0]).replace("<|endoftext|>", "\n")
#     # pretty print last ouput tokens from bot
#     response =  tokenizer.decode(bot_input_ids.shape[-1][0], skip_special_tokens=True)
#     print("The response is ", [response])
#     # history.append((message, response, new_user_input_ids, chat_history_ids))
#     return response, history, feedback(message)


def chat(message, history=[]):
    new_user_input_ids = tokenizer.encode(message+tokenizer.eos_token, return_tensors='pt')
    if  len(history) > 0:
        last_set_of_ids = history[len(history)-1][2]
        bot_input_ids = torch.cat([last_set_of_ids, new_user_input_ids], dim=-1) 
    else:
        bot_input_ids = new_user_input_ids
    chat_history_ids = model.generate(bot_input_ids, max_length=5000, pad_token_id=tokenizer.eos_token_id)
    response_ids = chat_history_ids[:, bot_input_ids.shape[-1]:][0]
    response = tokenizer.decode(response_ids, skip_special_tokens=True)
    history.append((message, response, chat_history_ids))
    return history, history, feedback(message)


def feedback(text):
    num_return_sequences=1
    batch =  grammar_tokenizer([text],truncation=True,padding='max_length',max_length=64, return_tensors="pt")
    corrections= grammar_model.generate(**batch,max_length=64,num_beams=2, num_return_sequences=num_return_sequences, temperature=1.5)
    corrected_text = tokenizer.decode(corrections[0], clean_up_tokenization_spaces=True, skip_special_tokens=True)
    print("The corrections are: ", corrections)
    if corrected_text == text:
        feedback = f'Looks good! Keep up the good work'
    else:
        feedback = f'\'{" ".join(corrected_text)}\' might be a little better'
    return feedback

iface = gr.Interface(
    chat,
    ["text", "state"],
    ["chatbot", "state", "text"],
    allow_screenshot=False,
    allow_flagging="never",
)
iface.launch()