Spaces:
Sleeping
Sleeping
import os | |
import google.generativeai as genai | |
from langchain_google_genai import ChatGoogleGenerativeAI | |
from langchain_community.embeddings import HuggingFaceEmbeddings | |
from langchain_community.vectorstores import FAISS | |
from langchain.chains import RetrievalQA, ConversationalRetrievalChain | |
from langchain_google_genai import ChatGoogleGenerativeAI | |
from langchain.prompts import PromptTemplate | |
from pinecone import Pinecone, ServerlessSpec | |
from langchain_pinecone import PineconeVectorStore | |
from dotenv import load_dotenv | |
import threading | |
from datetime import datetime | |
import time | |
from langchain.schema import HumanMessage, AIMessage | |
from langchain_google_genai import GoogleGenerativeAIEmbeddings | |
import functools | |
import hashlib | |
import logging | |
import random | |
# Configure logging | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', | |
) | |
logger = logging.getLogger(__name__) | |
# Load environment variables | |
load_dotenv() | |
# Configure API keys from environment variables | |
google_api_key = os.getenv("GOOGLE_API_KEY") | |
pinecone_api_key = os.getenv("PINECONE_API_KEY") | |
if not google_api_key or not pinecone_api_key: | |
raise ValueError("Missing required API keys in environment variables") | |
os.environ["GOOGLE_API_KEY"] = google_api_key | |
os.environ["PINECONE_API_KEY"] = pinecone_api_key | |
genai.configure(api_key=google_api_key) | |
# Lấy model chatbot | |
model = ChatGoogleGenerativeAI(model="gemini-1.5-flash-8b-latest", | |
temperature=0.8) | |
# Lấy model embedding | |
embeddings = GoogleGenerativeAIEmbeddings(model="models/embedding-001") | |
# Biến lưu history cho từng user (dạng chuỗi) | |
user_histories = {} | |
history_lock = threading.Lock() | |
# Cache for responses | |
response_cache = {} | |
cache_lock = threading.Lock() | |
# Maximum cache size và thời gian sống (30 phút) | |
MAX_CACHE_SIZE = 100 | |
CACHE_TTL = 1800 # 30 phút tính bằng giây | |
# Create a prompt template with conversation history | |
prompt = PromptTemplate( | |
template = """Goal: | |
You are a professional tour guide assistant that assists users in finding information about places in Da Nang, Vietnam. | |
You can provide details on restaurants, cafes, hotels, attractions, and other local venues. You have to chat with users, who are Da Nang tourists. | |
Return Format: | |
Respond in friendly, natural, and concise English like a real tour guide. | |
Warning: | |
Let's support users like a real tour guide, not a bot. The information in context is your own knowledge. | |
Your knowledge is provided in the Context. All of information in Context is about Da Nang, Vietnam. | |
You just care about current time that user mention when user ask about Solana event. | |
If you do not have enough information to answer user's question, reply with "I don't know. I don't have information about that". | |
Context: | |
{context} | |
Conversation History: | |
{chat_history} | |
User question: | |
{question} | |
Your answer: | |
""", | |
input_variables = ["context", "question", "chat_history"], | |
) | |
def get_history(user_id): | |
"""Get conversation history for a specific user""" | |
with history_lock: | |
return user_histories.get(user_id, "") | |
def update_history(user_id, new_entry): | |
"""Update conversation history for a user. | |
new_entry should be a string containing the new conversation information, e.g.: | |
"User: {question}\nBot: {answer}\n" | |
""" | |
with history_lock: | |
current_history = user_histories.get(user_id, "") | |
# Store only the last 30 interactions by keeping the 60 most recent lines | |
# (assuming 2 lines per interaction: 1 for user, 1 for bot) | |
history_lines = current_history.split('\n') | |
if len(history_lines) > 20: | |
history_lines = history_lines[-20:] | |
current_history = '\n'.join(history_lines) | |
updated_history = current_history + new_entry + "\n" | |
user_histories[user_id] = updated_history | |
def string_to_message_history(history_str): | |
"""Convert string-based history to LangChain message history format""" | |
if not history_str.strip(): | |
return [] | |
messages = [] | |
lines = history_str.strip().split('\n') | |
i = 0 | |
while i < len(lines): | |
line = lines[i].strip() | |
if line.startswith("User:"): | |
user_message = line[5:].strip() # Get the user message without "User:" | |
messages.append(HumanMessage(content=user_message)) | |
# Look for a Bot response (should be the next line) | |
if i + 1 < len(lines) and lines[i + 1].strip().startswith("Bot:"): | |
bot_response = lines[i + 1][4:].strip() # Get bot response without "Bot:" | |
messages.append(AIMessage(content=bot_response)) | |
i += 2 # Skip the bot line too | |
else: | |
i += 1 | |
else: | |
i += 1 # Skip any unexpected format lines | |
return messages | |
# Singleton pattern để chỉ khởi tạo retriever một lần | |
_retriever_instance = None | |
_retriever_lock = threading.Lock() | |
def get_chain(): | |
"""Get the retrieval chain with Pinecone vector store (singleton pattern)""" | |
global _retriever_instance | |
# Nếu đã có instance, trả về ngay | |
if _retriever_instance is not None: | |
return _retriever_instance | |
# Thread-safe khởi tạo | |
with _retriever_lock: | |
# Kiểm tra lại trong trường hợp một thread khác đã khởi tạo | |
if _retriever_instance is not None: | |
return _retriever_instance | |
try: | |
start_time = time.time() | |
pc = Pinecone( | |
api_key=os.environ["PINECONE_API_KEY"] | |
) | |
# Get the vector store from the existing index | |
vectorstore = PineconeVectorStore.from_existing_index( | |
index_name="testbot768", | |
embedding=embeddings, | |
text_key="text" | |
) | |
_retriever_instance = vectorstore.as_retriever(search_kwargs={"k": 3}) | |
logger.info(f"Pinecone retriever initialized in {time.time() - start_time:.2f} seconds") | |
return _retriever_instance | |
except Exception as e: | |
logger.error(f"Error getting vector store from Pinecone: {e}") | |
# Fallback to a local vector store or return None | |
try: | |
# Try to load a local FAISS index if it exists | |
start_time = time.time() | |
vectorstore = FAISS.load_local("faiss_index", embeddings) | |
_retriever_instance = vectorstore.as_retriever(search_kwargs={"k": 3}) | |
logger.info(f"FAISS retriever initialized in {time.time() - start_time:.2f} seconds") | |
return _retriever_instance | |
except Exception as faiss_error: | |
logger.error(f"Error getting FAISS vector store: {faiss_error}") | |
return None | |
def clean_cache(): | |
"""Clean expired cache entries""" | |
with cache_lock: | |
current_time = time.time() | |
expired_keys = [k for k, v in response_cache.items() if current_time - v['timestamp'] > CACHE_TTL] | |
for key in expired_keys: | |
del response_cache[key] | |
# Nếu cache vẫn quá lớn, xóa các mục cũ nhất | |
if len(response_cache) > MAX_CACHE_SIZE: | |
# Sắp xếp theo thời gian và giữ lại MAX_CACHE_SIZE mục mới nhất | |
sorted_items = sorted(response_cache.items(), key=lambda x: x[1]['timestamp']) | |
items_to_remove = sorted_items[:len(sorted_items) - MAX_CACHE_SIZE] | |
for key, _ in items_to_remove: | |
del response_cache[key] | |
def generate_cache_key(request, user_id): | |
"""Generate a unique cache key from the request and user_id""" | |
# Tạo một chuỗi kết hợp để hash | |
combined = f"{request.strip().lower()}:{user_id}" | |
# Tạo MD5 hash | |
return hashlib.md5(combined.encode()).hexdigest() | |
def chat(request, user_id="default_user"): | |
"""Process a chat request from a specific user""" | |
start_time = time.time() | |
# Định kỳ xóa các mục cache hết hạn | |
if random.random() < 0.1: # 10% cơ hội mỗi lần gọi | |
clean_cache() | |
# Tạo cache key | |
cache_key = generate_cache_key(request, user_id) | |
# Kiểm tra cache | |
with cache_lock: | |
if cache_key in response_cache: | |
cache_data = response_cache[cache_key] | |
# Kiểm tra thời gian sống | |
if time.time() - cache_data['timestamp'] <= CACHE_TTL: | |
logger.info(f"Cache hit for user {user_id}, request: '{request[:30]}...'") | |
# Cập nhật timestamp để reset TTL | |
cache_data['timestamp'] = time.time() | |
# Vẫn cập nhật lịch sử trò chuyện | |
new_entry = f"User: {request}\nBot: {cache_data['response']}" | |
update_history(user_id, new_entry) | |
return cache_data['response'] | |
try: | |
retriever = get_chain() | |
if not retriever: | |
return "Error: Could not initialize retriever" | |
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") | |
retrieved_docs = retriever.get_relevant_documents(request) | |
context = "\n".join([doc.page_content for doc in retrieved_docs]) | |
# context = context + "\n(Current time: " + current_time + ")" | |
# print("Context:", context) | |
# print(prompt.format( | |
# context=context, | |
# question=request, | |
# chat_history=get_history(user_id) | |
# )) | |
response = model.invoke( | |
prompt.format( | |
context=context, | |
question=request, | |
chat_history=get_history(user_id) | |
) | |
) | |
answer = str(response.content) | |
new_entry = f"User: {request}\nBot: {answer}" | |
update_history(user_id, new_entry) | |
# print(get_history(user_id)) | |
# Lưu vào cache | |
with cache_lock: | |
response_cache[cache_key] = { | |
'response': answer, | |
'timestamp': time.time() | |
} | |
logger.info(f"Total processing time: {time.time() - start_time:.2f} seconds") | |
return answer | |
except Exception as e: | |
logger.error(f"Error in chat: {e}") | |
return f"I don't know how to answer that right now. Let me forward this to the admin team." | |
def clear_memory(user_id="default_user"): | |
"""Clear the conversation history for a specific user""" | |
with history_lock: | |
if user_id in user_histories: | |
del user_histories[user_id] | |
return f"Conversation history cleared for user {user_id}" | |
return f"No conversation history found for user {user_id}" |