File size: 830 Bytes
3f7f9d9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
from transformers import RagTokenizer, RagRetriever, RagTokenForGeneration
from transformers import pipeline
import torch

def setup_retriever_and_qa():
    tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-base")
    retriever = RagRetriever.from_pretrained("facebook/rag-token-base", index_name="exact", use_dummy_dataset=True)
    rag_model = RagTokenForGeneration.from_pretrained("facebook/rag-token-base")
    qa_pipeline = pipeline("text2text-generation", model=rag_model, tokenizer=tokenizer, device=0 if torch.cuda.is_available() else -1)
    return retriever, qa_pipeline

def get_answer(context: str, question: str, retriever, qa_pipeline):
    input_text = f"question: {question} context: {context}"
    result = qa_pipeline(input_text, max_length=200, do_sample=True)
    return result[0]['generated_text']