VishnuRamDebyez commited on
Commit
ab11098
·
verified ·
1 Parent(s): 0b10d8a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +74 -69
app.py CHANGED
@@ -14,14 +14,16 @@ from qdrant_client import QdrantClient
14
  from qdrant_client.http.models import Distance, VectorParams
15
  from qdrant_client.models import PointIdsList
16
 
17
- from langgraph.graph import MessagesState, StateGraph, END
18
- from langgraph.checkpoint.memory import MemorySaver
19
  from langchain_core.messages import SystemMessage, HumanMessage
 
 
 
 
20
 
21
  logging.basicConfig(level=logging.INFO)
22
  logger = logging.getLogger(__name__)
23
 
24
- # Load environment variables
25
  load_dotenv()
26
  GOOGLE_API_KEY = os.getenv('GOOGLE_API_KEY')
27
  GROQ_API_KEY = os.getenv('GROQ_API_KEY')
@@ -35,7 +37,7 @@ class QASystem:
35
  def __init__(self):
36
  self.vector_store = None
37
  self.graph = None
38
- self.memory = MemorySaver() # LangGraph memory saver for conversation history
39
  self.embeddings = None
40
  self.client = None
41
  self.pdf_dir = "pdfss"
@@ -43,10 +45,10 @@ class QASystem:
43
  def load_pdf_documents(self):
44
  documents = []
45
  pdf_dir = Path(self.pdf_dir)
46
-
47
  if not pdf_dir.exists():
48
  raise FileNotFoundError(f"PDF directory not found: {self.pdf_dir}")
49
-
50
  for pdf_path in pdf_dir.glob("*.pdf"):
51
  try:
52
  loader = PyPDFLoader(str(pdf_path))
@@ -55,16 +57,18 @@ class QASystem:
55
  except Exception as e:
56
  logger.error(f"Error loading PDF {pdf_path}: {str(e)}")
57
 
58
- text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
 
 
 
59
  split_docs = text_splitter.split_documents(documents)
60
  logger.info(f"Split documents into {len(split_docs)} chunks")
61
  return split_docs
62
 
63
  def initialize_system(self):
64
  try:
65
- # Qdrant setup
66
  self.client = QdrantClient(":memory:")
67
-
68
  try:
69
  self.client.get_collection("pdf_data")
70
  except Exception:
@@ -73,88 +77,92 @@ class QASystem:
73
  vectors_config=VectorParams(size=768, distance=Distance.COSINE),
74
  )
75
  logger.info("Created new collection: pdf_data")
76
-
77
- # Embeddings and vector store
78
  self.embeddings = GoogleGenerativeAIEmbeddings(
79
- model="models/embedding-001", google_api_key=GOOGLE_API_KEY
 
80
  )
81
-
82
  self.vector_store = QdrantVectorStore(
83
  client=self.client,
84
  collection_name="pdf_data",
85
  embeddings=self.embeddings,
86
  )
87
-
88
- # Load and add documents
89
  documents = self.load_pdf_documents()
90
  if documents:
91
- points = self.client.scroll(collection_name="pdf_data", limit=100)[0]
92
- if points:
93
- self.client.delete(
94
- collection_name="pdf_data",
95
- points_selector=PointIdsList(points=[p.id for p in points])
96
- )
 
 
 
 
 
 
97
  self.vector_store.add_documents(documents)
98
  logger.info(f"Added {len(documents)} documents to vector store")
99
 
100
- # LLM setup
101
  llm = ChatGroq(
102
- model="llama3-8b-8192",
103
  api_key=GROQ_API_KEY,
104
  temperature=0.7
105
  )
106
-
107
- # Graph building
108
  graph_builder = StateGraph(MessagesState)
109
 
110
- # === TOOL NODE for context fetching from Qdrant ===
111
- def retrieve_documents(state: MessagesState):
112
- query = [m.content for m in state["messages"] if m.type == "human"][-1]
113
- results = self.vector_store.similarity_search(query, k=4)
114
- context = "\n\n".join([doc.page_content for doc in results])
115
- return {"messages": [SystemMessage(content=context, name="retrieval")]} # as tool message
116
-
117
- # === GENERATOR NODE that uses full memory (chat history) ===
118
- def generate_response(state: MessagesState):
119
- # Get full history from memory
120
- thread_id = state["configurable"].get("thread_id", "default")
121
- history = self.memory.get_memory(thread_id).get("messages", [])
122
-
123
- logger.info(f"[Thread {thread_id}] History: {[m.content for m in history]}")
124
-
125
- # Add current turn messages
126
- all_messages = history + state["messages"]
127
-
128
- # Extract context from retrieved docs (tool messages)
129
- retrieved_docs = [m for m in all_messages if m.type == "tool"]
 
 
 
 
 
