File size: 9,220 Bytes
e6f4fec |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 |
import sqlite3
import spacy
import re
from thefuzz import process
import numpy as np
from transformers import pipeline
# Load intent classification model
# Use Hugging Face's zero-shot pipeline for flexibility
classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
nlp = spacy.load("en_core_web_sm")
nlp_vectors = spacy.load("en_core_web_md")
# Define operator mappings
operator_mappings = {
"greater than": ">",
"less than": "<",
"equal to": "=",
"not equal to": "!=",
"starts with": "LIKE",
"ends with": "LIKE",
"contains": "LIKE",
"above": ">",
"below": "<",
"more than": ">",
"less than": "<",
"<": "<",
">": ">"
}
# Connect to SQLite database
def connect_to_db(db_path):
conn = sqlite3.connect(db_path)
return conn
# Fetch database schema
def fetch_schema(conn):
cursor = conn.cursor()
query = """
SELECT name
FROM sqlite_master
WHERE type='table';
"""
cursor.execute(query)
tables = cursor.fetchall()
schema = {}
for table in tables:
table_name = table[0]
cursor.execute(f"PRAGMA table_info({table_name});")
columns = cursor.fetchall()
schema[table_name] = [{"name": col[1], "type": col[2], "not_null": col[3], "default": col[4], "pk": col[5]} for col in columns]
return schema
def find_ai_synonym(token_text, table_schema):
"""Return the best-matching column from table_schema based on vector similarity."""
token_vec = nlp_vectors(token_text)[0].vector
best_col = None
best_score = 0.0
for col in table_schema:
col_vec = nlp_vectors(col)[0].vector
# Cosine similarity
score = token_vec.dot(col_vec) / (np.linalg.norm(token_vec) * np.linalg.norm(col_vec))
if score > best_score:
best_score = score
best_col = col
# Apply threshold
if best_score > 0.65:
return best_col
return None
def identify_table(question, schema_tables):
# schema_tables = ["products", "users", "orders", ...]
table, score = process.extractOne(question, schema_tables)
if score > 80: # a comfortable threshold
return table
return None
def identify_columns(question, columns_for_table):
# columns_for_table = ["id", "price", "stock", "name", ...]
# For each token in question, fuzzy match to columns
matched_cols = []
tokens = question.lower().split()
for token in tokens:
col, score = process.extractOne(token, columns_for_table)
if score > 80:
matched_cols.append(col)
return matched_cols
def find_closest_column(token, table_schema):
# table_schema is a list of column names, e.g. ["price", "stock", "name"]
# This returns (best_match, score)
best_match, score = process.extractOne(token, table_schema)
# You can tune this threshold as needed (e.g. 70, 80, etc.)
if score > 90:
return best_match
return None
# Condition extraction with NLP
def extract_conditions(question, schema, table):
table_schema = [col["name"].lower() for col in schema.get(table, [])]
# Detect whether the user used 'AND' / 'OR'
# (case-insensitive, hence .lower() checks)
use_and = " and " in question.lower()
use_or = " or " in question.lower()
last_column = None
# Split on 'and' or 'or' to handle multiple conditions
condition_parts = re.split(r'\band\b|\bor\b', question, flags=re.IGNORECASE)
print(condition_parts)
conditions = []
for part in condition_parts:
part = part.strip()
# Use spaCy to tokenize each part
doc = nlp(part.lower())
tokens = [token.text for token in doc]
# Skip the recognized_table token if it appears in tokens
# so it won't be matched as a column
tokens = [t for t in tokens if t != table.lower()]
part_conditions = []
current_part_column = None
print(tokens)
for i, token in enumerate(tokens):
# Try synonyms/fuzzy, etc. to find a column
possible_col = find_ai_synonym(token, table_schema)
if possible_col:
current_part_column = possible_col
last_column = possible_col # update last_column
# Check for any matching operator phrase in this part
for phrase, sql_operator in operator_mappings.items():
if phrase in part.lower():
# Extract the value after the phrase
value_index = part.lower().find(phrase) + len(phrase)
value = part[value_index:].strip().split(" ")[0]
value = value.replace("'", "").replace('"', "").strip()
# Special handling for LIKE operators
if sql_operator == "LIKE":
if "starts with" in phrase:
value = f"'{value}%'"
elif "ends with" in phrase:
value = f"'%{value}'"
elif "contains" in phrase:
value = f"'%{value}%'"
# If we did not find a new column, fallback to last_column
column_to_use = current_part_column or last_column
if column_to_use:
# Add this condition to the list for this part
part_conditions.append(f"{column_to_use} {sql_operator} {value}")
# If multiple conditions are found in this part, join them with AND
# (e.g., "price > 100 AND stock < 50" within the same part)
if part_conditions:
conditions.append(" AND ".join(part_conditions))
# Finally, combine each part with AND or OR, depending on the user query
if use_and:
return " AND ".join(conditions)
elif use_or:
return " OR ".join(conditions)
else:
# If there's only one part or no explicit 'and'/'or', default to AND
return " AND ".join(conditions)
# Interpret user question using intent recognition
def interpret_question(question, schema):
# Define potential intents
intents = {
"describe_table": "Provide information about the columns and structure of a table.",
"list_table_data": "Fetch and display all data stored in a table.",
"count_records": "Count the number of records in a table.",
"fetch_column": "Fetch a specific column's data from a table."
}
# Use classifier to predict intent
labels = list(intents.keys())
result = classifier(question, labels)
predicted_intent = result["labels"][0]
table = identify_table(question, list(schema.keys()))
# Rule-based fallback for conditional queries
condition_keywords = list(operator_mappings.keys())
if any(keyword in question.lower() for keyword in condition_keywords):
predicted_intent = "list_table_data"
return {"intent": predicted_intent, "table": table}
# Handle different intents
def handle_intent(intent_data, schema, conn, question):
intent = intent_data["intent"]
table = intent_data["table"]
if not table:
return "I couldn't identify which table you're referring to."
if intent == "describe_table":
# Describe table structure
table_schema = schema[table]
description = [f"Table '{table}' has the following columns:"]
for col in table_schema:
col_details = f"- {col['name']} ({col['type']})"
if col['not_null']:
col_details += " [NOT NULL]"
if col['default'] is not None:
col_details += f" [DEFAULT: {col['default']}]"
if col['pk']:
col_details += " [PRIMARY KEY]"
description.append(col_details)
return "\n".join(description)
elif intent == "list_table_data":
# Check for conditions
condition = extract_conditions(question, schema, table)
cursor = conn.cursor()
query = f"SELECT * FROM {table}"
if condition:
query += f" WHERE {condition};"
else:
query += ";"
print(query)
cursor.execute(query)
return cursor.fetchall()
elif intent == "count_records":
# Count records in the table
cursor = conn.cursor()
cursor.execute(f"SELECT COUNT(*) FROM {table};")
return cursor.fetchone()
elif intent == "fetch_column":
return "Fetching specific column data is not yet implemented."
else:
return "I couldn't understand your question."
# Main function
def answer_question(question, conn, schema):
intent_data = interpret_question(question, schema)
print(intent_data)
return handle_intent(intent_data, schema, conn, question)
# Example Usage
if __name__ == "__main__":
db_path = "./ecommerce.db" # Replace with your SQLite database path
conn = connect_to_db(db_path)
schema = fetch_schema(conn)
print("Schema:", schema)
while True:
question = input("\nAsk a question about the database: ")
if question.lower() in ["exit", "quit"]:
break
answer = answer_question(question, conn, schema)
print("Answer:", answer)
|