File size: 3,249 Bytes
1284099
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
from typing import List, Union
from dotenv import find_dotenv, load_dotenv
from langchain.chains import RetrievalQA
from langchain.chat_models import init_chat_model
from langchain.schema import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import FAISS
from langchain_huggingface.embeddings import HuggingFaceEmbeddings


def get_default_splitter() -> RecursiveCharacterTextSplitter:
    """Returns a pre-configured text splitter."""
    return RecursiveCharacterTextSplitter(
        # Using markdown headers as separators is a good strategy
        separators=["\n### ", "\n## ", "\n# ", "\n\n", "\n", " "],
        chunk_size=1000,
        chunk_overlap=200,
    )

def get_default_embeddings() -> HuggingFaceEmbeddings:
    """Returns a pre-configured embedding model."""
    return HuggingFaceEmbeddings(
        model_name="sentence-transformers/all-MiniLM-L6-v2",
        model_kwargs={'device': 'cpu'}
    )


def build_retriever(
        data: Union[str, List[Document]],
        splitter: RecursiveCharacterTextSplitter = None,
        embeddings: HuggingFaceEmbeddings = None,
        top_k: int = 5):
    """Builds a retriever from either a raw text string or a list of documents.

    Args:
        Args:
        data (Union[str, List[Document]]): The source data to build the retriever from.
        splitter (RecursiveCharacterTextSplitter, optional): The text splitter to use.
                                                            Defaults to get_default_splitter().
        embeddings (HuggingFaceEmbeddings, optional): The embedding model to use.
                                                     Defaults to get_default_embeddings().
        top_k (int, optional): The number of top results to return. Defaults to 5.
    """
    splitter = splitter or get_default_splitter()
    embeddings = embeddings or get_default_embeddings()
    if isinstance(data, str):
        # If the input is a raw string, split it into chunks first
        chunks = splitter.split_text(data)
        # Then convert those chunks into Document objects
        docs = [Document(page_content=chunk) for chunk in chunks]
    elif isinstance(data, list):
        # If the input is already a list of documents, split them directly
        docs = splitter.split_documents(data)
    else:
        raise ValueError(f"Unsupported data type: {type(data)}. Must be str or List[Document].")

    index = FAISS.from_documents(docs, embeddings)
    return index.as_retriever(search_kwargs={"k": top_k})


def create_retrieval_qa(
        retriever,
        llm=None
    ) -> RetrievalQA:
    """Creates a RetrievalQA instance from a given retriever and LLM.

    Args:
        retriever (BaseRetriever): The retriever to be used by the QA chain.
        llm (LLM, optional): The language model to use. If not provided,
                                a default model will be initialized.
    """
    if llm is None:
        load_dotenv(find_dotenv())
        llm = init_chat_model("groq:meta-llama/llama-4-scout-17b-16e-instruct")
    return RetrievalQA.from_chain_type(
        llm=llm,
        chain_type="stuff",
        retriever=retriever,
        return_source_documents=True,
    )