Spaces:
Sleeping
Sleeping
Update chatbot/core.py
Browse files- chatbot/core.py +16 -15
chatbot/core.py
CHANGED
@@ -3,7 +3,9 @@ from chatbot.memory import memory
|
|
3 |
from chatbot.prompts import chat_prompt
|
4 |
from langchain.retrievers import WikipediaRetriever
|
5 |
from langchain.chains import ConversationalRetrievalChain
|
6 |
-
from
|
|
|
|
|
7 |
|
8 |
def translate_to_english(text: str) -> str:
|
9 |
"""Use Gemini LLM to translate text to English."""
|
@@ -12,26 +14,25 @@ def translate_to_english(text: str) -> str:
|
|
12 |
return response # Assuming `gemini_llm.invoke()` returns plain text
|
13 |
|
14 |
class WikipediaTranslationRetriever(BaseRetriever):
|
15 |
-
|
16 |
-
|
17 |
-
self.retriever = retriever
|
18 |
-
self.translator = translator
|
19 |
|
20 |
-
def get_relevant_documents(self, query):
|
21 |
-
translated_query = self.translator(query)
|
22 |
print(f"🔄 Translated Query: {translated_query}")
|
23 |
return self.retriever.get_relevant_documents(translated_query)
|
|
|
|
|
|
|
|
|
24 |
|
25 |
-
async def aget_relevant_documents(self, query):
|
26 |
-
# If your environment doesn't need async support, you can simply raise an error.
|
27 |
-
raise NotImplementedError("Asynchronous retrieval is not implemented.")
|
28 |
|
|
|
|
|
|
|
|
|
|
|
29 |
|
30 |
-
# ✅ Use WikipediaRetriever
|
31 |
-
wiki_retriever = WikipediaRetriever()
|
32 |
-
|
33 |
-
# ✅ Wrap with translation
|
34 |
-
retriever = WikipediaTranslationRetriever(wiki_retriever, translate_to_english)
|
35 |
|
36 |
# ✅ Use ConversationalRetrievalChain
|
37 |
qa_chain = ConversationalRetrievalChain.from_llm(
|
|
|
3 |
from chatbot.prompts import chat_prompt
|
4 |
from langchain.retrievers import WikipediaRetriever
|
5 |
from langchain.chains import ConversationalRetrievalChain
|
6 |
+
from pydantic import Field
|
7 |
+
from typing import List, Callable
|
8 |
+
from langchain.schema import BaseRetriever, Document
|
9 |
|
10 |
def translate_to_english(text: str) -> str:
|
11 |
"""Use Gemini LLM to translate text to English."""
|
|
|
14 |
return response # Assuming `gemini_llm.invoke()` returns plain text
|
15 |
|
16 |
class WikipediaTranslationRetriever(BaseRetriever):
|
17 |
+
retriever: WikipediaRetriever = Field(..., description="The underlying Wikipedia retriever")
|
18 |
+
translator: Callable[[str], str] = Field(..., description="Function to translate queries to English")
|
|
|
|
|
19 |
|
20 |
+
def get_relevant_documents(self, query: str) -> List[Document]:
|
21 |
+
translated_query = self.translator(query)
|
22 |
print(f"🔄 Translated Query: {translated_query}")
|
23 |
return self.retriever.get_relevant_documents(translated_query)
|
24 |
+
|
25 |
+
async def aget_relevant_documents(self, query: str) -> List[Document]:
|
26 |
+
# For simplicity, we are not implementing the async version.
|
27 |
+
raise NotImplementedError("Async retrieval is not implemented.")
|
28 |
|
|
|
|
|
|
|
29 |
|
30 |
+
# Create the retriever instance to be used in your qa_chain:
|
31 |
+
retriever = WikipediaTranslationRetriever(
|
32 |
+
retriever=WikipediaRetriever(),
|
33 |
+
translator=translate_to_english
|
34 |
+
)
|
35 |
|
|
|
|
|
|
|
|
|
|
|
36 |
|
37 |
# ✅ Use ConversationalRetrievalChain
|
38 |
qa_chain = ConversationalRetrievalChain.from_llm(
|