Spaces:
Sleeping
Sleeping
Update ki_gen/data_retriever.py
Browse files- ki_gen/data_retriever.py +44 -36
ki_gen/data_retriever.py
CHANGED
|
@@ -7,9 +7,9 @@ from random import shuffle, sample
|
|
| 7 |
from langgraph.checkpoint.sqlite import SqliteSaver
|
| 8 |
|
| 9 |
# Remove ChatGroq import
|
| 10 |
-
# from langchain_groq import ChatGroq
|
| 11 |
# Add ChatGoogleGenerativeAI import
|
| 12 |
-
from langchain_google_genai import ChatGoogleGenerativeAI
|
| 13 |
import os # Add os import
|
| 14 |
|
| 15 |
from langchain_openai import ChatOpenAI
|
|
@@ -21,21 +21,20 @@ from langchain_core.prompts import ChatPromptTemplate
|
|
| 21 |
from langchain_core.pydantic_v1 import Field
|
| 22 |
from pydantic import BaseModel
|
| 23 |
|
| 24 |
-
from neo4j import GraphDatabase
|
| 25 |
|
| 26 |
from langgraph.graph import StateGraph
|
| 27 |
|
| 28 |
from llmlingua import PromptCompressor
|
| 29 |
|
| 30 |
from ki_gen.prompts import (
|
| 31 |
-
CYPHER_GENERATION_PROMPT,
|
| 32 |
CONCEPT_SELECTION_PROMPT,
|
| 33 |
BINARY_GRADER_PROMPT,
|
| 34 |
SCORE_GRADER_PROMPT,
|
| 35 |
RELEVANT_CONCEPTS_PROMPT,
|
| 36 |
)
|
| 37 |
# Import get_model which now handles Gemini
|
| 38 |
-
from ki_gen.utils import ConfigSchema, DocRetrieverState, get_model, format_doc
|
| 39 |
|
| 40 |
|
| 41 |
# ... (extract_cypher remains the same)
|
|
@@ -99,7 +98,7 @@ def get_concepts(graph: Neo4jGraph):
|
|
| 99 |
def get_related_concepts(graph: Neo4jGraph, question: str):
|
| 100 |
concepts = get_concepts(graph)
|
| 101 |
# Use get_model
|
| 102 |
-
llm = get_model()
|
| 103 |
print(f"this is the llm variable : {llm}")
|
| 104 |
def parse_answer(llm_answer : str):
|
| 105 |
try:
|
|
@@ -113,7 +112,7 @@ def get_related_concepts(graph: Neo4jGraph, question: str):
|
|
| 113 |
|
| 114 |
print(f"This is the question of the user : {question}")
|
| 115 |
print(f"This is the concepts of the user : {concepts}")
|
| 116 |
-
|
| 117 |
# Remove specific Groq error handling block
|
| 118 |
try:
|
| 119 |
related_concepts_raw = related_concepts_chain.invoke({"user_query" : question, "concepts" : '\n'.join(concepts)})
|
|
@@ -148,7 +147,7 @@ def build_concept_string(graph: Neo4jGraph, concept_list: list[str]):
|
|
| 148 |
MATCH (c:Concept {{name: "{concept}" }}) RETURN c.description
|
| 149 |
"""
|
| 150 |
concept_description = graph.query(concept_description_query)[0]['c.description']
|
| 151 |
-
concept_string += f"name: {concept}\ndescription: {concept_description}\n\n"
|
| 152 |
return concept_string
|
| 153 |
|
| 154 |
def get_global_concepts(graph: Neo4jGraph):
|
|
@@ -167,12 +166,20 @@ def generate_cypher(state: DocRetrieverState, config: ConfigSchema):
|
|
| 167 |
"""
|
| 168 |
The node where the cypher is generated
|
| 169 |
"""
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 176 |
related_concepts = get_related_concepts(graph, question)
|
| 177 |
cyphers = []
|
| 178 |
|
|
@@ -183,15 +190,18 @@ def generate_cypher(state: DocRetrieverState, config: ConfigSchema):
|
|
| 183 |
"question": question,
|
| 184 |
"concepts": related_concepts
|
| 185 |
})
|
| 186 |
-
|
| 187 |
# Remove specific Groq error handling block
|
| 188 |
try:
|
| 189 |
if config["configurable"].get("cypher_gen_method") == 'guided':
|
| 190 |
concept_selection_chain = get_concept_selection_chain()
|
| 191 |
print(f"Concept selection chain is : {concept_selection_chain}")
|
|
|
|
| 192 |
selected_topic = concept_selection_chain.invoke({"question" : question, "concepts": get_concepts(graph)})
|
| 193 |
print(f"Selected topic are : {selected_topic}")
|
| 194 |
-
|
|
|
|
|
|
|
| 195 |
print(f"Cyphers are : {cyphers}")
|
| 196 |
|
| 197 |
except Exception as e:
|
|
@@ -205,7 +215,7 @@ def generate_cypher(state: DocRetrieverState, config: ConfigSchema):
|
|
| 205 |
corrector_schema = [Schema(el["start"], el["type"], el["end"]) for el in graph.structured_schema.get("relationships", [])]
|
| 206 |
cypher_corrector = CypherQueryCorrector(corrector_schema)
|
| 207 |
# Apply corrector only if cyphers were generated
|
| 208 |
-
if cyphers:
|
| 209 |
try:
|
| 210 |
cyphers = [cypher_corrector(cypher) for cypher in cyphers]
|
| 211 |
except Exception as corr_e:
|
|
@@ -214,9 +224,10 @@ def generate_cypher(state: DocRetrieverState, config: ConfigSchema):
|
|
| 214 |
else:
|
| 215 |
print("Warning: Cypher validation skipped, graph or schema unavailable.")
|
| 216 |
|
| 217 |
-
|
| 218 |
return {"cyphers" : cyphers}
|
| 219 |
|
|
|
|
| 220 |
# ... (generate_cypher_from_topic, get_docs remain the same)
|
| 221 |
def generate_cypher_from_topic(selected_concept: str, plan_step: int):
|
| 222 |
"""
|
|
@@ -232,25 +243,21 @@ def generate_cypher_from_topic(selected_concept: str, plan_step: int):
|
|
| 232 |
cypher_el = "(rp:ResearchPaper) RETURN rp.title, rp.abstract"
|
| 233 |
case 2:
|
| 234 |
cypher_el = "(ki:KeyIssue) RETURN ki.description"
|
| 235 |
-
return f"MATCH (c:Concept {{name:'{selected_concept}'}})-[:RELATED_TO]-{cypher_el}"
|
| 236 |
|
| 237 |
def get_docs(state:DocRetrieverState, config:ConfigSchema):
|
| 238 |
"""
|
| 239 |
This node retrieves docs from the graph using the generated cypher
|
| 240 |
"""
|
| 241 |
-
|
| 242 |
-
NEO4J_URI = "neo4j+s://4985272f.databases.neo4j.io"
|
| 243 |
-
NEO4J_USERNAME = "neo4j"
|
| 244 |
-
NEO4J_PASSWORD = os.getenv("neo4j_password")
|
| 245 |
-
graph = GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USERNAME, NEO4J_PASSWORD))
|
| 246 |
output = []
|
| 247 |
if graph is not None and state.get("cyphers"): # Check if cyphers exist
|
| 248 |
for cypher in state["cyphers"]:
|
| 249 |
try:
|
| 250 |
output = graph.query(cypher)
|
| 251 |
# Assuming the first successful query is sufficient
|
| 252 |
-
if output:
|
| 253 |
-
break
|
| 254 |
except Exception as e:
|
| 255 |
print(f"Failed to retrieve docs with cypher '{cypher}': {e}")
|
| 256 |
# Continue to try next cypher if one fails
|
|
@@ -264,13 +271,13 @@ def get_docs(state:DocRetrieverState, config:ConfigSchema):
|
|
| 264 |
for key in doc:
|
| 265 |
if isinstance(doc[key], dict):
|
| 266 |
# If a value is a dict, treat it as a separate document
|
| 267 |
-
all_docs.append(doc[key])
|
| 268 |
else:
|
| 269 |
unwinded_doc.update({key: doc[key]})
|
| 270 |
# Add the unwinded parts if any keys were not dictionaries
|
| 271 |
-
if unwinded_doc:
|
| 272 |
all_docs.append(unwinded_doc)
|
| 273 |
-
|
| 274 |
filtered_docs = []
|
| 275 |
seen_docs = set() # Use a set for faster duplicate checking based on a unique identifier
|
| 276 |
|
|
@@ -278,7 +285,7 @@ def get_docs(state:DocRetrieverState, config:ConfigSchema):
|
|
| 278 |
# Create a tuple of items to check for duplicates, assuming dicts are hashable
|
| 279 |
# If dicts contain unhashable types (like lists), convert them to strings or use a primary key
|
| 280 |
try:
|
| 281 |
-
doc_tuple = tuple(sorted(doc.items()))
|
| 282 |
if doc_tuple not in seen_docs:
|
| 283 |
filtered_docs.append(doc)
|
| 284 |
seen_docs.add(doc_tuple)
|
|
@@ -290,7 +297,7 @@ def get_docs(state:DocRetrieverState, config:ConfigSchema):
|
|
| 290 |
filtered_docs.append(doc)
|
| 291 |
seen_docs.add(doc_str)
|
| 292 |
|
| 293 |
-
|
| 294 |
return {"docs": filtered_docs}
|
| 295 |
|
| 296 |
|
|
@@ -385,13 +392,13 @@ def eval_doc(doc, query, method="binary", threshold=0.7, eval_model="gemini-2.0-
|
|
| 385 |
# Update default model
|
| 386 |
def eval_docs(state: DocRetrieverState, config: ConfigSchema):
|
| 387 |
"""
|
| 388 |
-
This node performs evaluation of the retrieved docs and
|
| 389 |
"""
|
| 390 |
|
| 391 |
eval_method = config["configurable"].get("eval_method") or "binary"
|
| 392 |
MAX_DOCS = config["configurable"].get("max_docs") or 15
|
| 393 |
# Update default model name
|
| 394 |
-
eval_model_name = config["configurable"].get("eval_model") or "gemini-2.0-flash"
|
| 395 |
valid_doc_scores = []
|
| 396 |
|
| 397 |
# Ensure 'docs' exists and is a list
|
|
@@ -419,7 +426,7 @@ def eval_docs(state: DocRetrieverState, config: ConfigSchema):
|
|
| 419 |
|
| 420 |
score = eval_doc(
|
| 421 |
doc=formatted_doc_str,
|
| 422 |
-
query=state["query"],
|
| 423 |
method=eval_method,
|
| 424 |
threshold=config["configurable"].get("eval_threshold") or 0.7,
|
| 425 |
eval_model=eval_model_name # Pass the eval_model name
|
|
@@ -431,7 +438,7 @@ def eval_docs(state: DocRetrieverState, config: ConfigSchema):
|
|
| 431 |
else:
|
| 432 |
print(f"Warning: Received non-numeric score ({score}) for doc {doc}, skipping.")
|
| 433 |
|
| 434 |
-
|
| 435 |
if eval_method == 'score':
|
| 436 |
# Get at most MAX_DOCS items with the highest score if score method was used
|
| 437 |
valid_docs_sorted = sorted(valid_doc_scores, key=lambda x: x[1], reverse=True) # Sort descending
|
|
@@ -454,7 +461,7 @@ def build_data_retriever_graph(memory):
|
|
| 454 |
"""
|
| 455 |
Builds the data_retriever graph
|
| 456 |
"""
|
| 457 |
-
#with SqliteSaver.from_conn_string(":memory:") as memory :
|
| 458 |
|
| 459 |
graph_builder_doc_retriever = StateGraph(DocRetrieverState)
|
| 460 |
|
|
@@ -469,6 +476,7 @@ def build_data_retriever_graph(memory):
|
|
| 469 |
graph_builder_doc_retriever.add_edge("eval_docs", "__end__")
|
| 470 |
|
| 471 |
graph_doc_retriever = graph_builder_doc_retriever.compile(checkpointer=memory)
|
|
|
|
| 472 |
return graph_doc_retriever
|
| 473 |
|
| 474 |
# Remove Groq specific error handling function
|
|
|
|
| 7 |
from langgraph.checkpoint.sqlite import SqliteSaver
|
| 8 |
|
| 9 |
# Remove ChatGroq import
|
| 10 |
+
# from langchain_groq import ChatGroq
|
| 11 |
# Add ChatGoogleGenerativeAI import
|
| 12 |
+
from langchain_google_genai import ChatGoogleGenerativeAI
|
| 13 |
import os # Add os import
|
| 14 |
|
| 15 |
from langchain_openai import ChatOpenAI
|
|
|
|
| 21 |
from langchain_core.pydantic_v1 import Field
|
| 22 |
from pydantic import BaseModel
|
| 23 |
|
|
|
|
| 24 |
|
| 25 |
from langgraph.graph import StateGraph
|
| 26 |
|
| 27 |
from llmlingua import PromptCompressor
|
| 28 |
|
| 29 |
from ki_gen.prompts import (
|
| 30 |
+
CYPHER_GENERATION_PROMPT,
|
| 31 |
CONCEPT_SELECTION_PROMPT,
|
| 32 |
BINARY_GRADER_PROMPT,
|
| 33 |
SCORE_GRADER_PROMPT,
|
| 34 |
RELEVANT_CONCEPTS_PROMPT,
|
| 35 |
)
|
| 36 |
# Import get_model which now handles Gemini
|
| 37 |
+
from ki_gen.utils import ConfigSchema, DocRetrieverState, get_model, format_doc
|
| 38 |
|
| 39 |
|
| 40 |
# ... (extract_cypher remains the same)
|
|
|
|
| 98 |
def get_related_concepts(graph: Neo4jGraph, question: str):
|
| 99 |
concepts = get_concepts(graph)
|
| 100 |
# Use get_model
|
| 101 |
+
llm = get_model()
|
| 102 |
print(f"this is the llm variable : {llm}")
|
| 103 |
def parse_answer(llm_answer : str):
|
| 104 |
try:
|
|
|
|
| 112 |
|
| 113 |
print(f"This is the question of the user : {question}")
|
| 114 |
print(f"This is the concepts of the user : {concepts}")
|
| 115 |
+
|
| 116 |
# Remove specific Groq error handling block
|
| 117 |
try:
|
| 118 |
related_concepts_raw = related_concepts_chain.invoke({"user_query" : question, "concepts" : '\n'.join(concepts)})
|
|
|
|
| 147 |
MATCH (c:Concept {{name: "{concept}" }}) RETURN c.description
|
| 148 |
"""
|
| 149 |
concept_description = graph.query(concept_description_query)[0]['c.description']
|
| 150 |
+
concept_string += f"name: {concept}\ndescription: {concept_description}\n\n"
|
| 151 |
return concept_string
|
| 152 |
|
| 153 |
def get_global_concepts(graph: Neo4jGraph):
|
|
|
|
| 166 |
"""
|
| 167 |
The node where the cypher is generated
|
| 168 |
"""
|
| 169 |
+
graph = config["configurable"].get("graph")
|
| 170 |
+
|
| 171 |
+
# --- Correction Applied Here ---
|
| 172 |
+
# Use .get() for safer access to 'query'
|
| 173 |
+
question = state.get('query')
|
| 174 |
+
if not question:
|
| 175 |
+
# Handle the case where query is missing
|
| 176 |
+
print("Error: 'query' key not found in state for generate_cypher node.")
|
| 177 |
+
# Return an empty list or appropriate error state
|
| 178 |
+
# This prevents the KeyError and stops processing for this branch if query is missing
|
| 179 |
+
return {"cyphers": []}
|
| 180 |
+
# --- End of Correction ---
|
| 181 |
+
|
| 182 |
+
|
| 183 |
related_concepts = get_related_concepts(graph, question)
|
| 184 |
cyphers = []
|
| 185 |
|
|
|
|
| 190 |
"question": question,
|
| 191 |
"concepts": related_concepts
|
| 192 |
})
|
| 193 |
+
|
| 194 |
# Remove specific Groq error handling block
|
| 195 |
try:
|
| 196 |
if config["configurable"].get("cypher_gen_method") == 'guided':
|
| 197 |
concept_selection_chain = get_concept_selection_chain()
|
| 198 |
print(f"Concept selection chain is : {concept_selection_chain}")
|
| 199 |
+
# Ensure 'current_plan_step' is also safely accessed if needed here, though it's used later
|
| 200 |
selected_topic = concept_selection_chain.invoke({"question" : question, "concepts": get_concepts(graph)})
|
| 201 |
print(f"Selected topic are : {selected_topic}")
|
| 202 |
+
# Safely get 'current_plan_step', defaulting to 0 if not found
|
| 203 |
+
current_plan_step = state.get('current_plan_step', 0)
|
| 204 |
+
cyphers = [generate_cypher_from_topic(selected_topic, current_plan_step)]
|
| 205 |
print(f"Cyphers are : {cyphers}")
|
| 206 |
|
| 207 |
except Exception as e:
|
|
|
|
| 215 |
corrector_schema = [Schema(el["start"], el["type"], el["end"]) for el in graph.structured_schema.get("relationships", [])]
|
| 216 |
cypher_corrector = CypherQueryCorrector(corrector_schema)
|
| 217 |
# Apply corrector only if cyphers were generated
|
| 218 |
+
if cyphers:
|
| 219 |
try:
|
| 220 |
cyphers = [cypher_corrector(cypher) for cypher in cyphers]
|
| 221 |
except Exception as corr_e:
|
|
|
|
| 224 |
else:
|
| 225 |
print("Warning: Cypher validation skipped, graph or schema unavailable.")
|
| 226 |
|
| 227 |
+
|
| 228 |
return {"cyphers" : cyphers}
|
| 229 |
|
| 230 |
+
|
| 231 |
# ... (generate_cypher_from_topic, get_docs remain the same)
|
| 232 |
def generate_cypher_from_topic(selected_concept: str, plan_step: int):
|
| 233 |
"""
|
|
|
|
| 243 |
cypher_el = "(rp:ResearchPaper) RETURN rp.title, rp.abstract"
|
| 244 |
case 2:
|
| 245 |
cypher_el = "(ki:KeyIssue) RETURN ki.description"
|
| 246 |
+
return f"MATCH (c:Concept {{name:'{selected_concept}'}})-[:RELATED_TO]-{cypher_el}"
|
| 247 |
|
| 248 |
def get_docs(state:DocRetrieverState, config:ConfigSchema):
|
| 249 |
"""
|
| 250 |
This node retrieves docs from the graph using the generated cypher
|
| 251 |
"""
|
| 252 |
+
graph = config["configurable"].get("graph")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 253 |
output = []
|
| 254 |
if graph is not None and state.get("cyphers"): # Check if cyphers exist
|
| 255 |
for cypher in state["cyphers"]:
|
| 256 |
try:
|
| 257 |
output = graph.query(cypher)
|
| 258 |
# Assuming the first successful query is sufficient
|
| 259 |
+
if output:
|
| 260 |
+
break
|
| 261 |
except Exception as e:
|
| 262 |
print(f"Failed to retrieve docs with cypher '{cypher}': {e}")
|
| 263 |
# Continue to try next cypher if one fails
|
|
|
|
| 271 |
for key in doc:
|
| 272 |
if isinstance(doc[key], dict):
|
| 273 |
# If a value is a dict, treat it as a separate document
|
| 274 |
+
all_docs.append(doc[key])
|
| 275 |
else:
|
| 276 |
unwinded_doc.update({key: doc[key]})
|
| 277 |
# Add the unwinded parts if any keys were not dictionaries
|
| 278 |
+
if unwinded_doc:
|
| 279 |
all_docs.append(unwinded_doc)
|
| 280 |
+
|
| 281 |
filtered_docs = []
|
| 282 |
seen_docs = set() # Use a set for faster duplicate checking based on a unique identifier
|
| 283 |
|
|
|
|
| 285 |
# Create a tuple of items to check for duplicates, assuming dicts are hashable
|
| 286 |
# If dicts contain unhashable types (like lists), convert them to strings or use a primary key
|
| 287 |
try:
|
| 288 |
+
doc_tuple = tuple(sorted(doc.items()))
|
| 289 |
if doc_tuple not in seen_docs:
|
| 290 |
filtered_docs.append(doc)
|
| 291 |
seen_docs.add(doc_tuple)
|
|
|
|
| 297 |
filtered_docs.append(doc)
|
| 298 |
seen_docs.add(doc_str)
|
| 299 |
|
| 300 |
+
|
| 301 |
return {"docs": filtered_docs}
|
| 302 |
|
| 303 |
|
|
|
|
| 392 |
# Update default model
|
| 393 |
def eval_docs(state: DocRetrieverState, config: ConfigSchema):
|
| 394 |
"""
|
| 395 |
+
This node performs evaluation of the retrieved docs and
|
| 396 |
"""
|
| 397 |
|
| 398 |
eval_method = config["configurable"].get("eval_method") or "binary"
|
| 399 |
MAX_DOCS = config["configurable"].get("max_docs") or 15
|
| 400 |
# Update default model name
|
| 401 |
+
eval_model_name = config["configurable"].get("eval_model") or "gemini-2.0-flash"
|
| 402 |
valid_doc_scores = []
|
| 403 |
|
| 404 |
# Ensure 'docs' exists and is a list
|
|
|
|
| 426 |
|
| 427 |
score = eval_doc(
|
| 428 |
doc=formatted_doc_str,
|
| 429 |
+
query=state["query"], # This line assumes "query" exists in state
|
| 430 |
method=eval_method,
|
| 431 |
threshold=config["configurable"].get("eval_threshold") or 0.7,
|
| 432 |
eval_model=eval_model_name # Pass the eval_model name
|
|
|
|
| 438 |
else:
|
| 439 |
print(f"Warning: Received non-numeric score ({score}) for doc {doc}, skipping.")
|
| 440 |
|
| 441 |
+
|
| 442 |
if eval_method == 'score':
|
| 443 |
# Get at most MAX_DOCS items with the highest score if score method was used
|
| 444 |
valid_docs_sorted = sorted(valid_doc_scores, key=lambda x: x[1], reverse=True) # Sort descending
|
|
|
|
| 461 |
"""
|
| 462 |
Builds the data_retriever graph
|
| 463 |
"""
|
| 464 |
+
#with SqliteSaver.from_conn_string(":memory:") as memory :
|
| 465 |
|
| 466 |
graph_builder_doc_retriever = StateGraph(DocRetrieverState)
|
| 467 |
|
|
|
|
| 476 |
graph_builder_doc_retriever.add_edge("eval_docs", "__end__")
|
| 477 |
|
| 478 |
graph_doc_retriever = graph_builder_doc_retriever.compile(checkpointer=memory)
|
| 479 |
+
|
| 480 |
return graph_doc_retriever
|
| 481 |
|
| 482 |
# Remove Groq specific error handling function
|