Spaces:
Build error
Build error
File size: 2,773 Bytes
71f2227 3e149d5 71f2227 0a042d8 cdfe192 703a8b7 0a042d8 61a2cb3 fbd4cfd 0d07fde fbd4cfd 870b1fe 482fbde fd58ec5 2dfc510 ed80cd3 3e149d5 55f3dd8 c669d92 482fbde c669d92 482fbde c669d92 482fbde 08d35ca 703a8b7 db0e2dc cdfe192 247e692 db0e2dc cdfe192 db0e2dc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 |
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModelForSeq2SeqLM, T5ForConditionalGeneration, T5Tokenizer
tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-large")
model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-large")
grammar_tokenizer = T5Tokenizer.from_pretrained('deep-learning-analytics/GrammarCorrector')
grammar_model = T5ForConditionalGeneration.from_pretrained('deep-learning-analytics/GrammarCorrector')
import torch
import gradio as gr
# def chat(message, history):
# history = history if history is not None else []
# new_user_input_ids = tokenizer.encode(message+tokenizer.eos_token, return_tensors='pt')
# bot_input_ids = torch.cat([torch.LongTensor(history), new_user_input_ids], dim=-1)
# history = model.generate(bot_input_ids, max_length=500, pad_token_id=tokenizer.eos_token_id).tolist()
# # response = tokenizer.decode(history[0]).replace("<|endoftext|>", "\n")
# # pretty print last ouput tokens from bot
# response = tokenizer.decode(bot_input_ids.shape[-1][0], skip_special_tokens=True)
# print("The response is ", [response])
# # history.append((message, response, new_user_input_ids, chat_history_ids))
# return response, history, feedback(message)
def chat(message, history=[]):
new_user_input_ids = tokenizer.encode(message+tokenizer.eos_token, return_tensors='pt')
if len(history) > 0:
last_set_of_ids = history[len(history)-1][2]
bot_input_ids = torch.cat([last_set_of_ids, new_user_input_ids], dim=-1)
else:
bot_input_ids = new_user_input_ids
chat_history_ids = model.generate(bot_input_ids, max_length=5000, pad_token_id=tokenizer.eos_token_id)
response_ids = chat_history_ids[:, bot_input_ids.shape[-1]:][0]
response = tokenizer.decode(response_ids, skip_special_tokens=True)
history.append((message, response, chat_history_ids))
return history, history, feedback(message)
def feedback(text):
num_return_sequences=1
batch = grammar_tokenizer([text],truncation=True,padding='max_length',max_length=64, return_tensors="pt")
corrections= grammar_model.generate(**batch,max_length=64,num_beams=2, num_return_sequences=num_return_sequences, temperature=1.5)
corrected_text = tokenizer.decode(corrections[0], clean_up_tokenization_spaces=True, skip_special_tokens=True)
print("The corrections are: ", corrections)
if corrected_text == text:
feedback = f'Looks good! Keep up the good work'
else:
feedback = f'\'{" ".join(corrected_text)}\' might be a little better'
return feedback
iface = gr.Interface(
chat,
["text", "state"],
["chatbot", "state", "text"],
allow_screenshot=False,
allow_flagging="never",
)
iface.launch()
|