HusnaManakkot commited on
Commit
4198861
·
verified ·
1 Parent(s): f0a792b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -8
app.py CHANGED
@@ -9,22 +9,24 @@ spider_dataset = load_dataset("spider", split='train[:5]')
9
  tokenizer = AutoTokenizer.from_pretrained("mrm8488/t5-base-finetuned-wikiSQL")
10
  model = AutoModelForSeq2SeqLM.from_pretrained("mrm8488/t5-base-finetuned-wikiSQL")
11
 
12
- def generate_sql(query):
 
 
13
  input_text = "translate English to SQL: " + query
14
  inputs = tokenizer(input_text, return_tensors="pt", padding=True)
15
  outputs = model.generate(**inputs, max_length=512)
16
  sql_query = tokenizer.decode(outputs[0], skip_special_tokens=True)
17
- return sql_query
18
 
19
  # Create a Gradio interface
20
  interface = gr.Interface(
21
- fn=generate_sql,
22
- inputs=gr.Textbox(lines=2, placeholder="Enter your natural language query here..."),
23
- outputs="text",
24
- title="NL to SQL with T5",
25
- description="This model converts natural language queries into SQL. Enter your query!"
26
  )
27
 
28
  # Launch the app
29
  if __name__ == "__main__":
30
- interface.launch()
 
9
  tokenizer = AutoTokenizer.from_pretrained("mrm8488/t5-base-finetuned-wikiSQL")
10
  model = AutoModelForSeq2SeqLM.from_pretrained("mrm8488/t5-base-finetuned-wikiSQL")
11
 
12
+ def generate_sql_from_dataset(index):
13
+ # Get the natural language query from the dataset
14
+ query = spider_dataset[index]['question']
15
  input_text = "translate English to SQL: " + query
16
  inputs = tokenizer(input_text, return_tensors="pt", padding=True)
17
  outputs = model.generate(**inputs, max_length=512)
18
  sql_query = tokenizer.decode(outputs[0], skip_special_tokens=True)
19
+ return query, sql_query
20
 
21
  # Create a Gradio interface
22
  interface = gr.Interface(
23
+ fn=generate_sql_from_dataset,
24
+ inputs=gr.Number(label="Dataset Index (0-4)", default=0),
25
+ outputs=[gr.Textbox(label="Natural Language Query"), gr.Textbox(label="Generated SQL Query")],
26
+ title="NL to SQL with T5 using Spider Dataset",
27
+ description="This model converts natural language queries from the Spider dataset into SQL. Enter the index of the dataset entry (0-4)!"
28
  )
29
 
30
  # Launch the app
31
  if __name__ == "__main__":
32
+ interface.launch()