Spaces:
Sleeping
Sleeping
import gradio as gr | |
import sqlite3 | |
import json | |
import pandas as pd | |
from openai import OpenAI | |
import traceback | |
from typing import Dict, List, Tuple, Any | |
import re | |
from datetime import datetime | |
import threading | |
import queue | |
import html | |
import sys | |
import os | |
# Force stdout to use UTF-8 encoding to handle Unicode characters | |
if sys.stdout.encoding != 'utf-8': | |
sys.stdout = open(sys.stdout.fileno(), mode='w', encoding='utf-8', buffering=1) | |
class DatabaseQueryAgent: | |
def __init__(self, db_path: str = "innovativeskills.db"): | |
self.db_path = db_path | |
self.client = None | |
# Available models | |
self.models = { | |
"llama": "meta-llama/llama-3.3-70b-instruct:free", | |
"mistral": "mistralai/mistral-7b-instruct:free", | |
"gemma": "google/gemma-2-9b-it:free" # Verification model | |
} | |
# Initialize database connection | |
self.init_db_connection() | |
def init_db_connection(self): | |
"""Initialize database connection with UTF-8 encoding""" | |
try: | |
conn = sqlite3.connect(self.db_path, check_same_thread=False) | |
conn.execute("PRAGMA encoding = 'UTF-8';") | |
cursor = conn.cursor() | |
# Load table metadata | |
self.table_metadata = self.get_table_metadata(conn, cursor) | |
self.column_metadata = self.get_column_metadata(conn, cursor) | |
self.actual_schema = self.get_actual_schema(conn, cursor) | |
conn.close() | |
except Exception as e: | |
print(f"Database initialization error: {e}") | |
self.table_metadata = {} | |
self.column_metadata = {} | |
self.actual_schema = {} | |
def get_db_connection(self): | |
"""Get a new database connection with UTF-8 encoding""" | |
conn = sqlite3.connect(self.db_path, check_same_thread=False) | |
conn.execute("PRAGMA encoding = 'UTF-8';") | |
return conn | |
def get_actual_schema(self, conn, cursor) -> Dict: | |
"""Get actual database schema""" | |
try: | |
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%'") | |
tables = [row[0] for row in cursor.fetchall()] | |
schema = {} | |
for table in tables: | |
cursor.execute(f"PRAGMA table_info({table})") | |
columns = cursor.fetchall() | |
try: | |
cursor.execute(f"SELECT * FROM {table} LIMIT 3") | |
sample_data = cursor.fetchall() | |
except Exception: | |
sample_data = [] | |
try: | |
cursor.execute(f"SELECT COUNT(*) FROM {table}") | |
row_count = cursor.fetchone()[0] | |
except Exception: | |
row_count = 0 | |
schema[table] = { | |
'columns': [{'name': col[1], 'type': col[2], 'notnull': col[3], 'pk': col[5]} for col in columns], | |
'sample_data': sample_data, | |
'row_count': row_count | |
} | |
return schema | |
except Exception as e: | |
print(f"Error getting actual schema: {e}") | |
return {} | |
def get_table_metadata(self, conn, cursor) -> Dict: | |
"""Get table metadata""" | |
try: | |
query = """ | |
SELECT table_name, domain, description, row_count | |
FROM table_catalog | |
WHERE table_name NOT IN ('table_catalog', 'column_catalog') | |
""" | |
results = cursor.execute(query).fetchall() | |
metadata = {} | |
for table_name, domain, description, row_count in results: | |
metadata[table_name] = { | |
'domain': domain, | |
'description': description, | |
'row_count': row_count | |
} | |
return metadata | |
except Exception as e: | |
print(f"Error loading table metadata: {e}") | |
return {} | |
def get_column_metadata(self, conn, cursor) -> Dict: | |
"""Get column metadata""" | |
try: | |
query = """ | |
SELECT table_name, column_name, data_type, is_foreign_key, references_table, description | |
FROM column_catalog | |
""" | |
results = cursor.execute(query).fetchall() | |
metadata = {} | |
for table_name, column_name, data_type, is_fk, ref_table, description in results: | |
if table_name not in metadata: | |
metadata[table_name] = [] | |
metadata[table_name].append({ | |
'name': column_name, | |
'type': data_type, | |
'is_foreign_key': bool(is_fk), | |
'references': ref_table, | |
'description': description | |
}) | |
return metadata | |
except Exception as e: | |
print(f"Error loading column metadata: {e}") | |
return {} | |
def setup_client(self, api_key: str): | |
"""Setup OpenRouter client""" | |
self.client = OpenAI( | |
base_url="https://openrouter.ai/api/v1", | |
api_key=api_key, | |
) | |
def get_relevant_tables_for_query(self, query: str) -> str: | |
"""Analyze query and return relevant table info""" | |
query_lower = query.lower() | |
relevant_tables = [] | |
keywords = { | |
'customer': ['customer', 'client', 'buyer', 'user'], | |
'order': ['order', 'purchase', 'transaction', 'sale'], | |
'product': ['product', 'item', 'inventory', 'stock'], | |
'employee': ['employee', 'staff', 'worker', 'personnel'], | |
'patient': ['patient', 'medical', 'health'], | |
'student': ['student', 'enrollment', 'grade', 'course'], | |
'supplier': ['supplier', 'vendor', 'provider'], | |
'shipping': ['shipping', 'delivery', 'logistics'], | |
'payment': ['payment', 'invoice', 'billing'], | |
'account': ['account', 'financial', 'balance'] | |
} | |
for concept, search_terms in keywords.items(): | |
if any(term in query_lower for term in search_terms): | |
for table_name in self.actual_schema.keys(): | |
table_lower = table_name.lower() | |
if any(term in table_lower for term in search_terms): | |
if table_name not in relevant_tables: | |
relevant_tables.append(table_name) | |
if not relevant_tables: | |
relevant_tables = [name for name, info in self.actual_schema.items() | |
if info['row_count'] > 10][:10] | |
schema_info = "" | |
for table in relevant_tables[:15]: | |
if table in self.actual_schema: | |
info = self.actual_schema[table] | |
columns_str = ", ".join([f"{col['name']}({col['type']})" for col in info['columns']]) | |
schema_info += f"\nTable: {table}\n" | |
schema_info += f" Columns: {columns_str}\n" | |
schema_info += f" Rows: {info['row_count']}\n" | |
if table in self.table_metadata: | |
meta = self.table_metadata[table] | |
schema_info += f" Domain: {meta['domain']}\n" | |
schema_info += f" Description: {meta['description']}\n" | |
if info['sample_data']: | |
schema_info += f" Sample: {info['sample_data'][0] if info['sample_data'] else 'No data'}\n" | |
return schema_info | |
def get_system_prompt(self, user_query: str) -> str: | |
"""Generate system prompt with actual schema""" | |
relevant_schema = self.get_relevant_tables_for_query(user_query) | |
return f"""You are an intelligent database query agent that specializes in identifying relevant tables and generating accurate SQL queries. | |
DATABASE SCHEMA INFORMATION: | |
{relevant_schema} | |
CRITICAL SQL RULES: | |
1. NEVER use reserved words as table aliases (like 'to', 'from', 'where', 'select', etc.) | |
2. Use descriptive aliases like 'cust', 'ord', 'prod' instead | |
3. Only JOIN tables if you can identify a logical relationship between them | |
4. If no clear JOIN relationship exists, use separate SELECT statements or UNION | |
5. Always use the EXACT column names shown in the schema | |
6. Do not assume foreign key relationships unless explicitly shown | |
CRITICAL: You MUST respond with ONLY a valid JSON object. No markdown, no explanations outside the JSON. | |
Your response must be exactly in this JSON format: | |
{{ | |
"analysis": "Brief analysis of the query and table selection reasoning", | |
"identified_tables": ["table1", "table2", "table3"], | |
"domains_involved": ["domain1", "domain2"], | |
"sql_query": "SELECT ... FROM ... WHERE ...", | |
"explanation": "Step-by-step explanation of the query logic", | |
"confidence": 0.95, | |
"alternative_queries": ["Alternative SQL if applicable"] | |
}} | |
IMPORTANT RULES: | |
1. Respond with ONLY valid JSON - no markdown formatting | |
2. Use ONLY the actual table names shown in the schema above | |
3. Use ONLY the actual column names shown in the schema above | |
4. Generate syntactically correct SQL queries with proper aliases | |
5. Focus on tables that actually exist and have relevant data | |
6. Include confidence scores between 0.0 and 1.0 | |
7. Provide clear explanations | |
8. Ensure table names in 'identified_tables' match those used in 'sql_query' | |
9. Check that columns referenced in SQL actually exist in the tables | |
10. If no perfect match exists, choose the closest relevant tables and explain the compromise | |
11. Avoid reserved word aliases like 'to', 'from', 'order', 'select' | |
QUERY ANALYSIS GUIDELINES: | |
- For customer/order queries: Look for tables with customer-related or order-related names and columns | |
- For employee queries: Look for tables with employee, staff, or HR-related names | |
- For product queries: Look for tables with product, inventory, or item-related names | |
- Always verify column names exist before using them in SQL | |
- Use proper JOIN syntax when combining tables, but only if logical relationships exist | |
- Include appropriate WHERE clauses when filtering is implied | |
- If unsure about relationships, prefer simpler queries or multiple separate queries""" | |
def extract_json_from_response(self, response_text: str) -> Dict: | |
"""Extract JSON from response text""" | |
try: | |
return json.loads(response_text) | |
except json.JSONDecodeError: | |
json_pattern = r'```json\s*(.*?)\s*```' | |
json_match = re.search(json_pattern, response_text, re.DOTALL) | |
if json_match: | |
try: | |
return json.loads(json_match.group(1)) | |
except json.JSONDecodeError: | |
pass | |
json_pattern = r'\{.*\}' | |
json_match = re.search(json_pattern, response_text, re.DOTALL) | |
if json_match: | |
try: | |
return json.loads(json_match.group(0)) | |
except json.JSONDecodeError: | |
pass | |
return self.create_fallback_response(response_text) | |
def create_fallback_response(self, response_text: str) -> Dict: | |
"""Create a fallback response when JSON parsing fails""" | |
sql_pattern = r'SELECT.*?(?:;|$)' | |
sql_match = re.search(sql_pattern, response_text, re.IGNORECASE | re.DOTALL) | |
sql_query = sql_match.group(0).strip(';') if sql_match else "" | |
identified_tables = [table_name for table_name in self.actual_schema.keys() | |
if table_name.lower() in response_text.lower()] | |
domains_involved = [self.table_metadata[table]['domain'] for table in identified_tables | |
if table in self.table_metadata and self.table_metadata[table]['domain'] not in domains_involved] | |
return { | |
"analysis": "Fallback analysis from unparseable response", | |
"identified_tables": identified_tables[:5], | |
"domains_involved": domains_involved[:3], | |
"sql_query": sql_query, | |
"explanation": "Response could not be parsed as JSON, extracted information where possible", | |
"confidence": 0.5, | |
"alternative_queries": [] | |
} | |
def validate_sql_query(self, sql_query: str, identified_tables: List[str]) -> Tuple[bool, str]: | |
"""Validate SQL query against schema""" | |
try: | |
if not sql_query.strip(): | |
return False, "Empty SQL query" | |
for table in identified_tables: | |
if table not in self.actual_schema: | |
return False, f"Table '{table}' does not exist in database" | |
sql_upper = sql_query.upper() | |
if not sql_upper.strip().startswith('SELECT'): | |
return False, "Only SELECT queries are allowed" | |
reserved_words = ['TO', 'FROM', 'WHERE', 'SELECT', 'ORDER', 'GROUP', 'HAVING', 'UNION', 'JOIN', 'ON'] | |
alias_pattern = r'(?:FROM|JOIN)\s+(\w+)\s+(\w+)' | |
aliases = re.findall(alias_pattern, sql_query, re.IGNORECASE) | |
for table, alias in aliases: | |
if alias.upper() in reserved_words: | |
return False, f"Cannot use reserved word '{alias}' as table alias" | |
for table in identified_tables: | |
if table in sql_query: | |
table_info = self.actual_schema[table] | |
available_columns = [col['name'] for col in table_info['columns']] | |
column_patterns = [ | |
rf'{re.escape(table)}\.(\w+)', | |
rf'\b(\w+)\.(\w+)', | |
rf'SELECT\s+([^FROM]+)' | |
] | |
for pattern in column_patterns: | |
matches = re.findall(pattern, sql_query, re.IGNORECASE) | |
for match in matches: | |
if isinstance(match, tuple): | |
column = match[1] if len(match) == 2 else match[0] if match else '' | |
else: | |
column = match | |
if column.upper() in ['*', 'COUNT', 'SUM', 'AVG', 'MAX', 'MIN', 'DISTINCT']: | |
continue | |
if column and column not in available_columns and f'{table}.{column}' in sql_query: | |
return False, f"Column '{column}' does not exist in table '{table}'" | |
return True, "Query validation passed" | |
except Exception as e: | |
return False, f"Validation error: {str(e)}" | |
def call_model(self, model_key: str, prompt: str, user_query: str) -> Dict: | |
"""Call specific model with prompt""" | |
try: | |
messages = [ | |
{"role": "system", "content": prompt}, | |
{"role": "user", "content": f"Query: {user_query}\n\nRespond with ONLY a valid JSON object following the exact format specified in the system prompt."} | |
] | |
completion = self.client.chat.completions.create( | |
model=self.models[model_key], | |
messages=messages, | |
temperature=0.1, | |
max_tokens=2000 | |
) | |
response = completion.choices[0].message.content.strip() | |
parsed_response = self.extract_json_from_response(response) | |
sql_query = parsed_response.get('sql_query', '') | |
identified_tables = parsed_response.get('identified_tables', []) | |
if sql_query: | |
is_valid, validation_message = self.validate_sql_query(sql_query, identified_tables) | |
parsed_response['sql_validation'] = { | |
'is_valid': is_valid, | |
'message': validation_message | |
} | |
return { | |
"success": True, | |
"response": parsed_response, | |
"raw_response": response, | |
"model": model_key | |
} | |
except Exception as e: | |
return { | |
"success": False, | |
"error": str(e), | |
"model": model_key | |
} | |
def verify_response(self, api_key: str, original_query: str, llama_response: Dict, mistral_response: Dict) -> Dict: | |
"""Use Gemma to verify responses""" | |
self.setup_client(api_key) | |
relevant_schema = self.get_relevant_tables_for_query(original_query) | |
verification_prompt = f"""You are a database query verification expert. You have access to the actual database schema and must verify responses against it. | |
ACTUAL DATABASE SCHEMA: | |
{relevant_schema} | |
ORIGINAL QUERY: {original_query} | |
LLAMA RESPONSE: {json.dumps(llama_response.get('response', {}), indent=2)} | |
MISTRAL RESPONSE: {json.dumps(mistral_response.get('response', {}), indent=2)} | |
Verify these responses against the ACTUAL schema above. Check: | |
1. Do the table names actually exist in the schema? | |
2. Do the column names actually exist in those tables? | |
3. Are the table selections appropriate for the query? | |
4. Is the SQL syntax correct? | |
5. Are table aliases proper (not reserved words)? | |
Respond with ONLY a valid JSON object: | |
{{ | |
"verification_summary": "Overall assessment based on actual schema", | |
"table_selection_accuracy": "Assessment of table choices against actual schema", | |
"sql_correctness": "SQL syntax and schema validation", | |
"consistency_check": "Comparison between responses", | |
"recommended_response": "llama, mistral, or neither", | |
"confidence_score": 0.85, | |
"suggested_improvements": ["improvement1", "improvement2"], | |
"potential_issues": ["issue1", "issue2"], | |
"schema_compliance": "Assessment of how well responses match actual schema" | |
}}""" | |
return self.call_model("gemma", verification_prompt, "Verify the above responses against the actual database schema.") | |
def execute_query_in_thread(self, sql_query: str, result_queue: queue.Queue): | |
"""Execute SQL query in a thread""" | |
try: | |
if not sql_query.strip().upper().startswith('SELECT'): | |
result_queue.put((False, "Only SELECT queries are allowed")) | |
return | |
sql_query = sql_query.strip().rstrip(';') | |
conn = self.get_db_connection() | |
try: | |
df = pd.read_sql_query(sql_query, conn) | |
result_queue.put((True, df)) | |
except Exception as e: | |
result_queue.put((False, str(e))) | |
finally: | |
conn.close() | |
except Exception as e: | |
result_queue.put((False, f"Query execution error: {str(e)}")) | |
def execute_query(self, sql_query: str) -> Tuple[bool, Any]: | |
"""Execute SQL query using thread-safe approach""" | |
try: | |
result_queue = queue.Queue() | |
thread = threading.Thread( | |
target=self.execute_query_in_thread, | |
args=(sql_query, result_queue) | |
) | |
thread.start() | |
thread.join(timeout=30) | |
if thread.is_alive(): | |
return False, "Query execution timed out" | |
if not result_queue.empty(): | |
return result_queue.get() | |
else: | |
return False, "No result returned from query execution" | |
except Exception as e: | |
return False, f"Execution error: {str(e)}" | |
def process_query(self, api_key: str, user_query: str) -> Dict: | |
"""Process user query""" | |
if not api_key: | |
return {"error": "Please provide OpenRouter API key"} | |
try: | |
self.setup_client(api_key) | |
system_prompt = self.get_system_prompt(user_query) | |
llama_result = self.call_model("llama", system_prompt, user_query) | |
mistral_result = self.call_model("mistral", system_prompt, user_query) | |
verification_result = self.verify_response(api_key, user_query, llama_result, mistral_result) | |
execution_results = {} | |
for model_name, result in [("llama", llama_result), ("mistral", mistral_result)]: | |
if result.get("success") and result.get("response", {}).get("sql_query"): | |
sql_query = result["response"]["sql_query"] | |
validation_info = result["response"].get("sql_validation", {}) | |
if sql_query.strip(): | |
if validation_info.get("is_valid", True): | |
success, data = self.execute_query(sql_query) | |
execution_results[model_name] = { | |
"success": success, | |
"data": data.to_dict('records') if success and isinstance(data, pd.DataFrame) else str(data), | |
"row_count": len(data) if success and isinstance(data, pd.DataFrame) else 0, | |
"sql_query": sql_query, | |
"validation": validation_info | |
} | |
else: | |
execution_results[model_name] = { | |
"success": False, | |
"data": f"Query validation failed: {validation_info.get('message', 'Unknown error')}", | |
"row_count": 0, | |
"sql_query": sql_query, | |
"validation": validation_info | |
} | |
else: | |
execution_results[model_name] = { | |
"success": False, | |
"data": "No SQL query generated", | |
"row_count": 0, | |
"sql_query": "", | |
"validation": {"is_valid": False, "message": "Empty query"} | |
} | |
else: | |
execution_results[model_name] = { | |
"success": False, | |
"data": "Model failed to generate response", | |
"row_count": 0, | |
"sql_query": "", | |
"validation": {"is_valid": False, "message": "Model error"} | |
} | |
return { | |
"llama_response": llama_result, | |
"mistral_response": mistral_result, | |
"verification": verification_result, | |
"execution_results": execution_results, | |
"timestamp": datetime.now().isoformat(), | |
"schema_info": self.get_relevant_tables_for_query(user_query) | |
} | |
except Exception as e: | |
return {"error": f"Processing error: {str(e)}", "traceback": traceback.format_exc()} | |
def response_to_markdown(response_dict: Dict) -> str: | |
"""Convert model response to Markdown""" | |
if not response_dict.get("success", False): | |
return f"**Error**: {response_dict.get('error', 'Unknown error')}" | |
response = response_dict.get("response", {}) | |
markdown = "**Query Analysis Results**\n\n" | |
markdown += f"- **Analysis**: {response.get('analysis', 'N/A')}\n\n" | |
identified_tables = response.get('identified_tables', []) | |
markdown += f"- **Identified Tables**: {', '.join(identified_tables) if identified_tables else 'None'}\n\n" | |
domains_involved = response.get('domains_involved', []) | |
markdown += f"- **Domains Involved**: {', '.join(domains_involved) if domains_involved else 'None'}\n\n" | |
sql_query = response.get('sql_query', '') | |
if sql_query: | |
markdown += "- **SQL Query**:\n\n```sql\n" + sql_query + "\n```\n\n" | |
else: | |
markdown += "- **SQL Query**: None\n\n" | |
markdown += f"- **Explanation**: {response.get('explanation', 'N/A')}\n\n" | |
markdown += f"- **Confidence**: {response.get('confidence', 'N/A')}\n\n" | |
alternative_queries = response.get('alternative_queries', []) | |
if alternative_queries: | |
markdown += "- **Alternative Queries**:\n" | |
for query in alternative_queries: | |
markdown += f" - {query}\n" | |
else: | |
markdown += "- **Alternative Queries**: None\n" | |
validation = response.get('sql_validation', {}) | |
if validation: | |
is_valid = validation.get('is_valid', False) | |
message = validation.get('message', 'N/A') | |
markdown += f"\n- **SQL Validation**: {'Passed' if is_valid else 'Failed'} - {message}\n" | |
return markdown | |
def verification_to_markdown(verification_dict: Dict) -> str: | |
"""Convert verification response to Markdown""" | |
if not verification_dict.get("success", False): | |
return f"**Error**: {verification_dict.get('error', 'Unknown error')}" | |
response = verification_dict.get("response", {}) | |
markdown = "**Verification Results**\n\n" | |
markdown += f"- **Verification Summary**: {response.get('verification_summary', 'N/A')}\n\n" | |
markdown += f"- **Table Selection Accuracy**: {response.get('table_selection_accuracy', 'N/A')}\n\n" | |
markdown += f"- **SQL Correctness**: {response.get('sql_correctness', 'N/A')}\n\n" | |
markdown += f"- **Consistency Check**: {response.get('consistency_check', 'N/A')}\n\n" | |
markdown += f"- **Recommended Response**: {response.get('recommended_response', 'N/A')}\n\n" | |
markdown += f"- **Confidence Score**: {response.get('confidence_score', 'N/A')}\n\n" | |
suggested_improvements = response.get('suggested_improvements', []) | |
if suggested_improvements: | |
markdown += "- **Suggested Improvements**:\n" | |
for improvement in suggested_improvements: | |
markdown += f" - {improvement}\n" | |
else: | |
markdown += "- **Suggested Improvements**: None\n" | |
potential_issues = response.get('potential_issues', []) | |
if potential_issues: | |
markdown += "- **Potential Issues**:\n" | |
for issue in potential_issues: | |
markdown += f" - {issue}\n" | |
else: | |
markdown += "- **Potential Issues**: None\n" | |
markdown += f"- **Schema Compliance**: {response.get('schema_compliance', 'N/A')}\n" | |
return markdown | |
def create_gradio_interface(): | |
"""Create Gradio interface""" | |
agent = DatabaseQueryAgent() | |
sample_queries = [ | |
"Find all customers from customer tables", | |
"Show me employee information from HR tables", | |
"Get patient data from healthcare tables", | |
"List all products with their details", | |
"Find students enrolled in courses", | |
"Show financial transaction records", | |
"Get shipping information for deliveries", | |
"Find all suppliers and their information", | |
"Show retail store data", | |
"Get manufacturing production records" | |
] | |
def process_user_query(api_key, query): | |
"""Process query and return formatted results""" | |
if not query.strip(): | |
return "Please enter a query", "", "", "", "", "" | |
results = agent.process_query(api_key, query) | |
if "error" in results: | |
return f"**Error**: {results['error']}", "", "", "", "", "" | |
# Format responses as Markdown | |
llama_markdown = response_to_markdown(results.get("llama_response", {})) | |
mistral_markdown = response_to_markdown(results.get("mistral_response", {})) | |
verification_markdown = verification_to_markdown(results.get("verification", {})) | |
# Format execution results | |
exec_results = results.get("execution_results", {}) | |
execution_formatted = "" | |
for model, result in exec_results.items(): | |
execution_formatted += f"\n=== {model.upper()} EXECUTION ===\n" | |
execution_formatted += f"SQL Query: {result.get('sql_query', 'N/A')}\n" | |
validation = result.get('validation', {}) | |
if validation.get('is_valid'): | |
execution_formatted += f"β Query Validation: PASSED\n" | |
else: | |
execution_formatted += f"β Query Validation: FAILED - {validation.get('message', 'Unknown error')}\n" | |
if result["success"]: | |
execution_formatted += f"β Execution: Success! Retrieved {result['row_count']} rows\n" | |
if result["row_count"] > 0: | |
sample_data = result['data'][:3] if isinstance(result['data'], list) else [] | |
execution_formatted += f"Sample data:\n{json.dumps(sample_data, indent=2)}\n" | |
else: | |
execution_formatted += "No data returned (empty result set)\n" | |
else: | |
execution_formatted += f"β Execution Error: {result['data']}\n" | |
execution_formatted += "\n" | |
if not execution_formatted: | |
execution_formatted = "No queries were executed. Check if valid SQL was generated." | |
schema_info = results.get('schema_info', 'No schema information available') | |
# Format summary as Markdown | |
verification_resp = results.get('verification', {}).get('response', {}) | |
summary = f""" | |
**π QUERY ANALYSIS COMPLETE** | |
ββββββββββββββββββββββββ | |
**π Models Used**: Llama 3.1 8B, Mistral 7B, Gemma 2 9B (verification) | |
**β° Processed**: {results.get('timestamp', 'N/A')} | |
**π― Verification Summary**: | |
{verification_resp.get('verification_summary', 'N/A')} | |
**π‘ Recommended Model**: {verification_resp.get('recommended_response', 'N/A')} | |
**π Confidence**: {verification_resp.get('confidence_score', 'N/A')} | |
**ποΈ Schema Compliance**: {verification_resp.get('schema_compliance', 'N/A')} | |
**ποΈ Query Execution Status**: | |
{len(exec_results)} queries attempted | |
""" | |
return summary, llama_markdown, mistral_markdown, verification_markdown, execution_formatted, schema_info | |
with gr.Blocks( | |
title="Fixed Intelligent Database Query Agent", | |
theme=gr.themes.Soft(), | |
css=""" | |
.gradio-container { | |
max-width: 1200px !important; | |
margin: 0 auto !important; | |
} | |
.result-box { | |
background-color: #f8f9fa; | |
border: 1px solid #dee2e6; | |
border-radius: 8px; | |
padding: 15px; | |
} | |
""" | |
) as interface: | |
gr.HTML(""" | |
<div style="text-align: center; padding: 20px; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); color: white; border-radius: 10px; margin-bottom: 20px;"> | |
<h1>π€ Fixed Intelligent Database Query Agent</h1> | |
<p>AI-powered agent that intelligently selects relevant tables from 100+ tables and generates optimized SQL queries</p> | |
<p><strong>Database:</strong> 100 tables across 10 business domains | <strong>Models:</strong> Llama 3.1 8B + Mistral 7B + Gemma 2 9B</p> | |
<p><strong>β FIXED:</strong> Reserved Word Aliases | Enhanced Column Validation | Better SQL Syntax Checking</p> | |
</div> | |
""") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
api_key_input = gr.Textbox( | |
label="π OpenRouter API Key", | |
type="password", | |
placeholder="Enter your OpenRouter API key...", | |
info="Get your free API key from openrouter.ai" | |
) | |
query_input = gr.Textbox( | |
label="π¬ Database Query", | |
placeholder="Enter your natural language query...", | |
lines=3, | |
info="Example: 'Find all customers who placed orders in the last month'" | |
) | |
with gr.Row(): | |
submit_btn = gr.Button("π Process Query", variant="primary", size="lg") | |
clear_btn = gr.Button("ποΈ Clear", variant="secondary") | |
gr.HTML("<h3>π Sample Test Queries</h3>") | |
sample_dropdown = gr.Dropdown( | |
choices=sample_queries, | |
label="Quick Test Examples", | |
info="Select a sample query to test the agent" | |
) | |
with gr.Column(scale=2): | |
summary_output = gr.Markdown(label="π Analysis Summary") | |
with gr.Tabs(): | |
with gr.Tab("π¦ Llama 3.1 8B Response"): | |
llama_output = gr.Markdown(label="Llama Response") | |
with gr.Tab("π Mistral 7B Response"): | |
mistral_output = gr.Markdown(label="Mistral Response") | |
with gr.Tab("β Verification (Gemma 2 9B)"): | |
verification_output = gr.Markdown(label="Verification Analysis") | |
with gr.Tab("ποΈ Query Execution Results"): | |
execution_output = gr.Textbox( | |
label="Database Execution Results", | |
lines=15, | |
max_lines=20, | |
elem_classes=["result-box"] | |
) | |
with gr.Tab("π Database Schema"): | |
schema_output = gr.Textbox( | |
label="Relevant Database Schema", | |
lines=15, | |
max_lines=20, | |
elem_classes=["result-box"] | |
) | |
submit_btn.click( | |
fn=process_user_query, | |
inputs=[api_key_input, query_input], | |
outputs=[summary_output, llama_output, mistral_output, verification_output, execution_output, schema_output] | |
) | |
clear_btn.click( | |
fn=lambda: ("", "", "", "", "", "", ""), | |
outputs=[query_input, summary_output, llama_output, mistral_output, verification_output, execution_output, schema_output] | |
) | |
sample_dropdown.change( | |
fn=lambda x: x, | |
inputs=[sample_dropdown], | |
outputs=[query_input] | |
) | |
gr.HTML(""" | |
<div style="margin-top: 20px; padding: 15px; background-color: #f8f9fa; border-radius: 8px;"> | |
<h3>π― How to Use</h3> | |
<ol> | |
<li><strong>API Key:</strong> Get a free API key from <a href="https://openrouter.ai" target="_blank">openrouter.ai</a></li> | |
<li><strong>Query:</strong> Enter your natural language database query</li> | |
<li><strong>Process:</strong> The agent will analyze your query across 100+ tables and generate optimized SQL</li> | |
<li><strong>Results:</strong> View responses from multiple AI models, verification analysis, and actual query execution results</li> | |
</ol> | |
<p><strong>Features:</strong></p> | |
<ul> | |
<li>π§ Multi-model AI analysis (Llama, Mistral, Gemma)</li> | |
<li>π Intelligent table selection from 100+ tables</li> | |
<li>β SQL validation and syntax checking</li> | |
<li>ποΈ Real database query execution with results</li> | |
<li>π Cross-model verification and comparison</li> | |
</ul> | |
</div> | |
""") | |
return interface | |
def main(): | |
"""Main function to launch the application""" | |
print("π Starting Intelligent Database Query Agent...") | |
print("π Loading database schema and metadata...") | |
interface = create_gradio_interface() | |
print("β Database Query Agent Ready!") | |
print("π Access the interface at: http://localhost:7860") | |
print("π Don't forget to add your OpenRouter API key!") | |
interface.launch(share=True) | |
if __name__ == "__main__": | |
main() |