File size: 587 Bytes
44466c7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 |
from transformers import pipeline, AutoTokenizer
class gpt2:
def __init__(self,device="cpu"):
self.text_generation = pipeline("text-generation", model="gpt2",device=device)
self.tokenizer = AutoTokenizer.from_pretrained("gpt2")
def generate_text(self,**kwargs):
results = self.text_generation(**kwargs)
return [item['generated_text'] for item in results[0]]
def get_tokenizer(self):
return self.tokenizer
if __name__ == '__main__':
gpt2 = gpt2()
print(gpt2.generate_text(["Hello, how are you?","I am fine, thank you."]))
|