from fastapi import FastAPI, HTTPException import os from typing import List, Dict from dotenv import load_dotenv import logging from pathlib import Path from langchain_community.document_loaders import PyPDFLoader from langchain_text_splitters import RecursiveCharacterTextSplitter from langchain_community.vectorstores import Qdrant as QdrantVectorStore from langchain_google_genai import GoogleGenerativeAIEmbeddings from langchain_groq import ChatGroq from qdrant_client import QdrantClient from qdrant_client.http.models import Distance, VectorParams from qdrant_client.models import PointIdsList from langgraph.graph import MessagesState, StateGraph from langchain_core.messages import SystemMessage, HumanMessage, ToolMessage from langgraph.prebuilt import ToolNode from langgraph.graph import END from langgraph.prebuilt import tools_condition from langgraph.checkpoint.memory import MemorySaver logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) load_dotenv() GOOGLE_API_KEY = os.getenv('GOOGLE_API_KEY') GROQ_API_KEY = os.getenv('GROQ_API_KEY') if not GOOGLE_API_KEY or not GROQ_API_KEY: raise ValueError("API keys not set in environment variables") app = FastAPI() class QASystem: def __init__(self): self.vector_store = None self.graph = None self.memory = None self.embeddings = None self.client = None self.pdf_dir = "pdfss" def load_pdf_documents(self): documents = [] pdf_dir = Path(self.pdf_dir) if not pdf_dir.exists(): raise FileNotFoundError(f"PDF directory not found: {self.pdf_dir}") for pdf_path in pdf_dir.glob("*.pdf"): try: loader = PyPDFLoader(str(pdf_path)) documents.extend(loader.load()) logger.info(f"Loaded PDF: {pdf_path}") except Exception as e: logger.error(f"Error loading PDF {pdf_path}: {str(e)}") text_splitter = RecursiveCharacterTextSplitter( chunk_size=1000, chunk_overlap=100 ) split_docs = text_splitter.split_documents(documents) logger.info(f"Split documents into {len(split_docs)} chunks") return split_docs def initialize_system(self): try: self.client = QdrantClient(":memory:") try: self.client.get_collection("pdf_data") except Exception: self.client.create_collection( collection_name="pdf_data", vectors_config=VectorParams(size=768, distance=Distance.COSINE), ) logger.info("Created new collection: pdf_data") self.embeddings = GoogleGenerativeAIEmbeddings( model="models/embedding-001", google_api_key=GOOGLE_API_KEY ) self.vector_store = QdrantVectorStore( client=self.client, collection_name="pdf_data", embeddings=self.embeddings, ) documents = self.load_pdf_documents() if documents: try: points = self.client.scroll(collection_name="pdf_data", limit=100)[0] if points: self.client.delete( collection_name="pdf_data", points_selector=PointIdsList( points=[p.id for p in points] ) ) except Exception as e: logger.error(f"Error clearing vectors: {str(e)}") self.vector_store.add_documents(documents) logger.info(f"Added {len(documents)} documents to vector store") llm = ChatGroq( model="llama3-8b-8192", api_key=GROQ_API_KEY, temperature=0.7 ) graph_builder = StateGraph(MessagesState) # Define a retrieval node that fetches relevant docs def retrieve_docs(state: MessagesState): # Get the most recent human message human_messages = [m for m in state["messages"] if m.type == "human"] if not human_messages: return {"messages": state["messages"]} user_query = human_messages[-1].content logger.info(f"Retrieving documents for query: {user_query}") # Query the vector store try: retrieved_docs = self.vector_store.similarity_search(user_query, k=3) # Create tool messages for each retrieved document tool_messages = [] for i, doc in enumerate(retrieved_docs): tool_messages.append( ToolMessage( content=f"Document {i+1}: {doc.page_content}", tool_call_id=f"retrieval_{i}" ) ) logger.info(f"Retrieved {len(tool_messages)} relevant documents") return {"messages": state["messages"] + tool_messages} except Exception as e: logger.error(f"Error retrieving documents: {str(e)}") return {"messages": state["messages"]} # Updated generate function that uses retrieved documents def generate(state: MessagesState): # Extract retrieved documents (tool messages) tool_messages = [m for m in state["messages"] if m.type == "tool"] # Collect context from retrieved documents if tool_messages: context = "\n".join([m.content for m in tool_messages]) logger.info(f"Using context from {len(tool_messages)} retrieved documents") else: context = "No specific mountain bicycle documentation available." logger.info("No relevant documents retrieved, using default context") system_prompt = ( "You are an AI assistant embedded within the Interactive Electronic Technical Manual (IETM) for Mountain Cycles. " "Always provide accurate responses with references to provided data. " "If the user query is not technical-specific, still respond from a IETM perspective." f"\n\nContext from mountain bicycle documentation:\n{context}" ) # Get all messages excluding tool messages to avoid redundancy human_and_ai_messages = [m for m in state["messages"] if m.type != "tool"] # Create the full message history for the LLM messages = [SystemMessage(content=system_prompt)] + human_and_ai_messages logger.info(f"Sending query to LLM with {len(messages)} messages") # Generate the response response = llm.invoke(messages) return {"messages": state["messages"] + [response]} # Add nodes to the graph graph_builder.add_node("retrieve_docs", retrieve_docs) graph_builder.add_node("generate", generate) # Set the flow of the graph graph_builder.set_entry_point("retrieve_docs") graph_builder.add_edge("retrieve_docs", "generate") graph_builder.add_edge("generate", END) self.memory = MemorySaver() self.graph = graph_builder.compile(checkpointer=self.memory) return True except Exception as e: logger.error(f"System initialization error: {str(e)}") return False def process_query(self, query: str) -> Dict[str, str]: """Process a query and return a single final response""" try: # Generate a unique thread ID for production use # For simplicity, using a fixed ID here thread_id = "abc123" # Use invoke instead of stream to get only the final result final_state = self.graph.invoke( {"messages": [HumanMessage(content=query)]}, config={"configurable": {"thread_id": thread_id}} ) # Extract only the last AI message from the final state ai_messages = [m for m in final_state["messages"] if m.type == "ai"] if ai_messages: # Return only the last AI message return { 'content': ai_messages[-1].content, 'type': ai_messages[-1].type } return { 'content': "No response generated", 'type': 'error' } except Exception as e: logger.error(f"Query processing error: {str(e)}") return { 'content': f"Query processing error: {str(e)}", 'type': 'error' } qa_system = QASystem() if qa_system.initialize_system(): logger.info("QA System Initialized Successfully") else: raise RuntimeError("Failed to initialize QA System") @app.post("/query") async def query_api(query: str): """API endpoint that returns a single response for a query""" response = qa_system.process_query(query) return {"response": response}