Spaces:
Sleeping
Sleeping
from chatbot.llm import gemini_llm # Import Gemini LLM | |
from chatbot.memory import memory | |
from chatbot.prompts import chat_prompt | |
from langchain.retrievers import WikipediaRetriever | |
from langchain.chains import ConversationalRetrievalChain | |
from pydantic import Field | |
from typing import List, Callable | |
from langchain.schema import BaseRetriever, Document | |
from langchain.schema import HumanMessage, AIMessage | |
from datetime import datetime | |
import pytz | |
def get_current_time(): | |
utc_plus_7 = pytz.timezone("Asia/Ho_Chi_Minh") # UTC+7 | |
now = datetime.now(utc_plus_7) | |
return now.strftime("%A, %Y-%m-%d %H:%M:%S") # Example: "Saturday, 2025-03-29 14:30:00" | |
def translate_to_english(text: str) -> str: | |
"""Use Gemini LLM to translate text to English with recent chat context.""" | |
recent_messages = memory.chat_memory.messages[-3:] # Lấy 3 tin nhắn gần nhất | |
recent_context = "\n".join([msg.content for msg in recent_messages]) | |
prompt = f""" | |
You are an assistant for Wikipedia searches. The query may be in any language. | |
There are two types of query: general informational queries and queries asking about you. Determine the type and respond as follows: | |
**If the query asks for general information**: | |
Extract and return only the most relevant keyword (e.g. a person's name, city, or key term) in English/international form. | |
If the query includes a specific time reference (date, month, or year), ensure that the translated keyword includes this time reference. | |
Current time: {get_current_time()}. | |
Identify the referenced time point (day, month, year, decade or century) mentioned in the query, calculate its exact value, and include it in the keyword if necessary. | |
If the query refers to time but does not require absolute precision (e.g., using words like "around" or "approximately"), return a broader time range and include it in the keyword if necessary. | |
For example: | |
If the current year is 2025, then: "around 30 years ago" is 1990s, "about 200 years ago" is 19th century. | |
**If query asks about what happened to you**: | |
Interpret 'you' as Kumiko, based on your knowledge of the Hibike! Euphonium plot, return one of the following keywords: | |
- Liz and the Blue Bird (Regarding information about Nozomi and Mizore) | |
- Sound! Euphonium: The Movie – Welcome to the Kitauji High School Concert Band (Information about Kumiko and Reina) | |
- Sound! Euphonium: The Movie – May the Melody Reach You! (Information about Asuka and Mamiko) | |
- Sound! Euphonium: The Movie – Our Promise: A Brand New Day (Information about Kumiko's second year) | |
- Sound! Euphonium (Information about Kumiko's other acquaintances) | |
- List of Sound! Euphonium episodes (other information) | |
Consider the recent conversation context to disambiguate references. | |
Recent Context: | |
{recent_context} | |
Query: | |
{text} | |
Return only the keyword — no explanations. | |
""" | |
response = gemini_llm.invoke(prompt) # Invoke Gemini for translation | |
return response | |
class WikipediaTranslationRetriever(BaseRetriever): | |
retriever: WikipediaRetriever = Field(..., description="The underlying Wikipedia retriever") | |
translator: Callable[[str], str] = Field(..., description="Function to translate queries to English") | |
def get_relevant_documents(self, query: str) -> List[Document]: | |
translated_query = self.translator(query) | |
print(f"🔄 Translated Query: {translated_query}") | |
return self.retriever.get_relevant_documents(translated_query) | |
async def aget_relevant_documents(self, query: str) -> List[Document]: | |
# For simplicity, we are not implementing the async version. | |
raise NotImplementedError("Async retrieval is not implemented.") | |
def custom_get_chat_history(chat_history): | |
# Nếu chat_history là chuỗi (summary) thì trả về chuỗi đó | |
if isinstance(chat_history, str): | |
return chat_history | |
# Nếu là danh sách các message, chuyển thành chuỗi | |
elif isinstance(chat_history, list): | |
return "\n".join([msg.content for msg in chat_history]) | |
else: | |
raise ValueError("Unsupported chat history format.") | |
# Create the retriever instance to be used in your qa_chain: | |
retriever = WikipediaTranslationRetriever( | |
retriever=WikipediaRetriever(), | |
translator=translate_to_english | |
) | |
# ✅ Use ConversationalRetrievalChain | |
qa_chain = ConversationalRetrievalChain.from_llm( | |
llm=gemini_llm, | |
retriever=retriever, | |
memory=memory, | |
return_source_documents=False, | |
combine_docs_chain_kwargs={"prompt": chat_prompt}, | |
output_key="result" | |
) | |
qa_chain.get_chat_history = custom_get_chat_history | |
def get_chat_response(user_input: str) -> str: | |
"""Process user input and return chat response using Wikipedia retrieval.""" | |
response = qa_chain({ | |
"question": user_input, | |
"current_time": get_current_time() # Pass the current time | |
}) | |
return response["result"] | |