Kumiko_v1 / chatbot /core.py
anhkhoiphan's picture
Update chatbot/core.py
b1c1560 verified
raw
history blame
2.2 kB
from chatbot.llm import gemini_llm # Import Gemini LLM
from chatbot.memory import memory
from chatbot.prompts import chat_prompt
from langchain.retrievers import WikipediaRetriever
from langchain.chains import ConversationalRetrievalChain
from pydantic import Field
from typing import List, Callable
from langchain.schema import BaseRetriever, Document
def translate_to_english(text: str) -> str:
"""Use Gemini LLM to translate text to English."""
prompt = f"Translate the following text to English:\n\n{text}"
response = gemini_llm.invoke(prompt) # Invoke Gemini for translation
return response # Assuming `gemini_llm.invoke()` returns plain text
class WikipediaTranslationRetriever(BaseRetriever):
retriever: WikipediaRetriever = Field(..., description="The underlying Wikipedia retriever")
translator: Callable[[str], str] = Field(..., description="Function to translate queries to English")
def get_relevant_documents(self, query: str) -> List[Document]:
translated_query = self.translator(query)
print(f"🔄 Translated Query: {translated_query}")
return self.retriever.get_relevant_documents(translated_query)
async def aget_relevant_documents(self, query: str) -> List[Document]:
# For simplicity, we are not implementing the async version.
raise NotImplementedError("Async retrieval is not implemented.")
# Create the retriever instance to be used in your qa_chain:
retriever = WikipediaTranslationRetriever(
retriever=WikipediaRetriever(),
translator=translate_to_english
)
# ✅ Use ConversationalRetrievalChain
qa_chain = ConversationalRetrievalChain.from_llm(
llm=gemini_llm,
retriever=retriever,
memory=memory,
return_source_documents=False,
combine_docs_chain_kwargs={"prompt": chat_prompt},
output_key="result"
)
def get_chat_response(user_input: str) -> str:
"""Process user input and return chat response using Wikipedia retrieval."""
response = qa_chain(user_input) # Pass query to retrieval-based QA chain
# Save conversation context
memory.save_context({"input": user_input}, {"output": response["result"]})
return response["result"]