nikhilkomakula's picture
Added RAG Evaluation Code
5e8a58c
# import libraries
from langchain_core.prompts import ChatPromptTemplate
from langchain_community.llms import HuggingFaceEndpoint
from langchain_core.runnables import RunnablePassthrough
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnableParallel
# import functions
from ..indexing.build_indexes import retrieve_indexes
# instantiate base retriever
def get_base_retriever(embedding_model, k=4, search_type="mmr"):
"""
Instantiates base retriever.
Args:
embedding_model(str): Hugging Face Embedding Model name.
k (int, optional): Top k results to retrieve. Defaults to 4.
search_type (str, optional): Search type (mmr or similarity). Defaults to 'mmr'.
Returns:
VectorStoreRetriever: Returns base retriever.
"""
# get the vector store of indexes
vector_store = retrieve_indexes(embedding_model)
base_retriever = vector_store.as_retriever(
search_type=search_type, search_kwargs={"k": k}
)
return base_retriever
# define prompt template
def create_prompt_template():
"""
Creates prompt template.
Returns:
PromptTemplate: Returns prompt template.
"""
prompt_template = """
<|system|>
You are an AI assistant for question-answering tasks. Use the provided context to answer the question. If you don't know the answer, just say that you don't know. The generated answer should be relevant to the question being asked, short and concise. Do not be creative and do not make up the answer. Make sure the generated answer always starts with a word.</s>
{context}</s>
<|user|>
{query}</s>
<|assistant|>
"""
chat_prompt_template = ChatPromptTemplate.from_template(prompt_template)
return chat_prompt_template
# define llm
def load_hf_llm(repo_id, max_new_tokens=512, temperature=0.2):
"""
Loads Hugging Face Endpoint for inference.
Args:
repo_id (str): HuggingFace Model Repo ID.
max_new_tokens (int, optional): Maximum number of new tokens to generate. Defaults to 512.
temperature (float, optional): Temperature setting. Defaults to 0.2.
Returns:
HuggingFaceEndpoint: Returns HuggingFace Endpoint.
"""
hf_llm = HuggingFaceEndpoint(
repo_id=repo_id,
max_new_tokens=max_new_tokens,
temperature=temperature,
do_sample=True,
repetition_penalty=1.1,
return_full_text=False,
)
return hf_llm
# define retrieval chain
def create_qa_chain(retriever, llm):
"""
Instantiates qa chain.
Args:
retriever (VectorStoreRetriever): Vector store.
llm (HuggingFaceEndpoint): HuggingFace endpoint.
Returns:
Runnable: Returns qa chain.
"""
def format_docs(docs):
return "\n\n".join(doc.page_content for doc in docs)
qa_chain = (
{"context": retriever | format_docs, "query": RunnablePassthrough()}
| create_prompt_template()
| llm
| StrOutputParser()
)
return qa_chain
# define retrieval chain for evaluation
def create_qa_chain_eval(retriever, llm):
"""
Instantiates qa chain for evaluation.
Args:
retriever (VectorStoreRetriever): Vector store.
llm (HuggingFaceEndpoint): HuggingFace endpoint.
Returns:
Runnable: Returns qa chain.
"""
def format_docs(docs):
return "\n\n".join(doc.page_content for doc in docs)
rag_chain_from_docs = (
RunnablePassthrough.assign(context=(lambda x: format_docs(x["context"])))
| create_prompt_template()
| llm
| StrOutputParser()
)
rag_chain_with_source = RunnableParallel(
{"context": retriever, "query": RunnablePassthrough()}
).assign(result=rag_chain_from_docs)
return rag_chain_with_source