import json import openai from config import PROJECT_ID, DATASET_ID from utils.bigquery_utils import get_bigquery_schema_info def table_selection_agent(state): """Identifies relevant tables for the natural language query based on schema.""" natural_language_query = state["sql_query"] client = state["client"] if client is None: return {"relevant_tables": [], "error": "Failed to connect to BigQuery."} schema_info = get_bigquery_schema_info(client, PROJECT_ID, DATASET_ID) # Format the schema for the prompt schema_text = "" for table_name, columns in schema_info.items(): schema_text += f"- **{DATASET_ID}.{table_name}** ({', '.join(columns)})\n" prompt = f""" Based on the following natural language query and BigQuery schema, identify the tables that would be needed to answer the query. **Query:** "{natural_language_query}" **BigQuery Schema:** {schema_text} Analyze the query and determine which tables contain the necessary information. IMPORTANT: Return ONLY a raw JSON array of table names without any markdown formatting, code blocks, or explanations. Example of correct response format: ["{DATASET_ID}.users", "{DATASET_ID}.orders"] Example of INCORRECT response format: ```json ["{DATASET_ID}.users", "{DATASET_ID}.orders"] ``` DO NOT use code blocks, backticks, or any other formatting. Return ONLY the raw JSON array. """ try: response = openai.chat.completions.create( model="gpt-4o", messages=[{"role": "user", "content": prompt}], temperature=0.0 ) # Get the content from the response content = response.choices[0].message.content.strip() # Remove markdown code block formatting if present if content.startswith("```"): # Extract content between the code block markers parts = content.split("```") if len(parts) >= 3: # There should be at least 3 parts if there are code blocks content = parts[1] # If there's a language identifier (like json), remove it if content.startswith("json"): content = content.replace("json", "", 1).strip() # Parse the JSON relevant_tables = json.loads(content) print(f"Parsed relevant tables: {relevant_tables}") return {"relevant_tables": relevant_tables} except json.JSONDecodeError as e: print(f"JSON Decode Error: {e}") print(f"Response content: {response.choices[0].message.content}") return {"relevant_tables": [], "error": "Invalid JSON response from OpenAI"} except Exception as e: print(f"Unexpected error: {e}") return {"relevant_tables": [], "error": f"Error: {str(e)}"}