Spaces:
Sleeping
Sleeping
File size: 2,201 Bytes
05e3517 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 |
from langgraph.graph import StateGraph, START, END
from typing import TypedDict, Optional
from agents.table_selection import table_selection_agent
from agents.data_retrieval import sample_data_retrieval_agent
from agents.sql_generation import sql_generation_agent
from agents.validation import query_validation_and_optimization
from agents.execution import execution_agent
from utils.bigquery_utils import init_bigquery_connection
# Define the state schema
class SQLExecutionState(TypedDict):
sql_query: str # Natural language query
client: Optional[object] # BigQuery client
relevant_tables: Optional[list] # Tables identified as relevant
sample_data: Optional[dict] # Sample data from relevant tables
generated_sql: Optional[str] # The actual SQL query (not JSON)
validation_result: Optional[dict]
optimized_sql: Optional[str]
execution_result: Optional[dict]
def initialize_client(state: SQLExecutionState) -> SQLExecutionState:
"""Initialize the BigQuery client and add it to the state."""
client = init_bigquery_connection()
return {"client": client}
def create_workflow():
"""Create and return the workflow graph."""
# Initialize the LangGraph Workflow
graph = StateGraph(state_schema=SQLExecutionState)
# Add nodes
graph.add_node("Initialize Client", initialize_client)
graph.add_node("Table Selection", table_selection_agent)
graph.add_node("Sample Data Retrieval", sample_data_retrieval_agent)
graph.add_node("SQL Generation", sql_generation_agent)
graph.add_node("Query Validation & Optimization", query_validation_and_optimization)
graph.add_node("SQL Execution", execution_agent)
# Define execution flow
graph.add_edge(START, "Initialize Client")
graph.add_edge("Initialize Client", "Table Selection")
graph.add_edge("Table Selection", "Sample Data Retrieval")
graph.add_edge("Sample Data Retrieval", "SQL Generation")
graph.add_edge("SQL Generation", "Query Validation & Optimization")
graph.add_edge("Query Validation & Optimization", "SQL Execution")
graph.add_edge("SQL Execution", END)
# Compile the graph
return graph.compile() |