|
from transformers import pipeline, AutoTokenizer |
|
|
|
|
|
class gpt2: |
|
def __init__(self,device="cpu"): |
|
self.text_generation = pipeline("text-generation", model="gpt2",device=device) |
|
self.tokenizer = AutoTokenizer.from_pretrained("gpt2") |
|
|
|
def generate_text(self,**kwargs): |
|
results = self.text_generation(**kwargs) |
|
|
|
return [item['generated_text'] for item in results[0]] |
|
|
|
def get_tokenizer(self): |
|
return self.tokenizer |
|
|
|
if __name__ == '__main__': |
|
gpt2 = gpt2() |
|
print(gpt2.generate_text(["Hello, how are you?","I am fine, thank you."])) |
|
|