Spaces:
Sleeping
Sleeping
File size: 2,203 Bytes
6a383c4 f09971c 3c5f44b b1c1560 3c5f44b 6a383c4 f09971c a0c0e21 b1c1560 6a383c4 b1c1560 6a383c4 b1c1560 6a383c4 a0c0e21 b1c1560 a0c0e21 6a383c4 3c5f44b 6a383c4 3c5f44b 5b97f9e 3c5f44b 5b97f9e 3c5f44b 6a383c4 3c5f44b 6a383c4 3c5f44b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 |
from chatbot.llm import gemini_llm # Import Gemini LLM
from chatbot.memory import memory
from chatbot.prompts import chat_prompt
from langchain.retrievers import WikipediaRetriever
from langchain.chains import ConversationalRetrievalChain
from pydantic import Field
from typing import List, Callable
from langchain.schema import BaseRetriever, Document
def translate_to_english(text: str) -> str:
"""Use Gemini LLM to translate text to English."""
prompt = f"Translate the following text to English:\n\n{text}"
response = gemini_llm.invoke(prompt) # Invoke Gemini for translation
return response # Assuming `gemini_llm.invoke()` returns plain text
class WikipediaTranslationRetriever(BaseRetriever):
retriever: WikipediaRetriever = Field(..., description="The underlying Wikipedia retriever")
translator: Callable[[str], str] = Field(..., description="Function to translate queries to English")
def get_relevant_documents(self, query: str) -> List[Document]:
translated_query = self.translator(query)
print(f"🔄 Translated Query: {translated_query}")
return self.retriever.get_relevant_documents(translated_query)
async def aget_relevant_documents(self, query: str) -> List[Document]:
# For simplicity, we are not implementing the async version.
raise NotImplementedError("Async retrieval is not implemented.")
# Create the retriever instance to be used in your qa_chain:
retriever = WikipediaTranslationRetriever(
retriever=WikipediaRetriever(),
translator=translate_to_english
)
# ✅ Use ConversationalRetrievalChain
qa_chain = ConversationalRetrievalChain.from_llm(
llm=gemini_llm,
retriever=retriever,
memory=memory,
return_source_documents=False,
combine_docs_chain_kwargs={"prompt": chat_prompt},
output_key="result"
)
def get_chat_response(user_input: str) -> str:
"""Process user input and return chat response using Wikipedia retrieval."""
response = qa_chain(user_input) # Pass query to retrieval-based QA chain
# Save conversation context
memory.save_context({"input": user_input}, {"output": response["result"]})
return response["result"]
|