Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
| import os | |
| import random | |
| from datetime import datetime | |
| from operator import itemgetter | |
| from typing import Sequence | |
| import langsmith | |
| from langchain.memory import ConversationBufferWindowMemory | |
| from langchain_community.document_transformers import LongContextReorder | |
| from langchain_core.documents import Document | |
| from langchain_core.output_parsers import StrOutputParser | |
| from langchain_core.runnables import Runnable, RunnableLambda | |
| from langchain_openai import ChatOpenAI | |
| from zoneinfo import ZoneInfo | |
| from rag.retrievers import RetrieversConfig | |
| from .prompt_template import generate_prompt_template | |
| # Helpers | |
| def get_datetime() -> str: | |
| """Get the current date and time.""" | |
| return datetime.now(ZoneInfo("America/Vancouver")).strftime("%A, %Y-%b-%d %H:%M:%S") | |
| def reorder_documents(docs: list[Document]) -> Sequence[Document]: | |
| """Reorder documents to mitigate performance degradation with long contexts.""" | |
| return LongContextReorder().transform_documents(docs) | |
| def randomize_documents(documents: list[Document]) -> list[Document]: | |
| """Randomize documents to vary model recommendations.""" | |
| random.shuffle(documents) | |
| return documents | |
| class DocumentFormatter: | |
| def __init__(self, prefix: str): | |
| self.prefix = prefix | |
| def __call__(self, docs: list[Document]) -> str: | |
| """Format the Documents to markdown. | |
| Args: | |
| docs (list[Documents]): List of Langchain documents | |
| Returns: | |
| docs (str): | |
| """ | |
| return "\n---\n".join( | |
| [ | |
| f"- {self.prefix} {i+1}:\n\n\t" + d.page_content | |
| for i, d in enumerate(docs) | |
| ] | |
| ) | |
| def create_langsmith_client(): | |
| """Create a Langsmith client.""" | |
| os.environ["LANGCHAIN_TRACING_V2"] = "true" | |
| os.environ["LANGCHAIN_PROJECT"] = "admin-ai-assistant" | |
| os.environ["LANGCHAIN_ENDPOINT"] = "https://api.smith.langchain.com" | |
| langsmith_api_key = os.getenv("LANGCHAIN_API_KEY") | |
| if not langsmith_api_key: | |
| raise EnvironmentError("Missing environment variable: LANGCHAIN_API_KEY") | |
| return langsmith.Client() | |
| # Set up Runnable and Memory | |
| def get_runnable( | |
| model: str = "gpt-4o-mini", temperature: float = 0.1 | |
| ) -> tuple[Runnable, ConversationBufferWindowMemory]: | |
| """Set up runnable and chat memory | |
| Args: | |
| model_name (str, optional): LLM model. Defaults to "gpt-4o". | |
| temperature (float, optional): Model temperature. Defaults to 0.1. | |
| Returns: | |
| Runnable, Memory: Chain and Memory | |
| """ | |
| # Set up Langsmith to trace the chain | |
| create_langsmith_client() | |
| # LLM and prompt template | |
| llm = ChatOpenAI( | |
| model=model, | |
| temperature=temperature, | |
| ) | |
| prompt = generate_prompt_template() | |
| # Set retrievers with Hybrid search | |
| retrievers_config = RetrieversConfig() | |
| # Practitioners data | |
| practitioners_data_retriever = retrievers_config.get_practitioners_retriever(k=10) | |
| # Tall Tree documents with contact information for locations and services | |
| documents_retriever = retrievers_config.get_documents_retriever(k=10) | |
| # Set conversation history window memory. It only uses the last k interactions | |
| memory = ConversationBufferWindowMemory( | |
| memory_key="history", | |
| return_messages=True, | |
| k=6, | |
| ) | |
| # Set up runnable using LCEL | |
| setup = { | |
| "practitioners_db": itemgetter("message") | |
| | practitioners_data_retriever | |
| | DocumentFormatter("Practitioner #"), | |
| "tall_tree_db": itemgetter("message") | |
| | documents_retriever | |
| | DocumentFormatter("No."), | |
| "timestamp": lambda _: get_datetime(), | |
| "history": RunnableLambda(memory.load_memory_variables) | itemgetter("history"), | |
| "message": itemgetter("message"), | |
| } | |
| chain = setup | prompt | llm | StrOutputParser() | |
| return chain, memory | |