Spaces:
Sleeping
Sleeping
| import os | |
| import re | |
| from typing import Annotated | |
| from typing_extensions import TypedDict | |
| # Remove ChatGroq import | |
| # from langchain_groq import ChatGroq | |
| # Add ChatGoogleGenerativeAI import | |
| from langchain_google_genai import ChatGoogleGenerativeAI | |
| import os # Add os import | |
| from langchain_openai import ChatOpenAI | |
| from langchain_core.messages import SystemMessage, HumanMessage | |
| from langchain_community.graphs import Neo4jGraph | |
| from langgraph.graph import StateGraph | |
| from langgraph.graph import add_messages | |
| from ki_gen.prompts import PLAN_GEN_PROMPT, PLAN_MODIFICATION_PROMPT | |
| from ki_gen.data_retriever import build_data_retriever_graph | |
| from ki_gen.data_processor import build_data_processor_graph | |
| # Import get_model which now handles Gemini | |
| from ki_gen.utils import ConfigSchema, State, HumanValidationState, DocProcessorState, DocRetrieverState, get_model | |
| from langgraph.checkpoint.sqlite import SqliteSaver | |
| ########################################################################## | |
| ###### NODES DEFINITION ###### | |
| ########################################################################## | |
| def validate_node(state: State): | |
| """ | |
| This node inserts the plan validation prompt. | |
| """ | |
| prompt = """System : You only need to focus on Key Issues, no need to focus on solutions or stakeholders yet and your plan should be concise. | |
| If needed, give me an updated plan to follow this instruction. If your plan already follows the instruction just say "My plan is correct".""" | |
| output = HumanMessage(content=prompt) | |
| return {"messages" : [output]} | |
| # Remove Groq-specific error handler | |
| # def error_chatbot_groq(error, model_name, query): ... | |
| # Wrappers to call LLMs on the state messsages field | |
| # Simplify: Use get_model directly or a single chatbot function | |
| def chatbot_node(state: State, config: ConfigSchema): | |
| """Generic chatbot node using the main_llm from config.""" | |
| model_name = config["configurable"].get("main_llm") or "gemini-2.0-flash" | |
| llm = get_model(model_name) | |
| try: | |
| # Check if messages exist and are not empty | |
| if "messages" in state and state["messages"]: | |
| response = llm.invoke(state["messages"]) | |
| return {"messages": [response]} | |
| else: | |
| print("Warning: No messages found in state for chatbot_node.") | |
| # Return state unchanged or an empty message list? | |
| return {} # Or {"messages": []} | |
| except Exception as e: | |
| print(f"Error invoking model {model_name}: {e}") | |
| # Handle error, maybe return an error message or empty dict | |
| return {"messages": [SystemMessage(content=f"Error during generation: {e}")]} | |
| # Remove old chatbot functions (chatbot_llama, chatbot_mixtral, chatbot_openai) | |
| # Replace the chatbots dictionary with direct calls to the generic function or specific models via get_model | |
| # This simplifies planner.py, relying on utils.py and config for model selection. | |
| def parse_plan(state: State): | |
| """ | |
| This node parses the generated plan and writes in the 'store_plan' field of the state | |
| """ | |
| # Find the AI message likely containing the plan (often the second to last if validate_node was used) | |
| plan_message_content = "" | |
| if "messages" in state and len(state["messages"]) >= 1: | |
| # Search backwards for the plan, as its position might vary | |
| for msg in reversed(state["messages"]): | |
| if hasattr(msg, 'content') and "Plan:" in msg.content and "<END_OF_PLAN>" in msg.content: | |
| plan_message_content = msg.content | |
| break # Found the plan | |
| if not plan_message_content: | |
| print("Error: Could not find plan message in state.") | |
| # Handle error: maybe return current state or raise an exception | |
| return state # Return unchanged state if plan not found | |
| store_plan = [] | |
| try: | |
| # Improved parsing: handle potential variations in formatting | |
| plan_section = plan_message_content.split("Plan:")[1].split("<END_OF_PLAN>")[0] | |
| # Split by numbered steps, removing empty entries | |
| store_plan = [step.strip() for step in re.split(r"\n\s*\d+\.\s*", plan_section) if step.strip()] | |
| except Exception as e: | |
| print(f"Error while parsing plan: {e}") | |
| # Handle parsing error, potentially keep store_plan empty or log the error | |
| store_plan = [] # Reset plan on error | |
| return {"store_plan" : store_plan} | |
| # Update get_detailed_query to use get_model and default model | |
| def get_detailed_query(context : list, model : str = "gemini-2.0-flash"): | |
| """ | |
| Simple helper function for the detail_step node | |
| """ | |
| llm = get_model(model) # Use get_model | |
| try: | |
| return llm.invoke(context) | |
| except Exception as e: | |
| print(f"Error in get_detailed_query with model {model}: {e}") | |
| # Return a default message or raise error | |
| return SystemMessage(content=f"Error generating detailed query: {e}") | |
| def detail_step(state: State, config: ConfigSchema): | |
| """ | |
| This node updates the value of the 'current_plan_step' field and defines the query to be used for the data_retriever. | |
| """ | |
| print("Entering detail_step") # Debug print | |
| print(f"Current state keys: {state.keys()}") # Debug print | |
| # Initialize current_plan_step if not present | |
| current_plan_step = state.get("current_plan_step", -1) + 1 | |
| # Ensure store_plan exists and has enough steps | |
| store_plan = state.get("store_plan", []) | |
| if not store_plan or current_plan_step >= len(store_plan): | |
| print(f"Warning: Plan step {current_plan_step} out of bounds or plan is empty.") | |
| # Decide how to handle: end graph, return error state? | |
| # For now, let's prevent index error and maybe signal an issue | |
| # Returning an empty query might halt progress or cause issues downstream | |
| return {"current_plan_step": current_plan_step, 'query' : "Error: Plan step unavailable.", "valid_docs" : []} | |
| plan_step_description = store_plan[current_plan_step] | |
| if config["configurable"].get("use_detailed_query"): | |
| prompt = HumanMessage(f"""Specify what additional information you need to proceed with the next step of your plan : | |
| Step {current_plan_step + 1} : {plan_step_description}""") | |
| # Ensure messages exist before appending | |
| current_messages = state.get("messages", []) | |
| query_message = get_detailed_query(context = current_messages + [prompt], model=config["configurable"].get("main_llm", "gemini-2.0-flash")) | |
| query_content = query_message.content if hasattr(query_message, 'content') else "Error: Could not get detailed query content." | |
| return {"messages" : [prompt, query_message], "current_plan_step": current_plan_step, 'query' : query_content, "valid_docs": state.get("valid_docs", [])} # Ensure valid_docs is preserved | |
| # If not using detailed query, use the plan step description directly | |
| return {"current_plan_step": current_plan_step, 'query' : plan_step_description, "valid_docs" : state.get("valid_docs", [])} # Ensure valid_docs is preserved | |
| def concatenate_data(state: State): | |
| """ | |
| This node concatenates all the data that was processed by the data_processor and inserts it in the state's messages | |
| """ | |
| # Ensure valid_docs exists and current_plan_step is valid | |
| valid_docs_content = state.get("valid_docs", "No processed documents available.") | |
| current_plan_step = state.get("current_plan_step", -1) | |
| store_plan = state.get("store_plan", []) | |
| if current_plan_step < 0 or current_plan_step >= len(store_plan): | |
| print(f"Warning: Invalid current_plan_step ({current_plan_step}) in concatenate_data.") | |
| # Handle error - maybe return an error message | |
| step_description = "Error: Current plan step invalid." | |
| else: | |
| step_description = store_plan[current_plan_step] | |
| prompt = f"""#########TECHNICAL INFORMATION ############ | |
| {str(valid_docs_content)} | |
| ########END OF TECHNICAL INFORMATION####### | |
| Using the information provided above, proceed with step {current_plan_step + 1} of your plan : | |
| {step_description} | |
| """ | |
| return {"messages": [HumanMessage(content=prompt)]} | |
| def human_validation(state: HumanValidationState) -> HumanValidationState: | |
| """ | |
| Dummy node to interrupt before processing, can be used for manual validation later. | |
| """ | |
| # Defaulting to no processing steps needed unless specified elsewhere | |
| return {'process_steps' : state.get('process_steps', [])} | |
| def generate_ki(state: State): | |
| """ | |
| This node inserts the prompt to begin Key Issues generation | |
| """ | |
| print(f"THIS IS THE STATE FOR CURRENT PLAN STEP IN GENERATE_KI : {state.get('current_plan_step')}") | |
| current_plan_step = state.get("current_plan_step", -1) | |
| store_plan = state.get("store_plan", []) | |
| # Check if the next step exists in the plan | |
| next_step_index = current_plan_step + 1 | |
| if next_step_index < 0 or next_step_index >= len(store_plan): | |
| print(f"Warning: Invalid next plan step ({next_step_index}) for KI generation.") | |
| step_description = "Error: Plan step for KI generation unavailable." | |
| else: | |
| step_description = store_plan[next_step_index] | |
| prompt = f"""Using the information provided above, proceed with step {next_step_index + 1} of your plan to provide the user with NEW and INNOVATIVE Key Issues : | |
| {step_description}""" | |
| return {"messages" : [HumanMessage(content=prompt)]} | |
| def detail_ki(state: State): | |
| """ | |
| This node inserts the last prompt to detail the generated Key Issues | |
| """ | |
| current_plan_step = state.get("current_plan_step", -1) | |
| store_plan = state.get("store_plan", []) | |
| # Check if the step after next exists in the plan | |
| detail_step_index = current_plan_step + 2 | |
| if detail_step_index < 0 or detail_step_index >= len(store_plan): | |
| print(f"Warning: Invalid plan step ({detail_step_index}) for KI detailing.") | |
| step_description = "Error: Plan step for KI detailing unavailable." | |
| else: | |
| step_description = store_plan[detail_step_index] | |
| prompt = f"""Using the information provided above, proceed with step {detail_step_index + 1} of your plan to provide the user with NEW and INNOVATIVE Key Issues : | |
| {step_description}""" | |
| return {"messages" : [HumanMessage(content=prompt)]} | |
| ########################################################################## | |
| ###### CONDITIONAL EDGE FUNCTIONS ###### | |
| ########################################################################## | |
| def validate_plan(state: State): | |
| """ | |
| Whether to regenerate the plan or to parse it | |
| """ | |
| # Check the last message for "My plan is correct" | |
| if "messages" in state and state["messages"]: | |
| last_message = state["messages"][-1] | |
| if hasattr(last_message, 'content') and "My plan is correct" in last_message.content: | |
| return "parse" | |
| # Default to validate (regenerate) if condition not met or messages are missing | |
| return "validate" | |
| def next_plan_step(state: State, config: ConfigSchema): | |
| """ | |
| Proceed to next plan step (either generate KI or retrieve more data) | |
| """ | |
| current_plan_step = state.get("current_plan_step", -1) | |
| store_plan_len = len(state.get("store_plan", [])) | |
| # Simplified logic: go to KI generation if it's the last step based on plan length | |
| if current_plan_step >= store_plan_len - 1: | |
| return "generate_key_issues" | |
| else: | |
| return "detail_step" | |
| def detail_or_data_retriever(state: State, config: ConfigSchema): | |
| """ | |
| Decide whether to detail the query or go straight to data retrieval. | |
| """ | |
| # Check configuration if detailed query is needed | |
| if config["configurable"].get("use_detailed_query"): | |
| # Need to invoke the LLM to get the detailed query | |
| return "chatbot_detail" | |
| else: | |
| # Use the plan step directly as the query | |
| return "data_retriever" | |
| def retrieve_or_process(state: State): | |
| """ | |
| Process the retrieved docs or keep retrieving (based on human_validated flag). | |
| """ | |
| # Check the 'human_validated' flag in the state | |
| # This flag needs to be set externally (e.g., by Streamlit UI or another mechanism) | |
| # before this node is reached after data_retriever. | |
| if state.get('human_validated'): | |
| return "process" | |
| else: | |
| # If not validated, loop back to retrieve more (or wait for validation) | |
| # This assumes data_retriever might be called again or the graph waits. | |
| # In the Streamlit app, the human_validation node allows setting this flag. | |
| return "retrieve" | |
| def build_planner_graph(memory, config): | |
| """ | |
| Builds the planner graph | |
| """ | |
| graph_builder = StateGraph(State) | |
| graph_doc_retriever = build_data_retriever_graph(memory) | |
| graph_doc_processor = build_data_processor_graph(memory) | |
| # Use the generic chatbot node function | |
| graph_builder.add_node("chatbot_planner", lambda state: chatbot_node(state, config)) | |
| graph_builder.add_node("validate", validate_node) | |
| # Add node for chatbot interaction when detailed query is needed | |
| graph_builder.add_node("chatbot_detail", lambda state: chatbot_node(state, config)) | |
| graph_builder.add_node("parse", parse_plan) | |
| # Pass config to detail_step as it needs it now | |
| graph_builder.add_node("detail_step", lambda state: detail_step(state, config)) | |
| graph_builder.add_node("data_retriever", graph_doc_retriever) # Input mapping happens automatically if state keys match | |
| graph_builder.add_node("human_validation", human_validation) # Needs input mapping if HumanValidationState differs significantly | |
| graph_builder.add_node("data_processor", graph_doc_processor) # Needs input mapping if DocProcessorState differs significantly | |
| graph_builder.add_node("concatenate_data", concatenate_data) | |
| # Use the generic chatbot node function | |
| graph_builder.add_node("chatbot_exec_step", lambda state: chatbot_node(state, config)) | |
| graph_builder.add_node("generate_ki", generate_ki) | |
| # Use the generic chatbot node function | |
| graph_builder.add_node("chatbot_ki", lambda state: chatbot_node(state, config)) | |
| graph_builder.add_node("detail_ki", detail_ki) | |
| # Use the generic chatbot node function | |
| graph_builder.add_node("chatbot_final", lambda state: chatbot_node(state, config)) | |
| # Define edges | |
| graph_builder.add_edge("validate", "chatbot_planner") | |
| graph_builder.add_edge("parse", "detail_step") | |
| # Edge from chatbot_detail (after getting detailed query) to data_retriever | |
| graph_builder.add_edge("chatbot_detail", "data_retriever") | |
| graph_builder.add_edge("data_retriever", "human_validation") | |
| graph_builder.add_edge("data_processor", "concatenate_data") | |
| graph_builder.add_edge("concatenate_data", "chatbot_exec_step") | |
| graph_builder.add_edge("generate_ki", "chatbot_ki") | |
| graph_builder.add_edge("chatbot_ki", "detail_ki") | |
| graph_builder.add_edge("detail_ki", "chatbot_final") | |
| graph_builder.add_edge("chatbot_final", "__end__") | |
| # Define conditional edges | |
| graph_builder.add_conditional_edges( | |
| "detail_step", | |
| # Pass config to the conditional function | |
| lambda state: detail_or_data_retriever(state, config), | |
| {"chatbot_detail": "chatbot_detail", "data_retriever": "data_retriever"} | |
| ) | |
| graph_builder.add_conditional_edges( | |
| "human_validation", | |
| retrieve_or_process, | |
| # Map 'retrieve' back to 'data_retriever' node, 'process' to 'data_processor' | |
| {"retrieve" : "data_retriever", "process" : "data_processor"} | |
| ) | |
| graph_builder.add_conditional_edges( | |
| "chatbot_planner", | |
| validate_plan, | |
| {"parse" : "parse", "validate": "validate"} | |
| ) | |
| graph_builder.add_conditional_edges( | |
| "chatbot_exec_step", | |
| # Pass config to the conditional function | |
| lambda state: next_plan_step(state, config), | |
| {"generate_key_issues" : "generate_ki", "detail_step": "detail_step"} | |
| ) | |
| # Set entry point | |
| graph_builder.set_entry_point("chatbot_planner") | |
| # Compile the graph | |
| graph = graph_builder.compile( | |
| checkpointer=memory, | |
| # Define interrupt points if needed for human interaction or debugging | |
| interrupt_after=["human_validation", "chatbot_final"], | |
| ) | |
| return graph |