HusnaManakkot commited on
Commit
514fc02
·
verified ·
1 Parent(s): 887c95b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -6
app.py CHANGED
@@ -6,21 +6,28 @@ from datasets import load_dataset
6
  spider_dataset = load_dataset("spider", split='train') # Load a subset of the dataset
7
 
8
  # Extract schema information from the dataset
 
9
  table_names = set()
10
  column_names = set()
11
  for item in spider_dataset:
12
- for table in item['db']['table_names_original']:
13
- table_names.add(table)
14
- for column in item['db']['column_names_original']:
15
- column_names.add(column[1])
 
 
16
 
17
  # Load tokenizer and model
18
- tokenizer = AutoTokenizer.from_pretrained("mrm8488/t5-base-finetuned-wikiSQL") # Update this to a model fine-tuned on Spider if available
19
- model = AutoModelForSeq2SeqLM.from_pretrained("mrm8488/t5-base-finetuned-wikiSQL") # Update this to a model fine-tuned on Spider if available
20
 
21
  def post_process_sql_query(sql_query):
22
  # Modify the SQL query to match the dataset's schema
23
  # This is just an example and might need to be adapted based on the dataset and model output
 
 
 
 
24
  for table_name in table_names:
25
  if "TABLE" in sql_query:
26
  sql_query = sql_query.replace("TABLE", table_name)
 
6
  spider_dataset = load_dataset("spider", split='train') # Load a subset of the dataset
7
 
8
  # Extract schema information from the dataset
9
+ db_ids = set()
10
  table_names = set()
11
  column_names = set()
12
  for item in spider_dataset:
13
+ db_ids.add(item['db_id'])
14
+ for table in item['sql']['from']['table_units']:
15
+ if isinstance(table, list):
16
+ table_names.add(table[1])
17
+ for column in item['sql']['select'][1]:
18
+ column_names.add(column[1][1])
19
 
20
  # Load tokenizer and model
21
+ tokenizer = AutoTokenizer.from_pretrained("mrm8488/t5-base-finetuned-wikiSQL")
22
+ model = AutoModelForSeq2SeqLM.from_pretrained("mrm8488/t5-base-finetuned-wikiSQL")
23
 
24
  def post_process_sql_query(sql_query):
25
  # Modify the SQL query to match the dataset's schema
26
  # This is just an example and might need to be adapted based on the dataset and model output
27
+ for db_id in db_ids:
28
+ if "DB_ID" in sql_query:
29
+ sql_query = sql_query.replace("DB_ID", db_id)
30
+ break # Assuming only one database is referenced in the query
31
  for table_name in table_names:
32
  if "TABLE" in sql_query:
33
  sql_query = sql_query.replace("TABLE", table_name)