File size: 5,015 Bytes
bc723cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
# utils 

from langchain_chroma import Chroma
from langchain_nomic.embeddings import NomicEmbeddings
from langchain_core.documents import Document
from langchain.retrievers.document_compressors import CohereRerank
#from langchain_core import CohereRerank
#from langchain_cohere import CohereRerank

from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers import EnsembleRetriever
from langchain.retrievers import BM25Retriever
from langchain_groq import ChatGroq

from dotenv import load_dotenv
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import Runnable, RunnableMap
from langchain.schema import BaseRetriever
from qdrant_client import models


from langchain_huggingface.embeddings import HuggingFaceEmbeddings

load_dotenv()

import os
LANGCHAIN_API_KEY = os.getenv('LANGCHAIN_API_KEY')

#Retriever

def get_retriever(n_docs=5):  # renamed function
    vector_database_path = "db"
    
    embedding_model = NomicEmbeddings(model="nomic-embed-text-v1.5", inference_mode="local")
    
    vectorstore = Chroma(collection_name="chromadb3",
                        persist_directory=vector_database_path,
                        embedding_function=embedding_model)
    
    vs_retriever = vectorstore.as_retriever(k=n_docs)
    
    # Get documents from vector store
    try:
        store_data = vectorstore.get()
        texts = store_data['documents']
        metadatas = store_data['metadatas']
        
        if not texts:  # If no documents found
            print("Warning: No documents found in vector store. Using vector retriever only.")
            return vs_retriever
            
        # Create documents with explicit IDs
        documents = []
        for i, (text, metadata) in enumerate(zip(texts, metadatas)):
            doc = Document(
                page_content=text,
                metadata=metadata if metadata else {},
                id_=str(i)  # Add explicit ID
            )
            documents.append(doc)
        
        # Create BM25 retriever with explicit document handling
        keyword_retriever = BM25Retriever.from_texts(
            texts=[doc.page_content for doc in documents],
            metadatas=[doc.metadata for doc in documents],
            ids=[doc.id_ for doc in documents]
        )
        keyword_retriever.k = n_docs
        
        ensemble_retriever = EnsembleRetriever(
            retrievers=[vs_retriever, keyword_retriever],
            weights=[0.5, 0.5]
        )
        
        compressor = CohereRerank(model="rerank-english-v3.0")
        compression_retriever = ContextualCompressionRetriever(
            base_compressor=compressor,
            base_retriever=ensemble_retriever
        )
        
        return compression_retriever
        
    except Exception as e:
        print(f"Warning: Error creating combined retriever ({str(e)}). Using vector retriever only.")
        return vs_retriever

#Retriever prompt
rag_prompt = """You are a medical chatbot designed to answer health-related questions.
The questions you will receive will primarily focus on medical topics and patient care.
Here is the context to use to answer the question:
{context}
Think carefully about the above context.
Now, review the user question:
{input}
Provide an answer to this question using only the above context.
Answer:"""

# Post-processing
def format_docs(docs):
    return "\n\n".join(doc.page_content for doc in docs)

#RAG chain
def get_expression_chain(retriever: BaseRetriever, model_name="llama-3.1-70b-versatile", temp=0 ) -> Runnable:
    """Return a chain defined primarily in LangChain Expression Language"""
    def retrieve_context(input_text):
        # Use the retriever to fetch relevant documents
        docs = retriever.get_relevant_documents(input_text)
        return format_docs(docs)
    
    ingress = RunnableMap(
        {
            "input": lambda x: x["input"],
            "context": lambda x: retrieve_context(x["input"]),
        }
    )
    prompt = ChatPromptTemplate.from_messages(
        [
            (
                "system",
                rag_prompt
            )
        ]
    )
    llm = ChatGroq(model=model_name,api_key="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", temperature=temp)

    chain = ingress | prompt | llm
    return chain

embedding_model = NomicEmbeddings(model="nomic-embed-text-v1.5", inference_mode="local")
#embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")

#Generate embeddings for a given text
def get_embeddings(text):
    return embedding_model.embed([text], task_type='search_document')[0]


# Create or connect to a Qdrant collection
def create_qdrant_collection(client, collection_name):
    if collection_name not in client.get_collections().collections:
        client.create_collection(
            collection_name=collection_name,
            vectors_config=models.VectorParams(size=768, distance=models.Distance.COSINE)
        )