File size: 9,821 Bytes
8d1e83e
d594a38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8d1e83e
d594a38
 
8d1e83e
 
 
 
 
 
 
d594a38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7e9684b
8d1e83e
 
 
7e9684b
8d1e83e
 
7e9684b
8d1e83e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7e9684b
d594a38
 
 
 
 
 
 
 
8d1e83e
7e9684b
8d1e83e
 
 
 
7e9684b
8d1e83e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8882a59
 
 
8d1e83e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d594a38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
"""
Module for performing retrieval-augmented generation (RAG) using LangChain.
This module provides functions to optimize search queries, retrieve relevant documents,
and generate answers to questions using the retrieved context. It leverages the LangChain
library for building the RAG pipeline.
Functions:
- get_optimized_search_messages(query: str) -> list:
Generate optimized search messages for a given query.
- optimize_search_query(chat_llm, query: str, callbacks: list = []) -> str:
Optimize the search query using the chat language model.
- get_rag_prompt_template() -> ChatPromptTemplate:
Get the prompt template for retrieval-augmented generation (RAG).
- format_docs(docs: list) -> str:
Format the retrieved documents into a JSON string.
- multi_query_rag(chat_llm, question: str, search_query: str, vectorstore, callbacks: list = []) -> str:
Perform RAG using multiple queries to retrieve relevant documents.
- query_rag(chat_llm, question: str, search_query: str, vectorstore, callbacks: list = []) -> str:
Perform RAG using a single query to retrieve relevant documents.
"""
import os
import json
from langchain.schema import SystemMessage, HumanMessage
from langchain.prompts.chat import (
    HumanMessagePromptTemplate,
    SystemMessagePromptTemplate,
    ChatPromptTemplate
)
from langchain.prompts.prompt import PromptTemplate
from langchain.retrievers.multi_query import MultiQueryRetriever

from langchain_cohere.chat_models import ChatCohere
from langchain_groq import ChatGroq
from langchain_openai import ChatOpenAI
from langchain_community.chat_models.bedrock import BedrockChat
from langchain_community.chat_models.ollama import ChatOllama

def get_chat_llm(provider, model=None, temperature=0.0):
    match provider:
        case 'bedrock':
            if model is None:
                model = "anthropic.claude-3-sonnet-20240229-v1:0"
            chat_llm = BedrockChat(
                credentials_profile_name=os.getenv('CREDENTIALS_PROFILE_NAME'),
                model_id=model,
                model_kwargs={"temperature": temperature },
            )
        case 'openai':
            if model is None:
                model = "gpt-3.5-turbo"
            chat_llm = ChatOpenAI(model_name=model, temperature=temperature)
        case 'groq':
            if model is None:
                model = 'mixtral-8x7b-32768'
            chat_llm = ChatGroq(model_name=model, temperature=temperature)
        case 'ollama':
            if model is None:
                model = 'llama2'
            chat_llm = ChatOllama(model=model, temperature=temperature)
        case 'cohere':
            if model is None:
                model = 'command-r-plus'
            chat_llm = ChatCohere(model=model, temperature=temperature)
        case _:
            raise ValueError(f"Unknown LLM provider {provider}")
    return chat_llm


