basicTransformersExample / AI /question_answering.py
ricardo-lsantos's picture
Added AI and Pages for first few examples
91e858d
raw
history blame
944 Bytes
# Author: Ricardo Lisboa Santos
# Creation date: 2024-01-10
import torch
import torch_directml
from transformers import pipeline
def getDevice(DEVICE):
device = None
if DEVICE == "cpu":
device = torch.device("cpu")
dtype = torch.float32
elif DEVICE == "cuda":
device = torch.device("cuda")
dtype = torch.float16
elif DEVICE == "directml":
device = torch_directml.device()
dtype = torch.float16
return device
def loadGenerator(device):
generator = pipeline("question-answering") # .to(device)
return generator
def query(generator, question, context):
output = generator(
question=question,
context=context,
)
return output
def clearCache(DEVICE, generator):
generator.tokenizer.save_pretrained("cache")
generator.model.save_pretrained("cache")
del generator
if DEVICE == "directml":
torch_directml.empty_cache()