shukdevdattaEX commited on
Commit
3dd1898
Β·
verified Β·
1 Parent(s): 18242a0

Upload 2 files

Browse files
Files changed (3) hide show
  1. .gitattributes +1 -0
  2. app.py +749 -0
  3. innovativeskills.db +3 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ innovativeskills.db filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,749 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import sqlite3
3
+ import json
4
+ import pandas as pd
5
+ from openai import OpenAI
6
+ import traceback
7
+ from typing import Dict, List, Tuple, Any
8
+ import re
9
+ from datetime import datetime
10
+ import threading
11
+ import queue
12
+ import html
13
+ import sys
14
+ import os
15
+
16
+ # Force stdout to use UTF-8 encoding to handle Unicode characters
17
+ if sys.stdout.encoding != 'utf-8':
18
+ sys.stdout = open(sys.stdout.fileno(), mode='w', encoding='utf-8', buffering=1)
19
+
20
+ class DatabaseQueryAgent:
21
+ def __init__(self, db_path: str = "innovativeskills.db"):
22
+ self.db_path = db_path
23
+ self.client = None
24
+
25
+ # Available models
26
+ self.models = {
27
+ "llama": "meta-llama/llama-3.1-8b-instruct:free",
28
+ "mistral": "mistralai/mistral-7b-instruct:free",
29
+ "gemma": "google/gemma-2-9b-it:free" # Verification model
30
+ }
31
+
32
+ # Initialize database connection
33
+ self.init_db_connection()
34
+
35
+ def init_db_connection(self):
36
+ """Initialize database connection with UTF-8 encoding"""
37
+ try:
38
+ conn = sqlite3.connect(self.db_path, check_same_thread=False)
39
+ conn.execute("PRAGMA encoding = 'UTF-8';")
40
+ cursor = conn.cursor()
41
+
42
+ # Load table metadata
43
+ self.table_metadata = self.get_table_metadata(conn, cursor)
44
+ self.column_metadata = self.get_column_metadata(conn, cursor)
45
+ self.actual_schema = self.get_actual_schema(conn, cursor)
46
+
47
+ conn.close()
48
+
49
+ except Exception as e:
50
+ print(f"Database initialization error: {e}")
51
+ self.table_metadata = {}
52
+ self.column_metadata = {}
53
+ self.actual_schema = {}
54
+
55
+ def get_db_connection(self):
56
+ """Get a new database connection with UTF-8 encoding"""
57
+ conn = sqlite3.connect(self.db_path, check_same_thread=False)
58
+ conn.execute("PRAGMA encoding = 'UTF-8';")
59
+ return conn
60
+
61
+ def get_actual_schema(self, conn, cursor) -> Dict:
62
+ """Get actual database schema"""
63
+ try:
64
+ cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%'")
65
+ tables = [row[0] for row in cursor.fetchall()]
66
+ schema = {}
67
+ for table in tables:
68
+ cursor.execute(f"PRAGMA table_info({table})")
69
+ columns = cursor.fetchall()
70
+ try:
71
+ cursor.execute(f"SELECT * FROM {table} LIMIT 3")
72
+ sample_data = cursor.fetchall()
73
+ except Exception:
74
+ sample_data = []
75
+ try:
76
+ cursor.execute(f"SELECT COUNT(*) FROM {table}")
77
+ row_count = cursor.fetchone()[0]
78
+ except Exception:
79
+ row_count = 0
80
+ schema[table] = {
81
+ 'columns': [{'name': col[1], 'type': col[2], 'notnull': col[3], 'pk': col[5]} for col in columns],
82
+ 'sample_data': sample_data,
83
+ 'row_count': row_count
84
+ }
85
+ return schema
86
+ except Exception as e:
87
+ print(f"Error getting actual schema: {e}")
88
+ return {}
89
+
90
+ def get_table_metadata(self, conn, cursor) -> Dict:
91
+ """Get table metadata"""
92
+ try:
93
+ query = """
94
+ SELECT table_name, domain, description, row_count
95
+ FROM table_catalog
96
+ WHERE table_name NOT IN ('table_catalog', 'column_catalog')
97
+ """
98
+ results = cursor.execute(query).fetchall()
99
+ metadata = {}
100
+ for table_name, domain, description, row_count in results:
101
+ metadata[table_name] = {
102
+ 'domain': domain,
103
+ 'description': description,
104
+ 'row_count': row_count
105
+ }
106
+ return metadata
107
+ except Exception as e:
108
+ print(f"Error loading table metadata: {e}")
109
+ return {}
110
+
111
+ def get_column_metadata(self, conn, cursor) -> Dict:
112
+ """Get column metadata"""
113
+ try:
114
+ query = """
115
+ SELECT table_name, column_name, data_type, is_foreign_key, references_table, description
116
+ FROM column_catalog
117
+ """
118
+ results = cursor.execute(query).fetchall()
119
+ metadata = {}
120
+ for table_name, column_name, data_type, is_fk, ref_table, description in results:
121
+ if table_name not in metadata:
122
+ metadata[table_name] = []
123
+ metadata[table_name].append({
124
+ 'name': column_name,
125
+ 'type': data_type,
126
+ 'is_foreign_key': bool(is_fk),
127
+ 'references': ref_table,
128
+ 'description': description
129
+ })
130
+ return metadata
131
+ except Exception as e:
132
+ print(f"Error loading column metadata: {e}")
133
+ return {}
134
+
135
+ def setup_client(self, api_key: str):
136
+ """Setup OpenRouter client"""
137
+ self.client = OpenAI(
138
+ base_url="https://openrouter.ai/api/v1",
139
+ api_key=api_key,
140
+ )
141
+
142
+ def get_relevant_tables_for_query(self, query: str) -> str:
143
+ """Analyze query and return relevant table info"""
144
+ query_lower = query.lower()
145
+ relevant_tables = []
146
+ keywords = {
147
+ 'customer': ['customer', 'client', 'buyer', 'user'],
148
+ 'order': ['order', 'purchase', 'transaction', 'sale'],
149
+ 'product': ['product', 'item', 'inventory', 'stock'],
150
+ 'employee': ['employee', 'staff', 'worker', 'personnel'],
151
+ 'patient': ['patient', 'medical', 'health'],
152
+ 'student': ['student', 'enrollment', 'grade', 'course'],
153
+ 'supplier': ['supplier', 'vendor', 'provider'],
154
+ 'shipping': ['shipping', 'delivery', 'logistics'],
155
+ 'payment': ['payment', 'invoice', 'billing'],
156
+ 'account': ['account', 'financial', 'balance']
157
+ }
158
+ for concept, search_terms in keywords.items():
159
+ if any(term in query_lower for term in search_terms):
160
+ for table_name in self.actual_schema.keys():
161
+ table_lower = table_name.lower()
162
+ if any(term in table_lower for term in search_terms):
163
+ if table_name not in relevant_tables:
164
+ relevant_tables.append(table_name)
165
+ if not relevant_tables:
166
+ relevant_tables = [name for name, info in self.actual_schema.items()
167
+ if info['row_count'] > 10][:10]
168
+ schema_info = ""
169
+ for table in relevant_tables[:15]:
170
+ if table in self.actual_schema:
171
+ info = self.actual_schema[table]
172
+ columns_str = ", ".join([f"{col['name']}({col['type']})" for col in info['columns']])
173
+ schema_info += f"\nTable: {table}\n"
174
+ schema_info += f" Columns: {columns_str}\n"
175
+ schema_info += f" Rows: {info['row_count']}\n"
176
+ if table in self.table_metadata:
177
+ meta = self.table_metadata[table]
178
+ schema_info += f" Domain: {meta['domain']}\n"
179
+ schema_info += f" Description: {meta['description']}\n"
180
+ if info['sample_data']:
181
+ schema_info += f" Sample: {info['sample_data'][0] if info['sample_data'] else 'No data'}\n"
182
+ return schema_info
183
+
184
+ def get_system_prompt(self, user_query: str) -> str:
185
+ """Generate system prompt with actual schema"""
186
+ relevant_schema = self.get_relevant_tables_for_query(user_query)
187
+ return f"""You are an intelligent database query agent that specializes in identifying relevant tables and generating accurate SQL queries.
188
+
189
+ DATABASE SCHEMA INFORMATION:
190
+ {relevant_schema}
191
+
192
+ CRITICAL SQL RULES:
193
+ 1. NEVER use reserved words as table aliases (like 'to', 'from', 'where', 'select', etc.)
194
+ 2. Use descriptive aliases like 'cust', 'ord', 'prod' instead
195
+ 3. Only JOIN tables if you can identify a logical relationship between them
196
+ 4. If no clear JOIN relationship exists, use separate SELECT statements or UNION
197
+ 5. Always use the EXACT column names shown in the schema
198
+ 6. Do not assume foreign key relationships unless explicitly shown
199
+
200
+ CRITICAL: You MUST respond with ONLY a valid JSON object. No markdown, no explanations outside the JSON.
201
+
202
+ Your response must be exactly in this JSON format:
203
+ {{
204
+ "analysis": "Brief analysis of the query and table selection reasoning",
205
+ "identified_tables": ["table1", "table2", "table3"],
206
+ "domains_involved": ["domain1", "domain2"],
207
+ "sql_query": "SELECT ... FROM ... WHERE ...",
208
+ "explanation": "Step-by-step explanation of the query logic",
209
+ "confidence": 0.95,
210
+ "alternative_queries": ["Alternative SQL if applicable"]
211
+ }}
212
+
213
+ IMPORTANT RULES:
214
+ 1. Respond with ONLY valid JSON - no markdown formatting
215
+ 2. Use ONLY the actual table names shown in the schema above
216
+ 3. Use ONLY the actual column names shown in the schema above
217
+ 4. Generate syntactically correct SQL queries with proper aliases
218
+ 5. Focus on tables that actually exist and have relevant data
219
+ 6. Include confidence scores between 0.0 and 1.0
220
+ 7. Provide clear explanations
221
+ 8. Ensure table names in 'identified_tables' match those used in 'sql_query'
222
+ 9. Check that columns referenced in SQL actually exist in the tables
223
+ 10. If no perfect match exists, choose the closest relevant tables and explain the compromise
224
+ 11. Avoid reserved word aliases like 'to', 'from', 'order', 'select'
225
+
226
+ QUERY ANALYSIS GUIDELINES:
227
+ - For customer/order queries: Look for tables with customer-related or order-related names and columns
228
+ - For employee queries: Look for tables with employee, staff, or HR-related names
229
+ - For product queries: Look for tables with product, inventory, or item-related names
230
+ - Always verify column names exist before using them in SQL
231
+ - Use proper JOIN syntax when combining tables, but only if logical relationships exist
232
+ - Include appropriate WHERE clauses when filtering is implied
233
+ - If unsure about relationships, prefer simpler queries or multiple separate queries"""
234
+
235
+ def extract_json_from_response(self, response_text: str) -> Dict:
236
+ """Extract JSON from response text"""
237
+ try:
238
+ return json.loads(response_text)
239
+ except json.JSONDecodeError:
240
+ json_pattern = r'```json\s*(.*?)\s*```'
241
+ json_match = re.search(json_pattern, response_text, re.DOTALL)
242
+ if json_match:
243
+ try:
244
+ return json.loads(json_match.group(1))
245
+ except json.JSONDecodeError:
246
+ pass
247
+ json_pattern = r'\{.*\}'
248
+ json_match = re.search(json_pattern, response_text, re.DOTALL)
249
+ if json_match:
250
+ try:
251
+ return json.loads(json_match.group(0))
252
+ except json.JSONDecodeError:
253
+ pass
254
+ return self.create_fallback_response(response_text)
255
+
256
+ def create_fallback_response(self, response_text: str) -> Dict:
257
+ """Create a fallback response when JSON parsing fails"""
258
+ sql_pattern = r'SELECT.*?(?:;|$)'
259
+ sql_match = re.search(sql_pattern, response_text, re.IGNORECASE | re.DOTALL)
260
+ sql_query = sql_match.group(0).strip(';') if sql_match else ""
261
+ identified_tables = [table_name for table_name in self.actual_schema.keys()
262
+ if table_name.lower() in response_text.lower()]
263
+ domains_involved = [self.table_metadata[table]['domain'] for table in identified_tables
264
+ if table in self.table_metadata and self.table_metadata[table]['domain'] not in domains_involved]
265
+ return {
266
+ "analysis": "Fallback analysis from unparseable response",
267
+ "identified_tables": identified_tables[:5],
268
+ "domains_involved": domains_involved[:3],
269
+ "sql_query": sql_query,
270
+ "explanation": "Response could not be parsed as JSON, extracted information where possible",
271
+ "confidence": 0.5,
272
+ "alternative_queries": []
273
+ }
274
+
275
+ def validate_sql_query(self, sql_query: str, identified_tables: List[str]) -> Tuple[bool, str]:
276
+ """Validate SQL query against schema"""
277
+ try:
278
+ if not sql_query.strip():
279
+ return False, "Empty SQL query"
280
+ for table in identified_tables:
281
+ if table not in self.actual_schema:
282
+ return False, f"Table '{table}' does not exist in database"
283
+ sql_upper = sql_query.upper()
284
+ if not sql_upper.strip().startswith('SELECT'):
285
+ return False, "Only SELECT queries are allowed"
286
+ reserved_words = ['TO', 'FROM', 'WHERE', 'SELECT', 'ORDER', 'GROUP', 'HAVING', 'UNION', 'JOIN', 'ON']
287
+ alias_pattern = r'(?:FROM|JOIN)\s+(\w+)\s+(\w+)'
288
+ aliases = re.findall(alias_pattern, sql_query, re.IGNORECASE)
289
+ for table, alias in aliases:
290
+ if alias.upper() in reserved_words:
291
+ return False, f"Cannot use reserved word '{alias}' as table alias"
292
+ for table in identified_tables:
293
+ if table in sql_query:
294
+ table_info = self.actual_schema[table]
295
+ available_columns = [col['name'] for col in table_info['columns']]
296
+ column_patterns = [
297
+ rf'{re.escape(table)}\.(\w+)',
298
+ rf'\b(\w+)\.(\w+)',
299
+ rf'SELECT\s+([^FROM]+)'
300
+ ]
301
+ for pattern in column_patterns:
302
+ matches = re.findall(pattern, sql_query, re.IGNORECASE)
303
+ for match in matches:
304
+ if isinstance(match, tuple):
305
+ column = match[1] if len(match) == 2 else match[0] if match else ''
306
+ else:
307
+ column = match
308
+ if column.upper() in ['*', 'COUNT', 'SUM', 'AVG', 'MAX', 'MIN', 'DISTINCT']:
309
+ continue
310
+ if column and column not in available_columns and f'{table}.{column}' in sql_query:
311
+ return False, f"Column '{column}' does not exist in table '{table}'"
312
+ return True, "Query validation passed"
313
+ except Exception as e:
314
+ return False, f"Validation error: {str(e)}"
315
+
316
+ def call_model(self, model_key: str, prompt: str, user_query: str) -> Dict:
317
+ """Call specific model with prompt"""
318
+ try:
319
+ messages = [
320
+ {"role": "system", "content": prompt},
321
+ {"role": "user", "content": f"Query: {user_query}\n\nRespond with ONLY a valid JSON object following the exact format specified in the system prompt."}
322
+ ]
323
+ completion = self.client.chat.completions.create(
324
+ model=self.models[model_key],
325
+ messages=messages,
326
+ temperature=0.1,
327
+ max_tokens=2000
328
+ )
329
+ response = completion.choices[0].message.content.strip()
330
+ parsed_response = self.extract_json_from_response(response)
331
+ sql_query = parsed_response.get('sql_query', '')
332
+ identified_tables = parsed_response.get('identified_tables', [])
333
+ if sql_query:
334
+ is_valid, validation_message = self.validate_sql_query(sql_query, identified_tables)
335
+ parsed_response['sql_validation'] = {
336
+ 'is_valid': is_valid,
337
+ 'message': validation_message
338
+ }
339
+ return {
340
+ "success": True,
341
+ "response": parsed_response,
342
+ "raw_response": response,
343
+ "model": model_key
344
+ }
345
+ except Exception as e:
346
+ return {
347
+ "success": False,
348
+ "error": str(e),
349
+ "model": model_key
350
+ }
351
+
352
+ def verify_response(self, api_key: str, original_query: str, llama_response: Dict, mistral_response: Dict) -> Dict:
353
+ """Use Gemma to verify responses"""
354
+ self.setup_client(api_key)
355
+ relevant_schema = self.get_relevant_tables_for_query(original_query)
356
+ verification_prompt = f"""You are a database query verification expert. You have access to the actual database schema and must verify responses against it.
357
+
358
+ ACTUAL DATABASE SCHEMA:
359
+ {relevant_schema}
360
+
361
+ ORIGINAL QUERY: {original_query}
362
+
363
+ LLAMA RESPONSE: {json.dumps(llama_response.get('response', {}), indent=2)}
364
+
365
+ MISTRAL RESPONSE: {json.dumps(mistral_response.get('response', {}), indent=2)}
366
+
367
+ Verify these responses against the ACTUAL schema above. Check:
368
+ 1. Do the table names actually exist in the schema?
369
+ 2. Do the column names actually exist in those tables?
370
+ 3. Are the table selections appropriate for the query?
371
+ 4. Is the SQL syntax correct?
372
+ 5. Are table aliases proper (not reserved words)?
373
+
374
+ Respond with ONLY a valid JSON object:
375
+ {{
376
+ "verification_summary": "Overall assessment based on actual schema",
377
+ "table_selection_accuracy": "Assessment of table choices against actual schema",
378
+ "sql_correctness": "SQL syntax and schema validation",
379
+ "consistency_check": "Comparison between responses",
380
+ "recommended_response": "llama, mistral, or neither",
381
+ "confidence_score": 0.85,
382
+ "suggested_improvements": ["improvement1", "improvement2"],
383
+ "potential_issues": ["issue1", "issue2"],
384
+ "schema_compliance": "Assessment of how well responses match actual schema"
385
+ }}"""
386
+ return self.call_model("gemma", verification_prompt, "Verify the above responses against the actual database schema.")
387
+
388
+ def execute_query_in_thread(self, sql_query: str, result_queue: queue.Queue):
389
+ """Execute SQL query in a thread"""
390
+ try:
391
+ if not sql_query.strip().upper().startswith('SELECT'):
392
+ result_queue.put((False, "Only SELECT queries are allowed"))
393
+ return
394
+ sql_query = sql_query.strip().rstrip(';')
395
+ conn = self.get_db_connection()
396
+ try:
397
+ df = pd.read_sql_query(sql_query, conn)
398
+ result_queue.put((True, df))
399
+ except Exception as e:
400
+ result_queue.put((False, str(e)))
401
+ finally:
402
+ conn.close()
403
+ except Exception as e:
404
+ result_queue.put((False, f"Query execution error: {str(e)}"))
405
+
406
+ def execute_query(self, sql_query: str) -> Tuple[bool, Any]:
407
+ """Execute SQL query using thread-safe approach"""
408
+ try:
409
+ result_queue = queue.Queue()
410
+ thread = threading.Thread(
411
+ target=self.execute_query_in_thread,
412
+ args=(sql_query, result_queue)
413
+ )
414
+ thread.start()
415
+ thread.join(timeout=30)
416
+ if thread.is_alive():
417
+ return False, "Query execution timed out"
418
+ if not result_queue.empty():
419
+ return result_queue.get()
420
+ else:
421
+ return False, "No result returned from query execution"
422
+ except Exception as e:
423
+ return False, f"Execution error: {str(e)}"
424
+
425
+ def process_query(self, api_key: str, user_query: str) -> Dict:
426
+ """Process user query"""
427
+ if not api_key:
428
+ return {"error": "Please provide OpenRouter API key"}
429
+ try:
430
+ self.setup_client(api_key)
431
+ system_prompt = self.get_system_prompt(user_query)
432
+ llama_result = self.call_model("llama", system_prompt, user_query)
433
+ mistral_result = self.call_model("mistral", system_prompt, user_query)
434
+ verification_result = self.verify_response(api_key, user_query, llama_result, mistral_result)
435
+ execution_results = {}
436
+ for model_name, result in [("llama", llama_result), ("mistral", mistral_result)]:
437
+ if result.get("success") and result.get("response", {}).get("sql_query"):
438
+ sql_query = result["response"]["sql_query"]
439
+ validation_info = result["response"].get("sql_validation", {})
440
+ if sql_query.strip():
441
+ if validation_info.get("is_valid", True):
442
+ success, data = self.execute_query(sql_query)
443
+ execution_results[model_name] = {
444
+ "success": success,
445
+ "data": data.to_dict('records') if success and isinstance(data, pd.DataFrame) else str(data),
446
+ "row_count": len(data) if success and isinstance(data, pd.DataFrame) else 0,
447
+ "sql_query": sql_query,
448
+ "validation": validation_info
449
+ }
450
+ else:
451
+ execution_results[model_name] = {
452
+ "success": False,
453
+ "data": f"Query validation failed: {validation_info.get('message', 'Unknown error')}",
454
+ "row_count": 0,
455
+ "sql_query": sql_query,
456
+ "validation": validation_info
457
+ }
458
+ else:
459
+ execution_results[model_name] = {
460
+ "success": False,
461
+ "data": "No SQL query generated",
462
+ "row_count": 0,
463
+ "sql_query": "",
464
+ "validation": {"is_valid": False, "message": "Empty query"}
465
+ }
466
+ else:
467
+ execution_results[model_name] = {
468
+ "success": False,
469
+ "data": "Model failed to generate response",
470
+ "row_count": 0,
471
+ "sql_query": "",
472
+ "validation": {"is_valid": False, "message": "Model error"}
473
+ }
474
+ return {
475
+ "llama_response": llama_result,
476
+ "mistral_response": mistral_result,
477
+ "verification": verification_result,
478
+ "execution_results": execution_results,
479
+ "timestamp": datetime.now().isoformat(),
480
+ "schema_info": self.get_relevant_tables_for_query(user_query)
481
+ }
482
+ except Exception as e:
483
+ return {"error": f"Processing error: {str(e)}", "traceback": traceback.format_exc()}
484
+
485
+ def response_to_markdown(response_dict: Dict) -> str:
486
+ """Convert model response to Markdown"""
487
+ if not response_dict.get("success", False):
488
+ return f"**Error**: {response_dict.get('error', 'Unknown error')}"
489
+ response = response_dict.get("response", {})
490
+ markdown = "**Query Analysis Results**\n\n"
491
+ markdown += f"- **Analysis**: {response.get('analysis', 'N/A')}\n\n"
492
+ identified_tables = response.get('identified_tables', [])
493
+ markdown += f"- **Identified Tables**: {', '.join(identified_tables) if identified_tables else 'None'}\n\n"
494
+ domains_involved = response.get('domains_involved', [])
495
+ markdown += f"- **Domains Involved**: {', '.join(domains_involved) if domains_involved else 'None'}\n\n"
496
+ sql_query = response.get('sql_query', '')
497
+ if sql_query:
498
+ markdown += "- **SQL Query**:\n\n```sql\n" + sql_query + "\n```\n\n"
499
+ else:
500
+ markdown += "- **SQL Query**: None\n\n"
501
+ markdown += f"- **Explanation**: {response.get('explanation', 'N/A')}\n\n"
502
+ markdown += f"- **Confidence**: {response.get('confidence', 'N/A')}\n\n"
503
+ alternative_queries = response.get('alternative_queries', [])
504
+ if alternative_queries:
505
+ markdown += "- **Alternative Queries**:\n"
506
+ for query in alternative_queries:
507
+ markdown += f" - {query}\n"
508
+ else:
509
+ markdown += "- **Alternative Queries**: None\n"
510
+ validation = response.get('sql_validation', {})
511
+ if validation:
512
+ is_valid = validation.get('is_valid', False)
513
+ message = validation.get('message', 'N/A')
514
+ markdown += f"\n- **SQL Validation**: {'Passed' if is_valid else 'Failed'} - {message}\n"
515
+ return markdown
516
+
517
+ def verification_to_markdown(verification_dict: Dict) -> str:
518
+ """Convert verification response to Markdown"""
519
+ if not verification_dict.get("success", False):
520
+ return f"**Error**: {verification_dict.get('error', 'Unknown error')}"
521
+ response = verification_dict.get("response", {})
522
+ markdown = "**Verification Results**\n\n"
523
+ markdown += f"- **Verification Summary**: {response.get('verification_summary', 'N/A')}\n\n"
524
+ markdown += f"- **Table Selection Accuracy**: {response.get('table_selection_accuracy', 'N/A')}\n\n"
525
+ markdown += f"- **SQL Correctness**: {response.get('sql_correctness', 'N/A')}\n\n"
526
+ markdown += f"- **Consistency Check**: {response.get('consistency_check', 'N/A')}\n\n"
527
+ markdown += f"- **Recommended Response**: {response.get('recommended_response', 'N/A')}\n\n"
528
+ markdown += f"- **Confidence Score**: {response.get('confidence_score', 'N/A')}\n\n"
529
+ suggested_improvements = response.get('suggested_improvements', [])
530
+ if suggested_improvements:
531
+ markdown += "- **Suggested Improvements**:\n"
532
+ for improvement in suggested_improvements:
533
+ markdown += f" - {improvement}\n"
534
+ else:
535
+ markdown += "- **Suggested Improvements**: None\n"
536
+ potential_issues = response.get('potential_issues', [])
537
+ if potential_issues:
538
+ markdown += "- **Potential Issues**:\n"
539
+ for issue in potential_issues:
540
+ markdown += f" - {issue}\n"
541
+ else:
542
+ markdown += "- **Potential Issues**: None\n"
543
+ markdown += f"- **Schema Compliance**: {response.get('schema_compliance', 'N/A')}\n"
544
+ return markdown
545
+
546
+ def create_gradio_interface():
547
+ """Create Gradio interface"""
548
+ agent = DatabaseQueryAgent()
549
+ sample_queries = [
550
+ "Find all customers from customer tables",
551
+ "Show me employee information from HR tables",
552
+ "Get patient data from healthcare tables",
553
+ "List all products with their details",
554
+ "Find students enrolled in courses",
555
+ "Show financial transaction records",
556
+ "Get shipping information for deliveries",
557
+ "Find all suppliers and their information",
558
+ "Show retail store data",
559
+ "Get manufacturing production records"
560
+ ]
561
+
562
+ def process_user_query(api_key, query):
563
+ """Process query and return formatted results"""
564
+ if not query.strip():
565
+ return "Please enter a query", "", "", "", "", ""
566
+ results = agent.process_query(api_key, query)
567
+ if "error" in results:
568
+ return f"**Error**: {results['error']}", "", "", "", "", ""
569
+
570
+ # Format responses as Markdown
571
+ llama_markdown = response_to_markdown(results.get("llama_response", {}))
572
+ mistral_markdown = response_to_markdown(results.get("mistral_response", {}))
573
+ verification_markdown = verification_to_markdown(results.get("verification", {}))
574
+
575
+ # Format execution results
576
+ exec_results = results.get("execution_results", {})
577
+ execution_formatted = ""
578
+ for model, result in exec_results.items():
579
+ execution_formatted += f"\n=== {model.upper()} EXECUTION ===\n"
580
+ execution_formatted += f"SQL Query: {result.get('sql_query', 'N/A')}\n"
581
+ validation = result.get('validation', {})
582
+ if validation.get('is_valid'):
583
+ execution_formatted += f"βœ… Query Validation: PASSED\n"
584
+ else:
585
+ execution_formatted += f"❌ Query Validation: FAILED - {validation.get('message', 'Unknown error')}\n"
586
+ if result["success"]:
587
+ execution_formatted += f"βœ… Execution: Success! Retrieved {result['row_count']} rows\n"
588
+ if result["row_count"] > 0:
589
+ sample_data = result['data'][:3] if isinstance(result['data'], list) else []
590
+ execution_formatted += f"Sample data:\n{json.dumps(sample_data, indent=2)}\n"
591
+ else:
592
+ execution_formatted += "No data returned (empty result set)\n"
593
+ else:
594
+ execution_formatted += f"❌ Execution Error: {result['data']}\n"
595
+ execution_formatted += "\n"
596
+ if not execution_formatted:
597
+ execution_formatted = "No queries were executed. Check if valid SQL was generated."
598
+
599
+ schema_info = results.get('schema_info', 'No schema information available')
600
+
601
+ # Format summary as Markdown
602
+ verification_resp = results.get('verification', {}).get('response', {})
603
+ summary = f"""
604
+ **πŸ” QUERY ANALYSIS COMPLETE**
605
+
606
+ ━━━━━━━━━━━━━━━━━━━━━━━━
607
+
608
+ **πŸ“Š Models Used**: Llama 3.1 8B, Mistral 7B, Gemma 2 9B (verification)
609
+
610
+ **⏰ Processed**: {results.get('timestamp', 'N/A')}
611
+
612
+ **🎯 Verification Summary**:
613
+
614
+ {verification_resp.get('verification_summary', 'N/A')}
615
+
616
+ **πŸ’‘ Recommended Model**: {verification_resp.get('recommended_response', 'N/A')}
617
+
618
+ **πŸ“ˆ Confidence**: {verification_resp.get('confidence_score', 'N/A')}
619
+
620
+ **πŸ—„οΈ Schema Compliance**: {verification_resp.get('schema_compliance', 'N/A')}
621
+
622
+ **πŸ—„οΈ Query Execution Status**:
623
+
624
+ {len(exec_results)} queries attempted
625
+ """
626
+
627
+ return summary, llama_markdown, mistral_markdown, verification_markdown, execution_formatted, schema_info
628
+
629
+ with gr.Blocks(
630
+ title="Fixed Intelligent Database Query Agent",
631
+ theme=gr.themes.Soft(),
632
+ css="""
633
+ .gradio-container {
634
+ max-width: 1200px !important;
635
+ margin: 0 auto !important;
636
+ }
637
+ .result-box {
638
+ background-color: #f8f9fa;
639
+ border: 1px solid #dee2e6;
640
+ border-radius: 8px;
641
+ padding: 15px;
642
+ }
643
+ """
644
+ ) as interface:
645
+ gr.HTML("""
646
+ <div style="text-align: center; padding: 20px; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); color: white; border-radius: 10px; margin-bottom: 20px;">
647
+ <h1>πŸ€– Fixed Intelligent Database Query Agent</h1>
648
+ <p>AI-powered agent that intelligently selects relevant tables from 100+ tables and generates optimized SQL queries</p>
649
+ <p><strong>Database:</strong> 100 tables across 10 business domains | <strong>Models:</strong> Llama 3.1 8B + Mistral 7B + Gemma 2 9B</p>
650
+ <p><strong>βœ… FIXED:</strong> Reserved Word Aliases | Enhanced Column Validation | Better SQL Syntax Checking</p>
651
+ </div>
652
+ """)
653
+
654
+ with gr.Row():
655
+ with gr.Column(scale=1):
656
+ api_key_input = gr.Textbox(
657
+ label="πŸ”‘ OpenRouter API Key",
658
+ type="password",
659
+ placeholder="Enter your OpenRouter API key...",
660
+ info="Get your free API key from openrouter.ai"
661
+ )
662
+ query_input = gr.Textbox(
663
+ label="πŸ’¬ Database Query",
664
+ placeholder="Enter your natural language query...",
665
+ lines=3,
666
+ info="Example: 'Find all customers who placed orders in the last month'"
667
+ )
668
+ with gr.Row():
669
+ submit_btn = gr.Button("πŸš€ Process Query", variant="primary", size="lg")
670
+ clear_btn = gr.Button("πŸ—‘οΈ Clear", variant="secondary")
671
+ gr.HTML("<h3>πŸ“ Sample Test Queries</h3>")
672
+ sample_dropdown = gr.Dropdown(
673
+ choices=sample_queries,
674
+ label="Quick Test Examples",
675
+ info="Select a sample query to test the agent"
676
+ )
677
+
678
+ with gr.Column(scale=2):
679
+ summary_output = gr.Markdown(label="πŸ“Š Analysis Summary")
680
+ with gr.Tabs():
681
+ with gr.Tab("πŸ¦™ Llama 3.1 8B Response"):
682
+ llama_output = gr.Markdown(label="Llama Response")
683
+ with gr.Tab("🌟 Mistral 7B Response"):
684
+ mistral_output = gr.Markdown(label="Mistral Response")
685
+ with gr.Tab("βœ… Verification (Gemma 2 9B)"):
686
+ verification_output = gr.Markdown(label="Verification Analysis")
687
+ with gr.Tab("πŸ—„οΈ Query Execution Results"):
688
+ execution_output = gr.Textbox(
689
+ label="Database Execution Results",
690
+ lines=15,
691
+ max_lines=20,
692
+ elem_classes=["result-box"]
693
+ )
694
+ with gr.Tab("πŸ“‹ Database Schema"):
695
+ schema_output = gr.Textbox(
696
+ label="Relevant Database Schema",
697
+ lines=15,
698
+ max_lines=20,
699
+ elem_classes=["result-box"]
700
+ )
701
+
702
+ submit_btn.click(
703
+ fn=process_user_query,
704
+ inputs=[api_key_input, query_input],
705
+ outputs=[summary_output, llama_output, mistral_output, verification_output, execution_output, schema_output]
706
+ )
707
+ clear_btn.click(
708
+ fn=lambda: ("", "", "", "", "", "", ""),
709
+ outputs=[query_input, summary_output, llama_output, mistral_output, verification_output, execution_output, schema_output]
710
+ )
711
+ sample_dropdown.change(
712
+ fn=lambda x: x,
713
+ inputs=[sample_dropdown],
714
+ outputs=[query_input]
715
+ )
716
+ gr.HTML("""
717
+ <div style="margin-top: 20px; padding: 15px; background-color: #f8f9fa; border-radius: 8px;">
718
+ <h3>🎯 How to Use</h3>
719
+ <ol>
720
+ <li><strong>API Key:</strong> Get a free API key from <a href="https://openrouter.ai" target="_blank">openrouter.ai</a></li>
721
+ <li><strong>Query:</strong> Enter your natural language database query</li>
722
+ <li><strong>Process:</strong> The agent will analyze your query across 100+ tables and generate optimized SQL</li>
723
+ <li><strong>Results:</strong> View responses from multiple AI models, verification analysis, and actual query execution results</li>
724
+ </ol>
725
+ <p><strong>Features:</strong></p>
726
+ <ul>
727
+ <li>🧠 Multi-model AI analysis (Llama, Mistral, Gemma)</li>
728
+ <li>πŸ” Intelligent table selection from 100+ tables</li>
729
+ <li>βœ… SQL validation and syntax checking</li>
730
+ <li>πŸ—„οΈ Real database query execution with results</li>
731
+ <li>πŸ“Š Cross-model verification and comparison</li>
732
+ </ul>
733
+ </div>
734
+ """)
735
+
736
+ return interface
737
+
738
+ def main():
739
+ """Main function to launch the application"""
740
+ print("πŸš€ Starting Intelligent Database Query Agent...")
741
+ print("πŸ“Š Loading database schema and metadata...")
742
+ interface = create_gradio_interface()
743
+ print("βœ… Database Query Agent Ready!")
744
+ print("🌐 Access the interface at: http://localhost:7860")
745
+ print("πŸ”‘ Don't forget to add your OpenRouter API key!")
746
+ interface.launch(share=True)
747
+
748
+ if __name__ == "__main__":
749
+ main()
innovativeskills.db ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d8a720ae09bc17fc60b5cdc03f68e27652470418b7d18a7eb791aa8933421629
3
+ size 16896000