def get_optimized_search_messages(query):
    """
    Generate optimized search messages for a given query.

    Args:
        query (str): The user's query.

    Returns:
        list: A list containing the system message and human message for optimized search.
    """
    system_message = SystemMessage(
        content="""
            I want you to act as a prompt optimizer for web search. I will provide you with a chat prompt, and your goal is to optimize it into a search string that will yield the most relevant and useful information from a search engine like Google.
            To optimize the prompt:
            Identify the key information being requested
            Arrange the keywords into a concise search string
            Keep it short, around 1 to 5 words total
            Put the most important keywords first
            
            Some tips and things to be sure to remove:
            - Remove any conversational or instructional phrases
            - Removed style such as "in the style of", "engaging", "short", "long"
            - Remove lenght instruction (example: essay, article, letter, blog, post, blogpost, etc)
            - Remove style instructions (exmaple: "in the style of", engaging, short, long)
            - Remove lenght instruction (example: essay, article, letter, etc)
            
            Add "**" to the end of the search string to indicate the end of the query
            
            Example:
                Question: How do I bake chocolate chip cookies from scratch?
                Search query: chocolate chip cookies recipe from scratch**
            Example:
                Question: I would like you to show me a timeline of Marie Curie's life. Show results as a markdown table
                Search query: Marie Curie timeline**
            Example:
                Question: I would like you to write a long article on NATO vs Russia. Use known geopolitical frameworks.
                Search query: geopolitics nato russia**
            Example:
                Question: Write an engaging LinkedIn post about Andrew Ng
                Search query: Andrew Ng**
            Example:
                Question: Write a short article about the solar system in the style of Carl Sagan
                Search query: solar system**
            Example:
                Question: Should I use Kubernetes? Answer in the style of Gilfoyle from the TV show Silicon Valley
                Search query: Kubernetes decision**
            Example:
                Question: Biography of Napoleon. Include a table with the major events.
                Search query: napoleon biography events**
            Example:
                Question: Write a short article on the history of the United States. Include a table with the major events.
                Search query: united states history events**
            Example:
                Question: Write a short article about the solar system in the style of donald trump
                Search query: solar system**
        """
    )
    human_message = HumanMessage(
        content=f"""                 
            Question: {query}
            Search query: 
        """
    )
    return [system_message, human_message]


def optimize_search_query(chat_llm, query, callbacks=[]):
    messages = get_optimized_search_messages(query)
    response = chat_llm.invoke(messages, config={"callbacks": callbacks})
    optimized_search_query = response.content
    return optimized_search_query.strip('"').split("**", 1)[0]


def get_rag_prompt_template():
    """
    Get the prompt template for Retrieval-Augmented Generation (RAG).

    Returns:
        ChatPromptTemplate: The prompt template for RAG.
    """
    system_prompt = SystemMessagePromptTemplate(
        prompt=PromptTemplate(
            input_variables=[],
            template="""
                You are an expert research assistant.
                You are provided with a Context in JSON format and a Question. 
                Each JSON entry contains: content, title, link

                Use RAG to answer the Question, providing references and links to the Context material you retrieve and use in your answer:
                When generating your answer, follow these steps:
                - Retrieve the most relevant context material from your knowledge base to help answer the question
                - Cite the references you use by including the title, author, publication, and a link to each source
                - Synthesize the retrieved information into a clear, informative answer to the question
                - Format your answer in Markdown, using heading levels 2-3 as needed
                - Include a "References" section at the end with the full citations and link for each source you used
                
                If you cannot answer the question with confidence just say: "I'm not sure about the answer to be honest"
                If the provided context is not relevant to the question, just say: "The context provided is not relevant to the question"
            """
        )
    )
    human_prompt = HumanMessagePromptTemplate(
        prompt=PromptTemplate(
            input_variables=["context", "query"],
            template="""
                Context: 
                ---------------------
                {context}
                ---------------------
                Question: {query}
                Answer:
            """
        )
    )
    return ChatPromptTemplate(
        input_variables=["context", "query"],
        messages=[system_prompt, human_prompt],
    )

def format_docs(docs):
    formatted_docs = []
    for d in docs:
        content = d.page_content
        title = d.metadata['title']
        source = d.metadata['source']
        doc = {"content": content, "title": title, "link": source}
        formatted_docs.append(doc)
    docs_as_json = json.dumps(formatted_docs, indent=2, ensure_ascii=False)
    return docs_as_json


def multi_query_rag(chat_llm, question, search_query, vectorstore, callbacks = []):
    retriever_from_llm = MultiQueryRetriever.from_llm(
        retriever=vectorstore.as_retriever(), llm=chat_llm, include_original=True,
    )
    unique_docs = retriever_from_llm.get_relevant_documents(
        query=search_query, callbacks=callbacks, verbose=True
    )
    context = format_docs(unique_docs)
    prompt = get_rag_prompt_template().format(query=question, context=context)
    response = chat_llm.invoke(prompt, config={"callbacks": callbacks})
    return response.content


def query_rag(chat_llm, question, search_query, vectorstore, callbacks = []):
    unique_docs = vectorstore.similarity_search(search_query, k=15, callbacks=callbacks, verbose=True)
    context = format_docs(unique_docs)
    prompt = get_rag_prompt_template().format(query=question, context=context)
    response = chat_llm.invoke(prompt, config={"callbacks": callbacks})
    return response.content