anhkhoiphan commited on
Commit
b1c1560
·
verified ·
1 Parent(s): a0c0e21

Update chatbot/core.py

Browse files
Files changed (1) hide show
  1. chatbot/core.py +16 -15
chatbot/core.py CHANGED
@@ -3,7 +3,9 @@ from chatbot.memory import memory
3
  from chatbot.prompts import chat_prompt
4
  from langchain.retrievers import WikipediaRetriever
5
  from langchain.chains import ConversationalRetrievalChain
6
- from langchain.schema import BaseRetriever
 
 
7
 
8
  def translate_to_english(text: str) -> str:
9
  """Use Gemini LLM to translate text to English."""
@@ -12,26 +14,25 @@ def translate_to_english(text: str) -> str:
12
  return response # Assuming `gemini_llm.invoke()` returns plain text
13
 
14
  class WikipediaTranslationRetriever(BaseRetriever):
15
- """Custom Retriever that translates queries before searching Wikipedia."""
16
- def __init__(self, retriever, translator):
17
- self.retriever = retriever
18
- self.translator = translator
19
 
20
- def get_relevant_documents(self, query):
21
- translated_query = self.translator(query) # Translate query to English
22
  print(f"🔄 Translated Query: {translated_query}")
23
  return self.retriever.get_relevant_documents(translated_query)
 
 
 
 
24
 
25
- async def aget_relevant_documents(self, query):
26
- # If your environment doesn't need async support, you can simply raise an error.
27
- raise NotImplementedError("Asynchronous retrieval is not implemented.")
28
 
 
 
 
 
 
29
 
30
- # ✅ Use WikipediaRetriever
31
- wiki_retriever = WikipediaRetriever()
32
-
33
- # ✅ Wrap with translation
34
- retriever = WikipediaTranslationRetriever(wiki_retriever, translate_to_english)
35
 
36
  # ✅ Use ConversationalRetrievalChain
37
  qa_chain = ConversationalRetrievalChain.from_llm(
 
3
  from chatbot.prompts import chat_prompt
4
  from langchain.retrievers import WikipediaRetriever
5
  from langchain.chains import ConversationalRetrievalChain
6
+ from pydantic import Field
7
+ from typing import List, Callable
8
+ from langchain.schema import BaseRetriever, Document
9
 
10
  def translate_to_english(text: str) -> str:
11
  """Use Gemini LLM to translate text to English."""
 
14
  return response # Assuming `gemini_llm.invoke()` returns plain text
15
 
16
  class WikipediaTranslationRetriever(BaseRetriever):
17
+ retriever: WikipediaRetriever = Field(..., description="The underlying Wikipedia retriever")
18
+ translator: Callable[[str], str] = Field(..., description="Function to translate queries to English")
 
 
19
 
20
+ def get_relevant_documents(self, query: str) -> List[Document]:
21
+ translated_query = self.translator(query)
22
  print(f"🔄 Translated Query: {translated_query}")
23
  return self.retriever.get_relevant_documents(translated_query)
24
+
25
+ async def aget_relevant_documents(self, query: str) -> List[Document]:
26
+ # For simplicity, we are not implementing the async version.
27
+ raise NotImplementedError("Async retrieval is not implemented.")
28
 
 
 
 
29
 
30
+ # Create the retriever instance to be used in your qa_chain:
31
+ retriever = WikipediaTranslationRetriever(
32
+ retriever=WikipediaRetriever(),
33
+ translator=translate_to_english
34
+ )
35
 
 
 
 
 
 
36
 
37
  # ✅ Use ConversationalRetrievalChain
38
  qa_chain = ConversationalRetrievalChain.from_llm(