anhkhoiphan commited on
Commit
bbd820c
·
verified ·
1 Parent(s): 98e4789

Update memory

Browse files
Files changed (1) hide show
  1. chatbot/core.py +38 -4
chatbot/core.py CHANGED
@@ -6,20 +6,41 @@ from langchain.chains import ConversationalRetrievalChain
6
  from pydantic import Field
7
  from typing import List, Callable
8
  from langchain.schema import BaseRetriever, Document
 
9
 
10
  def translate_to_english(text: str) -> str:
11
- """Use Gemini LLM to translate text to English."""
 
 
 
12
  prompt = f"""
13
  You are an assistant for Wikipedia searches.
 
14
  The query may be in any language.
 
15
  Extract and return only the most relevant keyword (e.g. a person's name, city, or key term) in English/international form.
16
- Return only the keyword—no explanations.
 
 
 
 
 
 
 
 
 
 
 
17
 
18
  Query:
19
  {text}
 
 
20
  """
 
21
  response = gemini_llm.invoke(prompt) # Invoke Gemini for translation
22
- return response # Assuming `gemini_llm.invoke()` returns plain text
 
23
 
24
  class WikipediaTranslationRetriever(BaseRetriever):
25
  retriever: WikipediaRetriever = Field(..., description="The underlying Wikipedia retriever")
@@ -34,6 +55,16 @@ class WikipediaTranslationRetriever(BaseRetriever):
34
  # For simplicity, we are not implementing the async version.
35
  raise NotImplementedError("Async retrieval is not implemented.")
36
 
 
 
 
 
 
 
 
 
 
 
37
 
38
  # Create the retriever instance to be used in your qa_chain:
39
  retriever = WikipediaTranslationRetriever(
@@ -52,11 +83,14 @@ qa_chain = ConversationalRetrievalChain.from_llm(
52
  output_key="result"
53
  )
54
 
 
 
55
  def get_chat_response(user_input: str) -> str:
56
  """Process user input and return chat response using Wikipedia retrieval."""
57
  response = qa_chain(user_input) # Pass query to retrieval-based QA chain
58
 
59
  # Save conversation context
60
- memory.save_context({"input": user_input}, {"output": response["result"]})
 
61
 
62
  return response["result"]
 
6
  from pydantic import Field
7
  from typing import List, Callable
8
  from langchain.schema import BaseRetriever, Document
9
+ from langchain.schema import HumanMessage, AIMessage
10
 
11
  def translate_to_english(text: str) -> str:
12
+ """Use Gemini LLM to translate text to English with recent chat context."""
13
+ recent_messages = memory.chat_memory.messages[-3:] # Lấy 3 tin nhắn gần nhất
14
+ recent_context = "\n".join([msg.content for msg in recent_messages])
15
+
16
  prompt = f"""
17
  You are an assistant for Wikipedia searches.
18
+
19
  The query may be in any language.
20
+
21
  Extract and return only the most relevant keyword (e.g. a person's name, city, or key term) in English/international form.
22
+
23
+ Consider the recent conversation context to disambiguate references.
24
+
25
+ For query about what happened to you, interpret 'you' as Kumiko, based on your knowledge of the Hibike! Euphonium plot, return one of the following keywords
26
+ - Liz and the Blue Bird (Regarding information about Nozomi and Mizore)
27
+ - Sound! Euphonium: The Movie – Welcome to the Kitauji High School Concert Band (Information about Kumiko and Reina)
28
+ - Sound! Euphonium: The Movie – May the Melody Reach You! (Information about Asuka and Mamiko)
29
+ - Sound! Euphonium: The Movie – Our Promise: A Brand New Day (Information about Kumiko's second year)
30
+ - List of Sound! Euphonium episodes (other informations)
31
+
32
+ Recent Context:
33
+ {recent_context}
34
 
35
  Query:
36
  {text}
37
+
38
+ Return only the keyword—no explanations.
39
  """
40
+
41
  response = gemini_llm.invoke(prompt) # Invoke Gemini for translation
42
+ return response
43
+
44
 
45
  class WikipediaTranslationRetriever(BaseRetriever):
46
  retriever: WikipediaRetriever = Field(..., description="The underlying Wikipedia retriever")
 
55
  # For simplicity, we are not implementing the async version.
56
  raise NotImplementedError("Async retrieval is not implemented.")
57
 
58
+ def custom_get_chat_history(chat_history):
59
+ # Nếu chat_history là chuỗi (summary) thì trả về chuỗi đó
60
+ if isinstance(chat_history, str):
61
+ return chat_history
62
+ # Nếu là danh sách các message, chuyển thành chuỗi
63
+ elif isinstance(chat_history, list):
64
+ return "\n".join([msg.content for msg in chat_history])
65
+ else:
66
+ raise ValueError("Unsupported chat history format.")
67
+
68
 
69
  # Create the retriever instance to be used in your qa_chain:
70
  retriever = WikipediaTranslationRetriever(
 
83
  output_key="result"
84
  )
85
 
86
+ qa_chain.get_chat_history = custom_get_chat_history
87
+
88
  def get_chat_response(user_input: str) -> str:
89
  """Process user input and return chat response using Wikipedia retrieval."""
90
  response = qa_chain(user_input) # Pass query to retrieval-based QA chain
91
 
92
  # Save conversation context
93
+ # memory.chat_memory.add_message(HumanMessage(content=user_input))
94
+ # memory.chat_memory.add_message(AIMessage(content=response["result"]))
95
 
96
  return response["result"]