Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, HTTPException | |
import os | |
from typing import List, Dict, Any | |
from dotenv import load_dotenv | |
import logging | |
from pathlib import Path | |
from langchain_community.document_loaders import PyPDFLoader | |
from langchain_text_splitters import RecursiveCharacterTextSplitter | |
from langchain_community.vectorstores import Qdrant as QdrantVectorStore | |
from langchain_google_genai import GoogleGenerativeAIEmbeddings | |
from langchain_groq import ChatGroq | |
from qdrant_client import QdrantClient | |
from qdrant_client.http.models import Distance, VectorParams | |
from qdrant_client.models import PointIdsList | |
from langgraph.graph import MessagesState, StateGraph | |
from langchain_core.messages import SystemMessage, HumanMessage, ToolMessage | |
from langgraph.prebuilt import ToolNode | |
from langgraph.graph import END | |
from langgraph.prebuilt import tools_condition | |
from langgraph.checkpoint.memory import MemorySaver | |
# Configure logging | |
logging.basicConfig(level=logging.INFO, | |
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') | |
logger = logging.getLogger(__name__) | |
# Load environment variables | |
load_dotenv() | |
GOOGLE_API_KEY = os.getenv('GOOGLE_API_KEY') | |
GROQ_API_KEY = os.getenv('GROQ_API_KEY') | |
if not GOOGLE_API_KEY or not GROQ_API_KEY: | |
raise ValueError("API keys not set in environment variables") | |
app = FastAPI() | |
class QASystem: | |
def __init__(self): | |
self.vector_store = None | |
self.graph = None | |
self.memory = None | |
self.embeddings = None | |
self.client = None | |
self.pdf_dir = "pdfss" | |
self.is_initialized = False | |
def load_pdf_documents(self): | |
"""Load and process PDF documents from the pdf directory""" | |
documents = [] | |
pdf_dir = Path(self.pdf_dir) | |
if not pdf_dir.exists(): | |
raise FileNotFoundError(f"PDF directory not found: {self.pdf_dir}") | |
pdf_files = list(pdf_dir.glob("*.pdf")) | |
if not pdf_files: | |
logger.warning(f"No PDF files found in directory: {self.pdf_dir}") | |
return [] | |
logger.info(f"Found {len(pdf_files)} PDF files to process") | |
for pdf_path in pdf_files: | |
try: | |
logger.info(f"Processing PDF: {pdf_path}") | |
loader = PyPDFLoader(str(pdf_path)) | |
pdf_documents = loader.load() | |
# Add source information to metadata | |
for doc in pdf_documents: | |
if not hasattr(doc, 'metadata'): | |
doc.metadata = {} | |
doc.metadata['source'] = str(pdf_path.name) | |
documents.extend(pdf_documents) | |
logger.info(f"Loaded PDF: {pdf_path} - {len(pdf_documents)} pages/sections") | |
except Exception as e: | |
logger.error(f"Error loading PDF {pdf_path}: {str(e)}") | |
if not documents: | |
logger.warning("No documents were loaded from PDFs. Check the PDF directory and file formats.") | |
return [] | |
# Split documents into smaller chunks for better retrieval | |
text_splitter = RecursiveCharacterTextSplitter( | |
chunk_size=1000, | |
chunk_overlap=200 | |
) | |
split_docs = text_splitter.split_documents(documents) | |
logger.info(f"Split {len(documents)} documents into {len(split_docs)} chunks") | |
# Verify content of the first few chunks | |
for i, doc in enumerate(split_docs[:3]): | |
if i >= len(split_docs): | |
break | |
logger.info(f"Sample chunk {i+1} content preview: {doc.page_content[:100]}...") | |
return split_docs | |
def initialize_system(self): | |
"""Initialize the RAG system with vector store and LLM""" | |
try: | |
logger.info("Initializing QA System...") | |
# Initialize Qdrant client | |
self.client = QdrantClient(":memory:") | |
logger.info("Qdrant client initialized (in-memory)") | |
# Create or get collection | |
try: | |
collection_info = self.client.get_collection("pdf_data") | |
logger.info(f"Using existing collection: pdf_data") | |
except Exception: | |
self.client.create_collection( | |
collection_name="pdf_data", | |
vectors_config=VectorParams(size=768, distance=Distance.COSINE), | |
) | |
logger.info("Created new collection: pdf_data") | |
# Initialize embeddings model | |
self.embeddings = GoogleGenerativeAIEmbeddings( | |
model="models/embedding-001", | |
google_api_key=GOOGLE_API_KEY | |
) | |
logger.info("Google AI Embeddings initialized") | |
# Initialize vector store | |
self.vector_store = QdrantVectorStore( | |
client=self.client, | |
collection_name="pdf_data", | |
embeddings=self.embeddings, | |
) | |
logger.info("Qdrant vector store initialized") | |
# Load documents | |
documents = self.load_pdf_documents() | |
if not documents: | |
logger.warning("No documents loaded. The system will continue but may not provide relevant responses.") | |
# Clear existing vectors if any | |
if documents: | |
try: | |
points = self.client.scroll(collection_name="pdf_data", limit=1000)[0] | |
if points: | |
logger.info(f"Clearing {len(points)} existing vectors from collection") | |
self.client.delete( | |
collection_name="pdf_data", | |
points_selector=PointIdsList( | |
points=[p.id for p in points] | |
) | |
) | |
except Exception as e: | |
logger.error(f"Error clearing vectors: {str(e)}") | |
# Add documents to vector store | |
logger.info(f"Adding {len(documents)} documents to vector store") | |
self.vector_store.add_documents(documents) | |
logger.info(f"Successfully added documents to vector store") | |
# Verify vector store has documents | |
try: | |
count = len(self.client.scroll(collection_name="pdf_data", limit=1)[0]) | |
logger.info(f"Vector store contains points: {count > 0}") | |
except Exception as e: | |
logger.error(f"Error verifying vector store: {str(e)}") | |
# Initialize LLM | |
llm = ChatGroq( | |
model="llama3-8b-8192", | |
api_key=GROQ_API_KEY, | |
temperature=0.7 | |
) | |
logger.info("Groq LLM initialized") | |
# Create LangGraph | |
graph_builder = StateGraph(MessagesState) | |
logger.info("Creating LangGraph for conversation flow") | |
# Define retrieval node (self reference for vector_store access) | |
vector_store_ref = self.vector_store | |
def retrieve_docs(state: MessagesState): | |
"""Node that retrieves relevant documents from the vector store""" | |
# Get the most recent human message | |
human_messages = [m for m in state["messages"] if m.type == "human"] | |
if not human_messages: | |
logger.warning("No human messages found in state") | |
return {"messages": state["messages"]} | |
user_query = human_messages[-1].content | |
logger.info(f"Retrieving documents for query: '{user_query}'") | |
# Check if vector store exists | |
if not vector_store_ref: | |
logger.error("Vector store not initialized or empty") | |
return {"messages": state["messages"]} | |
# Query the vector store | |
try: | |
retrieved_docs = vector_store_ref.similarity_search(user_query, k=3) | |
if not retrieved_docs: | |
logger.warning(f"No documents retrieved for query: '{user_query}'") | |
return {"messages": state["messages"]} | |
# Log what was actually retrieved | |
for i, doc in enumerate(retrieved_docs): | |
source = doc.metadata.get('source', 'Unknown') if hasattr(doc, 'metadata') else 'Unknown' | |
content_preview = doc.page_content[:100] + "..." if len(doc.page_content) > 100 else doc.page_content | |
logger.info(f"Retrieved doc {i+1} from {source}, preview: {content_preview}") | |
# Create tool messages with more detailed content | |
tool_messages = [] | |
for i, doc in enumerate(retrieved_docs): | |
# Include source information if available | |
source_info = f" (Source: {doc.metadata.get('source', 'Unknown')})" if hasattr(doc, 'metadata') else "" | |
tool_messages.append( | |
ToolMessage( | |
content=f"Document {i+1}{source_info}: {doc.page_content}", | |
tool_call_id=f"retrieval_{i}" | |
) | |
) | |
logger.info(f"Created {len(tool_messages)} tool messages with retrieved content") | |
return {"messages": state["messages"] + tool_messages} | |
except Exception as e: | |
logger.error(f"Error retrieving documents: {str(e)}") | |
return {"messages": state["messages"]} | |
# Generate response using retrieved documents | |
def generate(state: MessagesState): | |
"""Node that generates a response using the LLM and retrieved documents""" | |
# Extract retrieved documents (tool messages) | |
tool_messages = [m for m in state["messages"] if m.type == "tool"] | |
# Collect context from retrieved documents | |
if tool_messages: | |
context = "\n\n".join([m.content for m in tool_messages]) | |
logger.info(f"Using context from {len(tool_messages)} retrieved documents") | |
else: | |
context = "No specific mountain bicycle documentation available for this query." | |
logger.warning("No relevant documents retrieved, using default context") | |
system_prompt = ( | |
"You are an AI assistant embedded within the Interactive Electronic Technical Manual (IETM) for Mountain Cycles. " | |
"Your primary role is to provide accurate technical information about mountain bicycles. " | |
"Always base your responses on the provided documentation. " | |
"If you don't find specific information in the provided context, clearly state that the information " | |
"is not available in the current documentation instead of making up details. " | |
"When responding, reference specific parts of the documentation." | |
f"\n\nContext from mountain bicycle documentation:\n{context}" | |
) | |
# Get all messages excluding tool messages to avoid redundancy | |
human_and_ai_messages = [m for m in state["messages"] if m.type != "tool"] | |
# Create the full message history for the LLM | |
messages = [SystemMessage(content=system_prompt)] + human_and_ai_messages | |
logger.info(f"Sending query to LLM with {len(messages)} messages") | |
# Generate the response | |
try: | |
response = llm.invoke(messages) | |
logger.info(f"LLM generated response successfully") | |
return {"messages": state["messages"] + [response]} | |
except Exception as e: | |
logger.error(f"Error generating response: {str(e)}") | |
error_message = SystemMessage(content=f"Error generating response: {str(e)}") | |
return {"messages": state["messages"] + [error_message]} | |
# Add nodes to the graph | |
graph_builder.add_node("retrieve_docs", retrieve_docs) | |
graph_builder.add_node("generate", generate) | |
# Set the flow of the graph | |
graph_builder.set_entry_point("retrieve_docs") | |
graph_builder.add_edge("retrieve_docs", "generate") | |
graph_builder.add_edge("generate", END) | |
# Initialize memory | |
self.memory = MemorySaver() | |
self.graph = graph_builder.compile(checkpointer=self.memory) | |
logger.info("Graph compiled successfully") | |
self.is_initialized = True | |
return True | |
except Exception as e: | |
logger.error(f"System initialization error: {str(e)}") | |
self.is_initialized = False | |
return False | |
def process_query(self, query: str) -> Dict[str, Any]: | |
"""Process a query and return a single final response""" | |
try: | |
if not self.is_initialized: | |
logger.error("System not initialized. Cannot process query.") | |
return { | |
'content': "Error: QA System not initialized properly", | |
'type': 'error' | |
} | |
logger.info(f"Processing query: '{query}'") | |
# Generate a thread ID (use a more sophisticated method for production) | |
thread_id = "abc123" | |
# Use invoke to get only the final result | |
final_state = self.graph.invoke( | |
{"messages": [HumanMessage(content=query)]}, | |
config={"configurable": {"thread_id": thread_id}} | |
) | |
# Extract only the last AI message from the final state | |
ai_messages = [m for m in final_state["messages"] if m.type == "ai"] | |
if ai_messages: | |
logger.info("Successfully generated response") | |
# Return only the last AI message | |
return { | |
'content': ai_messages[-1].content, | |
'type': ai_messages[-1].type | |
} | |
logger.warning("No AI message generated in response") | |
return { | |
'content': "No response could be generated for your query. Please try a different question.", | |
'type': 'error' | |
} | |
except Exception as e: | |
logger.error(f"Query processing error: {str(e)}") | |
return { | |
'content': f"Error processing your query: {str(e)}", | |
'type': 'error' | |
} | |
# Initialize the QA system | |
qa_system = QASystem() | |
initialization_success = qa_system.initialize_system() | |
async def query_api(query: str): | |
"""API endpoint that returns a single response for a query""" | |
if not qa_system.is_initialized: | |
raise HTTPException(status_code=500, detail="QA System not initialized properly") | |
response = qa_system.process_query(query) | |
return {"response": response} |