Kumiko_v1 / chatbot /core.py
anhkhoiphan's picture
Update memory
bbd820c verified
raw
history blame
3.96 kB
from chatbot.llm import gemini_llm # Import Gemini LLM
from chatbot.memory import memory
from chatbot.prompts import chat_prompt
from langchain.retrievers import WikipediaRetriever
from langchain.chains import ConversationalRetrievalChain
from pydantic import Field
from typing import List, Callable
from langchain.schema import BaseRetriever, Document
from langchain.schema import HumanMessage, AIMessage
def translate_to_english(text: str) -> str:
"""Use Gemini LLM to translate text to English with recent chat context."""
recent_messages = memory.chat_memory.messages[-3:] # Lấy 3 tin nhắn gần nhất
recent_context = "\n".join([msg.content for msg in recent_messages])
prompt = f"""
You are an assistant for Wikipedia searches.
The query may be in any language.
Extract and return only the most relevant keyword (e.g. a person's name, city, or key term) in English/international form.
Consider the recent conversation context to disambiguate references.
For query about what happened to you, interpret 'you' as Kumiko, based on your knowledge of the Hibike! Euphonium plot, return one of the following keywords
- Liz and the Blue Bird (Regarding information about Nozomi and Mizore)
- Sound! Euphonium: The Movie – Welcome to the Kitauji High School Concert Band (Information about Kumiko and Reina)
- Sound! Euphonium: The Movie – May the Melody Reach You! (Information about Asuka and Mamiko)
- Sound! Euphonium: The Movie – Our Promise: A Brand New Day (Information about Kumiko's second year)
- List of Sound! Euphonium episodes (other informations)
Recent Context:
{recent_context}
Query:
{text}
Return only the keyword—no explanations.
"""
response = gemini_llm.invoke(prompt) # Invoke Gemini for translation
return response
class WikipediaTranslationRetriever(BaseRetriever):
retriever: WikipediaRetriever = Field(..., description="The underlying Wikipedia retriever")
translator: Callable[[str], str] = Field(..., description="Function to translate queries to English")
def get_relevant_documents(self, query: str) -> List[Document]:
translated_query = self.translator(query)
print(f"🔄 Translated Query: {translated_query}")
return self.retriever.get_relevant_documents(translated_query)
async def aget_relevant_documents(self, query: str) -> List[Document]:
# For simplicity, we are not implementing the async version.
raise NotImplementedError("Async retrieval is not implemented.")
def custom_get_chat_history(chat_history):
# Nếu chat_history là chuỗi (summary) thì trả về chuỗi đó
if isinstance(chat_history, str):
return chat_history
# Nếu là danh sách các message, chuyển thành chuỗi
elif isinstance(chat_history, list):
return "\n".join([msg.content for msg in chat_history])
else:
raise ValueError("Unsupported chat history format.")
# Create the retriever instance to be used in your qa_chain:
retriever = WikipediaTranslationRetriever(
retriever=WikipediaRetriever(),
translator=translate_to_english
)
# ✅ Use ConversationalRetrievalChain
qa_chain = ConversationalRetrievalChain.from_llm(
llm=gemini_llm,
retriever=retriever,
memory=memory,
return_source_documents=False,
combine_docs_chain_kwargs={"prompt": chat_prompt},
output_key="result"
)
qa_chain.get_chat_history = custom_get_chat_history
def get_chat_response(user_input: str) -> str:
"""Process user input and return chat response using Wikipedia retrieval."""
response = qa_chain(user_input) # Pass query to retrieval-based QA chain
# Save conversation context
# memory.chat_memory.add_message(HumanMessage(content=user_input))
# memory.chat_memory.add_message(AIMessage(content=response["result"]))
return response["result"]