Zekun Wu
update
44466c7
raw
history blame
587 Bytes
from transformers import pipeline, AutoTokenizer
class gpt2:
def __init__(self,device="cpu"):
self.text_generation = pipeline("text-generation", model="gpt2",device=device)
self.tokenizer = AutoTokenizer.from_pretrained("gpt2")
def generate_text(self,**kwargs):
results = self.text_generation(**kwargs)
return [item['generated_text'] for item in results[0]]
def get_tokenizer(self):
return self.tokenizer
if __name__ == '__main__':
gpt2 = gpt2()
print(gpt2.generate_text(["Hello, how are you?","I am fine, thank you."]))