Spaces:
Sleeping
Sleeping
Update chatbot/core.py
Browse files- chatbot/core.py +7 -1
chatbot/core.py
CHANGED
@@ -3,6 +3,7 @@ from chatbot.memory import memory
|
|
3 |
from chatbot.prompts import chat_prompt
|
4 |
from langchain.retrievers import WikipediaRetriever
|
5 |
from langchain.chains import ConversationalRetrievalChain
|
|
|
6 |
|
7 |
def translate_to_english(text: str) -> str:
|
8 |
"""Use Gemini LLM to translate text to English."""
|
@@ -10,7 +11,7 @@ def translate_to_english(text: str) -> str:
|
|
10 |
response = gemini_llm.invoke(prompt) # Invoke Gemini for translation
|
11 |
return response # Assuming `gemini_llm.invoke()` returns plain text
|
12 |
|
13 |
-
class WikipediaTranslationRetriever:
|
14 |
"""Custom Retriever that translates queries before searching Wikipedia."""
|
15 |
def __init__(self, retriever, translator):
|
16 |
self.retriever = retriever
|
@@ -21,6 +22,11 @@ class WikipediaTranslationRetriever:
|
|
21 |
print(f"🔄 Translated Query: {translated_query}")
|
22 |
return self.retriever.get_relevant_documents(translated_query)
|
23 |
|
|
|
|
|
|
|
|
|
|
|
24 |
# ✅ Use WikipediaRetriever
|
25 |
wiki_retriever = WikipediaRetriever()
|
26 |
|
|
|
3 |
from chatbot.prompts import chat_prompt
|
4 |
from langchain.retrievers import WikipediaRetriever
|
5 |
from langchain.chains import ConversationalRetrievalChain
|
6 |
+
from langchain.schema import BaseRetriever
|
7 |
|
8 |
def translate_to_english(text: str) -> str:
|
9 |
"""Use Gemini LLM to translate text to English."""
|
|
|
11 |
response = gemini_llm.invoke(prompt) # Invoke Gemini for translation
|
12 |
return response # Assuming `gemini_llm.invoke()` returns plain text
|
13 |
|
14 |
+
class WikipediaTranslationRetriever(BaseRetriever):
|
15 |
"""Custom Retriever that translates queries before searching Wikipedia."""
|
16 |
def __init__(self, retriever, translator):
|
17 |
self.retriever = retriever
|
|
|
22 |
print(f"🔄 Translated Query: {translated_query}")
|
23 |
return self.retriever.get_relevant_documents(translated_query)
|
24 |
|
25 |
+
async def aget_relevant_documents(self, query):
|
26 |
+
# If your environment doesn't need async support, you can simply raise an error.
|
27 |
+
raise NotImplementedError("Asynchronous retrieval is not implemented.")
|
28 |
+
|
29 |
+
|
30 |
# ✅ Use WikipediaRetriever
|
31 |
wiki_retriever = WikipediaRetriever()
|
32 |
|