import os import gradio as gr from langchain.vectorstores import FAISS from langchain.embeddings import HuggingFaceEmbeddings from langchain_openai import ChatOpenAI from langchain.chains import ConversationalRetrievalChain from langchain.memory import ConversationBufferMemory # Configuration - Use Hugging Face Spaces secrets openai_api_key = os.getenv("OPENAI_API_KEY") DB_DIR = "vector_db" # Initialize embedding model embed_model = HuggingFaceEmbeddings( model_name="intfloat/e5-base", model_kwargs={"device": "cpu"}, encode_kwargs={"batch_size": 16} # Reduced batch size for HF Spaces ) # Global variables for lazy loading db = None qa_chain = None def initialize_system(): """Initialize the RAG system - called once when first question is asked""" global db, qa_chain if qa_chain is not None: return True try: # Load FAISS vector database db = FAISS.load_local(DB_DIR, embeddings=embed_model, allow_dangerous_deserialization=True) retriever = db.as_retriever(search_kwargs={"k": 3}) # Initialize LLM llm = ChatOpenAI( model_name="gpt-3.5-turbo", temperature=0.5, openai_api_key=openai_api_key ) # Initialize memory memory = ConversationBufferMemory( memory_key="chat_history", return_messages=True, output_key="answer" ) # Create QA chain qa_chain = ConversationalRetrievalChain.from_llm( llm=llm, retriever=retriever, memory=memory, return_source_documents=True ) return True except Exception as e: print(f"Error initializing system: {e}") return False def chat_with_rag(message, history): """ Chat function optimized for Hugging Face Spaces """ # Check API key if not openai_api_key: return "⚠️ OpenAI API key not configured. Please set the OPENAI_API_KEY secret in your Hugging Face Space settings." # Initialize system on first use if not initialize_system(): return "❌ Failed to initialize the RAG system. Please check if the vector database is properly uploaded." if not message.strip(): return "Please enter a question about medical research." try: # Convert Gradio history format to LangChain format chat_history = [] if history: for i in range(0, len(history), 2): if i + 1 < len(history): chat_history.append((history[i], history[i + 1])) # Get response from QA chain response = qa_chain.invoke({ "question": message, "chat_history": chat_history }) # Extract answer answer = response["answer"] return answer except Exception as e: error_msg = str(e) if "API key" in error_msg.lower(): return "⚠️ Invalid OpenAI API key. Please check your API key in the Space settings." elif "rate limit" in error_msg.lower(): return "⚠️ Rate limit exceeded. Please wait a moment before asking another question." else: return f"❌ An error occurred: {error_msg}" # Create Gradio interface optimized for HF Spaces def create_interface(): with gr.Blocks( theme=gr.themes.Soft(), title="PubMed RAG Chatbot", css=""" .gradio-container { max-width: 800px !important; margin: auto !important; } /* Ensure proper contrast in both light and dark modes */ .gr-button { transition: all 0.2s ease; } .gr-button:hover { transform: translateY(-1px); } """ ) as interface: gr.HTML("""
Ask questions about medical research and get answers from PubMed literature
This chatbot searches through medical literature to provide research-based answers.
Disclaimer: This is for informational purposes only and not medical advice.