File size: 2,201 Bytes
05e3517
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
from langgraph.graph import StateGraph, START, END
from typing import TypedDict, Optional

from agents.table_selection import table_selection_agent
from agents.data_retrieval import sample_data_retrieval_agent
from agents.sql_generation import sql_generation_agent
from agents.validation import query_validation_and_optimization
from agents.execution import execution_agent
from utils.bigquery_utils import init_bigquery_connection

# Define the state schema
class SQLExecutionState(TypedDict):
    sql_query: str  # Natural language query
    client: Optional[object]  # BigQuery client
    relevant_tables: Optional[list]  # Tables identified as relevant
    sample_data: Optional[dict]      # Sample data from relevant tables
    generated_sql: Optional[str]     # The actual SQL query (not JSON)
    validation_result: Optional[dict]
    optimized_sql: Optional[str]
    execution_result: Optional[dict]

def initialize_client(state: SQLExecutionState) -> SQLExecutionState:
    """Initialize the BigQuery client and add it to the state."""
    client = init_bigquery_connection()
    return {"client": client}

def create_workflow():
    """Create and return the workflow graph."""
    # Initialize the LangGraph Workflow
    graph = StateGraph(state_schema=SQLExecutionState)
    
    # Add nodes
    graph.add_node("Initialize Client", initialize_client)
    graph.add_node("Table Selection", table_selection_agent)
    graph.add_node("Sample Data Retrieval", sample_data_retrieval_agent)
    graph.add_node("SQL Generation", sql_generation_agent)
    graph.add_node("Query Validation & Optimization", query_validation_and_optimization)
    graph.add_node("SQL Execution", execution_agent)
    
    # Define execution flow
    graph.add_edge(START, "Initialize Client")
    graph.add_edge("Initialize Client", "Table Selection")
    graph.add_edge("Table Selection", "Sample Data Retrieval")
    graph.add_edge("Sample Data Retrieval", "SQL Generation")
    graph.add_edge("SQL Generation", "Query Validation & Optimization")
    graph.add_edge("Query Validation & Optimization", "SQL Execution")
    graph.add_edge("SQL Execution", END)
    
    # Compile the graph
    return graph.compile()