File size: 587 Bytes
44466c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
from transformers import pipeline, AutoTokenizer


class gpt2:
    def __init__(self,device="cpu"):
        self.text_generation = pipeline("text-generation", model="gpt2",device=device)
        self.tokenizer = AutoTokenizer.from_pretrained("gpt2")

    def generate_text(self,**kwargs):
        results = self.text_generation(**kwargs)

        return [item['generated_text'] for item in results[0]]

    def get_tokenizer(self):
        return self.tokenizer

if __name__ == '__main__':
    gpt2 = gpt2()
    print(gpt2.generate_text(["Hello, how are you?","I am fine, thank you."]))