Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
| import os | |
| import random | |
| from functools import cache | |
| from operator import itemgetter | |
| import langsmith | |
| from langchain.memory import ConversationBufferWindowMemory | |
| from langchain.retrievers import EnsembleRetriever | |
| from langchain_community.document_transformers import LongContextReorder | |
| from langchain_core.documents import Document | |
| from langchain_core.output_parsers import StrOutputParser | |
| from langchain_core.runnables import RunnableLambda | |
| from langchain_openai.chat_models import ChatOpenAI | |
| from .prompt_template import generate_prompt_template | |
| from .retrievers_setup import (DenseRetrieverClient, SparseRetrieverClient, | |
| compression_retriever_setup) | |
| # Helpers | |
| def reorder_documents(docs: list[Document]) -> list[Document]: | |
| """Long-Context Reorder: No matter the architecture of the model, there is | |
| a performance degradation when we include 10+ retrieved documents. | |
| Args: | |
| docs (list): List of Langchain documents | |
| Returns: | |
| list: Reordered list of Langchain documents | |
| """ | |
| reorder = LongContextReorder() | |
| return reorder.transform_documents(docs) | |
| def randomize_documents(documents: list[Document]) -> list[Document]: | |
| """Randomize the documents to vary the recommendations.""" | |
| random.shuffle(documents) | |
| return documents | |
| def format_practitioners_docs(docs: list[Document]) -> str: | |
| """Format the practitioners_db Documents to markdown. | |
| Args: | |
| docs (list[Documents]): List of Langchain documents | |
| Returns: | |
| docs (str): | |
| """ | |
| return f"\n{'-' * 3}\n".join( | |
| [f"- Practitioner #{i+1}:\n\n\t" + | |
| d.page_content for i, d in enumerate(docs)] | |
| ) | |
| def format_tall_tree_docs(docs: list[Document]) -> str: | |
| """Format the tall_tree_db Documents to markdown. | |
| Args: | |
| docs (list[Documents]): List of Langchain documents | |
| Returns: | |
| docs (str): | |
| """ | |
| return f"\n{'-' * 3}\n".join( | |
| [f"- No. {i+1}:\n\n\t" + | |
| d.page_content for i, d in enumerate(docs)] | |
| ) | |
| def create_langsmith_client(): | |
| """Create a Langsmith client.""" | |
| os.environ["LANGCHAIN_TRACING_V2"] = "true" | |
| os.environ["LANGCHAIN_PROJECT"] = "talltree-ai-assistant" | |
| os.environ["LANGCHAIN_ENDPOINT"] = "https://api.smith.langchain.com" | |
| langsmith_api_key = os.getenv("LANGCHAIN_API_KEY") | |
| if not langsmith_api_key: | |
| raise EnvironmentError( | |
| "Missing environment variable: LANGCHAIN_API_KEY") | |
| return langsmith.Client() | |
| # Set up Runnable and Memory | |
| def get_rag_chain(model_name: str = "gpt-4", temperature: float = 0.2) -> tuple[ChatOpenAI, ConversationBufferWindowMemory]: | |
| """Set up runnable and chat memory | |
| Args: | |
| model_name (str, optional): LLM model. Defaults to "gpt-4" 30012024. | |
| temperature (float, optional): Model temperature. Defaults to 0.2. | |
| Returns: | |
| Runnable, Memory: Chain and Memory | |
| """ | |
| # Set up Langsmith to trace the chain | |
| langsmith_tracing = create_langsmith_client() | |
| # LLM and prompt template | |
| llm = ChatOpenAI(model_name=model_name, | |
| temperature=temperature) | |
| prompt = generate_prompt_template() | |
| # Set retrievers pointing to the practitioners's dataset | |
| embeddings_model = "text-embedding-ada-002" | |
| dense_retriever_client = DenseRetrieverClient(embeddings_model=embeddings_model, | |
| collection_name="practitioners_db") | |
| # Qdrant db as a retriever | |
| practitioners_db_dense_retriever = dense_retriever_client.get_dense_retriever(search_type="similarity", | |
| k=10) | |
| # Testing the sparse vector retriever using Qdrant | |
| collection_name = "practitioners_db_sparse_collection" | |
| vector_name = "sparse_vector" | |
| sparse_retriever_client = SparseRetrieverClient( | |
| collection_name=collection_name, | |
| vector_name=vector_name, | |
| splade_model_id="naver/splade-cocondenser-ensembledistil", | |
| k=15) | |
| practitioners_db_sparse_retriever = sparse_retriever_client.get_sparse_retriever() | |
| # Ensemble retriever for hyprid search (dense retriever seems to work better but the dense retriever is good for acronyms like RMT) | |
| practitioners_ensemble_retriever = EnsembleRetriever( | |
| retrievers=[practitioners_db_dense_retriever, | |
| practitioners_db_sparse_retriever], weights=[0.2, 0.8] | |
| ) | |
| # Compression retriever for practitioners db | |
| # TODO | |
| practitioners_db_compression_retriever = compression_retriever_setup( | |
| practitioners_ensemble_retriever, | |
| embeddings_model="text-embedding-ada-002", | |
| similarity_threshold=0.74 | |
| ) | |
| # Set retrievers pointing to the tall_tree_db | |
| dense_retriever_client = DenseRetrieverClient(embeddings_model=embeddings_model, | |
| collection_name="tall_tree_db") | |
| tall_tree_db_dense_retriever = dense_retriever_client.get_dense_retriever(search_type="similarity", | |
| k=5) | |
| # Compression retriever for tall_tree_db | |
| tall_tree_db_compression_retriever = compression_retriever_setup( | |
| tall_tree_db_dense_retriever, | |
| embeddings_model="text-embedding-ada-002", | |
| similarity_threshold=0.5 | |
| ) | |
| # Set conversation history window memory. It only uses the last k interactions. | |
| memory = ConversationBufferWindowMemory(memory_key="history", | |
| return_messages=True, | |
| k=5) | |
| # Set up runnable using LCEL | |
| setup_and_retrieval = {"practitioners_db": itemgetter("message") | |
| | practitioners_db_compression_retriever | |
| | format_practitioners_docs, | |
| "tall_tree_db": itemgetter("message") | tall_tree_db_compression_retriever | format_tall_tree_docs, | |
| "history": RunnableLambda(memory.load_memory_variables) | itemgetter("history"), | |
| "message": itemgetter("message") | |
| } | |
| chain = ( | |
| setup_and_retrieval | |
| | prompt | |
| | llm | |
| | StrOutputParser() | |
| ) | |
| return chain, memory | |