HusnaManakkot commited on
Commit
e1f7e24
·
verified ·
1 Parent(s): b2ea7cc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -21
app.py CHANGED
@@ -1,36 +1,39 @@
1
  import gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
  from datasets import load_dataset
 
 
4
 
5
  # Load the Spider dataset
6
  spider_dataset = load_dataset("spider", split='train') # Load a subset of the dataset
7
 
8
- # Extract schema information from the dataset
9
- db_table_names = set()
10
- column_names = set()
11
- for item in spider_dataset:
12
- db_id = item['db_id']
13
- for table in item['table_names']:
14
- db_table_names.add((db_id, table))
15
- for column in item['column_names']:
16
- column_names.add(column[1])
17
 
18
  # Load tokenizer and model
19
  tokenizer = AutoTokenizer.from_pretrained("mrm8488/t5-base-finetuned-wikiSQL")
20
  model = AutoModelForSeq2SeqLM.from_pretrained("mrm8488/t5-base-finetuned-wikiSQL")
21
 
22
- def post_process_sql_query(sql_query):
23
  # Modify the SQL query to match the dataset's schema
24
- for db_id, table_name in db_table_names:
25
- if "TABLE" in sql_query:
26
- sql_query = sql_query.replace("TABLE", table_name)
27
- break # Assuming only one table is referenced in the query
28
- for column_name in column_names:
29
- if "COLUMN" in sql_query:
30
- sql_query = sql_query.replace("COLUMN", column_name, 1)
 
 
31
  return sql_query
32
 
33
- def generate_sql_from_user_input(query):
34
  # Generate SQL for the user's query
35
  input_text = "translate English to SQL: " + query
36
  inputs = tokenizer(input_text, return_tensors="pt", padding=True)
@@ -38,13 +41,13 @@ def generate_sql_from_user_input(query):
38
  sql_query = tokenizer.decode(outputs[0], skip_special_tokens=True)
39
 
40
  # Post-process the SQL query to match the dataset's schema
41
- sql_query = post_process_sql_query(sql_query)
42
  return sql_query
43
 
44
  # Create a Gradio interface
45
  interface = gr.Interface(
46
- fn=generate_sql_from_user_input,
47
- inputs=gr.Textbox(label="Enter your natural language query"),
48
  outputs=gr.Textbox(label="Generated SQL Query"),
49
  title="NL to SQL using Spider Dataset",
50
  description="This interface generates an SQL query from your natural language input based on the Spider dataset."
 
1
  import gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
  from datasets import load_dataset
4
+ import json
5
+ import os
6
 
7
  # Load the Spider dataset
8
  spider_dataset = load_dataset("spider", split='train') # Load a subset of the dataset
9
 
10
+ # Load the database schemas
11
+ db_schemas = {}
12
+ database_dir = 'path/to/database/folder'
13
+ for filename in os.listdir(database_dir):
14
+ if filename.endswith('.json'):
15
+ with open(os.path.join(database_dir, filename), 'r') as file:
16
+ db_schema = json.load(file)
17
+ db_schemas[db_schema['db_id']] = db_schema
 
18
 
19
  # Load tokenizer and model
20
  tokenizer = AutoTokenizer.from_pretrained("mrm8488/t5-base-finetuned-wikiSQL")
21
  model = AutoModelForSeq2SeqLM.from_pretrained("mrm8488/t5-base-finetuned-wikiSQL")
22
 
23
+ def post_process_sql_query(sql_query, db_id):
24
  # Modify the SQL query to match the dataset's schema
25
+ if db_id in db_schemas:
26
+ db_schema = db_schemas[db_id]
27
+ for table_name in db_schema['table_names_original']:
28
+ if "TABLE" in sql_query:
29
+ sql_query = sql_query.replace("TABLE", table_name)
30
+ break # Assuming only one table is referenced in the query
31
+ for column_name in db_schema['column_names_original']:
32
+ if "COLUMN" in sql_query:
33
+ sql_query = sql_query.replace("COLUMN", column_name[1], 1)
34
  return sql_query
35
 
36
+ def generate_sql_from_user_input(query, db_id):
37
  # Generate SQL for the user's query
38
  input_text = "translate English to SQL: " + query
39
  inputs = tokenizer(input_text, return_tensors="pt", padding=True)
 
41
  sql_query = tokenizer.decode(outputs[0], skip_special_tokens=True)
42
 
43
  # Post-process the SQL query to match the dataset's schema
44
+ sql_query = post_process_sql_query(sql_query, db_id)
45
  return sql_query
46
 
47
  # Create a Gradio interface
48
  interface = gr.Interface(
49
+ fn=lambda query, db_id: generate_sql_from_user_input(query, db_id),
50
+ inputs=[gr.Textbox(label="Enter your natural language query"), gr.Dropdown(label="Select Database ID", choices=list(db_schemas.keys()))],
51
  outputs=gr.Textbox(label="Generated SQL Query"),
52
  title="NL to SQL using Spider Dataset",
53
  description="This interface generates an SQL query from your natural language input based on the Spider dataset."