130
  context = ' '.join(m.content for m in retrieved_docs) if retrieved_docs else "mountain bicycle documentation knowledge"
131
 
132
- # Compose system prompt
133
  system_prompt = (
134
  "You are an AI assistant embedded within the Interactive Electronic Technical Manual (IETM) for Mountain Cycles. "
135
- "Your responses MUST be accurate, concise (5 sentences max). "
136
- "If you don't know the answer, say 'I don't know based on available data.'\n\n"
137
- f"Context:\n{context}"
138
  )
139
 
140
- final_messages = [SystemMessage(content=system_prompt)] + all_messages
141
- response = llm.invoke(final_messages)
142
-
143
- # Save updated chat to memory
144
- self.memory.save_checkpoint(thread_id, {"messages": all_messages + [response]})
145
 
 
146
  return {"messages": [response]}
147
 
148
- # Add graph nodes
149
- graph_builder.add_node("retrieval", retrieve_documents)
150
- graph_builder.add_node("generate", generate_response)
151
 
152
- # Graph edges
153
- graph_builder.set_entry_point("retrieval")
154
- graph_builder.add_edge("retrieval", "generate")
 
 
155
  graph_builder.add_edge("generate", END)
156
-
157
- # Compile graph with memory
158
  self.graph = graph_builder.compile(checkpointer=self.memory)
159
  return True
160
 
@@ -162,14 +170,13 @@ class QASystem:
162
  logger.error(f"System initialization error: {str(e)}")
163
  return False
164
 
165
- # === Query Processor with Memory ===
166
- def process_query(self, query: str, user_id: str) -> List[Dict[str, str]]:
167
  try:
168
  responses = []
169
  for step in self.graph.stream(
170
  {"messages": [HumanMessage(content=query)]},
171
  stream_mode="values",
172
- config={"configurable": {"thread_id": user_id}} # thread ID for user memory
173
  ):
174
  if step["messages"]:
