RAG_FASTAPI / app.py
VishnuRamDebyez's picture
Update app.py
aca5ce5 verified
raw
history blame
15.8 kB
from fastapi import FastAPI, HTTPException
import os
from typing import List, Dict, Any
from dotenv import load_dotenv
import logging
from pathlib import Path
from langchain_community.document_loaders import PyPDFLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import Qdrant as QdrantVectorStore
from langchain_google_genai import GoogleGenerativeAIEmbeddings
from langchain_groq import ChatGroq
from qdrant_client import QdrantClient
from qdrant_client.http.models import Distance, VectorParams
from qdrant_client.models import PointIdsList
from langgraph.graph import MessagesState, StateGraph
from langchain_core.messages import SystemMessage, HumanMessage, ToolMessage
from langgraph.prebuilt import ToolNode
from langgraph.graph import END
from langgraph.prebuilt import tools_condition
from langgraph.checkpoint.memory import MemorySaver
# Configure logging
logging.basicConfig(level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# Load environment variables
load_dotenv()
GOOGLE_API_KEY = os.getenv('GOOGLE_API_KEY')
GROQ_API_KEY = os.getenv('GROQ_API_KEY')
if not GOOGLE_API_KEY or not GROQ_API_KEY:
raise ValueError("API keys not set in environment variables")
app = FastAPI()
class QASystem:
def __init__(self):
self.vector_store = None
self.graph = None
self.memory = None
self.embeddings = None
self.client = None
self.pdf_dir = "pdfss"
self.is_initialized = False
def load_pdf_documents(self):
"""Load and process PDF documents from the pdf directory"""
documents = []
pdf_dir = Path(self.pdf_dir)
if not pdf_dir.exists():
raise FileNotFoundError(f"PDF directory not found: {self.pdf_dir}")
pdf_files = list(pdf_dir.glob("*.pdf"))
if not pdf_files:
logger.warning(f"No PDF files found in directory: {self.pdf_dir}")
return []
logger.info(f"Found {len(pdf_files)} PDF files to process")
for pdf_path in pdf_files:
try:
logger.info(f"Processing PDF: {pdf_path}")
loader = PyPDFLoader(str(pdf_path))
pdf_documents = loader.load()
# Add source information to metadata
for doc in pdf_documents:
if not hasattr(doc, 'metadata'):
doc.metadata = {}
doc.metadata['source'] = str(pdf_path.name)
documents.extend(pdf_documents)
logger.info(f"Loaded PDF: {pdf_path} - {len(pdf_documents)} pages/sections")
except Exception as e:
logger.error(f"Error loading PDF {pdf_path}: {str(e)}")
if not documents:
logger.warning("No documents were loaded from PDFs. Check the PDF directory and file formats.")
return []
# Split documents into smaller chunks for better retrieval
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=1000,
chunk_overlap=200
)
split_docs = text_splitter.split_documents(documents)
logger.info(f"Split {len(documents)} documents into {len(split_docs)} chunks")
# Verify content of the first few chunks
for i, doc in enumerate(split_docs[:3]):
if i >= len(split_docs):
break
logger.info(f"Sample chunk {i+1} content preview: {doc.page_content[:100]}...")
return split_docs
def initialize_system(self):
"""Initialize the RAG system with vector store and LLM"""
try:
logger.info("Initializing QA System...")
# Initialize Qdrant client
self.client = QdrantClient(":memory:")
logger.info("Qdrant client initialized (in-memory)")
# Create or get collection
try:
collection_info = self.client.get_collection("pdf_data")
logger.info(f"Using existing collection: pdf_data")
except Exception:
self.client.create_collection(
collection_name="pdf_data",
vectors_config=VectorParams(size=768, distance=Distance.COSINE),
)
logger.info("Created new collection: pdf_data")
# Initialize embeddings model
self.embeddings = GoogleGenerativeAIEmbeddings(
model="models/embedding-001",
google_api_key=GOOGLE_API_KEY
)
logger.info("Google AI Embeddings initialized")
# Initialize vector store
self.vector_store = QdrantVectorStore(
client=self.client,
collection_name="pdf_data",
embeddings=self.embeddings,
)
logger.info("Qdrant vector store initialized")
# Load documents
documents = self.load_pdf_documents()
if not documents:
logger.warning("No documents loaded. The system will continue but may not provide relevant responses.")
# Clear existing vectors if any
if documents:
try:
points = self.client.scroll(collection_name="pdf_data", limit=1000)[0]
if points:
logger.info(f"Clearing {len(points)} existing vectors from collection")
self.client.delete(
collection_name="pdf_data",
points_selector=PointIdsList(
points=[p.id for p in points]
)
)
except Exception as e:
logger.error(f"Error clearing vectors: {str(e)}")
# Add documents to vector store
logger.info(f"Adding {len(documents)} documents to vector store")
self.vector_store.add_documents(documents)
logger.info(f"Successfully added documents to vector store")
# Verify vector store has documents
try:
count = len(self.client.scroll(collection_name="pdf_data", limit=1)[0])
logger.info(f"Vector store contains points: {count > 0}")
except Exception as e:
logger.error(f"Error verifying vector store: {str(e)}")
# Initialize LLM
llm = ChatGroq(
model="llama3-8b-8192",
api_key=GROQ_API_KEY,
temperature=0.7
)
logger.info("Groq LLM initialized")
# Create LangGraph
graph_builder = StateGraph(MessagesState)
logger.info("Creating LangGraph for conversation flow")
# Define retrieval node (self reference for vector_store access)
vector_store_ref = self.vector_store
def retrieve_docs(state: MessagesState):
"""Node that retrieves relevant documents from the vector store"""
# Get the most recent human message
human_messages = [m for m in state["messages"] if m.type == "human"]
if not human_messages:
logger.warning("No human messages found in state")
return {"messages": state["messages"]}
user_query = human_messages[-1].content
logger.info(f"Retrieving documents for query: '{user_query}'")
# Check if vector store exists
if not vector_store_ref:
logger.error("Vector store not initialized or empty")
return {"messages": state["messages"]}
# Query the vector store
try:
retrieved_docs = vector_store_ref.similarity_search(user_query, k=3)
if not retrieved_docs:
logger.warning(f"No documents retrieved for query: '{user_query}'")
return {"messages": state["messages"]}
# Log what was actually retrieved
for i, doc in enumerate(retrieved_docs):
source = doc.metadata.get('source', 'Unknown') if hasattr(doc, 'metadata') else 'Unknown'
content_preview = doc.page_content[:100] + "..." if len(doc.page_content) > 100 else doc.page_content
logger.info(f"Retrieved doc {i+1} from {source}, preview: {content_preview}")
# Create tool messages with more detailed content
tool_messages = []
for i, doc in enumerate(retrieved_docs):
# Include source information if available
source_info = f" (Source: {doc.metadata.get('source', 'Unknown')})" if hasattr(doc, 'metadata') else ""
tool_messages.append(
ToolMessage(
content=f"Document {i+1}{source_info}: {doc.page_content}",
tool_call_id=f"retrieval_{i}"
)
)
logger.info(f"Created {len(tool_messages)} tool messages with retrieved content")
return {"messages": state["messages"] + tool_messages}
except Exception as e:
logger.error(f"Error retrieving documents: {str(e)}")
return {"messages": state["messages"]}
# Generate response using retrieved documents
def generate(state: MessagesState):
"""Node that generates a response using the LLM and retrieved documents"""
# Extract retrieved documents (tool messages)
tool_messages = [m for m in state["messages"] if m.type == "tool"]
# Collect context from retrieved documents
if tool_messages:
context = "\n\n".join([m.content for m in tool_messages])
logger.info(f"Using context from {len(tool_messages)} retrieved documents")
else:
context = "No specific mountain bicycle documentation available for this query."
logger.warning("No relevant documents retrieved, using default context")
system_prompt = (
"You are an AI assistant embedded within the Interactive Electronic Technical Manual (IETM) for Mountain Cycles. "
"Your primary role is to provide accurate technical information about mountain bicycles. "
"Always base your responses on the provided documentation. "
"If you don't find specific information in the provided context, clearly state that the information "
"is not available in the current documentation instead of making up details. "
"When responding, reference specific parts of the documentation."
f"\n\nContext from mountain bicycle documentation:\n{context}"
)
# Get all messages excluding tool messages to avoid redundancy
human_and_ai_messages = [m for m in state["messages"] if m.type != "tool"]
# Create the full message history for the LLM
messages = [SystemMessage(content=system_prompt)] + human_and_ai_messages
logger.info(f"Sending query to LLM with {len(messages)} messages")
# Generate the response
try:
response = llm.invoke(messages)
logger.info(f"LLM generated response successfully")
return {"messages": state["messages"] + [response]}
except Exception as e:
logger.error(f"Error generating response: {str(e)}")
error_message = SystemMessage(content=f"Error generating response: {str(e)}")
return {"messages": state["messages"] + [error_message]}
# Add nodes to the graph
graph_builder.add_node("retrieve_docs", retrieve_docs)
graph_builder.add_node("generate", generate)
# Set the flow of the graph
graph_builder.set_entry_point("retrieve_docs")
graph_builder.add_edge("retrieve_docs", "generate")
graph_builder.add_edge("generate", END)
# Initialize memory
self.memory = MemorySaver()
self.graph = graph_builder.compile(checkpointer=self.memory)
logger.info("Graph compiled successfully")
self.is_initialized = True
return True
except Exception as e:
logger.error(f"System initialization error: {str(e)}")
self.is_initialized = False
return False
def process_query(self, query: str) -> Dict[str, Any]:
"""Process a query and return a single final response"""
try:
if not self.is_initialized:
logger.error("System not initialized. Cannot process query.")
return {
'content': "Error: QA System not initialized properly",
'type': 'error'
}
logger.info(f"Processing query: '{query}'")
# Generate a thread ID (use a more sophisticated method for production)
thread_id = "abc123"
# Use invoke to get only the final result
final_state = self.graph.invoke(
{"messages": [HumanMessage(content=query)]},
config={"configurable": {"thread_id": thread_id}}
)
# Extract only the last AI message from the final state
ai_messages = [m for m in final_state["messages"] if m.type == "ai"]
if ai_messages:
logger.info("Successfully generated response")
# Return only the last AI message
return {
'content': ai_messages[-1].content,
'type': ai_messages[-1].type
}
logger.warning("No AI message generated in response")
return {
'content': "No response could be generated for your query. Please try a different question.",
'type': 'error'
}
except Exception as e:
logger.error(f"Query processing error: {str(e)}")
return {
'content': f"Error processing your query: {str(e)}",
'type': 'error'
}
# Initialize the QA system
qa_system = QASystem()
initialization_success = qa_system.initialize_system()
@app.post("/query")
async def query_api(query: str):
"""API endpoint that returns a single response for a query"""
if not qa_system.is_initialized:
raise HTTPException(status_code=500, detail="QA System not initialized properly")
response = qa_system.process_query(query)
return {"response": response}