Spaces:
Sleeping
Sleeping
Update memory
Browse files- chatbot/core.py +38 -4
chatbot/core.py
CHANGED
@@ -6,20 +6,41 @@ from langchain.chains import ConversationalRetrievalChain
|
|
6 |
from pydantic import Field
|
7 |
from typing import List, Callable
|
8 |
from langchain.schema import BaseRetriever, Document
|
|
|
9 |
|
10 |
def translate_to_english(text: str) -> str:
|
11 |
-
"""Use Gemini LLM to translate text to English."""
|
|
|
|
|
|
|
12 |
prompt = f"""
|
13 |
You are an assistant for Wikipedia searches.
|
|
|
14 |
The query may be in any language.
|
|
|
15 |
Extract and return only the most relevant keyword (e.g. a person's name, city, or key term) in English/international form.
|
16 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
|
18 |
Query:
|
19 |
{text}
|
|
|
|
|
20 |
"""
|
|
|
21 |
response = gemini_llm.invoke(prompt) # Invoke Gemini for translation
|
22 |
-
return response
|
|
|
23 |
|
24 |
class WikipediaTranslationRetriever(BaseRetriever):
|
25 |
retriever: WikipediaRetriever = Field(..., description="The underlying Wikipedia retriever")
|
@@ -34,6 +55,16 @@ class WikipediaTranslationRetriever(BaseRetriever):
|
|
34 |
# For simplicity, we are not implementing the async version.
|
35 |
raise NotImplementedError("Async retrieval is not implemented.")
|
36 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
|
38 |
# Create the retriever instance to be used in your qa_chain:
|
39 |
retriever = WikipediaTranslationRetriever(
|
@@ -52,11 +83,14 @@ qa_chain = ConversationalRetrievalChain.from_llm(
|
|
52 |
output_key="result"
|
53 |
)
|
54 |
|
|
|
|
|
55 |
def get_chat_response(user_input: str) -> str:
|
56 |
"""Process user input and return chat response using Wikipedia retrieval."""
|
57 |
response = qa_chain(user_input) # Pass query to retrieval-based QA chain
|
58 |
|
59 |
# Save conversation context
|
60 |
-
memory.
|
|
|
61 |
|
62 |
return response["result"]
|
|
|
6 |
from pydantic import Field
|
7 |
from typing import List, Callable
|
8 |
from langchain.schema import BaseRetriever, Document
|
9 |
+
from langchain.schema import HumanMessage, AIMessage
|
10 |
|
11 |
def translate_to_english(text: str) -> str:
|
12 |
+
"""Use Gemini LLM to translate text to English with recent chat context."""
|
13 |
+
recent_messages = memory.chat_memory.messages[-3:] # Lấy 3 tin nhắn gần nhất
|
14 |
+
recent_context = "\n".join([msg.content for msg in recent_messages])
|
15 |
+
|
16 |
prompt = f"""
|
17 |
You are an assistant for Wikipedia searches.
|
18 |
+
|
19 |
The query may be in any language.
|
20 |
+
|
21 |
Extract and return only the most relevant keyword (e.g. a person's name, city, or key term) in English/international form.
|
22 |
+
|
23 |
+
Consider the recent conversation context to disambiguate references.
|
24 |
+
|
25 |
+
For query about what happened to you, interpret 'you' as Kumiko, based on your knowledge of the Hibike! Euphonium plot, return one of the following keywords
|
26 |
+
- Liz and the Blue Bird (Regarding information about Nozomi and Mizore)
|
27 |
+
- Sound! Euphonium: The Movie – Welcome to the Kitauji High School Concert Band (Information about Kumiko and Reina)
|
28 |
+
- Sound! Euphonium: The Movie – May the Melody Reach You! (Information about Asuka and Mamiko)
|
29 |
+
- Sound! Euphonium: The Movie – Our Promise: A Brand New Day (Information about Kumiko's second year)
|
30 |
+
- List of Sound! Euphonium episodes (other informations)
|
31 |
+
|
32 |
+
Recent Context:
|
33 |
+
{recent_context}
|
34 |
|
35 |
Query:
|
36 |
{text}
|
37 |
+
|
38 |
+
Return only the keyword—no explanations.
|
39 |
"""
|
40 |
+
|
41 |
response = gemini_llm.invoke(prompt) # Invoke Gemini for translation
|
42 |
+
return response
|
43 |
+
|
44 |
|
45 |
class WikipediaTranslationRetriever(BaseRetriever):
|
46 |
retriever: WikipediaRetriever = Field(..., description="The underlying Wikipedia retriever")
|
|
|
55 |
# For simplicity, we are not implementing the async version.
|
56 |
raise NotImplementedError("Async retrieval is not implemented.")
|
57 |
|
58 |
+
def custom_get_chat_history(chat_history):
|
59 |
+
# Nếu chat_history là chuỗi (summary) thì trả về chuỗi đó
|
60 |
+
if isinstance(chat_history, str):
|
61 |
+
return chat_history
|
62 |
+
# Nếu là danh sách các message, chuyển thành chuỗi
|
63 |
+
elif isinstance(chat_history, list):
|
64 |
+
return "\n".join([msg.content for msg in chat_history])
|
65 |
+
else:
|
66 |
+
raise ValueError("Unsupported chat history format.")
|
67 |
+
|
68 |
|
69 |
# Create the retriever instance to be used in your qa_chain:
|
70 |
retriever = WikipediaTranslationRetriever(
|
|
|
83 |
output_key="result"
|
84 |
)
|
85 |
|
86 |
+
qa_chain.get_chat_history = custom_get_chat_history
|
87 |
+
|
88 |
def get_chat_response(user_input: str) -> str:
|
89 |
"""Process user input and return chat response using Wikipedia retrieval."""
|
90 |
response = qa_chain(user_input) # Pass query to retrieval-based QA chain
|
91 |
|
92 |
# Save conversation context
|
93 |
+
# memory.chat_memory.add_message(HumanMessage(content=user_input))
|
94 |
+
# memory.chat_memory.add_message(AIMessage(content=response["result"]))
|
95 |
|
96 |
return response["result"]
|