175
  responses.append({
@@ -181,15 +188,13 @@ class QASystem:
181
  logger.error(f"Query processing error: {str(e)}")
182
  return [{'content': f"Query processing error: {str(e)}", 'type': 'error'}]
183
 
184
- # === Initialize QA System ===
185
  qa_system = QASystem()
186
  if qa_system.initialize_system():
187
  logger.info("QA System Initialized Successfully")
188
  else:
189
  raise RuntimeError("Failed to initialize QA System")
190
 
191
- # === FastAPI Route ===
192
  @app.post("/query")
193
- async def query_api(query: str, user_id: str): # Pass user_id for session-specific memory
194
- responses = qa_system.process_query(query, user_id)
195
  return {"responses": responses}
 
14
  from qdrant_client.http.models import Distance, VectorParams
15
  from qdrant_client.models import PointIdsList
16
 
17
+ from langgraph.graph import MessagesState, StateGraph
 
18
  from langchain_core.messages import SystemMessage, HumanMessage
19
+ from langgraph.prebuilt import ToolNode
20
+ from langgraph.graph import END
21
+ from langgraph.prebuilt import tools_condition
22
+ from langgraph.checkpoint.memory import MemorySaver
23
 
24
  logging.basicConfig(level=logging.INFO)
25
  logger = logging.getLogger(__name__)
26
 
 
27
  load_dotenv()
28
  GOOGLE_API_KEY = os.getenv('GOOGLE_API_KEY')
29
  GROQ_API_KEY = os.getenv('GROQ_API_KEY')
 
37
  def __init__(self):
38
  self.vector_store = None
39
  self.graph = None
40
+ self.memory = None
41
  self.embeddings = None
42
  self.client = None
43
  self.pdf_dir = "pdfss"
 
45
  def load_pdf_documents(self):
46
  documents = []
47
  pdf_dir = Path(self.pdf_dir)
48
+
49
  if not pdf_dir.exists():
50
  raise FileNotFoundError(f"PDF directory not found: {self.pdf_dir}")
51
+
52
  for pdf_path in pdf_dir.glob("*.pdf"):
53
  try:
54
  loader = PyPDFLoader(str(pdf_path))
 
57
  except Exception as e:
58
  logger.error(f"Error loading PDF {pdf_path}: {str(e)}")
59
 
60
+ text_splitter = RecursiveCharacterTextSplitter(
61
+ chunk_size=1000,
62
+ chunk_overlap=100
63
+ )
64
  split_docs = text_splitter.split_documents(documents)
65
  logger.info(f"Split documents into {len(split_docs)} chunks")
66
  return split_docs
67
 
68
  def initialize_system(self):
69
  try:
 
70
  self.client = QdrantClient(":memory:")
71
+
72
  try:
73
  self.client.get_collection("pdf_data")
74
  except Exception:
 
77
  vectors_config=VectorParams(size=768, distance=Distance.COSINE),
78
  )
79
  logger.info("Created new collection: pdf_data")
80
+
 
81
  self.embeddings = GoogleGenerativeAIEmbeddings(
82
+ model="models/embedding-001",
83
+ google_api_key=GOOGLE_API_KEY
84
  )
85
+
86
  self.vector_store = QdrantVectorStore(
87
  client=self.client,
88
  collection_name="pdf_data",
89
  embeddings=self.embeddings,
90
  )
91
+
 
92
  documents = self.load_pdf_documents()
93
  if documents:
94
+ try:
95
+ points = self.client.scroll(collection_name="pdf_data", limit=100)[0]
96
+ if points:
97
+ self.client.delete(
98
+ collection_name="pdf_data",
99
+ points_selector=PointIdsList(
100
+ points=[p.id for p in points]
101
+ )
102
+ )
103
+ except Exception as e:
104
+ logger.error(f"Error clearing vectors: {str(e)}")
105
+
106
  self.vector_store.add_documents(documents)
107
  logger.info(f"Added {len(documents)} documents to vector store")
108
 
 
109
  llm = ChatGroq(
110
+ model="llama3-8b-8192",
111
  api_key=GROQ_API_KEY,
112
  temperature=0.7
113
  )
114
+
 
115
  graph_builder = StateGraph(MessagesState)
116
 
117
+ def query_or_respond(state: MessagesState):
118
+ retrieved_docs = [m for m in state["messages"] if m.type == "tool"]
119
+
120
+ if retrieved_docs:
121
+ context = ' '.join(m.content for m in retrieved_docs)
122
+ else:
123
+ context = "mountain bicycle documentation knowledge"
124
+
125
+ system_prompt = (
126
+ "You are an AI assistant embedded within the Interactive Electronic Technical Manual (IETM) for Mountain Cycles.. "
127
+ "Always provide accurate responses with references to provided data. "
128
+ "If the user query is not technical-specific, still respond from a IETM perspective."
129
+ f"\n\nContext:\n{context}"
130
+ )
131
+
132
+ messages = [SystemMessage(content=system_prompt)] + state["messages"]
133
+
134
+ logger.info(f"Sending to LLM: {[m.content for m in messages]}") # Debugging log
135
+
136
+ response = llm.invoke(messages)
137
+ return {"messages": [response]}
138
+
139
+ def generate(state: MessagesState):
140
+ retrieved_docs = [m for m in reversed(state["messages"]) if m.type == "tool"][::-1]
141
+
142
  context = ' '.join(m.content for m in retrieved_docs) if retrieved_docs else "mountain bicycle documentation knowledge"
143
 
 
144
  system_prompt = (
145
  "You are an AI assistant embedded within the Interactive Electronic Technical Manual (IETM) for Mountain Cycles. "
146
+ "Your responses MUST be accurate, concise (5 sentences max)."
147
+ f"\n\nContext:\n{context}"
 
148
  )
149
 
150
+ messages = [SystemMessage(content=system_prompt)] + state["messages"]
151
+
152
+ logger.info(f"Sending to LLM: {[m.content for m in messages]}") # Debugging log
 
 
153
 
154
+ response = llm.invoke(messages)
155
  return {"messages": [response]}
156
 
 
 
 
157
 
158
+ graph_builder.add_node("query_or_respond", query_or_respond)
159
+ graph_builder.add_node("generate", generate)
160
+
161
+ graph_builder.set_entry_point("query_or_respond")
162
+ graph_builder.add_edge("query_or_respond", "generate")
163
  graph_builder.add_edge("generate", END)
164
+
165
+ self.memory = MemorySaver()
166
  self.graph = graph_builder.compile(checkpointer=self.memory)
167
  return True
168
 
 
170
  logger.error(f"System initialization error: {str(e)}")
171
  return False
172
 
173
+ def process_query(self, query: str) -> List[Dict[str, str]]:
 
174
  try:
175
  responses = []
176
  for step in self.graph.stream(
177
  {"messages": [HumanMessage(content=query)]},
178
  stream_mode="values",
179
+ config={"configurable": {"thread_id": "abc123"}}
180
  ):
181
  if step["messages"]:
182
  responses.append({
 
188
  logger.error(f"Query processing error: {str(e)}")
189
  return [{'content': f"Query processing error: {str(e)}", 'type': 'error'}]
190
 
 
191
  qa_system = QASystem()
192
  if qa_system.initialize_system():
193
  logger.info("QA System Initialized Successfully")
194
  else:
195
  raise RuntimeError("Failed to initialize QA System")
196
 
 
197
  @app.post("/query")
198
+ async def query_api(query: str):
199
+ responses = qa_system.process_query(query)
200
  return {"responses": responses}