File size: 2,203 Bytes
6a383c4
 
 
f09971c
3c5f44b
b1c1560
 
 
3c5f44b
6a383c4
 
 
 
 
f09971c
a0c0e21
b1c1560
 
6a383c4
b1c1560
 
6a383c4
 
b1c1560
 
 
 
6a383c4
a0c0e21
b1c1560
 
 
 
 
a0c0e21
6a383c4
 
3c5f44b
 
6a383c4
3c5f44b
5b97f9e
3c5f44b
5b97f9e
3c5f44b
 
 
6a383c4
 
3c5f44b
6a383c4
3c5f44b
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
from chatbot.llm import gemini_llm  # Import Gemini LLM
from chatbot.memory import memory
from chatbot.prompts import chat_prompt
from langchain.retrievers import WikipediaRetriever
from langchain.chains import ConversationalRetrievalChain
from pydantic import Field
from typing import List, Callable
from langchain.schema import BaseRetriever, Document

def translate_to_english(text: str) -> str:
    """Use Gemini LLM to translate text to English."""
    prompt = f"Translate the following text to English:\n\n{text}"
    response = gemini_llm.invoke(prompt)  # Invoke Gemini for translation
    return response  # Assuming `gemini_llm.invoke()` returns plain text

class WikipediaTranslationRetriever(BaseRetriever):
    retriever: WikipediaRetriever = Field(..., description="The underlying Wikipedia retriever")
    translator: Callable[[str], str] = Field(..., description="Function to translate queries to English")

    def get_relevant_documents(self, query: str) -> List[Document]:
        translated_query = self.translator(query)
        print(f"🔄 Translated Query: {translated_query}")
        return self.retriever.get_relevant_documents(translated_query)
    
    async def aget_relevant_documents(self, query: str) -> List[Document]:
        # For simplicity, we are not implementing the async version.
        raise NotImplementedError("Async retrieval is not implemented.")


# Create the retriever instance to be used in your qa_chain:
retriever = WikipediaTranslationRetriever(
    retriever=WikipediaRetriever(),
    translator=translate_to_english
)


# ✅ Use ConversationalRetrievalChain
qa_chain = ConversationalRetrievalChain.from_llm(
    llm=gemini_llm, 
    retriever=retriever, 
    memory=memory,  
    return_source_documents=False,  
    combine_docs_chain_kwargs={"prompt": chat_prompt},
    output_key="result"
)

def get_chat_response(user_input: str) -> str:
    """Process user input and return chat response using Wikipedia retrieval."""
    response = qa_chain(user_input)  # Pass query to retrieval-based QA chain

    # Save conversation context
    memory.save_context({"input": user_input}, {"output": response["result"]})

    return response["result"]