Spaces:
Sleeping
Sleeping
Update chatbot/core.py
Browse files- chatbot/core.py +30 -9
chatbot/core.py
CHANGED
@@ -1,15 +1,36 @@
|
|
|
|
|
|
|
|
1 |
from langchain.retrievers import WikipediaRetriever
|
2 |
-
from chatbot.llm import gemini_llm
|
3 |
-
from chatbot.memory import memory
|
4 |
-
from chatbot.prompts import chat_prompt
|
5 |
from langchain.chains import ConversationalRetrievalChain
|
6 |
|
7 |
-
|
8 |
-
|
|
|
|
|
|
|
9 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
qa_chain = ConversationalRetrievalChain.from_llm(
|
11 |
llm=gemini_llm,
|
12 |
-
retriever=retriever,
|
13 |
memory=memory,
|
14 |
return_source_documents=False,
|
15 |
combine_docs_chain_kwargs={"prompt": chat_prompt},
|
@@ -17,10 +38,10 @@ qa_chain = ConversationalRetrievalChain.from_llm(
|
|
17 |
)
|
18 |
|
19 |
def get_chat_response(user_input: str) -> str:
|
20 |
-
response
|
|
|
21 |
|
22 |
-
#
|
23 |
memory.save_context({"input": user_input}, {"output": response["result"]})
|
24 |
|
25 |
return response["result"]
|
26 |
-
|
|
|
1 |
+
from chatbot.llm import gemini_llm # Import Gemini LLM
|
2 |
+
from chatbot.memory import memory
|
3 |
+
from chatbot.prompts import chat_prompt
|
4 |
from langchain.retrievers import WikipediaRetriever
|
|
|
|
|
|
|
5 |
from langchain.chains import ConversationalRetrievalChain
|
6 |
|
7 |
+
def translate_to_english(text: str) -> str:
|
8 |
+
"""Use Gemini LLM to translate text to English."""
|
9 |
+
prompt = f"Translate the following text to English:\n\n{text}"
|
10 |
+
response = gemini_llm.invoke(prompt) # Invoke Gemini for translation
|
11 |
+
return response # Assuming `gemini_llm.invoke()` returns plain text
|
12 |
|
13 |
+
class WikipediaTranslationRetriever:
|
14 |
+
"""Custom Retriever that translates queries before searching Wikipedia."""
|
15 |
+
def __init__(self, retriever, translator):
|
16 |
+
self.retriever = retriever
|
17 |
+
self.translator = translator
|
18 |
+
|
19 |
+
def get_relevant_documents(self, query):
|
20 |
+
translated_query = self.translator(query) # Translate query to English
|
21 |
+
print(f"🔄 Translated Query: {translated_query}")
|
22 |
+
return self.retriever.get_relevant_documents(translated_query)
|
23 |
+
|
24 |
+
# ✅ Use WikipediaRetriever
|
25 |
+
wiki_retriever = WikipediaRetriever()
|
26 |
+
|
27 |
+
# ✅ Wrap with translation
|
28 |
+
retriever = WikipediaTranslationRetriever(wiki_retriever, translate_to_english)
|
29 |
+
|
30 |
+
# ✅ Use ConversationalRetrievalChain
|
31 |
qa_chain = ConversationalRetrievalChain.from_llm(
|
32 |
llm=gemini_llm,
|
33 |
+
retriever=retriever,
|
34 |
memory=memory,
|
35 |
return_source_documents=False,
|
36 |
combine_docs_chain_kwargs={"prompt": chat_prompt},
|
|
|
38 |
)
|
39 |
|
40 |
def get_chat_response(user_input: str) -> str:
|
41 |
+
"""Process user input and return chat response using Wikipedia retrieval."""
|
42 |
+
response = qa_chain(user_input) # Pass query to retrieval-based QA chain
|
43 |
|
44 |
+
# Save conversation context
|
45 |
memory.save_context({"input": user_input}, {"output": response["result"]})
|
46 |
|
47 |
return response["result"]
|
|