File size: 2,128 Bytes
e04cd14
7f7b773
e04cd14
7f7b773
 
 
 
 
 
e04cd14
7f7b773
e04cd14
7f7b773
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e04cd14
7f7b773
 
 
 
 
 
 
 
 
e04cd14
 
7f7b773
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
from langchain.prompts.prompt import PromptTemplate
from langchain.chains import RetrievalQA
from langchain.llms.huggingface_pipeline import HuggingFacePipeline

class HuggingFaceQuestionAnswering:
    def __init__(self, retriever) -> None:
        self.retriever = retriever
        self.llm = HuggingFacePipeline.from_model_id(
            # model_id="bigscience/bloom-1b7",
            model_id="bigscience/bloomz-1b7",
            task="text-generation",
            # device=1,
            # model_kwargs={"do_sample": True, "temperature": 0.7, "num_beams": 4, "top_p": 0.95, "repetition_penalty": 1.25, "length_penalty": 1.2},
            model_kwargs={"do_sample": True, "temperature": 0.7, "num_beams": 2},
            # pipeline_kwargs={"max_new_tokens": 256, "min_new_tokens": 30},
            pipeline_kwargs={"max_new_tokens": 256, "min_new_tokens": 30},
        )
        self.chain = None

    def initialize(self):
        template = """Use the information contained in the following text: {context}. Complete the phrase: {question} """
        prompt_template = PromptTemplate(
            template=template,
            input_variables=["context", "question"],
        )
        # self.chain = RetrievalQA.from_chain_type(self.llm, retriever=self.retriever.retriever, chain_type_kwargs={"prompt": prompt_template})

    def answer_question(self, question: str, filter_dict):
        retriever = self.retriever.vector_store.db.as_retriever(search_kwargs={"filter": filter_dict, "fetch_k": 150})
        # retriever = self.retriever.retriever

        try:
            self.chain = RetrievalQA.from_chain_type(self.llm, retriever=retriever, return_source_documents=True)
            result = self.chain({"query": question})
            docs = '\n'.join([x.metadata["paper_title"][:40] + " - " + x.page_content[:40].replace("\n", " ") + "..."  for x in result["source_documents"]])
            print(f"""
Retrieved Documents:
{docs if docs != "" else "No documents found."}""")
            return result
        except Exception as e:
            print(e)
            return {"result": "Error generating answer."}