Spaces:
Running
Running
File size: 5,015 Bytes
bc723cd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
# utils
from langchain_chroma import Chroma
from langchain_nomic.embeddings import NomicEmbeddings
from langchain_core.documents import Document
from langchain.retrievers.document_compressors import CohereRerank
#from langchain_core import CohereRerank
#from langchain_cohere import CohereRerank
from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers import EnsembleRetriever
from langchain.retrievers import BM25Retriever
from langchain_groq import ChatGroq
from dotenv import load_dotenv
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import Runnable, RunnableMap
from langchain.schema import BaseRetriever
from qdrant_client import models
from langchain_huggingface.embeddings import HuggingFaceEmbeddings
load_dotenv()
import os
LANGCHAIN_API_KEY = os.getenv('LANGCHAIN_API_KEY')
#Retriever
def get_retriever(n_docs=5): # renamed function
vector_database_path = "db"
embedding_model = NomicEmbeddings(model="nomic-embed-text-v1.5", inference_mode="local")
vectorstore = Chroma(collection_name="chromadb3",
persist_directory=vector_database_path,
embedding_function=embedding_model)
vs_retriever = vectorstore.as_retriever(k=n_docs)
# Get documents from vector store
try:
store_data = vectorstore.get()
texts = store_data['documents']
metadatas = store_data['metadatas']
if not texts: # If no documents found
print("Warning: No documents found in vector store. Using vector retriever only.")
return vs_retriever
# Create documents with explicit IDs
documents = []
for i, (text, metadata) in enumerate(zip(texts, metadatas)):
doc = Document(
page_content=text,
metadata=metadata if metadata else {},
id_=str(i) # Add explicit ID
)
documents.append(doc)
# Create BM25 retriever with explicit document handling
keyword_retriever = BM25Retriever.from_texts(
texts=[doc.page_content for doc in documents],
metadatas=[doc.metadata for doc in documents],
ids=[doc.id_ for doc in documents]
)
keyword_retriever.k = n_docs
ensemble_retriever = EnsembleRetriever(
retrievers=[vs_retriever, keyword_retriever],
weights=[0.5, 0.5]
)
compressor = CohereRerank(model="rerank-english-v3.0")
compression_retriever = ContextualCompressionRetriever(
base_compressor=compressor,
base_retriever=ensemble_retriever
)
return compression_retriever
except Exception as e:
print(f"Warning: Error creating combined retriever ({str(e)}). Using vector retriever only.")
return vs_retriever
#Retriever prompt
rag_prompt = """You are a medical chatbot designed to answer health-related questions.
The questions you will receive will primarily focus on medical topics and patient care.
Here is the context to use to answer the question:
{context}
Think carefully about the above context.
Now, review the user question:
{input}
Provide an answer to this question using only the above context.
Answer:"""
# Post-processing
def format_docs(docs):
return "\n\n".join(doc.page_content for doc in docs)
#RAG chain
def get_expression_chain(retriever: BaseRetriever, model_name="llama-3.1-70b-versatile", temp=0 ) -> Runnable:
"""Return a chain defined primarily in LangChain Expression Language"""
def retrieve_context(input_text):
# Use the retriever to fetch relevant documents
docs = retriever.get_relevant_documents(input_text)
return format_docs(docs)
ingress = RunnableMap(
{
"input": lambda x: x["input"],
"context": lambda x: retrieve_context(x["input"]),
}
)
prompt = ChatPromptTemplate.from_messages(
[
(
"system",
rag_prompt
)
]
)
llm = ChatGroq(model=model_name,api_key="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", temperature=temp)
chain = ingress | prompt | llm
return chain
embedding_model = NomicEmbeddings(model="nomic-embed-text-v1.5", inference_mode="local")
#embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
#Generate embeddings for a given text
def get_embeddings(text):
return embedding_model.embed([text], task_type='search_document')[0]
# Create or connect to a Qdrant collection
def create_qdrant_collection(client, collection_name):
if collection_name not in client.get_collections().collections:
client.create_collection(
collection_name=collection_name,
vectors_config=models.VectorParams(size=768, distance=models.Distance.COSINE)
) |